added flags to socket_recv

This commit is contained in:
Timerix 2025-10-25 19:08:37 +05:00
parent eea36ec2a3
commit 42702ffbe7
7 changed files with 49 additions and 24 deletions

2
dependencies/tlibc vendored

@ -1 +1 @@
Subproject commit 0184d2e8c96d882b815eae05fa160a40e9f7faf2 Subproject commit f0992c02178758546b56a87574b6787ced797c7d

View File

@ -26,8 +26,9 @@ void EncryptorAES_construct(EncryptorAES* ptr, Array(u8) key){
rng_init_sha256_seedFromSystem(&ptr->rng_ctx.vtable); rng_init_sha256_seedFromSystem(&ptr->rng_ctx.vtable);
} }
void EncryptorAES_encrypt(EncryptorAES* ptr, Array(u8) src, Array(u8) dst){ Result(void) EncryptorAES_encrypt(EncryptorAES* ptr, Array(u8) src, Array(u8) dst){
assert(dst.size >= EncryptorAES_calcDstSize(src.size)); Deferral(4);
try_assert(dst.size >= EncryptorAES_calcDstSize(src.size));
// generate random initial vector // generate random initial vector
br_hmac_drbg_generate(&ptr->rng_ctx, ptr->iv, __AES_IV_SIZE); br_hmac_drbg_generate(&ptr->rng_ctx, ptr->iv, __AES_IV_SIZE);
@ -57,6 +58,8 @@ void EncryptorAES_encrypt(EncryptorAES* ptr, Array(u8) src, Array(u8) dst){
br_aes_ct64_cbcenc_run(&ptr->enc_ctx, ptr->iv, ptr->buf, src_size_padded); br_aes_ct64_cbcenc_run(&ptr->enc_ctx, ptr->iv, ptr->buf, src_size_padded);
memcpy(dst.data, ptr->buf, src_size_padded); memcpy(dst.data, ptr->buf, src_size_padded);
} }
Return RESULT_VOID;
} }
@ -66,10 +69,11 @@ void DecryptorAES_construct(DecryptorAES* ptr, Array(u8) key){
br_aes_ct64_cbcdec_init(&ptr->dec_ctx, key.data, key.size); br_aes_ct64_cbcdec_init(&ptr->dec_ctx, key.data, key.size);
} }
void DecryptorAES_decrypt(DecryptorAES* ptr, Array(u8) src, Array(u8) dst, u32* decrypted_size){ Result(void) DecryptorAES_decrypt(DecryptorAES* ptr, Array(u8) src, Array(u8) dst, u32* decrypted_size){
assert(src.size >= EncryptorAES_calcDstSize(0)); Deferral(4);
assert(src.size % 16 == 0 && "src must be array of 16-byte blocks"); try_assert(src.size >= EncryptorAES_calcDstSize(0));
assert(dst.size >= src.size); try_assert(src.size % 16 == 0 && "src must be array of 16-byte blocks");
try_assert(dst.size >= src.size);
// read IV from the beginning of src // read IV from the beginning of src
__Array_readNext(ptr->iv, &src, __AES_IV_SIZE); __Array_readNext(ptr->iv, &src, __AES_IV_SIZE);
@ -97,4 +101,6 @@ void DecryptorAES_decrypt(DecryptorAES* ptr, Array(u8) src, Array(u8) dst, u32*
br_aes_ct64_cbcdec_run(&ptr->dec_ctx, ptr->iv, ptr->buf, src_size_padded); br_aes_ct64_cbcdec_run(&ptr->dec_ctx, ptr->iv, ptr->buf, src_size_padded);
memcpy(dst.data, ptr->buf, src.size); memcpy(dst.data, ptr->buf, src.size);
} }
Return RESULT_VOID;
} }

View File

@ -73,7 +73,7 @@ void EncryptorAES_construct(EncryptorAES* ptr, Array(u8) key);
/// @brief Encrypts `src` and writes output to `dst`. /// @brief Encrypts `src` and writes output to `dst`.
/// @param src array of any size /// @param src array of any size
/// @param dst array of size >= EncryptorAES_calcDstSize(src.size) /// @param dst array of size >= EncryptorAES_calcDstSize(src.size)
void EncryptorAES_encrypt(EncryptorAES* ptr, Array(u8) src, Array(u8) dst); Result(void) EncryptorAES_encrypt(EncryptorAES* ptr, Array(u8) src, Array(u8) dst);
#define EncryptorAES_calcDstSize(SRC_SIZE) (__AES_IV_SIZE + sizeof(EncryptedBlockHeader) + ALIGN_TO(SRC_SIZE, 16)) #define EncryptorAES_calcDstSize(SRC_SIZE) (__AES_IV_SIZE + sizeof(EncryptedBlockHeader) + ALIGN_TO(SRC_SIZE, 16))
@ -91,7 +91,7 @@ void DecryptorAES_construct(DecryptorAES* ptr, Array(u8) key);
/// @param src array of size at least EncryptorAES_calcDstSize(0). Size must be multiple of 16. /// @param src array of size at least EncryptorAES_calcDstSize(0). Size must be multiple of 16.
/// @param dst array of size >= src.size /// @param dst array of size >= src.size
/// @param decrypted_size size of original data without padding added by EncryptorAES_encrypt /// @param decrypted_size size of original data without padding added by EncryptorAES_encrypt
void DecryptorAES_decrypt(DecryptorAES* ptr, Array(u8) src, Array(u8) dst, u32* decrypted_size); Result(void) DecryptorAES_decrypt(DecryptorAES* ptr, Array(u8) src, Array(u8) dst, u32* decrypted_size);
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////

View File

@ -12,7 +12,7 @@ Result(void) EncryptedSocket_send(EncryptedSocket* ptr,
Array(u8) decrypted_buf, Array(u8) encrypted_buf) Array(u8) decrypted_buf, Array(u8) encrypted_buf)
{ {
Deferral(4); Deferral(4);
EncryptorAES_encrypt(&ptr->enc, decrypted_buf, encrypted_buf); try_void(EncryptorAES_encrypt(&ptr->enc, decrypted_buf, encrypted_buf));
try_void(socket_send(ptr->sock, encrypted_buf)); try_void(socket_send(ptr->sock, encrypted_buf));
Return RESULT_VOID; Return RESULT_VOID;
} }
@ -22,28 +22,29 @@ Result(void) EncryptedSocket_sendto(EncryptedSocket* ptr,
EndpointIPv4 remote_end) EndpointIPv4 remote_end)
{ {
Deferral(4); Deferral(4);
EncryptorAES_encrypt(&ptr->enc, decrypted_buf, encrypted_buf); try_void(EncryptorAES_encrypt(&ptr->enc, decrypted_buf, encrypted_buf));
try_void(socket_sendto(ptr->sock, encrypted_buf, remote_end)); try_void(socket_sendto(ptr->sock, encrypted_buf, remote_end));
Return RESULT_VOID; Return RESULT_VOID;
} }
Result(i32) EncryptedSocket_recv(EncryptedSocket* ptr, Result(i32) EncryptedSocket_recv(EncryptedSocket* ptr,
Array(u8) encrypted_buf, Array(u8) decrypted_buf) Array(u8) encrypted_buf, Array(u8) decrypted_buf,
SocketRecvFlag flags)
{ {
Deferral(4); Deferral(4);
try(i32 r, i, socket_recv(ptr->sock, encrypted_buf)); try(i32 r, i, socket_recv(ptr->sock, encrypted_buf, flags));
encrypted_buf.size = r; encrypted_buf.size = r;
DecryptorAES_decrypt(&ptr->dec, encrypted_buf, decrypted_buf, (u32*)&r); try_void(DecryptorAES_decrypt(&ptr->dec, encrypted_buf, decrypted_buf, (u32*)&r));
Return RESULT_VALUE(i, r); Return RESULT_VALUE(i, r);
} }
Result(i32) EncryptedSocket_recvfrom(EncryptedSocket* ptr, Result(i32) EncryptedSocket_recvfrom(EncryptedSocket* ptr,
Array(u8) encrypted_buf, Array(u8) decrypted_buf, Array(u8) encrypted_buf, Array(u8) decrypted_buf,
NULLABLE(EndpointIPv4*) remote_end) SocketRecvFlag flags, NULLABLE(EndpointIPv4*) remote_end)
{ {
Deferral(4); Deferral(4);
try(i32 r, i, socket_recvfrom(ptr->sock, encrypted_buf, remote_end)); try(i32 r, i, socket_recvfrom(ptr->sock, encrypted_buf, flags, remote_end));
encrypted_buf.size = r; encrypted_buf.size = r;
DecryptorAES_decrypt(&ptr->dec, encrypted_buf, decrypted_buf, (u32*)&r); try_void(DecryptorAES_decrypt(&ptr->dec, encrypted_buf, decrypted_buf, (u32*)&r));
Return RESULT_VALUE(i, r); Return RESULT_VALUE(i, r);
} }

View File

