olm.cpp 19.2 KB
Newer Older
Mark Haines's avatar
Mark Haines committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
/* Copyright 2015 OpenMarket Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
15
16
17
18
19
#include "olm/olm.hh"
#include "olm/session.hh"
#include "olm/account.hh"
#include "olm/base64.hh"
#include "olm/cipher.hh"
20
#include "olm/memory.hh"
21
22
23
24
25
26

#include <new>
#include <cstring>

namespace {

27
28
static OlmAccount * to_c(olm::Account * account) {
    return reinterpret_cast<OlmAccount *>(account);
29
30
}

31
32
static OlmSession * to_c(olm::Session * account) {
    return reinterpret_cast<OlmSession *>(account);
33
34
}

35
36
static olm::Account * from_c(OlmAccount * account) {
    return reinterpret_cast<olm::Account *>(account);
37
38
}

39
40
static olm::Session * from_c(OlmSession * account) {
    return reinterpret_cast<olm::Session *>(account);
41
42
43
44
45
46
47
48
49
50
51
52
}

static std::uint8_t * from_c(void * bytes) {
    return reinterpret_cast<std::uint8_t *>(bytes);
}

static std::uint8_t const * from_c(void const * bytes) {
    return reinterpret_cast<std::uint8_t const *>(bytes);
}

static const std::uint8_t CIPHER_KDF_INFO[] = "Pickle";

53
static const olm::CipherAesSha256 PICKLE_CIPHER(
54
55
56
57
58
59
60
61
    CIPHER_KDF_INFO, sizeof(CIPHER_KDF_INFO) -1
);

std::size_t enc_output_length(
    size_t raw_length
) {
    std::size_t length = PICKLE_CIPHER.encrypt_ciphertext_length(raw_length);
    length += PICKLE_CIPHER.mac_length();
62
    return olm::encode_base64_length(length);
63
64
65
66
67
68
69
70
71
}


std::uint8_t * enc_output_pos(
    std::uint8_t * output,
    size_t raw_length
) {
    std::size_t length = PICKLE_CIPHER.encrypt_ciphertext_length(raw_length);
    length += PICKLE_CIPHER.mac_length();
72
    return output + olm::encode_base64_length(length) - length;
73
74
75
76
77
78
79
80
81
82
}

std::size_t enc_output(
    std::uint8_t const * key, std::size_t key_length,
    std::uint8_t * output, size_t raw_length
) {
    std::size_t ciphertext_length = PICKLE_CIPHER.encrypt_ciphertext_length(
        raw_length
    );
    std::size_t length = ciphertext_length + PICKLE_CIPHER.mac_length();
83
    std::size_t base64_length = olm::encode_base64_length(length);
84
85
86
87
88
89
90
    std::uint8_t * raw_output = output + base64_length - length;
    PICKLE_CIPHER.encrypt(
        key, key_length,
        raw_output, raw_length,
        raw_output, ciphertext_length,
        raw_output, length
    );
91
    olm::encode_base64(raw_output, length, output);
92
93
94
95
96
97
    return raw_length;
}

std::size_t enc_input(
    std::uint8_t const * key, std::size_t key_length,
    std::uint8_t * input, size_t b64_length,
98
    olm::ErrorCode & last_error
99
) {
100
    std::size_t enc_length = olm::decode_base64_length(b64_length);
101
    if (enc_length == std::size_t(-1)) {
102
        last_error = olm::ErrorCode::INVALID_BASE64;
103
104
        return std::size_t(-1);
    }
105
    olm::decode_base64(input, b64_length, input);
106
107
108
109
110
111
112
113
    std::size_t raw_length = enc_length - PICKLE_CIPHER.mac_length();
    std::size_t result = PICKLE_CIPHER.decrypt(
        key, key_length,
        input, enc_length,
        input, raw_length,
        input, raw_length
    );
    if (result == std::size_t(-1)) {
114
        last_error = olm::ErrorCode::BAD_ACCOUNT_KEY;
115
116
117
118
119
120
121
122
    }
    return result;
}


std::size_t b64_output_length(
    size_t raw_length
) {
123
    return olm::encode_base64_length(raw_length);
124
125
126
127
128
129
}

std::uint8_t * b64_output_pos(
    std::uint8_t * output,
    size_t raw_length
) {
130
    return output + olm::encode_base64_length(raw_length) - raw_length;
131
132
133
134
135
}

std::size_t b64_output(
    std::uint8_t * output, size_t raw_length
) {
136
    std::size_t base64_length = olm::encode_base64_length(raw_length);
137
    std::uint8_t * raw_output = output + base64_length - raw_length;
138
    olm::encode_base64(raw_output, raw_length, output);
139
140
141
142
143
    return base64_length;
}

std::size_t b64_input(
    std::uint8_t * input, size_t b64_length,
144
    olm::ErrorCode & last_error
145
) {
146
    std::size_t raw_length = olm::decode_base64_length(b64_length);
147
    if (raw_length == std::size_t(-1)) {
148
        last_error = olm::ErrorCode::INVALID_BASE64;
149
150
        return std::size_t(-1);
    }
151
    olm::decode_base64(input, b64_length, input);
152
153
154
    return raw_length;
}

155
static const char * ERRORS[11] {
156
157
158
159
160
161
162
163
164
    "SUCCESS",
    "NOT_ENOUGH_RANDOM",
    "OUTPUT_BUFFER_TOO_SMALL",
    "BAD_MESSAGE_VERSION",
    "BAD_MESSAGE_FORMAT",
    "BAD_MESSAGE_MAC",
    "BAD_MESSAGE_KEY_ID",
    "INVALID_BASE64",
    "BAD_ACCOUNT_KEY",
165
166
    "UNKNOWN_PICKLE_VERSION",
    "CORRUPTED_PICKLE",
167
168
169
170
171
172
173
174
};

} // namespace


extern "C" {


175
size_t olm_error() {
176
177
178
179
    return std::size_t(-1);
}


180
const char * olm_account_last_error(
181
    OlmAccount * account
182
183
) {
    unsigned error = unsigned(from_c(account)->last_error);
184
185
    if (error < sizeof(ERRORS)) {
        return ERRORS[error];
186
187
188
189
190
191
    } else {
        return "UNKNOWN_ERROR";
    }
}


192
193
const char * olm_session_last_error(
    OlmSession * session
194
195
) {
    unsigned error = unsigned(from_c(session)->last_error);
196
197
    if (error < sizeof(ERRORS)) {
        return ERRORS[error];
198
199
200
201
202
203
    } else {
        return "UNKNOWN_ERROR";
    }
}


204
205
size_t olm_account_size() {
    return sizeof(olm::Account);
206
207
208
}


209
210
size_t olm_session_size() {
    return sizeof(olm::Session);
211
212
213
}


214
OlmAccount * olm_account(
215
216
    void * memory
) {
217
    olm::unset(memory, sizeof(olm::Account));
218
    return to_c(new(memory) olm::Account());
219
220
221
}


222
OlmSession * olm_session(
223
224
    void * memory
) {
225
    olm::unset(memory, sizeof(olm::Session));
226
    return to_c(new(memory) olm::Session());
227
228
229
}


230
size_t olm_clear_account(
231
    OlmAccount * account
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
) {
    /* Clear the memory backing the account  */
    olm::unset(account, sizeof(olm::Account));
    /* Initialise a fresh account object in case someone tries to use it */
    new(account) olm::Account();
    return sizeof(olm::Account);
}


size_t olm_clear_session(
    OlmSession * session
) {
    /* Clear the memory backing the session */
    olm::unset(session, sizeof(olm::Session));
    /* Initialise a fresh session object in case someone tries to use it */
    new(session) olm::Session();
    return sizeof(olm::Session);
}


252
253
size_t olm_pickle_account_length(
    OlmAccount * account
254
255
256
257
258
) {
    return enc_output_length(pickle_length(*from_c(account)));
}


259
260
size_t olm_pickle_session_length(
    OlmSession * session
261
262
263
264
265
) {
    return enc_output_length(pickle_length(*from_c(session)));
}


266
267
size_t olm_pickle_account(
    OlmAccount * account,
268
269
270
    void const * key, size_t key_length,
    void * pickled, size_t pickled_length
) {
271
    olm::Account & object = *from_c(account);
272
273
    std::size_t raw_length = pickle_length(object);
    if (pickled_length < enc_output_length(raw_length)) {
274
        object.last_error = olm::ErrorCode::OUTPUT_BUFFER_TOO_SMALL;
275
276
277
278
279
280
281
        return size_t(-1);
    }
    pickle(enc_output_pos(from_c(pickled), raw_length), object);
    return enc_output(from_c(key), key_length, from_c(pickled), raw_length);
}


282
283
size_t olm_pickle_session(
    OlmSession * session,
284
285
286
    void const * key, size_t key_length,
    void * pickled, size_t pickled_length
) {
287
    olm::Session & object = *from_c(session);
288
289
    std::size_t raw_length = pickle_length(object);
    if (pickled_length < enc_output_length(raw_length)) {
290
        object.last_error = olm::ErrorCode::OUTPUT_BUFFER_TOO_SMALL;
291
292
293
294
295
296
297
        return size_t(-1);
    }
    pickle(enc_output_pos(from_c(pickled), raw_length), object);
    return enc_output(from_c(key), key_length, from_c(pickled), raw_length);
}


298
299
size_t olm_unpickle_account(
    OlmAccount * account,
300
301
302
    void const * key, size_t key_length,
    void * pickled, size_t pickled_length
) {
303
    olm::Account & object = *from_c(account);
304
305
306
307
308
309
310
311
    std::uint8_t * const pos = from_c(pickled);
    std::size_t raw_length = enc_input(
        from_c(key), key_length, pos, pickled_length, object.last_error
    );
    if (raw_length == std::size_t(-1)) {
        return std::size_t(-1);
    }
    std::uint8_t * const end = pos + raw_length;
312
313
314
315
316
317
318
319
320
321
    /* On success unpickle will return (pos + raw_length). If unpickling
     * terminates too soon then it will return a pointer before
     * (pos + raw_length). On error unpickle will return (pos + raw_length + 1).
     */
    if (end != unpickle(pos, end + 1, object)) {
        if (object.last_error == olm::ErrorCode::SUCCESS) {
            object.last_error = olm::ErrorCode::CORRUPTED_PICKLE;
        }
        return std::size_t(-1);
    }
322
323
324
325
    return pickled_length;
}


326
327
size_t olm_unpickle_session(
    OlmSession * session,
328
329
330
    void const * key, size_t key_length,
    void * pickled, size_t pickled_length
) {
331
    olm::Session & object = *from_c(session);
332
333
334
335
336
337
338
    std::uint8_t * const pos = from_c(pickled);
    std::size_t raw_length = enc_input(
        from_c(key), key_length, pos, pickled_length, object.last_error
    );
    if (raw_length == std::size_t(-1)) {
        return std::size_t(-1);
    }
339

340
    std::uint8_t * const end = pos + raw_length;
341
342
343
344
345
346
347
348
349
350
    /* On success unpickle will return (pos + raw_length). If unpickling
     * terminates too soon then it will return a pointer before
     * (pos + raw_length). On error unpickle will return (pos + raw_length + 1).
     */
    if (end != unpickle(pos, end + 1, object)) {
        if (object.last_error == olm::ErrorCode::SUCCESS) {
            object.last_error = olm::ErrorCode::CORRUPTED_PICKLE;
        }
        return std::size_t(-1);
    }
351
352
353
354
    return pickled_length;
}


355
356
size_t olm_create_account_random_length(
    OlmAccount * account
357
358
359
360
361
) {
    return from_c(account)->new_account_random_length();
}


362
363
size_t olm_create_account(
    OlmAccount * account,
364
    void * random, size_t random_length
365
) {
366
367
368
    size_t result = from_c(account)->new_account(from_c(random), random_length);
    olm::unset(random, random_length);
    return result;
369
370
}

371

372
size_t olm_account_identity_keys_length(
373
374
375
    OlmAccount * account
) {
    return from_c(account)->get_identity_json_length();
376
377
}

378

379
380
size_t olm_account_identity_keys(
    OlmAccount * account,
381
382
    void * identity_keys, size_t identity_key_length
) {
383
384
385
    return from_c(account)->get_identity_json(
        from_c(identity_keys), identity_key_length
    );
386
387
388
}


389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
size_t olm_account_signature_length(
    OlmAccount * account
) {
    return b64_output_length(from_c(account)->signature_length());
}


size_t olm_account_sign(
    OlmAccount * account,
    void const * message, size_t message_length,
    void * signature, size_t signature_length
) {
    std::size_t raw_length = from_c(account)->signature_length();
    if (signature_length < b64_output_length(raw_length)) {
        from_c(account)->last_error =
            olm::ErrorCode::OUTPUT_BUFFER_TOO_SMALL;
        return std::size_t(-1);
    }
    from_c(account)->sign(
         from_c(message), message_length,
         b64_output_pos(from_c(signature), raw_length), raw_length
    );
    return b64_output(from_c(signature), raw_length);
}


415
416
size_t olm_account_one_time_keys_length(
    OlmAccount * account
417
) {
418
    return from_c(account)->get_one_time_keys_json_length();
419
420
421
}


422
423
size_t olm_account_one_time_keys(
    OlmAccount * account,
424
    void * one_time_keys_json, size_t one_time_key_json_length
425
) {
426
427
428
    return from_c(account)->get_one_time_keys_json(
        from_c(one_time_keys_json), one_time_key_json_length
    );
429
430
431
}


432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
size_t olm_account_mark_keys_as_published(
    OlmAccount * account
) {
    return from_c(account)->mark_keys_as_published();
}


size_t olm_account_max_number_of_one_time_keys(
    OlmAccount * account
) {
    return from_c(account)->max_number_of_one_time_keys();
}


size_t olm_account_generate_one_time_keys_random_length(
    OlmAccount * account,
    size_t number_of_keys
) {
    return from_c(account)->generate_one_time_keys_random_length(number_of_keys);
}


size_t olm_account_generate_one_time_keys(
    OlmAccount * account,
    size_t number_of_keys,
457
    void * random, size_t random_length
458
) {
459
    size_t result = from_c(account)->generate_one_time_keys(
460
461
462
        number_of_keys,
        from_c(random), random_length
    );
463
464
    olm::unset(random, random_length);
    return result;
465
466
467
}


468
469
size_t olm_create_outbound_session_random_length(
    OlmSession * session
Mark Haines's avatar
Mark Haines committed
470
471
472
473
) {
    return from_c(session)->new_outbound_session_random_length();
}

474

475
476
477
size_t olm_create_outbound_session(
    OlmSession * session,
    OlmAccount * account,
478
479
    void const * their_identity_key, size_t their_identity_key_length,
    void const * their_one_time_key, size_t their_one_time_key_length,
480
    void * random, size_t random_length
481
) {
482
483
    if (olm::decode_base64_length(their_identity_key_length) != 32
            || olm::decode_base64_length(their_one_time_key_length) != 32
484
    ) {
485
        from_c(session)->last_error = olm::ErrorCode::INVALID_BASE64;
486
487
        return std::size_t(-1);
    }
488
    olm::Curve25519PublicKey identity_key;
489
    olm::Curve25519PublicKey one_time_key;
490

491
    olm::decode_base64(
492
493
494
        from_c(their_identity_key), their_identity_key_length,
        identity_key.public_key
    );
495
    olm::decode_base64(
496
        from_c(their_one_time_key), their_one_time_key_length,
497
        one_time_key.public_key
498
499
    );

500
    size_t result = from_c(session)->new_outbound_session(
501
502
503
        *from_c(account), identity_key, one_time_key,
        from_c(random), random_length
    );
504
505
    olm::unset(random, random_length);
    return result;
506
507
508
}


509
510
511
size_t olm_create_inbound_session(
    OlmSession * session,
    OlmAccount * account,
512
513
514
515
516
517
518
519
520
    void * one_time_key_message, size_t message_length
) {
    std::size_t raw_length = b64_input(
        from_c(one_time_key_message), message_length, from_c(session)->last_error
    );
    if (raw_length == std::size_t(-1)) {
        return std::size_t(-1);
    }
    return from_c(session)->new_inbound_session(
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
        *from_c(account), nullptr, from_c(one_time_key_message), raw_length
    );
}


size_t olm_create_inbound_session_from(
    OlmSession * session,
    OlmAccount * account,
    void const * their_identity_key, size_t their_identity_key_length,
    void * one_time_key_message, size_t message_length
) {
    if (olm::decode_base64_length(their_identity_key_length) != 32) {
        from_c(session)->last_error = olm::ErrorCode::INVALID_BASE64;
        return std::size_t(-1);
    }
    olm::Curve25519PublicKey identity_key;
    olm::decode_base64(
        from_c(their_identity_key), their_identity_key_length,
        identity_key.public_key
    );

    std::size_t raw_length = b64_input(
        from_c(one_time_key_message), message_length, from_c(session)->last_error
    );
    if (raw_length == std::size_t(-1)) {
        return std::size_t(-1);
    }
    return from_c(session)->new_inbound_session(
        *from_c(account), &identity_key,
        from_c(one_time_key_message), raw_length
551
552
553
554
    );
}


555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
size_t olm_session_id_length(
    OlmSession * session
) {
    return b64_output_length(from_c(session)->session_id_length());
}


size_t olm_session_id(
    OlmSession * session,
    void * id, size_t id_length
) {
    std::size_t raw_length = from_c(session)->session_id_length();
    if (id_length < b64_output_length(raw_length)) {
        from_c(session)->last_error =
                olm::ErrorCode::OUTPUT_BUFFER_TOO_SMALL;
        return std::size_t(-1);
    }
    std::size_t result = from_c(session)->session_id(
       b64_output_pos(from_c(id), raw_length), raw_length
    );
    if (result == std::size_t(-1)) {
        return result;
    }
    return b64_output(from_c(id), raw_length);
}


