From 42702ffbe7c76d657883cda8f395d7eda1b99e5b Mon Sep 17 00:00:00 2001 From: Timerix Date: Sat, 25 Oct 2025 19:08:37 +0500 Subject: [PATCH] added flags to socket_recv --- dependencies/tlibc | 2 +- src/cryptography/AES.c | 18 ++++++++++++------ src/cryptography/cryptography.h | 4 ++-- src/network/EncryptedSocket.c | 17 +++++++++-------- src/network/EncryptedSocket.h | 4 +++- src/network/socket.c | 17 +++++++++++++---- src/network/socket.h | 11 +++++++++-- 7 files changed, 49 insertions(+), 24 deletions(-) diff --git a/dependencies/tlibc b/dependencies/tlibc index 0184d2e..f0992c0 160000 --- a/dependencies/tlibc +++ b/dependencies/tlibc @@ -1 +1 @@ -Subproject commit 0184d2e8c96d882b815eae05fa160a40e9f7faf2 +Subproject commit f0992c02178758546b56a87574b6787ced797c7d diff --git a/src/cryptography/AES.c b/src/cryptography/AES.c index 150dc23..658b3cf 100755 --- a/src/cryptography/AES.c +++ b/src/cryptography/AES.c @@ -26,8 +26,9 @@ void EncryptorAES_construct(EncryptorAES* ptr, Array(u8) key){ rng_init_sha256_seedFromSystem(&ptr->rng_ctx.vtable); } -void EncryptorAES_encrypt(EncryptorAES* ptr, Array(u8) src, Array(u8) dst){ - assert(dst.size >= EncryptorAES_calcDstSize(src.size)); +Result(void) EncryptorAES_encrypt(EncryptorAES* ptr, Array(u8) src, Array(u8) dst){ + Deferral(4); + try_assert(dst.size >= EncryptorAES_calcDstSize(src.size)); // generate random initial vector br_hmac_drbg_generate(&ptr->rng_ctx, ptr->iv, __AES_IV_SIZE); @@ -57,6 +58,8 @@ void EncryptorAES_encrypt(EncryptorAES* ptr, Array(u8) src, Array(u8) dst){ br_aes_ct64_cbcenc_run(&ptr->enc_ctx, ptr->iv, ptr->buf, src_size_padded); memcpy(dst.data, ptr->buf, src_size_padded); } + + Return RESULT_VOID; } @@ -66,10 +69,11 @@ void DecryptorAES_construct(DecryptorAES* ptr, Array(u8) key){ br_aes_ct64_cbcdec_init(&ptr->dec_ctx, key.data, key.size); } -void DecryptorAES_decrypt(DecryptorAES* ptr, Array(u8) src, Array(u8) dst, u32* decrypted_size){ - assert(src.size >= EncryptorAES_calcDstSize(0)); - assert(src.size % 16 == 0 && "src must be array of 16-byte blocks"); - assert(dst.size >= src.size); +Result(void) DecryptorAES_decrypt(DecryptorAES* ptr, Array(u8) src, Array(u8) dst, u32* decrypted_size){ + Deferral(4); + try_assert(src.size >= EncryptorAES_calcDstSize(0)); + try_assert(src.size % 16 == 0 && "src must be array of 16-byte blocks"); + try_assert(dst.size >= src.size); // read IV from the beginning of src __Array_readNext(ptr->iv, &src, __AES_IV_SIZE); @@ -97,4 +101,6 @@ void DecryptorAES_decrypt(DecryptorAES* ptr, Array(u8) src, Array(u8) dst, u32* br_aes_ct64_cbcdec_run(&ptr->dec_ctx, ptr->iv, ptr->buf, src_size_padded); memcpy(dst.data, ptr->buf, src.size); } + + Return RESULT_VOID; } diff --git a/src/cryptography/cryptography.h b/src/cryptography/cryptography.h index fcd7e60..e5aadc0 100755 --- a/src/cryptography/cryptography.h +++ b/src/cryptography/cryptography.h @@ -73,7 +73,7 @@ void EncryptorAES_construct(EncryptorAES* ptr, Array(u8) key); /// @brief Encrypts `src` and writes output to `dst`. /// @param src array of any size /// @param dst array of size >= EncryptorAES_calcDstSize(src.size) -void EncryptorAES_encrypt(EncryptorAES* ptr, Array(u8) src, Array(u8) dst); +Result(void) EncryptorAES_encrypt(EncryptorAES* ptr, Array(u8) src, Array(u8) dst); #define EncryptorAES_calcDstSize(SRC_SIZE) (__AES_IV_SIZE + sizeof(EncryptedBlockHeader) + ALIGN_TO(SRC_SIZE, 16)) @@ -91,7 +91,7 @@ void DecryptorAES_construct(DecryptorAES* ptr, Array(u8) key); /// @param src array of size at least EncryptorAES_calcDstSize(0). Size must be multiple of 16. /// @param dst array of size >= src.size /// @param decrypted_size size of original data without padding added by EncryptorAES_encrypt -void DecryptorAES_decrypt(DecryptorAES* ptr, Array(u8) src, Array(u8) dst, u32* decrypted_size); +Result(void) DecryptorAES_decrypt(DecryptorAES* ptr, Array(u8) src, Array(u8) dst, u32* decrypted_size); ////////////////////////////////////////////////////////////////////////////// diff --git a/src/network/EncryptedSocket.c b/src/network/EncryptedSocket.c index 0baa93a..170894e 100644 --- a/src/network/EncryptedSocket.c +++ b/src/network/EncryptedSocket.c @@ -12,7 +12,7 @@ Result(void) EncryptedSocket_send(EncryptedSocket* ptr, Array(u8) decrypted_buf, Array(u8) encrypted_buf) { Deferral(4); - EncryptorAES_encrypt(&ptr->enc, decrypted_buf, encrypted_buf); + try_void(EncryptorAES_encrypt(&ptr->enc, decrypted_buf, encrypted_buf)); try_void(socket_send(ptr->sock, encrypted_buf)); Return RESULT_VOID; } @@ -22,28 +22,29 @@ Result(void) EncryptedSocket_sendto(EncryptedSocket* ptr, EndpointIPv4 remote_end) { Deferral(4); - EncryptorAES_encrypt(&ptr->enc, decrypted_buf, encrypted_buf); + try_void(EncryptorAES_encrypt(&ptr->enc, decrypted_buf, encrypted_buf)); try_void(socket_sendto(ptr->sock, encrypted_buf, remote_end)); Return RESULT_VOID; } Result(i32) EncryptedSocket_recv(EncryptedSocket* ptr, - Array(u8) encrypted_buf, Array(u8) decrypted_buf) + Array(u8) encrypted_buf, Array(u8) decrypted_buf, + SocketRecvFlag flags) { Deferral(4); - try(i32 r, i, socket_recv(ptr->sock, encrypted_buf)); + try(i32 r, i, socket_recv(ptr->sock, encrypted_buf, flags)); encrypted_buf.size = r; - DecryptorAES_decrypt(&ptr->dec, encrypted_buf, decrypted_buf, (u32*)&r); + try_void(DecryptorAES_decrypt(&ptr->dec, encrypted_buf, decrypted_buf, (u32*)&r)); Return RESULT_VALUE(i, r); } Result(i32) EncryptedSocket_recvfrom(EncryptedSocket* ptr, Array(u8) encrypted_buf, Array(u8) decrypted_buf, - NULLABLE(EndpointIPv4*) remote_end) + SocketRecvFlag flags, NULLABLE(EndpointIPv4*) remote_end) { Deferral(4); - try(i32 r, i, socket_recvfrom(ptr->sock, encrypted_buf, remote_end)); + try(i32 r, i, socket_recvfrom(ptr->sock, encrypted_buf, flags, remote_end)); encrypted_buf.size = r; - DecryptorAES_decrypt(&ptr->dec, encrypted_buf, decrypted_buf, (u32*)&r); + try_void(DecryptorAES_decrypt(&ptr->dec, encrypted_buf, decrypted_buf, (u32*)&r)); Return RESULT_VALUE(i, r); } diff --git a/src/network/EncryptedSocket.h b/src/network/EncryptedSocket.h index e687302..a18050b 100644 --- a/src/network/EncryptedSocket.h +++ b/src/network/EncryptedSocket.h @@ -18,8 +18,10 @@ Result(void) EncryptedSocket_sendto(EncryptedSocket* ptr, EndpointIPv4 remote_end); Result(i32) EncryptedSocket_recv(EncryptedSocket* ptr, - Array(u8) encrypted_buf, Array(u8) decrypted_buf); + Array(u8) encrypted_buf, Array(u8) decrypted_buf, + SocketRecvFlag flags); Result(i32) EncryptedSocket_recvfrom(EncryptedSocket* ptr, Array(u8) encrypted_buf, Array(u8) decrypted_buf, + SocketRecvFlag flags, NULLABLE(EndpointIPv4*) remote_end); diff --git a/src/network/socket.c b/src/network/socket.c index fbbd663..91639ef 100755 --- a/src/network/socket.c +++ b/src/network/socket.c @@ -74,17 +74,26 @@ Result(void) socket_sendto(Socket s, Array(u8) buffer, EndpointIPv4 dst){ return RESULT_VOID; } -Result(i32) socket_recv(Socket s, Array(u8) buffer){ - i32 r = recv(s, buffer.data, buffer.size, 0); +static inline int SocketRecvFlags_toStd(SocketRecvFlag flags){ + int f = 0; + if (flags & SocketRecvFlag_Peek) + f |= MSG_PEEK; + if (flags & SocketRecvFlag_WaitAll) + f |= MSG_WAITALL; + return f; +} + +Result(i32) socket_recv(Socket s, Array(u8) buffer, SocketRecvFlag flags){ + i32 r = recv(s, buffer.data, buffer.size, SocketRecvFlags_toStd(flags)); if(r < 0) return RESULT_ERROR_SOCKET(); return RESULT_VALUE(i, r); } -Result(i32) socket_recvfrom(Socket s, Array(u8) buffer, NULLABLE(EndpointIPv4*) remote_end){ +Result(i32) socket_recvfrom(Socket s, Array(u8) buffer, SocketRecvFlag flags, NULLABLE(EndpointIPv4*) remote_end){ struct sockaddr_in remote_addr = {0}; i32 sockaddr_size = sizeof(remote_addr); - i32 r = recvfrom(s, buffer.data, buffer.size, 0, + i32 r = recvfrom(s, buffer.data, buffer.size, SocketRecvFlags_toStd(flags), (struct sockaddr*)&remote_addr, (void*)&sockaddr_size); if(r < 0) return RESULT_ERROR_SOCKET(); diff --git a/src/network/socket.h b/src/network/socket.h index 6d4bab6..b552dc5 100755 --- a/src/network/socket.h +++ b/src/network/socket.h @@ -2,6 +2,7 @@ #include "endpoint.h" #include "tlibc/errors.h" #include "tlibc/collections/Array.h" +#include "tlibc/time.h" typedef enum SocketShutdownType { SocketShutdownType_Receive = 0, @@ -9,6 +10,12 @@ typedef enum SocketShutdownType { SocketShutdownType_Both = 2, } SocketShutdownType; +typedef enum SocketRecvFlag { + SocketRecvFlag_None = 0, + SocketRecvFlag_Peek = 0b1 /* next recv call will read the same data */, + SocketRecvFlag_WaitAll = 0b10 /* waits until buffer is full */, +} SocketRecvFlag; + typedef i64 Socket; Result(Socket) socket_open_TCP(); @@ -20,5 +27,5 @@ Result(Socket) socket_accept(Socket s, NULLABLE(EndpointIPv4*) remote_end); Result(void) socket_connect(Socket s, EndpointIPv4 remote_end); Result(void) socket_send(Socket s, Array(u8) buffer); Result(void) socket_sendto(Socket s, Array(u8) buffer, EndpointIPv4 dst); -Result(i32) socket_recv(Socket s, Array(u8) buffer); -Result(i32) socket_recvfrom(Socket s, Array(u8) buffer, NULLABLE(EndpointIPv4*) remote_end); +Result(i32) socket_recv(Socket s, Array(u8) buffer, SocketRecvFlag flags); +Result(i32) socket_recvfrom(Socket s, Array(u8) buffer, SocketRecvFlag flags, NULLABLE(EndpointIPv4*) remote_end);