Compare commits
5 Commits
88c2f8aa51
...
e2edd4070a
| Author | SHA1 | Date | |
|---|---|---|---|
| e2edd4070a | |||
| d461cae077 | |||
| 49793e2929 | |||
| 72696dea70 | |||
| 084a1828b2 |
3
.gitmodules
vendored
3
.gitmodules
vendored
@@ -7,3 +7,6 @@
|
||||
[submodule "dependencies/tlibtoml"]
|
||||
path = dependencies/tlibtoml
|
||||
url = https://timerix.ddns.net/git/Timerix/tlibtoml.git
|
||||
[submodule "dependencies/tsqlite"]
|
||||
path = dependencies/tsqlite
|
||||
url = https://timerix.ddns.net/git/Timerix/tsqlite.git
|
||||
|
||||
1
.vscode/c_cpp_properties.json
vendored
1
.vscode/c_cpp_properties.json
vendored
@@ -9,6 +9,7 @@
|
||||
"dependencies/BearSSL/inc",
|
||||
"dependencies/tlibc/include",
|
||||
"dependencies/tlibtoml/include",
|
||||
"dependencies/tsqlite/include",
|
||||
"${default}"
|
||||
],
|
||||
"cStandard": "c99"
|
||||
|
||||
62
.vscode/launch.json
vendored
62
.vscode/launch.json
vendored
@@ -2,29 +2,65 @@
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "gdb_debug",
|
||||
"name": "(gdb) Client | Build and debug",
|
||||
"type": "cppdbg",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/bin/tcp-chat",
|
||||
"windows": { "program": "${workspaceFolder}/bin/tcp-chat.exe" },
|
||||
// "args": [ "-l" ],
|
||||
"preLaunchTask": "build_exec_dbg",
|
||||
|
||||
"stopAtEntry": false,
|
||||
"cwd": "${workspaceFolder}/bin",
|
||||
"externalConsole": false,
|
||||
"internalConsoleOptions": "neverOpen",
|
||||
"MIMode": "gdb",
|
||||
"miDebuggerPath": "gdb",
|
||||
"setupCommands": [
|
||||
{
|
||||
"text": "-enable-pretty-printing",
|
||||
"ignoreFailures": true
|
||||
},
|
||||
{
|
||||
"text": "-gdb-set disassembly-flavor intel",
|
||||
"ignoreFailures": true
|
||||
}
|
||||
]
|
||||
"miDebuggerPath": "gdb"
|
||||
},
|
||||
{
|
||||
"name": "(gdb) Client | Just debug",
|
||||
"type": "cppdbg",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/bin/tcp-chat",
|
||||
"windows": { "program": "${workspaceFolder}/bin/tcp-chat.exe" },
|
||||
|
||||
"stopAtEntry": false,
|
||||
"cwd": "${workspaceFolder}/bin",
|
||||
"externalConsole": false,
|
||||
"internalConsoleOptions": "neverOpen",
|
||||
"MIMode": "gdb",
|
||||
"miDebuggerPath": "gdb"
|
||||
},
|
||||
|
||||
{
|
||||
"name": "(gdb) Server | Build and debug",
|
||||
"type": "cppdbg",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/bin/tcp-chat",
|
||||
"windows": { "program": "${workspaceFolder}/bin/tcp-chat.exe" },
|
||||
"args": [ "-l" ],
|
||||
"preLaunchTask": "build_exec_dbg",
|
||||
|
||||
"stopAtEntry": false,
|
||||
"cwd": "${workspaceFolder}/bin",
|
||||
"externalConsole": false,
|
||||
"internalConsoleOptions": "neverOpen",
|
||||
"MIMode": "gdb",
|
||||
"miDebuggerPath": "gdb"
|
||||
},
|
||||
{
|
||||
"name": "(gdb) Server | Just debug",
|
||||
"type": "cppdbg",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/bin/tcp-chat",
|
||||
"windows": { "program": "${workspaceFolder}/bin/tcp-chat.exe" },
|
||||
"args": [ "-l" ],
|
||||
|
||||
"stopAtEntry": false,
|
||||
"cwd": "${workspaceFolder}/bin",
|
||||
"externalConsole": false,
|
||||
"internalConsoleOptions": "neverOpen",
|
||||
"MIMode": "gdb",
|
||||
"miDebuggerPath": "gdb"
|
||||
}
|
||||
]
|
||||
}
|
||||
2
dependencies/tlibc
vendored
2
dependencies/tlibc
vendored
Submodule dependencies/tlibc updated: 0d422cd7e5...de88e9ff16
2
dependencies/tlibtoml
vendored
2
dependencies/tlibtoml
vendored
Submodule dependencies/tlibtoml updated: bd38585b35...5cb121d1de
1
dependencies/tsqlite
vendored
Submodule
1
dependencies/tsqlite
vendored
Submodule
Submodule dependencies/tsqlite added at 4b15db7c1f
30
dependencies/tsqlite.config
vendored
Normal file
30
dependencies/tsqlite.config
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# This is a dependency config.
|
||||
# You can copy it to another project to add tsqlite dependency.
|
||||
|
||||
DEP_WORKING_DIR="$DEPENDENCIES_DIR/tsqlite"
|
||||
|
||||
user_config_path="project.config.user"
|
||||
absolute_dep_dir=$(realpath "$DEPENDENCIES_DIR")
|
||||
|
||||
function setup_user_config(){
|
||||
# Set variable `DEPENDENCIES_DIR`` in `tsqlite/project.config.user`
|
||||
# to the directory where `tlibc`` is installed
|
||||
file_copy_default_if_not_present "$user_config_path" "$user_config_path.default"
|
||||
replace_var_value_in_script "$user_config_path" "DEPENDENCIES_DIR" "$absolute_dep_dir"
|
||||
}
|
||||
|
||||
if [[ "$TASK" = *_dbg ]]; then
|
||||
dep_build_target="build_static_lib_dbg"
|
||||
else
|
||||
dep_build_target="build_static_lib"
|
||||
fi
|
||||
DEP_PRE_BUILD_COMMAND="setup_user_config"
|
||||
DEP_BUILD_COMMAND="cbuild $dep_build_target"
|
||||
DEP_POST_BUILD_COMMAND=""
|
||||
DEP_CLEAN_COMMAND="cbuild clean"
|
||||
DEP_DYNAMIC_OUT_FILES=""
|
||||
DEP_STATIC_OUT_FILES="bin/tsqlite.a"
|
||||
DEP_OTHER_OUT_FILES=""
|
||||
PRESERVE_OUT_DIRECTORY_STRUCTURE=false
|
||||
@@ -20,15 +20,15 @@ Result(void) Client_connect(Client* client, cstr server_addr_cstr, cstr server_p
|
||||
void Client_disconnect(Client* client);
|
||||
|
||||
/// @param self connected client
|
||||
/// @param out_name owned by Client, fetched from server during Client_connect
|
||||
Result(void) Client_getServerName(Client* self, str* out_name);
|
||||
/// @param out_str heap-allocated string
|
||||
Result(void) Client_getServerName(Client* self, str* out_str);
|
||||
|
||||
/// @param self connected client
|
||||
/// @param out_name owned by Client, fetched from server during Client_connect
|
||||
Result(void) Client_getServerDescription(Client* self, str* out_desc);
|
||||
/// @param out_str heap-allocated string
|
||||
Result(void) Client_getServerDescription(Client* self, str* out_str);
|
||||
|
||||
/// Create new account on connected server
|
||||
Result(void) Client_register(Client* self, u64* out_user_id);
|
||||
Result(void) Client_register(Client* self, i64* out_user_id);
|
||||
|
||||
/// Authorize on connected server
|
||||
Result(void) Client_login(Client* self, u64* out_user_id, u64* out_landing_channel_id);
|
||||
Result(void) Client_login(Client* self, i64* out_user_id, i64* out_landing_channel_id);
|
||||
|
||||
@@ -18,4 +18,4 @@
|
||||
#define CHANNEL_DESC_SIZE_MAX 1023
|
||||
#define MESSAGE_SIZE_MIN 1
|
||||
#define MESSAGE_SIZE_MAX 4000
|
||||
#define MESSAGE_BLOCK_SIZE (64*1024)
|
||||
#define MESSAGE_BLOCK_COUNT_MAX 50
|
||||
@@ -12,3 +12,5 @@ typedef enum TcpChatError {
|
||||
TcpChatError_Unknown,
|
||||
TcpChatError_RejectIncoming,
|
||||
} TcpChatError;
|
||||
|
||||
#define MESSAGE_TIMESTAMP_FMT_SQL "%Y.%m.%d-%H:%M:%f"
|
||||
|
||||
@@ -24,7 +24,7 @@ SRC_CPP="$(find src -name '*.cpp')"
|
||||
# See cbuild/example_dependency_configs
|
||||
DEPENDENCY_CONFIGS_DIR='dependencies'
|
||||
# List of dependency config files in DEPENDENCY_CONFIGS_DIR separated by space.
|
||||
ENABLED_DEPENDENCIES='bearssl tlibc tlibtoml'
|
||||
ENABLED_DEPENDENCIES='bearssl tlibc tlibtoml tsqlite'
|
||||
|
||||
# OBJDIR structure:
|
||||
# ├── objects/ - Compiled object files. Cleans on each call of build task
|
||||
@@ -38,7 +38,8 @@ STATIC_LIB_FILE="$PROJECT.a"
|
||||
INCLUDE="-Isrc -Iinclude
|
||||
-I$DEPENDENCIES_DIR/BearSSL/inc
|
||||
-I$DEPENDENCIES_DIR/tlibc/include
|
||||
-I$DEPENDENCIES_DIR/tlibtoml/include"
|
||||
-I$DEPENDENCIES_DIR/tlibtoml/include
|
||||
-I$DEPENDENCIES_DIR/tsqlite/include"
|
||||
|
||||
# OS-specific options
|
||||
case "$OS" in
|
||||
@@ -46,13 +47,13 @@ case "$OS" in
|
||||
EXEC_FILE="$PROJECT.exe"
|
||||
SHARED_LIB_FILE="$PROJECT.dll"
|
||||
INCLUDE="$INCLUDE "
|
||||
LINKER_LIBS="-static -lpthread -lws2_32"
|
||||
LINKER_LIBS="-static -lpthread -lws2_32 -luuid -lsqlite3"
|
||||
;;
|
||||
LINUX)
|
||||
EXEC_FILE="$PROJECT"
|
||||
SHARED_LIB_FILE="$PROJECT.so"
|
||||
INCLUDE="$INCLUDE "
|
||||
LINKER_LIBS=""
|
||||
LINKER_LIBS="-lsqlite3"
|
||||
;;
|
||||
*)
|
||||
error "operating system $OS has no configuration variants"
|
||||
|
||||
@@ -22,28 +22,28 @@ static const str farewell_art = STR(
|
||||
#define is_alias(LITERAL) str_equals(command, STR(LITERAL))
|
||||
|
||||
static Result(void) ClientCLI_askUserNameAndPassword(str* username_out, str* password_out);
|
||||
static Result(void) ClientCLI_execCommand(ClientCLI* self, str command, bool* stop);
|
||||
static Result(void) ClientCLI_openUserDB(ClientCLI* self);
|
||||
static Result(ServerInfo*) ClientCLI_saveServerInfo(ClientCLI* self,
|
||||
str addr, str pk_base64, str name, str desc);
|
||||
static Result(ServerInfo*) ClientCLI_joinNewServer(ClientCLI* self);
|
||||
static Result(ServerInfo*) ClientCLI_selectServerFromCache(ClientCLI* self);
|
||||
static Result(void) ClientCLI_showServerInfo(ClientCLI* self, ServerInfo* server);
|
||||
static Result(void) ClientCLI_execCommand(ClientCLI* self, str command, bool* stop);
|
||||
static Result(SavedServer*) ClientCLI_joinNewServer(ClientCLI* self);
|
||||
static Result(SavedServer*) ClientCLI_selectServerFromCache(ClientCLI* self);
|
||||
static Result(void) ClientCLI_showSavedServer(ClientCLI* self, SavedServer* server);
|
||||
static Result(void) ClientCLI_register(ClientCLI* self);
|
||||
static Result(void) ClientCLI_login(ClientCLI* self);
|
||||
|
||||
|
||||
void ClientCLI_destroy(ClientCLI* self){
|
||||
if(!self)
|
||||
return;
|
||||
|
||||
Client_free(self->client);
|
||||
|
||||
idb_close(self->db);
|
||||
List_ServerInfo_destroy(&self->servers.list);
|
||||
HashMap_destroy(&self->servers.addr_id_map);
|
||||
ClientQueries_free(self->queries);
|
||||
tsqlite_connection_close(self->db);
|
||||
List_SavedServer_destroyWithElements(&self->saved_servers, SavedServer_destroy);
|
||||
}
|
||||
|
||||
void ClientCLI_construct(ClientCLI* self){
|
||||
zeroStruct(self);
|
||||
self->saved_servers = List_SavedServer_alloc(0);
|
||||
}
|
||||
|
||||
Result(void) ClientCLI_run(ClientCLI* self) {
|
||||
@@ -137,6 +137,33 @@ static Result(void) ClientCLI_askUserNameAndPassword(str* username_out, str* pas
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
static Result(void) ClientCLI_openUserDB(ClientCLI* self){
|
||||
Deferral(8);
|
||||
|
||||
str username = Client_getUserName(self->client);
|
||||
// TODO: encrypt user database
|
||||
// Array(u8) user_data_key = Client_getUserDataKey(self->client);
|
||||
|
||||
// build database file path
|
||||
try(char* user_dir, p, path_getUserDir());
|
||||
Defer(free(user_dir));
|
||||
char* db_path = strcat_malloc(
|
||||
user_dir,
|
||||
path_seps".local"path_seps"tcp-chat-client"path_seps"user-db"path_seps,
|
||||
username.data, ".sqlite"
|
||||
);
|
||||
Defer(free(db_path));
|
||||
printf("loading database '%s'\n", db_path);
|
||||
|
||||
try(self->db, p, ClientDatabase_open(db_path));
|
||||
try(self->queries, p, ClientQueries_compile(self->db));
|
||||
|
||||
// load whole servers table to list
|
||||
try_void(SavedServer_getAll(self->queries, &self->saved_servers));
|
||||
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
static Result(void) ClientCLI_execCommand(ClientCLI* self, str command, bool* stop){
|
||||
Deferral(64);
|
||||
|
||||
@@ -190,236 +217,144 @@ static Result(void) ClientCLI_execCommand(ClientCLI* self, str command, bool* st
|
||||
|
||||
static Result(void) ClientCLI_joinNewServer(ClientCLI* self){
|
||||
Deferral(8);
|
||||
bool success = false;
|
||||
|
||||
// ask server address
|
||||
const u32 address_alloc_size = HOSTADDR_SIZE_MAX + 1;
|
||||
str address = str_construct((char*)malloc(address_alloc_size), address_alloc_size, true);
|
||||
Defer(if(!success) str_destroy(address));
|
||||
printf("Enter server address (ip:port):\n");
|
||||
char server_addr_cstr[HOSTADDR_SIZE_MAX + 1];
|
||||
try_void(term_readLine(server_addr_cstr, sizeof(server_addr_cstr)));
|
||||
str server_addr_str = str_from_cstr(server_addr_cstr);
|
||||
str_trim(&server_addr_str, true);
|
||||
try_void(term_readLine(address.data, address.len));
|
||||
address.len = strlen(address.data);
|
||||
str_trim(&address, true);
|
||||
|
||||
// ask server public key
|
||||
const u32 server_pk_alloc_size = PUBLIC_KEY_BASE64_SIZE_MAX + 1;
|
||||
str server_pk = str_construct((char*)malloc(server_pk_alloc_size), server_pk_alloc_size, true);
|
||||
Defer(if(!success) str_destroy(server_pk));
|
||||
printf("Enter server public key (RSA-Public-<SIZE>:<DATA>):\n");
|
||||
char server_pk_cstr[PUBLIC_KEY_BASE64_SIZE_MAX + 1];
|
||||
try_void(term_readLine(server_pk_cstr, sizeof(server_pk_cstr)));
|
||||
str server_pk_str = str_from_cstr(server_pk_cstr);
|
||||
str_trim(&server_pk_str, true);
|
||||
try_void(term_readLine(server_pk.data, server_pk.len));
|
||||
server_pk.len = strlen(server_pk.data);
|
||||
str_trim(&server_pk, true);
|
||||
|
||||
printf("Connecting to server...\n");
|
||||
try_void(Client_connect(self->client, server_addr_cstr, server_pk_cstr));
|
||||
try_void(Client_connect(self->client, address.data, server_pk.data));
|
||||
printf("Connection established\n");
|
||||
|
||||
str server_name = str_null;
|
||||
str server_description = str_null;
|
||||
try_void(Client_getServerName(self->client, &server_name));
|
||||
Defer(if(!success) str_destroy(server_name));
|
||||
str server_description = str_null;
|
||||
try_void(Client_getServerDescription(self->client, &server_description));
|
||||
try(ServerInfo* server, p, ClientCLI_saveServerInfo(self,
|
||||
server_addr_str, server_pk_str,
|
||||
server_name, server_description));
|
||||
Defer(if(!success) str_destroy(server_description));
|
||||
|
||||
try_void(ClientCLI_showServerInfo(self, server));
|
||||
SavedServer server = SavedServer_construct(
|
||||
address,
|
||||
server_pk,
|
||||
server_name,
|
||||
server_description
|
||||
);
|
||||
try_void(SavedServer_createOrUpdate(self->queries, &server));
|
||||
List_SavedServer_pushMany(&self->saved_servers, &server, 1);
|
||||
|
||||
try_void(ClientCLI_showSavedServer(self, &server));
|
||||
|
||||
success = true;
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
static Result(void) ClientCLI_selectServerFromCache(ClientCLI* self){
|
||||
Deferral(8);
|
||||
bool success = false;
|
||||
|
||||
// Lock table until this function returns.
|
||||
// It may not change any data in table, but it uses associated cache structures.
|
||||
idb_lockTable(self->servers.table);
|
||||
Defer(idb_unlockTable(self->servers.table));
|
||||
|
||||
u32 servers_count = self->servers.list.len;
|
||||
u32 servers_count = self->saved_servers.len;
|
||||
if(servers_count == 0){
|
||||
printf("No servers found in cache\n");
|
||||
printf("No saved servers found\n");
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
for(u32 id = 0; id < servers_count; id++){
|
||||
ServerInfo* server = self->servers.list.data + id;
|
||||
for(u32 i = 0; i < servers_count; i++){
|
||||
SavedServer* server = &self->saved_servers.data[i];
|
||||
printf("[%02u] "FMT_str" "FMT_str"\n",
|
||||
id, server->address_len, server->address, server->name_len, server->name);
|
||||
i, str_unwrap(server->address), str_unwrap(server->name));
|
||||
}
|
||||
|
||||
char buf[32];
|
||||
u32 id = -1;
|
||||
u32 selected_i = -1;
|
||||
while(true) {
|
||||
printf("Type 'q' to cancel\n");
|
||||
printf("Select server (number): ");
|
||||
printf("Select server number: ");
|
||||
try_void(term_readLine(buf, sizeof(buf)));
|
||||
str input_line = str_from_cstr(buf);
|
||||
str_trim(&input_line, true);
|
||||
if(str_equals(input_line, STR("q"))){
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
if(sscanf(buf, FMT_u32, &id) != 1){
|
||||
if(sscanf(buf, FMT_u32, &selected_i) != 1){
|
||||
printf("ERROR: not a number\n");
|
||||
}
|
||||
else if(id >= servers_count){
|
||||
printf("ERROR: not a server number: %u\n", id);
|
||||
else if(selected_i >= servers_count){
|
||||
printf("ERROR: not a server number\n");
|
||||
}
|
||||
else break;
|
||||
}
|
||||
ServerInfo* server = self->servers.list.data + id;
|
||||
SavedServer* selected_server = &self->saved_servers.data[selected_i];
|
||||
|
||||
printf("Connecting to '"FMT_str"'...\n", server->address_len, server->address);
|
||||
try_void(Client_connect(self->client, server->address, server->pk_base64));
|
||||
printf("Connecting to '"FMT_str"'...\n", str_unwrap(selected_server->address));
|
||||
try_void(Client_connect(self->client, selected_server->address.data, selected_server->pk_base64.data));
|
||||
printf("Connection established\n");
|
||||
|
||||
// update server name
|
||||
bool server_info_changed = false;
|
||||
// update cached server name
|
||||
str name = str_null;
|
||||
try_void(Client_getServerName(self->client, &name));
|
||||
if(!str_equals(name, str_construct(server->name, server->name_len, true))){
|
||||
str updated_server_name = str_null;
|
||||
try_void(Client_getServerName(self->client, &updated_server_name));
|
||||
Defer(if(!success) str_destroy(updated_server_name));
|
||||
if(!str_equals(updated_server_name, selected_server->name)){
|
||||
server_info_changed = true;
|
||||
if(name.len > SERVER_NAME_SIZE_MAX)
|
||||
name.len = SERVER_NAME_SIZE_MAX;
|
||||
server->name_len = name.len;
|
||||
memcpy(server->name, name.data, server->name_len);
|
||||
selected_server->name = updated_server_name;
|
||||
}
|
||||
// update cached server description
|
||||
str desc = str_null;
|
||||
try_void(Client_getServerDescription(self->client, &desc));
|
||||
if(!str_equals(desc, str_construct(server->desc, server->desc_len, true))){
|
||||
|
||||
// update server description
|
||||
str updated_server_description = str_null;
|
||||
try_void(Client_getServerDescription(self->client, &updated_server_description));
|
||||
Defer(if(!success) str_destroy(updated_server_description));
|
||||
if(!str_equals(updated_server_description, selected_server->description)){
|
||||
server_info_changed = true;
|
||||
if(desc.len > SERVER_DESC_SIZE_MAX)
|
||||
desc.len = SERVER_DESC_SIZE_MAX;
|
||||
server->desc_len = desc.len;
|
||||
memcpy(server->desc, desc.data, server->desc_len);
|
||||
selected_server->description = updated_server_description;
|
||||
}
|
||||
|
||||
if(server_info_changed){
|
||||
try_void(idb_updateRow(self->servers.table, id, server, false));
|
||||
try_void(SavedServer_createOrUpdate(self->queries, selected_server));
|
||||
}
|
||||
|
||||
try_void(ClientCLI_showServerInfo(self, server));
|
||||
try_void(ClientCLI_showSavedServer(self, selected_server));
|
||||
|
||||
success = true;
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
static Result(void) ClientCLI_showServerInfo(ClientCLI* self, ServerInfo* server){
|
||||
static Result(void) ClientCLI_showSavedServer(ClientCLI* self, SavedServer* server){
|
||||
Deferral(8);
|
||||
(void)self;
|
||||
|
||||
printf("Server Name: "FMT_str"\n", server->name_len, server->name);
|
||||
printf("Host Address: "FMT_str"\n", server->address_len, server->address);
|
||||
printf("Description:\n"FMT_str"\n\n", server->desc_len, server->desc);
|
||||
printf("Public Key:\n" FMT_str"\n\n", server->pk_base64_len, server->pk_base64);
|
||||
printf("Server Name: "FMT_str"\n", str_unwrap(server->name));
|
||||
printf("Host Address: "FMT_str"\n", str_unwrap(server->address));
|
||||
printf("Description:\n"FMT_str"\n\n", str_unwrap(server->description));
|
||||
printf("Public Key:\n" FMT_str"\n\n", str_unwrap(server->pk_base64));
|
||||
printf("Type 'register' if you don't have an account on the server.\n");
|
||||
printf("Type 'login' to authorize on the server.\n");
|
||||
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
static Result(void) ClientCLI_openUserDB(ClientCLI* self){
|
||||
Deferral(8);
|
||||
|
||||
str username = Client_getUserName(self->client);
|
||||
Array(u8) user_data_key = Client_getUserDataKey(self->client);
|
||||
str user_db_dir = str_from_cstr(strcat_malloc("client-db", path_seps, username.data));
|
||||
Defer(free(user_db_dir.data));
|
||||
try(self->db, p, idb_open(user_db_dir, user_data_key));
|
||||
|
||||
// Lock DB until this function returns.
|
||||
idb_lockDB(self->db);
|
||||
Defer(idb_unlockDB(self->db));
|
||||
|
||||
// Load servers table
|
||||
try(self->servers.table, p,
|
||||
idb_getOrCreateTable(self->db, str_null, STR("servers"), sizeof(ServerInfo), false)
|
||||
);
|
||||
|
||||
// Lock table until this function returns.
|
||||
idb_lockTable(self->servers.table);
|
||||
Defer(idb_unlockTable(self->servers.table));
|
||||
|
||||
// load whole servers table to list
|
||||
try_void(
|
||||
idb_createListFromTable(self->servers.table, (void*)&self->servers.list, false)
|
||||
);
|
||||
|
||||
// build address-id map
|
||||
try(u64 servers_count, u,
|
||||
idb_getRowCount(self->servers.table, false)
|
||||
);
|
||||
HashMap_construct(&self->servers.addr_id_map, u64, NULL);
|
||||
for(u64 id = 0; id < servers_count; id++){
|
||||
ServerInfo* server = self->servers.list.data + id;
|
||||
str key = str_construct(server->address, server->address_len, true);
|
||||
if(!HashMap_tryPush(&self->servers.addr_id_map, key, &id)){
|
||||
Return RESULT_ERROR_FMT(
|
||||
"duplicate server address '"FMT_str"'",
|
||||
key.len, key.data);
|
||||
}
|
||||
}
|
||||
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
static Result(ServerInfo*) ClientCLI_saveServerInfo(ClientCLI* self,
|
||||
str addr, str pk_base64, str name, str desc){
|
||||
Deferral(8);
|
||||
|
||||
// create new server info
|
||||
ServerInfo server;
|
||||
zeroStruct(&server);
|
||||
// address
|
||||
if(addr.len > HOSTADDR_SIZE_MAX)
|
||||
addr.len = HOSTADDR_SIZE_MAX;
|
||||
server.address_len = addr.len;
|
||||
memcpy(server.address, addr.data, server.address_len);
|
||||
// public key
|
||||
if(pk_base64.len > PUBLIC_KEY_BASE64_SIZE_MAX)
|
||||
pk_base64.len = PUBLIC_KEY_BASE64_SIZE_MAX;
|
||||
server.pk_base64_len = pk_base64.len;
|
||||
memcpy(server.pk_base64, pk_base64.data, server.pk_base64_len);
|
||||
// name
|
||||
if(name.len > SERVER_NAME_SIZE_MAX)
|
||||
name.len = SERVER_NAME_SIZE_MAX;
|
||||
server.name_len = name.len;
|
||||
memcpy(server.name, name.data, server.name_len);
|
||||
// description
|
||||
if(desc.len > SERVER_DESC_SIZE_MAX)
|
||||
desc.len = SERVER_DESC_SIZE_MAX;
|
||||
server.desc_len = desc.len;
|
||||
memcpy(server.desc, desc.data, server.desc_len);
|
||||
|
||||
// Lock table until this function returns.
|
||||
// It may not change any data in table, but it uses associated cache structures.
|
||||
idb_lockTable(self->servers.table);
|
||||
Defer(idb_unlockTable(self->servers.table));
|
||||
|
||||
// try find server id in cache
|
||||
ServerInfo* cached_row_ptr = NULL;
|
||||
u64* id_ptr = NULL;
|
||||
id_ptr = HashMap_tryGetPtr(&self->servers.addr_id_map, addr);
|
||||
if(id_ptr){
|
||||
// update existing server
|
||||
u64 id = *id_ptr;
|
||||
try_void(idb_updateRow(self->servers.table, id, &server, false));
|
||||
try_assert(id < self->servers.list.len);
|
||||
cached_row_ptr = self->servers.list.data + id;
|
||||
memcpy(cached_row_ptr, &server, sizeof(ServerInfo));
|
||||
}
|
||||
else {
|
||||
// push new server
|
||||
try(u64 id, u, idb_pushRow(self->servers.table, &server, false));
|
||||
try_assert(id == self->servers.list.len);
|
||||
List_ServerInfo_pushMany(&self->servers.list, &server, 1);
|
||||
cached_row_ptr = self->servers.list.data + id;
|
||||
try_assert(HashMap_tryPush(&self->servers.addr_id_map, addr, &id));
|
||||
}
|
||||
|
||||
Return RESULT_VALUE(p, cached_row_ptr);
|
||||
}
|
||||
|
||||
static Result(void) ClientCLI_register(ClientCLI* self){
|
||||
Deferral(8);
|
||||
|
||||
u64 user_id = 0;
|
||||
i64 user_id = 0;
|
||||
try_void(Client_register(self->client, &user_id));
|
||||
printf("Registered successfully\n");
|
||||
printf("user_id: "FMT_u64"\n", user_id);
|
||||
printf("user_id: "FMT_i64"\n", user_id);
|
||||
try_assert(user_id > 0);
|
||||
// TODO: use user_id somewhere
|
||||
|
||||
Return RESULT_VOID;
|
||||
@@ -428,10 +363,11 @@ static Result(void) ClientCLI_register(ClientCLI* self){
|
||||
static Result(void) ClientCLI_login(ClientCLI* self){
|
||||
Deferral(8);
|
||||
|
||||
u64 user_id = 0, landing_channel_id = 0;
|
||||
i64 user_id = 0, landing_channel_id = 0;
|
||||
try_void(Client_login(self->client, &user_id, &landing_channel_id));
|
||||
printf("Authorized successfully\n");
|
||||
printf("user_id: "FMT_u64", landing_channel_id: "FMT_u64"\n", user_id, landing_channel_id);
|
||||
printf("user_id: "FMT_i64", landing_channel_id: "FMT_i64"\n", user_id, landing_channel_id);
|
||||
try_assert(user_id > 0);
|
||||
// TODO: use user_id, landing_channel_id somewhere
|
||||
|
||||
Return RESULT_VOID;
|
||||
|
||||
@@ -3,19 +3,13 @@
|
||||
#include "tlibc/collections/HashMap.h"
|
||||
#include "tlibc/collections/List.h"
|
||||
#include "tcp-chat/client.h"
|
||||
#include "db/idb.h"
|
||||
#include "db/tables.h"
|
||||
|
||||
List_declare(ServerInfo);
|
||||
#include "db/client_db.h"
|
||||
|
||||
typedef struct ClientCLI {
|
||||
Client* client;
|
||||
IncrementalDB* db;
|
||||
struct {
|
||||
Table* table;
|
||||
List(ServerInfo) list; // index is id
|
||||
HashMap(u64) addr_id_map; // key is server address
|
||||
} servers;
|
||||
tsqlite_connection* db;
|
||||
ClientQueries* queries;
|
||||
List(SavedServer) saved_servers;
|
||||
} ClientCLI;
|
||||
|
||||
void ClientCLI_construct(ClientCLI* self);
|
||||
|
||||
98
src/cli/ClientCLI/db/SavedServer.c
Normal file
98
src/cli/ClientCLI/db/SavedServer.c
Normal file
@@ -0,0 +1,98 @@
|
||||
#include "client_db_internal.h"
|
||||
|
||||
void SavedServer_destroy(SavedServer* self){
|
||||
if(!self)
|
||||
return;
|
||||
str_destroy(self->address);
|
||||
str_destroy(self->pk_base64);
|
||||
str_destroy(self->name);
|
||||
str_destroy(self->description);
|
||||
}
|
||||
|
||||
Result(bool) SavedServer_exists(ClientQueries* q, str address){
|
||||
Deferral(4);
|
||||
|
||||
tsqlite_statement* st = q->servers.exists;
|
||||
Defer(tsqlite_statement_reset(st));
|
||||
try_void(tsqlite_statement_bind_str(st, "$address", address, NULL));
|
||||
|
||||
try(bool has_result, i, tsqlite_statement_step(st));
|
||||
|
||||
Return RESULT_VALUE(i, has_result);
|
||||
}
|
||||
|
||||
Result(bool) SavedServer_comparePublicKey(ClientQueries* q, str address, str pk_base64){
|
||||
Deferral(4);
|
||||
|
||||
tsqlite_statement* st = q->servers.compare_pk;
|
||||
Defer(tsqlite_statement_reset(st));
|
||||
try_void(tsqlite_statement_bind_str(st, "$address", address, NULL));
|
||||
try_void(tsqlite_statement_bind_str(st, "$pk_base64", pk_base64, NULL));
|
||||
|
||||
try(bool has_result, i, tsqlite_statement_step(st));
|
||||
|
||||
Return RESULT_VALUE(i, has_result);
|
||||
}
|
||||
|
||||
Result(void) SavedServer_createOrUpdate(ClientQueries* q, SavedServer* server){
|
||||
Deferral(4);
|
||||
try_assert(server->address.len >= HOSTADDR_SIZE_MIN && server->address.len <= HOSTADDR_SIZE_MAX);
|
||||
try_assert(server->pk_base64.len > 0 && server->pk_base64.len <= PUBLIC_KEY_BASE64_SIZE_MAX);
|
||||
try_assert(server->name.len >= SERVER_NAME_SIZE_MIN && server->name.len <= SERVER_NAME_SIZE_MAX);
|
||||
try_assert(server->description.len <= SERVER_DESC_SIZE_MAX);
|
||||
|
||||
try(bool server_exists, i, SavedServer_exists(q, server->address));
|
||||
tsqlite_statement* st = NULL;
|
||||
Defer(tsqlite_statement_reset(st));
|
||||
if(server_exists){
|
||||
st = q->servers.update;
|
||||
try(bool pk_matches, i, SavedServer_comparePublicKey(q, server->address, server->pk_base64));
|
||||
if(!pk_matches){
|
||||
Return RESULT_ERROR_FMT(
|
||||
"trying to update server '"FMT_str"' but public keys don't match",
|
||||
str_unwrap(server->address));
|
||||
}
|
||||
}
|
||||
else {
|
||||
st = q->servers.insert;
|
||||
try_void(tsqlite_statement_bind_str(st, "$pk_base64", server->pk_base64, NULL));
|
||||
}
|
||||
try_void(tsqlite_statement_bind_str(st, "$address", server->address, NULL));
|
||||
try_void(tsqlite_statement_bind_str(st, "$name", server->name, NULL));
|
||||
try_void(tsqlite_statement_bind_str(st, "$description", server->description, NULL));
|
||||
try_void(tsqlite_statement_step(st));
|
||||
|
||||
Return RESULT_VALUE(i, !server_exists);
|
||||
}
|
||||
|
||||
Result(void) SavedServer_getAll(ClientQueries* q, List(SavedServer)* dst_list){
|
||||
Deferral(4);
|
||||
|
||||
tsqlite_statement* st = q->servers.get_all;
|
||||
Defer(tsqlite_statement_reset(st));
|
||||
|
||||
SavedServer server = SavedServer_construct(str_null, str_null, str_null, str_null);
|
||||
str tmp_str = str_null;
|
||||
while(true){
|
||||
try(bool has_result, i, tsqlite_statement_step(st));
|
||||
if(!has_result)
|
||||
break;
|
||||
|
||||
// address
|
||||
try_void(tsqlite_statement_getResult_str(st, &tmp_str));
|
||||
server.address = str_copy(tmp_str);
|
||||
// pk_base64
|
||||
try_void(tsqlite_statement_getResult_str(st, &tmp_str));
|
||||
server.pk_base64 = str_copy(tmp_str);
|
||||
// name
|
||||
try_void(tsqlite_statement_getResult_str(st, &tmp_str));
|
||||
server.name = str_copy(tmp_str);
|
||||
// description
|
||||
try_void(tsqlite_statement_getResult_str(st, &tmp_str));
|
||||
server.description = str_copy(tmp_str);
|
||||
|
||||
List_SavedServer_pushMany(dst_list, &server, 1);
|
||||
}
|
||||
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
81
src/cli/ClientCLI/db/client_db.c
Normal file
81
src/cli/ClientCLI/db/client_db.c
Normal file
@@ -0,0 +1,81 @@
|
||||
#include "client_db_internal.h"
|
||||
#include "tlibc/filesystem.h"
|
||||
|
||||
Result(tsqlite_connection* db) ClientDatabase_open(cstr file_path){
|
||||
Deferral(64);
|
||||
|
||||
try_void(dir_createParent(file_path));
|
||||
try(tsqlite_connection* db, p, tsqlite_connection_open(file_path));
|
||||
bool success = false;
|
||||
Defer(if(!success) tsqlite_connection_close(db));
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// SERVERS //
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
try(tsqlite_statement* create_table_servers, p, tsqlite_statement_compile(db, STR(
|
||||
"CREATE TABLE IF NOT EXISTS servers (\n"
|
||||
" address VARCHAR PRIMARY KEY,\n"
|
||||
" pk_base64 VARCHAR NOT NULL,\n"
|
||||
" name VARCHAR NOT NULL,\n"
|
||||
" description VARCHAR NOT NULL\n"
|
||||
");"
|
||||
)));
|
||||
Defer(tsqlite_statement_free(create_table_servers));
|
||||
try_void(tsqlite_statement_step(create_table_servers));
|
||||
|
||||
success = true;
|
||||
Return RESULT_VALUE(p, db);
|
||||
}
|
||||
|
||||
|
||||
void ClientQueries_free(ClientQueries* q){
|
||||
if(!q)
|
||||
return;
|
||||
|
||||
tsqlite_statement_free(q->servers.insert);
|
||||
tsqlite_statement_free(q->servers.update);
|
||||
tsqlite_statement_free(q->servers.exists);
|
||||
tsqlite_statement_free(q->servers.compare_pk);
|
||||
tsqlite_statement_free(q->servers.get_all);
|
||||
|
||||
free(q);
|
||||
}
|
||||
|
||||
Result(ClientQueries*) ClientQueries_compile(tsqlite_connection* db){
|
||||
Deferral(4);
|
||||
|
||||
ClientQueries* q = (ClientQueries*)malloc(sizeof(*q));
|
||||
zeroStruct(q);
|
||||
bool success = false;
|
||||
Defer(if(!success) ClientQueries_free(q));
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// SERVERS //
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
try(q->servers.insert, p, tsqlite_statement_compile(db, STR(
|
||||
"INSERT INTO\n"
|
||||
"servers (address, pk_base64, name, description)\n"
|
||||
"VALUES ($address, $pk_base64, $name, $description);"
|
||||
)));
|
||||
|
||||
try(q->servers.update, p, tsqlite_statement_compile(db, STR(
|
||||
"UPDATE servers\n"
|
||||
"SET name = $name, description = $description\n"
|
||||
"WHERE address = $address;"
|
||||
)));
|
||||
|
||||
try(q->servers.exists, p, tsqlite_statement_compile(db, STR(
|
||||
"SELECT 1 FROM servers WHERE address = $address;"
|
||||
)));
|
||||
|
||||
try(q->servers.compare_pk, p, tsqlite_statement_compile(db, STR(
|
||||
"SELECT 1 FROM servers WHERE address = $address AND pk_base64 = $pk_base64;"
|
||||
)));
|
||||
|
||||
try(q->servers.get_all, p, tsqlite_statement_compile(db, STR(
|
||||
"SELECT * FROM servers;"
|
||||
)));
|
||||
|
||||
success = true;
|
||||
Return RESULT_VALUE(p, q);
|
||||
}
|
||||
38
src/cli/ClientCLI/db/client_db.h
Normal file
38
src/cli/ClientCLI/db/client_db.h
Normal file
@@ -0,0 +1,38 @@
|
||||
#pragma once
|
||||
#include "tcp-chat/tcp-chat.h"
|
||||
#include "tsqlite.h"
|
||||
#include "network/tcp-chat-protocol/v1.h"
|
||||
#include "tlibc/collections/List.h"
|
||||
|
||||
/// @brief open DB and create tables
|
||||
Result(tsqlite_connection* db) ClientDatabase_open(cstr file_path);
|
||||
|
||||
typedef struct ClientQueries ClientQueries;
|
||||
Result(ClientQueries*) ClientQueries_compile(tsqlite_connection* db);
|
||||
void ClientQueries_free(ClientQueries* self);
|
||||
|
||||
|
||||
typedef struct SavedServer {
|
||||
str address;
|
||||
str pk_base64;
|
||||
str name;
|
||||
str description;
|
||||
} SavedServer;
|
||||
|
||||
List_declare(SavedServer);
|
||||
|
||||
#define SavedServer_construct(ADDR, PK, NAME, DESC) ((SavedServer){ \
|
||||
.address = ADDR, .pk_base64 = PK, .name = NAME, .description = DESC })
|
||||
|
||||
void SavedServer_destroy(SavedServer* self);
|
||||
|
||||
/// @return true if new row was created
|
||||
Result(bool) SavedServer_createOrUpdate(ClientQueries* q, SavedServer* server);
|
||||
|
||||
/// @param dst_list there SavedServer values are pushed
|
||||
Result(void) SavedServer_getAll(ClientQueries* q, List(SavedServer)* dst_list);
|
||||
|
||||
Result(bool) SavedServer_exists(ClientQueries* q, str address);
|
||||
|
||||
/// @return true if provided key and saved key match
|
||||
Result(bool) SavedServer_comparePublicKey(ClientQueries* q, str address, str pk_base64);
|
||||
17
src/cli/ClientCLI/db/client_db_internal.h
Normal file
17
src/cli/ClientCLI/db/client_db_internal.h
Normal file
@@ -0,0 +1,17 @@
|
||||
#pragma once
|
||||
#include "client_db.h"
|
||||
|
||||
typedef struct ClientQueries {
|
||||
struct {
|
||||
/* ($address, $pk_base64, $name, $description) -> void */
|
||||
tsqlite_statement* insert;
|
||||
/* ($address, $name, $description) -> void */
|
||||
tsqlite_statement* update;
|
||||
/* ($address) -> 1 or nothing */
|
||||
tsqlite_statement* exists;
|
||||
/* ($address, $pk_base64) -> 1 or nothing */
|
||||
tsqlite_statement* compare_pk;
|
||||
/* () -> [(*)] */
|
||||
tsqlite_statement* get_all;
|
||||
} servers;
|
||||
} ClientQueries;
|
||||
@@ -15,7 +15,7 @@ Result(void) run_RsaGenStdin(u32 key_size) {
|
||||
do {
|
||||
read_n = fread(input_buf.data, 1, input_buf.len, stdin);
|
||||
if(read_n < 0){
|
||||
Return RESULT_ERROR("ERROR: can't read stdin", false);
|
||||
Return RESULT_ERROR_LITERAL("ERROR: can't read stdin");
|
||||
}
|
||||
// put bytes to rng as seed
|
||||
br_hmac_drbg_update(&rng, input_buf.data, read_n);
|
||||
|
||||
@@ -8,8 +8,6 @@ void ServerConnection_close(ServerConnection* self){
|
||||
EncryptedSocketTCP_destroy(&self->sock);
|
||||
Array_u8_destroy(&self->token);
|
||||
Array_u8_destroy(&self->session_key);
|
||||
str_destroy(self->server_name);
|
||||
str_destroy(self->server_description);
|
||||
free(self);
|
||||
}
|
||||
|
||||
@@ -75,18 +73,13 @@ Result(ServerConnection*) ServerConnection_open(Client* client, cstr server_addr
|
||||
PacketType_ServerHandshake));
|
||||
conn->session_id = server_handshake.session_id;
|
||||
|
||||
// get server name
|
||||
try_void(ServerConnection_requestServerName(conn));
|
||||
// get server description
|
||||
try_void(ServerConnection_requestServerDescription(conn));
|
||||
|
||||
success = true;
|
||||
Return RESULT_VALUE(p, conn);
|
||||
}
|
||||
|
||||
Result(void) ServerConnection_requestServerName(ServerConnection* conn){
|
||||
Result(void) ServerConnection_requestServerName(ServerConnection* conn, str* out_str){
|
||||
if(conn == NULL){
|
||||
return RESULT_ERROR("Client is not connected to a server", false);
|
||||
return RESULT_ERROR_LITERAL("Client is not connected to a server");
|
||||
}
|
||||
Deferral(4);
|
||||
|
||||
@@ -98,14 +91,14 @@ Result(void) ServerConnection_requestServerName(ServerConnection* conn){
|
||||
try_void(sendRequest(&conn->sock, &req_header, &public_info_req));
|
||||
try_void(recvResponse(&conn->sock, &res_header, &public_info_res,
|
||||
PacketType_ServerPublicInfoResponse));
|
||||
try_void(recvStr(&conn->sock, public_info_res.data_size, &conn->server_name));
|
||||
try_void(recvStr(&conn->sock, public_info_res.data_size, out_str));
|
||||
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
Result(void) ServerConnection_requestServerDescription(ServerConnection* conn){
|
||||
Result(void) ServerConnection_requestServerDescription(ServerConnection* conn, str* out_str){
|
||||
if(conn == NULL){
|
||||
return RESULT_ERROR("Client is not connected to a server", false);
|
||||
return RESULT_ERROR_LITERAL("Client is not connected to a server");
|
||||
}
|
||||
Deferral(4);
|
||||
|
||||
@@ -117,7 +110,7 @@ Result(void) ServerConnection_requestServerDescription(ServerConnection* conn){
|
||||
try_void(sendRequest(&conn->sock, &req_header, &public_info_req));
|
||||
try_void(recvResponse(&conn->sock, &res_header, &public_info_res,
|
||||
PacketType_ServerPublicInfoResponse));
|
||||
try_void(recvStr(&conn->sock, public_info_res.data_size, &conn->server_description));
|
||||
try_void(recvStr(&conn->sock, public_info_res.data_size, out_str));
|
||||
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
@@ -52,27 +52,27 @@ Array(u8) Client_getUserDataKey(Client* client){
|
||||
return client->user_data_key;
|
||||
}
|
||||
|
||||
Result(void) Client_getServerName(Client* self, str* out_name){
|
||||
Result(void) Client_getServerName(Client* self, str* out_str){
|
||||
Deferral(1);
|
||||
try_assert(self != NULL);
|
||||
try_assert(self->conn != NULL && "didn't connect to a server yet");
|
||||
|
||||
*out_name = self->conn->server_name;
|
||||
try_void(ServerConnection_requestServerName(self->conn, out_str));
|
||||
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
Result(void) Client_getServerDescription(Client* self, str* out_desc){
|
||||
Result(void) Client_getServerDescription(Client* self, str* out_str){
|
||||
Deferral(1);
|
||||
try_assert(self != NULL);
|
||||
try_assert(self->conn != NULL && "didn't connect to a server yet");
|
||||
|
||||
*out_desc = self->conn->server_description;
|
||||
try_void(ServerConnection_requestServerDescription(self->conn, out_str));
|
||||
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
Result(void) Client_register(Client* self, u64* out_user_id){
|
||||
Result(void) Client_register(Client* self, i64* out_user_id){
|
||||
Deferral(1);
|
||||
try_assert(self != NULL);
|
||||
try_assert(self->conn != NULL && "didn't connect to a server yet");
|
||||
@@ -90,7 +90,7 @@ Result(void) Client_register(Client* self, u64* out_user_id){
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
Result(void) Client_login(Client* self, u64* out_user_id, u64* out_landing_channel_id){
|
||||
Result(void) Client_login(Client* self, i64* out_user_id, i64* out_landing_channel_id){
|
||||
Deferral(1);
|
||||
try_assert(self != NULL);
|
||||
try_assert(self->conn != NULL && "didn't connect to a server yet");
|
||||
|
||||
@@ -21,10 +21,8 @@ typedef struct ServerConnection {
|
||||
Array(u8) token;
|
||||
Array(u8) session_key;
|
||||
EncryptedSocketTCP sock;
|
||||
u64 session_id;
|
||||
str server_name;
|
||||
str server_description;
|
||||
u64 user_id;
|
||||
i64 session_id;
|
||||
i64 user_id;
|
||||
} ServerConnection;
|
||||
|
||||
/// @param server_addr_cstr
|
||||
@@ -34,8 +32,8 @@ Result(ServerConnection*) ServerConnection_open(Client* client,
|
||||
|
||||
void ServerConnection_close(ServerConnection* conn);
|
||||
|
||||
/// updates conn->server_name
|
||||
Result(void) ServerConnection_requestServerName(ServerConnection* conn);
|
||||
/// @param out_str heap-allocated string
|
||||
Result(void) ServerConnection_requestServerName(ServerConnection* conn, str* out_str);
|
||||
|
||||
/// updates conn->server_description
|
||||
Result(void) ServerConnection_requestServerDescription(ServerConnection* conn);
|
||||
/// @param out_str heap-allocated string
|
||||
Result(void) ServerConnection_requestServerDescription(ServerConnection* conn, str* out_str);
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#include "requests.h"
|
||||
|
||||
Result(void) recvStr(EncryptedSocketTCP* sock, u32 size, str* out_s){
|
||||
Result(void) recvStr(EncryptedSocketTCP* sock, u32 size, str* out_str){
|
||||
Deferral(4);
|
||||
|
||||
str s = str_construct(malloc(size + 1), size, true);
|
||||
@@ -17,7 +17,7 @@ Result(void) recvStr(EncryptedSocketTCP* sock, u32 size, str* out_s){
|
||||
);
|
||||
|
||||
s.data[s.len] = 0;
|
||||
*out_s = s;
|
||||
*out_str = s;
|
||||
success = true;
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ Result(void) _recvResponse(EncryptedSocketTCP* sock,
|
||||
if(res_header->type == PacketType_ErrorMessage){
|
||||
str err_msg;
|
||||
try_void(recvErrorMessage(sock, res_header, &err_msg));
|
||||
Return RESULT_ERROR(err_msg.data, true);
|
||||
Return RESULT_ERROR(err_msg, true);
|
||||
}
|
||||
|
||||
try_void(PacketHeader_validateType(res_header, res_type));
|
||||
|
||||
@@ -3,10 +3,12 @@
|
||||
#include "client/client_internal.h"
|
||||
|
||||
|
||||
/// @param out_err_msg heap-allocated string
|
||||
Result(void) recvErrorMessage(EncryptedSocketTCP* sock, PacketHeader* res_header,
|
||||
str* out_err_msg);
|
||||
|
||||
Result(void) recvStr(EncryptedSocketTCP* sock, u32 size, str* out_s);
|
||||
/// @param out_str heap-allocated string
|
||||
Result(void) recvStr(EncryptedSocketTCP* sock, u32 size, str* out_str);
|
||||
|
||||
Result(void) _recvResponse(EncryptedSocketTCP* sock,
|
||||
PacketHeader* res_header, Array(u8) res, PacketType res_type);
|
||||
|
||||
@@ -127,7 +127,7 @@ Result(u32) AESBlockDecryptor_decrypt(AESBlockDecryptor* ptr,
|
||||
|
||||
// validate decrypted data
|
||||
if(memcmp(header.key_checksum, ptr->key_checksum, __AES_BLOCK_KEY_CHECKSUM_SIZE) != 0){
|
||||
Return RESULT_ERROR("decrypted data is invalid or key is wrong", false);
|
||||
Return RESULT_ERROR_LITERAL("decrypted data is invalid or key is wrong");
|
||||
}
|
||||
|
||||
// size of decrypted data without padding
|
||||
@@ -266,7 +266,7 @@ Result(u32) AESStreamDecryptor_decrypt(AESStreamDecryptor* ptr,
|
||||
key_checksum, __AES_BLOCK_KEY_CHECKSUM_SIZE);
|
||||
// validate decrypted data
|
||||
if(memcmp(key_checksum, ptr->key_checksum, __AES_BLOCK_KEY_CHECKSUM_SIZE) != 0){
|
||||
Return RESULT_ERROR("decrypted data is invalid or key is wrong", false);
|
||||
Return RESULT_ERROR_LITERAL("decrypted data is invalid or key is wrong");
|
||||
}
|
||||
}
|
||||
// size without IV
|
||||
|
||||
@@ -25,7 +25,7 @@ Result(void) RSA_generateKeyPair(u32 key_size,
|
||||
|
||||
success = br_rsa_i31_keygen(rng_vtable_ptr, sk, sk_buf, pk, pk_buf, key_size, DEFAULT_PUBLIC_EXPONENT);
|
||||
if(!success){
|
||||
Return RESULT_ERROR("br_rsa_i31_keygen() failed", false);
|
||||
Return RESULT_ERROR_LITERAL("br_rsa_i31_keygen() failed");
|
||||
}
|
||||
|
||||
Return RESULT_VOID;
|
||||
@@ -58,7 +58,7 @@ Result(void) RSA_computePublicKey(const br_rsa_private_key* sk, br_rsa_public_ke
|
||||
|
||||
size_t modulus_size = compute_modulus(NULL, sk);
|
||||
if (modulus_size == 0) {
|
||||
Return RESULT_ERROR("compute_modulus", false);
|
||||
Return RESULT_ERROR_LITERAL("compute_modulus");
|
||||
}
|
||||
void* modulus = malloc(modulus_size);
|
||||
bool success = false;
|
||||
@@ -67,12 +67,12 @@ Result(void) RSA_computePublicKey(const br_rsa_private_key* sk, br_rsa_public_ke
|
||||
free(modulus)
|
||||
);
|
||||
if (compute_modulus(modulus, sk) != modulus_size) {
|
||||
Return RESULT_ERROR("compute_modulus", false);
|
||||
Return RESULT_ERROR_LITERAL("compute_modulus");
|
||||
}
|
||||
|
||||
u32 pubexp_little_endian = compute_pubexp(sk);
|
||||
if (pubexp_little_endian == 0) {
|
||||
Return RESULT_ERROR("compute_pubexp", false);
|
||||
Return RESULT_ERROR_LITERAL("compute_pubexp");
|
||||
}
|
||||
u8 pubexp_big_endian[4];
|
||||
pubexp_big_endian[0] = pubexp_little_endian >> 24;
|
||||
@@ -115,7 +115,7 @@ Result(void) RSA_parsePublicKey_base64(cstr src, br_rsa_public_key* pk){
|
||||
Deferral(4);
|
||||
u32 n_bitlen = 0;
|
||||
if(sscanf(src, "RSA-Public-%u:", &n_bitlen) != 1){
|
||||
Return RESULT_ERROR("can't parse key size", false);
|
||||
Return RESULT_ERROR_LITERAL("can't parse key size");
|
||||
}
|
||||
u32 key_buffer_size = BR_RSA_KBUF_PUB_SIZE(n_bitlen);
|
||||
pk->n = malloc(key_buffer_size);
|
||||
@@ -125,7 +125,7 @@ Result(void) RSA_parsePublicKey_base64(cstr src, br_rsa_public_key* pk){
|
||||
str src_str = str_from_cstr(src);
|
||||
u32 offset = str_seekChar(src_str, ':', 10) + 1;
|
||||
if(offset == 0){
|
||||
Return RESULT_ERROR("missing ':' before key data", false);
|
||||
Return RESULT_ERROR_LITERAL("missing ':' before key data");
|
||||
}
|
||||
str key_base64_str = src_str;
|
||||
key_base64_str.data += offset;
|
||||
@@ -136,7 +136,7 @@ Result(void) RSA_parsePublicKey_base64(cstr src, br_rsa_public_key* pk){
|
||||
}
|
||||
decoded_size = base64_decode(key_base64_str.data, key_base64_str.len, pk->n);
|
||||
if(decoded_size != key_buffer_size){
|
||||
Return RESULT_ERROR("key decoding failed", false);
|
||||
Return RESULT_ERROR_LITERAL("key decoding failed");
|
||||
}
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
@@ -145,7 +145,7 @@ Result(void) RSA_parsePrivateKey_base64(cstr src, br_rsa_private_key* sk){
|
||||
Deferral(4);
|
||||
u32 n_bitlen = 0;
|
||||
if(sscanf(src, "RSA-Private-%u:", &n_bitlen) != 1){
|
||||
Return RESULT_ERROR("can't parse key size", false);
|
||||
Return RESULT_ERROR_LITERAL("can't parse key size");
|
||||
}
|
||||
sk->n_bitlen = n_bitlen;
|
||||
u32 key_buffer_size = BR_RSA_KBUF_PRIV_SIZE(n_bitlen);
|
||||
@@ -159,7 +159,7 @@ Result(void) RSA_parsePrivateKey_base64(cstr src, br_rsa_private_key* sk){
|
||||
str src_str = str_from_cstr(src);
|
||||
u32 offset = str_seekChar(src_str, ':', 10) + 1;
|
||||
if(offset == 0){
|
||||
Return RESULT_ERROR("missing ':' before key data", false);
|
||||
Return RESULT_ERROR_LITERAL("missing ':' before key data");
|
||||
}
|
||||
str key_base64_str = src_str;
|
||||
key_base64_str.data += offset;
|
||||
@@ -170,7 +170,7 @@ Result(void) RSA_parsePrivateKey_base64(cstr src, br_rsa_private_key* sk){
|
||||
}
|
||||
decoded_size = base64_decode(key_base64_str.data, key_base64_str.len, sk->p);
|
||||
if(decoded_size != key_buffer_size){
|
||||
Return RESULT_ERROR("key decoding failed", false);
|
||||
Return RESULT_ERROR_LITERAL("key decoding failed");
|
||||
}
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
@@ -205,7 +205,7 @@ Result(u32) RSAEncryptor_encrypt(RSAEncryptor* ptr, Array(u8) src, Array(u8) dst
|
||||
src.data, src.len);
|
||||
|
||||
if(sz == 0){
|
||||
return RESULT_ERROR("RSA encryption failed", false);
|
||||
return RESULT_ERROR_LITERAL("RSA encryption failed");
|
||||
}
|
||||
return RESULT_VALUE(u, sz);
|
||||
}
|
||||
@@ -234,7 +234,7 @@ Result(u32) RSADecryptor_decrypt(RSADecryptor* ptr, Array(u8) buffer){
|
||||
buffer.data, &sz);
|
||||
|
||||
if(r == 0){
|
||||
return RESULT_ERROR("RSA encryption failed", false);
|
||||
return RESULT_ERROR_LITERAL("RSA encryption failed");
|
||||
}
|
||||
return RESULT_VALUE(u, sz);
|
||||
}
|
||||
|
||||
521
src/db/idb.c
521
src/db/idb.c
@@ -1,521 +0,0 @@
|
||||
#include "idb.h"
|
||||
#include "tlibc/magic.h"
|
||||
#include "tlibc/filesystem.h"
|
||||
#include "tlibc/collections/HashMap.h"
|
||||
#include "tlibc/string/StringBuilder.h"
|
||||
#include "cryptography/AES.h"
|
||||
#include <pthread.h>
|
||||
|
||||
static const char KEY_CHALLENGE_PLAIN[16] = "key is correct!";
|
||||
#define KEY_CHALLENGE_PLAIN_SIZE sizeof(KEY_CHALLENGE_PLAIN)
|
||||
#define KEY_CHALLENGE_CIPHER_SIZE AESBlockEncryptor_calcDstSize(KEY_CHALLENGE_PLAIN_SIZE)
|
||||
|
||||
typedef struct TableFileHeader {
|
||||
Magic32 magic;
|
||||
u16 version;
|
||||
bool _dirty_bit;
|
||||
bool encrypted;
|
||||
u32 row_size;
|
||||
/* encrypted KEY_CHALLENGE_PLAIN */
|
||||
u8 key_challenge[KEY_CHALLENGE_CIPHER_SIZE];
|
||||
} ATTRIBUTE_ALIGNED(256) TableFileHeader;
|
||||
|
||||
typedef struct Table {
|
||||
TableFileHeader header;
|
||||
IncrementalDB* db;
|
||||
str name;
|
||||
str table_file_path;
|
||||
str changes_file_path;
|
||||
FILE* table_file;
|
||||
FILE* changes_file;
|
||||
pthread_mutex_t mutex;
|
||||
u64 row_count;
|
||||
AESBlockEncryptor enc;
|
||||
AESBlockDecryptor dec;
|
||||
Array(u8) enc_buf;
|
||||
} Table;
|
||||
|
||||
typedef struct IncrementalDB {
|
||||
str db_dir;
|
||||
Array(u8) aes_key;
|
||||
HashMap(Table**) tables_map;
|
||||
pthread_mutex_t mutex;
|
||||
} IncrementalDB;
|
||||
|
||||
static const Magic32 TABLE_FILE_MAGIC = { .bytes = { 'I', 'D', 'B', 't' } };
|
||||
|
||||
|
||||
static void Table_close(Table* t){
|
||||
if(t == NULL)
|
||||
return;
|
||||
fclose(t->table_file);
|
||||
fclose(t->changes_file);
|
||||
str_destroy(t->name);
|
||||
str_destroy(t->table_file_path);
|
||||
str_destroy(t->changes_file_path);
|
||||
pthread_mutex_destroy(&t->mutex);
|
||||
Array_u8_destroy(&t->enc_buf);
|
||||
free(t);
|
||||
}
|
||||
|
||||
// element destructor for HashMap(Table*)
|
||||
static void TablePtr_free(void* t_ptr_ptr){
|
||||
Table_close(*(Table**)t_ptr_ptr);
|
||||
}
|
||||
|
||||
/// @param name must be null-terminated
|
||||
static Result(void) validateTableName(str name){
|
||||
char forbidden_characters[] = { '/', '\\', ':', ';', '?', '"', '\'', '\n', '\r', '\t'};
|
||||
for(u32 i = 0; i < ARRAY_LEN(forbidden_characters); i++) {
|
||||
char c = forbidden_characters[i];
|
||||
if(str_seekChar(name, c, 0) != -1){
|
||||
return RESULT_ERROR_FMT(
|
||||
"Table name '%s' contains forbidden character '%c'",
|
||||
name.data, c);
|
||||
}
|
||||
}
|
||||
|
||||
return RESULT_VOID;
|
||||
}
|
||||
|
||||
static Result(void) Table_readHeader(Table* t){
|
||||
Deferral(4);
|
||||
// seek for start of the file
|
||||
try_void(file_seek(t->table_file, 0, SeekOrigin_Start));
|
||||
// read header
|
||||
try_void(file_readStructsExactly(t->table_file, &t->header, sizeof(t->header), 1));
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
static Result(void) Table_writeHeader(Table* t){
|
||||
Deferral(4);
|
||||
// seek for start of the file
|
||||
try_void(file_seek(t->table_file, 0, SeekOrigin_Start));
|
||||
// write header
|
||||
try_void(file_writeStructs(t->table_file, &t->header, sizeof(t->header), 1));
|
||||
// TODO: add more fflush calls
|
||||
fflush(t->table_file);
|
||||
fflush(t->changes_file);
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
static Result(void) Table_setDirtyBit(Table* t, bool val){
|
||||
Deferral(4);
|
||||
t->header._dirty_bit = val;
|
||||
try_void(Table_writeHeader(t));
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
static Result(bool) Table_getDirtyBit(Table* t){
|
||||
Deferral(4);
|
||||
try_void(Table_readHeader(t));
|
||||
Return RESULT_VALUE(i, t->header._dirty_bit);
|
||||
}
|
||||
|
||||
static Result(void) Table_calculateRowCount(Table* t){
|
||||
Deferral(4);
|
||||
try(i64 file_size, i, file_getSize(t->table_file));
|
||||
i64 data_size = file_size - sizeof(t->header);
|
||||
i64 row_size_in_file = t->header.encrypted
|
||||
? AESBlockEncryptor_calcDstSize(t->header.row_size)
|
||||
: t->header.row_size;
|
||||
if(data_size % row_size_in_file != 0){
|
||||
//TODO: fix table instead of trowing error
|
||||
Return RESULT_ERROR_FMT(
|
||||
"Table '%s' has invalid size. Last row is incomplete.",
|
||||
t->name.data);
|
||||
}
|
||||
|
||||
t->row_count = data_size / row_size_in_file;
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
static Result(void) Table_validateHeader(Table* t){
|
||||
Deferral(4);
|
||||
if(t->header.magic.n != TABLE_FILE_MAGIC.n
|
||||
|| t->header.row_size == 0)
|
||||
{
|
||||
Return RESULT_ERROR_FMT(
|
||||
"Table file '%s' has invalid header",
|
||||
t->table_file_path.data);
|
||||
}
|
||||
|
||||
if(t->header.version != IDB_VERSION){
|
||||
Return RESULT_ERROR_FMT(
|
||||
"Table file '%s' was created for IDBv%u, which is incompatible with v%u",
|
||||
t->table_file_path.data, t->header.version, (u32)IDB_VERSION);
|
||||
}
|
||||
|
||||
try(bool dirty_bit, i, Table_getDirtyBit(t));
|
||||
if(dirty_bit){
|
||||
//TODO: handle dirty bit instead of throwing error
|
||||
Return RESULT_ERROR_FMT(
|
||||
"Table file '%s' has dirty bit set",
|
||||
t->table_file_path.data);
|
||||
}
|
||||
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
static Result(void) Table_validateEncryption(Table* t){
|
||||
Deferral(1);
|
||||
|
||||
bool db_encrypted = t->db->aes_key.len != 0;
|
||||
if(t->header.encrypted && !db_encrypted){
|
||||
Return RESULT_ERROR_FMT(
|
||||
"Table '%s' is encrypted, but encryption key is not set."
|
||||
"Database '%s' is encrypted and must have not-null encryption key.",
|
||||
t->name.data, t->db->db_dir.data);
|
||||
}
|
||||
|
||||
if(!t->header.encrypted && db_encrypted){
|
||||
Return RESULT_ERROR_FMT(
|
||||
"Table '%s' is not encrypted, but encryption key is set."
|
||||
"Do not set encryption key for not encrypted database '%s'.",
|
||||
t->name.data, t->db->db_dir.data);
|
||||
}
|
||||
|
||||
// validate aes encryption key
|
||||
if(t->header.encrypted){
|
||||
try_void(
|
||||
AESBlockDecryptor_decrypt(
|
||||
&t->dec,
|
||||
Array_u8_construct(t->header.key_challenge, KEY_CHALLENGE_CIPHER_SIZE),
|
||||
t->enc_buf
|
||||
)
|
||||
);
|
||||
if(memcmp(t->enc_buf.data, KEY_CHALLENGE_PLAIN, KEY_CHALLENGE_PLAIN_SIZE) != 0){
|
||||
Return RESULT_ERROR_FMT(
|
||||
"Encryption key for table '%s' is wrong",
|
||||
t->name.data);
|
||||
}
|
||||
}
|
||||
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
static Result(void) Table_validateRowSize(Table* t, u32 row_size){
|
||||
if(row_size != t->header.row_size){
|
||||
return RESULT_ERROR_FMT(
|
||||
"Requested row size (%u) doesn't match saved row size (%u)",
|
||||
row_size, t->header.row_size);
|
||||
}
|
||||
|
||||
return RESULT_VOID;
|
||||
}
|
||||
|
||||
|
||||
void idb_close(IncrementalDB* db){
|
||||
if(db == NULL)
|
||||
return;
|
||||
str_destroy(db->db_dir);
|
||||
Array_u8_destroy(&db->aes_key);
|
||||
HashMap_destroy(&db->tables_map);
|
||||
pthread_mutex_destroy(&db->mutex);
|
||||
free(db);
|
||||
}
|
||||
|
||||
Result(IncrementalDB*) idb_open(str db_dir, NULLABLE(Array(u8) aes_key)){
|
||||
Deferral(16);
|
||||
try_assert(db_dir.len > 0);
|
||||
try_assert(aes_key.len == 0 || aes_key.len == 16 || aes_key.len == 24 || aes_key.len == 32);
|
||||
|
||||
IncrementalDB* db = (IncrementalDB*)malloc(sizeof(IncrementalDB));
|
||||
// value of *db must be set to zero or behavior of idb_close will be undefined
|
||||
zeroStruct(db);
|
||||
// if object construction fails, destroy incomplete object
|
||||
bool success = false;
|
||||
Defer(if(!success) idb_close(db));
|
||||
|
||||
if(aes_key.len != 0){
|
||||
db->aes_key = Array_u8_copy(aes_key);
|
||||
}
|
||||
|
||||
db->db_dir = str_copy(db_dir);
|
||||
try_void(dir_create(db->db_dir.data));
|
||||
HashMap_construct(&db->tables_map, Table*, TablePtr_free);
|
||||
try_stderrcode(pthread_mutex_init(&db->mutex, NULL));
|
||||
|
||||
success = true;
|
||||
Return RESULT_VALUE(p, db);
|
||||
}
|
||||
|
||||
void idb_lockDB(IncrementalDB* db){
|
||||
try_fatal_stderrcode(pthread_mutex_lock(&db->mutex));
|
||||
}
|
||||
|
||||
void idb_unlockDB(IncrementalDB* db){
|
||||
try_fatal_stderrcode(pthread_mutex_unlock(&db->mutex));
|
||||
}
|
||||
|
||||
void idb_lockTable(Table* t){
|
||||
try_fatal_stderrcode(pthread_mutex_lock(&t->mutex));
|
||||
}
|
||||
|
||||
void idb_unlockTable(Table* t){
|
||||
try_fatal_stderrcode(pthread_mutex_unlock(&t->mutex));
|
||||
}
|
||||
|
||||
|
||||
Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str subdir, str table_name, u32 row_size, bool lock_db){
|
||||
Deferral(16);
|
||||
|
||||
if(lock_db){
|
||||
idb_lockDB(db);
|
||||
Defer(idb_unlockDB(db));
|
||||
}
|
||||
|
||||
Table** tpp = HashMap_tryGetPtr(&db->tables_map, table_name);
|
||||
if(tpp != NULL){
|
||||
Table* existing_table = *tpp;
|
||||
try_void(Table_validateRowSize(existing_table, row_size));
|
||||
Return RESULT_VALUE(p, existing_table);
|
||||
}
|
||||
|
||||
try_void(validateTableName(table_name));
|
||||
|
||||
Table* t = (Table*)malloc(sizeof(Table));
|
||||
// value of *t must be set to zero or behavior of Table_close will be undefined
|
||||
zeroStruct(t);
|
||||
// if object construction fails, destroy incomplete object
|
||||
bool success = false;
|
||||
Defer(if(!success) Table_close(t));
|
||||
|
||||
t->db = db;
|
||||
try_stderrcode(pthread_mutex_init(&t->mutex, NULL));
|
||||
t->name = str_copy(table_name);
|
||||
|
||||
// set file paths
|
||||
StringBuilder sb = StringBuilder_alloc(256);
|
||||
Defer(StringBuilder_destroy(&sb));
|
||||
StringBuilder_append_str(&sb, db->db_dir);
|
||||
StringBuilder_append_char(&sb, path_sep);
|
||||
if(subdir.len != 0){
|
||||
StringBuilder_append_str(&sb, subdir);
|
||||
try_void(dir_create(StringBuilder_getStr(&sb).data));
|
||||
StringBuilder_append_char(&sb, path_sep);
|
||||
}
|
||||
StringBuilder_append_str(&sb, t->name);
|
||||
// table file
|
||||
str table_file_ext = STR(".idb-table");
|
||||
StringBuilder_append_str(&sb, table_file_ext);
|
||||
t->table_file_path = str_copy(StringBuilder_getStr(&sb));
|
||||
// changes file
|
||||
str changes_file_ext = STR(".idb-changes");
|
||||
StringBuilder_removeFromEnd(&sb, table_file_ext.len);
|
||||
StringBuilder_append_str(&sb, changes_file_ext);
|
||||
t->changes_file_path = str_copy(StringBuilder_getStr(&sb));
|
||||
|
||||
bool table_file_exists = file_exists(t->table_file_path.data);
|
||||
|
||||
// open or create file with table data
|
||||
try(t->table_file, p, file_openOrCreateReadWrite(t->table_file_path.data));
|
||||
// open or create file with backups of updated rows
|
||||
try(t->changes_file, p, file_openOrCreateReadWrite(t->changes_file_path.data));
|
||||
|
||||
// init encryptor and decryptor now to use them in table header validation/creation
|
||||
if(db->aes_key.len != 0) {
|
||||
AESBlockEncryptor_construct(&t->enc, db->aes_key, AESBlockEncryptor_DEFAULT_CLASS);
|
||||
AESBlockDecryptor_construct(&t->dec, db->aes_key, AESBlockDecryptor_DEFAULT_CLASS);
|
||||
u32 row_size_in_file = AESBlockEncryptor_calcDstSize(row_size);
|
||||
t->enc_buf = Array_u8_alloc(row_size_in_file);
|
||||
}
|
||||
|
||||
// init header
|
||||
if(table_file_exists){
|
||||
// read table file
|
||||
try_void(Table_readHeader(t));
|
||||
try_void(Table_validateHeader(t));
|
||||
try_void(Table_validateEncryption(t));
|
||||
try_void(Table_validateRowSize(t, row_size));
|
||||
try_void(Table_calculateRowCount(t));
|
||||
}
|
||||
else {
|
||||
// create table file
|
||||
t->header.magic.n = TABLE_FILE_MAGIC.n;
|
||||
t->header.version = IDB_VERSION;
|
||||
t->header.encrypted = db->aes_key.len != 0;
|
||||
t->header._dirty_bit = false;
|
||||
t->header.row_size = row_size;
|
||||
memset(t->header.key_challenge, 0, KEY_CHALLENGE_CIPHER_SIZE);
|
||||
try_void(
|
||||
AESBlockEncryptor_encrypt(
|
||||
&t->enc,
|
||||
Array_u8_construct((void*)KEY_CHALLENGE_PLAIN, KEY_CHALLENGE_PLAIN_SIZE),
|
||||
Array_u8_construct(t->header.key_challenge, KEY_CHALLENGE_CIPHER_SIZE)
|
||||
)
|
||||
);
|
||||
try_void(Table_writeHeader(t));
|
||||
}
|
||||
|
||||
if(!HashMap_tryPush(&db->tables_map, t->name, &t)){
|
||||
Return RESULT_ERROR_FMT(
|
||||
"Table '%s' is already open",
|
||||
t->name.data);
|
||||
}
|
||||
|
||||
success = true;
|
||||
Return RESULT_VALUE(p, t);
|
||||
}
|
||||
|
||||
Result(void) idb_getRows(Table* t, u64 id, void* dst, u64 count, bool lock_table){
|
||||
Deferral(8);
|
||||
|
||||
if(lock_table){
|
||||
idb_lockTable(t);
|
||||
Defer(idb_unlockTable(t));
|
||||
}
|
||||
|
||||
if(id + count > t->row_count){
|
||||
Return RESULT_ERROR_FMT(
|
||||
"Can't read "FMT_u64" rows at index "FMT_u64
|
||||
" because table '%s' has only "FMT_u64" rows",
|
||||
count, id, t->name.data, t->row_count);
|
||||
}
|
||||
|
||||
u32 row_size = t->header.row_size;
|
||||
u32 row_size_in_file = t->header.encrypted ? t->enc_buf.len : row_size;
|
||||
i64 file_pos = sizeof(t->header) + id * row_size_in_file;
|
||||
|
||||
// seek for the row position in file
|
||||
try_void(file_seek(t->table_file, file_pos, SeekOrigin_Start));
|
||||
|
||||
// read rows from file
|
||||
for(u64 i = 0; i < count; i++){
|
||||
void* row_ptr = (u8*)dst + row_size * i;
|
||||
void* read_dst = t->header.encrypted
|
||||
? t->enc_buf.data
|
||||
: row_ptr;
|
||||
try_void(file_readStructsExactly(t->table_file, read_dst, row_size_in_file, 1));
|
||||
if(t->header.encrypted) {
|
||||
try_void(
|
||||
AESBlockDecryptor_decrypt(
|
||||
&t->dec,
|
||||
t->enc_buf,
|
||||
Array_u8_construct(row_ptr, row_size)
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
Result(void) idb_updateRows(Table* t, u64 id, const void* src, u64 count, bool lock_table){
|
||||
Deferral(8);
|
||||
|
||||
if(lock_table){
|
||||
idb_lockTable(t);
|
||||
Defer(idb_unlockTable(t));
|
||||
}
|
||||
|
||||
if(id + count > t->row_count){
|
||||
Return RESULT_ERROR_FMT(
|
||||
"Can't update "FMT_u64" rows at index "FMT_u64
|
||||
" because table '%s' has only "FMT_u64" rows",
|
||||
count, id, t->name.data, t->row_count);
|
||||
}
|
||||
|
||||
try_void(Table_setDirtyBit(t, true));
|
||||
Defer(IGNORE_RESULT Table_setDirtyBit(t, false));
|
||||
|
||||
u32 row_size = t->header.row_size;
|
||||
u32 row_size_in_file = t->header.encrypted ? t->enc_buf.len : row_size;
|
||||
i64 file_pos = sizeof(t->header) + id * row_size_in_file;
|
||||
|
||||
// TODO: set dirty bit in backup file too
|
||||
// TODO: save old values to the backup file
|
||||
|
||||
// seek for the row position in file
|
||||
try_void(file_seek(t->table_file, file_pos, SeekOrigin_Start));
|
||||
|
||||
// replace rows in file
|
||||
for(u64 i = 0; i < count; i++){
|
||||
void* row_ptr = (u8*)src + row_size * i;
|
||||
if(t->header.encrypted){
|
||||
try_void(
|
||||
AESBlockEncryptor_encrypt(
|
||||
&t->enc,
|
||||
Array_u8_construct(row_ptr, row_size),
|
||||
t->enc_buf
|
||||
)
|
||||
);
|
||||
row_ptr = t->enc_buf.data;
|
||||
}
|
||||
try_void(file_writeStructs(t->table_file, row_ptr, row_size_in_file, 1));
|
||||
}
|
||||
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
Result(u64) idb_pushRows(Table* t, const void* src, u64 count, bool lock_table){
|
||||
Deferral(8);
|
||||
|
||||
if(lock_table){
|
||||
idb_lockTable(t);
|
||||
Defer(idb_unlockTable(t));
|
||||
}
|
||||
|
||||
try_void(Table_setDirtyBit(t, true));
|
||||
Defer(IGNORE_RESULT Table_setDirtyBit(t, false));
|
||||
|
||||
u32 row_size = t->header.row_size;
|
||||
u32 row_size_in_file = t->header.encrypted ? t->enc_buf.len : row_size;
|
||||
const u64 new_row_index = t->row_count;
|
||||
|
||||
// seek for end of the file
|
||||
try_void(file_seek(t->table_file, 0, SeekOrigin_End));
|
||||
|
||||
// write new rows to the file
|
||||
for(u64 i = 0; i < count; i++){
|
||||
void* row_ptr = (u8*)src + row_size * i;
|
||||
if(t->header.encrypted){
|
||||
try_void(
|
||||
AESBlockEncryptor_encrypt(
|
||||
&t->enc,
|
||||
Array_u8_construct(row_ptr, row_size),
|
||||
t->enc_buf
|
||||
)
|
||||
);
|
||||
row_ptr = t->enc_buf.data;
|
||||
}
|
||||
try_void(file_writeStructs(t->table_file, row_ptr, row_size_in_file, 1));
|
||||
t->row_count++;
|
||||
}
|
||||
|
||||
Return RESULT_VALUE(u, new_row_index);
|
||||
}
|
||||
|
||||
Result(u64) idb_getRowCount(Table* t, bool lock_table){
|
||||
Deferral(4);
|
||||
|
||||
if(lock_table){
|
||||
idb_lockTable(t);
|
||||
Defer(idb_unlockTable(t));
|
||||
}
|
||||
|
||||
u64 count = t->row_count;
|
||||
Return RESULT_VALUE(u, count);
|
||||
}
|
||||
|
||||
Result(void) idb_createListFromTable(Table* t, List_* l, bool lock_table){
|
||||
Deferral(1);
|
||||
|
||||
if(lock_table){
|
||||
idb_lockTable(t);
|
||||
Defer(idb_unlockTable(t));
|
||||
}
|
||||
|
||||
u64 row_count = t->row_count;
|
||||
u64 row_size = t->header.row_size;
|
||||
u64 total_size = row_count * row_size;
|
||||
*l = _List_alloc(total_size, row_size);
|
||||
l->len = row_count;
|
||||
bool success = false;
|
||||
Defer(if(!success) _List_destroy(l));
|
||||
|
||||
try_void(idb_getRows(t, 0, l->data, row_count, false));
|
||||
|
||||
success = true;
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
47
src/db/idb.h
47
src/db/idb.h
@@ -1,47 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "tlibc/errors.h"
|
||||
#include "tlibc/collections/List.h"
|
||||
|
||||
#define IDB_VERSION 2
|
||||
#define IDB_AES_KEY_SIZE 32
|
||||
|
||||
typedef struct IncrementalDB IncrementalDB;
|
||||
typedef struct Table Table;
|
||||
|
||||
|
||||
Result(IncrementalDB*) idb_open(str db_dir, NULLABLE(Array(u8) aes_key));
|
||||
void idb_close(IncrementalDB* db);
|
||||
|
||||
/// before performing atransaction on DB lock it manually or call functions with parameter lock_db=true
|
||||
void idb_lockDB(IncrementalDB* db);
|
||||
|
||||
/// USAGE:
|
||||
/// idb_lockDB(db);
|
||||
/// Defer(idb_unlockDB(db));
|
||||
void idb_unlockDB(IncrementalDB* db);
|
||||
|
||||
/// before performing a transaction on Table lock it manually or call function with parameter lock_db=true
|
||||
void idb_lockTable(Table* t);
|
||||
|
||||
/// USAGE:
|
||||
/// idb_lockTable(t);
|
||||
/// Defer(idb_unlockTable(t));
|
||||
void idb_unlockTable(Table* t);
|
||||
|
||||
|
||||
Result(Table*) idb_getOrCreateTable(IncrementalDB* db, NULLABLE(str) subdir, str table_name, u32 row_size, bool lock_db);
|
||||
|
||||
Result(void) idb_getRows(Table* t, u64 start_from_id, void* dst, u64 count, bool lock_table);
|
||||
#define idb_getRow(T, ID, DST, LOCK) idb_getRows(T, ID, DST, 1, LOCK)
|
||||
|
||||
Result(u64) idb_pushRows(Table* t, const void* src, u64 count, bool lock_table);
|
||||
#define idb_pushRow(T, SRC, LOCK) idb_pushRows(T, SRC, 1, LOCK)
|
||||
|
||||
Result(void) idb_updateRows(Table* t, u64 start_from_id, const void* src, u64 count, bool lock_table);
|
||||
#define idb_updateRow(T, ID, SRC, LOCK) idb_updateRows(T, ID, SRC, 1, LOCK)
|
||||
|
||||
Result(u64) idb_getRowCount(Table* t, bool lock_table);
|
||||
|
||||
/// construct new list and load whole table into it
|
||||
Result(void) idb_createListFromTable(Table* t, List_* l, bool lock_table);
|
||||
@@ -1,63 +0,0 @@
|
||||
#pragma once
|
||||
#include "tcp-chat/common_constants.h"
|
||||
#include "tlibc/time.h"
|
||||
#include "tlibc/magic.h"
|
||||
|
||||
// TODO: add table versions
|
||||
|
||||
typedef struct UserInfo {
|
||||
u16 name_len;
|
||||
char name[USERNAME_SIZE_MAX + 1]; // null-terminated
|
||||
u8 token[PASSWORD_HASH_SIZE]; // token is hashed again on server side
|
||||
DateTime registration_time_utc;
|
||||
} ATTRIBUTE_ALIGNED(256) UserInfo;
|
||||
|
||||
|
||||
typedef struct ServerInfo {
|
||||
u16 address_len;
|
||||
char address[HOSTADDR_SIZE_MAX + 1];
|
||||
u32 pk_base64_len;
|
||||
char pk_base64[PUBLIC_KEY_BASE64_SIZE_MAX + 1];
|
||||
u16 name_len;
|
||||
char name[SERVER_NAME_SIZE_MAX + 1];
|
||||
u16 desc_len;
|
||||
char desc[SERVER_DESC_SIZE_MAX + 1];
|
||||
} ATTRIBUTE_ALIGNED(16*1024) ServerInfo;
|
||||
|
||||
|
||||
typedef struct ChannelInfo {
|
||||
u16 name_len;
|
||||
char name[CHANNEL_NAME_SIZE_MAX + 1];
|
||||
u16 desc_len;
|
||||
char desc[CHANNEL_DESC_SIZE_MAX + 1];
|
||||
} ATTRIBUTE_ALIGNED(4*1024) ChannelInfo;
|
||||
|
||||
|
||||
// not a table
|
||||
typedef struct MessageMeta {
|
||||
/*
|
||||
In block messages can be stored with some padding (zero bytes) between them.
|
||||
To distinguish message from padding, each message starts with MESSAGE_MAGIC.
|
||||
*/
|
||||
Magic32 magic;
|
||||
u16 data_size;
|
||||
u64 id;
|
||||
u64 sender_id;
|
||||
DateTime receiving_time_utc;
|
||||
} ATTRIBUTE_ALIGNED(64) MessageMeta;
|
||||
|
||||
#define MESSAGE_MAGIC ((Magic32){ .bytes = { 'M', 's', 'g', 'S' } })
|
||||
|
||||
// Stores some number of messages. Look in MessageBlockMeta to see how much.
|
||||
typedef struct MessageBlock {
|
||||
/* ((sequence MessageMeta), (sequence binary-data)) */
|
||||
u8 data[MESSAGE_BLOCK_SIZE];
|
||||
} ATTRIBUTE_ALIGNED(64) MessageBlock;
|
||||
|
||||
// is used to find in which MessageBlock a message is stored
|
||||
typedef struct MessageBlockMeta {
|
||||
u64 first_message_id;
|
||||
u32 messages_count;
|
||||
u32 data_size;
|
||||
} ATTRIBUTE_ALIGNED(16) MessageBlockMeta;
|
||||
|
||||
@@ -98,7 +98,7 @@ Result(i32) socket_recv(Socket s, Array(u8) buffer, SocketRecvFlag flags){
|
||||
}
|
||||
if(r == 0 || (flags & SocketRecvFlag_WholeBuffer && (u32)r != buffer.len))
|
||||
{
|
||||
return RESULT_ERROR("Socket closed", false);
|
||||
return RESULT_ERROR_LITERAL("Socket closed");
|
||||
}
|
||||
return RESULT_VALUE(i, r);
|
||||
}
|
||||
@@ -113,7 +113,7 @@ Result(i32) socket_recvfrom(Socket s, Array(u8) buffer, SocketRecvFlag flags, NU
|
||||
}
|
||||
if(r == 0 || (flags & SocketRecvFlag_WholeBuffer && (u32)r != buffer.len))
|
||||
{
|
||||
return RESULT_ERROR("Socket closed", false);
|
||||
return RESULT_ERROR_LITERAL("Socket closed");
|
||||
}
|
||||
|
||||
//TODO: add IPV6 support (struct sockaddr_in6)
|
||||
@@ -132,12 +132,12 @@ Result(i32) socket_recvfrom(Socket s, Array(u8) buffer, SocketRecvFlag flags, NU
|
||||
}
|
||||
|
||||
Result(void) socket_TCP_enableAliveChecks(Socket s,
|
||||
sec_t first_check_time, u32 checks_count, sec_t checks_interval)
|
||||
sec_t first_check_time, u32 check_count, sec_t checks_interval)
|
||||
{
|
||||
#if KN_USE_WINSOCK
|
||||
BOOL opt_SO_KEEPALIVE = 1; // enable keepalives
|
||||
DWORD opt_TCP_KEEPIDLE = first_check_time;
|
||||
DWORD opt_TCP_KEEPCNT = checks_count;
|
||||
DWORD opt_TCP_KEEPCNT = check_count;
|
||||
DWORD opt_TCP_KEEPINTVL = checks_interval;
|
||||
try_setsockopt(s, SOL_SOCKET, SO_KEEPALIVE);
|
||||
try_setsockopt(s, IPPROTO_TCP, TCP_KEEPIDLE);
|
||||
@@ -145,12 +145,12 @@ Result(void) socket_TCP_enableAliveChecks(Socket s,
|
||||
try_setsockopt(s, IPPROTO_TCP, TCP_KEEPINTVL);
|
||||
|
||||
// timeout for connect()
|
||||
DWORD opt_TCP_MAXRT = checks_count * checks_interval;
|
||||
DWORD opt_TCP_MAXRT = check_count * checks_interval;
|
||||
try_setsockopt(s, IPPROTO_TCP, TCP_MAXRT);
|
||||
#else
|
||||
int opt_SO_KEEPALIVE = 1; // enable keepalives
|
||||
int opt_TCP_KEEPIDLE = first_check_time;
|
||||
int opt_TCP_KEEPCNT = checks_count;
|
||||
int opt_TCP_KEEPCNT = check_count;
|
||||
int opt_TCP_KEEPINTVL = checks_interval;
|
||||
try_setsockopt(s, SOL_SOCKET, SO_KEEPALIVE);
|
||||
try_setsockopt(s, IPPROTO_TCP, TCP_KEEPIDLE);
|
||||
@@ -158,7 +158,7 @@ Result(void) socket_TCP_enableAliveChecks(Socket s,
|
||||
try_setsockopt(s, IPPROTO_TCP, TCP_KEEPINTVL);
|
||||
|
||||
// read more in the article
|
||||
int opt_TCP_USER_TIMEOUT = checks_count * checks_interval * 1000;
|
||||
int opt_TCP_USER_TIMEOUT = check_count * checks_interval * 1000;
|
||||
try_setsockopt(s, IPPROTO_TCP, TCP_USER_TIMEOUT);
|
||||
#endif
|
||||
return RESULT_VOID;
|
||||
|
||||
@@ -38,7 +38,7 @@ Result(i32) socket_recvfrom(Socket s, Array(u8) buffer, SocketRecvFlag flags
|
||||
/// Read more: https://blog.cloudflare.com/when-tcp-sockets-refuse-to-die/
|
||||
/// RU translaton: https://habr.com/ru/articles/700470/
|
||||
Result(void) socket_TCP_enableAliveChecks(Socket s,
|
||||
sec_t first_check_time, u32 checks_count, sec_t checks_interval);
|
||||
sec_t first_check_time, u32 check_count, sec_t checks_interval);
|
||||
#define socket_TCP_enableAliveChecks_default(socket) \
|
||||
socket_TCP_enableAliveChecks(socket, 1, 4, 5)
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ const Magic64 PacketHeader_MAGIC = { .bytes = { 't', 'c', 'p', '-', 'c', 'h', 'a
|
||||
|
||||
Result(void) PacketHeader_validateMagic(PacketHeader* ptr){
|
||||
if (ptr->magic.n != PacketHeader_MAGIC.n){
|
||||
return RESULT_ERROR("invalid packet magic", false);
|
||||
return RESULT_ERROR_LITERAL("invalid packet magic");
|
||||
}
|
||||
return RESULT_VOID;
|
||||
}
|
||||
|
||||
@@ -63,7 +63,7 @@ Result(void) ClientHandshake_tryConstruct(ClientHandshake* ptr, PacketHeader* he
|
||||
}
|
||||
|
||||
void ServerHandshake_construct(ServerHandshake* ptr, PacketHeader* header,
|
||||
u64 session_id)
|
||||
i64 session_id)
|
||||
{
|
||||
_PacketHeader_construct(ServerHandshake);
|
||||
zeroStruct(ptr);
|
||||
@@ -95,7 +95,7 @@ Result(void) LoginRequest_tryConstruct(LoginRequest *ptr, PacketHeader* header,
|
||||
|
||||
str name_error_str = validateUsername_str(username);
|
||||
if(name_error_str.data){
|
||||
Return RESULT_ERROR(name_error_str.data, true);
|
||||
Return RESULT_ERROR(name_error_str, true);
|
||||
}
|
||||
memcpy(ptr->username, username.data, username.len);
|
||||
|
||||
@@ -106,7 +106,7 @@ Result(void) LoginRequest_tryConstruct(LoginRequest *ptr, PacketHeader* header,
|
||||
}
|
||||
|
||||
void LoginResponse_construct(LoginResponse* ptr, PacketHeader* header,
|
||||
u64 user_id, u64 landing_channel_id)
|
||||
i64 user_id, i64 landing_channel_id)
|
||||
{
|
||||
_PacketHeader_construct(LoginResponse);
|
||||
zeroStruct(ptr);
|
||||
@@ -124,7 +124,7 @@ Result(void) RegisterRequest_tryConstruct(RegisterRequest *ptr, PacketHeader* he
|
||||
|
||||
str name_error_str = validateUsername_str(username);
|
||||
if(name_error_str.data){
|
||||
Return RESULT_ERROR(name_error_str.data, true);
|
||||
Return RESULT_ERROR(name_error_str, true);
|
||||
}
|
||||
memcpy(ptr->username, username.data, username.len);
|
||||
|
||||
@@ -135,15 +135,73 @@ Result(void) RegisterRequest_tryConstruct(RegisterRequest *ptr, PacketHeader* he
|
||||
}
|
||||
|
||||
void RegisterResponse_construct(RegisterResponse *ptr, PacketHeader* header,
|
||||
u64 user_id)
|
||||
i64 user_id)
|
||||
{
|
||||
_PacketHeader_construct(RegisterResponse);
|
||||
zeroStruct(ptr);
|
||||
ptr->user_id = user_id;
|
||||
}
|
||||
|
||||
Result(u32) MessageBlock_writeMessage(
|
||||
MessageMeta* msg, Array(u8) msg_content,
|
||||
MessageBlockMeta* block_meta, Array(u8)* block_free_part)
|
||||
{
|
||||
Deferral(1);
|
||||
try_assert(msg->data_size >= MESSAGE_SIZE_MIN && msg->data_size <= MESSAGE_SIZE_MAX);
|
||||
try_assert(msg->data_size <= msg_content.len);
|
||||
|
||||
u32 offset_increment = sizeof(MessageMeta) + msg->data_size;
|
||||
if(block_free_part->len < offset_increment){
|
||||
Return RESULT_VALUE(u, 0);
|
||||
}
|
||||
|
||||
memcpy(block_free_part->data, msg, sizeof(MessageMeta));
|
||||
block_free_part->data += sizeof(MessageMeta);
|
||||
block_free_part->len -= sizeof(MessageMeta);
|
||||
|
||||
memcpy(block_free_part->data, msg_content.data, msg->data_size);
|
||||
block_free_part->data += msg->data_size;
|
||||
block_free_part->len -= msg->data_size;
|
||||
|
||||
if(block_meta->message_count == 0)
|
||||
block_meta->first_message_id = msg->id;
|
||||
block_meta->message_count++;
|
||||
block_meta->data_size += offset_increment;
|
||||
|
||||
Return RESULT_VALUE(u, offset_increment);
|
||||
}
|
||||
|
||||
Result(u32) MessageBlock_readMessage(
|
||||
Array(u8)* block_unread_part,
|
||||
MessageMeta* msg, Array(u8) msg_content)
|
||||
{
|
||||
Deferral(1);
|
||||
try_assert(block_unread_part->len >= sizeof(MessageMeta) + MESSAGE_SIZE_MIN);
|
||||
try_assert(msg_content.len >= MESSAGE_SIZE_MIN && msg_content.len <= MESSAGE_SIZE_MAX);
|
||||
|
||||
memcpy(msg, block_unread_part->data, sizeof(MessageMeta));
|
||||
block_unread_part->data += sizeof(MessageMeta);
|
||||
block_unread_part->len -= sizeof(MessageMeta);
|
||||
|
||||
if(msg->magic.n != MESSAGE_MAGIC.n){
|
||||
Return RESULT_VALUE(u, 0);
|
||||
}
|
||||
try_assert(block_unread_part->len >= msg->data_size);
|
||||
try_assert(msg->data_size >= MESSAGE_SIZE_MIN && msg->data_size <= MESSAGE_SIZE_MAX);
|
||||
try_assert(msg->id > 0);
|
||||
try_assert(msg->sender_id > 0);
|
||||
try_assert(msg->timestamp.d.year > 2024);
|
||||
|
||||
memcpy(msg_content.data, block_unread_part->data, msg->data_size);
|
||||
block_unread_part->data += msg->data_size;
|
||||
block_unread_part->len -= msg->data_size;
|
||||
|
||||
u32 offset_increment = sizeof(MessageMeta) + msg->data_size;
|
||||
Return RESULT_VALUE(u, offset_increment);
|
||||
}
|
||||
|
||||
void SendMessageRequest_construct(SendMessageRequest *ptr, PacketHeader *header,
|
||||
u64 channel_id, u16 data_size)
|
||||
i64 channel_id, u16 data_size)
|
||||
{
|
||||
_PacketHeader_construct(SendMessageRequest);
|
||||
zeroStruct(ptr);
|
||||
@@ -152,30 +210,28 @@ void SendMessageRequest_construct(SendMessageRequest *ptr, PacketHeader *header,
|
||||
}
|
||||
|
||||
void SendMessageResponse_construct(SendMessageResponse *ptr, PacketHeader *header,
|
||||
u64 message_id, DateTime receiving_time_utc)
|
||||
i64 message_id, DateTime timestamp)
|
||||
{
|
||||
_PacketHeader_construct(SendMessageResponse);
|
||||
zeroStruct(ptr);
|
||||
ptr->message_id = message_id;
|
||||
ptr->receiving_time_utc = receiving_time_utc;
|
||||
ptr->timestamp = timestamp;
|
||||
}
|
||||
|
||||
void GetMessageBlockRequest_construct(GetMessageBlockRequest *ptr, PacketHeader *header,
|
||||
u64 channel_id, u64 first_message_id, u32 messages_count)
|
||||
i64 channel_id, i64 first_message_id, u32 message_count)
|
||||
{
|
||||
_PacketHeader_construct(GetMessageBlockRequest);
|
||||
zeroStruct(ptr);
|
||||
ptr->channel_id = channel_id;
|
||||
ptr->first_message_id = first_message_id;
|
||||
ptr->messages_count = messages_count;
|
||||
ptr->message_count = message_count;
|
||||
}
|
||||
|
||||
void GetMessageBlockResponse_construct(GetMessageBlockResponse *ptr, PacketHeader *header,
|
||||
u64 first_message_id, u32 messages_count, u32 data_size)
|
||||
MessageBlockMeta* block_meta)
|
||||
{
|
||||
_PacketHeader_construct(GetMessageBlockResponse);
|
||||
zeroStruct(ptr);
|
||||
ptr->first_message_id = first_message_id;
|
||||
ptr->messages_count = messages_count;
|
||||
ptr->data_size = data_size;
|
||||
memcpy(&ptr->block_meta, block_meta, sizeof(MessageBlockMeta));
|
||||
}
|
||||
|
||||
@@ -60,11 +60,11 @@ Result(void) ClientHandshake_tryConstruct(ClientHandshake* ptr, PacketHeader* he
|
||||
|
||||
|
||||
typedef struct ServerHandshake {
|
||||
u64 session_id;
|
||||
i64 session_id;
|
||||
} ALIGN_PACKET_STRUCT ServerHandshake;
|
||||
|
||||
void ServerHandshake_construct(ServerHandshake* ptr, PacketHeader* header,
|
||||
u64 session_id);
|
||||
i64 session_id);
|
||||
|
||||
|
||||
typedef enum ServerPublicInfo {
|
||||
@@ -99,12 +99,12 @@ Result(void) LoginRequest_tryConstruct(LoginRequest* ptr, PacketHeader* header,
|
||||
|
||||
|
||||
typedef struct LoginResponse {
|
||||
u64 user_id;
|
||||
u64 landing_channel_id;
|
||||
i64 user_id;
|
||||
i64 landing_channel_id;
|
||||
} ALIGN_PACKET_STRUCT LoginResponse;
|
||||
|
||||
void LoginResponse_construct(LoginResponse* ptr, PacketHeader* header,
|
||||
u64 user_id, u64 landing_channel_id);
|
||||
i64 user_id, i64 landing_channel_id);
|
||||
|
||||
|
||||
typedef struct RegisterRequest {
|
||||
@@ -117,49 +117,80 @@ Result(void) RegisterRequest_tryConstruct(RegisterRequest* ptr, PacketHeader* he
|
||||
|
||||
|
||||
typedef struct RegisterResponse {
|
||||
u64 user_id;
|
||||
i64 user_id;
|
||||
} ALIGN_PACKET_STRUCT RegisterResponse;
|
||||
|
||||
void RegisterResponse_construct(RegisterResponse* ptr, PacketHeader* header,
|
||||
u64 user_id);
|
||||
i64 user_id);
|
||||
|
||||
|
||||
typedef struct MessageMeta {
|
||||
Magic32 magic;
|
||||
u16 data_size;
|
||||
i64 id;
|
||||
i64 sender_id;
|
||||
DateTime timestamp; /* UTC */
|
||||
} ALIGN_PACKET_STRUCT MessageMeta;
|
||||
|
||||
#define MESSAGE_MAGIC ((Magic32){ .bytes = { 'M', 's', 'g', 'S' } })
|
||||
|
||||
typedef struct MessageBlockMeta {
|
||||
i64 first_message_id;
|
||||
u32 message_count;
|
||||
u32 data_size;
|
||||
} ALIGN_PACKET_STRUCT MessageBlockMeta;
|
||||
|
||||
/// @brief write msg_meta and msg_meta->data_size bytes from msg_content to buffer
|
||||
/// @param block_meta set to {0} if block is empty yet
|
||||
/// @param block_free_part .data and .len are adjusted to point to free part
|
||||
/// @return amount of bytes written to block, may be 0 if msg_meta and msg_content don't fit
|
||||
Result(u32) MessageBlock_writeMessage(
|
||||
MessageMeta* msg_meta, Array(u8) msg_content,
|
||||
MessageBlockMeta* block_meta, Array(u8)* block_free_part);
|
||||
|
||||
/// @brief read message meta and content from buffer
|
||||
/// @param block_unread_part .data and .len are adjusted to point to unread part
|
||||
/// @param msg_content .len must be >= MESSAGE_SIZE_MAX
|
||||
/// @return amount of bytes read from block, may be 0 if it doesn't start with MESSAGE_MAGIC
|
||||
Result(u32) MessageBlock_readMessage(
|
||||
Array(u8)* block_unread_part,
|
||||
MessageMeta* msg_meta, Array(u8) msg_content);
|
||||
|
||||
|
||||
typedef struct SendMessageRequest {
|
||||
u64 channel_id;
|
||||
i64 channel_id;
|
||||
u16 data_size;
|
||||
/* stream of size data_size */
|
||||
} ALIGN_PACKET_STRUCT SendMessageRequest;
|
||||
|
||||
void SendMessageRequest_construct(SendMessageRequest* ptr, PacketHeader* header,
|
||||
u64 channel_id, u16 data_size);
|
||||
i64 channel_id, u16 data_size);
|
||||
|
||||
|
||||
typedef struct SendMessageResponse {
|
||||
u64 message_id;
|
||||
DateTime receiving_time_utc;
|
||||
i64 message_id;
|
||||
DateTime timestamp; /* UTC */
|
||||
} ALIGN_PACKET_STRUCT SendMessageResponse;
|
||||
|
||||
void SendMessageResponse_construct(SendMessageResponse* ptr, PacketHeader* header,
|
||||
u64 message_id, DateTime receiving_time_utc);
|
||||
i64 message_id, DateTime timestamp);
|
||||
|
||||
|
||||
typedef struct GetMessageBlockRequest {
|
||||
u64 channel_id;
|
||||
u64 first_message_id;
|
||||
u32 messages_count;
|
||||
i64 channel_id;
|
||||
i64 first_message_id;
|
||||
u32 message_count;
|
||||
} ALIGN_PACKET_STRUCT GetMessageBlockRequest;
|
||||
|
||||
void GetMessageBlockRequest_construct(GetMessageBlockRequest* ptr, PacketHeader* header,
|
||||
u64 channel_id, u64 first_message_id, u32 messages_count);
|
||||
i64 channel_id, i64 first_message_id, u32 message_count);
|
||||
|
||||
|
||||
typedef struct GetMessageBlockResponse {
|
||||
u64 first_message_id;
|
||||
u32 messages_count;
|
||||
u32 data_size;
|
||||
/* stream of size data_size : ((sequence MessageMeta), (sequence binary-data)) */
|
||||
MessageBlockMeta block_meta;
|
||||
/* stream of size data_size : sequence (MessageMeta, byte[MessageMeta.data_size]) */
|
||||
} ALIGN_PACKET_STRUCT GetMessageBlockResponse;
|
||||
|
||||
void GetMessageBlockResponse_construct(GetMessageBlockResponse* ptr, PacketHeader* header,
|
||||
u64 first_message_id, u32 messages_count, u32 data_size);
|
||||
MessageBlockMeta* block_meta);
|
||||
|
||||
|
||||
@@ -1,186 +0,0 @@
|
||||
#include "server_internal.h"
|
||||
#include "tlibc/string/StringBuilder.h"
|
||||
#include "tlibc/filesystem.h"
|
||||
|
||||
void Channel_free(Channel* self){
|
||||
if(!self)
|
||||
return;
|
||||
str_destroy(self->name);
|
||||
str_destroy(self->description);
|
||||
List_MessageBlockMeta_destroy(&self->messages.blocks_meta_list);
|
||||
LList_MessageBlock_destroy(&self->messages.blocks_queue);
|
||||
free(self);
|
||||
}
|
||||
|
||||
Result(Channel*) Channel_create(u64 chan_id, str name, str description,
|
||||
IncrementalDB* db, bool lock_db)
|
||||
{
|
||||
Deferral(8);
|
||||
|
||||
Channel* self = (Channel*)malloc(sizeof(Channel));
|
||||
zeroStruct(self);
|
||||
bool success = false;
|
||||
Defer(if(!success) Channel_free(self));
|
||||
|
||||
self->id = chan_id;
|
||||
try_assert(name.len >= CHANNEL_NAME_SIZE_MIN && name.len <= CHANNEL_NAME_SIZE_MAX);
|
||||
self->name = str_copy(name);
|
||||
try_assert(description.len <= CHANNEL_DESC_SIZE_MAX);
|
||||
self->description = str_copy(description);
|
||||
|
||||
if(lock_db){
|
||||
idb_lockDB(db);
|
||||
Defer(idb_unlockDB(db));
|
||||
}
|
||||
|
||||
StringBuilder sb = StringBuilder_alloc(CHANNEL_NAME_SIZE_MAX + 32 + 1);
|
||||
Defer(StringBuilder_destroy(&sb));
|
||||
StringBuilder_append_str(&sb, STR("channels"));
|
||||
StringBuilder_append_char(&sb, path_sep);
|
||||
StringBuilder_append_str(&sb, name);
|
||||
|
||||
str subdir = str_copy(StringBuilder_getStr(&sb));
|
||||
Defer(str_destroy(subdir));
|
||||
str message_blocks_str = STR("message_blocks");
|
||||
str message_blocks_meta_str = STR("message_blocks_meta");
|
||||
|
||||
StringBuilder_removeFromEnd(&sb, -1);
|
||||
StringBuilder_append_str(&sb, name);
|
||||
StringBuilder_append_char(&sb, '_');
|
||||
StringBuilder_append_str(&sb, message_blocks_str);
|
||||
try(self->messages.blocks_table, p,
|
||||
idb_getOrCreateTable(db, subdir, StringBuilder_getStr(&sb), sizeof(MessageBlock), false)
|
||||
);
|
||||
idb_lockTable(self->messages.blocks_table);
|
||||
Defer(idb_unlockTable(self->messages.blocks_table));
|
||||
|
||||
|
||||
StringBuilder_removeFromEnd(&sb, message_blocks_str.len);
|
||||
StringBuilder_append_str(&sb, message_blocks_meta_str);
|
||||
try(self->messages.blocks_meta_table, p,
|
||||
idb_getOrCreateTable(db, subdir, StringBuilder_getStr(&sb), sizeof(MessageBlockMeta), false)
|
||||
);
|
||||
idb_lockTable(self->messages.blocks_meta_table);
|
||||
Defer(idb_unlockTable(self->messages.blocks_meta_table));
|
||||
|
||||
// load whole message_blocks_meta table to list
|
||||
try_void(
|
||||
idb_createListFromTable(self->messages.blocks_meta_table, (void*)&self->messages.blocks_meta_list, false)
|
||||
);
|
||||
|
||||
// load N last blocks to the queue
|
||||
self->messages.blocks_queue = LList_construct(MessageBlock, NULL);
|
||||
u64 message_blocks_count = self->messages.blocks_meta_list.len;
|
||||
u64 first_block_id = 0;
|
||||
if(message_blocks_count > MESSAGE_BLOCKS_CACHE_COUNT)
|
||||
first_block_id = message_blocks_count - MESSAGE_BLOCKS_CACHE_COUNT;
|
||||
for(u64 id = first_block_id; id < message_blocks_count; id++){
|
||||
LLNode(MessageBlock)* node = LLNode_MessageBlock_createZero();
|
||||
LList_MessageBlock_insertAfter(
|
||||
&self->messages.blocks_queue,
|
||||
self->messages.blocks_queue.last,
|
||||
node
|
||||
);
|
||||
try_void(idb_getRow(self->messages.blocks_table, id, node->value.data, false));
|
||||
}
|
||||
|
||||
if(self->messages.blocks_meta_list.len > 0){
|
||||
MessageBlockMeta last_block_meta = self->messages.blocks_meta_list.data[self->messages.blocks_meta_list.len - 1];
|
||||
self->messages.count = last_block_meta.first_message_id + last_block_meta.messages_count - 1;
|
||||
}
|
||||
else {
|
||||
self->messages.count = 0;
|
||||
}
|
||||
|
||||
success = true;
|
||||
Return RESULT_VALUE(p, self);
|
||||
}
|
||||
|
||||
void Channel_unloadExcessBlocks(Channel* self){
|
||||
while(self->messages.blocks_queue.count > MESSAGE_BLOCKS_CACHE_COUNT){
|
||||
LLNode(MessageBlock)* node = self->messages.blocks_queue.first;
|
||||
LList_MessageBlock_detatch(&self->messages.blocks_queue, node);
|
||||
free(node);
|
||||
}
|
||||
}
|
||||
|
||||
Result(void) Channel_saveMessage(Channel* self, Array(u8) message_data, u64 sender_id,
|
||||
MessageMeta* out_message_meta, bool lock_tables)
|
||||
{
|
||||
Deferral(4);
|
||||
|
||||
if(lock_tables){
|
||||
idb_lockTable(self->messages.blocks_table);
|
||||
idb_lockTable(self->messages.blocks_meta_table);
|
||||
Defer(
|
||||
idb_unlockTable(self->messages.blocks_table);
|
||||
idb_unlockTable(self->messages.blocks_meta_table);
|
||||
);
|
||||
}
|
||||
|
||||
// create new block if message won't fit in the last existing
|
||||
MessageBlockMeta* incomplete_block_meta = self->messages.blocks_meta_list.data + self->messages.blocks_meta_list.len;
|
||||
u64 new_message_id = incomplete_block_meta->first_message_id + incomplete_block_meta->messages_count;
|
||||
u32 message_size_in_block = sizeof(MessageMeta) + ALIGN_TO(message_data.len, 8);
|
||||
if(incomplete_block_meta->data_size + message_size_in_block > MESSAGE_BLOCK_SIZE){
|
||||
// create new MessageBlockMeta
|
||||
incomplete_block_meta = List_MessageBlockMeta_expand(&self->messages.blocks_meta_list, 1);
|
||||
incomplete_block_meta->first_message_id = new_message_id;
|
||||
incomplete_block_meta->messages_count = 0;
|
||||
incomplete_block_meta->data_size = 0;
|
||||
// create new MessageBlock
|
||||
LList_MessageBlock_insertAfter(
|
||||
&self->messages.blocks_queue,
|
||||
self->messages.blocks_queue.last,
|
||||
LLNode_MessageBlock_createZero());
|
||||
// unload old blocks from cache
|
||||
Channel_unloadExcessBlocks(self);
|
||||
}
|
||||
|
||||
// create message meta
|
||||
out_message_meta->magic = MESSAGE_MAGIC;
|
||||
out_message_meta->data_size = message_data.len;
|
||||
out_message_meta->id = new_message_id;
|
||||
out_message_meta->sender_id = sender_id;
|
||||
DateTime_getUTC(&out_message_meta->receiving_time_utc);
|
||||
|
||||
// copy message data to message block
|
||||
MessageBlock* incomplete_block = &self->messages.blocks_queue.last->value;
|
||||
u8* data_ptr = incomplete_block->data + incomplete_block_meta->data_size;
|
||||
memcpy(data_ptr, out_message_meta, sizeof(MessageMeta));
|
||||
data_ptr += sizeof(MessageMeta);
|
||||
memcpy(data_ptr, message_data.data, message_data.len);
|
||||
incomplete_block_meta->data_size += sizeof(MessageMeta) + ALIGN_TO(message_data.len, 8);
|
||||
incomplete_block_meta->messages_count++;
|
||||
|
||||
// save to DB
|
||||
try_void(idb_pushRow(self->messages.blocks_meta_table, incomplete_block_meta, false));
|
||||
try_void(idb_pushRow(self->messages.blocks_table, incomplete_block, false));
|
||||
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
Result(void) Channel_loadMessageBlock(Channel* self, u64 fisrt_message_id, u32 count,
|
||||
MessageBlockMeta* out_meta, NULLABLE(Array(u8)*) out_block, bool lock_tables)
|
||||
{
|
||||
Deferral(4);
|
||||
|
||||
if(lock_tables){
|
||||
idb_lockTable(self->messages.blocks_table);
|
||||
idb_lockTable(self->messages.blocks_meta_table);
|
||||
Defer(
|
||||
idb_unlockTable(self->messages.blocks_table);
|
||||
idb_unlockTable(self->messages.blocks_meta_table);
|
||||
);
|
||||
}
|
||||
|
||||
// TODO: Maybe it's better to request message block id directly? Client doesn't know how much bytes `count` messages will take, this can lead to severe lags on slow internet
|
||||
|
||||
// TODO: binary search in list of blocks meta
|
||||
// TODO: return if out_block == NULL
|
||||
// TODO: check if block is in N_LAST_BLOCKS
|
||||
// TODO: load block
|
||||
// TODO: insert block in queue and keep it sorted
|
||||
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
@@ -1,11 +1,14 @@
|
||||
#include "server/server_internal.h"
|
||||
#include "network/tcp-chat-protocol/v1.h"
|
||||
|
||||
void ClientConnection_close(ClientConnection* conn){
|
||||
if(!conn)
|
||||
return;
|
||||
EncryptedSocketTCP_destroy(&conn->sock);
|
||||
Array_u8_destroy(&conn->session_key);
|
||||
Array_u8_destroy(&conn->message_block);
|
||||
Array_u8_destroy(&conn->message_content);
|
||||
ServerQueries_free(conn->queries);
|
||||
tsqlite_connection_close(conn->db);
|
||||
free(conn);
|
||||
}
|
||||
|
||||
@@ -21,10 +24,17 @@ Result(ClientConnection*) ClientConnection_accept(ConnectionHandlerArgs* args)
|
||||
conn->server = args->server;
|
||||
conn->client_end = args->client_end;
|
||||
conn->session_id = args->session_id;
|
||||
conn->authorized = false;
|
||||
conn->user_id = -1;
|
||||
conn->session_key = Array_u8_alloc(AES_SESSION_KEY_SIZE);
|
||||
|
||||
// buffers
|
||||
conn->message_block = Array_u8_alloc(MESSAGE_BLOCK_COUNT_MAX * (sizeof(MessageMeta) + MESSAGE_SIZE_MAX));
|
||||
conn->message_content = Array_u8_alloc(MESSAGE_SIZE_MAX);
|
||||
|
||||
// database
|
||||
try(conn->db, p, tsqlite_connection_open(args->server->db_path));
|
||||
try(conn->queries, p, ServerQueries_compile(conn->db));
|
||||
|
||||
// correct session key will be received from client later
|
||||
conn->session_key = Array_u8_alloc(AES_SESSION_KEY_SIZE);
|
||||
Array_u8_memset(&conn->session_key, 0);
|
||||
EncryptedSocketTCP_construct(&conn->sock, args->accepted_socket_tcp, NETWORK_BUFFER_SIZE, conn->session_key);
|
||||
try_void(socket_TCP_enableAliveChecks_default(args->accepted_socket_tcp));
|
||||
@@ -57,4 +67,4 @@ Result(ClientConnection*) ClientConnection_accept(ConnectionHandlerArgs* args)
|
||||
|
||||
success = true;
|
||||
Return RESULT_VALUE(p, conn);
|
||||
}
|
||||
}
|
||||
|
||||
105
src/server/db/Channel.c
Normal file
105
src/server/db/Channel.c
Normal file
@@ -0,0 +1,105 @@
|
||||
#include "server_db_internal.h"
|
||||
|
||||
Result(bool) Channel_exists(ServerQueries* q, i64 id){
|
||||
Deferral(4);
|
||||
|
||||
tsqlite_statement* st = q->channels.exists;
|
||||
Defer(tsqlite_statement_reset(st));
|
||||
try_void(tsqlite_statement_bind_i64(st, "$id", id));
|
||||
|
||||
try(bool has_result, i, tsqlite_statement_step(st));
|
||||
|
||||
Return RESULT_VALUE(i, has_result);
|
||||
}
|
||||
|
||||
Result(bool) Channel_createOrUpdate(ServerQueries* q,
|
||||
i64 id, str name, str description)
|
||||
{
|
||||
Deferral(4);
|
||||
try_assert(id > 0);
|
||||
try_assert(name.len >= CHANNEL_NAME_SIZE_MIN && name.len <= CHANNEL_NAME_SIZE_MAX);
|
||||
try_assert(description.len <= CHANNEL_DESC_SIZE_MAX);
|
||||
|
||||
try(bool channel_exists, i, Channel_exists(q, id));
|
||||
tsqlite_statement* st = NULL;
|
||||
Defer(tsqlite_statement_reset(st));
|
||||
if(channel_exists){
|
||||
st = q->channels.update;
|
||||
}
|
||||
else {
|
||||
st = q->channels.insert;
|
||||
}
|
||||
try_void(tsqlite_statement_bind_i64(st, "$id", id));
|
||||
try_void(tsqlite_statement_bind_str(st, "$name", name, NULL));
|
||||
try_void(tsqlite_statement_bind_str(st, "$description", description, NULL));
|
||||
try_void(tsqlite_statement_step(st));
|
||||
|
||||
Return RESULT_VALUE(i, !channel_exists);
|
||||
}
|
||||
|
||||
Result(void) Channel_saveMessage(ServerQueries* q,
|
||||
i64 channel_id, i64 sender_id, Array(u8) content,
|
||||
DateTime* out_timestamp)
|
||||
{
|
||||
Deferral(4);
|
||||
try_assert(content.len >= MESSAGE_SIZE_MIN && content.len <= MESSAGE_SIZE_MAX);
|
||||
|
||||
tsqlite_statement* st = q->messages.insert;
|
||||
Defer(tsqlite_statement_reset(st));
|
||||
try_void(tsqlite_statement_bind_i64(st, "$channel_id", channel_id));
|
||||
try_void(tsqlite_statement_bind_i64(st, "$sender_id", sender_id));
|
||||
try_void(tsqlite_statement_bind_blob(st, "$content", content, NULL));
|
||||
|
||||
try(bool has_result, i, tsqlite_statement_step(st));
|
||||
try_assert(has_result);
|
||||
|
||||
try(i64 message_id, i, tsqlite_statement_getResult_i64(st));
|
||||
str timestamp_str;
|
||||
try_void(tsqlite_statement_getResult_str(st, ×tamp_str));
|
||||
try_void(DateTime_parse(timestamp_str.data, out_timestamp));
|
||||
|
||||
Return RESULT_VALUE(i, message_id);
|
||||
}
|
||||
|
||||
Result(void) Channel_loadMessageBlock(ServerQueries* q,
|
||||
i64 channel_id, i64 first_message_id, u32 count,
|
||||
MessageBlockMeta* block_meta, Array(u8) block_data)
|
||||
{
|
||||
Deferral(4);
|
||||
try_assert(channel_id > 0);
|
||||
try_assert(block_data.len >= count * (sizeof(MessageMeta) + MESSAGE_SIZE_MAX));
|
||||
if(count == 0){
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
tsqlite_statement* st = q->messages.get_block;
|
||||
Defer(tsqlite_statement_reset(st));
|
||||
try_void(tsqlite_statement_bind_i64(st, "$channel_id", channel_id));
|
||||
try_void(tsqlite_statement_bind_i64(st, "$first_message_id", first_message_id));
|
||||
try_void(tsqlite_statement_bind_i64(st, "$count", count));
|
||||
|
||||
zeroStruct(block_meta);
|
||||
MessageMeta msg_meta = {0};
|
||||
Array(u8) msg_content;
|
||||
str tmp_str = str_null;
|
||||
while(true){
|
||||
try(bool has_result, i, tsqlite_statement_step(st));
|
||||
if(!has_result)
|
||||
break;
|
||||
|
||||
// id
|
||||
try(msg_meta.id, i, tsqlite_statement_getResult_i64(st));
|
||||
// sender_id
|
||||
try(msg_meta.sender_id, i, tsqlite_statement_getResult_i64(st));
|
||||
// content
|
||||
try_void(tsqlite_statement_getResult_blob(st, &msg_content));
|
||||
// timestamp
|
||||
try_void(tsqlite_statement_getResult_str(st, &tmp_str));
|
||||
try_void(DateTime_parse(tmp_str.data, &msg_meta.timestamp));
|
||||
|
||||
try(u32 write_n, u, MessageBlock_writeMessage(&msg_meta, msg_content, block_meta, &block_data));
|
||||
try_assert(write_n > 0);
|
||||
}
|
||||
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
50
src/server/db/User.c
Normal file
50
src/server/db/User.c
Normal file
@@ -0,0 +1,50 @@
|
||||
#include "server_db_internal.h"
|
||||
|
||||
Result(i64) User_findByUsername(ServerQueries* q, str username){
|
||||
Deferral(4);
|
||||
|
||||
tsqlite_statement* st = q->users.find_by_username;
|
||||
Defer(tsqlite_statement_reset(st));
|
||||
try_void(tsqlite_statement_bind_str(st, "$username", username, NULL));
|
||||
|
||||
try(bool has_result, i, tsqlite_statement_step(st));
|
||||
i64 user_id = 0;
|
||||
if(has_result){
|
||||
try(user_id, i, tsqlite_statement_getResult_i64(st));
|
||||
try_assert(user_id > 0);
|
||||
}
|
||||
|
||||
Return RESULT_VALUE(i, user_id);
|
||||
}
|
||||
|
||||
Result(i64) User_register(ServerQueries* q, str username, Array(u8) token){
|
||||
Deferral(4);
|
||||
try_assert(username.len >= USERNAME_SIZE_MIN && username.len <= USERNAME_SIZE_MAX);
|
||||
try_assert(token.len == PASSWORD_HASH_SIZE)
|
||||
|
||||
tsqlite_statement* st = q->users.insert;
|
||||
Defer(tsqlite_statement_reset(st));
|
||||
try_void(tsqlite_statement_bind_str(st, "$username", username, NULL));
|
||||
try_void(tsqlite_statement_bind_blob(st, "$token", token, NULL));
|
||||
|
||||
try(bool has_result, i, tsqlite_statement_step(st));
|
||||
try_assert(has_result);
|
||||
try(i64 user_id, i, tsqlite_statement_getResult_i64(st));
|
||||
try_assert(user_id > 0);
|
||||
|
||||
Return RESULT_VALUE(i, user_id);
|
||||
}
|
||||
|
||||
Result(bool) User_tryAuthorize(ServerQueries* q, u64 id, Array(u8) token){
|
||||
Deferral(4);
|
||||
try_assert(token.len == PASSWORD_HASH_SIZE)
|
||||
|
||||
tsqlite_statement* st = q->users.compare_token;
|
||||
Defer(tsqlite_statement_reset(st));
|
||||
try_void(tsqlite_statement_bind_i64(st, "$id", id));
|
||||
try_void(tsqlite_statement_bind_blob(st, "$token", token, NULL));
|
||||
|
||||
try(bool has_result, i, tsqlite_statement_step(st));
|
||||
|
||||
Return RESULT_VALUE(i, has_result);
|
||||
}
|
||||
153
src/server/db/server_db.c
Normal file
153
src/server/db/server_db.c
Normal file
@@ -0,0 +1,153 @@
|
||||
#include "server_db_internal.h"
|
||||
#include "tlibc/filesystem.h"
|
||||
|
||||
Result(tsqlite_connection*) ServerDatabase_open(cstr file_path){
|
||||
Deferral(64);
|
||||
|
||||
try_void(dir_createParent(file_path));
|
||||
try(tsqlite_connection* db, p, tsqlite_connection_open(file_path));
|
||||
bool success = false;
|
||||
Defer(if(!success) tsqlite_connection_close(db));
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// CHANNELS //
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
try(tsqlite_statement* create_table_channels, p, tsqlite_statement_compile(db, STR(
|
||||
"CREATE TABLE IF NOT EXISTS channels (\n"
|
||||
" id INTEGER PRIMARY KEY AUTOINCREMENT,\n"
|
||||
" name VARCHAR NOT NULL,\n"
|
||||
" description VARCHAR NOT NULL\n"
|
||||
");"
|
||||
)));
|
||||
Defer(tsqlite_statement_free(create_table_channels));
|
||||
try_void(tsqlite_statement_step(create_table_channels));
|
||||
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// MESSAGES //
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
try(tsqlite_statement* create_table_messages, p, tsqlite_statement_compile(db, STR(
|
||||
"CREATE TABLE IF NOT EXISTS messages (\n"
|
||||
" id INTEGER PRIMARY KEY AUTOINCREMENT,\n"
|
||||
" channel_id INTEGER NOT NULL REFERENCES channels(id),\n"
|
||||
" sender_id INTEGER NOT NULL REFERENCES users(id),\n"
|
||||
" content BLOB NOT NULL,\n"
|
||||
" timestamp DATETIME NOT NULL DEFAULT (\n"
|
||||
" strftime('"MESSAGE_TIMESTAMP_FMT_SQL"', 'now', 'utc', 'subsecond')\n"
|
||||
" )\n"
|
||||
");"
|
||||
)));
|
||||
Defer(tsqlite_statement_free(create_table_messages));
|
||||
try_void(tsqlite_statement_step(create_table_messages));
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// USERS //
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
try(tsqlite_statement* create_table_users, p, tsqlite_statement_compile(db, STR(
|
||||
"CREATE TABLE IF NOT EXISTS users (\n"
|
||||
" id INTEGER PRIMARY KEY AUTOINCREMENT,\n"
|
||||
" username VARCHAR NOT NULL,\n"
|
||||
" token BLOB NOT NULL,\n"
|
||||
" registration_time DATETIME NOT NULL DEFAULT (\n"
|
||||
" strftime('"MESSAGE_TIMESTAMP_FMT_SQL"', 'now', 'utc', 'subsecond')\n"
|
||||
" )\n"
|
||||
");"
|
||||
)));
|
||||
Defer(tsqlite_statement_free(create_table_users));
|
||||
try_void(tsqlite_statement_step(create_table_users));
|
||||
|
||||
try(tsqlite_statement* create_index_username, p, tsqlite_statement_compile(db, STR(
|
||||
"CREATE UNIQUE INDEX IF NOT EXISTS idx_users_username ON users(username);"
|
||||
)));
|
||||
Defer(tsqlite_statement_free(create_index_username));
|
||||
try_void(tsqlite_statement_step(create_index_username));
|
||||
|
||||
success = true;
|
||||
Return RESULT_VALUE(p, db);
|
||||
}
|
||||
|
||||
|
||||
|
||||
void ServerQueries_free(ServerQueries* q){
|
||||
if(!q)
|
||||
return;
|
||||
|
||||
tsqlite_statement_free(q->channels.insert);
|
||||
tsqlite_statement_free(q->channels.update);
|
||||
tsqlite_statement_free(q->channels.exists);
|
||||
|
||||
tsqlite_statement_free(q->messages.insert);
|
||||
tsqlite_statement_free(q->messages.get_block);
|
||||
|
||||
tsqlite_statement_free(q->users.insert);
|
||||
tsqlite_statement_free(q->users.find_by_username);
|
||||
tsqlite_statement_free(q->users.compare_token);
|
||||
|
||||
free(q);
|
||||
}
|
||||
|
||||
Result(ServerQueries*) ServerQueries_compile(tsqlite_connection* db){
|
||||
Deferral(4);
|
||||
|
||||
ServerQueries* q = (ServerQueries*)malloc(sizeof(*q));
|
||||
zeroStruct(q);
|
||||
bool success = false;
|
||||
Defer(if(!success) ServerQueries_free(q));
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// CHANNELS //
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
try(q->channels.insert, p, tsqlite_statement_compile(db, STR(
|
||||
"INSERT INTO\n"
|
||||
"channels (id, name, description)\n"
|
||||
"VALUES ($id, $name, $description);"
|
||||
)));
|
||||
|
||||
try(q->channels.exists, p, tsqlite_statement_compile(db, STR(
|
||||
"SELECT 1 FROM channels WHERE id = $id;"
|
||||
)));
|
||||
|
||||
try(q->channels.update, p, tsqlite_statement_compile(db, STR(
|
||||
"UPDATE channels\n"
|
||||
"SET name = $name, description = $description\n"
|
||||
"WHERE id = $id;"
|
||||
)));
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// MESSAGES //
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
try(q->messages.insert, p, tsqlite_statement_compile(db, STR(
|
||||
"INSERT INTO\n"
|
||||
"messages (channel_id, sender_id, content)\n"
|
||||
"VALUES ($channel_id, $sender_id, $content)\n"
|
||||
"RETURNING id, timestamp;"
|
||||
)));
|
||||
|
||||
try(q->messages.get_block, p, tsqlite_statement_compile(db, STR(
|
||||
"SELECT id, sender_id, content, timestamp FROM messages\n"
|
||||
"WHERE id >= $first_message_id\n"
|
||||
"AND channel_id = $channel_id\n"
|
||||
"LIMIT $count;"
|
||||
)));
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
// USERS //
|
||||
///////////////////////////////////////////////////////////////////////////
|
||||
try(q->users.insert, p, tsqlite_statement_compile(db, STR(
|
||||
"INSERT INTO\n"
|
||||
"users (username, token)\n"
|
||||
"VALUES ($username, $token)\n"
|
||||
"RETURNING id, registration_time;"
|
||||
)));
|
||||
|
||||
try(q->users.find_by_username, p, tsqlite_statement_compile(db, STR(
|
||||
"SELECT id FROM users WHERE username = $username;"
|
||||
)));
|
||||
|
||||
try(q->users.compare_token, p, tsqlite_statement_compile(db, STR(
|
||||
"SELECT 1 FROM users WHERE id = $id AND token = $token;"
|
||||
)));
|
||||
|
||||
success = true;
|
||||
Return RESULT_VALUE(p, q);
|
||||
}
|
||||
40
src/server/db/server_db.h
Normal file
40
src/server/db/server_db.h
Normal file
@@ -0,0 +1,40 @@
|
||||
#pragma once
|
||||
#include "tcp-chat/tcp-chat.h"
|
||||
#include "tsqlite.h"
|
||||
#include "network/tcp-chat-protocol/v1.h"
|
||||
|
||||
/// @brief open DB and create tables
|
||||
Result(tsqlite_connection*) ServerDatabase_open(cstr file_path);
|
||||
|
||||
typedef struct ServerQueries ServerQueries;
|
||||
Result(ServerQueries*) ServerQueries_compile(tsqlite_connection* db);
|
||||
void ServerQueries_free(ServerQueries* self);
|
||||
|
||||
|
||||
Result(bool) Channel_exists(ServerQueries* q, i64 id);
|
||||
|
||||
/// @return true if new row was created
|
||||
Result(bool) Channel_createOrUpdate(ServerQueries* q,
|
||||
i64 id, str name, str description);
|
||||
|
||||
/// @return new message id
|
||||
Result(i64) Channel_saveMessage(ServerQueries* q,
|
||||
i64 channel_id, i64 sender_id, Array(u8) content,
|
||||
DateTime* out_timestamp_utc);
|
||||
|
||||
/// @brief try to find `count` messages starting from `first_message_id`
|
||||
/// @param out_meta writes here information about found messages, .count can be 0 if no messages found
|
||||
/// @param out_block .len must be >= count * (sizeof(MessageMeta) + MESSAGE_SIZE_MAX)
|
||||
Result(void) Channel_loadMessageBlock(ServerQueries* q,
|
||||
i64 channel_id, i64 first_message_id, u32 count,
|
||||
MessageBlockMeta* out_block_meta, Array(u8) out_block_data);
|
||||
|
||||
|
||||
/// @return existing user id or 0
|
||||
Result(i64) User_findByUsername(ServerQueries* q, str username);
|
||||
|
||||
/// @return new user id
|
||||
Result(i64) User_register(ServerQueries* q, str username, Array(u8) token);
|
||||
|
||||
/// @return true for successful authorization
|
||||
Result(bool) User_tryAuthorize(ServerQueries* q, u64 id, Array(u8) token);
|
||||
27
src/server/db/server_db_internal.h
Normal file
27
src/server/db/server_db_internal.h
Normal file
@@ -0,0 +1,27 @@
|
||||
#pragma once
|
||||
#include "server_db.h"
|
||||
|
||||
typedef struct ServerQueries {
|
||||
struct {
|
||||
/* ($id, $name, $description) -> void */
|
||||
tsqlite_statement* insert;
|
||||
/* ($id, $name, $description) -> void */
|
||||
tsqlite_statement* update;
|
||||
/* ($id) -> 1 or nothing */
|
||||
tsqlite_statement* exists;
|
||||
} channels;
|
||||
struct {
|
||||
/* ($channel_id, $sender_id, $content) -> (id, timestamp) */
|
||||
tsqlite_statement* insert;
|
||||
/* ($channel_id, $first_message_id, $count) -> [(id, sender_id, content, timestamp)] */
|
||||
tsqlite_statement* get_block;
|
||||
} messages;
|
||||
struct {
|
||||
/* ($username, $token) -> (id, registration_time) */
|
||||
tsqlite_statement* insert;
|
||||
/* ($username) -> (id) */
|
||||
tsqlite_statement* find_by_username;
|
||||
/* ($id, $token) -> 1 or nothing */
|
||||
tsqlite_statement* compare_token;
|
||||
} users;
|
||||
} ServerQueries;
|
||||
@@ -20,28 +20,38 @@ declare_RequestHandler(GetMessageBlock)
|
||||
LogSeverity_Warn, STR("not authorized") ));
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
// validate message_count
|
||||
if(req.message_count < 1 || req.message_count > MESSAGE_BLOCK_COUNT_MAX){
|
||||
try_void(sendErrorMessage(log_ctx, conn, res_head,
|
||||
LogSeverity_Warn, STR("invalid message count in request") ));
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
// get message block from channel
|
||||
Channel* ch = Server_tryGetChannel(conn->server, req.channel_id);
|
||||
if(ch == NULL){
|
||||
// validate channel id
|
||||
try(bool channel_exists, i, Channel_exists(conn->queries, req.channel_id));
|
||||
if(!channel_exists){
|
||||
try_void(sendErrorMessage(log_ctx, conn, res_head,
|
||||
LogSeverity_Warn, STR("invalid channel id") ));
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
MessageBlockMeta meta;
|
||||
Array(u8) block_data;
|
||||
try_void(Channel_loadMessageBlock(ch, req.first_message_id, req.messages_count,
|
||||
&meta, &block_data, true));
|
||||
Defer(Array_u8_destroy(&block_data));
|
||||
|
||||
// reset block meta
|
||||
zeroStruct(&conn->message_block_meta);
|
||||
// get message block from channel
|
||||
try_void(Channel_loadMessageBlock(conn->queries,
|
||||
req.channel_id, req.first_message_id, req.message_count,
|
||||
&conn->message_block_meta, conn->message_block));
|
||||
|
||||
// send response
|
||||
GetMessageBlockResponse res;
|
||||
GetMessageBlockResponse_construct(&res, res_head,
|
||||
meta.first_message_id, meta.messages_count, meta.data_size);
|
||||
GetMessageBlockResponse_construct(&res, res_head, &conn->message_block_meta);
|
||||
try_void(EncryptedSocketTCP_sendStruct(&conn->sock, res_head));
|
||||
try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &res));
|
||||
if(block_data.len != 0){
|
||||
try_void(EncryptedSocketTCP_send(&conn->sock, block_data));
|
||||
if(conn->message_block_meta.data_size != 0){
|
||||
try_void(EncryptedSocketTCP_send(&conn->sock,
|
||||
Array_u8_sliceTo(conn->message_block, conn->message_block_meta.data_size))
|
||||
);
|
||||
}
|
||||
|
||||
Return RESULT_VOID;
|
||||
|
||||
@@ -22,8 +22,8 @@ declare_RequestHandler(Login)
|
||||
}
|
||||
|
||||
// validate username
|
||||
str username_str = str_null;
|
||||
str name_error_str = validateUsername_cstr(req.username, &username_str);
|
||||
str username = str_null;
|
||||
str name_error_str = validateUsername_cstr(req.username, &username);
|
||||
if(name_error_str.data){
|
||||
Defer(str_destroy(name_error_str));
|
||||
try_void(sendErrorMessage(log_ctx, conn, res_head,
|
||||
@@ -31,42 +31,28 @@ declare_RequestHandler(Login)
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
// lock users cache
|
||||
idb_lockTable(srv->users.table);
|
||||
bool unlocked_users_cache_mutex = false;
|
||||
Defer(
|
||||
if(!unlocked_users_cache_mutex)
|
||||
idb_unlockTable(srv->users.table)
|
||||
);
|
||||
|
||||
// try get id from name cache
|
||||
u64* id_ptr = HashMap_tryGetPtr(&srv->users.name_id_map, username_str);
|
||||
if(id_ptr == NULL){
|
||||
// get user by id
|
||||
try(u64 user_id, i, User_findByUsername(conn->queries, username));
|
||||
if(user_id == 0){
|
||||
try_void(sendErrorMessage(log_ctx, conn, res_head,
|
||||
LogSeverity_Warn, STR("Username is not registered") ));
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
u64 user_id = *id_ptr;
|
||||
|
||||
// get user by id
|
||||
try_assert(user_id < srv->users.list.len);
|
||||
UserInfo* u = srv->users.list.data + user_id;
|
||||
|
||||
// TODO: get user token
|
||||
Array(u8) token = Array_u8_construct(req.token, sizeof(req.token));
|
||||
try(bool authorized, i, User_tryAuthorize(conn->queries, user_id, token));
|
||||
// validate token hash
|
||||
if(memcmp(req.token, u->token, sizeof(req.token)) != 0){
|
||||
if(!authorized){
|
||||
try_void(sendErrorMessage(log_ctx, conn, res_head,
|
||||
LogSeverity_Warn, STR("wrong password") ));
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
// manually unlock mutex
|
||||
idb_unlockTable(srv->users.table);
|
||||
unlocked_users_cache_mutex = true;
|
||||
|
||||
// authorize
|
||||
conn->authorized = true;
|
||||
conn->user_id = user_id;
|
||||
logInfo("authorized user '%s'", username_str.data);
|
||||
logInfo("authorized user '%s' with id "FMT_i64, username.data, user_id);
|
||||
|
||||
// send response
|
||||
LoginResponse res;
|
||||
|
||||
@@ -22,8 +22,8 @@ declare_RequestHandler(Register)
|
||||
}
|
||||
|
||||
// validate username
|
||||
str username_str = str_null;
|
||||
str name_error_str = validateUsername_cstr(req.username, &username_str);
|
||||
str username = str_null;
|
||||
str name_error_str = validateUsername_cstr(req.username, &username);
|
||||
if(name_error_str.data){
|
||||
Defer(str_destroy(name_error_str));
|
||||
try_void(sendErrorMessage(log_ctx, conn, res_head,
|
||||
@@ -31,46 +31,19 @@ declare_RequestHandler(Register)
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
// lock users cache
|
||||
idb_lockTable(srv->users.table);
|
||||
bool unlocked_users_cache_mutex = false;
|
||||
// unlock mutex on error catch
|
||||
Defer(
|
||||
if(!unlocked_users_cache_mutex)
|
||||
idb_unlockTable(srv->users.table)
|
||||
);
|
||||
|
||||
// check if name is taken
|
||||
if(HashMap_tryGetPtr(&srv->users.name_id_map, username_str) != NULL){
|
||||
try(u64 user_id, i, User_findByUsername(conn->queries, username));
|
||||
if(user_id != 0){
|
||||
try_void(sendErrorMessage(log_ctx, conn, res_head,
|
||||
LogSeverity_Warn, STR("Username already exists") ));
|
||||
LogSeverity_Warn, STR("Username is already taken") ));
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
// initialize new user
|
||||
UserInfo user;
|
||||
zeroStruct(&user);
|
||||
|
||||
user.name_len = username_str.len;
|
||||
memcpy(user.name, username_str.data, user.name_len);
|
||||
|
||||
memcpy(user.token, req.token, sizeof(req.token));
|
||||
|
||||
DateTime_getUTC(&user.registration_time_utc);
|
||||
|
||||
// save new user to db and cache
|
||||
try(u64 user_id, u,
|
||||
idb_pushRow(srv->users.table, &user, false)
|
||||
);
|
||||
try_assert(user_id == srv->users.list.len);
|
||||
List_UserInfo_pushMany(&srv->users.list, &user, 1);
|
||||
try_assert(HashMap_tryPush(&srv->users.name_id_map, username_str, &user_id));
|
||||
|
||||
// manually unlock mutex
|
||||
idb_unlockTable(srv->users.table);
|
||||
unlocked_users_cache_mutex = true;
|
||||
|
||||
logInfo("registered user '%s'", username_str.data);
|
||||
// register new user
|
||||
Array(u8) token = Array_u8_construct(req.token, sizeof(req.token));
|
||||
try(user_id, i, User_register(conn->queries, username, token));
|
||||
logInfo("registered user '"FMT_str"' with id "FMT_i64,
|
||||
str_unwrap(username), user_id);
|
||||
|
||||
// send response
|
||||
RegisterResponse res;
|
||||
|
||||
@@ -21,6 +21,7 @@ declare_RequestHandler(SendMessage)
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
// validate content size
|
||||
if(req.data_size < MESSAGE_SIZE_MIN || req.data_size > MESSAGE_SIZE_MAX){
|
||||
try_void(sendErrorMessage(log_ctx, conn, res_head,
|
||||
LogSeverity_Warn, STR("invalid message size") ));
|
||||
@@ -29,24 +30,26 @@ declare_RequestHandler(SendMessage)
|
||||
}
|
||||
|
||||
// receive message data
|
||||
Array(u8) message_data = Array_u8_alloc(req.data_size);
|
||||
try_void(EncryptedSocketTCP_recv(&conn->sock, message_data, SocketRecvFlag_WholeBuffer));
|
||||
try_void(EncryptedSocketTCP_recv(&conn->sock, conn->message_content, SocketRecvFlag_WholeBuffer));
|
||||
|
||||
// save message to channel
|
||||
Channel* ch = Server_tryGetChannel(conn->server, req.channel_id);
|
||||
if(ch == NULL){
|
||||
// validate channel id
|
||||
try(bool channel_exists, i, Channel_exists(conn->queries, req.channel_id));
|
||||
if(!channel_exists){
|
||||
try_void(sendErrorMessage(log_ctx, conn, res_head,
|
||||
LogSeverity_Warn, STR("invalid channel id") ));
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
MessageMeta message_meta;
|
||||
try_void(Channel_saveMessage(ch, message_data, conn->user_id,
|
||||
&message_meta, true));
|
||||
|
||||
// save message to channel
|
||||
DateTime timestamp;
|
||||
try(i64 message_id, i, Channel_saveMessage(conn->queries,
|
||||
req.channel_id, conn->user_id, conn->message_content,
|
||||
×tamp));
|
||||
|
||||
// send response
|
||||
SendMessageResponse res;
|
||||
SendMessageResponse_construct(&res, res_head,
|
||||
message_meta.id, message_meta.receiving_time_utc);
|
||||
message_id, timestamp);
|
||||
try_void(EncryptedSocketTCP_sendStruct(&conn->sock, res_head));
|
||||
try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &res));
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
#pragma once
|
||||
#include "network/tcp-chat-protocol/v1.h"
|
||||
#include "server/server_internal.h"
|
||||
|
||||
|
||||
Result(void) sendErrorMessage(
|
||||
cstr log_ctx, ClientConnection* conn, PacketHeader* res_head,
|
||||
LogSeverity log_severity, str msg);
|
||||
|
||||
@@ -9,7 +9,7 @@ Result(void) sendErrorMessage(
|
||||
cstr log_ctx, ClientConnection* conn, PacketHeader* res_head,
|
||||
LogSeverity log_severity, str msg)
|
||||
{
|
||||
Deferral(1);
|
||||
Deferral(4);
|
||||
|
||||
//limit ErrorMessage size to fit into EncryptedSocketTCP.internal_buffer_size
|
||||
if(msg.len > ERROR_MESSAGE_MAX_SIZE)
|
||||
@@ -44,7 +44,7 @@ Result(void) sendErrorMessage_f(
|
||||
ClientConnection* conn, PacketHeader* res_head,
|
||||
LogSeverity log_severity, cstr format, ...)
|
||||
{
|
||||
Deferral(1);
|
||||
Deferral(4);
|
||||
|
||||
va_list argv;
|
||||
va_start(argv, format);
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
#include <pthread.h>
|
||||
#include "tlibc/filesystem.h"
|
||||
#include "tlibc/time.h"
|
||||
#include "tlibc/base64.h"
|
||||
#include "tlibc/algorithms.h"
|
||||
#include "server/server_internal.h"
|
||||
#include "network/tcp-chat-protocol/v1.h"
|
||||
#include "server/responses/responses.h"
|
||||
#include "tlibtoml.h"
|
||||
|
||||
@@ -20,17 +17,9 @@ void Server_free(Server* self){
|
||||
RSA_destroyPrivateKey(&self->rsa_sk);
|
||||
RSA_destroyPublicKey(&self->rsa_pk);
|
||||
|
||||
idb_close(self->db);
|
||||
|
||||
List_UserInfo_destroy(&self->users.list);
|
||||
HashMap_destroy(&self->users.name_id_map);
|
||||
|
||||
for(u32 i = 0; i < self->channels.list.len; i++){
|
||||
Channel_free(self->channels.list.data[i]);
|
||||
}
|
||||
List_ChannelPtr_destroy(&self->channels.list);
|
||||
HashMap_destroy(&self->channels.name_id_map);
|
||||
|
||||
free(self->db_path);
|
||||
ServerQueries_free(self->queries);
|
||||
tsqlite_connection_close(self->db);
|
||||
|
||||
free(self);
|
||||
}
|
||||
@@ -58,78 +47,51 @@ Result(Server*) Server_create(str config_file_content, cstr config_file_name,
|
||||
try(TomlTable* config_top, p, toml_load_str_filename(config_file_content, config_file_name));
|
||||
Defer(TomlTable_free(config_top));
|
||||
|
||||
// [server]
|
||||
try(TomlTable* config_server, p, TomlTable_get_table(config_top, STR("server")))
|
||||
// parse name
|
||||
// name
|
||||
try(str* v_name, p, TomlTable_get_str(config_server, STR("name")));
|
||||
self->name = str_copy(*v_name);
|
||||
|
||||
// parse description
|
||||
// description
|
||||
try(str* v_desc, p, TomlTable_get_str(config_server, STR("description")));
|
||||
self->description = str_copy(*v_desc);
|
||||
|
||||
// parse local_address
|
||||
// local_address
|
||||
try(str* v_local_address, p, TomlTable_get_str(config_server, STR("local_address")));
|
||||
try_assert(v_local_address->isZeroTerminated);
|
||||
try_void(EndpointIPv4_parse(v_local_address->data, &self->local_end));
|
||||
|
||||
// parse landing_channel_id
|
||||
// landing_channel_id
|
||||
try(i64 v_landing_channel_id, i, TomlTable_get_integer(config_server, STR("landing_channel_id")));
|
||||
self->landing_channel_id = v_landing_channel_id;
|
||||
|
||||
// [keys]
|
||||
try(TomlTable* config_keys, p, TomlTable_get_table(config_top, STR("keys")))
|
||||
// parse rsa_private_key
|
||||
// rsa_private_key
|
||||
try(str* v_rsa_sk, p, TomlTable_get_str(config_keys, STR("rsa_private_key")));
|
||||
try_assert(v_rsa_sk->isZeroTerminated);
|
||||
try_void(RSA_parsePrivateKey_base64(v_rsa_sk->data, &self->rsa_sk));
|
||||
|
||||
// parse rsa_public_key
|
||||
// rsa_public_key
|
||||
try(str* v_rsa_pk, p, TomlTable_get_str(config_keys, STR("rsa_public_key")));
|
||||
try_assert(v_rsa_pk->isZeroTerminated);
|
||||
try_void(RSA_parsePublicKey_base64(v_rsa_pk->data, &self->rsa_pk));
|
||||
|
||||
// [db]
|
||||
try(TomlTable* config_db, p, TomlTable_get_table(config_top, STR("database")))
|
||||
// parse db_aes_key
|
||||
try(str* v_db_aes_key, p, TomlTable_get_str(config_db, STR("aes_key")));
|
||||
str db_aes_key_s = *v_db_aes_key;
|
||||
Array(u8) db_aes_key = Array_u8_alloc(base64_decodedSize(db_aes_key_s.data, db_aes_key_s.len));
|
||||
Defer(free(db_aes_key.data));
|
||||
base64_decode(db_aes_key_s.data, db_aes_key_s.len, db_aes_key.data);
|
||||
|
||||
// parse db_dir
|
||||
try(str* v_db_dir, p, TomlTable_get_str(config_db, STR("dir")));
|
||||
// path
|
||||
try(str* v_db_path, p, TomlTable_get_str(config_db, STR("path")));
|
||||
self->db_path = str_copy(*v_db_path).data;
|
||||
|
||||
// open DB
|
||||
try(self->db, p, idb_open(*v_db_dir, db_aes_key));
|
||||
logInfo("loading database '%s'", self->db_path);
|
||||
try(self->db, p, ServerDatabase_open(self->db_path));
|
||||
try(self->queries, p, ServerQueries_compile(self->db));
|
||||
|
||||
// build users cache
|
||||
logInfo("loading users...");
|
||||
try(self->users.table, p,
|
||||
idb_getOrCreateTable(self->db, str_null, STR("users"), sizeof(UserInfo), false)
|
||||
);
|
||||
|
||||
// load whole users table to list
|
||||
try_void(
|
||||
idb_createListFromTable(self->users.table, (void*)&self->users.list, false)
|
||||
);
|
||||
|
||||
// build name-id map
|
||||
try(u64 users_count, u, idb_getRowCount(self->users.table, false));
|
||||
HashMap_construct(&self->users.name_id_map, u64, NULL);
|
||||
for(u64 id = 0; id < users_count; id++){
|
||||
UserInfo* user = self->users.list.data + id;
|
||||
str key = str_construct(user->name, user->name_len, true);
|
||||
if(!HashMap_tryPush(&self->users.name_id_map, key, &id)){
|
||||
Return RESULT_ERROR_FMT("duplicate user name '"FMT_str"'", str_expand(key));
|
||||
}
|
||||
}
|
||||
logInfo("loaded "FMT_u64" users", users_count);
|
||||
|
||||
// parse channels
|
||||
// [channels]
|
||||
logDebug("loading channels...");
|
||||
HashMap_construct(&self->channels.name_id_map, u64, NULL);
|
||||
self->channels.list = List_ChannelPtr_alloc(32);
|
||||
try(TomlTable* config_channels, p, TomlTable_get_table(config_top, STR("channels")));
|
||||
|
||||
HashMapIter channels_iter = HashMapIter_create(config_channels);
|
||||
while(HashMapIter_moveNext(&channels_iter)){
|
||||
HashMapKeyValue kv;
|
||||
@@ -140,25 +102,14 @@ Result(Server*) Server_create(str config_file_content, cstr config_file_name,
|
||||
if(val->type != TLIBTOML_TABLE)
|
||||
continue;
|
||||
|
||||
logInfo("loading channel '"FMT_str"'", str_expand(name))
|
||||
logInfo("loading channel '"FMT_str"'", str_unwrap(name))
|
||||
TomlTable* config_channel = val->table;
|
||||
try(u64 id, u, TomlTable_get_integer(config_channel, STR("id")));
|
||||
try(i64 id, u, TomlTable_get_integer(config_channel, STR("id")));
|
||||
try(str* v_ch_desc, p, TomlTable_get_str(config_channel, STR("description")))
|
||||
str description = *v_ch_desc;
|
||||
|
||||
if(!HashMap_tryPush(&self->channels.name_id_map, name, &id)){
|
||||
Return RESULT_ERROR_FMT("duplicate channel '"FMT_str"'", str_expand(name));
|
||||
}
|
||||
|
||||
logDebug("loading messages...");
|
||||
try(Channel* channel, p, Channel_create(id, name, description, self->db, false));
|
||||
logDebug("loaded "FMT_u64" messages", channel->messages.count);
|
||||
List_ChannelPtr_push(&self->channels.list, channel);
|
||||
try_void(Channel_createOrUpdate(self->queries, id, name, description));
|
||||
}
|
||||
logDebug("loaded "FMT_u32" channels", self->channels.list.len);
|
||||
// sort channels list by id
|
||||
for(u32 i = 0; i < self->channels.list.len; i++);
|
||||
insertionSort_inline(self->channels.list.data, self->channels.list.len, ->id);
|
||||
|
||||
success = true;
|
||||
Return RESULT_VALUE(p, self);
|
||||
@@ -184,7 +135,7 @@ Result(void) Server_run(Server* server){
|
||||
Defer(free(local_end_str.data));
|
||||
logInfo("server is listening at %s", local_end_str.data);
|
||||
|
||||
u64 session_id = 1;
|
||||
i64 session_id = 1;
|
||||
while(true){
|
||||
ConnectionHandlerArgs* args = (ConnectionHandlerArgs*)malloc(sizeof(ConnectionHandlerArgs));
|
||||
args->server = server;
|
||||
@@ -201,16 +152,6 @@ Result(void) Server_run(Server* server){
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
Channel* Server_tryGetChannel(Server* self, u64 id){
|
||||
i32 index;
|
||||
binarySearch_inline(self->channels.list.data, self->channels.list.len, id, ->id, index);
|
||||
if(index == -1){
|
||||
return NULL;
|
||||
}
|
||||
Channel* channel = self->channels.list.data[index];
|
||||
return channel;
|
||||
}
|
||||
|
||||
static void* handleConnection(void* _args){
|
||||
ConnectionHandlerArgs* args = (ConnectionHandlerArgs*)_args;
|
||||
Server* server = args->server;
|
||||
|
||||
@@ -1,56 +1,14 @@
|
||||
#pragma once
|
||||
#include "tlibc/collections/HashMap.h"
|
||||
#include "tlibc/collections/List.h"
|
||||
#include "tlibc/collections/LList.h"
|
||||
#include "tcp-chat/tcp-chat.h"
|
||||
#include "tcp-chat/server.h"
|
||||
#include "cryptography/AES.h"
|
||||
#include "cryptography/RSA.h"
|
||||
#include "network/encrypted_sockets.h"
|
||||
#include "db/idb.h"
|
||||
#include "db/tables.h"
|
||||
#include "network/tcp-chat-protocol/v1.h"
|
||||
#include "db/server_db.h"
|
||||
|
||||
typedef struct ClientConnection ClientConnection;
|
||||
|
||||
List_declare(UserInfo);
|
||||
List_declare(MessageBlockMeta);
|
||||
LList_declare(MessageBlock);
|
||||
|
||||
#define MESSAGE_BLOCKS_CACHE_COUNT 50
|
||||
|
||||
typedef struct Channel {
|
||||
u64 id;
|
||||
str name;
|
||||
str description;
|
||||
struct {
|
||||
u64 count;
|
||||
Table* blocks_table;
|
||||
Table* blocks_meta_table;
|
||||
List(MessageBlockMeta) blocks_meta_list; // index is id
|
||||
// last MESSAGE_BLOCKS_CACHE_COUNT MessageBlocks, ascending
|
||||
// new messages are written to the last block
|
||||
LList(MessageBlock) blocks_queue;
|
||||
} messages;
|
||||
} Channel;
|
||||
|
||||
typedef Channel* ChannelPtr;
|
||||
List_declare(ChannelPtr);
|
||||
|
||||
Result(Channel*) Channel_create(u64 id, str name, str description,
|
||||
IncrementalDB* db, bool lock_db);
|
||||
|
||||
void Channel_free(Channel* self);
|
||||
|
||||
Result(void) Channel_saveMessage(Channel* self, Array(u8) message_data, u64 sender_id,
|
||||
MessageMeta* out_message_meta, bool lock_tables);
|
||||
|
||||
/// @brief try to find `count` messages starting from `fisrt_message_id`
|
||||
/// @param out_meta information about found messages, .count can be 0 if no messages found
|
||||
/// @param out_block allocates buffer on heap and copies them there, .len can be 0 if no messages found
|
||||
Result(void) Channel_loadMessageBlock(Channel* self, u64 fisrt_message_id, u32 count,
|
||||
MessageBlockMeta* out_meta, NULLABLE(Array(u8)*) out_block, bool lock_tables);
|
||||
|
||||
|
||||
typedef struct Server {
|
||||
/* from constructor */
|
||||
void* logger;
|
||||
@@ -59,42 +17,42 @@ typedef struct Server {
|
||||
/* from config */
|
||||
str name;
|
||||
str description;
|
||||
u64 landing_channel_id;
|
||||
i64 landing_channel_id;
|
||||
EndpointIPv4 local_end;
|
||||
br_rsa_private_key rsa_sk;
|
||||
br_rsa_public_key rsa_pk;
|
||||
|
||||
/* database and cache */
|
||||
IncrementalDB* db;
|
||||
struct {
|
||||
Table* table;
|
||||
List(UserInfo) list; // index is id
|
||||
HashMap(u64) name_id_map;
|
||||
} users;
|
||||
struct {
|
||||
List(ChannelPtr) list;
|
||||
HashMap(u64) name_id_map;
|
||||
} channels;
|
||||
/* database and cache*/
|
||||
char* db_path;
|
||||
tsqlite_connection* db;
|
||||
ServerQueries* queries; /* for server listener thread only */
|
||||
} Server;
|
||||
|
||||
Channel* Server_tryGetChannel(Server* self, u64 id);
|
||||
|
||||
|
||||
typedef struct ClientConnection {
|
||||
Server* server;
|
||||
u64 session_id;
|
||||
i64 session_id;
|
||||
EndpointIPv4 client_end;
|
||||
Array(u8) session_key;
|
||||
EncryptedSocketTCP sock;
|
||||
bool authorized;
|
||||
u64 user_id; // -1 for unauthorized
|
||||
i64 user_id; // 0 for unauthorized
|
||||
|
||||
/* buffers */
|
||||
MessageBlockMeta message_block_meta;
|
||||
Array(u8) message_block;
|
||||
Array(u8) message_content;
|
||||
|
||||
/* database */
|
||||
tsqlite_connection* db;
|
||||
ServerQueries* queries;
|
||||
} ClientConnection;
|
||||
|
||||
typedef struct ConnectionHandlerArgs {
|
||||
Server* server;
|
||||
Socket accepted_socket_tcp;
|
||||
EndpointIPv4 client_end;
|
||||
u64 session_id;
|
||||
i64 session_id;
|
||||
} ConnectionHandlerArgs;
|
||||
|
||||
Result(ClientConnection*) ClientConnection_accept(ConnectionHandlerArgs* args);
|
||||
|
||||
@@ -5,16 +5,17 @@ description = """\
|
||||
Qqqqq...\
|
||||
"""
|
||||
local_address = '127.0.0.1:9988'
|
||||
landing_channel_id = 0
|
||||
landing_channel_id = 1
|
||||
|
||||
# do not create channels with the same id
|
||||
[channels.general]
|
||||
id = 0
|
||||
id = 1
|
||||
description = "a text channel"
|
||||
|
||||
[database]
|
||||
dir = 'server-db'
|
||||
aes_key = '<generate with './tcp-chat --random-bytes-base64 32'>'
|
||||
path = 'tcp-chat-server/server.sqlite'
|
||||
# on windows use backslashes
|
||||
# path = 'tcp-chat-server\server.sqlite'
|
||||
|
||||
[keys]
|
||||
rsa_private_key = '<generate with './tcp-chat --rsa-gen-random'>'
|
||||
|
||||
Reference in New Issue
Block a user