Commit 653790ea authored by Mark Haines's avatar Mark Haines
Browse files

Return the message index when decrypting group messages.

Applications can use the index to detect replays of the same message.
parent 6ea9fb45
...@@ -140,7 +140,8 @@ size_t olm_group_decrypt( ...@@ -140,7 +140,8 @@ size_t olm_group_decrypt(
uint8_t * message, size_t message_length, uint8_t * message, size_t message_length,
/* output */ /* output */
uint8_t * plaintext, size_t max_plaintext_length uint8_t * plaintext, size_t max_plaintext_length,
uint32_t * message_index
); );
......
...@@ -403,8 +403,8 @@ DemoUser.prototype.decryptGroup = function(jsonpacket, callback) { ...@@ -403,8 +403,8 @@ DemoUser.prototype.decryptGroup = function(jsonpacket, callback) {
throw new Error("Unknown session id " + session_id); throw new Error("Unknown session id " + session_id);
} }
var plaintext = session.decrypt(packet.body); var result = session.decrypt(packet.body);
done(plaintext); done(result.plaintext);
}, callback); }, callback);
}; };
......
...@@ -73,10 +73,12 @@ InboundGroupSession.prototype['decrypt'] = restore_stack(function( ...@@ -73,10 +73,12 @@ InboundGroupSession.prototype['decrypt'] = restore_stack(function(
// So we copy the array to a new buffer // So we copy the array to a new buffer
var message_buffer = stack(message_array); var message_buffer = stack(message_array);
var plaintext_buffer = stack(max_plaintext_length + NULL_BYTE_PADDING_LENGTH); var plaintext_buffer = stack(max_plaintext_length + NULL_BYTE_PADDING_LENGTH);
var message_index = stack(4);
var plaintext_length = inbound_group_session_method(Module["_olm_group_decrypt"])( var plaintext_length = inbound_group_session_method(Module["_olm_group_decrypt"])(
this.ptr, this.ptr,
message_buffer, message_array.length, message_buffer, message_array.length,
plaintext_buffer, max_plaintext_length plaintext_buffer, max_plaintext_length,
message_index
); );
// Pointer_stringify requires a null-terminated argument (the optional // Pointer_stringify requires a null-terminated argument (the optional
...@@ -86,7 +88,10 @@ InboundGroupSession.prototype['decrypt'] = restore_stack(function( ...@@ -86,7 +88,10 @@ InboundGroupSession.prototype['decrypt'] = restore_stack(function(
0, "i8" 0, "i8"
); );
return Pointer_stringify(plaintext_buffer); return {
"plaintext": Pointer_stringify(plaintext_buffer),
"message_index": Module['getValue'](message_index, "i32")
}
}); });
InboundGroupSession.prototype['session_id'] = restore_stack(function() { InboundGroupSession.prototype['session_id'] = restore_stack(function() {
......
...@@ -328,7 +328,7 @@ def do_group_decrypt(args): ...@@ -328,7 +328,7 @@ def do_group_decrypt(args):
session = InboundGroupSession() session = InboundGroupSession()
session.unpickle(args.key, read_base64_file(args.session_file)) session.unpickle(args.key, read_base64_file(args.session_file))
message = args.message_file.read() message = args.message_file.read()
plaintext = session.decrypt(message) plaintext, message_index = session.decrypt(message)
with open(args.session_file, "wb") as f: with open(args.session_file, "wb") as f:
f.write(session.pickle(args.key)) f.write(session.pickle(args.key))
args.plaintext_file.write(plaintext) args.plaintext_file.write(plaintext)
......
...@@ -43,6 +43,7 @@ inbound_group_session_function( ...@@ -43,6 +43,7 @@ inbound_group_session_function(
lib.olm_group_decrypt, lib.olm_group_decrypt,
c_void_p, c_size_t, # message c_void_p, c_size_t, # message
c_void_p, c_size_t, # plaintext c_void_p, c_size_t, # plaintext
POINTER(c_uint32), # message_index
) )
inbound_group_session_function(lib.olm_inbound_group_session_id_length) inbound_group_session_function(lib.olm_inbound_group_session_id_length)
...@@ -82,11 +83,14 @@ class InboundGroupSession(object): ...@@ -82,11 +83,14 @@ class InboundGroupSession(object):
) )
plaintext_buffer = create_string_buffer(max_plaintext_length) plaintext_buffer = create_string_buffer(max_plaintext_length)
message_buffer = create_string_buffer(message) message_buffer = create_string_buffer(message)
message_index = c_uint32()
plaintext_length = lib.olm_group_decrypt( plaintext_length = lib.olm_group_decrypt(
self.ptr, message_buffer, len(message), self.ptr, message_buffer, len(message),
plaintext_buffer, max_plaintext_length plaintext_buffer, max_plaintext_length,
byref(message_index)
) )
return plaintext_buffer.raw[:plaintext_length] return plaintext_buffer.raw[:plaintext_length], message_index
def session_id(self): def session_id(self):
id_length = lib.olm_inbound_group_session_id_length(self.ptr) id_length = lib.olm_inbound_group_session_id_length(self.ptr)
......
...@@ -263,7 +263,8 @@ size_t olm_group_decrypt_max_plaintext_length( ...@@ -263,7 +263,8 @@ size_t olm_group_decrypt_max_plaintext_length(
static size_t _decrypt( static size_t _decrypt(
OlmInboundGroupSession *session, OlmInboundGroupSession *session,
uint8_t * message, size_t message_length, uint8_t * message, size_t message_length,
uint8_t * plaintext, size_t max_plaintext_length uint8_t * plaintext, size_t max_plaintext_length,
uint32_t * message_index
) { ) {
struct _OlmDecodeGroupMessageResults decoded_results; struct _OlmDecodeGroupMessageResults decoded_results;
size_t max_length, r; size_t max_length, r;
...@@ -286,6 +287,8 @@ static size_t _decrypt( ...@@ -286,6 +287,8 @@ static size_t _decrypt(
return (size_t)-1; return (size_t)-1;
} }
*message_index = decoded_results.message_index;
/* verify the signature. We could do this before decoding the message, but /* verify the signature. We could do this before decoding the message, but
* we allow for the possibility of future protocol versions which use a * we allow for the possibility of future protocol versions which use a
* different signing mechanism; we would rather throw "BAD_MESSAGE_VERSION" * different signing mechanism; we would rather throw "BAD_MESSAGE_VERSION"
...@@ -349,7 +352,8 @@ static size_t _decrypt( ...@@ -349,7 +352,8 @@ static size_t _decrypt(
size_t olm_group_decrypt( size_t olm_group_decrypt(
OlmInboundGroupSession *session, OlmInboundGroupSession *session,
uint8_t * message, size_t message_length, uint8_t * message, size_t message_length,
uint8_t * plaintext, size_t max_plaintext_length uint8_t * plaintext, size_t max_plaintext_length,
uint32_t * message_index
) { ) {
size_t raw_message_length; size_t raw_message_length;
...@@ -361,7 +365,8 @@ size_t olm_group_decrypt( ...@@ -361,7 +365,8 @@ size_t olm_group_decrypt(
return _decrypt( return _decrypt(
session, message, raw_message_length, session, message, raw_message_length,
plaintext, max_plaintext_length plaintext, max_plaintext_length,
message_index
); );
} }
......
...@@ -161,8 +161,9 @@ int main() { ...@@ -161,8 +161,9 @@ int main() {
memcpy(msgcopy, msg, msglen); memcpy(msgcopy, msg, msglen);
size = olm_group_decrypt_max_plaintext_length(inbound_session, msgcopy, msglen); size = olm_group_decrypt_max_plaintext_length(inbound_session, msgcopy, msglen);
uint8_t plaintext_buf[size]; uint8_t plaintext_buf[size];
uint32_t message_index;
res = olm_group_decrypt(inbound_session, msg, msglen, res = olm_group_decrypt(inbound_session, msg, msglen,
plaintext_buf, size); plaintext_buf, size, &message_index);
assert_equals(plaintext_length, res); assert_equals(plaintext_length, res);
assert_equals(plaintext, plaintext_buf, res); assert_equals(plaintext, plaintext_buf, res);
} }
...@@ -208,8 +209,9 @@ int main() { ...@@ -208,8 +209,9 @@ int main() {
memcpy(msgcopy, message, msglen); memcpy(msgcopy, message, msglen);
uint8_t plaintext_buf[size]; uint8_t plaintext_buf[size];
uint32_t message_index;
res = olm_group_decrypt( res = olm_group_decrypt(
inbound_session, msgcopy, msglen, plaintext_buf, size inbound_session, msgcopy, msglen, plaintext_buf, size, &message_index
); );
assert_equals(plaintext_length, res); assert_equals(plaintext_length, res);
assert_equals(plaintext, plaintext_buf, res); assert_equals(plaintext, plaintext_buf, res);
...@@ -227,7 +229,7 @@ int main() { ...@@ -227,7 +229,7 @@ int main() {
memcpy(msgcopy, message, msglen); memcpy(msgcopy, message, msglen);
res = olm_group_decrypt( res = olm_group_decrypt(
inbound_session, msgcopy, msglen, inbound_session, msgcopy, msglen,
plaintext_buf, size plaintext_buf, size, &message_index
); );
assert_equals((size_t)-1, res); assert_equals((size_t)-1, res);
assert_equals( assert_equals(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment