added internal buffers to encrypted sockets

This commit is contained in:
2025-11-06 22:27:41 +05:00
parent 375dd842d4
commit d36fe9e5b3
10 changed files with 233 additions and 126 deletions

View File

@@ -1,63 +1,134 @@
#include "encrypted_sockets.h"
void EncryptedSocketTCP_construct(EncryptedSocketTCP* ptr, Socket sock, Array(u8) aes_key)
void EncryptedSocketTCP_construct(EncryptedSocketTCP* ptr,
Socket sock, u32 crypto_buffer_size, Array(u8) aes_key)
{
ptr->sock = sock;
AESStreamEncryptor_construct(&ptr->enc, aes_key, AESStream_DEFAULT_CLASS);
AESStreamDecryptor_construct(&ptr->dec, aes_key, AESStream_DEFAULT_CLASS);
ptr->recv_buf = Array_alloc_size(crypto_buffer_size);
ptr->send_buf = Array_alloc_size(crypto_buffer_size);
}
void EncryptedSocketTCP_destroy(EncryptedSocketTCP* ptr){
socket_close(ptr->sock);
free(ptr->recv_buf.data);
free(ptr->send_buf.data);
}
Result(void) EncryptedSocketTCP_send(EncryptedSocketTCP* ptr,
Array(u8) decrypted_buf, Array(u8) encrypted_buf)
Result(void) EncryptedSocketTCP_send(EncryptedSocketTCP* ptr,
Array(u8) buffer)
{
Deferral(4);
try(u32 encrypted_size, u, AESStreamEncryptor_encrypt(&ptr->enc, decrypted_buf, encrypted_buf));
encrypted_buf.size = encrypted_size;
try_void(socket_send(ptr->sock, encrypted_buf));
Deferral(1);
try(u32 encrypted_size, u,
AESStreamEncryptor_encrypt(
&ptr->enc,
buffer,
ptr->send_buf
)
);
try_void(
socket_send(
ptr->sock,
Array_sliceBefore(ptr->send_buf, encrypted_size)
)
);
Return RESULT_VOID;
}
Result(i32) EncryptedSocketTCP_recv(EncryptedSocketTCP* ptr,
Array(u8) encrypted_buf, Array(u8) decrypted_buf,
SocketRecvFlag flags)
Result(u32) EncryptedSocketTCP_recv(EncryptedSocketTCP* ptr,
Array(u8) buffer, SocketRecvFlag flags)
{
Deferral(4);
try(i32 received_size, i, socket_recv(ptr->sock, encrypted_buf, flags));
//TODO: return error if WaitAll flag was set and socket closed before filling the buffer
//TODO: return something when received_size == 0 (socket has been closed)
encrypted_buf.size = received_size;
try(i32 decrypted_size, u, AESStreamDecryptor_decrypt(&ptr->dec, encrypted_buf, decrypted_buf));
Return RESULT_VALUE(i, decrypted_size);
Deferral(1);
u32 size_to_receive = buffer.size;
if(ptr->dec.block_counter == 0){
// There is some metadata at the beginning of AES stream
size_to_receive = AESStreamEncryptor_calcDstSize(size_to_receive);
}
try(i32 received_size, i,
socket_recv(
ptr->sock,
Array_sliceBefore(ptr->recv_buf, size_to_receive),
flags
)
);
try(u32 decrypted_size, u,
AESStreamDecryptor_decrypt(
&ptr->dec,
Array_sliceBefore(ptr->recv_buf, received_size),
buffer
)
);
Return RESULT_VALUE(u, decrypted_size);
}
void EncryptedSocketUDP_construct(EncryptedSocketUDP* ptr, Socket sock, Array(u8) aes_key)
void EncryptedSocketUDP_construct(EncryptedSocketUDP* ptr,
Socket sock, u32 crypto_buffer_size, Array(u8) aes_key)
{
ptr->sock = sock;
AESBlockEncryptor_construct(&ptr->enc, aes_key, AESBlockEncryptor_DEFAULT_CLASS);
AESBlockDecryptor_construct(&ptr->dec, aes_key, AESBlockDecryptor_DEFAULT_CLASS);
ptr->recv_buf = Array_alloc_size(crypto_buffer_size);
ptr->send_buf = Array_alloc_size(crypto_buffer_size);
}
Result(void) EncryptedSocketUDP_sendto(EncryptedSocketUDP* ptr,
Array(u8) decrypted_buf, Array(u8) encrypted_buf,
EndpointIPv4 remote_end)
void EncryptedSocketUDP_destroy(EncryptedSocketUDP* ptr){
socket_close(ptr->sock);
free(ptr->recv_buf.data);
free(ptr->send_buf.data);
}
Result(void) EncryptedSocketUDP_sendto(EncryptedSocketUDP* ptr,
Array(u8) buffer, EndpointIPv4 remote_end)
{
Deferral(4);
try(u32 encrypted_size, u, AESBlockEncryptor_encrypt(&ptr->enc, decrypted_buf, encrypted_buf));
encrypted_buf.size = encrypted_size;
try_void(socket_sendto(ptr->sock, encrypted_buf, remote_end));
Deferral(1);
try(u32 encrypted_size, u,
AESBlockEncryptor_encrypt(
&ptr->enc,
buffer,
ptr->send_buf
)
);
try_void(
socket_sendto(
ptr->sock,
Array_sliceBefore(ptr->send_buf, encrypted_size),
remote_end
)
);
Return RESULT_VOID;
}
Result(i32) EncryptedSocketUDP_recvfrom(EncryptedSocketUDP* ptr,
Array(u8) encrypted_buf, Array(u8) decrypted_buf,
SocketRecvFlag flags, NULLABLE(EndpointIPv4*) remote_end)
Array(u8) buffer, SocketRecvFlag flags, NULLABLE(EndpointIPv4*) remote_end)
{
Deferral(4);
try(i32 received_size, i, socket_recvfrom(ptr->sock, encrypted_buf, flags, remote_end));
encrypted_buf.size = received_size;
try(i32 decrypted_size, u, AESBlockDecryptor_decrypt(&ptr->dec, encrypted_buf, decrypted_buf));
Return RESULT_VALUE(i, decrypted_size);
Deferral(1);
// There is some metadata at the start of each AES block
u32 size_to_receive = AESBlockEncryptor_calcDstSize(buffer.size);
try(i32 received_size, i,
socket_recvfrom(
ptr->sock,
Array_sliceBefore(ptr->recv_buf, size_to_receive),
flags,
remote_end
)
);
try(u32 decrypted_size, u,
AESBlockDecryptor_decrypt(
&ptr->dec,
Array_sliceBefore(ptr->recv_buf, received_size),
buffer
)
);
Return RESULT_VALUE(u, decrypted_size);
}

