olm.cpp 18.4 KB
Newer Older
Mark Haines's avatar
Mark Haines committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
/* Copyright 2015 OpenMarket Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
15
16
17
18
19
#include "olm/olm.hh"
#include "olm/session.hh"
#include "olm/account.hh"
#include "olm/base64.hh"
#include "olm/cipher.hh"
20
#include "olm/memory.hh"
21
22
23
24
25
26

#include <new>
#include <cstring>

namespace {

27
28
static OlmAccount * to_c(olm::Account * account) {
    return reinterpret_cast<OlmAccount *>(account);
29
30
}

31
32
static OlmSession * to_c(olm::Session * account) {
    return reinterpret_cast<OlmSession *>(account);
33
34
}

35
36
static olm::Account * from_c(OlmAccount * account) {
    return reinterpret_cast<olm::Account *>(account);
37
38
}

39
40
static olm::Session * from_c(OlmSession * account) {
    return reinterpret_cast<olm::Session *>(account);
41
42
43
44
45
46
47
48
49
50
51
52
}

static std::uint8_t * from_c(void * bytes) {
    return reinterpret_cast<std::uint8_t *>(bytes);
}

static std::uint8_t const * from_c(void const * bytes) {
    return reinterpret_cast<std::uint8_t const *>(bytes);
}

static const std::uint8_t CIPHER_KDF_INFO[] = "Pickle";

53
static const olm::CipherAesSha256 PICKLE_CIPHER(
54
55
56
57
58
59
60
61
    CIPHER_KDF_INFO, sizeof(CIPHER_KDF_INFO) -1
);

std::size_t enc_output_length(
    size_t raw_length
) {
    std::size_t length = PICKLE_CIPHER.encrypt_ciphertext_length(raw_length);
    length += PICKLE_CIPHER.mac_length();
62
    return olm::encode_base64_length(length);
63
64
65
66
67
68
69
70
71
}


std::uint8_t * enc_output_pos(
    std::uint8_t * output,
    size_t raw_length
) {
    std::size_t length = PICKLE_CIPHER.encrypt_ciphertext_length(raw_length);
    length += PICKLE_CIPHER.mac_length();
72
    return output + olm::encode_base64_length(length) - length;
73
74
75
76
77
78
79
80
81
82
}

std::size_t enc_output(
    std::uint8_t const * key, std::size_t key_length,
    std::uint8_t * output, size_t raw_length
) {
    std::size_t ciphertext_length = PICKLE_CIPHER.encrypt_ciphertext_length(
        raw_length
    );
    std::size_t length = ciphertext_length + PICKLE_CIPHER.mac_length();
83
    std::size_t base64_length = olm::encode_base64_length(length);
84
85
86
87
88
89
90
    std::uint8_t * raw_output = output + base64_length - length;
    PICKLE_CIPHER.encrypt(
        key, key_length,
        raw_output, raw_length,
        raw_output, ciphertext_length,
        raw_output, length
    );
91
    olm::encode_base64(raw_output, length, output);
92
93
94
95
96
97
    return raw_length;
}

std::size_t enc_input(
    std::uint8_t const * key, std::size_t key_length,
    std::uint8_t * input, size_t b64_length,
98
    olm::ErrorCode & last_error
99
) {
100
    std::size_t enc_length = olm::decode_base64_length(b64_length);
101
    if (enc_length == std::size_t(-1)) {
102
        last_error = olm::ErrorCode::INVALID_BASE64;
103
104
        return std::size_t(-1);
    }
105
    olm::decode_base64(input, b64_length, input);
106
107
108
109
110
111
112
113
    std::size_t raw_length = enc_length - PICKLE_CIPHER.mac_length();
    std::size_t result = PICKLE_CIPHER.decrypt(
        key, key_length,
        input, enc_length,
        input, raw_length,
        input, raw_length
    );
    if (result == std::size_t(-1)) {
114
        last_error = olm::ErrorCode::BAD_ACCOUNT_KEY;
115
116
117
118
119
120
121
122
    }
    return result;
}


std::size_t b64_output_length(
    size_t raw_length
) {
123
    return olm::encode_base64_length(raw_length);
124
125
126
127
128
129
}

std::uint8_t * b64_output_pos(
    std::uint8_t * output,
    size_t raw_length
) {
130
    return output + olm::encode_base64_length(raw_length) - raw_length;
131
132
133
134
135
}

std::size_t b64_output(
    std::uint8_t * output, size_t raw_length
) {
136
    std::size_t base64_length = olm::encode_base64_length(raw_length);
137
    std::uint8_t * raw_output = output + base64_length - raw_length;
138
    olm::encode_base64(raw_output, raw_length, output);
139
140
141
142
143
    return base64_length;
}

std::size_t b64_input(
    std::uint8_t * input, size_t b64_length,
144
    olm::ErrorCode & last_error
145
) {
146
    std::size_t raw_length = olm::decode_base64_length(b64_length);
147
    if (raw_length == std::size_t(-1)) {
148
        last_error = olm::ErrorCode::INVALID_BASE64;
149
150
        return std::size_t(-1);
    }
151
    olm::decode_base64(input, b64_length, input);
152
153
154
    return raw_length;
}

155
static const char * ERRORS[11] {
156
157
158
159
160
161
162
163
164
    "SUCCESS",
    "NOT_ENOUGH_RANDOM",
    "OUTPUT_BUFFER_TOO_SMALL",
    "BAD_MESSAGE_VERSION",
    "BAD_MESSAGE_FORMAT",
    "BAD_MESSAGE_MAC",
    "BAD_MESSAGE_KEY_ID",
    "INVALID_BASE64",
    "BAD_ACCOUNT_KEY",
165
166
    "UNKNOWN_PICKLE_VERSION",
    "CORRUPTED_PICKLE",
167
168
169
170
171
172
173
174
};

} // namespace


extern "C" {


175
size_t olm_error() {
176
177
178
179
    return std::size_t(-1);
}


180
const char * olm_account_last_error(
181
    OlmAccount * account
182
183
) {
    unsigned error = unsigned(from_c(account)->last_error);
184
185
    if (error < sizeof(ERRORS)) {
        return ERRORS[error];
186
187
188
189
190
191
    } else {
        return "UNKNOWN_ERROR";
    }
}


192
193
const char * olm_session_last_error(
    OlmSession * session
194
195
) {
    unsigned error = unsigned(from_c(session)->last_error);
196
197
    if (error < sizeof(ERRORS)) {
        return ERRORS[error];
198
199
200
201
202
203
    } else {
        return "UNKNOWN_ERROR";
    }
}


204
205
size_t olm_account_size() {
    return sizeof(olm::Account);
206
207
208
}


209
210
size_t olm_session_size() {
    return sizeof(olm::Session);
211
212
213
}


214
OlmAccount * olm_account(
215
216
    void * memory
) {
217
    olm::unset(memory, sizeof(olm::Account));
218
    return to_c(new(memory) olm::Account());
219
220
221
}


222
OlmSession * olm_session(
223
224
    void * memory
) {
225
    olm::unset(memory, sizeof(olm::Session));
226
    return to_c(new(memory) olm::Session());
227
228
229
}


230
size_t olm_clear_account(
231
    OlmAccount * account
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
) {
    /* Clear the memory backing the account  */
    olm::unset(account, sizeof(olm::Account));
    /* Initialise a fresh account object in case someone tries to use it */
    new(account) olm::Account();
    return sizeof(olm::Account);
}


