diff --git a/dependencies/tlibc b/dependencies/tlibc index d6436d0..2c8e6fc 160000 --- a/dependencies/tlibc +++ b/dependencies/tlibc @@ -1 +1 @@ -Subproject commit d6436d08338a0a762e727f0c816dd5a09782b180 +Subproject commit 2c8e6fc601a868851d8ce50f77b391e6a9b7e656 diff --git a/src/client/ClientCredential.c b/src/client/ClientCredential.c deleted file mode 100644 index e171c7e..0000000 --- a/src/client/ClientCredential.c +++ /dev/null @@ -1,49 +0,0 @@ -#include "client.h" -#include "tlibc/string/StringBuilder.h" - -void ClientCredentials_free(ClientCredentials* cred){ - if(cred == NULL) - return; - free(cred->username.data); - free(cred->aes_key.data); - free(cred); -} - - -Result(ClientCredentials*) ClientCredentials_create(str username, str password){ - Deferral(8); - ClientCredentials* cred = (ClientCredentials*)malloc(sizeof(ClientCredentials)); - memset(cred, 0, sizeof(ClientCredentials)); - bool success = false; - Defer( - if(!success) - free(cred); - ); - - cred->username = str_copy(username); - Defer( - if(!success) - free(cred->username.data); - ); - - // concat password and username - StringBuilder sb = StringBuilder_alloc(username.size + password.size + 1); - Defer(StringBuilder_destroy(&sb)); - StringBuilder_append_str(&sb, password); - StringBuilder_append_str(&sb, username); - Array(u8) password_and_username = str_castTo_Array(StringBuilder_getStr(&sb)); - cred->aes_key = Array_alloc(u8, PASSWORD_HASH_SIZE); - Defer( - if(!success){ - free(cred->aes_key.data); - } - ); - // lvl 1 hash - is used as AES key for user data - hash_password(password_and_username, cred->aes_key.data, __PASSWORD_HASH_LVL_ITERATIONS); - - AESBlockEncryptor_construct(&cred->user_data_aes_enc, cred->aes_key, AESBlockEncryptor_DEFAULT_CLASS); - AESBlockDecryptor_construct(&cred->user_data_aes_dec, cred->aes_key, AESBlockDecryptor_DEFAULT_CLASS); - - success = true; - Return RESULT_VALUE(p, cred); -} diff --git a/src/client/ClientCredentials.c b/src/client/ClientCredentials.c new file mode 100644 index 0000000..433d97c --- /dev/null +++ b/src/client/ClientCredentials.c @@ -0,0 +1,43 @@ +#include "client.h" +#include "tlibc/string/StringBuilder.h" + +void ClientCredentials_destroy(ClientCredentials* cred){ + if(!cred) + return; + free(cred->username.data); + free(cred->user_data_key.data); + free(cred->token.data); +} + + +Result(void) ClientCredentials_tryConstruct(ClientCredentials* cred, + str username, str password) +{ + Deferral(8); + + memset(cred, 0, sizeof(ClientCredentials)); + bool success = false; + Defer(if(!success) ClientCredentials_destroy(cred)); + + cred->username = str_copy(username); + + // concat password and username + StringBuilder sb = StringBuilder_alloc(username.size + password.size + 1); + Defer(StringBuilder_destroy(&sb)); + StringBuilder_append_str(&sb, password); + StringBuilder_append_str(&sb, username); + Array(u8) password_and_username = str_castTo_Array(StringBuilder_getStr(&sb)); + + // lvl 1 hash - is used as AES key for user data + cred->user_data_key = Array_alloc(u8, PASSWORD_HASH_SIZE); + hash_password(password_and_username, cred->user_data_key.data, __PASSWORD_HASH_LVL_ITERATIONS); + // lvl 2 hash - is used for authentification + cred->token = Array_alloc(u8, PASSWORD_HASH_SIZE); + hash_password(cred->user_data_key, cred->token.data, __PASSWORD_HASH_LVL_ITERATIONS); + + AESBlockEncryptor_construct(&cred->user_data_aes_enc, cred->user_data_key, AESBlockEncryptor_DEFAULT_CLASS); + AESBlockDecryptor_construct(&cred->user_data_aes_dec, cred->user_data_key, AESBlockDecryptor_DEFAULT_CLASS); + + success = true; + Return RESULT_VOID; +} diff --git a/src/client/ServerConnection.c b/src/client/ServerConnection.c index a26ecc8..4f6ee88 100644 --- a/src/client/ServerConnection.c +++ b/src/client/ServerConnection.c @@ -2,7 +2,7 @@ #include "network/tcp-chat-protocol/v1.h" void ServerConnection_close(ServerConnection* conn){ - if(conn == NULL) + if(!conn) return; RSA_destroyPublicKey(&conn->server_pk); EncryptedSocketTCP_destroy(&conn->sock); @@ -42,15 +42,12 @@ Result(void) ServerLink_parse(cstr server_link_cstr, EndpointIPv4* server_end_ou } Result(ServerConnection*) ServerConnection_open(ClientCredentials* client_credentials, cstr server_link_cstr){ - Deferral(64); + Deferral(16); ServerConnection* conn = (ServerConnection*)malloc(sizeof(ServerConnection)); memset(conn, 0, sizeof(*conn)); bool success = false; - Defer( - if(!success) - ServerConnection_close(conn); - ); + Defer(if(!success) ServerConnection_close(conn)); try_void(ServerLink_parse(server_link_cstr, &conn->server_end, &conn->server_pk)); RSAEncryptor_construct(&conn->rsa_enc, &conn->server_pk); diff --git a/src/client/client.c b/src/client/client.c index 81d8920..7b4758e 100644 --- a/src/client/client.c +++ b/src/client/client.c @@ -17,12 +17,30 @@ static const str farewell_art = STR( "\\(_,J J L l`,)/\n" ); -static ClientCredentials* _client_credentials = NULL; -static ServerConnection* _server_connection = NULL; +Result(void) Client_createFromConfig(cstr config_path){ + Deferral(16); -static Result(void) commandExec(str command, bool* stop); + Client* client = (Client*)malloc(sizeof(Client)); + memset(client, 0, sizeof(Client)); + bool success = false; + Defer(if(!success) Client_free(client)); -static Result(void) askUserNameAndPassword(ClientCredentials** cred){ + success = true; + Return RESULT_VALUE(p, client); +} + +void Client_free(Client* client){ + if(!client) + return; + + ClientCredentials_destroy(&client->cred); + ServerConnection_close(client->server_connection); + free(client); +} + +static Result(void) commandExec(Client* client, str command, bool* stop); + +static Result(void) askUserNameAndPassword(ClientCredentials* cred){ Deferral(8); char username_buf[128]; @@ -33,6 +51,7 @@ static Result(void) askUserNameAndPassword(ClientCredentials** cred){ Return RESULT_ERROR("STDIN is closed", false); } username = str_from_cstr(username_buf); + str_trim(&username, true); if(username.size < USERNAME_SIZE_MIN || username.size > USERNAME_SIZE_MAX){ printf("ERROR: username length (in bytes) must be >= %i and <= %i\n", USERNAME_SIZE_MIN, USERNAME_SIZE_MAX); @@ -49,6 +68,7 @@ static Result(void) askUserNameAndPassword(ClientCredentials** cred){ Return RESULT_ERROR("STDIN is closed", false); } password = str_from_cstr(password_buf); + str_trim(&password, true); if(password.size < PASSWORD_SIZE_MIN || password.size > PASSWORD_SIZE_MAX){ printf("ERROR: password length (in bytes) must be >= %i and <= %i\n", PASSWORD_SIZE_MIN, PASSWORD_SIZE_MAX); @@ -56,23 +76,18 @@ static Result(void) askUserNameAndPassword(ClientCredentials** cred){ else break; } - try(*cred, p, ClientCredentials_create(username, password)); + try_void(ClientCredentials_tryConstruct(cred, username, password)); Return RESULT_VOID; } -Result(void) client_run() { - Deferral(32); +Result(void) Client_run(Client* client) { + Deferral(16); if(!term_init()){ Return RESULT_ERROR("can't init terminal", false); } fputs(greeting_art.data, stdout); - - Defer( - ClientCredentials_free(_client_credentials); - ServerConnection_close(_server_connection); - ); - try_void(askUserNameAndPassword(&_client_credentials)); + try_void(askUserNameAndPassword(&client->cred)); Array(char) input_buf = Array_alloc(char, 10000); Defer(free(input_buf.data)); @@ -90,7 +105,7 @@ Result(void) client_run() { if(command_input.size == 0) continue; - ResultVar(void) com_result = commandExec(command_input, &stop); + ResultVar(void) com_result = commandExec(client, command_input, &stop); if(com_result.error){ str e_str = Error_toStr(com_result.error); printf("%s\n", e_str.data); @@ -104,7 +119,7 @@ Result(void) client_run() { #define is_alias(LITERAL) str_equals(command, STR(LITERAL)) -static Result(void) commandExec(str command, bool* stop){ +static Result(void) commandExec(Client* client, str command, bool* stop){ Deferral(64); char answer_buf[10000]; const u32 answer_buf_size = sizeof(answer_buf); @@ -126,7 +141,7 @@ static Result(void) commandExec(str command, bool* stop){ ); } else if (is_alias("j") || is_alias("join")){ - ServerConnection_close(_server_connection); + ServerConnection_close(client->server_connection); puts("Enter server address (ip:port:public_key): "); if(fgets(answer_buf, answer_buf_size, stdin) == NULL){ @@ -136,7 +151,8 @@ static Result(void) commandExec(str command, bool* stop){ str_trim(&new_server_link, true); printf("connecting to server...\n"); - try(_server_connection, p, ServerConnection_open(_client_credentials, new_server_link.data)); + try(client->server_connection, p, + ServerConnection_open(&client->cred, new_server_link.data)); printf("connection established\n"); // TODO: request server info @@ -145,7 +161,7 @@ static Result(void) commandExec(str command, bool* stop){ // try log in // if not registered, request registration and then log in - // call serverConnection_run(): + // call Client_runIO(): // function with infinite loop which sends and receives messages // with navigation across server channels // @@ -159,9 +175,8 @@ static Result(void) commandExec(str command, bool* stop){ // regiser and then log in } else { - Return RESULT_ERROR_FMT("unknown kommand: '%s'\n" - "Use 'h' to see list of avaliable commands", - command.data); + printf("ERROR: unknown command.\n" + "Use 'h' to see list of avaliable commands\n"); } Return RESULT_VOID; diff --git a/src/client/client.h b/src/client/client.h index 738aaeb..801994b 100644 --- a/src/client/client.h +++ b/src/client/client.h @@ -3,17 +3,21 @@ #include "cryptography/RSA.h" #include "network/encrypted_sockets.h" -Result(void) client_run(); +typedef struct Client Client; typedef struct ClientCredentials { str username; - Array(u8) aes_key; + Array(u8) user_data_key; + Array(u8) token; AESBlockEncryptor user_data_aes_enc; AESBlockDecryptor user_data_aes_dec; } ClientCredentials; -Result(ClientCredentials*) ClientCredentials_create(str username, str password); -void ClientCredentials_free(ClientCredentials* cred); +Result(void) ClientCredentials_tryConstruct(ClientCredentials* cred, + str username, str password); + +void ClientCredentials_destroy(ClientCredentials* cred); + typedef struct ServerConnection { u64 session_id; @@ -24,5 +28,17 @@ typedef struct ServerConnection { EncryptedSocketTCP sock; } ServerConnection; -Result(ServerConnection*) ServerConnection_open(ClientCredentials* client_credentials, cstr server_link_cstr); +Result(ServerConnection*) ServerConnection_open(ClientCredentials* client_credentials, + cstr server_link_cstr); + void ServerConnection_close(ServerConnection* conn); + + +typedef struct Client { + ClientCredentials cred; + ServerConnection* server_connection; +} Client; + +Result(void) Client_createFromConfig(cstr config_path); +void Client_free(Client* client); +Result(void) Client_run(Client* client); diff --git a/src/cryptography/RSA.c b/src/cryptography/RSA.c index 0ca141d..a9d592c 100644 --- a/src/cryptography/RSA.c +++ b/src/cryptography/RSA.c @@ -11,7 +11,7 @@ Result(void) RSA_generateKeyPair(u32 key_size, br_rsa_private_key* sk, br_rsa_public_key* pk, const br_prng_class** rng_vtable_ptr) { - Deferral(8); + Deferral(4); bool success = false; void* sk_buf = malloc(BR_RSA_KBUF_PRIV_SIZE(key_size)); @@ -34,7 +34,7 @@ Result(void) RSA_generateKeyPair(u32 key_size, Result(void) RSA_generateKeyPairFromSystemRandom(u32 key_size, br_rsa_private_key* sk, br_rsa_public_key* pk) { - Deferral(8); + Deferral(4); br_hmac_drbg_context time_based_rng = { .vtable = &br_hmac_drbg_vtable }; rng_init_sha256_seedFromSystem(&time_based_rng.vtable); try_void(RSA_generateKeyPair(key_size, sk, pk, &time_based_rng.vtable)); @@ -44,7 +44,7 @@ Result(void) RSA_generateKeyPairFromSystemRandom(u32 key_size, Result(void) RSA_generateKeyPairFromPassword(u32 key_size, br_rsa_private_key* sk, br_rsa_public_key* pk, str password) { - Deferral(8); + Deferral(4); br_hmac_drbg_context password_based_rng = { .vtable = &br_hmac_drbg_vtable }; br_hmac_drbg_init(&password_based_rng, &br_sha256_vtable, password.data, password.size); try_void(RSA_generateKeyPair(key_size, sk, pk, &password_based_rng.vtable)); @@ -52,7 +52,7 @@ Result(void) RSA_generateKeyPairFromPassword(u32 key_size, } Result(void) RSA_computePublicKey(const br_rsa_private_key* sk, br_rsa_public_key* pk){ - Deferral(8); + Deferral(4); br_rsa_compute_modulus compute_modulus = br_rsa_i31_compute_modulus; br_rsa_compute_pubexp compute_pubexp = br_rsa_i31_compute_pubexp; @@ -112,7 +112,7 @@ str RSA_serializePublicKey_base64(const br_rsa_public_key* pk){ } Result(void) RSA_parsePublicKey_base64(cstr src, br_rsa_public_key* pk){ - Deferral(8); + Deferral(4); u32 n_bitlen = 0; if(sscanf(src, "RSA-Public-%u:", &n_bitlen) != 1){ Return RESULT_ERROR("can't parse key size", false); @@ -142,7 +142,7 @@ Result(void) RSA_parsePublicKey_base64(cstr src, br_rsa_public_key* pk){ } Result(void) RSA_parsePrivateKey_base64(cstr src, br_rsa_private_key* sk){ - Deferral(8); + Deferral(4); u32 n_bitlen = 0; if(sscanf(src, "RSA-Private-%u:", &n_bitlen) != 1){ Return RESULT_ERROR("can't parse key size", false); diff --git a/src/db/idb.c b/src/db/idb.c index 8d68443..2e22c78 100644 --- a/src/db/idb.c +++ b/src/db/idb.c @@ -63,7 +63,7 @@ Result(void) validateTableName(str name){ } Result(void) Table_readHeader(Table* t){ - Deferral(8); + Deferral(4); // seek for start of the file try_void(file_seek(t->table_file, 0, SeekOrigin_Start)); // read header @@ -72,7 +72,7 @@ Result(void) Table_readHeader(Table* t){ } Result(void) Table_writeHeader(Table* t){ - Deferral(8); + Deferral(4); // seek for start of the file try_void(file_seek(t->table_file, 0, SeekOrigin_Start)); // write header @@ -81,20 +81,20 @@ Result(void) Table_writeHeader(Table* t){ } Result(void) Table_setDirtyBit(Table* t, bool val){ - Deferral(8); + Deferral(4); t->header._dirty_bit = val; try_void(Table_writeHeader(t)); Return RESULT_VOID; } Result(bool) Table_getDirtyBit(Table* t){ - Deferral(8); + Deferral(4); try_void(Table_readHeader(t)); Return RESULT_VALUE(i, t->header._dirty_bit); } Result(void) Table_calculateRowCount(Table* t){ - Deferral(8); + Deferral(4); try(i64 file_size, i, file_getSize(t->table_file)); i64 data_size = file_size - sizeof(t->header); if(data_size % t->header.row_size != 0){ @@ -109,7 +109,7 @@ Result(void) Table_calculateRowCount(Table* t){ } Result(void) Table_validateHeader(Table* t){ - Deferral(8); + Deferral(4); if(t->header.magic.n != TABLE_FILE_MAGIC.n || t->header.row_size == 0) { @@ -147,10 +147,7 @@ Result(IncrementalDB*) idb_open(str db_dir){ IncrementalDB* db = (IncrementalDB*)malloc(sizeof(IncrementalDB)); // if object construction fails, destroy incomplete object bool success = false; - Defer({ - if(!success) - idb_close(db); - }); + Defer(if(!success) idb_close(db)); // value of *db must be set to zero or behavior of idb_close will be undefined memset(db, 0, sizeof(IncrementalDB)); @@ -171,7 +168,7 @@ void idb_close(IncrementalDB* db){ } Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str _table_name, u32 row_size){ - Deferral(64); + Deferral(16); // db lock try_stderrcode(pthread_mutex_lock(&db->mutex)); Defer(pthread_mutex_unlock(&db->mutex)); @@ -188,10 +185,7 @@ Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str _table_name, u32 row_ Table* t = (Table*)malloc(sizeof(Table)); // if object construction fails, destroy incomplete object bool success = false; - Defer({ - if(!success) - Table_close(t); - }); + Defer(if(!success) Table_close(t)); // value of *t must be set to zero or behavior of Table_close will be undefined memset(t, 0, sizeof(Table)); @@ -237,7 +231,7 @@ Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str _table_name, u32 row_ } Result(void) idb_getRows(Table* t, u64 id, void* dst, u64 count){ - Deferral(16); + Deferral(8); // table lock try_stderrcode(pthread_mutex_lock(&t->mutex)); Defer(pthread_mutex_unlock(&t->mutex)); @@ -260,7 +254,7 @@ Result(void) idb_getRows(Table* t, u64 id, void* dst, u64 count){ } Result(void) idb_updateRows(Table* t, u64 id, const void* src, u64 count){ - Deferral(16); + Deferral(8); // table lock try_stderrcode(pthread_mutex_lock(&t->mutex)); Defer(pthread_mutex_unlock(&t->mutex)); @@ -289,7 +283,7 @@ Result(void) idb_updateRows(Table* t, u64 id, const void* src, u64 count){ } Result(u64) idb_pushRows(Table* t, const void* src, u64 count){ - Deferral(16); + Deferral(8); // table lock try_stderrcode(pthread_mutex_lock(&t->mutex)); Defer(pthread_mutex_unlock(&t->mutex)); @@ -309,7 +303,7 @@ Result(u64) idb_pushRows(Table* t, const void* src, u64 count){ } Result(u64) idb_getRowCount(Table* t){ - Deferral(8); + Deferral(4); // table lock try_stderrcode(pthread_mutex_lock(&t->mutex)); Defer(pthread_mutex_unlock(&t->mutex)); diff --git a/src/main.c b/src/main.c index 5203e78..d93b4bf 100755 --- a/src/main.c +++ b/src/main.c @@ -6,8 +6,8 @@ #define _DEFAULT_CONFIG_PATH_SERVER "tcp-chat-server.config" typedef enum ProgramMode { - Client, - Server, + ClientMode, + ServerMode, RsaGenStdin, RsaGenRandom, } ProgramMode; @@ -22,7 +22,7 @@ int main(const int argc, cstr const* argv){ return 1; } - ProgramMode mode = Client; + ProgramMode mode = ClientMode; cstr server_endpoint_cstr = NULL; cstr config_path = NULL; u32 key_size = 0; @@ -46,12 +46,12 @@ int main(const int argc, cstr const* argv){ Return 0; } if(arg_is("-l") || arg_is("--listen")){ - if(mode != Client){ + if(mode != ClientMode){ printf("program mode is set already\n"); Return 1; } - mode = Server; + mode = ServerMode; if(++argi >= argc){ printfe("ERROR: no endpoint specified\n"); Return 1; @@ -66,7 +66,7 @@ int main(const int argc, cstr const* argv){ config_path = argv[argi]; } else if(arg_is("--rsa-gen-stdin")){ - if(mode != Client){ + if(mode != ClientMode){ printf("program mode is set already\n"); Return 1; } @@ -81,7 +81,7 @@ int main(const int argc, cstr const* argv){ } else if(arg_is("--rsa-gen-random")){ - if(mode != Client){ + if(mode != ClientMode){ printf("program mode is set already\n"); Return 1; } @@ -106,17 +106,23 @@ int main(const int argc, cstr const* argv){ Defer(network_deinit()); switch(mode){ - case Client: { - if(config_path == NULL) + case ClientMode: { + if(!config_path) config_path = _DEFAULT_CONFIG_PATH_CLIENT; - try_fatal_void(client_run()); + + try_fatal(Client* client, p, Client_createFromConfig(config_path)); + Defer(Client_free(client)); + try_fatal_void(Client_run(client)); break; } - case Server: { - if(config_path == NULL) + case ServerMode: { + if(!config_path) config_path = _DEFAULT_CONFIG_PATH_SERVER; - try_fatal_void(server_run(server_endpoint_cstr, config_path)); + + try_fatal(Server* server, p, Server_createFromConfig(config_path)); + Defer(Server_free(server)); + try_fatal_void(Server_run(server, server_endpoint_cstr)); break; } diff --git a/src/network/encrypted_sockets.c b/src/network/encrypted_sockets.c index 4241bfc..e5b2316 100644 --- a/src/network/encrypted_sockets.c +++ b/src/network/encrypted_sockets.c @@ -15,6 +15,8 @@ void EncryptedSocketTCP_construct(EncryptedSocketTCP* ptr, } void EncryptedSocketTCP_destroy(EncryptedSocketTCP* ptr){ + if(!ptr) + return; socket_close(ptr->sock); free(ptr->recv_buf.data); free(ptr->send_buf.data); @@ -161,6 +163,9 @@ void EncryptedSocketUDP_construct(EncryptedSocketUDP* ptr, } void EncryptedSocketUDP_destroy(EncryptedSocketUDP* ptr){ + if(!ptr) + return; + socket_close(ptr->sock); free(ptr->recv_buf.data); free(ptr->send_buf.data); diff --git a/src/network/internal.h b/src/network/internal.h index 79f3076..0dc413e 100644 --- a/src/network/internal.h +++ b/src/network/internal.h @@ -21,6 +21,7 @@ #include #include #include + #include #include #include #include diff --git a/src/server/ClientConnection.c b/src/server/ClientConnection.c index 069d657..70ed334 100644 --- a/src/server/ClientConnection.c +++ b/src/server/ClientConnection.c @@ -2,25 +2,21 @@ #include "network/tcp-chat-protocol/v1.h" void ClientConnection_close(ClientConnection* conn){ - if(conn == NULL) + if(!conn) return; EncryptedSocketTCP_destroy(&conn->sock); free(conn->session_key.data); free(conn); } -Result(ClientConnection*) ClientConnection_accept(ServerCredentials* server_credentials, - ConnectionHandlerArgs* args) +Result(ClientConnection*) ClientConnection_accept(ConnectionHandlerArgs* args) { - Deferral(32); + Deferral(8); ClientConnection* conn = (ClientConnection*)malloc(sizeof(ClientConnection)); memset(conn, 0, sizeof(*conn)); bool success = false; - Defer( - if(!success) - ClientConnection_close(conn); - ); + Defer(if(!success) ClientConnection_close(conn)); conn->client_end = args->client_end; conn->session_id = args->session_id; @@ -33,7 +29,7 @@ Result(ClientConnection*) ClientConnection_accept(ServerCredentials* server_cred // decrypt the rsa messages using server private key RSADecryptor rsa_dec; - RSADecryptor_construct(&rsa_dec, &server_credentials->rsa_sk); + RSADecryptor_construct(&rsa_dec, &args->server->cred.rsa_sk); // receive PacketHeader PacketHeader packet_header = {0}; diff --git a/src/server/ServerCredentials.c b/src/server/ServerCredentials.c index 53727a9..29c2e81 100644 --- a/src/server/ServerCredentials.c +++ b/src/server/ServerCredentials.c @@ -1,26 +1,26 @@ #include "server.h" -Result(ServerCredentials*) ServerCredentials_create(cstr rsa_sk_base64, cstr rsa_pk_base64){ + +Result(void) ServerCredentials_tryConstruct(ServerCredentials* cred, + cstr rsa_sk_base64, cstr rsa_pk_base64) +{ Deferral(4); - ServerCredentials* cred = (ServerCredentials*)malloc(sizeof(ServerCredentials)); memset(cred, 0, sizeof(*cred)); bool success = false; - Defer( - if(!success) - ServerCredentials_free(cred); - ); + Defer(if(!success) ServerCredentials_destroy(cred)); try_void(RSA_parsePrivateKey_base64(rsa_sk_base64, &cred->rsa_sk)); try_void(RSA_parsePublicKey_base64(rsa_pk_base64, &cred->rsa_pk)); success = true; - Return RESULT_VALUE(p, cred); + Return RESULT_VOID; } -void ServerCredentials_free(ServerCredentials* cred){ +void ServerCredentials_destroy(ServerCredentials* cred){ + if(!cred) + return; RSA_destroyPrivateKey(&cred->rsa_sk); RSA_destroyPublicKey(&cred->rsa_pk); - free(cred); } diff --git a/src/server/server.c b/src/server/server.c index 6bceb11..74c19bb 100644 --- a/src/server/server.c +++ b/src/server/server.c @@ -8,15 +8,26 @@ #include "network/tcp-chat-protocol/v1.h" #include "server/request_handlers/request_handlers.h" -static void* handle_connection(void* _args); -static Result(void) try_handle_connection(ConnectionHandlerArgs* args, cstr log_ctx); +static void* handleConnection(void* _args); +static Result(void) try_handleConnection(ConnectionHandlerArgs* args, cstr log_ctx); +void Server_free(Server* server){ + if(!server) + return; + free(server->name.data); + free(server->description.data); + ServerCredentials_destroy(&server->cred); +} -static ServerCredentials* _server_credentials = NULL; - - -static Result(void) parseConfig(cstr config_path){ - Deferral(8); +Result(Server*) Server_createFromConfig(cstr config_path){ + Deferral(16); + cstr log_ctx = "ServerInit"; + logInfo(log_ctx, "parsing config"); + + Server* server = (Server*)malloc(sizeof(Server)); + memset(server, 0, sizeof(Server)); + bool success = false; + Defer(if(!success) Server_free(server)); // open file try(FILE* config_file, p, file_open(config_path, FO_ReadExisting)); @@ -28,32 +39,40 @@ static Result(void) parseConfig(cstr config_path){ try_void(file_readBytesArray(config_file, config_buf)); str config_str = Array_castTo_str(config_buf, false); - str sk_base64; - str pk_base64; - try_void(config_findValue(config_str, STR("rsa_private_key"), &sk_base64, true)); - try_void(config_findValue(config_str, STR("rsa_public_key"), &pk_base64, true)); - char* sk_base64_cstr = str_copy(sk_base64).data; - char* pk_base64_cstr = str_copy(pk_base64).data; - Defer( - free(sk_base64_cstr); - free(pk_base64_cstr); - ); - try(_server_credentials, p, ServerCredentials_create(sk_base64_cstr, pk_base64_cstr)); + // parse name + str tmp_str = str_null; + try_void(config_findValue(config_str, STR("name"), &tmp_str, true)); + server->name = str_copy(tmp_str); + + // parse description + try_void(config_findValue(config_str, STR("description"), &tmp_str, true)); + server->description = str_copy(tmp_str); - Return RESULT_VOID; + // parse rsa_private_key + try_void(config_findValue(config_str, STR("rsa_private_key"), &tmp_str, true)); + char* sk_base64_cstr = str_copy(tmp_str).data; + Defer(free(sk_base64_cstr)); + + // parse rsa_public_key + try_void(config_findValue(config_str, STR("rsa_public_key"), &tmp_str, true)); + char* pk_base64_cstr = str_copy(tmp_str).data; + Defer(free(pk_base64_cstr)); + + try_void(ServerCredentials_tryConstruct(&server->cred, sk_base64_cstr, pk_base64_cstr)); + + success = true; + Return RESULT_VALUE(p, server); } -Result(void) server_run(cstr server_endpoint_cstr, cstr config_path){ - Deferral(32); - cstr log_ctx = "MainThread"; +Result(void) Server_run(Server* server, cstr server_endpoint_cstr){ + Deferral(16); + cstr log_ctx = "ListenerThread"; logInfo(log_ctx, "starting server"); - logDebug(log_ctx, "parsing config"); - try_void(parseConfig(config_path)); - Defer(ServerCredentials_free(_server_credentials)); - logDebug(log_ctx, "initializing main socket"); EndpointIPv4 server_end; try_void(EndpointIPv4_parse(server_endpoint_cstr, &server_end)); + + logDebug(log_ctx, "initializing main socket"); try(Socket main_socket, i, socket_open_TCP()); try_void(socket_bind(main_socket, server_end)); try_void(socket_listen(main_socket, 512)); @@ -62,24 +81,26 @@ Result(void) server_run(cstr server_endpoint_cstr, cstr config_path){ u64 session_id = 1; while(true){ ConnectionHandlerArgs* args = (ConnectionHandlerArgs*)malloc(sizeof(ConnectionHandlerArgs)); - try(args->accepted_socket, i, socket_accept(main_socket, &args->client_end)); + args->server = server; + try(args->accepted_socket, i, + socket_accept(main_socket, &args->client_end)); args->session_id = session_id++; pthread_t conn_thread = {0}; //TODO: use async IO instead of threads to not waste system resources // while waiting for incoming data in 100500 threads - try_stderrcode(pthread_create(&conn_thread, NULL, handle_connection, args)); + try_stderrcode(pthread_create(&conn_thread, NULL, handleConnection, args)); try_stderrcode(pthread_detach(conn_thread)); } Return RESULT_VOID; } -static void* handle_connection(void* _args){ +static void* handleConnection(void* _args){ ConnectionHandlerArgs* args = (ConnectionHandlerArgs*)_args; char log_ctx[64]; sprintf(log_ctx, "Session-" IFWIN("%llx", "%lx"), args->session_id); - ResultVar(void) r = try_handle_connection(args, log_ctx); + ResultVar(void) r = try_handleConnection(args, log_ctx); if(r.error){ str error_s = Error_toStr(r.error); logError(log_ctx, "%s", error_s.data); @@ -89,8 +110,8 @@ static void* handle_connection(void* _args){ return NULL; } -static Result(void) try_handle_connection(ConnectionHandlerArgs* args, cstr log_ctx){ - Deferral(64); +static Result(void) try_handleConnection(ConnectionHandlerArgs* args, cstr log_ctx){ + Deferral(16); Defer(free(args)); ClientConnection* conn = NULL; @@ -99,7 +120,7 @@ static Result(void) try_handle_connection(ConnectionHandlerArgs* args, cstr log_ logInfo(log_ctx, "session closed"); ); // establish encrypted connection - try(conn, p, ClientConnection_accept(_server_credentials, args)); + try(conn, p, ClientConnection_accept(args)); logInfo(log_ctx, "session accepted"); // handle requests diff --git a/src/server/server.h b/src/server/server.h index a3de4ae..e501cc0 100644 --- a/src/server/server.h +++ b/src/server/server.h @@ -3,23 +3,17 @@ #include "cryptography/RSA.h" #include "network/encrypted_sockets.h" -Result(void) server_run(cstr server_endpoint_cstr, cstr config_path); - +typedef struct Server Server; typedef struct ServerCredentials { br_rsa_private_key rsa_sk; br_rsa_public_key rsa_pk; } ServerCredentials; -Result(ServerCredentials*) ServerCredentials_create(cstr rsa_sk_base64, cstr rsa_pk_base64); +Result(void) ServerCredentials_tryConstruct(ServerCredentials* cred, + cstr rsa_sk_base64, cstr rsa_pk_base64); -void ServerCredentials_free(ServerCredentials* cred); - - -typedef struct ServerInfo { - str name; - str description; -} ServerInfo; +void ServerCredentials_destroy(ServerCredentials* cred); typedef struct ClientConnection { @@ -32,12 +26,23 @@ typedef struct ClientConnection { typedef struct ConnectionHandlerArgs { + Server* server; Socket accepted_socket; EndpointIPv4 client_end; u64 session_id; } ConnectionHandlerArgs; -Result(ClientConnection*) ClientConnection_accept(ServerCredentials* server_credentials, - ConnectionHandlerArgs* args); +Result(ClientConnection*) ClientConnection_accept(ConnectionHandlerArgs* args); -void ClientConnection_close(ClientConnection* conn); \ No newline at end of file +void ClientConnection_close(ClientConnection* conn); + + +typedef struct Server { + str name; + str description; + ServerCredentials cred; +} Server; + +Result(Server*) Server_createFromConfig(cstr config_path); +void Server_free(Server* server); +Result(void) Server_run(Server* server, cstr server_endpoint_cstr);