From 8179609d47397a90d09b2f7bfb1ed65ce71ddaeb Mon Sep 17 00:00:00 2001 From: Timerix Date: Sat, 1 Nov 2025 19:51:43 +0500 Subject: [PATCH] implemented client-server connection, but found out RSA is broken --- .vscode/launch.json | 1 + dependencies/tlibc | 2 +- src/client/ClientCredential.c | 8 +- src/client/ServerConnection.c | 120 ++++++++++++----------- src/client/client.c | 17 ++-- src/client/client.h | 13 ++- src/config.c | 36 +++++++ src/config.h | 6 ++ src/cryptography/AES.c | 6 +- src/cryptography/RSA.c | 6 +- src/cryptography/RSA.h | 5 +- src/log.h | 8 ++ src/main.c | 94 ++++++++++++++---- src/network/encrypted_sockets.c | 1 + src/network/tcp-chat-protocol/constant.h | 4 +- src/network/tcp-chat-protocol/v1.c | 4 + src/network/tcp-chat-protocol/v1.h | 5 + src/server/ClientConnection.c | 81 ++++++++++++++- src/server/ServerCredentials.c | 26 +++++ src/server/server.c | 94 +++++++++++++++--- src/server/server.h | 27 ++++- 21 files changed, 441 insertions(+), 123 deletions(-) create mode 100644 src/config.c create mode 100644 src/config.h create mode 100644 src/log.h create mode 100644 src/server/ServerCredentials.c diff --git a/.vscode/launch.json b/.vscode/launch.json index f25ba41..fedd752 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -7,6 +7,7 @@ "request": "launch", "program": "${workspaceFolder}/bin/tcp-chat", "windows": { "program": "${workspaceFolder}/bin/tcp-chat.exe" }, + "args": [ "-l", "127.0.0.1:9988" ], "preLaunchTask": "build_exec_dbg", "stopAtEntry": false, "cwd": "${workspaceFolder}/bin", diff --git a/dependencies/tlibc b/dependencies/tlibc index 75c94e8..5fb2db2 160000 --- a/dependencies/tlibc +++ b/dependencies/tlibc @@ -1 +1 @@ -Subproject commit 75c94e88d9a7e12839d343bf4c7978cb7ecf91e1 +Subproject commit 5fb2db2380b678381ef455a18c8210a6a3314e60 diff --git a/src/client/ClientCredential.c b/src/client/ClientCredential.c index 87dd8d0..e171c7e 100644 --- a/src/client/ClientCredential.c +++ b/src/client/ClientCredential.c @@ -1,7 +1,7 @@ #include "client.h" #include "tlibc/string/StringBuilder.h" -void ClientCredential_free(ClientCredential* cred){ +void ClientCredentials_free(ClientCredentials* cred){ if(cred == NULL) return; free(cred->username.data); @@ -10,10 +10,10 @@ void ClientCredential_free(ClientCredential* cred){ } -Result(ClientCredential*) ClientCredential_create(str username, str password){ +Result(ClientCredentials*) ClientCredentials_create(str username, str password){ Deferral(8); - ClientCredential* cred = (ClientCredential*)malloc(sizeof(ClientCredential)); - memset(cred, 0, sizeof(ClientCredential)); + ClientCredentials* cred = (ClientCredentials*)malloc(sizeof(ClientCredentials)); + memset(cred, 0, sizeof(ClientCredentials)); bool success = false; Defer( if(!success) diff --git a/src/client/ServerConnection.c b/src/client/ServerConnection.c index 90e81b8..00bb61d 100644 --- a/src/client/ServerConnection.c +++ b/src/client/ServerConnection.c @@ -5,8 +5,7 @@ void ServerConnection_close(ServerConnection* conn){ if(conn == NULL) return; RSA_destroyPublicKey(&conn->server_pk); - socket_close(conn->system_socket.sock); - socket_close(conn->content_socket.sock); + socket_close(conn->sock.sock); free(conn->session_key.data); free(conn); } @@ -17,10 +16,9 @@ void ServerConnection_close(ServerConnection* conn){ Result(void) ServerLink_parse(cstr server_link_cstr, EndpointIPv4* server_end_out, br_rsa_public_key* server_key_out){ Deferral(8); str server_link_str = str_from_cstr(server_link_cstr); - i32 sep_pos = 0; // parse address and port - sep_pos = str_seekChar(server_link_str, ':', sep_pos); + i32 sep_pos = str_seekChar(server_link_str, ':', 0); if(sep_pos == -1){ Return RESULT_ERROR_FMT("server link is invalid: %s", server_link_cstr); } @@ -31,22 +29,21 @@ Result(void) ServerLink_parse(cstr server_link_cstr, EndpointIPv4* server_end_ou } // parse public key - sep_pos = str_seekChar(server_link_str, ':', sep_pos); + sep_pos = str_seekChar(server_link_str, ':', sep_pos + 1); if(sep_pos == -1){ Return RESULT_ERROR_FMT("server link is invalid: %s", server_link_cstr); } - str server_key_str = server_link_str; - server_key_str.data += sep_pos + 1; + str server_key_str = str_sliceAfter(server_link_str, sep_pos + 1); try_void(RSA_parsePublicKey_base64(server_key_str, server_key_out)); Return RESULT_VOID; } -Result(ServerConnection*) ServerConnection_open(ClientCredential* client_credential, cstr server_link_cstr){ +Result(ServerConnection*) ServerConnection_open(ClientCredentials* client_credentials, cstr server_link_cstr){ Deferral(64); ServerConnection* conn = (ServerConnection*)malloc(sizeof(ServerConnection)); - memset(conn, 0, sizeof(ServerConnection)); + memset(conn, 0, sizeof(*conn)); bool success = false; Defer( if(!success) @@ -57,6 +54,7 @@ Result(ServerConnection*) ServerConnection_open(ClientCredential* client_credent RSAEncryptor_construct(&conn->rsa_enc, &conn->server_pk); conn->session_key = Array_alloc_size(AES_SESSION_KEY_SIZE); + // generate random session key br_hmac_drbg_context key_rng = { .vtable = &br_hmac_drbg_vtable }; rng_init_sha256_seedFromSystem(&key_rng.vtable); br_hmac_drbg_generate(&key_rng, conn->session_key.data, conn->session_key.size); @@ -64,73 +62,85 @@ Result(ServerConnection*) ServerConnection_open(ClientCredential* client_credent // connect to server address Socket _s; try(_s, i, socket_open_TCP()); - EncryptedSocketTCP_construct(&conn->system_socket, _s, conn->session_key); - try_void(socket_connect(conn->system_socket.sock, conn->server_end)); + try_void(socket_connect(_s, conn->server_end)); + EncryptedSocketTCP_construct(&conn->sock, _s, conn->session_key); - Array(u8) encrypted_buf = Array_alloc_size(8*1024); - Defer(free(encrypted_buf.data)); - Array(u8) decrypted_buf = Array_alloc_size(8*1024); - Defer(free(decrypted_buf.data)); - u32 encrypted_size = 0, decrypted_size = 0; + Array(u8) enc_buf = Array_alloc_size(8*1024); + Defer(free(enc_buf.data)); + Array(u8) dec_buf = Array_alloc_size(8*1024); + Defer(free(dec_buf.data)); + u32 enc_size = 0, dec_size = 0; - // send handshake to the server - ClientHandshake client_handshake; - try_void(ClientHandshake_tryConstruct(&client_handshake, conn->session_key)); - // encryption by server public key - try(encrypted_size, u, RSAEncryptor_encrypt(&conn->rsa_enc, - struct_castTo_Array(&client_handshake), - encrypted_buf)); - try_void(socket_send(conn->system_socket.sock, - Array_construct_size(encrypted_buf.data, encrypted_size))); + // construct ClientHandshake in dec_buf + try_void(ClientHandshake_tryConstruct((ClientHandshake*)dec_buf.data, conn->session_key)); + // encrypt by server public key + try(enc_size, u, + RSAEncryptor_encrypt( + &conn->rsa_enc, + Array_sliceBefore(dec_buf, sizeof(ClientHandshake)), + enc_buf + ) + ); + try_void(socket_send(conn->sock.sock, Array_sliceBefore(enc_buf, enc_size))); // receive server response - encrypted_size = AESStreamEncryptor_calcDstSize(sizeof(PacketHeader)); - try(decrypted_size, u, EncryptedSocketTCP_recv(&conn->system_socket, - Array_construct_size(encrypted_buf.data, encrypted_size), - decrypted_buf, - SocketRecvFlag_WaitAll)); - try_assert(decrypted_size == sizeof(PacketHeader)); - PacketHeader* packet_header = decrypted_buf.data; + enc_size = AESStreamEncryptor_calcDstSize(sizeof(PacketHeader)); + try(dec_size, u, + EncryptedSocketTCP_recv(&conn->sock, + Array_sliceBefore(enc_buf, enc_size), + dec_buf, + SocketRecvFlag_WaitAll + ) + ); + try_assert(dec_size == sizeof(PacketHeader)); + PacketHeader* packet_header = dec_buf.data; try_void(PacketHeader_validateMagic(packet_header)); + // handle server response switch(packet_header->type){ - case PacketType_ErrorMessage: + case PacketType_ErrorMessage: { Array(u8) err_buf = Array_alloc_size(packet_header->content_size + 1); bool err_msg_completed = false; Defer( if(!err_msg_completed) free(err_buf.data); ); - encrypted_size = AESStreamEncryptor_calcDstSize(packet_header->content_size); - if(encrypted_size > encrypted_buf.size) - encrypted_size = encrypted_buf.size; - try_void(EncryptedSocketTCP_recv(&conn->system_socket, - Array_construct_size(encrypted_buf.data, encrypted_size), - err_buf, - SocketRecvFlag_WaitAll)); - ((u8*)err_buf.data)[encrypted_size] = 0; + + // receive error message of length packet_header->content_size + enc_size = AESStreamEncryptor_calcDstSize(packet_header->content_size); + if(enc_size > enc_buf.size) + enc_size = enc_buf.size; + try_void( + EncryptedSocketTCP_recv( + &conn->sock, + Array_sliceBefore(enc_buf, enc_size), + err_buf, + SocketRecvFlag_WaitAll + ) + ); + + ((u8*)err_buf.data)[enc_size] = 0; err_msg_completed = true; Return RESULT_ERROR((char*)err_buf.data, true); - case PacketType_ServerHandshake: - encrypted_size = AESStreamEncryptor_calcDstSize(sizeof(ServerHandshake) - sizeof(PacketHeader)); - try_void(EncryptedSocketTCP_recv(&conn->system_socket, - Array_construct_size(encrypted_buf.data, encrypted_size), - Array_construct_size((u8*)decrypted_buf.data + sizeof(PacketHeader), decrypted_buf.size - sizeof(PacketHeader)), - SocketRecvFlag_WaitAll - )); - ServerHandshake* server_handshake = decrypted_buf.data; + } + case PacketType_ServerHandshake: { + enc_size = AESStreamEncryptor_calcDstSize(sizeof(ServerHandshake) - sizeof(PacketHeader)); + try_void( + EncryptedSocketTCP_recv( + &conn->sock, + Array_sliceBefore(enc_buf, enc_size), + Array_sliceAfter(dec_buf, sizeof(PacketHeader)), + SocketRecvFlag_WaitAll + ) + ); + ServerHandshake* server_handshake = dec_buf.data; conn->session_id = server_handshake->session_id; break; + } default: Return RESULT_ERROR_FMT("unexpected response type: %i", packet_header->type); } - - try(_s, i, socket_open_TCP()); - EncryptedSocketTCP_construct(&conn->content_socket, _s, conn->session_key); - // TODO: how to connect the second socket to the server and associate it with current session? - //try_void(socket_connect(conn->content_socket.sock, conn->server_end????)); - success = true; Return RESULT_VALUE(p, conn); } diff --git a/src/client/client.c b/src/client/client.c index cb58c7c..ffcd6b1 100644 --- a/src/client/client.c +++ b/src/client/client.c @@ -15,12 +15,12 @@ static const str farewell_art = STR( "\\(_,J J L l`,)/\n" ); -static ClientCredential* _client_credential = NULL; +static ClientCredentials* _client_credentials = NULL; static ServerConnection* _server_connection = NULL; static Result(void) commandExec(str command, bool* stop); -static Result(void) askUserNameAndPassword(ClientCredential** cred){ +static Result(void) askUserNameAndPassword(ClientCredentials** cred){ Deferral(8); printf("username: "); @@ -32,7 +32,7 @@ static Result(void) askUserNameAndPassword(ClientCredential** cred){ // TODO: hide password fgets(password, sizeof(password), stdin); - try(*cred, p, ClientCredential_create(str_from_cstr(username), str_from_cstr(password))); + try(*cred, p, ClientCredentials_create(str_from_cstr(username), str_from_cstr(password))); Return RESULT_VOID; } @@ -43,7 +43,7 @@ Result(void) client_run() { } fputs(greeting_art.data, stdout); - try_void(askUserNameAndPassword(&_client_credential)); + try_void(askUserNameAndPassword(&_client_credentials)); Array(char) input_buf = Array_alloc(char, 10000); str command_input = str_null; @@ -68,7 +68,7 @@ Result(void) client_run() { } free(input_buf.data); - ClientCredential_free(_client_credential); + ClientCredentials_free(_client_credentials); ServerConnection_close(_server_connection); Return RESULT_VOID; } @@ -77,7 +77,7 @@ Result(void) client_run() { static Result(void) commandExec(str command, bool* stop){ Deferral(64); - char answer_buf[512]; + char answer_buf[10000]; const u32 answer_buf_size = sizeof(answer_buf); if(is_alias("q") || is_alias("quit") || is_alias("exit")){ fputs(farewell_art.data, stdout); @@ -89,6 +89,7 @@ static Result(void) commandExec(str command, bool* stop){ else if(is_alias("h") || is_alias("help")){ puts( "COMMANDS:\n" + "h, help Show this message.\n" "q, quit, exit Close the program.\n" "clear Clear the screen.\n" "j, join Join a server.\n" @@ -103,8 +104,8 @@ static Result(void) commandExec(str command, bool* stop){ str new_server_link = str_from_cstr(answer_buf); str_trim(&new_server_link, true); - printf("connecting to server %s\n", new_server_link.data); - try(_server_connection, p, ServerConnection_open(_client_credential, new_server_link.data)); + printf("connecting to server...\n"); + try(_server_connection, p, ServerConnection_open(_client_credentials, new_server_link.data)); // TODO: request server info // show server info diff --git a/src/client/client.h b/src/client/client.h index 565211f..738aaeb 100644 --- a/src/client/client.h +++ b/src/client/client.h @@ -5,15 +5,15 @@ Result(void) client_run(); -typedef struct ClientCredential { +typedef struct ClientCredentials { str username; Array(u8) aes_key; AESBlockEncryptor user_data_aes_enc; AESBlockDecryptor user_data_aes_dec; -} ClientCredential; +} ClientCredentials; -Result(ClientCredential*) ClientCredential_create(str username, str password); -void ClientCredential_free(ClientCredential* cred); +Result(ClientCredentials*) ClientCredentials_create(str username, str password); +void ClientCredentials_free(ClientCredentials* cred); typedef struct ServerConnection { u64 session_id; @@ -21,9 +21,8 @@ typedef struct ServerConnection { br_rsa_public_key server_pk; RSAEncryptor rsa_enc; Array(u8) session_key; - EncryptedSocketTCP system_socket; - EncryptedSocketTCP content_socket; + EncryptedSocketTCP sock; } ServerConnection; -Result(ServerConnection*) ServerConnection_open(ClientCredential* client_credential, cstr server_link_cstr); +Result(ServerConnection*) ServerConnection_open(ClientCredentials* client_credentials, cstr server_link_cstr); void ServerConnection_close(ServerConnection* conn); diff --git a/src/config.c b/src/config.c new file mode 100644 index 0000000..eaae4dc --- /dev/null +++ b/src/config.c @@ -0,0 +1,36 @@ +#include "config.h" + +Result(void) config_findValue(str config_str, str key, str* value, bool throwNotFoundError){ + u32 line_n = 0; + while(config_str.size > 0){ + line_n++; + i32 line_end = str_seekChar(config_str, '\n', 0); + if(line_end < 0) + line_end = config_str.size - 1; + str line = str_sliceBefore(config_str, line_end); + config_str = str_sliceAfter(config_str, line_end + 1); + + i32 sep_pos = str_seekChar(line, '=', 1); + if(sep_pos < 0){ + //not a 'key = value' line + continue; + } + + str line_key = str_sliceBefore(line, sep_pos - 1); + str_trim(&line_key, false); + if(str_equals(line_key, key)){ + str line_value = str_sliceAfter(line, sep_pos + 1); + str_trim(&line_value, false); + *value = line_value; + return RESULT_VOID; + } + } + + if(throwNotFoundError){ + char* key_cstr = str_copy(key).data; + char* err_msg = sprintf_malloc(key.size + 64, "can't find key '%s'", key_cstr); + free(key_cstr); + return RESULT_ERROR(err_msg, true); + } + return RESULT_VOID; +} diff --git a/src/config.h b/src/config.h new file mode 100644 index 0000000..61f8c76 --- /dev/null +++ b/src/config.h @@ -0,0 +1,6 @@ +#pragma once +#include "tlibc/errors.h" +#include "tlibc/string/str.h" + +/// searches for pattern `key = value` +Result(void) config_findValue(str config_str, str key, str* value, bool throwNotFoundError); diff --git a/src/cryptography/AES.c b/src/cryptography/AES.c index 089d38e..718effe 100755 --- a/src/cryptography/AES.c +++ b/src/cryptography/AES.c @@ -4,15 +4,13 @@ // write data from src to array and increment array data pointer static inline void __Array_writeNext(Array(u8)* dst, u8* src, size_t size){ memcpy(dst->data, src, size); - dst->data = (u8*)dst->data + size; - dst->size -= size; + *dst = Array_sliceAfter(*dst, size); } // read data from array to dst and increment array data pointer static inline void __Array_readNext(u8* dst, Array(u8)* src, size_t size){ memcpy(dst, src->data, size); - src->data = (u8*)src->data + size; - src->size -= size; + *src = Array_sliceAfter(*src, size); } diff --git a/src/cryptography/RSA.c b/src/cryptography/RSA.c index 9e1f122..97c1c56 100644 --- a/src/cryptography/RSA.c +++ b/src/cryptography/RSA.c @@ -31,7 +31,7 @@ Result(void) RSA_generateKeyPair(u32 key_size, Return RESULT_VOID; } -Result(void) RSA_generateKeyPairFromTime(u32 key_size, +Result(void) RSA_generateKeyPairFromSystemRandom(u32 key_size, br_rsa_private_key* sk, br_rsa_public_key* pk) { Deferral(8); @@ -122,7 +122,7 @@ Result(void) RSA_parsePublicKey_base64(const str src, br_rsa_public_key* pk){ pk->elen = 4; pk->nlen = key_buffer_size - 4; pk->e = pk->n + pk->nlen; - u32 offset = str_seekChar(src, ':', 14) + 1; + u32 offset = str_seekChar(src, ':', 10) + 1; if(offset == 0){ Return RESULT_ERROR("missing ':' before key data", false); } @@ -155,7 +155,7 @@ Result(void) RSA_parsePrivateKey_base64(const str src, br_rsa_private_key* sk){ sk->dp = sk->q + field_len; sk->dq = sk->dp + field_len; sk->iq = sk->dq + field_len; - u32 offset = str_seekChar(src, ':', 14) + 1; + u32 offset = str_seekChar(src, ':', 10) + 1; if(offset == 0){ Return RESULT_ERROR("missing ':' before key data", false); } diff --git a/src/cryptography/RSA.h b/src/cryptography/RSA.h index 373726b..82d2879 100644 --- a/src/cryptography/RSA.h +++ b/src/cryptography/RSA.h @@ -12,7 +12,7 @@ #define RSA_DEFAULT_KEY_SIZE 3072 -/// @brief generate random key pair based on system time +/// @brief generate random key pair /// @param key_size size of public key in bits (2048/3072/4096) /// @param sk key for decryption /// @param pk key for encryption @@ -21,7 +21,8 @@ Result(void) RSA_generateKeyPair(u32 key_size, br_rsa_private_key* sk, br_rsa_public_key* pk, const br_prng_class** rng_vtable_ptr); -Result(void) RSA_generateKeyPairFromTime(u32 key_size, +/// @brief generate random key pair using system crypto-rng provider +Result(void) RSA_generateKeyPairFromSystemRandom(u32 key_size, br_rsa_private_key* sk, br_rsa_public_key* pk); Result(void) RSA_generateKeyPairFromPassword(u32 key_size, diff --git a/src/log.h b/src/log.h new file mode 100644 index 0000000..184450a --- /dev/null +++ b/src/log.h @@ -0,0 +1,8 @@ +#pragma once +#include + +#define log(context, severity, format, ...) printf("[%s][" severity "]: " format "\n", context ,##__VA_ARGS__) +#define logDebug(context, format, ...) log(context, "DEBUG", format ,##__VA_ARGS__) +#define logInfo(context, format, ...) log(context, "INFO", format ,##__VA_ARGS__) +#define logWarning(context, format, ...) log(context, "WARNING", format ,##__VA_ARGS__) +#define logError(context, format, ...) log(context, "ERROR", format ,##__VA_ARGS__) diff --git a/src/main.c b/src/main.c index b5c6fba..5203e78 100755 --- a/src/main.c +++ b/src/main.c @@ -2,10 +2,14 @@ #include "client/client.h" #include "server/server.h" +#define _DEFAULT_CONFIG_PATH_CLIENT "tcp-chat-client.config" +#define _DEFAULT_CONFIG_PATH_SERVER "tcp-chat-server.config" + typedef enum ProgramMode { Client, Server, - RsaGen, + RsaGenStdin, + RsaGenRandom, } ProgramMode; #define arg_is(LITERAL) str_equals(arg_str, STR(LITERAL)) @@ -19,8 +23,9 @@ int main(const int argc, cstr const* argv){ } ProgramMode mode = Client; - cstr server_endpoint_cstr; - u32 key_size = RSA_DEFAULT_KEY_SIZE; + cstr server_endpoint_cstr = NULL; + cstr config_path = NULL; + u32 key_size = 0; for(int argi = 1; argi < argc; argi++){ str arg_str = str_from_cstr(argv[argi]); @@ -30,7 +35,12 @@ int main(const int argc, cstr const* argv){ "no arguments Interactive client mode.\n" "-h, --help Show this message.\n" "-l, --listen [addr:port] Start server.\n" - "--rsa-gen [size] Generate RSA private and public keys based on stdin data (64Kb max)\n" + "--config [path] Load config from specified path.\n" + " Default path for config is '" _DEFAULT_CONFIG_PATH_CLIENT "' or '" _DEFAULT_CONFIG_PATH_SERVER "'\n" + "--rsa-gen-stdin [size] Generate RSA private and public keys based on stdin data (64Kb max).\n" + " size: 2048 / 3072 (default) / 4096\n" + " Usage: `cat somefile | tcp-chat --gen-rsa-stdin`\n" + "--rsa-gen-random [size] Generate random RSA private and public keys.\n" " size: 2048 / 3072 (default) / 4096\n" ); Return 0; @@ -48,18 +58,39 @@ int main(const int argc, cstr const* argv){ } server_endpoint_cstr = argv[argi]; } - else if(arg_is("--rsa-gen")){ + else if(arg_is("--config")){ + if(++argi >= argc){ + printfe("ERROR: no config path specified\n"); + Return 1; + } + config_path = argv[argi]; + } + else if(arg_is("--rsa-gen-stdin")){ if(mode != Client){ printf("program mode is set already\n"); Return 1; } - mode = RsaGen; + mode = RsaGenStdin; if(++argi >= argc){ + key_size = RSA_DEFAULT_KEY_SIZE; + } + else if(sscanf(argv[argi], "%u", &key_size) != 1){ printfe("ERROR: no key size specified\n"); + } + } + + else if(arg_is("--rsa-gen-random")){ + if(mode != Client){ + printf("program mode is set already\n"); Return 1; } - if(sscanf(argv[argi], "%u", &key_size) != 1){ + + mode = RsaGenRandom; + if(++argi >= argc){ + key_size = RSA_DEFAULT_KEY_SIZE; + } + else if(sscanf(argv[argi], "%u", &key_size) != 1){ printfe("ERROR: no key size specified\n"); } } @@ -75,13 +106,21 @@ int main(const int argc, cstr const* argv){ Defer(network_deinit()); switch(mode){ - case Client: + case Client: { + if(config_path == NULL) + config_path = _DEFAULT_CONFIG_PATH_CLIENT; try_fatal_void(client_run()); break; - case Server: - try_fatal_void(server_run(server_endpoint_cstr)); + } + + case Server: { + if(config_path == NULL) + config_path = _DEFAULT_CONFIG_PATH_SERVER; + try_fatal_void(server_run(server_endpoint_cstr, config_path)); break; - case RsaGen:{ + } + + case RsaGenStdin: { size_t input_max_size = 64*1024; char* input_buf = malloc(input_max_size); Defer(free(input_buf)); @@ -91,7 +130,7 @@ int main(const int argc, cstr const* argv){ Return 1; } str input_str = str_construct(input_buf, read_n, false); - printfe("generating RSA key pair...\n"); + printfe("generating RSA key pair based on stdin...\n"); br_rsa_private_key sk; br_rsa_public_key pk; try_fatal_void(RSA_generateKeyPairFromPassword(key_size, &sk, &pk, input_str)); @@ -99,18 +138,37 @@ int main(const int argc, cstr const* argv){ RSA_destroyPrivateKey(&sk); RSA_destroyPublicKey(&pk); ); - puts("-----BEGIN RSA PRIVATE KEY-----"); + str sk_str = RSA_serializePrivateKey_base64(&sk); - puts(sk_str.data); + printf("rsa_private_key = %s\n", sk_str.data); free(sk_str.data); - puts("-----END RSA PRIVATE KEY-------"); - puts("-----BEGIN RSA PUBLIC KEY------"); + str pk_str = RSA_serializePublicKey_base64(&pk); - puts(pk_str.data); + printf("\nrsa_public_key = %s\n", pk_str.data); free(pk_str.data); - puts("-----END RSA PUBLIC KEY--------"); } break; + + case RsaGenRandom: { + printfe("generating random RSA key pair...\n"); + br_rsa_private_key sk; + br_rsa_public_key pk; + try_fatal_void(RSA_generateKeyPairFromSystemRandom(key_size, &sk, &pk)); + Defer( + RSA_destroyPrivateKey(&sk); + RSA_destroyPublicKey(&pk); + ); + + str sk_str = RSA_serializePrivateKey_base64(&sk); + printf("rsa_private_key = %s\n", sk_str.data); + free(sk_str.data); + + str pk_str = RSA_serializePublicKey_base64(&pk); + printf("\nrsa_public_key = %s\n", pk_str.data); + free(pk_str.data); + } + break; + default: printfe("ERROR: invalid program mode %i\n", mode); Return 1; diff --git a/src/network/encrypted_sockets.c b/src/network/encrypted_sockets.c index 5ab8f67..46e6830 100644 --- a/src/network/encrypted_sockets.c +++ b/src/network/encrypted_sockets.c @@ -24,6 +24,7 @@ Result(i32) EncryptedSocketTCP_recv(EncryptedSocketTCP* ptr, { Deferral(4); try(i32 received_size, i, socket_recv(ptr->sock, encrypted_buf, flags)); + //TODO: return something when received_size == 0 (socket has been closed) encrypted_buf.size = received_size; try(i32 decrypted_size, u, AESStreamDecryptor_decrypt(&ptr->dec, encrypted_buf, decrypted_buf)); Return RESULT_VALUE(i, decrypted_size); diff --git a/src/network/tcp-chat-protocol/constant.h b/src/network/tcp-chat-protocol/constant.h index a7011f2..81c437e 100644 --- a/src/network/tcp-chat-protocol/constant.h +++ b/src/network/tcp-chat-protocol/constant.h @@ -6,11 +6,13 @@ extern const Magic64 PacketHeader_MAGIC; +// sizeof(PacketHeader) must be 64 typedef struct PacketHeader { Magic64 magic; u8 protocol_version; - u8 _reserved; + u8 _reserved1; u16 type; + u32 _reserved4; u64 content_size; } __attribute__((aligned(64))) PacketHeader; diff --git a/src/network/tcp-chat-protocol/v1.c b/src/network/tcp-chat-protocol/v1.c index fcbb101..dc0de44 100644 --- a/src/network/tcp-chat-protocol/v1.c +++ b/src/network/tcp-chat-protocol/v1.c @@ -8,3 +8,7 @@ Result(void) ClientHandshake_tryConstruct(ClientHandshake* ptr, Array(u8) sessio Return RESULT_VOID; } +void ServerHandshake_construct(ServerHandshake* ptr, u64 session_id){ + PacketHeader_construct(&ptr->header, PROTOCOL_VERSION, PacketType_ClientHandshake, sizeof(session_id)); + ptr->session_id = session_id; +} diff --git a/src/network/tcp-chat-protocol/v1.h b/src/network/tcp-chat-protocol/v1.h index 5ce2ee7..1361818 100644 --- a/src/network/tcp-chat-protocol/v1.h +++ b/src/network/tcp-chat-protocol/v1.h @@ -4,6 +4,7 @@ #define PROTOCOL_VERSION 1 /* 1.0.0 */ + typedef enum PacketType { PacketType_Invalid, PacketType_ErrorMessage, @@ -11,11 +12,13 @@ typedef enum PacketType { PacketType_ServerHandshake, } __attribute__((__packed__)) PacketType; + typedef struct ErrorMessage { PacketHeader header; /* content stream of size `header.content_size` */ } ErrorMessage; + typedef struct ClientHandshake { PacketHeader header; u8 session_key[AES_SESSION_KEY_SIZE]; @@ -28,3 +31,5 @@ typedef struct ServerHandshake { PacketHeader header; u64 session_id; } ServerHandshake; + +void ServerHandshake_construct(ServerHandshake* ptr, u64 session_id); diff --git a/src/server/ClientConnection.c b/src/server/ClientConnection.c index 64bab43..609a303 100644 --- a/src/server/ClientConnection.c +++ b/src/server/ClientConnection.c @@ -1,10 +1,87 @@ #include "server.h" +#include "network/tcp-chat-protocol/v1.h" void ClientConnection_close(ClientConnection* conn){ if(conn == NULL) return; - socket_close(conn->system_socket.sock); - socket_close(conn->content_socket.sock); + socket_close(conn->sock.sock); free(conn->session_key.data); free(conn); } + +Result(ClientConnection*) ClientConnection_accept(ServerCredentials* server_credentials, + Socket sock_tcp, EndpointIPv4 client_end, u64 session_id) +{ + Deferral(32); + + ClientConnection* conn = (ClientConnection*)malloc(sizeof(ClientConnection)); + memset(conn, 0, sizeof(*conn)); + bool success = false; + Defer( + if(!success) + ClientConnection_close(conn); + ); + + conn->client_end = client_end; + conn->session_id = session_id; + conn->session_key = Array_alloc_size(AES_SESSION_KEY_SIZE); + + Array(u8) enc_buf = Array_alloc_size(8*1024); + Defer(free(enc_buf.data)); + Array(u8) dec_buf = Array_alloc_size(8*1024); + Defer(free(dec_buf.data)); + u32 enc_size = 0, dec_size = 0; + + // receive message encrypted by server public key + try(enc_size, u, + socket_recv( + sock_tcp, + Array_sliceBefore(enc_buf, server_credentials->rsa_pk.nlen), + SocketRecvFlag_WaitAll + ) + ); + + // decrypt the message using server private key + RSADecryptor rsa_dec; + RSADecryptor_construct(&rsa_dec, &server_credentials->rsa_sk); + try(dec_size, u, + RSADecryptor_decrypt( + &rsa_dec, + Array_sliceBefore(enc_buf, enc_size) + ) + ); + + // validate client handshake + if(dec_size != sizeof(ClientHandshake)){ + Return RESULT_ERROR_FMT( + "decrypted message (size: %u) is not a ClientHandshake (size: %u)", + dec_size, (u32)sizeof(ClientHandshake) + ); + } + ClientHandshake* client_handshake = dec_buf.data; + try_void(PacketHeader_validateMagic(&client_handshake->header)); + if(client_handshake->header.type != PacketType_ClientHandshake){ + Return RESULT_ERROR_FMT( + "received message of unexpected type: %u", + client_handshake->header.type + ); + } + + // use received session key + memcpy(conn->session_key.data, client_handshake->session_key, conn->session_key.size); + EncryptedSocketTCP_construct(&conn->sock, sock_tcp, conn->session_key); + + // construct ServerHandshake in dec_buf + ServerHandshake_construct((ServerHandshake*)dec_buf.data, session_id); + // send ServerHandshake over encrypted TCP socket + try_void( + EncryptedSocketTCP_send( + &conn->sock, + Array_sliceBefore(dec_buf, sizeof(ServerHandshake)), + enc_buf + ) + ); + + success = true; + Return RESULT_VALUE(p, conn); +} \ No newline at end of file diff --git a/src/server/ServerCredentials.c b/src/server/ServerCredentials.c new file mode 100644 index 0000000..5466a20 --- /dev/null +++ b/src/server/ServerCredentials.c @@ -0,0 +1,26 @@ +#include "server.h" + + +Result(ServerCredentials*) ServerCredentials_create(const str rsa_sk_base64, const str rsa_pk_base64){ + Deferral(4); + + ServerCredentials* cred = (ServerCredentials*)malloc(sizeof(ServerCredentials)); + memset(cred, 0, sizeof(*cred)); + bool success = false; + Defer( + if(!success) + ServerCredentials_free(cred); + ); + + try_void(RSA_parsePrivateKey_base64(rsa_sk_base64, &cred->rsa_sk)); + try_void(RSA_parsePublicKey_base64(rsa_pk_base64, &cred->rsa_pk)); + + success = true; + Return RESULT_VALUE(p, cred); +} + +void ServerCredentials_free(ServerCredentials* cred){ + RSA_destroyPrivateKey(&cred->rsa_sk); + RSA_destroyPublicKey(&cred->rsa_pk); + free(cred); +} diff --git a/src/server/server.c b/src/server/server.c index 1b64a19..4f2cb61 100644 --- a/src/server/server.c +++ b/src/server/server.c @@ -1,8 +1,9 @@ -#include "server.h" -#include "db/idb.h" #include - -static void* handle_connection(void* _args); +#include "tlibc/filesystem.h" +#include "db/idb.h" +#include "server.h" +#include "config.h" +#include "log.h" typedef struct ConnectionHandlerArgs { Socket accepted_socket; @@ -10,22 +11,57 @@ typedef struct ConnectionHandlerArgs { u64 session_id; } ConnectionHandlerArgs; -Result(void) server_run(cstr server_endpoint_str){ +static void* handle_connection(void* _args); +static Result(void) try_handle_connection(ConnectionHandlerArgs* args, cstr log_ctx); + + +static ServerCredentials* _server_credentials = NULL; + + +static Result(void) parseConfig(cstr config_path){ + Deferral(8); + + // open file + try(FILE* config_file, p, file_open(config_path, FO_ReadExisting)); + // read whole file into Array(char) + try(i64 config_file_size, i, file_getSize(config_file)); + Array(char) config_buf = Array_alloc(char, config_file_size); + Defer(free(config_buf.data)); + try_void(file_readBytesArray(config_file, config_buf)); + str config_str = Array_castTo_str(config_buf, false); + + str sk_base64; + str pk_base64; + try_void(config_findValue(config_str, STR("rsa_private_key"), &sk_base64, true)); + try_void(config_findValue(config_str, STR("rsa_public_key"), &pk_base64, true)); + try(_server_credentials, p, ServerCredentials_create(sk_base64, pk_base64)); + + Return RESULT_VOID; +} + +Result(void) server_run(cstr server_endpoint_cstr, cstr config_path){ Deferral(32); + cstr log_ctx = "Server/MainThread"; + logInfo(log_ctx, "starting server"); + logDebug(log_ctx, "parsing config"); + try_void(parseConfig(config_path)); + + logDebug(log_ctx, "initializing main socket"); EndpointIPv4 server_end; - EndpointIPv4_parse(server_endpoint_str, &server_end); - //TODO: add log + EndpointIPv4_parse(server_endpoint_cstr, &server_end); try(Socket main_socket, i, socket_open_TCP()); try_void(socket_bind(main_socket, server_end)); try_void(socket_listen(main_socket, 512)); + logInfo(log_ctx, "server is listening at %s", server_endpoint_cstr); u64 session_id = 1; - while(true){ ConnectionHandlerArgs* args = (ConnectionHandlerArgs*)malloc(sizeof(ConnectionHandlerArgs)); try(args->accepted_socket, i, socket_accept(main_socket, &args->client_end)); args->session_id = session_id++; pthread_t conn_thread = {0}; + //TODO: use async IO instead of threads to not waste system resources + // while waiting for incoming data in 100500 threads try_stderrcode(pthread_create(&conn_thread, NULL, handle_connection, args)); } @@ -33,13 +69,41 @@ Result(void) server_run(cstr server_endpoint_str){ } static void* handle_connection(void* _args){ - Deferral(64); - //ConnectionHandlerArgs* args = (ConnectionHandlerArgs*)_args; - Defer(free(_args)); - // TODO: receive handshake and session key + ConnectionHandlerArgs* args = (ConnectionHandlerArgs*)_args; + char log_ctx[64]; + sprintf(log_ctx, "Session-" IFWIN("%llx", "%lx"), args->session_id); - //ClientConnection conn; + Result(void) r = try_handle_connection(args, log_ctx); + if(r.error){ + str error_s = Error_toStr(r.error); + logError(log_ctx, "%s", error_s.data); + free(error_s.data); + } - - Return NULL; + return NULL; +} + +static Result(void) try_handle_connection(ConnectionHandlerArgs* args, cstr log_ctx){ + Deferral(64); + Defer(free(args)); + + ClientConnection* conn = NULL; + Defer(ClientConnection_close(conn)); + // establish encrypted connection + try(conn, p, + ClientConnection_accept( + _server_credentials, + args->accepted_socket, + args->client_end, + args->session_id + ) + ); + logDebug(log_ctx, "session accepted"); + + // handle requests + while(true){ + sleepMsec(10); + } + + Return RESULT_VOID; } diff --git a/src/server/server.h b/src/server/server.h index 06fc5cc..808a3d3 100644 --- a/src/server/server.h +++ b/src/server/server.h @@ -3,12 +3,33 @@ #include "cryptography/RSA.h" #include "network/encrypted_sockets.h" -Result(void) server_run(cstr server_endpoint_str); +Result(void) server_run(cstr server_endpoint_cstr, cstr config_path); + + +typedef struct ServerCredentials { + br_rsa_private_key rsa_sk; + br_rsa_public_key rsa_pk; +} ServerCredentials; + +Result(ServerCredentials*) ServerCredentials_create(const str rsa_sk_base64, const str rsa_pk_base64); + +void ServerCredentials_free(ServerCredentials* cred); + + +typedef struct ServerMetadata { + str name; + str description; +} ServerMetadata; + typedef struct ClientConnection { u64 session_id; EndpointIPv4 client_end; Array(u8) session_key; - EncryptedSocketTCP system_socket; - EncryptedSocketTCP content_socket; + EncryptedSocketTCP sock; } ClientConnection; + +Result(ClientConnection*) ClientConnection_accept(ServerCredentials* server_credentials, + Socket sock, EndpointIPv4 client_end, u64 session_id); + +void ClientConnection_close(ClientConnection* conn); \ No newline at end of file