Commit 2fc83aa9 authored by Richard van der Hoff's avatar Richard van der Hoff
Browse files

Sign megolm messages

Add ed25519 keys to the inbound and outbound sessions, and use them to sign and
verify megolm messages.

We just stuff the ed25519 public key in alongside the megolm session key (and
add a version byte), to save adding more boilerplate to the JS/python/etc
layers.
parent 50cd2b2a
......@@ -46,6 +46,11 @@ enum OlmErrorCode {
*/
OLM_BAD_LEGACY_ACCOUNT_PICKLE = 13,
/**
* Received message had a bad signature
*/
OLM_BAD_SIGNATURE = 14,
/* remember to update the list of string constants in error.c when updating
* this list. */
};
......
......@@ -97,7 +97,7 @@ size_t olm_init_inbound_group_session(
OlmInboundGroupSession *session,
uint32_t message_index,
/* base64-encoded key */
/* base64-encoded keys */
uint8_t const * session_key, size_t session_key_length
);
......
......@@ -37,7 +37,8 @@ extern "C" {
size_t _olm_encode_group_message_length(
uint32_t chain_index,
size_t ciphertext_length,
size_t mac_length
size_t mac_length,
size_t signature_length
);
/**
......@@ -49,7 +50,8 @@ size_t _olm_encode_group_message_length(
* output: where to write the output. Should be at least
* olm_encode_group_message_length() bytes long.
* ciphertext_ptr: returns the address that the ciphertext
* should be written to, followed by the MAC.
* should be written to, followed by the MAC and the
* signature.
*
* Returns the size of the message, up to the MAC.
*/
......@@ -76,7 +78,7 @@ struct _OlmDecodeGroupMessageResults {
*/
void _olm_decode_group_message(
const uint8_t *input, size_t input_length,
size_t mac_length,
size_t mac_length, size_t signature_length,
/* output structure: updated with results */
struct _OlmDecodeGroupMessageResults *results
......
......@@ -30,6 +30,7 @@ static const char * ERRORS[] = {
"BAD_SESSION_KEY",
"UNKNOWN_MESSAGE_INDEX",
"BAD_LEGACY_ACCOUNT_PICKLE",
"BAD_SIGNATURE",
};
const char * _olm_error_to_string(enum OlmErrorCode error)
......
......@@ -19,6 +19,7 @@
#include "olm/base64.h"
#include "olm/cipher.h"
#include "olm/crypto.h"
#include "olm/error.h"
#include "olm/megolm.h"
#include "olm/memory.h"
......@@ -29,6 +30,7 @@
#define OLM_PROTOCOL_VERSION 3
#define PICKLE_VERSION 1
#define SESSION_KEY_VERSION 1
struct OlmInboundGroupSession {
/** our earliest known ratchet value */
......@@ -37,6 +39,9 @@ struct OlmInboundGroupSession {
/** The most recent ratchet value */
Megolm latest_ratchet;
/** The ed25519 signing key */
struct _olm_ed25519_public_key signing_key;
enum OlmErrorCode last_error;
};
......@@ -65,30 +70,56 @@ size_t olm_clear_inbound_group_session(
return sizeof(OlmInboundGroupSession);
}
#define SESSION_KEY_RAW_LENGTH \
(1 + MEGOLM_RATCHET_LENGTH + ED25519_PUBLIC_KEY_LENGTH)
/** init the session keys from the un-base64-ed session keys */
static size_t _init_group_session_keys(
OlmInboundGroupSession *session,
uint32_t message_index,
const uint8_t *key_buf
) {
const uint8_t *ptr = key_buf;
size_t version = *ptr++;
if (version != SESSION_KEY_VERSION) {
session->last_error = OLM_BAD_SESSION_KEY;
return (size_t)-1;
}
megolm_init(&session->initial_ratchet, ptr, message_index);
megolm_init(&session->latest_ratchet, ptr, message_index);
ptr += MEGOLM_RATCHET_LENGTH;
memcpy(
session->signing_key.public_key, ptr, ED25519_PUBLIC_KEY_LENGTH
);
ptr += ED25519_PUBLIC_KEY_LENGTH;
return 0;
}
size_t olm_init_inbound_group_session(
OlmInboundGroupSession *session,
uint32_t message_index,
const uint8_t * session_key, size_t session_key_length
) {
uint8_t key_buf[MEGOLM_RATCHET_LENGTH];
uint8_t key_buf[SESSION_KEY_RAW_LENGTH];
size_t raw_length = _olm_decode_base64_length(session_key_length);
size_t result;
if (raw_length == (size_t)-1) {
session->last_error = OLM_INVALID_BASE64;
return (size_t)-1;
}
if (raw_length != MEGOLM_RATCHET_LENGTH) {
if (raw_length != SESSION_KEY_RAW_LENGTH) {
session->last_error = OLM_BAD_SESSION_KEY;
return (size_t)-1;
}
_olm_decode_base64(session_key, session_key_length, key_buf);
megolm_init(&session->initial_ratchet, key_buf, message_index);
megolm_init(&session->latest_ratchet, key_buf, message_index);
_olm_unset(key_buf, MEGOLM_RATCHET_LENGTH);
return 0;
result = _init_group_session_keys(session, message_index, key_buf);
_olm_unset(key_buf, SESSION_KEY_RAW_LENGTH);
return result;
}
static size_t raw_pickle_length(
......@@ -98,6 +129,7 @@ static size_t raw_pickle_length(
length += _olm_pickle_uint32_length(PICKLE_VERSION);
length += megolm_pickle_length(&session->initial_ratchet);
length += megolm_pickle_length(&session->latest_ratchet);
length += _olm_pickle_ed25519_public_key_length(&session->signing_key);
return length;
}
......@@ -124,6 +156,7 @@ size_t olm_pickle_inbound_group_session(
pos = _olm_pickle_uint32(pos, PICKLE_VERSION);
pos = megolm_pickle(&session->initial_ratchet, pos);
pos = megolm_pickle(&session->latest_ratchet, pos);
pos = _olm_pickle_ed25519_public_key(pos, &session->signing_key);
return _olm_enc_output(key, key_length, pickled, raw_length);
}
......@@ -153,6 +186,7 @@ size_t olm_unpickle_inbound_group_session(
}
pos = megolm_unpickle(&session->initial_ratchet, pos, end);
pos = megolm_unpickle(&session->latest_ratchet, pos, end);
pos = _olm_unpickle_ed25519_public_key(pos, end, &session->signing_key);
if (end != pos) {
/* We had the wrong number of bytes in the input. */
......@@ -175,6 +209,7 @@ static size_t _decrypt_max_plaintext_length(
_olm_decode_group_message(
message, message_length,
megolm_cipher->ops->mac_length(megolm_cipher),
ED25519_SIGNATURE_LENGTH,
&decoded_results);
if (decoded_results.version != OLM_PROTOCOL_VERSION) {
......@@ -224,6 +259,7 @@ static size_t _decrypt(
_olm_decode_group_message(
message, message_length,
megolm_cipher->ops->mac_length(megolm_cipher),
ED25519_SIGNATURE_LENGTH,
&decoded_results);
if (decoded_results.version != OLM_PROTOCOL_VERSION) {
......@@ -231,11 +267,28 @@ static size_t _decrypt(
return (size_t)-1;
}
if (!decoded_results.has_message_index || !decoded_results.ciphertext ) {
if (!decoded_results.has_message_index || !decoded_results.ciphertext) {
session->last_error = OLM_BAD_MESSAGE_FORMAT;
return (size_t)-1;
}
/* verify the signature. We could do this before decoding the message, but
* we allow for the possibility of future protocol versions which use a
* different signing mechanism; we would rather throw "BAD_MESSAGE_VERSION"
* than "BAD_SIGNATURE" in this case.
*/
message_length -= ED25519_SIGNATURE_LENGTH;
r = _olm_crypto_ed25519_verify(
&session->signing_key,
message, message_length,
message + message_length
);
if (!r) {
session->last_error = OLM_BAD_SIGNATURE;
return (size_t)-1;
}
max_length = megolm_cipher->ops->decrypt_max_plaintext_length(
megolm_cipher,
decoded_results.ciphertext_length
......
......@@ -334,12 +334,14 @@ static const std::uint8_t GROUP_CIPHERTEXT_TAG = 022;
size_t _olm_encode_group_message_length(
uint32_t message_index,
size_t ciphertext_length,
size_t mac_length
size_t mac_length,
size_t signature_length
) {
size_t length = VERSION_LENGTH;
length += 1 + varint_length(message_index);
length += 1 + varstring_length(ciphertext_length);
length += mac_length;
length += signature_length;
return length;
}
......@@ -361,11 +363,12 @@ size_t _olm_encode_group_message(
void _olm_decode_group_message(
const uint8_t *input, size_t input_length,
size_t mac_length,
size_t mac_length, size_t signature_length,
struct _OlmDecodeGroupMessageResults *results
) {
std::uint8_t const * pos = input;
std::uint8_t const * end = input + input_length - mac_length;
std::size_t trailer_length = mac_length + signature_length;
std::uint8_t const * end = input + input_length - trailer_length;
std::uint8_t const * unknown = nullptr;
bool has_message_index = false;
......@@ -373,8 +376,7 @@ void _olm_decode_group_message(
results->ciphertext = nullptr;
results->ciphertext_length = 0;
if (pos == end) return;
if (input_length < mac_length) return;
if (input_length < trailer_length) return;
results->version = *(pos++);
while (pos != end) {
......
......@@ -20,6 +20,7 @@
#include "olm/base64.h"
#include "olm/cipher.h"
#include "olm/crypto.h"
#include "olm/error.h"
#include "olm/megolm.h"
#include "olm/memory.h"
......@@ -31,11 +32,15 @@
#define SESSION_ID_RANDOM_BYTES 4
#define GROUP_SESSION_ID_LENGTH (sizeof(struct timeval) + SESSION_ID_RANDOM_BYTES)
#define PICKLE_VERSION 1
#define SESSION_KEY_VERSION 1
struct OlmOutboundGroupSession {
/** the Megolm ratchet providing the encryption keys */
Megolm ratchet;
/** The ed25519 keypair used for signing the messages */
struct _olm_ed25519_key_pair signing_key;
/** unique identifier for this session */
uint8_t session_id[GROUP_SESSION_ID_LENGTH];
......@@ -74,6 +79,7 @@ static size_t raw_pickle_length(
size_t length = 0;
length += _olm_pickle_uint32_length(PICKLE_VERSION);
length += megolm_pickle_length(&(session->ratchet));
length += _olm_pickle_ed25519_key_pair_length(&(session->signing_key));
length += _olm_pickle_bytes_length(session->session_id,
GROUP_SESSION_ID_LENGTH);
return length;
......@@ -101,6 +107,7 @@ size_t olm_pickle_outbound_group_session(
pos = _olm_enc_output_pos(pickled, raw_length);
pos = _olm_pickle_uint32(pos, PICKLE_VERSION);
pos = megolm_pickle(&(session->ratchet), pos);
pos = _olm_pickle_ed25519_key_pair(pos, &(session->signing_key));
pos = _olm_pickle_bytes(pos, session->session_id, GROUP_SESSION_ID_LENGTH);
return _olm_enc_output(key, key_length, pickled, raw_length);
......@@ -130,6 +137,7 @@ size_t olm_unpickle_outbound_group_session(
return (size_t)-1;
}
pos = megolm_unpickle(&(session->ratchet), pos, end);
pos = _olm_unpickle_ed25519_key_pair(pos, end, &(session->signing_key));
pos = _olm_unpickle_bytes(pos, end, session->session_id, GROUP_SESSION_ID_LENGTH);
if (end != pos) {
......@@ -148,7 +156,9 @@ size_t olm_init_outbound_group_session_random_length(
/* we need data to initialize the megolm ratchet, plus some more for the
* session id.
*/
return MEGOLM_RATCHET_LENGTH + SESSION_ID_RANDOM_BYTES;
return MEGOLM_RATCHET_LENGTH +
ED25519_RANDOM_LENGTH +
SESSION_ID_RANDOM_BYTES;
}
size_t olm_init_outbound_group_session(
......@@ -164,6 +174,9 @@ size_t olm_init_outbound_group_session(
megolm_init(&(session->ratchet), random, 0);
random += MEGOLM_RATCHET_LENGTH;
_olm_crypto_ed25519_generate_key(random, &(session->signing_key));
random += ED25519_RANDOM_LENGTH;
/* initialise the session id. This just has to be unique. We use the
* current time plus some random data.
*/
......@@ -188,7 +201,8 @@ static size_t raw_message_length(
return _olm_encode_group_message_length(
session->ratchet.counter,
ciphertext_length, mac_length);
ciphertext_length, mac_length, ED25519_SIGNATURE_LENGTH
);
}
size_t olm_group_encrypt_message_length(
......@@ -241,6 +255,13 @@ static size_t _encrypt(
megolm_advance(&(session->ratchet));
/* sign the whole thing with the ed25519 key. */
_olm_crypto_ed25519_sign(
&(session->signing_key),
buffer, message_length,
buffer + message_length
);
return result;
}
......@@ -302,23 +323,40 @@ uint32_t olm_outbound_group_session_message_index(
return session->ratchet.counter;
}
#define SESSION_KEY_RAW_LENGTH \
(1 + MEGOLM_RATCHET_LENGTH + ED25519_PUBLIC_KEY_LENGTH)
size_t olm_outbound_group_session_key_length(
const OlmOutboundGroupSession *session
) {
return _olm_encode_base64_length(MEGOLM_RATCHET_LENGTH);
return _olm_encode_base64_length(SESSION_KEY_RAW_LENGTH);
}
size_t olm_outbound_group_session_key(
OlmOutboundGroupSession *session,
uint8_t * key, size_t key_length
) {
if (key_length < olm_outbound_group_session_key_length(session)) {
uint8_t *raw;
uint8_t *ptr;
size_t encoded_length = olm_outbound_group_session_key_length(session);
if (key_length < encoded_length) {
session->last_error = OLM_OUTPUT_BUFFER_TOO_SMALL;
return (size_t)-1;
}
return _olm_encode_base64(
megolm_get_data(&session->ratchet),
MEGOLM_RATCHET_LENGTH, key
/* put the raw data at the end of the output buffer. */
raw = ptr = key + encoded_length - SESSION_KEY_RAW_LENGTH;
*ptr++ = SESSION_KEY_VERSION;
memcpy(ptr, megolm_get_data(&session->ratchet), MEGOLM_RATCHET_LENGTH);
ptr += MEGOLM_RATCHET_LENGTH;
memcpy(
ptr, session->signing_key.public_key.public_key,
ED25519_PUBLIC_KEY_LENGTH
);
ptr += ED25519_PUBLIC_KEY_LENGTH;
return _olm_encode_base64(raw, SESSION_KEY_RAW_LENGTH, key);
}
......@@ -80,7 +80,6 @@ int main() {
assert_equals(pickle1, pickle2, pickle_length);
}
{
TestCase test_case("Group message send/receive");
......@@ -89,6 +88,7 @@ int main() {
"0123456789ABDEF0123456789ABCDEF"
"0123456789ABDEF0123456789ABCDEF"
"0123456789ABDEF0123456789ABCDEF"
"0123456789ABDEF0123456789ABCDEF"
"0123456789ABDEF0123456789ABCDEF";
......@@ -97,7 +97,7 @@ int main() {
uint8_t memory[size];
OlmOutboundGroupSession *session = olm_outbound_group_session(memory);
assert_equals((size_t)132,
assert_equals((size_t)164,
olm_init_outbound_group_session_random_length(session));
size_t res = olm_init_outbound_group_session(
......@@ -109,7 +109,6 @@ int main() {
uint8_t session_key[session_key_len];
olm_outbound_group_session_key(session, session_key, session_key_len);
/* encode the message */
uint8_t plaintext[] = "Message";
size_t plaintext_length = sizeof(plaintext) - 1;
......@@ -148,4 +147,73 @@ int main() {
assert_equals(plaintext, plaintext_buf, res);
}
{
TestCase test_case("Invalid signature group message");
uint8_t plaintext[] = "Message";
size_t plaintext_length = sizeof(plaintext) - 1;
uint8_t session_key[] =
"ATAxMjM0NTY3ODlBQkRFRjAxMjM0NTY3ODlBQkNERUYwMTIzNDU2Nzg5QUJERUYw"
"MTIzNDU2Nzg5QUJDREVGMDEyMzQ1Njc4OUFCREVGMDEyMzQ1Njc4OUFCQ0RFRjAx"
"MjM0NTY3ODlBQkRFRjAxMjM0NTY3ODlBQkNERUYwMTIzDRt2DUEOrg/H+yUGjDTq"
"ryf8H1YF/BZjI04HwOVSZcY";
uint8_t message[] =
"AwgAEhAcbh6UpbByoyZxufQ+h2B+8XHMjhR69G8F4+qjMaFlnIXusJZX3r8LnROR"
"G9T3DXFdbVuvIWrLyRfm4i8QRbe8VPwGRFG57B1CtmxanuP8bHtnnYqlwPsD";
size_t msglen = sizeof(message)-1;
/* build the inbound session */
size_t size = olm_inbound_group_session_size();
uint8_t inbound_session_memory[size];
OlmInboundGroupSession *inbound_session =
olm_inbound_group_session(inbound_session_memory);
size_t res = olm_init_inbound_group_session(
inbound_session, 0U, session_key, sizeof(session_key)-1
);
assert_equals((size_t)0, res);
/* decode the message */
/* olm_group_decrypt_max_plaintext_length destroys the input so we have to
copy it. */
uint8_t msgcopy[msglen];
memcpy(msgcopy, message, msglen);
size = olm_group_decrypt_max_plaintext_length(
inbound_session, msgcopy, msglen
);
memcpy(msgcopy, message, msglen);
uint8_t plaintext_buf[size];
res = olm_group_decrypt(
inbound_session, msgcopy, msglen, plaintext_buf, size
);
assert_equals(plaintext_length, res);
assert_equals(plaintext, plaintext_buf, res);
/* now twiddle the signature */
message[msglen-1] = 'E';
memcpy(msgcopy, message, msglen);
assert_equals(
size,
olm_group_decrypt_max_plaintext_length(
inbound_session, msgcopy, msglen
)
);
memcpy(msgcopy, message, msglen);
res = olm_group_decrypt(
inbound_session, msgcopy, msglen,
plaintext_buf, size
);
assert_equals((size_t)-1, res);
assert_equals(
std::string("BAD_SIGNATURE"),
std::string(olm_inbound_group_session_last_error(inbound_session))
);
}
}
......@@ -67,8 +67,8 @@ assert_equals(message2, output, 35);
TestCase test_case("Group message encode test");
size_t length = _olm_encode_group_message_length(200, 10, 8);
size_t expected_length = 1 + (1+2) + (2+10) + 8;
size_t length = _olm_encode_group_message_length(200, 10, 8, 64);
size_t expected_length = 1 + (1+2) + (2+10) + 8 + 64;
assert_equals(expected_length, length);
uint8_t output[50];
......@@ -99,9 +99,10 @@ assert_equals(message2, output, 35);
"\x03"
"\x08\xC8\x01"
"\x12\x0A" "ciphertext"
"hmacsha2";
"hmacsha2"
"ed25519signature";
_olm_decode_group_message(message, sizeof(message)-1, 8, &results);
_olm_decode_group_message(message, sizeof(message)-1, 8, 16, &results);
assert_equals(std::uint8_t(3), results.version);
assert_equals(1, results.has_message_index);
assert_equals(std::uint32_t(200), results.message_index);
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment