Commit 294cf482 authored by Richard van der Hoff's avatar Richard van der Hoff
Browse files

Convert cipher.hh to plain C

parent f9139dfa
......@@ -13,30 +13,30 @@
* limitations under the License.
*/
#ifndef OLM_CIPHER_HH_
#define OLM_CIPHER_HH_
#ifndef OLM_CIPHER_H_
#define OLM_CIPHER_H_
#include <cstdint>
#include <cstddef>
#include <stdint.h>
#include <stdlib.h>
namespace olm {
#ifdef __cplusplus
extern "C" {
#endif
class Cipher {
public:
virtual ~Cipher();
struct olm_cipher;
struct cipher_ops {
/**
* Returns the length of the message authentication code that will be
* appended to the output.
*/
virtual std::size_t mac_length() const = 0;
size_t (*mac_length)(const struct olm_cipher *cipher);
/**
* Returns the length of cipher-text for a given length of plain-text.
*/
virtual std::size_t encrypt_ciphertext_length(
std::size_t plaintext_length
) const = 0;
size_t (*encrypt_ciphertext_length)(const struct olm_cipher *cipher,
size_t plaintext_length);
/*
* Encrypts the plain-text into the output buffer and authenticates the
......@@ -49,23 +49,25 @@ public:
*
* The plain-text pointers and cipher-text pointers may be the same.
*
* Returns std::size_t(-1) if the length of the cipher-text or the output
* Returns size_t(-1) if the length of the cipher-text or the output
* buffer is too small. Otherwise returns the length of the output buffer.
*/
virtual std::size_t encrypt(
std::uint8_t const * key, std::size_t key_length,
std::uint8_t const * plaintext, std::size_t plaintext_length,
std::uint8_t * ciphertext, std::size_t ciphertext_length,
std::uint8_t * output, std::size_t output_length
) const = 0;
size_t (*encrypt)(
const struct olm_cipher *cipher,
uint8_t const * key, size_t key_length,
uint8_t const * plaintext, size_t plaintext_length,
uint8_t * ciphertext, size_t ciphertext_length,
uint8_t * output, size_t output_length
);
/**
* Returns the maximum length of plain-text that a given length of
* cipher-text can contain.
*/
virtual std::size_t decrypt_max_plaintext_length(
std::size_t ciphertext_length
) const = 0;
size_t (*decrypt_max_plaintext_length)(
const struct olm_cipher *cipher,
size_t ciphertext_length
);
/**
* Authenticates the input and decrypts the cipher-text into the plain-text
......@@ -77,56 +79,56 @@ public:
*
* The plain-text pointers and cipher-text pointers may be the same.
*
* Returns std::size_t(-1) if the length of the plain-text buffer is too
* Returns size_t(-1) if the length of the plain-text buffer is too
* small or if the authentication check fails. Otherwise returns the length
* of the plain text.
*/
virtual std::size_t decrypt(
std::uint8_t const * key, std::size_t key_length,
std::uint8_t const * input, std::size_t input_length,
std::uint8_t const * ciphertext, std::size_t ciphertext_length,
std::uint8_t * plaintext, std::size_t max_plaintext_length
) const = 0;
size_t (*decrypt)(
const struct olm_cipher *cipher,
uint8_t const * key, size_t key_length,
uint8_t const * input, size_t input_length,
uint8_t const * ciphertext, size_t ciphertext_length,
uint8_t * plaintext, size_t max_plaintext_length
);
/** destroy any private data associated with this cipher */
void (*destruct)(struct olm_cipher *cipher);
};
struct olm_cipher {
const struct cipher_ops *ops;
/* cipher-specific fields follow */
};
class CipherAesSha256 : public Cipher {
public:
CipherAesSha256(
std::uint8_t const * kdf_info, std::size_t kdf_info_length
);
struct olm_cipher_aes_sha_256 {
struct olm_cipher base_cipher;
virtual std::size_t mac_length() const;
virtual std::size_t encrypt_ciphertext_length(
std::size_t plaintext_length
) const;
virtual std::size_t encrypt(
std::uint8_t const * key, std::size_t key_length,
std::uint8_t const * plaintext, std::size_t plaintext_length,
std::uint8_t * ciphertext, std::size_t ciphertext_length,
std::uint8_t * output, std::size_t output_length
) const;
virtual std::size_t decrypt_max_plaintext_length(
std::size_t ciphertext_length
) const;
virtual std::size_t decrypt(
std::uint8_t const * key, std::size_t key_length,
std::uint8_t const * input, std::size_t input_length,
std::uint8_t const * ciphertext, std::size_t ciphertext_length,
std::uint8_t * plaintext, std::size_t max_plaintext_length
) const;
private:
std::uint8_t const * kdf_info;
std::size_t kdf_info_length;
uint8_t const * kdf_info;
size_t kdf_info_length;
};
} // namespace
/**
* initialises a cipher type which uses AES256 for encryption and SHA256 for
* authentication.
*
* cipher: structure to be initialised
*
* kdf_info: context string for the HKDF used for deriving the AES256 key, HMAC
* key, and AES IV, from the key material passed to encrypt/decrypt. Note that
* this is NOT copied so must have a lifetime at least as long as the cipher
* instance.
*
* kdf_info_length: length of context string kdf_info
*/
struct olm_cipher *olm_cipher_aes_sha_256_init(
struct olm_cipher_aes_sha_256 *cipher,
uint8_t const * kdf_info,
size_t kdf_info_length);
#ifdef __cplusplus
} /* extern "C" */
#endif
#endif /* OLM_CIPHER_HH_ */
#endif /* OLM_CIPHER_H_ */
......@@ -17,9 +17,9 @@
#include "olm/list.hh"
#include "olm/error.h"
namespace olm {
struct olm_cipher;
class Cipher;
namespace olm {
typedef std::uint8_t SharedKey[olm::KEY_LENGTH];
......@@ -69,14 +69,14 @@ struct Ratchet {
Ratchet(
KdfInfo const & kdf_info,
Cipher const & ratchet_cipher
olm_cipher const *ratchet_cipher
);
/** A some strings identifying the application to feed into the KDF. */
KdfInfo const & kdf_info;
/** The AEAD cipher to use for encrypting messages. */
Cipher const & ratchet_cipher;
olm_cipher const *ratchet_cipher;
/** The last error that happened encrypting or decrypting a message. */
OlmErrorCode last_error;
......
......@@ -12,15 +12,11 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "olm/cipher.hh"
#include "olm/cipher.h"
#include "olm/crypto.hh"
#include "olm/memory.hh"
#include <cstring>
olm::Cipher::~Cipher() {
}
namespace {
struct DerivedKeys {
......@@ -51,41 +47,34 @@ static void derive_keys(
static const std::size_t MAC_LENGTH = 8;
} // namespace
olm::CipherAesSha256::CipherAesSha256(
std::uint8_t const * kdf_info, std::size_t kdf_info_length
) : kdf_info(kdf_info), kdf_info_length(kdf_info_length) {
}
std::size_t olm::CipherAesSha256::mac_length() const {
size_t aes_sha_256_cipher_mac_length(const struct olm_cipher *cipher) {
return MAC_LENGTH;
}
std::size_t olm::CipherAesSha256::encrypt_ciphertext_length(
std::size_t plaintext_length
) const {
size_t aes_sha_256_cipher_encrypt_ciphertext_length(
const struct olm_cipher *cipher, size_t plaintext_length
) {
return olm::aes_encrypt_cbc_length(plaintext_length);
}
size_t aes_sha_256_cipher_encrypt(
const struct olm_cipher *cipher,
uint8_t const * key, size_t key_length,
uint8_t const * plaintext, size_t plaintext_length,
uint8_t * ciphertext, size_t ciphertext_length,
uint8_t * output, size_t output_length
) {
auto *c = reinterpret_cast<const olm_cipher_aes_sha_256 *>(cipher);
std::size_t olm::CipherAesSha256::encrypt(
std::uint8_t const * key, std::size_t key_length,
std::uint8_t const * plaintext, std::size_t plaintext_length,
std::uint8_t * ciphertext, std::size_t ciphertext_length,
std::uint8_t * output, std::size_t output_length
) const {
if (encrypt_ciphertext_length(plaintext_length) < ciphertext_length) {
if (aes_sha_256_cipher_encrypt_ciphertext_length(cipher, plaintext_length)
< ciphertext_length) {
return std::size_t(-1);
}
struct DerivedKeys keys;
std::uint8_t mac[SHA256_OUTPUT_LENGTH];
derive_keys(kdf_info, kdf_info_length, key, key_length, keys);
derive_keys(c->kdf_info, c->kdf_info_length, key, key_length, keys);
olm::aes_encrypt_cbc(
keys.aes_key, keys.aes_iv, plaintext, plaintext_length, ciphertext
......@@ -102,22 +91,26 @@ std::size_t olm::CipherAesSha256::encrypt(
}
std::size_t olm::CipherAesSha256::decrypt_max_plaintext_length(
std::size_t ciphertext_length
) const {
size_t aes_sha_256_cipher_decrypt_max_plaintext_length(
const struct olm_cipher *cipher,
size_t ciphertext_length
) {
return ciphertext_length;
}
std::size_t olm::CipherAesSha256::decrypt(
std::uint8_t const * key, std::size_t key_length,
std::uint8_t const * input, std::size_t input_length,
std::uint8_t const * ciphertext, std::size_t ciphertext_length,
std::uint8_t * plaintext, std::size_t max_plaintext_length
) const {
size_t aes_sha_256_cipher_decrypt(
const struct olm_cipher *cipher,
uint8_t const * key, size_t key_length,
uint8_t const * input, size_t input_length,
uint8_t const * ciphertext, size_t ciphertext_length,
uint8_t * plaintext, size_t max_plaintext_length
) {
auto *c = reinterpret_cast<const olm_cipher_aes_sha_256 *>(cipher);
DerivedKeys keys;
std::uint8_t mac[SHA256_OUTPUT_LENGTH];
derive_keys(kdf_info, kdf_info_length, key, key_length, keys);
derive_keys(c->kdf_info, c->kdf_info_length, key, key_length, keys);
crypto_hmac_sha256(
keys.mac_key, olm::KEY_LENGTH, input, input_length - MAC_LENGTH, mac
......@@ -136,3 +129,30 @@ std::size_t olm::CipherAesSha256::decrypt(
olm::unset(keys);
return plaintext_length;
}
void aes_sha_256_cipher_destruct(struct olm_cipher *cipher) {
}
const cipher_ops aes_sha_256_cipher_ops = {
aes_sha_256_cipher_mac_length,
aes_sha_256_cipher_encrypt_ciphertext_length,
aes_sha_256_cipher_encrypt,
aes_sha_256_cipher_decrypt_max_plaintext_length,
aes_sha_256_cipher_decrypt,
aes_sha_256_cipher_destruct
};
} // namespace
olm_cipher *olm_cipher_aes_sha_256_init(struct olm_cipher_aes_sha_256 *cipher,
uint8_t const * kdf_info,
size_t kdf_info_length)
{
cipher->base_cipher.ops = &aes_sha_256_cipher_ops;
cipher->kdf_info = kdf_info;
cipher->kdf_info_length = kdf_info_length;
return &(cipher->base_cipher);
}
......@@ -15,9 +15,9 @@
#include "olm/olm.h"
#include "olm/session.hh"
#include "olm/account.hh"
#include "olm/cipher.h"
#include "olm/utility.hh"
#include "olm/base64.hh"
#include "olm/cipher.hh"
#include "olm/memory.hh"
#include <new>
......@@ -59,15 +59,24 @@ static std::uint8_t const * from_c(void const * bytes) {
static const std::uint8_t CIPHER_KDF_INFO[] = "Pickle";
static const olm::CipherAesSha256 PICKLE_CIPHER(
CIPHER_KDF_INFO, sizeof(CIPHER_KDF_INFO) -1
);
const olm_cipher *get_pickle_cipher() {
static olm_cipher *cipher = NULL;
static olm_cipher_aes_sha_256 PICKLE_CIPHER;
if (!cipher) {
cipher = olm_cipher_aes_sha_256_init(
&PICKLE_CIPHER,
CIPHER_KDF_INFO, sizeof(CIPHER_KDF_INFO) - 1
);
}
return cipher;
}
std::size_t enc_output_length(
size_t raw_length
) {
std::size_t length = PICKLE_CIPHER.encrypt_ciphertext_length(raw_length);
length += PICKLE_CIPHER.mac_length();
auto *cipher = get_pickle_cipher();
std::size_t length = cipher->ops->encrypt_ciphertext_length(cipher, raw_length);
length += cipher->ops->mac_length(cipher);
return olm::encode_base64_length(length);
}
......@@ -76,8 +85,9 @@ std::uint8_t * enc_output_pos(
std::uint8_t * output,
size_t raw_length
) {
std::size_t length = PICKLE_CIPHER.encrypt_ciphertext_length(raw_length);
length += PICKLE_CIPHER.mac_length();
auto *cipher = get_pickle_cipher();
std::size_t length = cipher->ops->encrypt_ciphertext_length(cipher, raw_length);
length += cipher->ops->mac_length(cipher);
return output + olm::encode_base64_length(length) - length;
}
......@@ -85,13 +95,15 @@ std::size_t enc_output(
std::uint8_t const * key, std::size_t key_length,
std::uint8_t * output, size_t raw_length
) {
std::size_t ciphertext_length = PICKLE_CIPHER.encrypt_ciphertext_length(
raw_length
auto *cipher = get_pickle_cipher();
std::size_t ciphertext_length = cipher->ops->encrypt_ciphertext_length(
cipher, raw_length
);
std::size_t length = ciphertext_length + PICKLE_CIPHER.mac_length();
std::size_t length = ciphertext_length + cipher->ops->mac_length(cipher);
std::size_t base64_length = olm::encode_base64_length(length);
std::uint8_t * raw_output = output + base64_length - length;
PICKLE_CIPHER.encrypt(
cipher->ops->encrypt(
cipher,
key, key_length,
raw_output, raw_length,
raw_output, ciphertext_length,
......@@ -112,8 +124,10 @@ std::size_t enc_input(
return std::size_t(-1);
}
olm::decode_base64(input, b64_length, input);
std::size_t raw_length = enc_length - PICKLE_CIPHER.mac_length();
std::size_t result = PICKLE_CIPHER.decrypt(
auto *cipher = get_pickle_cipher();
std::size_t raw_length = enc_length - cipher->ops->mac_length(cipher);
std::size_t result = cipher->ops->decrypt(
cipher,
key, key_length,
input, enc_length,
input, raw_length,
......
......@@ -15,7 +15,7 @@
#include "olm/ratchet.hh"
#include "olm/message.hh"
#include "olm/memory.hh"
#include "olm/cipher.hh"
#include "olm/cipher.h"
#include "olm/pickle.hh"
#include <cstring>
......@@ -94,12 +94,13 @@ static void create_message_keys(
static std::size_t verify_mac_and_decrypt(
olm::Cipher const & cipher,
olm_cipher const *cipher,
olm::MessageKey const & message_key,
olm::MessageReader const & reader,
std::uint8_t * plaintext, std::size_t max_plaintext_length
) {
return cipher.decrypt(
return cipher->ops->decrypt(
cipher,
message_key.key, sizeof(message_key.key),
reader.input, reader.input_length,
reader.ciphertext, reader.ciphertext_length,
......@@ -183,7 +184,7 @@ static std::size_t verify_mac_and_decrypt_for_new_chain(
olm::Ratchet::Ratchet(
olm::KdfInfo const & kdf_info,
Cipher const & ratchet_cipher
olm_cipher const * ratchet_cipher
) : kdf_info(kdf_info),
ratchet_cipher(ratchet_cipher),
last_error(OlmErrorCode::OLM_SUCCESS) {
......@@ -405,11 +406,12 @@ std::size_t olm::Ratchet::encrypt_output_length(
if (!sender_chain.empty()) {
counter = sender_chain[0].chain_key.index;
}
std::size_t padded = ratchet_cipher.encrypt_ciphertext_length(
std::size_t padded = ratchet_cipher->ops->encrypt_ciphertext_length(
ratchet_cipher,
plaintext_length
);
return olm::encode_message_length(
counter, olm::KEY_LENGTH, padded, ratchet_cipher.mac_length()
counter, olm::KEY_LENGTH, padded, ratchet_cipher->ops->mac_length(ratchet_cipher)
);
}
......@@ -452,7 +454,8 @@ std::size_t olm::Ratchet::encrypt(
create_message_keys(chain_index, sender_chain[0].chain_key, kdf_info, keys);
advance_chain_key(chain_index, sender_chain[0].chain_key, sender_chain[0].chain_key);
std::size_t ciphertext_length = ratchet_cipher.encrypt_ciphertext_length(
std::size_t ciphertext_length = ratchet_cipher->ops->encrypt_ciphertext_length(
ratchet_cipher,
plaintext_length
);
std::uint32_t counter = keys.index;
......@@ -467,7 +470,8 @@ std::size_t olm::Ratchet::encrypt(
olm::store_array(writer.ratchet_key, ratchet_key.public_key);
ratchet_cipher.encrypt(
ratchet_cipher->ops->encrypt(
ratchet_cipher,
keys.key, sizeof(keys.key),
plaintext, plaintext_length,
writer.ciphertext, ciphertext_length,
......@@ -484,7 +488,8 @@ std::size_t olm::Ratchet::decrypt_max_plaintext_length(
) {
olm::MessageReader reader;
olm::decode_message(
reader, input, input_length, ratchet_cipher.mac_length()
reader, input, input_length,
ratchet_cipher->ops->mac_length(ratchet_cipher)
);
if (!reader.ciphertext) {
......@@ -492,7 +497,8 @@ std::size_t olm::Ratchet::decrypt_max_plaintext_length(
return std::size_t(-1);
}
return ratchet_cipher.decrypt_max_plaintext_length(reader.ciphertext_length);
return ratchet_cipher->ops->decrypt_max_plaintext_length(
ratchet_cipher, reader.ciphertext_length);
}
......@@ -502,7 +508,8 @@ std::size_t olm::Ratchet::decrypt(
) {
olm::MessageReader reader;
olm::decode_message(
reader, input, input_length, ratchet_cipher.mac_length()
reader, input, input_length,
ratchet_cipher->ops->mac_length(ratchet_cipher)
);
if (reader.version != PROTOCOL_VERSION) {
......@@ -515,7 +522,8 @@ std::size_t olm::Ratchet::decrypt(
return std::size_t(-1);
}
std::size_t max_length = ratchet_cipher.decrypt_max_plaintext_length(
std::size_t max_length = ratchet_cipher->ops->decrypt_max_plaintext_length(
ratchet_cipher,
reader.ciphertext_length
);
......
......@@ -13,7 +13,7 @@
* limitations under the License.
*/
#include "olm/session.hh"
#include "olm/cipher.hh"
#include "olm/cipher.h"
#include "olm/crypto.hh"
#include "olm/account.hh"
#include "olm/memory.hh"
......@@ -30,19 +30,27 @@ static const std::uint8_t ROOT_KDF_INFO[] = "OLM_ROOT";
static const std::uint8_t RATCHET_KDF_INFO[] = "OLM_RATCHET";
static const std::uint8_t CIPHER_KDF_INFO[] = "OLM_KEYS";
static const olm::CipherAesSha256 OLM_CIPHER(
CIPHER_KDF_INFO, sizeof(CIPHER_KDF_INFO) -1
);
static const olm::KdfInfo OLM_KDF_INFO = {
ROOT_KDF_INFO, sizeof(ROOT_KDF_INFO) - 1,
RATCHET_KDF_INFO, sizeof(RATCHET_KDF_INFO) - 1
};
const olm_cipher *get_cipher() {
static olm_cipher *cipher;
static olm_cipher_aes_sha_256 OLM_CIPHER;
if (!cipher) {
cipher = olm_cipher_aes_sha_256_init(
&OLM_CIPHER,
CIPHER_KDF_INFO, sizeof(CIPHER_KDF_INFO) - 1
);
}
return cipher;
}
} // namespace
olm::Session::Session(
) : ratchet(OLM_KDF_INFO, OLM_CIPHER),
) : ratchet(OLM_KDF_INFO, get_cipher()),
last_error(OlmErrorCode::OLM_SUCCESS),
received_message(false) {
......@@ -149,7 +157,7 @@ std::size_t olm::Session::new_inbound_session(
olm::MessageReader message_reader;
decode_message(
message_reader, reader.message, reader.message_length,
ratchet.ratchet_cipher.mac_length()
ratchet.ratchet_cipher->ops->mac_length(ratchet.ratchet_cipher)
);
if (!message_reader.ratchet_key
......
......@@ -13,7 +13,7 @@
* limitations under the License.
*/
#include "olm/ratchet.hh"
#include "olm/cipher.hh"
#include "olm/cipher.h"
#include "unittest.hh"
......@@ -28,8 +28,9 @@ olm::KdfInfo kdf_info = {
ratchet_info, sizeof(ratchet_info) - 1
};
olm::CipherAesSha256 cipher(
message_info, sizeof(message_info) - 1
olm_cipher_aes_sha_256 cipher0;
olm_cipher *cipher = olm_cipher_aes_sha_256_init(
&cipher0, message_info, sizeof(message_info) - 1
);
std::uint8_t random_bytes[] = "0123456789ABDEF0123456789ABCDEF";
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment