diff --git a/dependencies/tlibc b/dependencies/tlibc index bdbe959..de88e9f 160000 --- a/dependencies/tlibc +++ b/dependencies/tlibc @@ -1 +1 @@ -Subproject commit bdbe959e23fab5697cfdfee4ef92d2da1c755b3b +Subproject commit de88e9ff168cbb71d4ab4c48df949b2e25d1b23d diff --git a/dependencies/tlibtoml b/dependencies/tlibtoml index 11f16a7..5cb121d 160000 --- a/dependencies/tlibtoml +++ b/dependencies/tlibtoml @@ -1 +1 @@ -Subproject commit 11f16a79fc08dd231b4594e4b5eafe8e6bb561cb +Subproject commit 5cb121d1de96f07f7c36ce13afcfb6e511c5427d diff --git a/dependencies/tsqlite b/dependencies/tsqlite index 58840ce..4b15db7 160000 --- a/dependencies/tsqlite +++ b/dependencies/tsqlite @@ -1 +1 @@ -Subproject commit 58840cecd058476af12b9844e83080dd5cc22a38 +Subproject commit 4b15db7c1f8a14495f75e7314e4ffbb558a8268c diff --git a/include/tcp-chat/client.h b/include/tcp-chat/client.h index 2ce351f..cb85d51 100644 --- a/include/tcp-chat/client.h +++ b/include/tcp-chat/client.h @@ -20,12 +20,12 @@ Result(void) Client_connect(Client* client, cstr server_addr_cstr, cstr server_p void Client_disconnect(Client* client); /// @param self connected client -/// @param out_name owned by Client, fetched from server during Client_connect -Result(void) Client_getServerName(Client* self, str* out_name); +/// @param out_str heap-allocated string +Result(void) Client_getServerName(Client* self, str* out_str); /// @param self connected client -/// @param out_name owned by Client, fetched from server during Client_connect -Result(void) Client_getServerDescription(Client* self, str* out_desc); +/// @param out_str heap-allocated string +Result(void) Client_getServerDescription(Client* self, str* out_str); /// Create new account on connected server Result(void) Client_register(Client* self, i64* out_user_id); diff --git a/include/tcp-chat/tcp-chat.h b/include/tcp-chat/tcp-chat.h index 4f231e3..85d08fd 100644 --- a/include/tcp-chat/tcp-chat.h +++ b/include/tcp-chat/tcp-chat.h @@ -12,3 +12,5 @@ typedef enum TcpChatError { TcpChatError_Unknown, TcpChatError_RejectIncoming, } TcpChatError; + +#define MESSAGE_TIMESTAMP_FMT_SQL "%Y.%m.%d-%H:%M:%f" diff --git a/project.config b/project.config index 4580ff7..fde2540 100644 --- a/project.config +++ b/project.config @@ -47,7 +47,7 @@ case "$OS" in EXEC_FILE="$PROJECT.exe" SHARED_LIB_FILE="$PROJECT.dll" INCLUDE="$INCLUDE " - LINKER_LIBS="-static -lpthread -lws2_32 -lsqlite3" + LINKER_LIBS="-static -lpthread -lws2_32 -luuid -lsqlite3" ;; LINUX) EXEC_FILE="$PROJECT" diff --git a/src/cli/ClientCLI/ClientCLI.c b/src/cli/ClientCLI/ClientCLI.c index 78a2d75..8101f83 100644 --- a/src/cli/ClientCLI/ClientCLI.c +++ b/src/cli/ClientCLI/ClientCLI.c @@ -22,28 +22,28 @@ static const str farewell_art = STR( #define is_alias(LITERAL) str_equals(command, STR(LITERAL)) static Result(void) ClientCLI_askUserNameAndPassword(str* username_out, str* password_out); -static Result(void) ClientCLI_execCommand(ClientCLI* self, str command, bool* stop); static Result(void) ClientCLI_openUserDB(ClientCLI* self); -static Result(ServerInfo*) ClientCLI_saveServerInfo(ClientCLI* self, - str addr, str pk_base64, str name, str desc); -static Result(ServerInfo*) ClientCLI_joinNewServer(ClientCLI* self); -static Result(ServerInfo*) ClientCLI_selectServerFromCache(ClientCLI* self); -static Result(void) ClientCLI_showServerInfo(ClientCLI* self, ServerInfo* server); +static Result(void) ClientCLI_execCommand(ClientCLI* self, str command, bool* stop); +static Result(SavedServer*) ClientCLI_joinNewServer(ClientCLI* self); +static Result(SavedServer*) ClientCLI_selectServerFromCache(ClientCLI* self); +static Result(void) ClientCLI_showSavedServer(ClientCLI* self, SavedServer* server); static Result(void) ClientCLI_register(ClientCLI* self); static Result(void) ClientCLI_login(ClientCLI* self); + void ClientCLI_destroy(ClientCLI* self){ if(!self) return; Client_free(self->client); - - idb_close(self->db); - List_ServerInfo_destroy(&self->servers.list); - HashMap_destroy(&self->servers.addr_id_map); + ClientQueries_free(self->queries); + tsqlite_connection_close(self->db); + List_SavedServer_destroyWithElements(&self->saved_servers, SavedServer_destroy); } + void ClientCLI_construct(ClientCLI* self){ zeroStruct(self); + self->saved_servers = List_SavedServer_alloc(0); } Result(void) ClientCLI_run(ClientCLI* self) { @@ -137,6 +137,33 @@ static Result(void) ClientCLI_askUserNameAndPassword(str* username_out, str* pas Return RESULT_VOID; } +static Result(void) ClientCLI_openUserDB(ClientCLI* self){ + Deferral(8); + + str username = Client_getUserName(self->client); + // TODO: encrypt user database + // Array(u8) user_data_key = Client_getUserDataKey(self->client); + + // build database file path + try(char* user_dir, p, path_getUserDir()); + Defer(free(user_dir)); + char* db_path = strcat_malloc( + user_dir, + path_seps".local"path_seps"tcp-chat-client"path_seps"user-db"path_seps, + username.data, ".sqlite" + ); + Defer(free(db_path)); + printf("loading database '%s'\n", db_path); + + try(self->db, p, ClientDatabase_open(db_path)); + try(self->queries, p, ClientQueries_compile(self->db)); + + // load whole servers table to list + try_void(SavedServer_getAll(self->queries, &self->saved_servers)); + + Return RESULT_VOID; +} + static Result(void) ClientCLI_execCommand(ClientCLI* self, str command, bool* stop){ Deferral(64); @@ -190,229 +217,136 @@ static Result(void) ClientCLI_execCommand(ClientCLI* self, str command, bool* st static Result(void) ClientCLI_joinNewServer(ClientCLI* self){ Deferral(8); + bool success = false; // ask server address + const u32 address_alloc_size = HOSTADDR_SIZE_MAX + 1; + str address = str_construct((char*)malloc(address_alloc_size), address_alloc_size, true); + Defer(if(!success) str_destroy(address)); printf("Enter server address (ip:port):\n"); - char server_addr_cstr[HOSTADDR_SIZE_MAX + 1]; - try_void(term_readLine(server_addr_cstr, sizeof(server_addr_cstr))); - str server_addr_str = str_from_cstr(server_addr_cstr); - str_trim(&server_addr_str, true); + try_void(term_readLine(address.data, address.len)); + address.len = strlen(address.data); + str_trim(&address, true); // ask server public key + const u32 server_pk_alloc_size = PUBLIC_KEY_BASE64_SIZE_MAX + 1; + str server_pk = str_construct((char*)malloc(server_pk_alloc_size), server_pk_alloc_size, true); + Defer(if(!success) str_destroy(server_pk)); printf("Enter server public key (RSA-Public-:):\n"); - char server_pk_cstr[PUBLIC_KEY_BASE64_SIZE_MAX + 1]; - try_void(term_readLine(server_pk_cstr, sizeof(server_pk_cstr))); - str server_pk_str = str_from_cstr(server_pk_cstr); - str_trim(&server_pk_str, true); + try_void(term_readLine(server_pk.data, server_pk.len)); + server_pk.len = strlen(server_pk.data); + str_trim(&server_pk, true); printf("Connecting to server...\n"); - try_void(Client_connect(self->client, server_addr_cstr, server_pk_cstr)); + try_void(Client_connect(self->client, address.data, server_pk.data)); printf("Connection established\n"); str server_name = str_null; - str server_description = str_null; try_void(Client_getServerName(self->client, &server_name)); + Defer(if(!success) str_destroy(server_name)); + str server_description = str_null; try_void(Client_getServerDescription(self->client, &server_description)); - try(ServerInfo* server, p, ClientCLI_saveServerInfo(self, - server_addr_str, server_pk_str, - server_name, server_description)); + Defer(if(!success) str_destroy(server_description)); - try_void(ClientCLI_showServerInfo(self, server)); + SavedServer server = { + .address = address, + .pk_base64 = server_pk, + .name = server_name, + .description = server_description + }; + try_void(SavedServer_save(self->queries, &server)); + List_SavedServer_pushMany(&self->saved_servers, &server, 1); + try_void(ClientCLI_showSavedServer(self, &server)); + + success = true; Return RESULT_VOID; } static Result(void) ClientCLI_selectServerFromCache(ClientCLI* self){ Deferral(8); + bool success = false; - // Lock table until this function returns. - // It may not change any data in table, but it uses associated cache structures. - idb_lockTable(self->servers.table); - Defer(idb_unlockTable(self->servers.table)); - - u32 server_count = self->servers.list.len; - if(server_count == 0){ - printf("No servers found in cache\n"); + u32 servers_count = self->saved_servers.len; + if(servers_count == 0){ + printf("No saved servers found\n"); Return RESULT_VOID; } - for(u32 id = 0; id < server_count; id++){ - ServerInfo* server = self->servers.list.data + id; + for(u32 i = 0; i < servers_count; i++){ + SavedServer* server = &self->saved_servers.data[i]; printf("[%02u] "FMT_str" "FMT_str"\n", - id, server->address_len, server->address, server->name_len, server->name); + i, str_unwrap(server->address), str_unwrap(server->name)); } char buf[32]; - u32 id = -1; + u32 selected_i = -1; while(true) { printf("Type 'q' to cancel\n"); - printf("Select server (number): "); + printf("Select server number: "); try_void(term_readLine(buf, sizeof(buf))); str input_line = str_from_cstr(buf); str_trim(&input_line, true); if(str_equals(input_line, STR("q"))){ Return RESULT_VOID; } - if(sscanf(buf, FMT_u32, &id) != 1){ + if(sscanf(buf, FMT_u32, &selected_i) != 1){ printf("ERROR: not a number\n"); } - else if(id >= server_count){ - printf("ERROR: not a server number: %u\n", id); + else if(selected_i >= servers_count){ + printf("ERROR: not a server number\n"); } else break; } - ServerInfo* server = self->servers.list.data + id; + SavedServer* selected_server = &self->saved_servers.data[selected_i]; - printf("Connecting to '"FMT_str"'...\n", server->address_len, server->address); - try_void(Client_connect(self->client, server->address, server->pk_base64)); + printf("Connecting to '"FMT_str"'...\n", str_unwrap(selected_server->address)); + try_void(Client_connect(self->client, selected_server->address.data, selected_server->pk_base64.data)); printf("Connection established\n"); + // update server name bool server_info_changed = false; - // update cached server name - str name = str_null; - try_void(Client_getServerName(self->client, &name)); - if(!str_equals(name, str_construct(server->name, server->name_len, true))){ + str updated_server_name = str_null; + try_void(Client_getServerName(self->client, &updated_server_name)); + Defer(if(!success) str_destroy(updated_server_name)); + if(!str_equals(updated_server_name, selected_server->name)){ server_info_changed = true; - if(name.len > SERVER_NAME_SIZE_MAX) - name.len = SERVER_NAME_SIZE_MAX; - server->name_len = name.len; - memcpy(server->name, name.data, server->name_len); + selected_server->name = updated_server_name; } - // update cached server description - str desc = str_null; - try_void(Client_getServerDescription(self->client, &desc)); - if(!str_equals(desc, str_construct(server->desc, server->desc_len, true))){ + + // update server description + str updated_server_description = str_null; + try_void(Client_getServerDescription(self->client, &updated_server_description)); + Defer(if(!success) str_destroy(updated_server_description)); + if(!str_equals(updated_server_description, selected_server->description)){ server_info_changed = true; - if(desc.len > SERVER_DESC_SIZE_MAX) - desc.len = SERVER_DESC_SIZE_MAX; - server->desc_len = desc.len; - memcpy(server->desc, desc.data, server->desc_len); + selected_server->description = updated_server_description; } + if(server_info_changed){ - try_void(idb_updateRow(self->servers.table, id, server, false)); + try_void(SavedServer_save(self->queries, selected_server)); } - try_void(ClientCLI_showServerInfo(self, server)); + try_void(ClientCLI_showSavedServer(self, selected_server)); + success = true; Return RESULT_VOID; } -static Result(void) ClientCLI_showServerInfo(ClientCLI* self, ServerInfo* server){ +static Result(void) ClientCLI_showSavedServer(ClientCLI* self, SavedServer* server){ Deferral(8); (void)self; - printf("Server Name: "FMT_str"\n", server->name_len, server->name); - printf("Host Address: "FMT_str"\n", server->address_len, server->address); - printf("Description:\n"FMT_str"\n\n", server->desc_len, server->desc); - printf("Public Key:\n" FMT_str"\n\n", server->pk_base64_len, server->pk_base64); + printf("Server Name: "FMT_str"\n", str_unwrap(server->name)); + printf("Host Address: "FMT_str"\n", str_unwrap(server->address)); + printf("Description:\n"FMT_str"\n\n", str_unwrap(server->description)); + printf("Public Key:\n" FMT_str"\n\n", str_unwrap(server->pk_base64)); printf("Type 'register' if you don't have an account on the server.\n"); printf("Type 'login' to authorize on the server.\n"); Return RESULT_VOID; } -static Result(void) ClientCLI_openUserDB(ClientCLI* self){ - Deferral(8); - - str username = Client_getUserName(self->client); - Array(u8) user_data_key = Client_getUserDataKey(self->client); - str user_db_path = str_from_cstr(strcat_malloc("client-db", path_seps, username.data)); - Defer(free(user_db_path.data)); - try(self->db, p, idb_open(user_db_path, user_data_key)); - - // Lock DB until this function returns. - idb_lockDB(self->db); - Defer(idb_unlockDB(self->db)); - - // Load servers table - try(self->servers.table, p, - idb_getOrCreateTable(self->db, str_null, STR("servers"), sizeof(ServerInfo), false) - ); - - // Lock table until this function returns. - idb_lockTable(self->servers.table); - Defer(idb_unlockTable(self->servers.table)); - - // load whole servers table to list - try_void( - idb_createListFromTable(self->servers.table, (void*)&self->servers.list, false) - ); - - // build address-id map - try(i64 server_count, u, - idb_getRowCount(self->servers.table, false) - ); - HashMap_construct(&self->servers.addr_id_map, i64, NULL); - for(i64 id = 0; id < server_count; id++){ - ServerInfo* server = self->servers.list.data + id; - str key = str_construct(server->address, server->address_len, true); - if(!HashMap_tryPush(&self->servers.addr_id_map, key, &id)){ - Return RESULT_ERROR_FMT( - "duplicate server address '"FMT_str"'", - key.len, key.data); - } - } - - Return RESULT_VOID; -} - -static Result(ServerInfo*) ClientCLI_saveServerInfo(ClientCLI* self, - str addr, str pk_base64, str name, str desc){ - Deferral(8); - - // create new server info - ServerInfo server; - zeroStruct(&server); - // address - if(addr.len > HOSTADDR_SIZE_MAX) - addr.len = HOSTADDR_SIZE_MAX; - server.address_len = addr.len; - memcpy(server.address, addr.data, server.address_len); - // public key - if(pk_base64.len > PUBLIC_KEY_BASE64_SIZE_MAX) - pk_base64.len = PUBLIC_KEY_BASE64_SIZE_MAX; - server.pk_base64_len = pk_base64.len; - memcpy(server.pk_base64, pk_base64.data, server.pk_base64_len); - // name - if(name.len > SERVER_NAME_SIZE_MAX) - name.len = SERVER_NAME_SIZE_MAX; - server.name_len = name.len; - memcpy(server.name, name.data, server.name_len); - // description - if(desc.len > SERVER_DESC_SIZE_MAX) - desc.len = SERVER_DESC_SIZE_MAX; - server.desc_len = desc.len; - memcpy(server.desc, desc.data, server.desc_len); - - // Lock table until this function returns. - // It may not change any data in table, but it uses associated cache structures. - idb_lockTable(self->servers.table); - Defer(idb_unlockTable(self->servers.table)); - - // try find server id in cache - ServerInfo* cached_row_ptr = NULL; - i64* id_ptr = NULL; - id_ptr = HashMap_tryGetPtr(&self->servers.addr_id_map, addr); - if(id_ptr){ - // update existing server - i64 id = *id_ptr; - try_void(idb_updateRow(self->servers.table, id, &server, false)); - try_assert(id < self->servers.list.len); - cached_row_ptr = self->servers.list.data + id; - memcpy(cached_row_ptr, &server, sizeof(ServerInfo)); - } - else { - // push new server - try(i64 id, u, idb_pushRow(self->servers.table, &server, false)); - try_assert(id == self->servers.list.len); - List_ServerInfo_pushMany(&self->servers.list, &server, 1); - cached_row_ptr = self->servers.list.data + id; - try_assert(HashMap_tryPush(&self->servers.addr_id_map, addr, &id)); - } - - Return RESULT_VALUE(p, cached_row_ptr); -} - static Result(void) ClientCLI_register(ClientCLI* self){ Deferral(8); @@ -420,6 +354,7 @@ static Result(void) ClientCLI_register(ClientCLI* self){ try_void(Client_register(self->client, &user_id)); printf("Registered successfully\n"); printf("user_id: "FMT_i64"\n", user_id); + try_assert(user_id > 0); // TODO: use user_id somewhere Return RESULT_VOID; @@ -432,6 +367,7 @@ static Result(void) ClientCLI_login(ClientCLI* self){ try_void(Client_login(self->client, &user_id, &landing_channel_id)); printf("Authorized successfully\n"); printf("user_id: "FMT_i64", landing_channel_id: "FMT_i64"\n", user_id, landing_channel_id); + try_assert(user_id > 0); // TODO: use user_id, landing_channel_id somewhere Return RESULT_VOID; diff --git a/src/cli/ClientCLI/ClientCLI.h b/src/cli/ClientCLI/ClientCLI.h index 6f2d4e7..b578690 100644 --- a/src/cli/ClientCLI/ClientCLI.h +++ b/src/cli/ClientCLI/ClientCLI.h @@ -2,12 +2,14 @@ #include #include "tlibc/collections/HashMap.h" #include "tlibc/collections/List.h" -#include "tsqlite.h" #include "tcp-chat/client.h" +#include "db/client_db.h" typedef struct ClientCLI { Client* client; tsqlite_connection* db; + ClientQueries* queries; + List(SavedServer) saved_servers; } ClientCLI; void ClientCLI_construct(ClientCLI* self); diff --git a/src/cli/ClientCLI/db/SavedServer.c b/src/cli/ClientCLI/db/SavedServer.c new file mode 100644 index 0000000..85f9a51 --- /dev/null +++ b/src/cli/ClientCLI/db/SavedServer.c @@ -0,0 +1,24 @@ +#include "client_db_internal.h" + +void SavedServer_destroy(SavedServer* self){ + if(!self) + return; + str_destroy(self->address); + str_destroy(self->pk_base64); + str_destroy(self->name); + str_destroy(self->description); +} + +Result(void) SavedServer_save(ClientQueries* q, SavedServer* server){ + (void)q; + (void)server; + Deferral(4); + Return RESULT_VOID; +} + +Result(void) SavedServer_getAll(ClientQueries* q, List(SavedServer)* dst_list){ + (void)q; + (void)dst_list; + Deferral(4); + Return RESULT_VOID; +} diff --git a/src/cli/ClientCLI/db/client_db.c b/src/cli/ClientCLI/db/client_db.c new file mode 100644 index 0000000..68e4434 --- /dev/null +++ b/src/cli/ClientCLI/db/client_db.c @@ -0,0 +1,24 @@ +#include "client_db_internal.h" +#include "tlibc/filesystem.h" + +Result(tsqlite_connection* db) ClientDatabase_open(cstr file_path){ + Deferral(64); + + try_void(dir_createParent(file_path)); + try(tsqlite_connection* db, p, tsqlite_connection_open(file_path)); + bool success = false; + Defer(if(!success) tsqlite_connection_close(db)); + + success = true; + Return RESULT_VALUE(p, db); +} + +Result(ClientQueries*) ClientQueries_compile(tsqlite_connection* db){ + (void)db; + Deferral(4); + Return RESULT_VOID; +} + +void ClientQueries_free(ClientQueries* self){ + (void)self; +} \ No newline at end of file diff --git a/src/cli/ClientCLI/db/client_db.h b/src/cli/ClientCLI/db/client_db.h new file mode 100644 index 0000000..c109bb2 --- /dev/null +++ b/src/cli/ClientCLI/db/client_db.h @@ -0,0 +1,30 @@ +#pragma once +#include "tcp-chat/tcp-chat.h" +#include "tsqlite.h" +#include "network/tcp-chat-protocol/v1.h" +#include "tlibc/collections/List.h" + +/// @brief open DB and create tables +Result(tsqlite_connection* db) ClientDatabase_open(cstr file_path); + +typedef struct ClientQueries ClientQueries; +Result(ClientQueries*) ClientQueries_compile(tsqlite_connection* db); +void ClientQueries_free(ClientQueries* self); + + +typedef struct SavedServer { + str address; + str pk_base64; + str name; + str description; +} SavedServer; + +List_declare(SavedServer); + +void SavedServer_destroy(SavedServer* self); + +/// @brief insert new DB row or update existing +Result(void) SavedServer_save(ClientQueries* q, SavedServer* server); + +/// @param dst_list there SavedServer values are pushed +Result(void) SavedServer_getAll(ClientQueries* q, List(SavedServer)* dst_list); diff --git a/src/cli/ClientCLI/db/client_db_internal.h b/src/cli/ClientCLI/db/client_db_internal.h new file mode 100644 index 0000000..1d5d8a9 --- /dev/null +++ b/src/cli/ClientCLI/db/client_db_internal.h @@ -0,0 +1,10 @@ +#pragma once +#include "client_db.h" + +typedef struct ClientQueries { + struct { + tsqlite_statement* find_by_id; + } servers; +} ClientQueries; + + diff --git a/src/client/ServerConnection.c b/src/client/ServerConnection.c index a7d8ba2..52fa3b9 100644 --- a/src/client/ServerConnection.c +++ b/src/client/ServerConnection.c @@ -8,8 +8,6 @@ void ServerConnection_close(ServerConnection* self){ EncryptedSocketTCP_destroy(&self->sock); Array_u8_destroy(&self->token); Array_u8_destroy(&self->session_key); - str_destroy(self->server_name); - str_destroy(self->server_description); free(self); } @@ -75,16 +73,11 @@ Result(ServerConnection*) ServerConnection_open(Client* client, cstr server_addr PacketType_ServerHandshake)); conn->session_id = server_handshake.session_id; - // get server name - try_void(ServerConnection_requestServerName(conn)); - // get server description - try_void(ServerConnection_requestServerDescription(conn)); - success = true; Return RESULT_VALUE(p, conn); } -Result(void) ServerConnection_requestServerName(ServerConnection* conn){ +Result(void) ServerConnection_requestServerName(ServerConnection* conn, str* out_str){ if(conn == NULL){ return RESULT_ERROR_LITERAL("Client is not connected to a server"); } @@ -98,12 +91,12 @@ Result(void) ServerConnection_requestServerName(ServerConnection* conn){ try_void(sendRequest(&conn->sock, &req_header, &public_info_req)); try_void(recvResponse(&conn->sock, &res_header, &public_info_res, PacketType_ServerPublicInfoResponse)); - try_void(recvStr(&conn->sock, public_info_res.data_size, &conn->server_name)); + try_void(recvStr(&conn->sock, public_info_res.data_size, out_str)); Return RESULT_VOID; } -Result(void) ServerConnection_requestServerDescription(ServerConnection* conn){ +Result(void) ServerConnection_requestServerDescription(ServerConnection* conn, str* out_str){ if(conn == NULL){ return RESULT_ERROR_LITERAL("Client is not connected to a server"); } @@ -117,7 +110,7 @@ Result(void) ServerConnection_requestServerDescription(ServerConnection* conn){ try_void(sendRequest(&conn->sock, &req_header, &public_info_req)); try_void(recvResponse(&conn->sock, &res_header, &public_info_res, PacketType_ServerPublicInfoResponse)); - try_void(recvStr(&conn->sock, public_info_res.data_size, &conn->server_description)); + try_void(recvStr(&conn->sock, public_info_res.data_size, out_str)); Return RESULT_VOID; } \ No newline at end of file diff --git a/src/client/client.c b/src/client/client.c index 632ee7c..e17615e 100644 --- a/src/client/client.c +++ b/src/client/client.c @@ -52,22 +52,22 @@ Array(u8) Client_getUserDataKey(Client* client){ return client->user_data_key; } -Result(void) Client_getServerName(Client* self, str* out_name){ +Result(void) Client_getServerName(Client* self, str* out_str){ Deferral(1); try_assert(self != NULL); try_assert(self->conn != NULL && "didn't connect to a server yet"); - *out_name = self->conn->server_name; + try_void(ServerConnection_requestServerName(self->conn, out_str)); Return RESULT_VOID; } -Result(void) Client_getServerDescription(Client* self, str* out_desc){ +Result(void) Client_getServerDescription(Client* self, str* out_str){ Deferral(1); try_assert(self != NULL); try_assert(self->conn != NULL && "didn't connect to a server yet"); - *out_desc = self->conn->server_description; + try_void(ServerConnection_requestServerDescription(self->conn, out_str)); Return RESULT_VOID; } diff --git a/src/client/client_internal.h b/src/client/client_internal.h index 5a75dd6..c573f40 100644 --- a/src/client/client_internal.h +++ b/src/client/client_internal.h @@ -22,8 +22,6 @@ typedef struct ServerConnection { Array(u8) session_key; EncryptedSocketTCP sock; i64 session_id; - str server_name; - str server_description; i64 user_id; } ServerConnection; @@ -34,8 +32,8 @@ Result(ServerConnection*) ServerConnection_open(Client* client, void ServerConnection_close(ServerConnection* conn); -/// updates conn->server_name -Result(void) ServerConnection_requestServerName(ServerConnection* conn); +/// @param out_str heap-allocated string +Result(void) ServerConnection_requestServerName(ServerConnection* conn, str* out_str); -/// updates conn->server_description -Result(void) ServerConnection_requestServerDescription(ServerConnection* conn); +/// @param out_str heap-allocated string +Result(void) ServerConnection_requestServerDescription(ServerConnection* conn, str* out_str); diff --git a/src/client/requests/ErrorMessage.c b/src/client/requests/ErrorMessage.c index e662911..91d4da9 100644 --- a/src/client/requests/ErrorMessage.c +++ b/src/client/requests/ErrorMessage.c @@ -1,6 +1,6 @@ #include "requests.h" -Result(void) recvStr(EncryptedSocketTCP* sock, u32 size, str* out_s){ +Result(void) recvStr(EncryptedSocketTCP* sock, u32 size, str* out_str){ Deferral(4); str s = str_construct(malloc(size + 1), size, true); @@ -17,7 +17,7 @@ Result(void) recvStr(EncryptedSocketTCP* sock, u32 size, str* out_s){ ); s.data[s.len] = 0; - *out_s = s; + *out_str = s; success = true; Return RESULT_VOID; } diff --git a/src/client/requests/requests.h b/src/client/requests/requests.h index d7b6687..36d84dd 100644 --- a/src/client/requests/requests.h +++ b/src/client/requests/requests.h @@ -3,10 +3,12 @@ #include "client/client_internal.h" +/// @param out_err_msg heap-allocated string Result(void) recvErrorMessage(EncryptedSocketTCP* sock, PacketHeader* res_header, str* out_err_msg); -Result(void) recvStr(EncryptedSocketTCP* sock, u32 size, str* out_s); +/// @param out_str heap-allocated string +Result(void) recvStr(EncryptedSocketTCP* sock, u32 size, str* out_str); Result(void) _recvResponse(EncryptedSocketTCP* sock, PacketHeader* res_header, Array(u8) res, PacketType res_type); diff --git a/src/server/ClientConnection.c b/src/server/ClientConnection.c index 4acdd1a..9f74731 100644 --- a/src/server/ClientConnection.c +++ b/src/server/ClientConnection.c @@ -3,12 +3,12 @@ void ClientConnection_close(ClientConnection* conn){ if(!conn) return; - tsqlite_connection_close(conn->db); EncryptedSocketTCP_destroy(&conn->sock); Array_u8_destroy(&conn->session_key); Array_u8_destroy(&conn->message_block); Array_u8_destroy(&conn->message_content); - CommonQueries_free(conn->queries.common); + ServerQueries_free(conn->queries); + tsqlite_connection_close(conn->db); free(conn); } @@ -31,7 +31,7 @@ Result(ClientConnection*) ClientConnection_accept(ConnectionHandlerArgs* args) // database try(conn->db, p, tsqlite_connection_open(args->server->db_path)); - try(conn->queries.common, p, CommonQueries_compile(conn->db)); + try(conn->queries, p, ServerQueries_compile(conn->db)); // correct session key will be received from client later conn->session_key = Array_u8_alloc(AES_SESSION_KEY_SIZE); diff --git a/src/server/db/Channel.c b/src/server/db/Channel.c index 931c4fa..0668fc0 100644 --- a/src/server/db/Channel.c +++ b/src/server/db/Channel.c @@ -1,54 +1,45 @@ -#include "db_internal.h" +#include "server_db_internal.h" -Result(bool) Channel_exists(CommonQueries* q, i64 id){ +Result(bool) Channel_exists(ServerQueries* q, i64 id){ Deferral(1); + tsqlite_statement* st = q->channels.exists; - try_void(tsqlite_statement_reset(st)); - try_void(tsqlite_statement_bind_i64(st, "id", id)); + Defer(tsqlite_statement_reset(st)); + try_void(tsqlite_statement_bind_i64(st, "$id", id)); + try(bool has_result, i, tsqlite_statement_step(st)); + Return RESULT_VALUE(i, has_result); } -Result(void) Channel_createOrUpdate(CommonQueries* q, +Result(void) Channel_createOrUpdate(ServerQueries* q, i64 id, str name, str description) { Deferral(4); try_assert(id > 0); try_assert(name.len >= CHANNEL_NAME_SIZE_MIN && name.len <= CHANNEL_NAME_SIZE_MAX); try_assert(description.len <= CHANNEL_DESC_SIZE_MAX); - - // create channels table - try_void(tsqlite_statement_reset(q->channels.create_table)); - try_void(tsqlite_statement_bind_i64(q->channels.create_table, "name_max", CHANNEL_NAME_SIZE_MAX)); - try_void(tsqlite_statement_bind_i64(q->channels.create_table, "desc_max", CHANNEL_DESC_SIZE_MAX)); - try_void(tsqlite_statement_step(q->channels.create_table)); - - // create messages table - try_void(tsqlite_statement_reset(q->messages.create_table)); - try_void(tsqlite_statement_step(q->messages.create_table)); + tsqlite_statement* st = NULL; + Defer(tsqlite_statement_reset(st)); try(bool channel_exists, i, Channel_exists(q, id)); if(channel_exists){ // update existing channel - try_void(tsqlite_statement_reset(q->channels.update)); - try_void(tsqlite_statement_bind_i64(q->channels.update, "id", id)); - try_void(tsqlite_statement_bind_str(q->channels.update, "name", str_copy(name), free)); - try_void(tsqlite_statement_bind_str(q->channels.update, "description", str_copy(description), free)); - try_void(tsqlite_statement_step(q->channels.update)); + st = q->channels.update; } else { // insert new channel - try_void(tsqlite_statement_reset(q->channels.insert)); - try_void(tsqlite_statement_bind_i64(q->channels.insert, "id", id)); - try_void(tsqlite_statement_bind_str(q->channels.insert, "name", str_copy(name), free)); - try_void(tsqlite_statement_bind_str(q->channels.insert, "description", str_copy(description), free)); - try_void(tsqlite_statement_step(q->channels.insert)); + st = q->channels.insert; } + try_void(tsqlite_statement_bind_i64(st, "$id", id)); + try_void(tsqlite_statement_bind_str(st, "$name", name, NULL)); + try_void(tsqlite_statement_bind_str(st, "$description", description, NULL)); + try_void(tsqlite_statement_step(st)); Return RESULT_VOID; } -Result(void) Channel_saveMessage(CommonQueries* q, +Result(void) Channel_saveMessage(ServerQueries* q, i64 channel_id, i64 sender_id, Array(u8) content, DateTime* out_timestamp) { @@ -56,10 +47,10 @@ Result(void) Channel_saveMessage(CommonQueries* q, try_assert(content.len >= MESSAGE_SIZE_MIN && content.len <= MESSAGE_SIZE_MAX); tsqlite_statement* st = q->messages.insert; - try_void(tsqlite_statement_reset(st)); - try_void(tsqlite_statement_bind_i64(st, "channel_id", channel_id)); - try_void(tsqlite_statement_bind_i64(st, "sender_id", sender_id)); - try_void(tsqlite_statement_bind_blob(st, "content", Array_u8_copy(content), free)); + Defer(tsqlite_statement_reset(st)); + try_void(tsqlite_statement_bind_i64(st, "$channel_id", channel_id)); + try_void(tsqlite_statement_bind_i64(st, "$sender_id", sender_id)); + try_void(tsqlite_statement_bind_blob(st, "$content", content, NULL)); try(bool has_result, i, tsqlite_statement_step(st)); try_assert(has_result); @@ -72,7 +63,7 @@ Result(void) Channel_saveMessage(CommonQueries* q, Return RESULT_VALUE(i, message_id); } -Result(void) Channel_loadMessageBlock(CommonQueries* q, +Result(void) Channel_loadMessageBlock(ServerQueries* q, i64 channel_id, i64 first_message_id, u32 count, MessageBlockMeta* block_meta, Array(u8) block_data) { @@ -84,10 +75,10 @@ Result(void) Channel_loadMessageBlock(CommonQueries* q, } tsqlite_statement* st = q->messages.get_block; - try_void(tsqlite_statement_reset(st)); - try_void(tsqlite_statement_bind_i64(st, "channel_id", channel_id)); - try_void(tsqlite_statement_bind_i64(st, "first_message_id", first_message_id)); - try_void(tsqlite_statement_bind_i64(st, "count", count)); + Defer(tsqlite_statement_reset(st)); + try_void(tsqlite_statement_bind_i64(st, "$channel_id", channel_id)); + try_void(tsqlite_statement_bind_i64(st, "$first_message_id", first_message_id)); + try_void(tsqlite_statement_bind_i64(st, "$count", count)); zeroStruct(block_meta); MessageMeta msg_meta = {0}; diff --git a/src/server/db/CommonQueries.c b/src/server/db/CommonQueries.c deleted file mode 100644 index b06a8c6..0000000 --- a/src/server/db/CommonQueries.c +++ /dev/null @@ -1,82 +0,0 @@ -#include "db_internal.h" - -void CommonQueries_free(CommonQueries* q){ - if(!q) - return; - tsqlite_statement_free(q->channels.create_table); - tsqlite_statement_free(q->channels.insert); - tsqlite_statement_free(q->channels.exists); - tsqlite_statement_free(q->channels.update); - tsqlite_statement_free(q->messages.create_table); - tsqlite_statement_free(q->messages.insert); - tsqlite_statement_free(q->messages.get_block); - free(q); -} - -Result(void) CommonQueries_compile(tsqlite_connection* db){ - Deferral(4); - - CommonQueries* q = (CommonQueries*)malloc(sizeof(*q)); - zeroStruct(q); - bool success = false; - Defer(if(!success) CommonQueries_free(q)); - - /////////////////////////////////////////////////////////////////////////// - // CHANNELS // - /////////////////////////////////////////////////////////////////////////// - try(q->channels.create_table, p, tsqlite_statement_compile(db, STR( - "CREATE TABLE IF NOT EXISTS channels (\n" - " id BIGINT PRIMARY KEY,\n" - " name VARCHAR($name_max) NOT NULL,\n" - " description VARCHAR($desc_max) NOT NULL\n" - ");" - ))); - - try(q->channels.insert, p, tsqlite_statement_compile(db, STR( - "INSERT INTO\n" - "channels (id, name, description)\n" - "VALUES ($id, $name, $description);" - ))); - - try(q->channels.exists, p, tsqlite_statement_compile(db, STR( - "SELECT 1 FROM channels WHERE id = $id;" - ))); - - try(q->channels.update, p, tsqlite_statement_compile(db, STR( - "UPDATE channels\n" - "SET name = $name, description = $description\n" - "WHERE id = $id;" - ))); - - /////////////////////////////////////////////////////////////////////////// - // MESSAGES // - /////////////////////////////////////////////////////////////////////////// - try(q->messages.create_table, p, tsqlite_statement_compile(db, STR( - "CREATE TABLE IF NOT EXISTS messages (\n" - " id BIGINT PRIMARY KEY,\n" - " channel_id BIGINT NOT NULL REFERENCES channels(id)\n" - " sender_id BIGINT NOT NULL REFERENCES users(id),\n" - " content BLOB NOT NULL,\n" - " timestamp DATETIME NOT NULL DEFAULT (\n" - " strftime('"MESSAGE_TIMESTAMP_FMT_SQL"', 'now', 'utc', 'subsecond')\n" - " )\n" - ");" - ))); - - try(q->messages.insert, p, tsqlite_statement_compile(db, STR( - "INSERT INTO\n" - "messages (channel_id, sender_id, content)\n" - "VALUES ($channel_id, $sender_id, $content)\n" - "RETURNING id, timestamp;" - ))); - - try(q->messages.get_block, p, tsqlite_statement_compile(db, STR( - "SELECT id, sender_id, content, timestamp FROM messages\n" - "WHERE id >= $first_message_id\n" - "AND channel_id = $channel_id\n" - "LIMIT $count;" - ))); - - success = true; - Return RESULT_VALUE(p, q); -} diff --git a/src/server/db/User.c b/src/server/db/User.c index 51490d8..1486fa2 100644 --- a/src/server/db/User.c +++ b/src/server/db/User.c @@ -1,3 +1,50 @@ -#include "db.h" +#include "server_db_internal.h" +Result(i64) User_findByUsername(ServerQueries* q, str username){ + Deferral(1); + tsqlite_statement* st = q->users.find_by_username; + Defer(tsqlite_statement_reset(st)); + try_void(tsqlite_statement_bind_str(st, "$username", username, NULL)); + + try(bool has_result, i, tsqlite_statement_step(st)); + i64 user_id = 0; + if(has_result){ + try(user_id, i, tsqlite_statement_getResult_i64(st)); + try_assert(user_id > 0); + } + + Return RESULT_VALUE(i, user_id); +} + +Result(i64) User_register(ServerQueries* q, str username, Array(u8) token){ + Deferral(1); + try_assert(username.len >= USERNAME_SIZE_MIN && username.len <= USERNAME_SIZE_MAX); + try_assert(token.len == PASSWORD_HASH_SIZE) + + tsqlite_statement* st = q->users.insert; + Defer(tsqlite_statement_reset(st)); + try_void(tsqlite_statement_bind_str(st, "$username", username, NULL)); + try_void(tsqlite_statement_bind_blob(st, "$token", token, NULL)); + + try(bool has_result, i, tsqlite_statement_step(st)); + try_assert(has_result); + try(i64 user_id, i, tsqlite_statement_getResult_i64(st)); + try_assert(user_id > 0); + + Return RESULT_VALUE(i, user_id); +} + +Result(bool) User_tryAuthorize(ServerQueries* q, u64 id, Array(u8) token){ + Deferral(1); + try_assert(token.len == PASSWORD_HASH_SIZE) + + tsqlite_statement* st = q->users.compare_token; + Defer(tsqlite_statement_reset(st)); + try_void(tsqlite_statement_bind_i64(st, "$id", id)); + try_void(tsqlite_statement_bind_blob(st, "$token", token, NULL)); + + try(bool has_result, i, tsqlite_statement_step(st)); + + Return RESULT_VALUE(i, has_result); +} diff --git a/src/server/db/db.h b/src/server/db/db.h deleted file mode 100644 index 2a04f04..0000000 --- a/src/server/db/db.h +++ /dev/null @@ -1,39 +0,0 @@ -#pragma once -#include "tsqlite.h" -#include "network/tcp-chat-protocol/v1.h" - -// typedef struct ChannelInfo { -// i64 id; -// str name; -// str description; -// } ChannelInfo; - -typedef struct CommonQueries CommonQueries; - -Result(CommonQueries*) CommonQueries_compile(tsqlite_connection* db); -void CommonQueries_free(CommonQueries* self); - -Result(bool) Channel_exists(CommonQueries* q, i64 id); - -Result(void) Channel_createOrUpdate(CommonQueries* q, - i64 id, str name, str description); - -/// @return new message id -Result(i64) Channel_saveMessage(CommonQueries* q, - i64 channel_id, i64 sender_id, Array(u8) content, - DateTime* out_timestamp_utc); - -/// @brief try to find `count` messages starting from `first_message_id` -/// @param out_meta writes here information about found messages, .count can be 0 if no messages found -/// @param out_block .len must be >= count * (sizeof(MessageMeta) + MESSAGE_SIZE_MAX) -Result(void) Channel_loadMessageBlock(CommonQueries* q, - i64 channel_id, i64 first_message_id, u32 count, - MessageBlockMeta* out_block_meta, Array(u8) out_block_data); - -/// @return existing user id or 0 -Result(i64) User_getIdForUsername(CommonQueries* q, str username); - -/// @return new user id -Result(i64) User_register(CommonQueries* q, str username, Array(u8) token); - -Result(bool) User_tryAuthorize(CommonQueries* q, u64 id, Array(u8) token); diff --git a/src/server/db/db_internal.h b/src/server/db/db_internal.h deleted file mode 100644 index cc33a9f..0000000 --- a/src/server/db/db_internal.h +++ /dev/null @@ -1,33 +0,0 @@ -#pragma once -#include "db.h" - -typedef struct CommonQueries { - struct { - /* (name_max, desc_max) -> void */ - tsqlite_statement* create_table; - /* (id, name, description) -> void */ - tsqlite_statement* insert; - /* (id) -> bool */ - tsqlite_statement* exists; - /* (id, name, description) -> void */ - tsqlite_statement* update; - } channels; - struct { - /* () -> void */ - tsqlite_statement* create_table; - /* (channel_id, sender_id, content) -> (id, timestamp) */ - tsqlite_statement* insert; - /* (channel_id, first_message_id, count) -> [(id, sender_id, content, timestamp)] */ - tsqlite_statement* get_block; - } messages; - struct { - tsqlite_statement* create_table; - tsqlite_statement* registration_begin; - tsqlite_statement* registration_end; - tsqlite_statement* get_credentials; - tsqlite_statement* get_public_info; - tsqlite_statement* update; - } users; -} CommonQueries; - -#define MESSAGE_TIMESTAMP_FMT_SQL "%Y.%m.%d-%H:%M:%f" diff --git a/src/server/db/server_db.c b/src/server/db/server_db.c new file mode 100644 index 0000000..84ae75d --- /dev/null +++ b/src/server/db/server_db.c @@ -0,0 +1,153 @@ +#include "server_db_internal.h" +#include "tlibc/filesystem.h" + +Result(tsqlite_connection*) ServerDatabase_open(cstr file_path){ + Deferral(64); + + try_void(dir_createParent(file_path)); + try(tsqlite_connection* db, p, tsqlite_connection_open(file_path)); + bool success = false; + Defer(if(!success) tsqlite_connection_close(db)); + + /////////////////////////////////////////////////////////////////////////// + // CHANNELS // + /////////////////////////////////////////////////////////////////////////// + try(tsqlite_statement* create_table_channels, p, tsqlite_statement_compile(db, STR( + "CREATE TABLE IF NOT EXISTS channels (\n" + " id INTEGER PRIMARY KEY AUTOINCREMENT,\n" + " name VARCHAR NOT NULL,\n" + " description VARCHAR NOT NULL\n" + ");" + ))); + Defer(tsqlite_statement_free(create_table_channels)); + try_void(tsqlite_statement_step(create_table_channels)); + + + /////////////////////////////////////////////////////////////////////////// + // MESSAGES // + /////////////////////////////////////////////////////////////////////////// + try(tsqlite_statement* create_table_messages, p, tsqlite_statement_compile(db, STR( + "CREATE TABLE IF NOT EXISTS messages (\n" + " id INTEGER PRIMARY KEY AUTOINCREMENT,\n" + " channel_id INTEGER NOT NULL REFERENCES channels(id),\n" + " sender_id INTEGER NOT NULL REFERENCES users(id),\n" + " content BLOB NOT NULL,\n" + " timestamp DATETIME NOT NULL DEFAULT (\n" + " strftime('"MESSAGE_TIMESTAMP_FMT_SQL"', 'now', 'utc', 'subsecond')\n" + " )\n" + ");" + ))); + Defer(tsqlite_statement_free(create_table_messages)); + try_void(tsqlite_statement_step(create_table_messages)); + + /////////////////////////////////////////////////////////////////////////// + // USERS // + /////////////////////////////////////////////////////////////////////////// + try(tsqlite_statement* create_table_users, p, tsqlite_statement_compile(db, STR( + "CREATE TABLE IF NOT EXISTS users (\n" + " id INTEGER PRIMARY KEY AUTOINCREMENT,\n" + " username VARCHAR NOT NULL,\n" + " token BLOB NOT NULL,\n" + " registration_time DATETIME NOT NULL DEFAULT (\n" + " strftime('"MESSAGE_TIMESTAMP_FMT_SQL"', 'now', 'utc', 'subsecond')\n" + " )\n" + ");" + ))); + Defer(tsqlite_statement_free(create_table_users)); + try_void(tsqlite_statement_step(create_table_users)); + + try(tsqlite_statement* create_index_username, p, tsqlite_statement_compile(db, STR( + "CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users(username);" + ))); + Defer(tsqlite_statement_free(create_index_username)); + try_void(tsqlite_statement_step(create_index_username)); + + success = true; + Return RESULT_VALUE(p, db); +} + + + +void ServerQueries_free(ServerQueries* q){ + if(!q) + return; + + tsqlite_statement_free(q->channels.insert); + tsqlite_statement_free(q->channels.update); + tsqlite_statement_free(q->channels.exists); + + tsqlite_statement_free(q->messages.insert); + tsqlite_statement_free(q->messages.get_block); + + tsqlite_statement_free(q->users.insert); + tsqlite_statement_free(q->users.find_by_username); + tsqlite_statement_free(q->users.compare_token); + + free(q); +} + +Result(ServerQueries*) ServerQueries_compile(tsqlite_connection* db){ + Deferral(4); + + ServerQueries* q = (ServerQueries*)malloc(sizeof(*q)); + zeroStruct(q); + bool success = false; + Defer(if(!success) ServerQueries_free(q)); + + /////////////////////////////////////////////////////////////////////////// + // CHANNELS // + /////////////////////////////////////////////////////////////////////////// + try(q->channels.insert, p, tsqlite_statement_compile(db, STR( + "INSERT INTO\n" + "channels (id, name, description)\n" + "VALUES ($id, $name, $description);" + ))); + + try(q->channels.exists, p, tsqlite_statement_compile(db, STR( + "SELECT 1 FROM channels WHERE id = $id;" + ))); + + try(q->channels.update, p, tsqlite_statement_compile(db, STR( + "UPDATE channels\n" + "SET name = $name, description = $description\n" + "WHERE id = $id;" + ))); + + /////////////////////////////////////////////////////////////////////////// + // MESSAGES // + /////////////////////////////////////////////////////////////////////////// + try(q->messages.insert, p, tsqlite_statement_compile(db, STR( + "INSERT INTO\n" + "messages (channel_id, sender_id, content)\n" + "VALUES ($channel_id, $sender_id, $content)\n" + "RETURNING id, timestamp;" + ))); + + try(q->messages.get_block, p, tsqlite_statement_compile(db, STR( + "SELECT id, sender_id, content, timestamp FROM messages\n" + "WHERE id >= $first_message_id\n" + "AND channel_id = $channel_id\n" + "LIMIT $count;" + ))); + + /////////////////////////////////////////////////////////////////////////// + // USERS // + /////////////////////////////////////////////////////////////////////////// + try(q->users.insert, p, tsqlite_statement_compile(db, STR( + "INSERT INTO\n" + "users (username, token)\n" + "VALUES ($username, $token)\n" + "RETURNING id, registration_time;" + ))); + + try(q->users.find_by_username, p, tsqlite_statement_compile(db, STR( + "SELECT id FROM users WHERE username = $username;" + ))); + + try(q->users.compare_token, p, tsqlite_statement_compile(db, STR( + "SELECT 1 FROM users WHERE id = $id AND token = $token;" + ))); + + success = true; + Return RESULT_VALUE(p, q); +} diff --git a/src/server/db/server_db.h b/src/server/db/server_db.h new file mode 100644 index 0000000..b08dbff --- /dev/null +++ b/src/server/db/server_db.h @@ -0,0 +1,39 @@ +#pragma once +#include "tcp-chat/tcp-chat.h" +#include "tsqlite.h" +#include "network/tcp-chat-protocol/v1.h" + +/// @brief open DB and create tables +Result(tsqlite_connection*) ServerDatabase_open(cstr file_path); + +typedef struct ServerQueries ServerQueries; +Result(ServerQueries*) ServerQueries_compile(tsqlite_connection* db); +void ServerQueries_free(ServerQueries* self); + + +Result(bool) Channel_exists(ServerQueries* q, i64 id); + +Result(void) Channel_createOrUpdate(ServerQueries* q, + i64 id, str name, str description); + +/// @return new message id +Result(i64) Channel_saveMessage(ServerQueries* q, + i64 channel_id, i64 sender_id, Array(u8) content, + DateTime* out_timestamp_utc); + +/// @brief try to find `count` messages starting from `first_message_id` +/// @param out_meta writes here information about found messages, .count can be 0 if no messages found +/// @param out_block .len must be >= count * (sizeof(MessageMeta) + MESSAGE_SIZE_MAX) +Result(void) Channel_loadMessageBlock(ServerQueries* q, + i64 channel_id, i64 first_message_id, u32 count, + MessageBlockMeta* out_block_meta, Array(u8) out_block_data); + + +/// @return existing user id or 0 +Result(i64) User_findByUsername(ServerQueries* q, str username); + +/// @return new user id +Result(i64) User_register(ServerQueries* q, str username, Array(u8) token); + +/// @return true for successful authorization +Result(bool) User_tryAuthorize(ServerQueries* q, u64 id, Array(u8) token); diff --git a/src/server/db/server_db_internal.h b/src/server/db/server_db_internal.h new file mode 100644 index 0000000..212189c --- /dev/null +++ b/src/server/db/server_db_internal.h @@ -0,0 +1,27 @@ +#pragma once +#include "server_db.h" + +typedef struct ServerQueries { + struct { + /* (id, name, description) -> void */ + tsqlite_statement* insert; + /* (id, name, description) -> void */ + tsqlite_statement* update; + /* (id) -> 1 or nothing */ + tsqlite_statement* exists; + } channels; + struct { + /* (channel_id, sender_id, content) -> (id, timestamp) */ + tsqlite_statement* insert; + /* (channel_id, first_message_id, count) -> [(id, sender_id, content, timestamp)] */ + tsqlite_statement* get_block; + } messages; + struct { + /* (username, token) -> (id, registration_time) */ + tsqlite_statement* insert; + /* (username) -> (id) */ + tsqlite_statement* find_by_username; + /* (id, token) -> 1 or nothing */ + tsqlite_statement* compare_token; + } users; +} ServerQueries; diff --git a/src/server/responses/GetMessageBlock.c b/src/server/responses/GetMessageBlock.c index 762145d..3ac8ccc 100644 --- a/src/server/responses/GetMessageBlock.c +++ b/src/server/responses/GetMessageBlock.c @@ -29,7 +29,7 @@ declare_RequestHandler(GetMessageBlock) } // validate channel id - try(bool channel_exists, i, Channel_exists(conn->queries.common, req.channel_id)); + try(bool channel_exists, i, Channel_exists(conn->queries, req.channel_id)); if(!channel_exists){ try_void(sendErrorMessage(log_ctx, conn, res_head, LogSeverity_Warn, STR("invalid channel id") )); @@ -39,7 +39,7 @@ declare_RequestHandler(GetMessageBlock) // reset block meta zeroStruct(&conn->message_block_meta); // get message block from channel - try_void(Channel_loadMessageBlock(conn->queries.common, + try_void(Channel_loadMessageBlock(conn->queries, req.channel_id, req.first_message_id, req.message_count, &conn->message_block_meta, conn->message_block)); diff --git a/src/server/responses/Login.c b/src/server/responses/Login.c index 41c2f7b..1f49a62 100644 --- a/src/server/responses/Login.c +++ b/src/server/responses/Login.c @@ -32,7 +32,7 @@ declare_RequestHandler(Login) } // get user by id - try(u64 user_id, i, User_getIdForUsername(conn->queries.common, username)); + try(u64 user_id, i, User_findByUsername(conn->queries, username)); if(user_id == 0){ try_void(sendErrorMessage(log_ctx, conn, res_head, LogSeverity_Warn, STR("Username is not registered") )); @@ -41,7 +41,7 @@ declare_RequestHandler(Login) // TODO: get user token Array(u8) token = Array_u8_construct(req.token, sizeof(req.token)); - try(bool authorized, i, User_tryAuthorize(conn->queries.common, user_id, token)); + try(bool authorized, i, User_tryAuthorize(conn->queries, user_id, token)); // validate token hash if(!authorized){ try_void(sendErrorMessage(log_ctx, conn, res_head, @@ -52,7 +52,7 @@ declare_RequestHandler(Login) // authorize conn->authorized = true; conn->user_id = user_id; - logInfo("authorized user '%s'", username.data); + logInfo("authorized user '%s' with id "FMT_i64, username.data, user_id); // send response LoginResponse res; diff --git a/src/server/responses/Register.c b/src/server/responses/Register.c index 22d1462..8c6e8e1 100644 --- a/src/server/responses/Register.c +++ b/src/server/responses/Register.c @@ -32,7 +32,7 @@ declare_RequestHandler(Register) } // check if name is taken - try(u64 user_id, i, User_getIdForUsername(conn->queries.common, username)); + try(u64 user_id, i, User_findByUsername(conn->queries, username)); if(user_id != 0){ try_void(sendErrorMessage(log_ctx, conn, res_head, LogSeverity_Warn, STR("Username is already taken") )); @@ -41,8 +41,9 @@ declare_RequestHandler(Register) // register new user Array(u8) token = Array_u8_construct(req.token, sizeof(req.token)); - try(user_id, i, User_register(conn->queries.common, username, token)); - logInfo("registered user '"FMT_str"'", str_expand(username)); + try(user_id, i, User_register(conn->queries, username, token)); + logInfo("registered user '"FMT_str"' with id "FMT_i64, + str_unwrap(username), user_id); // send response RegisterResponse res; diff --git a/src/server/responses/SendMessage.c b/src/server/responses/SendMessage.c index 4656845..94dba96 100644 --- a/src/server/responses/SendMessage.c +++ b/src/server/responses/SendMessage.c @@ -33,7 +33,7 @@ declare_RequestHandler(SendMessage) try_void(EncryptedSocketTCP_recv(&conn->sock, conn->message_content, SocketRecvFlag_WholeBuffer)); // validate channel id - try(bool channel_exists, i, Channel_exists(conn->queries.common, req.channel_id)); + try(bool channel_exists, i, Channel_exists(conn->queries, req.channel_id)); if(!channel_exists){ try_void(sendErrorMessage(log_ctx, conn, res_head, LogSeverity_Warn, STR("invalid channel id") )); @@ -42,7 +42,7 @@ declare_RequestHandler(SendMessage) // save message to channel DateTime timestamp; - try(i64 message_id, i, Channel_saveMessage(conn->queries.common, + try(i64 message_id, i, Channel_saveMessage(conn->queries, req.channel_id, conn->user_id, conn->message_content, ×tamp)); diff --git a/src/server/server.c b/src/server/server.c index 99793d7..cb193f2 100644 --- a/src/server/server.c +++ b/src/server/server.c @@ -1,8 +1,6 @@ #include #include "tlibc/filesystem.h" #include "tlibc/time.h" -#include "tlibc/base64.h" -#include "tlibc/algorithms.h" #include "server/server_internal.h" #include "server/responses/responses.h" #include "tlibtoml.h" @@ -20,6 +18,7 @@ void Server_free(Server* self){ RSA_destroyPublicKey(&self->rsa_pk); free(self->db_path); + ServerQueries_free(self->queries); tsqlite_connection_close(self->db); free(self); @@ -86,7 +85,9 @@ Result(Server*) Server_create(str config_file_content, cstr config_file_name, self->db_path = str_copy(*v_db_path).data; // open DB - try(self->db, p, tsqlite_connection_open(self->db_path)); + logInfo("loading database '%s'", self->db_path); + try(self->db, p, ServerDatabase_open(self->db_path)); + try(self->queries, p, ServerQueries_compile(self->db)); // [channels] logDebug("loading channels..."); @@ -101,13 +102,13 @@ Result(Server*) Server_create(str config_file_content, cstr config_file_name, if(val->type != TLIBTOML_TABLE) continue; - logInfo("loading channel '"FMT_str"'", str_expand(name)) + logInfo("loading channel '"FMT_str"'", str_unwrap(name)) TomlTable* config_channel = val->table; try(i64 id, u, TomlTable_get_integer(config_channel, STR("id"))); try(str* v_ch_desc, p, TomlTable_get_str(config_channel, STR("description"))) str description = *v_ch_desc; - try_void(Channel_createOrUpdate(self->server_queries, id, name, description)); + try_void(Channel_createOrUpdate(self->queries, id, name, description)); } success = true; diff --git a/src/server/server_internal.h b/src/server/server_internal.h index ee0d161..ad2e1e4 100644 --- a/src/server/server_internal.h +++ b/src/server/server_internal.h @@ -5,7 +5,7 @@ #include "cryptography/RSA.h" #include "network/encrypted_sockets.h" #include "network/tcp-chat-protocol/v1.h" -#include "server/db/db.h" +#include "db/server_db.h" typedef struct ClientConnection ClientConnection; @@ -25,6 +25,7 @@ typedef struct Server { /* database and cache*/ char* db_path; tsqlite_connection* db; + ServerQueries* queries; /* for server listener thread only */ } Server; @@ -44,9 +45,7 @@ typedef struct ClientConnection { /* database */ tsqlite_connection* db; - struct { - CommonQueries* common; - } queries; + ServerQueries* queries; } ClientConnection; typedef struct ConnectionHandlerArgs {