Commit 79485b22 authored by Richard van der Hoff's avatar Richard van der Hoff Committed by GitHub
Browse files

Merge pull request #19 from matrix-org/rav/megolm_signing

Sign megolm messages
parents 50cd2b2a 2fc83aa9
...@@ -46,6 +46,11 @@ enum OlmErrorCode { ...@@ -46,6 +46,11 @@ enum OlmErrorCode {
*/ */
OLM_BAD_LEGACY_ACCOUNT_PICKLE = 13, OLM_BAD_LEGACY_ACCOUNT_PICKLE = 13,
/**
* Received message had a bad signature
*/
OLM_BAD_SIGNATURE = 14,
/* remember to update the list of string constants in error.c when updating /* remember to update the list of string constants in error.c when updating
* this list. */ * this list. */
}; };
......
...@@ -97,7 +97,7 @@ size_t olm_init_inbound_group_session( ...@@ -97,7 +97,7 @@ size_t olm_init_inbound_group_session(
OlmInboundGroupSession *session, OlmInboundGroupSession *session,
uint32_t message_index, uint32_t message_index,
/* base64-encoded key */ /* base64-encoded keys */
uint8_t const * session_key, size_t session_key_length uint8_t const * session_key, size_t session_key_length
); );
......
...@@ -37,7 +37,8 @@ extern "C" { ...@@ -37,7 +37,8 @@ extern "C" {
size_t _olm_encode_group_message_length( size_t _olm_encode_group_message_length(
uint32_t chain_index, uint32_t chain_index,
size_t ciphertext_length, size_t ciphertext_length,
size_t mac_length size_t mac_length,
size_t signature_length
); );
/** /**
...@@ -49,7 +50,8 @@ size_t _olm_encode_group_message_length( ...@@ -49,7 +50,8 @@ size_t _olm_encode_group_message_length(
* output: where to write the output. Should be at least * output: where to write the output. Should be at least
* olm_encode_group_message_length() bytes long. * olm_encode_group_message_length() bytes long.
* ciphertext_ptr: returns the address that the ciphertext * ciphertext_ptr: returns the address that the ciphertext
* should be written to, followed by the MAC. * should be written to, followed by the MAC and the
* signature.
* *
* Returns the size of the message, up to the MAC. * Returns the size of the message, up to the MAC.
*/ */
...@@ -76,7 +78,7 @@ struct _OlmDecodeGroupMessageResults { ...@@ -76,7 +78,7 @@ struct _OlmDecodeGroupMessageResults {
*/ */
void _olm_decode_group_message( void _olm_decode_group_message(
const uint8_t *input, size_t input_length, const uint8_t *input, size_t input_length,
size_t mac_length, size_t mac_length, size_t signature_length,
/* output structure: updated with results */ /* output structure: updated with results */
struct _OlmDecodeGroupMessageResults *results struct _OlmDecodeGroupMessageResults *results
......
...@@ -30,6 +30,7 @@ static const char * ERRORS[] = { ...@@ -30,6 +30,7 @@ static const char * ERRORS[] = {
"BAD_SESSION_KEY", "BAD_SESSION_KEY",
"UNKNOWN_MESSAGE_INDEX", "UNKNOWN_MESSAGE_INDEX",
"BAD_LEGACY_ACCOUNT_PICKLE", "BAD_LEGACY_ACCOUNT_PICKLE",
"BAD_SIGNATURE",
}; };
const char * _olm_error_to_string(enum OlmErrorCode error) const char * _olm_error_to_string(enum OlmErrorCode error)
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "olm/base64.h" #include "olm/base64.h"
#include "olm/cipher.h" #include "olm/cipher.h"
#include "olm/crypto.h"
#include "olm/error.h" #include "olm/error.h"
#include "olm/megolm.h" #include "olm/megolm.h"
#include "olm/memory.h" #include "olm/memory.h"
...@@ -29,6 +30,7 @@ ...@@ -29,6 +30,7 @@
#define OLM_PROTOCOL_VERSION 3 #define OLM_PROTOCOL_VERSION 3
#define PICKLE_VERSION 1 #define PICKLE_VERSION 1
#define SESSION_KEY_VERSION 1
struct OlmInboundGroupSession { struct OlmInboundGroupSession {
/** our earliest known ratchet value */ /** our earliest known ratchet value */
...@@ -37,6 +39,9 @@ struct OlmInboundGroupSession { ...@@ -37,6 +39,9 @@ struct OlmInboundGroupSession {
/** The most recent ratchet value */ /** The most recent ratchet value */
Megolm latest_ratchet; Megolm latest_ratchet;
/** The ed25519 signing key */
struct _olm_ed25519_public_key signing_key;
enum OlmErrorCode last_error; enum OlmErrorCode last_error;
}; };
...@@ -65,30 +70,56 @@ size_t olm_clear_inbound_group_session( ...@@ -65,30 +70,56 @@ size_t olm_clear_inbound_group_session(
return sizeof(OlmInboundGroupSession); return sizeof(OlmInboundGroupSession);
} }
#define SESSION_KEY_RAW_LENGTH \
(1 + MEGOLM_RATCHET_LENGTH + ED25519_PUBLIC_KEY_LENGTH)
/** init the session keys from the un-base64-ed session keys */
static size_t _init_group_session_keys(
OlmInboundGroupSession *session,
uint32_t message_index,
const uint8_t *key_buf
) {
const uint8_t *ptr = key_buf;
size_t version = *ptr++;
if (version != SESSION_KEY_VERSION) {
session->last_error = OLM_BAD_SESSION_KEY;
return (size_t)-1;
}
megolm_init(&session->initial_ratchet, ptr, message_index);
megolm_init(&session->latest_ratchet, ptr, message_index);
ptr += MEGOLM_RATCHET_LENGTH;
memcpy(
session->signing_key.public_key, ptr, ED25519_PUBLIC_KEY_LENGTH
);
ptr += ED25519_PUBLIC_KEY_LENGTH;
return 0;
}
size_t olm_init_inbound_group_session( size_t olm_init_inbound_group_session(
OlmInboundGroupSession *session, OlmInboundGroupSession *session,
uint32_t message_index, uint32_t message_index,
const uint8_t * session_key, size_t session_key_length const uint8_t * session_key, size_t session_key_length
) { ) {
uint8_t key_buf[MEGOLM_RATCHET_LENGTH]; uint8_t key_buf[SESSION_KEY_RAW_LENGTH];
size_t raw_length = _olm_decode_base64_length(session_key_length); size_t raw_length = _olm_decode_base64_length(session_key_length);
size_t result;
if (raw_length == (size_t)-1) { if (raw_length == (size_t)-1) {
session->last_error = OLM_INVALID_BASE64; session->last_error = OLM_INVALID_BASE64;
return (size_t)-1; return (size_t)-1;
} }
if (raw_length != MEGOLM_RATCHET_LENGTH) { if (raw_length != SESSION_KEY_RAW_LENGTH) {
session->last_error = OLM_BAD_SESSION_KEY; session->last_error = OLM_BAD_SESSION_KEY;
return (size_t)-1; return (size_t)-1;
} }
_olm_decode_base64(session_key, session_key_length, key_buf); _olm_decode_base64(session_key, session_key_length, key_buf);
megolm_init(&session->initial_ratchet, key_buf, message_index); result = _init_group_session_keys(session, message_index, key_buf);
megolm_init(&session->latest_ratchet, key_buf, message_index); _olm_unset(key_buf, SESSION_KEY_RAW_LENGTH);
_olm_unset(key_buf, MEGOLM_RATCHET_LENGTH); return result;
return 0;
} }
static size_t raw_pickle_length( static size_t raw_pickle_length(
...@@ -98,6 +129,7 @@ static size_t raw_pickle_length( ...@@ -98,6 +129,7 @@ static size_t raw_pickle_length(
length += _olm_pickle_uint32_length(PICKLE_VERSION); length += _olm_pickle_uint32_length(PICKLE_VERSION);
length += megolm_pickle_length(&session->initial_ratchet); length += megolm_pickle_length(&session->initial_ratchet);
length += megolm_pickle_length(&session->latest_ratchet); length += megolm_pickle_length(&session->latest_ratchet);
length += _olm_pickle_ed25519_public_key_length(&session->signing_key);
return length; return length;
} }
...@@ -124,6 +156,7 @@ size_t olm_pickle_inbound_group_session( ...@@ -124,6 +156,7 @@ size_t olm_pickle_inbound_group_session(
pos = _olm_pickle_uint32(pos, PICKLE_VERSION); pos = _olm_pickle_uint32(pos, PICKLE_VERSION);
pos = megolm_pickle(&session->initial_ratchet, pos); pos = megolm_pickle(&session->initial_ratchet, pos);
pos = megolm_pickle(&session->latest_ratchet, pos); pos = megolm_pickle(&session->latest_ratchet, pos);
pos = _olm_pickle_ed25519_public_key(pos, &session->signing_key);
return _olm_enc_output(key, key_length, pickled, raw_length); return _olm_enc_output(key, key_length, pickled, raw_length);
} }
...@@ -153,6 +186,7 @@ size_t olm_unpickle_inbound_group_session( ...@@ -153,6 +186,7 @@ size_t olm_unpickle_inbound_group_session(
} }
pos = megolm_unpickle(&session->initial_ratchet, pos, end); pos = megolm_unpickle(&session->initial_ratchet, pos, end);
pos = megolm_unpickle(&session->latest_ratchet, pos, end); pos = megolm_unpickle(&session->latest_ratchet, pos, end);
pos = _olm_unpickle_ed25519_public_key(pos, end, &session->signing_key);
if (end != pos) { if (end != pos) {
/* We had the wrong number of bytes in the input. */ /* We had the wrong number of bytes in the input. */
...@@ -175,6 +209,7 @@ static size_t _decrypt_max_plaintext_length( ...@@ -175,6 +209,7 @@ static size_t _decrypt_max_plaintext_length(
_olm_decode_group_message( _olm_decode_group_message(
message, message_length, message, message_length,
megolm_cipher->ops->mac_length(megolm_cipher), megolm_cipher->ops->mac_length(megolm_cipher),
ED25519_SIGNATURE_LENGTH,
&decoded_results); &decoded_results);
if (decoded_results.version != OLM_PROTOCOL_VERSION) { if (decoded_results.version != OLM_PROTOCOL_VERSION) {
...@@ -224,6 +259,7 @@ static size_t _decrypt( ...@@ -224,6 +259,7 @@ static size_t _decrypt(
_olm_decode_group_message( _olm_decode_group_message(
message, message_length, message, message_length,
megolm_cipher->ops->mac_length(megolm_cipher), megolm_cipher->ops->mac_length(megolm_cipher),
ED25519_SIGNATURE_LENGTH,
&decoded_results); &decoded_results);
if (decoded_results.version != OLM_PROTOCOL_VERSION) { if (decoded_results.version != OLM_PROTOCOL_VERSION) {
...@@ -231,11 +267,28 @@ static size_t _decrypt( ...@@ -231,11 +267,28 @@ static size_t _decrypt(
return (size_t)-1; return (size_t)-1;
} }
if (!decoded_results.has_message_index || !decoded_results.ciphertext ) { if (!decoded_results.has_message_index || !decoded_results.ciphertext) {
session->last_error = OLM_BAD_MESSAGE_FORMAT; session->last_error = OLM_BAD_MESSAGE_FORMAT;
return (size_t)-1; return (size_t)-1;
} }
/* verify the signature. We could do this before decoding the message, but
* we allow for the possibility of future protocol versions which use a
* different signing mechanism; we would rather throw "BAD_MESSAGE_VERSION"
* than "BAD_SIGNATURE" in this case.
*/
message_length -= ED25519_SIGNATURE_LENGTH;
r = _olm_crypto_ed25519_verify(
&session->signing_key,
message, message_length,
message + message_length
);
if (!r) {
session->last_error = OLM_BAD_SIGNATURE;
return (size_t)-1;
}
max_length = megolm_cipher->ops->decrypt_max_plaintext_length( max_length = megolm_cipher->ops->decrypt_max_plaintext_length(
megolm_cipher, megolm_cipher,
decoded_results.ciphertext_length decoded_results.ciphertext_length
......
...@@ -334,12 +334,14 @@ static const std::uint8_t GROUP_CIPHERTEXT_TAG = 022; ...@@ -334,12 +334,14 @@ static const std::uint8_t GROUP_CIPHERTEXT_TAG = 022;
size_t _olm_encode_group_message_length( size_t _olm_encode_group_message_length(
uint32_t message_index, uint32_t message_index,
size_t ciphertext_length, size_t ciphertext_length,
size_t mac_length size_t mac_length,
size_t signature_length
) { ) {
size_t length = VERSION_LENGTH; size_t length = VERSION_LENGTH;
length += 1 + varint_length(message_index); length += 1 + varint_length(message_index);
length += 1 + varstring_length(ciphertext_length); length += 1 + varstring_length(ciphertext_length);
length += mac_length; length += mac_length;
length += signature_length;
return length; return length;
} }
...@@ -361,11 +363,12 @@ size_t _olm_encode_group_message( ...@@ -361,11 +363,12 @@ size_t _olm_encode_group_message(
void _olm_decode_group_message( void _olm_decode_group_message(
const uint8_t *input, size_t input_length, const uint8_t *input, size_t input_length,
size_t mac_length, size_t mac_length, size_t signature_length,
struct _OlmDecodeGroupMessageResults *results struct _OlmDecodeGroupMessageResults *results
) { ) {
std::uint8_t const * pos = input; std::uint8_t const * pos = input;
std::uint8_t const * end = input + input_length - mac_length; std::size_t trailer_length = mac_length + signature_length;
std::uint8_t const * end = input + input_length - trailer_length;
std::uint8_t const * unknown = nullptr; std::uint8_t const * unknown = nullptr;
bool has_message_index = false; bool has_message_index = false;
...@@ -373,8 +376,7 @@ void _olm_decode_group_message( ...@@ -373,8 +376,7 @@ void _olm_decode_group_message(
results->ciphertext = nullptr; results->ciphertext = nullptr;
results->ciphertext_length = 0; results->ciphertext_length = 0;
if (pos == end) return; if (input_length < trailer_length) return;
if (input_length < mac_length) return;
results->version = *(pos++); results->version = *(pos++);
while (pos != end) { while (pos != end) {
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "olm/base64.h" #include "olm/base64.h"
#include "olm/cipher.h" #include "olm/cipher.h"
#include "olm/crypto.h"
#include "olm/error.h" #include "olm/error.h"
#include "olm/megolm.h" #include "olm/megolm.h"
#include "olm/memory.h" #include "olm/memory.h"
...@@ -31,11 +32,15 @@ ...@@ -31,11 +32,15 @@
#define SESSION_ID_RANDOM_BYTES 4 #define SESSION_ID_RANDOM_BYTES 4
#define GROUP_SESSION_ID_LENGTH (sizeof(struct timeval) + SESSION_ID_RANDOM_BYTES) #define GROUP_SESSION_ID_LENGTH (sizeof(struct timeval) + SESSION_ID_RANDOM_BYTES)
#define PICKLE_VERSION 1 #define PICKLE_VERSION 1
#define SESSION_KEY_VERSION 1
struct OlmOutboundGroupSession { struct OlmOutboundGroupSession {
/** the Megolm ratchet providing the encryption keys */ /** the Megolm ratchet providing the encryption keys */
Megolm ratchet; Megolm ratchet;
/** The ed25519 keypair used for signing the messages */
struct _olm_ed25519_key_pair signing_key;
/** unique identifier for this session */ /** unique identifier for this session */
uint8_t session_id[GROUP_SESSION_ID_LENGTH]; uint8_t session_id[GROUP_SESSION_ID_LENGTH];
...@@ -74,6 +79,7 @@ static size_t raw_pickle_length( ...@@ -74,6 +79,7 @@ static size_t raw_pickle_length(
size_t length = 0; size_t length = 0;
length += _olm_pickle_uint32_length(PICKLE_VERSION); length += _olm_pickle_uint32_length(PICKLE_VERSION);
length += megolm_pickle_length(&(session->ratchet)); length += megolm_pickle_length(&(session->ratchet));
length += _olm_pickle_ed25519_key_pair_length(&(session->signing_key));
length += _olm_pickle_bytes_length(session->session_id, length += _olm_pickle_bytes_length(session->session_id,
GROUP_SESSION_ID_LENGTH); GROUP_SESSION_ID_LENGTH);
return length; return length;
...@@ -101,6 +107,7 @@ size_t olm_pickle_outbound_group_session( ...@@ -101,6 +107,7 @@ size_t olm_pickle_outbound_group_session(
pos = _olm_enc_output_pos(pickled, raw_length); pos = _olm_enc_output_pos(pickled, raw_length);
pos = _olm_pickle_uint32(pos, PICKLE_VERSION); pos = _olm_pickle_uint32(pos, PICKLE_VERSION);
pos = megolm_pickle(&(session->ratchet), pos); pos = megolm_pickle(&(session->ratchet), pos);
pos = _olm_pickle_ed25519_key_pair(pos, &(session->signing_key));
pos = _olm_pickle_bytes(pos, session->session_id, GROUP_SESSION_ID_LENGTH); pos = _olm_pickle_bytes(pos, session->session_id, GROUP_SESSION_ID_LENGTH);
return _olm_enc_output(key, key_length, pickled, raw_length); return _olm_enc_output(key, key_length, pickled, raw_length);
...@@ -130,6 +137,7 @@ size_t olm_unpickle_outbound_group_session( ...@@ -130,6 +137,7 @@ size_t olm_unpickle_outbound_group_session(
return (size_t)-1; return (size_t)-1;
} }
pos = megolm_unpickle(&(session->ratchet), pos, end); pos = megolm_unpickle(&(session->ratchet), pos, end);
pos = _olm_unpickle_ed25519_key_pair(pos, end, &(session->signing_key));
pos = _olm_unpickle_bytes(pos, end, session->session_id, GROUP_SESSION_ID_LENGTH); pos = _olm_unpickle_bytes(pos, end, session->session_id, GROUP_SESSION_ID_LENGTH);
if (end != pos) { if (end != pos) {
...@@ -148,7 +156,9 @@ size_t olm_init_outbound_group_session_random_length( ...@@ -148,7 +156,9 @@ size_t olm_init_outbound_group_session_random_length(
/* we need data to initialize the megolm ratchet, plus some more for the /* we need data to initialize the megolm ratchet, plus some more for the
* session id. * session id.
*/ */
return MEGOLM_RATCHET_LENGTH + SESSION_ID_RANDOM_BYTES; return MEGOLM_RATCHET_LENGTH +
ED25519_RANDOM_LENGTH +
SESSION_ID_RANDOM_BYTES;
} }
size_t olm_init_outbound_group_session( size_t olm_init_outbound_group_session(
...@@ -164,6 +174,9 @@ size_t olm_init_outbound_group_session( ...@@ -164,6 +174,9 @@ size_t olm_init_outbound_group_session(
megolm_init(&(session->ratchet), random, 0); megolm_init(&(session->ratchet), random, 0);
random += MEGOLM_RATCHET_LENGTH; random += MEGOLM_RATCHET_LENGTH;
_olm_crypto_ed25519_generate_key(random, &(session->signing_key));
random += ED25519_RANDOM_LENGTH;
/* initialise the session id. This just has to be unique. We use the /* initialise the session id. This just has to be unique. We use the
* current time plus some random data. * current time plus some random data.
*/ */
...@@ -188,7 +201,8 @@ static size_t raw_message_length( ...@@ -188,7 +201,8 @@ static size_t raw_message_length(
return _olm_encode_group_message_length( return _olm_encode_group_message_length(
session->ratchet.counter, session->ratchet.counter,
ciphertext_length, mac_length); ciphertext_length, mac_length, ED25519_SIGNATURE_LENGTH
);
} }
size_t olm_group_encrypt_message_length( size_t olm_group_encrypt_message_length(
...@@ -241,6 +255,13 @@ static size_t _encrypt( ...@@ -241,6 +255,13 @@ static size_t _encrypt(
megolm_advance(&(session->ratchet)); megolm_advance(&(session->ratchet));
/* sign the whole thing with the ed25519 key. */
_olm_crypto_ed25519_sign(
&(session->signing_key),
buffer, message_length,
buffer + message_length
);
return result; return result;
} }
...@@ -302,23 +323,40 @@ uint32_t olm_outbound_group_session_message_index( ...@@ -302,23 +323,40 @@ uint32_t olm_outbound_group_session_message_index(
return session->ratchet.counter; return session->ratchet.counter;
} }
#define SESSION_KEY_RAW_LENGTH \
(1 + MEGOLM_RATCHET_LENGTH + ED25519_PUBLIC_KEY_LENGTH)
size_t olm_outbound_group_session_key_length( size_t olm_outbound_group_session_key_length(
const OlmOutboundGroupSession *session const OlmOutboundGroupSession *session
) { ) {
return _olm_encode_base64_length(MEGOLM_RATCHET_LENGTH); return _olm_encode_base64_length(SESSION_KEY_RAW_LENGTH);
} }
size_t olm_outbound_group_session_key( size_t olm_outbound_group_session_key(
OlmOutboundGroupSession *session, OlmOutboundGroupSession *session,
uint8_t * key, size_t key_length uint8_t * key, size_t key_length
) { ) {
if (key_length < olm_outbound_group_session_key_length(session)) { uint8_t *raw;
uint8_t *ptr;
size_t encoded_length = olm_outbound_group_session_key_length(session);
if (key_length < encoded_length) {
session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL; session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL;
return (size_t)-1; return (size_t)-1;
} }
return _olm_encode_base64( /* put the raw data at the end of the output buffer. */
megolm_get_data(&session->ratchet), raw = ptr = key + encoded_length - SESSION_KEY_RAW_LENGTH;
MEGOLM_RATCHET_LENGTH, key *ptr++ = SESSION_KEY_VERSION;
memcpy(ptr, megolm_get_data(&session->ratchet), MEGOLM_RATCHET_LENGTH);
ptr += MEGOLM_RATCHET_LENGTH;
memcpy(
ptr, session->signing_key.public_key.public_key,
ED25519_PUBLIC_KEY_LENGTH
); );
ptr += ED25519_PUBLIC_KEY_LENGTH;
return _olm_encode_base64(raw, SESSION_KEY_RAW_LENGTH, key);
} }
...@@ -80,7 +80,6 @@ int main() { ...@@ -80,7 +80,6 @@ int main() {
assert_equals(pickle1, pickle2, pickle_length); assert_equals(pickle1, pickle2, pickle_length);
} }
{ {
TestCase test_case("Group message send/receive"); TestCase test_case("Group message send/receive");
...@@ -89,6 +88,7 @@ int main() { ...@@ -89,6 +88,7 @@ int main() {
"0123456789ABDEF0123456789ABCDEF" "0123456789ABDEF0123456789ABCDEF"
"0123456789ABDEF0123456789ABCDEF" "0123456789ABDEF0123456789ABCDEF"
"0123456789ABDEF0123456789ABCDEF" "0123456789ABDEF0123456789ABCDEF"
"0123456789ABDEF0123456789ABCDEF"
"0123456789ABDEF0123456789ABCDEF"; "0123456789ABDEF0123456789ABCDEF";
...@@ -97,7 +97,7 @@ int main() { ...@@ -97,7 +97,7 @@ int main() {
uint8_t memory[size]; uint8_t memory[size];
OlmOutboundGroupSession *session = olm_outbound_group_session(memory); OlmOutboundGroupSession *session = olm_outbound_group_session(memory);
assert_equals((size_t)132, assert_equals((size_t)164,
olm_init_outbound_group_session_random_length(session)); olm_init_outbound_group_session_random_length(session));
size_t res = olm_init_outbound_group_session( size_t res = olm_init_outbound_group_session(
...@@ -109,7 +109,6 @@ int main() { ...@@ -109,7 +109,6 @@ int main() {
uint8_t session_key[session_key_len]; uint8_t session_key[session_key_len];
olm_outbound_group_session_key(session, session_key, session_key_len); olm_outbound_group_session_key(session, session_key, session_key_len);
/* encode the message */ /* encode the message */
uint8_t plaintext[] = "Message"; uint8_t plaintext[] = "Message";
size_t plaintext_length = sizeof(plaintext) - 1; size_t plaintext_length = sizeof(plaintext) - 1;
...@@ -148,4 +147,73 @@ int main() { ...@@ -148,4 +147,73 @@ int main() {
assert_equals(plaintext, plaintext_buf, res); assert_equals(plaintext, plaintext_buf, res);
} }
{
TestCase test_case("Invalid signature group message");
uint8_t plaintext[] = "Message";