fixed bugs in server and moved token hash calculation to client

This commit is contained in:
Timerix 2025-11-21 21:22:53 +05:00
parent baca2fb4d3
commit 0abee3f7df
20 changed files with 161 additions and 156 deletions

2
dependencies/tlibc vendored

@ -1 +1 @@
Subproject commit 0b4574b53e7798c6197a2e6af6300a7e64ad7d06 Subproject commit 3a7f09bb498658dff1b5f38e5f5bf9474ad833ba

View File

@ -24,18 +24,20 @@ static const str farewell_art = STR(
static Result(void) ClientCLI_askUserNameAndPassword(str* username_out, str* password_out); static Result(void) ClientCLI_askUserNameAndPassword(str* username_out, str* password_out);
static Result(void) ClientCLI_execCommand(ClientCLI* self, str command, bool* stop); static Result(void) ClientCLI_execCommand(ClientCLI* self, str command, bool* stop);
static Result(void) ClientCLI_openUserDB(ClientCLI* self); static Result(void) ClientCLI_openUserDB(ClientCLI* self);
static Result(Server*) ClientCLI_saveServerInfo(ClientCLI* self, static Result(ServerInfo*) ClientCLI_saveServerInfo(ClientCLI* self,
str addr, str pk_base64, str name, str desc); str addr, str pk_base64, str name, str desc);
static Result(Server*) ClientCLI_joinNewServer(ClientCLI* self); static Result(ServerInfo*) ClientCLI_joinNewServer(ClientCLI* self);
static Result(Server*) ClientCLI_selectServerFromCache(ClientCLI* self); static Result(ServerInfo*) ClientCLI_selectServerFromCache(ClientCLI* self);
static Result(void) ClientCLI_showServerInfo(ClientCLI* self, Server* server); static Result(void) ClientCLI_showServerInfo(ClientCLI* self, ServerInfo* server);
static Result(void) ClientCLI_register(ClientCLI* self); static Result(void) ClientCLI_register(ClientCLI* self);
static Result(void) ClientCLI_login(ClientCLI* self); static Result(void) ClientCLI_login(ClientCLI* self);
void ClientCLI_destroy(ClientCLI* self){ void ClientCLI_destroy(ClientCLI* self){
if(!self) if(!self)
return; return;
Client_free(self->client); Client_free(self->client);
idb_close(self->db); idb_close(self->db);
pthread_mutex_destroy(&self->servers_cache_mutex); pthread_mutex_destroy(&self->servers_cache_mutex);
List_destroy(self->servers_cache_list); List_destroy(self->servers_cache_list);
@ -213,7 +215,7 @@ static Result(void) ClientCLI_joinNewServer(ClientCLI* self){
str server_description = str_null; str server_description = str_null;
try_void(Client_getServerName(self->client, &server_name)); try_void(Client_getServerName(self->client, &server_name));
try_void(Client_getServerDescription(self->client, &server_description)); try_void(Client_getServerDescription(self->client, &server_description));
try(Server* server, p, ClientCLI_saveServerInfo(self, try(ServerInfo* server, p, ClientCLI_saveServerInfo(self,
server_addr_str, server_pk_str, server_addr_str, server_pk_str,
server_name, server_description)); server_name, server_description));
@ -229,14 +231,14 @@ static Result(void) ClientCLI_selectServerFromCache(ClientCLI* self){
try_stderrcode(pthread_mutex_lock(&self->servers_cache_mutex)); try_stderrcode(pthread_mutex_lock(&self->servers_cache_mutex));
Defer(pthread_mutex_unlock(&self->servers_cache_mutex)); Defer(pthread_mutex_unlock(&self->servers_cache_mutex));
u32 servers_count = List_len(self->servers_cache_list, Server); u32 servers_count = List_len(self->servers_cache_list, ServerInfo);
if(servers_count == 0){ if(servers_count == 0){
printf("No servers found in cache\n"); printf("No servers found in cache\n");
Return RESULT_VOID; Return RESULT_VOID;
} }
for(u32 id = 0; id < servers_count; id++){ for(u32 id = 0; id < servers_count; id++){
Server* row = &List_index(self->servers_cache_list, Server, id); ServerInfo* row = &List_index(self->servers_cache_list, ServerInfo, id);
printf("[%02u] "FMT_str" "FMT_str"\n", printf("[%02u] "FMT_str" "FMT_str"\n",
id, row->address_len, row->address, row->name_len, row->name); id, row->address_len, row->address, row->name_len, row->name);
} }
@ -260,7 +262,7 @@ static Result(void) ClientCLI_selectServerFromCache(ClientCLI* self){
} }
else break; else break;
} }
Server* server = &List_index(self->servers_cache_list, Server, id); ServerInfo* server = &List_index(self->servers_cache_list, ServerInfo, id);
printf("Connecting to '"FMT_str"'...\n", server->address_len, server->address); printf("Connecting to '"FMT_str"'...\n", server->address_len, server->address);
try_void(Client_connect(self->client, server->address, server->pk_base64)); try_void(Client_connect(self->client, server->address, server->pk_base64));
@ -296,7 +298,7 @@ static Result(void) ClientCLI_selectServerFromCache(ClientCLI* self){
Return RESULT_VOID; Return RESULT_VOID;
} }
static Result(void) ClientCLI_showServerInfo(ClientCLI* self, Server* server){ static Result(void) ClientCLI_showServerInfo(ClientCLI* self, ServerInfo* server){
Deferral(8); Deferral(8);
(void)self; (void)self;
@ -320,16 +322,16 @@ static Result(void) ClientCLI_openUserDB(ClientCLI* self){
try(self->db, p, idb_open(user_db_dir, user_data_key)); try(self->db, p, idb_open(user_db_dir, user_data_key));
// load servers table // load servers table
try(self->db_servers_table, p, idb_getOrCreateTable(self->db, STR("servers"), sizeof(Server))); try(self->db_servers_table, p, idb_getOrCreateTable(self->db, STR("servers"), sizeof(ServerInfo)));
// load whole table to list // load whole table to list
try(u64 servers_count, u, idb_getRowCount(self->db_servers_table)); try(u64 servers_count, u, idb_getRowCount(self->db_servers_table));
self->servers_cache_list = List_alloc(Server, servers_count); self->servers_cache_list = List_alloc(ServerInfo, servers_count);
try_void(idb_getRows(self->db_servers_table, 0, self->servers_cache_list.data, servers_count)); try_void(idb_getRows(self->db_servers_table, 0, self->servers_cache_list.data, servers_count));
self->servers_cache_list.size = sizeof(Server) * servers_count; self->servers_cache_list.size = sizeof(ServerInfo) * servers_count;
// build address-id map // build address-id map
HashMap_construct(&self->servers_addr_id_map, u64, NULL); HashMap_construct(&self->servers_addr_id_map, u64, NULL);
for(u64 id = 0; id < servers_count; id++){ for(u64 id = 0; id < servers_count; id++){
Server* row = &List_index(self->servers_cache_list, Server, id); ServerInfo* row = &List_index(self->servers_cache_list, ServerInfo, id);
str key = str_construct(row->address, row->address_len, true); str key = str_construct(row->address, row->address_len, true);
if(!HashMap_tryPush(&self->servers_addr_id_map, key, &id)){ if(!HashMap_tryPush(&self->servers_addr_id_map, key, &id)){
Return RESULT_ERROR_FMT("duplicate server address '"FMT_str"'", key.size, key.data); Return RESULT_ERROR_FMT("duplicate server address '"FMT_str"'", key.size, key.data);
@ -339,13 +341,13 @@ static Result(void) ClientCLI_openUserDB(ClientCLI* self){
Return RESULT_VOID; Return RESULT_VOID;
} }
static Result(Server*) ClientCLI_saveServerInfo(ClientCLI* self, static Result(ServerInfo*) ClientCLI_saveServerInfo(ClientCLI* self,
str addr, str pk_base64, str name, str desc){ str addr, str pk_base64, str name, str desc){
Deferral(8); Deferral(8);
// create new server info // create new server info
Server server; ServerInfo server;
memset(&server, 0, sizeof(Server)); memset(&server, 0, sizeof(ServerInfo));
// address // address
if(addr.size > HOSTADDR_SIZE_MAX) if(addr.size > HOSTADDR_SIZE_MAX)
addr.size = HOSTADDR_SIZE_MAX; addr.size = HOSTADDR_SIZE_MAX;
@ -372,23 +374,23 @@ static Result(Server*) ClientCLI_saveServerInfo(ClientCLI* self,
Defer(pthread_mutex_unlock(&self->servers_cache_mutex)); Defer(pthread_mutex_unlock(&self->servers_cache_mutex));
// try find server id in cache // try find server id in cache
Server* cached_row_ptr = NULL; ServerInfo* cached_row_ptr = NULL;
u64* id_ptr = NULL; u64* id_ptr = NULL;
id_ptr = HashMap_tryGetPtr(&self->servers_addr_id_map, addr); id_ptr = HashMap_tryGetPtr(&self->servers_addr_id_map, addr);
if(id_ptr){ if(id_ptr){
// update existing server // update existing server
u64 id = *id_ptr; u64 id = *id_ptr;
try_void(idb_updateRow(self->db_servers_table, id, &server)); try_void(idb_updateRow(self->db_servers_table, id, &server));
try_assert(id < List_len(self->servers_cache_list, Server)); try_assert(id < List_len(self->servers_cache_list, ServerInfo));
cached_row_ptr = &List_index(self->servers_cache_list, Server, id); cached_row_ptr = &List_index(self->servers_cache_list, ServerInfo, id);
memcpy(cached_row_ptr, &server, sizeof(Server)); memcpy(cached_row_ptr, &server, sizeof(ServerInfo));
} }
else { else {
// push new server // push new server
try(u64 id, u, idb_pushRow(self->db_servers_table, &server)); try(u64 id, u, idb_pushRow(self->db_servers_table, &server));
try_assert(id == List_len(self->servers_cache_list, Server)); try_assert(id == List_len(self->servers_cache_list, ServerInfo));
List_pushMany(&self->servers_cache_list, Server, &server, 1); List_pushMany(&self->servers_cache_list, ServerInfo, &server, 1);
cached_row_ptr = &List_index(self->servers_cache_list, Server, id); cached_row_ptr = &List_index(self->servers_cache_list, ServerInfo, id);
try_assert(HashMap_tryPush(&self->servers_addr_id_map, addr, &id)); try_assert(HashMap_tryPush(&self->servers_addr_id_map, addr, &id));
} }

View File

@ -11,7 +11,7 @@ typedef struct ClientCLI {
IncrementalDB* db; IncrementalDB* db;
Table* db_servers_table; Table* db_servers_table;
pthread_mutex_t servers_cache_mutex; pthread_mutex_t servers_cache_mutex;
List(Server) servers_cache_list; // index is id List(ServerInfo) servers_cache_list; // index is id
HashMap(u64) servers_addr_id_map; // key is server address HashMap(u64) servers_addr_id_map; // key is server address
} ClientCLI; } ClientCLI;

View File

@ -2,7 +2,7 @@
#include "tcp-chat/common_constants.h" #include "tcp-chat/common_constants.h"
#include "tlibc/time.h" #include "tlibc/time.h"
typedef struct Server { typedef struct ServerInfo {
u16 address_len; u16 address_len;
char address[HOSTADDR_SIZE_MAX + 1]; char address[HOSTADDR_SIZE_MAX + 1];
u32 pk_base64_len; u32 pk_base64_len;
@ -11,4 +11,4 @@ typedef struct Server {
char name[SERVER_NAME_SIZE_MAX + 1]; char name[SERVER_NAME_SIZE_MAX + 1];
u16 desc_len; u16 desc_len;
char desc[SERVER_DESC_SIZE_MAX + 1]; char desc[SERVER_DESC_SIZE_MAX + 1];
} ATTRIBUTE_ALIGNED(16*1024) Server; } ATTRIBUTE_ALIGNED(16*1024) ServerInfo;

View File

@ -33,7 +33,6 @@ Result(void) run_ServerMode(cstr config_path) {
try_void(file_readWhole(config_file, &config_buf)); try_void(file_readWhole(config_file, &config_buf));
Defer(Array_free(config_buf)); Defer(Array_free(config_buf));
str config_str = Array_castTo_str(config_buf, false); str config_str = Array_castTo_str(config_buf, false);
config_buf.data = NULL;
// init server // init server
try(Server* server, p, Server_create(config_str, NULL, log_func)); try(Server* server, p, Server_create(config_str, NULL, log_func));
@ -43,6 +42,7 @@ Result(void) run_ServerMode(cstr config_path) {
file_close(config_file); file_close(config_file);
config_file = NULL; config_file = NULL;
Array_free(config_buf); Array_free(config_buf);
config_buf.data = NULL;
// start infinite loop on main thread // start infinite loop on main thread
try_void(Server_run(server)); try_void(Server_run(server));

View File

@ -1,18 +1,19 @@
#include "client_internal.h" #include "client_internal.h"
#include "requests/requests.h" #include "requests/requests.h"
void ServerConnection_close(ServerConnection* conn){ void ServerConnection_close(ServerConnection* self){
if(!conn) if(!self)
return; return;
RSA_destroyPublicKey(&conn->server_pk); RSA_destroyPublicKey(&self->server_pk);
EncryptedSocketTCP_destroy(&conn->sock); EncryptedSocketTCP_destroy(&self->sock);
Array_free(conn->session_key); Array_free(self->token);
str_free(conn->server_name); Array_free(self->session_key);
str_free(conn->server_description); str_free(self->server_name);
free(conn); str_free(self->server_description);
free(self);
} }
Result(ServerConnection*) ServerConnection_open(cstr server_addr_cstr, cstr server_pk_base64) Result(ServerConnection*) ServerConnection_open(Client* client, cstr server_addr_cstr, cstr server_pk_base64)
{ {
Deferral(16); Deferral(16);
@ -21,6 +22,8 @@ Result(ServerConnection*) ServerConnection_open(cstr server_addr_cstr, cstr serv
bool success = false; bool success = false;
Defer(if(!success) ServerConnection_close(conn)); Defer(if(!success) ServerConnection_close(conn));
conn->client = client;
// TODO: parse domain name and get ip from it // TODO: parse domain name and get ip from it
conn->server_end = EndpointIPv4_INVALID; conn->server_end = EndpointIPv4_INVALID;
try_void(EndpointIPv4_parse(server_addr_cstr, &conn->server_end)); try_void(EndpointIPv4_parse(server_addr_cstr, &conn->server_end));
@ -31,8 +34,21 @@ Result(ServerConnection*) ServerConnection_open(cstr server_addr_cstr, cstr serv
try_void(RSA_parsePublicKey_base64(server_pk_base64, &conn->server_pk)); try_void(RSA_parsePublicKey_base64(server_pk_base64, &conn->server_pk));
RSAEncryptor_construct(&conn->rsa_enc, &conn->server_pk); RSAEncryptor_construct(&conn->rsa_enc, &conn->server_pk);
// lvl 2 hash - is used for authentification
conn->token = Array_alloc(u8, PASSWORD_HASH_SIZE);
// hash user_data_key with server_pk once
Array(u8) server_pk_data = Array_construct_size(conn->server_pk.n,
BR_RSA_KBUF_PUB_SIZE(conn->server_pk.nlen * 8));
u8 server_pk_hash[PASSWORD_HASH_SIZE];
Array(u8) server_pk_hash_array = Array_construct_size(server_pk_hash, PASSWORD_HASH_SIZE);
hash_password(conn->client->user_data_key, server_pk_data,
server_pk_hash, 1);
// hash user_data_key with server_pk_hash
hash_password(conn->token, server_pk_hash_array,
conn->token.data, PASSWORD_HASH_LVL_ROUNDS);
// generate session random AES key
conn->session_key = Array_alloc_size(AES_SESSION_KEY_SIZE); conn->session_key = Array_alloc_size(AES_SESSION_KEY_SIZE);
// generate random session key
br_hmac_drbg_context key_rng = { .vtable = &br_hmac_drbg_vtable }; br_hmac_drbg_context key_rng = { .vtable = &br_hmac_drbg_vtable };
rng_init_sha256_seedFromSystem(&key_rng.vtable); rng_init_sha256_seedFromSystem(&key_rng.vtable);
br_hmac_drbg_generate(&key_rng, conn->session_key.data, conn->session_key.size); br_hmac_drbg_generate(&key_rng, conn->session_key.data, conn->session_key.size);

View File

@ -6,9 +6,8 @@ void Client_free(Client* self){
return; return;
str_free(self->username); str_free(self->username);
Array_free(self->token);
Array_free(self->user_data_key); Array_free(self->user_data_key);
ServerConnection_close(self->server_connection); ServerConnection_close(self->conn);
free(self); free(self);
} }
@ -22,21 +21,10 @@ Result(Client*) Client_create(str username, str password){
self->username = str_copy(username); self->username = str_copy(username);
// concat password and username
List(u8) data_to_hash = List_alloc_size(password.size + username.size + PASSWORD_HASH_SIZE);
Defer(free(data_to_hash.data));
List_push_size(&data_to_hash, password.data, password.size);
List_push_size(&data_to_hash, username.data, username.size);
// lvl 1 hash - is used as AES key for user data // lvl 1 hash - is used as AES key for user data
self->user_data_key = Array_alloc(u8, PASSWORD_HASH_SIZE); self->user_data_key = Array_alloc(u8, PASSWORD_HASH_SIZE);
hash_password(List_castTo_Array(data_to_hash), self->user_data_key.data, PASSWORD_HASH_LVL_ROUNDS); hash_password(str_castTo_Array(password), str_castTo_Array(username),
// concat lvl 1 hash to data_to_hash self->user_data_key.data, PASSWORD_HASH_LVL_ROUNDS);
List_push_size(&data_to_hash, self->user_data_key.data, self->user_data_key.size);
// lvl 2 hash - is used for authentification
self->token = Array_alloc(u8, PASSWORD_HASH_SIZE);
// TODO: generate different token for each server
hash_password(List_castTo_Array(data_to_hash), self->token.data, PASSWORD_HASH_LVL_ROUNDS);
success = true; success = true;
Return RESULT_VALUE(p, self); Return RESULT_VALUE(p, self);
@ -45,14 +33,15 @@ Result(Client*) Client_create(str username, str password){
Result(void) Client_connect(Client* self, cstr server_addr_cstr, cstr server_pk_base64){ Result(void) Client_connect(Client* self, cstr server_addr_cstr, cstr server_pk_base64){
Deferral(8); Deferral(8);
Client_disconnect(self); Client_disconnect(self);
try(self->server_connection, p, try(self->conn, p,
ServerConnection_open(server_addr_cstr, server_pk_base64)); ServerConnection_open(self, server_addr_cstr, server_pk_base64)
);
Return RESULT_VOID; Return RESULT_VOID;
} }
void Client_disconnect(Client* self){ void Client_disconnect(Client* self){
ServerConnection_close(self->server_connection); ServerConnection_close(self->conn);
self->server_connection = NULL; self->conn = NULL;
} }
str Client_getUserName(Client* client){ str Client_getUserName(Client* client){
@ -66,9 +55,9 @@ Array(u8) Client_getUserDataKey(Client* client){
Result(void) Client_getServerName(Client* self, str* out_name){ Result(void) Client_getServerName(Client* self, str* out_name){
Deferral(1); Deferral(1);
try_assert(self != NULL); try_assert(self != NULL);
try_assert(self->server_connection != NULL && "didn't connect to a server yet"); try_assert(self->conn != NULL && "didn't connect to a server yet");
*out_name = self->server_connection->server_name; *out_name = self->conn->server_name;
Return RESULT_VOID; Return RESULT_VOID;
} }
@ -76,9 +65,9 @@ Result(void) Client_getServerName(Client* self, str* out_name){
Result(void) Client_getServerDescription(Client* self, str* out_desc){ Result(void) Client_getServerDescription(Client* self, str* out_desc){
Deferral(1); Deferral(1);
try_assert(self != NULL); try_assert(self != NULL);
try_assert(self->server_connection != NULL && "didn't connect to a server yet"); try_assert(self->conn != NULL && "didn't connect to a server yet");
*out_desc = self->server_connection->server_description; *out_desc = self->conn->server_description;
Return RESULT_VOID; Return RESULT_VOID;
} }
@ -86,16 +75,16 @@ Result(void) Client_getServerDescription(Client* self, str* out_desc){
Result(void) Client_register(Client* self, u64* out_user_id){ Result(void) Client_register(Client* self, u64* out_user_id){
Deferral(1); Deferral(1);
try_assert(self != NULL); try_assert(self != NULL);
try_assert(self->server_connection != NULL && "didn't connect to a server yet"); try_assert(self->conn != NULL && "didn't connect to a server yet");
PacketHeader req_head, res_head; PacketHeader req_head, res_head;
RegisterRequest req; RegisterRequest req;
RegisterResponse res; RegisterResponse res;
// TODO: hash token with server public key // TODO: hash token with server public key
try_void(RegisterRequest_tryConstruct(&req, &req_head, self->username, self->token)); try_void(RegisterRequest_tryConstruct(&req, &req_head, self->username, self->conn->token));
try_void(sendRequest(&self->server_connection->sock, &req_head, &req)); try_void(sendRequest(&self->conn->sock, &req_head, &req));
try_void(recvResponse(&self->server_connection->sock, &res_head, &res, PacketType_RegisterResponse)); try_void(recvResponse(&self->conn->sock, &res_head, &res, PacketType_RegisterResponse));
self->server_connection->user_id = res.user_id; self->conn->user_id = res.user_id;
*out_user_id = res.user_id; *out_user_id = res.user_id;
Return RESULT_VOID; Return RESULT_VOID;
@ -104,16 +93,16 @@ Result(void) Client_register(Client* self, u64* out_user_id){
Result(void) Client_login(Client* self, u64* out_user_id, u64* out_landing_channel_id){ Result(void) Client_login(Client* self, u64* out_user_id, u64* out_landing_channel_id){
Deferral(1); Deferral(1);
try_assert(self != NULL); try_assert(self != NULL);
try_assert(self->server_connection != NULL && "didn't connect to a server yet"); try_assert(self->conn != NULL && "didn't connect to a server yet");
PacketHeader req_head, res_head; PacketHeader req_head, res_head;
LoginRequest req; LoginRequest req;
LoginResponse res; LoginResponse res;
// TODO: hash token with server public key // TODO: hash token with server public key
try_void(LoginRequest_tryConstruct(&req, &req_head, self->username, self->token)); try_void(LoginRequest_tryConstruct(&req, &req_head, self->username, self->conn->token));
try_void(sendRequest(&self->server_connection->sock, &req_head, &req)); try_void(sendRequest(&self->conn->sock, &req_head, &req));
try_void(recvResponse(&self->server_connection->sock, &res_head, &res, PacketType_LoginResponse)); try_void(recvResponse(&self->conn->sock, &res_head, &res, PacketType_LoginResponse));
self->server_connection->user_id = res.user_id; self->conn->user_id = res.user_id;
*out_user_id = res.user_id; *out_user_id = res.user_id;
*out_landing_channel_id = res.landing_channel_id; *out_landing_channel_id = res.landing_channel_id;

View File

@ -9,18 +9,19 @@ typedef struct ServerConnection ServerConnection;
typedef struct Client { typedef struct Client {
str username; str username;
Array(u8) user_data_key; Array(u8) user_data_key;
Array(u8) token; ServerConnection* conn;
ServerConnection* server_connection;
} Client; } Client;
typedef struct ServerConnection { typedef struct ServerConnection {
Client* client;
EndpointIPv4 server_end; EndpointIPv4 server_end;
br_rsa_public_key server_pk; br_rsa_public_key server_pk;
RSAEncryptor rsa_enc; RSAEncryptor rsa_enc;
u64 session_id; Array(u8) token;
Array(u8) session_key; Array(u8) session_key;
EncryptedSocketTCP sock; EncryptedSocketTCP sock;
u64 session_id;
str server_name; str server_name;
str server_description; str server_description;
u64 user_id; u64 user_id;
@ -28,10 +29,13 @@ typedef struct ServerConnection {
/// @param server_addr_cstr /// @param server_addr_cstr
/// @param server_pk_base64 public key encoded by `RSA_serializePublicKey_base64()` /// @param server_pk_base64 public key encoded by `RSA_serializePublicKey_base64()`
Result(ServerConnection*) ServerConnection_open(cstr server_addr_cstr, cstr server_pk_base64); Result(ServerConnection*) ServerConnection_open(Client* client,
void ServerConnection_close(ServerConnection* conn); cstr server_addr_cstr, cstr server_pk_base64);
void ServerConnection_close(ServerConnection* conn);
/// updates conn->server_name /// updates conn->server_name
Result(void) ServerConnection_requestServerName(ServerConnection* conn); Result(void) ServerConnection_requestServerName(ServerConnection* conn);
/// updates conn->server_description /// updates conn->server_description
Result(void) ServerConnection_requestServerDescription(ServerConnection* conn); Result(void) ServerConnection_requestServerDescription(ServerConnection* conn);

View File

@ -13,9 +13,10 @@
/// @brief hashes password multiple times using its own hash as salt /// @brief hashes password multiple times using its own hash as salt
/// @param password some byte array /// @param password some byte array
/// @param salt some byte array
/// @param out_buffer u8[PASSWORD_HASH_SIZE] /// @param out_buffer u8[PASSWORD_HASH_SIZE]
/// @param rounds number of rounds /// @param rounds number of rounds
void hash_password(Array(u8) password, u8* out_buffer, i32 rounds); void hash_password(Array(u8) password, Array(u8) salt, u8* out_buffer, i32 rounds);
#define PASSWORD_HASH_LVL_ROUNDS 1e5 #define PASSWORD_HASH_LVL_ROUNDS 1e5
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
@ -27,6 +28,7 @@ void hash_password(Array(u8) password, u8* out_buffer, i32 rounds);
/// @brief Initialize prng context with sha256 hashing algorithm /// @brief Initialize prng context with sha256 hashing algorithm
/// and seed from system-provided cryptographic random bytes source. /// and seed from system-provided cryptographic random bytes source.
/// @param rng_vtable_ptr pointer to vtable field in prng context. The field must be initialized. /// @param rng_vtable_ptr pointer to vtable field in prng context. The field must be initialized.
///
/// EXAMPLE: /// EXAMPLE:
/// ``` /// ```
/// br_hmac_drbg_context rng_ctx = { .vtable = &br_hmac_drbg_vtable }; /// br_hmac_drbg_context rng_ctx = { .vtable = &br_hmac_drbg_vtable };
@ -36,6 +38,7 @@ void rng_init_sha256_seedFromSystem(const br_prng_class** rng_vtable_ptr);
/// @brief Initialize prng context with sha256 hashing algorithm and seed from CLOCK_REALTIME. /// @brief Initialize prng context with sha256 hashing algorithm and seed from CLOCK_REALTIME.
/// @param rng_vtable_ptr pointer to vtable field in prng context. The field must be initialized. /// @param rng_vtable_ptr pointer to vtable field in prng context. The field must be initialized.
///
/// EXAMPLE: /// EXAMPLE:
/// ``` /// ```
/// br_hmac_drbg_context rng_ctx = { .vtable = &br_hmac_drbg_vtable }; /// br_hmac_drbg_context rng_ctx = { .vtable = &br_hmac_drbg_vtable };

View File

@ -1,7 +1,7 @@
#include "cryptography.h" #include "cryptography.h"
#include "assert.h" #include "assert.h"
void hash_password(Array(u8) password, u8* out_buffer, i32 iterations){ void hash_password(Array(u8) password, Array(u8) salt, u8* out_buffer, i32 iterations){
assert(PASSWORD_HASH_SIZE == br_sha256_SIZE);; assert(PASSWORD_HASH_SIZE == br_sha256_SIZE);;
memset(out_buffer, 0, br_sha256_SIZE); memset(out_buffer, 0, br_sha256_SIZE);
br_sha256_context sha256_ctx; br_sha256_context sha256_ctx;
@ -9,6 +9,7 @@ void hash_password(Array(u8) password, u8* out_buffer, i32 iterations){
for(i32 i = 0; i < iterations; i++){ for(i32 i = 0; i < iterations; i++){
br_sha256_update(&sha256_ctx, password.data, password.size); br_sha256_update(&sha256_ctx, password.data, password.size);
br_sha256_update(&sha256_ctx, salt.data, salt.size);
br_sha256_out(&sha256_ctx, out_buffer); br_sha256_out(&sha256_ctx, out_buffer);
br_sha256_update(&sha256_ctx, out_buffer, PASSWORD_HASH_SIZE); br_sha256_update(&sha256_ctx, out_buffer, PASSWORD_HASH_SIZE);
} }

View File

@ -18,6 +18,7 @@ Result(ClientConnection*) ClientConnection_accept(ConnectionHandlerArgs* args)
bool success = false; bool success = false;
Defer(if(!success) ClientConnection_close(conn)); Defer(if(!success) ClientConnection_close(conn));
conn->server = args->server;
conn->client_end = args->client_end; conn->client_end = args->client_end;
conn->session_id = args->session_id; conn->session_id = args->session_id;
conn->authorized = false; conn->authorized = false;

View File

@ -2,16 +2,16 @@
#include "tcp-chat/common_constants.h" #include "tcp-chat/common_constants.h"
#include "tlibc/time.h" #include "tlibc/time.h"
typedef struct User { typedef struct UserInfo {
u16 name_len; u16 name_len;
char name[USERNAME_SIZE_MAX + 1]; // null-terminated char name[USERNAME_SIZE_MAX + 1]; // null-terminated
u8 token_hash[PASSWORD_HASH_SIZE]; // token is hashed again on server side u8 token[PASSWORD_HASH_SIZE]; // token is hashed again on server side
DateTime registration_time; DateTime registration_time;
} ATTRIBUTE_ALIGNED(256) User; } ATTRIBUTE_ALIGNED(256) UserInfo;
typedef struct Channel { typedef struct ChannelInfo {
u16 name_len; u16 name_len;
char name[CHANNEL_NAME_SIZE_MAX + 1]; char name[CHANNEL_NAME_SIZE_MAX + 1];
u16 desc_len; u16 desc_len;
char desc[CHANNEL_DESC_SIZE_MAX + 1]; char desc[CHANNEL_DESC_SIZE_MAX + 1];
} ATTRIBUTE_ALIGNED(4*1024) Channel; } ATTRIBUTE_ALIGNED(4*1024) ChannelInfo;

View File

@ -1,7 +1,7 @@
#include "request_handlers.h" #include "request_handlers.h"
#define LOGGER server->logger #define LOGGER conn->server->logger
#define LOG_FUNC server->log_func #define LOG_FUNC conn->server->log_func
declare_RequestHandler(Login) declare_RequestHandler(Login)
{ {
@ -14,7 +14,7 @@ declare_RequestHandler(Login)
try_void(EncryptedSocketTCP_recvStruct(&conn->sock, &req)); try_void(EncryptedSocketTCP_recvStruct(&conn->sock, &req));
if(conn->authorized){ if(conn->authorized){
try_void(sendErrorMessage(server, log_ctx, conn, res_head, try_void(sendErrorMessage(log_ctx, conn, res_head,
LogSeverity_Warn, LogSeverity_Warn,
STR("is authorized in already") STR("is authorized in already")
)); ));
@ -25,30 +25,25 @@ declare_RequestHandler(Login)
str username_str = str_null; str username_str = str_null;
str username_check_error = validateUsername_cstr(req.username, &username_str); str username_check_error = validateUsername_cstr(req.username, &username_str);
if(username_check_error.data){ if(username_check_error.data){
try_void(sendErrorMessage(server, log_ctx, conn, res_head, try_void(sendErrorMessage(log_ctx, conn, res_head,
LogSeverity_Warn, LogSeverity_Warn,
username_check_error username_check_error
)); ));
Return RESULT_VOID; Return RESULT_VOID;
} }
// cakculate hash of received token
u8 token_hash[PASSWORD_HASH_SIZE];
hash_password(
Array_construct_size(req.token, sizeof(req.token)),
token_hash,
PASSWORD_HASH_LVL_ROUNDS
);
// lock users cache // lock users cache
try_stderrcode(pthread_mutex_lock(&server->users_cache_mutex)); try_stderrcode(pthread_mutex_lock(&conn->server->users_cache_mutex));
bool unlocked_users_cache_mutex = false; bool unlocked_users_cache_mutex = false;
Defer(if(!unlocked_users_cache_mutex) pthread_mutex_unlock(&server->users_cache_mutex)); Defer(
if(!unlocked_users_cache_mutex)
pthread_mutex_unlock(&conn->server->users_cache_mutex)
);
// try get id from name cache // try get id from name cache
u64* id_ptr = HashMap_tryGetPtr(&server->users_name_id_map, username_str); u64* id_ptr = HashMap_tryGetPtr(&conn->server->users_name_id_map, username_str);
if(id_ptr == NULL){ if(id_ptr == NULL){
try_void(sendErrorMessage_f(server, log_ctx, conn, res_head, try_void(sendErrorMessage_f(log_ctx, conn, res_head,
LogSeverity_Warn, LogSeverity_Warn,
"Username '%s' is not registered", "Username '%s' is not registered",
username_str.data username_str.data
@ -58,12 +53,12 @@ declare_RequestHandler(Login)
u64 user_id = *id_ptr; u64 user_id = *id_ptr;
// get user by id // get user by id
try_assert(user_id < List_len(server->users_cache_list, User)); try_assert(user_id < List_len(conn->server->users_cache_list, UserInfo));
User* u = &List_index(server->users_cache_list, User, user_id); UserInfo* u = &List_index(conn->server->users_cache_list, UserInfo, user_id);
// validate token hash // validate token hash
if(memcmp(token_hash, u->token_hash, sizeof(token_hash)) != 0){ if(memcmp(req.token, u->token, sizeof(req.token)) != 0){
try_void(sendErrorMessage(server, log_ctx, conn, res_head, try_void(sendErrorMessage(log_ctx, conn, res_head,
LogSeverity_Warn, LogSeverity_Warn,
STR("wrong password") STR("wrong password")
)); ));
@ -71,7 +66,7 @@ declare_RequestHandler(Login)
} }
// manually unlock mutex // manually unlock mutex
pthread_mutex_unlock(&server->users_cache_mutex); pthread_mutex_unlock(&conn->server->users_cache_mutex);
unlocked_users_cache_mutex = true; unlocked_users_cache_mutex = true;
// authorize // authorize
@ -80,7 +75,7 @@ declare_RequestHandler(Login)
// send response // send response
LoginResponse res; LoginResponse res;
LoginResponse_construct(&res, res_head, user_id, server->landing_channel_id); LoginResponse_construct(&res, res_head, user_id, conn->server->landing_channel_id);
try_void(EncryptedSocketTCP_sendStruct(&conn->sock, res_head)); try_void(EncryptedSocketTCP_sendStruct(&conn->sock, res_head));
try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &res)); try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &res));

View File

@ -1,7 +1,7 @@
#include "request_handlers.h" #include "request_handlers.h"
#define LOGGER server->logger #define LOGGER conn->server->logger
#define LOG_FUNC server->log_func #define LOG_FUNC conn->server->log_func
declare_RequestHandler(Register) declare_RequestHandler(Register)
{ {
@ -14,7 +14,7 @@ declare_RequestHandler(Register)
try_void(EncryptedSocketTCP_recvStruct(&conn->sock, &req)); try_void(EncryptedSocketTCP_recvStruct(&conn->sock, &req));
if(conn->authorized){ if(conn->authorized){
try_void(sendErrorMessage(server, log_ctx, conn, res_head, try_void(sendErrorMessage(log_ctx, conn, res_head,
LogSeverity_Warn, LogSeverity_Warn,
STR("is authorized in already") STR("is authorized in already")
)); ));
@ -25,7 +25,7 @@ declare_RequestHandler(Register)
str username_str = str_null; str username_str = str_null;
str username_check_error = validateUsername_cstr(req.username, &username_str); str username_check_error = validateUsername_cstr(req.username, &username_str);
if(username_check_error.data){ if(username_check_error.data){
try_void(sendErrorMessage(server, log_ctx, conn, res_head, try_void(sendErrorMessage(log_ctx, conn, res_head,
LogSeverity_Warn, LogSeverity_Warn,
username_check_error username_check_error
)); ));
@ -33,17 +33,17 @@ declare_RequestHandler(Register)
} }
// lock users cache // lock users cache
try_stderrcode(pthread_mutex_lock(&server->users_cache_mutex)); try_stderrcode(pthread_mutex_lock(&conn->server->users_cache_mutex));
bool unlocked_users_cache_mutex = false; bool unlocked_users_cache_mutex = false;
// unlock mutex on error catch // unlock mutex on error catch
Defer( Defer(
if(!unlocked_users_cache_mutex) if(!unlocked_users_cache_mutex)
pthread_mutex_unlock(&server->users_cache_mutex) pthread_mutex_unlock(&conn->server->users_cache_mutex)
); );
// check if name is taken // check if name is taken
if(HashMap_tryGetPtr(&server->users_name_id_map, username_str) != NULL){ if(HashMap_tryGetPtr(&conn->server->users_name_id_map, username_str) != NULL){
try_void(sendErrorMessage_f(server, log_ctx, conn, res_head, try_void(sendErrorMessage_f(log_ctx, conn, res_head,
LogSeverity_Warn, LogSeverity_Warn,
"Username '%s' already exists", "Username '%s' already exists",
username_str.data)); username_str.data));
@ -51,28 +51,24 @@ declare_RequestHandler(Register)
} }
// initialize new user // initialize new user
User user; UserInfo user;
memset(&user, 0, sizeof(User)); memset(&user, 0, sizeof(UserInfo));
user.name_len = username_str.size; user.name_len = username_str.size;
memcpy(user.name, username_str.data, user.name_len); memcpy(user.name, username_str.data, user.name_len);
hash_password( memcpy(user.token, req.token, sizeof(req.token));
Array_construct_size(req.token, sizeof(req.token)),
user.token_hash,
PASSWORD_HASH_LVL_ROUNDS
);
DateTime_getUTC(&user.registration_time); DateTime_getUTC(&user.registration_time);
// save new user to db and cache // save new user to db and cache
try(u64 user_id, u, idb_pushRow(server->db_users_table, &user)); try(u64 user_id, u, idb_pushRow(conn->server->db_users_table, &user));
try_assert(user_id == List_len(server->users_cache_list, User)); try_assert(user_id == List_len(conn->server->users_cache_list, UserInfo));
List_pushMany(&server->users_cache_list, User, &user, 1); List_pushMany(&conn->server->users_cache_list, UserInfo, &user, 1);
try_assert(HashMap_tryPush(&server->users_name_id_map, username_str, &user_id)); try_assert(HashMap_tryPush(&conn->server->users_name_id_map, username_str, &user_id));
// manually unlock mutex // manually unlock mutex
pthread_mutex_unlock(&server->users_cache_mutex); pthread_mutex_unlock(&conn->server->users_cache_mutex);
unlocked_users_cache_mutex = true; unlocked_users_cache_mutex = true;
logInfo(log_ctx, "registered user '%s'", username_str.data); logInfo(log_ctx, "registered user '%s'", username_str.data);

View File

@ -1,7 +1,7 @@
#include "request_handlers.h" #include "request_handlers.h"
#define LOGGER server->logger #define LOGGER conn->server->logger
#define LOG_FUNC server->log_func #define LOG_FUNC conn->server->log_func
declare_RequestHandler(ServerPublicInfo) declare_RequestHandler(ServerPublicInfo)
{ {
@ -17,17 +17,17 @@ declare_RequestHandler(ServerPublicInfo)
Array(u8) content; Array(u8) content;
switch(req.property){ switch(req.property){
default:{ default:{
try_void(sendErrorMessage_f(server, log_ctx, conn, res_head, try_void(sendErrorMessage_f(log_ctx, conn, res_head,
LogSeverity_Warn, LogSeverity_Warn,
"Unknown ServerPublicInfo property %u", "Unknown ServerPublicInfo property %u",
req.property)); req.property));
Return RESULT_VOID; Return RESULT_VOID;
} }
case ServerPublicInfo_Name: case ServerPublicInfo_Name:
content = str_castTo_Array(server->name); content = str_castTo_Array(conn->server->name);
break; break;
case ServerPublicInfo_Description: case ServerPublicInfo_Description:
content = str_castTo_Array(server->description); content = str_castTo_Array(conn->server->description);
break; break;
} }

View File

@ -4,29 +4,26 @@
Result(void) sendErrorMessage( Result(void) sendErrorMessage(
Server* server, cstr log_ctx, cstr log_ctx, ClientConnection* conn, PacketHeader* res_head,
ClientConnection* conn, PacketHeader* res_head,
LogSeverity log_severity, str msg); LogSeverity log_severity, str msg);
Result(void) __sendErrorMessage_fv( Result(void) __sendErrorMessage_fv(
Server* server, cstr log_ctx, cstr log_ctx, ClientConnection* conn, PacketHeader* res_head,
ClientConnection* conn, PacketHeader* res_head,
LogSeverity log_severity, cstr format, va_list argv); LogSeverity log_severity, cstr format, va_list argv);
Result(void) sendErrorMessage_f( Result(void) sendErrorMessage_f(
Server* server, cstr log_ctx, cstr log_ctx, ClientConnection* conn, PacketHeader* res_head,
ClientConnection* conn, PacketHeader* res_head, LogSeverity log_severity, cstr format, ...) ATTRIBUTE_CHECK_FORMAT_PRINTF(5, 6);
LogSeverity log_severity, cstr format, ...) ATTRIBUTE_CHECK_FORMAT_PRINTF(6, 7);
#define declare_RequestHandler(TYPE) \ #define declare_RequestHandler(TYPE) \
Result(void) handleRequest_##TYPE(\ Result(void) handleRequest_##TYPE(\
Server* server, cstr log_ctx, cstr req_type_name, \ cstr log_ctx, cstr req_type_name, \
ClientConnection* conn, PacketHeader* req_head, PacketHeader* res_head) ClientConnection* conn, PacketHeader* req_head, PacketHeader* res_head)
#define case_handleRequest(TYPE) \ #define case_handleRequest(TYPE) \
case PacketType_##TYPE##Request:\ case PacketType_##TYPE##Request:\
try_void(handleRequest_##TYPE(args->server, log_ctx, #TYPE, conn, &req_head, &res_head));\ try_void(handleRequest_##TYPE(log_ctx, #TYPE, conn, &req_head, &res_head));\
break; break;
declare_RequestHandler(ServerPublicInfo); declare_RequestHandler(ServerPublicInfo);

View File

@ -1,11 +1,10 @@
#include "request_handlers.h" #include "request_handlers.h"
#define LOGGER server->logger #define LOGGER conn->server->logger
#define LOG_FUNC server->log_func #define LOG_FUNC conn->server->log_func
Result(void) sendErrorMessage( Result(void) sendErrorMessage(
Server* server, cstr log_ctx, cstr log_ctx, ClientConnection* conn, PacketHeader* res_head,
ClientConnection* conn, PacketHeader* res_head,
LogSeverity log_severity, str msg) LogSeverity log_severity, str msg)
{ {
Deferral(1); Deferral(1);
@ -26,21 +25,20 @@ Result(void) sendErrorMessage(
} }
Result(void) __sendErrorMessage_fv( Result(void) __sendErrorMessage_fv(
Server* server, cstr log_ctx, cstr log_ctx, ClientConnection* conn, PacketHeader* res_head,
ClientConnection* conn, PacketHeader* res_head,
LogSeverity log_severity, cstr format, va_list argv) LogSeverity log_severity, cstr format, va_list argv)
{ {
Deferral(4); Deferral(4);
str msg = str_from_cstr(vsprintf_malloc(format, argv)); str msg = str_from_cstr(vsprintf_malloc(format, argv));
Defer(free(msg.data)); Defer(free(msg.data));
try_void(sendErrorMessage(server, log_ctx, conn, res_head, log_severity, msg)); try_void(sendErrorMessage(log_ctx, conn, res_head, log_severity, msg));
Return RESULT_VOID; Return RESULT_VOID;
} }
Result(void) sendErrorMessage_f( Result(void) sendErrorMessage_f(
Server* server, cstr log_ctx, cstr log_ctx,
ClientConnection* conn, PacketHeader* res_head, ClientConnection* conn, PacketHeader* res_head,
LogSeverity log_severity, cstr format, ...) LogSeverity log_severity, cstr format, ...)
{ {
@ -49,7 +47,7 @@ Result(void) sendErrorMessage_f(
va_list argv; va_list argv;
va_start(argv, format); va_start(argv, format);
Defer(va_end(argv)); Defer(va_end(argv));
try_void(__sendErrorMessage_fv(server, log_ctx, conn, res_head, log_severity, format, argv)); try_void(__sendErrorMessage_fv(log_ctx, conn, res_head, log_severity, format, argv));
Return RESULT_VOID; Return RESULT_VOID;
} }

View File

@ -1,7 +1,7 @@
#include "request_handlers.h" #include "request_handlers.h"
#define LOGGER server->logger #define LOGGER conn->server->logger
#define LOG_FUNC server->log_func #define LOG_FUNC conn->server->log_func
declare_RequestHandler(NAME) declare_RequestHandler(NAME)
{ {

View File

@ -23,6 +23,8 @@ void Server_free(Server* self){
pthread_mutex_destroy(&self->users_cache_mutex); pthread_mutex_destroy(&self->users_cache_mutex);
List_destroy(self->users_cache_list); List_destroy(self->users_cache_list);
HashMap_destroy(&self->users_name_id_map); HashMap_destroy(&self->users_name_id_map);
free(self);
} }
@ -91,16 +93,16 @@ Result(Server*) Server_create(str config_str, void* logger, LogFunction_t log_fu
// build users cache // build users cache
logDebug(log_ctx, "loading users..."); logDebug(log_ctx, "loading users...");
pthread_mutex_init(&self->users_cache_mutex, NULL); pthread_mutex_init(&self->users_cache_mutex, NULL);
try(self->db_users_table, p, idb_getOrCreateTable(self->db, STR("users"), sizeof(User))); try(self->db_users_table, p, idb_getOrCreateTable(self->db, STR("users"), sizeof(UserInfo)));
// load whole table to list // load whole table to list
try(u64 users_count, u, idb_getRowCount(self->db_users_table)); try(u64 users_count, u, idb_getRowCount(self->db_users_table));
self->users_cache_list = List_alloc(User, users_count); self->users_cache_list = List_alloc(UserInfo, users_count);
try_void(idb_getRows(self->db_users_table, 0, self->users_cache_list.data, users_count)); try_void(idb_getRows(self->db_users_table, 0, self->users_cache_list.data, users_count));
self->users_cache_list.size = sizeof(User) * users_count; self->users_cache_list.size = sizeof(UserInfo) * users_count;
// build name-id map // build name-id map
HashMap_construct(&self->users_name_id_map, u64, NULL); HashMap_construct(&self->users_name_id_map, u64, NULL);
for(u64 id = 0; id < users_count; id++){ for(u64 id = 0; id < users_count; id++){
User* row = &List_index(self->users_cache_list, User, id); UserInfo* row = &List_index(self->users_cache_list, UserInfo, id);
str key = str_construct(row->name, row->name_len, true); str key = str_construct(row->name, row->name_len, true);
if(!HashMap_tryPush(&self->users_name_id_map, key, &id)){ if(!HashMap_tryPush(&self->users_name_id_map, key, &id)){
Return RESULT_ERROR_FMT("duplicate user name '"FMT_str"'", key.size, key.data); Return RESULT_ERROR_FMT("duplicate user name '"FMT_str"'", key.size, key.data);
@ -191,7 +193,7 @@ static Result(void) try_handleConnection(ConnectionHandlerArgs* args, cstr log_c
switch(req_head.type){ switch(req_head.type){
// send error message and close connection // send error message and close connection
default: default:
try_void(sendErrorMessage_f(server, log_ctx, conn, &res_head, try_void(sendErrorMessage_f(log_ctx, conn, &res_head,
LogSeverity_Error, LogSeverity_Error,
"Received unexpected packet of type %u", "Received unexpected packet of type %u",
req_head.type)); req_head.type));

View File

@ -28,12 +28,13 @@ typedef struct Server {
IncrementalDB* db; IncrementalDB* db;
Table* db_users_table; Table* db_users_table;
pthread_mutex_t users_cache_mutex; pthread_mutex_t users_cache_mutex;
List(User) users_cache_list; // index is id List(UserInfo) users_cache_list; // index is id
HashMap(u64) users_name_id_map; // key is user name HashMap(u64) users_name_id_map; // key is user name
} Server; } Server;
typedef struct ClientConnection { typedef struct ClientConnection {
Server* server;
u64 session_id; u64 session_id;
EndpointIPv4 client_end; EndpointIPv4 client_end;
Array(u8) session_key; Array(u8) session_key;