#include "idb.h" #include "tlibc/magic.h" #include "tlibc/filesystem.h" #include "tlibc/collections/HashMap.h" #include "cryptography/AES.h" #include static const char KEY_CHALLENGE_PLAIN[16] = "key is correct!"; #define KEY_CHALLENGE_PLAIN_SIZE sizeof(KEY_CHALLENGE_PLAIN) #define KEY_CHALLENGE_CIPHER_SIZE AESBlockEncryptor_calcDstSize(KEY_CHALLENGE_PLAIN_SIZE) typedef struct TableFileHeader { Magic32 magic; u16 version; bool _dirty_bit; bool encrypted; u32 row_size; /* encrypted KEY_CHALLENGE_PLAIN */ u8 key_challenge[KEY_CHALLENGE_CIPHER_SIZE]; } ATTRIBUTE_ALIGNED(256) TableFileHeader; typedef struct Table { TableFileHeader header; IncrementalDB* db; str name; str table_file_path; str changes_file_path; FILE* table_file; FILE* changes_file; pthread_mutex_t mutex; u64 row_count; AESBlockEncryptor enc; AESBlockDecryptor dec; Array(u8) enc_buf; } Table; typedef struct IncrementalDB { str db_dir; Array(u8) aes_key; HashMap(Table**) tables_map; pthread_mutex_t mutex; } IncrementalDB; static const Magic32 TABLE_FILE_MAGIC = { .bytes = { 'I', 'D', 'B', 't' } }; static void Table_close(Table* t){ if(t == NULL) return; fclose(t->table_file); fclose(t->changes_file); str_destroy(t->name); str_destroy(t->table_file_path); str_destroy(t->changes_file_path); pthread_mutex_destroy(&t->mutex); Array_u8_destroy(&t->enc_buf); free(t); } // element destructor for HashMap(Table*) static void TablePtr_free(void* t_ptr_ptr){ Table_close(*(Table**)t_ptr_ptr); } /// @param name must be null-terminated static Result(void) validateTableName(str name){ char forbidden_characters[] = { '/', '\\', ':', ';', '?', '"', '\'', '\n', '\r', '\t'}; for(u32 i = 0; i < ARRAY_LEN(forbidden_characters); i++) { char c = forbidden_characters[i]; if(str_seekChar(name, c, 0) != -1){ return RESULT_ERROR_FMT( "Table name '%s' contains forbidden character '%c'", name.data, c); } } return RESULT_VOID; } static Result(void) Table_readHeader(Table* t){ Deferral(4); // seek for start of the file try_void(file_seek(t->table_file, 0, SeekOrigin_Start)); // read header try_void(file_readStructsExactly(t->table_file, &t->header, sizeof(t->header), 1)); Return RESULT_VOID; } static Result(void) Table_writeHeader(Table* t){ Deferral(4); // seek for start of the file try_void(file_seek(t->table_file, 0, SeekOrigin_Start)); // write header try_void(file_writeStructs(t->table_file, &t->header, sizeof(t->header), 1)); // TODO: add more fflush calls fflush(t->table_file); fflush(t->changes_file); Return RESULT_VOID; } static Result(void) Table_setDirtyBit(Table* t, bool val){ Deferral(4); t->header._dirty_bit = val; try_void(Table_writeHeader(t)); Return RESULT_VOID; } static Result(bool) Table_getDirtyBit(Table* t){ Deferral(4); try_void(Table_readHeader(t)); Return RESULT_VALUE(i, t->header._dirty_bit); } static Result(void) Table_calculateRowCount(Table* t){ Deferral(4); try(i64 file_size, i, file_getSize(t->table_file)); i64 data_size = file_size - sizeof(t->header); i64 row_size_in_file = t->header.encrypted ? AESBlockEncryptor_calcDstSize(t->header.row_size) : t->header.row_size; if(data_size % row_size_in_file != 0){ //TODO: fix table instead of trowing error Return RESULT_ERROR_FMT( "Table '%s' has invalid size. Last row is incomplete.", t->name.data); } t->row_count = data_size / row_size_in_file; Return RESULT_VOID; } static Result(void) Table_validateHeader(Table* t){ Deferral(4); if(t->header.magic.n != TABLE_FILE_MAGIC.n || t->header.row_size == 0) { Return RESULT_ERROR_FMT( "Table file '%s' has invalid header", t->table_file_path.data); } if(t->header.version != IDB_VERSION){ Return RESULT_ERROR_FMT( "Table file '%s' was created for IDBv%u, which is incompatible with v%u", t->table_file_path.data, t->header.version, (u32)IDB_VERSION); } try(bool dirty_bit, i, Table_getDirtyBit(t)); if(dirty_bit){ //TODO: handle dirty bit instead of throwing error Return RESULT_ERROR_FMT( "Table file '%s' has dirty bit set", t->table_file_path.data); } Return RESULT_VOID; } static Result(void) Table_validateEncryption(Table* t){ Deferral(1); bool db_encrypted = t->db->aes_key.len != 0; if(t->header.encrypted && !db_encrypted){ Return RESULT_ERROR_FMT( "Table '%s' is encrypted, but encryption key is not set." "Database '%s' is encrypted and must have not-null encryption key.", t->name.data, t->db->db_dir.data); } if(!t->header.encrypted && db_encrypted){ Return RESULT_ERROR_FMT( "Table '%s' is not encrypted, but encryption key is set." "Do not set encryption key for not encrypted database '%s'.", t->name.data, t->db->db_dir.data); } // validate aes encryption key if(t->header.encrypted){ try_void( AESBlockDecryptor_decrypt( &t->dec, Array_u8_construct(t->header.key_challenge, KEY_CHALLENGE_CIPHER_SIZE), t->enc_buf ) ); if(memcmp(t->enc_buf.data, KEY_CHALLENGE_PLAIN, KEY_CHALLENGE_PLAIN_SIZE) != 0){ Return RESULT_ERROR_FMT( "Encryption key for table '%s' is wrong", t->name.data); } } Return RESULT_VOID; } static Result(void) Table_validateRowSize(Table* t, u32 row_size){ if(row_size != t->header.row_size){ return RESULT_ERROR_FMT( "Requested row size (%u) doesn't match saved row size (%u)", row_size, t->header.row_size); } return RESULT_VOID; } void idb_close(IncrementalDB* db){ if(db == NULL) return; str_destroy(db->db_dir); Array_u8_destroy(&db->aes_key); HashMap_destroy(&db->tables_map); pthread_mutex_destroy(&db->mutex); free(db); } Result(IncrementalDB*) idb_open(str db_dir, NULLABLE(Array(u8) aes_key)){ Deferral(16); try_assert(aes_key.len == 0 || aes_key.len == 16 || aes_key.len == 24 || aes_key.len == 32); IncrementalDB* db = (IncrementalDB*)malloc(sizeof(IncrementalDB)); // value of *db must be set to zero or behavior of idb_close will be undefined memset(db, 0, sizeof(IncrementalDB)); // if object construction fails, destroy incomplete object bool success = false; Defer(if(!success) idb_close(db)); if(aes_key.len != 0){ db->aes_key = Array_u8_copy(aes_key); } db->db_dir = str_copy(db_dir); try_void(dir_create(db->db_dir.data)); HashMap_construct(&db->tables_map, Table*, TablePtr_free); try_stderrcode(pthread_mutex_init(&db->mutex, NULL)); success = true; Return RESULT_VALUE(p, db); } void idb_lockDB(IncrementalDB* db){ try_fatal_stderrcode(pthread_mutex_lock(&db->mutex)); } void idb_unlockDB(IncrementalDB* db){ try_fatal_stderrcode(pthread_mutex_unlock(&db->mutex)); } void idb_lockTable(Table* t){ try_fatal_stderrcode(pthread_mutex_lock(&t->mutex)); } void idb_unlockTable(Table* t){ try_fatal_stderrcode(pthread_mutex_unlock(&t->mutex)); } Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str table_name, u32 row_size, bool lock_db){ Deferral(16); if(lock_db){ idb_lockDB(db); Defer(idb_unlockDB(db)); } Table** tpp = HashMap_tryGetPtr(&db->tables_map, table_name); if(tpp != NULL){ Table* existing_table = *tpp; try_void(Table_validateRowSize(existing_table, row_size)); Return RESULT_VALUE(p, existing_table); } try_void(validateTableName(table_name)); Table* t = (Table*)malloc(sizeof(Table)); // value of *t must be set to zero or behavior of Table_close will be undefined memset(t, 0, sizeof(Table)); // if object construction fails, destroy incomplete object bool success = false; Defer(if(!success) Table_close(t)); t->db = db; try_stderrcode(pthread_mutex_init(&t->mutex, NULL)); t->name = str_copy(table_name); t->table_file_path = str_from_cstr( strcat_malloc(db->db_dir.data, path_seps, t->name.data, ".idb-table")); t->changes_file_path = str_from_cstr( strcat_malloc(db->db_dir.data, path_seps, t->name.data, ".idb-changes")); bool table_file_exists = file_exists(t->table_file_path.data); // open or create file with table data try(t->table_file, p, file_openOrCreateReadWrite(t->table_file_path.data)); // open or create file with backups of updated rows try(t->changes_file, p, file_openOrCreateReadWrite(t->changes_file_path.data)); // init encryptor and decryptor now to use them in table header validation/creation if(db->aes_key.len != 0) { AESBlockEncryptor_construct(&t->enc, db->aes_key, AESBlockEncryptor_DEFAULT_CLASS); AESBlockDecryptor_construct(&t->dec, db->aes_key, AESBlockDecryptor_DEFAULT_CLASS); u32 row_size_in_file = AESBlockEncryptor_calcDstSize(row_size); t->enc_buf = Array_u8_alloc(row_size_in_file); } // init header if(table_file_exists){ // read table file try_void(Table_readHeader(t)); try_void(Table_validateHeader(t)); try_void(Table_validateEncryption(t)); try_void(Table_validateRowSize(t, row_size)); try_void(Table_calculateRowCount(t)); } else { // create table file t->header.magic.n = TABLE_FILE_MAGIC.n; t->header.version = IDB_VERSION; t->header.encrypted = db->aes_key.len != 0; t->header._dirty_bit = false; t->header.row_size = row_size; memset(t->header.key_challenge, 0, KEY_CHALLENGE_CIPHER_SIZE); try_void( AESBlockEncryptor_encrypt( &t->enc, Array_u8_construct((void*)KEY_CHALLENGE_PLAIN, KEY_CHALLENGE_PLAIN_SIZE), Array_u8_construct(t->header.key_challenge, KEY_CHALLENGE_CIPHER_SIZE) ) ); try_void(Table_writeHeader(t)); } if(!HashMap_tryPush(&db->tables_map, t->name, &t)){ Return RESULT_ERROR_FMT( "Table '%s' is already open", t->name.data); } success = true; Return RESULT_VALUE(p, t); } Result(void) idb_getRows(Table* t, u64 id, void* dst, u64 count, bool lock_table){ Deferral(8); if(lock_table){ idb_lockTable(t); Defer(idb_unlockTable(t)); } if(id + count > t->row_count){ Return RESULT_ERROR_FMT( "Can't read "FMT_u64" rows at index "FMT_u64 " because table '%s' has only "FMT_u64" rows", count, id, t->name.data, t->row_count); } u32 row_size = t->header.row_size; u32 row_size_in_file = t->header.encrypted ? t->enc_buf.len : row_size; i64 file_pos = sizeof(t->header) + id * row_size_in_file; // seek for the row position in file try_void(file_seek(t->table_file, file_pos, SeekOrigin_Start)); // read rows from file for(u64 i = 0; i < count; i++){ void* row_ptr = (u8*)dst + row_size * i; void* read_dst = t->header.encrypted ? t->enc_buf.data : row_ptr; try_void(file_readStructsExactly(t->table_file, read_dst, row_size_in_file, 1)); if(t->header.encrypted) { try_void( AESBlockDecryptor_decrypt( &t->dec, t->enc_buf, Array_u8_construct(row_ptr, row_size) ) ); } } Return RESULT_VOID; } Result(void) idb_updateRows(Table* t, u64 id, const void* src, u64 count, bool lock_table){ Deferral(8); if(lock_table){ idb_lockTable(t); Defer(idb_unlockTable(t)); } if(id + count > t->row_count){ Return RESULT_ERROR_FMT( "Can't update "FMT_u64" rows at index "FMT_u64 " because table '%s' has only "FMT_u64" rows", count, id, t->name.data, t->row_count); } try_void(Table_setDirtyBit(t, true)); Defer(IGNORE_RESULT Table_setDirtyBit(t, false)); u32 row_size = t->header.row_size; u32 row_size_in_file = t->header.encrypted ? t->enc_buf.len : row_size; i64 file_pos = sizeof(t->header) + id * row_size_in_file; // TODO: set dirty bit in backup file too // TODO: save old values to the backup file // seek for the row position in file try_void(file_seek(t->table_file, file_pos, SeekOrigin_Start)); // replace rows in file for(u64 i = 0; i < count; i++){ void* row_ptr = (u8*)src + row_size * i; if(t->header.encrypted){ try_void( AESBlockEncryptor_encrypt( &t->enc, Array_u8_construct(row_ptr, row_size), t->enc_buf ) ); row_ptr = t->enc_buf.data; } try_void(file_writeStructs(t->table_file, row_ptr, row_size_in_file, 1)); } Return RESULT_VOID; } Result(u64) idb_pushRows(Table* t, const void* src, u64 count, bool lock_table){ Deferral(8); if(lock_table){ idb_lockTable(t); Defer(idb_unlockTable(t)); } try_void(Table_setDirtyBit(t, true)); Defer(IGNORE_RESULT Table_setDirtyBit(t, false)); u32 row_size = t->header.row_size; u32 row_size_in_file = t->header.encrypted ? t->enc_buf.len : row_size; const u64 new_row_index = t->row_count; // seek for end of the file try_void(file_seek(t->table_file, 0, SeekOrigin_End)); // write new rows to the file for(u64 i = 0; i < count; i++){ void* row_ptr = (u8*)src + row_size * i; if(t->header.encrypted){ try_void( AESBlockEncryptor_encrypt( &t->enc, Array_u8_construct(row_ptr, row_size), t->enc_buf ) ); row_ptr = t->enc_buf.data; } try_void(file_writeStructs(t->table_file, row_ptr, row_size_in_file, 1)); t->row_count++; } Return RESULT_VALUE(u, new_row_index); } Result(u64) idb_getRowCount(Table* t, bool lock_table){ Deferral(4); if(lock_table){ idb_lockTable(t); Defer(idb_unlockTable(t)); } u64 count = t->row_count; Return RESULT_VALUE(u, count); } Result(void) idb_createListFromTable(Table* t, List_* l, bool lock_table){ Deferral(1); if(lock_table){ idb_lockTable(t); Defer(idb_unlockTable(t)); } u64 row_count = t->row_count; u64 row_size = t->header.row_size; u64 total_size = row_count * row_size; *l = _List_alloc(total_size, row_size); l->len = row_count; bool success = false; Defer(if(!success) _List_destroy(l)); try_void(idb_getRows(t, 0, l->data, row_count, false)); success = true; Return RESULT_VOID; }