implemented CommonQueries

This commit is contained in:
2025-12-15 23:26:32 +05:00
parent 72696dea70
commit 49793e2929
29 changed files with 540 additions and 495 deletions

View File

@@ -230,13 +230,13 @@ static Result(void) ClientCLI_selectServerFromCache(ClientCLI* self){
idb_lockTable(self->servers.table);
Defer(idb_unlockTable(self->servers.table));
u32 servers_count = self->servers.list.len;
if(servers_count == 0){
u32 server_count = self->servers.list.len;
if(server_count == 0){
printf("No servers found in cache\n");
Return RESULT_VOID;
}
for(u32 id = 0; id < servers_count; id++){
for(u32 id = 0; id < server_count; id++){
ServerInfo* server = self->servers.list.data + id;
printf("[%02u] "FMT_str" "FMT_str"\n",
id, server->address_len, server->address, server->name_len, server->name);
@@ -256,7 +256,7 @@ static Result(void) ClientCLI_selectServerFromCache(ClientCLI* self){
if(sscanf(buf, FMT_u32, &id) != 1){
printf("ERROR: not a number\n");
}
else if(id >= servers_count){
else if(id >= server_count){
printf("ERROR: not a server number: %u\n", id);
}
else break;
@@ -316,9 +316,9 @@ static Result(void) ClientCLI_openUserDB(ClientCLI* self){
str username = Client_getUserName(self->client);
Array(u8) user_data_key = Client_getUserDataKey(self->client);
str user_db_dir = str_from_cstr(strcat_malloc("client-db", path_seps, username.data));
Defer(free(user_db_dir.data));
try(self->db, p, idb_open(user_db_dir, user_data_key));
str user_db_path = str_from_cstr(strcat_malloc("client-db", path_seps, username.data));
Defer(free(user_db_path.data));
try(self->db, p, idb_open(user_db_path, user_data_key));
// Lock DB until this function returns.
idb_lockDB(self->db);
@@ -339,11 +339,11 @@ static Result(void) ClientCLI_openUserDB(ClientCLI* self){
);
// build address-id map
try(u64 servers_count, u,
try(i64 server_count, u,
idb_getRowCount(self->servers.table, false)
);
HashMap_construct(&self->servers.addr_id_map, u64, NULL);
for(u64 id = 0; id < servers_count; id++){
HashMap_construct(&self->servers.addr_id_map, i64, NULL);
for(i64 id = 0; id < server_count; id++){
ServerInfo* server = self->servers.list.data + id;
str key = str_construct(server->address, server->address_len, true);
if(!HashMap_tryPush(&self->servers.addr_id_map, key, &id)){
@@ -391,11 +391,11 @@ static Result(ServerInfo*) ClientCLI_saveServerInfo(ClientCLI* self,
// try find server id in cache
ServerInfo* cached_row_ptr = NULL;
u64* id_ptr = NULL;
i64* id_ptr = NULL;
id_ptr = HashMap_tryGetPtr(&self->servers.addr_id_map, addr);
if(id_ptr){
// update existing server
u64 id = *id_ptr;
i64 id = *id_ptr;
try_void(idb_updateRow(self->servers.table, id, &server, false));
try_assert(id < self->servers.list.len);
cached_row_ptr = self->servers.list.data + id;
@@ -403,7 +403,7 @@ static Result(ServerInfo*) ClientCLI_saveServerInfo(ClientCLI* self,
}
else {
// push new server
try(u64 id, u, idb_pushRow(self->servers.table, &server, false));
try(i64 id, u, idb_pushRow(self->servers.table, &server, false));
try_assert(id == self->servers.list.len);
List_ServerInfo_pushMany(&self->servers.list, &server, 1);
cached_row_ptr = self->servers.list.data + id;
@@ -416,10 +416,10 @@ static Result(ServerInfo*) ClientCLI_saveServerInfo(ClientCLI* self,
static Result(void) ClientCLI_register(ClientCLI* self){
Deferral(8);
u64 user_id = 0;
i64 user_id = 0;
try_void(Client_register(self->client, &user_id));
printf("Registered successfully\n");
printf("user_id: "FMT_u64"\n", user_id);
printf("user_id: "FMT_i64"\n", user_id);
// TODO: use user_id somewhere
Return RESULT_VOID;
@@ -428,10 +428,10 @@ static Result(void) ClientCLI_register(ClientCLI* self){
static Result(void) ClientCLI_login(ClientCLI* self){
Deferral(8);
u64 user_id = 0, landing_channel_id = 0;
i64 user_id = 0, landing_channel_id = 0;
try_void(Client_login(self->client, &user_id, &landing_channel_id));
printf("Authorized successfully\n");
printf("user_id: "FMT_u64", landing_channel_id: "FMT_u64"\n", user_id, landing_channel_id);
printf("user_id: "FMT_i64", landing_channel_id: "FMT_i64"\n", user_id, landing_channel_id);
// TODO: use user_id, landing_channel_id somewhere
Return RESULT_VOID;

View File

@@ -2,20 +2,12 @@
#include <pthread.h>
#include "tlibc/collections/HashMap.h"
#include "tlibc/collections/List.h"
#include "tsqlite.h"
#include "tcp-chat/client.h"
#include "db/idb.h"
#include "db/tables.h"
List_declare(ServerInfo);
typedef struct ClientCLI {
Client* client;
IncrementalDB* db;
struct {
Table* table;
List(ServerInfo) list; // index is id
HashMap(u64) addr_id_map; // key is server address
} servers;
tsqlite_connection* db;
} ClientCLI;
void ClientCLI_construct(ClientCLI* self);

View File

@@ -72,7 +72,7 @@ Result(void) Client_getServerDescription(Client* self, str* out_desc){
Return RESULT_VOID;
}
Result(void) Client_register(Client* self, u64* out_user_id){
Result(void) Client_register(Client* self, i64* out_user_id){
Deferral(1);
try_assert(self != NULL);
try_assert(self->conn != NULL && "didn't connect to a server yet");
@@ -90,7 +90,7 @@ Result(void) Client_register(Client* self, u64* out_user_id){
Return RESULT_VOID;
}
Result(void) Client_login(Client* self, u64* out_user_id, u64* out_landing_channel_id){
Result(void) Client_login(Client* self, i64* out_user_id, i64* out_landing_channel_id){
Deferral(1);
try_assert(self != NULL);
try_assert(self->conn != NULL && "didn't connect to a server yet");

View File

@@ -21,10 +21,10 @@ typedef struct ServerConnection {
Array(u8) token;
Array(u8) session_key;
EncryptedSocketTCP sock;
u64 session_id;
i64 session_id;
str server_name;
str server_description;
u64 user_id;
i64 user_id;
} ServerConnection;
/// @param server_addr_cstr

View File

@@ -132,12 +132,12 @@ Result(i32) socket_recvfrom(Socket s, Array(u8) buffer, SocketRecvFlag flags, NU
}
Result(void) socket_TCP_enableAliveChecks(Socket s,
sec_t first_check_time, u32 checks_count, sec_t checks_interval)
sec_t first_check_time, u32 check_count, sec_t checks_interval)
{
#if KN_USE_WINSOCK
BOOL opt_SO_KEEPALIVE = 1; // enable keepalives
DWORD opt_TCP_KEEPIDLE = first_check_time;
DWORD opt_TCP_KEEPCNT = checks_count;
DWORD opt_TCP_KEEPCNT = check_count;
DWORD opt_TCP_KEEPINTVL = checks_interval;
try_setsockopt(s, SOL_SOCKET, SO_KEEPALIVE);
try_setsockopt(s, IPPROTO_TCP, TCP_KEEPIDLE);
@@ -145,12 +145,12 @@ Result(void) socket_TCP_enableAliveChecks(Socket s,
try_setsockopt(s, IPPROTO_TCP, TCP_KEEPINTVL);
// timeout for connect()
DWORD opt_TCP_MAXRT = checks_count * checks_interval;
DWORD opt_TCP_MAXRT = check_count * checks_interval;
try_setsockopt(s, IPPROTO_TCP, TCP_MAXRT);
#else
int opt_SO_KEEPALIVE = 1; // enable keepalives
int opt_TCP_KEEPIDLE = first_check_time;
int opt_TCP_KEEPCNT = checks_count;
int opt_TCP_KEEPCNT = check_count;
int opt_TCP_KEEPINTVL = checks_interval;
try_setsockopt(s, SOL_SOCKET, SO_KEEPALIVE);
try_setsockopt(s, IPPROTO_TCP, TCP_KEEPIDLE);
@@ -158,7 +158,7 @@ Result(void) socket_TCP_enableAliveChecks(Socket s,
try_setsockopt(s, IPPROTO_TCP, TCP_KEEPINTVL);
// read more in the article
int opt_TCP_USER_TIMEOUT = checks_count * checks_interval * 1000;
int opt_TCP_USER_TIMEOUT = check_count * checks_interval * 1000;
try_setsockopt(s, IPPROTO_TCP, TCP_USER_TIMEOUT);
#endif
return RESULT_VOID;

View File

@@ -38,7 +38,7 @@ Result(i32) socket_recvfrom(Socket s, Array(u8) buffer, SocketRecvFlag flags
/// Read more: https://blog.cloudflare.com/when-tcp-sockets-refuse-to-die/
/// RU translaton: https://habr.com/ru/articles/700470/
Result(void) socket_TCP_enableAliveChecks(Socket s,
sec_t first_check_time, u32 checks_count, sec_t checks_interval);
sec_t first_check_time, u32 check_count, sec_t checks_interval);
#define socket_TCP_enableAliveChecks_default(socket) \
socket_TCP_enableAliveChecks(socket, 1, 4, 5)

View File

@@ -63,7 +63,7 @@ Result(void) ClientHandshake_tryConstruct(ClientHandshake* ptr, PacketHeader* he
}
void ServerHandshake_construct(ServerHandshake* ptr, PacketHeader* header,
u64 session_id)
i64 session_id)
{
_PacketHeader_construct(ServerHandshake);
zeroStruct(ptr);
@@ -106,7 +106,7 @@ Result(void) LoginRequest_tryConstruct(LoginRequest *ptr, PacketHeader* header,
}
void LoginResponse_construct(LoginResponse* ptr, PacketHeader* header,
u64 user_id, u64 landing_channel_id)
i64 user_id, i64 landing_channel_id)
{
_PacketHeader_construct(LoginResponse);
zeroStruct(ptr);
@@ -135,15 +135,73 @@ Result(void) RegisterRequest_tryConstruct(RegisterRequest *ptr, PacketHeader* he
}
void RegisterResponse_construct(RegisterResponse *ptr, PacketHeader* header,
u64 user_id)
i64 user_id)
{
_PacketHeader_construct(RegisterResponse);
zeroStruct(ptr);
ptr->user_id = user_id;
}
Result(u32) MessageBlock_writeMessage(
MessageMeta* msg, Array(u8) msg_content,
MessageBlockMeta* block_meta, Array(u8)* block_free_part)
{
Deferral(1);
try_assert(msg->data_size >= MESSAGE_SIZE_MIN && msg->data_size <= MESSAGE_SIZE_MAX);
try_assert(msg->data_size <= msg_content.len);
u32 offset_increment = sizeof(MessageMeta) + msg->data_size;
if(block_free_part->len < offset_increment){
Return RESULT_VALUE(u, 0);
}
memcpy(block_free_part->data, msg, sizeof(MessageMeta));
block_free_part->data += sizeof(MessageMeta);
block_free_part->len -= sizeof(MessageMeta);
memcpy(block_free_part->data, msg_content.data, msg->data_size);
block_free_part->data += msg->data_size;
block_free_part->len -= msg->data_size;
if(block_meta->message_count == 0)
block_meta->first_message_id = msg->id;
block_meta->message_count++;
block_meta->data_size += offset_increment;
Return RESULT_VALUE(u, offset_increment);
}
Result(u32) MessageBlock_readMessage(
Array(u8)* block_unread_part,
MessageMeta* msg, Array(u8) msg_content)
{
Deferral(1);
try_assert(block_unread_part->len >= sizeof(MessageMeta) + MESSAGE_SIZE_MIN);
try_assert(msg_content.len >= MESSAGE_SIZE_MIN && msg_content.len <= MESSAGE_SIZE_MAX);
memcpy(msg, block_unread_part->data, sizeof(MessageMeta));
block_unread_part->data += sizeof(MessageMeta);
block_unread_part->len -= sizeof(MessageMeta);
if(msg->magic.n != MESSAGE_MAGIC.n){
Return RESULT_VALUE(u, 0);
}
try_assert(block_unread_part->len >= msg->data_size);
try_assert(msg->data_size >= MESSAGE_SIZE_MIN && msg->data_size <= MESSAGE_SIZE_MAX);
try_assert(msg->id > 0);
try_assert(msg->sender_id > 0);
try_assert(msg->timestamp.d.year > 2024);
memcpy(msg_content.data, block_unread_part->data, msg->data_size);
block_unread_part->data += msg->data_size;
block_unread_part->len -= msg->data_size;
u32 offset_increment = sizeof(MessageMeta) + msg->data_size;
Return RESULT_VALUE(u, offset_increment);
}
void SendMessageRequest_construct(SendMessageRequest *ptr, PacketHeader *header,
u64 channel_id, u16 data_size)
i64 channel_id, u16 data_size)
{
_PacketHeader_construct(SendMessageRequest);
zeroStruct(ptr);
@@ -152,30 +210,28 @@ void SendMessageRequest_construct(SendMessageRequest *ptr, PacketHeader *header,
}
void SendMessageResponse_construct(SendMessageResponse *ptr, PacketHeader *header,
u64 message_id, DateTime receiving_time_utc)
i64 message_id, DateTime timestamp)
{
_PacketHeader_construct(SendMessageResponse);
zeroStruct(ptr);
ptr->message_id = message_id;
ptr->receiving_time_utc = receiving_time_utc;
ptr->timestamp = timestamp;
}
void GetMessageBlockRequest_construct(GetMessageBlockRequest *ptr, PacketHeader *header,
u64 channel_id, u64 first_message_id, u32 messages_count)
i64 channel_id, i64 first_message_id, u32 message_count)
{
_PacketHeader_construct(GetMessageBlockRequest);
zeroStruct(ptr);
ptr->channel_id = channel_id;
ptr->first_message_id = first_message_id;
ptr->messages_count = messages_count;
ptr->message_count = message_count;
}
void GetMessageBlockResponse_construct(GetMessageBlockResponse *ptr, PacketHeader *header,
u64 first_message_id, u32 messages_count, u32 data_size)
MessageBlockMeta* block_meta)
{
_PacketHeader_construct(GetMessageBlockResponse);
zeroStruct(ptr);
ptr->first_message_id = first_message_id;
ptr->messages_count = messages_count;
ptr->data_size = data_size;
memcpy(&ptr->block_meta, block_meta, sizeof(MessageBlockMeta));
}

View File

@@ -60,11 +60,11 @@ Result(void) ClientHandshake_tryConstruct(ClientHandshake* ptr, PacketHeader* he
typedef struct ServerHandshake {
u64 session_id;
i64 session_id;
} ALIGN_PACKET_STRUCT ServerHandshake;
void ServerHandshake_construct(ServerHandshake* ptr, PacketHeader* header,
u64 session_id);
i64 session_id);
typedef enum ServerPublicInfo {
@@ -99,12 +99,12 @@ Result(void) LoginRequest_tryConstruct(LoginRequest* ptr, PacketHeader* header,
typedef struct LoginResponse {
u64 user_id;
u64 landing_channel_id;
i64 user_id;
i64 landing_channel_id;
} ALIGN_PACKET_STRUCT LoginResponse;
void LoginResponse_construct(LoginResponse* ptr, PacketHeader* header,
u64 user_id, u64 landing_channel_id);
i64 user_id, i64 landing_channel_id);
typedef struct RegisterRequest {
@@ -117,49 +117,80 @@ Result(void) RegisterRequest_tryConstruct(RegisterRequest* ptr, PacketHeader* he
typedef struct RegisterResponse {
u64 user_id;
i64 user_id;
} ALIGN_PACKET_STRUCT RegisterResponse;
void RegisterResponse_construct(RegisterResponse* ptr, PacketHeader* header,
u64 user_id);
i64 user_id);
typedef struct MessageMeta {
Magic32 magic;
u16 data_size;
i64 id;
i64 sender_id;
DateTime timestamp; /* UTC */
} ALIGN_PACKET_STRUCT MessageMeta;
#define MESSAGE_MAGIC ((Magic32){ .bytes = { 'M', 's', 'g', 'S' } })
typedef struct MessageBlockMeta {
i64 first_message_id;
u32 message_count;
u32 data_size;
} ALIGN_PACKET_STRUCT MessageBlockMeta;
/// @brief write msg_meta and msg_meta->data_size bytes from msg_content to buffer
/// @param block_meta set to {0} if block is empty yet
/// @param block_free_part .data and .len are adjusted to point to free part
/// @return amount of bytes written to block, may be 0 if msg_meta and msg_content don't fit
Result(u32) MessageBlock_writeMessage(
MessageMeta* msg_meta, Array(u8) msg_content,
MessageBlockMeta* block_meta, Array(u8)* block_free_part);
/// @brief read message meta and content from buffer
/// @param block_unread_part .data and .len are adjusted to point to unread part
/// @param msg_content .len must be >= MESSAGE_SIZE_MAX
/// @return amount of bytes read from block, may be 0 if it doesn't start with MESSAGE_MAGIC
Result(u32) MessageBlock_readMessage(
Array(u8)* block_unread_part,
MessageMeta* msg_meta, Array(u8) msg_content);
typedef struct SendMessageRequest {
u64 channel_id;
i64 channel_id;
u16 data_size;
/* stream of size data_size */
} ALIGN_PACKET_STRUCT SendMessageRequest;
void SendMessageRequest_construct(SendMessageRequest* ptr, PacketHeader* header,
u64 channel_id, u16 data_size);
i64 channel_id, u16 data_size);
typedef struct SendMessageResponse {
u64 message_id;
DateTime receiving_time_utc;
i64 message_id;
DateTime timestamp; /* UTC */
} ALIGN_PACKET_STRUCT SendMessageResponse;
void SendMessageResponse_construct(SendMessageResponse* ptr, PacketHeader* header,
u64 message_id, DateTime receiving_time_utc);
i64 message_id, DateTime timestamp);
typedef struct GetMessageBlockRequest {
u64 channel_id;
u64 first_message_id;
u32 messages_count;
i64 channel_id;
i64 first_message_id;
u32 message_count;
} ALIGN_PACKET_STRUCT GetMessageBlockRequest;
void GetMessageBlockRequest_construct(GetMessageBlockRequest* ptr, PacketHeader* header,
u64 channel_id, u64 first_message_id, u32 messages_count);
i64 channel_id, i64 first_message_id, u32 message_count);
typedef struct GetMessageBlockResponse {
u64 first_message_id;
u32 messages_count;
u32 data_size;
/* stream of size data_size : ((sequence MessageMeta), (sequence binary-data)) */
MessageBlockMeta block_meta;
/* stream of size data_size : sequence (MessageMeta, byte[MessageMeta.data_size]) */
} ALIGN_PACKET_STRUCT GetMessageBlockResponse;
void GetMessageBlockResponse_construct(GetMessageBlockResponse* ptr, PacketHeader* header,
u64 first_message_id, u32 messages_count, u32 data_size);
MessageBlockMeta* block_meta);

View File

@@ -1,186 +0,0 @@
#include "server_internal.h"
#include "tlibc/string/StringBuilder.h"
#include "tlibc/filesystem.h"
void Channel_free(Channel* self){
if(!self)
return;
str_destroy(self->name);
str_destroy(self->description);
List_MessageBlockMeta_destroy(&self->messages.blocks_meta_list);
LList_MessageBlock_destroy(&self->messages.blocks_queue);
free(self);
}
Result(Channel*) Channel_create(u64 chan_id, str name, str description,
IncrementalDB* db, bool lock_db)
{
Deferral(8);
Channel* self = (Channel*)malloc(sizeof(Channel));
zeroStruct(self);
bool success = false;
Defer(if(!success) Channel_free(self));
self->id = chan_id;
try_assert(name.len >= CHANNEL_NAME_SIZE_MIN && name.len <= CHANNEL_NAME_SIZE_MAX);
self->name = str_copy(name);
try_assert(description.len <= CHANNEL_DESC_SIZE_MAX);
self->description = str_copy(description);
if(lock_db){
idb_lockDB(db);
Defer(idb_unlockDB(db));
}
StringBuilder sb = StringBuilder_alloc(CHANNEL_NAME_SIZE_MAX + 32 + 1);
Defer(StringBuilder_destroy(&sb));
StringBuilder_append_str(&sb, STR("channels"));
StringBuilder_append_char(&sb, path_sep);
StringBuilder_append_str(&sb, name);
str subdir = str_copy(StringBuilder_getStr(&sb));
Defer(str_destroy(subdir));
str message_blocks_str = STR("message_blocks");
str message_blocks_meta_str = STR("message_blocks_meta");
StringBuilder_removeFromEnd(&sb, -1);
StringBuilder_append_str(&sb, name);
StringBuilder_append_char(&sb, '_');
StringBuilder_append_str(&sb, message_blocks_str);
try(self->messages.blocks_table, p,
idb_getOrCreateTable(db, subdir, StringBuilder_getStr(&sb), sizeof(MessageBlock), false)
);
idb_lockTable(self->messages.blocks_table);
Defer(idb_unlockTable(self->messages.blocks_table));
StringBuilder_removeFromEnd(&sb, message_blocks_str.len);
StringBuilder_append_str(&sb, message_blocks_meta_str);
try(self->messages.blocks_meta_table, p,
idb_getOrCreateTable(db, subdir, StringBuilder_getStr(&sb), sizeof(MessageBlockMeta), false)
);
idb_lockTable(self->messages.blocks_meta_table);
Defer(idb_unlockTable(self->messages.blocks_meta_table));
// load whole message_blocks_meta table to list
try_void(
idb_createListFromTable(self->messages.blocks_meta_table, (void*)&self->messages.blocks_meta_list, false)
);
// load N last blocks to the queue
self->messages.blocks_queue = LList_construct(MessageBlock, NULL);
u64 message_blocks_count = self->messages.blocks_meta_list.len;
u64 first_block_id = 0;
if(message_blocks_count > MESSAGE_BLOCKS_CACHE_COUNT)
first_block_id = message_blocks_count - MESSAGE_BLOCKS_CACHE_COUNT;
for(u64 id = first_block_id; id < message_blocks_count; id++){
LLNode(MessageBlock)* node = LLNode_MessageBlock_createZero();
LList_MessageBlock_insertAfter(
&self->messages.blocks_queue,
self->messages.blocks_queue.last,
node
);
try_void(idb_getRow(self->messages.blocks_table, id, node->value.data, false));
}
if(self->messages.blocks_meta_list.len > 0){
MessageBlockMeta last_block_meta = self->messages.blocks_meta_list.data[self->messages.blocks_meta_list.len - 1];
self->messages.count = last_block_meta.first_message_id + last_block_meta.messages_count - 1;
}
else {
self->messages.count = 0;
}
success = true;
Return RESULT_VALUE(p, self);
}
void Channel_unloadExcessBlocks(Channel* self){
while(self->messages.blocks_queue.count > MESSAGE_BLOCKS_CACHE_COUNT){
LLNode(MessageBlock)* node = self->messages.blocks_queue.first;
LList_MessageBlock_detatch(&self->messages.blocks_queue, node);
free(node);
}
}
Result(void) Channel_saveMessage(Channel* self, Array(u8) message_data, u64 sender_id,
MessageMeta* out_message_meta, bool lock_tables)
{
Deferral(4);
if(lock_tables){
idb_lockTable(self->messages.blocks_table);
idb_lockTable(self->messages.blocks_meta_table);
Defer(
idb_unlockTable(self->messages.blocks_table);
idb_unlockTable(self->messages.blocks_meta_table);
);
}
// create new block if message won't fit in the last existing
MessageBlockMeta* incomplete_block_meta = self->messages.blocks_meta_list.data + self->messages.blocks_meta_list.len;
u64 new_message_id = incomplete_block_meta->first_message_id + incomplete_block_meta->messages_count;
u32 message_size_in_block = sizeof(MessageMeta) + ALIGN_TO(message_data.len, 8);
if(incomplete_block_meta->data_size + message_size_in_block > MESSAGE_BLOCK_SIZE){
// create new MessageBlockMeta
incomplete_block_meta = List_MessageBlockMeta_expand(&self->messages.blocks_meta_list, 1);
incomplete_block_meta->first_message_id = new_message_id;
incomplete_block_meta->messages_count = 0;
incomplete_block_meta->data_size = 0;
// create new MessageBlock
LList_MessageBlock_insertAfter(
&self->messages.blocks_queue,
self->messages.blocks_queue.last,
LLNode_MessageBlock_createZero());
// unload old blocks from cache
Channel_unloadExcessBlocks(self);
}
// create message meta
out_message_meta->magic = MESSAGE_MAGIC;
out_message_meta->data_size = message_data.len;
out_message_meta->id = new_message_id;
out_message_meta->sender_id = sender_id;
DateTime_getUTC(&out_message_meta->receiving_time_utc);
// copy message data to message block
MessageBlock* incomplete_block = &self->messages.blocks_queue.last->value;
u8* data_ptr = incomplete_block->data + incomplete_block_meta->data_size;
memcpy(data_ptr, out_message_meta, sizeof(MessageMeta));
data_ptr += sizeof(MessageMeta);
memcpy(data_ptr, message_data.data, message_data.len);
incomplete_block_meta->data_size += sizeof(MessageMeta) + ALIGN_TO(message_data.len, 8);
incomplete_block_meta->messages_count++;
// save to DB
try_void(idb_pushRow(self->messages.blocks_meta_table, incomplete_block_meta, false));
try_void(idb_pushRow(self->messages.blocks_table, incomplete_block, false));
Return RESULT_VOID;
}
Result(void) Channel_loadMessageBlock(Channel* self, u64 fisrt_message_id, u32 count,
MessageBlockMeta* out_meta, NULLABLE(Array(u8)*) out_block, bool lock_tables)
{
Deferral(4);
if(lock_tables){
idb_lockTable(self->messages.blocks_table);
idb_lockTable(self->messages.blocks_meta_table);
Defer(
idb_unlockTable(self->messages.blocks_table);
idb_unlockTable(self->messages.blocks_meta_table);
);
}
// TODO: Maybe it's better to request message block id directly? Client doesn't know how much bytes `count` messages will take, this can lead to severe lags on slow internet
// TODO: binary search in list of blocks meta
// TODO: return if out_block == NULL
// TODO: check if block is in N_LAST_BLOCKS
// TODO: load block
// TODO: insert block in queue and keep it sorted
Return RESULT_VOID;
}

View File

@@ -1,11 +1,14 @@
#include "server/server_internal.h"
#include "network/tcp-chat-protocol/v1.h"
void ClientConnection_close(ClientConnection* conn){
if(!conn)
return;
tsqlite_connection_close(conn->db);
EncryptedSocketTCP_destroy(&conn->sock);
Array_u8_destroy(&conn->session_key);
Array_u8_destroy(&conn->message_block);
Array_u8_destroy(&conn->message_content);
CommonQueries_free(conn->queries.common);
free(conn);
}
@@ -21,10 +24,17 @@ Result(ClientConnection*) ClientConnection_accept(ConnectionHandlerArgs* args)
conn->server = args->server;
conn->client_end = args->client_end;
conn->session_id = args->session_id;
conn->authorized = false;
conn->user_id = -1;
conn->session_key = Array_u8_alloc(AES_SESSION_KEY_SIZE);
// buffers
conn->message_block = Array_u8_alloc(MESSAGE_BLOCK_COUNT_MAX * (sizeof(MessageMeta) + MESSAGE_SIZE_MAX));
conn->message_content = Array_u8_alloc(MESSAGE_SIZE_MAX);
// database
try(conn->db, p, tsqlite_connection_open(args->server->db_path));
try(conn->queries.common, p, CommonQueries_compile(conn->db));
// correct session key will be received from client later
conn->session_key = Array_u8_alloc(AES_SESSION_KEY_SIZE);
Array_u8_memset(&conn->session_key, 0);
EncryptedSocketTCP_construct(&conn->sock, args->accepted_socket_tcp, NETWORK_BUFFER_SIZE, conn->session_key);
try_void(socket_TCP_enableAliveChecks_default(args->accepted_socket_tcp));
@@ -57,4 +67,4 @@ Result(ClientConnection*) ClientConnection_accept(ConnectionHandlerArgs* args)
success = true;
Return RESULT_VALUE(p, conn);
}
}

116
src/server/db/Channel.c Normal file
View File

@@ -0,0 +1,116 @@
#include "db_internal.h"
Result(bool) Channel_exists(CommonQueries* q, i64 id){
Deferral(1);
tsqlite_statement* st = q->channels.exists;
try_void(tsqlite_statement_reset(st));
try_void(tsqlite_statement_bind_i64(st, "id", id));
try(bool has_result, i, tsqlite_statement_step(st));
Return RESULT_VALUE(i, has_result);
}
Result(void) Channel_createOrUpdate(CommonQueries* q,
i64 id, str name, str description)
{
Deferral(4);
try_assert(id > 0);
try_assert(name.len >= CHANNEL_NAME_SIZE_MIN && name.len <= CHANNEL_NAME_SIZE_MAX);
try_assert(description.len <= CHANNEL_DESC_SIZE_MAX);
// create channels table
try_void(tsqlite_statement_reset(q->channels.create_table));
try_void(tsqlite_statement_bind_i64(q->channels.create_table, "name_max", CHANNEL_NAME_SIZE_MAX));
try_void(tsqlite_statement_bind_i64(q->channels.create_table, "desc_max", CHANNEL_DESC_SIZE_MAX));
try_void(tsqlite_statement_step(q->channels.create_table));
// create messages table
try_void(tsqlite_statement_reset(q->messages.create_table));
try_void(tsqlite_statement_step(q->messages.create_table));
try(bool channel_exists, i, Channel_exists(q, id));
if(channel_exists){
// update existing channel
try_void(tsqlite_statement_reset(q->channels.update));
try_void(tsqlite_statement_bind_i64(q->channels.update, "id", id));
try_void(tsqlite_statement_bind_str(q->channels.update, "name", str_copy(name), free));
try_void(tsqlite_statement_bind_str(q->channels.update, "description", str_copy(description), free));
try_void(tsqlite_statement_step(q->channels.update));
}
else {
// insert new channel
try_void(tsqlite_statement_reset(q->channels.insert));
try_void(tsqlite_statement_bind_i64(q->channels.insert, "id", id));
try_void(tsqlite_statement_bind_str(q->channels.insert, "name", str_copy(name), free));
try_void(tsqlite_statement_bind_str(q->channels.insert, "description", str_copy(description), free));
try_void(tsqlite_statement_step(q->channels.insert));
}
Return RESULT_VOID;
}
Result(void) Channel_saveMessage(CommonQueries* q,
i64 channel_id, i64 sender_id, Array(u8) content,
DateTime* out_timestamp)
{
Deferral(1);
try_assert(content.len >= MESSAGE_SIZE_MIN && content.len <= MESSAGE_SIZE_MAX);
tsqlite_statement* st = q->messages.insert;
try_void(tsqlite_statement_reset(st));
try_void(tsqlite_statement_bind_i64(st, "channel_id", channel_id));
try_void(tsqlite_statement_bind_i64(st, "sender_id", sender_id));
try_void(tsqlite_statement_bind_blob(st, "content", Array_u8_copy(content), free));
try(bool has_result, i, tsqlite_statement_step(st));
try_assert(has_result);
try(i64 message_id, i, tsqlite_statement_getResult_i64(st));
str timestamp_str;
try_void(tsqlite_statement_getResult_str(st, &timestamp_str));
try_void(DateTime_parse(timestamp_str.data, out_timestamp));
Return RESULT_VALUE(i, message_id);
}
Result(void) Channel_loadMessageBlock(CommonQueries* q,
i64 channel_id, i64 first_message_id, u32 count,
MessageBlockMeta* block_meta, Array(u8) block_data)
{
Deferral(1);
try_assert(channel_id > 0);
try_assert(block_data.len >= count * (sizeof(MessageMeta) + MESSAGE_SIZE_MAX));
if(count == 0){
Return RESULT_VOID;
}
tsqlite_statement* st = q->messages.get_block;
try_void(tsqlite_statement_reset(st));
try_void(tsqlite_statement_bind_i64(st, "channel_id", channel_id));
try_void(tsqlite_statement_bind_i64(st, "first_message_id", first_message_id));
try_void(tsqlite_statement_bind_i64(st, "count", count));
zeroStruct(block_meta);
MessageMeta msg_meta = {0};
Array(u8) msg_content;
str tmp_str = str_null;
while(true){
try(bool has_result, i, tsqlite_statement_step(st));
if(!has_result)
break;
// id
try(msg_meta.id, i, tsqlite_statement_getResult_i64(st));
// sender_id
try(msg_meta.sender_id, i, tsqlite_statement_getResult_i64(st));
// content
try_void(tsqlite_statement_getResult_blob(st, &msg_content));
// timestamp
try_void(tsqlite_statement_getResult_str(st, &tmp_str));
try_void(DateTime_parse(tmp_str.data, &msg_meta.timestamp));
try(u32 write_n, u, MessageBlock_writeMessage(&msg_meta, msg_content, block_meta, &block_data));
try_assert(write_n > 0);
}
Return RESULT_VOID;
}

View File

@@ -0,0 +1,82 @@
#include "db_internal.h"
void CommonQueries_free(CommonQueries* q){
if(!q)
return;
tsqlite_statement_free(q->channels.create_table);
tsqlite_statement_free(q->channels.insert);
tsqlite_statement_free(q->channels.exists);
tsqlite_statement_free(q->channels.update);
tsqlite_statement_free(q->messages.create_table);
tsqlite_statement_free(q->messages.insert);
tsqlite_statement_free(q->messages.get_block);
free(q);
}
Result(void) CommonQueries_compile(tsqlite_connection* db){
Deferral(4);
CommonQueries* q = (CommonQueries*)malloc(sizeof(*q));
zeroStruct(q);
bool success = false;
Defer(if(!success) CommonQueries_free(q));
///////////////////////////////////////////////////////////////////////////
// CHANNELS //
///////////////////////////////////////////////////////////////////////////
try(q->channels.create_table, p, tsqlite_statement_compile(db, STR(
"CREATE TABLE IF NOT EXISTS channels (\n"
" id BIGINT PRIMARY KEY,\n"
" name VARCHAR($name_max) NOT NULL,\n"
" description VARCHAR($desc_max) NOT NULL\n"
");"
)));
try(q->channels.insert, p, tsqlite_statement_compile(db, STR(
"INSERT INTO\n"
"channels (id, name, description)\n"
"VALUES ($id, $name, $description);"
)));
try(q->channels.exists, p, tsqlite_statement_compile(db, STR(
"SELECT 1 FROM channels WHERE id = $id;"
)));
try(q->channels.update, p, tsqlite_statement_compile(db, STR(
"UPDATE channels\n"
"SET name = $name, description = $description\n"
"WHERE id = $id;"
)));
///////////////////////////////////////////////////////////////////////////
// MESSAGES //
///////////////////////////////////////////////////////////////////////////
try(q->messages.create_table, p, tsqlite_statement_compile(db, STR(
"CREATE TABLE IF NOT EXISTS messages (\n"
" id BIGINT PRIMARY KEY,\n"
" channel_id BIGINT NOT NULL REFERENCES channels(id)\n"
" sender_id BIGINT NOT NULL REFERENCES users(id),\n"
" content BLOB NOT NULL,\n"
" timestamp DATETIME NOT NULL DEFAULT (\n"
" strftime('"MESSAGE_TIMESTAMP_FMT_SQL"', 'now', 'utc', 'subsecond')\n"
" )\n"
");"
)));
try(q->messages.insert, p, tsqlite_statement_compile(db, STR(
"INSERT INTO\n"
"messages (channel_id, sender_id, content)\n"
"VALUES ($channel_id, $sender_id, $content)\n"
"RETURNING id, timestamp;"
)));
try(q->messages.get_block, p, tsqlite_statement_compile(db, STR(
"SELECT id, sender_id, content, timestamp FROM messages\n"
"WHERE id >= $first_message_id\n"
"AND channel_id = $channel_id\n"
"LIMIT $count;"
)));
success = true;
Return RESULT_VALUE(p, q);
}

3
src/server/db/User.c Normal file
View File

@@ -0,0 +1,3 @@
#include "db.h"

39
src/server/db/db.h Normal file
View File

@@ -0,0 +1,39 @@
#pragma once
#include "tsqlite.h"
#include "network/tcp-chat-protocol/v1.h"
// typedef struct ChannelInfo {
// i64 id;
// str name;
// str description;
// } ChannelInfo;
typedef struct CommonQueries CommonQueries;
Result(CommonQueries*) CommonQueries_compile(tsqlite_connection* db);
void CommonQueries_free(CommonQueries* self);
Result(bool) Channel_exists(CommonQueries* q, i64 id);
Result(void) Channel_createOrUpdate(CommonQueries* q,
i64 id, str name, str description);
/// @return new message id
Result(i64) Channel_saveMessage(CommonQueries* q,
i64 channel_id, i64 sender_id, Array(u8) content,
DateTime* out_timestamp_utc);
/// @brief try to find `count` messages starting from `first_message_id`
/// @param out_meta writes here information about found messages, .count can be 0 if no messages found
/// @param out_block .len must be >= count * (sizeof(MessageMeta) + MESSAGE_SIZE_MAX)
Result(void) Channel_loadMessageBlock(CommonQueries* q,
i64 channel_id, i64 first_message_id, u32 count,
MessageBlockMeta* out_block_meta, Array(u8) out_block_data);
/// @return existing user id or 0
Result(i64) User_getIdForUsername(CommonQueries* q, str username);
/// @return new user id
Result(i64) User_register(CommonQueries* q, str username, Array(u8) token);
Result(bool) User_tryAuthorize(CommonQueries* q, u64 id, Array(u8) token);

View File

@@ -0,0 +1,33 @@
#pragma once
#include "db.h"
typedef struct CommonQueries {
struct {
/* (name_max, desc_max) -> void */
tsqlite_statement* create_table;
/* (id, name, description) -> void */
tsqlite_statement* insert;
/* (id) -> bool */
tsqlite_statement* exists;
/* (id, name, description) -> void */
tsqlite_statement* update;
} channels;
struct {
/* () -> void */
tsqlite_statement* create_table;
/* (channel_id, sender_id, content) -> (id, timestamp) */
tsqlite_statement* insert;
/* (channel_id, first_message_id, count) -> [(id, sender_id, content, timestamp)] */
tsqlite_statement* get_block;
} messages;
struct {
tsqlite_statement* create_table;
tsqlite_statement* registration_begin;
tsqlite_statement* registration_end;
tsqlite_statement* get_credentials;
tsqlite_statement* get_public_info;
tsqlite_statement* update;
} users;
} CommonQueries;
#define MESSAGE_TIMESTAMP_FMT_SQL "%Y.%m.%d-%H:%M:%f"

View File

@@ -20,28 +20,38 @@ declare_RequestHandler(GetMessageBlock)
LogSeverity_Warn, STR("not authorized") ));
Return RESULT_VOID;
}
// validate message_count
if(req.message_count < 1 || req.message_count > MESSAGE_BLOCK_COUNT_MAX){
try_void(sendErrorMessage(log_ctx, conn, res_head,
LogSeverity_Warn, STR("invalid message count in request") ));
Return RESULT_VOID;
}
// get message block from channel
Channel* ch = Server_tryGetChannel(conn->server, req.channel_id);
if(ch == NULL){
// validate channel id
try(bool channel_exists, i, Channel_exists(conn->queries.common, req.channel_id));
if(!channel_exists){
try_void(sendErrorMessage(log_ctx, conn, res_head,
LogSeverity_Warn, STR("invalid channel id") ));
Return RESULT_VOID;
}
MessageBlockMeta meta;
Array(u8) block_data;
try_void(Channel_loadMessageBlock(ch, req.first_message_id, req.messages_count,
&meta, &block_data, true));
Defer(Array_u8_destroy(&block_data));
// reset block meta
zeroStruct(&conn->message_block_meta);
// get message block from channel
try_void(Channel_loadMessageBlock(conn->queries.common,
req.channel_id, req.first_message_id, req.message_count,
&conn->message_block_meta, conn->message_block));
// send response
GetMessageBlockResponse res;
GetMessageBlockResponse_construct(&res, res_head,
meta.first_message_id, meta.messages_count, meta.data_size);
GetMessageBlockResponse_construct(&res, res_head, &conn->message_block_meta);
try_void(EncryptedSocketTCP_sendStruct(&conn->sock, res_head));
try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &res));
if(block_data.len != 0){
try_void(EncryptedSocketTCP_send(&conn->sock, block_data));
if(conn->message_block_meta.data_size != 0){
try_void(EncryptedSocketTCP_send(&conn->sock,
Array_u8_sliceTo(conn->message_block, conn->message_block_meta.data_size))
);
}
Return RESULT_VOID;

View File

@@ -22,8 +22,8 @@ declare_RequestHandler(Login)
}
// validate username
str username_str = str_null;
str name_error_str = validateUsername_cstr(req.username, &username_str);
str username = str_null;
str name_error_str = validateUsername_cstr(req.username, &username);
if(name_error_str.data){
Defer(str_destroy(name_error_str));
try_void(sendErrorMessage(log_ctx, conn, res_head,
@@ -31,42 +31,28 @@ declare_RequestHandler(Login)
Return RESULT_VOID;
}
// lock users cache
idb_lockTable(srv->users.table);
bool unlocked_users_cache_mutex = false;
Defer(
if(!unlocked_users_cache_mutex)
idb_unlockTable(srv->users.table)
);
// try get id from name cache
u64* id_ptr = HashMap_tryGetPtr(&srv->users.name_id_map, username_str);
if(id_ptr == NULL){
// get user by id
try(u64 user_id, i, User_getIdForUsername(conn->queries.common, username));
if(user_id == 0){
try_void(sendErrorMessage(log_ctx, conn, res_head,
LogSeverity_Warn, STR("Username is not registered") ));
Return RESULT_VOID;
}
u64 user_id = *id_ptr;
// get user by id
try_assert(user_id < srv->users.list.len);
UserInfo* u = srv->users.list.data + user_id;
// TODO: get user token
Array(u8) token = Array_u8_construct(req.token, sizeof(req.token));
try(bool authorized, i, User_tryAuthorize(conn->queries.common, user_id, token));
// validate token hash
if(memcmp(req.token, u->token, sizeof(req.token)) != 0){
if(!authorized){
try_void(sendErrorMessage(log_ctx, conn, res_head,
LogSeverity_Warn, STR("wrong password") ));
Return RESULT_VOID;
}
// manually unlock mutex
idb_unlockTable(srv->users.table);
unlocked_users_cache_mutex = true;
// authorize
conn->authorized = true;
conn->user_id = user_id;
logInfo("authorized user '%s'", username_str.data);
logInfo("authorized user '%s'", username.data);
// send response
LoginResponse res;

View File

@@ -22,8 +22,8 @@ declare_RequestHandler(Register)
}
// validate username
str username_str = str_null;
str name_error_str = validateUsername_cstr(req.username, &username_str);
str username = str_null;
str name_error_str = validateUsername_cstr(req.username, &username);
if(name_error_str.data){
Defer(str_destroy(name_error_str));
try_void(sendErrorMessage(log_ctx, conn, res_head,
@@ -31,46 +31,18 @@ declare_RequestHandler(Register)
Return RESULT_VOID;
}
// lock users cache
idb_lockTable(srv->users.table);
bool unlocked_users_cache_mutex = false;
// unlock mutex on error catch
Defer(
if(!unlocked_users_cache_mutex)
idb_unlockTable(srv->users.table)
);
// check if name is taken
if(HashMap_tryGetPtr(&srv->users.name_id_map, username_str) != NULL){
try(u64 user_id, i, User_getIdForUsername(conn->queries.common, username));
if(user_id != 0){
try_void(sendErrorMessage(log_ctx, conn, res_head,
LogSeverity_Warn, STR("Username already exists") ));
LogSeverity_Warn, STR("Username is already taken") ));
Return RESULT_VOID;
}
// initialize new user
UserInfo user;
zeroStruct(&user);
user.name_len = username_str.len;
memcpy(user.name, username_str.data, user.name_len);
memcpy(user.token, req.token, sizeof(req.token));
DateTime_getUTC(&user.registration_time_utc);
// save new user to db and cache
try(u64 user_id, u,
idb_pushRow(srv->users.table, &user, false)
);
try_assert(user_id == srv->users.list.len);
List_UserInfo_pushMany(&srv->users.list, &user, 1);
try_assert(HashMap_tryPush(&srv->users.name_id_map, username_str, &user_id));
// manually unlock mutex
idb_unlockTable(srv->users.table);
unlocked_users_cache_mutex = true;
logInfo("registered user '%s'", username_str.data);
// register new user
Array(u8) token = Array_u8_construct(req.token, sizeof(req.token));
try(user_id, i, User_register(conn->queries.common, username, token));
logInfo("registered user '"FMT_str"'", str_expand(username));
// send response
RegisterResponse res;

View File

@@ -21,6 +21,7 @@ declare_RequestHandler(SendMessage)
Return RESULT_VOID;
}
// validate content size
if(req.data_size < MESSAGE_SIZE_MIN || req.data_size > MESSAGE_SIZE_MAX){
try_void(sendErrorMessage(log_ctx, conn, res_head,
LogSeverity_Warn, STR("invalid message size") ));
@@ -29,24 +30,26 @@ declare_RequestHandler(SendMessage)
}
// receive message data
Array(u8) message_data = Array_u8_alloc(req.data_size);
try_void(EncryptedSocketTCP_recv(&conn->sock, message_data, SocketRecvFlag_WholeBuffer));
try_void(EncryptedSocketTCP_recv(&conn->sock, conn->message_content, SocketRecvFlag_WholeBuffer));
// save message to channel
Channel* ch = Server_tryGetChannel(conn->server, req.channel_id);
if(ch == NULL){
// validate channel id
try(bool channel_exists, i, Channel_exists(conn->queries.common, req.channel_id));
if(!channel_exists){
try_void(sendErrorMessage(log_ctx, conn, res_head,
LogSeverity_Warn, STR("invalid channel id") ));
Return RESULT_VOID;
}
MessageMeta message_meta;
try_void(Channel_saveMessage(ch, message_data, conn->user_id,
&message_meta, true));
// save message to channel
DateTime timestamp;
try(i64 message_id, i, Channel_saveMessage(conn->queries.common,
req.channel_id, conn->user_id, conn->message_content,
&timestamp));
// send response
SendMessageResponse res;
SendMessageResponse_construct(&res, res_head,
message_meta.id, message_meta.receiving_time_utc);
message_id, timestamp);
try_void(EncryptedSocketTCP_sendStruct(&conn->sock, res_head));
try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &res));

View File

@@ -1,8 +1,6 @@
#pragma once
#include "network/tcp-chat-protocol/v1.h"
#include "server/server_internal.h"
Result(void) sendErrorMessage(
cstr log_ctx, ClientConnection* conn, PacketHeader* res_head,
LogSeverity log_severity, str msg);

View File

@@ -4,7 +4,6 @@
#include "tlibc/base64.h"
#include "tlibc/algorithms.h"
#include "server/server_internal.h"
#include "network/tcp-chat-protocol/v1.h"
#include "server/responses/responses.h"
#include "tlibtoml.h"
@@ -20,17 +19,8 @@ void Server_free(Server* self){
RSA_destroyPrivateKey(&self->rsa_sk);
RSA_destroyPublicKey(&self->rsa_pk);
idb_close(self->db);
List_UserInfo_destroy(&self->users.list);
HashMap_destroy(&self->users.name_id_map);
for(u32 i = 0; i < self->channels.list.len; i++){
Channel_free(self->channels.list.data[i]);
}
List_ChannelPtr_destroy(&self->channels.list);
HashMap_destroy(&self->channels.name_id_map);
free(self->db_path);
tsqlite_connection_close(self->db);
free(self);
}
@@ -58,78 +48,49 @@ Result(Server*) Server_create(str config_file_content, cstr config_file_name,
try(TomlTable* config_top, p, toml_load_str_filename(config_file_content, config_file_name));
Defer(TomlTable_free(config_top));
// [server]
try(TomlTable* config_server, p, TomlTable_get_table(config_top, STR("server")))
// parse name
// name
try(str* v_name, p, TomlTable_get_str(config_server, STR("name")));
self->name = str_copy(*v_name);
// parse description
// description
try(str* v_desc, p, TomlTable_get_str(config_server, STR("description")));
self->description = str_copy(*v_desc);
// parse local_address
// local_address
try(str* v_local_address, p, TomlTable_get_str(config_server, STR("local_address")));
try_assert(v_local_address->isZeroTerminated);
try_void(EndpointIPv4_parse(v_local_address->data, &self->local_end));
// parse landing_channel_id
// landing_channel_id
try(i64 v_landing_channel_id, i, TomlTable_get_integer(config_server, STR("landing_channel_id")));
self->landing_channel_id = v_landing_channel_id;
// [keys]
try(TomlTable* config_keys, p, TomlTable_get_table(config_top, STR("keys")))
// parse rsa_private_key
// rsa_private_key
try(str* v_rsa_sk, p, TomlTable_get_str(config_keys, STR("rsa_private_key")));
try_assert(v_rsa_sk->isZeroTerminated);
try_void(RSA_parsePrivateKey_base64(v_rsa_sk->data, &self->rsa_sk));
// parse rsa_public_key
// rsa_public_key
try(str* v_rsa_pk, p, TomlTable_get_str(config_keys, STR("rsa_public_key")));
try_assert(v_rsa_pk->isZeroTerminated);
try_void(RSA_parsePublicKey_base64(v_rsa_pk->data, &self->rsa_pk));
// [db]
try(TomlTable* config_db, p, TomlTable_get_table(config_top, STR("database")))
// parse db_aes_key
try(str* v_db_aes_key, p, TomlTable_get_str(config_db, STR("aes_key")));
str db_aes_key_s = *v_db_aes_key;
Array(u8) db_aes_key = Array_u8_alloc(base64_decodedSize(db_aes_key_s.data, db_aes_key_s.len));
Defer(free(db_aes_key.data));
base64_decode(db_aes_key_s.data, db_aes_key_s.len, db_aes_key.data);
// parse db_dir
try(str* v_db_dir, p, TomlTable_get_str(config_db, STR("dir")));
// path
try(str* v_db_path, p, TomlTable_get_str(config_db, STR("path")));
self->db_path = str_copy(*v_db_path).data;
// open DB
try(self->db, p, idb_open(*v_db_dir, db_aes_key));
try(self->db, p, tsqlite_connection_open(self->db_path));
// build users cache
logInfo("loading users...");
try(self->users.table, p,
idb_getOrCreateTable(self->db, str_null, STR("users"), sizeof(UserInfo), false)
);
// load whole users table to list
try_void(
idb_createListFromTable(self->users.table, (void*)&self->users.list, false)
);
// build name-id map
try(u64 users_count, u, idb_getRowCount(self->users.table, false));
HashMap_construct(&self->users.name_id_map, u64, NULL);
for(u64 id = 0; id < users_count; id++){
UserInfo* user = self->users.list.data + id;
str key = str_construct(user->name, user->name_len, true);
if(!HashMap_tryPush(&self->users.name_id_map, key, &id)){
Return RESULT_ERROR_FMT("duplicate user name '"FMT_str"'", str_expand(key));
}
}
logInfo("loaded "FMT_u64" users", users_count);
// parse channels
// [channels]
logDebug("loading channels...");
HashMap_construct(&self->channels.name_id_map, u64, NULL);
self->channels.list = List_ChannelPtr_alloc(32);
try(TomlTable* config_channels, p, TomlTable_get_table(config_top, STR("channels")));
HashMapIter channels_iter = HashMapIter_create(config_channels);
while(HashMapIter_moveNext(&channels_iter)){
HashMapKeyValue kv;
@@ -142,23 +103,12 @@ Result(Server*) Server_create(str config_file_content, cstr config_file_name,
logInfo("loading channel '"FMT_str"'", str_expand(name))
TomlTable* config_channel = val->table;
try(u64 id, u, TomlTable_get_integer(config_channel, STR("id")));
try(i64 id, u, TomlTable_get_integer(config_channel, STR("id")));
try(str* v_ch_desc, p, TomlTable_get_str(config_channel, STR("description")))
str description = *v_ch_desc;
if(!HashMap_tryPush(&self->channels.name_id_map, name, &id)){
Return RESULT_ERROR_FMT("duplicate channel '"FMT_str"'", str_expand(name));
}
logDebug("loading messages...");
try(Channel* channel, p, Channel_create(id, name, description, self->db, false));
logDebug("loaded "FMT_u64" messages", channel->messages.count);
List_ChannelPtr_push(&self->channels.list, channel);
try_void(Channel_createOrUpdate(self->server_queries, id, name, description));
}
logDebug("loaded "FMT_u32" channels", self->channels.list.len);
// sort channels list by id
for(u32 i = 0; i < self->channels.list.len; i++);
insertionSort_inline(self->channels.list.data, self->channels.list.len, ->id);
success = true;
Return RESULT_VALUE(p, self);
@@ -184,7 +134,7 @@ Result(void) Server_run(Server* server){
Defer(free(local_end_str.data));
logInfo("server is listening at %s", local_end_str.data);
u64 session_id = 1;
i64 session_id = 1;
while(true){
ConnectionHandlerArgs* args = (ConnectionHandlerArgs*)malloc(sizeof(ConnectionHandlerArgs));
args->server = server;
@@ -201,16 +151,6 @@ Result(void) Server_run(Server* server){
Return RESULT_VOID;
}
Channel* Server_tryGetChannel(Server* self, u64 id){
i32 index;
binarySearch_inline(self->channels.list.data, self->channels.list.len, id, ->id, index);
if(index == -1){
return NULL;
}
Channel* channel = self->channels.list.data[index];
return channel;
}
static void* handleConnection(void* _args){
ConnectionHandlerArgs* args = (ConnectionHandlerArgs*)_args;
Server* server = args->server;

View File

@@ -1,56 +1,14 @@
#pragma once
#include "tlibc/collections/HashMap.h"
#include "tlibc/collections/List.h"
#include "tlibc/collections/LList.h"
#include "tcp-chat/tcp-chat.h"
#include "tcp-chat/server.h"
#include "cryptography/AES.h"
#include "cryptography/RSA.h"
#include "network/encrypted_sockets.h"
#include "db/idb.h"
#include "db/tables.h"
#include "network/tcp-chat-protocol/v1.h"
#include "server/db/db.h"
typedef struct ClientConnection ClientConnection;
List_declare(UserInfo);
List_declare(MessageBlockMeta);
LList_declare(MessageBlock);
#define MESSAGE_BLOCKS_CACHE_COUNT 50
typedef struct Channel {
u64 id;
str name;
str description;
struct {
u64 count;
Table* blocks_table;
Table* blocks_meta_table;
List(MessageBlockMeta) blocks_meta_list; // index is id
// last MESSAGE_BLOCKS_CACHE_COUNT MessageBlocks, ascending
// new messages are written to the last block
LList(MessageBlock) blocks_queue;
} messages;
} Channel;
typedef Channel* ChannelPtr;
List_declare(ChannelPtr);
Result(Channel*) Channel_create(u64 id, str name, str description,
IncrementalDB* db, bool lock_db);
void Channel_free(Channel* self);
Result(void) Channel_saveMessage(Channel* self, Array(u8) message_data, u64 sender_id,
MessageMeta* out_message_meta, bool lock_tables);
/// @brief try to find `count` messages starting from `fisrt_message_id`
/// @param out_meta information about found messages, .count can be 0 if no messages found
/// @param out_block allocates buffer on heap and copies them there, .len can be 0 if no messages found
Result(void) Channel_loadMessageBlock(Channel* self, u64 fisrt_message_id, u32 count,
MessageBlockMeta* out_meta, NULLABLE(Array(u8)*) out_block, bool lock_tables);
typedef struct Server {
/* from constructor */
void* logger;
@@ -59,42 +17,43 @@ typedef struct Server {
/* from config */
str name;
str description;
u64 landing_channel_id;
i64 landing_channel_id;
EndpointIPv4 local_end;
br_rsa_private_key rsa_sk;
br_rsa_public_key rsa_pk;
/* database and cache */
IncrementalDB* db;
struct {
Table* table;
List(UserInfo) list; // index is id
HashMap(u64) name_id_map;
} users;
struct {
List(ChannelPtr) list;
HashMap(u64) name_id_map;
} channels;
/* database and cache*/
char* db_path;
tsqlite_connection* db;
} Server;
Channel* Server_tryGetChannel(Server* self, u64 id);
typedef struct ClientConnection {
Server* server;
u64 session_id;
i64 session_id;
EndpointIPv4 client_end;
Array(u8) session_key;
EncryptedSocketTCP sock;
bool authorized;
u64 user_id; // -1 for unauthorized
i64 user_id; // 0 for unauthorized
/* buffers */
MessageBlockMeta message_block_meta;
Array(u8) message_block;
Array(u8) message_content;
/* database */
tsqlite_connection* db;
struct {
CommonQueries* common;
} queries;
} ClientConnection;
typedef struct ConnectionHandlerArgs {
Server* server;
Socket accepted_socket_tcp;
EndpointIPv4 client_end;
u64 session_id;
i64 session_id;
} ConnectionHandlerArgs;
Result(ClientConnection*) ClientConnection_accept(ConnectionHandlerArgs* args);