Commit 294cf482 authored by Richard van der Hoff's avatar Richard van der Hoff
Browse files

Convert cipher.hh to plain C

parent f9139dfa
...@@ -13,30 +13,30 @@ ...@@ -13,30 +13,30 @@
* limitations under the License. * limitations under the License.
*/ */
#ifndef OLM_CIPHER_HH_ #ifndef OLM_CIPHER_H_
#define OLM_CIPHER_HH_ #define OLM_CIPHER_H_
#include <cstdint> #include <stdint.h>
#include <cstddef> #include <stdlib.h>
namespace olm { #ifdef __cplusplus
extern "C" {
#endif
class Cipher { struct olm_cipher;
public:
virtual ~Cipher();
struct cipher_ops {
/** /**
* Returns the length of the message authentication code that will be * Returns the length of the message authentication code that will be
* appended to the output. * appended to the output.
*/ */
virtual std::size_t mac_length() const = 0; size_t (*mac_length)(const struct olm_cipher *cipher);
/** /**
* Returns the length of cipher-text for a given length of plain-text. * Returns the length of cipher-text for a given length of plain-text.
*/ */
virtual std::size_t encrypt_ciphertext_length( size_t (*encrypt_ciphertext_length)(const struct olm_cipher *cipher,
std::size_t plaintext_length size_t plaintext_length);
) const = 0;
/* /*
* Encrypts the plain-text into the output buffer and authenticates the * Encrypts the plain-text into the output buffer and authenticates the
...@@ -49,23 +49,25 @@ public: ...@@ -49,23 +49,25 @@ public:
* *
* The plain-text pointers and cipher-text pointers may be the same. * The plain-text pointers and cipher-text pointers may be the same.
* *
* Returns std::size_t(-1) if the length of the cipher-text or the output * Returns size_t(-1) if the length of the cipher-text or the output
* buffer is too small. Otherwise returns the length of the output buffer. * buffer is too small. Otherwise returns the length of the output buffer.
*/ */
virtual std::size_t encrypt( size_t (*encrypt)(
std::uint8_t const * key, std::size_t key_length, const struct olm_cipher *cipher,
std::uint8_t const * plaintext, std::size_t plaintext_length, uint8_t const * key, size_t key_length,
std::uint8_t * ciphertext, std::size_t ciphertext_length, uint8_t const * plaintext, size_t plaintext_length,
std::uint8_t * output, std::size_t output_length uint8_t * ciphertext, size_t ciphertext_length,
) const = 0; uint8_t * output, size_t output_length
);
/** /**
* Returns the maximum length of plain-text that a given length of * Returns the maximum length of plain-text that a given length of
* cipher-text can contain. * cipher-text can contain.
*/ */
virtual std::size_t decrypt_max_plaintext_length( size_t (*decrypt_max_plaintext_length)(
std::size_t ciphertext_length const struct olm_cipher *cipher,
) const = 0; size_t ciphertext_length
);
/** /**
* Authenticates the input and decrypts the cipher-text into the plain-text * Authenticates the input and decrypts the cipher-text into the plain-text
...@@ -77,56 +79,56 @@ public: ...@@ -77,56 +79,56 @@ public:
* *
* The plain-text pointers and cipher-text pointers may be the same. * The plain-text pointers and cipher-text pointers may be the same.
* *
* Returns std::size_t(-1) if the length of the plain-text buffer is too * Returns size_t(-1) if the length of the plain-text buffer is too
* small or if the authentication check fails. Otherwise returns the length * small or if the authentication check fails. Otherwise returns the length
* of the plain text. * of the plain text.
*/ */
virtual std::size_t decrypt( size_t (*decrypt)(
std::uint8_t const * key, std::size_t key_length, const struct olm_cipher *cipher,
std::uint8_t const * input, std::size_t input_length, uint8_t const * key, size_t key_length,
std::uint8_t const * ciphertext, std::size_t ciphertext_length, uint8_t const * input, size_t input_length,
std::uint8_t * plaintext, std::size_t max_plaintext_length uint8_t const * ciphertext, size_t ciphertext_length,
) const = 0; uint8_t * plaintext, size_t max_plaintext_length
);
/** destroy any private data associated with this cipher */
void (*destruct)(struct olm_cipher *cipher);
}; };
struct olm_cipher {
const struct cipher_ops *ops;
/* cipher-specific fields follow */
};
class CipherAesSha256 : public Cipher { struct olm_cipher_aes_sha_256 {
public: struct olm_cipher base_cipher;
CipherAesSha256(
std::uint8_t const * kdf_info, std::size_t kdf_info_length
);
virtual std::size_t mac_length() const; uint8_t const * kdf_info;
size_t kdf_info_length;
virtual std::size_t encrypt_ciphertext_length(
std::size_t plaintext_length
) const;
virtual std::size_t encrypt(
std::uint8_t const * key, std::size_t key_length,
std::uint8_t const * plaintext, std::size_t plaintext_length,
std::uint8_t * ciphertext, std::size_t ciphertext_length,
std::uint8_t * output, std::size_t output_length
) const;
virtual std::size_t decrypt_max_plaintext_length(
std::size_t ciphertext_length
) const;
virtual std::size_t decrypt(
std::uint8_t const * key, std::size_t key_length,
std::uint8_t const * input, std::size_t input_length,
std::uint8_t const * ciphertext, std::size_t ciphertext_length,
std::uint8_t * plaintext, std::size_t max_plaintext_length
) const;
private:
std::uint8_t const * kdf_info;
std::size_t kdf_info_length;
}; };
} // namespace /**
* initialises a cipher type which uses AES256 for encryption and SHA256 for
* authentication.
*
* cipher: structure to be initialised
*
* kdf_info: context string for the HKDF used for deriving the AES256 key, HMAC
* key, and AES IV, from the key material passed to encrypt/decrypt. Note that
* this is NOT copied so must have a lifetime at least as long as the cipher
* instance.
*
* kdf_info_length: length of context string kdf_info
*/
struct olm_cipher *olm_cipher_aes_sha_256_init(
struct olm_cipher_aes_sha_256 *cipher,
uint8_t const * kdf_info,
size_t kdf_info_length);
#ifdef __cplusplus
} /* extern "C" */
#endif
#endif /* OLM_CIPHER_HH_ */ #endif /* OLM_CIPHER_H_ */
...@@ -17,9 +17,9 @@ ...@@ -17,9 +17,9 @@
#include "olm/list.hh" #include "olm/list.hh"
#include "olm/error.h" #include "olm/error.h"
namespace olm { struct olm_cipher;
class Cipher; namespace olm {
typedef std::uint8_t SharedKey[olm::KEY_LENGTH]; typedef std::uint8_t SharedKey[olm::KEY_LENGTH];
...@@ -69,14 +69,14 @@ struct Ratchet { ...@@ -69,14 +69,14 @@ struct Ratchet {
Ratchet( Ratchet(
KdfInfo const & kdf_info, KdfInfo const & kdf_info,
Cipher const & ratchet_cipher olm_cipher const *ratchet_cipher
); );
/** A some strings identifying the application to feed into the KDF. */ /** A some strings identifying the application to feed into the KDF. */
KdfInfo const & kdf_info; KdfInfo const & kdf_info;
/** The AEAD cipher to use for encrypting messages. */ /** The AEAD cipher to use for encrypting messages. */
Cipher const & ratchet_cipher; olm_cipher const *ratchet_cipher;
/** The last error that happened encrypting or decrypting a message. */ /** The last error that happened encrypting or decrypting a message. */
OlmErrorCode last_error; OlmErrorCode last_error;
......
...@@ -12,15 +12,11 @@ ...@@ -12,15 +12,11 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
#include "olm/cipher.hh" #include "olm/cipher.h"
#include "olm/crypto.hh" #include "olm/crypto.hh"
#include "olm/memory.hh" #include "olm/memory.hh"
#include <cstring> #include <cstring>
olm::Cipher::~Cipher() {
}
namespace { namespace {
struct DerivedKeys { struct DerivedKeys {
...@@ -51,41 +47,34 @@ static void derive_keys( ...@@ -51,41 +47,34 @@ static void derive_keys(
static const std::size_t MAC_LENGTH = 8; static const std::size_t MAC_LENGTH = 8;
} // namespace size_t aes_sha_256_cipher_mac_length(const struct olm_cipher *cipher) {
olm::CipherAesSha256::CipherAesSha256(
std::uint8_t const * kdf_info, std::size_t kdf_info_length
) : kdf_info(kdf_info), kdf_info_length(kdf_info_length) {
}
std::size_t olm::CipherAesSha256::mac_length() const {
return MAC_LENGTH; return MAC_LENGTH;
} }
size_t aes_sha_256_cipher_encrypt_ciphertext_length(
std::size_t olm::CipherAesSha256::encrypt_ciphertext_length( const struct olm_cipher *cipher, size_t plaintext_length
std::size_t plaintext_length ) {
) const {
return olm::aes_encrypt_cbc_length(plaintext_length); return olm::aes_encrypt_cbc_length(plaintext_length);
} }
size_t aes_sha_256_cipher_encrypt(
const struct olm_cipher *cipher,
uint8_t const * key, size_t key_length,
uint8_t const * plaintext, size_t plaintext_length,
uint8_t * ciphertext, size_t ciphertext_length,
uint8_t * output, size_t output_length
) {
auto *c = reinterpret_cast<const olm_cipher_aes_sha_256 *>(cipher);
std::size_t olm::CipherAesSha256::encrypt( if (aes_sha_256_cipher_encrypt_ciphertext_length(cipher, plaintext_length)
std::uint8_t const * key, std::size_t key_length, < ciphertext_length) {
std::uint8_t const * plaintext, std::size_t plaintext_length,
std::uint8_t * ciphertext, std::size_t ciphertext_length,
std::uint8_t * output, std::size_t output_length
) const {
if (encrypt_ciphertext_length(plaintext_length) < ciphertext_length) {
return std::size_t(-1); return std::size_t(-1);
} }
struct DerivedKeys keys; struct DerivedKeys keys;
std::uint8_t mac[SHA256_OUTPUT_LENGTH]; std::uint8_t mac[SHA256_OUTPUT_LENGTH];
derive_keys(kdf_info, kdf_info_length, key, key_length, keys); derive_keys(c->kdf_info, c->kdf_info_length, key, key_length, keys);
olm::aes_encrypt_cbc( olm::aes_encrypt_cbc(
keys.aes_key, keys.aes_iv, plaintext, plaintext_length, ciphertext keys.aes_key, keys.aes_iv, plaintext, plaintext_length, ciphertext
...@@ -102,22 +91,26 @@ std::size_t olm::CipherAesSha256::encrypt( ...@@ -102,22 +91,26 @@ std::size_t olm::CipherAesSha256::encrypt(
} }
std::size_t olm::CipherAesSha256::decrypt_max_plaintext_length( size_t aes_sha_256_cipher_decrypt_max_plaintext_length(
std::size_t ciphertext_length const struct olm_cipher *cipher,
) const { size_t ciphertext_length
) {
return ciphertext_length; return ciphertext_length;
} }
std::size_t olm::CipherAesSha256::decrypt( size_t aes_sha_256_cipher_decrypt(
std::uint8_t const * key, std::size_t key_length, const struct olm_cipher *cipher,
std::uint8_t const * input, std::size_t input_length, uint8_t const * key, size_t key_length,
std::uint8_t const * ciphertext, std::size_t ciphertext_length, uint8_t const * input, size_t input_length,
std::uint8_t * plaintext, std::size_t max_plaintext_length uint8_t const * ciphertext, size_t ciphertext_length,
) const { uint8_t * plaintext, size_t max_plaintext_length
) {
auto *c = reinterpret_cast<const olm_cipher_aes_sha_256 *>(cipher);
DerivedKeys keys; DerivedKeys keys;
std::uint8_t mac[SHA256_OUTPUT_LENGTH]; std::uint8_t mac[SHA256_OUTPUT_LENGTH];
derive_keys(kdf_info, kdf_info_length, key, key_length, keys); derive_keys(c->kdf_info, c->kdf_info_length, key, key_length, keys);
crypto_hmac_sha256( crypto_hmac_sha256(
keys.mac_key, olm::KEY_LENGTH, input, input_length - MAC_LENGTH, mac keys.mac_key, olm::KEY_LENGTH, input, input_length - MAC_LENGTH, mac
...@@ -136,3 +129,30 @@ std::size_t olm::CipherAesSha256::decrypt( ...@@ -136,3 +129,30 @@ std::size_t olm::CipherAesSha256::decrypt(
olm::unset(keys); olm::unset(keys);
return plaintext_length; return plaintext_length;
} }
void aes_sha_256_cipher_destruct(struct olm_cipher *cipher) {
}
const cipher_ops aes_sha_256_cipher_ops = {
aes_sha_256_cipher_mac_length,
aes_sha_256_cipher_encrypt_ciphertext_length,
aes_sha_256_cipher_encrypt,
aes_sha_256_cipher_decrypt_max_plaintext_length,
aes_sha_256_cipher_decrypt,
aes_sha_256_cipher_destruct
};
} // namespace
olm_cipher *olm_cipher_aes_sha_256_init(struct olm_cipher_aes_sha_256 *cipher,
uint8_t const * kdf_info,
size_t kdf_info_length)
{
cipher->base_cipher.ops = &aes_sha_256_cipher_ops;
cipher->kdf_info = kdf_info;
cipher->kdf_info_length = kdf_info_length;
return &(cipher->base_cipher);
}
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
#include "olm/olm.h" #include "olm/olm.h"
#include "olm/session.hh" #include "olm/session.hh"
#include "olm/account.hh" #include "olm/account.hh"
#include "olm/cipher.h"
#include "olm/utility.hh" #include "olm/utility.hh"
#include "olm/base64.hh" #include "olm/base64.hh"
#include "olm/cipher.hh"
#include "olm/memory.hh" #include "olm/memory.hh"
#include <new> #include <new>
...@@ -59,15 +59,24 @@ static std::uint8_t const * from_c(void const * bytes) { ...@@ -59,15 +59,24 @@ static std::uint8_t const * from_c(void const * bytes) {
static const std::uint8_t CIPHER_KDF_INFO[] = "Pickle"; static const std::uint8_t CIPHER_KDF_INFO[] = "Pickle";
static const olm::CipherAesSha256 PICKLE_CIPHER( const olm_cipher *get_pickle_cipher() {
CIPHER_KDF_INFO, sizeof(CIPHER_KDF_INFO) -1 static olm_cipher *cipher = NULL;
); static olm_cipher_aes_sha_256 PICKLE_CIPHER;
if (!cipher) {
cipher = olm_cipher_aes_sha_256_init(
&PICKLE_CIPHER,
CIPHER_KDF_INFO, sizeof(CIPHER_KDF_INFO) - 1
);
}
return cipher;
}
std::size_t enc_output_length( std::size_t enc_output_length(
size_t raw_length size_t raw_length
) { ) {
std::size_t length = PICKLE_CIPHER.encrypt_ciphertext_length(raw_length); auto *cipher = get_pickle_cipher();
length += PICKLE_CIPHER.mac_length(); std::size_t length = cipher->ops->encrypt_ciphertext_length(cipher, raw_length);
length += cipher->ops->mac_length(cipher);
return olm::encode_base64_length(length); return olm::encode_base64_length(length);
} }
...@@ -76,8 +85,9 @@ std::uint8_t * enc_output_pos( ...@@ -76,8 +85,9 @@ std::uint8_t * enc_output_pos(
std::uint8_t * output, std::uint8_t * output,
size_t raw_length size_t raw_length
) { ) {
std::size_t length = PICKLE_CIPHER.encrypt_ciphertext_length(raw_length); auto *cipher = get_pickle_cipher();
length += PICKLE_CIPHER.mac_length(); std::size_t length = cipher->ops->encrypt_ciphertext_length(cipher, raw_length);
length += cipher->ops->mac_length(cipher);
return output + olm::encode_base64_length(length) - length; return output + olm::encode_base64_length(length) - length;
} }
...@@ -85,13 +95,15 @@ std::size_t enc_output( ...@@ -85,13 +95,15 @@ std::size_t enc_output(
std::uint8_t const * key, std::size_t key_length, std::uint8_t const * key, std::size_t key_length,
std::uint8_t * output, size_t raw_length std::uint8_t * output, size_t raw_length
) { ) {
std::size_t ciphertext_length = PICKLE_CIPHER.encrypt_ciphertext_length( auto *cipher = get_pickle_cipher();
raw_length std::size_t ciphertext_length = cipher->ops->encrypt_ciphertext_length(
cipher, raw_length
); );
std::size_t length = ciphertext_length + PICKLE_CIPHER.mac_length(); std::size_t length = ciphertext_length + cipher->ops->mac_length(cipher);
std::size_t base64_length = olm::encode_base64_length(length); std::size_t base64_length = olm::encode_base64_length(length);
std::uint8_t * raw_output = output + base64_length - length; std::uint8_t * raw_output = output + base64_length - length;
PICKLE_CIPHER.encrypt( cipher->ops->encrypt(
cipher,
key, key_length, key, key_length,
raw_output, raw_length, raw_output, raw_length,
raw_output, ciphertext_length, raw_output, ciphertext_length,
...@@ -112,8 +124,10 @@ std::size_t enc_input( ...@@ -112,8 +124,10 @@ std::size_t enc_input(
return std::size_t(-1); return std::size_t(-1);
} }
olm::decode_base64(input, b64_length, input); olm::decode_base64(input, b64_length, input);
std::size_t raw_length = enc_length - PICKLE_CIPHER.mac_length(); auto *cipher = get_pickle_cipher();
std::size_t result = PICKLE_CIPHER.decrypt( std::size_t raw_length = enc_length - cipher->ops->mac_length(cipher);
std::size_t result = cipher->ops->decrypt(
cipher,
key, key_length, key, key_length,
input, enc_length, input, enc_length,
input, raw_length, input, raw_length,
......
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
#include "olm/ratchet.hh" #include "olm/ratchet.hh"
#include "olm/message.hh" #include "olm/message.hh"
#include "olm/memory.hh" #include "olm/memory.hh"
#include "olm/cipher.hh" #include "olm/cipher.h"
#include "olm/pickle.hh" #include "olm/pickle.hh"
#include <cstring> #include <cstring>
...@@ -94,12 +94,13 @@ static void create_message_keys( ...@@ -94,12 +94,13 @@ static void create_message_keys(
static std::size_t verify_mac_and_decrypt( static std::size_t verify_mac_and_decrypt(
olm::Cipher const & cipher, olm_cipher const *cipher,
olm::MessageKey const & message_key, olm::MessageKey const & message_key,
olm::MessageReader const & reader, olm::MessageReader const & reader,
std::uint8_t * plaintext, std::size_t max_plaintext_length std::uint8_t * plaintext, std::size_t max_plaintext_length
) { ) {
return cipher.decrypt( return cipher->ops->decrypt(
cipher,
message_key.key, sizeof(message_key.key), message_key.key, sizeof(message_key.key),
reader.input, reader.input_length, reader.input, reader.input_length,
reader.ciphertext, reader.ciphertext_length, reader.ciphertext, reader.ciphertext_length,
...@@ -183,7 +184,7 @@ static std::size_t verify_mac_and_decrypt_for_new_chain( ...@@ -183,7 +184,7 @@ static std::size_t verify_mac_and_decrypt_for_new_chain(
olm::Ratchet::Ratchet( olm::Ratchet::Ratchet(
olm::KdfInfo const & kdf_info, olm::KdfInfo const & kdf_info,
Cipher const & ratchet_cipher olm_cipher const * ratchet_cipher
) : kdf_info(kdf_info), ) : kdf_info(kdf_info),
ratchet_cipher(ratchet_cipher), ratchet_cipher(ratchet_cipher),
last_error(OlmErrorCode::OLM_SUCCESS) { last_error(OlmErrorCode::OLM_SUCCESS) {
...@@ -405,11 +406,12 @@ std::size_t olm::Ratchet::encrypt_output_length( ...@@ -405,11 +406,12 @@ std::size_t olm::Ratchet::encrypt_output_length(
if (!sender_chain.empty()) { if (!sender_chain.empty()) {
counter = sender_chain[0].chain_key.index; counter = sender_chain[0].chain_key.index;
} }
std::size_t padded = ratchet_cipher.encrypt_ciphertext_length( std::size_t padded = ratchet_cipher->ops->encrypt_ciphertext_length(
ratchet_cipher,
plaintext_length plaintext_length
); );
return olm::encode_message_length( return olm::encode_message_length(
counter, olm::KEY_LENGTH, padded, ratchet_cipher.mac_length() counter, olm::KEY_LENGTH, padded, ratchet_cipher->ops->mac_length(ratchet_cipher)
); );
} }
...@@ -452,7 +454,8 @@ std::size_t olm::Ratchet::encrypt( ...@@ -452,7 +454,8 @@ std::size_t olm::Ratchet::encrypt(
create_message_keys(chain_index, sender_chain[0].chain_key, kdf_info, keys); create_message_keys(chain_index, sender_chain[0].chain_key, kdf_info, keys);
advance_chain_key(chain_index, sender_chain[0].chain_key, sender_chain[0].chain_key); advance_chain_key(chain_index, sender_chain[0].chain_key, sender_chain[0].chain_key);
std::size_t ciphertext_length = ratchet_cipher.encrypt_ciphertext_length( std::size_t ciphertext_length = ratchet_cipher->ops->encrypt_ciphertext_length(
ratchet_cipher,
plaintext_length plaintext_length
); );
std::uint32_t counter = keys.index; std::uint32_t counter = keys.index;
...@@ -467,7 +470,8 @@ std::size_t olm::Ratchet::encrypt( ...@@ -467,7 +470,8 @@ std::size_t olm::Ratchet::encrypt(
olm::store_array(writer.ratchet_key, ratchet_key.public_key); olm::store_array(writer.ratchet_key, ratchet_key.public_key);
ratchet_cipher.encrypt( ratchet_cipher->ops->encrypt(