olm_outbound_group_session.cpp 19.3 KB
Newer Older
1
2
/*
 * Copyright 2016 OpenMarket Ltd
ylecollen's avatar
ylecollen committed
3
 * Copyright 2016 Vector Creations Ltd
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "olm_outbound_group_session.h"

pedroGitt's avatar
pedroGitt committed
20
using namespace AndroidOlmSdk;
21
22
23
24
25
26
27
28

/**
 * Release the session allocation made by initializeOutboundGroupSessionMemory().<br>
 * This method MUST be called when java counter part account instance is done.
 *
 */
JNIEXPORT void OLM_OUTBOUND_GROUP_SESSION_FUNC_DEF(releaseSessionJni)(JNIEnv *env, jobject thiz)
{
29
    LOGD("## releaseSessionJni(): OutBound group session IN");
30

31
    OlmOutboundGroupSession* sessionPtr = (OlmOutboundGroupSession*)getOutboundGroupSessionInstanceId(env,thiz);
32

33
34
35
36
37
38
39
    if (!sessionPtr)
    {
        LOGE(" ## releaseSessionJni(): failure - invalid outbound group session instance");
    }
    else
    {
        LOGD(" ## releaseSessionJni(): sessionPtr=%p",sessionPtr);
40

41
#ifdef ENABLE_JNI_LOG
42
43
        size_t retCode = olm_clear_outbound_group_session(sessionPtr);
        LOGD(" ## releaseSessionJni(): clear_outbound_group_session=%lu",static_cast<long unsigned int>(retCode));
44
#else
45
        olm_clear_outbound_group_session(sessionPtr);
46
#endif
47

48
49
50
51
        LOGD(" ## releaseSessionJni(): free IN");
        free(sessionPtr);
        LOGD(" ## releaseSessionJni(): free OUT");
    }
52
53
54
55
56
57
58
59
}

/**
* Initialize a new outbound group session and return it to JAVA side.<br>
* Since a C prt is returned as a jlong, special care will be taken
* to make the cast (OlmOutboundGroupSession* => jlong) platform independent.
* @return the initialized OlmOutboundGroupSession* instance if init succeed, NULL otherwise
**/
60
JNIEXPORT jlong OLM_OUTBOUND_GROUP_SESSION_FUNC_DEF(createNewSessionJni)(JNIEnv *env, jobject thiz)
61
62
{
    OlmOutboundGroupSession* sessionPtr = NULL;
63
64
    size_t sessionSize = 0;

65
    LOGD("## createNewSessionJni(): outbound group session IN");
66
    sessionSize = olm_outbound_group_session_size();
67

68
    if (0 == sessionSize)
69
    {
70
        LOGE(" ## createNewSessionJni(): failure - outbound group session size = 0");
71
    }
72
    else if (!(sessionPtr = (OlmOutboundGroupSession*)malloc(sessionSize)))
73
    {
74
75
        sessionPtr = olm_outbound_group_session(sessionPtr);
        LOGD(" ## createNewSessionJni(): success - outbound group session size=%lu",static_cast<long unsigned int>(sessionSize));
76
77
78
    }
    else
    {
79
        LOGE(" ## createNewSessionJni(): failure - outbound group session OOM");
80
81
82
83
84
85
    }

    return (jlong)(intptr_t)sessionPtr;
}

/**
86
 * Start a new outbound session.<br>
87
88
 * @return ERROR_CODE_OK if operation succeed, ERROR_CODE_KO otherwise
 */
