From 49793e2929de1410f4667ab4f27b996b79ad6f79 Mon Sep 17 00:00:00 2001 From: Timerix Date: Mon, 15 Dec 2025 23:26:32 +0500 Subject: [PATCH] implemented CommonQueries --- .vscode/c_cpp_properties.json | 1 + dependencies/tlibc | 2 +- dependencies/tsqlite | 2 +- include/tcp-chat/client.h | 4 +- include/tcp-chat/common_constants.h | 2 +- project.config | 3 +- src/cli/ClientCLI/ClientCLI.c | 34 ++--- src/cli/ClientCLI/ClientCLI.h | 12 +- src/client/client.c | 4 +- src/client/client_internal.h | 4 +- src/network/socket.c | 10 +- src/network/socket.h | 2 +- src/network/tcp-chat-protocol/v1.c | 80 +++++++++-- src/network/tcp-chat-protocol/v1.h | 73 +++++++--- src/server/Channel.c | 186 ------------------------- src/server/ClientConnection.c | 20 ++- src/server/db/Channel.c | 116 +++++++++++++++ src/server/db/CommonQueries.c | 82 +++++++++++ src/server/db/User.c | 3 + src/server/db/db.h | 39 ++++++ src/server/db/db_internal.h | 33 +++++ src/server/responses/GetMessageBlock.c | 34 +++-- src/server/responses/Login.c | 34 ++--- src/server/responses/Register.c | 46 ++---- src/server/responses/SendMessage.c | 21 +-- src/server/responses/responses.h | 2 - src/server/server.c | 98 +++---------- src/server/server_internal.h | 81 +++-------- tcp-chat-server.toml.default | 7 +- 29 files changed, 540 insertions(+), 495 deletions(-) delete mode 100644 src/server/Channel.c create mode 100644 src/server/db/Channel.c create mode 100644 src/server/db/CommonQueries.c create mode 100644 src/server/db/User.c create mode 100644 src/server/db/db.h create mode 100644 src/server/db/db_internal.h diff --git a/.vscode/c_cpp_properties.json b/.vscode/c_cpp_properties.json index ff6a5f5..c817fd8 100755 --- a/.vscode/c_cpp_properties.json +++ b/.vscode/c_cpp_properties.json @@ -9,6 +9,7 @@ "dependencies/BearSSL/inc", "dependencies/tlibc/include", "dependencies/tlibtoml/include", + "dependencies/tsqlite/include", "${default}" ], "cStandard": "c99" diff --git a/dependencies/tlibc b/dependencies/tlibc index 08d45fa..bdbe959 160000 --- a/dependencies/tlibc +++ b/dependencies/tlibc @@ -1 +1 @@ -Subproject commit 08d45faa83bbb1b56c3933b117e1dce6b9f21e71 +Subproject commit bdbe959e23fab5697cfdfee4ef92d2da1c755b3b diff --git a/dependencies/tsqlite b/dependencies/tsqlite index 979bbfe..58840ce 160000 --- a/dependencies/tsqlite +++ b/dependencies/tsqlite @@ -1 +1 @@ -Subproject commit 979bbfe2bab9b0b04b1e75008ef8d0ba8bb96757 +Subproject commit 58840cecd058476af12b9844e83080dd5cc22a38 diff --git a/include/tcp-chat/client.h b/include/tcp-chat/client.h index 14a04a5..2ce351f 100644 --- a/include/tcp-chat/client.h +++ b/include/tcp-chat/client.h @@ -28,7 +28,7 @@ Result(void) Client_getServerName(Client* self, str* out_name); Result(void) Client_getServerDescription(Client* self, str* out_desc); /// Create new account on connected server -Result(void) Client_register(Client* self, u64* out_user_id); +Result(void) Client_register(Client* self, i64* out_user_id); /// Authorize on connected server -Result(void) Client_login(Client* self, u64* out_user_id, u64* out_landing_channel_id); +Result(void) Client_login(Client* self, i64* out_user_id, i64* out_landing_channel_id); diff --git a/include/tcp-chat/common_constants.h b/include/tcp-chat/common_constants.h index d6d353c..badd428 100644 --- a/include/tcp-chat/common_constants.h +++ b/include/tcp-chat/common_constants.h @@ -18,4 +18,4 @@ #define CHANNEL_DESC_SIZE_MAX 1023 #define MESSAGE_SIZE_MIN 1 #define MESSAGE_SIZE_MAX 4000 -#define MESSAGE_BLOCK_SIZE (64*1024) \ No newline at end of file +#define MESSAGE_BLOCK_COUNT_MAX 50 \ No newline at end of file diff --git a/project.config b/project.config index f2fbd75..4580ff7 100644 --- a/project.config +++ b/project.config @@ -38,7 +38,8 @@ STATIC_LIB_FILE="$PROJECT.a" INCLUDE="-Isrc -Iinclude -I$DEPENDENCIES_DIR/BearSSL/inc -I$DEPENDENCIES_DIR/tlibc/include - -I$DEPENDENCIES_DIR/tlibtoml/include" + -I$DEPENDENCIES_DIR/tlibtoml/include + -I$DEPENDENCIES_DIR/tsqlite/include" # OS-specific options case "$OS" in diff --git a/src/cli/ClientCLI/ClientCLI.c b/src/cli/ClientCLI/ClientCLI.c index b65c4f4..78a2d75 100644 --- a/src/cli/ClientCLI/ClientCLI.c +++ b/src/cli/ClientCLI/ClientCLI.c @@ -230,13 +230,13 @@ static Result(void) ClientCLI_selectServerFromCache(ClientCLI* self){ idb_lockTable(self->servers.table); Defer(idb_unlockTable(self->servers.table)); - u32 servers_count = self->servers.list.len; - if(servers_count == 0){ + u32 server_count = self->servers.list.len; + if(server_count == 0){ printf("No servers found in cache\n"); Return RESULT_VOID; } - for(u32 id = 0; id < servers_count; id++){ + for(u32 id = 0; id < server_count; id++){ ServerInfo* server = self->servers.list.data + id; printf("[%02u] "FMT_str" "FMT_str"\n", id, server->address_len, server->address, server->name_len, server->name); @@ -256,7 +256,7 @@ static Result(void) ClientCLI_selectServerFromCache(ClientCLI* self){ if(sscanf(buf, FMT_u32, &id) != 1){ printf("ERROR: not a number\n"); } - else if(id >= servers_count){ + else if(id >= server_count){ printf("ERROR: not a server number: %u\n", id); } else break; @@ -316,9 +316,9 @@ static Result(void) ClientCLI_openUserDB(ClientCLI* self){ str username = Client_getUserName(self->client); Array(u8) user_data_key = Client_getUserDataKey(self->client); - str user_db_dir = str_from_cstr(strcat_malloc("client-db", path_seps, username.data)); - Defer(free(user_db_dir.data)); - try(self->db, p, idb_open(user_db_dir, user_data_key)); + str user_db_path = str_from_cstr(strcat_malloc("client-db", path_seps, username.data)); + Defer(free(user_db_path.data)); + try(self->db, p, idb_open(user_db_path, user_data_key)); // Lock DB until this function returns. idb_lockDB(self->db); @@ -339,11 +339,11 @@ static Result(void) ClientCLI_openUserDB(ClientCLI* self){ ); // build address-id map - try(u64 servers_count, u, + try(i64 server_count, u, idb_getRowCount(self->servers.table, false) ); - HashMap_construct(&self->servers.addr_id_map, u64, NULL); - for(u64 id = 0; id < servers_count; id++){ + HashMap_construct(&self->servers.addr_id_map, i64, NULL); + for(i64 id = 0; id < server_count; id++){ ServerInfo* server = self->servers.list.data + id; str key = str_construct(server->address, server->address_len, true); if(!HashMap_tryPush(&self->servers.addr_id_map, key, &id)){ @@ -391,11 +391,11 @@ static Result(ServerInfo*) ClientCLI_saveServerInfo(ClientCLI* self, // try find server id in cache ServerInfo* cached_row_ptr = NULL; - u64* id_ptr = NULL; + i64* id_ptr = NULL; id_ptr = HashMap_tryGetPtr(&self->servers.addr_id_map, addr); if(id_ptr){ // update existing server - u64 id = *id_ptr; + i64 id = *id_ptr; try_void(idb_updateRow(self->servers.table, id, &server, false)); try_assert(id < self->servers.list.len); cached_row_ptr = self->servers.list.data + id; @@ -403,7 +403,7 @@ static Result(ServerInfo*) ClientCLI_saveServerInfo(ClientCLI* self, } else { // push new server - try(u64 id, u, idb_pushRow(self->servers.table, &server, false)); + try(i64 id, u, idb_pushRow(self->servers.table, &server, false)); try_assert(id == self->servers.list.len); List_ServerInfo_pushMany(&self->servers.list, &server, 1); cached_row_ptr = self->servers.list.data + id; @@ -416,10 +416,10 @@ static Result(ServerInfo*) ClientCLI_saveServerInfo(ClientCLI* self, static Result(void) ClientCLI_register(ClientCLI* self){ Deferral(8); - u64 user_id = 0; + i64 user_id = 0; try_void(Client_register(self->client, &user_id)); printf("Registered successfully\n"); - printf("user_id: "FMT_u64"\n", user_id); + printf("user_id: "FMT_i64"\n", user_id); // TODO: use user_id somewhere Return RESULT_VOID; @@ -428,10 +428,10 @@ static Result(void) ClientCLI_register(ClientCLI* self){ static Result(void) ClientCLI_login(ClientCLI* self){ Deferral(8); - u64 user_id = 0, landing_channel_id = 0; + i64 user_id = 0, landing_channel_id = 0; try_void(Client_login(self->client, &user_id, &landing_channel_id)); printf("Authorized successfully\n"); - printf("user_id: "FMT_u64", landing_channel_id: "FMT_u64"\n", user_id, landing_channel_id); + printf("user_id: "FMT_i64", landing_channel_id: "FMT_i64"\n", user_id, landing_channel_id); // TODO: use user_id, landing_channel_id somewhere Return RESULT_VOID; diff --git a/src/cli/ClientCLI/ClientCLI.h b/src/cli/ClientCLI/ClientCLI.h index 93bdeb6..6f2d4e7 100644 --- a/src/cli/ClientCLI/ClientCLI.h +++ b/src/cli/ClientCLI/ClientCLI.h @@ -2,20 +2,12 @@ #include #include "tlibc/collections/HashMap.h" #include "tlibc/collections/List.h" +#include "tsqlite.h" #include "tcp-chat/client.h" -#include "db/idb.h" -#include "db/tables.h" - -List_declare(ServerInfo); typedef struct ClientCLI { Client* client; - IncrementalDB* db; - struct { - Table* table; - List(ServerInfo) list; // index is id - HashMap(u64) addr_id_map; // key is server address - } servers; + tsqlite_connection* db; } ClientCLI; void ClientCLI_construct(ClientCLI* self); diff --git a/src/client/client.c b/src/client/client.c index 2709246..632ee7c 100644 --- a/src/client/client.c +++ b/src/client/client.c @@ -72,7 +72,7 @@ Result(void) Client_getServerDescription(Client* self, str* out_desc){ Return RESULT_VOID; } -Result(void) Client_register(Client* self, u64* out_user_id){ +Result(void) Client_register(Client* self, i64* out_user_id){ Deferral(1); try_assert(self != NULL); try_assert(self->conn != NULL && "didn't connect to a server yet"); @@ -90,7 +90,7 @@ Result(void) Client_register(Client* self, u64* out_user_id){ Return RESULT_VOID; } -Result(void) Client_login(Client* self, u64* out_user_id, u64* out_landing_channel_id){ +Result(void) Client_login(Client* self, i64* out_user_id, i64* out_landing_channel_id){ Deferral(1); try_assert(self != NULL); try_assert(self->conn != NULL && "didn't connect to a server yet"); diff --git a/src/client/client_internal.h b/src/client/client_internal.h index d9ebdb6..5a75dd6 100644 --- a/src/client/client_internal.h +++ b/src/client/client_internal.h @@ -21,10 +21,10 @@ typedef struct ServerConnection { Array(u8) token; Array(u8) session_key; EncryptedSocketTCP sock; - u64 session_id; + i64 session_id; str server_name; str server_description; - u64 user_id; + i64 user_id; } ServerConnection; /// @param server_addr_cstr diff --git a/src/network/socket.c b/src/network/socket.c index 0cc39cf..a7610b0 100755 --- a/src/network/socket.c +++ b/src/network/socket.c @@ -132,12 +132,12 @@ Result(i32) socket_recvfrom(Socket s, Array(u8) buffer, SocketRecvFlag flags, NU } Result(void) socket_TCP_enableAliveChecks(Socket s, - sec_t first_check_time, u32 checks_count, sec_t checks_interval) + sec_t first_check_time, u32 check_count, sec_t checks_interval) { #if KN_USE_WINSOCK BOOL opt_SO_KEEPALIVE = 1; // enable keepalives DWORD opt_TCP_KEEPIDLE = first_check_time; - DWORD opt_TCP_KEEPCNT = checks_count; + DWORD opt_TCP_KEEPCNT = check_count; DWORD opt_TCP_KEEPINTVL = checks_interval; try_setsockopt(s, SOL_SOCKET, SO_KEEPALIVE); try_setsockopt(s, IPPROTO_TCP, TCP_KEEPIDLE); @@ -145,12 +145,12 @@ Result(void) socket_TCP_enableAliveChecks(Socket s, try_setsockopt(s, IPPROTO_TCP, TCP_KEEPINTVL); // timeout for connect() - DWORD opt_TCP_MAXRT = checks_count * checks_interval; + DWORD opt_TCP_MAXRT = check_count * checks_interval; try_setsockopt(s, IPPROTO_TCP, TCP_MAXRT); #else int opt_SO_KEEPALIVE = 1; // enable keepalives int opt_TCP_KEEPIDLE = first_check_time; - int opt_TCP_KEEPCNT = checks_count; + int opt_TCP_KEEPCNT = check_count; int opt_TCP_KEEPINTVL = checks_interval; try_setsockopt(s, SOL_SOCKET, SO_KEEPALIVE); try_setsockopt(s, IPPROTO_TCP, TCP_KEEPIDLE); @@ -158,7 +158,7 @@ Result(void) socket_TCP_enableAliveChecks(Socket s, try_setsockopt(s, IPPROTO_TCP, TCP_KEEPINTVL); // read more in the article - int opt_TCP_USER_TIMEOUT = checks_count * checks_interval * 1000; + int opt_TCP_USER_TIMEOUT = check_count * checks_interval * 1000; try_setsockopt(s, IPPROTO_TCP, TCP_USER_TIMEOUT); #endif return RESULT_VOID; diff --git a/src/network/socket.h b/src/network/socket.h index 6d6cdfc..bbb6eec 100755 --- a/src/network/socket.h +++ b/src/network/socket.h @@ -38,7 +38,7 @@ Result(i32) socket_recvfrom(Socket s, Array(u8) buffer, SocketRecvFlag flags /// Read more: https://blog.cloudflare.com/when-tcp-sockets-refuse-to-die/ /// RU translaton: https://habr.com/ru/articles/700470/ Result(void) socket_TCP_enableAliveChecks(Socket s, - sec_t first_check_time, u32 checks_count, sec_t checks_interval); + sec_t first_check_time, u32 check_count, sec_t checks_interval); #define socket_TCP_enableAliveChecks_default(socket) \ socket_TCP_enableAliveChecks(socket, 1, 4, 5) diff --git a/src/network/tcp-chat-protocol/v1.c b/src/network/tcp-chat-protocol/v1.c index 2388e06..bba1c49 100644 --- a/src/network/tcp-chat-protocol/v1.c +++ b/src/network/tcp-chat-protocol/v1.c @@ -63,7 +63,7 @@ Result(void) ClientHandshake_tryConstruct(ClientHandshake* ptr, PacketHeader* he } void ServerHandshake_construct(ServerHandshake* ptr, PacketHeader* header, - u64 session_id) + i64 session_id) { _PacketHeader_construct(ServerHandshake); zeroStruct(ptr); @@ -106,7 +106,7 @@ Result(void) LoginRequest_tryConstruct(LoginRequest *ptr, PacketHeader* header, } void LoginResponse_construct(LoginResponse* ptr, PacketHeader* header, - u64 user_id, u64 landing_channel_id) + i64 user_id, i64 landing_channel_id) { _PacketHeader_construct(LoginResponse); zeroStruct(ptr); @@ -135,15 +135,73 @@ Result(void) RegisterRequest_tryConstruct(RegisterRequest *ptr, PacketHeader* he } void RegisterResponse_construct(RegisterResponse *ptr, PacketHeader* header, - u64 user_id) + i64 user_id) { _PacketHeader_construct(RegisterResponse); zeroStruct(ptr); ptr->user_id = user_id; } +Result(u32) MessageBlock_writeMessage( + MessageMeta* msg, Array(u8) msg_content, + MessageBlockMeta* block_meta, Array(u8)* block_free_part) +{ + Deferral(1); + try_assert(msg->data_size >= MESSAGE_SIZE_MIN && msg->data_size <= MESSAGE_SIZE_MAX); + try_assert(msg->data_size <= msg_content.len); + + u32 offset_increment = sizeof(MessageMeta) + msg->data_size; + if(block_free_part->len < offset_increment){ + Return RESULT_VALUE(u, 0); + } + + memcpy(block_free_part->data, msg, sizeof(MessageMeta)); + block_free_part->data += sizeof(MessageMeta); + block_free_part->len -= sizeof(MessageMeta); + + memcpy(block_free_part->data, msg_content.data, msg->data_size); + block_free_part->data += msg->data_size; + block_free_part->len -= msg->data_size; + + if(block_meta->message_count == 0) + block_meta->first_message_id = msg->id; + block_meta->message_count++; + block_meta->data_size += offset_increment; + + Return RESULT_VALUE(u, offset_increment); +} + +Result(u32) MessageBlock_readMessage( + Array(u8)* block_unread_part, + MessageMeta* msg, Array(u8) msg_content) +{ + Deferral(1); + try_assert(block_unread_part->len >= sizeof(MessageMeta) + MESSAGE_SIZE_MIN); + try_assert(msg_content.len >= MESSAGE_SIZE_MIN && msg_content.len <= MESSAGE_SIZE_MAX); + + memcpy(msg, block_unread_part->data, sizeof(MessageMeta)); + block_unread_part->data += sizeof(MessageMeta); + block_unread_part->len -= sizeof(MessageMeta); + + if(msg->magic.n != MESSAGE_MAGIC.n){ + Return RESULT_VALUE(u, 0); + } + try_assert(block_unread_part->len >= msg->data_size); + try_assert(msg->data_size >= MESSAGE_SIZE_MIN && msg->data_size <= MESSAGE_SIZE_MAX); + try_assert(msg->id > 0); + try_assert(msg->sender_id > 0); + try_assert(msg->timestamp.d.year > 2024); + + memcpy(msg_content.data, block_unread_part->data, msg->data_size); + block_unread_part->data += msg->data_size; + block_unread_part->len -= msg->data_size; + + u32 offset_increment = sizeof(MessageMeta) + msg->data_size; + Return RESULT_VALUE(u, offset_increment); +} + void SendMessageRequest_construct(SendMessageRequest *ptr, PacketHeader *header, - u64 channel_id, u16 data_size) + i64 channel_id, u16 data_size) { _PacketHeader_construct(SendMessageRequest); zeroStruct(ptr); @@ -152,30 +210,28 @@ void SendMessageRequest_construct(SendMessageRequest *ptr, PacketHeader *header, } void SendMessageResponse_construct(SendMessageResponse *ptr, PacketHeader *header, - u64 message_id, DateTime receiving_time_utc) + i64 message_id, DateTime timestamp) { _PacketHeader_construct(SendMessageResponse); zeroStruct(ptr); ptr->message_id = message_id; - ptr->receiving_time_utc = receiving_time_utc; + ptr->timestamp = timestamp; } void GetMessageBlockRequest_construct(GetMessageBlockRequest *ptr, PacketHeader *header, - u64 channel_id, u64 first_message_id, u32 messages_count) + i64 channel_id, i64 first_message_id, u32 message_count) { _PacketHeader_construct(GetMessageBlockRequest); zeroStruct(ptr); ptr->channel_id = channel_id; ptr->first_message_id = first_message_id; - ptr->messages_count = messages_count; + ptr->message_count = message_count; } void GetMessageBlockResponse_construct(GetMessageBlockResponse *ptr, PacketHeader *header, - u64 first_message_id, u32 messages_count, u32 data_size) + MessageBlockMeta* block_meta) { _PacketHeader_construct(GetMessageBlockResponse); zeroStruct(ptr); - ptr->first_message_id = first_message_id; - ptr->messages_count = messages_count; - ptr->data_size = data_size; + memcpy(&ptr->block_meta, block_meta, sizeof(MessageBlockMeta)); } diff --git a/src/network/tcp-chat-protocol/v1.h b/src/network/tcp-chat-protocol/v1.h index a5db0c9..bdd0fad 100644 --- a/src/network/tcp-chat-protocol/v1.h +++ b/src/network/tcp-chat-protocol/v1.h @@ -60,11 +60,11 @@ Result(void) ClientHandshake_tryConstruct(ClientHandshake* ptr, PacketHeader* he typedef struct ServerHandshake { - u64 session_id; + i64 session_id; } ALIGN_PACKET_STRUCT ServerHandshake; void ServerHandshake_construct(ServerHandshake* ptr, PacketHeader* header, - u64 session_id); + i64 session_id); typedef enum ServerPublicInfo { @@ -99,12 +99,12 @@ Result(void) LoginRequest_tryConstruct(LoginRequest* ptr, PacketHeader* header, typedef struct LoginResponse { - u64 user_id; - u64 landing_channel_id; + i64 user_id; + i64 landing_channel_id; } ALIGN_PACKET_STRUCT LoginResponse; void LoginResponse_construct(LoginResponse* ptr, PacketHeader* header, - u64 user_id, u64 landing_channel_id); + i64 user_id, i64 landing_channel_id); typedef struct RegisterRequest { @@ -117,49 +117,80 @@ Result(void) RegisterRequest_tryConstruct(RegisterRequest* ptr, PacketHeader* he typedef struct RegisterResponse { - u64 user_id; + i64 user_id; } ALIGN_PACKET_STRUCT RegisterResponse; void RegisterResponse_construct(RegisterResponse* ptr, PacketHeader* header, - u64 user_id); + i64 user_id); + + +typedef struct MessageMeta { + Magic32 magic; + u16 data_size; + i64 id; + i64 sender_id; + DateTime timestamp; /* UTC */ +} ALIGN_PACKET_STRUCT MessageMeta; + +#define MESSAGE_MAGIC ((Magic32){ .bytes = { 'M', 's', 'g', 'S' } }) + +typedef struct MessageBlockMeta { + i64 first_message_id; + u32 message_count; + u32 data_size; +} ALIGN_PACKET_STRUCT MessageBlockMeta; + +/// @brief write msg_meta and msg_meta->data_size bytes from msg_content to buffer +/// @param block_meta set to {0} if block is empty yet +/// @param block_free_part .data and .len are adjusted to point to free part +/// @return amount of bytes written to block, may be 0 if msg_meta and msg_content don't fit +Result(u32) MessageBlock_writeMessage( + MessageMeta* msg_meta, Array(u8) msg_content, + MessageBlockMeta* block_meta, Array(u8)* block_free_part); + +/// @brief read message meta and content from buffer +/// @param block_unread_part .data and .len are adjusted to point to unread part +/// @param msg_content .len must be >= MESSAGE_SIZE_MAX +/// @return amount of bytes read from block, may be 0 if it doesn't start with MESSAGE_MAGIC +Result(u32) MessageBlock_readMessage( + Array(u8)* block_unread_part, + MessageMeta* msg_meta, Array(u8) msg_content); typedef struct SendMessageRequest { - u64 channel_id; + i64 channel_id; u16 data_size; /* stream of size data_size */ } ALIGN_PACKET_STRUCT SendMessageRequest; void SendMessageRequest_construct(SendMessageRequest* ptr, PacketHeader* header, - u64 channel_id, u16 data_size); + i64 channel_id, u16 data_size); typedef struct SendMessageResponse { - u64 message_id; - DateTime receiving_time_utc; + i64 message_id; + DateTime timestamp; /* UTC */ } ALIGN_PACKET_STRUCT SendMessageResponse; void SendMessageResponse_construct(SendMessageResponse* ptr, PacketHeader* header, - u64 message_id, DateTime receiving_time_utc); + i64 message_id, DateTime timestamp); typedef struct GetMessageBlockRequest { - u64 channel_id; - u64 first_message_id; - u32 messages_count; + i64 channel_id; + i64 first_message_id; + u32 message_count; } ALIGN_PACKET_STRUCT GetMessageBlockRequest; void GetMessageBlockRequest_construct(GetMessageBlockRequest* ptr, PacketHeader* header, - u64 channel_id, u64 first_message_id, u32 messages_count); + i64 channel_id, i64 first_message_id, u32 message_count); typedef struct GetMessageBlockResponse { - u64 first_message_id; - u32 messages_count; - u32 data_size; - /* stream of size data_size : ((sequence MessageMeta), (sequence binary-data)) */ + MessageBlockMeta block_meta; + /* stream of size data_size : sequence (MessageMeta, byte[MessageMeta.data_size]) */ } ALIGN_PACKET_STRUCT GetMessageBlockResponse; void GetMessageBlockResponse_construct(GetMessageBlockResponse* ptr, PacketHeader* header, - u64 first_message_id, u32 messages_count, u32 data_size); + MessageBlockMeta* block_meta); diff --git a/src/server/Channel.c b/src/server/Channel.c deleted file mode 100644 index f128590..0000000 --- a/src/server/Channel.c +++ /dev/null @@ -1,186 +0,0 @@ -#include "server_internal.h" -#include "tlibc/string/StringBuilder.h" -#include "tlibc/filesystem.h" - -void Channel_free(Channel* self){ - if(!self) - return; - str_destroy(self->name); - str_destroy(self->description); - List_MessageBlockMeta_destroy(&self->messages.blocks_meta_list); - LList_MessageBlock_destroy(&self->messages.blocks_queue); - free(self); -} - -Result(Channel*) Channel_create(u64 chan_id, str name, str description, - IncrementalDB* db, bool lock_db) -{ - Deferral(8); - - Channel* self = (Channel*)malloc(sizeof(Channel)); - zeroStruct(self); - bool success = false; - Defer(if(!success) Channel_free(self)); - - self->id = chan_id; - try_assert(name.len >= CHANNEL_NAME_SIZE_MIN && name.len <= CHANNEL_NAME_SIZE_MAX); - self->name = str_copy(name); - try_assert(description.len <= CHANNEL_DESC_SIZE_MAX); - self->description = str_copy(description); - - if(lock_db){ - idb_lockDB(db); - Defer(idb_unlockDB(db)); - } - - StringBuilder sb = StringBuilder_alloc(CHANNEL_NAME_SIZE_MAX + 32 + 1); - Defer(StringBuilder_destroy(&sb)); - StringBuilder_append_str(&sb, STR("channels")); - StringBuilder_append_char(&sb, path_sep); - StringBuilder_append_str(&sb, name); - - str subdir = str_copy(StringBuilder_getStr(&sb)); - Defer(str_destroy(subdir)); - str message_blocks_str = STR("message_blocks"); - str message_blocks_meta_str = STR("message_blocks_meta"); - - StringBuilder_removeFromEnd(&sb, -1); - StringBuilder_append_str(&sb, name); - StringBuilder_append_char(&sb, '_'); - StringBuilder_append_str(&sb, message_blocks_str); - try(self->messages.blocks_table, p, - idb_getOrCreateTable(db, subdir, StringBuilder_getStr(&sb), sizeof(MessageBlock), false) - ); - idb_lockTable(self->messages.blocks_table); - Defer(idb_unlockTable(self->messages.blocks_table)); - - - StringBuilder_removeFromEnd(&sb, message_blocks_str.len); - StringBuilder_append_str(&sb, message_blocks_meta_str); - try(self->messages.blocks_meta_table, p, - idb_getOrCreateTable(db, subdir, StringBuilder_getStr(&sb), sizeof(MessageBlockMeta), false) - ); - idb_lockTable(self->messages.blocks_meta_table); - Defer(idb_unlockTable(self->messages.blocks_meta_table)); - - // load whole message_blocks_meta table to list - try_void( - idb_createListFromTable(self->messages.blocks_meta_table, (void*)&self->messages.blocks_meta_list, false) - ); - - // load N last blocks to the queue - self->messages.blocks_queue = LList_construct(MessageBlock, NULL); - u64 message_blocks_count = self->messages.blocks_meta_list.len; - u64 first_block_id = 0; - if(message_blocks_count > MESSAGE_BLOCKS_CACHE_COUNT) - first_block_id = message_blocks_count - MESSAGE_BLOCKS_CACHE_COUNT; - for(u64 id = first_block_id; id < message_blocks_count; id++){ - LLNode(MessageBlock)* node = LLNode_MessageBlock_createZero(); - LList_MessageBlock_insertAfter( - &self->messages.blocks_queue, - self->messages.blocks_queue.last, - node - ); - try_void(idb_getRow(self->messages.blocks_table, id, node->value.data, false)); - } - - if(self->messages.blocks_meta_list.len > 0){ - MessageBlockMeta last_block_meta = self->messages.blocks_meta_list.data[self->messages.blocks_meta_list.len - 1]; - self->messages.count = last_block_meta.first_message_id + last_block_meta.messages_count - 1; - } - else { - self->messages.count = 0; - } - - success = true; - Return RESULT_VALUE(p, self); -} - -void Channel_unloadExcessBlocks(Channel* self){ - while(self->messages.blocks_queue.count > MESSAGE_BLOCKS_CACHE_COUNT){ - LLNode(MessageBlock)* node = self->messages.blocks_queue.first; - LList_MessageBlock_detatch(&self->messages.blocks_queue, node); - free(node); - } -} - -Result(void) Channel_saveMessage(Channel* self, Array(u8) message_data, u64 sender_id, - MessageMeta* out_message_meta, bool lock_tables) -{ - Deferral(4); - - if(lock_tables){ - idb_lockTable(self->messages.blocks_table); - idb_lockTable(self->messages.blocks_meta_table); - Defer( - idb_unlockTable(self->messages.blocks_table); - idb_unlockTable(self->messages.blocks_meta_table); - ); - } - - // create new block if message won't fit in the last existing - MessageBlockMeta* incomplete_block_meta = self->messages.blocks_meta_list.data + self->messages.blocks_meta_list.len; - u64 new_message_id = incomplete_block_meta->first_message_id + incomplete_block_meta->messages_count; - u32 message_size_in_block = sizeof(MessageMeta) + ALIGN_TO(message_data.len, 8); - if(incomplete_block_meta->data_size + message_size_in_block > MESSAGE_BLOCK_SIZE){ - // create new MessageBlockMeta - incomplete_block_meta = List_MessageBlockMeta_expand(&self->messages.blocks_meta_list, 1); - incomplete_block_meta->first_message_id = new_message_id; - incomplete_block_meta->messages_count = 0; - incomplete_block_meta->data_size = 0; - // create new MessageBlock - LList_MessageBlock_insertAfter( - &self->messages.blocks_queue, - self->messages.blocks_queue.last, - LLNode_MessageBlock_createZero()); - // unload old blocks from cache - Channel_unloadExcessBlocks(self); - } - - // create message meta - out_message_meta->magic = MESSAGE_MAGIC; - out_message_meta->data_size = message_data.len; - out_message_meta->id = new_message_id; - out_message_meta->sender_id = sender_id; - DateTime_getUTC(&out_message_meta->receiving_time_utc); - - // copy message data to message block - MessageBlock* incomplete_block = &self->messages.blocks_queue.last->value; - u8* data_ptr = incomplete_block->data + incomplete_block_meta->data_size; - memcpy(data_ptr, out_message_meta, sizeof(MessageMeta)); - data_ptr += sizeof(MessageMeta); - memcpy(data_ptr, message_data.data, message_data.len); - incomplete_block_meta->data_size += sizeof(MessageMeta) + ALIGN_TO(message_data.len, 8); - incomplete_block_meta->messages_count++; - - // save to DB - try_void(idb_pushRow(self->messages.blocks_meta_table, incomplete_block_meta, false)); - try_void(idb_pushRow(self->messages.blocks_table, incomplete_block, false)); - - Return RESULT_VOID; -} - -Result(void) Channel_loadMessageBlock(Channel* self, u64 fisrt_message_id, u32 count, - MessageBlockMeta* out_meta, NULLABLE(Array(u8)*) out_block, bool lock_tables) -{ - Deferral(4); - - if(lock_tables){ - idb_lockTable(self->messages.blocks_table); - idb_lockTable(self->messages.blocks_meta_table); - Defer( - idb_unlockTable(self->messages.blocks_table); - idb_unlockTable(self->messages.blocks_meta_table); - ); - } - - // TODO: Maybe it's better to request message block id directly? Client doesn't know how much bytes `count` messages will take, this can lead to severe lags on slow internet - - // TODO: binary search in list of blocks meta - // TODO: return if out_block == NULL - // TODO: check if block is in N_LAST_BLOCKS - // TODO: load block - // TODO: insert block in queue and keep it sorted - - Return RESULT_VOID; -} diff --git a/src/server/ClientConnection.c b/src/server/ClientConnection.c index 39d22ad..4acdd1a 100644 --- a/src/server/ClientConnection.c +++ b/src/server/ClientConnection.c @@ -1,11 +1,14 @@ #include "server/server_internal.h" -#include "network/tcp-chat-protocol/v1.h" void ClientConnection_close(ClientConnection* conn){ if(!conn) return; + tsqlite_connection_close(conn->db); EncryptedSocketTCP_destroy(&conn->sock); Array_u8_destroy(&conn->session_key); + Array_u8_destroy(&conn->message_block); + Array_u8_destroy(&conn->message_content); + CommonQueries_free(conn->queries.common); free(conn); } @@ -21,10 +24,17 @@ Result(ClientConnection*) ClientConnection_accept(ConnectionHandlerArgs* args) conn->server = args->server; conn->client_end = args->client_end; conn->session_id = args->session_id; - conn->authorized = false; - conn->user_id = -1; - conn->session_key = Array_u8_alloc(AES_SESSION_KEY_SIZE); + + // buffers + conn->message_block = Array_u8_alloc(MESSAGE_BLOCK_COUNT_MAX * (sizeof(MessageMeta) + MESSAGE_SIZE_MAX)); + conn->message_content = Array_u8_alloc(MESSAGE_SIZE_MAX); + + // database + try(conn->db, p, tsqlite_connection_open(args->server->db_path)); + try(conn->queries.common, p, CommonQueries_compile(conn->db)); + // correct session key will be received from client later + conn->session_key = Array_u8_alloc(AES_SESSION_KEY_SIZE); Array_u8_memset(&conn->session_key, 0); EncryptedSocketTCP_construct(&conn->sock, args->accepted_socket_tcp, NETWORK_BUFFER_SIZE, conn->session_key); try_void(socket_TCP_enableAliveChecks_default(args->accepted_socket_tcp)); @@ -57,4 +67,4 @@ Result(ClientConnection*) ClientConnection_accept(ConnectionHandlerArgs* args) success = true; Return RESULT_VALUE(p, conn); -} \ No newline at end of file +} diff --git a/src/server/db/Channel.c b/src/server/db/Channel.c new file mode 100644 index 0000000..931c4fa --- /dev/null +++ b/src/server/db/Channel.c @@ -0,0 +1,116 @@ +#include "db_internal.h" + +Result(bool) Channel_exists(CommonQueries* q, i64 id){ + Deferral(1); + tsqlite_statement* st = q->channels.exists; + try_void(tsqlite_statement_reset(st)); + try_void(tsqlite_statement_bind_i64(st, "id", id)); + try(bool has_result, i, tsqlite_statement_step(st)); + Return RESULT_VALUE(i, has_result); +} + +Result(void) Channel_createOrUpdate(CommonQueries* q, + i64 id, str name, str description) +{ + Deferral(4); + try_assert(id > 0); + try_assert(name.len >= CHANNEL_NAME_SIZE_MIN && name.len <= CHANNEL_NAME_SIZE_MAX); + try_assert(description.len <= CHANNEL_DESC_SIZE_MAX); + + // create channels table + try_void(tsqlite_statement_reset(q->channels.create_table)); + try_void(tsqlite_statement_bind_i64(q->channels.create_table, "name_max", CHANNEL_NAME_SIZE_MAX)); + try_void(tsqlite_statement_bind_i64(q->channels.create_table, "desc_max", CHANNEL_DESC_SIZE_MAX)); + try_void(tsqlite_statement_step(q->channels.create_table)); + + // create messages table + try_void(tsqlite_statement_reset(q->messages.create_table)); + try_void(tsqlite_statement_step(q->messages.create_table)); + + try(bool channel_exists, i, Channel_exists(q, id)); + if(channel_exists){ + // update existing channel + try_void(tsqlite_statement_reset(q->channels.update)); + try_void(tsqlite_statement_bind_i64(q->channels.update, "id", id)); + try_void(tsqlite_statement_bind_str(q->channels.update, "name", str_copy(name), free)); + try_void(tsqlite_statement_bind_str(q->channels.update, "description", str_copy(description), free)); + try_void(tsqlite_statement_step(q->channels.update)); + } + else { + // insert new channel + try_void(tsqlite_statement_reset(q->channels.insert)); + try_void(tsqlite_statement_bind_i64(q->channels.insert, "id", id)); + try_void(tsqlite_statement_bind_str(q->channels.insert, "name", str_copy(name), free)); + try_void(tsqlite_statement_bind_str(q->channels.insert, "description", str_copy(description), free)); + try_void(tsqlite_statement_step(q->channels.insert)); + } + + Return RESULT_VOID; +} + +Result(void) Channel_saveMessage(CommonQueries* q, + i64 channel_id, i64 sender_id, Array(u8) content, + DateTime* out_timestamp) +{ + Deferral(1); + try_assert(content.len >= MESSAGE_SIZE_MIN && content.len <= MESSAGE_SIZE_MAX); + + tsqlite_statement* st = q->messages.insert; + try_void(tsqlite_statement_reset(st)); + try_void(tsqlite_statement_bind_i64(st, "channel_id", channel_id)); + try_void(tsqlite_statement_bind_i64(st, "sender_id", sender_id)); + try_void(tsqlite_statement_bind_blob(st, "content", Array_u8_copy(content), free)); + + try(bool has_result, i, tsqlite_statement_step(st)); + try_assert(has_result); + + try(i64 message_id, i, tsqlite_statement_getResult_i64(st)); + str timestamp_str; + try_void(tsqlite_statement_getResult_str(st, ×tamp_str)); + try_void(DateTime_parse(timestamp_str.data, out_timestamp)); + + Return RESULT_VALUE(i, message_id); +} + +Result(void) Channel_loadMessageBlock(CommonQueries* q, + i64 channel_id, i64 first_message_id, u32 count, + MessageBlockMeta* block_meta, Array(u8) block_data) +{ + Deferral(1); + try_assert(channel_id > 0); + try_assert(block_data.len >= count * (sizeof(MessageMeta) + MESSAGE_SIZE_MAX)); + if(count == 0){ + Return RESULT_VOID; + } + + tsqlite_statement* st = q->messages.get_block; + try_void(tsqlite_statement_reset(st)); + try_void(tsqlite_statement_bind_i64(st, "channel_id", channel_id)); + try_void(tsqlite_statement_bind_i64(st, "first_message_id", first_message_id)); + try_void(tsqlite_statement_bind_i64(st, "count", count)); + + zeroStruct(block_meta); + MessageMeta msg_meta = {0}; + Array(u8) msg_content; + str tmp_str = str_null; + while(true){ + try(bool has_result, i, tsqlite_statement_step(st)); + if(!has_result) + break; + + // id + try(msg_meta.id, i, tsqlite_statement_getResult_i64(st)); + // sender_id + try(msg_meta.sender_id, i, tsqlite_statement_getResult_i64(st)); + // content + try_void(tsqlite_statement_getResult_blob(st, &msg_content)); + // timestamp + try_void(tsqlite_statement_getResult_str(st, &tmp_str)); + try_void(DateTime_parse(tmp_str.data, &msg_meta.timestamp)); + + try(u32 write_n, u, MessageBlock_writeMessage(&msg_meta, msg_content, block_meta, &block_data)); + try_assert(write_n > 0); + } + + Return RESULT_VOID; +} diff --git a/src/server/db/CommonQueries.c b/src/server/db/CommonQueries.c new file mode 100644 index 0000000..b06a8c6 --- /dev/null +++ b/src/server/db/CommonQueries.c @@ -0,0 +1,82 @@ +#include "db_internal.h" + +void CommonQueries_free(CommonQueries* q){ + if(!q) + return; + tsqlite_statement_free(q->channels.create_table); + tsqlite_statement_free(q->channels.insert); + tsqlite_statement_free(q->channels.exists); + tsqlite_statement_free(q->channels.update); + tsqlite_statement_free(q->messages.create_table); + tsqlite_statement_free(q->messages.insert); + tsqlite_statement_free(q->messages.get_block); + free(q); +} + +Result(void) CommonQueries_compile(tsqlite_connection* db){ + Deferral(4); + + CommonQueries* q = (CommonQueries*)malloc(sizeof(*q)); + zeroStruct(q); + bool success = false; + Defer(if(!success) CommonQueries_free(q)); + + /////////////////////////////////////////////////////////////////////////// + // CHANNELS // + /////////////////////////////////////////////////////////////////////////// + try(q->channels.create_table, p, tsqlite_statement_compile(db, STR( + "CREATE TABLE IF NOT EXISTS channels (\n" + " id BIGINT PRIMARY KEY,\n" + " name VARCHAR($name_max) NOT NULL,\n" + " description VARCHAR($desc_max) NOT NULL\n" + ");" + ))); + + try(q->channels.insert, p, tsqlite_statement_compile(db, STR( + "INSERT INTO\n" + "channels (id, name, description)\n" + "VALUES ($id, $name, $description);" + ))); + + try(q->channels.exists, p, tsqlite_statement_compile(db, STR( + "SELECT 1 FROM channels WHERE id = $id;" + ))); + + try(q->channels.update, p, tsqlite_statement_compile(db, STR( + "UPDATE channels\n" + "SET name = $name, description = $description\n" + "WHERE id = $id;" + ))); + + /////////////////////////////////////////////////////////////////////////// + // MESSAGES // + /////////////////////////////////////////////////////////////////////////// + try(q->messages.create_table, p, tsqlite_statement_compile(db, STR( + "CREATE TABLE IF NOT EXISTS messages (\n" + " id BIGINT PRIMARY KEY,\n" + " channel_id BIGINT NOT NULL REFERENCES channels(id)\n" + " sender_id BIGINT NOT NULL REFERENCES users(id),\n" + " content BLOB NOT NULL,\n" + " timestamp DATETIME NOT NULL DEFAULT (\n" + " strftime('"MESSAGE_TIMESTAMP_FMT_SQL"', 'now', 'utc', 'subsecond')\n" + " )\n" + ");" + ))); + + try(q->messages.insert, p, tsqlite_statement_compile(db, STR( + "INSERT INTO\n" + "messages (channel_id, sender_id, content)\n" + "VALUES ($channel_id, $sender_id, $content)\n" + "RETURNING id, timestamp;" + ))); + + try(q->messages.get_block, p, tsqlite_statement_compile(db, STR( + "SELECT id, sender_id, content, timestamp FROM messages\n" + "WHERE id >= $first_message_id\n" + "AND channel_id = $channel_id\n" + "LIMIT $count;" + ))); + + success = true; + Return RESULT_VALUE(p, q); +} diff --git a/src/server/db/User.c b/src/server/db/User.c new file mode 100644 index 0000000..51490d8 --- /dev/null +++ b/src/server/db/User.c @@ -0,0 +1,3 @@ +#include "db.h" + + diff --git a/src/server/db/db.h b/src/server/db/db.h new file mode 100644 index 0000000..2a04f04 --- /dev/null +++ b/src/server/db/db.h @@ -0,0 +1,39 @@ +#pragma once +#include "tsqlite.h" +#include "network/tcp-chat-protocol/v1.h" + +// typedef struct ChannelInfo { +// i64 id; +// str name; +// str description; +// } ChannelInfo; + +typedef struct CommonQueries CommonQueries; + +Result(CommonQueries*) CommonQueries_compile(tsqlite_connection* db); +void CommonQueries_free(CommonQueries* self); + +Result(bool) Channel_exists(CommonQueries* q, i64 id); + +Result(void) Channel_createOrUpdate(CommonQueries* q, + i64 id, str name, str description); + +/// @return new message id +Result(i64) Channel_saveMessage(CommonQueries* q, + i64 channel_id, i64 sender_id, Array(u8) content, + DateTime* out_timestamp_utc); + +/// @brief try to find `count` messages starting from `first_message_id` +/// @param out_meta writes here information about found messages, .count can be 0 if no messages found +/// @param out_block .len must be >= count * (sizeof(MessageMeta) + MESSAGE_SIZE_MAX) +Result(void) Channel_loadMessageBlock(CommonQueries* q, + i64 channel_id, i64 first_message_id, u32 count, + MessageBlockMeta* out_block_meta, Array(u8) out_block_data); + +/// @return existing user id or 0 +Result(i64) User_getIdForUsername(CommonQueries* q, str username); + +/// @return new user id +Result(i64) User_register(CommonQueries* q, str username, Array(u8) token); + +Result(bool) User_tryAuthorize(CommonQueries* q, u64 id, Array(u8) token); diff --git a/src/server/db/db_internal.h b/src/server/db/db_internal.h new file mode 100644 index 0000000..cc33a9f --- /dev/null +++ b/src/server/db/db_internal.h @@ -0,0 +1,33 @@ +#pragma once +#include "db.h" + +typedef struct CommonQueries { + struct { + /* (name_max, desc_max) -> void */ + tsqlite_statement* create_table; + /* (id, name, description) -> void */ + tsqlite_statement* insert; + /* (id) -> bool */ + tsqlite_statement* exists; + /* (id, name, description) -> void */ + tsqlite_statement* update; + } channels; + struct { + /* () -> void */ + tsqlite_statement* create_table; + /* (channel_id, sender_id, content) -> (id, timestamp) */ + tsqlite_statement* insert; + /* (channel_id, first_message_id, count) -> [(id, sender_id, content, timestamp)] */ + tsqlite_statement* get_block; + } messages; + struct { + tsqlite_statement* create_table; + tsqlite_statement* registration_begin; + tsqlite_statement* registration_end; + tsqlite_statement* get_credentials; + tsqlite_statement* get_public_info; + tsqlite_statement* update; + } users; +} CommonQueries; + +#define MESSAGE_TIMESTAMP_FMT_SQL "%Y.%m.%d-%H:%M:%f" diff --git a/src/server/responses/GetMessageBlock.c b/src/server/responses/GetMessageBlock.c index 07ac415..762145d 100644 --- a/src/server/responses/GetMessageBlock.c +++ b/src/server/responses/GetMessageBlock.c @@ -20,28 +20,38 @@ declare_RequestHandler(GetMessageBlock) LogSeverity_Warn, STR("not authorized") )); Return RESULT_VOID; } + + // validate message_count + if(req.message_count < 1 || req.message_count > MESSAGE_BLOCK_COUNT_MAX){ + try_void(sendErrorMessage(log_ctx, conn, res_head, + LogSeverity_Warn, STR("invalid message count in request") )); + Return RESULT_VOID; + } - // get message block from channel - Channel* ch = Server_tryGetChannel(conn->server, req.channel_id); - if(ch == NULL){ + // validate channel id + try(bool channel_exists, i, Channel_exists(conn->queries.common, req.channel_id)); + if(!channel_exists){ try_void(sendErrorMessage(log_ctx, conn, res_head, LogSeverity_Warn, STR("invalid channel id") )); Return RESULT_VOID; } - MessageBlockMeta meta; - Array(u8) block_data; - try_void(Channel_loadMessageBlock(ch, req.first_message_id, req.messages_count, - &meta, &block_data, true)); - Defer(Array_u8_destroy(&block_data)); + + // reset block meta + zeroStruct(&conn->message_block_meta); + // get message block from channel + try_void(Channel_loadMessageBlock(conn->queries.common, + req.channel_id, req.first_message_id, req.message_count, + &conn->message_block_meta, conn->message_block)); // send response GetMessageBlockResponse res; - GetMessageBlockResponse_construct(&res, res_head, - meta.first_message_id, meta.messages_count, meta.data_size); + GetMessageBlockResponse_construct(&res, res_head, &conn->message_block_meta); try_void(EncryptedSocketTCP_sendStruct(&conn->sock, res_head)); try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &res)); - if(block_data.len != 0){ - try_void(EncryptedSocketTCP_send(&conn->sock, block_data)); + if(conn->message_block_meta.data_size != 0){ + try_void(EncryptedSocketTCP_send(&conn->sock, + Array_u8_sliceTo(conn->message_block, conn->message_block_meta.data_size)) + ); } Return RESULT_VOID; diff --git a/src/server/responses/Login.c b/src/server/responses/Login.c index 4d62d4a..41c2f7b 100644 --- a/src/server/responses/Login.c +++ b/src/server/responses/Login.c @@ -22,8 +22,8 @@ declare_RequestHandler(Login) } // validate username - str username_str = str_null; - str name_error_str = validateUsername_cstr(req.username, &username_str); + str username = str_null; + str name_error_str = validateUsername_cstr(req.username, &username); if(name_error_str.data){ Defer(str_destroy(name_error_str)); try_void(sendErrorMessage(log_ctx, conn, res_head, @@ -31,42 +31,28 @@ declare_RequestHandler(Login) Return RESULT_VOID; } - // lock users cache - idb_lockTable(srv->users.table); - bool unlocked_users_cache_mutex = false; - Defer( - if(!unlocked_users_cache_mutex) - idb_unlockTable(srv->users.table) - ); - - // try get id from name cache - u64* id_ptr = HashMap_tryGetPtr(&srv->users.name_id_map, username_str); - if(id_ptr == NULL){ + // get user by id + try(u64 user_id, i, User_getIdForUsername(conn->queries.common, username)); + if(user_id == 0){ try_void(sendErrorMessage(log_ctx, conn, res_head, LogSeverity_Warn, STR("Username is not registered") )); Return RESULT_VOID; } - u64 user_id = *id_ptr; - - // get user by id - try_assert(user_id < srv->users.list.len); - UserInfo* u = srv->users.list.data + user_id; + // TODO: get user token + Array(u8) token = Array_u8_construct(req.token, sizeof(req.token)); + try(bool authorized, i, User_tryAuthorize(conn->queries.common, user_id, token)); // validate token hash - if(memcmp(req.token, u->token, sizeof(req.token)) != 0){ + if(!authorized){ try_void(sendErrorMessage(log_ctx, conn, res_head, LogSeverity_Warn, STR("wrong password") )); Return RESULT_VOID; } - // manually unlock mutex - idb_unlockTable(srv->users.table); - unlocked_users_cache_mutex = true; - // authorize conn->authorized = true; conn->user_id = user_id; - logInfo("authorized user '%s'", username_str.data); + logInfo("authorized user '%s'", username.data); // send response LoginResponse res; diff --git a/src/server/responses/Register.c b/src/server/responses/Register.c index c0eb0ca..22d1462 100644 --- a/src/server/responses/Register.c +++ b/src/server/responses/Register.c @@ -22,8 +22,8 @@ declare_RequestHandler(Register) } // validate username - str username_str = str_null; - str name_error_str = validateUsername_cstr(req.username, &username_str); + str username = str_null; + str name_error_str = validateUsername_cstr(req.username, &username); if(name_error_str.data){ Defer(str_destroy(name_error_str)); try_void(sendErrorMessage(log_ctx, conn, res_head, @@ -31,46 +31,18 @@ declare_RequestHandler(Register) Return RESULT_VOID; } - // lock users cache - idb_lockTable(srv->users.table); - bool unlocked_users_cache_mutex = false; - // unlock mutex on error catch - Defer( - if(!unlocked_users_cache_mutex) - idb_unlockTable(srv->users.table) - ); - // check if name is taken - if(HashMap_tryGetPtr(&srv->users.name_id_map, username_str) != NULL){ + try(u64 user_id, i, User_getIdForUsername(conn->queries.common, username)); + if(user_id != 0){ try_void(sendErrorMessage(log_ctx, conn, res_head, - LogSeverity_Warn, STR("Username already exists") )); + LogSeverity_Warn, STR("Username is already taken") )); Return RESULT_VOID; } - - // initialize new user - UserInfo user; - zeroStruct(&user); - - user.name_len = username_str.len; - memcpy(user.name, username_str.data, user.name_len); - - memcpy(user.token, req.token, sizeof(req.token)); - DateTime_getUTC(&user.registration_time_utc); - - // save new user to db and cache - try(u64 user_id, u, - idb_pushRow(srv->users.table, &user, false) - ); - try_assert(user_id == srv->users.list.len); - List_UserInfo_pushMany(&srv->users.list, &user, 1); - try_assert(HashMap_tryPush(&srv->users.name_id_map, username_str, &user_id)); - - // manually unlock mutex - idb_unlockTable(srv->users.table); - unlocked_users_cache_mutex = true; - - logInfo("registered user '%s'", username_str.data); + // register new user + Array(u8) token = Array_u8_construct(req.token, sizeof(req.token)); + try(user_id, i, User_register(conn->queries.common, username, token)); + logInfo("registered user '"FMT_str"'", str_expand(username)); // send response RegisterResponse res; diff --git a/src/server/responses/SendMessage.c b/src/server/responses/SendMessage.c index e742b90..4656845 100644 --- a/src/server/responses/SendMessage.c +++ b/src/server/responses/SendMessage.c @@ -21,6 +21,7 @@ declare_RequestHandler(SendMessage) Return RESULT_VOID; } + // validate content size if(req.data_size < MESSAGE_SIZE_MIN || req.data_size > MESSAGE_SIZE_MAX){ try_void(sendErrorMessage(log_ctx, conn, res_head, LogSeverity_Warn, STR("invalid message size") )); @@ -29,24 +30,26 @@ declare_RequestHandler(SendMessage) } // receive message data - Array(u8) message_data = Array_u8_alloc(req.data_size); - try_void(EncryptedSocketTCP_recv(&conn->sock, message_data, SocketRecvFlag_WholeBuffer)); + try_void(EncryptedSocketTCP_recv(&conn->sock, conn->message_content, SocketRecvFlag_WholeBuffer)); - // save message to channel - Channel* ch = Server_tryGetChannel(conn->server, req.channel_id); - if(ch == NULL){ + // validate channel id + try(bool channel_exists, i, Channel_exists(conn->queries.common, req.channel_id)); + if(!channel_exists){ try_void(sendErrorMessage(log_ctx, conn, res_head, LogSeverity_Warn, STR("invalid channel id") )); Return RESULT_VOID; } - MessageMeta message_meta; - try_void(Channel_saveMessage(ch, message_data, conn->user_id, - &message_meta, true)); + + // save message to channel + DateTime timestamp; + try(i64 message_id, i, Channel_saveMessage(conn->queries.common, + req.channel_id, conn->user_id, conn->message_content, + ×tamp)); // send response SendMessageResponse res; SendMessageResponse_construct(&res, res_head, - message_meta.id, message_meta.receiving_time_utc); + message_id, timestamp); try_void(EncryptedSocketTCP_sendStruct(&conn->sock, res_head)); try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &res)); diff --git a/src/server/responses/responses.h b/src/server/responses/responses.h index e27c2f9..246755c 100644 --- a/src/server/responses/responses.h +++ b/src/server/responses/responses.h @@ -1,8 +1,6 @@ #pragma once -#include "network/tcp-chat-protocol/v1.h" #include "server/server_internal.h" - Result(void) sendErrorMessage( cstr log_ctx, ClientConnection* conn, PacketHeader* res_head, LogSeverity log_severity, str msg); diff --git a/src/server/server.c b/src/server/server.c index 77ee057..99793d7 100644 --- a/src/server/server.c +++ b/src/server/server.c @@ -4,7 +4,6 @@ #include "tlibc/base64.h" #include "tlibc/algorithms.h" #include "server/server_internal.h" -#include "network/tcp-chat-protocol/v1.h" #include "server/responses/responses.h" #include "tlibtoml.h" @@ -20,17 +19,8 @@ void Server_free(Server* self){ RSA_destroyPrivateKey(&self->rsa_sk); RSA_destroyPublicKey(&self->rsa_pk); - idb_close(self->db); - - List_UserInfo_destroy(&self->users.list); - HashMap_destroy(&self->users.name_id_map); - - for(u32 i = 0; i < self->channels.list.len; i++){ - Channel_free(self->channels.list.data[i]); - } - List_ChannelPtr_destroy(&self->channels.list); - HashMap_destroy(&self->channels.name_id_map); - + free(self->db_path); + tsqlite_connection_close(self->db); free(self); } @@ -58,78 +48,49 @@ Result(Server*) Server_create(str config_file_content, cstr config_file_name, try(TomlTable* config_top, p, toml_load_str_filename(config_file_content, config_file_name)); Defer(TomlTable_free(config_top)); + // [server] try(TomlTable* config_server, p, TomlTable_get_table(config_top, STR("server"))) - // parse name + // name try(str* v_name, p, TomlTable_get_str(config_server, STR("name"))); self->name = str_copy(*v_name); - // parse description + // description try(str* v_desc, p, TomlTable_get_str(config_server, STR("description"))); self->description = str_copy(*v_desc); - // parse local_address + // local_address try(str* v_local_address, p, TomlTable_get_str(config_server, STR("local_address"))); try_assert(v_local_address->isZeroTerminated); try_void(EndpointIPv4_parse(v_local_address->data, &self->local_end)); - // parse landing_channel_id + // landing_channel_id try(i64 v_landing_channel_id, i, TomlTable_get_integer(config_server, STR("landing_channel_id"))); self->landing_channel_id = v_landing_channel_id; + // [keys] try(TomlTable* config_keys, p, TomlTable_get_table(config_top, STR("keys"))) - // parse rsa_private_key + // rsa_private_key try(str* v_rsa_sk, p, TomlTable_get_str(config_keys, STR("rsa_private_key"))); try_assert(v_rsa_sk->isZeroTerminated); try_void(RSA_parsePrivateKey_base64(v_rsa_sk->data, &self->rsa_sk)); - // parse rsa_public_key + // rsa_public_key try(str* v_rsa_pk, p, TomlTable_get_str(config_keys, STR("rsa_public_key"))); try_assert(v_rsa_pk->isZeroTerminated); try_void(RSA_parsePublicKey_base64(v_rsa_pk->data, &self->rsa_pk)); + // [db] try(TomlTable* config_db, p, TomlTable_get_table(config_top, STR("database"))) - // parse db_aes_key - try(str* v_db_aes_key, p, TomlTable_get_str(config_db, STR("aes_key"))); - str db_aes_key_s = *v_db_aes_key; - Array(u8) db_aes_key = Array_u8_alloc(base64_decodedSize(db_aes_key_s.data, db_aes_key_s.len)); - Defer(free(db_aes_key.data)); - base64_decode(db_aes_key_s.data, db_aes_key_s.len, db_aes_key.data); - - // parse db_dir - try(str* v_db_dir, p, TomlTable_get_str(config_db, STR("dir"))); + // path + try(str* v_db_path, p, TomlTable_get_str(config_db, STR("path"))); + self->db_path = str_copy(*v_db_path).data; // open DB - try(self->db, p, idb_open(*v_db_dir, db_aes_key)); + try(self->db, p, tsqlite_connection_open(self->db_path)); - // build users cache - logInfo("loading users..."); - try(self->users.table, p, - idb_getOrCreateTable(self->db, str_null, STR("users"), sizeof(UserInfo), false) - ); - - // load whole users table to list - try_void( - idb_createListFromTable(self->users.table, (void*)&self->users.list, false) - ); - - // build name-id map - try(u64 users_count, u, idb_getRowCount(self->users.table, false)); - HashMap_construct(&self->users.name_id_map, u64, NULL); - for(u64 id = 0; id < users_count; id++){ - UserInfo* user = self->users.list.data + id; - str key = str_construct(user->name, user->name_len, true); - if(!HashMap_tryPush(&self->users.name_id_map, key, &id)){ - Return RESULT_ERROR_FMT("duplicate user name '"FMT_str"'", str_expand(key)); - } - } - logInfo("loaded "FMT_u64" users", users_count); - - // parse channels + // [channels] logDebug("loading channels..."); - HashMap_construct(&self->channels.name_id_map, u64, NULL); - self->channels.list = List_ChannelPtr_alloc(32); try(TomlTable* config_channels, p, TomlTable_get_table(config_top, STR("channels"))); - HashMapIter channels_iter = HashMapIter_create(config_channels); while(HashMapIter_moveNext(&channels_iter)){ HashMapKeyValue kv; @@ -142,23 +103,12 @@ Result(Server*) Server_create(str config_file_content, cstr config_file_name, logInfo("loading channel '"FMT_str"'", str_expand(name)) TomlTable* config_channel = val->table; - try(u64 id, u, TomlTable_get_integer(config_channel, STR("id"))); + try(i64 id, u, TomlTable_get_integer(config_channel, STR("id"))); try(str* v_ch_desc, p, TomlTable_get_str(config_channel, STR("description"))) str description = *v_ch_desc; - if(!HashMap_tryPush(&self->channels.name_id_map, name, &id)){ - Return RESULT_ERROR_FMT("duplicate channel '"FMT_str"'", str_expand(name)); - } - - logDebug("loading messages..."); - try(Channel* channel, p, Channel_create(id, name, description, self->db, false)); - logDebug("loaded "FMT_u64" messages", channel->messages.count); - List_ChannelPtr_push(&self->channels.list, channel); + try_void(Channel_createOrUpdate(self->server_queries, id, name, description)); } - logDebug("loaded "FMT_u32" channels", self->channels.list.len); - // sort channels list by id - for(u32 i = 0; i < self->channels.list.len; i++); - insertionSort_inline(self->channels.list.data, self->channels.list.len, ->id); success = true; Return RESULT_VALUE(p, self); @@ -184,7 +134,7 @@ Result(void) Server_run(Server* server){ Defer(free(local_end_str.data)); logInfo("server is listening at %s", local_end_str.data); - u64 session_id = 1; + i64 session_id = 1; while(true){ ConnectionHandlerArgs* args = (ConnectionHandlerArgs*)malloc(sizeof(ConnectionHandlerArgs)); args->server = server; @@ -201,16 +151,6 @@ Result(void) Server_run(Server* server){ Return RESULT_VOID; } -Channel* Server_tryGetChannel(Server* self, u64 id){ - i32 index; - binarySearch_inline(self->channels.list.data, self->channels.list.len, id, ->id, index); - if(index == -1){ - return NULL; - } - Channel* channel = self->channels.list.data[index]; - return channel; -} - static void* handleConnection(void* _args){ ConnectionHandlerArgs* args = (ConnectionHandlerArgs*)_args; Server* server = args->server; diff --git a/src/server/server_internal.h b/src/server/server_internal.h index df5ef80..ee0d161 100644 --- a/src/server/server_internal.h +++ b/src/server/server_internal.h @@ -1,56 +1,14 @@ #pragma once -#include "tlibc/collections/HashMap.h" -#include "tlibc/collections/List.h" -#include "tlibc/collections/LList.h" #include "tcp-chat/tcp-chat.h" #include "tcp-chat/server.h" #include "cryptography/AES.h" #include "cryptography/RSA.h" #include "network/encrypted_sockets.h" -#include "db/idb.h" -#include "db/tables.h" +#include "network/tcp-chat-protocol/v1.h" +#include "server/db/db.h" typedef struct ClientConnection ClientConnection; -List_declare(UserInfo); -List_declare(MessageBlockMeta); -LList_declare(MessageBlock); - -#define MESSAGE_BLOCKS_CACHE_COUNT 50 - -typedef struct Channel { - u64 id; - str name; - str description; - struct { - u64 count; - Table* blocks_table; - Table* blocks_meta_table; - List(MessageBlockMeta) blocks_meta_list; // index is id - // last MESSAGE_BLOCKS_CACHE_COUNT MessageBlocks, ascending - // new messages are written to the last block - LList(MessageBlock) blocks_queue; - } messages; -} Channel; - -typedef Channel* ChannelPtr; -List_declare(ChannelPtr); - -Result(Channel*) Channel_create(u64 id, str name, str description, - IncrementalDB* db, bool lock_db); - -void Channel_free(Channel* self); - -Result(void) Channel_saveMessage(Channel* self, Array(u8) message_data, u64 sender_id, - MessageMeta* out_message_meta, bool lock_tables); - -/// @brief try to find `count` messages starting from `fisrt_message_id` -/// @param out_meta information about found messages, .count can be 0 if no messages found -/// @param out_block allocates buffer on heap and copies them there, .len can be 0 if no messages found -Result(void) Channel_loadMessageBlock(Channel* self, u64 fisrt_message_id, u32 count, - MessageBlockMeta* out_meta, NULLABLE(Array(u8)*) out_block, bool lock_tables); - - typedef struct Server { /* from constructor */ void* logger; @@ -59,42 +17,43 @@ typedef struct Server { /* from config */ str name; str description; - u64 landing_channel_id; + i64 landing_channel_id; EndpointIPv4 local_end; br_rsa_private_key rsa_sk; br_rsa_public_key rsa_pk; - /* database and cache */ - IncrementalDB* db; - struct { - Table* table; - List(UserInfo) list; // index is id - HashMap(u64) name_id_map; - } users; - struct { - List(ChannelPtr) list; - HashMap(u64) name_id_map; - } channels; + /* database and cache*/ + char* db_path; + tsqlite_connection* db; } Server; -Channel* Server_tryGetChannel(Server* self, u64 id); - typedef struct ClientConnection { Server* server; - u64 session_id; + i64 session_id; EndpointIPv4 client_end; Array(u8) session_key; EncryptedSocketTCP sock; bool authorized; - u64 user_id; // -1 for unauthorized + i64 user_id; // 0 for unauthorized + + /* buffers */ + MessageBlockMeta message_block_meta; + Array(u8) message_block; + Array(u8) message_content; + + /* database */ + tsqlite_connection* db; + struct { + CommonQueries* common; + } queries; } ClientConnection; typedef struct ConnectionHandlerArgs { Server* server; Socket accepted_socket_tcp; EndpointIPv4 client_end; - u64 session_id; + i64 session_id; } ConnectionHandlerArgs; Result(ClientConnection*) ClientConnection_accept(ConnectionHandlerArgs* args); diff --git a/tcp-chat-server.toml.default b/tcp-chat-server.toml.default index 0349bb3..52f31fd 100644 --- a/tcp-chat-server.toml.default +++ b/tcp-chat-server.toml.default @@ -5,16 +5,15 @@ description = """\ Qqqqq...\ """ local_address = '127.0.0.1:9988' -landing_channel_id = 0 +landing_channel_id = 1 # do not create channels with the same id [channels.general] -id = 0 +id = 1 description = "a text channel" [database] -dir = 'server-db' -aes_key = '' +path = 'server-db/server.sqlite' [keys] rsa_private_key = ''