fuzz_group_decrypt.cpp 2.92 KB
Newer Older
1
2
3
4
#include "olm/olm.hh"

#include "fuzzing.hh"

5
6
7
8
9
10
11
12
13
14
15
16
#ifndef __AFL_FUZZ_TESTCASE_LEN
  ssize_t fuzz_len;
  #define __AFL_FUZZ_TESTCASE_LEN fuzz_len
  unsigned char fuzz_buf[1024000];
  #define __AFL_FUZZ_TESTCASE_BUF fuzz_buf
  #define __AFL_FUZZ_INIT() void sync(void);
  #define __AFL_LOOP(x) ((fuzz_len = read(0, fuzz_buf, sizeof(fuzz_buf))) > 0 ? 1 : 0)
  #define __AFL_INIT() sync()
#endif

__AFL_FUZZ_INIT();

17
18
19
int main(int argc, const char *argv[]) {
    if (argc <= 2) {
        const char * message = "Usage: decrypt <pickle_key> <group_session>\n";
20
        (void)write(STDERR_FILENO, message, strlen(message));
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
        exit(3);
    }

    const char * key = argv[1];
    size_t key_length = strlen(key);


    int session_fd = check_errno(
        "Error opening session file", open(argv[2], O_RDONLY)
    );

    uint8_t *session_buffer;
    ssize_t session_length = check_errno(
        "Error reading session file", read_file(session_fd, &session_buffer)
    );

    uint8_t session_memory[olm_inbound_group_session_size()];
    OlmInboundGroupSession * session = olm_inbound_group_session(session_memory);
    check_error(
        olm_inbound_group_session_last_error,
        session,
        "Error unpickling session",
        olm_unpickle_inbound_group_session(
            session, key, key_length, session_buffer, session_length
        )
    );

48
49
50
#ifdef __AFL_HAVE_MANUAL_CONTROL
    __AFL_INIT();
#endif
51

52
53
54
    size_t test_case_buf_len = 1024;
    uint8_t * message_buffer = (uint8_t *) malloc(test_case_buf_len);
    uint8_t * tmp_buffer = (uint8_t *) malloc(test_case_buf_len);
55

56
57
    while (__AFL_LOOP(10000)) {
        size_t message_length = __AFL_FUZZ_TESTCASE_LEN;
58

59
60
61
62
63
64
65
66
67
68
69
70
        if (message_length > test_case_buf_len) {
            message_buffer = (uint8_t *)realloc(message_buffer, message_length);
            tmp_buffer = (uint8_t *)realloc(tmp_buffer, message_length);

            if (!message_buffer || !tmp_buffer) return 1;
        }

        memcpy(message_buffer, __AFL_FUZZ_TESTCASE_BUF, message_length);
        memcpy(tmp_buffer, message_buffer, message_length);

        size_t max_length = check_error(
            olm_inbound_group_session_last_error,
71
            session,
72
73
74
75
76
            "Error getting plaintext length",
            olm_group_decrypt_max_plaintext_length(
                session, tmp_buffer, message_length
            )
        );
77

78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
        uint8_t plaintext[max_length];

        uint32_t ratchet_index;

        size_t length = check_error(
            olm_inbound_group_session_last_error,
            session,
            "Error decrypting message",
            olm_group_decrypt(
                session,
                message_buffer, message_length,
                plaintext, max_length, &ratchet_index
            )
        );

        (void)write(STDOUT_FILENO, plaintext, length);
        (void)write(STDOUT_FILENO, "\n", 1);
    }
96
97
98
99
100
101

    free(session_buffer);
    free(message_buffer);
    free(tmp_buffer);

    return EXIT_SUCCESS;
102
}