582
583
size_t olm_matches_inbound_session(
    OlmSession * session,
584
585
586
587
588
589
590
591
592
    void * one_time_key_message, size_t message_length
) {
    std::size_t raw_length = b64_input(
        from_c(one_time_key_message), message_length, from_c(session)->last_error
    );
    if (raw_length == std::size_t(-1)) {
        return std::size_t(-1);
    }
    bool matches = from_c(session)->matches_inbound_session(
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
        nullptr, from_c(one_time_key_message), raw_length
    );
    return matches ? 1 : 0;
}


size_t olm_matches_inbound_session_from(
    OlmSession * session,
    void const * their_identity_key, size_t their_identity_key_length,
    void * one_time_key_message, size_t message_length
) {
    if (olm::decode_base64_length(their_identity_key_length) != 32) {
        from_c(session)->last_error = olm::ErrorCode::INVALID_BASE64;
        return std::size_t(-1);
    }
    olm::Curve25519PublicKey identity_key;
    olm::decode_base64(
        from_c(their_identity_key), their_identity_key_length,
        identity_key.public_key
    );

    std::size_t raw_length = b64_input(
        from_c(one_time_key_message), message_length, from_c(session)->last_error
    );
    if (raw_length == std::size_t(-1)) {
        return std::size_t(-1);
    }
    bool matches = from_c(session)->matches_inbound_session(
        &identity_key, from_c(one_time_key_message), raw_length
622
623
624
625
626
    );
    return matches ? 1 : 0;
}


627
628
629
size_t olm_remove_one_time_keys(
    OlmAccount * account,
    OlmSession * session
630
631
) {
    size_t result = from_c(account)->remove_key(
632
        from_c(session)->bob_one_time_key
633
634
    );
    if (result == std::size_t(-1)) {
635
        from_c(account)->last_error = olm::ErrorCode::BAD_MESSAGE_KEY_ID;
636
637
638
639
640
    }
    return result;
}


641
642
size_t olm_encrypt_message_type(
    OlmSession * session
643
644
645
646
647
) {
    return size_t(from_c(session)->encrypt_message_type());
}


648
649
size_t olm_encrypt_random_length(
    OlmSession * session
650
651
652
653
654
) {
    return from_c(session)->encrypt_random_length();
}


655
656
size_t olm_encrypt_message_length(
    OlmSession * session,
657
658
659
660
661
662
663
664
    size_t plaintext_length
) {
    return b64_output_length(
        from_c(session)->encrypt_message_length(plaintext_length)
    );
}


665
666
size_t olm_encrypt(
    OlmSession * session,
667
    void const * plaintext, size_t plaintext_length,
668
    void * random, size_t random_length,
669
670
671
672
673
    void * message, size_t message_length
) {
    std::size_t raw_length = from_c(session)->encrypt_message_length(
        plaintext_length
    );
674
    if (message_length < b64_output_length(raw_length)) {
675
        from_c(session)->last_error =
676
            olm::ErrorCode::OUTPUT_BUFFER_TOO_SMALL;
677
678
        return std::size_t(-1);
    }
679
    std::size_t result = from_c(session)->encrypt(
680
681
682
683
        from_c(plaintext), plaintext_length,
        from_c(random), random_length,
        b64_output_pos(from_c(message), raw_length), raw_length
    );
684
    olm::unset(random, random_length);
685
686
687
    if (result == std::size_t(-1)) {
        return result;
    }
688
689
690
691
    return b64_output(from_c(message), raw_length);
}


692
693
size_t olm_decrypt_max_plaintext_length(
    OlmSession * session,
694
695
696
697
698
699
700
701
702
703
    size_t message_type,
    void * message, size_t message_length
) {
    std::size_t raw_length = b64_input(
        from_c(message), message_length, from_c(session)->last_error
    );
    if (raw_length == std::size_t(-1)) {
        return std::size_t(-1);
    }
    return from_c(session)->decrypt_max_plaintext_length(
704
        olm::MessageType(message_type), from_c(message), raw_length
705
706
707
708
    );
}


709
710
size_t olm_decrypt(
    OlmSession * session,
711
712
713
714
715
716
717
718
719
720
721
    size_t message_type,
    void * message, size_t message_length,
    void * plaintext, size_t max_plaintext_length
) {
    std::size_t raw_length = b64_input(
        from_c(message), message_length, from_c(session)->last_error
    );
    if (raw_length == std::size_t(-1)) {
        return std::size_t(-1);
    }
    return from_c(session)->decrypt(
722
        olm::MessageType(message_type), from_c(message), raw_length,
723
724
725
726
727
        from_c(plaintext), max_plaintext_length
    );
}

}