olm.py 15.5 KB
Newer Older
1
2
3
#! /usr/bin/python
from ctypes import *
import json
4
import os
5

6
7
8
lib = cdll.LoadLibrary(os.path.join(
    os.path.dirname(__file__), "build", "libolm.so")
)
9
10


11
12
lib.olm_error.argtypes = []
lib.olm_error.restypes = c_size_t
13

14
ERR = lib.olm_error()
15

16
class OlmError(Exception):
17
18
19
    pass


20
21
lib.olm_account_size.argtypes = []
lib.olm_account_size.restype = c_size_t
22

23
24
lib.olm_account.argtypes = [c_void_p]
lib.olm_account.restype = c_void_p
25

26
27
lib.olm_account_last_error.argtypes = [c_void_p]
lib.olm_account_last_error.restype = c_char_p
28
29
30

def account_errcheck(res, func, args):
    if res == ERR:
31
32
        raise OlmError("%s: %s" % (
            func.__name__, lib.olm_account_last_error(args[0])
33
34
35
36
37
38
39
40
41
42
43
        ))
    return res


def account_function(func, *types):
    func.argtypes = (c_void_p,) + types
    func.restypes = c_size_t
    func.errcheck = account_errcheck


account_function(
44
    lib.olm_pickle_account, c_void_p, c_size_t, c_void_p, c_size_t
45
46
)
account_function(
47
    lib.olm_unpickle_account, c_void_p, c_size_t, c_void_p, c_size_t
48
)
49
50
account_function(lib.olm_create_account_random_length)
account_function(lib.olm_create_account, c_void_p, c_size_t)
51
52
53
54
55
56
57
58
59
account_function(
    lib.olm_account_identity_keys_length,
    c_size_t, c_size_t, c_uint64, c_uint64
)
account_function(
    lib.olm_account_identity_keys,
    c_void_p, c_size_t, c_void_p, c_size_t, c_uint64, c_uint64,
    c_void_p, c_size_t
)
60
61
account_function(lib.olm_account_one_time_keys_length)
account_function(lib.olm_account_one_time_keys, c_void_p, c_size_t)
62
63
64
65
66
67
68
69
70
71
72
73
account_function(lib.olm_account_mark_keys_as_published)
account_function(lib.olm_account_max_number_of_one_time_keys)
account_function(
    lib.olm_account_generate_one_time_keys_random_length,
    c_size_t
)
account_function(
    lib.olm_account_generate_one_time_keys
    c_size_t,
    c_void_p, c_size_t
)

74
75
76
77
78
79
80

def read_random(n):
    with open("/dev/urandom", "rb") as f:
        return f.read(n)

class Account(object):
    def __init__(self):
81
82
        self.buf = create_string_buffer(lib.olm_account_size())
        self.ptr = lib.olm_account(self.buf)
83
84

    def create(self):
85
        random_length = lib.olm_create_account_random_length(self.ptr)
86
87
        random = read_random(random_length)
        random_buffer = create_string_buffer(random)
88
        lib.olm_create_account(self.ptr, random_buffer, random_length)
89
90
91

    def pickle(self, key):
        key_buffer = create_string_buffer(key)
92
        pickle_length = lib.olm_pickle_account_length(self.ptr)
93
        pickle_buffer = create_string_buffer(pickle_length)
94
        lib.olm_pickle_account(
95
96
97
98
99
100
101
            self.ptr, key_buffer, len(key), pickle_buffer, pickle_length
        )
        return pickle_buffer.raw

    def unpickle(self, key, pickle):
        key_buffer = create_string_buffer(key)
        pickle_buffer = create_string_buffer(pickle)
102
        lib.olm_unpickle_account(
103
104
105
            self.ptr, key_buffer, len(key), pickle_buffer, len(pickle)
        )

106
107
108
109
110
111
    def identity_keys(self, user_id, device_id, valid_after, valid_until):
        out_length = lib.olm_account_identity_keys_length(
            self.ptr, len(user_id), len(device_id), valid_after, valid_until
        )
        user_id_buffer = create_string_buffer(user_id)
        device_id_buffer = create_string_buffer(device_id)
112
        out_buffer = create_string_buffer(out_length)