@ -18,8 +18,10 @@ Result(void) EncryptedSocket_sendto(EncryptedSocket* ptr,
EndpointIPv4 remote_end); EndpointIPv4 remote_end);
Result(i32) EncryptedSocket_recv(EncryptedSocket* ptr, Result(i32) EncryptedSocket_recv(EncryptedSocket* ptr,
Array(u8) encrypted_buf, Array(u8) decrypted_buf); Array(u8) encrypted_buf, Array(u8) decrypted_buf,
SocketRecvFlag flags);
Result(i32) EncryptedSocket_recvfrom(EncryptedSocket* ptr, Result(i32) EncryptedSocket_recvfrom(EncryptedSocket* ptr,
Array(u8) encrypted_buf, Array(u8) decrypted_buf, Array(u8) encrypted_buf, Array(u8) decrypted_buf,
SocketRecvFlag flags,
NULLABLE(EndpointIPv4*) remote_end); NULLABLE(EndpointIPv4*) remote_end);

View File

@ -74,17 +74,26 @@ Result(void) socket_sendto(Socket s, Array(u8) buffer, EndpointIPv4 dst){
return RESULT_VOID; return RESULT_VOID;
} }
Result(i32) socket_recv(Socket s, Array(u8) buffer){ static inline int SocketRecvFlags_toStd(SocketRecvFlag flags){
i32 r = recv(s, buffer.data, buffer.size, 0); int f = 0;
if (flags & SocketRecvFlag_Peek)
f |= MSG_PEEK;
if (flags & SocketRecvFlag_WaitAll)
f |= MSG_WAITALL;
return f;
}
Result(i32) socket_recv(Socket s, Array(u8) buffer, SocketRecvFlag flags){
i32 r = recv(s, buffer.data, buffer.size, SocketRecvFlags_toStd(flags));
if(r < 0) if(r < 0)
return RESULT_ERROR_SOCKET(); return RESULT_ERROR_SOCKET();
return RESULT_VALUE(i, r); return RESULT_VALUE(i, r);
} }
Result(i32) socket_recvfrom(Socket s, Array(u8) buffer, NULLABLE(EndpointIPv4*) remote_end){ Result(i32) socket_recvfrom(Socket s, Array(u8) buffer, SocketRecvFlag flags, NULLABLE(EndpointIPv4*) remote_end){
struct sockaddr_in remote_addr = {0}; struct sockaddr_in remote_addr = {0};
i32 sockaddr_size = sizeof(remote_addr); i32 sockaddr_size = sizeof(remote_addr);
i32 r = recvfrom(s, buffer.data, buffer.size, 0, i32 r = recvfrom(s, buffer.data, buffer.size, SocketRecvFlags_toStd(flags),
(struct sockaddr*)&remote_addr, (void*)&sockaddr_size); (struct sockaddr*)&remote_addr, (void*)&sockaddr_size);
if(r < 0) if(r < 0)
return RESULT_ERROR_SOCKET(); return RESULT_ERROR_SOCKET();

View File

@ -2,6 +2,7 @@
#include "endpoint.h" #include "endpoint.h"
#include "tlibc/errors.h" #include "tlibc/errors.h"
#include "tlibc/collections/Array.h" #include "tlibc/collections/Array.h"
#include "tlibc/time.h"
typedef enum SocketShutdownType { typedef enum SocketShutdownType {
SocketShutdownType_Receive = 0, SocketShutdownType_Receive = 0,
@ -9,6 +10,12 @@ typedef enum SocketShutdownType {
SocketShutdownType_Both = 2, SocketShutdownType_Both = 2,
} SocketShutdownType; } SocketShutdownType;
typedef enum SocketRecvFlag {
SocketRecvFlag_None = 0,
SocketRecvFlag_Peek = 0b1 /* next recv call will read the same data */,
SocketRecvFlag_WaitAll = 0b10 /* waits until buffer is full */,
} SocketRecvFlag;
typedef i64 Socket; typedef i64 Socket;
Result(Socket) socket_open_TCP(); Result(Socket) socket_open_TCP();
@ -20,5 +27,5 @@ Result(Socket) socket_accept(Socket s, NULLABLE(EndpointIPv4*) remote_end);
Result(void) socket_connect(Socket s, EndpointIPv4 remote_end); Result(void) socket_connect(Socket s, EndpointIPv4 remote_end);
Result(void) socket_send(Socket s, Array(u8) buffer); Result(void) socket_send(Socket s, Array(u8) buffer);
Result(void) socket_sendto(Socket s, Array(u8) buffer, EndpointIPv4 dst); Result(void) socket_sendto(Socket s, Array(u8) buffer, EndpointIPv4 dst);
Result(i32) socket_recv(Socket s, Array(u8) buffer); Result(i32) socket_recv(Socket s, Array(u8) buffer, SocketRecvFlag flags);
Result(i32) socket_recvfrom(Socket s, Array(u8) buffer, NULLABLE(EndpointIPv4*) remote_end); Result(i32) socket_recvfrom(Socket s, Array(u8) buffer, SocketRecvFlag flags, NULLABLE(EndpointIPv4*) remote_end);