89
JNIEXPORT jint OLM_OUTBOUND_GROUP_SESSION_FUNC_DEF(initOutboundGroupSessionJni)(JNIEnv *env, jobject thiz)
90
91
{
    jint retCode = ERROR_CODE_KO;
92
93

    LOGD("## initOutboundGroupSessionJni(): IN");
94

95
96
97
    OlmOutboundGroupSession *sessionPtr = (OlmOutboundGroupSession*)getOutboundGroupSessionInstanceId(env,thiz);

    if (!sessionPtr)
98
    {
99
        LOGE(" ## initOutboundGroupSessionJni(): failure - invalid outbound group session instance");
100
101
102
103
104
    }
    else
    {
        // compute random buffer
        size_t randomLength = olm_init_outbound_group_session_random_length(sessionPtr);
105
106
        uint8_t *randomBuffPtr = NULL;

pedroGitt's avatar
pedroGitt committed
107
        LOGW(" ## initOutboundGroupSessionJni(): randomLength=%lu",static_cast<long unsigned int>(randomLength));
108
109

        if ((0 != randomLength) && !setRandomInBuffer(env, &randomBuffPtr, randomLength))
110
        {
111
            LOGE(" ## initOutboundGroupSessionJni(): failure - random buffer init");
112
113
114
        }
        else
        {
115
            if (0 == randomLength)
116
            {
117
                LOGW(" ## initOutboundGroupSessionJni(): random buffer is not required");
118
119
            }

120
            size_t sessionResult = olm_init_outbound_group_session(sessionPtr, randomBuffPtr, randomLength);
121
122

            if (sessionResult == olm_error()) {
123
                LOGE(" ## initOutboundGroupSessionJni(): failure - init outbound session creation  Msg=%s",(const char *)olm_outbound_group_session_last_error(sessionPtr));
124
125
126
127
            }
            else
            {
                retCode = ERROR_CODE_OK;
pedroGitt's avatar
pedroGitt committed
128
                LOGD(" ## initOutboundGroupSessionJni(): success - result=%lu", static_cast<long unsigned int>(sessionResult));
129
130
            }

131
132
            free(randomBuffPtr);
        }
133
    }
134
135
136
137
138
139
140

    return retCode;
}

/**
* Get a base64-encoded identifier for this outbound group session.
*/
141
JNIEXPORT jbyteArray OLM_OUTBOUND_GROUP_SESSION_FUNC_DEF(sessionIdentifierJni)(JNIEnv *env, jobject thiz)
142
143
144
{
    LOGD("## sessionIdentifierJni(): outbound group session IN");

145
146
147
    OlmOutboundGroupSession *sessionPtr = (OlmOutboundGroupSession*)getOutboundGroupSessionInstanceId(env,thiz);
    jbyteArray returnValue = 0;
    
148
    if (!sessionPtr)
149
    {
150
        LOGE(" ## sessionIdentifierJni(): failure - invalid outbound group session instance");
151
152
153
    }
    else
    {
154
155
        // get the size to alloc
        size_t lengthSessionId = olm_outbound_group_session_id_length(sessionPtr);
pedroGitt's avatar
pedroGitt committed
156
        LOGD(" ## sessionIdentifierJni(): outbound group session lengthSessionId=%lu",static_cast<long unsigned int>(lengthSessionId));
157

158
159
160
        uint8_t *sessionIdPtr =  (uint8_t*)malloc((lengthSessionId+1)*sizeof(uint8_t));

        if (!sessionIdPtr)
161
        {
162
           LOGE(" ## sessionIdentifierJni(): failure - outbound identifier allocation OOM");
163
164
165
        }
        else
        {
166
            size_t result = olm_outbound_group_session_id(sessionPtr, sessionIdPtr, lengthSessionId);
167

168
169
            if (result == olm_error())
            {
pedroGitt's avatar
pedroGitt committed
170
                LOGE(" ## sessionIdentifierJni(): failure - outbound group session identifier failure Msg=%s",reinterpret_cast<const char*>(olm_outbound_group_session_last_error(sessionPtr)));
171
172
173
174
175
            }
            else
            {
                // update length
                sessionIdPtr[result] = static_cast<char>('\0');
176
177
178
179

                returnValue = env->NewByteArray(result);
                env->SetByteArrayRegion(returnValue, 0 , result, (jbyte*)sessionIdPtr);

pedroGitt's avatar
pedroGitt committed
180
                LOGD(" ## sessionIdentifierJni(): success - outbound group session identifier result=%lu sessionId=%s",static_cast<long unsigned int>(result), reinterpret_cast<char*>(sessionIdPtr));
181
            }
182

183
184
185
186
            // free alloc
            free(sessionIdPtr);
        }
    }
187

188
    return returnValue;
189
190
191
}