View File

@@ -10,16 +10,21 @@ typedef struct EncryptedSocketTCP {
Socket sock;
AESStreamEncryptor enc;
AESStreamDecryptor dec;
Array(u8) send_buf;
Array(u8) recv_buf;
} EncryptedSocketTCP;
void EncryptedSocketTCP_construct(EncryptedSocketTCP* ptr, Socket sock, Array(u8) aes_key);
void EncryptedSocketTCP_construct(EncryptedSocketTCP* ptr,
Socket sock, u32 crypto_buffer_size, Array(u8) aes_key);
Result(void) EncryptedSocketTCP_send(EncryptedSocketTCP* ptr,
Array(u8) decrypted_buf, Array(u8) encrypted_buf);
/// closes the socket
void EncryptedSocketTCP_destroy(EncryptedSocketTCP* ptr);
Result(void) EncryptedSocketTCP_recv(EncryptedSocketTCP* ptr,
Array(u8) encrypted_buf, Array(u8) decrypted_buf,
SocketRecvFlag flags);
Result(void) EncryptedSocketTCP_send(EncryptedSocketTCP* ptr,
Array(u8) buffer);
Result(u32) EncryptedSocketTCP_recv(EncryptedSocketTCP* ptr,
Array(u8) buffer, SocketRecvFlag flags);
//////////////////////////////////////////////////////////////////////////////
@@ -30,15 +35,18 @@ typedef struct EncryptedSocketUDP {
Socket sock;
AESBlockEncryptor enc;
AESBlockDecryptor dec;
Array(u8) send_buf;
Array(u8) recv_buf;
} EncryptedSocketUDP;
void EncryptedSocketUDP_construct(EncryptedSocketUDP* ptr, Socket sock, Array(u8) aes_key);
void EncryptedSocketUDP_construct(EncryptedSocketUDP* ptr,
Socket sock, u32 crypto_buffer_size, Array(u8) aes_key);
Result(void) EncryptedSocketUDP_sendto(EncryptedSocketUDP* ptr,
Array(u8) decrypted_buf, Array(u8) encrypted_buf,
EndpointIPv4 remote_end);
/// closes the socket
void EncryptedSocketUDP_destroy(EncryptedSocketUDP* ptr);
Result(i32) EncryptedSocketUDP_recvfrom(EncryptedSocketUDP* ptr,
Array(u8) encrypted_buf, Array(u8) decrypted_buf,
SocketRecvFlag flags,
NULLABLE(EndpointIPv4*) remote_end);
Result(void) EncryptedSocketUDP_sendto(EncryptedSocketUDP* ptr,
Array(u8) buffer, EndpointIPv4 remote_end);
Result(u32) EncryptedSocketUDP_recvfrom(EncryptedSocketUDP* ptr,
Array(u8) buffer, SocketRecvFlag flags, NULLABLE(EndpointIPv4*) remote_end);

View File

@@ -61,16 +61,24 @@ Result(void) socket_connect(Socket s, EndpointIPv4 remote_end){
Result(void) socket_send(Socket s, Array(u8) buffer){
i32 r = send(s, buffer.data, buffer.size, 0);
if(r < 0)
if(r < 0){
return RESULT_ERROR_SOCKET();
}
if((u32)r != buffer.size){
return RESULT_ERROR_FMT("Socket was unable to send data");
}
return RESULT_VOID;
}
Result(void) socket_sendto(Socket s, Array(u8) buffer, EndpointIPv4 dst){
struct sockaddr_in sockaddr = EndpointIPv4_toSockaddr(dst);
i32 r = sendto(s, buffer.data, buffer.size, 0, (void*)&sockaddr, sizeof(sockaddr));
if(r < 0)
if(r < 0){
return RESULT_ERROR_SOCKET();
}
if((u32)r != buffer.size){
return RESULT_ERROR_FMT("Socket was unable to send data");
}
return RESULT_VOID;
}
@@ -85,8 +93,13 @@ static inline int SocketRecvFlags_toStd(SocketRecvFlag flags){
Result(i32) socket_recv(Socket s, Array(u8) buffer, SocketRecvFlag flags){
i32 r = recv(s, buffer.data, buffer.size, SocketRecvFlags_toStd(flags));
if(r < 0)
if(r < 0){
return RESULT_ERROR_SOCKET();
}
if(r == 0 || (flags & SocketRecvFlag_WaitAll && (u32)r != buffer.size))
{
return RESULT_ERROR("Socket closed", false);
}
return RESULT_VALUE(i, r);
}
@@ -95,9 +108,14 @@ Result(i32) socket_recvfrom(Socket s, Array(u8) buffer, SocketRecvFlag flags, NU
i32 sockaddr_size = sizeof(remote_addr);
i32 r = recvfrom(s, buffer.data, buffer.size, SocketRecvFlags_toStd(flags),
(struct sockaddr*)&remote_addr, (void*)&sockaddr_size);
if(r < 0)
if(r < 0){
return RESULT_ERROR_SOCKET();
}
if(r == 0 || (flags & SocketRecvFlag_WaitAll && (u32)r != buffer.size))
{
return RESULT_ERROR("Socket closed", false);
}
//TODO: add IPV6 support (struct sockaddr_in6)
assert(sockaddr_size == sizeof(remote_addr));

View File

@@ -1,14 +1,9 @@
#include "v1.h"
Result(void) ClientHandshake_tryConstruct(ClientHandshake* ptr, Array(u8) session_key){
Deferral(1);
try_assert(session_key.size == AES_SESSION_KEY_SIZE);
PacketHeader_construct(&ptr->header, PROTOCOL_VERSION, PacketType_ClientHandshake, session_key.size);
memcpy(ptr->session_key, session_key.data, session_key.size);
Return RESULT_VOID;
void ClientHandshake_construct(ClientHandshake* ptr, Array(u8) session_key){
memcpy(ptr->session_key, session_key.data, sizeof(ptr->session_key));
}
void ServerHandshake_construct(ServerHandshake* ptr, u64 session_id){
PacketHeader_construct(&ptr->header, PROTOCOL_VERSION, PacketType_ServerHandshake, sizeof(session_id));
ptr->session_id = session_id;
}

View File

@@ -3,7 +3,7 @@
#include "network/tcp-chat-protocol/constant.h"
#define PROTOCOL_VERSION 1 /* 1.0.0 */
#define NETWORK_BUFFER_SIZE 65536
typedef enum PacketType {
PacketType_Invalid,
@@ -14,21 +14,18 @@ typedef enum PacketType {
typedef struct ErrorMessage {
PacketHeader header;
/* content stream of size `header.content_size` */
} ErrorMessage;
typedef struct ClientHandshake {
PacketHeader header;
u8 session_key[AES_SESSION_KEY_SIZE];
} ClientHandshake;
Result(void) ClientHandshake_tryConstruct(ClientHandshake* ptr, Array(u8) session_key);
void ClientHandshake_construct(ClientHandshake* ptr, Array(u8) session_key);
typedef struct ServerHandshake {
PacketHeader header;
u64 session_id;
} ServerHandshake;