From 94fcbe5dafb5ef9361f72977977ed32528bfd8a9 Mon Sep 17 00:00:00 2001 From: Timerix Date: Sun, 2 Nov 2025 15:26:30 +0500 Subject: [PATCH] fixed connection bugs --- src/client/ServerConnection.c | 13 ++++++------ src/client/client.c | 33 ++++++++++++++++++++++-------- src/cryptography/AES.c | 1 - src/cryptography/AES.h | 3 ++- src/cryptography/RSA.c | 27 ++++++++++++++++-------- src/cryptography/RSA.h | 5 +++-- src/network/encrypted_sockets.c | 1 + src/network/tcp-chat-protocol/v1.c | 2 +- src/server/ClientConnection.c | 5 ++++- 9 files changed, 61 insertions(+), 29 deletions(-) diff --git a/src/client/ServerConnection.c b/src/client/ServerConnection.c index 00bb61d..7cd3189 100644 --- a/src/client/ServerConnection.c +++ b/src/client/ServerConnection.c @@ -60,8 +60,8 @@ Result(ServerConnection*) ServerConnection_open(ClientCredentials* client_creden br_hmac_drbg_generate(&key_rng, conn->session_key.data, conn->session_key.size); // connect to server address - Socket _s; - try(_s, i, socket_open_TCP()); + try(Socket _s, i, socket_open_TCP()); + // TODO: set socket timeout to 5 seconds try_void(socket_connect(_s, conn->server_end)); EncryptedSocketTCP_construct(&conn->sock, _s, conn->session_key); @@ -73,11 +73,12 @@ Result(ServerConnection*) ServerConnection_open(ClientCredentials* client_creden // construct ClientHandshake in dec_buf try_void(ClientHandshake_tryConstruct((ClientHandshake*)dec_buf.data, conn->session_key)); + dec_size = sizeof(ClientHandshake); // encrypt by server public key try(enc_size, u, RSAEncryptor_encrypt( &conn->rsa_enc, - Array_sliceBefore(dec_buf, sizeof(ClientHandshake)), + Array_sliceBefore(dec_buf, dec_size), enc_buf ) ); @@ -107,7 +108,7 @@ Result(ServerConnection*) ServerConnection_open(ClientCredentials* client_creden ); // receive error message of length packet_header->content_size - enc_size = AESStreamEncryptor_calcDstSize(packet_header->content_size); + enc_size = packet_header->content_size; if(enc_size > enc_buf.size) enc_size = enc_buf.size; try_void( @@ -124,12 +125,12 @@ Result(ServerConnection*) ServerConnection_open(ClientCredentials* client_creden Return RESULT_ERROR((char*)err_buf.data, true); } case PacketType_ServerHandshake: { - enc_size = AESStreamEncryptor_calcDstSize(sizeof(ServerHandshake) - sizeof(PacketHeader)); + enc_size = sizeof(ServerHandshake) - sizeof(PacketHeader); try_void( EncryptedSocketTCP_recv( &conn->sock, Array_sliceBefore(enc_buf, enc_size), - Array_sliceAfter(dec_buf, sizeof(PacketHeader)), + dec_buf, SocketRecvFlag_WaitAll ) ); diff --git a/src/client/client.c b/src/client/client.c index ffcd6b1..7b8cfb4 100644 --- a/src/client/client.c +++ b/src/client/client.c @@ -23,16 +23,32 @@ static Result(void) commandExec(str command, bool* stop); static Result(void) askUserNameAndPassword(ClientCredentials** cred){ Deferral(8); - printf("username: "); - char username[1024]; - fgets(username, sizeof(username), stdin); + char username_buf[1024]; + str usrername = str_null; + while(true) { + printf("username: "); + fgets(username_buf, sizeof(username_buf), stdin); + usrername = str_from_cstr(username_buf); + if(usrername.size < 4){ + printf("ERROR: username length must be at least 4\n"); + } + else break; + } - printf("password: "); - char password[1024]; - // TODO: hide password - fgets(password, sizeof(password), stdin); + char password_buf[1024]; + str password = str_null; + while(true) { + printf("password: "); + // TODO: hide password + fgets(password_buf, sizeof(password_buf), stdin); + password = str_from_cstr(password_buf); + if(password.size < 8){ + printf("ERROR: password length must be at least 8\n"); + } + else break; + } - try(*cred, p, ClientCredentials_create(str_from_cstr(username), str_from_cstr(password))); + try(*cred, p, ClientCredentials_create(usrername, password)); Return RESULT_VOID; } @@ -106,6 +122,7 @@ static Result(void) commandExec(str command, bool* stop){ printf("connecting to server...\n"); try(_server_connection, p, ServerConnection_open(_client_credentials, new_server_link.data)); + printf("connection established\n"); // TODO: request server info // show server info diff --git a/src/cryptography/AES.c b/src/cryptography/AES.c index 718effe..5de266f 100755 --- a/src/cryptography/AES.c +++ b/src/cryptography/AES.c @@ -176,7 +176,6 @@ void AESStreamDecryptor_construct(AESStreamDecryptor* ptr, Array(u8) key, const Result(u32) AESStreamDecryptor_decrypt(AESStreamDecryptor* ptr, Array(u8) src, Array(u8) dst){ Deferral(4); - try_assert(src.size >= AESStreamEncryptor_calcDstSize(0)); u32 decrypted_size = __AESStreamDecryptor_calcDstSize(src.size); try_assert(dst.size >= decrypted_size); diff --git a/src/cryptography/AES.h b/src/cryptography/AES.h index a6287d2..3dbf197 100644 --- a/src/cryptography/AES.h +++ b/src/cryptography/AES.h @@ -88,6 +88,7 @@ typedef struct AESStreamEncryptor { /// @param dec_class &br_aes_XXX_ctr_vtable void AESStreamEncryptor_construct(AESStreamEncryptor* ptr, Array(u8) key, const br_block_ctr_class* ctr_class); +/// use this only at the beginning of the stream #define AESStreamEncryptor_calcDstSize(src_size) (src_size + __AES_BLOCK_IV_SIZE) /// @brief If ptr->block_counter == 0, writes random IV to `dst`. After that writes encrypted data to dst. @@ -114,7 +115,7 @@ typedef struct AESStreamDecryptor { void AESStreamDecryptor_construct(AESStreamDecryptor* ptr, Array(u8) key, const br_block_ctr_class* ctr_class); /// @brief Reads IV from `src`, then decrypts data and writes it to dst -/// @param src array of size at least AESStreamEncryptor_calcDstSize(0). +/// @param src array of any size /// @param dst array of size >= src.size /// @return size of decrypted data Result(u32) AESStreamDecryptor_decrypt(AESStreamDecryptor* ptr, Array(u8) src, Array(u8) dst); diff --git a/src/cryptography/RSA.c b/src/cryptography/RSA.c index 97c1c56..7f5beeb 100644 --- a/src/cryptography/RSA.c +++ b/src/cryptography/RSA.c @@ -180,14 +180,15 @@ void RSAEncryptor_construct(RSAEncryptor* ptr, const br_rsa_public_key* pk){ } Result(u32) RSAEncryptor_encrypt(RSAEncryptor* ptr, Array(u8) src, Array(u8) dst){ - const u32 max_src_size = RSAEncryptor_calcMaxSrcSize(ptr->pk->nlen * 8, 256); + u32 key_size_bytes = ptr->pk->nlen; + const u32 max_src_size = RSAEncryptor_calcMaxSrcSize(key_size_bytes * 8, 256); if(src.size > max_src_size){ return RESULT_ERROR_FMT("src.size (%u) must be <= %u (use RSAEncryptor_calcMaxSrcSize)", src.size, max_src_size); } - if(dst.size < ptr->pk->nlen){ + if(dst.size < key_size_bytes){ return RESULT_ERROR_FMT("dst.size (%u) must be >= %u (key length in bytes)", - dst.size, (u32)ptr->pk->nlen); + dst.size, key_size_bytes); } size_t sz = br_rsa_i31_oaep_encrypt( &ptr->rng.vtable, &br_sha256_vtable, @@ -207,18 +208,26 @@ void RSADecryptor_construct(RSADecryptor* ptr, const br_rsa_private_key* sk){ ptr->sk = sk; } -Result(u32) RSADecryptor_decrypt(RSADecryptor* ptr, Array(u8) buf){ - if(buf.size != ptr->sk->n_bitlen/8){ - return RESULT_ERROR_FMT("buf.size (%u) must be == %u (key length in bytes)", - buf.size, ptr->sk->n_bitlen/8); +Result(u32) RSADecryptor_decrypt(RSADecryptor* ptr, Array(u8) src, Array(u8) dst){ + u32 key_size_bits = ptr->sk->n_bitlen; + if(src.size != key_size_bits/8){ + return RESULT_ERROR_FMT("src.size (%u) must be == %u (key length in bytes)", + src.size, key_size_bits/8); } - size_t sz = buf.size; + const u32 max_src_size = RSAEncryptor_calcMaxSrcSize(key_size_bits, 256); + if(dst.size < max_src_size){ + return RESULT_ERROR_FMT("dst.size (%u) must be >= %u (use RSAEncryptor_calcMaxSrcSize)", + dst.size, max_src_size); + } + + memcpy(dst.data, src.data, src.size); + size_t sz = src.size; size_t r = br_rsa_i31_oaep_decrypt( &br_sha256_vtable, NULL, 0, ptr->sk, - buf.data, &sz); + dst.data, &sz); if(r == 0){ return RESULT_ERROR("RSA encryption failed", false); diff --git a/src/cryptography/RSA.h b/src/cryptography/RSA.h index 82d2879..7cdb659 100644 --- a/src/cryptography/RSA.h +++ b/src/cryptography/RSA.h @@ -101,6 +101,7 @@ typedef struct RSADecryptor { /// RSA OAEP encryption with SHA256 hashing algorithm void RSADecryptor_construct(RSADecryptor* ptr, const br_rsa_private_key* sk); -/// @param buf buffer with size == key size in bytes +/// @param src buffer with size == key size in bytes +/// @param dst buffer with size >= `RSAEncryptor_calcMaxSrcSize(key_size_bits, 256)` /// @return size of decrypted data -Result(u32) RSADecryptor_decrypt(RSADecryptor* ptr, Array(u8) buf); +Result(u32) RSADecryptor_decrypt(RSADecryptor* ptr, Array(u8) src, Array(u8) dst); diff --git a/src/network/encrypted_sockets.c b/src/network/encrypted_sockets.c index 46e6830..6f2243e 100644 --- a/src/network/encrypted_sockets.c +++ b/src/network/encrypted_sockets.c @@ -24,6 +24,7 @@ Result(i32) EncryptedSocketTCP_recv(EncryptedSocketTCP* ptr, { Deferral(4); try(i32 received_size, i, socket_recv(ptr->sock, encrypted_buf, flags)); + //TODO: return error if WaitAll flag was set and socket closed before filling the buffer //TODO: return something when received_size == 0 (socket has been closed) encrypted_buf.size = received_size; try(i32 decrypted_size, u, AESStreamDecryptor_decrypt(&ptr->dec, encrypted_buf, decrypted_buf)); diff --git a/src/network/tcp-chat-protocol/v1.c b/src/network/tcp-chat-protocol/v1.c index dc0de44..dfd7af9 100644 --- a/src/network/tcp-chat-protocol/v1.c +++ b/src/network/tcp-chat-protocol/v1.c @@ -9,6 +9,6 @@ Result(void) ClientHandshake_tryConstruct(ClientHandshake* ptr, Array(u8) sessio } void ServerHandshake_construct(ServerHandshake* ptr, u64 session_id){ - PacketHeader_construct(&ptr->header, PROTOCOL_VERSION, PacketType_ClientHandshake, sizeof(session_id)); + PacketHeader_construct(&ptr->header, PROTOCOL_VERSION, PacketType_ServerHandshake, sizeof(session_id)); ptr->session_id = session_id; } diff --git a/src/server/ClientConnection.c b/src/server/ClientConnection.c index 609a303..c713621 100644 --- a/src/server/ClientConnection.c +++ b/src/server/ClientConnection.c @@ -31,6 +31,8 @@ Result(ClientConnection*) ClientConnection_accept(ServerCredentials* server_cred Array(u8) dec_buf = Array_alloc_size(8*1024); Defer(free(dec_buf.data)); u32 enc_size = 0, dec_size = 0; + + // TODO: set socket timeout to 5 seconds // receive message encrypted by server public key try(enc_size, u, @@ -47,7 +49,8 @@ Result(ClientConnection*) ClientConnection_accept(ServerCredentials* server_cred try(dec_size, u, RSADecryptor_decrypt( &rsa_dec, - Array_sliceBefore(enc_buf, enc_size) + Array_sliceBefore(enc_buf, enc_size), + dec_buf ) );