113
114
115
116
117
118
119
        lib.olm_account_identity_keys(
            self.ptr,
            user_id_buffer, len(user_id),
            device_id_buffer, len(device_id),
            valid_after, valid_until,
            out_buffer, out_length
        )
120
121
122
        return json.loads(out_buffer.raw)

    def one_time_keys(self):
123
        out_length = lib.olm_account_one_time_keys_length(self.ptr)
124
        out_buffer = create_string_buffer(out_length)
125
        lib.olm_account_one_time_keys(self.ptr, out_buffer, out_length)
126
127
        return json.loads(out_buffer.raw)

128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
    def mark_keys_as_published(self):
        lib.olm_account_mark_keys_as_published(self.ptr)


    def max_number_of_one_time_keys(self):
        return lib.olm_account_max_number_of_one_time_keys(self.ptr)

    def generate_one_time_keys(self, count):
        random_length = lib.olm_account_generate_one_time_keys_random_length(
            self.ptr
        )
        random = read_random(random_length)
        random_buffer = create_string_buffer(random)
        lib.olm_account_generate_one_time_keys(
            self.ptr, random_buffer, random_length
        )

145
146
147
148
    def clear(self):
        pass


149
150
lib.olm_session_size.argtypes = []
lib.olm_session_size.restype = c_size_t
151

152
153
lib.olm_session.argtypes = [c_void_p]
lib.olm_session.restype = c_void_p
154

155
156
lib.olm_session_last_error.argtypes = [c_void_p]
lib.olm_session_last_error.restype = c_char_p
157
158
159
160


def session_errcheck(res, func, args):
    if res == ERR:
161
162
        raise OlmError("%s: %s" % (
            func.__name__, lib.olm_session_last_error(args[0])
163
164
165
166
167
168
169
170
171
        ))
    return res


def session_function(func, *types):
    func.argtypes = (c_void_p,) + types
    func.restypes = c_size_t
    func.errcheck = session_errcheck

172
session_function(lib.olm_session_last_error)
173
session_function(
174
    lib.olm_pickle_session, c_void_p, c_size_t, c_void_p, c_size_t
175
176
)
session_function(
177
    lib.olm_unpickle_session, c_void_p, c_size_t, c_void_p, c_size_t
178
)
179
session_function(lib.olm_create_outbound_session_random_length)
180
session_function(
181
    lib.olm_create_outbound_session,
182
183
184
185
186
187
    c_void_p,  # Account
    c_void_p, c_size_t,  # Identity Key
    c_void_p, c_size_t,  # One Time Key
    c_void_p, c_size_t,  # Random
)
session_function(
188
    lib.olm_create_inbound_session,
189
190
191
    c_void_p,  # Account
    c_void_p, c_size_t,  # Pre Key Message
)
192
193
194
195
session_function(lib.olm_matches_inbound_session, c_void_p, c_size_t)
session_function(lib.olm_encrypt_message_type)
session_function(lib.olm_encrypt_random_length)
session_function(lib.olm_encrypt_message_length, c_size_t)
196
session_function(
197
    lib.olm_encrypt,
198
199
200
201
202
    c_void_p, c_size_t,  # Plaintext
    c_void_p, c_size_t,  # Random
    c_void_p, c_size_t,  # Message
);
session_function(
203
    lib.olm_decrypt_max_plaintext_length,
204
205
206
207
    c_size_t,  # Message Type
    c_void_p, c_size_t,  # Message
)
session_function(
208
    lib.olm_decrypt,
209
210
211
212
213
214
215
    c_size_t,  # Message Type
    c_void_p, c_size_t,  # Message
    c_void_p, c_size_t, # Plaintext
)

class Session(object):
    def __init__(self):
216
217
        self.buf = create_string_buffer(lib.olm_session_size())
        self.ptr = lib.olm_session(self.buf)
218
219
220

    def pickle(self, key):
        key_buffer = create_string_buffer(key)
221
        pickle_length = lib.olm_pickle_session_length(self.ptr)
222
        pickle_buffer = create_string_buffer(pickle_length)
223
        lib.olm_pickle_session(
224
225
226
227
228
229
230
            self.ptr, key_buffer, len(key), pickle_buffer, pickle_length
        )
        return pickle_buffer.raw

    def unpickle(self, key, pickle):
        key_buffer = create_string_buffer(key)
        pickle_buffer = create_string_buffer(pickle)
