diff --git a/src/crypto.cpp b/src/crypto.cpp index ae8f306..03675c5 100644 --- a/src/crypto.cpp +++ b/src/crypto.cpp @@ -4,9 +4,13 @@ #include #include #include +#include +#include + #include #include "crypto.h" +#include "utils.h" class EvpCipherCtx { EVP_CIPHER_CTX *ptr; @@ -26,44 +30,166 @@ public: }; -int AES256::encrypt(const std::vector& key, const std::vector& iv, - const std::vector& input, const size_t input_len, - std::vector& output, size_t& output_len) -{ - //auto ctx = std::unique_ptr(EVP_CIPHER_CTX_new()); +AES256_CBC::AES256_CBC(){ + encryption_key = utils::generate_random(kKeySize); +} +AES256_CBC::AES256_CBC(const std::vector& key){ + encryption_key = key; +} +AES256_CBC::AES256_CBC(std::vector&& key){ + encryption_key = key; +} + +EvpCipherCtx AES256_CBC::init(Cipher::Mode mode, + const std::vector& key, const std::vector& iv){ + if (key.size() != kKeySize){ + throw std::runtime_error("Wrong key size"); + } + if (iv.size() != kIvSize){ + throw std::runtime_error("Wrong IV size"); + } + EvpCipherCtx ctx; - if (EVP_EncryptInit(ctx.get(), EVP_aes_256_cbc(), - (key.data()), (iv.data())) == 0){ - throw std::exception(); - } + if (mode == Cipher::Mode::kEncrypt){ + if (EVP_EncryptInit(ctx.get(), EVP_aes_256_cbc(), + (key.data()), (iv.data())) == 0){ + throw std::runtime_error("EVP_EncryptInit"); + } + } else if (mode == Cipher::Mode::kDecrypt){ + if (EVP_DecryptInit(ctx.get(), EVP_aes_256_cbc(), + (key.data()), (iv.data())) == 0){ + throw std::runtime_error("EVP_DecryptInit"); + } + } else throw std::invalid_argument("Invalid Cipher mode"); + return ctx; +} +size_t AES256_CBC::process_chunk(Cipher::Mode mode, EvpCipherCtx& ctx, + std::vector::const_iterator begin, + std::vector::const_iterator end, + std::vector& output, size_t output_offset, + bool resize_in, bool resize_out) +{ + auto chunk_size = end - begin; int len; - output.reserve(input.size() + block_len); - if (1 != EVP_EncryptUpdate(ctx.get(), - (output.data()), &len, (input.data()), input.size())){ - throw std::exception(); + if(resize_in){ + // Make sure ouput is large enough to add encrypted data + padding + output.resize(output_offset + chunk_size + kIvSize); } - output_len = len; - if (1 != EVP_EncryptFinal_ex(ctx.get(), (output.data()) + len, &len)){ - throw std::exception(); + if (mode == Cipher::Mode::kEncrypt) { + if (1 != EVP_EncryptUpdate(ctx.get(), + output.data() + output_offset, &len, &*begin, chunk_size)){ + throw std::runtime_error("EVP_EncryptUpdate"); + } + } else { + if (1 != EVP_DecryptUpdate(ctx.get(), + output.data() + output_offset, &len, &*begin, chunk_size)){ + throw std::runtime_error("EVP_DecryptUpdate"); + } } - output_len += len; - output.resize(output_len); + if (resize_out) + output.resize(output_offset + len); + return len; +} +size_t AES256_CBC::process_final(Cipher::Mode mode, EvpCipherCtx& ctx, + std::vector& output, size_t output_offset, + bool resize_in, bool resize_out) +{ + int len; + if (resize_in) { + // Make sure output is large enough to add the last block + output.resize(output_offset + kIvSize); + } + if (mode == Cipher::Mode::kEncrypt) { + if (1 != EVP_EncryptFinal_ex(ctx.get(), output.data() + output_offset, &len)){ + throw std::runtime_error("EVP_EncryptFinal"); + } + } else { + if (1 != EVP_DecryptFinal_ex(ctx.get(), (output.data()) + output_offset, &len)){ + throw std::runtime_error("EVP_DecryptFinal"); + } + } + if (resize_out) + output.resize(output_offset + len); + return len; +} + +size_t AES256_CBC::process_all(Cipher::Mode mode, EvpCipherCtx& ctx, + std::vector::const_iterator begin, + std::vector::const_iterator end, + std::vector& output, size_t output_offset, + bool resize_in, bool resize_out) +{ + int len = process_chunk(mode, ctx, begin, end, + output, output_offset, resize_in, false); + len += process_final(mode, ctx, + output, output_offset + len, false, resize_out); + return len; +} + + +std::vector AES256_CBC::encrypt(std::vector& plaintext) { + return encrypt(encryption_key, utils::generate_random(kIvSize), plaintext); +} +std::vector AES256_CBC::decrypt(std::vector& ciphertext) { + return decrypt(encryption_key, ciphertext); +} + +std::vector AES256_CBC::encrypt(const std::vector& key, const std::vector& iv, + const std::vector& input) +{ + auto ctx = init(Cipher::Mode::kEncrypt, key, iv); + + // Make sure ouput is large enough to contain IV + encrypted data + padding + std::vector output(input.size() + (2 * kIvSize)); + std::copy(iv.begin(), iv.end(), output.begin()); + process_all(Cipher::Mode::kEncrypt, ctx, input.begin(), input.end(), output, kIvSize); + return output; +} + + +int AES256_CBC::encrypt(const std::vector& key, const std::vector& iv, + const std::vector& input, std::vector& output) +{ + auto ctx = init(Cipher::Mode::kEncrypt, key, iv); + process_all(Cipher::Mode::kEncrypt, ctx, input.begin(), input.end(), output, 0); return 0; } -int AES256::encrypt(const std::vector& key, const std::vector& iv, +int AES256_CBC::encrypt(const std::vector& key, const std::vector& iv, std::istream input, const size_t input_len, std::ostream output, size_t& output_len){ - auto inbuf = std::vector(input_len); - auto outbuf = std::vector(input_len + 16); + auto inbuf = std::vector(input_len); + auto outbuf = std::vector(input_len + 16); input.read(reinterpret_cast(inbuf.data()), input_len); - encrypt(key, iv, inbuf, input_len, outbuf, output_len); + encrypt(key, iv, inbuf, outbuf); - output.write(reinterpret_cast(outbuf.data()), output_len); + output.write(reinterpret_cast(outbuf.data()), outbuf.size()); return 0; } + +int AES256_CBC::decrypt(const std::vector& key, const std::vector& iv, + const std::vector& input, std::vector& output) +{ + auto ctx = init(Cipher::Mode::kDecrypt, key, iv); + process_all(Cipher::Mode::kDecrypt, ctx, input.begin(), input.end(), output, 0); + return 0; +} +std::vector AES256_CBC::decrypt(const std::vector& key, + const std::vector& ciphertext){ + std::vector iv(ciphertext.begin(), ciphertext.begin() + kIvSize); + return decrypt(key, iv, ciphertext.begin() + kIvSize, ciphertext.end()); +} +std::vector AES256_CBC::decrypt( + const std::vector& key, const std::vector& iv, + std::vector::const_iterator begin, std::vector::const_iterator end) +{ + auto ctx = init(Cipher::Mode::kDecrypt, key, iv); + std::vector output; + process_all(Cipher::Mode::kDecrypt, ctx, begin, end, output, 0); + return output; +} diff --git a/src/crypto.h b/src/crypto.h index 7b6994b..3a29b95 100644 --- a/src/crypto.h +++ b/src/crypto.h @@ -2,18 +2,113 @@ #include #include -class AES256 { +enum class CipherType {AES256_CBC}; - unsigned char key[256]; - unsigned char iv[128]; - const size_t block_len = 16; +class Cipher { +protected: + std::vector encryption_key; public: + enum Mode { + kEncrypt, + kDecrypt + }; + + virtual std::vector encrypt(std::vector& plaintext) = 0; + virtual std::vector decrypt(std::vector& ciphertext) = 0; +/* + virtual std::pair, std::vector> + encrypt_all(std::vector& plaintext) = 0; + + virtual std::vector + encrypt_all(std::vector& plaintext, + std::ostream); + + virtual std::vector + decrypt_all(std::vector& ciphertext, std::vector& key) = 0; + + virtual void encrypt_init(std::vector& key, std::vector& iv) = 0; + + virtual void encrypt_chunk() = 0; + + virtual void encrypt_fini() = 0; + + virtual void decrypt_init(std::vector& key, std::vector& iv) = 0; + + virtual void decrypt_chunk() = 0; + + virtual void decrypt_fini() = 0; +*/ +}; + +class EvpCipherCtx; +class AES256_CBC : public Cipher { + + const size_t kKeySize = 32; + const size_t kIvSize = 16; + +public: + AES256_CBC(); + AES256_CBC(const std::vector& key); + AES256_CBC(std::vector&& key); + + std::vector encrypt(std::vector& plaintext) override; + std::vector decrypt(std::vector& ciphertext) override; + + std::vector encrypt(const std::vector& key, const std::vector& iv, + const std::vector& input); + + //int encrypt(const std::vector& key, + // const std::vector& input, std::vector& output); + int encrypt(const std::vector& key, const std::vector& iv, - const std::vector& input, const size_t input_len, - std::vector& output, size_t& output_len); + const std::vector& input, std::vector& output); int encrypt(const std::vector& key, const std::vector& iv, std::istream input, const size_t input_len, std::ostream output, size_t& output_len); + + int decrypt(const std::vector& key, const std::vector& iv, + const std::vector& input, std::vector& output); + + + std::pair, std::vector> + encrypt(std::vector plaintext); + + std::vector decrypt(const std::vector& key, + const std::vector& ciphertext); + + std::vector decrypt( + const std::vector& key, const std::vector& iv, + std::vector::const_iterator begin, std::vector::const_iterator end); + +private: + EvpCipherCtx init(Cipher::Mode mode, + const std::vector& key, const std::vector& iv); + + size_t process_chunk(Cipher::Mode mode, EvpCipherCtx& ctx, + std::vector::const_iterator begin, + std::vector::const_iterator end, + std::vector& output, size_t output_offset, + bool resize_in = true, bool resize_out = true); + + size_t process_final(Cipher::Mode mode, EvpCipherCtx& ctx, + std::vector& output, size_t output_offset, + bool resize_in = true, bool resize_out = true); + + size_t process_all(Cipher::Mode mode, EvpCipherCtx& ctx, + std::vector::const_iterator begin, + std::vector::const_iterator end, + std::vector& output, size_t output_offset, + bool resize_in = true, bool resize_out = true); }; + +static std::unique_ptr createCipher(CipherType type){ + switch(type){ + case CipherType::AES256_CBC: + return std::make_unique(); + default: + throw std::runtime_error("Unknown cipher"); + } +} + diff --git a/test/crypto-test.cpp b/test/crypto-test.cpp index fd5bb7a..9f32f4e 100644 --- a/test/crypto-test.cpp +++ b/test/crypto-test.cpp @@ -29,15 +29,64 @@ const std::vector test1_enc { 0x55, 0xd2, 0x04, 0x73, 0x16, 0x39, 0xc7, 0x6a, 0xd3, 0x61, 0x2c, 0x22, 0x59, 0x25, 0xa6, 0x20 }; -TEST(encryptTest, test1){ - const std::vector input(test1_str.begin(), test1_str.end()); - size_t output_len = input.size() + 16; +TEST(CryptoTest, encrypt1){ + const std::vector plaintext(test1_str.begin(), test1_str.end()); + size_t output_len = plaintext.size() + 16; std::vector output(output_len); const std::vector key(test1_key.begin(), test1_key.end()); const std::vector iv(test1_iv.begin(), test1_iv.end()); - AES256 a; - a.encrypt(key, iv, input, input.size(), output, output_len); + AES256_CBC a; + a.encrypt(key, iv, plaintext, output); EXPECT_EQ(test1_enc, output); } + +TEST(CryptoTest, encrypt2){ + const std::vector plaintext(test1_str.begin(), test1_str.end()); + //size_t output_len = plaintext.size() + 16; + //std::vector output(output_len); + const std::vector key(test1_key.begin(), test1_key.end()); + const std::vector iv(test1_iv.begin(), test1_iv.end()); + + AES256_CBC a; + auto output = a.encrypt(key, iv, plaintext); + auto temp = iv; + temp.insert(temp.end(), test1_enc.begin(), test1_enc.end()); + + EXPECT_EQ(temp.size(), output.size()); + EXPECT_EQ(std::vector(temp.begin() + 50, temp.end()), + std::vector(output.begin() + 50, output.end())); + EXPECT_EQ(temp, output); +} + +TEST(CryptoTest, decrypt1){ + const std::vector plaintext(test1_str.begin(), test1_str.end()); + size_t output_len = test1_enc.size(); + std::vector output(output_len); + const std::vector key(test1_key.begin(), test1_key.end()); + const std::vector iv(test1_iv.begin(), test1_iv.end()); + + AES256_CBC a; + a.decrypt(key, iv, test1_enc, output); + + EXPECT_EQ(plaintext, output); +} +TEST(CryptoTest, decrypt2){ + const std::vector plaintext(test1_str.begin(), test1_str.end()); + size_t output_len = test1_enc.size(); + const std::vector key(test1_key.begin(), test1_key.end()); + const std::vector iv(test1_iv.begin(), test1_iv.end()); + // constructs encrypted input (iv + encrypted_data) + auto input = std::vector(iv); + input.insert(input.end(),test1_enc.begin(), test1_enc.end()); + + AES256_CBC a; + auto output = a.decrypt(key, input); + + EXPECT_EQ(plaintext, output); + + AES256_CBC b(key); + output = b.decrypt(input); + EXPECT_EQ(plaintext, output); +}