diff --git a/include/tcp-chat/tcp-chat.h b/include/tcp-chat/tcp-chat.h new file mode 100644 index 0000000..4f231e3 --- /dev/null +++ b/include/tcp-chat/tcp-chat.h @@ -0,0 +1,14 @@ +#pragma once +#include "tlibc/errors.h" + +/// requires tlibc and tlibtoml init +Result(void) TcpChat_init(); +void TcpChat_deinit(); + +ErrorCodePage_declare(WINSOCK2); +ErrorCodePage_declare(TcpChat); + +typedef enum TcpChatError { + TcpChatError_Unknown, + TcpChatError_RejectIncoming, +} TcpChatError; diff --git a/src/cli/main.c b/src/cli/main.c index 3229afe..7df6eac 100644 --- a/src/cli/main.c +++ b/src/cli/main.c @@ -1,7 +1,8 @@ #include "tlibc/tlibc.h" -#include "network/network.h" +#include "tlibtoml.h" +#include "tcp-chat/tcp-chat.h" #include "cryptography/RSA.h" -#include "modes/modes.h" +#include "cli/modes/modes.h" #define _DEFAULT_CONFIG_PATH_CLIENT "tcp-chat-client.config" #define _DEFAULT_CONFIG_PATH_SERVER "tcp-chat-server.toml" @@ -11,11 +12,12 @@ int main(const int argc, cstr const* argv){ Deferral(32); - try_fatal_void(tlibc_init()); Defer(tlibc_deinit()); - try_fatal_void(network_init()); - Defer(network_deinit()); + try_fatal_void(tlibtoml_init()); + Defer(tlibtoml_deinit()); + try_fatal_void(TcpChat_init()); + Defer(TcpChat_deinit()); if(br_prng_seeder_system(NULL) == NULL){ printfe("Can't get system random seeder. Bearssl is compiled incorrectly."); diff --git a/src/network/internal.h b/src/network/internal.h index 5a669a3..e79d725 100644 --- a/src/network/internal.h +++ b/src/network/internal.h @@ -1,7 +1,6 @@ #pragma once -#include "tlibc/errors.h" +#include "tcp-chat/tcp-chat.h" #include "endpoint.h" -#include "network.h" #if !defined(KN_USE_WINSOCK) #if defined(_WIN64) || defined(_WIN32) diff --git a/src/network/network.h b/src/network/network.h deleted file mode 100755 index 1eb807a..0000000 --- a/src/network/network.h +++ /dev/null @@ -1,7 +0,0 @@ -#pragma once -#include "tlibc/errors.h" - -ErrorCodePage_declare(WINSOCK2); - -Result(void) network_init(); -void network_deinit(); diff --git a/src/network/tcp-chat-protocol/v1.c b/src/network/tcp-chat-protocol/v1.c index 2758abf..be9fe4d 100644 --- a/src/network/tcp-chat-protocol/v1.c +++ b/src/network/tcp-chat-protocol/v1.c @@ -161,10 +161,11 @@ void SendMessageResponse_construct(SendMessageResponse *ptr, PacketHeader *heade } void GetMessageBlockRequest_construct(GetMessageBlockRequest *ptr, PacketHeader *header, - u64 first_message_id, u32 messages_count) + u64 channel_id, u64 first_message_id, u32 messages_count) { _PacketHeader_construct(GetMessageBlockRequest); zeroStruct(ptr); + ptr->channel_id = channel_id; ptr->first_message_id = first_message_id; ptr->messages_count = messages_count; } diff --git a/src/network/tcp-chat-protocol/v1.h b/src/network/tcp-chat-protocol/v1.h index d45f27a..a5db0c9 100644 --- a/src/network/tcp-chat-protocol/v1.h +++ b/src/network/tcp-chat-protocol/v1.h @@ -144,12 +144,13 @@ void SendMessageResponse_construct(SendMessageResponse* ptr, PacketHeader* heade typedef struct GetMessageBlockRequest { + u64 channel_id; u64 first_message_id; u32 messages_count; } ALIGN_PACKET_STRUCT GetMessageBlockRequest; void GetMessageBlockRequest_construct(GetMessageBlockRequest* ptr, PacketHeader* header, - u64 first_message_id, u32 messages_count); + u64 channel_id, u64 first_message_id, u32 messages_count); typedef struct GetMessageBlockResponse { diff --git a/src/server/Channel.c b/src/server/Channel.c index 125988a..f128590 100644 --- a/src/server/Channel.c +++ b/src/server/Channel.c @@ -58,7 +58,7 @@ Result(Channel*) Channel_create(u64 chan_id, str name, str description, StringBuilder_removeFromEnd(&sb, message_blocks_str.len); StringBuilder_append_str(&sb, message_blocks_meta_str); try(self->messages.blocks_meta_table, p, - idb_getOrCreateTable(db, subdir, StringBuilder_getStr(&sb), sizeof(MessageBlock), false) + idb_getOrCreateTable(db, subdir, StringBuilder_getStr(&sb), sizeof(MessageBlockMeta), false) ); idb_lockTable(self->messages.blocks_meta_table); Defer(idb_unlockTable(self->messages.blocks_meta_table)); @@ -96,11 +96,27 @@ Result(Channel*) Channel_create(u64 chan_id, str name, str description, Return RESULT_VALUE(p, self); } +void Channel_unloadExcessBlocks(Channel* self){ + while(self->messages.blocks_queue.count > MESSAGE_BLOCKS_CACHE_COUNT){ + LLNode(MessageBlock)* node = self->messages.blocks_queue.first; + LList_MessageBlock_detatch(&self->messages.blocks_queue, node); + free(node); + } +} Result(void) Channel_saveMessage(Channel* self, Array(u8) message_data, u64 sender_id, - MessageMeta* out_message_meta) + MessageMeta* out_message_meta, bool lock_tables) { - Deferral(1); + Deferral(4); + + if(lock_tables){ + idb_lockTable(self->messages.blocks_table); + idb_lockTable(self->messages.blocks_meta_table); + Defer( + idb_unlockTable(self->messages.blocks_table); + idb_unlockTable(self->messages.blocks_meta_table); + ); + } // create new block if message won't fit in the last existing MessageBlockMeta* incomplete_block_meta = self->messages.blocks_meta_list.data + self->messages.blocks_meta_list.len; @@ -117,35 +133,54 @@ Result(void) Channel_saveMessage(Channel* self, Array(u8) message_data, u64 send &self->messages.blocks_queue, self->messages.blocks_queue.last, LLNode_MessageBlock_createZero()); - //TODO: save to DB + // unload old blocks from cache + Channel_unloadExcessBlocks(self); } - // copy message to message block + // create message meta out_message_meta->magic = MESSAGE_MAGIC; out_message_meta->data_size = message_data.len; out_message_meta->id = new_message_id; out_message_meta->sender_id = sender_id; DateTime_getUTC(&out_message_meta->receiving_time_utc); + // copy message data to message block MessageBlock* incomplete_block = &self->messages.blocks_queue.last->value; u8* data_ptr = incomplete_block->data + incomplete_block_meta->data_size; memcpy(data_ptr, out_message_meta, sizeof(MessageMeta)); data_ptr += sizeof(MessageMeta); memcpy(data_ptr, message_data.data, message_data.len); + incomplete_block_meta->data_size += sizeof(MessageMeta) + ALIGN_TO(message_data.len, 8); + incomplete_block_meta->messages_count++; + + // save to DB + try_void(idb_pushRow(self->messages.blocks_meta_table, incomplete_block_meta, false)); + try_void(idb_pushRow(self->messages.blocks_table, incomplete_block, false)); Return RESULT_VOID; } -// Result(void) Channel_loadMessage(Channel* self, u64 id, -// MessageMeta* out_message_meta, u8* out_data) -// { -// Deferral(1); -// Return RESULT_VOID; -// } +Result(void) Channel_loadMessageBlock(Channel* self, u64 fisrt_message_id, u32 count, + MessageBlockMeta* out_meta, NULLABLE(Array(u8)*) out_block, bool lock_tables) +{ + Deferral(4); -// Result(void) Channel_loadMessageBlock(Channel* self, u64 fisrt_message_id, u32 count, -// MessageBlockMeta* out_message_meta, MessageBlock* out_block) -// { -// Deferral(1); -// Return RESULT_VOID; -// } + if(lock_tables){ + idb_lockTable(self->messages.blocks_table); + idb_lockTable(self->messages.blocks_meta_table); + Defer( + idb_unlockTable(self->messages.blocks_table); + idb_unlockTable(self->messages.blocks_meta_table); + ); + } + + // TODO: Maybe it's better to request message block id directly? Client doesn't know how much bytes `count` messages will take, this can lead to severe lags on slow internet + + // TODO: binary search in list of blocks meta + // TODO: return if out_block == NULL + // TODO: check if block is in N_LAST_BLOCKS + // TODO: load block + // TODO: insert block in queue and keep it sorted + + Return RESULT_VOID; +} diff --git a/src/server/responses/GetMessageBlock.c b/src/server/responses/GetMessageBlock.c index 2379308..07ac415 100644 --- a/src/server/responses/GetMessageBlock.c +++ b/src/server/responses/GetMessageBlock.c @@ -9,18 +9,40 @@ declare_RequestHandler(GetMessageBlock) { Deferral(4); logInfo("requested %s", req_type_name); - + // receive request GetMessageBlockRequest req; try_void(PacketHeader_validateContentSize(req_head, sizeof(req))); try_void(EncryptedSocketTCP_recvStruct(&conn->sock, &req)); - (void)res_head; + if(!conn->authorized){ + try_void(sendErrorMessage(log_ctx, conn, res_head, + LogSeverity_Warn, STR("not authorized") )); + Return RESULT_VOID; + } + + // get message block from channel + Channel* ch = Server_tryGetChannel(conn->server, req.channel_id); + if(ch == NULL){ + try_void(sendErrorMessage(log_ctx, conn, res_head, + LogSeverity_Warn, STR("invalid channel id") )); + Return RESULT_VOID; + } + MessageBlockMeta meta; + Array(u8) block_data; + try_void(Channel_loadMessageBlock(ch, req.first_message_id, req.messages_count, + &meta, &block_data, true)); + Defer(Array_u8_destroy(&block_data)); + // send response - // GetMessageBlockResponse res; - // GetMessageBlockResponse_construct(&res, res_head, ); - // try_void(EncryptedSocketTCP_sendStruct(&conn->sock, res_head)); - // try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &res)); + GetMessageBlockResponse res; + GetMessageBlockResponse_construct(&res, res_head, + meta.first_message_id, meta.messages_count, meta.data_size); + try_void(EncryptedSocketTCP_sendStruct(&conn->sock, res_head)); + try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &res)); + if(block_data.len != 0){ + try_void(EncryptedSocketTCP_send(&conn->sock, block_data)); + } Return RESULT_VOID; } diff --git a/src/server/responses/Login.c b/src/server/responses/Login.c index 68df2da..4d62d4a 100644 --- a/src/server/responses/Login.c +++ b/src/server/responses/Login.c @@ -17,9 +17,7 @@ declare_RequestHandler(Login) if(conn->authorized){ try_void(sendErrorMessage(log_ctx, conn, res_head, - LogSeverity_Warn, - STR("is authorized in already") - )); + LogSeverity_Warn, STR("is authorized in already") )); Return RESULT_VOID; } @@ -29,9 +27,7 @@ declare_RequestHandler(Login) if(name_error_str.data){ Defer(str_destroy(name_error_str)); try_void(sendErrorMessage(log_ctx, conn, res_head, - LogSeverity_Warn, - name_error_str - )); + LogSeverity_Warn, name_error_str)); Return RESULT_VOID; } @@ -46,11 +42,8 @@ declare_RequestHandler(Login) // try get id from name cache u64* id_ptr = HashMap_tryGetPtr(&srv->users.name_id_map, username_str); if(id_ptr == NULL){ - try_void(sendErrorMessage_f(log_ctx, conn, res_head, - LogSeverity_Warn, - "Username '%s' is not registered", - username_str.data - )); + try_void(sendErrorMessage(log_ctx, conn, res_head, + LogSeverity_Warn, STR("Username is not registered") )); Return RESULT_VOID; } u64 user_id = *id_ptr; @@ -62,9 +55,7 @@ declare_RequestHandler(Login) // validate token hash if(memcmp(req.token, u->token, sizeof(req.token)) != 0){ try_void(sendErrorMessage(log_ctx, conn, res_head, - LogSeverity_Warn, - STR("wrong password") - )); + LogSeverity_Warn, STR("wrong password") )); Return RESULT_VOID; } diff --git a/src/server/responses/Register.c b/src/server/responses/Register.c index 243485e..c0eb0ca 100644 --- a/src/server/responses/Register.c +++ b/src/server/responses/Register.c @@ -17,9 +17,7 @@ declare_RequestHandler(Register) if(conn->authorized){ try_void(sendErrorMessage(log_ctx, conn, res_head, - LogSeverity_Warn, - STR("is authorized in already") - )); + LogSeverity_Warn, STR("is authorized in already") )); Return RESULT_VOID; } @@ -29,9 +27,7 @@ declare_RequestHandler(Register) if(name_error_str.data){ Defer(str_destroy(name_error_str)); try_void(sendErrorMessage(log_ctx, conn, res_head, - LogSeverity_Warn, - name_error_str - )); + LogSeverity_Warn, name_error_str)); Return RESULT_VOID; } @@ -46,10 +42,8 @@ declare_RequestHandler(Register) // check if name is taken if(HashMap_tryGetPtr(&srv->users.name_id_map, username_str) != NULL){ - try_void(sendErrorMessage_f(log_ctx, conn, res_head, - LogSeverity_Warn, - "Username '%s' already exists", - username_str.data)); + try_void(sendErrorMessage(log_ctx, conn, res_head, + LogSeverity_Warn, STR("Username already exists") )); Return RESULT_VOID; } diff --git a/src/server/responses/SendMessage.c b/src/server/responses/SendMessage.c index 04477e5..e742b90 100644 --- a/src/server/responses/SendMessage.c +++ b/src/server/responses/SendMessage.c @@ -17,20 +17,15 @@ declare_RequestHandler(SendMessage) if(!conn->authorized){ try_void(sendErrorMessage(log_ctx, conn, res_head, - LogSeverity_Warn, - STR("is not authorized") - )); + LogSeverity_Warn, STR("not authorized") )); Return RESULT_VOID; } if(req.data_size < MESSAGE_SIZE_MIN || req.data_size > MESSAGE_SIZE_MAX){ - try_void(sendErrorMessage_f(log_ctx, conn, res_head, - LogSeverity_Warn, - "message size must be >= %i and <= %i", - MESSAGE_SIZE_MIN, MESSAGE_SIZE_MAX - )); - // this will close socket connection - Return RESULT_ERROR("invalid message size", false); + try_void(sendErrorMessage(log_ctx, conn, res_head, + LogSeverity_Warn, STR("invalid message size") )); + // close socket connection to reject incoming data + Return RESULT_ERROR_CODE_FMT(TcpChat, TcpChatError_RejectIncoming, "invalid message size: %u", req.data_size); } // receive message data @@ -41,13 +36,12 @@ declare_RequestHandler(SendMessage) Channel* ch = Server_tryGetChannel(conn->server, req.channel_id); if(ch == NULL){ try_void(sendErrorMessage(log_ctx, conn, res_head, - LogSeverity_Warn, - STR("invalid channel id") - )); + LogSeverity_Warn, STR("invalid channel id") )); Return RESULT_VOID; } MessageMeta message_meta; - try_void(Channel_saveMessage(ch, message_data, conn->user_id, &message_meta)); + try_void(Channel_saveMessage(ch, message_data, conn->user_id, + &message_meta, true)); // send response SendMessageResponse res; diff --git a/src/server/server.c b/src/server/server.c index c74526a..77ee057 100644 --- a/src/server/server.c +++ b/src/server/server.c @@ -221,7 +221,15 @@ static void* handleConnection(void* _args){ if(r.error){ Error_addCallPos(r.error, ErrorCallPos_here()); str e_str = Error_toStr(r.error); - logError(FMT_str, e_str.len, e_str.data); + LogSeverity severity = LogSeverity_Error; + + if(r.error->error_code_page == ErrorCodePage_TcpChat){ + if(r.error->error_code == TcpChatError_RejectIncoming){ + severity = LogSeverity_Debug; + } + } + + log(severity, FMT_str, e_str.len, e_str.data); str_destroy(e_str); Error_free(r.error); } diff --git a/src/server/server_internal.h b/src/server/server_internal.h index 8709182..df5ef80 100644 --- a/src/server/server_internal.h +++ b/src/server/server_internal.h @@ -2,6 +2,7 @@ #include "tlibc/collections/HashMap.h" #include "tlibc/collections/List.h" #include "tlibc/collections/LList.h" +#include "tcp-chat/tcp-chat.h" #include "tcp-chat/server.h" #include "cryptography/AES.h" #include "cryptography/RSA.h" @@ -41,14 +42,13 @@ Result(Channel*) Channel_create(u64 id, str name, str description, void Channel_free(Channel* self); Result(void) Channel_saveMessage(Channel* self, Array(u8) message_data, u64 sender_id, - MessageMeta* out_message_meta); - -/// @param out_data buffer of size >= MESSAGE_SIZE_MAX -Result(void) Channel_loadMessage(Channel* self, u64 id, - MessageMeta* out_meta, u8* out_data); + MessageMeta* out_message_meta, bool lock_tables); +/// @brief try to find `count` messages starting from `fisrt_message_id` +/// @param out_meta information about found messages, .count can be 0 if no messages found +/// @param out_block allocates buffer on heap and copies them there, .len can be 0 if no messages found Result(void) Channel_loadMessageBlock(Channel* self, u64 fisrt_message_id, u32 count, - MessageBlockMeta* out_meta, MessageBlock* out_block); + MessageBlockMeta* out_meta, NULLABLE(Array(u8)*) out_block, bool lock_tables); typedef struct Server { diff --git a/src/network/network.c b/src/tcp-chat.c old mode 100755 new mode 100644 similarity index 50% rename from src/network/network.c rename to src/tcp-chat.c index e1392ed..5e76214 --- a/src/network/network.c +++ b/src/tcp-chat.c @@ -1,21 +1,27 @@ -#include "internal.h" +#include "network/internal.h" ErrorCodePage_define(WINSOCK2); +ErrorCodePage_define(TcpChat); + +Result(void) TcpChat_init(){ + Deferral(4); + + ErrorCodePage_register(TcpChat); -Result(void) network_init(){ #if _WIN32 ErrorCodePage_register(WINSOCK2); + // Initialize Winsock WSADATA wsaData = {0}; int result = WSAStartup(MAKEWORD(2,2), &wsaData); if (result != 0) { - return RESULT_ERROR_FMT("WSAStartup failed with error code 0x%X", result); + Return RESULT_ERROR_CODE_FMT(WINSOCK2, result, "WSAStartup failed with error code %i", result); } #endif - return RESULT_VOID; + Return RESULT_VOID; } -void network_deinit(){ +void TcpChat_deinit(){ #if _WIN32 // Deinitialize Winsock (void)WSACleanup();