implemented channels on server

This commit is contained in:
2025-12-02 20:28:56 +05:00
parent c263d02b36
commit 6d1f450f32
26 changed files with 451 additions and 147 deletions

View File

@@ -43,7 +43,7 @@ void ClientCLI_destroy(ClientCLI* self){
HashMap_destroy(&self->servers.addr_id_map);
}
void ClientCLI_construct(ClientCLI* self){
memset(self, 0, sizeof(*self));
zeroStruct(self);
}
Result(void) ClientCLI_run(ClientCLI* self) {
@@ -326,7 +326,7 @@ static Result(void) ClientCLI_openUserDB(ClientCLI* self){
// Load servers table
try(self->servers.table, p,
idb_getOrCreateTable(self->db, STR("servers"), sizeof(ServerInfo), false)
idb_getOrCreateTable(self->db, str_null, STR("servers"), sizeof(ServerInfo), false)
);
// Lock table until this function returns.
@@ -362,7 +362,7 @@ static Result(ServerInfo*) ClientCLI_saveServerInfo(ClientCLI* self,
// create new server info
ServerInfo server;
memset(&server, 0, sizeof(ServerInfo));
zeroStruct(&server);
// address
if(addr.len > HOSTADDR_SIZE_MAX)
addr.len = HOSTADDR_SIZE_MAX;

View File

@@ -18,7 +18,7 @@ Result(ServerConnection*) ServerConnection_open(Client* client, cstr server_addr
Deferral(16);
ServerConnection* conn = (ServerConnection*)malloc(sizeof(ServerConnection));
memset(conn, 0, sizeof(*conn));
zeroStruct(conn);
bool success = false;
Defer(if(!success) ServerConnection_close(conn));

View File

@@ -15,7 +15,7 @@ Result(Client*) Client_create(str username, str password){
Deferral(16);
Client* self = (Client*)malloc(sizeof(Client));
memset(self, 0, sizeof(Client));
zeroStruct(self);
bool success = false;
Defer(if(!success) Client_free(self));

View File

@@ -55,7 +55,7 @@ Result(u32) AESBlockEncryptor_encrypt(AESBlockEncryptor* ptr,
__Array_writeNext(&dst, ptr->iv, __AES_BLOCK_IV_SIZE);
EncryptedBlockHeader header;
memset(&header, 0, sizeof(header));
zeroStruct(&header);
memcpy(header.key_checksum, ptr->key_checksum, __AES_BLOCK_KEY_CHECKSUM_SIZE);
header.padding_size = (16 - src.len % 16) % 16;
// write header to buffer

View File

@@ -2,6 +2,7 @@
#include "tlibc/magic.h"
#include "tlibc/filesystem.h"
#include "tlibc/collections/HashMap.h"
#include "tlibc/string/StringBuilder.h"
#include "cryptography/AES.h"
#include <pthread.h>
@@ -216,11 +217,12 @@ void idb_close(IncrementalDB* db){
Result(IncrementalDB*) idb_open(str db_dir, NULLABLE(Array(u8) aes_key)){
Deferral(16);
try_assert(db_dir.len > 0);
try_assert(aes_key.len == 0 || aes_key.len == 16 || aes_key.len == 24 || aes_key.len == 32);
IncrementalDB* db = (IncrementalDB*)malloc(sizeof(IncrementalDB));
// value of *db must be set to zero or behavior of idb_close will be undefined
memset(db, 0, sizeof(IncrementalDB));
zeroStruct(db);
// if object construction fails, destroy incomplete object
bool success = false;
Defer(if(!success) idb_close(db));
@@ -255,7 +257,7 @@ void idb_unlockTable(Table* t){
}
Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str table_name, u32 row_size, bool lock_db){
Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str subdir, str table_name, u32 row_size, bool lock_db){
Deferral(16);
if(lock_db){
@@ -274,7 +276,7 @@ Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str table_name, u32 row_s
Table* t = (Table*)malloc(sizeof(Table));
// value of *t must be set to zero or behavior of Table_close will be undefined
memset(t, 0, sizeof(Table));
zeroStruct(t);
// if object construction fails, destroy incomplete object
bool success = false;
Defer(if(!success) Table_close(t));
@@ -282,10 +284,27 @@ Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str table_name, u32 row_s
t->db = db;
try_stderrcode(pthread_mutex_init(&t->mutex, NULL));
t->name = str_copy(table_name);
t->table_file_path = str_from_cstr(
strcat_malloc(db->db_dir.data, path_seps, t->name.data, ".idb-table"));
t->changes_file_path = str_from_cstr(
strcat_malloc(db->db_dir.data, path_seps, t->name.data, ".idb-changes"));
// set file paths
StringBuilder sb = StringBuilder_alloc(256);
Defer(StringBuilder_destroy(&sb));
StringBuilder_append_str(&sb, db->db_dir);
StringBuilder_append_char(&sb, path_sep);
if(subdir.len != 0){
StringBuilder_append_str(&sb, subdir);
try_void(dir_create(StringBuilder_getStr(&sb).data));
StringBuilder_append_char(&sb, path_sep);
}
StringBuilder_append_str(&sb, t->name);
// table file
str table_file_ext = STR(".idb-table");
StringBuilder_append_str(&sb, table_file_ext);
t->table_file_path = str_copy(StringBuilder_getStr(&sb));
// changes file
str changes_file_ext = STR(".idb-changes");
StringBuilder_removeFromEnd(&sb, table_file_ext.len);
StringBuilder_append_str(&sb, changes_file_ext);
t->changes_file_path = str_copy(StringBuilder_getStr(&sb));
bool table_file_exists = file_exists(t->table_file_path.data);

View File

@@ -30,7 +30,7 @@ void idb_lockTable(Table* t);
void idb_unlockTable(Table* t);
Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str table_name, u32 row_size, bool lock_db);
Result(Table*) idb_getOrCreateTable(IncrementalDB* db, NULLABLE(str) subdir, str table_name, u32 row_size, bool lock_db);
Result(void) idb_getRows(Table* t, u64 start_from_id, void* dst, u64 count, bool lock_table);
#define idb_getRow(T, ID, DST, LOCK) idb_getRows(T, ID, DST, 1, LOCK)

View File

@@ -32,6 +32,7 @@ typedef struct ChannelInfo {
char desc[CHANNEL_DESC_SIZE_MAX + 1];
} ATTRIBUTE_ALIGNED(4*1024) ChannelInfo;
// not a table
typedef struct MessageMeta {
/*
@@ -50,12 +51,13 @@ typedef struct MessageMeta {
// Stores some number of messages. Look in MessageBlockMeta to see how much.
typedef struct MessageBlock {
/* ((sequence MessageMeta), (sequence binary-data)) */
u8 data[64*1024 - 4];
u8 data[MESSAGE_BLOCK_SIZE];
} ATTRIBUTE_ALIGNED(64) MessageBlock;
// is used to find in which MessageBlock a message is stored
typedef struct MessageBlockMeta {
u64 message_id_first;
u64 first_message_id;
u32 messages_count;
u32 data_size;
} ATTRIBUTE_ALIGNED(16) MessageBlockMeta;

View File

@@ -29,7 +29,7 @@ Result(void) PacketHeader_validateContentSize(PacketHeader* ptr, u64 expected_si
}
void PacketHeader_construct(PacketHeader* ptr, u8 protocol_version, u16 type, u64 content_size){
memset(ptr, 0, sizeof(*ptr));
zeroStruct(ptr);
ptr->magic.n = PacketHeader_MAGIC.n;
ptr->protocol_version = protocol_version;
ptr->type = type;

View File

@@ -45,7 +45,7 @@ str validateUsername_str(str username){
void ErrorMessage_construct(ErrorMessage* ptr, PacketHeader* header, u32 msg_size){
_PacketHeader_construct(ErrorMessage);
memset(ptr, 0, sizeof(*ptr));
zeroStruct(ptr);
ptr->msg_size = msg_size;
}
@@ -54,7 +54,7 @@ Result(void) ClientHandshake_tryConstruct(ClientHandshake* ptr, PacketHeader* he
{
Deferral(1);
_PacketHeader_construct(ClientHandshake);
memset(ptr, 0, sizeof(*ptr));
zeroStruct(ptr);
try_assert(session_key.len == sizeof(ptr->session_key));
memcpy(ptr->session_key, session_key.data, session_key.len);
@@ -66,7 +66,7 @@ void ServerHandshake_construct(ServerHandshake* ptr, PacketHeader* header,
u64 session_id)
{
_PacketHeader_construct(ServerHandshake);
memset(ptr, 0, sizeof(*ptr));
zeroStruct(ptr);
ptr->session_id = session_id;
}
@@ -74,7 +74,7 @@ void ServerPublicInfoRequest_construct(ServerPublicInfoRequest *ptr, PacketHeade
ServerPublicInfo property)
{
_PacketHeader_construct(ServerPublicInfoRequest);
memset(ptr, 0, sizeof(*ptr));
zeroStruct(ptr);
ptr->property = property;
}
@@ -82,7 +82,7 @@ void ServerPublicInfoResponse_construct(ServerPublicInfoResponse* ptr, PacketHea
u32 data_size)
{
_PacketHeader_construct(ServerPublicInfoResponse);
memset(ptr, 0, sizeof(*ptr));
zeroStruct(ptr);
ptr->data_size = data_size;
}
@@ -91,7 +91,7 @@ Result(void) LoginRequest_tryConstruct(LoginRequest *ptr, PacketHeader* header,
{
Deferral(1);
_PacketHeader_construct(LoginRequest);
memset(ptr, 0, sizeof(*ptr));
zeroStruct(ptr);
str name_error_str = validateUsername_str(username);
if(name_error_str.data){
@@ -109,7 +109,7 @@ void LoginResponse_construct(LoginResponse* ptr, PacketHeader* header,
u64 user_id, u64 landing_channel_id)
{
_PacketHeader_construct(LoginResponse);
memset(ptr, 0, sizeof(*ptr));
zeroStruct(ptr);
ptr->user_id = user_id;
ptr->landing_channel_id = landing_channel_id;
@@ -120,7 +120,7 @@ Result(void) RegisterRequest_tryConstruct(RegisterRequest *ptr, PacketHeader* he
{
Deferral(1);
_PacketHeader_construct(RegisterRequest);
memset(ptr, 0, sizeof(*ptr));
zeroStruct(ptr);
str name_error_str = validateUsername_str(username);
if(name_error_str.data){
@@ -138,6 +138,43 @@ void RegisterResponse_construct(RegisterResponse *ptr, PacketHeader* header,
u64 user_id)
{
_PacketHeader_construct(RegisterResponse);
memset(ptr, 0, sizeof(*ptr));
zeroStruct(ptr);
ptr->user_id = user_id;
}
void SendMessageRequest_construct(SendMessageRequest *ptr, PacketHeader *header,
u64 channel_id, u16 data_size)
{
_PacketHeader_construct(SendMessageRequest);
zeroStruct(ptr);
ptr->channel_id = channel_id;
ptr->data_size = data_size;
}
void SendMessageResponse_construct(SendMessageResponse *ptr, PacketHeader *header,
u64 message_id, DateTime receiving_time_utc)
{
_PacketHeader_construct(SendMessageResponse);
zeroStruct(ptr);
ptr->message_id = message_id;
ptr->receiving_time_utc = receiving_time_utc;
}
void GetMessageBlockRequest_construct(GetMessageBlockRequest *ptr, PacketHeader *header,
u64 first_message_id, u32 messages_count)
{
_PacketHeader_construct(GetMessageBlockRequest);
zeroStruct(ptr);
ptr->first_message_id = first_message_id;
ptr->messages_count = messages_count;
}
void GetMessageBlockResponse_construct(GetMessageBlockResponse *ptr, PacketHeader *header,
u64 first_message_id, u32 messages_count, u32 data_size)
{
_PacketHeader_construct(GetMessageBlockResponse);
zeroStruct(ptr);
ptr->first_message_id = first_message_id;
ptr->messages_count = messages_count;
ptr->data_size = data_size;
}

View File

@@ -33,6 +33,10 @@ typedef enum PacketType {
PacketType_LoginResponse,
PacketType_RegisterRequest,
PacketType_RegisterResponse,
PacketType_SendMessageRequest,
PacketType_SendMessageResponse,
PacketType_GetMessageBlockRequest,
PacketType_GetMessageBlockResponse,
} ATTRIBUTE_PACKED PacketType;
@@ -140,21 +144,21 @@ void SendMessageResponse_construct(SendMessageResponse* ptr, PacketHeader* heade
typedef struct GetMessageBlockRequest {
u64 message_id_first;
u64 first_message_id;
u32 messages_count;
} ALIGN_PACKET_STRUCT GetMessageBlockRequest;
void GetMessageBlockRequest_construct(GetMessageBlockRequest* ptr, PacketHeader* header,
u64 message_id_first, u32 messages_count);
u64 first_message_id, u32 messages_count);
typedef struct GetMessageBlockResponse {
u64 message_id_first;
u64 first_message_id;
u32 messages_count;
u32 data_size;
/* stream of size data_size : ((sequence MessageMeta), (sequence binary-data)) */
} ALIGN_PACKET_STRUCT GetMessageBlockResponse;
void GetMessageBlockResponse_construct(GetMessageBlockResponse* ptr, PacketHeader* header,
u32 data_size);
u64 first_message_id, u32 messages_count, u32 data_size);

151
src/server/Channel.c Normal file
View File

@@ -0,0 +1,151 @@
#include "server_internal.h"
#include "tlibc/string/StringBuilder.h"
#include "tlibc/filesystem.h"
void Channel_free(Channel* self){
if(!self)
return;
str_destroy(self->name);
str_destroy(self->description);
List_MessageBlockMeta_destroy(&self->messages.blocks_meta_list);
LList_MessageBlock_destroy(&self->messages.blocks_queue);
free(self);
}
Result(Channel*) Channel_create(u64 chan_id, str name, str description,
IncrementalDB* db, bool lock_db)
{
Deferral(8);
Channel* self = (Channel*)malloc(sizeof(Channel));
zeroStruct(self);
bool success = false;
Defer(if(!success) Channel_free(self));
self->id = chan_id;
try_assert(name.len >= CHANNEL_NAME_SIZE_MIN && name.len <= CHANNEL_NAME_SIZE_MAX);
self->name = str_copy(name);
try_assert(description.len <= CHANNEL_DESC_SIZE_MAX);
self->description = str_copy(description);
if(lock_db){
idb_lockDB(db);
Defer(idb_unlockDB(db));
}
StringBuilder sb = StringBuilder_alloc(CHANNEL_NAME_SIZE_MAX + 32 + 1);
Defer(StringBuilder_destroy(&sb));
StringBuilder_append_str(&sb, STR("channels"));
StringBuilder_append_char(&sb, path_sep);
StringBuilder_append_str(&sb, name);
str subdir = str_copy(StringBuilder_getStr(&sb));
Defer(str_destroy(subdir));
str message_blocks_str = STR("message_blocks");
str message_blocks_meta_str = STR("message_blocks_meta");
StringBuilder_removeFromEnd(&sb, -1);
StringBuilder_append_str(&sb, name);
StringBuilder_append_char(&sb, '_');
StringBuilder_append_str(&sb, message_blocks_str);
try(self->messages.blocks_table, p,
idb_getOrCreateTable(db, subdir, StringBuilder_getStr(&sb), sizeof(MessageBlock), false)
);
idb_lockTable(self->messages.blocks_table);
Defer(idb_unlockTable(self->messages.blocks_table));
StringBuilder_removeFromEnd(&sb, message_blocks_str.len);
StringBuilder_append_str(&sb, message_blocks_meta_str);
try(self->messages.blocks_meta_table, p,
idb_getOrCreateTable(db, subdir, StringBuilder_getStr(&sb), sizeof(MessageBlock), false)
);
idb_lockTable(self->messages.blocks_meta_table);
Defer(idb_unlockTable(self->messages.blocks_meta_table));
// load whole message_blocks_meta table to list
try_void(
idb_createListFromTable(self->messages.blocks_meta_table, (void*)&self->messages.blocks_meta_list, false)
);
// load N last blocks to the queue
self->messages.blocks_queue = LList_construct(MessageBlock, NULL);
u64 message_blocks_count = self->messages.blocks_meta_list.len;
u64 first_block_id = 0;
if(message_blocks_count > MESSAGE_BLOCKS_CACHE_COUNT)
first_block_id = message_blocks_count - MESSAGE_BLOCKS_CACHE_COUNT;
for(u64 id = first_block_id; id < message_blocks_count; id++){
LLNode(MessageBlock)* node = LLNode_MessageBlock_createZero();
LList_MessageBlock_insertAfter(
&self->messages.blocks_queue,
self->messages.blocks_queue.last,
node
);
try_void(idb_getRow(self->messages.blocks_table, id, node->value.data, false));
}
if(self->messages.blocks_meta_list.len > 0){
MessageBlockMeta last_block_meta = self->messages.blocks_meta_list.data[self->messages.blocks_meta_list.len - 1];
self->messages.count = last_block_meta.first_message_id + last_block_meta.messages_count - 1;
}
else {
self->messages.count = 0;
}
success = true;
Return RESULT_VALUE(p, self);
}
Result(void) Channel_saveMessage(Channel* self, Array(u8) message_data, u64 sender_id,
MessageMeta* out_message_meta)
{
Deferral(1);
// create new block if message won't fit in the last existing
MessageBlockMeta* incomplete_block_meta = self->messages.blocks_meta_list.data + self->messages.blocks_meta_list.len;
u64 new_message_id = incomplete_block_meta->first_message_id + incomplete_block_meta->messages_count;
u32 message_size_in_block = sizeof(MessageMeta) + ALIGN_TO(message_data.len, 8);
if(incomplete_block_meta->data_size + message_size_in_block > MESSAGE_BLOCK_SIZE){
// create new MessageBlockMeta
incomplete_block_meta = List_MessageBlockMeta_expand(&self->messages.blocks_meta_list, 1);
incomplete_block_meta->first_message_id = new_message_id;
incomplete_block_meta->messages_count = 0;
incomplete_block_meta->data_size = 0;
// create new MessageBlock
LList_MessageBlock_insertAfter(
&self->messages.blocks_queue,
self->messages.blocks_queue.last,
LLNode_MessageBlock_createZero());
//TODO: save to DB
}
// copy message to message block
out_message_meta->magic = MESSAGE_MAGIC;
out_message_meta->data_size = message_data.len;
out_message_meta->id = new_message_id;
out_message_meta->sender_id = sender_id;
DateTime_getUTC(&out_message_meta->receiving_time_utc);
MessageBlock* incomplete_block = &self->messages.blocks_queue.last->value;
u8* data_ptr = incomplete_block->data + incomplete_block_meta->data_size;
memcpy(data_ptr, out_message_meta, sizeof(MessageMeta));
data_ptr += sizeof(MessageMeta);
memcpy(data_ptr, message_data.data, message_data.len);
Return RESULT_VOID;
}
// Result(void) Channel_loadMessage(Channel* self, u64 id,
// MessageMeta* out_message_meta, u8* out_data)
// {
// Deferral(1);
// Return RESULT_VOID;
// }
// Result(void) Channel_loadMessageBlock(Channel* self, u64 fisrt_message_id, u32 count,
// MessageBlockMeta* out_message_meta, MessageBlock* out_block)
// {
// Deferral(1);
// Return RESULT_VOID;
// }

View File

@@ -14,7 +14,7 @@ Result(ClientConnection*) ClientConnection_accept(ConnectionHandlerArgs* args)
Deferral(8);
ClientConnection* conn = (ClientConnection*)malloc(sizeof(ClientConnection));
memset(conn, 0, sizeof(*conn));
zeroStruct(conn);
bool success = false;
Defer(if(!success) ClientConnection_close(conn));
@@ -22,6 +22,7 @@ Result(ClientConnection*) ClientConnection_accept(ConnectionHandlerArgs* args)
conn->client_end = args->client_end;
conn->session_id = args->session_id;
conn->authorized = false;
conn->user_id = -1;
conn->session_key = Array_u8_alloc(AES_SESSION_KEY_SIZE);
// correct session key will be received from client later
Array_u8_memset(&conn->session_key, 0);

View File

@@ -1,7 +1,8 @@
#include "responses.h"
#define LOGGER conn->server->logger
#define LOG_FUNC conn->server->log_func
#define srv conn->server
#define LOGGER srv->logger
#define LOG_FUNC srv->log_func
#define LOG_CONTEXT log_ctx
declare_RequestHandler(GetMessageBlock)
@@ -14,6 +15,7 @@ declare_RequestHandler(GetMessageBlock)
try_void(PacketHeader_validateContentSize(req_head, sizeof(req)));
try_void(EncryptedSocketTCP_recvStruct(&conn->sock, &req));
(void)res_head;
// send response
// GetMessageBlockResponse res;
// GetMessageBlockResponse_construct(&res, res_head, );

View File

@@ -1,7 +1,8 @@
#include "responses.h"
#define LOGGER conn->server->logger
#define LOG_FUNC conn->server->log_func
#define srv conn->server
#define LOGGER srv->logger
#define LOG_FUNC srv->log_func
#define LOG_CONTEXT log_ctx
declare_RequestHandler(Login)
@@ -35,15 +36,15 @@ declare_RequestHandler(Login)
}
// lock users cache
idb_lockTable(conn->server->users.table);
idb_lockTable(srv->users.table);
bool unlocked_users_cache_mutex = false;
Defer(
if(!unlocked_users_cache_mutex)
idb_unlockTable(conn->server->users.table)
idb_unlockTable(srv->users.table)
);
// try get id from name cache
u64* id_ptr = HashMap_tryGetPtr(&conn->server->users.username_id_map, username_str);
u64* id_ptr = HashMap_tryGetPtr(&srv->users.name_id_map, username_str);
if(id_ptr == NULL){
try_void(sendErrorMessage_f(log_ctx, conn, res_head,
LogSeverity_Warn,
@@ -55,8 +56,8 @@ declare_RequestHandler(Login)
u64 user_id = *id_ptr;
// get user by id
try_assert(user_id < conn->server->users.cache_list.len);
UserInfo* u = conn->server->users.cache_list.data + user_id;
try_assert(user_id < srv->users.list.len);
UserInfo* u = srv->users.list.data + user_id;
// validate token hash
if(memcmp(req.token, u->token, sizeof(req.token)) != 0){
@@ -68,16 +69,17 @@ declare_RequestHandler(Login)
}
// manually unlock mutex
idb_unlockTable(conn->server->users.table);
idb_unlockTable(srv->users.table);
unlocked_users_cache_mutex = true;
// authorize
conn->authorized = true;
conn->user_id = user_id;
logInfo("authorized user '%s'", username_str.data);
// send response
LoginResponse res;
LoginResponse_construct(&res, res_head, user_id, conn->server->landing_channel_id);
LoginResponse_construct(&res, res_head, user_id, srv->landing_channel_id);
try_void(EncryptedSocketTCP_sendStruct(&conn->sock, res_head));
try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &res));

View File

@@ -1,7 +1,8 @@
#include "responses.h"
#define LOGGER conn->server->logger
#define LOG_FUNC conn->server->log_func
#define srv conn->server
#define LOGGER srv->logger
#define LOG_FUNC srv->log_func
#define LOG_CONTEXT log_ctx
declare_RequestHandler(Register)
@@ -35,16 +36,16 @@ declare_RequestHandler(Register)
}
// lock users cache
idb_lockTable(conn->server->users.table);
idb_lockTable(srv->users.table);
bool unlocked_users_cache_mutex = false;
// unlock mutex on error catch
Defer(
if(!unlocked_users_cache_mutex)
idb_unlockTable(conn->server->users.table)
idb_unlockTable(srv->users.table)
);
// check if name is taken
if(HashMap_tryGetPtr(&conn->server->users.username_id_map, username_str) != NULL){
if(HashMap_tryGetPtr(&srv->users.name_id_map, username_str) != NULL){
try_void(sendErrorMessage_f(log_ctx, conn, res_head,
LogSeverity_Warn,
"Username '%s' already exists",
@@ -54,7 +55,7 @@ declare_RequestHandler(Register)
// initialize new user
UserInfo user;
memset(&user, 0, sizeof(UserInfo));
zeroStruct(&user);
user.name_len = username_str.len;
memcpy(user.name, username_str.data, user.name_len);
@@ -65,14 +66,14 @@ declare_RequestHandler(Register)
// save new user to db and cache
try(u64 user_id, u,
idb_pushRow(conn->server->users.table, &user, false)
idb_pushRow(srv->users.table, &user, false)
);
try_assert(user_id == conn->server->users.cache_list.len);
List_UserInfo_pushMany(&conn->server->users.cache_list, &user, 1);
try_assert(HashMap_tryPush(&conn->server->users.username_id_map, username_str, &user_id));
try_assert(user_id == srv->users.list.len);
List_UserInfo_pushMany(&srv->users.list, &user, 1);
try_assert(HashMap_tryPush(&srv->users.name_id_map, username_str, &user_id));
// manually unlock mutex
idb_unlockTable(conn->server->users.table);
idb_unlockTable(srv->users.table);
unlocked_users_cache_mutex = true;
logInfo("registered user '%s'", username_str.data);

View File

@@ -1,7 +1,8 @@
#include "responses.h"
#define LOGGER conn->server->logger
#define LOG_FUNC conn->server->log_func
#define srv conn->server
#define LOGGER srv->logger
#define LOG_FUNC srv->log_func
#define LOG_CONTEXT log_ctx
declare_RequestHandler(SendMessage)
@@ -14,6 +15,14 @@ declare_RequestHandler(SendMessage)
try_void(PacketHeader_validateContentSize(req_head, sizeof(req)));
try_void(EncryptedSocketTCP_recvStruct(&conn->sock, &req));
if(!conn->authorized){
try_void(sendErrorMessage(log_ctx, conn, res_head,
LogSeverity_Warn,
STR("is not authorized")
));
Return RESULT_VOID;
}
if(req.data_size < MESSAGE_SIZE_MIN || req.data_size > MESSAGE_SIZE_MAX){
try_void(sendErrorMessage_f(log_ctx, conn, res_head,
LogSeverity_Warn,
@@ -28,16 +37,24 @@ declare_RequestHandler(SendMessage)
Array(u8) message_data = Array_u8_alloc(req.data_size);
try_void(EncryptedSocketTCP_recv(&conn->sock, message_data, SocketRecvFlag_WholeBuffer));
for(u16 i = 0; i < message_data.len; i++){
u8 b = message_data.data[i];
// save message to channel
Channel* ch = Server_tryGetChannel(conn->server, req.channel_id);
if(ch == NULL){
try_void(sendErrorMessage(log_ctx, conn, res_head,
LogSeverity_Warn,
STR("invalid channel id")
));
Return RESULT_VOID;
}
MessageMeta message_meta;
try_void(Channel_saveMessage(ch, message_data, conn->user_id, &message_meta));
// send response
// SendMessageResponse res;
// SendMessageResponse_construct(&res, res_head, );
// try_void(EncryptedSocketTCP_sendStruct(&conn->sock, res_head));
// try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &res));
SendMessageResponse res;
SendMessageResponse_construct(&res, res_head,
message_meta.id, message_meta.receiving_time_utc);
try_void(EncryptedSocketTCP_sendStruct(&conn->sock, res_head));
try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &res));
Return RESULT_VOID;
}

View File

@@ -1,7 +1,8 @@
#include "responses.h"
#define LOGGER conn->server->logger
#define LOG_FUNC conn->server->log_func
#define srv conn->server
#define LOGGER srv->logger
#define LOG_FUNC srv->log_func
#define LOG_CONTEXT log_ctx
declare_RequestHandler(ServerPublicInfo)
@@ -25,10 +26,10 @@ declare_RequestHandler(ServerPublicInfo)
Return RESULT_VOID;
}
case ServerPublicInfo_Name:
content = str_castTo_Array_u8(conn->server->name);
content = str_castTo_Array_u8(srv->name);
break;
case ServerPublicInfo_Description:
content = str_castTo_Array_u8(conn->server->description);
content = str_castTo_Array_u8(srv->description);
break;
}

View File

@@ -17,8 +17,7 @@ Result(void) sendErrorMessage_f(
#define declare_RequestHandler(TYPE) \
Result(void) handleRequest_##TYPE(\
cstr log_ctx, cstr req_type_name, \
Result(void) handleRequest_##TYPE(cstr log_ctx, cstr req_type_name, \
ClientConnection* conn, PacketHeader* req_head, PacketHeader* res_head)
#define case_handleRequest(TYPE) \
@@ -29,5 +28,5 @@ Result(void) sendErrorMessage_f(
declare_RequestHandler(ServerPublicInfo);
declare_RequestHandler(Login);
declare_RequestHandler(Register);
declare_RequestHandler(SendMessage);
declare_RequestHandler(GetMessageBlock);

View File

@@ -1,7 +1,8 @@
#include "responses.h"
#define LOGGER conn->server->logger
#define LOG_FUNC conn->server->log_func
#define srv conn->server
#define LOGGER srv->logger
#define LOG_FUNC srv->log_func
#define LOG_CONTEXT log_ctx
Result(void) sendErrorMessage(

View File

@@ -1,7 +1,8 @@
#include "responses.h"
#define LOGGER conn->server->logger
#define LOG_FUNC conn->server->log_func
#define srv conn->server
#define LOGGER srv->logger
#define LOG_FUNC srv->log_func
#define LOG_CONTEXT log_ctx
declare_RequestHandler(NAME)

View File

@@ -2,6 +2,7 @@
#include "tlibc/filesystem.h"
#include "tlibc/time.h"
#include "tlibc/base64.h"
#include "tlibc/algorithms.h"
#include "server/server_internal.h"
#include "network/tcp-chat-protocol/v1.h"
#include "server/responses/responses.h"
@@ -21,12 +22,15 @@ void Server_free(Server* self){
idb_close(self->db);
List_UserInfo_destroy(&self->users.cache_list);
HashMap_destroy(&self->users.username_id_map);
List_UserInfo_destroy(&self->users.list);
HashMap_destroy(&self->users.name_id_map);
for(u32 i = 0; i < self->channels.list.len; i++){
Channel_free(self->channels.list.data[i]);
}
List_ChannelPtr_destroy(&self->channels.list);
HashMap_destroy(&self->channels.name_id_map);
List_MessageBlockMeta_destroy(&self->messages.blocks_meta_list);
LList_MessageBlock_destroy(&self->messages.blocks_queue);
free(self->messages.incomplete_block);
free(self);
}
@@ -38,12 +42,12 @@ void Server_free(Server* self){
Result(Server*) Server_create(str config_file_content, cstr config_file_name,
void* logger, LogFunction_t log_func)
{
{
Deferral(16);
cstr log_ctx = "ServerInit";
Server* self = (Server*)malloc(sizeof(Server));
memset(self, 0, sizeof(Server));
zeroStruct(self);
bool success = false;
Defer(if(!success) Server_free(self));
@@ -51,102 +55,110 @@ Result(Server*) Server_create(str config_file_content, cstr config_file_name,
self->log_func = log_func;
logDebug("parsing config");
try(TomlTable* config_toml, p, toml_load_str_filename(config_file_content, config_file_name));
Defer(TomlTable_free(config_toml));
try(TomlTable* config_top, p, toml_load_str_filename(config_file_content, config_file_name));
Defer(TomlTable_free(config_top));
try(TomlTable* config_server, p, TomlTable_get_table(config_top, STR("server")))
// parse name
try(str* v_name, p, TomlTable_get_str(config_toml, STR("name")));
try(str* v_name, p, TomlTable_get_str(config_server, STR("name")));
self->name = str_copy(*v_name);
// parse description
try(str* v_desc, p, TomlTable_get_str(config_toml, STR("description")));
try(str* v_desc, p, TomlTable_get_str(config_server, STR("description")));
self->description = str_copy(*v_desc);
// parse landing_channel_id
try(i64 v_landing_channel_id, i, TomlTable_get_integer(config_toml, STR("landing_channel_id")));
self->landing_channel_id = v_landing_channel_id;
// parse local_address
try(str* v_local_address, p, TomlTable_get_str(config_toml, STR("local_address")));
try(str* v_local_address, p, TomlTable_get_str(config_server, STR("local_address")));
try_assert(v_local_address->isZeroTerminated);
try_void(EndpointIPv4_parse(v_local_address->data, &self->local_end));
// parse landing_channel_id
try(i64 v_landing_channel_id, i, TomlTable_get_integer(config_server, STR("landing_channel_id")));
self->landing_channel_id = v_landing_channel_id;
try(TomlTable* config_keys, p, TomlTable_get_table(config_top, STR("keys")))
// parse rsa_private_key
try(str* v_rsa_sk, p, TomlTable_get_str(config_toml, STR("rsa_private_key")));
try(str* v_rsa_sk, p, TomlTable_get_str(config_keys, STR("rsa_private_key")));
try_assert(v_rsa_sk->isZeroTerminated);
try_void(RSA_parsePrivateKey_base64(v_rsa_sk->data, &self->rsa_sk));
// parse rsa_public_key
try(str* v_rsa_pk, p, TomlTable_get_str(config_toml, STR("rsa_public_key")));
try(str* v_rsa_pk, p, TomlTable_get_str(config_keys, STR("rsa_public_key")));
try_assert(v_rsa_pk->isZeroTerminated);
try_void(RSA_parsePublicKey_base64(v_rsa_pk->data, &self->rsa_pk));
try(TomlTable* config_db, p, TomlTable_get_table(config_top, STR("database")))
// parse db_aes_key
try(str* v_db_aes_key, p, TomlTable_get_str(config_toml, STR("db_aes_key")));
try(str* v_db_aes_key, p, TomlTable_get_str(config_db, STR("aes_key")));
str db_aes_key_s = *v_db_aes_key;
Array(u8) db_aes_key = Array_u8_alloc(base64_decodedSize(db_aes_key_s.data, db_aes_key_s.len));
Defer(free(db_aes_key.data));
base64_decode(db_aes_key_s.data, db_aes_key_s.len, db_aes_key.data);
// parse db_dir and open db
try(str* v_db_dir, p, TomlTable_get_str(config_toml, STR("db_dir")));
// parse db_dir
try(str* v_db_dir, p, TomlTable_get_str(config_db, STR("dir")));
// open DB
try(self->db, p, idb_open(*v_db_dir, db_aes_key));
// build users cache
logDebug("loading users...");
logInfo("loading users...");
try(self->users.table, p,
idb_getOrCreateTable(self->db, STR("users"), sizeof(UserInfo), false)
idb_getOrCreateTable(self->db, str_null, STR("users"), sizeof(UserInfo), false)
);
// load whole users table to list
try_void(
idb_createListFromTable(self->users.table, (void*)&self->users.cache_list, false)
idb_createListFromTable(self->users.table, (void*)&self->users.list, false)
);
// build name-id map
try(u64 users_count, u, idb_getRowCount(self->users.table, false));
HashMap_construct(&self->users.username_id_map, u64, NULL);
HashMap_construct(&self->users.name_id_map, u64, NULL);
for(u64 id = 0; id < users_count; id++){
UserInfo* user = self->users.cache_list.data + id;
UserInfo* user = self->users.list.data + id;
str key = str_construct(user->name, user->name_len, true);
if(!HashMap_tryPush(&self->users.username_id_map, key, &id)){
Return RESULT_ERROR_FMT("duplicate user name '"FMT_str"'", key.len, key.data);
if(!HashMap_tryPush(&self->users.name_id_map, key, &id)){
Return RESULT_ERROR_FMT("duplicate user name '"FMT_str"'", str_expand(key));
}
}
logDebug("loaded "FMT_u64" users", users_count);
logInfo("loaded "FMT_u64" users", users_count);
// parse channels
logDebug("loading channels...");
HashMap_construct(&self->channels.name_id_map, u64, NULL);
self->channels.list = List_ChannelPtr_alloc(32);
try(TomlTable* config_channels, p, TomlTable_get_table(config_top, STR("channels")));
HashMapIter channels_iter = HashMapIter_create(config_channels);
while(HashMapIter_moveNext(&channels_iter)){
HashMapKeyValue kv;
HashMapIter_getCurrent(&channels_iter, &kv);
str name = kv.key;
TomlValue* val = kv.value_ptr;
// skip if not table
if(val->type != TLIBTOML_TABLE)
continue;
logInfo("loading channel '"FMT_str"'", str_expand(name))
TomlTable* config_channel = val->table;
try(u64 id, u, TomlTable_get_integer(config_channel, STR("id")));
try(str* v_ch_desc, p, TomlTable_get_str(config_channel, STR("description")))
str description = *v_ch_desc;
// build messages cache
logDebug("loading messages...");
try(self->messages.blocks_table, p,
idb_getOrCreateTable(self->db, STR("message_blocks"), sizeof(MessageBlock), false)
);
try(self->messages.blocks_meta_table, p,
idb_getOrCreateTable(self->db, STR("message_blocks_meta"), sizeof(MessageBlockMeta), false)
);
if(!HashMap_tryPush(&self->channels.name_id_map, name, &id)){
Return RESULT_ERROR_FMT("duplicate channel '"FMT_str"'", str_expand(name));
}
// load whole message_blocks_meta table to list
try_void(
idb_createListFromTable(self->messages.blocks_meta_table, (void*)&self->messages.blocks_meta_list, false)
);
// load N last blocks to the queue
self->messages.incomplete_block = LLNode_MessageBlock_createZero();
self->messages.blocks_queue = LList_construct(MessageBlock, NULL);
try(u64 message_blocks_count, u, idb_getRowCount(self->messages.blocks_table, false));
u64 first_id = 0;
if(message_blocks_count > MESSAGE_BLOCKS_CACHE_COUNT)
first_id = message_blocks_count - MESSAGE_BLOCKS_CACHE_COUNT;
for(u64 id = first_id; id < message_blocks_count; id++){
LLNode(MessageBlock)* node = LLNode_MessageBlock_createZero();
LList_MessageBlock_insertAfter(
&self->messages.blocks_queue,
self->messages.blocks_queue.last,
node
);
try_void(idb_getRow(self->messages.blocks_table, id, node->value.data, false));
logDebug("loading messages...");
try(Channel* channel, p, Channel_create(id, name, description, self->db, false));
logDebug("loaded "FMT_u64" messages", channel->messages.count);
List_ChannelPtr_push(&self->channels.list, channel);
}
logDebug("loaded "FMT_u64" message blocks", message_blocks_count);
logDebug("loaded "FMT_u32" channels", self->channels.list.len);
// sort channels list by id
for(u32 i = 0; i < self->channels.list.len; i++);
insertionSort_inline(self->channels.list.data, self->channels.list.len, ->id);
success = true;
Return RESULT_VALUE(p, self);
@@ -189,6 +201,15 @@ Result(void) Server_run(Server* server){
Return RESULT_VOID;
}
Channel* Server_tryGetChannel(Server* self, u64 id){
i32 index;
binarySearch_inline(self->channels.list.data, self->channels.list.len, id, ->id, index);
if(index == -1){
return NULL;
}
Channel* channel = self->channels.list.data[index];
return channel;
}
static void* handleConnection(void* _args){
ConnectionHandlerArgs* args = (ConnectionHandlerArgs*)_args;
@@ -244,7 +265,10 @@ static Result(void) try_handleConnection(ConnectionHandlerArgs* args, cstr log_c
case_handleRequest(ServerPublicInfo);
case_handleRequest(Login);
case_handleRequest(Register);
// authorized requests
case_handleRequest(SendMessage);
case_handleRequest(GetMessageBlock);
}
}

View File

@@ -17,6 +17,40 @@ LList_declare(MessageBlock);
#define MESSAGE_BLOCKS_CACHE_COUNT 50
typedef struct Channel {
u64 id;
str name;
str description;
struct {
u64 count;
Table* blocks_table;
Table* blocks_meta_table;
List(MessageBlockMeta) blocks_meta_list; // index is id
// last MESSAGE_BLOCKS_CACHE_COUNT MessageBlocks, ascending
// new messages are written to the last block
LList(MessageBlock) blocks_queue;
} messages;
} Channel;
typedef Channel* ChannelPtr;
List_declare(ChannelPtr);
Result(Channel*) Channel_create(u64 id, str name, str description,
IncrementalDB* db, bool lock_db);
void Channel_free(Channel* self);
Result(void) Channel_saveMessage(Channel* self, Array(u8) message_data, u64 sender_id,
MessageMeta* out_message_meta);
/// @param out_data buffer of size >= MESSAGE_SIZE_MAX
Result(void) Channel_loadMessage(Channel* self, u64 id,
MessageMeta* out_meta, u8* out_data);
Result(void) Channel_loadMessageBlock(Channel* self, u64 fisrt_message_id, u32 count,
MessageBlockMeta* out_meta, MessageBlock* out_block);
typedef struct Server {
/* from constructor */
void* logger;
@@ -34,19 +68,17 @@ typedef struct Server {
IncrementalDB* db;
struct {
Table* table;
List(UserInfo) cache_list; // index is id
HashMap(u64) username_id_map;
List(UserInfo) list; // index is id
HashMap(u64) name_id_map;
} users;
/* messages */
struct {
Table* blocks_table;
Table* blocks_meta_table;
List(MessageBlockMeta) blocks_meta_list; // index is id
LList(MessageBlock) blocks_queue; // last N MessageBlocks, ascending
LLNode(MessageBlock)* incomplete_block; // new messages are written here until block is full
} messages;
List(ChannelPtr) list;
HashMap(u64) name_id_map;
} channels;
} Server;
Channel* Server_tryGetChannel(Server* self, u64 id);
typedef struct ClientConnection {
Server* server;
@@ -55,6 +87,7 @@ typedef struct ClientConnection {
Array(u8) session_key;
EncryptedSocketTCP sock;
bool authorized;
u64 user_id; // -1 for unauthorized
} ClientConnection;
typedef struct ConnectionHandlerArgs {