From 0abee3f7dff35dd4682a1b68788542130f953236 Mon Sep 17 00:00:00 2001 From: Timerix Date: Fri, 21 Nov 2025 21:22:53 +0500 Subject: [PATCH] fixed bugs in server and moved token hash calculation to client --- dependencies/tlibc | 2 +- src/cli/ClientCLI/ClientCLI.c | 48 ++++++++-------- src/cli/ClientCLI/ClientCLI.h | 2 +- src/cli/ClientCLI/db_tables.h | 4 +- src/cli/modes/ServerMode.c | 2 +- src/client/ServerConnection.c | 36 ++++++++---- src/client/client.c | 55 ++++++++----------- src/client/client_internal.h | 14 +++-- src/cryptography/cryptography.h | 5 +- src/cryptography/hash.c | 3 +- src/server/ClientConnection.c | 1 + src/server/db_tables.h | 10 ++-- src/server/request_handlers/Login.c | 39 ++++++------- src/server/request_handlers/Register.c | 36 ++++++------ .../request_handlers/ServerPublicInfo.c | 10 ++-- .../request_handlers/request_handlers.h | 15 ++--- src/server/request_handlers/send_error.c | 16 +++--- src/server/request_handlers/template | 4 +- src/server/server.c | 12 ++-- src/server/server_internal.h | 3 +- 20 files changed, 161 insertions(+), 156 deletions(-) diff --git a/dependencies/tlibc b/dependencies/tlibc index 0b4574b..3a7f09b 160000 --- a/dependencies/tlibc +++ b/dependencies/tlibc @@ -1 +1 @@ -Subproject commit 0b4574b53e7798c6197a2e6af6300a7e64ad7d06 +Subproject commit 3a7f09bb498658dff1b5f38e5f5bf9474ad833ba diff --git a/src/cli/ClientCLI/ClientCLI.c b/src/cli/ClientCLI/ClientCLI.c index a841eda..a2357c9 100644 --- a/src/cli/ClientCLI/ClientCLI.c +++ b/src/cli/ClientCLI/ClientCLI.c @@ -24,18 +24,20 @@ static const str farewell_art = STR( static Result(void) ClientCLI_askUserNameAndPassword(str* username_out, str* password_out); static Result(void) ClientCLI_execCommand(ClientCLI* self, str command, bool* stop); static Result(void) ClientCLI_openUserDB(ClientCLI* self); -static Result(Server*) ClientCLI_saveServerInfo(ClientCLI* self, +static Result(ServerInfo*) ClientCLI_saveServerInfo(ClientCLI* self, str addr, str pk_base64, str name, str desc); -static Result(Server*) ClientCLI_joinNewServer(ClientCLI* self); -static Result(Server*) ClientCLI_selectServerFromCache(ClientCLI* self); -static Result(void) ClientCLI_showServerInfo(ClientCLI* self, Server* server); +static Result(ServerInfo*) ClientCLI_joinNewServer(ClientCLI* self); +static Result(ServerInfo*) ClientCLI_selectServerFromCache(ClientCLI* self); +static Result(void) ClientCLI_showServerInfo(ClientCLI* self, ServerInfo* server); static Result(void) ClientCLI_register(ClientCLI* self); static Result(void) ClientCLI_login(ClientCLI* self); void ClientCLI_destroy(ClientCLI* self){ if(!self) return; + Client_free(self->client); + idb_close(self->db); pthread_mutex_destroy(&self->servers_cache_mutex); List_destroy(self->servers_cache_list); @@ -213,7 +215,7 @@ static Result(void) ClientCLI_joinNewServer(ClientCLI* self){ str server_description = str_null; try_void(Client_getServerName(self->client, &server_name)); try_void(Client_getServerDescription(self->client, &server_description)); - try(Server* server, p, ClientCLI_saveServerInfo(self, + try(ServerInfo* server, p, ClientCLI_saveServerInfo(self, server_addr_str, server_pk_str, server_name, server_description)); @@ -229,14 +231,14 @@ static Result(void) ClientCLI_selectServerFromCache(ClientCLI* self){ try_stderrcode(pthread_mutex_lock(&self->servers_cache_mutex)); Defer(pthread_mutex_unlock(&self->servers_cache_mutex)); - u32 servers_count = List_len(self->servers_cache_list, Server); + u32 servers_count = List_len(self->servers_cache_list, ServerInfo); if(servers_count == 0){ printf("No servers found in cache\n"); Return RESULT_VOID; } for(u32 id = 0; id < servers_count; id++){ - Server* row = &List_index(self->servers_cache_list, Server, id); + ServerInfo* row = &List_index(self->servers_cache_list, ServerInfo, id); printf("[%02u] "FMT_str" "FMT_str"\n", id, row->address_len, row->address, row->name_len, row->name); } @@ -260,7 +262,7 @@ static Result(void) ClientCLI_selectServerFromCache(ClientCLI* self){ } else break; } - Server* server = &List_index(self->servers_cache_list, Server, id); + ServerInfo* server = &List_index(self->servers_cache_list, ServerInfo, id); printf("Connecting to '"FMT_str"'...\n", server->address_len, server->address); try_void(Client_connect(self->client, server->address, server->pk_base64)); @@ -296,7 +298,7 @@ static Result(void) ClientCLI_selectServerFromCache(ClientCLI* self){ Return RESULT_VOID; } -static Result(void) ClientCLI_showServerInfo(ClientCLI* self, Server* server){ +static Result(void) ClientCLI_showServerInfo(ClientCLI* self, ServerInfo* server){ Deferral(8); (void)self; @@ -320,16 +322,16 @@ static Result(void) ClientCLI_openUserDB(ClientCLI* self){ try(self->db, p, idb_open(user_db_dir, user_data_key)); // load servers table - try(self->db_servers_table, p, idb_getOrCreateTable(self->db, STR("servers"), sizeof(Server))); + try(self->db_servers_table, p, idb_getOrCreateTable(self->db, STR("servers"), sizeof(ServerInfo))); // load whole table to list try(u64 servers_count, u, idb_getRowCount(self->db_servers_table)); - self->servers_cache_list = List_alloc(Server, servers_count); + self->servers_cache_list = List_alloc(ServerInfo, servers_count); try_void(idb_getRows(self->db_servers_table, 0, self->servers_cache_list.data, servers_count)); - self->servers_cache_list.size = sizeof(Server) * servers_count; + self->servers_cache_list.size = sizeof(ServerInfo) * servers_count; // build address-id map HashMap_construct(&self->servers_addr_id_map, u64, NULL); for(u64 id = 0; id < servers_count; id++){ - Server* row = &List_index(self->servers_cache_list, Server, id); + ServerInfo* row = &List_index(self->servers_cache_list, ServerInfo, id); str key = str_construct(row->address, row->address_len, true); if(!HashMap_tryPush(&self->servers_addr_id_map, key, &id)){ Return RESULT_ERROR_FMT("duplicate server address '"FMT_str"'", key.size, key.data); @@ -339,13 +341,13 @@ static Result(void) ClientCLI_openUserDB(ClientCLI* self){ Return RESULT_VOID; } -static Result(Server*) ClientCLI_saveServerInfo(ClientCLI* self, +static Result(ServerInfo*) ClientCLI_saveServerInfo(ClientCLI* self, str addr, str pk_base64, str name, str desc){ Deferral(8); // create new server info - Server server; - memset(&server, 0, sizeof(Server)); + ServerInfo server; + memset(&server, 0, sizeof(ServerInfo)); // address if(addr.size > HOSTADDR_SIZE_MAX) addr.size = HOSTADDR_SIZE_MAX; @@ -372,23 +374,23 @@ static Result(Server*) ClientCLI_saveServerInfo(ClientCLI* self, Defer(pthread_mutex_unlock(&self->servers_cache_mutex)); // try find server id in cache - Server* cached_row_ptr = NULL; + ServerInfo* cached_row_ptr = NULL; u64* id_ptr = NULL; id_ptr = HashMap_tryGetPtr(&self->servers_addr_id_map, addr); if(id_ptr){ // update existing server u64 id = *id_ptr; try_void(idb_updateRow(self->db_servers_table, id, &server)); - try_assert(id < List_len(self->servers_cache_list, Server)); - cached_row_ptr = &List_index(self->servers_cache_list, Server, id); - memcpy(cached_row_ptr, &server, sizeof(Server)); + try_assert(id < List_len(self->servers_cache_list, ServerInfo)); + cached_row_ptr = &List_index(self->servers_cache_list, ServerInfo, id); + memcpy(cached_row_ptr, &server, sizeof(ServerInfo)); } else { // push new server try(u64 id, u, idb_pushRow(self->db_servers_table, &server)); - try_assert(id == List_len(self->servers_cache_list, Server)); - List_pushMany(&self->servers_cache_list, Server, &server, 1); - cached_row_ptr = &List_index(self->servers_cache_list, Server, id); + try_assert(id == List_len(self->servers_cache_list, ServerInfo)); + List_pushMany(&self->servers_cache_list, ServerInfo, &server, 1); + cached_row_ptr = &List_index(self->servers_cache_list, ServerInfo, id); try_assert(HashMap_tryPush(&self->servers_addr_id_map, addr, &id)); } diff --git a/src/cli/ClientCLI/ClientCLI.h b/src/cli/ClientCLI/ClientCLI.h index 5fc210e..02b8f7d 100644 --- a/src/cli/ClientCLI/ClientCLI.h +++ b/src/cli/ClientCLI/ClientCLI.h @@ -11,7 +11,7 @@ typedef struct ClientCLI { IncrementalDB* db; Table* db_servers_table; pthread_mutex_t servers_cache_mutex; - List(Server) servers_cache_list; // index is id + List(ServerInfo) servers_cache_list; // index is id HashMap(u64) servers_addr_id_map; // key is server address } ClientCLI; diff --git a/src/cli/ClientCLI/db_tables.h b/src/cli/ClientCLI/db_tables.h index 37409ba..6f7ff61 100644 --- a/src/cli/ClientCLI/db_tables.h +++ b/src/cli/ClientCLI/db_tables.h @@ -2,7 +2,7 @@ #include "tcp-chat/common_constants.h" #include "tlibc/time.h" -typedef struct Server { +typedef struct ServerInfo { u16 address_len; char address[HOSTADDR_SIZE_MAX + 1]; u32 pk_base64_len; @@ -11,4 +11,4 @@ typedef struct Server { char name[SERVER_NAME_SIZE_MAX + 1]; u16 desc_len; char desc[SERVER_DESC_SIZE_MAX + 1]; -} ATTRIBUTE_ALIGNED(16*1024) Server; +} ATTRIBUTE_ALIGNED(16*1024) ServerInfo; diff --git a/src/cli/modes/ServerMode.c b/src/cli/modes/ServerMode.c index accc92c..b4ff472 100644 --- a/src/cli/modes/ServerMode.c +++ b/src/cli/modes/ServerMode.c @@ -33,7 +33,6 @@ Result(void) run_ServerMode(cstr config_path) { try_void(file_readWhole(config_file, &config_buf)); Defer(Array_free(config_buf)); str config_str = Array_castTo_str(config_buf, false); - config_buf.data = NULL; // init server try(Server* server, p, Server_create(config_str, NULL, log_func)); @@ -43,6 +42,7 @@ Result(void) run_ServerMode(cstr config_path) { file_close(config_file); config_file = NULL; Array_free(config_buf); + config_buf.data = NULL; // start infinite loop on main thread try_void(Server_run(server)); diff --git a/src/client/ServerConnection.c b/src/client/ServerConnection.c index 9833bb8..d03e194 100644 --- a/src/client/ServerConnection.c +++ b/src/client/ServerConnection.c @@ -1,18 +1,19 @@ #include "client_internal.h" #include "requests/requests.h" -void ServerConnection_close(ServerConnection* conn){ - if(!conn) +void ServerConnection_close(ServerConnection* self){ + if(!self) return; - RSA_destroyPublicKey(&conn->server_pk); - EncryptedSocketTCP_destroy(&conn->sock); - Array_free(conn->session_key); - str_free(conn->server_name); - str_free(conn->server_description); - free(conn); + RSA_destroyPublicKey(&self->server_pk); + EncryptedSocketTCP_destroy(&self->sock); + Array_free(self->token); + Array_free(self->session_key); + str_free(self->server_name); + str_free(self->server_description); + free(self); } -Result(ServerConnection*) ServerConnection_open(cstr server_addr_cstr, cstr server_pk_base64) +Result(ServerConnection*) ServerConnection_open(Client* client, cstr server_addr_cstr, cstr server_pk_base64) { Deferral(16); @@ -21,6 +22,8 @@ Result(ServerConnection*) ServerConnection_open(cstr server_addr_cstr, cstr serv bool success = false; Defer(if(!success) ServerConnection_close(conn)); + conn->client = client; + // TODO: parse domain name and get ip from it conn->server_end = EndpointIPv4_INVALID; try_void(EndpointIPv4_parse(server_addr_cstr, &conn->server_end)); @@ -31,8 +34,21 @@ Result(ServerConnection*) ServerConnection_open(cstr server_addr_cstr, cstr serv try_void(RSA_parsePublicKey_base64(server_pk_base64, &conn->server_pk)); RSAEncryptor_construct(&conn->rsa_enc, &conn->server_pk); + // lvl 2 hash - is used for authentification + conn->token = Array_alloc(u8, PASSWORD_HASH_SIZE); + // hash user_data_key with server_pk once + Array(u8) server_pk_data = Array_construct_size(conn->server_pk.n, + BR_RSA_KBUF_PUB_SIZE(conn->server_pk.nlen * 8)); + u8 server_pk_hash[PASSWORD_HASH_SIZE]; + Array(u8) server_pk_hash_array = Array_construct_size(server_pk_hash, PASSWORD_HASH_SIZE); + hash_password(conn->client->user_data_key, server_pk_data, + server_pk_hash, 1); + // hash user_data_key with server_pk_hash + hash_password(conn->token, server_pk_hash_array, + conn->token.data, PASSWORD_HASH_LVL_ROUNDS); + + // generate session random AES key conn->session_key = Array_alloc_size(AES_SESSION_KEY_SIZE); - // generate random session key br_hmac_drbg_context key_rng = { .vtable = &br_hmac_drbg_vtable }; rng_init_sha256_seedFromSystem(&key_rng.vtable); br_hmac_drbg_generate(&key_rng, conn->session_key.data, conn->session_key.size); diff --git a/src/client/client.c b/src/client/client.c index c94871c..f89a45b 100644 --- a/src/client/client.c +++ b/src/client/client.c @@ -6,9 +6,8 @@ void Client_free(Client* self){ return; str_free(self->username); - Array_free(self->token); Array_free(self->user_data_key); - ServerConnection_close(self->server_connection); + ServerConnection_close(self->conn); free(self); } @@ -22,21 +21,10 @@ Result(Client*) Client_create(str username, str password){ self->username = str_copy(username); - // concat password and username - List(u8) data_to_hash = List_alloc_size(password.size + username.size + PASSWORD_HASH_SIZE); - Defer(free(data_to_hash.data)); - List_push_size(&data_to_hash, password.data, password.size); - List_push_size(&data_to_hash, username.data, username.size); - // lvl 1 hash - is used as AES key for user data self->user_data_key = Array_alloc(u8, PASSWORD_HASH_SIZE); - hash_password(List_castTo_Array(data_to_hash), self->user_data_key.data, PASSWORD_HASH_LVL_ROUNDS); - // concat lvl 1 hash to data_to_hash - List_push_size(&data_to_hash, self->user_data_key.data, self->user_data_key.size); - // lvl 2 hash - is used for authentification - self->token = Array_alloc(u8, PASSWORD_HASH_SIZE); - // TODO: generate different token for each server - hash_password(List_castTo_Array(data_to_hash), self->token.data, PASSWORD_HASH_LVL_ROUNDS); + hash_password(str_castTo_Array(password), str_castTo_Array(username), + self->user_data_key.data, PASSWORD_HASH_LVL_ROUNDS); success = true; Return RESULT_VALUE(p, self); @@ -45,14 +33,15 @@ Result(Client*) Client_create(str username, str password){ Result(void) Client_connect(Client* self, cstr server_addr_cstr, cstr server_pk_base64){ Deferral(8); Client_disconnect(self); - try(self->server_connection, p, - ServerConnection_open(server_addr_cstr, server_pk_base64)); + try(self->conn, p, + ServerConnection_open(self, server_addr_cstr, server_pk_base64) + ); Return RESULT_VOID; } void Client_disconnect(Client* self){ - ServerConnection_close(self->server_connection); - self->server_connection = NULL; + ServerConnection_close(self->conn); + self->conn = NULL; } str Client_getUserName(Client* client){ @@ -66,9 +55,9 @@ Array(u8) Client_getUserDataKey(Client* client){ Result(void) Client_getServerName(Client* self, str* out_name){ Deferral(1); try_assert(self != NULL); - try_assert(self->server_connection != NULL && "didn't connect to a server yet"); + try_assert(self->conn != NULL && "didn't connect to a server yet"); - *out_name = self->server_connection->server_name; + *out_name = self->conn->server_name; Return RESULT_VOID; } @@ -76,9 +65,9 @@ Result(void) Client_getServerName(Client* self, str* out_name){ Result(void) Client_getServerDescription(Client* self, str* out_desc){ Deferral(1); try_assert(self != NULL); - try_assert(self->server_connection != NULL && "didn't connect to a server yet"); + try_assert(self->conn != NULL && "didn't connect to a server yet"); - *out_desc = self->server_connection->server_description; + *out_desc = self->conn->server_description; Return RESULT_VOID; } @@ -86,16 +75,16 @@ Result(void) Client_getServerDescription(Client* self, str* out_desc){ Result(void) Client_register(Client* self, u64* out_user_id){ Deferral(1); try_assert(self != NULL); - try_assert(self->server_connection != NULL && "didn't connect to a server yet"); + try_assert(self->conn != NULL && "didn't connect to a server yet"); PacketHeader req_head, res_head; RegisterRequest req; RegisterResponse res; // TODO: hash token with server public key - try_void(RegisterRequest_tryConstruct(&req, &req_head, self->username, self->token)); - try_void(sendRequest(&self->server_connection->sock, &req_head, &req)); - try_void(recvResponse(&self->server_connection->sock, &res_head, &res, PacketType_RegisterResponse)); - self->server_connection->user_id = res.user_id; + try_void(RegisterRequest_tryConstruct(&req, &req_head, self->username, self->conn->token)); + try_void(sendRequest(&self->conn->sock, &req_head, &req)); + try_void(recvResponse(&self->conn->sock, &res_head, &res, PacketType_RegisterResponse)); + self->conn->user_id = res.user_id; *out_user_id = res.user_id; Return RESULT_VOID; @@ -104,16 +93,16 @@ Result(void) Client_register(Client* self, u64* out_user_id){ Result(void) Client_login(Client* self, u64* out_user_id, u64* out_landing_channel_id){ Deferral(1); try_assert(self != NULL); - try_assert(self->server_connection != NULL && "didn't connect to a server yet"); + try_assert(self->conn != NULL && "didn't connect to a server yet"); PacketHeader req_head, res_head; LoginRequest req; LoginResponse res; // TODO: hash token with server public key - try_void(LoginRequest_tryConstruct(&req, &req_head, self->username, self->token)); - try_void(sendRequest(&self->server_connection->sock, &req_head, &req)); - try_void(recvResponse(&self->server_connection->sock, &res_head, &res, PacketType_LoginResponse)); - self->server_connection->user_id = res.user_id; + try_void(LoginRequest_tryConstruct(&req, &req_head, self->username, self->conn->token)); + try_void(sendRequest(&self->conn->sock, &req_head, &req)); + try_void(recvResponse(&self->conn->sock, &res_head, &res, PacketType_LoginResponse)); + self->conn->user_id = res.user_id; *out_user_id = res.user_id; *out_landing_channel_id = res.landing_channel_id; diff --git a/src/client/client_internal.h b/src/client/client_internal.h index 21ca694..d9ebdb6 100644 --- a/src/client/client_internal.h +++ b/src/client/client_internal.h @@ -9,18 +9,19 @@ typedef struct ServerConnection ServerConnection; typedef struct Client { str username; Array(u8) user_data_key; - Array(u8) token; - ServerConnection* server_connection; + ServerConnection* conn; } Client; typedef struct ServerConnection { + Client* client; EndpointIPv4 server_end; br_rsa_public_key server_pk; RSAEncryptor rsa_enc; - u64 session_id; + Array(u8) token; Array(u8) session_key; EncryptedSocketTCP sock; + u64 session_id; str server_name; str server_description; u64 user_id; @@ -28,10 +29,13 @@ typedef struct ServerConnection { /// @param server_addr_cstr /// @param server_pk_base64 public key encoded by `RSA_serializePublicKey_base64()` -Result(ServerConnection*) ServerConnection_open(cstr server_addr_cstr, cstr server_pk_base64); -void ServerConnection_close(ServerConnection* conn); +Result(ServerConnection*) ServerConnection_open(Client* client, + cstr server_addr_cstr, cstr server_pk_base64); + + void ServerConnection_close(ServerConnection* conn); /// updates conn->server_name Result(void) ServerConnection_requestServerName(ServerConnection* conn); + /// updates conn->server_description Result(void) ServerConnection_requestServerDescription(ServerConnection* conn); diff --git a/src/cryptography/cryptography.h b/src/cryptography/cryptography.h index 576bf7a..3e7dee7 100755 --- a/src/cryptography/cryptography.h +++ b/src/cryptography/cryptography.h @@ -13,9 +13,10 @@ /// @brief hashes password multiple times using its own hash as salt /// @param password some byte array +/// @param salt some byte array /// @param out_buffer u8[PASSWORD_HASH_SIZE] /// @param rounds number of rounds -void hash_password(Array(u8) password, u8* out_buffer, i32 rounds); +void hash_password(Array(u8) password, Array(u8) salt, u8* out_buffer, i32 rounds); #define PASSWORD_HASH_LVL_ROUNDS 1e5 ////////////////////////////////////////////////////////////////////////////// @@ -27,6 +28,7 @@ void hash_password(Array(u8) password, u8* out_buffer, i32 rounds); /// @brief Initialize prng context with sha256 hashing algorithm /// and seed from system-provided cryptographic random bytes source. /// @param rng_vtable_ptr pointer to vtable field in prng context. The field must be initialized. +/// /// EXAMPLE: /// ``` /// br_hmac_drbg_context rng_ctx = { .vtable = &br_hmac_drbg_vtable }; @@ -36,6 +38,7 @@ void rng_init_sha256_seedFromSystem(const br_prng_class** rng_vtable_ptr); /// @brief Initialize prng context with sha256 hashing algorithm and seed from CLOCK_REALTIME. /// @param rng_vtable_ptr pointer to vtable field in prng context. The field must be initialized. +/// /// EXAMPLE: /// ``` /// br_hmac_drbg_context rng_ctx = { .vtable = &br_hmac_drbg_vtable }; diff --git a/src/cryptography/hash.c b/src/cryptography/hash.c index 181f15b..f567046 100755 --- a/src/cryptography/hash.c +++ b/src/cryptography/hash.c @@ -1,7 +1,7 @@ #include "cryptography.h" #include "assert.h" -void hash_password(Array(u8) password, u8* out_buffer, i32 iterations){ +void hash_password(Array(u8) password, Array(u8) salt, u8* out_buffer, i32 iterations){ assert(PASSWORD_HASH_SIZE == br_sha256_SIZE);; memset(out_buffer, 0, br_sha256_SIZE); br_sha256_context sha256_ctx; @@ -9,6 +9,7 @@ void hash_password(Array(u8) password, u8* out_buffer, i32 iterations){ for(i32 i = 0; i < iterations; i++){ br_sha256_update(&sha256_ctx, password.data, password.size); + br_sha256_update(&sha256_ctx, salt.data, salt.size); br_sha256_out(&sha256_ctx, out_buffer); br_sha256_update(&sha256_ctx, out_buffer, PASSWORD_HASH_SIZE); } diff --git a/src/server/ClientConnection.c b/src/server/ClientConnection.c index 9214629..af76a63 100644 --- a/src/server/ClientConnection.c +++ b/src/server/ClientConnection.c @@ -18,6 +18,7 @@ Result(ClientConnection*) ClientConnection_accept(ConnectionHandlerArgs* args) bool success = false; Defer(if(!success) ClientConnection_close(conn)); + conn->server = args->server; conn->client_end = args->client_end; conn->session_id = args->session_id; conn->authorized = false; diff --git a/src/server/db_tables.h b/src/server/db_tables.h index 336f957..a3ec148 100644 --- a/src/server/db_tables.h +++ b/src/server/db_tables.h @@ -2,16 +2,16 @@ #include "tcp-chat/common_constants.h" #include "tlibc/time.h" -typedef struct User { +typedef struct UserInfo { u16 name_len; char name[USERNAME_SIZE_MAX + 1]; // null-terminated - u8 token_hash[PASSWORD_HASH_SIZE]; // token is hashed again on server side + u8 token[PASSWORD_HASH_SIZE]; // token is hashed again on server side DateTime registration_time; -} ATTRIBUTE_ALIGNED(256) User; +} ATTRIBUTE_ALIGNED(256) UserInfo; -typedef struct Channel { +typedef struct ChannelInfo { u16 name_len; char name[CHANNEL_NAME_SIZE_MAX + 1]; u16 desc_len; char desc[CHANNEL_DESC_SIZE_MAX + 1]; -} ATTRIBUTE_ALIGNED(4*1024) Channel; +} ATTRIBUTE_ALIGNED(4*1024) ChannelInfo; diff --git a/src/server/request_handlers/Login.c b/src/server/request_handlers/Login.c index 449cc85..1a5be35 100644 --- a/src/server/request_handlers/Login.c +++ b/src/server/request_handlers/Login.c @@ -1,7 +1,7 @@ #include "request_handlers.h" -#define LOGGER server->logger -#define LOG_FUNC server->log_func +#define LOGGER conn->server->logger +#define LOG_FUNC conn->server->log_func declare_RequestHandler(Login) { @@ -14,7 +14,7 @@ declare_RequestHandler(Login) try_void(EncryptedSocketTCP_recvStruct(&conn->sock, &req)); if(conn->authorized){ - try_void(sendErrorMessage(server, log_ctx, conn, res_head, + try_void(sendErrorMessage(log_ctx, conn, res_head, LogSeverity_Warn, STR("is authorized in already") )); @@ -25,30 +25,25 @@ declare_RequestHandler(Login) str username_str = str_null; str username_check_error = validateUsername_cstr(req.username, &username_str); if(username_check_error.data){ - try_void(sendErrorMessage(server, log_ctx, conn, res_head, + try_void(sendErrorMessage(log_ctx, conn, res_head, LogSeverity_Warn, username_check_error )); Return RESULT_VOID; } - // cakculate hash of received token - u8 token_hash[PASSWORD_HASH_SIZE]; - hash_password( - Array_construct_size(req.token, sizeof(req.token)), - token_hash, - PASSWORD_HASH_LVL_ROUNDS - ); - // lock users cache - try_stderrcode(pthread_mutex_lock(&server->users_cache_mutex)); + try_stderrcode(pthread_mutex_lock(&conn->server->users_cache_mutex)); bool unlocked_users_cache_mutex = false; - Defer(if(!unlocked_users_cache_mutex) pthread_mutex_unlock(&server->users_cache_mutex)); + Defer( + if(!unlocked_users_cache_mutex) + pthread_mutex_unlock(&conn->server->users_cache_mutex) + ); // try get id from name cache - u64* id_ptr = HashMap_tryGetPtr(&server->users_name_id_map, username_str); + u64* id_ptr = HashMap_tryGetPtr(&conn->server->users_name_id_map, username_str); if(id_ptr == NULL){ - try_void(sendErrorMessage_f(server, log_ctx, conn, res_head, + try_void(sendErrorMessage_f(log_ctx, conn, res_head, LogSeverity_Warn, "Username '%s' is not registered", username_str.data @@ -58,12 +53,12 @@ declare_RequestHandler(Login) u64 user_id = *id_ptr; // get user by id - try_assert(user_id < List_len(server->users_cache_list, User)); - User* u = &List_index(server->users_cache_list, User, user_id); + try_assert(user_id < List_len(conn->server->users_cache_list, UserInfo)); + UserInfo* u = &List_index(conn->server->users_cache_list, UserInfo, user_id); // validate token hash - if(memcmp(token_hash, u->token_hash, sizeof(token_hash)) != 0){ - try_void(sendErrorMessage(server, log_ctx, conn, res_head, + if(memcmp(req.token, u->token, sizeof(req.token)) != 0){ + try_void(sendErrorMessage(log_ctx, conn, res_head, LogSeverity_Warn, STR("wrong password") )); @@ -71,7 +66,7 @@ declare_RequestHandler(Login) } // manually unlock mutex - pthread_mutex_unlock(&server->users_cache_mutex); + pthread_mutex_unlock(&conn->server->users_cache_mutex); unlocked_users_cache_mutex = true; // authorize @@ -80,7 +75,7 @@ declare_RequestHandler(Login) // send response LoginResponse res; - LoginResponse_construct(&res, res_head, user_id, server->landing_channel_id); + LoginResponse_construct(&res, res_head, user_id, conn->server->landing_channel_id); try_void(EncryptedSocketTCP_sendStruct(&conn->sock, res_head)); try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &res)); diff --git a/src/server/request_handlers/Register.c b/src/server/request_handlers/Register.c index 4d2c48b..e60e1a4 100644 --- a/src/server/request_handlers/Register.c +++ b/src/server/request_handlers/Register.c @@ -1,7 +1,7 @@ #include "request_handlers.h" -#define LOGGER server->logger -#define LOG_FUNC server->log_func +#define LOGGER conn->server->logger +#define LOG_FUNC conn->server->log_func declare_RequestHandler(Register) { @@ -14,7 +14,7 @@ declare_RequestHandler(Register) try_void(EncryptedSocketTCP_recvStruct(&conn->sock, &req)); if(conn->authorized){ - try_void(sendErrorMessage(server, log_ctx, conn, res_head, + try_void(sendErrorMessage(log_ctx, conn, res_head, LogSeverity_Warn, STR("is authorized in already") )); @@ -25,7 +25,7 @@ declare_RequestHandler(Register) str username_str = str_null; str username_check_error = validateUsername_cstr(req.username, &username_str); if(username_check_error.data){ - try_void(sendErrorMessage(server, log_ctx, conn, res_head, + try_void(sendErrorMessage(log_ctx, conn, res_head, LogSeverity_Warn, username_check_error )); @@ -33,17 +33,17 @@ declare_RequestHandler(Register) } // lock users cache - try_stderrcode(pthread_mutex_lock(&server->users_cache_mutex)); + try_stderrcode(pthread_mutex_lock(&conn->server->users_cache_mutex)); bool unlocked_users_cache_mutex = false; // unlock mutex on error catch Defer( if(!unlocked_users_cache_mutex) - pthread_mutex_unlock(&server->users_cache_mutex) + pthread_mutex_unlock(&conn->server->users_cache_mutex) ); // check if name is taken - if(HashMap_tryGetPtr(&server->users_name_id_map, username_str) != NULL){ - try_void(sendErrorMessage_f(server, log_ctx, conn, res_head, + if(HashMap_tryGetPtr(&conn->server->users_name_id_map, username_str) != NULL){ + try_void(sendErrorMessage_f(log_ctx, conn, res_head, LogSeverity_Warn, "Username '%s' already exists", username_str.data)); @@ -51,28 +51,24 @@ declare_RequestHandler(Register) } // initialize new user - User user; - memset(&user, 0, sizeof(User)); + UserInfo user; + memset(&user, 0, sizeof(UserInfo)); user.name_len = username_str.size; memcpy(user.name, username_str.data, user.name_len); - hash_password( - Array_construct_size(req.token, sizeof(req.token)), - user.token_hash, - PASSWORD_HASH_LVL_ROUNDS - ); + memcpy(user.token, req.token, sizeof(req.token)); DateTime_getUTC(&user.registration_time); // save new user to db and cache - try(u64 user_id, u, idb_pushRow(server->db_users_table, &user)); - try_assert(user_id == List_len(server->users_cache_list, User)); - List_pushMany(&server->users_cache_list, User, &user, 1); - try_assert(HashMap_tryPush(&server->users_name_id_map, username_str, &user_id)); + try(u64 user_id, u, idb_pushRow(conn->server->db_users_table, &user)); + try_assert(user_id == List_len(conn->server->users_cache_list, UserInfo)); + List_pushMany(&conn->server->users_cache_list, UserInfo, &user, 1); + try_assert(HashMap_tryPush(&conn->server->users_name_id_map, username_str, &user_id)); // manually unlock mutex - pthread_mutex_unlock(&server->users_cache_mutex); + pthread_mutex_unlock(&conn->server->users_cache_mutex); unlocked_users_cache_mutex = true; logInfo(log_ctx, "registered user '%s'", username_str.data); diff --git a/src/server/request_handlers/ServerPublicInfo.c b/src/server/request_handlers/ServerPublicInfo.c index 1b2d951..b8dc0d8 100644 --- a/src/server/request_handlers/ServerPublicInfo.c +++ b/src/server/request_handlers/ServerPublicInfo.c @@ -1,7 +1,7 @@ #include "request_handlers.h" -#define LOGGER server->logger -#define LOG_FUNC server->log_func +#define LOGGER conn->server->logger +#define LOG_FUNC conn->server->log_func declare_RequestHandler(ServerPublicInfo) { @@ -17,17 +17,17 @@ declare_RequestHandler(ServerPublicInfo) Array(u8) content; switch(req.property){ default:{ - try_void(sendErrorMessage_f(server, log_ctx, conn, res_head, + try_void(sendErrorMessage_f(log_ctx, conn, res_head, LogSeverity_Warn, "Unknown ServerPublicInfo property %u", req.property)); Return RESULT_VOID; } case ServerPublicInfo_Name: - content = str_castTo_Array(server->name); + content = str_castTo_Array(conn->server->name); break; case ServerPublicInfo_Description: - content = str_castTo_Array(server->description); + content = str_castTo_Array(conn->server->description); break; } diff --git a/src/server/request_handlers/request_handlers.h b/src/server/request_handlers/request_handlers.h index 6d03e2b..e400760 100644 --- a/src/server/request_handlers/request_handlers.h +++ b/src/server/request_handlers/request_handlers.h @@ -4,29 +4,26 @@ Result(void) sendErrorMessage( - Server* server, cstr log_ctx, - ClientConnection* conn, PacketHeader* res_head, + cstr log_ctx, ClientConnection* conn, PacketHeader* res_head, LogSeverity log_severity, str msg); Result(void) __sendErrorMessage_fv( - Server* server, cstr log_ctx, - ClientConnection* conn, PacketHeader* res_head, + cstr log_ctx, ClientConnection* conn, PacketHeader* res_head, LogSeverity log_severity, cstr format, va_list argv); Result(void) sendErrorMessage_f( - Server* server, cstr log_ctx, - ClientConnection* conn, PacketHeader* res_head, - LogSeverity log_severity, cstr format, ...) ATTRIBUTE_CHECK_FORMAT_PRINTF(6, 7); + cstr log_ctx, ClientConnection* conn, PacketHeader* res_head, + LogSeverity log_severity, cstr format, ...) ATTRIBUTE_CHECK_FORMAT_PRINTF(5, 6); #define declare_RequestHandler(TYPE) \ Result(void) handleRequest_##TYPE(\ - Server* server, cstr log_ctx, cstr req_type_name, \ + cstr log_ctx, cstr req_type_name, \ ClientConnection* conn, PacketHeader* req_head, PacketHeader* res_head) #define case_handleRequest(TYPE) \ case PacketType_##TYPE##Request:\ - try_void(handleRequest_##TYPE(args->server, log_ctx, #TYPE, conn, &req_head, &res_head));\ + try_void(handleRequest_##TYPE(log_ctx, #TYPE, conn, &req_head, &res_head));\ break; declare_RequestHandler(ServerPublicInfo); diff --git a/src/server/request_handlers/send_error.c b/src/server/request_handlers/send_error.c index 06031c0..e19a838 100644 --- a/src/server/request_handlers/send_error.c +++ b/src/server/request_handlers/send_error.c @@ -1,11 +1,10 @@ #include "request_handlers.h" -#define LOGGER server->logger -#define LOG_FUNC server->log_func +#define LOGGER conn->server->logger +#define LOG_FUNC conn->server->log_func Result(void) sendErrorMessage( - Server* server, cstr log_ctx, - ClientConnection* conn, PacketHeader* res_head, + cstr log_ctx, ClientConnection* conn, PacketHeader* res_head, LogSeverity log_severity, str msg) { Deferral(1); @@ -26,21 +25,20 @@ Result(void) sendErrorMessage( } Result(void) __sendErrorMessage_fv( - Server* server, cstr log_ctx, - ClientConnection* conn, PacketHeader* res_head, + cstr log_ctx, ClientConnection* conn, PacketHeader* res_head, LogSeverity log_severity, cstr format, va_list argv) { Deferral(4); str msg = str_from_cstr(vsprintf_malloc(format, argv)); Defer(free(msg.data)); - try_void(sendErrorMessage(server, log_ctx, conn, res_head, log_severity, msg)); + try_void(sendErrorMessage(log_ctx, conn, res_head, log_severity, msg)); Return RESULT_VOID; } Result(void) sendErrorMessage_f( - Server* server, cstr log_ctx, + cstr log_ctx, ClientConnection* conn, PacketHeader* res_head, LogSeverity log_severity, cstr format, ...) { @@ -49,7 +47,7 @@ Result(void) sendErrorMessage_f( va_list argv; va_start(argv, format); Defer(va_end(argv)); - try_void(__sendErrorMessage_fv(server, log_ctx, conn, res_head, log_severity, format, argv)); + try_void(__sendErrorMessage_fv(log_ctx, conn, res_head, log_severity, format, argv)); Return RESULT_VOID; } diff --git a/src/server/request_handlers/template b/src/server/request_handlers/template index 51caf71..5ff1a88 100644 --- a/src/server/request_handlers/template +++ b/src/server/request_handlers/template @@ -1,7 +1,7 @@ #include "request_handlers.h" -#define LOGGER server->logger -#define LOG_FUNC server->log_func +#define LOGGER conn->server->logger +#define LOG_FUNC conn->server->log_func declare_RequestHandler(NAME) { diff --git a/src/server/server.c b/src/server/server.c index 2122f3a..a9eb513 100644 --- a/src/server/server.c +++ b/src/server/server.c @@ -23,6 +23,8 @@ void Server_free(Server* self){ pthread_mutex_destroy(&self->users_cache_mutex); List_destroy(self->users_cache_list); HashMap_destroy(&self->users_name_id_map); + + free(self); } @@ -91,16 +93,16 @@ Result(Server*) Server_create(str config_str, void* logger, LogFunction_t log_fu // build users cache logDebug(log_ctx, "loading users..."); pthread_mutex_init(&self->users_cache_mutex, NULL); - try(self->db_users_table, p, idb_getOrCreateTable(self->db, STR("users"), sizeof(User))); + try(self->db_users_table, p, idb_getOrCreateTable(self->db, STR("users"), sizeof(UserInfo))); // load whole table to list try(u64 users_count, u, idb_getRowCount(self->db_users_table)); - self->users_cache_list = List_alloc(User, users_count); + self->users_cache_list = List_alloc(UserInfo, users_count); try_void(idb_getRows(self->db_users_table, 0, self->users_cache_list.data, users_count)); - self->users_cache_list.size = sizeof(User) * users_count; + self->users_cache_list.size = sizeof(UserInfo) * users_count; // build name-id map HashMap_construct(&self->users_name_id_map, u64, NULL); for(u64 id = 0; id < users_count; id++){ - User* row = &List_index(self->users_cache_list, User, id); + UserInfo* row = &List_index(self->users_cache_list, UserInfo, id); str key = str_construct(row->name, row->name_len, true); if(!HashMap_tryPush(&self->users_name_id_map, key, &id)){ Return RESULT_ERROR_FMT("duplicate user name '"FMT_str"'", key.size, key.data); @@ -191,7 +193,7 @@ static Result(void) try_handleConnection(ConnectionHandlerArgs* args, cstr log_c switch(req_head.type){ // send error message and close connection default: - try_void(sendErrorMessage_f(server, log_ctx, conn, &res_head, + try_void(sendErrorMessage_f(log_ctx, conn, &res_head, LogSeverity_Error, "Received unexpected packet of type %u", req_head.type)); diff --git a/src/server/server_internal.h b/src/server/server_internal.h index 55e23c4..0c033f0 100644 --- a/src/server/server_internal.h +++ b/src/server/server_internal.h @@ -28,12 +28,13 @@ typedef struct Server { IncrementalDB* db; Table* db_users_table; pthread_mutex_t users_cache_mutex; - List(User) users_cache_list; // index is id + List(UserInfo) users_cache_list; // index is id HashMap(u64) users_name_id_map; // key is user name } Server; typedef struct ClientConnection { + Server* server; u64 session_id; EndpointIPv4 client_end; Array(u8) session_key;