implemented aes key validation
This commit is contained in:
parent
d32f7d4b89
commit
baca2fb4d3
2
.vscode/launch.json
vendored
2
.vscode/launch.json
vendored
@ -7,7 +7,7 @@
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/bin/tcp-chat",
|
||||
"windows": { "program": "${workspaceFolder}/bin/tcp-chat.exe" },
|
||||
"args": [ "-l" ],
|
||||
// "args": [ "-l" ],
|
||||
"preLaunchTask": "build_exec_dbg",
|
||||
"stopAtEntry": false,
|
||||
"cwd": "${workspaceFolder}/bin",
|
||||
|
||||
@ -42,8 +42,8 @@ void ClientCLI_destroy(ClientCLI* self){
|
||||
HashMap_destroy(&self->servers_addr_id_map);
|
||||
}
|
||||
void ClientCLI_construct(ClientCLI* self){
|
||||
self->client = NULL;
|
||||
self->db = NULL;
|
||||
memset(self, 0, sizeof(*self));
|
||||
pthread_mutex_init(&self->servers_cache_mutex, NULL);
|
||||
}
|
||||
|
||||
Result(void) ClientCLI_run(ClientCLI* self) {
|
||||
@ -237,7 +237,8 @@ static Result(void) ClientCLI_selectServerFromCache(ClientCLI* self){
|
||||
|
||||
for(u32 id = 0; id < servers_count; id++){
|
||||
Server* row = &List_index(self->servers_cache_list, Server, id);
|
||||
printf("[%02u] "FMT_str"\n", id, row->name_len, row->name);
|
||||
printf("[%02u] "FMT_str" "FMT_str"\n",
|
||||
id, row->address_len, row->address, row->name_len, row->name);
|
||||
}
|
||||
|
||||
char buf[32];
|
||||
@ -319,7 +320,6 @@ static Result(void) ClientCLI_openUserDB(ClientCLI* self){
|
||||
try(self->db, p, idb_open(user_db_dir, user_data_key));
|
||||
|
||||
// load servers table
|
||||
pthread_mutex_init(&self->servers_cache_mutex, NULL);
|
||||
try(self->db_servers_table, p, idb_getOrCreateTable(self->db, STR("servers"), sizeof(Server)));
|
||||
// load whole table to list
|
||||
try(u64 servers_count, u, idb_getRowCount(self->db_servers_table));
|
||||
|
||||
@ -9,7 +9,7 @@ Result(void) run_ClientMode(cstr config_path) {
|
||||
ClientCLI_construct(&client);
|
||||
Defer(ClientCLI_destroy(&client));
|
||||
// start infinite loop on main thread
|
||||
try_fatal_void(ClientCLI_run(&client));
|
||||
try_void(ClientCLI_run(&client));
|
||||
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
@ -66,7 +66,7 @@ Array(u8) Client_getUserDataKey(Client* client){
|
||||
Result(void) Client_getServerName(Client* self, str* out_name){
|
||||
Deferral(1);
|
||||
try_assert(self != NULL);
|
||||
try_assert(self->server_connection != NULL);
|
||||
try_assert(self->server_connection != NULL && "didn't connect to a server yet");
|
||||
|
||||
*out_name = self->server_connection->server_name;
|
||||
|
||||
@ -76,7 +76,7 @@ Result(void) Client_getServerName(Client* self, str* out_name){
|
||||
Result(void) Client_getServerDescription(Client* self, str* out_desc){
|
||||
Deferral(1);
|
||||
try_assert(self != NULL);
|
||||
try_assert(self->server_connection != NULL);
|
||||
try_assert(self->server_connection != NULL && "didn't connect to a server yet");
|
||||
|
||||
*out_desc = self->server_connection->server_description;
|
||||
|
||||
@ -86,7 +86,7 @@ Result(void) Client_getServerDescription(Client* self, str* out_desc){
|
||||
Result(void) Client_register(Client* self, u64* out_user_id){
|
||||
Deferral(1);
|
||||
try_assert(self != NULL);
|
||||
try_assert(self->server_connection != NULL);
|
||||
try_assert(self->server_connection != NULL && "didn't connect to a server yet");
|
||||
|
||||
PacketHeader req_head, res_head;
|
||||
RegisterRequest req;
|
||||
@ -104,7 +104,7 @@ Result(void) Client_register(Client* self, u64* out_user_id){
|
||||
Result(void) Client_login(Client* self, u64* out_user_id, u64* out_landing_channel_id){
|
||||
Deferral(1);
|
||||
try_assert(self != NULL);
|
||||
try_assert(self->server_connection != NULL);
|
||||
try_assert(self->server_connection != NULL && "didn't connect to a server yet");
|
||||
|
||||
PacketHeader req_head, res_head;
|
||||
LoginRequest req;
|
||||
|
||||
@ -13,6 +13,12 @@ static inline void __Array_readNext(u8* dst, Array(u8)* src, size_t size){
|
||||
*src = Array_sliceFrom(*src, size);
|
||||
}
|
||||
|
||||
static void __calcKeyCheckSum(Array(u8) key, void* dst){
|
||||
br_sha256_context sha_ctx;
|
||||
br_sha256_init(&sha_ctx);
|
||||
br_sha256_update(&sha_ctx, key.data, key.size);
|
||||
br_sha256_out(&sha_ctx, dst);
|
||||
}
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// AESBlockEncryptor //
|
||||
@ -21,19 +27,19 @@ static inline void __Array_readNext(u8* dst, Array(u8)* src, size_t size){
|
||||
void AESBlockEncryptor_construct(AESBlockEncryptor* ptr,
|
||||
Array(u8) key, const br_block_cbcenc_class* enc_class)
|
||||
{
|
||||
assert(key.size == 16 || key.size == 24 || key.size == 32);
|
||||
|
||||
ptr->enc_class = enc_class;
|
||||
ptr->enc_class->init((void*)ptr->enc_keys, key.data, key.size);
|
||||
AESBlockEncryptor_changeKey(ptr, key);
|
||||
|
||||
ptr->rng_ctx.vtable = &br_hmac_drbg_vtable;
|
||||
rng_init_sha256_seedFromSystem(&ptr->rng_ctx.vtable);
|
||||
|
||||
}
|
||||
|
||||
void AESBlockEncryptor_changeKey(AESBlockEncryptor* ptr, Array(u8) key)
|
||||
{
|
||||
assert(key.size == 16 || key.size == 24 || key.size == 32);
|
||||
ptr->enc_class->init((void*)ptr->enc_keys, key.data, key.size);
|
||||
__calcKeyCheckSum(key, ptr->key_checksum);
|
||||
}
|
||||
|
||||
Result(u32) AESBlockEncryptor_encrypt(AESBlockEncryptor* ptr,
|
||||
@ -50,6 +56,7 @@ Result(u32) AESBlockEncryptor_encrypt(AESBlockEncryptor* ptr,
|
||||
|
||||
EncryptedBlockHeader header;
|
||||
memset(&header, 0, sizeof(header));
|
||||
memcpy(header.key_checksum, ptr->key_checksum, __AES_BLOCK_KEY_CHECKSUM_SIZE);
|
||||
header.padding_size = (16 - src.size % 16) % 16;
|
||||
// write header to buffer
|
||||
memcpy(ptr->buf, &header, sizeof(header));
|
||||
@ -89,16 +96,15 @@ Result(u32) AESBlockEncryptor_encrypt(AESBlockEncryptor* ptr,
|
||||
void AESBlockDecryptor_construct(AESBlockDecryptor* ptr,
|
||||
Array(u8) key, const br_block_cbcdec_class* dec_class)
|
||||
{
|
||||
assert(key.size == 16 || key.size == 24 || key.size == 32);
|
||||
|
||||
ptr->dec_class = dec_class;
|
||||
ptr->dec_class->init((void*)ptr->dec_keys, key.data, key.size);
|
||||
AESBlockDecryptor_changeKey(ptr, key);
|
||||
}
|
||||
|
||||
void AESBlockDecryptor_changeKey(AESBlockDecryptor* ptr, Array(u8) key)
|
||||
{
|
||||
assert(key.size == 16 || key.size == 24 || key.size == 32);
|
||||
ptr->dec_class->init((void*)ptr->dec_keys, key.data, key.size);
|
||||
__calcKeyCheckSum(key, ptr->key_checksum);
|
||||
}
|
||||
|
||||
Result(u32) AESBlockDecryptor_decrypt(AESBlockDecryptor* ptr,
|
||||
@ -118,7 +124,14 @@ Result(u32) AESBlockDecryptor_decrypt(AESBlockDecryptor* ptr,
|
||||
__Array_readNext((void*)&header, &src, sizeof(header));
|
||||
// decrypt header
|
||||
ptr->dec_class->run((void*)ptr->dec_keys, ptr->iv, &header, sizeof(header));
|
||||
|
||||
// validate decrypted data
|
||||
if(memcmp(header.key_checksum, ptr->key_checksum, __AES_BLOCK_KEY_CHECKSUM_SIZE) != 0){
|
||||
Return RESULT_ERROR("decrypted data is invalid or key is wrong", false);
|
||||
}
|
||||
|
||||
// size of decrypted data without padding
|
||||
try_assert(src.size >= header.padding_size && "invalid padding size");
|
||||
u32 decrypted_size = src.size - header.padding_size;
|
||||
src.size = decrypted_size;
|
||||
|
||||
@ -153,23 +166,21 @@ Result(u32) AESBlockDecryptor_decrypt(AESBlockDecryptor* ptr,
|
||||
void AESStreamEncryptor_construct(AESStreamEncryptor* ptr,
|
||||
Array(u8) key, const br_block_ctr_class* ctr_class)
|
||||
{
|
||||
assert(key.size == 16 || key.size == 24 || key.size == 32);
|
||||
|
||||
ptr->ctr_class = ctr_class;
|
||||
ptr->ctr_class->init((void*)ptr->ctr_keys, key.data, key.size);
|
||||
AESStreamEncryptor_changeKey(ptr, key);
|
||||
ptr->block_counter = 0;
|
||||
|
||||
br_hmac_drbg_context rng_ctx;
|
||||
rng_ctx.vtable = &br_hmac_drbg_vtable;
|
||||
rng_init_sha256_seedFromSystem(&rng_ctx.vtable);
|
||||
br_hmac_drbg_generate(&rng_ctx, ptr->iv, __AES_STREAM_IV_SIZE);
|
||||
|
||||
ptr->block_counter = 0;
|
||||
}
|
||||
|
||||
void AESStreamEncryptor_changeKey(AESStreamEncryptor* ptr, Array(u8) key)
|
||||
{
|
||||
assert(key.size == 16 || key.size == 24 || key.size == 32);
|
||||
ptr->ctr_class->init((void*)ptr->ctr_keys, key.data, key.size);
|
||||
__calcKeyCheckSum(key, ptr->key_checksum);
|
||||
}
|
||||
|
||||
Result(u32) AESStreamEncryptor_encrypt(AESStreamEncryptor* ptr,
|
||||
@ -178,17 +189,27 @@ Result(u32) AESStreamEncryptor_encrypt(AESStreamEncryptor* ptr,
|
||||
Deferral(4);
|
||||
|
||||
u32 encrypted_size = src.size;
|
||||
// if it is the beginning of the stream, write IV
|
||||
// if it is the beginning of the stream,
|
||||
if(ptr->block_counter == 0){
|
||||
// write IV generated during initialization
|
||||
__Array_writeNext(&dst, ptr->iv, __AES_STREAM_IV_SIZE);
|
||||
encrypted_size = AESStreamEncryptor_calcDstSize(encrypted_size);
|
||||
|
||||
// encrypt checksum
|
||||
u8 key_checksum[__AES_BLOCK_KEY_CHECKSUM_SIZE];
|
||||
memcpy(key_checksum, ptr->key_checksum, __AES_BLOCK_KEY_CHECKSUM_SIZE);
|
||||
ptr->block_counter = ptr->ctr_class->run((void*)ptr->ctr_keys,
|
||||
ptr->iv, ptr->block_counter,
|
||||
key_checksum, __AES_BLOCK_KEY_CHECKSUM_SIZE);
|
||||
// write checksum to dst
|
||||
__Array_writeNext(&dst, key_checksum, __AES_BLOCK_KEY_CHECKSUM_SIZE);
|
||||
}
|
||||
try_assert(dst.size >= encrypted_size);
|
||||
|
||||
// encrypt full buffers
|
||||
while(src.size > __AES_BUFFER_SIZE){
|
||||
__Array_readNext(ptr->buf, &src, __AES_BUFFER_SIZE);
|
||||
ptr->ctr_class->run((void*)ptr->ctr_keys,
|
||||
ptr->block_counter = ptr->ctr_class->run((void*)ptr->ctr_keys,
|
||||
ptr->iv, ptr->block_counter,
|
||||
ptr->buf, __AES_BUFFER_SIZE);
|
||||
__Array_writeNext(&dst, ptr->buf, __AES_BUFFER_SIZE);
|
||||
@ -197,13 +218,12 @@ Result(u32) AESStreamEncryptor_encrypt(AESStreamEncryptor* ptr,
|
||||
// encrypt remaining data
|
||||
if(src.size > 0){
|
||||
memcpy(ptr->buf, src.data, src.size);
|
||||
ptr->ctr_class->run((void*)ptr->ctr_keys,
|
||||
ptr->block_counter = ptr->ctr_class->run((void*)ptr->ctr_keys,
|
||||
ptr->iv, ptr->block_counter,
|
||||
ptr->buf, src.size);
|
||||
memcpy(dst.data, ptr->buf, src.size);
|
||||
}
|
||||
|
||||
ptr->block_counter++;
|
||||
Return RESULT_VALUE(u, encrypted_size);
|
||||
}
|
||||
|
||||
@ -215,11 +235,8 @@ Result(u32) AESStreamEncryptor_encrypt(AESStreamEncryptor* ptr,
|
||||
void AESStreamDecryptor_construct(AESStreamDecryptor* ptr,
|
||||
Array(u8) key, const br_block_ctr_class* ctr_class)
|
||||
{
|
||||
assert(key.size == 16 || key.size == 24 || key.size == 32);
|
||||
|
||||
ptr->ctr_class = ctr_class;
|
||||
ptr->ctr_class->init((void*)ptr->ctr_keys, key.data, key.size);
|
||||
|
||||
AESStreamDecryptor_changeKey(ptr, key);
|
||||
ptr->block_counter = 0;
|
||||
}
|
||||
|
||||
@ -227,6 +244,7 @@ void AESStreamDecryptor_changeKey(AESStreamDecryptor* ptr, Array(u8) key)
|
||||
{
|
||||
assert(key.size == 16 || key.size == 24 || key.size == 32);
|
||||
ptr->ctr_class->init((void*)ptr->ctr_keys, key.data, key.size);
|
||||
__calcKeyCheckSum(key, ptr->key_checksum);
|
||||
}
|
||||
|
||||
Result(u32) AESStreamDecryptor_decrypt(AESStreamDecryptor* ptr,
|
||||
@ -234,9 +252,22 @@ Result(u32) AESStreamDecryptor_decrypt(AESStreamDecryptor* ptr,
|
||||
{
|
||||
Deferral(4);
|
||||
|
||||
// if it is the beginning of the stream, read IV
|
||||
// if it is the beginning of the stream
|
||||
if(ptr->block_counter == 0){
|
||||
// read random IV
|
||||
__Array_readNext(ptr->iv, &src, __AES_STREAM_IV_SIZE);
|
||||
|
||||
// read checksum
|
||||
u8 key_checksum[__AES_BLOCK_KEY_CHECKSUM_SIZE];
|
||||
__Array_readNext(key_checksum, &src, __AES_BLOCK_KEY_CHECKSUM_SIZE);
|
||||
// decrypt checksum
|
||||
ptr->block_counter = ptr->ctr_class->run((void*)ptr->ctr_keys,
|
||||
ptr->iv, ptr->block_counter,
|
||||
key_checksum, __AES_BLOCK_KEY_CHECKSUM_SIZE);
|
||||
// validate decrypted data
|
||||
if(memcmp(key_checksum, ptr->key_checksum, __AES_BLOCK_KEY_CHECKSUM_SIZE) != 0){
|
||||
Return RESULT_ERROR("decrypted data is invalid or key is wrong", false);
|
||||
}
|
||||
}
|
||||
// size without IV
|
||||
u32 decrypted_size = src.size;
|
||||
@ -245,7 +276,7 @@ Result(u32) AESStreamDecryptor_decrypt(AESStreamDecryptor* ptr,
|
||||
// decrypt full buffers
|
||||
while(src.size > __AES_BUFFER_SIZE){
|
||||
__Array_readNext(ptr->buf, &src, __AES_BUFFER_SIZE);
|
||||
ptr->ctr_class->run((void*)ptr->ctr_keys,
|
||||
ptr->block_counter = ptr->ctr_class->run((void*)ptr->ctr_keys,
|
||||
ptr->iv, ptr->block_counter,
|
||||
ptr->buf, __AES_BUFFER_SIZE);
|
||||
__Array_writeNext(&dst, ptr->buf, __AES_BUFFER_SIZE);
|
||||
@ -254,12 +285,11 @@ Result(u32) AESStreamDecryptor_decrypt(AESStreamDecryptor* ptr,
|
||||
// decrypt remaining data
|
||||
if(src.size > 0){
|
||||
memcpy(ptr->buf, src.data, src.size);
|
||||
ptr->ctr_class->run((void*)ptr->ctr_keys,
|
||||
ptr->block_counter = ptr->ctr_class->run((void*)ptr->ctr_keys,
|
||||
ptr->iv, ptr->block_counter,
|
||||
ptr->buf, src.size);
|
||||
memcpy(dst.data, ptr->buf, src.size);
|
||||
}
|
||||
|
||||
ptr->block_counter++;
|
||||
Return RESULT_VALUE(u, decrypted_size);
|
||||
}
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
#include "tlibc/collections/Array.h"
|
||||
#include "tlibc/errors.h"
|
||||
#include "tlibc/magic.h"
|
||||
#include "bearssl_block.h"
|
||||
#include "cryptography.h"
|
||||
|
||||
@ -15,7 +16,9 @@
|
||||
#define AESStream_DEFAULT_CLASS (&br_aes_big_ctr_vtable)
|
||||
|
||||
|
||||
#define __AES_BLOCK_KEY_CHECKSUM_SIZE br_sha256_SIZE
|
||||
typedef struct EncryptedBlockHeader {
|
||||
u8 key_checksum[__AES_BLOCK_KEY_CHECKSUM_SIZE];
|
||||
u8 padding_size;
|
||||
} ATTRIBUTE_ALIGNED(16) EncryptedBlockHeader;
|
||||
|
||||
@ -32,6 +35,7 @@ typedef struct AESBlockEncryptor {
|
||||
u8 enc_keys[sizeof(br_aes_big_cbcenc_keys)];
|
||||
u8 buf[__AES_BUFFER_SIZE];
|
||||
u8 iv[__AES_BLOCK_IV_SIZE];
|
||||
u8 key_checksum[__AES_BLOCK_KEY_CHECKSUM_SIZE];
|
||||
br_hmac_drbg_context rng_ctx;
|
||||
} AESBlockEncryptor;
|
||||
|
||||
@ -59,6 +63,7 @@ typedef struct AESBlockDecryptor {
|
||||
u8 dec_keys[sizeof(br_aes_big_cbcdec_keys)];
|
||||
u8 buf[__AES_BUFFER_SIZE];
|
||||
u8 iv[__AES_BLOCK_IV_SIZE];
|
||||
u8 key_checksum[__AES_BLOCK_KEY_CHECKSUM_SIZE];
|
||||
} AESBlockDecryptor;
|
||||
|
||||
/// @param key supported sizes: 16, 24, 32
|
||||
@ -85,6 +90,7 @@ typedef struct AESStreamEncryptor {
|
||||
u8 ctr_keys[sizeof(br_aes_big_ctr_keys)];
|
||||
u8 buf[__AES_BUFFER_SIZE];
|
||||
u8 iv[__AES_STREAM_IV_SIZE];
|
||||
u8 key_checksum[__AES_BLOCK_KEY_CHECKSUM_SIZE];
|
||||
u32 block_counter;
|
||||
} AESStreamEncryptor;
|
||||
|
||||
@ -97,7 +103,7 @@ void AESStreamEncryptor_construct(AESStreamEncryptor* ptr, Array(u8) key, const
|
||||
void AESStreamEncryptor_changeKey(AESStreamEncryptor* ptr, Array(u8) key);
|
||||
|
||||
/// use this only at the beginning of the stream
|
||||
#define AESStreamEncryptor_calcDstSize(src_size) (src_size + __AES_STREAM_IV_SIZE)
|
||||
#define AESStreamEncryptor_calcDstSize(src_size) (__AES_STREAM_IV_SIZE + __AES_BLOCK_KEY_CHECKSUM_SIZE + src_size)
|
||||
|
||||
/// @brief If ptr->block_counter == 0, writes random IV to `dst`. After that writes encrypted data to dst.
|
||||
/// @param src array of any size
|
||||
@ -114,6 +120,7 @@ typedef struct AESStreamDecryptor {
|
||||
u8 ctr_keys[sizeof(br_aes_big_ctr_keys)];
|
||||
u8 buf[__AES_BUFFER_SIZE];
|
||||
u8 iv[__AES_STREAM_IV_SIZE];
|
||||
u8 key_checksum[__AES_BLOCK_KEY_CHECKSUM_SIZE];
|
||||
u32 block_counter;
|
||||
} AESStreamDecryptor;
|
||||
|
||||
|
||||
78
src/db/idb.c
78
src/db/idb.c
@ -5,12 +5,18 @@
|
||||
#include "cryptography/AES.h"
|
||||
#include <pthread.h>
|
||||
|
||||
static const char KEY_CHALLENGE_PLAIN[16] = "key is correct!";
|
||||
#define KEY_CHALLENGE_PLAIN_SIZE sizeof(KEY_CHALLENGE_PLAIN)
|
||||
#define KEY_CHALLENGE_CIPHER_SIZE AESBlockEncryptor_calcDstSize(KEY_CHALLENGE_PLAIN_SIZE)
|
||||
|
||||
typedef struct TableFileHeader {
|
||||
Magic32 magic;
|
||||
u16 version;
|
||||
bool _dirty_bit;
|
||||
bool encrypted;
|
||||
u32 row_size;
|
||||
/* encrypted KEY_CHALLENGE_PLAIN */
|
||||
u8 key_challenge[KEY_CHALLENGE_CIPHER_SIZE];
|
||||
} ATTRIBUTE_ALIGNED(256) TableFileHeader;
|
||||
|
||||
typedef struct Table {
|
||||
@ -105,16 +111,12 @@ static Result(bool) Table_getDirtyBit(Table* t){
|
||||
Return RESULT_VALUE(i, t->header._dirty_bit);
|
||||
}
|
||||
|
||||
static u32 Table_calcEncryptedRowSize(Table* t){
|
||||
return AESBlockEncryptor_calcDstSize(t->header.row_size);
|
||||
}
|
||||
|
||||
static Result(void) Table_calculateRowCount(Table* t){
|
||||
Deferral(4);
|
||||
try(i64 file_size, i, file_getSize(t->table_file));
|
||||
i64 data_size = file_size - sizeof(t->header);
|
||||
i64 row_size_in_file = t->header.encrypted
|
||||
? Table_calcEncryptedRowSize(t)
|
||||
? AESBlockEncryptor_calcDstSize(t->header.row_size)
|
||||
: t->header.row_size;
|
||||
if(data_size % row_size_in_file != 0){
|
||||
//TODO: fix table instead of trowing error
|
||||
@ -137,7 +139,11 @@ static Result(void) Table_validateHeader(Table* t){
|
||||
t->table_file_path.data);
|
||||
}
|
||||
|
||||
//TODO: check version
|
||||
if(t->header.version != IDB_VERSION){
|
||||
Return RESULT_ERROR_FMT(
|
||||
"Table file '%s' was created for IDBv%u, which is incompatible with v%u",
|
||||
t->table_file_path.data, t->header.version, (u32)IDB_VERSION);
|
||||
}
|
||||
|
||||
try(bool dirty_bit, i, Table_getDirtyBit(t));
|
||||
if(dirty_bit){
|
||||
@ -151,20 +157,40 @@ static Result(void) Table_validateHeader(Table* t){
|
||||
}
|
||||
|
||||
static Result(void) Table_validateEncryption(Table* t){
|
||||
Deferral(1);
|
||||
|
||||
bool db_encrypted = t->db->aes_key.size != 0;
|
||||
if(t->header.encrypted && !db_encrypted){
|
||||
return RESULT_ERROR_FMT("Table '%s' is encrypted, but db->aes_key is not set."
|
||||
Return RESULT_ERROR_FMT(
|
||||
"Table '%s' is encrypted, but encryption key is not set."
|
||||
"Database '%s' is encrypted and must have not-null encryption key.",
|
||||
t->name.data, t->db->db_dir.data);
|
||||
}
|
||||
|
||||
if(!t->header.encrypted && db_encrypted){
|
||||
return RESULT_ERROR_FMT("table '%s' is not encrypted, but db->aes_key is set."
|
||||
Return RESULT_ERROR_FMT(
|
||||
"Table '%s' is not encrypted, but encryption key is set."
|
||||
"Do not set encryption key for not encrypted database '%s'.",
|
||||
t->name.data, t->db->db_dir.data);
|
||||
}
|
||||
|
||||
return RESULT_VOID;
|
||||
// validate aes encryption key
|
||||
if(t->header.encrypted){
|
||||
try_void(
|
||||
AESBlockDecryptor_decrypt(
|
||||
&t->dec,
|
||||
Array_construct_size(t->header.key_challenge, KEY_CHALLENGE_CIPHER_SIZE),
|
||||
t->enc_buf
|
||||
)
|
||||
);
|
||||
if(memcmp(t->enc_buf.data, KEY_CHALLENGE_PLAIN, KEY_CHALLENGE_PLAIN_SIZE) != 0){
|
||||
Return RESULT_ERROR_FMT(
|
||||
"Encryption key for table '%s' is wrong",
|
||||
t->name.data);
|
||||
}
|
||||
}
|
||||
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
static Result(void) Table_validateRowSize(Table* t, u32 row_size){
|
||||
@ -177,6 +203,7 @@ static Result(void) Table_validateRowSize(Table* t, u32 row_size){
|
||||
return RESULT_VOID;
|
||||
}
|
||||
|
||||
|
||||
Result(IncrementalDB*) idb_open(str db_dir, NULLABLE(Array(u8) aes_key)){
|
||||
Deferral(16);
|
||||
try_assert(aes_key.size == 0 || aes_key.size == 16 || aes_key.size == 24 || aes_key.size == 32);
|
||||
@ -190,7 +217,6 @@ Result(IncrementalDB*) idb_open(str db_dir, NULLABLE(Array(u8) aes_key)){
|
||||
|
||||
if(aes_key.size != 0){
|
||||
db->aes_key = Array_copy(aes_key);
|
||||
//TODO: validate aes encryption key
|
||||
}
|
||||
|
||||
db->db_dir = str_copy(db_dir);
|
||||
@ -242,14 +268,23 @@ Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str table_name, u32 row_s
|
||||
t->changes_file_path = str_from_cstr(
|
||||
strcat_malloc(db->db_dir.data, path_seps, t->name.data, ".idb-changes"));
|
||||
|
||||
bool table_exists = file_exists(t->table_file_path.data);
|
||||
bool table_file_exists = file_exists(t->table_file_path.data);
|
||||
|
||||
// open or create file with table data
|
||||
try(t->table_file, p, file_openOrCreateReadWrite(t->table_file_path.data));
|
||||
// open or create file with backups of updated rows
|
||||
try(t->changes_file, p, file_openOrCreateReadWrite(t->changes_file_path.data));
|
||||
|
||||
if(table_exists){
|
||||
// init encryptor and decryptor now to use them in table header validation/creation
|
||||
if(db->aes_key.size != 0) {
|
||||
AESBlockEncryptor_construct(&t->enc, db->aes_key, AESBlockEncryptor_DEFAULT_CLASS);
|
||||
AESBlockDecryptor_construct(&t->dec, db->aes_key, AESBlockDecryptor_DEFAULT_CLASS);
|
||||
u32 row_size_in_file = AESBlockEncryptor_calcDstSize(row_size);
|
||||
t->enc_buf = Array_alloc_size(row_size_in_file);
|
||||
}
|
||||
|
||||
// init header
|
||||
if(table_file_exists){
|
||||
// read table file
|
||||
try_void(Table_readHeader(t));
|
||||
try_void(Table_validateHeader(t));
|
||||
@ -260,20 +295,21 @@ Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str table_name, u32 row_s
|
||||
else {
|
||||
// create table file
|
||||
t->header.magic.n = TABLE_FILE_MAGIC.n;
|
||||
t->header.row_size = row_size;
|
||||
t->header.version = IDB_VERSION;
|
||||
t->header.encrypted = db->aes_key.size != 0;
|
||||
t->header._dirty_bit = false;
|
||||
t->header.row_size = row_size;
|
||||
memset(t->header.key_challenge, 0, KEY_CHALLENGE_CIPHER_SIZE);
|
||||
try_void(
|
||||
AESBlockEncryptor_encrypt(
|
||||
&t->enc,
|
||||
Array_construct_size((void*)KEY_CHALLENGE_PLAIN, KEY_CHALLENGE_PLAIN_SIZE),
|
||||
Array_construct_size(t->header.key_challenge, KEY_CHALLENGE_CIPHER_SIZE)
|
||||
)
|
||||
);
|
||||
try_void(Table_writeHeader(t));
|
||||
}
|
||||
|
||||
if(t->header.encrypted){
|
||||
AESBlockEncryptor_construct(&t->enc, db->aes_key, AESBlockEncryptor_DEFAULT_CLASS);
|
||||
AESBlockDecryptor_construct(&t->dec, db->aes_key, AESBlockDecryptor_DEFAULT_CLASS);
|
||||
u32 row_size_in_file = Table_calcEncryptedRowSize(t);
|
||||
t->enc_buf = Array_alloc_size(row_size_in_file);
|
||||
}
|
||||
|
||||
if(!HashMap_tryPush(&db->tables_map, t->name, &t)){
|
||||
Return RESULT_ERROR_FMT(
|
||||
"Table '%s' is already open",
|
||||
@ -331,7 +367,7 @@ Result(void) idb_updateRows(Table* t, u64 id, const void* src, u64 count){
|
||||
try_stderrcode(pthread_mutex_lock(&t->mutex));
|
||||
Defer(pthread_mutex_unlock(&t->mutex));
|
||||
|
||||
if(id + count >= t->row_count){
|
||||
if(id + count > t->row_count){
|
||||
Return RESULT_ERROR_FMT(
|
||||
"Can't update "FMT_u64" rows at index "FMT_u64
|
||||
" because table '%s' has only "FMT_u64" rows",
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
#include "tlibc/errors.h"
|
||||
|
||||
#define IDB_VERSION 1
|
||||
#define IDB_VERSION 2
|
||||
#define IDB_AES_KEY_SIZE 32
|
||||
|
||||
typedef struct IncrementalDB IncrementalDB;
|
||||
|
||||
Loading…
Reference in New Issue
Block a user