size_t olm_clear_session(
    OlmSession * session
) {
    /* Clear the memory backing the session */
    olm::unset(session, sizeof(olm::Session));
    /* Initialise a fresh session object in case someone tries to use it */
    new(session) olm::Session();
    return sizeof(olm::Session);
}


252
253
size_t olm_pickle_account_length(
    OlmAccount * account
254
255
256
257
258
) {
    return enc_output_length(pickle_length(*from_c(account)));
}


259
260
size_t olm_pickle_session_length(
    OlmSession * session
261
262
263
264
265
) {
    return enc_output_length(pickle_length(*from_c(session)));
}


266
267
size_t olm_pickle_account(
    OlmAccount * account,
268
269
270
    void const * key, size_t key_length,
    void * pickled, size_t pickled_length
) {
271
    olm::Account & object = *from_c(account);
272
273
    std::size_t raw_length = pickle_length(object);
    if (pickled_length < enc_output_length(raw_length)) {
274
        object.last_error = olm::ErrorCode::OUTPUT_BUFFER_TOO_SMALL;
275
276
277
278
279
280
281
        return size_t(-1);
    }
    pickle(enc_output_pos(from_c(pickled), raw_length), object);
    return enc_output(from_c(key), key_length, from_c(pickled), raw_length);
}


282
283
size_t olm_pickle_session(
    OlmSession * session,
284
285
286
    void const * key, size_t key_length,
    void * pickled, size_t pickled_length
) {
287
    olm::Session & object = *from_c(session);
288
289
    std::size_t raw_length = pickle_length(object);
    if (pickled_length < enc_output_length(raw_length)) {
290
        object.last_error = olm::ErrorCode::OUTPUT_BUFFER_TOO_SMALL;
291
292
293
294
295
296
297
        return size_t(-1);
    }
    pickle(enc_output_pos(from_c(pickled), raw_length), object);
    return enc_output(from_c(key), key_length, from_c(pickled), raw_length);
}


298
299
size_t olm_unpickle_account(
    OlmAccount * account,
300
301
302
    void const * key, size_t key_length,
    void * pickled, size_t pickled_length
) {
303
    olm::Account & object = *from_c(account);
304
305
306
307
308
309
310
311
    std::uint8_t * const pos = from_c(pickled);
    std::size_t raw_length = enc_input(
        from_c(key), key_length, pos, pickled_length, object.last_error
    );
    if (raw_length == std::size_t(-1)) {
        return std::size_t(-1);
    }
    std::uint8_t * const end = pos + raw_length;
312
313
314
315
316
317
318
319
320
321
    /* On success unpickle will return (pos + raw_length). If unpickling
     * terminates too soon then it will return a pointer before
     * (pos + raw_length). On error unpickle will return (pos + raw_length + 1).
     */
    if (end != unpickle(pos, end + 1, object)) {
        if (object.last_error == olm::ErrorCode::SUCCESS) {
            object.last_error = olm::ErrorCode::CORRUPTED_PICKLE;
        }
        return std::size_t(-1);
    }
322
323
324
325
    return pickled_length;
}


326
327
size_t olm_unpickle_session(
    OlmSession * session,
328
329
330
    void const * key, size_t key_length,
    void * pickled, size_t pickled_length
) {
331
    olm::Session & object = *from_c(session);
332
333
334
335
336
337
338
    std::uint8_t * const pos = from_c(pickled);
    std::size_t raw_length = enc_input(
        from_c(key), key_length, pos, pickled_length, object.last_error
    );
    if (raw_length == std::size_t(-1)) {
        return std::size_t(-1);
    }
339

340
    std::uint8_t * const end = pos + raw_length;
341
342
343
344
345
346
347
348
349
350
    /* On success unpickle will return (pos + raw_length). If unpickling
     * terminates too soon then it will return a pointer before
     * (pos + raw_length). On error unpickle will return (pos + raw_length + 1).
     */
    if (end != unpickle(pos, end + 1, object)) {
        if (object.last_error == olm::ErrorCode::SUCCESS) {
            object.last_error = olm::ErrorCode::CORRUPTED_PICKLE;
        }
        return std::size_t(-1);
    }
351
352
353
354
    return pickled_length;
}


355
356
size_t olm_create_account_random_length(
    OlmAccount * account
357
358
359
360
361
) {
    return from_c(account)->new_account_random_length();
}


362
363
size_t olm_create_account(
    OlmAccount * account,
364
    void * random, size_t random_length
365
) {
366
367
368
    size_t result = from_c(account)->new_account(from_c(random), random_length);
    olm::unset(random, random_length);
    return result;
369
370
}

371

372
size_t olm_account_identity_keys_length(
373
374
375
    OlmAccount * account
) {
    return from_c(account)->get_identity_json_length();
376
377
}

378

379
380
size_t olm_account_identity_keys(
    OlmAccount * account,
381
382
    void * identity_keys, size_t identity_key_length
) {
383
384
385
    return from_c(account)->get_identity_json(
        from_c(identity_keys), identity_key_length
    );
386
387
388
}


389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
size_t olm_account_signature_length(
    OlmAccount * account
) {
    return b64_output_length(from_c(account)->signature_length());
}


size_t olm_account_sign(
    OlmAccount * account,
    void const * message, size_t message_length,
    void * signature, size_t signature_length
) {
    std::size_t raw_length = from_c(account)->signature_length();
    if (signature_length < b64_output_length(raw_length)) {
        from_c(account)->last_error =
            olm::ErrorCode::OUTPUT_BUFFER_TOO_SMALL;
        return std::size_t(-1);
    }
    from_c(account)->sign(
         from_c(message), message_length,
         b64_output_pos(from_c(signature), raw_length), raw_length
    );
    return b64_output(from_c(signature), raw_length);
}