192
193
194
195
196
197
198
/**
* Get the current message index for this session.<br>
* Each message is sent with an increasing index, this
* method returns the index for the next message.
* @return current session index
*/
JNIEXPORT jint OLM_OUTBOUND_GROUP_SESSION_FUNC_DEF(messageIndexJni)(JNIEnv *env, jobject thiz)
199
{
200
201
    OlmOutboundGroupSession *sessionPtr = NULL;
    jint indexRetValue = 0;
202

203
    LOGD("## messageIndexJni(): IN");
204

205
    if (!(sessionPtr = (OlmOutboundGroupSession*)getOutboundGroupSessionInstanceId(env,thiz)))
206
    {
207
        LOGE(" ## messageIndexJni(): failure - invalid outbound group session instance");
208
209
    }
    else
210
    {
211
        indexRetValue = static_cast<jint>(olm_outbound_group_session_message_index(sessionPtr));
212
    }
213

214
    LOGD(" ## messageIndexJni(): success - index=%d",indexRetValue);
215
216
217
218
219
220
221

    return indexRetValue;
}

/**
* Get the base64-encoded current ratchet key for this session.<br>
*/
222
JNIEXPORT jbyteArray OLM_OUTBOUND_GROUP_SESSION_FUNC_DEF(sessionKeyJni)(JNIEnv *env, jobject thiz)
223
224
{
    LOGD("## sessionKeyJni(): outbound group session IN");
225

226
    OlmOutboundGroupSession *sessionPtr = (OlmOutboundGroupSession*)getOutboundGroupSessionInstanceId(env,thiz);
227
    jbyteArray returnValue = 0;
228

229
    if (!sessionPtr)
230
    {
231
        LOGE(" ## sessionKeyJni(): failure - invalid outbound group session instance");
232
233
234
    }
    else
    {
235
236
        // get the size to alloc
        size_t sessionKeyLength = olm_outbound_group_session_key_length(sessionPtr);
pedroGitt's avatar
pedroGitt committed
237
        LOGD(" ## sessionKeyJni(): sessionKeyLength=%lu",static_cast<long unsigned int>(sessionKeyLength));
238

239
240
241
        uint8_t *sessionKeyPtr = (uint8_t*)malloc((sessionKeyLength+1)*sizeof(uint8_t));

        if (!sessionKeyPtr)
242
        {
243
           LOGE(" ## sessionKeyJni(): failure - session key allocation OOM");
244
245
246
        }
        else
        {
247
            size_t result = olm_outbound_group_session_key(sessionPtr, sessionKeyPtr, sessionKeyLength);
248

249
250
            if (result == olm_error())
            {
251
                LOGE(" ## sessionKeyJni(): failure - session key failure Msg=%s",(const char *)olm_outbound_group_session_last_error(sessionPtr));
252
253
254
255
256
            }
            else
            {
                // update length
                sessionKeyPtr[result] = static_cast<char>('\0');
pedroGitt's avatar
pedroGitt committed
257
                LOGD(" ## sessionKeyJni(): success - outbound group session key result=%lu sessionKey=%s",static_cast<long unsigned int>(result), reinterpret_cast<char*>(sessionKeyPtr));
258
259
260

                returnValue = env->NewByteArray(result);
                env->SetByteArrayRegion(returnValue, 0 , result, (jbyte*)sessionKeyPtr);
261
            }
262

263
264
            // free alloc
            free(sessionKeyPtr);
265
266
267
        }
    }

268
    return returnValue;
269
270
}

