implemented EncryptedSocketTCP_recvStruct and EncryptedSocketTCP_recvRSA

This commit is contained in:
Timerix 2025-11-08 18:21:47 +05:00
parent ee522ac401
commit 2db37bb902
12 changed files with 267 additions and 153 deletions

2
dependencies/tlibc vendored

@ -1 +1 @@
Subproject commit 00a1a29d342c6b3bcb8015a661364e10a3134fe4
Subproject commit 30c141f587ac8fa635812c6f4713dbc20b18d7c9

View File

@ -67,45 +67,25 @@ Result(ServerConnection*) ServerConnection_open(ClientCredentials* client_creden
try_void(socket_connect(_s, conn->server_end));
EncryptedSocketTCP_construct(&conn->sock, _s, NETWORK_BUFFER_SIZE, conn->session_key);
Array(u8) buffer = Array_alloc_size(NETWORK_BUFFER_SIZE);
// fix for valgrind false detected errors about uninitialized memory
Array_memset(buffer, 0xCC);
Defer(free(buffer.data));
// construct PacketHeader and ClientHandshake in buffer
PacketHeader_construct(buffer.data, PROTOCOL_VERSION,
PacketType_ClientHandshake, sizeof(ClientHandshake));
ClientHandshake_construct(
Array_sliceAfter(buffer, sizeof(PacketHeader)).data,
// send PacketHeader and ClientHandshake
// encryption by server public key
PacketHeader packet_header;
PacketHeader_construct(&packet_header,
PROTOCOL_VERSION, PacketType_ClientHandshake, sizeof(ClientHandshake));
ClientHandshake client_handshake;
ClientHandshake_construct(&client_handshake,
conn->session_key);
u32 header_and_message_size = sizeof(PacketHeader) + sizeof(ClientHandshake);
// encrypt message by server public key
Array(u8) bufferPart_encryptedClientHandshake = Array_sliceAfter(buffer, header_and_message_size);
try(u32 rsa_enc_size, u,
RSAEncryptor_encrypt(
&conn->rsa_enc,
Array_sliceBefore(buffer, header_and_message_size),
bufferPart_encryptedClientHandshake
)
);
bufferPart_encryptedClientHandshake.size = rsa_enc_size;
// send encrypted message
try_void(socket_send(conn->sock.sock, bufferPart_encryptedClientHandshake));
try_void(EncryptedSocketTCP_sendStructRSA(&conn->sock, &conn->rsa_enc, &packet_header));
try_void(EncryptedSocketTCP_sendStructRSA(&conn->sock, &conn->rsa_enc, &client_handshake));
// receive server response
try_void(
EncryptedSocketTCP_recv(&conn->sock,
Array_sliceBefore(buffer, sizeof(PacketHeader)),
SocketRecvFlag_WaitAll
)
);
PacketHeader* packet_header = buffer.data;
try_void(PacketHeader_validateMagic(packet_header));
try_void(EncryptedSocketTCP_recvStruct(&conn->sock, &packet_header));
try_void(PacketHeader_validateMagic(&packet_header));
// handle server response
switch(packet_header->type){
switch(packet_header.type){
case PacketType_ErrorMessage: {
u32 err_msg_size = packet_header->content_size;
u32 err_msg_size = packet_header.content_size;
if(err_msg_size > conn->sock.recv_buf.size)
err_msg_size = conn->sock.recv_buf.size;
Array(u8) err_buf = Array_alloc_size(err_msg_size + 1);
@ -119,8 +99,8 @@ Result(ServerConnection*) ServerConnection_open(ClientCredentials* client_creden
try_void(
EncryptedSocketTCP_recv(
&conn->sock,
Array_sliceBefore(err_buf, err_msg_size),
SocketRecvFlag_WaitAll
Array_sliceTo(err_buf, err_msg_size),
SocketRecvFlag_WholeBuffer
)
);
@ -129,23 +109,13 @@ Result(ServerConnection*) ServerConnection_open(ClientCredentials* client_creden
Return RESULT_ERROR((char*)err_buf.data, true);
}
case PacketType_ServerHandshake: {
Array(u8) bufferPart_ServerHandshake = {
.data = (u8*)buffer.data + sizeof(PacketHeader),
.size = sizeof(ServerHandshake)
};
try_void(
EncryptedSocketTCP_recv(
&conn->sock,
bufferPart_ServerHandshake,
SocketRecvFlag_WaitAll
)
);
ServerHandshake* server_handshake = bufferPart_ServerHandshake.data;
conn->session_id = server_handshake->session_id;
ServerHandshake server_handshake;
try_void(EncryptedSocketTCP_recvStruct(&conn->sock, &server_handshake));
conn->session_id = server_handshake.session_id;
break;
}
default:
Return RESULT_ERROR_FMT("unexpected response type: %i", packet_header->type);
Return RESULT_ERROR_FMT("unexpected response type: %i", packet_header.type);
}
success = true;

View File

@ -4,18 +4,23 @@
// write data from src to array and increment array data pointer
static inline void __Array_writeNext(Array(u8)* dst, u8* src, size_t size){
memcpy(dst->data, src, size);
*dst = Array_sliceAfter(*dst, size);
*dst = Array_sliceFrom(*dst, size);
}
// read data from array to dst and increment array data pointer
static inline void __Array_readNext(u8* dst, Array(u8)* src, size_t size){
memcpy(dst, src->data, size);
*src = Array_sliceAfter(*src, size);
*src = Array_sliceFrom(*src, size);
}
//////////////////////////////////////////////////////////////////////////////
// AESBlockEncryptor //
//////////////////////////////////////////////////////////////////////////////
void AESBlockEncryptor_construct(AESBlockEncryptor* ptr, Array(u8) key, const br_block_cbcenc_class* enc_class){
void AESBlockEncryptor_construct(AESBlockEncryptor* ptr,
Array(u8) key, const br_block_cbcenc_class* enc_class)
{
assert(key.size == 16 || key.size == 24 || key.size == 32);
ptr->enc_class = enc_class;
@ -25,7 +30,15 @@ void AESBlockEncryptor_construct(AESBlockEncryptor* ptr, Array(u8) key, const br
rng_init_sha256_seedFromSystem(&ptr->rng_ctx.vtable);
}
Result(u32) AESBlockEncryptor_encrypt(AESBlockEncryptor* ptr, Array(u8) src, Array(u8) dst){
void AESBlockEncryptor_changeKey(AESBlockEncryptor* ptr, Array(u8) key)
{
assert(key.size == 16 || key.size == 24 || key.size == 32);
ptr->enc_class->init((void*)ptr->enc_keys, key.data, key.size);
}
Result(u32) AESBlockEncryptor_encrypt(AESBlockEncryptor* ptr,
Array(u8) src, Array(u8) dst)
{
Deferral(4);
u32 encrypted_size = AESBlockEncryptor_calcDstSize(src.size);
try_assert(dst.size >= encrypted_size);
@ -67,15 +80,28 @@ Result(u32) AESBlockEncryptor_encrypt(AESBlockEncryptor* ptr, Array(u8) src, Arr
}
//////////////////////////////////////////////////////////////////////////////
// AESBlockDecryptor //
//////////////////////////////////////////////////////////////////////////////
void AESBlockDecryptor_construct(AESBlockDecryptor* ptr, Array(u8) key, const br_block_cbcdec_class* dec_class){
void AESBlockDecryptor_construct(AESBlockDecryptor* ptr,
Array(u8) key, const br_block_cbcdec_class* dec_class)
{
assert(key.size == 16 || key.size == 24 || key.size == 32);
ptr->dec_class = dec_class;
ptr->dec_class->init((void*)ptr->dec_keys, key.data, key.size);
}
Result(u32) AESBlockDecryptor_decrypt(AESBlockDecryptor* ptr, Array(u8) src, Array(u8) dst){
void AESBlockDecryptor_changeKey(AESBlockDecryptor* ptr, Array(u8) key)
{
assert(key.size == 16 || key.size == 24 || key.size == 32);
ptr->dec_class->init((void*)ptr->dec_keys, key.data, key.size);
}
Result(u32) AESBlockDecryptor_decrypt(AESBlockDecryptor* ptr,
Array(u8) src, Array(u8) dst)
{
Deferral(4);
try_assert(src.size >= AESBlockEncryptor_calcDstSize(0));
try_assert(src.size % 16 == 0 && "src must be array of 16-byte blocks");
@ -116,8 +142,13 @@ Result(u32) AESBlockDecryptor_decrypt(AESBlockDecryptor* ptr, Array(u8) src, Arr
}
//////////////////////////////////////////////////////////////////////////////
// AESStreamEncryptor //
//////////////////////////////////////////////////////////////////////////////
void AESStreamEncryptor_construct(AESStreamEncryptor* ptr, Array(u8) key, const br_block_ctr_class* ctr_class){
void AESStreamEncryptor_construct(AESStreamEncryptor* ptr,
Array(u8) key, const br_block_ctr_class* ctr_class)
{
assert(key.size == 16 || key.size == 24 || key.size == 32);
ptr->ctr_class = ctr_class;
@ -131,7 +162,15 @@ void AESStreamEncryptor_construct(AESStreamEncryptor* ptr, Array(u8) key, const
ptr->block_counter = 0;
}
Result(u32) AESStreamEncryptor_encrypt(AESStreamEncryptor* ptr, Array(u8) src, Array(u8) dst){
void AESStreamEncryptor_changeKey(AESStreamEncryptor* ptr, Array(u8) key)
{
assert(key.size == 16 || key.size == 24 || key.size == 32);
ptr->ctr_class->init((void*)ptr->ctr_keys, key.data, key.size);
}
Result(u32) AESStreamEncryptor_encrypt(AESStreamEncryptor* ptr,
Array(u8) src, Array(u8) dst)
{
Deferral(4);
u32 encrypted_size = AESStreamEncryptor_calcDstSize(src.size);
try_assert(dst.size >= encrypted_size);
@ -164,8 +203,13 @@ Result(u32) AESStreamEncryptor_encrypt(AESStreamEncryptor* ptr, Array(u8) src, A
}
//////////////////////////////////////////////////////////////////////////////
// AESStreamDecryptor //
//////////////////////////////////////////////////////////////////////////////
void AESStreamDecryptor_construct(AESStreamDecryptor* ptr, Array(u8) key, const br_block_ctr_class* ctr_class){
void AESStreamDecryptor_construct(AESStreamDecryptor* ptr,
Array(u8) key, const br_block_ctr_class* ctr_class)
{
assert(key.size == 16 || key.size == 24 || key.size == 32);
ptr->ctr_class = ctr_class;
@ -174,7 +218,15 @@ void AESStreamDecryptor_construct(AESStreamDecryptor* ptr, Array(u8) key, const
ptr->block_counter = 0;
}
Result(u32) AESStreamDecryptor_decrypt(AESStreamDecryptor* ptr, Array(u8) src, Array(u8) dst){
void AESStreamDecryptor_changeKey(AESStreamDecryptor* ptr, Array(u8) key)
{
assert(key.size == 16 || key.size == 24 || key.size == 32);
ptr->ctr_class->init((void*)ptr->ctr_keys, key.data, key.size);
}
Result(u32) AESStreamDecryptor_decrypt(AESStreamDecryptor* ptr,
Array(u8) src, Array(u8) dst)
{
Deferral(4);
// if it is the beginning of the stream, read IV

View File

@ -40,6 +40,9 @@ typedef struct AESBlockEncryptor {
/// @param enc_class &br_aes_XXX_cbcenc_vtable
void AESBlockEncryptor_construct(AESBlockEncryptor* ptr, Array(u8) key, const br_block_cbcenc_class* enc_class);
/// @param key supported sizes: 16, 24, 32
void AESBlockEncryptor_changeKey(AESBlockEncryptor* ptr, Array(u8) key);
/// @brief Encrypts a complete message. For part-by-part encryption use AESStreamEncryptor.
/// @param src array of any size
/// @param dst array of size >= AESBlockEncryptor_calcDstSize(src.size)
@ -63,6 +66,9 @@ typedef struct AESBlockDecryptor {
/// @param dec_class &br_aes_XXX_cbcdec_vtable
void AESBlockDecryptor_construct(AESBlockDecryptor* ptr, Array(u8) key, const br_block_cbcdec_class* dec_class);
/// @param key supported sizes: 16, 24, 32
void AESBlockDecryptor_changeKey(AESBlockDecryptor* ptr, Array(u8) key);
/// @brief Decrypts a complete message. For part-by-part decryption use AESStreamEncryptor.
/// @param src array of size at least AESBlockEncryptor_calcDstSize(0). Size must be multiple of 16.
/// @param dst array of size >= src.size
@ -88,6 +94,9 @@ typedef struct AESStreamEncryptor {
/// @param dec_class &br_aes_XXX_ctr_vtable
void AESStreamEncryptor_construct(AESStreamEncryptor* ptr, Array(u8) key, const br_block_ctr_class* ctr_class);
/// @param key supported sizes: 16, 24, 32
void AESStreamEncryptor_changeKey(AESStreamEncryptor* ptr, Array(u8) key);
/// use this only at the beginning of the stream
#define AESStreamEncryptor_calcDstSize(src_size) (src_size + __AES_STREAM_IV_SIZE)
@ -114,6 +123,9 @@ typedef struct AESStreamDecryptor {
/// @param dec_class &br_aes_XXX_ctr_vtable
void AESStreamDecryptor_construct(AESStreamDecryptor* ptr, Array(u8) key, const br_block_ctr_class* ctr_class);
/// @param key supported sizes: 16, 24, 32
void AESStreamDecryptor_changeKey(AESStreamDecryptor* ptr, Array(u8) key);
/// @brief Reads IV from `src`, then decrypts data and writes it to dst
/// @param src array of any size
/// @param dst array of size >= src.size

View File

@ -175,6 +175,11 @@ Result(void) RSA_parsePrivateKey_base64(cstr src, br_rsa_private_key* sk){
Return RESULT_VOID;
}
//////////////////////////////////////////////////////////////////////////////
// RSAEncryptor //
//////////////////////////////////////////////////////////////////////////////
void RSAEncryptor_construct(RSAEncryptor* ptr, const br_rsa_public_key* pk){
ptr->pk = pk;
ptr->rng.vtable = &br_hmac_drbg_vtable;
@ -185,11 +190,11 @@ Result(u32) RSAEncryptor_encrypt(RSAEncryptor* ptr, Array(u8) src, Array(u8) dst
u32 key_size_bytes = ptr->pk->nlen;
const u32 max_src_size = RSAEncryptor_calcMaxSrcSize(key_size_bytes * 8, 256);
if(src.size > max_src_size){
return RESULT_ERROR_FMT("src.size (%u) must be <= %u (use RSAEncryptor_calcMaxSrcSize)",
return RESULT_ERROR_FMT("src.size (%u) must be <= RSAEncryptor_calcMaxSrcSize() (%u)",
src.size, max_src_size);
}
if(dst.size < key_size_bytes){
return RESULT_ERROR_FMT("dst.size (%u) must be >= %u (key length in bytes)",
return RESULT_ERROR_FMT("dst.size (%u) must be >= key length in bytes (%u)",
dst.size, key_size_bytes);
}
size_t sz = br_rsa_i31_oaep_encrypt(
@ -206,30 +211,27 @@ Result(u32) RSAEncryptor_encrypt(RSAEncryptor* ptr, Array(u8) src, Array(u8) dst
}
//////////////////////////////////////////////////////////////////////////////
// RSADecryptor //
//////////////////////////////////////////////////////////////////////////////
void RSADecryptor_construct(RSADecryptor* ptr, const br_rsa_private_key* sk){
ptr->sk = sk;
}
Result(u32) RSADecryptor_decrypt(RSADecryptor* ptr, Array(u8) src, Array(u8) dst){
Result(u32) RSADecryptor_decrypt(RSADecryptor* ptr, Array(u8) buffer){
u32 key_size_bits = ptr->sk->n_bitlen;
if(src.size != key_size_bits/8){
return RESULT_ERROR_FMT("src.size (%u) must be == %u (key length in bytes)",
src.size, key_size_bits/8);
if(buffer.size != key_size_bits/8){
return RESULT_ERROR_FMT("buffer.size (%u) must be == key length in bytes (%u)",
buffer.size, key_size_bits/8);
}
const u32 max_src_size = RSAEncryptor_calcMaxSrcSize(key_size_bits, 256);
if(dst.size < max_src_size){
return RESULT_ERROR_FMT("dst.size (%u) must be >= %u (use RSAEncryptor_calcMaxSrcSize)",
dst.size, max_src_size);
}
memcpy(dst.data, src.data, src.size);
size_t sz = src.size;
size_t sz = buffer.size;
size_t r = br_rsa_i31_oaep_decrypt(
&br_sha256_vtable,
NULL, 0,
ptr->sk,
dst.data, &sz);
buffer.data, &sz);
if(r == 0){
return RESULT_ERROR("RSA encryption failed", false);

View File

@ -58,6 +58,7 @@ str RSA_serializePublicKey_base64(const br_rsa_public_key* sk);
/// @param sk out public key. WARNING: .p is allocated on heap
Result(void) RSA_parsePublicKey_base64(cstr src, br_rsa_public_key* sk);
//////////////////////////////////////////////////////////////////////////////
// RSAEncryptor //
//////////////////////////////////////////////////////////////////////////////
@ -90,6 +91,7 @@ https://crypto.stackexchange.com/a/42100
/// @return size of encrypted data
Result(u32) RSAEncryptor_encrypt(RSAEncryptor* ptr, Array(u8) src, Array(u8) dst);
//////////////////////////////////////////////////////////////////////////////
// RSADecryptor //
//////////////////////////////////////////////////////////////////////////////
@ -102,6 +104,5 @@ typedef struct RSADecryptor {
void RSADecryptor_construct(RSADecryptor* ptr, const br_rsa_private_key* sk);
/// @param src buffer with size == key size in bytes
/// @param dst buffer with size >= `RSAEncryptor_calcMaxSrcSize(key_size_bits, 256)`
/// @return size of decrypted data
Result(u32) RSADecryptor_decrypt(RSADecryptor* ptr, Array(u8) src, Array(u8) dst);
Result(u32) RSADecryptor_decrypt(RSADecryptor* ptr, Array(u8) buffer);

View File

@ -1,5 +1,9 @@
#include "encrypted_sockets.h"
//////////////////////////////////////////////////////////////////////////////
// EncryptedSocketTCP //
//////////////////////////////////////////////////////////////////////////////
void EncryptedSocketTCP_construct(EncryptedSocketTCP* ptr,
Socket sock, u32 crypto_buffer_size, Array(u8) aes_key)
{
@ -16,6 +20,11 @@ void EncryptedSocketTCP_destroy(EncryptedSocketTCP* ptr){
free(ptr->send_buf.data);
}
void EncryptedSocketTCP_changeKey(EncryptedSocketTCP* ptr, Array(u8) aes_key){
AESStreamEncryptor_changeKey(&ptr->enc, aes_key);
AESStreamDecryptor_changeKey(&ptr->dec, aes_key);
}
Result(void) EncryptedSocketTCP_send(EncryptedSocketTCP* ptr,
Array(u8) buffer)
{
@ -31,7 +40,7 @@ Result(void) EncryptedSocketTCP_send(EncryptedSocketTCP* ptr,
try_void(
socket_send(
ptr->sock,
Array_sliceBefore(ptr->send_buf, encrypted_size)
Array_sliceTo(ptr->send_buf, encrypted_size)
)
);
@ -51,14 +60,14 @@ Result(u32) EncryptedSocketTCP_recv(EncryptedSocketTCP* ptr,
try(i32 received_size, i,
socket_recv(
ptr->sock,
Array_sliceBefore(ptr->recv_buf, size_to_receive),
Array_sliceTo(ptr->recv_buf, size_to_receive),
flags
)
);
try(u32 decrypted_size, u,
AESStreamDecryptor_decrypt(
&ptr->dec,
Array_sliceBefore(ptr->recv_buf, received_size),
Array_sliceTo(ptr->recv_buf, received_size),
buffer
)
);
@ -66,7 +75,80 @@ Result(u32) EncryptedSocketTCP_recv(EncryptedSocketTCP* ptr,
Return RESULT_VALUE(u, decrypted_size);
}
Result(void) EncryptedSocketTCP_sendRSA(EncryptedSocketTCP* ptr,
RSAEncryptor* rsa_enc, Array(u8) buffer)
{
Deferral(1);
try(u32 encrypted_size, u,
RSAEncryptor_encrypt(
rsa_enc,
buffer,
ptr->send_buf
)
);
try_void(
socket_send(
ptr->sock,
Array_sliceTo(ptr->send_buf, encrypted_size)
)
);
Return RESULT_VOID;
}
Result(u32) EncryptedSocketTCP_recvRSA(EncryptedSocketTCP* ptr,
RSADecryptor* rsa_dec, Array(u8) buffer, SocketRecvFlag flags)
{
Deferral(1);
// RSA encrypts message in block of size KEY_SIZE_BYTES.
// SocketRecvFlag_WholeBuffer should be always enabled to receive such blocks.
// If this flag is set in `flags` by caller, it means decrypted message size
// must be the same as buffer size.
bool fill_whole_buffer = (flags & SocketRecvFlag_WholeBuffer) != 0;
flags |= SocketRecvFlag_WholeBuffer;
u32 size_to_receive = rsa_dec->sk->n_bitlen / 8;
try(i32 received_size, i,
socket_recv(
ptr->sock,
Array_sliceTo(ptr->recv_buf, size_to_receive),
flags
)
);
try(u32 decrypted_size, u,
RSADecryptor_decrypt(
rsa_dec,
Array_sliceTo(ptr->recv_buf, received_size)
)
);
if(fill_whole_buffer){
if(decrypted_size != buffer.size){
Return RESULT_ERROR_FMT(
"SocketRecvFlag_WholeBuffer is set, "
"but decrypted_size (%u) != buffer.size (%u)",
decrypted_size, buffer.size
);
}
}
else if(decrypted_size > buffer.size){
Return RESULT_ERROR_FMT(
"decrypted_size (%u) > buffer.size (%u)",
decrypted_size, buffer.size
);
}
memcpy(buffer.data, ptr->recv_buf.data, decrypted_size);
Return RESULT_VALUE(u, decrypted_size);
}
//////////////////////////////////////////////////////////////////////////////
// EncryptedSocketUDP //
//////////////////////////////////////////////////////////////////////////////
void EncryptedSocketUDP_construct(EncryptedSocketUDP* ptr,
Socket sock, u32 crypto_buffer_size, Array(u8) aes_key)
@ -84,6 +166,11 @@ void EncryptedSocketUDP_destroy(EncryptedSocketUDP* ptr){
free(ptr->send_buf.data);
}
void EncryptedSocketUDP_changeKey(EncryptedSocketUDP* ptr, Array(u8) aes_key){
AESBlockEncryptor_changeKey(&ptr->enc, aes_key);
AESBlockDecryptor_changeKey(&ptr->dec, aes_key);
}
Result(void) EncryptedSocketUDP_sendto(EncryptedSocketUDP* ptr,
Array(u8) buffer, EndpointIPv4 remote_end)
{
@ -99,7 +186,7 @@ Result(void) EncryptedSocketUDP_sendto(EncryptedSocketUDP* ptr,
try_void(
socket_sendto(
ptr->sock,
Array_sliceBefore(ptr->send_buf, encrypted_size),
Array_sliceTo(ptr->send_buf, encrypted_size),
remote_end
)
);
@ -117,7 +204,7 @@ Result(i32) EncryptedSocketUDP_recvfrom(EncryptedSocketUDP* ptr,
try(i32 received_size, i,
socket_recvfrom(
ptr->sock,
Array_sliceBefore(ptr->recv_buf, size_to_receive),
Array_sliceTo(ptr->recv_buf, size_to_receive),
flags,
remote_end
)
@ -125,7 +212,7 @@ Result(i32) EncryptedSocketUDP_recvfrom(EncryptedSocketUDP* ptr,
try(u32 decrypted_size, u,
AESBlockDecryptor_decrypt(
&ptr->dec,
Array_sliceBefore(ptr->recv_buf, received_size),
Array_sliceTo(ptr->recv_buf, received_size),
buffer
)
);

View File

@ -1,6 +1,7 @@
#pragma once
#include "network/socket.h"
#include "cryptography/AES.h"
#include "cryptography/RSA.h"
//////////////////////////////////////////////////////////////////////////////
// EncryptedSocketTCP //
@ -20,12 +21,37 @@ void EncryptedSocketTCP_construct(EncryptedSocketTCP* ptr,
/// closes the socket
void EncryptedSocketTCP_destroy(EncryptedSocketTCP* ptr);
void EncryptedSocketTCP_changeKey(EncryptedSocketTCP* ptr, Array(u8) aes_key);
Result(void) EncryptedSocketTCP_send(EncryptedSocketTCP* ptr,
Array(u8) buffer);
#define EncryptedSocketTCP_sendStruct(socket, structPtr)\
EncryptedSocketTCP_send(socket,\
Array_construct_size(structPtr, sizeof(*structPtr)))
Result(u32) EncryptedSocketTCP_recv(EncryptedSocketTCP* ptr,
Array(u8) buffer, SocketRecvFlag flags);
#define EncryptedSocketTCP_recvStruct(socket, structPtr)\
EncryptedSocketTCP_recv(socket,\
Array_construct_size(structPtr, sizeof(*structPtr)),\
SocketRecvFlag_WholeBuffer)
Result(void) EncryptedSocketTCP_sendRSA(EncryptedSocketTCP* ptr,
RSAEncryptor* rsa_enc, Array(u8) buffer);
#define EncryptedSocketTCP_sendStructRSA(socket, rsa_enc, structPtr)\
EncryptedSocketTCP_sendRSA(socket, rsa_enc,\
Array_construct_size(structPtr, sizeof(*structPtr)))
Result(u32) EncryptedSocketTCP_recvRSA(EncryptedSocketTCP* ptr,
RSADecryptor* rsa_dec, Array(u8) buffer, SocketRecvFlag flags);
#define EncryptedSocketTCP_recvStructRSA(socket, rsa_dec, structPtr)\
EncryptedSocketTCP_recvRSA(socket, rsa_dec,\
Array_construct_size(structPtr, sizeof(*structPtr)),\
SocketRecvFlag_WholeBuffer)
//////////////////////////////////////////////////////////////////////////////
// EncryptedSocketUDP //
@ -45,6 +71,8 @@ void EncryptedSocketUDP_construct(EncryptedSocketUDP* ptr,
/// closes the socket
void EncryptedSocketUDP_destroy(EncryptedSocketUDP* ptr);
void EncryptedSocketUDP_changeKey(EncryptedSocketUDP* ptr, Array(u8) aes_key);
Result(void) EncryptedSocketUDP_sendto(EncryptedSocketUDP* ptr,
Array(u8) buffer, EndpointIPv4 remote_end);

View File

@ -86,7 +86,7 @@ static inline int SocketRecvFlags_toStd(SocketRecvFlag flags){
int f = 0;
if (flags & SocketRecvFlag_Peek)
f |= MSG_PEEK;
if (flags & SocketRecvFlag_WaitAll)
if (flags & SocketRecvFlag_WholeBuffer)
f |= MSG_WAITALL;
return f;
}
@ -96,7 +96,7 @@ Result(i32) socket_recv(Socket s, Array(u8) buffer, SocketRecvFlag flags){
if(r < 0){
return RESULT_ERROR_SOCKET();
}
if(r == 0 || (flags & SocketRecvFlag_WaitAll && (u32)r != buffer.size))
if(r == 0 || (flags & SocketRecvFlag_WholeBuffer && (u32)r != buffer.size))
{
return RESULT_ERROR("Socket closed", false);
}
@ -111,7 +111,7 @@ Result(i32) socket_recvfrom(Socket s, Array(u8) buffer, SocketRecvFlag flags, NU
if(r < 0){
return RESULT_ERROR_SOCKET();
}
if(r == 0 || (flags & SocketRecvFlag_WaitAll && (u32)r != buffer.size))
if(r == 0 || (flags & SocketRecvFlag_WholeBuffer && (u32)r != buffer.size))
{
return RESULT_ERROR("Socket closed", false);
}

View File

@ -13,7 +13,7 @@ typedef enum SocketShutdownType {
typedef enum SocketRecvFlag {
SocketRecvFlag_None = 0,
SocketRecvFlag_Peek = 0b1 /* next recv call will read the same data */,
SocketRecvFlag_WaitAll = 0b10 /* waits until buffer is full */,
SocketRecvFlag_WholeBuffer = 0b10 /* waits until buffer is full */,
} SocketRecvFlag;
typedef i64 Socket;

View File

@ -25,74 +25,42 @@ Result(ClientConnection*) ClientConnection_accept(ServerCredentials* server_cred
conn->client_end = client_end;
conn->session_id = session_id;
conn->session_key = Array_alloc_size(AES_SESSION_KEY_SIZE);
Array(u8) buffer = Array_alloc_size(NETWORK_BUFFER_SIZE);
// fix for valgrind false detected errors about uninitialized memory
Array_memset(buffer, 0xCC);
Defer(free(buffer.data));
// correct session key will be received from client later
Array_memset(conn->session_key, 0);
EncryptedSocketTCP_construct(&conn->sock, sock_tcp, NETWORK_BUFFER_SIZE, conn->session_key);
// TODO: set socket timeout to 5 seconds
// receive message encrypted by server public key
u32 header_and_message_size = sizeof(PacketHeader) + sizeof(ClientHandshake);
Array(u8) bufferPart_encryptedClientHandshake = {
.data = (u8*)buffer.data + header_and_message_size,
.size = server_credentials->rsa_pk.nlen
};
try_void(
socket_recv(
sock_tcp,
bufferPart_encryptedClientHandshake,
SocketRecvFlag_WaitAll
)
);
// decrypt the message using server private key
// decrypt the rsa messages using server private key
RSADecryptor rsa_dec;
RSADecryptor_construct(&rsa_dec, &server_credentials->rsa_sk);
try(u32 rsa_dec_size, u,
RSADecryptor_decrypt(
&rsa_dec,
bufferPart_encryptedClientHandshake,
buffer
)
);
// validate client handshake
if(rsa_dec_size != header_and_message_size){
Return RESULT_ERROR_FMT(
"decrypted message (size: %u) is not a ClientHandshake (size: %u)",
rsa_dec_size, header_and_message_size
);
}
PacketHeader* packet_header = buffer.data;
ClientHandshake* client_handshake = Array_sliceAfter(buffer, sizeof(PacketHeader)).data;
try_void(PacketHeader_validateMagic(packet_header));
if(packet_header->type != PacketType_ClientHandshake){
// receive PacketHeader
PacketHeader packet_header;
try_void(EncryptedSocketTCP_recvStructRSA(&conn->sock, &rsa_dec, &packet_header));
try_void(PacketHeader_validateMagic(&packet_header));
if(packet_header.type != PacketType_ClientHandshake){
Return RESULT_ERROR_FMT(
"received message of unexpected type: %u",
packet_header->type
packet_header.type
);
}
// use received session key
memcpy(conn->session_key.data, client_handshake->session_key, conn->session_key.size);
EncryptedSocketTCP_construct(&conn->sock, sock_tcp, NETWORK_BUFFER_SIZE, conn->session_key);
// receive ClientHandshake
ClientHandshake client_handshake;
try_void(EncryptedSocketTCP_recvStructRSA(&conn->sock, &rsa_dec, &client_handshake));
// construct PacketHeader and ServerHandshake in buffer
PacketHeader_construct(buffer.data, PROTOCOL_VERSION,
PacketType_ServerHandshake, sizeof(ServerHandshake));
ServerHandshake_construct(
Array_sliceAfter(buffer, sizeof(PacketHeader)).data,
// use received session key
memcpy(conn->session_key.data, client_handshake.session_key, conn->session_key.size);
EncryptedSocketTCP_changeKey(&conn->sock, conn->session_key);
// send PacketHeader and ServerHandshake over encrypted TCP socket
PacketHeader_construct(&packet_header,
PROTOCOL_VERSION, PacketType_ServerHandshake, sizeof(ServerHandshake));
ServerHandshake server_handshake;
ServerHandshake_construct(&server_handshake,
session_id);
// send ServerHandshake over encrypted TCP socket
header_and_message_size = sizeof(PacketHeader) + sizeof(ServerHandshake);
try_void(
EncryptedSocketTCP_send(
&conn->sock,
Array_sliceBefore(buffer, header_and_message_size)
)
);
try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &packet_header));
try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &server_handshake));
success = true;
Return RESULT_VALUE(p, conn);

View File

@ -72,7 +72,7 @@ Result(void) server_run(cstr server_endpoint_cstr, cstr config_path){
//TODO: use async IO instead of threads to not waste system resources
// while waiting for incoming data in 100500 threads
try_stderrcode(pthread_create(&conn_thread, NULL, handle_connection, args));
try_stderrcode(pthread_detach(&conn_thread));
try_stderrcode(pthread_detach(conn_thread));
}
Return RESULT_VOID;
@ -113,15 +113,9 @@ static Result(void) try_handle_connection(ConnectionHandlerArgs* args, cstr log_
);
logInfo(log_ctx, "session accepted");
// handle requests
Array(u8) buffer = Array_alloc_size(NETWORK_BUFFER_SIZE);
// fix for valgrind false detected errors about uninitialized memory
Array_memset(buffer, 0xCC);
Defer(free(buffer.data));
u32 dec_size = 0;
// handle unauthorized requests
while(true){
sleepMsec(10);
}