Commit 158f7ee8 authored by Mark Haines's avatar Mark Haines
Browse files

Fix crash where the message length was shorter than the length of the mac

parent a4b29278
...@@ -213,6 +213,7 @@ void olm::decode_message( ...@@ -213,6 +213,7 @@ void olm::decode_message(
reader.ciphertext_length = 0; reader.ciphertext_length = 0;
if (pos == end) return; if (pos == end) return;
if (input_length < mac_length) return;
reader.version = *(pos++); reader.version = *(pos++);
while (pos != end) { while (pos != end) {
......
...@@ -5,6 +5,7 @@ const char * test_cases[] = { ...@@ -5,6 +5,7 @@ const char * test_cases[] = {
"41776f", "41776f",
"7fff6f0101346d671201", "7fff6f0101346d671201",
"ee776f41496f674177804177778041776f6716670a677d6f670a67c2677d", "ee776f41496f674177804177778041776f6716670a677d6f670a67c2677d",
"e9e9c9c1e9e9c9e9c9c1e9e9c9c1",
}; };
...@@ -39,14 +40,17 @@ void decrypt_case(int message_type, const char * test_case) { ...@@ -39,14 +40,17 @@ void decrypt_case(int message_type, const char * test_case) {
::olm_unpickle_session(session, "", 0, pickled, sizeof(pickled)); ::olm_unpickle_session(session, "", 0, pickled, sizeof(pickled));
std::size_t message_length = strlen(test_case) / 2; std::size_t message_length = strlen(test_case) / 2;
std::uint8_t message[message_length]; std::uint8_t * message = (std::uint8_t *) ::malloc(message_length);
decode_hex(test_case, message, message_length); decode_hex(test_case, message, message_length);
size_t max_length = olm_decrypt_max_plaintext_length( size_t max_length = olm_decrypt_max_plaintext_length(
session, message_type, message, message_length session, message_type, message, message_length
); );
if (max_length == std::size_t(-1)) return; if (max_length == std::size_t(-1)) {
free(message);
return;
}
uint8_t plaintext[max_length]; uint8_t plaintext[max_length];
decode_hex(test_case, message, message_length); decode_hex(test_case, message, message_length);
...@@ -55,12 +59,13 @@ void decrypt_case(int message_type, const char * test_case) { ...@@ -55,12 +59,13 @@ void decrypt_case(int message_type, const char * test_case) {
message, message_length, message, message_length,
plaintext, max_length plaintext, max_length
); );
free(message);
} }
int main() { int main() {
{ {
TestCase("Olm decrypt test"); TestCase my_test("Olm decrypt test");
for (int i = 0; i < sizeof(test_cases)/ sizeof(const char *); ++i) { for (int i = 0; i < sizeof(test_cases)/ sizeof(const char *); ++i) {
decrypt_case(0, test_cases[i]); decrypt_case(0, test_cases[i]);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment