Commit 79485b22 authored by Richard van der Hoff's avatar Richard van der Hoff Committed by GitHub
Browse files

Merge pull request #19 from matrix-org/rav/megolm_signing

Sign megolm messages
parents 50cd2b2a 2fc83aa9
......@@ -46,6 +46,11 @@ enum OlmErrorCode {
*/
OLM_BAD_LEGACY_ACCOUNT_PICKLE = 13,
/**
* Received message had a bad signature
*/
OLM_BAD_SIGNATURE = 14,
/* remember to update the list of string constants in error.c when updating
* this list. */
};
......
......@@ -97,7 +97,7 @@ size_t olm_init_inbound_group_session(
OlmInboundGroupSession *session,
uint32_t message_index,
/* base64-encoded key */
/* base64-encoded keys */
uint8_t const * session_key, size_t session_key_length
);
......
......@@ -37,7 +37,8 @@ extern "C" {
size_t _olm_encode_group_message_length(
uint32_t chain_index,
size_t ciphertext_length,
size_t mac_length
size_t mac_length,
size_t signature_length
);
/**
......@@ -49,7 +50,8 @@ size_t _olm_encode_group_message_length(
* output: where to write the output. Should be at least
* olm_encode_group_message_length() bytes long.
* ciphertext_ptr: returns the address that the ciphertext
* should be written to, followed by the MAC.
* should be written to, followed by the MAC and the
* signature.
*
* Returns the size of the message, up to the MAC.
*/
......@@ -76,7 +78,7 @@ struct _OlmDecodeGroupMessageResults {
*/
void _olm_decode_group_message(
const uint8_t *input, size_t input_length,
size_t mac_length,
size_t mac_length, size_t signature_length,
/* output structure: updated with results */
struct _OlmDecodeGroupMessageResults *results
......
......@@ -30,6 +30,7 @@ static const char * ERRORS[] = {
"BAD_SESSION_KEY",
"UNKNOWN_MESSAGE_INDEX",
"BAD_LEGACY_ACCOUNT_PICKLE",
"BAD_SIGNATURE",
};
const char * _olm_error_to_string(enum OlmErrorCode error)
......
......@@ -19,6 +19,7 @@
#include "olm/base64.h"
#include "olm/cipher.h"
#include "olm/crypto.h"
#include "olm/error.h"
#include "olm/megolm.h"
#include "olm/memory.h"
......@@ -29,6 +30,7 @@
#define OLM_PROTOCOL_VERSION 3
#define PICKLE_VERSION 1
#define SESSION_KEY_VERSION 1
struct OlmInboundGroupSession {
/** our earliest known ratchet value */
......@@ -37,6 +39,9 @@ struct OlmInboundGroupSession {
/** The most recent ratchet value */
Megolm latest_ratchet;
/** The ed25519 signing key */
struct _olm_ed25519_public_key signing_key;
enum OlmErrorCode last_error;
};
......@@ -65,30 +70,56 @@ size_t olm_clear_inbound_group_session(
return sizeof(OlmInboundGroupSession);
}
#define SESSION_KEY_RAW_LENGTH \
(1 + MEGOLM_RATCHET_LENGTH + ED25519_PUBLIC_KEY_LENGTH)
/** init the session keys from the un-base64-ed session keys */
static size_t _init_group_session_keys(
OlmInboundGroupSession *session,
uint32_t message_index,
const uint8_t *key_buf
) {
const uint8_t *ptr = key_buf;
size_t version = *ptr++;
if (version != SESSION_KEY_VERSION) {
session->last_error = OLM_BAD_SESSION_KEY;
return (size_t)-1;
}
megolm_init(&session->initial_ratchet, ptr, message_index);
megolm_init(&session->latest_ratchet, ptr, message_index);
ptr += MEGOLM_RATCHET_LENGTH;
memcpy(
session->signing_key.public_key, ptr, ED25519_PUBLIC_KEY_LENGTH
);
ptr += ED25519_PUBLIC_KEY_LENGTH;
return 0;
}
size_t olm_init_inbound_group_session(
OlmInboundGroupSession *session,
uint32_t message_index,
const uint8_t * session_key, size_t session_key_length
) {
uint8_t key_buf[MEGOLM_RATCHET_LENGTH];
uint8_t key_buf[SESSION_KEY_RAW_LENGTH];
size_t raw_length = _olm_decode_base64_length(session_key_length);
size_t result;
if (raw_length == (size_t)-1) {
session->last_error = OLM_INVALID_BASE64;
return (size_t)-1;
}
if (raw_length != MEGOLM_RATCHET_LENGTH) {
if (raw_length != SESSION_KEY_RAW_LENGTH) {
session->last_error = OLM_BAD_SESSION_KEY;
return (size_t)-1;
}
_olm_decode_base64(session_key, session_key_length, key_buf);
megolm_init(&session->initial_ratchet, key_buf, message_index);
megolm_init(&session->latest_ratchet, key_buf, message_index);
_olm_unset(key_buf, MEGOLM_RATCHET_LENGTH);
return 0;
result = _init_group_session_keys(session, message_index, key_buf);
_olm_unset(key_buf, SESSION_KEY_RAW_LENGTH);
return result;
}
static size_t raw_pickle_length(
......@@ -98,6 +129,7 @@ static size_t raw_pickle_length(
length += _olm_pickle_uint32_length(PICKLE_VERSION);
length += megolm_pickle_length(&session->initial_ratchet);
length += megolm_pickle_length(&session->latest_ratchet);
length += _olm_pickle_ed25519_public_key_length(&session->signing_key);
return length;
}
......@@ -124,6 +156,7 @@ size_t olm_pickle_inbound_group_session(
pos = _olm_pickle_uint32(pos, PICKLE_VERSION);
pos = megolm_pickle(&session->initial_ratchet, pos);
pos = megolm_pickle(&session->latest_ratchet, pos);
pos = _olm_pickle_ed25519_public_key(pos, &session->signing_key);
return _olm_enc_output(key, key_length, pickled, raw_length);
}
......@@ -153,6 +186,7 @@ size_t olm_unpickle_inbound_group_session(
}
pos = megolm_unpickle(&session->initial_ratchet, pos, end);
pos = megolm_unpickle(&session->latest_ratchet, pos, end);
pos = _olm_unpickle_ed25519_public_key(pos, end, &session->signing_key);
if (end != pos) {
/* We had the wrong number of bytes in the input. */
......@@ -175,6 +209,7 @@ static size_t _decrypt_max_plaintext_length(
_olm_decode_group_message(
message, message_length,
megolm_cipher->ops->mac_length(megolm_cipher),
ED25519_SIGNATURE_LENGTH,
&decoded_results);
if (decoded_results.version != OLM_PROTOCOL_VERSION) {
......@@ -224,6 +259,7 @@ static size_t _decrypt(
_olm_decode_group_message(
message, message_length,
megolm_cipher->ops->mac_length(megolm_cipher),
ED25519_SIGNATURE_LENGTH,
&decoded_results);
if (decoded_results.version != OLM_PROTOCOL_VERSION) {
......@@ -231,11 +267,28 @@ static size_t _decrypt(
return (size_t)-1;
}
if (!decoded_results.has_message_index || !decoded_results.ciphertext ) {
if (!decoded_results.has_message_index || !decoded_results.ciphertext) {
session->last_error = OLM_BAD_MESSAGE_FORMAT;
return (size_t)-1;
}
/* verify the signature. We could do this before decoding the message, but
* we allow for the possibility of future protocol versions which use a
* different signing mechanism; we would rather throw "BAD_MESSAGE_VERSION"
* than "BAD_SIGNATURE" in this case.
*/
message_length -= ED25519_SIGNATURE_LENGTH;
r = _olm_crypto_ed25519_verify(
&session->signing_key,
message, message_length,
message + message_length
);
if (!r) {
session->last_error = OLM_BAD_SIGNATURE;
return (size_t)-1;
}
max_length = megolm_cipher->ops->decrypt_max_plaintext_length(
megolm_cipher,
decoded_results.ciphertext_length
......
......@@ -334,12 +334,14 @@ static const std::uint8_t GROUP_CIPHERTEXT_TAG = 022;
size_t _olm_encode_group_message_length(
uint32_t message_index,
size_t ciphertext_length,
size_t mac_length
size_t mac_length,
size_t signature_length
) {
size_t length = VERSION_LENGTH;
length += 1 + varint_length(message_index);
length += 1 + varstring_length(ciphertext_length);
length += mac_length;
length += signature_length;
return length;
}
......@@ -361,11 +363,12 @@ size_t _olm_encode_group_message(
void _olm_decode_group_message(
const uint8_t *input, size_t input_length,
size_t mac_length,
size_t mac_length, size_t signature_length,
struct _OlmDecodeGroupMessageResults *results
) {
std::uint8_t const * pos = input;
std::uint8_t const * end = input + input_length - mac_length;
std::size_t trailer_length = mac_length + signature_length;
std::uint8_t const * end = input + input_length - trailer_length;
std::uint8_t const * unknown = nullptr;
bool has_message_index = false;
......@@ -373,8 +376,7 @@ void _olm_decode_group_message(
results->ciphertext = nullptr;
results->ciphertext_length = 0;
if (pos == end) return;
if (input_length < mac_length) return;
if (input_length < trailer_length) return;
results->version = *(pos++);
while (pos != end) {
......
......@@ -20,6 +20,7 @@
#include "olm/base64.h"
#include "olm/cipher.h"
#include "olm/crypto.h"
#include "olm/error.h"
#include "olm/megolm.h"
#include "olm/memory.h"
......@@ -31,11 +32,15 @@
#define SESSION_ID_RANDOM_BYTES 4
#define GROUP_SESSION_ID_LENGTH (sizeof(struct timeval) + SESSION_ID_RANDOM_BYTES)
#define PICKLE_VERSION 1
#define SESSION_KEY_VERSION 1
struct OlmOutboundGroupSession {
/** the Megolm ratchet providing the encryption keys */
Megolm ratchet;
/** The ed25519 keypair used for signing the messages */
struct _olm_ed25519_key_pair signing_key;
/** unique identifier for this session */
uint8_t session_id[GROUP_SESSION_ID_LENGTH];
......@@ -74,6 +79,7 @@ static size_t raw_pickle_length(
size_t length = 0;
length += _olm_pickle_uint32_length(PICKLE_VERSION);
length += megolm_pickle_length(&(session->ratchet));
length += _olm_pickle_ed25519_key_pair_length(&(session->signing_key));
length += _olm_pickle_bytes_length(session->session_id,
GROUP_SESSION_ID_LENGTH);
return length;
......@@ -101,6 +107,7 @@ size_t olm_pickle_outbound_group_session(
pos = _olm_enc_output_pos(pickled, raw_length);
pos = _olm_pickle_uint32(pos, PICKLE_VERSION);
pos = megolm_pickle(&(session->ratchet), pos);
pos = _olm_pickle_ed25519_key_pair(pos, &(session->signing_key));
pos = _olm_pickle_bytes(pos, session->session_id, GROUP_SESSION_ID_LENGTH);
return _olm_enc_output(key, key_length, pickled, raw_length);
......@@ -130,6 +137,7 @@ size_t olm_unpickle_outbound_group_session(
return (size_t)-1;
}
pos = megolm_unpickle(&(session->ratchet), pos, end);
pos = _olm_unpickle_ed25519_key_pair(pos, end, &(session->signing_key));
pos = _olm_unpickle_bytes(pos, end, session->session_id, GROUP_SESSION_ID_LENGTH);
if (end != pos) {
......@@ -148,7 +156,9 @@ size_t olm_init_outbound_group_session_random_length(
/* we need data to initialize the megolm ratchet, plus some more for the
* session id.
*/
return MEGOLM_RATCHET_LENGTH + SESSION_ID_RANDOM_BYTES;
return MEGOLM_RATCHET_LENGTH +
ED25519_RANDOM_LENGTH +
SESSION_ID_RANDOM_BYTES;
}
size_t olm_init_outbound_group_session(
......@@ -164,6 +174,9 @@ size_t olm_init_outbound_group_session(
megolm_init(&(session->ratchet), random, 0);
random += MEGOLM_RATCHET_LENGTH;
_olm_crypto_ed25519_generate_key(random, &(session->signing_key));
random += ED25519_RANDOM_LENGTH;
/* initialise the session id. This just has to be unique. We use the
* current time plus some random data.
*/
......@@ -188,7 +201,8 @@ static size_t raw_message_length(
return _olm_encode_group_message_length(
session->ratchet.counter,
ciphertext_length, mac_length);
ciphertext_length, mac_length, ED25519_SIGNATURE_LENGTH
);
}
size_t olm_group_encrypt_message_length(
......@@ -241,6 +255,13 @@ static size_t _encrypt(
megolm_advance(&(session->ratchet));
/* sign the whole thing with the ed25519 key. */
_olm_crypto_ed25519_sign(
&(session->signing_key),
buffer, message_length,
buffer + message_length
);
return result;
}
......@@ -302,23 +323,40 @@ uint32_t olm_outbound_group_session_message_index(
return session->ratchet.counter;
}
#define SESSION_KEY_RAW_LENGTH \
(1 + MEGOLM_RATCHET_LENGTH + ED25519_PUBLIC_KEY_LENGTH)
size_t olm_outbound_group_session_key_length(
const OlmOutboundGroupSession *session
) {
return _olm_encode_base64_length(MEGOLM_RATCHET_LENGTH);
return _olm_encode_base64_length(SESSION_KEY_RAW_LENGTH);
}
size_t olm_outbound_group_session_key(
OlmOutboundGroupSession *session,
uint8_t * key, size_t key_length
) {
if (key_length < olm_outbound_group_session_key_length(session)) {
uint8_t *raw;
uint8_t *ptr;
size_t encoded_length = olm_outbound_group_session_key_length(session);
if (key_length < encoded_length) {
session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL;
return (size_t)-1;
}
return _olm_encode_base64(
megolm_get_data(&session->ratchet),
MEGOLM_RATCHET_LENGTH, key
/* put the raw data at the end of the output buffer. */
raw = ptr = key + encoded_length - SESSION_KEY_RAW_LENGTH;
*ptr++ = SESSION_KEY_VERSION;
memcpy(ptr, megolm_get_data(&session->ratchet), MEGOLM_RATCHET_LENGTH);
ptr += MEGOLM_RATCHET_LENGTH;
memcpy(
ptr, session->signing_key.public_key.public_key,
ED25519_PUBLIC_KEY_LENGTH
);
ptr += ED25519_PUBLIC_KEY_LENGTH;
return _olm_encode_base64(raw, SESSION_KEY_RAW_LENGTH, key);
}
......@@ -80,7 +80,6 @@ int main() {
assert_equals(pickle1, pickle2, pickle_length);
}
{
TestCase test_case("Group message send/receive");
......@@ -89,6 +88,7 @@ int main() {
"0123456789ABDEF0123456789ABCDEF"
"0123456789ABDEF0123456789ABCDEF"
"0123456789ABDEF0123456789ABCDEF"
"0123456789ABDEF0123456789ABCDEF"
"0123456789ABDEF0123456789ABCDEF";
......@@ -97,7 +97,7 @@ int main() {
uint8_t memory[size];
OlmOutboundGroupSession *session = olm_outbound_group_session(memory);
assert_equals((size_t)132,
assert_equals((size_t)164,
olm_init_outbound_group_session_random_length(session));
size_t res = olm_init_outbound_group_session(
......@@ -109,7 +109,6 @@ int main() {
uint8_t session_key[session_key_len];
olm_outbound_group_session_key(session, session_key, session_key_len);
/* encode the message */
uint8_t plaintext[] = "Message";
size_t plaintext_length = sizeof(plaintext) - 1;
......@@ -148,4 +147,73 @@ int main() {
assert_equals(plaintext, plaintext_buf, res);
}
{
TestCase test_case("Invalid signature group message");
uint8_t plaintext[] = "Message";
size_t plaintext_length = sizeof(plaintext) - 1;
uint8_t session_key[] =
"ATAxMjM0NTY3ODlBQkRFRjAxMjM0NTY3ODlBQkNERUYwMTIzNDU2Nzg5QUJERUYw"
"MTIzNDU2Nzg5QUJDREVGMDEyMzQ1Njc4OUFCREVGMDEyMzQ1Njc4OUFCQ0RFRjAx"
"MjM0NTY3ODlBQkRFRjAxMjM0NTY3ODlBQkNERUYwMTIzDRt2DUEOrg/H+yUGjDTq"
"ryf8H1YF/BZjI04HwOVSZcY";
uint8_t message[] =
"AwgAEhAcbh6UpbByoyZxufQ+h2B+8XHMjhR69G8F4+qjMaFlnIXusJZX3r8LnROR"
"G9T3DXFdbVuvIWrLyRfm4i8QRbe8VPwGRFG57B1CtmxanuP8bHtnnYqlwPsD";
size_t msglen = sizeof(message)-1;
/* build the inbound session */
size_t size = olm_inbound_group_session_size();
uint8_t inbound_session_memory[size];
OlmInboundGroupSession *inbound_session =
olm_inbound_group_session(inbound_session_memory);
size_t res = olm_init_inbound_group_session(
inbound_session, 0U, session_key, sizeof(session_key)-1
);
assert_equals((size_t)0, res);
/* decode the message */
/* olm_group_decrypt_max_plaintext_length destroys the input so we have to
copy it. */
uint8_t msgcopy[msglen];
memcpy(msgcopy, message, msglen);
size = olm_group_decrypt_max_plaintext_length(
inbound_session, msgcopy, msglen
);
memcpy(msgcopy, message, msglen);
uint8_t plaintext_buf[size];
res = olm_group_decrypt(
inbound_session, msgcopy, msglen, plaintext_buf, size
);
assert_equals(plaintext_length, res);
assert_equals(plaintext, plaintext_buf, res);
/* now twiddle the signature */
message[msglen-1] = 'E';
memcpy(msgcopy, message, msglen);
assert_equals(
size,
olm_group_decrypt_max_plaintext_length(
inbound_session, msgcopy, msglen
)
);
memcpy(msgcopy, message, msglen);
res = olm_group_decrypt(
inbound_session, msgcopy, msglen,
plaintext_buf, size
);
assert_equals((size_t)-1, res);
assert_equals(
std::string("BAD_SIGNATURE"),
std::string(olm_inbound_group_session_last_error(inbound_session))
);
}
}
......@@ -67,8 +67,8 @@ assert_equals(message2, output, 35);
TestCase test_case("Group message encode test");
size_t length = _olm_encode_group_message_length(200, 10, 8);
size_t expected_length = 1 + (1+2) + (2+10) + 8;
size_t length = _olm_encode_group_message_length(200, 10, 8, 64);
size_t expected_length = 1 + (1+2) + (2+10) + 8 + 64;
assert_equals(expected_length, length);
uint8_t output[50];
......@@ -99,9 +99,10 @@ assert_equals(message2, output, 35);
"\x03"
"\x08\xC8\x01"
"\x12\x0A" "ciphertext"
"hmacsha2";
"hmacsha2"
"ed25519signature";
_olm_decode_group_message(message, sizeof(message)-1, 8, &results);
_olm_decode_group_message(message, sizeof(message)-1, 8, 16, &results);
assert_equals(std::uint8_t(3), results.version);
assert_equals(1, results.has_message_index);
assert_equals(std::uint32_t(200), results.message_index);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment