implemented EncryptedSocketTCP_recvStruct and EncryptedSocketTCP_recvRSA
This commit is contained in:
@@ -1,5 +1,9 @@
|
||||
#include "encrypted_sockets.h"
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// EncryptedSocketTCP //
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void EncryptedSocketTCP_construct(EncryptedSocketTCP* ptr,
|
||||
Socket sock, u32 crypto_buffer_size, Array(u8) aes_key)
|
||||
{
|
||||
@@ -16,6 +20,11 @@ void EncryptedSocketTCP_destroy(EncryptedSocketTCP* ptr){
|
||||
free(ptr->send_buf.data);
|
||||
}
|
||||
|
||||
void EncryptedSocketTCP_changeKey(EncryptedSocketTCP* ptr, Array(u8) aes_key){
|
||||
AESStreamEncryptor_changeKey(&ptr->enc, aes_key);
|
||||
AESStreamDecryptor_changeKey(&ptr->dec, aes_key);
|
||||
}
|
||||
|
||||
Result(void) EncryptedSocketTCP_send(EncryptedSocketTCP* ptr,
|
||||
Array(u8) buffer)
|
||||
{
|
||||
@@ -31,7 +40,7 @@ Result(void) EncryptedSocketTCP_send(EncryptedSocketTCP* ptr,
|
||||
try_void(
|
||||
socket_send(
|
||||
ptr->sock,
|
||||
Array_sliceBefore(ptr->send_buf, encrypted_size)
|
||||
Array_sliceTo(ptr->send_buf, encrypted_size)
|
||||
)
|
||||
);
|
||||
|
||||
@@ -51,14 +60,14 @@ Result(u32) EncryptedSocketTCP_recv(EncryptedSocketTCP* ptr,
|
||||
try(i32 received_size, i,
|
||||
socket_recv(
|
||||
ptr->sock,
|
||||
Array_sliceBefore(ptr->recv_buf, size_to_receive),
|
||||
Array_sliceTo(ptr->recv_buf, size_to_receive),
|
||||
flags
|
||||
)
|
||||
);
|
||||
try(u32 decrypted_size, u,
|
||||
AESStreamDecryptor_decrypt(
|
||||
&ptr->dec,
|
||||
Array_sliceBefore(ptr->recv_buf, received_size),
|
||||
Array_sliceTo(ptr->recv_buf, received_size),
|
||||
buffer
|
||||
)
|
||||
);
|
||||
@@ -66,7 +75,80 @@ Result(u32) EncryptedSocketTCP_recv(EncryptedSocketTCP* ptr,
|
||||
Return RESULT_VALUE(u, decrypted_size);
|
||||
}
|
||||
|
||||
Result(void) EncryptedSocketTCP_sendRSA(EncryptedSocketTCP* ptr,
|
||||
RSAEncryptor* rsa_enc, Array(u8) buffer)
|
||||
{
|
||||
Deferral(1);
|
||||
|
||||
try(u32 encrypted_size, u,
|
||||
RSAEncryptor_encrypt(
|
||||
rsa_enc,
|
||||
buffer,
|
||||
ptr->send_buf
|
||||
)
|
||||
);
|
||||
try_void(
|
||||
socket_send(
|
||||
ptr->sock,
|
||||
Array_sliceTo(ptr->send_buf, encrypted_size)
|
||||
)
|
||||
);
|
||||
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
|
||||
Result(u32) EncryptedSocketTCP_recvRSA(EncryptedSocketTCP* ptr,
|
||||
RSADecryptor* rsa_dec, Array(u8) buffer, SocketRecvFlag flags)
|
||||
{
|
||||
Deferral(1);
|
||||
|
||||
// RSA encrypts message in block of size KEY_SIZE_BYTES.
|
||||
// SocketRecvFlag_WholeBuffer should be always enabled to receive such blocks.
|
||||
// If this flag is set in `flags` by caller, it means decrypted message size
|
||||
// must be the same as buffer size.
|
||||
bool fill_whole_buffer = (flags & SocketRecvFlag_WholeBuffer) != 0;
|
||||
flags |= SocketRecvFlag_WholeBuffer;
|
||||
u32 size_to_receive = rsa_dec->sk->n_bitlen / 8;
|
||||
|
||||
try(i32 received_size, i,
|
||||
socket_recv(
|
||||
ptr->sock,
|
||||
Array_sliceTo(ptr->recv_buf, size_to_receive),
|
||||
flags
|
||||
)
|
||||
);
|
||||
try(u32 decrypted_size, u,
|
||||
RSADecryptor_decrypt(
|
||||
rsa_dec,
|
||||
Array_sliceTo(ptr->recv_buf, received_size)
|
||||
)
|
||||
);
|
||||
|
||||
if(fill_whole_buffer){
|
||||
if(decrypted_size != buffer.size){
|
||||
Return RESULT_ERROR_FMT(
|
||||
"SocketRecvFlag_WholeBuffer is set, "
|
||||
"but decrypted_size (%u) != buffer.size (%u)",
|
||||
decrypted_size, buffer.size
|
||||
);
|
||||
}
|
||||
}
|
||||
else if(decrypted_size > buffer.size){
|
||||
Return RESULT_ERROR_FMT(
|
||||
"decrypted_size (%u) > buffer.size (%u)",
|
||||
decrypted_size, buffer.size
|
||||
);
|
||||
}
|
||||
|
||||
memcpy(buffer.data, ptr->recv_buf.data, decrypted_size);
|
||||
Return RESULT_VALUE(u, decrypted_size);
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// EncryptedSocketUDP //
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void EncryptedSocketUDP_construct(EncryptedSocketUDP* ptr,
|
||||
Socket sock, u32 crypto_buffer_size, Array(u8) aes_key)
|
||||
@@ -84,6 +166,11 @@ void EncryptedSocketUDP_destroy(EncryptedSocketUDP* ptr){
|
||||
free(ptr->send_buf.data);
|
||||
}
|
||||
|
||||
void EncryptedSocketUDP_changeKey(EncryptedSocketUDP* ptr, Array(u8) aes_key){
|
||||
AESBlockEncryptor_changeKey(&ptr->enc, aes_key);
|
||||
AESBlockDecryptor_changeKey(&ptr->dec, aes_key);
|
||||
}
|
||||
|
||||
Result(void) EncryptedSocketUDP_sendto(EncryptedSocketUDP* ptr,
|
||||
Array(u8) buffer, EndpointIPv4 remote_end)
|
||||
{
|
||||
@@ -99,7 +186,7 @@ Result(void) EncryptedSocketUDP_sendto(EncryptedSocketUDP* ptr,
|
||||
try_void(
|
||||
socket_sendto(
|
||||
ptr->sock,
|
||||
Array_sliceBefore(ptr->send_buf, encrypted_size),
|
||||
Array_sliceTo(ptr->send_buf, encrypted_size),
|
||||
remote_end
|
||||
)
|
||||
);
|
||||
@@ -117,7 +204,7 @@ Result(i32) EncryptedSocketUDP_recvfrom(EncryptedSocketUDP* ptr,
|
||||
try(i32 received_size, i,
|
||||
socket_recvfrom(
|
||||
ptr->sock,
|
||||
Array_sliceBefore(ptr->recv_buf, size_to_receive),
|
||||
Array_sliceTo(ptr->recv_buf, size_to_receive),
|
||||
flags,
|
||||
remote_end
|
||||
)
|
||||
@@ -125,7 +212,7 @@ Result(i32) EncryptedSocketUDP_recvfrom(EncryptedSocketUDP* ptr,
|
||||
try(u32 decrypted_size, u,
|
||||
AESBlockDecryptor_decrypt(
|
||||
&ptr->dec,
|
||||
Array_sliceBefore(ptr->recv_buf, received_size),
|
||||
Array_sliceTo(ptr->recv_buf, received_size),
|
||||
buffer
|
||||
)
|
||||
);
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
#include "network/socket.h"
|
||||
#include "cryptography/AES.h"
|
||||
#include "cryptography/RSA.h"
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// EncryptedSocketTCP //
|
||||
@@ -20,12 +21,37 @@ void EncryptedSocketTCP_construct(EncryptedSocketTCP* ptr,
|
||||
/// closes the socket
|
||||
void EncryptedSocketTCP_destroy(EncryptedSocketTCP* ptr);
|
||||
|
||||
void EncryptedSocketTCP_changeKey(EncryptedSocketTCP* ptr, Array(u8) aes_key);
|
||||
|
||||
Result(void) EncryptedSocketTCP_send(EncryptedSocketTCP* ptr,
|
||||
Array(u8) buffer);
|
||||
|
||||
#define EncryptedSocketTCP_sendStruct(socket, structPtr)\
|
||||
EncryptedSocketTCP_send(socket,\
|
||||
Array_construct_size(structPtr, sizeof(*structPtr)))
|
||||
|
||||
Result(u32) EncryptedSocketTCP_recv(EncryptedSocketTCP* ptr,
|
||||
Array(u8) buffer, SocketRecvFlag flags);
|
||||
|
||||
#define EncryptedSocketTCP_recvStruct(socket, structPtr)\
|
||||
EncryptedSocketTCP_recv(socket,\
|
||||
Array_construct_size(structPtr, sizeof(*structPtr)),\
|
||||
SocketRecvFlag_WholeBuffer)
|
||||
|
||||
Result(void) EncryptedSocketTCP_sendRSA(EncryptedSocketTCP* ptr,
|
||||
RSAEncryptor* rsa_enc, Array(u8) buffer);
|
||||
|
||||
#define EncryptedSocketTCP_sendStructRSA(socket, rsa_enc, structPtr)\
|
||||
EncryptedSocketTCP_sendRSA(socket, rsa_enc,\
|
||||
Array_construct_size(structPtr, sizeof(*structPtr)))
|
||||
|
||||
Result(u32) EncryptedSocketTCP_recvRSA(EncryptedSocketTCP* ptr,
|
||||
RSADecryptor* rsa_dec, Array(u8) buffer, SocketRecvFlag flags);
|
||||
|
||||
#define EncryptedSocketTCP_recvStructRSA(socket, rsa_dec, structPtr)\
|
||||
EncryptedSocketTCP_recvRSA(socket, rsa_dec,\
|
||||
Array_construct_size(structPtr, sizeof(*structPtr)),\
|
||||
SocketRecvFlag_WholeBuffer)
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// EncryptedSocketUDP //
|
||||
@@ -45,6 +71,8 @@ void EncryptedSocketUDP_construct(EncryptedSocketUDP* ptr,
|
||||
/// closes the socket
|
||||
void EncryptedSocketUDP_destroy(EncryptedSocketUDP* ptr);
|
||||
|
||||
void EncryptedSocketUDP_changeKey(EncryptedSocketUDP* ptr, Array(u8) aes_key);
|
||||
|
||||
Result(void) EncryptedSocketUDP_sendto(EncryptedSocketUDP* ptr,
|
||||
Array(u8) buffer, EndpointIPv4 remote_end);
|
||||
|
||||
|
||||
@@ -86,7 +86,7 @@ static inline int SocketRecvFlags_toStd(SocketRecvFlag flags){
|
||||
int f = 0;
|
||||
if (flags & SocketRecvFlag_Peek)
|
||||
f |= MSG_PEEK;
|
||||
if (flags & SocketRecvFlag_WaitAll)
|
||||
if (flags & SocketRecvFlag_WholeBuffer)
|
||||
f |= MSG_WAITALL;
|
||||
return f;
|
||||
}
|
||||
@@ -96,7 +96,7 @@ Result(i32) socket_recv(Socket s, Array(u8) buffer, SocketRecvFlag flags){
|
||||
if(r < 0){
|
||||
return RESULT_ERROR_SOCKET();
|
||||
}
|
||||
if(r == 0 || (flags & SocketRecvFlag_WaitAll && (u32)r != buffer.size))
|
||||
if(r == 0 || (flags & SocketRecvFlag_WholeBuffer && (u32)r != buffer.size))
|
||||
{
|
||||
return RESULT_ERROR("Socket closed", false);
|
||||
}
|
||||
@@ -111,7 +111,7 @@ Result(i32) socket_recvfrom(Socket s, Array(u8) buffer, SocketRecvFlag flags, NU
|
||||
if(r < 0){
|
||||
return RESULT_ERROR_SOCKET();
|
||||
}
|
||||
if(r == 0 || (flags & SocketRecvFlag_WaitAll && (u32)r != buffer.size))
|
||||
if(r == 0 || (flags & SocketRecvFlag_WholeBuffer && (u32)r != buffer.size))
|
||||
{
|
||||
return RESULT_ERROR("Socket closed", false);
|
||||
}
|
||||
|
||||
@@ -13,7 +13,7 @@ typedef enum SocketShutdownType {
|
||||
typedef enum SocketRecvFlag {
|
||||
SocketRecvFlag_None = 0,
|
||||
SocketRecvFlag_Peek = 0b1 /* next recv call will read the same data */,
|
||||
SocketRecvFlag_WaitAll = 0b10 /* waits until buffer is full */,
|
||||
SocketRecvFlag_WholeBuffer = 0b10 /* waits until buffer is full */,
|
||||
} SocketRecvFlag;
|
||||
|
||||
typedef i64 Socket;
|
||||
|
||||
Reference in New Issue
Block a user