From 2db37bb902fec37be06f20b53d27daf1ee32ff6d Mon Sep 17 00:00:00 2001 From: Timerix Date: Sat, 8 Nov 2025 18:21:47 +0500 Subject: [PATCH] implemented EncryptedSocketTCP_recvStruct and EncryptedSocketTCP_recvRSA --- dependencies/tlibc | 2 +- src/client/ServerConnection.c | 68 +++++++--------------- src/cryptography/AES.c | 72 ++++++++++++++++++++---- src/cryptography/AES.h | 12 ++++ src/cryptography/RSA.c | 32 ++++++----- src/cryptography/RSA.h | 5 +- src/network/encrypted_sockets.c | 99 +++++++++++++++++++++++++++++++-- src/network/encrypted_sockets.h | 28 ++++++++++ src/network/socket.c | 6 +- src/network/socket.h | 2 +- src/server/ClientConnection.c | 82 +++++++++------------------ src/server/server.c | 12 +--- 12 files changed, 267 insertions(+), 153 deletions(-) diff --git a/dependencies/tlibc b/dependencies/tlibc index 00a1a29..30c141f 160000 --- a/dependencies/tlibc +++ b/dependencies/tlibc @@ -1 +1 @@ -Subproject commit 00a1a29d342c6b3bcb8015a661364e10a3134fe4 +Subproject commit 30c141f587ac8fa635812c6f4713dbc20b18d7c9 diff --git a/src/client/ServerConnection.c b/src/client/ServerConnection.c index 68a16d2..42039c4 100644 --- a/src/client/ServerConnection.c +++ b/src/client/ServerConnection.c @@ -67,45 +67,25 @@ Result(ServerConnection*) ServerConnection_open(ClientCredentials* client_creden try_void(socket_connect(_s, conn->server_end)); EncryptedSocketTCP_construct(&conn->sock, _s, NETWORK_BUFFER_SIZE, conn->session_key); - Array(u8) buffer = Array_alloc_size(NETWORK_BUFFER_SIZE); - // fix for valgrind false detected errors about uninitialized memory - Array_memset(buffer, 0xCC); - Defer(free(buffer.data)); - - // construct PacketHeader and ClientHandshake in buffer - PacketHeader_construct(buffer.data, PROTOCOL_VERSION, - PacketType_ClientHandshake, sizeof(ClientHandshake)); - ClientHandshake_construct( - Array_sliceAfter(buffer, sizeof(PacketHeader)).data, + // send PacketHeader and ClientHandshake + // encryption by server public key + PacketHeader packet_header; + PacketHeader_construct(&packet_header, + PROTOCOL_VERSION, PacketType_ClientHandshake, sizeof(ClientHandshake)); + ClientHandshake client_handshake; + ClientHandshake_construct(&client_handshake, conn->session_key); - u32 header_and_message_size = sizeof(PacketHeader) + sizeof(ClientHandshake); - // encrypt message by server public key - Array(u8) bufferPart_encryptedClientHandshake = Array_sliceAfter(buffer, header_and_message_size); - try(u32 rsa_enc_size, u, - RSAEncryptor_encrypt( - &conn->rsa_enc, - Array_sliceBefore(buffer, header_and_message_size), - bufferPart_encryptedClientHandshake - ) - ); - bufferPart_encryptedClientHandshake.size = rsa_enc_size; - // send encrypted message - try_void(socket_send(conn->sock.sock, bufferPart_encryptedClientHandshake)); + try_void(EncryptedSocketTCP_sendStructRSA(&conn->sock, &conn->rsa_enc, &packet_header)); + try_void(EncryptedSocketTCP_sendStructRSA(&conn->sock, &conn->rsa_enc, &client_handshake)); // receive server response - try_void( - EncryptedSocketTCP_recv(&conn->sock, - Array_sliceBefore(buffer, sizeof(PacketHeader)), - SocketRecvFlag_WaitAll - ) - ); - PacketHeader* packet_header = buffer.data; - try_void(PacketHeader_validateMagic(packet_header)); + try_void(EncryptedSocketTCP_recvStruct(&conn->sock, &packet_header)); + try_void(PacketHeader_validateMagic(&packet_header)); // handle server response - switch(packet_header->type){ + switch(packet_header.type){ case PacketType_ErrorMessage: { - u32 err_msg_size = packet_header->content_size; + u32 err_msg_size = packet_header.content_size; if(err_msg_size > conn->sock.recv_buf.size) err_msg_size = conn->sock.recv_buf.size; Array(u8) err_buf = Array_alloc_size(err_msg_size + 1); @@ -119,8 +99,8 @@ Result(ServerConnection*) ServerConnection_open(ClientCredentials* client_creden try_void( EncryptedSocketTCP_recv( &conn->sock, - Array_sliceBefore(err_buf, err_msg_size), - SocketRecvFlag_WaitAll + Array_sliceTo(err_buf, err_msg_size), + SocketRecvFlag_WholeBuffer ) ); @@ -129,23 +109,13 @@ Result(ServerConnection*) ServerConnection_open(ClientCredentials* client_creden Return RESULT_ERROR((char*)err_buf.data, true); } case PacketType_ServerHandshake: { - Array(u8) bufferPart_ServerHandshake = { - .data = (u8*)buffer.data + sizeof(PacketHeader), - .size = sizeof(ServerHandshake) - }; - try_void( - EncryptedSocketTCP_recv( - &conn->sock, - bufferPart_ServerHandshake, - SocketRecvFlag_WaitAll - ) - ); - ServerHandshake* server_handshake = bufferPart_ServerHandshake.data; - conn->session_id = server_handshake->session_id; + ServerHandshake server_handshake; + try_void(EncryptedSocketTCP_recvStruct(&conn->sock, &server_handshake)); + conn->session_id = server_handshake.session_id; break; } default: - Return RESULT_ERROR_FMT("unexpected response type: %i", packet_header->type); + Return RESULT_ERROR_FMT("unexpected response type: %i", packet_header.type); } success = true; diff --git a/src/cryptography/AES.c b/src/cryptography/AES.c index 00bbed7..ba9be05 100755 --- a/src/cryptography/AES.c +++ b/src/cryptography/AES.c @@ -4,18 +4,23 @@ // write data from src to array and increment array data pointer static inline void __Array_writeNext(Array(u8)* dst, u8* src, size_t size){ memcpy(dst->data, src, size); - *dst = Array_sliceAfter(*dst, size); + *dst = Array_sliceFrom(*dst, size); } // read data from array to dst and increment array data pointer static inline void __Array_readNext(u8* dst, Array(u8)* src, size_t size){ memcpy(dst, src->data, size); - *src = Array_sliceAfter(*src, size); + *src = Array_sliceFrom(*src, size); } +////////////////////////////////////////////////////////////////////////////// +// AESBlockEncryptor // +////////////////////////////////////////////////////////////////////////////// -void AESBlockEncryptor_construct(AESBlockEncryptor* ptr, Array(u8) key, const br_block_cbcenc_class* enc_class){ +void AESBlockEncryptor_construct(AESBlockEncryptor* ptr, + Array(u8) key, const br_block_cbcenc_class* enc_class) +{ assert(key.size == 16 || key.size == 24 || key.size == 32); ptr->enc_class = enc_class; @@ -25,7 +30,15 @@ void AESBlockEncryptor_construct(AESBlockEncryptor* ptr, Array(u8) key, const br rng_init_sha256_seedFromSystem(&ptr->rng_ctx.vtable); } -Result(u32) AESBlockEncryptor_encrypt(AESBlockEncryptor* ptr, Array(u8) src, Array(u8) dst){ +void AESBlockEncryptor_changeKey(AESBlockEncryptor* ptr, Array(u8) key) +{ + assert(key.size == 16 || key.size == 24 || key.size == 32); + ptr->enc_class->init((void*)ptr->enc_keys, key.data, key.size); +} + +Result(u32) AESBlockEncryptor_encrypt(AESBlockEncryptor* ptr, + Array(u8) src, Array(u8) dst) +{ Deferral(4); u32 encrypted_size = AESBlockEncryptor_calcDstSize(src.size); try_assert(dst.size >= encrypted_size); @@ -67,15 +80,28 @@ Result(u32) AESBlockEncryptor_encrypt(AESBlockEncryptor* ptr, Array(u8) src, Arr } +////////////////////////////////////////////////////////////////////////////// +// AESBlockDecryptor // +////////////////////////////////////////////////////////////////////////////// -void AESBlockDecryptor_construct(AESBlockDecryptor* ptr, Array(u8) key, const br_block_cbcdec_class* dec_class){ +void AESBlockDecryptor_construct(AESBlockDecryptor* ptr, + Array(u8) key, const br_block_cbcdec_class* dec_class) +{ assert(key.size == 16 || key.size == 24 || key.size == 32); ptr->dec_class = dec_class; ptr->dec_class->init((void*)ptr->dec_keys, key.data, key.size); } -Result(u32) AESBlockDecryptor_decrypt(AESBlockDecryptor* ptr, Array(u8) src, Array(u8) dst){ +void AESBlockDecryptor_changeKey(AESBlockDecryptor* ptr, Array(u8) key) +{ + assert(key.size == 16 || key.size == 24 || key.size == 32); + ptr->dec_class->init((void*)ptr->dec_keys, key.data, key.size); +} + +Result(u32) AESBlockDecryptor_decrypt(AESBlockDecryptor* ptr, + Array(u8) src, Array(u8) dst) +{ Deferral(4); try_assert(src.size >= AESBlockEncryptor_calcDstSize(0)); try_assert(src.size % 16 == 0 && "src must be array of 16-byte blocks"); @@ -116,8 +142,13 @@ Result(u32) AESBlockDecryptor_decrypt(AESBlockDecryptor* ptr, Array(u8) src, Arr } +////////////////////////////////////////////////////////////////////////////// +// AESStreamEncryptor // +////////////////////////////////////////////////////////////////////////////// -void AESStreamEncryptor_construct(AESStreamEncryptor* ptr, Array(u8) key, const br_block_ctr_class* ctr_class){ +void AESStreamEncryptor_construct(AESStreamEncryptor* ptr, + Array(u8) key, const br_block_ctr_class* ctr_class) +{ assert(key.size == 16 || key.size == 24 || key.size == 32); ptr->ctr_class = ctr_class; @@ -131,7 +162,15 @@ void AESStreamEncryptor_construct(AESStreamEncryptor* ptr, Array(u8) key, const ptr->block_counter = 0; } -Result(u32) AESStreamEncryptor_encrypt(AESStreamEncryptor* ptr, Array(u8) src, Array(u8) dst){ +void AESStreamEncryptor_changeKey(AESStreamEncryptor* ptr, Array(u8) key) +{ + assert(key.size == 16 || key.size == 24 || key.size == 32); + ptr->ctr_class->init((void*)ptr->ctr_keys, key.data, key.size); +} + +Result(u32) AESStreamEncryptor_encrypt(AESStreamEncryptor* ptr, + Array(u8) src, Array(u8) dst) +{ Deferral(4); u32 encrypted_size = AESStreamEncryptor_calcDstSize(src.size); try_assert(dst.size >= encrypted_size); @@ -164,8 +203,13 @@ Result(u32) AESStreamEncryptor_encrypt(AESStreamEncryptor* ptr, Array(u8) src, A } +////////////////////////////////////////////////////////////////////////////// +// AESStreamDecryptor // +////////////////////////////////////////////////////////////////////////////// -void AESStreamDecryptor_construct(AESStreamDecryptor* ptr, Array(u8) key, const br_block_ctr_class* ctr_class){ +void AESStreamDecryptor_construct(AESStreamDecryptor* ptr, + Array(u8) key, const br_block_ctr_class* ctr_class) +{ assert(key.size == 16 || key.size == 24 || key.size == 32); ptr->ctr_class = ctr_class; @@ -174,7 +218,15 @@ void AESStreamDecryptor_construct(AESStreamDecryptor* ptr, Array(u8) key, const ptr->block_counter = 0; } -Result(u32) AESStreamDecryptor_decrypt(AESStreamDecryptor* ptr, Array(u8) src, Array(u8) dst){ +void AESStreamDecryptor_changeKey(AESStreamDecryptor* ptr, Array(u8) key) +{ + assert(key.size == 16 || key.size == 24 || key.size == 32); + ptr->ctr_class->init((void*)ptr->ctr_keys, key.data, key.size); +} + +Result(u32) AESStreamDecryptor_decrypt(AESStreamDecryptor* ptr, + Array(u8) src, Array(u8) dst) +{ Deferral(4); // if it is the beginning of the stream, read IV diff --git a/src/cryptography/AES.h b/src/cryptography/AES.h index cad1ef4..09932bf 100644 --- a/src/cryptography/AES.h +++ b/src/cryptography/AES.h @@ -40,6 +40,9 @@ typedef struct AESBlockEncryptor { /// @param enc_class &br_aes_XXX_cbcenc_vtable void AESBlockEncryptor_construct(AESBlockEncryptor* ptr, Array(u8) key, const br_block_cbcenc_class* enc_class); +/// @param key supported sizes: 16, 24, 32 +void AESBlockEncryptor_changeKey(AESBlockEncryptor* ptr, Array(u8) key); + /// @brief Encrypts a complete message. For part-by-part encryption use AESStreamEncryptor. /// @param src array of any size /// @param dst array of size >= AESBlockEncryptor_calcDstSize(src.size) @@ -63,6 +66,9 @@ typedef struct AESBlockDecryptor { /// @param dec_class &br_aes_XXX_cbcdec_vtable void AESBlockDecryptor_construct(AESBlockDecryptor* ptr, Array(u8) key, const br_block_cbcdec_class* dec_class); +/// @param key supported sizes: 16, 24, 32 +void AESBlockDecryptor_changeKey(AESBlockDecryptor* ptr, Array(u8) key); + /// @brief Decrypts a complete message. For part-by-part decryption use AESStreamEncryptor. /// @param src array of size at least AESBlockEncryptor_calcDstSize(0). Size must be multiple of 16. /// @param dst array of size >= src.size @@ -88,6 +94,9 @@ typedef struct AESStreamEncryptor { /// @param dec_class &br_aes_XXX_ctr_vtable void AESStreamEncryptor_construct(AESStreamEncryptor* ptr, Array(u8) key, const br_block_ctr_class* ctr_class); +/// @param key supported sizes: 16, 24, 32 +void AESStreamEncryptor_changeKey(AESStreamEncryptor* ptr, Array(u8) key); + /// use this only at the beginning of the stream #define AESStreamEncryptor_calcDstSize(src_size) (src_size + __AES_STREAM_IV_SIZE) @@ -114,6 +123,9 @@ typedef struct AESStreamDecryptor { /// @param dec_class &br_aes_XXX_ctr_vtable void AESStreamDecryptor_construct(AESStreamDecryptor* ptr, Array(u8) key, const br_block_ctr_class* ctr_class); +/// @param key supported sizes: 16, 24, 32 +void AESStreamDecryptor_changeKey(AESStreamDecryptor* ptr, Array(u8) key); + /// @brief Reads IV from `src`, then decrypts data and writes it to dst /// @param src array of any size /// @param dst array of size >= src.size diff --git a/src/cryptography/RSA.c b/src/cryptography/RSA.c index 382c332..0ca141d 100644 --- a/src/cryptography/RSA.c +++ b/src/cryptography/RSA.c @@ -175,6 +175,11 @@ Result(void) RSA_parsePrivateKey_base64(cstr src, br_rsa_private_key* sk){ Return RESULT_VOID; } + +////////////////////////////////////////////////////////////////////////////// +// RSAEncryptor // +////////////////////////////////////////////////////////////////////////////// + void RSAEncryptor_construct(RSAEncryptor* ptr, const br_rsa_public_key* pk){ ptr->pk = pk; ptr->rng.vtable = &br_hmac_drbg_vtable; @@ -185,11 +190,11 @@ Result(u32) RSAEncryptor_encrypt(RSAEncryptor* ptr, Array(u8) src, Array(u8) dst u32 key_size_bytes = ptr->pk->nlen; const u32 max_src_size = RSAEncryptor_calcMaxSrcSize(key_size_bytes * 8, 256); if(src.size > max_src_size){ - return RESULT_ERROR_FMT("src.size (%u) must be <= %u (use RSAEncryptor_calcMaxSrcSize)", + return RESULT_ERROR_FMT("src.size (%u) must be <= RSAEncryptor_calcMaxSrcSize() (%u)", src.size, max_src_size); } if(dst.size < key_size_bytes){ - return RESULT_ERROR_FMT("dst.size (%u) must be >= %u (key length in bytes)", + return RESULT_ERROR_FMT("dst.size (%u) must be >= key length in bytes (%u)", dst.size, key_size_bytes); } size_t sz = br_rsa_i31_oaep_encrypt( @@ -206,30 +211,27 @@ Result(u32) RSAEncryptor_encrypt(RSAEncryptor* ptr, Array(u8) src, Array(u8) dst } +////////////////////////////////////////////////////////////////////////////// +// RSADecryptor // +////////////////////////////////////////////////////////////////////////////// + void RSADecryptor_construct(RSADecryptor* ptr, const br_rsa_private_key* sk){ ptr->sk = sk; } -Result(u32) RSADecryptor_decrypt(RSADecryptor* ptr, Array(u8) src, Array(u8) dst){ +Result(u32) RSADecryptor_decrypt(RSADecryptor* ptr, Array(u8) buffer){ u32 key_size_bits = ptr->sk->n_bitlen; - if(src.size != key_size_bits/8){ - return RESULT_ERROR_FMT("src.size (%u) must be == %u (key length in bytes)", - src.size, key_size_bits/8); + if(buffer.size != key_size_bits/8){ + return RESULT_ERROR_FMT("buffer.size (%u) must be == key length in bytes (%u)", + buffer.size, key_size_bits/8); } - const u32 max_src_size = RSAEncryptor_calcMaxSrcSize(key_size_bits, 256); - if(dst.size < max_src_size){ - return RESULT_ERROR_FMT("dst.size (%u) must be >= %u (use RSAEncryptor_calcMaxSrcSize)", - dst.size, max_src_size); - } - - memcpy(dst.data, src.data, src.size); - size_t sz = src.size; + size_t sz = buffer.size; size_t r = br_rsa_i31_oaep_decrypt( &br_sha256_vtable, NULL, 0, ptr->sk, - dst.data, &sz); + buffer.data, &sz); if(r == 0){ return RESULT_ERROR("RSA encryption failed", false); diff --git a/src/cryptography/RSA.h b/src/cryptography/RSA.h index bb2cfff..2a50c49 100644 --- a/src/cryptography/RSA.h +++ b/src/cryptography/RSA.h @@ -58,6 +58,7 @@ str RSA_serializePublicKey_base64(const br_rsa_public_key* sk); /// @param sk out public key. WARNING: .p is allocated on heap Result(void) RSA_parsePublicKey_base64(cstr src, br_rsa_public_key* sk); + ////////////////////////////////////////////////////////////////////////////// // RSAEncryptor // ////////////////////////////////////////////////////////////////////////////// @@ -90,6 +91,7 @@ https://crypto.stackexchange.com/a/42100 /// @return size of encrypted data Result(u32) RSAEncryptor_encrypt(RSAEncryptor* ptr, Array(u8) src, Array(u8) dst); + ////////////////////////////////////////////////////////////////////////////// // RSADecryptor // ////////////////////////////////////////////////////////////////////////////// @@ -102,6 +104,5 @@ typedef struct RSADecryptor { void RSADecryptor_construct(RSADecryptor* ptr, const br_rsa_private_key* sk); /// @param src buffer with size == key size in bytes -/// @param dst buffer with size >= `RSAEncryptor_calcMaxSrcSize(key_size_bits, 256)` /// @return size of decrypted data -Result(u32) RSADecryptor_decrypt(RSADecryptor* ptr, Array(u8) src, Array(u8) dst); +Result(u32) RSADecryptor_decrypt(RSADecryptor* ptr, Array(u8) buffer); diff --git a/src/network/encrypted_sockets.c b/src/network/encrypted_sockets.c index 9107707..4241bfc 100644 --- a/src/network/encrypted_sockets.c +++ b/src/network/encrypted_sockets.c @@ -1,5 +1,9 @@ #include "encrypted_sockets.h" +////////////////////////////////////////////////////////////////////////////// +// EncryptedSocketTCP // +////////////////////////////////////////////////////////////////////////////// + void EncryptedSocketTCP_construct(EncryptedSocketTCP* ptr, Socket sock, u32 crypto_buffer_size, Array(u8) aes_key) { @@ -16,6 +20,11 @@ void EncryptedSocketTCP_destroy(EncryptedSocketTCP* ptr){ free(ptr->send_buf.data); } +void EncryptedSocketTCP_changeKey(EncryptedSocketTCP* ptr, Array(u8) aes_key){ + AESStreamEncryptor_changeKey(&ptr->enc, aes_key); + AESStreamDecryptor_changeKey(&ptr->dec, aes_key); +} + Result(void) EncryptedSocketTCP_send(EncryptedSocketTCP* ptr, Array(u8) buffer) { @@ -31,7 +40,7 @@ Result(void) EncryptedSocketTCP_send(EncryptedSocketTCP* ptr, try_void( socket_send( ptr->sock, - Array_sliceBefore(ptr->send_buf, encrypted_size) + Array_sliceTo(ptr->send_buf, encrypted_size) ) ); @@ -51,14 +60,14 @@ Result(u32) EncryptedSocketTCP_recv(EncryptedSocketTCP* ptr, try(i32 received_size, i, socket_recv( ptr->sock, - Array_sliceBefore(ptr->recv_buf, size_to_receive), + Array_sliceTo(ptr->recv_buf, size_to_receive), flags ) ); try(u32 decrypted_size, u, AESStreamDecryptor_decrypt( &ptr->dec, - Array_sliceBefore(ptr->recv_buf, received_size), + Array_sliceTo(ptr->recv_buf, received_size), buffer ) ); @@ -66,7 +75,80 @@ Result(u32) EncryptedSocketTCP_recv(EncryptedSocketTCP* ptr, Return RESULT_VALUE(u, decrypted_size); } +Result(void) EncryptedSocketTCP_sendRSA(EncryptedSocketTCP* ptr, + RSAEncryptor* rsa_enc, Array(u8) buffer) +{ + Deferral(1); + try(u32 encrypted_size, u, + RSAEncryptor_encrypt( + rsa_enc, + buffer, + ptr->send_buf + ) + ); + try_void( + socket_send( + ptr->sock, + Array_sliceTo(ptr->send_buf, encrypted_size) + ) + ); + + Return RESULT_VOID; +} + + +Result(u32) EncryptedSocketTCP_recvRSA(EncryptedSocketTCP* ptr, + RSADecryptor* rsa_dec, Array(u8) buffer, SocketRecvFlag flags) +{ + Deferral(1); + + // RSA encrypts message in block of size KEY_SIZE_BYTES. + // SocketRecvFlag_WholeBuffer should be always enabled to receive such blocks. + // If this flag is set in `flags` by caller, it means decrypted message size + // must be the same as buffer size. + bool fill_whole_buffer = (flags & SocketRecvFlag_WholeBuffer) != 0; + flags |= SocketRecvFlag_WholeBuffer; + u32 size_to_receive = rsa_dec->sk->n_bitlen / 8; + + try(i32 received_size, i, + socket_recv( + ptr->sock, + Array_sliceTo(ptr->recv_buf, size_to_receive), + flags + ) + ); + try(u32 decrypted_size, u, + RSADecryptor_decrypt( + rsa_dec, + Array_sliceTo(ptr->recv_buf, received_size) + ) + ); + + if(fill_whole_buffer){ + if(decrypted_size != buffer.size){ + Return RESULT_ERROR_FMT( + "SocketRecvFlag_WholeBuffer is set, " + "but decrypted_size (%u) != buffer.size (%u)", + decrypted_size, buffer.size + ); + } + } + else if(decrypted_size > buffer.size){ + Return RESULT_ERROR_FMT( + "decrypted_size (%u) > buffer.size (%u)", + decrypted_size, buffer.size + ); + } + + memcpy(buffer.data, ptr->recv_buf.data, decrypted_size); + Return RESULT_VALUE(u, decrypted_size); +} + + +////////////////////////////////////////////////////////////////////////////// +// EncryptedSocketUDP // +////////////////////////////////////////////////////////////////////////////// void EncryptedSocketUDP_construct(EncryptedSocketUDP* ptr, Socket sock, u32 crypto_buffer_size, Array(u8) aes_key) @@ -84,6 +166,11 @@ void EncryptedSocketUDP_destroy(EncryptedSocketUDP* ptr){ free(ptr->send_buf.data); } +void EncryptedSocketUDP_changeKey(EncryptedSocketUDP* ptr, Array(u8) aes_key){ + AESBlockEncryptor_changeKey(&ptr->enc, aes_key); + AESBlockDecryptor_changeKey(&ptr->dec, aes_key); +} + Result(void) EncryptedSocketUDP_sendto(EncryptedSocketUDP* ptr, Array(u8) buffer, EndpointIPv4 remote_end) { @@ -99,7 +186,7 @@ Result(void) EncryptedSocketUDP_sendto(EncryptedSocketUDP* ptr, try_void( socket_sendto( ptr->sock, - Array_sliceBefore(ptr->send_buf, encrypted_size), + Array_sliceTo(ptr->send_buf, encrypted_size), remote_end ) ); @@ -117,7 +204,7 @@ Result(i32) EncryptedSocketUDP_recvfrom(EncryptedSocketUDP* ptr, try(i32 received_size, i, socket_recvfrom( ptr->sock, - Array_sliceBefore(ptr->recv_buf, size_to_receive), + Array_sliceTo(ptr->recv_buf, size_to_receive), flags, remote_end ) @@ -125,7 +212,7 @@ Result(i32) EncryptedSocketUDP_recvfrom(EncryptedSocketUDP* ptr, try(u32 decrypted_size, u, AESBlockDecryptor_decrypt( &ptr->dec, - Array_sliceBefore(ptr->recv_buf, received_size), + Array_sliceTo(ptr->recv_buf, received_size), buffer ) ); diff --git a/src/network/encrypted_sockets.h b/src/network/encrypted_sockets.h index a0c1cab..af50f5c 100644 --- a/src/network/encrypted_sockets.h +++ b/src/network/encrypted_sockets.h @@ -1,6 +1,7 @@ #pragma once #include "network/socket.h" #include "cryptography/AES.h" +#include "cryptography/RSA.h" ////////////////////////////////////////////////////////////////////////////// // EncryptedSocketTCP // @@ -20,12 +21,37 @@ void EncryptedSocketTCP_construct(EncryptedSocketTCP* ptr, /// closes the socket void EncryptedSocketTCP_destroy(EncryptedSocketTCP* ptr); +void EncryptedSocketTCP_changeKey(EncryptedSocketTCP* ptr, Array(u8) aes_key); + Result(void) EncryptedSocketTCP_send(EncryptedSocketTCP* ptr, Array(u8) buffer); +#define EncryptedSocketTCP_sendStruct(socket, structPtr)\ + EncryptedSocketTCP_send(socket,\ + Array_construct_size(structPtr, sizeof(*structPtr))) + Result(u32) EncryptedSocketTCP_recv(EncryptedSocketTCP* ptr, Array(u8) buffer, SocketRecvFlag flags); +#define EncryptedSocketTCP_recvStruct(socket, structPtr)\ + EncryptedSocketTCP_recv(socket,\ + Array_construct_size(structPtr, sizeof(*structPtr)),\ + SocketRecvFlag_WholeBuffer) + +Result(void) EncryptedSocketTCP_sendRSA(EncryptedSocketTCP* ptr, + RSAEncryptor* rsa_enc, Array(u8) buffer); + +#define EncryptedSocketTCP_sendStructRSA(socket, rsa_enc, structPtr)\ + EncryptedSocketTCP_sendRSA(socket, rsa_enc,\ + Array_construct_size(structPtr, sizeof(*structPtr))) + +Result(u32) EncryptedSocketTCP_recvRSA(EncryptedSocketTCP* ptr, + RSADecryptor* rsa_dec, Array(u8) buffer, SocketRecvFlag flags); + +#define EncryptedSocketTCP_recvStructRSA(socket, rsa_dec, structPtr)\ + EncryptedSocketTCP_recvRSA(socket, rsa_dec,\ + Array_construct_size(structPtr, sizeof(*structPtr)),\ + SocketRecvFlag_WholeBuffer) ////////////////////////////////////////////////////////////////////////////// // EncryptedSocketUDP // @@ -45,6 +71,8 @@ void EncryptedSocketUDP_construct(EncryptedSocketUDP* ptr, /// closes the socket void EncryptedSocketUDP_destroy(EncryptedSocketUDP* ptr); +void EncryptedSocketUDP_changeKey(EncryptedSocketUDP* ptr, Array(u8) aes_key); + Result(void) EncryptedSocketUDP_sendto(EncryptedSocketUDP* ptr, Array(u8) buffer, EndpointIPv4 remote_end); diff --git a/src/network/socket.c b/src/network/socket.c index e615876..9541361 100755 --- a/src/network/socket.c +++ b/src/network/socket.c @@ -86,7 +86,7 @@ static inline int SocketRecvFlags_toStd(SocketRecvFlag flags){ int f = 0; if (flags & SocketRecvFlag_Peek) f |= MSG_PEEK; - if (flags & SocketRecvFlag_WaitAll) + if (flags & SocketRecvFlag_WholeBuffer) f |= MSG_WAITALL; return f; } @@ -96,7 +96,7 @@ Result(i32) socket_recv(Socket s, Array(u8) buffer, SocketRecvFlag flags){ if(r < 0){ return RESULT_ERROR_SOCKET(); } - if(r == 0 || (flags & SocketRecvFlag_WaitAll && (u32)r != buffer.size)) + if(r == 0 || (flags & SocketRecvFlag_WholeBuffer && (u32)r != buffer.size)) { return RESULT_ERROR("Socket closed", false); } @@ -111,7 +111,7 @@ Result(i32) socket_recvfrom(Socket s, Array(u8) buffer, SocketRecvFlag flags, NU if(r < 0){ return RESULT_ERROR_SOCKET(); } - if(r == 0 || (flags & SocketRecvFlag_WaitAll && (u32)r != buffer.size)) + if(r == 0 || (flags & SocketRecvFlag_WholeBuffer && (u32)r != buffer.size)) { return RESULT_ERROR("Socket closed", false); } diff --git a/src/network/socket.h b/src/network/socket.h index b552dc5..72ff845 100755 --- a/src/network/socket.h +++ b/src/network/socket.h @@ -13,7 +13,7 @@ typedef enum SocketShutdownType { typedef enum SocketRecvFlag { SocketRecvFlag_None = 0, SocketRecvFlag_Peek = 0b1 /* next recv call will read the same data */, - SocketRecvFlag_WaitAll = 0b10 /* waits until buffer is full */, + SocketRecvFlag_WholeBuffer = 0b10 /* waits until buffer is full */, } SocketRecvFlag; typedef i64 Socket; diff --git a/src/server/ClientConnection.c b/src/server/ClientConnection.c index 4df959a..0e690d5 100644 --- a/src/server/ClientConnection.c +++ b/src/server/ClientConnection.c @@ -25,74 +25,42 @@ Result(ClientConnection*) ClientConnection_accept(ServerCredentials* server_cred conn->client_end = client_end; conn->session_id = session_id; conn->session_key = Array_alloc_size(AES_SESSION_KEY_SIZE); - - Array(u8) buffer = Array_alloc_size(NETWORK_BUFFER_SIZE); - // fix for valgrind false detected errors about uninitialized memory - Array_memset(buffer, 0xCC); - Defer(free(buffer.data)); - + // correct session key will be received from client later + Array_memset(conn->session_key, 0); + EncryptedSocketTCP_construct(&conn->sock, sock_tcp, NETWORK_BUFFER_SIZE, conn->session_key); // TODO: set socket timeout to 5 seconds - - // receive message encrypted by server public key - u32 header_and_message_size = sizeof(PacketHeader) + sizeof(ClientHandshake); - Array(u8) bufferPart_encryptedClientHandshake = { - .data = (u8*)buffer.data + header_and_message_size, - .size = server_credentials->rsa_pk.nlen - }; - try_void( - socket_recv( - sock_tcp, - bufferPart_encryptedClientHandshake, - SocketRecvFlag_WaitAll - ) - ); - - // decrypt the message using server private key + + // decrypt the rsa messages using server private key RSADecryptor rsa_dec; RSADecryptor_construct(&rsa_dec, &server_credentials->rsa_sk); - try(u32 rsa_dec_size, u, - RSADecryptor_decrypt( - &rsa_dec, - bufferPart_encryptedClientHandshake, - buffer - ) - ); - - // validate client handshake - if(rsa_dec_size != header_and_message_size){ - Return RESULT_ERROR_FMT( - "decrypted message (size: %u) is not a ClientHandshake (size: %u)", - rsa_dec_size, header_and_message_size - ); - } - PacketHeader* packet_header = buffer.data; - ClientHandshake* client_handshake = Array_sliceAfter(buffer, sizeof(PacketHeader)).data; - try_void(PacketHeader_validateMagic(packet_header)); - if(packet_header->type != PacketType_ClientHandshake){ + + // receive PacketHeader + PacketHeader packet_header; + try_void(EncryptedSocketTCP_recvStructRSA(&conn->sock, &rsa_dec, &packet_header)); + try_void(PacketHeader_validateMagic(&packet_header)); + if(packet_header.type != PacketType_ClientHandshake){ Return RESULT_ERROR_FMT( "received message of unexpected type: %u", - packet_header->type + packet_header.type ); } + // receive ClientHandshake + ClientHandshake client_handshake; + try_void(EncryptedSocketTCP_recvStructRSA(&conn->sock, &rsa_dec, &client_handshake)); + // use received session key - memcpy(conn->session_key.data, client_handshake->session_key, conn->session_key.size); - EncryptedSocketTCP_construct(&conn->sock, sock_tcp, NETWORK_BUFFER_SIZE, conn->session_key); + memcpy(conn->session_key.data, client_handshake.session_key, conn->session_key.size); + EncryptedSocketTCP_changeKey(&conn->sock, conn->session_key); - // construct PacketHeader and ServerHandshake in buffer - PacketHeader_construct(buffer.data, PROTOCOL_VERSION, - PacketType_ServerHandshake, sizeof(ServerHandshake)); - ServerHandshake_construct( - Array_sliceAfter(buffer, sizeof(PacketHeader)).data, + // send PacketHeader and ServerHandshake over encrypted TCP socket + PacketHeader_construct(&packet_header, + PROTOCOL_VERSION, PacketType_ServerHandshake, sizeof(ServerHandshake)); + ServerHandshake server_handshake; + ServerHandshake_construct(&server_handshake, session_id); - // send ServerHandshake over encrypted TCP socket - header_and_message_size = sizeof(PacketHeader) + sizeof(ServerHandshake); - try_void( - EncryptedSocketTCP_send( - &conn->sock, - Array_sliceBefore(buffer, header_and_message_size) - ) - ); + try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &packet_header)); + try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &server_handshake)); success = true; Return RESULT_VALUE(p, conn); diff --git a/src/server/server.c b/src/server/server.c index e209a86..dc73526 100644 --- a/src/server/server.c +++ b/src/server/server.c @@ -72,7 +72,7 @@ Result(void) server_run(cstr server_endpoint_cstr, cstr config_path){ //TODO: use async IO instead of threads to not waste system resources // while waiting for incoming data in 100500 threads try_stderrcode(pthread_create(&conn_thread, NULL, handle_connection, args)); - try_stderrcode(pthread_detach(&conn_thread)); + try_stderrcode(pthread_detach(conn_thread)); } Return RESULT_VOID; @@ -113,15 +113,9 @@ static Result(void) try_handle_connection(ConnectionHandlerArgs* args, cstr log_ ); logInfo(log_ctx, "session accepted"); - // handle requests - - Array(u8) buffer = Array_alloc_size(NETWORK_BUFFER_SIZE); - // fix for valgrind false detected errors about uninitialized memory - Array_memset(buffer, 0xCC); - Defer(free(buffer.data)); - u32 dec_size = 0; - + // handle unauthorized requests while(true){ + sleepMsec(10); }