231
        lib.olm_unpickle_session(
232
233
234
            self.ptr, key_buffer, len(key), pickle_buffer, len(pickle)
        )

235
    def create_outbound(self, account, identity_key, one_time_key):
236
        r_length = lib.olm_create_outbound_session_random_length(self.ptr)
237
238
239
240
        random = read_random(r_length)
        random_buffer = create_string_buffer(random)
        identity_key_buffer = create_string_buffer(identity_key)
        one_time_key_buffer = create_string_buffer(one_time_key)
241
        lib.olm_create_outbound_session(
242
243
244
245
246
247
248
249
250
            self.ptr,
            account.ptr,
            identity_key_buffer, len(identity_key),
            one_time_key_buffer, len(one_time_key),
            random_buffer, r_length
        )

    def create_inbound(self, account, one_time_key_message):
        one_time_key_message_buffer = create_string_buffer(one_time_key_message)
251
        lib.olm_create_inbound_session(
252
253
254
255
256
257
258
            self.ptr,
            account.ptr,
            one_time_key_message_buffer, len(one_time_key_message)
        )

    def matches_inbound(self, one_time_key_message):
        one_time_key_message_buffer = create_string_buffer(one_time_key_message)
259
        return bool(lib.olm_create_inbound_session(
260
261
262
263
264
            self.ptr,
            one_time_key_message_buffer, len(one_time_key_message)
        ))

    def encrypt(self, plaintext):
265
        r_length = lib.olm_encrypt_random_length(self.ptr)
266
267
268
        random = read_random(r_length)
        random_buffer = create_string_buffer(random)

269
270
        message_type = lib.olm_encrypt_message_type(self.ptr)
        message_length = lib.olm_encrypt_message_length(
271
272
273
274
275
276
            self.ptr, len(plaintext)
        )
        message_buffer = create_string_buffer(message_length)

        plaintext_buffer = create_string_buffer(plaintext)

277
        lib.olm_encrypt(
278
279
280
281
282
283
284
285
286
            self.ptr,
            plaintext_buffer, len(plaintext),
            random_buffer, r_length,
            message_buffer, message_length,
        )
        return message_type, message_buffer.raw

    def decrypt(self, message_type, message):
        message_buffer = create_string_buffer(message)
287
        max_plaintext_length = lib.olm_decrypt_max_plaintext_length(
288
289
290
291
            self.ptr, message_type, message_buffer, len(message)
        )
        plaintext_buffer = create_string_buffer(max_plaintext_length)
        message_buffer = create_string_buffer(message)
292
        plaintext_length = lib.olm_decrypt(
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
            self.ptr, message_type, message_buffer, len(message),
            plaintext_buffer, max_plaintext_length
        )
        return plaintext_buffer.raw[:plaintext_length]

    def clear(self):
        pass


if __name__ == '__main__':
    import argparse
    import sys
    import os
    import yaml

    parser = argparse.ArgumentParser()
    parser.add_argument("--key", help="Account encryption key", default="")
    commands = parser.add_subparsers()

    create_account = commands.add_parser("create_account", help="Create a new account")
    create_account.add_argument("account_file", help="Local account file")

    def do_create_account(args):
        if os.path.exists(args.account_file):
            sys.stderr.write("Account %r file already exists" % (
                args.account_file,
            ))
            sys.exit(1)
        account = Account()
        account.create()
        with open(args.account_file, "wb") as f:
            f.write(account.pickle(args.key))

    create_account.set_defaults(func=do_create_account)

    keys = commands.add_parser("keys", help="List public keys for an account")
329
330
    keys.add_argument("--user-id", default="@user:example.com")
    keys.add_argument("--device-id", default="default_device_id")
331
332
333
    keys.add_argument("--valid-after", default=0, type=int)
    keys.add_argument("--valid-until", default=0, type=int)
    keys.add_argument("account_file", help="Local account file")
334
335
336
337
338
339

    def do_keys(args):
        account = Account()
        with open(args.account_file, "rb") as f:
            account.unpickle(args.key, f.read())
        result1 = {
340
341
342
            "device_keys": account.identity_keys(
                args.user_id, args.device_id,
                args.valid_after, args.valid_until,
343
344
            ),
            "one_time_keys": account.one_time_keys(),
345
346
        }
        try:
