diff --git a/dependencies/tlibc b/dependencies/tlibc index ea6c20f..0d422cd 160000 --- a/dependencies/tlibc +++ b/dependencies/tlibc @@ -1 +1 @@ -Subproject commit ea6c20f430f4536631f379697ae56fc8b5fd5d17 +Subproject commit 0d422cd7e591795398649ff5cdb7fa2b5679e92d diff --git a/dependencies/tlibtoml b/dependencies/tlibtoml index a0e280d..bd38585 160000 --- a/dependencies/tlibtoml +++ b/dependencies/tlibtoml @@ -1 +1 @@ -Subproject commit a0e280d77b6d8967239b14915609eef422f6ba0b +Subproject commit bd38585b35937a977b5f3e048be56faed38b2e16 diff --git a/include/tcp-chat/common_constants.h b/include/tcp-chat/common_constants.h index e77faab..d6d353c 100644 --- a/include/tcp-chat/common_constants.h +++ b/include/tcp-chat/common_constants.h @@ -18,3 +18,4 @@ #define CHANNEL_DESC_SIZE_MAX 1023 #define MESSAGE_SIZE_MIN 1 #define MESSAGE_SIZE_MAX 4000 +#define MESSAGE_BLOCK_SIZE (64*1024) \ No newline at end of file diff --git a/src/cli/ClientCLI/ClientCLI.c b/src/cli/ClientCLI/ClientCLI.c index 7a864dc..b65c4f4 100644 --- a/src/cli/ClientCLI/ClientCLI.c +++ b/src/cli/ClientCLI/ClientCLI.c @@ -43,7 +43,7 @@ void ClientCLI_destroy(ClientCLI* self){ HashMap_destroy(&self->servers.addr_id_map); } void ClientCLI_construct(ClientCLI* self){ - memset(self, 0, sizeof(*self)); + zeroStruct(self); } Result(void) ClientCLI_run(ClientCLI* self) { @@ -326,7 +326,7 @@ static Result(void) ClientCLI_openUserDB(ClientCLI* self){ // Load servers table try(self->servers.table, p, - idb_getOrCreateTable(self->db, STR("servers"), sizeof(ServerInfo), false) + idb_getOrCreateTable(self->db, str_null, STR("servers"), sizeof(ServerInfo), false) ); // Lock table until this function returns. @@ -362,7 +362,7 @@ static Result(ServerInfo*) ClientCLI_saveServerInfo(ClientCLI* self, // create new server info ServerInfo server; - memset(&server, 0, sizeof(ServerInfo)); + zeroStruct(&server); // address if(addr.len > HOSTADDR_SIZE_MAX) addr.len = HOSTADDR_SIZE_MAX; diff --git a/src/client/ServerConnection.c b/src/client/ServerConnection.c index 8786dee..8b5deeb 100644 --- a/src/client/ServerConnection.c +++ b/src/client/ServerConnection.c @@ -18,7 +18,7 @@ Result(ServerConnection*) ServerConnection_open(Client* client, cstr server_addr Deferral(16); ServerConnection* conn = (ServerConnection*)malloc(sizeof(ServerConnection)); - memset(conn, 0, sizeof(*conn)); + zeroStruct(conn); bool success = false; Defer(if(!success) ServerConnection_close(conn)); diff --git a/src/client/client.c b/src/client/client.c index a1a3ced..2709246 100644 --- a/src/client/client.c +++ b/src/client/client.c @@ -15,7 +15,7 @@ Result(Client*) Client_create(str username, str password){ Deferral(16); Client* self = (Client*)malloc(sizeof(Client)); - memset(self, 0, sizeof(Client)); + zeroStruct(self); bool success = false; Defer(if(!success) Client_free(self)); diff --git a/src/cryptography/AES.c b/src/cryptography/AES.c index 28201a8..0a12840 100755 --- a/src/cryptography/AES.c +++ b/src/cryptography/AES.c @@ -55,7 +55,7 @@ Result(u32) AESBlockEncryptor_encrypt(AESBlockEncryptor* ptr, __Array_writeNext(&dst, ptr->iv, __AES_BLOCK_IV_SIZE); EncryptedBlockHeader header; - memset(&header, 0, sizeof(header)); + zeroStruct(&header); memcpy(header.key_checksum, ptr->key_checksum, __AES_BLOCK_KEY_CHECKSUM_SIZE); header.padding_size = (16 - src.len % 16) % 16; // write header to buffer diff --git a/src/db/idb.c b/src/db/idb.c index f515858..2daffc4 100644 --- a/src/db/idb.c +++ b/src/db/idb.c @@ -2,6 +2,7 @@ #include "tlibc/magic.h" #include "tlibc/filesystem.h" #include "tlibc/collections/HashMap.h" +#include "tlibc/string/StringBuilder.h" #include "cryptography/AES.h" #include @@ -216,11 +217,12 @@ void idb_close(IncrementalDB* db){ Result(IncrementalDB*) idb_open(str db_dir, NULLABLE(Array(u8) aes_key)){ Deferral(16); + try_assert(db_dir.len > 0); try_assert(aes_key.len == 0 || aes_key.len == 16 || aes_key.len == 24 || aes_key.len == 32); IncrementalDB* db = (IncrementalDB*)malloc(sizeof(IncrementalDB)); // value of *db must be set to zero or behavior of idb_close will be undefined - memset(db, 0, sizeof(IncrementalDB)); + zeroStruct(db); // if object construction fails, destroy incomplete object bool success = false; Defer(if(!success) idb_close(db)); @@ -255,7 +257,7 @@ void idb_unlockTable(Table* t){ } -Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str table_name, u32 row_size, bool lock_db){ +Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str subdir, str table_name, u32 row_size, bool lock_db){ Deferral(16); if(lock_db){ @@ -274,7 +276,7 @@ Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str table_name, u32 row_s Table* t = (Table*)malloc(sizeof(Table)); // value of *t must be set to zero or behavior of Table_close will be undefined - memset(t, 0, sizeof(Table)); + zeroStruct(t); // if object construction fails, destroy incomplete object bool success = false; Defer(if(!success) Table_close(t)); @@ -282,10 +284,27 @@ Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str table_name, u32 row_s t->db = db; try_stderrcode(pthread_mutex_init(&t->mutex, NULL)); t->name = str_copy(table_name); - t->table_file_path = str_from_cstr( - strcat_malloc(db->db_dir.data, path_seps, t->name.data, ".idb-table")); - t->changes_file_path = str_from_cstr( - strcat_malloc(db->db_dir.data, path_seps, t->name.data, ".idb-changes")); + + // set file paths + StringBuilder sb = StringBuilder_alloc(256); + Defer(StringBuilder_destroy(&sb)); + StringBuilder_append_str(&sb, db->db_dir); + StringBuilder_append_char(&sb, path_sep); + if(subdir.len != 0){ + StringBuilder_append_str(&sb, subdir); + try_void(dir_create(StringBuilder_getStr(&sb).data)); + StringBuilder_append_char(&sb, path_sep); + } + StringBuilder_append_str(&sb, t->name); + // table file + str table_file_ext = STR(".idb-table"); + StringBuilder_append_str(&sb, table_file_ext); + t->table_file_path = str_copy(StringBuilder_getStr(&sb)); + // changes file + str changes_file_ext = STR(".idb-changes"); + StringBuilder_removeFromEnd(&sb, table_file_ext.len); + StringBuilder_append_str(&sb, changes_file_ext); + t->changes_file_path = str_copy(StringBuilder_getStr(&sb)); bool table_file_exists = file_exists(t->table_file_path.data); diff --git a/src/db/idb.h b/src/db/idb.h index c51fd6c..f752766 100644 --- a/src/db/idb.h +++ b/src/db/idb.h @@ -30,7 +30,7 @@ void idb_lockTable(Table* t); void idb_unlockTable(Table* t); -Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str table_name, u32 row_size, bool lock_db); +Result(Table*) idb_getOrCreateTable(IncrementalDB* db, NULLABLE(str) subdir, str table_name, u32 row_size, bool lock_db); Result(void) idb_getRows(Table* t, u64 start_from_id, void* dst, u64 count, bool lock_table); #define idb_getRow(T, ID, DST, LOCK) idb_getRows(T, ID, DST, 1, LOCK) diff --git a/src/db/tables.h b/src/db/tables.h index 42ad2f7..dc2d197 100644 --- a/src/db/tables.h +++ b/src/db/tables.h @@ -32,6 +32,7 @@ typedef struct ChannelInfo { char desc[CHANNEL_DESC_SIZE_MAX + 1]; } ATTRIBUTE_ALIGNED(4*1024) ChannelInfo; + // not a table typedef struct MessageMeta { /* @@ -50,12 +51,13 @@ typedef struct MessageMeta { // Stores some number of messages. Look in MessageBlockMeta to see how much. typedef struct MessageBlock { /* ((sequence MessageMeta), (sequence binary-data)) */ - u8 data[64*1024 - 4]; + u8 data[MESSAGE_BLOCK_SIZE]; } ATTRIBUTE_ALIGNED(64) MessageBlock; // is used to find in which MessageBlock a message is stored typedef struct MessageBlockMeta { - u64 message_id_first; + u64 first_message_id; u32 messages_count; + u32 data_size; } ATTRIBUTE_ALIGNED(16) MessageBlockMeta; diff --git a/src/network/tcp-chat-protocol/constant.c b/src/network/tcp-chat-protocol/constant.c index c15a91b..9d0f2d7 100644 --- a/src/network/tcp-chat-protocol/constant.c +++ b/src/network/tcp-chat-protocol/constant.c @@ -29,7 +29,7 @@ Result(void) PacketHeader_validateContentSize(PacketHeader* ptr, u64 expected_si } void PacketHeader_construct(PacketHeader* ptr, u8 protocol_version, u16 type, u64 content_size){ - memset(ptr, 0, sizeof(*ptr)); + zeroStruct(ptr); ptr->magic.n = PacketHeader_MAGIC.n; ptr->protocol_version = protocol_version; ptr->type = type; diff --git a/src/network/tcp-chat-protocol/v1.c b/src/network/tcp-chat-protocol/v1.c index 5d88c51..2758abf 100644 --- a/src/network/tcp-chat-protocol/v1.c +++ b/src/network/tcp-chat-protocol/v1.c @@ -45,7 +45,7 @@ str validateUsername_str(str username){ void ErrorMessage_construct(ErrorMessage* ptr, PacketHeader* header, u32 msg_size){ _PacketHeader_construct(ErrorMessage); - memset(ptr, 0, sizeof(*ptr)); + zeroStruct(ptr); ptr->msg_size = msg_size; } @@ -54,7 +54,7 @@ Result(void) ClientHandshake_tryConstruct(ClientHandshake* ptr, PacketHeader* he { Deferral(1); _PacketHeader_construct(ClientHandshake); - memset(ptr, 0, sizeof(*ptr)); + zeroStruct(ptr); try_assert(session_key.len == sizeof(ptr->session_key)); memcpy(ptr->session_key, session_key.data, session_key.len); @@ -66,7 +66,7 @@ void ServerHandshake_construct(ServerHandshake* ptr, PacketHeader* header, u64 session_id) { _PacketHeader_construct(ServerHandshake); - memset(ptr, 0, sizeof(*ptr)); + zeroStruct(ptr); ptr->session_id = session_id; } @@ -74,7 +74,7 @@ void ServerPublicInfoRequest_construct(ServerPublicInfoRequest *ptr, PacketHeade ServerPublicInfo property) { _PacketHeader_construct(ServerPublicInfoRequest); - memset(ptr, 0, sizeof(*ptr)); + zeroStruct(ptr); ptr->property = property; } @@ -82,7 +82,7 @@ void ServerPublicInfoResponse_construct(ServerPublicInfoResponse* ptr, PacketHea u32 data_size) { _PacketHeader_construct(ServerPublicInfoResponse); - memset(ptr, 0, sizeof(*ptr)); + zeroStruct(ptr); ptr->data_size = data_size; } @@ -91,7 +91,7 @@ Result(void) LoginRequest_tryConstruct(LoginRequest *ptr, PacketHeader* header, { Deferral(1); _PacketHeader_construct(LoginRequest); - memset(ptr, 0, sizeof(*ptr)); + zeroStruct(ptr); str name_error_str = validateUsername_str(username); if(name_error_str.data){ @@ -109,7 +109,7 @@ void LoginResponse_construct(LoginResponse* ptr, PacketHeader* header, u64 user_id, u64 landing_channel_id) { _PacketHeader_construct(LoginResponse); - memset(ptr, 0, sizeof(*ptr)); + zeroStruct(ptr); ptr->user_id = user_id; ptr->landing_channel_id = landing_channel_id; @@ -120,7 +120,7 @@ Result(void) RegisterRequest_tryConstruct(RegisterRequest *ptr, PacketHeader* he { Deferral(1); _PacketHeader_construct(RegisterRequest); - memset(ptr, 0, sizeof(*ptr)); + zeroStruct(ptr); str name_error_str = validateUsername_str(username); if(name_error_str.data){ @@ -138,6 +138,43 @@ void RegisterResponse_construct(RegisterResponse *ptr, PacketHeader* header, u64 user_id) { _PacketHeader_construct(RegisterResponse); - memset(ptr, 0, sizeof(*ptr)); + zeroStruct(ptr); ptr->user_id = user_id; } + +void SendMessageRequest_construct(SendMessageRequest *ptr, PacketHeader *header, + u64 channel_id, u16 data_size) +{ + _PacketHeader_construct(SendMessageRequest); + zeroStruct(ptr); + ptr->channel_id = channel_id; + ptr->data_size = data_size; +} + +void SendMessageResponse_construct(SendMessageResponse *ptr, PacketHeader *header, + u64 message_id, DateTime receiving_time_utc) +{ + _PacketHeader_construct(SendMessageResponse); + zeroStruct(ptr); + ptr->message_id = message_id; + ptr->receiving_time_utc = receiving_time_utc; +} + +void GetMessageBlockRequest_construct(GetMessageBlockRequest *ptr, PacketHeader *header, + u64 first_message_id, u32 messages_count) +{ + _PacketHeader_construct(GetMessageBlockRequest); + zeroStruct(ptr); + ptr->first_message_id = first_message_id; + ptr->messages_count = messages_count; +} + +void GetMessageBlockResponse_construct(GetMessageBlockResponse *ptr, PacketHeader *header, + u64 first_message_id, u32 messages_count, u32 data_size) +{ + _PacketHeader_construct(GetMessageBlockResponse); + zeroStruct(ptr); + ptr->first_message_id = first_message_id; + ptr->messages_count = messages_count; + ptr->data_size = data_size; +} diff --git a/src/network/tcp-chat-protocol/v1.h b/src/network/tcp-chat-protocol/v1.h index 2c0f92e..d45f27a 100644 --- a/src/network/tcp-chat-protocol/v1.h +++ b/src/network/tcp-chat-protocol/v1.h @@ -33,6 +33,10 @@ typedef enum PacketType { PacketType_LoginResponse, PacketType_RegisterRequest, PacketType_RegisterResponse, + PacketType_SendMessageRequest, + PacketType_SendMessageResponse, + PacketType_GetMessageBlockRequest, + PacketType_GetMessageBlockResponse, } ATTRIBUTE_PACKED PacketType; @@ -140,21 +144,21 @@ void SendMessageResponse_construct(SendMessageResponse* ptr, PacketHeader* heade typedef struct GetMessageBlockRequest { - u64 message_id_first; + u64 first_message_id; u32 messages_count; } ALIGN_PACKET_STRUCT GetMessageBlockRequest; void GetMessageBlockRequest_construct(GetMessageBlockRequest* ptr, PacketHeader* header, - u64 message_id_first, u32 messages_count); + u64 first_message_id, u32 messages_count); typedef struct GetMessageBlockResponse { - u64 message_id_first; + u64 first_message_id; u32 messages_count; u32 data_size; /* stream of size data_size : ((sequence MessageMeta), (sequence binary-data)) */ } ALIGN_PACKET_STRUCT GetMessageBlockResponse; void GetMessageBlockResponse_construct(GetMessageBlockResponse* ptr, PacketHeader* header, - u32 data_size); + u64 first_message_id, u32 messages_count, u32 data_size); diff --git a/src/server/Channel.c b/src/server/Channel.c new file mode 100644 index 0000000..125988a --- /dev/null +++ b/src/server/Channel.c @@ -0,0 +1,151 @@ +#include "server_internal.h" +#include "tlibc/string/StringBuilder.h" +#include "tlibc/filesystem.h" + +void Channel_free(Channel* self){ + if(!self) + return; + str_destroy(self->name); + str_destroy(self->description); + List_MessageBlockMeta_destroy(&self->messages.blocks_meta_list); + LList_MessageBlock_destroy(&self->messages.blocks_queue); + free(self); +} + +Result(Channel*) Channel_create(u64 chan_id, str name, str description, + IncrementalDB* db, bool lock_db) +{ + Deferral(8); + + Channel* self = (Channel*)malloc(sizeof(Channel)); + zeroStruct(self); + bool success = false; + Defer(if(!success) Channel_free(self)); + + self->id = chan_id; + try_assert(name.len >= CHANNEL_NAME_SIZE_MIN && name.len <= CHANNEL_NAME_SIZE_MAX); + self->name = str_copy(name); + try_assert(description.len <= CHANNEL_DESC_SIZE_MAX); + self->description = str_copy(description); + + if(lock_db){ + idb_lockDB(db); + Defer(idb_unlockDB(db)); + } + + StringBuilder sb = StringBuilder_alloc(CHANNEL_NAME_SIZE_MAX + 32 + 1); + Defer(StringBuilder_destroy(&sb)); + StringBuilder_append_str(&sb, STR("channels")); + StringBuilder_append_char(&sb, path_sep); + StringBuilder_append_str(&sb, name); + + str subdir = str_copy(StringBuilder_getStr(&sb)); + Defer(str_destroy(subdir)); + str message_blocks_str = STR("message_blocks"); + str message_blocks_meta_str = STR("message_blocks_meta"); + + StringBuilder_removeFromEnd(&sb, -1); + StringBuilder_append_str(&sb, name); + StringBuilder_append_char(&sb, '_'); + StringBuilder_append_str(&sb, message_blocks_str); + try(self->messages.blocks_table, p, + idb_getOrCreateTable(db, subdir, StringBuilder_getStr(&sb), sizeof(MessageBlock), false) + ); + idb_lockTable(self->messages.blocks_table); + Defer(idb_unlockTable(self->messages.blocks_table)); + + + StringBuilder_removeFromEnd(&sb, message_blocks_str.len); + StringBuilder_append_str(&sb, message_blocks_meta_str); + try(self->messages.blocks_meta_table, p, + idb_getOrCreateTable(db, subdir, StringBuilder_getStr(&sb), sizeof(MessageBlock), false) + ); + idb_lockTable(self->messages.blocks_meta_table); + Defer(idb_unlockTable(self->messages.blocks_meta_table)); + + // load whole message_blocks_meta table to list + try_void( + idb_createListFromTable(self->messages.blocks_meta_table, (void*)&self->messages.blocks_meta_list, false) + ); + + // load N last blocks to the queue + self->messages.blocks_queue = LList_construct(MessageBlock, NULL); + u64 message_blocks_count = self->messages.blocks_meta_list.len; + u64 first_block_id = 0; + if(message_blocks_count > MESSAGE_BLOCKS_CACHE_COUNT) + first_block_id = message_blocks_count - MESSAGE_BLOCKS_CACHE_COUNT; + for(u64 id = first_block_id; id < message_blocks_count; id++){ + LLNode(MessageBlock)* node = LLNode_MessageBlock_createZero(); + LList_MessageBlock_insertAfter( + &self->messages.blocks_queue, + self->messages.blocks_queue.last, + node + ); + try_void(idb_getRow(self->messages.blocks_table, id, node->value.data, false)); + } + + if(self->messages.blocks_meta_list.len > 0){ + MessageBlockMeta last_block_meta = self->messages.blocks_meta_list.data[self->messages.blocks_meta_list.len - 1]; + self->messages.count = last_block_meta.first_message_id + last_block_meta.messages_count - 1; + } + else { + self->messages.count = 0; + } + + success = true; + Return RESULT_VALUE(p, self); +} + + +Result(void) Channel_saveMessage(Channel* self, Array(u8) message_data, u64 sender_id, + MessageMeta* out_message_meta) +{ + Deferral(1); + + // create new block if message won't fit in the last existing + MessageBlockMeta* incomplete_block_meta = self->messages.blocks_meta_list.data + self->messages.blocks_meta_list.len; + u64 new_message_id = incomplete_block_meta->first_message_id + incomplete_block_meta->messages_count; + u32 message_size_in_block = sizeof(MessageMeta) + ALIGN_TO(message_data.len, 8); + if(incomplete_block_meta->data_size + message_size_in_block > MESSAGE_BLOCK_SIZE){ + // create new MessageBlockMeta + incomplete_block_meta = List_MessageBlockMeta_expand(&self->messages.blocks_meta_list, 1); + incomplete_block_meta->first_message_id = new_message_id; + incomplete_block_meta->messages_count = 0; + incomplete_block_meta->data_size = 0; + // create new MessageBlock + LList_MessageBlock_insertAfter( + &self->messages.blocks_queue, + self->messages.blocks_queue.last, + LLNode_MessageBlock_createZero()); + //TODO: save to DB + } + + // copy message to message block + out_message_meta->magic = MESSAGE_MAGIC; + out_message_meta->data_size = message_data.len; + out_message_meta->id = new_message_id; + out_message_meta->sender_id = sender_id; + DateTime_getUTC(&out_message_meta->receiving_time_utc); + + MessageBlock* incomplete_block = &self->messages.blocks_queue.last->value; + u8* data_ptr = incomplete_block->data + incomplete_block_meta->data_size; + memcpy(data_ptr, out_message_meta, sizeof(MessageMeta)); + data_ptr += sizeof(MessageMeta); + memcpy(data_ptr, message_data.data, message_data.len); + + Return RESULT_VOID; +} + +// Result(void) Channel_loadMessage(Channel* self, u64 id, +// MessageMeta* out_message_meta, u8* out_data) +// { +// Deferral(1); +// Return RESULT_VOID; +// } + +// Result(void) Channel_loadMessageBlock(Channel* self, u64 fisrt_message_id, u32 count, +// MessageBlockMeta* out_message_meta, MessageBlock* out_block) +// { +// Deferral(1); +// Return RESULT_VOID; +// } diff --git a/src/server/ClientConnection.c b/src/server/ClientConnection.c index 11123ad..39d22ad 100644 --- a/src/server/ClientConnection.c +++ b/src/server/ClientConnection.c @@ -14,7 +14,7 @@ Result(ClientConnection*) ClientConnection_accept(ConnectionHandlerArgs* args) Deferral(8); ClientConnection* conn = (ClientConnection*)malloc(sizeof(ClientConnection)); - memset(conn, 0, sizeof(*conn)); + zeroStruct(conn); bool success = false; Defer(if(!success) ClientConnection_close(conn)); @@ -22,6 +22,7 @@ Result(ClientConnection*) ClientConnection_accept(ConnectionHandlerArgs* args) conn->client_end = args->client_end; conn->session_id = args->session_id; conn->authorized = false; + conn->user_id = -1; conn->session_key = Array_u8_alloc(AES_SESSION_KEY_SIZE); // correct session key will be received from client later Array_u8_memset(&conn->session_key, 0); diff --git a/src/server/responses/GetMessageBlock.c b/src/server/responses/GetMessageBlock.c index d2ab40e..2379308 100644 --- a/src/server/responses/GetMessageBlock.c +++ b/src/server/responses/GetMessageBlock.c @@ -1,7 +1,8 @@ #include "responses.h" -#define LOGGER conn->server->logger -#define LOG_FUNC conn->server->log_func +#define srv conn->server +#define LOGGER srv->logger +#define LOG_FUNC srv->log_func #define LOG_CONTEXT log_ctx declare_RequestHandler(GetMessageBlock) @@ -14,6 +15,7 @@ declare_RequestHandler(GetMessageBlock) try_void(PacketHeader_validateContentSize(req_head, sizeof(req))); try_void(EncryptedSocketTCP_recvStruct(&conn->sock, &req)); + (void)res_head; // send response // GetMessageBlockResponse res; // GetMessageBlockResponse_construct(&res, res_head, ); diff --git a/src/server/responses/Login.c b/src/server/responses/Login.c index f2102bd..68df2da 100644 --- a/src/server/responses/Login.c +++ b/src/server/responses/Login.c @@ -1,7 +1,8 @@ #include "responses.h" -#define LOGGER conn->server->logger -#define LOG_FUNC conn->server->log_func +#define srv conn->server +#define LOGGER srv->logger +#define LOG_FUNC srv->log_func #define LOG_CONTEXT log_ctx declare_RequestHandler(Login) @@ -35,15 +36,15 @@ declare_RequestHandler(Login) } // lock users cache - idb_lockTable(conn->server->users.table); + idb_lockTable(srv->users.table); bool unlocked_users_cache_mutex = false; Defer( if(!unlocked_users_cache_mutex) - idb_unlockTable(conn->server->users.table) + idb_unlockTable(srv->users.table) ); // try get id from name cache - u64* id_ptr = HashMap_tryGetPtr(&conn->server->users.username_id_map, username_str); + u64* id_ptr = HashMap_tryGetPtr(&srv->users.name_id_map, username_str); if(id_ptr == NULL){ try_void(sendErrorMessage_f(log_ctx, conn, res_head, LogSeverity_Warn, @@ -55,8 +56,8 @@ declare_RequestHandler(Login) u64 user_id = *id_ptr; // get user by id - try_assert(user_id < conn->server->users.cache_list.len); - UserInfo* u = conn->server->users.cache_list.data + user_id; + try_assert(user_id < srv->users.list.len); + UserInfo* u = srv->users.list.data + user_id; // validate token hash if(memcmp(req.token, u->token, sizeof(req.token)) != 0){ @@ -68,16 +69,17 @@ declare_RequestHandler(Login) } // manually unlock mutex - idb_unlockTable(conn->server->users.table); + idb_unlockTable(srv->users.table); unlocked_users_cache_mutex = true; // authorize conn->authorized = true; + conn->user_id = user_id; logInfo("authorized user '%s'", username_str.data); // send response LoginResponse res; - LoginResponse_construct(&res, res_head, user_id, conn->server->landing_channel_id); + LoginResponse_construct(&res, res_head, user_id, srv->landing_channel_id); try_void(EncryptedSocketTCP_sendStruct(&conn->sock, res_head)); try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &res)); diff --git a/src/server/responses/Register.c b/src/server/responses/Register.c index 4647fa5..243485e 100644 --- a/src/server/responses/Register.c +++ b/src/server/responses/Register.c @@ -1,7 +1,8 @@ #include "responses.h" -#define LOGGER conn->server->logger -#define LOG_FUNC conn->server->log_func +#define srv conn->server +#define LOGGER srv->logger +#define LOG_FUNC srv->log_func #define LOG_CONTEXT log_ctx declare_RequestHandler(Register) @@ -35,16 +36,16 @@ declare_RequestHandler(Register) } // lock users cache - idb_lockTable(conn->server->users.table); + idb_lockTable(srv->users.table); bool unlocked_users_cache_mutex = false; // unlock mutex on error catch Defer( if(!unlocked_users_cache_mutex) - idb_unlockTable(conn->server->users.table) + idb_unlockTable(srv->users.table) ); // check if name is taken - if(HashMap_tryGetPtr(&conn->server->users.username_id_map, username_str) != NULL){ + if(HashMap_tryGetPtr(&srv->users.name_id_map, username_str) != NULL){ try_void(sendErrorMessage_f(log_ctx, conn, res_head, LogSeverity_Warn, "Username '%s' already exists", @@ -54,7 +55,7 @@ declare_RequestHandler(Register) // initialize new user UserInfo user; - memset(&user, 0, sizeof(UserInfo)); + zeroStruct(&user); user.name_len = username_str.len; memcpy(user.name, username_str.data, user.name_len); @@ -65,14 +66,14 @@ declare_RequestHandler(Register) // save new user to db and cache try(u64 user_id, u, - idb_pushRow(conn->server->users.table, &user, false) + idb_pushRow(srv->users.table, &user, false) ); - try_assert(user_id == conn->server->users.cache_list.len); - List_UserInfo_pushMany(&conn->server->users.cache_list, &user, 1); - try_assert(HashMap_tryPush(&conn->server->users.username_id_map, username_str, &user_id)); + try_assert(user_id == srv->users.list.len); + List_UserInfo_pushMany(&srv->users.list, &user, 1); + try_assert(HashMap_tryPush(&srv->users.name_id_map, username_str, &user_id)); // manually unlock mutex - idb_unlockTable(conn->server->users.table); + idb_unlockTable(srv->users.table); unlocked_users_cache_mutex = true; logInfo("registered user '%s'", username_str.data); diff --git a/src/server/responses/SendMessage.c b/src/server/responses/SendMessage.c index 4b11245..04477e5 100644 --- a/src/server/responses/SendMessage.c +++ b/src/server/responses/SendMessage.c @@ -1,7 +1,8 @@ #include "responses.h" -#define LOGGER conn->server->logger -#define LOG_FUNC conn->server->log_func +#define srv conn->server +#define LOGGER srv->logger +#define LOG_FUNC srv->log_func #define LOG_CONTEXT log_ctx declare_RequestHandler(SendMessage) @@ -14,6 +15,14 @@ declare_RequestHandler(SendMessage) try_void(PacketHeader_validateContentSize(req_head, sizeof(req))); try_void(EncryptedSocketTCP_recvStruct(&conn->sock, &req)); + if(!conn->authorized){ + try_void(sendErrorMessage(log_ctx, conn, res_head, + LogSeverity_Warn, + STR("is not authorized") + )); + Return RESULT_VOID; + } + if(req.data_size < MESSAGE_SIZE_MIN || req.data_size > MESSAGE_SIZE_MAX){ try_void(sendErrorMessage_f(log_ctx, conn, res_head, LogSeverity_Warn, @@ -28,16 +37,24 @@ declare_RequestHandler(SendMessage) Array(u8) message_data = Array_u8_alloc(req.data_size); try_void(EncryptedSocketTCP_recv(&conn->sock, message_data, SocketRecvFlag_WholeBuffer)); - for(u16 i = 0; i < message_data.len; i++){ - u8 b = message_data.data[i]; + // save message to channel + Channel* ch = Server_tryGetChannel(conn->server, req.channel_id); + if(ch == NULL){ + try_void(sendErrorMessage(log_ctx, conn, res_head, + LogSeverity_Warn, + STR("invalid channel id") + )); + Return RESULT_VOID; } - + MessageMeta message_meta; + try_void(Channel_saveMessage(ch, message_data, conn->user_id, &message_meta)); // send response - // SendMessageResponse res; - // SendMessageResponse_construct(&res, res_head, ); - // try_void(EncryptedSocketTCP_sendStruct(&conn->sock, res_head)); - // try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &res)); + SendMessageResponse res; + SendMessageResponse_construct(&res, res_head, + message_meta.id, message_meta.receiving_time_utc); + try_void(EncryptedSocketTCP_sendStruct(&conn->sock, res_head)); + try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &res)); Return RESULT_VOID; } diff --git a/src/server/responses/ServerPublicInfo.c b/src/server/responses/ServerPublicInfo.c index c6bb85e..18eb28e 100644 --- a/src/server/responses/ServerPublicInfo.c +++ b/src/server/responses/ServerPublicInfo.c @@ -1,7 +1,8 @@ #include "responses.h" -#define LOGGER conn->server->logger -#define LOG_FUNC conn->server->log_func +#define srv conn->server +#define LOGGER srv->logger +#define LOG_FUNC srv->log_func #define LOG_CONTEXT log_ctx declare_RequestHandler(ServerPublicInfo) @@ -25,10 +26,10 @@ declare_RequestHandler(ServerPublicInfo) Return RESULT_VOID; } case ServerPublicInfo_Name: - content = str_castTo_Array_u8(conn->server->name); + content = str_castTo_Array_u8(srv->name); break; case ServerPublicInfo_Description: - content = str_castTo_Array_u8(conn->server->description); + content = str_castTo_Array_u8(srv->description); break; } diff --git a/src/server/responses/responses.h b/src/server/responses/responses.h index e400760..e27c2f9 100644 --- a/src/server/responses/responses.h +++ b/src/server/responses/responses.h @@ -17,8 +17,7 @@ Result(void) sendErrorMessage_f( #define declare_RequestHandler(TYPE) \ - Result(void) handleRequest_##TYPE(\ - cstr log_ctx, cstr req_type_name, \ + Result(void) handleRequest_##TYPE(cstr log_ctx, cstr req_type_name, \ ClientConnection* conn, PacketHeader* req_head, PacketHeader* res_head) #define case_handleRequest(TYPE) \ @@ -29,5 +28,5 @@ Result(void) sendErrorMessage_f( declare_RequestHandler(ServerPublicInfo); declare_RequestHandler(Login); declare_RequestHandler(Register); - - +declare_RequestHandler(SendMessage); +declare_RequestHandler(GetMessageBlock); diff --git a/src/server/responses/send_error.c b/src/server/responses/send_error.c index ad337d0..0866ba4 100644 --- a/src/server/responses/send_error.c +++ b/src/server/responses/send_error.c @@ -1,7 +1,8 @@ #include "responses.h" -#define LOGGER conn->server->logger -#define LOG_FUNC conn->server->log_func +#define srv conn->server +#define LOGGER srv->logger +#define LOG_FUNC srv->log_func #define LOG_CONTEXT log_ctx Result(void) sendErrorMessage( diff --git a/src/server/responses/template b/src/server/responses/template index d8a5837..d6a9798 100644 --- a/src/server/responses/template +++ b/src/server/responses/template @@ -1,7 +1,8 @@ #include "responses.h" -#define LOGGER conn->server->logger -#define LOG_FUNC conn->server->log_func +#define srv conn->server +#define LOGGER srv->logger +#define LOG_FUNC srv->log_func #define LOG_CONTEXT log_ctx declare_RequestHandler(NAME) diff --git a/src/server/server.c b/src/server/server.c index 7a84377..c74526a 100644 --- a/src/server/server.c +++ b/src/server/server.c @@ -2,6 +2,7 @@ #include "tlibc/filesystem.h" #include "tlibc/time.h" #include "tlibc/base64.h" +#include "tlibc/algorithms.h" #include "server/server_internal.h" #include "network/tcp-chat-protocol/v1.h" #include "server/responses/responses.h" @@ -21,12 +22,15 @@ void Server_free(Server* self){ idb_close(self->db); - List_UserInfo_destroy(&self->users.cache_list); - HashMap_destroy(&self->users.username_id_map); + List_UserInfo_destroy(&self->users.list); + HashMap_destroy(&self->users.name_id_map); + + for(u32 i = 0; i < self->channels.list.len; i++){ + Channel_free(self->channels.list.data[i]); + } + List_ChannelPtr_destroy(&self->channels.list); + HashMap_destroy(&self->channels.name_id_map); - List_MessageBlockMeta_destroy(&self->messages.blocks_meta_list); - LList_MessageBlock_destroy(&self->messages.blocks_queue); - free(self->messages.incomplete_block); free(self); } @@ -38,12 +42,12 @@ void Server_free(Server* self){ Result(Server*) Server_create(str config_file_content, cstr config_file_name, void* logger, LogFunction_t log_func) - { +{ Deferral(16); cstr log_ctx = "ServerInit"; Server* self = (Server*)malloc(sizeof(Server)); - memset(self, 0, sizeof(Server)); + zeroStruct(self); bool success = false; Defer(if(!success) Server_free(self)); @@ -51,102 +55,110 @@ Result(Server*) Server_create(str config_file_content, cstr config_file_name, self->log_func = log_func; logDebug("parsing config"); - try(TomlTable* config_toml, p, toml_load_str_filename(config_file_content, config_file_name)); - Defer(TomlTable_free(config_toml)); + try(TomlTable* config_top, p, toml_load_str_filename(config_file_content, config_file_name)); + Defer(TomlTable_free(config_top)); + try(TomlTable* config_server, p, TomlTable_get_table(config_top, STR("server"))) // parse name - try(str* v_name, p, TomlTable_get_str(config_toml, STR("name"))); + try(str* v_name, p, TomlTable_get_str(config_server, STR("name"))); self->name = str_copy(*v_name); // parse description - try(str* v_desc, p, TomlTable_get_str(config_toml, STR("description"))); + try(str* v_desc, p, TomlTable_get_str(config_server, STR("description"))); self->description = str_copy(*v_desc); - - // parse landing_channel_id - try(i64 v_landing_channel_id, i, TomlTable_get_integer(config_toml, STR("landing_channel_id"))); - self->landing_channel_id = v_landing_channel_id; // parse local_address - try(str* v_local_address, p, TomlTable_get_str(config_toml, STR("local_address"))); + try(str* v_local_address, p, TomlTable_get_str(config_server, STR("local_address"))); try_assert(v_local_address->isZeroTerminated); try_void(EndpointIPv4_parse(v_local_address->data, &self->local_end)); + + // parse landing_channel_id + try(i64 v_landing_channel_id, i, TomlTable_get_integer(config_server, STR("landing_channel_id"))); + self->landing_channel_id = v_landing_channel_id; + try(TomlTable* config_keys, p, TomlTable_get_table(config_top, STR("keys"))) // parse rsa_private_key - try(str* v_rsa_sk, p, TomlTable_get_str(config_toml, STR("rsa_private_key"))); + try(str* v_rsa_sk, p, TomlTable_get_str(config_keys, STR("rsa_private_key"))); try_assert(v_rsa_sk->isZeroTerminated); try_void(RSA_parsePrivateKey_base64(v_rsa_sk->data, &self->rsa_sk)); // parse rsa_public_key - try(str* v_rsa_pk, p, TomlTable_get_str(config_toml, STR("rsa_public_key"))); + try(str* v_rsa_pk, p, TomlTable_get_str(config_keys, STR("rsa_public_key"))); try_assert(v_rsa_pk->isZeroTerminated); try_void(RSA_parsePublicKey_base64(v_rsa_pk->data, &self->rsa_pk)); + try(TomlTable* config_db, p, TomlTable_get_table(config_top, STR("database"))) // parse db_aes_key - try(str* v_db_aes_key, p, TomlTable_get_str(config_toml, STR("db_aes_key"))); + try(str* v_db_aes_key, p, TomlTable_get_str(config_db, STR("aes_key"))); str db_aes_key_s = *v_db_aes_key; Array(u8) db_aes_key = Array_u8_alloc(base64_decodedSize(db_aes_key_s.data, db_aes_key_s.len)); Defer(free(db_aes_key.data)); base64_decode(db_aes_key_s.data, db_aes_key_s.len, db_aes_key.data); - // parse db_dir and open db - try(str* v_db_dir, p, TomlTable_get_str(config_toml, STR("db_dir"))); + // parse db_dir + try(str* v_db_dir, p, TomlTable_get_str(config_db, STR("dir"))); + + // open DB try(self->db, p, idb_open(*v_db_dir, db_aes_key)); // build users cache - logDebug("loading users..."); + logInfo("loading users..."); try(self->users.table, p, - idb_getOrCreateTable(self->db, STR("users"), sizeof(UserInfo), false) + idb_getOrCreateTable(self->db, str_null, STR("users"), sizeof(UserInfo), false) ); // load whole users table to list try_void( - idb_createListFromTable(self->users.table, (void*)&self->users.cache_list, false) + idb_createListFromTable(self->users.table, (void*)&self->users.list, false) ); // build name-id map try(u64 users_count, u, idb_getRowCount(self->users.table, false)); - HashMap_construct(&self->users.username_id_map, u64, NULL); + HashMap_construct(&self->users.name_id_map, u64, NULL); for(u64 id = 0; id < users_count; id++){ - UserInfo* user = self->users.cache_list.data + id; + UserInfo* user = self->users.list.data + id; str key = str_construct(user->name, user->name_len, true); - if(!HashMap_tryPush(&self->users.username_id_map, key, &id)){ - Return RESULT_ERROR_FMT("duplicate user name '"FMT_str"'", key.len, key.data); + if(!HashMap_tryPush(&self->users.name_id_map, key, &id)){ + Return RESULT_ERROR_FMT("duplicate user name '"FMT_str"'", str_expand(key)); } } - logDebug("loaded "FMT_u64" users", users_count); + logInfo("loaded "FMT_u64" users", users_count); + // parse channels + logDebug("loading channels..."); + HashMap_construct(&self->channels.name_id_map, u64, NULL); + self->channels.list = List_ChannelPtr_alloc(32); + try(TomlTable* config_channels, p, TomlTable_get_table(config_top, STR("channels"))); + + HashMapIter channels_iter = HashMapIter_create(config_channels); + while(HashMapIter_moveNext(&channels_iter)){ + HashMapKeyValue kv; + HashMapIter_getCurrent(&channels_iter, &kv); + str name = kv.key; + TomlValue* val = kv.value_ptr; + // skip if not table + if(val->type != TLIBTOML_TABLE) + continue; + + logInfo("loading channel '"FMT_str"'", str_expand(name)) + TomlTable* config_channel = val->table; + try(u64 id, u, TomlTable_get_integer(config_channel, STR("id"))); + try(str* v_ch_desc, p, TomlTable_get_str(config_channel, STR("description"))) + str description = *v_ch_desc; - // build messages cache - logDebug("loading messages..."); - try(self->messages.blocks_table, p, - idb_getOrCreateTable(self->db, STR("message_blocks"), sizeof(MessageBlock), false) - ); - try(self->messages.blocks_meta_table, p, - idb_getOrCreateTable(self->db, STR("message_blocks_meta"), sizeof(MessageBlockMeta), false) - ); + if(!HashMap_tryPush(&self->channels.name_id_map, name, &id)){ + Return RESULT_ERROR_FMT("duplicate channel '"FMT_str"'", str_expand(name)); + } - // load whole message_blocks_meta table to list - try_void( - idb_createListFromTable(self->messages.blocks_meta_table, (void*)&self->messages.blocks_meta_list, false) - ); - - // load N last blocks to the queue - self->messages.incomplete_block = LLNode_MessageBlock_createZero(); - self->messages.blocks_queue = LList_construct(MessageBlock, NULL); - try(u64 message_blocks_count, u, idb_getRowCount(self->messages.blocks_table, false)); - u64 first_id = 0; - if(message_blocks_count > MESSAGE_BLOCKS_CACHE_COUNT) - first_id = message_blocks_count - MESSAGE_BLOCKS_CACHE_COUNT; - for(u64 id = first_id; id < message_blocks_count; id++){ - LLNode(MessageBlock)* node = LLNode_MessageBlock_createZero(); - LList_MessageBlock_insertAfter( - &self->messages.blocks_queue, - self->messages.blocks_queue.last, - node - ); - try_void(idb_getRow(self->messages.blocks_table, id, node->value.data, false)); + logDebug("loading messages..."); + try(Channel* channel, p, Channel_create(id, name, description, self->db, false)); + logDebug("loaded "FMT_u64" messages", channel->messages.count); + List_ChannelPtr_push(&self->channels.list, channel); } - logDebug("loaded "FMT_u64" message blocks", message_blocks_count); + logDebug("loaded "FMT_u32" channels", self->channels.list.len); + // sort channels list by id + for(u32 i = 0; i < self->channels.list.len; i++); + insertionSort_inline(self->channels.list.data, self->channels.list.len, ->id); success = true; Return RESULT_VALUE(p, self); @@ -189,6 +201,15 @@ Result(void) Server_run(Server* server){ Return RESULT_VOID; } +Channel* Server_tryGetChannel(Server* self, u64 id){ + i32 index; + binarySearch_inline(self->channels.list.data, self->channels.list.len, id, ->id, index); + if(index == -1){ + return NULL; + } + Channel* channel = self->channels.list.data[index]; + return channel; +} static void* handleConnection(void* _args){ ConnectionHandlerArgs* args = (ConnectionHandlerArgs*)_args; @@ -244,7 +265,10 @@ static Result(void) try_handleConnection(ConnectionHandlerArgs* args, cstr log_c case_handleRequest(ServerPublicInfo); case_handleRequest(Login); case_handleRequest(Register); + // authorized requests + case_handleRequest(SendMessage); + case_handleRequest(GetMessageBlock); } } diff --git a/src/server/server_internal.h b/src/server/server_internal.h index d260ca1..8709182 100644 --- a/src/server/server_internal.h +++ b/src/server/server_internal.h @@ -17,6 +17,40 @@ LList_declare(MessageBlock); #define MESSAGE_BLOCKS_CACHE_COUNT 50 +typedef struct Channel { + u64 id; + str name; + str description; + struct { + u64 count; + Table* blocks_table; + Table* blocks_meta_table; + List(MessageBlockMeta) blocks_meta_list; // index is id + // last MESSAGE_BLOCKS_CACHE_COUNT MessageBlocks, ascending + // new messages are written to the last block + LList(MessageBlock) blocks_queue; + } messages; +} Channel; + +typedef Channel* ChannelPtr; +List_declare(ChannelPtr); + +Result(Channel*) Channel_create(u64 id, str name, str description, + IncrementalDB* db, bool lock_db); + +void Channel_free(Channel* self); + +Result(void) Channel_saveMessage(Channel* self, Array(u8) message_data, u64 sender_id, + MessageMeta* out_message_meta); + +/// @param out_data buffer of size >= MESSAGE_SIZE_MAX +Result(void) Channel_loadMessage(Channel* self, u64 id, + MessageMeta* out_meta, u8* out_data); + +Result(void) Channel_loadMessageBlock(Channel* self, u64 fisrt_message_id, u32 count, + MessageBlockMeta* out_meta, MessageBlock* out_block); + + typedef struct Server { /* from constructor */ void* logger; @@ -34,19 +68,17 @@ typedef struct Server { IncrementalDB* db; struct { Table* table; - List(UserInfo) cache_list; // index is id - HashMap(u64) username_id_map; + List(UserInfo) list; // index is id + HashMap(u64) name_id_map; } users; - /* messages */ struct { - Table* blocks_table; - Table* blocks_meta_table; - List(MessageBlockMeta) blocks_meta_list; // index is id - LList(MessageBlock) blocks_queue; // last N MessageBlocks, ascending - LLNode(MessageBlock)* incomplete_block; // new messages are written here until block is full - } messages; + List(ChannelPtr) list; + HashMap(u64) name_id_map; + } channels; } Server; +Channel* Server_tryGetChannel(Server* self, u64 id); + typedef struct ClientConnection { Server* server; @@ -55,6 +87,7 @@ typedef struct ClientConnection { Array(u8) session_key; EncryptedSocketTCP sock; bool authorized; + u64 user_id; // -1 for unauthorized } ClientConnection; typedef struct ConnectionHandlerArgs { diff --git a/tcp-chat-server.toml.default b/tcp-chat-server.toml.default index d9a6ea1..0349bb3 100644 --- a/tcp-chat-server.toml.default +++ b/tcp-chat-server.toml.default @@ -1,13 +1,21 @@ +[server] name = "Test Server" description = """\ Lorem ipsum labuba aboba.\n\ Qqqqq...\ """ - -landing_channel_id = 0 local_address = '127.0.0.1:9988' -db_dir = 'server-db' -db_aes_key = '' +landing_channel_id = 0 +# do not create channels with the same id +[channels.general] +id = 0 +description = "a text channel" + +[database] +dir = 'server-db' +aes_key = '' + +[keys] rsa_private_key = '' rsa_public_key = ''