implemented idb encryption

This commit is contained in:
Timerix 2025-11-13 02:31:00 +05:00
parent 2f51cd07ff
commit 4add849b9e
2 changed files with 127 additions and 32 deletions

View File

@ -2,12 +2,14 @@
#include "magic.h"
#include "tlibc/filesystem.h"
#include "tlibc/collections/HashMap.h"
#include "cryptography/AES.h"
#include <pthread.h>
typedef struct TableFileHeader {
Magic32 magic;
u16 version;
bool _dirty_bit;
bool encrypted;
u32 row_size;
} ATTRIBUTE_ALIGNED(256) TableFileHeader;
@ -21,10 +23,14 @@ typedef struct Table {
FILE* changes_file;
pthread_mutex_t mutex;
u64 row_count;
AESBlockEncryptor enc;
AESBlockDecryptor dec;
Array(u8) enc_buf;
} Table;
typedef struct IncrementalDB {
str db_dir;
Array(u8) aes_key;
HashMap(Table**) tables_map;
pthread_mutex_t mutex;
} IncrementalDB;
@ -39,16 +45,17 @@ void Table_close(Table* t){
free(t->table_file_path.data);
free(t->changes_file_path.data);
pthread_mutex_destroy(&t->mutex);
free(t->enc_buf.data);
free(t);
}
// element destructor for HashMap(Table*)
void TablePtr_destroy(void* t_ptr_ptr){
static void TablePtr_free(void* t_ptr_ptr){
Table_close(*(Table**)t_ptr_ptr);
}
/// @param name must be null-terminated
Result(void) validateTableName(str name){
static Result(void) validateTableName(str name){
char forbidden_characters[] = { '/', '\\', ':', ';', '?', '"', '\'', '\n', '\r', '\t'};
for(u32 i = 0; i < ARRAY_LEN(forbidden_characters); i++) {
char c = forbidden_characters[i];
@ -62,7 +69,7 @@ Result(void) validateTableName(str name){
return RESULT_VOID;
}
Result(void) Table_readHeader(Table* t){
static Result(void) Table_readHeader(Table* t){
Deferral(4);
// seek for start of the file
try_void(file_seek(t->table_file, 0, SeekOrigin_Start));
@ -71,7 +78,7 @@ Result(void) Table_readHeader(Table* t){
Return RESULT_VOID;
}
Result(void) Table_writeHeader(Table* t){
static Result(void) Table_writeHeader(Table* t){
Deferral(4);
// seek for start of the file
try_void(file_seek(t->table_file, 0, SeekOrigin_Start));
@ -80,35 +87,42 @@ Result(void) Table_writeHeader(Table* t){
Return RESULT_VOID;
}
Result(void) Table_setDirtyBit(Table* t, bool val){
static Result(void) Table_setDirtyBit(Table* t, bool val){
Deferral(4);
t->header._dirty_bit = val;
try_void(Table_writeHeader(t));
Return RESULT_VOID;
}
Result(bool) Table_getDirtyBit(Table* t){
static Result(bool) Table_getDirtyBit(Table* t){
Deferral(4);
try_void(Table_readHeader(t));
Return RESULT_VALUE(i, t->header._dirty_bit);
}
Result(void) Table_calculateRowCount(Table* t){
static u32 Table_calcEncryptedRowSize(Table* t){
return AESBlockEncryptor_calcDstSize(t->header.row_size);
}
static Result(void) Table_calculateRowCount(Table* t){
Deferral(4);
try(i64 file_size, i, file_getSize(t->table_file));
i64 data_size = file_size - sizeof(t->header);
if(data_size % t->header.row_size != 0){
i64 row_size_in_file = t->header.encrypted
? Table_calcEncryptedRowSize(t)
: t->header.row_size;
if(data_size % row_size_in_file != 0){
//TODO: fix table instead of trowing error
Return RESULT_ERROR_FMT(
"Table '%s' has invalid size. Last row is incomplete",
"Table '%s' has invalid size. Last row is incomplete.",
t->name.data);
}
t->row_count = data_size / t->header.row_size;
t->row_count = data_size / row_size_in_file;
Return RESULT_VOID;
}
Result(void) Table_validateHeader(Table* t){
static Result(void) Table_validateHeader(Table* t){
Deferral(4);
if(t->header.magic.n != TABLE_FILE_MAGIC.n
|| t->header.row_size == 0)
@ -131,7 +145,24 @@ Result(void) Table_validateHeader(Table* t){
Return RESULT_VOID;
}
Result(void) Table_validateRowSize(Table* t, u32 row_size){
static Result(void) Table_validateEncryption(Table* t){
bool db_encrypted = t->db->aes_key.size != 0;
if(t->header.encrypted && !db_encrypted){
return RESULT_ERROR_FMT("Table '%s' is encrypted, but db->aes_key is not set."
"Database '%s' is encrypted and must have not-null encryption key.",
t->name.data, t->db->db_dir.data);
}
if(!t->header.encrypted && db_encrypted){
return RESULT_ERROR_FMT("table '%s' is not encrypted, but db->aes_key is set."
"Do not set encryption key for not encrypted database '%s'.",
t->name.data, t->db->db_dir.data);
}
return RESULT_VOID;
}
static Result(void) Table_validateRowSize(Table* t, u32 row_size){
if(row_size != t->header.row_size){
ResultVar(void) error_result = RESULT_ERROR_FMT(
"Requested row size (%u) doesn't match saved row size (%u)",
@ -142,18 +173,24 @@ Result(void) Table_validateRowSize(Table* t, u32 row_size){
return RESULT_VOID;
}
Result(IncrementalDB*) idb_open(str db_dir){
Result(IncrementalDB*) idb_open(str db_dir, NULLABLE(Array(u8) aes_key)){
Deferral(16);
try_assert(aes_key.size == 0 || aes_key.size == 16 || aes_key.size == 24 || aes_key.size == 32);
IncrementalDB* db = (IncrementalDB*)malloc(sizeof(IncrementalDB));
// value of *db must be set to zero or behavior of idb_close will be undefined
memset(db, 0, sizeof(IncrementalDB));
// if object construction fails, destroy incomplete object
bool success = false;
Defer(if(!success) idb_close(db));
// value of *db must be set to zero or behavior of idb_close will be undefined
memset(db, 0, sizeof(IncrementalDB));
if(aes_key.size != 0){
db->aes_key = Array_copy(aes_key);
}
db->db_dir = str_copy(db_dir);
try_void(dir_create(db->db_dir.data));
HashMap_construct(&db->tables_map, Table*, TablePtr_destroy);
HashMap_construct(&db->tables_map, Table*, TablePtr_free);
try_stderrcode(pthread_mutex_init(&db->mutex, NULL));
success = true;
@ -162,36 +199,37 @@ Result(IncrementalDB*) idb_open(str db_dir){
void idb_close(IncrementalDB* db){
free(db->db_dir.data);
free(db->aes_key.data);
HashMap_destroy(&db->tables_map);
pthread_mutex_destroy(&db->mutex);
free(db);
}
Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str _table_name, u32 row_size){
Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str table_name, u32 row_size){
Deferral(16);
// db lock
try_stderrcode(pthread_mutex_lock(&db->mutex));
Defer(pthread_mutex_unlock(&db->mutex));
Table** tpp = HashMap_tryGetPtr(&db->tables_map, _table_name);
Table** tpp = HashMap_tryGetPtr(&db->tables_map, table_name);
if(tpp != NULL){
Table* existing_table = *tpp;
try_void(Table_validateRowSize(existing_table, row_size));
Return RESULT_VALUE(p, existing_table);
}
try_void(validateTableName(_table_name));
try_void(validateTableName(table_name));
Table* t = (Table*)malloc(sizeof(Table));
// value of *t must be set to zero or behavior of Table_close will be undefined
memset(t, 0, sizeof(Table));
// if object construction fails, destroy incomplete object
bool success = false;
Defer(if(!success) Table_close(t));
// value of *t must be set to zero or behavior of Table_close will be undefined
memset(t, 0, sizeof(Table));
t->db = db;
try_stderrcode(pthread_mutex_init(&t->mutex, NULL));
t->name = str_copy(_table_name);
t->name = str_copy(table_name);
t->table_file_path = str_from_cstr(
strcat_malloc(db->db_dir.data, path_seps, t->name.data, ".idb-table"));
t->changes_file_path = str_from_cstr(
@ -208,6 +246,7 @@ Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str _table_name, u32 row_
// read table file
try_void(Table_readHeader(t));
try_void(Table_validateHeader(t));
try_void(Table_validateEncryption(t));
try_void(Table_validateRowSize(t, row_size));
try_void(Table_calculateRowCount(t));
}
@ -216,9 +255,18 @@ Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str _table_name, u32 row_
t->header.magic.n = TABLE_FILE_MAGIC.n;
t->header.row_size = row_size;
t->header.version = IDB_VERSION;
t->header.encrypted = db->aes_key.size != 0;
t->header._dirty_bit = false;
try_void(Table_writeHeader(t));
}
if(t->header.encrypted){
AESBlockEncryptor_construct(&t->enc, db->aes_key, AESBlockEncryptor_DEFAULT_CLASS);
AESBlockDecryptor_construct(&t->dec, db->aes_key, AESBlockDecryptor_DEFAULT_CLASS);
u32 row_size_in_file = Table_calcEncryptedRowSize(t);
t->enc_buf = Array_alloc_size(row_size_in_file);
}
if(!HashMap_tryPush(&db->tables_map, t->name, &t)){
ResultVar(void) error_result = RESULT_ERROR_FMT(
"Table '%s' is already open",
@ -243,12 +291,28 @@ Result(void) idb_getRows(Table* t, u64 id, void* dst, u64 count){
count, id, t->name.data, t->row_count);
}
i64 file_pos = sizeof(t->header) + id * t->header.row_size;
u32 row_size = t->header.row_size;
u32 row_size_in_file = t->header.encrypted ? t->enc_buf.size : row_size;
i64 file_pos = sizeof(t->header) + id * row_size_in_file;
// seek for the row position in file
try_void(file_seek(t->table_file, file_pos, SeekOrigin_Start));
// read rows from file
try_void(file_readStructsExactly(t->table_file, dst, t->header.row_size, count));
for(u64 i = 0; i < count; i++){
void* row_ptr = (u8*)dst + row_size * i;
void* read_dst = t->header.encrypted ? t->enc_buf.data : row_ptr;
try_void(file_readStructsExactly(t->table_file, read_dst, row_size_in_file, 1));
if(t->header.encrypted) {
try_void(
AESBlockDecryptor_decrypt(
&t->dec,
t->enc_buf,
Array_construct_size(row_ptr, row_size)
)
);
}
}
Return RESULT_VOID;
}
@ -269,15 +333,31 @@ Result(void) idb_updateRows(Table* t, u64 id, const void* src, u64 count){
try_void(Table_setDirtyBit(t, true));
Defer(IGNORE_RESULT Table_setDirtyBit(t, false));
i64 file_pos = sizeof(t->header) + id * t->header.row_size;
u32 row_size = t->header.row_size;
u32 row_size_in_file = t->header.encrypted ? t->enc_buf.size : row_size;
i64 file_pos = sizeof(t->header) + id * row_size_in_file;
// TODO: set dirty bit in backup file too
// TODO: save old values to the backup file
// seek for the row position in file
try_void(file_seek(t->table_file, file_pos, SeekOrigin_Start));
// replace rows in file
try_void(file_writeStructs(t->table_file, src, t->header.row_size, count));
for(u64 i = 0; i < count; i++){
void* row_ptr = (u8*)src + row_size * i;
if(t->header.encrypted){
try_void(
AESBlockEncryptor_encrypt(
&t->enc,
Array_construct_size(row_ptr, row_size),
t->enc_buf
)
);
row_ptr = t->enc_buf.data;
}
try_void(file_writeStructs(t->table_file, row_ptr, row_size_in_file, 1));
}
Return RESULT_VOID;
}
@ -291,14 +371,30 @@ Result(u64) idb_pushRows(Table* t, const void* src, u64 count){
try_void(Table_setDirtyBit(t, true));
Defer(IGNORE_RESULT Table_setDirtyBit(t, false));
u32 row_size = t->header.row_size;
u32 row_size_in_file = t->header.encrypted ? t->enc_buf.size : row_size;
const u64 new_row_index = t->row_count;
// seek for end of the file
try_void(file_seek(t->table_file, 0, SeekOrigin_End));
// write new rows to the file
try_void(file_writeStructs(t->table_file, src, t->header.row_size, count));
t->row_count += count;
// write new rows to the file
for(u64 i = 0; i < count; i++){
void* row_ptr = (u8*)src + row_size * i;
if(t->header.encrypted){
try_void(
AESBlockEncryptor_encrypt(
&t->enc,
Array_construct_size(row_ptr, row_size),
t->enc_buf
)
);
row_ptr = t->enc_buf.data;
}
try_void(file_writeStructs(t->table_file, row_ptr, row_size_in_file, 1));
t->row_count++;
}
Return RESULT_VALUE(u, new_row_index);
}

View File

@ -9,11 +9,10 @@ typedef struct IncrementalDB IncrementalDB;
typedef struct Table Table;
Result(IncrementalDB*) idb_open(str db_dir);
Result(IncrementalDB*) idb_open(str db_dir, NULLABLE(Array(u8) aes_key));
void idb_close(IncrementalDB* db);
Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str _table_name, u32 row_size);
Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str table_name, u32 row_size);
Result(void) idb_getRows(Table* t, u64 id, void* dst, u64 count);
#define idb_getRow(T, ID, DST) idb_getRows(T, ID, DST, 1)