From 4add849b9e191faa9304a67b80147794e01ae35c Mon Sep 17 00:00:00 2001 From: Timerix Date: Thu, 13 Nov 2025 02:31:00 +0500 Subject: [PATCH] implemented idb encryption --- src/db/idb.c | 154 +++++++++++++++++++++++++++++++++++++++++---------- src/db/idb.h | 5 +- 2 files changed, 127 insertions(+), 32 deletions(-) diff --git a/src/db/idb.c b/src/db/idb.c index 2e22c78..71cffbc 100644 --- a/src/db/idb.c +++ b/src/db/idb.c @@ -2,12 +2,14 @@ #include "magic.h" #include "tlibc/filesystem.h" #include "tlibc/collections/HashMap.h" +#include "cryptography/AES.h" #include typedef struct TableFileHeader { Magic32 magic; u16 version; bool _dirty_bit; + bool encrypted; u32 row_size; } ATTRIBUTE_ALIGNED(256) TableFileHeader; @@ -21,10 +23,14 @@ typedef struct Table { FILE* changes_file; pthread_mutex_t mutex; u64 row_count; + AESBlockEncryptor enc; + AESBlockDecryptor dec; + Array(u8) enc_buf; } Table; typedef struct IncrementalDB { str db_dir; + Array(u8) aes_key; HashMap(Table**) tables_map; pthread_mutex_t mutex; } IncrementalDB; @@ -39,16 +45,17 @@ void Table_close(Table* t){ free(t->table_file_path.data); free(t->changes_file_path.data); pthread_mutex_destroy(&t->mutex); + free(t->enc_buf.data); free(t); } // element destructor for HashMap(Table*) -void TablePtr_destroy(void* t_ptr_ptr){ +static void TablePtr_free(void* t_ptr_ptr){ Table_close(*(Table**)t_ptr_ptr); } /// @param name must be null-terminated -Result(void) validateTableName(str name){ +static Result(void) validateTableName(str name){ char forbidden_characters[] = { '/', '\\', ':', ';', '?', '"', '\'', '\n', '\r', '\t'}; for(u32 i = 0; i < ARRAY_LEN(forbidden_characters); i++) { char c = forbidden_characters[i]; @@ -62,7 +69,7 @@ Result(void) validateTableName(str name){ return RESULT_VOID; } -Result(void) Table_readHeader(Table* t){ +static Result(void) Table_readHeader(Table* t){ Deferral(4); // seek for start of the file try_void(file_seek(t->table_file, 0, SeekOrigin_Start)); @@ -71,7 +78,7 @@ Result(void) Table_readHeader(Table* t){ Return RESULT_VOID; } -Result(void) Table_writeHeader(Table* t){ +static Result(void) Table_writeHeader(Table* t){ Deferral(4); // seek for start of the file try_void(file_seek(t->table_file, 0, SeekOrigin_Start)); @@ -80,35 +87,42 @@ Result(void) Table_writeHeader(Table* t){ Return RESULT_VOID; } -Result(void) Table_setDirtyBit(Table* t, bool val){ +static Result(void) Table_setDirtyBit(Table* t, bool val){ Deferral(4); t->header._dirty_bit = val; try_void(Table_writeHeader(t)); Return RESULT_VOID; } -Result(bool) Table_getDirtyBit(Table* t){ +static Result(bool) Table_getDirtyBit(Table* t){ Deferral(4); try_void(Table_readHeader(t)); Return RESULT_VALUE(i, t->header._dirty_bit); } -Result(void) Table_calculateRowCount(Table* t){ +static u32 Table_calcEncryptedRowSize(Table* t){ + return AESBlockEncryptor_calcDstSize(t->header.row_size); +} + +static Result(void) Table_calculateRowCount(Table* t){ Deferral(4); try(i64 file_size, i, file_getSize(t->table_file)); i64 data_size = file_size - sizeof(t->header); - if(data_size % t->header.row_size != 0){ + i64 row_size_in_file = t->header.encrypted + ? Table_calcEncryptedRowSize(t) + : t->header.row_size; + if(data_size % row_size_in_file != 0){ //TODO: fix table instead of trowing error Return RESULT_ERROR_FMT( - "Table '%s' has invalid size. Last row is incomplete", + "Table '%s' has invalid size. Last row is incomplete.", t->name.data); } - t->row_count = data_size / t->header.row_size; + t->row_count = data_size / row_size_in_file; Return RESULT_VOID; } -Result(void) Table_validateHeader(Table* t){ +static Result(void) Table_validateHeader(Table* t){ Deferral(4); if(t->header.magic.n != TABLE_FILE_MAGIC.n || t->header.row_size == 0) @@ -131,7 +145,24 @@ Result(void) Table_validateHeader(Table* t){ Return RESULT_VOID; } -Result(void) Table_validateRowSize(Table* t, u32 row_size){ +static Result(void) Table_validateEncryption(Table* t){ + bool db_encrypted = t->db->aes_key.size != 0; + if(t->header.encrypted && !db_encrypted){ + return RESULT_ERROR_FMT("Table '%s' is encrypted, but db->aes_key is not set." + "Database '%s' is encrypted and must have not-null encryption key.", + t->name.data, t->db->db_dir.data); + } + + if(!t->header.encrypted && db_encrypted){ + return RESULT_ERROR_FMT("table '%s' is not encrypted, but db->aes_key is set." + "Do not set encryption key for not encrypted database '%s'.", + t->name.data, t->db->db_dir.data); + } + + return RESULT_VOID; +} + +static Result(void) Table_validateRowSize(Table* t, u32 row_size){ if(row_size != t->header.row_size){ ResultVar(void) error_result = RESULT_ERROR_FMT( "Requested row size (%u) doesn't match saved row size (%u)", @@ -142,18 +173,24 @@ Result(void) Table_validateRowSize(Table* t, u32 row_size){ return RESULT_VOID; } -Result(IncrementalDB*) idb_open(str db_dir){ +Result(IncrementalDB*) idb_open(str db_dir, NULLABLE(Array(u8) aes_key)){ Deferral(16); + try_assert(aes_key.size == 0 || aes_key.size == 16 || aes_key.size == 24 || aes_key.size == 32); + IncrementalDB* db = (IncrementalDB*)malloc(sizeof(IncrementalDB)); + // value of *db must be set to zero or behavior of idb_close will be undefined + memset(db, 0, sizeof(IncrementalDB)); // if object construction fails, destroy incomplete object bool success = false; Defer(if(!success) idb_close(db)); - // value of *db must be set to zero or behavior of idb_close will be undefined - memset(db, 0, sizeof(IncrementalDB)); + if(aes_key.size != 0){ + db->aes_key = Array_copy(aes_key); + } + db->db_dir = str_copy(db_dir); try_void(dir_create(db->db_dir.data)); - HashMap_construct(&db->tables_map, Table*, TablePtr_destroy); + HashMap_construct(&db->tables_map, Table*, TablePtr_free); try_stderrcode(pthread_mutex_init(&db->mutex, NULL)); success = true; @@ -162,36 +199,37 @@ Result(IncrementalDB*) idb_open(str db_dir){ void idb_close(IncrementalDB* db){ free(db->db_dir.data); + free(db->aes_key.data); HashMap_destroy(&db->tables_map); pthread_mutex_destroy(&db->mutex); free(db); } -Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str _table_name, u32 row_size){ +Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str table_name, u32 row_size){ Deferral(16); // db lock try_stderrcode(pthread_mutex_lock(&db->mutex)); Defer(pthread_mutex_unlock(&db->mutex)); - Table** tpp = HashMap_tryGetPtr(&db->tables_map, _table_name); + Table** tpp = HashMap_tryGetPtr(&db->tables_map, table_name); if(tpp != NULL){ Table* existing_table = *tpp; try_void(Table_validateRowSize(existing_table, row_size)); Return RESULT_VALUE(p, existing_table); } - try_void(validateTableName(_table_name)); + try_void(validateTableName(table_name)); Table* t = (Table*)malloc(sizeof(Table)); + // value of *t must be set to zero or behavior of Table_close will be undefined + memset(t, 0, sizeof(Table)); // if object construction fails, destroy incomplete object bool success = false; Defer(if(!success) Table_close(t)); - // value of *t must be set to zero or behavior of Table_close will be undefined - memset(t, 0, sizeof(Table)); t->db = db; try_stderrcode(pthread_mutex_init(&t->mutex, NULL)); - t->name = str_copy(_table_name); + t->name = str_copy(table_name); t->table_file_path = str_from_cstr( strcat_malloc(db->db_dir.data, path_seps, t->name.data, ".idb-table")); t->changes_file_path = str_from_cstr( @@ -208,6 +246,7 @@ Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str _table_name, u32 row_ // read table file try_void(Table_readHeader(t)); try_void(Table_validateHeader(t)); + try_void(Table_validateEncryption(t)); try_void(Table_validateRowSize(t, row_size)); try_void(Table_calculateRowCount(t)); } @@ -216,9 +255,18 @@ Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str _table_name, u32 row_ t->header.magic.n = TABLE_FILE_MAGIC.n; t->header.row_size = row_size; t->header.version = IDB_VERSION; + t->header.encrypted = db->aes_key.size != 0; + t->header._dirty_bit = false; try_void(Table_writeHeader(t)); } + if(t->header.encrypted){ + AESBlockEncryptor_construct(&t->enc, db->aes_key, AESBlockEncryptor_DEFAULT_CLASS); + AESBlockDecryptor_construct(&t->dec, db->aes_key, AESBlockDecryptor_DEFAULT_CLASS); + u32 row_size_in_file = Table_calcEncryptedRowSize(t); + t->enc_buf = Array_alloc_size(row_size_in_file); + } + if(!HashMap_tryPush(&db->tables_map, t->name, &t)){ ResultVar(void) error_result = RESULT_ERROR_FMT( "Table '%s' is already open", @@ -243,12 +291,28 @@ Result(void) idb_getRows(Table* t, u64 id, void* dst, u64 count){ count, id, t->name.data, t->row_count); } - i64 file_pos = sizeof(t->header) + id * t->header.row_size; + u32 row_size = t->header.row_size; + u32 row_size_in_file = t->header.encrypted ? t->enc_buf.size : row_size; + i64 file_pos = sizeof(t->header) + id * row_size_in_file; // seek for the row position in file try_void(file_seek(t->table_file, file_pos, SeekOrigin_Start)); + // read rows from file - try_void(file_readStructsExactly(t->table_file, dst, t->header.row_size, count)); + for(u64 i = 0; i < count; i++){ + void* row_ptr = (u8*)dst + row_size * i; + void* read_dst = t->header.encrypted ? t->enc_buf.data : row_ptr; + try_void(file_readStructsExactly(t->table_file, read_dst, row_size_in_file, 1)); + if(t->header.encrypted) { + try_void( + AESBlockDecryptor_decrypt( + &t->dec, + t->enc_buf, + Array_construct_size(row_ptr, row_size) + ) + ); + } + } Return RESULT_VOID; } @@ -269,15 +333,31 @@ Result(void) idb_updateRows(Table* t, u64 id, const void* src, u64 count){ try_void(Table_setDirtyBit(t, true)); Defer(IGNORE_RESULT Table_setDirtyBit(t, false)); - i64 file_pos = sizeof(t->header) + id * t->header.row_size; + u32 row_size = t->header.row_size; + u32 row_size_in_file = t->header.encrypted ? t->enc_buf.size : row_size; + i64 file_pos = sizeof(t->header) + id * row_size_in_file; // TODO: set dirty bit in backup file too // TODO: save old values to the backup file // seek for the row position in file try_void(file_seek(t->table_file, file_pos, SeekOrigin_Start)); + // replace rows in file - try_void(file_writeStructs(t->table_file, src, t->header.row_size, count)); + for(u64 i = 0; i < count; i++){ + void* row_ptr = (u8*)src + row_size * i; + if(t->header.encrypted){ + try_void( + AESBlockEncryptor_encrypt( + &t->enc, + Array_construct_size(row_ptr, row_size), + t->enc_buf + ) + ); + row_ptr = t->enc_buf.data; + } + try_void(file_writeStructs(t->table_file, row_ptr, row_size_in_file, 1)); + } Return RESULT_VOID; } @@ -291,14 +371,30 @@ Result(u64) idb_pushRows(Table* t, const void* src, u64 count){ try_void(Table_setDirtyBit(t, true)); Defer(IGNORE_RESULT Table_setDirtyBit(t, false)); + u32 row_size = t->header.row_size; + u32 row_size_in_file = t->header.encrypted ? t->enc_buf.size : row_size; const u64 new_row_index = t->row_count; // seek for end of the file try_void(file_seek(t->table_file, 0, SeekOrigin_End)); - // write new rows to the file - try_void(file_writeStructs(t->table_file, src, t->header.row_size, count)); - t->row_count += count; + // write new rows to the file + for(u64 i = 0; i < count; i++){ + void* row_ptr = (u8*)src + row_size * i; + if(t->header.encrypted){ + try_void( + AESBlockEncryptor_encrypt( + &t->enc, + Array_construct_size(row_ptr, row_size), + t->enc_buf + ) + ); + row_ptr = t->enc_buf.data; + } + try_void(file_writeStructs(t->table_file, row_ptr, row_size_in_file, 1)); + t->row_count++; + } + Return RESULT_VALUE(u, new_row_index); } diff --git a/src/db/idb.h b/src/db/idb.h index 807c9d2..9da7b23 100644 --- a/src/db/idb.h +++ b/src/db/idb.h @@ -9,11 +9,10 @@ typedef struct IncrementalDB IncrementalDB; typedef struct Table Table; -Result(IncrementalDB*) idb_open(str db_dir); +Result(IncrementalDB*) idb_open(str db_dir, NULLABLE(Array(u8) aes_key)); void idb_close(IncrementalDB* db); - -Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str _table_name, u32 row_size); +Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str table_name, u32 row_size); Result(void) idb_getRows(Table* t, u64 id, void* dst, u64 count); #define idb_getRow(T, ID, DST) idb_getRows(T, ID, DST, 1)