olm.cpp 16.6 KB
Newer Older
Mark Haines's avatar
Mark Haines committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
/* Copyright 2015 OpenMarket Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
15
16
17
18
19
#include "olm/olm.hh"
#include "olm/session.hh"
#include "olm/account.hh"
#include "olm/base64.hh"
#include "olm/cipher.hh"
20
#include "olm/memory.hh"
21
22
23
24
25
26

#include <new>
#include <cstring>

namespace {

27
28
static OlmAccount * to_c(olm::Account * account) {
    return reinterpret_cast<OlmAccount *>(account);
29
30
}

31
32
static OlmSession * to_c(olm::Session * account) {
    return reinterpret_cast<OlmSession *>(account);
33
34
}

35
36
static olm::Account * from_c(OlmAccount * account) {
    return reinterpret_cast<olm::Account *>(account);
37
38
}

39
40
static olm::Session * from_c(OlmSession * account) {
    return reinterpret_cast<olm::Session *>(account);
41
42
43
44
45
46
47
48
49
50
51
52
}

static std::uint8_t * from_c(void * bytes) {
    return reinterpret_cast<std::uint8_t *>(bytes);
}

static std::uint8_t const * from_c(void const * bytes) {
    return reinterpret_cast<std::uint8_t const *>(bytes);
}

static const std::uint8_t CIPHER_KDF_INFO[] = "Pickle";

53
static const olm::CipherAesSha256 PICKLE_CIPHER(
54
55
56
57
58
59
60
61
    CIPHER_KDF_INFO, sizeof(CIPHER_KDF_INFO) -1
);

std::size_t enc_output_length(
    size_t raw_length
) {
    std::size_t length = PICKLE_CIPHER.encrypt_ciphertext_length(raw_length);
    length += PICKLE_CIPHER.mac_length();
62
    return olm::encode_base64_length(length);
63
64
65
66
67
68
69
70
71
}


std::uint8_t * enc_output_pos(
    std::uint8_t * output,
    size_t raw_length
) {
    std::size_t length = PICKLE_CIPHER.encrypt_ciphertext_length(raw_length);
    length += PICKLE_CIPHER.mac_length();
72
    return output + olm::encode_base64_length(length) - length;
73
74
75
76
77
78
79
80
81
82
}

std::size_t enc_output(
    std::uint8_t const * key, std::size_t key_length,
    std::uint8_t * output, size_t raw_length
) {
    std::size_t ciphertext_length = PICKLE_CIPHER.encrypt_ciphertext_length(
        raw_length
    );
    std::size_t length = ciphertext_length + PICKLE_CIPHER.mac_length();
83
    std::size_t base64_length = olm::encode_base64_length(length);
84
85
86
87
88
89
90
    std::uint8_t * raw_output = output + base64_length - length;
    PICKLE_CIPHER.encrypt(
        key, key_length,
        raw_output, raw_length,
        raw_output, ciphertext_length,
        raw_output, length
    );
91
    olm::encode_base64(raw_output, length, output);
92
93
94
95
96
97
    return raw_length;
}

std::size_t enc_input(
    std::uint8_t const * key, std::size_t key_length,
    std::uint8_t * input, size_t b64_length,
98
    olm::ErrorCode & last_error
99
) {
100
    std::size_t enc_length = olm::decode_base64_length(b64_length);
101
    if (enc_length == std::size_t(-1)) {
102
        last_error = olm::ErrorCode::INVALID_BASE64;
103
104
        return std::size_t(-1);
    }
105
    olm::decode_base64(input, b64_length, input);
106
107
108
109
110
111
112
113
    std::size_t raw_length = enc_length - PICKLE_CIPHER.mac_length();
    std::size_t result = PICKLE_CIPHER.decrypt(
        key, key_length,
        input, enc_length,
        input, raw_length,
        input, raw_length
    );
    if (result == std::size_t(-1)) {
114
        last_error = olm::ErrorCode::BAD_ACCOUNT_KEY;
115
116
117
118
119
120
121
122
    }
    return result;
}


std::size_t b64_output_length(
    size_t raw_length
) {
123
    return olm::encode_base64_length(raw_length);
124
125
126
127
128
129
}

std::uint8_t * b64_output_pos(
    std::uint8_t * output,
    size_t raw_length
) {
130
    return output + olm::encode_base64_length(raw_length) - raw_length;
131
132
133
134
135
}

std::size_t b64_output(
    std::uint8_t * output, size_t raw_length
) {
136
    std::size_t base64_length = olm::encode_base64_length(raw_length);
137
    std::uint8_t * raw_output = output + base64_length - raw_length;
138
    olm::encode_base64(raw_output, raw_length, output);
139
140
141
142
143
    return base64_length;
}

std::size_t b64_input(
    std::uint8_t * input, size_t b64_length,
144
    olm::ErrorCode & last_error
145
) {
146
    std::size_t raw_length = olm::decode_base64_length(b64_length);
147
    if (raw_length == std::size_t(-1)) {
148
        last_error = olm::ErrorCode::INVALID_BASE64;
149
150
        return std::size_t(-1);
    }
151
    olm::decode_base64(input, b64_length, input);
152
153
154
    return raw_length;
}

155
const char * errors[11] {
156
157
158
159
160
161
162
163
164
    "SUCCESS",
    "NOT_ENOUGH_RANDOM",
    "OUTPUT_BUFFER_TOO_SMALL",
    "BAD_MESSAGE_VERSION",
    "BAD_MESSAGE_FORMAT",
    "BAD_MESSAGE_MAC",
    "BAD_MESSAGE_KEY_ID",
    "INVALID_BASE64",
    "BAD_ACCOUNT_KEY",
165
166
    "UNKNOWN_PICKLE_VERSION",
    "CORRUPTED_PICKLE",
167
168
169
170
171
172
173
174
};

} // namespace


extern "C" {


175
size_t olm_error() {
176
177
178
179
    return std::size_t(-1);
}


180
181
const char * olm_account_last_error(
    OlmSession * account
182
183
184
185
186
187
188
189
190
191
) {
    unsigned error = unsigned(from_c(account)->last_error);
    if (error < 9) {
        return errors[error];
    } else {
        return "UNKNOWN_ERROR";
    }
}


192
193
const char * olm_session_last_error(
    OlmSession * session
194
195
196
197
198
199
200
201
202
203
) {
    unsigned error = unsigned(from_c(session)->last_error);
    if (error < 9) {
        return errors[error];
    } else {
        return "UNKNOWN_ERROR";
    }
}


204
205
size_t olm_account_size() {
    return sizeof(olm::Account);
206
207
208
}


209
210
size_t olm_session_size() {
    return sizeof(olm::Session);
211
212
213
}


214
OlmAccount * olm_account(
215
216
    void * memory
) {
217
    olm::unset(memory, sizeof(olm::Account));
218
    return to_c(new(memory) olm::Account());
219
220
221
}


222
OlmSession * olm_session(
223
224
    void * memory
) {
225
    olm::unset(memory, sizeof(olm::Session));
226
    return to_c(new(memory) olm::Session());
227
228
229
}


230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
size_t olm_clear_account(
    OlmSession * account
) {
    /* Clear the memory backing the account  */
    olm::unset(account, sizeof(olm::Account));
    /* Initialise a fresh account object in case someone tries to use it */
    new(account) olm::Account();
    return sizeof(olm::Account);
}


size_t olm_clear_session(
    OlmSession * session
) {
    /* Clear the memory backing the session */
    olm::unset(session, sizeof(olm::Session));
    /* Initialise a fresh session object in case someone tries to use it */
    new(session) olm::Session();
    return sizeof(olm::Session);
}


252
253
size_t olm_pickle_account_length(
    OlmAccount * account
254
255
256
257
258
) {
    return enc_output_length(pickle_length(*from_c(account)));
}


259
260
size_t olm_pickle_session_length(
    OlmSession * session
261
262
263
264
265
) {
    return enc_output_length(pickle_length(*from_c(session)));
}


266
267
size_t olm_pickle_account(
    OlmAccount * account,
268
269
270
    void const * key, size_t key_length,
    void * pickled, size_t pickled_length
) {
271
    olm::Account & object = *from_c(account);
272
273
    std::size_t raw_length = pickle_length(object);
    if (pickled_length < enc_output_length(raw_length)) {
274
        object.last_error = olm::ErrorCode::OUTPUT_BUFFER_TOO_SMALL;
275
276
277
278
279
280
281
        return size_t(-1);
    }
    pickle(enc_output_pos(from_c(pickled), raw_length), object);
    return enc_output(from_c(key), key_length, from_c(pickled), raw_length);
}


282
283
size_t olm_pickle_session(
    OlmSession * session,
284
285
286
    void const * key, size_t key_length,
    void * pickled, size_t pickled_length
) {
287
    olm::Session & object = *from_c(session);
288
289
    std::size_t raw_length = pickle_length(object);
    if (pickled_length < enc_output_length(raw_length)) {
290
        object.last_error = olm::ErrorCode::OUTPUT_BUFFER_TOO_SMALL;
291
292
293
294
295
296
297
        return size_t(-1);
    }
    pickle(enc_output_pos(from_c(pickled), raw_length), object);
    return enc_output(from_c(key), key_length, from_c(pickled), raw_length);
}


298
299
size_t olm_unpickle_account(
    OlmAccount * account,
300
301
302
    void const * key, size_t key_length,
    void * pickled, size_t pickled_length
) {
303
    olm::Account & object = *from_c(account);
304
305
306
307
308
309
310
311
    std::uint8_t * const pos = from_c(pickled);
    std::size_t raw_length = enc_input(
        from_c(key), key_length, pos, pickled_length, object.last_error
    );
    if (raw_length == std::size_t(-1)) {
        return std::size_t(-1);
    }
    std::uint8_t * const end = pos + raw_length;
312
313
314
315
316
317
318
319
320
321
    /* On success unpickle will return (pos + raw_length). If unpickling
     * terminates too soon then it will return a pointer before
     * (pos + raw_length). On error unpickle will return (pos + raw_length + 1).
     */
    if (end != unpickle(pos, end + 1, object)) {
        if (object.last_error == olm::ErrorCode::SUCCESS) {
            object.last_error = olm::ErrorCode::CORRUPTED_PICKLE;
        }
        return std::size_t(-1);
    }
322
323
324
325
    return pickled_length;
}


326
327
size_t olm_unpickle_session(
    OlmSession * session,
328
329
330
    void const * key, size_t key_length,
    void * pickled, size_t pickled_length
) {
331
    olm::Session & object = *from_c(session);
332
333
334
335
336
337
338
    std::uint8_t * const pos = from_c(pickled);
    std::size_t raw_length = enc_input(
        from_c(key), key_length, pos, pickled_length, object.last_error
    );
    if (raw_length == std::size_t(-1)) {
        return std::size_t(-1);
    }
339
340
341
342
343
344
345
346
347
348
349
350

    std::uint8_t * const end = pos + raw_length + 1;
    /* On success unpickle will return (pos + raw_length). If unpickling
     * terminates too soon then it will return a pointer before
     * (pos + raw_length). On error unpickle will return (pos + raw_length + 1).
     */
    if (end != unpickle(pos, end + 1, object)) {
        if (object.last_error == olm::ErrorCode::SUCCESS) {
            object.last_error = olm::ErrorCode::CORRUPTED_PICKLE;
        }
        return std::size_t(-1);
    }
351
352
353
354
    return pickled_length;
}


355
356
size_t olm_create_account_random_length(
    OlmAccount * account
357
358
359
360
361
) {
    return from_c(account)->new_account_random_length();
}


362
363
size_t olm_create_account(
    OlmAccount * account,
364
    void * random, size_t random_length
365
) {
366
367
368
    size_t result = from_c(account)->new_account(from_c(random), random_length);
    olm::unset(random, random_length);
    return result;
369
370
}

371