415
416
size_t olm_account_one_time_keys_length(
    OlmAccount * account
417
) {
418
    return from_c(account)->get_one_time_keys_json_length();
419
420
421
}


422
423
size_t olm_account_one_time_keys(
    OlmAccount * account,
424
    void * one_time_keys_json, size_t one_time_key_json_length
425
) {
426
427
428
    return from_c(account)->get_one_time_keys_json(
        from_c(one_time_keys_json), one_time_key_json_length
    );
429
430
431
}


432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
size_t olm_account_mark_keys_as_published(
    OlmAccount * account
) {
    return from_c(account)->mark_keys_as_published();
}


size_t olm_account_max_number_of_one_time_keys(
    OlmAccount * account
) {
    return from_c(account)->max_number_of_one_time_keys();
}


size_t olm_account_generate_one_time_keys_random_length(
    OlmAccount * account,
    size_t number_of_keys
) {
    return from_c(account)->generate_one_time_keys_random_length(number_of_keys);
}


size_t olm_account_generate_one_time_keys(
    OlmAccount * account,
    size_t number_of_keys,
457
    void * random, size_t random_length
458
) {
459
    size_t result = from_c(account)->generate_one_time_keys(
460
461
462
        number_of_keys,
        from_c(random), random_length
    );
463
464
    olm::unset(random, random_length);
    return result;
465
466
467
}


468
469
size_t olm_create_outbound_session_random_length(
    OlmSession * session
Mark Haines's avatar
Mark Haines committed
470
471
472
473
) {
    return from_c(session)->new_outbound_session_random_length();
}

474

475
476
477
size_t olm_create_outbound_session(
    OlmSession * session,
    OlmAccount * account,
478
479
    void const * their_identity_key, size_t their_identity_key_length,
    void const * their_one_time_key, size_t their_one_time_key_length,
480
    void * random, size_t random_length
481
) {
482
483
    if (olm::decode_base64_length(their_identity_key_length) != 32
            || olm::decode_base64_length(their_one_time_key_length) != 32
484
    ) {
485
        from_c(session)->last_error = olm::ErrorCode::INVALID_BASE64;
486
487
        return std::size_t(-1);
    }
488
    olm::Curve25519PublicKey identity_key;
489
    olm::Curve25519PublicKey one_time_key;
490

491
    olm::decode_base64(
492
493
494
        from_c(their_identity_key), their_identity_key_length,
        identity_key.public_key
    );
495
    olm::decode_base64(
496
        from_c(their_one_time_key), their_one_time_key_length,
497
        one_time_key.public_key
498
499
    );

500
    size_t result = from_c(session)->new_outbound_session(
501
502
503
        *from_c(account), identity_key, one_time_key,
        from_c(random), random_length
    );
504
505
    olm::unset(random, random_length);
    return result;
506
507
508
}


509
510
511
size_t olm_create_inbound_session(
    OlmSession * session,
    OlmAccount * account,
512
513
514
515
516
517
518
519
520
    void * one_time_key_message, size_t message_length
) {
    std::size_t raw_length = b64_input(
        from_c(one_time_key_message), message_length, from_c(session)->last_error
    );
    if (raw_length == std::size_t(-1)) {
        return std::size_t(-1);
    }
    return from_c(session)->new_inbound_session(
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
        *from_c(account), nullptr, from_c(one_time_key_message), raw_length
    );
}


size_t olm_create_inbound_session_from(
    OlmSession * session,
    OlmAccount * account,
    void const * their_identity_key, size_t their_identity_key_length,
    void * one_time_key_message, size_t message_length
) {
    if (olm::decode_base64_length(their_identity_key_length) != 32) {
        from_c(session)->last_error = olm::ErrorCode::INVALID_BASE64;
        return std::size_t(-1);
    }
    olm::Curve25519PublicKey identity_key;
    olm::decode_base64(
        from_c(their_identity_key), their_identity_key_length,
        identity_key.public_key
    );

    std::size_t raw_length = b64_input(
        from_c(one_time_key_message), message_length, from_c(session)->last_error
    );
    if (raw_length == std::size_t(-1)) {
        return std::size_t(-1);
    }
    return from_c(session)->new_inbound_session(
        *from_c(account), &identity_key,
        from_c(one_time_key_message), raw_length
551
552
553
554
    );
}