271
JNIEXPORT jbyteArray OLM_OUTBOUND_GROUP_SESSION_FUNC_DEF(encryptMessageJni)(JNIEnv *env, jobject thiz, jbyteArray aClearMsgBuffer, jobject aErrorMsg)
272
{
273
274
    LOGD("## encryptMessageJni(): IN");

275
276
    jbyteArray encryptedMsgRet = 0;

277
    OlmOutboundGroupSession *sessionPtr = NULL;
278
    jbyte* clearMsgPtr = NULL;
279

280
281
282
    jclass errorMsgJClass = 0;
    jmethodID errorMsgMethodId = 0;

283
    if (!(sessionPtr = (OlmOutboundGroupSession*)getOutboundGroupSessionInstanceId(env,thiz)))
284
    {
285
        LOGE(" ## encryptMessageJni(): failure - invalid outbound group session ptr=NULL");
286
    }
287
288
289
290
    else if (!aErrorMsg)
    {
        LOGE(" ## encryptMessageJni(): failure - invalid error output");
    }
291
    else if (!aClearMsgBuffer)
292
    {
293
        LOGE(" ## encryptMessageJni(): failure - invalid clear message");
294
    }
295
    else if (!(clearMsgPtr = env->GetByteArrayElements(aClearMsgBuffer, NULL)))
296
    {
297
        LOGE(" ## encryptMessageJni(): failure - clear message JNI allocation OOM");
298
    }
299
300
301
302
303
304
305
306
    else if (!(errorMsgJClass = env->GetObjectClass(aErrorMsg)))
    {
        LOGE(" ## encryptMessageJni(): failure - unable to get error class");
    }
    else if (!(errorMsgMethodId = env->GetMethodID(errorMsgJClass, "append", "(Ljava/lang/String;)Ljava/lang/StringBuffer;")))
    {
        LOGE(" ## encryptMessageJni(): failure - unable to get error method ID");
    }
307
308
    else
    {
309
        // get clear message length
310
        size_t clearMsgLength = (size_t)env->GetArrayLength(aClearMsgBuffer);
pedroGitt's avatar
pedroGitt committed
311
        LOGD(" ## encryptMessageJni(): clearMsgLength=%lu",static_cast<long unsigned int>(clearMsgLength));
312

313
314
        // compute max encrypted length
        size_t encryptedMsgLength = olm_group_encrypt_message_length(sessionPtr,clearMsgLength);
315
316
317
        uint8_t *encryptedMsgPtr = (uint8_t*)malloc((encryptedMsgLength+1)*sizeof(uint8_t));

        if (!encryptedMsgPtr)
318
        {
319
            LOGE(" ## encryptMessageJni(): failure - encryptedMsgPtr buffer OOM");
320
321
322
        }
        else
        {
pedroGitt's avatar
pedroGitt committed
323
            LOGD(" ## encryptMessageJni(): estimated encryptedMsgLength=%lu",static_cast<long unsigned int>(encryptedMsgLength));
324

325
            size_t encryptedLength = olm_group_encrypt(sessionPtr,
326
327
328
329
                                                       (uint8_t*)clearMsgPtr,
                                                       clearMsgLength,
                                                       encryptedMsgPtr,
                                                       encryptedMsgLength);
330
331


332
            if (encryptedLength == olm_error())
333
            {
334
335
336
337
338
339
340
341
342
                const char * errorMsgPtr = olm_outbound_group_session_last_error(sessionPtr);
                LOGE(" ## encryptMessageJni(): failure - olm_group_decrypt_max_plaintext_length Msg=%s",errorMsgPtr);

                jstring errorJstring = env->NewStringUTF(errorMsgPtr);

                if (errorJstring)
                {
                    env->CallObjectMethod(aErrorMsg, errorMsgMethodId, errorJstring);
                }
343
344
345
            }
            else
            {
346
                // update decrypted buffer size
347
                encryptedMsgPtr[encryptedLength] = static_cast<char>('\0');
348

pedroGitt's avatar
pedroGitt committed
349
                LOGD(" ## encryptMessageJni(): encrypted returnedLg=%lu plainTextMsgPtr=%s",static_cast<long unsigned int>(encryptedLength), reinterpret_cast<char*>(encryptedMsgPtr));
350
351
352

                encryptedMsgRet = env->NewByteArray(encryptedLength);
                env->SetByteArrayRegion(encryptedMsgRet, 0 , encryptedLength, (jbyte*)encryptedMsgPtr);
353
            }
354
355
356

            free(encryptedMsgPtr);
         }
357
    }
358
359

    // free alloc
360
    if (clearMsgPtr)
361
    {
362
        env->ReleaseByteArrayElements(aClearMsgBuffer, clearMsgPtr, JNI_ABORT);
363
364
    }

365
    return encryptedMsgRet;
366
367
368
}