372
size_t olm_account_identity_keys_length(
373
374
375
    OlmAccount * account
) {
    return from_c(account)->get_identity_json_length();
376
377
}

378

379
380
size_t olm_account_identity_keys(
    OlmAccount * account,
381
382
    void * identity_keys, size_t identity_key_length
) {
383
384
385
    return from_c(account)->get_identity_json(
        from_c(identity_keys), identity_key_length
    );
386
387
388
}


389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
size_t olm_account_signature_length(
    OlmAccount * account
) {
    return b64_output_length(from_c(account)->signature_length());
}


size_t olm_account_sign(
    OlmAccount * account,
    void const * message, size_t message_length,
    void * signature, size_t signature_length
) {
    std::size_t raw_length = from_c(account)->signature_length();
    if (signature_length < b64_output_length(raw_length)) {
        from_c(account)->last_error =
            olm::ErrorCode::OUTPUT_BUFFER_TOO_SMALL;
        return std::size_t(-1);
    }
    from_c(account)->sign(
         from_c(message), message_length,
         b64_output_pos(from_c(signature), raw_length), raw_length
    );
    return b64_output(from_c(signature), raw_length);
}


415
416
size_t olm_account_one_time_keys_length(
    OlmAccount * account
417
) {
418
    return from_c(account)->get_one_time_keys_json_length();
419
420
421
}


422
423
size_t olm_account_one_time_keys(
    OlmAccount * account,
424
    void * one_time_keys_json, size_t one_time_key_json_length
425
) {
426
427
428
    return from_c(account)->get_one_time_keys_json(
        from_c(one_time_keys_json), one_time_key_json_length
    );
429
430
431
}


432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
size_t olm_account_mark_keys_as_published(
    OlmAccount * account
) {
    return from_c(account)->mark_keys_as_published();
}


size_t olm_account_max_number_of_one_time_keys(
    OlmAccount * account
) {
    return from_c(account)->max_number_of_one_time_keys();
}


size_t olm_account_generate_one_time_keys_random_length(
    OlmAccount * account,
    size_t number_of_keys
) {
    return from_c(account)->generate_one_time_keys_random_length(number_of_keys);
}


size_t olm_account_generate_one_time_keys(
    OlmAccount * account,
    size_t number_of_keys,
457
    void * random, size_t random_length
458
) {
459
    size_t result = from_c(account)->generate_one_time_keys(
460
461
462
        number_of_keys,
        from_c(random), random_length
    );
463
464
    olm::unset(random, random_length);
    return result;
465
466
467
}


468
469
size_t olm_create_outbound_session_random_length(
    OlmSession * session
Mark Haines's avatar
Mark Haines committed
470
471
472
473
) {
    return from_c(session)->new_outbound_session_random_length();
}

474

475
476
477
size_t olm_create_outbound_session(
    OlmSession * session,
    OlmAccount * account,
478
479
    void const * their_identity_key, size_t their_identity_key_length,
    void const * their_one_time_key, size_t their_one_time_key_length,
480
    void * random, size_t random_length
481
) {
482
483
    if (olm::decode_base64_length(their_identity_key_length) != 32
            || olm::decode_base64_length(their_one_time_key_length) != 32
484
    ) {
485
        from_c(session)->last_error = olm::ErrorCode::INVALID_BASE64;
486
487
        return std::size_t(-1);
    }
488
    olm::Curve25519PublicKey identity_key;
489
    olm::Curve25519PublicKey one_time_key;
490

491
    olm::decode_base64(
492
493
494
        from_c(their_identity_key), their_identity_key_length,
        identity_key.public_key
    );
495
    olm::decode_base64(
496
        from_c(their_one_time_key), their_one_time_key_length,
497
        one_time_key.public_key
498
499
    );

500
    size_t result = from_c(session)->new_outbound_session(
501
502
503
        *from_c(account), identity_key, one_time_key,
        from_c(random), random_length
    );
504
505
    olm::unset(random, random_length);
    return result;
506
507
508
}