555
556
size_t olm_matches_inbound_session(
    OlmSession * session,
557
558
559
560
561
562
563
564
565
    void * one_time_key_message, size_t message_length
) {
    std::size_t raw_length = b64_input(
        from_c(one_time_key_message), message_length, from_c(session)->last_error
    );
    if (raw_length == std::size_t(-1)) {
        return std::size_t(-1);
    }
    bool matches = from_c(session)->matches_inbound_session(
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
        nullptr, from_c(one_time_key_message), raw_length
    );
    return matches ? 1 : 0;
}


size_t olm_matches_inbound_session_from(
    OlmSession * session,
    void const * their_identity_key, size_t their_identity_key_length,
    void * one_time_key_message, size_t message_length
) {
    if (olm::decode_base64_length(their_identity_key_length) != 32) {
        from_c(session)->last_error = olm::ErrorCode::INVALID_BASE64;
        return std::size_t(-1);
    }
    olm::Curve25519PublicKey identity_key;
    olm::decode_base64(
        from_c(their_identity_key), their_identity_key_length,
        identity_key.public_key
    );

    std::size_t raw_length = b64_input(
        from_c(one_time_key_message), message_length, from_c(session)->last_error
    );
    if (raw_length == std::size_t(-1)) {
        return std::size_t(-1);
    }
    bool matches = from_c(session)->matches_inbound_session(
        &identity_key, from_c(one_time_key_message), raw_length
595
596
597
598
599
    );
    return matches ? 1 : 0;
}


600
601
602
size_t olm_remove_one_time_keys(
    OlmAccount * account,
    OlmSession * session
603
604
) {
    size_t result = from_c(account)->remove_key(
605
        from_c(session)->bob_one_time_key
606
607
    );
    if (result == std::size_t(-1)) {
608
        from_c(account)->last_error = olm::ErrorCode::BAD_MESSAGE_KEY_ID;
609
610
611
612
613
    }
    return result;
}


614
615
size_t olm_encrypt_message_type(
    OlmSession * session
616
617
618
619
620
) {
    return size_t(from_c(session)->encrypt_message_type());
}


621
622
size_t olm_encrypt_random_length(
    OlmSession * session
623
624
625
626
627
) {
    return from_c(session)->encrypt_random_length();
}


628
629
size_t olm_encrypt_message_length(
    OlmSession * session,
630
631
632
633
634
635
636
637
    size_t plaintext_length
) {
    return b64_output_length(
        from_c(session)->encrypt_message_length(plaintext_length)
    );
}


638
639
size_t olm_encrypt(
    OlmSession * session,
640
    void const * plaintext, size_t plaintext_length,
641
    void * random, size_t random_length,
642
643
644
645
646
    void * message, size_t message_length
) {
    std::size_t raw_length = from_c(session)->encrypt_message_length(
        plaintext_length
    );
647
    if (message_length < b64_output_length(raw_length)) {
648
        from_c(session)->last_error =
649
            olm::ErrorCode::OUTPUT_BUFFER_TOO_SMALL;
650
651
652
653
654
655
656
        return std::size_t(-1);
    }
    from_c(session)->encrypt(
        from_c(plaintext), plaintext_length,
        from_c(random), random_length,
        b64_output_pos(from_c(message), raw_length), raw_length
    );
657
    olm::unset(random, random_length);
658
659
660
661
    return b64_output(from_c(message), raw_length);
}


662
663
size_t olm_decrypt_max_plaintext_length(
    OlmSession * session,
664
665
666
667
668
669
670
671
672
673
    size_t message_type,
    void * message, size_t message_length
) {
    std::size_t raw_length = b64_input(
        from_c(message), message_length, from_c(session)->last_error
    );
    if (raw_length == std::size_t(-1)) {
        return std::size_t(-1);
    }
    return from_c(session)->decrypt_max_plaintext_length(
674
        olm::MessageType(message_type), from_c(message), raw_length
675
676
677
678
    );
}


679
680
size_t olm_decrypt(
    OlmSession * session,
681
682
683
684
685
686
687
688
689
690
691
    size_t message_type,
    void * message, size_t message_length,
    void * plaintext, size_t max_plaintext_length
) {
    std::size_t raw_length = b64_input(
        from_c(message), message_length, from_c(session)->last_error
    );
    if (raw_length == std::size_t(-1)) {
        return std::size_t(-1);
    }
    return from_c(session)->decrypt(
692
        olm::MessageType(message_type), from_c(message), raw_length,
693
694
695
696
697
        from_c(plaintext), max_plaintext_length
    );
}

}