369
370
371
372
373
374
/**
* Serialize and encrypt session instance into a base64 string.<br>
* @param aKey key used to encrypt the serialized session data
* @param[out] aErrorMsg error message set if operation failed
* @return a base64 string if operation succeed, null otherwise
**/
375
JNIEXPORT jstring OLM_OUTBOUND_GROUP_SESSION_FUNC_DEF(serializeDataWithKeyJni)(JNIEnv *env, jobject thiz, jbyteArray aKeyBuffer, jobject aErrorMsg)
376
377
378
379
380
{
    jstring pickledDataRetValue = 0;
    jclass errorMsgJClass = 0;
    jmethodID errorMsgMethodId = 0;
    jstring errorJstring = 0;
381
    jbyte* keyPtr = NULL;
382
383
384
385
    OlmOutboundGroupSession* sessionPtr = NULL;

    LOGD("## outbound group session serializeDataWithKeyJni(): IN");

386
    if (!(sessionPtr = (OlmOutboundGroupSession*)getOutboundGroupSessionInstanceId(env,thiz)))
387
388
389
    {
        LOGE(" ## serializeDataWithKeyJni(): failure - invalid session ptr");
    }
390
    else if (!aKeyBuffer)
391
392
393
    {
        LOGE(" ## serializeDataWithKeyJni(): failure - invalid key");
    }
394
    else if (!aErrorMsg)
395
396
397
    {
        LOGE(" ## serializeDataWithKeyJni(): failure - invalid error object");
    }
398
    else if (!(errorMsgJClass = env->GetObjectClass(aErrorMsg)))
399
400
401
    {
        LOGE(" ## serializeDataWithKeyJni(): failure - unable to get error class");
    }
402
    else if (!(errorMsgMethodId = env->GetMethodID(errorMsgJClass, "append", "(Ljava/lang/String;)Ljava/lang/StringBuffer;")))
403
404
405
    {
        LOGE(" ## serializeDataWithKeyJni(): failure - unable to get error method ID");
    }
406
    else if (!(keyPtr = env->GetByteArrayElements(aKeyBuffer, 0)))
407
408
409
410
411
412
    {
        LOGE(" ## serializeDataWithKeyJni(): failure - keyPtr JNI allocation OOM");
    }
    else
    {
        size_t pickledLength = olm_pickle_outbound_group_session_length(sessionPtr);
413
        size_t keyLength = (size_t)env->GetArrayLength(aKeyBuffer);
pedroGitt's avatar
pedroGitt committed
414
        LOGD(" ## serializeDataWithKeyJni(): pickledLength=%lu keyLength=%lu",static_cast<long unsigned int>(pickledLength), static_cast<long unsigned int>(keyLength));
415
416
        LOGD(" ## serializeDataWithKeyJni(): key=%s",(char const *)keyPtr);

417
418
419
        void *pickledPtr = malloc((pickledLength+1)*sizeof(uint8_t));

        if(!pickledPtr)
420
421
422
423
424
425
426
427
428
429
        {
            LOGE(" ## serializeDataWithKeyJni(): failure - pickledPtr buffer OOM");
        }
        else
        {
            size_t result = olm_pickle_outbound_group_session(sessionPtr,
                                                             (void const *)keyPtr,
                                                              keyLength,
                                                              (void*)pickledPtr,
                                                              pickledLength);
430
            if (result == olm_error())
431
432
433
434
            {
                const char *errorMsgPtr = olm_outbound_group_session_last_error(sessionPtr);
                LOGE(" ## serializeDataWithKeyJni(): failure - olm_pickle_outbound_group_session() Msg=%s",errorMsgPtr);

435
                if (!(errorJstring = env->NewStringUTF(errorMsgPtr)))
436
437
438
439
440
441
442
443
444
                {
                    env->CallObjectMethod(aErrorMsg, errorMsgMethodId, errorJstring);
                }
            }
            else
            {
                // build success output
                (static_cast<char*>(pickledPtr))[pickledLength] = static_cast<char>('\0');
                pickledDataRetValue = env->NewStringUTF((const char*)pickledPtr);
pedroGitt's avatar
pedroGitt committed
445
                LOGD(" ## serializeDataWithKeyJni(): success - result=%lu pickled=%s", static_cast<long unsigned int>(result), static_cast<char*>(pickledPtr));
446
447
448
            }
        }

449
        free(pickledPtr);
450
451
    }

452
453
    // free alloc
    if (keyPtr)
454
    {
455
        env->ReleaseByteArrayElements(aKeyBuffer, keyPtr, JNI_ABORT);
456
457
458
459
460
461
    }

    return pickledDataRetValue;
}


