Compare commits

...

2 Commits

Author SHA1 Message Date
ee522ac401 fixed memory issues 2025-11-06 22:36:02 +05:00
d36fe9e5b3 added internal buffers to encrypted sockets 2025-11-06 22:27:41 +05:00
16 changed files with 286 additions and 150 deletions

2
dependencies/tlibc vendored

@ -1 +1 @@
Subproject commit 5fb2db2380b678381ef455a18c8210a6a3314e60 Subproject commit 00a1a29d342c6b3bcb8015a661364e10a3134fe4

View File

@ -5,7 +5,7 @@ void ServerConnection_close(ServerConnection* conn){
if(conn == NULL) if(conn == NULL)
return; return;
RSA_destroyPublicKey(&conn->server_pk); RSA_destroyPublicKey(&conn->server_pk);
socket_close(conn->sock.sock); EncryptedSocketTCP_destroy(&conn->sock);
free(conn->session_key.data); free(conn->session_key.data);
free(conn); free(conn);
} }
@ -34,7 +34,9 @@ Result(void) ServerLink_parse(cstr server_link_cstr, EndpointIPv4* server_end_ou
Return RESULT_ERROR_FMT("server link is invalid: %s", server_link_cstr); Return RESULT_ERROR_FMT("server link is invalid: %s", server_link_cstr);
} }
str server_key_str = str_sliceAfter(server_link_str, sep_pos + 1); str server_key_str = str_sliceAfter(server_link_str, sep_pos + 1);
try_void(RSA_parsePublicKey_base64(server_key_str, server_key_out)); char* server_key_cstr = str_copy(server_key_str).data;
Defer(free(server_key_cstr));
try_void(RSA_parsePublicKey_base64(server_key_cstr, server_key_out));
Return RESULT_VOID; Return RESULT_VOID;
} }
@ -63,78 +65,82 @@ Result(ServerConnection*) ServerConnection_open(ClientCredentials* client_creden
try(Socket _s, i, socket_open_TCP()); try(Socket _s, i, socket_open_TCP());
// TODO: set socket timeout to 5 seconds // TODO: set socket timeout to 5 seconds
try_void(socket_connect(_s, conn->server_end)); try_void(socket_connect(_s, conn->server_end));
EncryptedSocketTCP_construct(&conn->sock, _s, conn->session_key); EncryptedSocketTCP_construct(&conn->sock, _s, NETWORK_BUFFER_SIZE, conn->session_key);
Array(u8) enc_buf = Array_alloc_size(8*1024); Array(u8) buffer = Array_alloc_size(NETWORK_BUFFER_SIZE);
Defer(free(enc_buf.data)); // fix for valgrind false detected errors about uninitialized memory
Array(u8) dec_buf = Array_alloc_size(8*1024); Array_memset(buffer, 0xCC);
Defer(free(dec_buf.data)); Defer(free(buffer.data));
u32 enc_size = 0, dec_size = 0;
// construct ClientHandshake in dec_buf // construct PacketHeader and ClientHandshake in buffer
try_void(ClientHandshake_tryConstruct((ClientHandshake*)dec_buf.data, conn->session_key)); PacketHeader_construct(buffer.data, PROTOCOL_VERSION,
dec_size = sizeof(ClientHandshake); PacketType_ClientHandshake, sizeof(ClientHandshake));
// encrypt by server public key ClientHandshake_construct(
try(enc_size, u, Array_sliceAfter(buffer, sizeof(PacketHeader)).data,
conn->session_key);
u32 header_and_message_size = sizeof(PacketHeader) + sizeof(ClientHandshake);
// encrypt message by server public key
Array(u8) bufferPart_encryptedClientHandshake = Array_sliceAfter(buffer, header_and_message_size);
try(u32 rsa_enc_size, u,
RSAEncryptor_encrypt( RSAEncryptor_encrypt(
&conn->rsa_enc, &conn->rsa_enc,
Array_sliceBefore(dec_buf, dec_size), Array_sliceBefore(buffer, header_and_message_size),
enc_buf bufferPart_encryptedClientHandshake
) )
); );
try_void(socket_send(conn->sock.sock, Array_sliceBefore(enc_buf, enc_size))); bufferPart_encryptedClientHandshake.size = rsa_enc_size;
// send encrypted message
try_void(socket_send(conn->sock.sock, bufferPart_encryptedClientHandshake));
// receive server response // receive server response
enc_size = AESStreamEncryptor_calcDstSize(sizeof(PacketHeader)); try_void(
try(dec_size, u,
EncryptedSocketTCP_recv(&conn->sock, EncryptedSocketTCP_recv(&conn->sock,
Array_sliceBefore(enc_buf, enc_size), Array_sliceBefore(buffer, sizeof(PacketHeader)),
dec_buf,
SocketRecvFlag_WaitAll SocketRecvFlag_WaitAll
) )
); );
try_assert(dec_size == sizeof(PacketHeader)); PacketHeader* packet_header = buffer.data;
PacketHeader* packet_header = dec_buf.data;
try_void(PacketHeader_validateMagic(packet_header)); try_void(PacketHeader_validateMagic(packet_header));
// handle server response // handle server response
switch(packet_header->type){ switch(packet_header->type){
case PacketType_ErrorMessage: { case PacketType_ErrorMessage: {
Array(u8) err_buf = Array_alloc_size(packet_header->content_size + 1); u32 err_msg_size = packet_header->content_size;
if(err_msg_size > conn->sock.recv_buf.size)
err_msg_size = conn->sock.recv_buf.size;
Array(u8) err_buf = Array_alloc_size(err_msg_size + 1);
bool err_msg_completed = false; bool err_msg_completed = false;
Defer( Defer(
if(!err_msg_completed) if(!err_msg_completed)
free(err_buf.data); free(err_buf.data);
); );
// receive error message of length packet_header->content_size // receive error message
enc_size = packet_header->content_size;
if(enc_size > enc_buf.size)
enc_size = enc_buf.size;
try_void( try_void(
EncryptedSocketTCP_recv( EncryptedSocketTCP_recv(
&conn->sock, &conn->sock,
Array_sliceBefore(enc_buf, enc_size), Array_sliceBefore(err_buf, err_msg_size),
err_buf,
SocketRecvFlag_WaitAll SocketRecvFlag_WaitAll
) )
); );
((u8*)err_buf.data)[enc_size] = 0; ((u8*)err_buf.data)[err_msg_size] = 0;
err_msg_completed = true; err_msg_completed = true;
Return RESULT_ERROR((char*)err_buf.data, true); Return RESULT_ERROR((char*)err_buf.data, true);
} }
case PacketType_ServerHandshake: { case PacketType_ServerHandshake: {
enc_size = sizeof(ServerHandshake) - sizeof(PacketHeader); Array(u8) bufferPart_ServerHandshake = {
.data = (u8*)buffer.data + sizeof(PacketHeader),
.size = sizeof(ServerHandshake)
};
try_void( try_void(
EncryptedSocketTCP_recv( EncryptedSocketTCP_recv(
&conn->sock, &conn->sock,
Array_sliceBefore(enc_buf, enc_size), bufferPart_ServerHandshake,
dec_buf,
SocketRecvFlag_WaitAll SocketRecvFlag_WaitAll
) )
); );
ServerHandshake* server_handshake = dec_buf.data; ServerHandshake* server_handshake = bufferPart_ServerHandshake.data;
conn->session_id = server_handshake->session_id; conn->session_id = server_handshake->session_id;
break; break;
} }

View File

@ -27,7 +27,9 @@ static Result(void) askUserNameAndPassword(ClientCredentials** cred){
str usrername = str_null; str usrername = str_null;
while(true) { while(true) {
printf("username: "); printf("username: ");
fgets(username_buf, sizeof(username_buf), stdin); if(fgets(username_buf, sizeof(username_buf), stdin) == NULL){
Return RESULT_ERROR("STDIN is closed", false);
}
usrername = str_from_cstr(username_buf); usrername = str_from_cstr(username_buf);
if(usrername.size < 4){ if(usrername.size < 4){
printf("ERROR: username length must be at least 4\n"); printf("ERROR: username length must be at least 4\n");
@ -40,7 +42,9 @@ static Result(void) askUserNameAndPassword(ClientCredentials** cred){
while(true) { while(true) {
printf("password: "); printf("password: ");
// TODO: hide password // TODO: hide password
fgets(password_buf, sizeof(password_buf), stdin); if(fgets(password_buf, sizeof(password_buf), stdin) == NULL){
Return RESULT_ERROR("STDIN is closed", false);
}
password = str_from_cstr(password_buf); password = str_from_cstr(password_buf);
if(password.size < 8){ if(password.size < 8){
printf("ERROR: password length must be at least 8\n"); printf("ERROR: password length must be at least 8\n");
@ -59,15 +63,23 @@ Result(void) client_run() {
} }
fputs(greeting_art.data, stdout); fputs(greeting_art.data, stdout);
Defer(
ClientCredentials_free(_client_credentials);
ServerConnection_close(_server_connection);
);
try_void(askUserNameAndPassword(&_client_credentials)); try_void(askUserNameAndPassword(&_client_credentials));
Array(char) input_buf = Array_alloc(char, 10000); Array(char) input_buf = Array_alloc(char, 10000);
Defer(free(input_buf.data));
str command_input = str_null; str command_input = str_null;
bool stop = false; bool stop = false;
while(!stop){ while(!stop){
sleepMsec(50);
fputs("> ", stdout); fputs("> ", stdout);
if(fgets(input_buf.data, input_buf.size, stdin) == NULL) if(fgets(input_buf.data, input_buf.size, stdin) == NULL){
continue; Return RESULT_ERROR("STDIN is closed", false);
}
command_input = str_from_cstr(input_buf.data); command_input = str_from_cstr(input_buf.data);
str_trim(&command_input, true); str_trim(&command_input, true);
@ -77,15 +89,12 @@ Result(void) client_run() {
Result(void) com_result = commandExec(command_input, &stop); Result(void) com_result = commandExec(command_input, &stop);
if(com_result.error){ if(com_result.error){
str e_str = Error_toStr(com_result.error); str e_str = Error_toStr(com_result.error);
printfe("%s\n", e_str.data); printf("%s\n", e_str.data);
free(e_str.data); free(e_str.data);
Error_free(com_result.error); Error_free(com_result.error);
} }
} }
free(input_buf.data);
ClientCredentials_free(_client_credentials);
ServerConnection_close(_server_connection);
Return RESULT_VOID; Return RESULT_VOID;
} }
@ -116,7 +125,9 @@ static Result(void) commandExec(str command, bool* stop){
ServerConnection_close(_server_connection); ServerConnection_close(_server_connection);
puts("Enter server address (ip:port:public_key): "); puts("Enter server address (ip:port:public_key): ");
fgets(answer_buf, answer_buf_size, stdin); if(fgets(answer_buf, answer_buf_size, stdin) == NULL){
Return RESULT_ERROR("STDIN is closed", false);
}
str new_server_link = str_from_cstr(answer_buf); str new_server_link = str_from_cstr(answer_buf);
str_trim(&new_server_link, true); str_trim(&new_server_link, true);

View File

@ -176,13 +176,14 @@ void AESStreamDecryptor_construct(AESStreamDecryptor* ptr, Array(u8) key, const
Result(u32) AESStreamDecryptor_decrypt(AESStreamDecryptor* ptr, Array(u8) src, Array(u8) dst){ Result(u32) AESStreamDecryptor_decrypt(AESStreamDecryptor* ptr, Array(u8) src, Array(u8) dst){
Deferral(4); Deferral(4);
u32 decrypted_size = __AESStreamDecryptor_calcDstSize(src.size);
try_assert(dst.size >= decrypted_size);
// if it is the beginning of the stream, read IV // if it is the beginning of the stream, read IV
if(ptr->block_counter == 0){ if(ptr->block_counter == 0){
__Array_readNext(ptr->iv, &src, __AES_STREAM_IV_SIZE); __Array_readNext(ptr->iv, &src, __AES_STREAM_IV_SIZE);
} }
// size without IV
u32 decrypted_size = src.size;
try_assert(dst.size >= decrypted_size);
// decrypt full buffers // decrypt full buffers
while(src.size > __AES_BUFFER_SIZE){ while(src.size > __AES_BUFFER_SIZE){

View File

@ -89,7 +89,7 @@ typedef struct AESStreamEncryptor {
void AESStreamEncryptor_construct(AESStreamEncryptor* ptr, Array(u8) key, const br_block_ctr_class* ctr_class); void AESStreamEncryptor_construct(AESStreamEncryptor* ptr, Array(u8) key, const br_block_ctr_class* ctr_class);
/// use this only at the beginning of the stream /// use this only at the beginning of the stream
#define AESStreamEncryptor_calcDstSize(src_size) (src_size + __AES_BLOCK_IV_SIZE) #define AESStreamEncryptor_calcDstSize(src_size) (src_size + __AES_STREAM_IV_SIZE)
/// @brief If ptr->block_counter == 0, writes random IV to `dst`. After that writes encrypted data to dst. /// @brief If ptr->block_counter == 0, writes random IV to `dst`. After that writes encrypted data to dst.
/// @param src array of any size /// @param src array of any size
@ -119,5 +119,3 @@ void AESStreamDecryptor_construct(AESStreamDecryptor* ptr, Array(u8) key, const
/// @param dst array of size >= src.size /// @param dst array of size >= src.size
/// @return size of decrypted data /// @return size of decrypted data
Result(u32) AESStreamDecryptor_decrypt(AESStreamDecryptor* ptr, Array(u8) src, Array(u8) dst); Result(u32) AESStreamDecryptor_decrypt(AESStreamDecryptor* ptr, Array(u8) src, Array(u8) dst);
#define __AESStreamDecryptor_calcDstSize(src_size) (src_size - __AES_BLOCK_IV_SIZE)

View File

@ -111,10 +111,10 @@ str RSA_serializePublicKey_base64(const br_rsa_public_key* pk){
return str_construct(serialized_buf, offset, true); return str_construct(serialized_buf, offset, true);
} }
Result(void) RSA_parsePublicKey_base64(const str src, br_rsa_public_key* pk){ Result(void) RSA_parsePublicKey_base64(cstr src, br_rsa_public_key* pk){
Deferral(8); Deferral(8);
u32 n_bitlen = 0; u32 n_bitlen = 0;
if(sscanf(src.data, "RSA-Public-%u:", &n_bitlen) != 1){ if(sscanf(src, "RSA-Public-%u:", &n_bitlen) != 1){
Return RESULT_ERROR("can't parse key size", false); Return RESULT_ERROR("can't parse key size", false);
} }
u32 key_buffer_size = BR_RSA_KBUF_PUB_SIZE(n_bitlen); u32 key_buffer_size = BR_RSA_KBUF_PUB_SIZE(n_bitlen);
@ -122,11 +122,12 @@ Result(void) RSA_parsePublicKey_base64(const str src, br_rsa_public_key* pk){
pk->elen = 4; pk->elen = 4;
pk->nlen = key_buffer_size - 4; pk->nlen = key_buffer_size - 4;
pk->e = pk->n + pk->nlen; pk->e = pk->n + pk->nlen;
u32 offset = str_seekChar(src, ':', 10) + 1; str src_str = str_from_cstr(src);
u32 offset = str_seekChar(src_str, ':', 10) + 1;
if(offset == 0){ if(offset == 0){
Return RESULT_ERROR("missing ':' before key data", false); Return RESULT_ERROR("missing ':' before key data", false);
} }
str key_base64_str = src; str key_base64_str = src_str;
key_base64_str.data += offset; key_base64_str.data += offset;
key_base64_str.size -= offset; key_base64_str.size -= offset;
u32 decoded_size = base64_decodedSize(key_base64_str.data, key_base64_str.size); u32 decoded_size = base64_decodedSize(key_base64_str.data, key_base64_str.size);
@ -140,10 +141,10 @@ Result(void) RSA_parsePublicKey_base64(const str src, br_rsa_public_key* pk){
Return RESULT_VOID; Return RESULT_VOID;
} }
Result(void) RSA_parsePrivateKey_base64(const str src, br_rsa_private_key* sk){ Result(void) RSA_parsePrivateKey_base64(cstr src, br_rsa_private_key* sk){
Deferral(8); Deferral(8);
u32 n_bitlen = 0; u32 n_bitlen = 0;
if(sscanf(src.data, "RSA-Private-%u:", &n_bitlen) != 1){ if(sscanf(src, "RSA-Private-%u:", &n_bitlen) != 1){
Return RESULT_ERROR("can't parse key size", false); Return RESULT_ERROR("can't parse key size", false);
} }
sk->n_bitlen = n_bitlen; sk->n_bitlen = n_bitlen;
@ -155,11 +156,12 @@ Result(void) RSA_parsePrivateKey_base64(const str src, br_rsa_private_key* sk){
sk->dp = sk->q + field_len; sk->dp = sk->q + field_len;
sk->dq = sk->dp + field_len; sk->dq = sk->dp + field_len;
sk->iq = sk->dq + field_len; sk->iq = sk->dq + field_len;
u32 offset = str_seekChar(src, ':', 10) + 1; str src_str = str_from_cstr(src);
u32 offset = str_seekChar(src_str, ':', 10) + 1;
if(offset == 0){ if(offset == 0){
Return RESULT_ERROR("missing ':' before key data", false); Return RESULT_ERROR("missing ':' before key data", false);
} }
str key_base64_str = src; str key_base64_str = src_str;
key_base64_str.data += offset; key_base64_str.data += offset;
key_base64_str.size -= offset; key_base64_str.size -= offset;
u32 decoded_size = base64_decodedSize(key_base64_str.data, key_base64_str.size); u32 decoded_size = base64_decodedSize(key_base64_str.data, key_base64_str.size);

View File

@ -47,7 +47,7 @@ str RSA_serializePrivateKey_base64(const br_rsa_private_key* sk);
/// @param src serialized private key format "RSA-Private-%SIZE%:%DATA_BASE64%" /// @param src serialized private key format "RSA-Private-%SIZE%:%DATA_BASE64%"
/// @param sk out private key. WARNING: .p is allocated on heap /// @param sk out private key. WARNING: .p is allocated on heap
Result(void) RSA_parsePrivateKey_base64(const str src, br_rsa_private_key* sk); Result(void) RSA_parsePrivateKey_base64(cstr src, br_rsa_private_key* sk);
/// @brief Encode key data in human-readable format /// @brief Encode key data in human-readable format
/// @param src some data /// @param src some data
@ -56,7 +56,7 @@ str RSA_serializePublicKey_base64(const br_rsa_public_key* sk);
/// @param src serialized public key format "RSA-Public-%SIZE%:%DATA_BASE64%" /// @param src serialized public key format "RSA-Public-%SIZE%:%DATA_BASE64%"
/// @param sk out public key. WARNING: .p is allocated on heap /// @param sk out public key. WARNING: .p is allocated on heap
Result(void) RSA_parsePublicKey_base64(const str src, br_rsa_public_key* sk); Result(void) RSA_parsePublicKey_base64(cstr src, br_rsa_public_key* sk);
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// RSAEncryptor // // RSAEncryptor //

View File

@ -1,63 +1,134 @@
#include "encrypted_sockets.h" #include "encrypted_sockets.h"
void EncryptedSocketTCP_construct(EncryptedSocketTCP* ptr, Socket sock, Array(u8) aes_key) void EncryptedSocketTCP_construct(EncryptedSocketTCP* ptr,
Socket sock, u32 crypto_buffer_size, Array(u8) aes_key)
{ {
ptr->sock = sock; ptr->sock = sock;
AESStreamEncryptor_construct(&ptr->enc, aes_key, AESStream_DEFAULT_CLASS); AESStreamEncryptor_construct(&ptr->enc, aes_key, AESStream_DEFAULT_CLASS);
AESStreamDecryptor_construct(&ptr->dec, aes_key, AESStream_DEFAULT_CLASS); AESStreamDecryptor_construct(&ptr->dec, aes_key, AESStream_DEFAULT_CLASS);
ptr->recv_buf = Array_alloc_size(crypto_buffer_size);
ptr->send_buf = Array_alloc_size(crypto_buffer_size);
} }
void EncryptedSocketTCP_destroy(EncryptedSocketTCP* ptr){
socket_close(ptr->sock);
free(ptr->recv_buf.data);
free(ptr->send_buf.data);
}
Result(void) EncryptedSocketTCP_send(EncryptedSocketTCP* ptr, Result(void) EncryptedSocketTCP_send(EncryptedSocketTCP* ptr,
Array(u8) decrypted_buf, Array(u8) encrypted_buf) Array(u8) buffer)
{ {
Deferral(4); Deferral(1);
try(u32 encrypted_size, u, AESStreamEncryptor_encrypt(&ptr->enc, decrypted_buf, encrypted_buf));
encrypted_buf.size = encrypted_size; try(u32 encrypted_size, u,
try_void(socket_send(ptr->sock, encrypted_buf)); AESStreamEncryptor_encrypt(
&ptr->enc,
buffer,
ptr->send_buf
)
);
try_void(
socket_send(
ptr->sock,
Array_sliceBefore(ptr->send_buf, encrypted_size)
)
);
Return RESULT_VOID; Return RESULT_VOID;
} }
Result(i32) EncryptedSocketTCP_recv(EncryptedSocketTCP* ptr, Result(u32) EncryptedSocketTCP_recv(EncryptedSocketTCP* ptr,
Array(u8) encrypted_buf, Array(u8) decrypted_buf, Array(u8) buffer, SocketRecvFlag flags)
SocketRecvFlag flags)
{ {
Deferral(4); Deferral(1);
try(i32 received_size, i, socket_recv(ptr->sock, encrypted_buf, flags));
//TODO: return error if WaitAll flag was set and socket closed before filling the buffer u32 size_to_receive = buffer.size;
//TODO: return something when received_size == 0 (socket has been closed) if(ptr->dec.block_counter == 0){
encrypted_buf.size = received_size; // There is some metadata at the beginning of AES stream
try(i32 decrypted_size, u, AESStreamDecryptor_decrypt(&ptr->dec, encrypted_buf, decrypted_buf)); size_to_receive = AESStreamEncryptor_calcDstSize(size_to_receive);
Return RESULT_VALUE(i, decrypted_size); }
try(i32 received_size, i,
socket_recv(
ptr->sock,
Array_sliceBefore(ptr->recv_buf, size_to_receive),
flags
)
);
try(u32 decrypted_size, u,
AESStreamDecryptor_decrypt(
&ptr->dec,
Array_sliceBefore(ptr->recv_buf, received_size),
buffer
)
);
Return RESULT_VALUE(u, decrypted_size);
} }
void EncryptedSocketUDP_construct(EncryptedSocketUDP* ptr, Socket sock, Array(u8) aes_key) void EncryptedSocketUDP_construct(EncryptedSocketUDP* ptr,
Socket sock, u32 crypto_buffer_size, Array(u8) aes_key)
{ {
ptr->sock = sock; ptr->sock = sock;
AESBlockEncryptor_construct(&ptr->enc, aes_key, AESBlockEncryptor_DEFAULT_CLASS); AESBlockEncryptor_construct(&ptr->enc, aes_key, AESBlockEncryptor_DEFAULT_CLASS);
AESBlockDecryptor_construct(&ptr->dec, aes_key, AESBlockDecryptor_DEFAULT_CLASS); AESBlockDecryptor_construct(&ptr->dec, aes_key, AESBlockDecryptor_DEFAULT_CLASS);
ptr->recv_buf = Array_alloc_size(crypto_buffer_size);
ptr->send_buf = Array_alloc_size(crypto_buffer_size);
}
void EncryptedSocketUDP_destroy(EncryptedSocketUDP* ptr){
socket_close(ptr->sock);
free(ptr->recv_buf.data);
free(ptr->send_buf.data);
} }
Result(void) EncryptedSocketUDP_sendto(EncryptedSocketUDP* ptr, Result(void) EncryptedSocketUDP_sendto(EncryptedSocketUDP* ptr,
Array(u8) decrypted_buf, Array(u8) encrypted_buf, Array(u8) buffer, EndpointIPv4 remote_end)
EndpointIPv4 remote_end)
{ {
Deferral(4); Deferral(1);
try(u32 encrypted_size, u, AESBlockEncryptor_encrypt(&ptr->enc, decrypted_buf, encrypted_buf));
encrypted_buf.size = encrypted_size; try(u32 encrypted_size, u,
try_void(socket_sendto(ptr->sock, encrypted_buf, remote_end)); AESBlockEncryptor_encrypt(
&ptr->enc,
buffer,
ptr->send_buf
)
);
try_void(
socket_sendto(
ptr->sock,
Array_sliceBefore(ptr->send_buf, encrypted_size),
remote_end
)
);
Return RESULT_VOID; Return RESULT_VOID;
} }
Result(i32) EncryptedSocketUDP_recvfrom(EncryptedSocketUDP* ptr, Result(i32) EncryptedSocketUDP_recvfrom(EncryptedSocketUDP* ptr,
Array(u8) encrypted_buf, Array(u8) decrypted_buf, Array(u8) buffer, SocketRecvFlag flags, NULLABLE(EndpointIPv4*) remote_end)
SocketRecvFlag flags, NULLABLE(EndpointIPv4*) remote_end)
{ {
Deferral(4); Deferral(1);
try(i32 received_size, i, socket_recvfrom(ptr->sock, encrypted_buf, flags, remote_end));
encrypted_buf.size = received_size; // There is some metadata at the start of each AES block
try(i32 decrypted_size, u, AESBlockDecryptor_decrypt(&ptr->dec, encrypted_buf, decrypted_buf)); u32 size_to_receive = AESBlockEncryptor_calcDstSize(buffer.size);
Return RESULT_VALUE(i, decrypted_size); try(i32 received_size, i,
socket_recvfrom(
ptr->sock,
Array_sliceBefore(ptr->recv_buf, size_to_receive),
flags,
remote_end
)
);
try(u32 decrypted_size, u,
AESBlockDecryptor_decrypt(
&ptr->dec,
Array_sliceBefore(ptr->recv_buf, received_size),
buffer
)
);
Return RESULT_VALUE(u, decrypted_size);
} }

View File

@ -10,16 +10,21 @@ typedef struct EncryptedSocketTCP {
Socket sock; Socket sock;
AESStreamEncryptor enc; AESStreamEncryptor enc;
AESStreamDecryptor dec; AESStreamDecryptor dec;
Array(u8) send_buf;
Array(u8) recv_buf;
} EncryptedSocketTCP; } EncryptedSocketTCP;
void EncryptedSocketTCP_construct(EncryptedSocketTCP* ptr, Socket sock, Array(u8) aes_key); void EncryptedSocketTCP_construct(EncryptedSocketTCP* ptr,
Socket sock, u32 crypto_buffer_size, Array(u8) aes_key);
/// closes the socket
void EncryptedSocketTCP_destroy(EncryptedSocketTCP* ptr);
Result(void) EncryptedSocketTCP_send(EncryptedSocketTCP* ptr, Result(void) EncryptedSocketTCP_send(EncryptedSocketTCP* ptr,
Array(u8) decrypted_buf, Array(u8) encrypted_buf); Array(u8) buffer);
Result(void) EncryptedSocketTCP_recv(EncryptedSocketTCP* ptr, Result(u32) EncryptedSocketTCP_recv(EncryptedSocketTCP* ptr,
Array(u8) encrypted_buf, Array(u8) decrypted_buf, Array(u8) buffer, SocketRecvFlag flags);
SocketRecvFlag flags);
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
@ -30,15 +35,18 @@ typedef struct EncryptedSocketUDP {
Socket sock; Socket sock;
AESBlockEncryptor enc; AESBlockEncryptor enc;
AESBlockDecryptor dec; AESBlockDecryptor dec;
Array(u8) send_buf;
Array(u8) recv_buf;
} EncryptedSocketUDP; } EncryptedSocketUDP;
void EncryptedSocketUDP_construct(EncryptedSocketUDP* ptr, Socket sock, Array(u8) aes_key); void EncryptedSocketUDP_construct(EncryptedSocketUDP* ptr,
Socket sock, u32 crypto_buffer_size, Array(u8) aes_key);
/// closes the socket
void EncryptedSocketUDP_destroy(EncryptedSocketUDP* ptr);
Result(void) EncryptedSocketUDP_sendto(EncryptedSocketUDP* ptr, Result(void) EncryptedSocketUDP_sendto(EncryptedSocketUDP* ptr,
Array(u8) decrypted_buf, Array(u8) encrypted_buf, Array(u8) buffer, EndpointIPv4 remote_end);
EndpointIPv4 remote_end);
Result(i32) EncryptedSocketUDP_recvfrom(EncryptedSocketUDP* ptr, Result(u32) EncryptedSocketUDP_recvfrom(EncryptedSocketUDP* ptr,
Array(u8) encrypted_buf, Array(u8) decrypted_buf, Array(u8) buffer, SocketRecvFlag flags, NULLABLE(EndpointIPv4*) remote_end);
SocketRecvFlag flags,
NULLABLE(EndpointIPv4*) remote_end);

View File

@ -61,16 +61,24 @@ Result(void) socket_connect(Socket s, EndpointIPv4 remote_end){
Result(void) socket_send(Socket s, Array(u8) buffer){ Result(void) socket_send(Socket s, Array(u8) buffer){
i32 r = send(s, buffer.data, buffer.size, 0); i32 r = send(s, buffer.data, buffer.size, 0);
if(r < 0) if(r < 0){
return RESULT_ERROR_SOCKET(); return RESULT_ERROR_SOCKET();
}
if((u32)r != buffer.size){
return RESULT_ERROR_FMT("Socket was unable to send data");
}
return RESULT_VOID; return RESULT_VOID;
} }
Result(void) socket_sendto(Socket s, Array(u8) buffer, EndpointIPv4 dst){ Result(void) socket_sendto(Socket s, Array(u8) buffer, EndpointIPv4 dst){
struct sockaddr_in sockaddr = EndpointIPv4_toSockaddr(dst); struct sockaddr_in sockaddr = EndpointIPv4_toSockaddr(dst);
i32 r = sendto(s, buffer.data, buffer.size, 0, (void*)&sockaddr, sizeof(sockaddr)); i32 r = sendto(s, buffer.data, buffer.size, 0, (void*)&sockaddr, sizeof(sockaddr));
if(r < 0) if(r < 0){
return RESULT_ERROR_SOCKET(); return RESULT_ERROR_SOCKET();
}
if((u32)r != buffer.size){
return RESULT_ERROR_FMT("Socket was unable to send data");
}
return RESULT_VOID; return RESULT_VOID;
} }
@ -85,8 +93,13 @@ static inline int SocketRecvFlags_toStd(SocketRecvFlag flags){
Result(i32) socket_recv(Socket s, Array(u8) buffer, SocketRecvFlag flags){ Result(i32) socket_recv(Socket s, Array(u8) buffer, SocketRecvFlag flags){
i32 r = recv(s, buffer.data, buffer.size, SocketRecvFlags_toStd(flags)); i32 r = recv(s, buffer.data, buffer.size, SocketRecvFlags_toStd(flags));
if(r < 0) if(r < 0){
return RESULT_ERROR_SOCKET(); return RESULT_ERROR_SOCKET();
}
if(r == 0 || (flags & SocketRecvFlag_WaitAll && (u32)r != buffer.size))
{
return RESULT_ERROR("Socket closed", false);
}
return RESULT_VALUE(i, r); return RESULT_VALUE(i, r);
} }
@ -95,8 +108,13 @@ Result(i32) socket_recvfrom(Socket s, Array(u8) buffer, SocketRecvFlag flags, NU
i32 sockaddr_size = sizeof(remote_addr); i32 sockaddr_size = sizeof(remote_addr);
i32 r = recvfrom(s, buffer.data, buffer.size, SocketRecvFlags_toStd(flags), i32 r = recvfrom(s, buffer.data, buffer.size, SocketRecvFlags_toStd(flags),
(struct sockaddr*)&remote_addr, (void*)&sockaddr_size); (struct sockaddr*)&remote_addr, (void*)&sockaddr_size);
if(r < 0) if(r < 0){
return RESULT_ERROR_SOCKET(); return RESULT_ERROR_SOCKET();
}
if(r == 0 || (flags & SocketRecvFlag_WaitAll && (u32)r != buffer.size))
{
return RESULT_ERROR("Socket closed", false);
}
//TODO: add IPV6 support (struct sockaddr_in6) //TODO: add IPV6 support (struct sockaddr_in6)
assert(sockaddr_size == sizeof(remote_addr)); assert(sockaddr_size == sizeof(remote_addr));

View File

@ -1,14 +1,9 @@
#include "v1.h" #include "v1.h"
Result(void) ClientHandshake_tryConstruct(ClientHandshake* ptr, Array(u8) session_key){ void ClientHandshake_construct(ClientHandshake* ptr, Array(u8) session_key){
Deferral(1); memcpy(ptr->session_key, session_key.data, sizeof(ptr->session_key));
try_assert(session_key.size == AES_SESSION_KEY_SIZE);
PacketHeader_construct(&ptr->header, PROTOCOL_VERSION, PacketType_ClientHandshake, session_key.size);
memcpy(ptr->session_key, session_key.data, session_key.size);
Return RESULT_VOID;
} }
void ServerHandshake_construct(ServerHandshake* ptr, u64 session_id){ void ServerHandshake_construct(ServerHandshake* ptr, u64 session_id){
PacketHeader_construct(&ptr->header, PROTOCOL_VERSION, PacketType_ServerHandshake, sizeof(session_id));
ptr->session_id = session_id; ptr->session_id = session_id;
} }

View File

@ -3,7 +3,7 @@
#include "network/tcp-chat-protocol/constant.h" #include "network/tcp-chat-protocol/constant.h"
#define PROTOCOL_VERSION 1 /* 1.0.0 */ #define PROTOCOL_VERSION 1 /* 1.0.0 */
#define NETWORK_BUFFER_SIZE 65536
typedef enum PacketType { typedef enum PacketType {
PacketType_Invalid, PacketType_Invalid,
@ -14,21 +14,18 @@ typedef enum PacketType {
typedef struct ErrorMessage { typedef struct ErrorMessage {
PacketHeader header;
/* content stream of size `header.content_size` */ /* content stream of size `header.content_size` */
} ErrorMessage; } ErrorMessage;
typedef struct ClientHandshake { typedef struct ClientHandshake {
PacketHeader header;
u8 session_key[AES_SESSION_KEY_SIZE]; u8 session_key[AES_SESSION_KEY_SIZE];
} ClientHandshake; } ClientHandshake;
Result(void) ClientHandshake_tryConstruct(ClientHandshake* ptr, Array(u8) session_key); void ClientHandshake_construct(ClientHandshake* ptr, Array(u8) session_key);
typedef struct ServerHandshake { typedef struct ServerHandshake {
PacketHeader header;
u64 session_id; u64 session_id;
} ServerHandshake; } ServerHandshake;

View File

@ -4,7 +4,7 @@
void ClientConnection_close(ClientConnection* conn){ void ClientConnection_close(ClientConnection* conn){
if(conn == NULL) if(conn == NULL)
return; return;
socket_close(conn->sock.sock); EncryptedSocketTCP_destroy(&conn->sock);
free(conn->session_key.data); free(conn->session_key.data);
free(conn); free(conn);
} }
@ -26,19 +26,23 @@ Result(ClientConnection*) ClientConnection_accept(ServerCredentials* server_cred
conn->session_id = session_id; conn->session_id = session_id;
conn->session_key = Array_alloc_size(AES_SESSION_KEY_SIZE); conn->session_key = Array_alloc_size(AES_SESSION_KEY_SIZE);
Array(u8) enc_buf = Array_alloc_size(8*1024); Array(u8) buffer = Array_alloc_size(NETWORK_BUFFER_SIZE);
Defer(free(enc_buf.data)); // fix for valgrind false detected errors about uninitialized memory
Array(u8) dec_buf = Array_alloc_size(8*1024); Array_memset(buffer, 0xCC);
Defer(free(dec_buf.data)); Defer(free(buffer.data));
u32 enc_size = 0, dec_size = 0;
// TODO: set socket timeout to 5 seconds // TODO: set socket timeout to 5 seconds
// receive message encrypted by server public key // receive message encrypted by server public key
try(enc_size, u, u32 header_and_message_size = sizeof(PacketHeader) + sizeof(ClientHandshake);
Array(u8) bufferPart_encryptedClientHandshake = {
.data = (u8*)buffer.data + header_and_message_size,
.size = server_credentials->rsa_pk.nlen
};
try_void(
socket_recv( socket_recv(
sock_tcp, sock_tcp,
Array_sliceBefore(enc_buf, server_credentials->rsa_pk.nlen), bufferPart_encryptedClientHandshake,
SocketRecvFlag_WaitAll SocketRecvFlag_WaitAll
) )
); );
@ -46,42 +50,47 @@ Result(ClientConnection*) ClientConnection_accept(ServerCredentials* server_cred
// decrypt the message using server private key // decrypt the message using server private key
RSADecryptor rsa_dec; RSADecryptor rsa_dec;
RSADecryptor_construct(&rsa_dec, &server_credentials->rsa_sk); RSADecryptor_construct(&rsa_dec, &server_credentials->rsa_sk);
try(dec_size, u, try(u32 rsa_dec_size, u,
RSADecryptor_decrypt( RSADecryptor_decrypt(
&rsa_dec, &rsa_dec,
Array_sliceBefore(enc_buf, enc_size), bufferPart_encryptedClientHandshake,
dec_buf buffer
) )
); );
// validate client handshake // validate client handshake
if(dec_size != sizeof(ClientHandshake)){ if(rsa_dec_size != header_and_message_size){
Return RESULT_ERROR_FMT( Return RESULT_ERROR_FMT(
"decrypted message (size: %u) is not a ClientHandshake (size: %u)", "decrypted message (size: %u) is not a ClientHandshake (size: %u)",
dec_size, (u32)sizeof(ClientHandshake) rsa_dec_size, header_and_message_size
); );
} }
ClientHandshake* client_handshake = dec_buf.data; PacketHeader* packet_header = buffer.data;
try_void(PacketHeader_validateMagic(&client_handshake->header)); ClientHandshake* client_handshake = Array_sliceAfter(buffer, sizeof(PacketHeader)).data;
if(client_handshake->header.type != PacketType_ClientHandshake){ try_void(PacketHeader_validateMagic(packet_header));
if(packet_header->type != PacketType_ClientHandshake){
Return RESULT_ERROR_FMT( Return RESULT_ERROR_FMT(
"received message of unexpected type: %u", "received message of unexpected type: %u",
client_handshake->header.type packet_header->type
); );
} }
// use received session key // use received session key
memcpy(conn->session_key.data, client_handshake->session_key, conn->session_key.size); memcpy(conn->session_key.data, client_handshake->session_key, conn->session_key.size);
EncryptedSocketTCP_construct(&conn->sock, sock_tcp, conn->session_key); EncryptedSocketTCP_construct(&conn->sock, sock_tcp, NETWORK_BUFFER_SIZE, conn->session_key);
// construct ServerHandshake in dec_buf // construct PacketHeader and ServerHandshake in buffer
ServerHandshake_construct((ServerHandshake*)dec_buf.data, session_id); PacketHeader_construct(buffer.data, PROTOCOL_VERSION,
PacketType_ServerHandshake, sizeof(ServerHandshake));
ServerHandshake_construct(
Array_sliceAfter(buffer, sizeof(PacketHeader)).data,
session_id);
// send ServerHandshake over encrypted TCP socket // send ServerHandshake over encrypted TCP socket
header_and_message_size = sizeof(PacketHeader) + sizeof(ServerHandshake);
try_void( try_void(
EncryptedSocketTCP_send( EncryptedSocketTCP_send(
&conn->sock, &conn->sock,
Array_sliceBefore(dec_buf, sizeof(ServerHandshake)), Array_sliceBefore(buffer, header_and_message_size)
enc_buf
) )
); );

View File

@ -1,7 +1,7 @@
#include "server.h" #include "server.h"
Result(ServerCredentials*) ServerCredentials_create(const str rsa_sk_base64, const str rsa_pk_base64){ Result(ServerCredentials*) ServerCredentials_create(cstr rsa_sk_base64, cstr rsa_pk_base64){
Deferral(4); Deferral(4);
ServerCredentials* cred = (ServerCredentials*)malloc(sizeof(ServerCredentials)); ServerCredentials* cred = (ServerCredentials*)malloc(sizeof(ServerCredentials));

View File

@ -4,6 +4,7 @@
#include "server.h" #include "server.h"
#include "config.h" #include "config.h"
#include "log.h" #include "log.h"
#include "network/tcp-chat-protocol/v1.h"
typedef struct ConnectionHandlerArgs { typedef struct ConnectionHandlerArgs {
Socket accepted_socket; Socket accepted_socket;
@ -23,6 +24,7 @@ static Result(void) parseConfig(cstr config_path){
// open file // open file
try(FILE* config_file, p, file_open(config_path, FO_ReadExisting)); try(FILE* config_file, p, file_open(config_path, FO_ReadExisting));
Defer(file_close(config_file));
// read whole file into Array(char) // read whole file into Array(char)
try(i64 config_file_size, i, file_getSize(config_file)); try(i64 config_file_size, i, file_getSize(config_file));
Array(char) config_buf = Array_alloc(char, config_file_size); Array(char) config_buf = Array_alloc(char, config_file_size);
@ -34,7 +36,13 @@ static Result(void) parseConfig(cstr config_path){
str pk_base64; str pk_base64;
try_void(config_findValue(config_str, STR("rsa_private_key"), &sk_base64, true)); try_void(config_findValue(config_str, STR("rsa_private_key"), &sk_base64, true));
try_void(config_findValue(config_str, STR("rsa_public_key"), &pk_base64, true)); try_void(config_findValue(config_str, STR("rsa_public_key"), &pk_base64, true));
try(_server_credentials, p, ServerCredentials_create(sk_base64, pk_base64)); char* sk_base64_cstr = str_copy(sk_base64).data;
char* pk_base64_cstr = str_copy(pk_base64).data;
Defer(
free(sk_base64_cstr);
free(pk_base64_cstr);
);
try(_server_credentials, p, ServerCredentials_create(sk_base64_cstr, pk_base64_cstr));
Return RESULT_VOID; Return RESULT_VOID;
} }
@ -45,6 +53,7 @@ Result(void) server_run(cstr server_endpoint_cstr, cstr config_path){
logInfo(log_ctx, "starting server"); logInfo(log_ctx, "starting server");
logDebug(log_ctx, "parsing config"); logDebug(log_ctx, "parsing config");
try_void(parseConfig(config_path)); try_void(parseConfig(config_path));
Defer(ServerCredentials_free(_server_credentials));
logDebug(log_ctx, "initializing main socket"); logDebug(log_ctx, "initializing main socket");
EndpointIPv4 server_end; EndpointIPv4 server_end;
@ -63,6 +72,7 @@ Result(void) server_run(cstr server_endpoint_cstr, cstr config_path){
//TODO: use async IO instead of threads to not waste system resources //TODO: use async IO instead of threads to not waste system resources
// while waiting for incoming data in 100500 threads // while waiting for incoming data in 100500 threads
try_stderrcode(pthread_create(&conn_thread, NULL, handle_connection, args)); try_stderrcode(pthread_create(&conn_thread, NULL, handle_connection, args));
try_stderrcode(pthread_detach(&conn_thread));
} }
Return RESULT_VOID; Return RESULT_VOID;
@ -88,7 +98,10 @@ static Result(void) try_handle_connection(ConnectionHandlerArgs* args, cstr log_
Defer(free(args)); Defer(free(args));
ClientConnection* conn = NULL; ClientConnection* conn = NULL;
Defer(ClientConnection_close(conn)); Defer(
ClientConnection_close(conn);
logInfo(log_ctx, "session closed");
);
// establish encrypted connection // establish encrypted connection
try(conn, p, try(conn, p,
ClientConnection_accept( ClientConnection_accept(
@ -98,9 +111,16 @@ static Result(void) try_handle_connection(ConnectionHandlerArgs* args, cstr log_
args->session_id args->session_id
) )
); );
logDebug(log_ctx, "session accepted"); logInfo(log_ctx, "session accepted");
// handle requests // handle requests
Array(u8) buffer = Array_alloc_size(NETWORK_BUFFER_SIZE);
// fix for valgrind false detected errors about uninitialized memory
Array_memset(buffer, 0xCC);
Defer(free(buffer.data));
u32 dec_size = 0;
while(true){ while(true){
sleepMsec(10); sleepMsec(10);
} }

View File

@ -11,7 +11,7 @@ typedef struct ServerCredentials {
br_rsa_public_key rsa_pk; br_rsa_public_key rsa_pk;
} ServerCredentials; } ServerCredentials;
Result(ServerCredentials*) ServerCredentials_create(const str rsa_sk_base64, const str rsa_pk_base64); Result(ServerCredentials*) ServerCredentials_create(cstr rsa_sk_base64, cstr rsa_pk_base64);
void ServerCredentials_free(ServerCredentials* cred); void ServerCredentials_free(ServerCredentials* cred);