347
            yaml.safe_dump(result1, sys.stdout, default_flow_style=False)
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
        except:
            pass

    keys.set_defaults(func=do_keys)

    outbound = commands.add_parser("outbound", help="Create an outbound session")
    outbound.add_argument("account_file", help="Local account file")
    outbound.add_argument("session_file", help="Local session file")
    outbound.add_argument("identity_key", help="Remote identity key")
    outbound.add_argument("one_time_key", help="Remote one time key")

    def do_outbound(args):
        if os.path.exists(args.session_file):
            sys.stderr.write("Session %r file already exists" % (
                args.account_file,
            ))
            sys.exit(1)
        account = Account()
        with open(args.account_file, "rb") as f:
            account.unpickle(args.key, f.read())
        session = Session()
        session.create_outbound(
370
            account, args.identity_key, args.one_time_key
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
        )
        with open(args.session_file, "wb") as f:
            f.write(session.pickle(args.key))

    outbound.set_defaults(func=do_outbound)

    def open_in(path):
        if path == "-":
            return sys.stdin
        else:
            return open(path, "rb")

    def open_out(path):
        if path == "-":
            return sys.stdout
        else:
            return open(path, "wb")

    inbound = commands.add_parser("inbound", help="Create an inbound session")
    inbound.add_argument("account_file", help="Local account file")
    inbound.add_argument("session_file", help="Local session file")
    inbound.add_argument("message_file", help="Message", default="-")
    inbound.add_argument("plaintext_file", help="Plaintext", default="-")

    def do_inbound(args):
        if os.path.exists(args.session_file):
            sys.stderr.write("Session %r file already exists" % (
                args.account_file,
            ))
            sys.exit(1)
        account = Account()
        with open(args.account_file, "rb") as f:
            account.unpickle(args.key, f.read())
        with open_in(args.message_file) as f:
            message_type = f.read(8)
            message = f.read()
        if message_type != "PRE_KEY ":
            sys.stderr.write("Expecting a PRE_KEY message")
            sys.exit(1)
        session = Session()
        session.create_inbound(account, message)
        plaintext = session.decrypt(0, message)
        with open(args.session_file, "wb") as f:
            f.write(session.pickle(args.key))
        with open_out(args.plaintext_file) as f:
            f.write(plaintext)

    inbound.set_defaults(func=do_inbound)

    encrypt = commands.add_parser("encrypt", help="Encrypt a message")
    encrypt.add_argument("session_file", help="Local session file")
    encrypt.add_argument("plaintext_file", help="Plaintext", default="-")
    encrypt.add_argument("message_file", help="Message", default="-")

    def do_encrypt(args):
        session = Session()
        with open(args.session_file, "rb") as f:
            session.unpickle(args.key, f.read())
        with open_in(args.plaintext_file) as f:
            plaintext = f.read()
        message_type, message = session.encrypt(plaintext)
        with open(args.session_file, "wb") as f:
            f.write(session.pickle(args.key))
        with open_out(args.message_file) as f:
            f.write(["PRE_KEY ", "MESSAGE "][message_type])
            f.write(message)

    encrypt.set_defaults(func=do_encrypt)

    decrypt = commands.add_parser("decrypt", help="Decrypt a message")
    decrypt.add_argument("session_file", help="Local session file")
    decrypt.add_argument("message_file", help="Message", default="-")
443
    decrypt.add_argument("plaintext_file", help="Plaintext", default="-")
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467

    def do_decrypt(args):
        session = Session()
        with open(args.session_file, "rb") as f:
            session.unpickle(args.key, f.read())
        with open_in(args.message_file) as f:
            message_type = f.read(8)
            message = f.read()
        if message_type not in {"PRE_KEY ", "MESSAGE "}:
            sys.stderr.write("Expecting a PRE_KEY or MESSAGE message")
            sys.exit(1)
        message_type = 1 if message_type == "MESSAGE " else 0
        plaintext = session.decrypt(message_type, message)
        with open(args.session_file, "wb") as f:
            f.write(session.pickle(args.key))
        with open_out(args.plaintext_file) as f:
            f.write(plaintext)

    decrypt.set_defaults(func=do_decrypt)

    args = parser.parse_args()
    args.func(args)