fuzz_group_decrypt.cpp 1.97 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
#include "olm/olm.hh"

#include "fuzzing.hh"

int main(int argc, const char *argv[]) {
    size_t ignored;
    if (argc <= 2) {
        const char * message = "Usage: decrypt <pickle_key> <group_session>\n";
        ignored = write(STDERR_FILENO, message, strlen(message));
        exit(3);
    }

    const char * key = argv[1];
    size_t key_length = strlen(key);


    int session_fd = check_errno(
        "Error opening session file", open(argv[2], O_RDONLY)
    );

    uint8_t *session_buffer;
    ssize_t session_length = check_errno(
        "Error reading session file", read_file(session_fd, &session_buffer)
    );

    int message_fd = STDIN_FILENO;
    uint8_t * message_buffer;
    ssize_t message_length = check_errno(
        "Error reading message file", read_file(message_fd, &message_buffer)
    );

    uint8_t * tmp_buffer = (uint8_t *) malloc(message_length);
    memcpy(tmp_buffer, message_buffer, message_length);

    uint8_t session_memory[olm_inbound_group_session_size()];
    OlmInboundGroupSession * session = olm_inbound_group_session(session_memory);
    check_error(
        olm_inbound_group_session_last_error,
        session,
        "Error unpickling session",
        olm_unpickle_inbound_group_session(
            session, key, key_length, session_buffer, session_length
        )
    );

    size_t max_length = check_error(
        olm_inbound_group_session_last_error,
        session,
        "Error getting plaintext length",
        olm_group_decrypt_max_plaintext_length(
            session, tmp_buffer, message_length
        )
    );

    uint8_t plaintext[max_length];

    size_t length = check_error(
        olm_inbound_group_session_last_error,
        session,
        "Error decrypting message",
        olm_group_decrypt(
            session,
            message_buffer, message_length,
            plaintext, max_length
        )
    );

    ignored = write(STDOUT_FILENO, plaintext, length);
    ignored = write(STDOUT_FILENO, "\n", 1);
    return ignored;
}