462
JNIEXPORT jstring OLM_OUTBOUND_GROUP_SESSION_FUNC_DEF(initWithSerializedDataJni)(JNIEnv *env, jobject thiz, jbyteArray aSerializedDataBuffer, jbyteArray aKeyBuffer)
463
464
465
{
    OlmOutboundGroupSession* sessionPtr = NULL;
    jstring errorMessageRetValue = 0;
466
467
    jbyte* keyPtr = NULL;
    jbyte* pickledPtr = NULL;
468
469
470

    LOGD("## initWithSerializedDataJni(): IN");

471
    if (!(sessionPtr = (OlmOutboundGroupSession*)getOutboundGroupSessionInstanceId(env,thiz)))
472
473
474
    {
        LOGE(" ## initWithSerializedDataJni(): failure - session failure OOM");
    }
475
    else if (!aKeyBuffer)
476
477
478
    {
        LOGE(" ## initWithSerializedDataJni(): failure - invalid key");
    }
479
    else if (!aSerializedDataBuffer)
480
481
482
    {
        LOGE(" ## initWithSerializedDataJni(): failure - serialized data");
    }
483
    else if (!(keyPtr = env->GetByteArrayElements(aKeyBuffer, 0)))
484
485
486
    {
        LOGE(" ## initWithSerializedDataJni(): failure - keyPtr JNI allocation OOM");
    }
487
    else if (!(pickledPtr = env->GetByteArrayElements(aSerializedDataBuffer, 0)))
488
489
490
491
492
    {
        LOGE(" ## initWithSerializedDataJni(): failure - pickledPtr JNI allocation OOM");
    }
    else
    {
493
494
        size_t pickledLength = (size_t)env->GetArrayLength(aSerializedDataBuffer);
        size_t keyLength = (size_t)env->GetArrayLength(aKeyBuffer);
pedroGitt's avatar
pedroGitt committed
495
        LOGD(" ## initWithSerializedDataJni(): pickledLength=%lu keyLength=%lu",static_cast<long unsigned int>(pickledLength), static_cast<long unsigned int>(keyLength));
496
497
498
499
500
501
502
503
        LOGD(" ## initWithSerializedDataJni(): key=%s",(char const *)keyPtr);
        LOGD(" ## initWithSerializedDataJni(): pickled=%s",(char const *)pickledPtr);

        size_t result = olm_unpickle_outbound_group_session(sessionPtr,
                                                            (void const *)keyPtr,
                                                            keyLength,
                                                            (void*)pickledPtr,
                                                            pickledLength);
504
        if (result == olm_error())
505
506
507
508
509
510
511
        {
            const char *errorMsgPtr = olm_outbound_group_session_last_error(sessionPtr);
            LOGE(" ## initWithSerializedDataJni(): failure - olm_unpickle_outbound_group_session() Msg=%s",errorMsgPtr);
            errorMessageRetValue = env->NewStringUTF(errorMsgPtr);
        }
        else
        {
pedroGitt's avatar
pedroGitt committed
512
            LOGD(" ## initWithSerializedDataJni(): success - result=%lu ", static_cast<long unsigned int>(result));
513
514
515
516
        }
    }

    // free alloc
517
    if (keyPtr)
518
    {
519
        env->ReleaseByteArrayElements(aKeyBuffer, keyPtr, JNI_ABORT);
520
521
    }

522
    if (pickledPtr)
523
    {
524
        env->ReleaseByteArrayElements(aSerializedDataBuffer, pickledPtr, JNI_ABORT);
525
526
527
528
    }

    return errorMessageRetValue;
}
529