509
510
511
size_t olm_create_inbound_session(
    OlmSession * session,
    OlmAccount * account,
512
513
514
515
516
517
518
519
520
521
522
523
524
525
    void * one_time_key_message, size_t message_length
) {
    std::size_t raw_length = b64_input(
        from_c(one_time_key_message), message_length, from_c(session)->last_error
    );
    if (raw_length == std::size_t(-1)) {
        return std::size_t(-1);
    }
    return from_c(session)->new_inbound_session(
        *from_c(account), from_c(one_time_key_message), raw_length
    );
}


526
527
size_t olm_matches_inbound_session(
    OlmSession * session,
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
    void * one_time_key_message, size_t message_length
) {
    std::size_t raw_length = b64_input(
        from_c(one_time_key_message), message_length, from_c(session)->last_error
    );
    if (raw_length == std::size_t(-1)) {
        return std::size_t(-1);
    }
    bool matches = from_c(session)->matches_inbound_session(
        from_c(one_time_key_message), raw_length
    );
    return matches ? 1 : 0;
}


543
544
545
size_t olm_remove_one_time_keys(
    OlmAccount * account,
    OlmSession * session
546
547
) {
    size_t result = from_c(account)->remove_key(
548
        from_c(session)->bob_one_time_key
549
550
    );
    if (result == std::size_t(-1)) {
551
        from_c(account)->last_error = olm::ErrorCode::BAD_MESSAGE_KEY_ID;
552
553
554
555
556
    }
    return result;
}


557
558
size_t olm_encrypt_message_type(
    OlmSession * session
559
560
561
562
563
) {
    return size_t(from_c(session)->encrypt_message_type());
}


564
565
size_t olm_encrypt_random_length(
    OlmSession * session
566
567
568
569
570
) {
    return from_c(session)->encrypt_random_length();
}


571
572
size_t olm_encrypt_message_length(
    OlmSession * session,
573
574
575
576
577
578
579
580
    size_t plaintext_length
) {
    return b64_output_length(
        from_c(session)->encrypt_message_length(plaintext_length)
    );
}


581
582
size_t olm_encrypt(
    OlmSession * session,
583
    void const * plaintext, size_t plaintext_length,
584
    void * random, size_t random_length,
585
586
587
588
589
    void * message, size_t message_length
) {
    std::size_t raw_length = from_c(session)->encrypt_message_length(
        plaintext_length
    );
590
    if (message_length < b64_output_length(raw_length)) {
591
        from_c(session)->last_error =
592
            olm::ErrorCode::OUTPUT_BUFFER_TOO_SMALL;
593
594
595
596
597
598
599
        return std::size_t(-1);
    }
    from_c(session)->encrypt(
        from_c(plaintext), plaintext_length,
        from_c(random), random_length,
        b64_output_pos(from_c(message), raw_length), raw_length
    );
600
    olm::unset(random, random_length);
601
602
603
604
    return b64_output(from_c(message), raw_length);
}


605
606
size_t olm_decrypt_max_plaintext_length(
    OlmSession * session,
607
608
609
610
611
612
613
614
615
616
    size_t message_type,
    void * message, size_t message_length
) {
    std::size_t raw_length = b64_input(
        from_c(message), message_length, from_c(session)->last_error
    );
    if (raw_length == std::size_t(-1)) {
        return std::size_t(-1);
    }
    return from_c(session)->decrypt_max_plaintext_length(
617
        olm::MessageType(message_type), from_c(message), raw_length
618
619
620
621
    );
}


622
623
size_t olm_decrypt(
    OlmSession * session,
624
625
626
627
628
629
630
631
632
633
634
    size_t message_type,
    void * message, size_t message_length,
    void * plaintext, size_t max_plaintext_length
) {
    std::size_t raw_length = b64_input(
        from_c(message), message_length, from_c(session)->last_error
    );
    if (raw_length == std::size_t(-1)) {
        return std::size_t(-1);
    }
    return from_c(session)->decrypt(
635
        olm::MessageType(message_type), from_c(message), raw_length,
636
637
638
639
640
        from_c(plaintext), max_plaintext_length
    );
}

}