diff --git a/.vscode/launch.json b/.vscode/launch.json index fedd752..f03bfb9 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -7,7 +7,7 @@ "request": "launch", "program": "${workspaceFolder}/bin/tcp-chat", "windows": { "program": "${workspaceFolder}/bin/tcp-chat.exe" }, - "args": [ "-l", "127.0.0.1:9988" ], + "args": [ "-l" ], "preLaunchTask": "build_exec_dbg", "stopAtEntry": false, "cwd": "${workspaceFolder}/bin", diff --git a/dependencies/tlibc b/dependencies/tlibc index 3034e4d..1775b27 160000 --- a/dependencies/tlibc +++ b/dependencies/tlibc @@ -1 +1 @@ -Subproject commit 3034e4d1864e008a28f870ac2688053bcb6a9ce8 +Subproject commit 1775b27980d550dd9a50a81b11d797c51253ab22 diff --git a/src/config.c b/src/config.c index eaae4dc..a970e21 100644 --- a/src/config.c +++ b/src/config.c @@ -28,7 +28,7 @@ Result(void) config_findValue(str config_str, str key, str* value, bool throwNot if(throwNotFoundError){ char* key_cstr = str_copy(key).data; - char* err_msg = sprintf_malloc(key.size + 64, "can't find key '%s'", key_cstr); + char* err_msg = sprintf_malloc("can't find key '%s'", key_cstr); free(key_cstr); return RESULT_ERROR(err_msg, true); } diff --git a/src/main.c b/src/main.c index d93b4bf..e5e83f3 100755 --- a/src/main.c +++ b/src/main.c @@ -1,6 +1,8 @@ #include "network/network.h" #include "client/client.h" #include "server/server.h" +#include "tlibc/tlibc.h" +#include "tlibc/base64.h" #define _DEFAULT_CONFIG_PATH_CLIENT "tcp-chat-client.config" #define _DEFAULT_CONFIG_PATH_SERVER "tcp-chat-server.config" @@ -10,6 +12,8 @@ typedef enum ProgramMode { ServerMode, RsaGenStdin, RsaGenRandom, + RandomBytes, + RandomBytesBase64, } ProgramMode; #define arg_is(LITERAL) str_equals(arg_str, STR(LITERAL)) @@ -17,15 +21,19 @@ typedef enum ProgramMode { int main(const int argc, cstr const* argv){ Deferral(32); + try_fatal_void(tlibc_init()); + Defer(tlibc_deinit()); + try_fatal_void(network_init()); + Defer(network_deinit()); + if(br_prng_seeder_system(NULL) == NULL){ printfe("Can't get system random seeder. Bearssl is compiled incorrectly."); return 1; } ProgramMode mode = ClientMode; - cstr server_endpoint_cstr = NULL; cstr config_path = NULL; - u32 key_size = 0; + u32 size_arg = 0; for(int argi = 1; argi < argc; argi++){ str arg_str = str_from_cstr(argv[argi]); @@ -34,7 +42,7 @@ int main(const int argc, cstr const* argv){ "USAGE:\n" "no arguments Interactive client mode.\n" "-h, --help Show this message.\n" - "-l, --listen [addr:port] Start server.\n" + "-l, --listen Start server.\n" "--config [path] Load config from specified path.\n" " Default path for config is '" _DEFAULT_CONFIG_PATH_CLIENT "' or '" _DEFAULT_CONFIG_PATH_SERVER "'\n" "--rsa-gen-stdin [size] Generate RSA private and public keys based on stdin data (64Kb max).\n" @@ -42,6 +50,10 @@ int main(const int argc, cstr const* argv){ " Usage: `cat somefile | tcp-chat --gen-rsa-stdin`\n" "--rsa-gen-random [size] Generate random RSA private and public keys.\n" " size: 2048 / 3072 (default) / 4096\n" + "--random-bytes [size] Generate random bytes.\n" + " size: any number (default=32)\n" + "--random-bytes-base64 [size] Generate random bytes and print them in base64 encoding.\n" + " size: any number (default=32)\n" ); Return 0; } @@ -50,13 +62,7 @@ int main(const int argc, cstr const* argv){ printf("program mode is set already\n"); Return 1; } - mode = ServerMode; - if(++argi >= argc){ - printfe("ERROR: no endpoint specified\n"); - Return 1; - } - server_endpoint_cstr = argv[argi]; } else if(arg_is("--config")){ if(++argi >= argc){ @@ -73,9 +79,9 @@ int main(const int argc, cstr const* argv){ mode = RsaGenStdin; if(++argi >= argc){ - key_size = RSA_DEFAULT_KEY_SIZE; + size_arg = RSA_DEFAULT_KEY_SIZE; } - else if(sscanf(argv[argi], "%u", &key_size) != 1){ + else if(sscanf(argv[argi], "%u", &size_arg) != 1){ printfe("ERROR: no key size specified\n"); } } @@ -88,12 +94,40 @@ int main(const int argc, cstr const* argv){ mode = RsaGenRandom; if(++argi >= argc){ - key_size = RSA_DEFAULT_KEY_SIZE; + size_arg = RSA_DEFAULT_KEY_SIZE; } - else if(sscanf(argv[argi], "%u", &key_size) != 1){ + else if(sscanf(argv[argi], "%u", &size_arg) != 1){ printfe("ERROR: no key size specified\n"); } } + else if(arg_is("--random-bytes")){ + if(mode != ClientMode){ + printf("program mode is set already\n"); + Return 1; + } + + mode = RandomBytes; + if(++argi >= argc){ + size_arg = 32; + } + else if(sscanf(argv[argi], "%u", &size_arg) != 1){ + printfe("ERROR: no size specified\n"); + } + } + else if(arg_is("--random-bytes-base64")){ + if(mode != ClientMode){ + printf("program mode is set already\n"); + Return 1; + } + + mode = RandomBytesBase64; + if(++argi >= argc){ + size_arg = 32; + } + else if(sscanf(argv[argi], "%u", &size_arg) != 1){ + printfe("ERROR: no size specified\n"); + } + } else { printfe("ERROR: unknown argument '%s'\n" "Use '-h' to see list of avaliable arguments\n", @@ -101,9 +135,6 @@ int main(const int argc, cstr const* argv){ Return 1; } } - - try_fatal_void(network_init()); - Defer(network_deinit()); switch(mode){ case ClientMode: { @@ -122,58 +153,104 @@ int main(const int argc, cstr const* argv){ try_fatal(Server* server, p, Server_createFromConfig(config_path)); Defer(Server_free(server)); - try_fatal_void(Server_run(server, server_endpoint_cstr)); + try_fatal_void(Server_run(server)); break; } case RsaGenStdin: { - size_t input_max_size = 64*1024; - char* input_buf = malloc(input_max_size); - Defer(free(input_buf)); - size_t read_n = fread(input_buf, 1, input_max_size, stdin); - if(read_n == 0){ + printfe("reading stdin...\n"); + Array(u8) input_buf = Array_alloc_size(64*1024); + Defer(free(input_buf.data)); + br_hmac_drbg_context rng = { .vtable = &br_hmac_drbg_vtable }; + br_hmac_drbg_init(&rng, &br_sha256_vtable, NULL, 0); + i64 read_n = 0; + do { + read_n = fread(input_buf.data, 1, input_buf.size, stdin); + if(read_n < 0){ printfe("ERROR: no input\n"); Return 1; } - str input_str = str_construct(input_buf, read_n, false); - printfe("generating RSA key pair based on stdin...\n"); - br_rsa_private_key sk; - br_rsa_public_key pk; - try_fatal_void(RSA_generateKeyPairFromPassword(key_size, &sk, &pk, input_str)); - Defer( - RSA_destroyPrivateKey(&sk); - RSA_destroyPublicKey(&pk); - ); - - str sk_str = RSA_serializePrivateKey_base64(&sk); - printf("rsa_private_key = %s\n", sk_str.data); - free(sk_str.data); - - str pk_str = RSA_serializePublicKey_base64(&pk); - printf("\nrsa_public_key = %s\n", pk_str.data); - free(pk_str.data); - } + // put bytes to rng as seed + br_hmac_drbg_update(&rng, input_buf.data, read_n); + } while(read_n == input_buf.size); + printfe("generating RSA key pair based on stdin...\n"); + br_rsa_private_key sk; + br_rsa_public_key pk; + try_fatal_void(RSA_generateKeyPair(size_arg, &sk, &pk, &rng.vtable)); + Defer( + RSA_destroyPrivateKey(&sk); + RSA_destroyPublicKey(&pk); + ); + + str sk_str = RSA_serializePrivateKey_base64(&sk); + printf("rsa_private_key = %s\n", sk_str.data); + free(sk_str.data); + + str pk_str = RSA_serializePublicKey_base64(&pk); + printf("\nrsa_public_key = %s\n", pk_str.data); + free(pk_str.data); break; + } case RsaGenRandom: { - printfe("generating random RSA key pair...\n"); - br_rsa_private_key sk; - br_rsa_public_key pk; - try_fatal_void(RSA_generateKeyPairFromSystemRandom(key_size, &sk, &pk)); - Defer( - RSA_destroyPrivateKey(&sk); - RSA_destroyPublicKey(&pk); - ); + printfe("generating random RSA key pair...\n"); + br_rsa_private_key sk; + br_rsa_public_key pk; + try_fatal_void(RSA_generateKeyPairFromSystemRandom(size_arg, &sk, &pk)); + Defer( + RSA_destroyPrivateKey(&sk); + RSA_destroyPublicKey(&pk); + ); - str sk_str = RSA_serializePrivateKey_base64(&sk); - printf("rsa_private_key = %s\n", sk_str.data); - free(sk_str.data); - - str pk_str = RSA_serializePublicKey_base64(&pk); - printf("\nrsa_public_key = %s\n", pk_str.data); - free(pk_str.data); - } + str sk_str = RSA_serializePrivateKey_base64(&sk); + printf("rsa_private_key = %s\n", sk_str.data); + free(sk_str.data); + + str pk_str = RSA_serializePublicKey_base64(&pk); + printf("\nrsa_public_key = %s\n", pk_str.data); + free(pk_str.data); break; + } + + case RandomBytes: { + printfe("generating random bytes...\n"); + br_hmac_drbg_context rng = { .vtable = &br_hmac_drbg_vtable }; + rng_init_sha256_seedFromSystem(&rng.vtable); + Array(u8) random_buf = Array_alloc_size(1024); + u32 full_buffers_n = size_arg / random_buf.size; + u32 remaining_n = size_arg % random_buf.size; + while(full_buffers_n > 0){ + full_buffers_n--; + br_hmac_drbg_generate(&rng, random_buf.data, random_buf.size); + fwrite(random_buf.data, 1, random_buf.size, stdout); + } + + br_hmac_drbg_generate(&rng, random_buf.data, remaining_n); + fwrite(random_buf.data, 1, remaining_n, stdout); + break; + } + + case RandomBytesBase64: { + printfe("generating random bytes...\n"); + br_hmac_drbg_context rng = { .vtable = &br_hmac_drbg_vtable }; + rng_init_sha256_seedFromSystem(&rng.vtable); + Array(u8) random_buf = Array_alloc_size(1024); + Array(u8) base64_buf = Array_alloc_size(base64_encodedSize(random_buf.size)); + u32 full_buffers_n = size_arg / random_buf.size; + u32 remaining_n = size_arg % random_buf.size; + u32 enc_size = 0; + while(full_buffers_n > 0){ + full_buffers_n--; + br_hmac_drbg_generate(&rng, random_buf.data, random_buf.size); + enc_size = base64_encode(random_buf.data, random_buf.size, base64_buf.data); + fwrite(base64_buf.data, 1, enc_size, stdout); + } + + br_hmac_drbg_generate(&rng, random_buf.data, remaining_n); + enc_size = base64_encode(random_buf.data, remaining_n, base64_buf.data); + fwrite(base64_buf.data, 1, enc_size, stdout); + break; + } default: printfe("ERROR: invalid program mode %i\n", mode); diff --git a/src/network/internal.h b/src/network/internal.h index 0dc413e..5a669a3 100644 --- a/src/network/internal.h +++ b/src/network/internal.h @@ -1,6 +1,7 @@ #pragma once #include "tlibc/errors.h" #include "endpoint.h" +#include "network.h" #if !defined(KN_USE_WINSOCK) #if defined(_WIN64) || defined(_WIN32) @@ -30,10 +31,10 @@ #if KN_USE_WINSOCK #define RESULT_ERROR_SOCKET()\ - RESULT_ERROR(sprintf_malloc(64, "Winsock error %i (look in )", WSAGetLastError()), true); + RESULT_ERROR_CODE_FMT(WINSOCK2, WSAGetLastError(), "Winsock error %i (look in )", WSAGetLastError()); #else #define RESULT_ERROR_SOCKET()\ - RESULT_ERROR(strerror(errno), false); + RESULT_ERROR_ERRNO(); #endif struct sockaddr_in EndpointIPv4_toSockaddr(EndpointIPv4 end); diff --git a/src/network/network.c b/src/network/network.c index fb129b8..e1392ed 100755 --- a/src/network/network.c +++ b/src/network/network.c @@ -1,8 +1,10 @@ -#include "network.h" #include "internal.h" +ErrorCodePage_define(WINSOCK2); + Result(void) network_init(){ #if _WIN32 + ErrorCodePage_register(WINSOCK2); // Initialize Winsock WSADATA wsaData = {0}; int result = WSAStartup(MAKEWORD(2,2), &wsaData); diff --git a/src/network/network.h b/src/network/network.h index 12f080c..1eb807a 100755 --- a/src/network/network.h +++ b/src/network/network.h @@ -1,5 +1,7 @@ #pragma once #include "tlibc/errors.h" +ErrorCodePage_declare(WINSOCK2); + Result(void) network_init(); void network_deinit(); diff --git a/src/server/ClientConnection.c b/src/server/ClientConnection.c index 70ed334..de80e4e 100644 --- a/src/server/ClientConnection.c +++ b/src/server/ClientConnection.c @@ -24,8 +24,8 @@ Result(ClientConnection*) ClientConnection_accept(ConnectionHandlerArgs* args) conn->session_key = Array_alloc_size(AES_SESSION_KEY_SIZE); // correct session key will be received from client later Array_memset(conn->session_key, 0); - EncryptedSocketTCP_construct(&conn->sock, args->accepted_socket, NETWORK_BUFFER_SIZE, conn->session_key); - try_void(socket_TCP_enableAliveChecks_default(args->accepted_socket)); + EncryptedSocketTCP_construct(&conn->sock, args->accepted_socket_tcp, NETWORK_BUFFER_SIZE, conn->session_key); + try_void(socket_TCP_enableAliveChecks_default(args->accepted_socket_tcp)); // decrypt the rsa messages using server private key RSADecryptor rsa_dec; diff --git a/src/server/request_handlers/ServerPublicInfo.c b/src/server/request_handlers/ServerPublicInfo.c index 82d46c0..d36061f 100644 --- a/src/server/request_handlers/ServerPublicInfo.c +++ b/src/server/request_handlers/ServerPublicInfo.c @@ -11,6 +11,20 @@ declare_RequestHandler(ServerPublicInfo) //TODO: try find requested info Array(u8) content; + switch(req.property){ + default: + try(char* err_msg, p, sendErrorMessage(conn, res_head, + "unknown ServerPublicInfo property %u", req.property)); + logWarn(log_ctx, "%s", err_msg); + Return RESULT_VOID; + break; + case ServerPublicInfo_Name: + content = str_castTo_Array(server->name); + break; + case ServerPublicInfo_Description: + content = str_castTo_Array(server->name); + break; + } PacketHeader_construct(res_head, PROTOCOL_VERSION, PacketType_ServerPublicInfoResponse, content.size); diff --git a/src/server/request_handlers/request_handlers.h b/src/server/request_handlers/request_handlers.h index 75528b2..df93459 100644 --- a/src/server/request_handlers/request_handlers.h +++ b/src/server/request_handlers/request_handlers.h @@ -4,20 +4,20 @@ #include "log.h" -Result(char*) __sendErrorMessage(ClientConnection* conn, PacketHeader* req_head, PacketHeader* res_head, - u32 msg_buf_size, cstr format, va_list argv); -Result(char*) sendErrorMessage(ClientConnection* conn, PacketHeader* req_head, PacketHeader* res_head, - u32 msg_buf_size, cstr format, ...) ATTRIBUTE_CHECK_FORMAT_PRINTF(5, 6); +Result(char*) __sendErrorMessage_va(ClientConnection* conn, PacketHeader* res_head, + cstr format, va_list argv); +Result(char*) sendErrorMessage(ClientConnection* conn, PacketHeader* res_head, + cstr format, ...) ATTRIBUTE_CHECK_FORMAT_PRINTF(3, 4); #define declare_RequestHandler(TYPE) \ Result(void) handleRequest_##TYPE( \ - cstr log_ctx, cstr req_type_name, \ + Server* server, cstr log_ctx, cstr req_type_name, \ ClientConnection* conn, PacketHeader* req_head, PacketHeader* res_head) #define case_handleRequest(TYPE) \ case PacketType_##TYPE##Request:\ - try_void(handleRequest_##TYPE(log_ctx, #TYPE, conn, &req_head, &res_head));\ + try_void(handleRequest_##TYPE(args->server, log_ctx, #TYPE, conn, &req_head, &res_head));\ break; declare_RequestHandler(ServerPublicInfo); diff --git a/src/server/request_handlers/send_error.c b/src/server/request_handlers/send_error.c index a6eb197..f0d8c99 100644 --- a/src/server/request_handlers/send_error.c +++ b/src/server/request_handlers/send_error.c @@ -1,17 +1,18 @@ #include "request_handlers.h" -Result(char*) __sendErrorMessage(ClientConnection* conn, PacketHeader* req_head, PacketHeader* res_head, - u32 msg_buf_size, cstr format, va_list argv) +Result(char*) __sendErrorMessage_va(ClientConnection* conn, PacketHeader* res_head, + cstr format, va_list argv) { Deferral(4); - (void)req_head; - //TODO: limit ErrorMessage size to fit into EncryptedSocketTCP.internal_buffer_size - Array(u8) err_buf = Array_alloc(u8, msg_buf_size); + Array(u8) err_buf; + err_buf.data = vsprintf_malloc(format, argv); + err_buf.size = strlen(err_buf.data); + //limit ErrorMessage size to fit into EncryptedSocketTCP.internal_buffer_size + if(err_buf.size > NETWORK_BUFFER_SIZE) + err_buf.size = NETWORK_BUFFER_SIZE; bool err_complete = false; Defer(if(!err_complete) free(err_buf.data)); - vsprintf(err_buf.data, format, argv); - err_buf.size = strlen(err_buf.data); PacketHeader_construct(res_head, PROTOCOL_VERSION, PacketType_ErrorMessage, err_buf.size); @@ -22,12 +23,12 @@ Result(char*) __sendErrorMessage(ClientConnection* conn, PacketHeader* req_head, Return RESULT_VALUE(p, err_buf.data); } -Result(char*) sendErrorMessage(ClientConnection* conn, PacketHeader* req_head, PacketHeader* res_head, - u32 msg_buf_size, cstr format, ...) +Result(char*) sendErrorMessage(ClientConnection* conn, PacketHeader* res_head, + cstr format, ...) { va_list argv; va_start(argv, format); - ResultVar(char*) err_msg = __sendErrorMessage(conn, req_head, res_head, msg_buf_size, format, argv); + ResultVar(char*) err_msg = __sendErrorMessage_va(conn, res_head, format, argv); va_end(argv); return err_msg; } diff --git a/src/server/server.c b/src/server/server.c index 74c19bb..55973f1 100644 --- a/src/server/server.c +++ b/src/server/server.c @@ -1,7 +1,7 @@ #include #include "tlibc/filesystem.h" #include "tlibc/time.h" -#include "db/idb.h" +#include "tlibc/base64.h" #include "server.h" #include "config.h" #include "log.h" @@ -17,6 +17,7 @@ void Server_free(Server* server){ free(server->name.data); free(server->description.data); ServerCredentials_destroy(&server->cred); + idb_close(server->db); } Result(Server*) Server_createFromConfig(cstr config_path){ @@ -48,6 +49,12 @@ Result(Server*) Server_createFromConfig(cstr config_path){ try_void(config_findValue(config_str, STR("description"), &tmp_str, true)); server->description = str_copy(tmp_str); + // parse local_address + try_void(config_findValue(config_str, STR("local_address"), &tmp_str, true)); + char* local_end_cstr = str_copy(tmp_str).data; + Defer(free(local_end_cstr)); + try_void(EndpointIPv4_parse(local_end_cstr, &server->local_end)); + // parse rsa_private_key try_void(config_findValue(config_str, STR("rsa_private_key"), &tmp_str, true)); char* sk_base64_cstr = str_copy(tmp_str).data; @@ -60,29 +67,37 @@ Result(Server*) Server_createFromConfig(cstr config_path){ try_void(ServerCredentials_tryConstruct(&server->cred, sk_base64_cstr, pk_base64_cstr)); + // parse db_key + try_void(config_findValue(config_str, STR("db_aes_key"), &tmp_str, true)); + Array(u8) db_aes_key = Array_alloc_size(base64_decodedSize(tmp_str.data, tmp_str.size)); + base64_decode(tmp_str.data, tmp_str.size, db_aes_key.data); + + // parse db_dir and open db + try_void(config_findValue(config_str, STR("db_dir"), &tmp_str, true)); + try(server->db, p, idb_open(tmp_str, db_aes_key)); + success = true; Return RESULT_VALUE(p, server); } -Result(void) Server_run(Server* server, cstr server_endpoint_cstr){ +Result(void) Server_run(Server* server){ Deferral(16); cstr log_ctx = "ListenerThread"; logInfo(log_ctx, "starting server"); - EndpointIPv4 server_end; - try_void(EndpointIPv4_parse(server_endpoint_cstr, &server_end)); - logDebug(log_ctx, "initializing main socket"); try(Socket main_socket, i, socket_open_TCP()); - try_void(socket_bind(main_socket, server_end)); + try_void(socket_bind(main_socket, server->local_end)); try_void(socket_listen(main_socket, 512)); - logInfo(log_ctx, "server is listening at %s", server_endpoint_cstr); + str local_end_str = EndpointIPv4_toStr(server->local_end); + Defer(free(local_end_str.data)); + logInfo(log_ctx, "server is listening at %s", local_end_str.data); u64 session_id = 1; while(true){ ConnectionHandlerArgs* args = (ConnectionHandlerArgs*)malloc(sizeof(ConnectionHandlerArgs)); args->server = server; - try(args->accepted_socket, i, + try(args->accepted_socket_tcp, i, socket_accept(main_socket, &args->client_end)); args->session_id = session_id++; pthread_t conn_thread = {0}; @@ -137,9 +152,8 @@ static Result(void) try_handleConnection(ConnectionHandlerArgs* args, cstr log_c // send error message and close connection default: try(char* err_msg, p, - sendErrorMessage( - conn, &req_head, &res_head, - 128, "Received unexpected packet of type %u", + sendErrorMessage(conn, &res_head, + "Received unexpected packet of type %u", req_head.type ) ); diff --git a/src/server/server.h b/src/server/server.h index e501cc0..2de6a9c 100644 --- a/src/server/server.h +++ b/src/server/server.h @@ -2,6 +2,7 @@ #include "cryptography/AES.h" #include "cryptography/RSA.h" #include "network/encrypted_sockets.h" +#include "db/idb.h" typedef struct Server Server; @@ -27,7 +28,7 @@ typedef struct ClientConnection { typedef struct ConnectionHandlerArgs { Server* server; - Socket accepted_socket; + Socket accepted_socket_tcp; EndpointIPv4 client_end; u64 session_id; } ConnectionHandlerArgs; @@ -40,9 +41,11 @@ void ClientConnection_close(ClientConnection* conn); typedef struct Server { str name; str description; + EndpointIPv4 local_end; ServerCredentials cred; + IncrementalDB* db; } Server; Result(Server*) Server_createFromConfig(cstr config_path); void Server_free(Server* server); -Result(void) Server_run(Server* server, cstr server_endpoint_cstr); +Result(void) Server_run(Server* server);