implemented EncryptedSocketTCP_recvStruct and EncryptedSocketTCP_recvRSA
This commit is contained in:
parent
ee522ac401
commit
2db37bb902
2
dependencies/tlibc
vendored
2
dependencies/tlibc
vendored
@ -1 +1 @@
|
||||
Subproject commit 00a1a29d342c6b3bcb8015a661364e10a3134fe4
|
||||
Subproject commit 30c141f587ac8fa635812c6f4713dbc20b18d7c9
|
||||
@ -67,45 +67,25 @@ Result(ServerConnection*) ServerConnection_open(ClientCredentials* client_creden
|
||||
try_void(socket_connect(_s, conn->server_end));
|
||||
EncryptedSocketTCP_construct(&conn->sock, _s, NETWORK_BUFFER_SIZE, conn->session_key);
|
||||
|
||||
Array(u8) buffer = Array_alloc_size(NETWORK_BUFFER_SIZE);
|
||||
// fix for valgrind false detected errors about uninitialized memory
|
||||
Array_memset(buffer, 0xCC);
|
||||
Defer(free(buffer.data));
|
||||
|
||||
// construct PacketHeader and ClientHandshake in buffer
|
||||
PacketHeader_construct(buffer.data, PROTOCOL_VERSION,
|
||||
PacketType_ClientHandshake, sizeof(ClientHandshake));
|
||||
ClientHandshake_construct(
|
||||
Array_sliceAfter(buffer, sizeof(PacketHeader)).data,
|
||||
// send PacketHeader and ClientHandshake
|
||||
// encryption by server public key
|
||||
PacketHeader packet_header;
|
||||
PacketHeader_construct(&packet_header,
|
||||
PROTOCOL_VERSION, PacketType_ClientHandshake, sizeof(ClientHandshake));
|
||||
ClientHandshake client_handshake;
|
||||
ClientHandshake_construct(&client_handshake,
|
||||
conn->session_key);
|
||||
u32 header_and_message_size = sizeof(PacketHeader) + sizeof(ClientHandshake);
|
||||
// encrypt message by server public key
|
||||
Array(u8) bufferPart_encryptedClientHandshake = Array_sliceAfter(buffer, header_and_message_size);
|
||||
try(u32 rsa_enc_size, u,
|
||||
RSAEncryptor_encrypt(
|
||||
&conn->rsa_enc,
|
||||
Array_sliceBefore(buffer, header_and_message_size),
|
||||
bufferPart_encryptedClientHandshake
|
||||
)
|
||||
);
|
||||
bufferPart_encryptedClientHandshake.size = rsa_enc_size;
|
||||
// send encrypted message
|
||||
try_void(socket_send(conn->sock.sock, bufferPart_encryptedClientHandshake));
|
||||
try_void(EncryptedSocketTCP_sendStructRSA(&conn->sock, &conn->rsa_enc, &packet_header));
|
||||
try_void(EncryptedSocketTCP_sendStructRSA(&conn->sock, &conn->rsa_enc, &client_handshake));
|
||||
|
||||
// receive server response
|
||||
try_void(
|
||||
EncryptedSocketTCP_recv(&conn->sock,
|
||||
Array_sliceBefore(buffer, sizeof(PacketHeader)),
|
||||
SocketRecvFlag_WaitAll
|
||||
)
|
||||
);
|
||||
PacketHeader* packet_header = buffer.data;
|
||||
try_void(PacketHeader_validateMagic(packet_header));
|
||||
try_void(EncryptedSocketTCP_recvStruct(&conn->sock, &packet_header));
|
||||
try_void(PacketHeader_validateMagic(&packet_header));
|
||||
|
||||
// handle server response
|
||||
switch(packet_header->type){
|
||||
switch(packet_header.type){
|
||||
case PacketType_ErrorMessage: {
|
||||
u32 err_msg_size = packet_header->content_size;
|
||||
u32 err_msg_size = packet_header.content_size;
|
||||
if(err_msg_size > conn->sock.recv_buf.size)
|
||||
err_msg_size = conn->sock.recv_buf.size;
|
||||
Array(u8) err_buf = Array_alloc_size(err_msg_size + 1);
|
||||
@ -119,8 +99,8 @@ Result(ServerConnection*) ServerConnection_open(ClientCredentials* client_creden
|
||||
try_void(
|
||||
EncryptedSocketTCP_recv(
|
||||
&conn->sock,
|
||||
Array_sliceBefore(err_buf, err_msg_size),
|
||||
SocketRecvFlag_WaitAll
|
||||
Array_sliceTo(err_buf, err_msg_size),
|
||||
SocketRecvFlag_WholeBuffer
|
||||
)
|
||||
);
|
||||
|
||||
@ -129,23 +109,13 @@ Result(ServerConnection*) ServerConnection_open(ClientCredentials* client_creden
|
||||
Return RESULT_ERROR((char*)err_buf.data, true);
|
||||
}
|
||||
case PacketType_ServerHandshake: {
|
||||
Array(u8) bufferPart_ServerHandshake = {
|
||||
.data = (u8*)buffer.data + sizeof(PacketHeader),
|
||||
.size = sizeof(ServerHandshake)
|
||||
};
|
||||
try_void(
|
||||
EncryptedSocketTCP_recv(
|
||||
&conn->sock,
|
||||
bufferPart_ServerHandshake,
|
||||
SocketRecvFlag_WaitAll
|
||||
)
|
||||
);
|
||||
ServerHandshake* server_handshake = bufferPart_ServerHandshake.data;
|
||||
conn->session_id = server_handshake->session_id;
|
||||
ServerHandshake server_handshake;
|
||||
try_void(EncryptedSocketTCP_recvStruct(&conn->sock, &server_handshake));
|
||||
conn->session_id = server_handshake.session_id;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
Return RESULT_ERROR_FMT("unexpected response type: %i", packet_header->type);
|
||||
Return RESULT_ERROR_FMT("unexpected response type: %i", packet_header.type);
|
||||
}
|
||||
|
||||
success = true;
|
||||
|
||||
@ -4,18 +4,23 @@
|
||||
// write data from src to array and increment array data pointer
|
||||
static inline void __Array_writeNext(Array(u8)* dst, u8* src, size_t size){
|
||||
memcpy(dst->data, src, size);
|
||||
*dst = Array_sliceAfter(*dst, size);
|
||||
*dst = Array_sliceFrom(*dst, size);
|
||||
}
|
||||
|
||||
// read data from array to dst and increment array data pointer
|
||||
static inline void __Array_readNext(u8* dst, Array(u8)* src, size_t size){
|
||||
memcpy(dst, src->data, size);
|
||||
*src = Array_sliceAfter(*src, size);
|
||||
*src = Array_sliceFrom(*src, size);
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// AESBlockEncryptor //
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void AESBlockEncryptor_construct(AESBlockEncryptor* ptr, Array(u8) key, const br_block_cbcenc_class* enc_class){
|
||||
void AESBlockEncryptor_construct(AESBlockEncryptor* ptr,
|
||||
Array(u8) key, const br_block_cbcenc_class* enc_class)
|
||||
{
|
||||
assert(key.size == 16 || key.size == 24 || key.size == 32);
|
||||
|
||||
ptr->enc_class = enc_class;
|
||||
@ -25,7 +30,15 @@ void AESBlockEncryptor_construct(AESBlockEncryptor* ptr, Array(u8) key, const br
|
||||
rng_init_sha256_seedFromSystem(&ptr->rng_ctx.vtable);
|
||||
}
|
||||
|
||||
Result(u32) AESBlockEncryptor_encrypt(AESBlockEncryptor* ptr, Array(u8) src, Array(u8) dst){
|
||||
void AESBlockEncryptor_changeKey(AESBlockEncryptor* ptr, Array(u8) key)
|
||||
{
|
||||
assert(key.size == 16 || key.size == 24 || key.size == 32);
|
||||
ptr->enc_class->init((void*)ptr->enc_keys, key.data, key.size);
|
||||
}
|
||||
|
||||
Result(u32) AESBlockEncryptor_encrypt(AESBlockEncryptor* ptr,
|
||||
Array(u8) src, Array(u8) dst)
|
||||
{
|
||||
Deferral(4);
|
||||
u32 encrypted_size = AESBlockEncryptor_calcDstSize(src.size);
|
||||
try_assert(dst.size >= encrypted_size);
|
||||
@ -67,15 +80,28 @@ Result(u32) AESBlockEncryptor_encrypt(AESBlockEncryptor* ptr, Array(u8) src, Arr
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// AESBlockDecryptor //
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void AESBlockDecryptor_construct(AESBlockDecryptor* ptr, Array(u8) key, const br_block_cbcdec_class* dec_class){
|
||||
void AESBlockDecryptor_construct(AESBlockDecryptor* ptr,
|
||||
Array(u8) key, const br_block_cbcdec_class* dec_class)
|
||||
{
|
||||
assert(key.size == 16 || key.size == 24 || key.size == 32);
|
||||
|
||||
ptr->dec_class = dec_class;
|
||||
ptr->dec_class->init((void*)ptr->dec_keys, key.data, key.size);
|
||||
}
|
||||
|
||||
Result(u32) AESBlockDecryptor_decrypt(AESBlockDecryptor* ptr, Array(u8) src, Array(u8) dst){
|
||||
void AESBlockDecryptor_changeKey(AESBlockDecryptor* ptr, Array(u8) key)
|
||||
{
|
||||
assert(key.size == 16 || key.size == 24 || key.size == 32);
|
||||
ptr->dec_class->init((void*)ptr->dec_keys, key.data, key.size);
|
||||
}
|
||||
|
||||
Result(u32) AESBlockDecryptor_decrypt(AESBlockDecryptor* ptr,
|
||||
Array(u8) src, Array(u8) dst)
|
||||
{
|
||||
Deferral(4);
|
||||
try_assert(src.size >= AESBlockEncryptor_calcDstSize(0));
|
||||
try_assert(src.size % 16 == 0 && "src must be array of 16-byte blocks");
|
||||
@ -116,8 +142,13 @@ Result(u32) AESBlockDecryptor_decrypt(AESBlockDecryptor* ptr, Array(u8) src, Arr
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// AESStreamEncryptor //
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void AESStreamEncryptor_construct(AESStreamEncryptor* ptr, Array(u8) key, const br_block_ctr_class* ctr_class){
|
||||
void AESStreamEncryptor_construct(AESStreamEncryptor* ptr,
|
||||
Array(u8) key, const br_block_ctr_class* ctr_class)
|
||||
{
|
||||
assert(key.size == 16 || key.size == 24 || key.size == 32);
|
||||
|
||||
ptr->ctr_class = ctr_class;
|
||||
@ -131,7 +162,15 @@ void AESStreamEncryptor_construct(AESStreamEncryptor* ptr, Array(u8) key, const
|
||||
ptr->block_counter = 0;
|
||||
}
|
||||
|
||||
Result(u32) AESStreamEncryptor_encrypt(AESStreamEncryptor* ptr, Array(u8) src, Array(u8) dst){
|
||||
void AESStreamEncryptor_changeKey(AESStreamEncryptor* ptr, Array(u8) key)
|
||||
{
|
||||
assert(key.size == 16 || key.size == 24 || key.size == 32);
|
||||
ptr->ctr_class->init((void*)ptr->ctr_keys, key.data, key.size);
|
||||
}
|
||||
|
||||
Result(u32) AESStreamEncryptor_encrypt(AESStreamEncryptor* ptr,
|
||||
Array(u8) src, Array(u8) dst)
|
||||
{
|
||||
Deferral(4);
|
||||
u32 encrypted_size = AESStreamEncryptor_calcDstSize(src.size);
|
||||
try_assert(dst.size >= encrypted_size);
|
||||
@ -164,8 +203,13 @@ Result(u32) AESStreamEncryptor_encrypt(AESStreamEncryptor* ptr, Array(u8) src, A
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// AESStreamDecryptor //
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void AESStreamDecryptor_construct(AESStreamDecryptor* ptr, Array(u8) key, const br_block_ctr_class* ctr_class){
|
||||
void AESStreamDecryptor_construct(AESStreamDecryptor* ptr,
|
||||
Array(u8) key, const br_block_ctr_class* ctr_class)
|
||||
{
|
||||
assert(key.size == 16 || key.size == 24 || key.size == 32);
|
||||
|
||||
ptr->ctr_class = ctr_class;
|
||||
@ -174,7 +218,15 @@ void AESStreamDecryptor_construct(AESStreamDecryptor* ptr, Array(u8) key, const
|
||||
ptr->block_counter = 0;
|
||||
}
|
||||
|
||||
Result(u32) AESStreamDecryptor_decrypt(AESStreamDecryptor* ptr, Array(u8) src, Array(u8) dst){
|
||||
void AESStreamDecryptor_changeKey(AESStreamDecryptor* ptr, Array(u8) key)
|
||||
{
|
||||
assert(key.size == 16 || key.size == 24 || key.size == 32);
|
||||
ptr->ctr_class->init((void*)ptr->ctr_keys, key.data, key.size);
|
||||
}
|
||||
|
||||
Result(u32) AESStreamDecryptor_decrypt(AESStreamDecryptor* ptr,
|
||||
Array(u8) src, Array(u8) dst)
|
||||
{
|
||||
Deferral(4);
|
||||
|
||||
// if it is the beginning of the stream, read IV
|
||||
|
||||
@ -40,6 +40,9 @@ typedef struct AESBlockEncryptor {
|
||||
/// @param enc_class &br_aes_XXX_cbcenc_vtable
|
||||
void AESBlockEncryptor_construct(AESBlockEncryptor* ptr, Array(u8) key, const br_block_cbcenc_class* enc_class);
|
||||
|
||||
/// @param key supported sizes: 16, 24, 32
|
||||
void AESBlockEncryptor_changeKey(AESBlockEncryptor* ptr, Array(u8) key);
|
||||
|
||||
/// @brief Encrypts a complete message. For part-by-part encryption use AESStreamEncryptor.
|
||||
/// @param src array of any size
|
||||
/// @param dst array of size >= AESBlockEncryptor_calcDstSize(src.size)
|
||||
@ -63,6 +66,9 @@ typedef struct AESBlockDecryptor {
|
||||
/// @param dec_class &br_aes_XXX_cbcdec_vtable
|
||||
void AESBlockDecryptor_construct(AESBlockDecryptor* ptr, Array(u8) key, const br_block_cbcdec_class* dec_class);
|
||||
|
||||
/// @param key supported sizes: 16, 24, 32
|
||||
void AESBlockDecryptor_changeKey(AESBlockDecryptor* ptr, Array(u8) key);
|
||||
|
||||
/// @brief Decrypts a complete message. For part-by-part decryption use AESStreamEncryptor.
|
||||
/// @param src array of size at least AESBlockEncryptor_calcDstSize(0). Size must be multiple of 16.
|
||||
/// @param dst array of size >= src.size
|
||||
@ -88,6 +94,9 @@ typedef struct AESStreamEncryptor {
|
||||
/// @param dec_class &br_aes_XXX_ctr_vtable
|
||||
void AESStreamEncryptor_construct(AESStreamEncryptor* ptr, Array(u8) key, const br_block_ctr_class* ctr_class);
|
||||
|
||||
/// @param key supported sizes: 16, 24, 32
|
||||
void AESStreamEncryptor_changeKey(AESStreamEncryptor* ptr, Array(u8) key);
|
||||
|
||||
/// use this only at the beginning of the stream
|
||||
#define AESStreamEncryptor_calcDstSize(src_size) (src_size + __AES_STREAM_IV_SIZE)
|
||||
|
||||
@ -114,6 +123,9 @@ typedef struct AESStreamDecryptor {
|
||||
/// @param dec_class &br_aes_XXX_ctr_vtable
|
||||
void AESStreamDecryptor_construct(AESStreamDecryptor* ptr, Array(u8) key, const br_block_ctr_class* ctr_class);
|
||||
|
||||
/// @param key supported sizes: 16, 24, 32
|
||||
void AESStreamDecryptor_changeKey(AESStreamDecryptor* ptr, Array(u8) key);
|
||||
|
||||
/// @brief Reads IV from `src`, then decrypts data and writes it to dst
|
||||
/// @param src array of any size
|
||||
/// @param dst array of size >= src.size
|
||||
|
||||
@ -175,6 +175,11 @@ Result(void) RSA_parsePrivateKey_base64(cstr src, br_rsa_private_key* sk){
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// RSAEncryptor //
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void RSAEncryptor_construct(RSAEncryptor* ptr, const br_rsa_public_key* pk){
|
||||
ptr->pk = pk;
|
||||
ptr->rng.vtable = &br_hmac_drbg_vtable;
|
||||
@ -185,11 +190,11 @@ Result(u32) RSAEncryptor_encrypt(RSAEncryptor* ptr, Array(u8) src, Array(u8) dst
|
||||
u32 key_size_bytes = ptr->pk->nlen;
|
||||
const u32 max_src_size = RSAEncryptor_calcMaxSrcSize(key_size_bytes * 8, 256);
|
||||
if(src.size > max_src_size){
|
||||
return RESULT_ERROR_FMT("src.size (%u) must be <= %u (use RSAEncryptor_calcMaxSrcSize)",
|
||||
return RESULT_ERROR_FMT("src.size (%u) must be <= RSAEncryptor_calcMaxSrcSize() (%u)",
|
||||
src.size, max_src_size);
|
||||
}
|
||||
if(dst.size < key_size_bytes){
|
||||
return RESULT_ERROR_FMT("dst.size (%u) must be >= %u (key length in bytes)",
|
||||
return RESULT_ERROR_FMT("dst.size (%u) must be >= key length in bytes (%u)",
|
||||
dst.size, key_size_bytes);
|
||||
}
|
||||
size_t sz = br_rsa_i31_oaep_encrypt(
|
||||
@ -206,30 +211,27 @@ Result(u32) RSAEncryptor_encrypt(RSAEncryptor* ptr, Array(u8) src, Array(u8) dst
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// RSADecryptor //
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void RSADecryptor_construct(RSADecryptor* ptr, const br_rsa_private_key* sk){
|
||||
ptr->sk = sk;
|
||||
}
|
||||
|
||||
Result(u32) RSADecryptor_decrypt(RSADecryptor* ptr, Array(u8) src, Array(u8) dst){
|
||||
Result(u32) RSADecryptor_decrypt(RSADecryptor* ptr, Array(u8) buffer){
|
||||
u32 key_size_bits = ptr->sk->n_bitlen;
|
||||
if(src.size != key_size_bits/8){
|
||||
return RESULT_ERROR_FMT("src.size (%u) must be == %u (key length in bytes)",
|
||||
src.size, key_size_bits/8);
|
||||
if(buffer.size != key_size_bits/8){
|
||||
return RESULT_ERROR_FMT("buffer.size (%u) must be == key length in bytes (%u)",
|
||||
buffer.size, key_size_bits/8);
|
||||
}
|
||||
|
||||
const u32 max_src_size = RSAEncryptor_calcMaxSrcSize(key_size_bits, 256);
|
||||
if(dst.size < max_src_size){
|
||||
return RESULT_ERROR_FMT("dst.size (%u) must be >= %u (use RSAEncryptor_calcMaxSrcSize)",
|
||||
dst.size, max_src_size);
|
||||
}
|
||||
|
||||
memcpy(dst.data, src.data, src.size);
|
||||
size_t sz = src.size;
|
||||
size_t sz = buffer.size;
|
||||
size_t r = br_rsa_i31_oaep_decrypt(
|
||||
&br_sha256_vtable,
|
||||
NULL, 0,
|
||||
ptr->sk,
|
||||
dst.data, &sz);
|
||||
buffer.data, &sz);
|
||||
|
||||
if(r == 0){
|
||||
return RESULT_ERROR("RSA encryption failed", false);
|
||||
|
||||
@ -58,6 +58,7 @@ str RSA_serializePublicKey_base64(const br_rsa_public_key* sk);
|
||||
/// @param sk out public key. WARNING: .p is allocated on heap
|
||||
Result(void) RSA_parsePublicKey_base64(cstr src, br_rsa_public_key* sk);
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// RSAEncryptor //
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
@ -90,6 +91,7 @@ https://crypto.stackexchange.com/a/42100
|
||||
/// @return size of encrypted data
|
||||
Result(u32) RSAEncryptor_encrypt(RSAEncryptor* ptr, Array(u8) src, Array(u8) dst);
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// RSADecryptor //
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
@ -102,6 +104,5 @@ typedef struct RSADecryptor {
|
||||
void RSADecryptor_construct(RSADecryptor* ptr, const br_rsa_private_key* sk);
|
||||
|
||||
/// @param src buffer with size == key size in bytes
|
||||
/// @param dst buffer with size >= `RSAEncryptor_calcMaxSrcSize(key_size_bits, 256)`
|
||||
/// @return size of decrypted data
|
||||
Result(u32) RSADecryptor_decrypt(RSADecryptor* ptr, Array(u8) src, Array(u8) dst);
|
||||
Result(u32) RSADecryptor_decrypt(RSADecryptor* ptr, Array(u8) buffer);
|
||||
|
||||
@ -1,5 +1,9 @@
|
||||
#include "encrypted_sockets.h"
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// EncryptedSocketTCP //
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void EncryptedSocketTCP_construct(EncryptedSocketTCP* ptr,
|
||||
Socket sock, u32 crypto_buffer_size, Array(u8) aes_key)
|
||||
{
|
||||
@ -16,6 +20,11 @@ void EncryptedSocketTCP_destroy(EncryptedSocketTCP* ptr){
|
||||
free(ptr->send_buf.data);
|
||||
}
|
||||
|
||||
void EncryptedSocketTCP_changeKey(EncryptedSocketTCP* ptr, Array(u8) aes_key){
|
||||
AESStreamEncryptor_changeKey(&ptr->enc, aes_key);
|
||||
AESStreamDecryptor_changeKey(&ptr->dec, aes_key);
|
||||
}
|
||||
|
||||
Result(void) EncryptedSocketTCP_send(EncryptedSocketTCP* ptr,
|
||||
Array(u8) buffer)
|
||||
{
|
||||
@ -31,7 +40,7 @@ Result(void) EncryptedSocketTCP_send(EncryptedSocketTCP* ptr,
|
||||
try_void(
|
||||
socket_send(
|
||||
ptr->sock,
|
||||
Array_sliceBefore(ptr->send_buf, encrypted_size)
|
||||
Array_sliceTo(ptr->send_buf, encrypted_size)
|
||||
)
|
||||
);
|
||||
|
||||
@ -51,14 +60,14 @@ Result(u32) EncryptedSocketTCP_recv(EncryptedSocketTCP* ptr,
|
||||
try(i32 received_size, i,
|
||||
socket_recv(
|
||||
ptr->sock,
|
||||
Array_sliceBefore(ptr->recv_buf, size_to_receive),
|
||||
Array_sliceTo(ptr->recv_buf, size_to_receive),
|
||||
flags
|
||||
)
|
||||
);
|
||||
try(u32 decrypted_size, u,
|
||||
AESStreamDecryptor_decrypt(
|
||||
&ptr->dec,
|
||||
Array_sliceBefore(ptr->recv_buf, received_size),
|
||||
Array_sliceTo(ptr->recv_buf, received_size),
|
||||
buffer
|
||||
)
|
||||
);
|
||||
@ -66,7 +75,80 @@ Result(u32) EncryptedSocketTCP_recv(EncryptedSocketTCP* ptr,
|
||||
Return RESULT_VALUE(u, decrypted_size);
|
||||
}
|
||||
|
||||
Result(void) EncryptedSocketTCP_sendRSA(EncryptedSocketTCP* ptr,
|
||||
RSAEncryptor* rsa_enc, Array(u8) buffer)
|
||||
{
|
||||
Deferral(1);
|
||||
|
||||
try(u32 encrypted_size, u,
|
||||
RSAEncryptor_encrypt(
|
||||
rsa_enc,
|
||||
buffer,
|
||||
ptr->send_buf
|
||||
)
|
||||
);
|
||||
try_void(
|
||||
socket_send(
|
||||
ptr->sock,
|
||||
Array_sliceTo(ptr->send_buf, encrypted_size)
|
||||
)
|
||||
);
|
||||
|
||||
Return RESULT_VOID;
|
||||
}
|
||||
|
||||
|
||||
Result(u32) EncryptedSocketTCP_recvRSA(EncryptedSocketTCP* ptr,
|
||||
RSADecryptor* rsa_dec, Array(u8) buffer, SocketRecvFlag flags)
|
||||
{
|
||||
Deferral(1);
|
||||
|
||||
// RSA encrypts message in block of size KEY_SIZE_BYTES.
|
||||
// SocketRecvFlag_WholeBuffer should be always enabled to receive such blocks.
|
||||
// If this flag is set in `flags` by caller, it means decrypted message size
|
||||
// must be the same as buffer size.
|
||||
bool fill_whole_buffer = (flags & SocketRecvFlag_WholeBuffer) != 0;
|
||||
flags |= SocketRecvFlag_WholeBuffer;
|
||||
u32 size_to_receive = rsa_dec->sk->n_bitlen / 8;
|
||||
|
||||
try(i32 received_size, i,
|
||||
socket_recv(
|
||||
ptr->sock,
|
||||
Array_sliceTo(ptr->recv_buf, size_to_receive),
|
||||
flags
|
||||
)
|
||||
);
|
||||
try(u32 decrypted_size, u,
|
||||
RSADecryptor_decrypt(
|
||||
rsa_dec,
|
||||
Array_sliceTo(ptr->recv_buf, received_size)
|
||||
)
|
||||
);
|
||||
|
||||
if(fill_whole_buffer){
|
||||
if(decrypted_size != buffer.size){
|
||||
Return RESULT_ERROR_FMT(
|
||||
"SocketRecvFlag_WholeBuffer is set, "
|
||||
"but decrypted_size (%u) != buffer.size (%u)",
|
||||
decrypted_size, buffer.size
|
||||
);
|
||||
}
|
||||
}
|
||||
else if(decrypted_size > buffer.size){
|
||||
Return RESULT_ERROR_FMT(
|
||||
"decrypted_size (%u) > buffer.size (%u)",
|
||||
decrypted_size, buffer.size
|
||||
);
|
||||
}
|
||||
|
||||
memcpy(buffer.data, ptr->recv_buf.data, decrypted_size);
|
||||
Return RESULT_VALUE(u, decrypted_size);
|
||||
}
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// EncryptedSocketUDP //
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void EncryptedSocketUDP_construct(EncryptedSocketUDP* ptr,
|
||||
Socket sock, u32 crypto_buffer_size, Array(u8) aes_key)
|
||||
@ -84,6 +166,11 @@ void EncryptedSocketUDP_destroy(EncryptedSocketUDP* ptr){
|
||||
free(ptr->send_buf.data);
|
||||
}
|
||||
|
||||
void EncryptedSocketUDP_changeKey(EncryptedSocketUDP* ptr, Array(u8) aes_key){
|
||||
AESBlockEncryptor_changeKey(&ptr->enc, aes_key);
|
||||
AESBlockDecryptor_changeKey(&ptr->dec, aes_key);
|
||||
}
|
||||
|
||||
Result(void) EncryptedSocketUDP_sendto(EncryptedSocketUDP* ptr,
|
||||
Array(u8) buffer, EndpointIPv4 remote_end)
|
||||
{
|
||||
@ -99,7 +186,7 @@ Result(void) EncryptedSocketUDP_sendto(EncryptedSocketUDP* ptr,
|
||||
try_void(
|
||||
socket_sendto(
|
||||
ptr->sock,
|
||||
Array_sliceBefore(ptr->send_buf, encrypted_size),
|
||||
Array_sliceTo(ptr->send_buf, encrypted_size),
|
||||
remote_end
|
||||
)
|
||||
);
|
||||
@ -117,7 +204,7 @@ Result(i32) EncryptedSocketUDP_recvfrom(EncryptedSocketUDP* ptr,
|
||||
try(i32 received_size, i,
|
||||
socket_recvfrom(
|
||||
ptr->sock,
|
||||
Array_sliceBefore(ptr->recv_buf, size_to_receive),
|
||||
Array_sliceTo(ptr->recv_buf, size_to_receive),
|
||||
flags,
|
||||
remote_end
|
||||
)
|
||||
@ -125,7 +212,7 @@ Result(i32) EncryptedSocketUDP_recvfrom(EncryptedSocketUDP* ptr,
|
||||
try(u32 decrypted_size, u,
|
||||
AESBlockDecryptor_decrypt(
|
||||
&ptr->dec,
|
||||
Array_sliceBefore(ptr->recv_buf, received_size),
|
||||
Array_sliceTo(ptr->recv_buf, received_size),
|
||||
buffer
|
||||
)
|
||||
);
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
#include "network/socket.h"
|
||||
#include "cryptography/AES.h"
|
||||
#include "cryptography/RSA.h"
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// EncryptedSocketTCP //
|
||||
@ -20,12 +21,37 @@ void EncryptedSocketTCP_construct(EncryptedSocketTCP* ptr,
|
||||
/// closes the socket
|
||||
void EncryptedSocketTCP_destroy(EncryptedSocketTCP* ptr);
|
||||
|
||||
void EncryptedSocketTCP_changeKey(EncryptedSocketTCP* ptr, Array(u8) aes_key);
|
||||
|
||||
Result(void) EncryptedSocketTCP_send(EncryptedSocketTCP* ptr,
|
||||
Array(u8) buffer);
|
||||
|
||||
#define EncryptedSocketTCP_sendStruct(socket, structPtr)\
|
||||
EncryptedSocketTCP_send(socket,\
|
||||
Array_construct_size(structPtr, sizeof(*structPtr)))
|
||||
|
||||
Result(u32) EncryptedSocketTCP_recv(EncryptedSocketTCP* ptr,
|
||||
Array(u8) buffer, SocketRecvFlag flags);
|
||||
|
||||
#define EncryptedSocketTCP_recvStruct(socket, structPtr)\
|
||||
EncryptedSocketTCP_recv(socket,\
|
||||
Array_construct_size(structPtr, sizeof(*structPtr)),\
|
||||
SocketRecvFlag_WholeBuffer)
|
||||
|
||||
Result(void) EncryptedSocketTCP_sendRSA(EncryptedSocketTCP* ptr,
|
||||
RSAEncryptor* rsa_enc, Array(u8) buffer);
|
||||
|
||||
#define EncryptedSocketTCP_sendStructRSA(socket, rsa_enc, structPtr)\
|
||||
EncryptedSocketTCP_sendRSA(socket, rsa_enc,\
|
||||
Array_construct_size(structPtr, sizeof(*structPtr)))
|
||||
|
||||
Result(u32) EncryptedSocketTCP_recvRSA(EncryptedSocketTCP* ptr,
|
||||
RSADecryptor* rsa_dec, Array(u8) buffer, SocketRecvFlag flags);
|
||||
|
||||
#define EncryptedSocketTCP_recvStructRSA(socket, rsa_dec, structPtr)\
|
||||
EncryptedSocketTCP_recvRSA(socket, rsa_dec,\
|
||||
Array_construct_size(structPtr, sizeof(*structPtr)),\
|
||||
SocketRecvFlag_WholeBuffer)
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////////
|
||||
// EncryptedSocketUDP //
|
||||
@ -45,6 +71,8 @@ void EncryptedSocketUDP_construct(EncryptedSocketUDP* ptr,
|
||||
/// closes the socket
|
||||
void EncryptedSocketUDP_destroy(EncryptedSocketUDP* ptr);
|
||||
|
||||
void EncryptedSocketUDP_changeKey(EncryptedSocketUDP* ptr, Array(u8) aes_key);
|
||||
|
||||
Result(void) EncryptedSocketUDP_sendto(EncryptedSocketUDP* ptr,
|
||||
Array(u8) buffer, EndpointIPv4 remote_end);
|
||||
|
||||
|
||||
@ -86,7 +86,7 @@ static inline int SocketRecvFlags_toStd(SocketRecvFlag flags){
|
||||
int f = 0;
|
||||
if (flags & SocketRecvFlag_Peek)
|
||||
f |= MSG_PEEK;
|
||||
if (flags & SocketRecvFlag_WaitAll)
|
||||
if (flags & SocketRecvFlag_WholeBuffer)
|
||||
f |= MSG_WAITALL;
|
||||
return f;
|
||||
}
|
||||
@ -96,7 +96,7 @@ Result(i32) socket_recv(Socket s, Array(u8) buffer, SocketRecvFlag flags){
|
||||
if(r < 0){
|
||||
return RESULT_ERROR_SOCKET();
|
||||
}
|
||||
if(r == 0 || (flags & SocketRecvFlag_WaitAll && (u32)r != buffer.size))
|
||||
if(r == 0 || (flags & SocketRecvFlag_WholeBuffer && (u32)r != buffer.size))
|
||||
{
|
||||
return RESULT_ERROR("Socket closed", false);
|
||||
}
|
||||
@ -111,7 +111,7 @@ Result(i32) socket_recvfrom(Socket s, Array(u8) buffer, SocketRecvFlag flags, NU
|
||||
if(r < 0){
|
||||
return RESULT_ERROR_SOCKET();
|
||||
}
|
||||
if(r == 0 || (flags & SocketRecvFlag_WaitAll && (u32)r != buffer.size))
|
||||
if(r == 0 || (flags & SocketRecvFlag_WholeBuffer && (u32)r != buffer.size))
|
||||
{
|
||||
return RESULT_ERROR("Socket closed", false);
|
||||
}
|
||||
|
||||
@ -13,7 +13,7 @@ typedef enum SocketShutdownType {
|
||||
typedef enum SocketRecvFlag {
|
||||
SocketRecvFlag_None = 0,
|
||||
SocketRecvFlag_Peek = 0b1 /* next recv call will read the same data */,
|
||||
SocketRecvFlag_WaitAll = 0b10 /* waits until buffer is full */,
|
||||
SocketRecvFlag_WholeBuffer = 0b10 /* waits until buffer is full */,
|
||||
} SocketRecvFlag;
|
||||
|
||||
typedef i64 Socket;
|
||||
|
||||
@ -25,74 +25,42 @@ Result(ClientConnection*) ClientConnection_accept(ServerCredentials* server_cred
|
||||
conn->client_end = client_end;
|
||||
conn->session_id = session_id;
|
||||
conn->session_key = Array_alloc_size(AES_SESSION_KEY_SIZE);
|
||||
|
||||
Array(u8) buffer = Array_alloc_size(NETWORK_BUFFER_SIZE);
|
||||
// fix for valgrind false detected errors about uninitialized memory
|
||||
Array_memset(buffer, 0xCC);
|
||||
Defer(free(buffer.data));
|
||||
|
||||
// correct session key will be received from client later
|
||||
Array_memset(conn->session_key, 0);
|
||||
EncryptedSocketTCP_construct(&conn->sock, sock_tcp, NETWORK_BUFFER_SIZE, conn->session_key);
|
||||
// TODO: set socket timeout to 5 seconds
|
||||
|
||||
// receive message encrypted by server public key
|
||||
u32 header_and_message_size = sizeof(PacketHeader) + sizeof(ClientHandshake);
|
||||
Array(u8) bufferPart_encryptedClientHandshake = {
|
||||
.data = (u8*)buffer.data + header_and_message_size,
|
||||
.size = server_credentials->rsa_pk.nlen
|
||||
};
|
||||
try_void(
|
||||
socket_recv(
|
||||
sock_tcp,
|
||||
bufferPart_encryptedClientHandshake,
|
||||
SocketRecvFlag_WaitAll
|
||||
)
|
||||
);
|
||||
|
||||
// decrypt the message using server private key
|
||||
// decrypt the rsa messages using server private key
|
||||
RSADecryptor rsa_dec;
|
||||
RSADecryptor_construct(&rsa_dec, &server_credentials->rsa_sk);
|
||||
try(u32 rsa_dec_size, u,
|
||||
RSADecryptor_decrypt(
|
||||
&rsa_dec,
|
||||
bufferPart_encryptedClientHandshake,
|
||||
buffer
|
||||
)
|
||||
);
|
||||
|
||||
// validate client handshake
|
||||
if(rsa_dec_size != header_and_message_size){
|
||||
Return RESULT_ERROR_FMT(
|
||||
"decrypted message (size: %u) is not a ClientHandshake (size: %u)",
|
||||
rsa_dec_size, header_and_message_size
|
||||
);
|
||||
}
|
||||
PacketHeader* packet_header = buffer.data;
|
||||
ClientHandshake* client_handshake = Array_sliceAfter(buffer, sizeof(PacketHeader)).data;
|
||||
try_void(PacketHeader_validateMagic(packet_header));
|
||||
if(packet_header->type != PacketType_ClientHandshake){
|
||||
// receive PacketHeader
|
||||
PacketHeader packet_header;
|
||||
try_void(EncryptedSocketTCP_recvStructRSA(&conn->sock, &rsa_dec, &packet_header));
|
||||
try_void(PacketHeader_validateMagic(&packet_header));
|
||||
if(packet_header.type != PacketType_ClientHandshake){
|
||||
Return RESULT_ERROR_FMT(
|
||||
"received message of unexpected type: %u",
|
||||
packet_header->type
|
||||
packet_header.type
|
||||
);
|
||||
}
|
||||
|
||||
// use received session key
|
||||
memcpy(conn->session_key.data, client_handshake->session_key, conn->session_key.size);
|
||||
EncryptedSocketTCP_construct(&conn->sock, sock_tcp, NETWORK_BUFFER_SIZE, conn->session_key);
|
||||
// receive ClientHandshake
|
||||
ClientHandshake client_handshake;
|
||||
try_void(EncryptedSocketTCP_recvStructRSA(&conn->sock, &rsa_dec, &client_handshake));
|
||||
|
||||
// construct PacketHeader and ServerHandshake in buffer
|
||||
PacketHeader_construct(buffer.data, PROTOCOL_VERSION,
|
||||
PacketType_ServerHandshake, sizeof(ServerHandshake));
|
||||
ServerHandshake_construct(
|
||||
Array_sliceAfter(buffer, sizeof(PacketHeader)).data,
|
||||
// use received session key
|
||||
memcpy(conn->session_key.data, client_handshake.session_key, conn->session_key.size);
|
||||
EncryptedSocketTCP_changeKey(&conn->sock, conn->session_key);
|
||||
|
||||
// send PacketHeader and ServerHandshake over encrypted TCP socket
|
||||
PacketHeader_construct(&packet_header,
|
||||
PROTOCOL_VERSION, PacketType_ServerHandshake, sizeof(ServerHandshake));
|
||||
ServerHandshake server_handshake;
|
||||
ServerHandshake_construct(&server_handshake,
|
||||
session_id);
|
||||
// send ServerHandshake over encrypted TCP socket
|
||||
header_and_message_size = sizeof(PacketHeader) + sizeof(ServerHandshake);
|
||||
try_void(
|
||||
EncryptedSocketTCP_send(
|
||||
&conn->sock,
|
||||
Array_sliceBefore(buffer, header_and_message_size)
|
||||
)
|
||||
);
|
||||
try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &packet_header));
|
||||
try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &server_handshake));
|
||||
|
||||
success = true;
|
||||
Return RESULT_VALUE(p, conn);
|
||||
|
||||
@ -72,7 +72,7 @@ Result(void) server_run(cstr server_endpoint_cstr, cstr config_path){
|
||||
//TODO: use async IO instead of threads to not waste system resources
|
||||
// while waiting for incoming data in 100500 threads
|
||||
try_stderrcode(pthread_create(&conn_thread, NULL, handle_connection, args));
|
||||
try_stderrcode(pthread_detach(&conn_thread));
|
||||
try_stderrcode(pthread_detach(conn_thread));
|
||||
}
|
||||
|
||||
Return RESULT_VOID;
|
||||
@ -113,15 +113,9 @@ static Result(void) try_handle_connection(ConnectionHandlerArgs* args, cstr log_
|
||||
);
|
||||
logInfo(log_ctx, "session accepted");
|
||||
|
||||
// handle requests
|
||||
|
||||
Array(u8) buffer = Array_alloc_size(NETWORK_BUFFER_SIZE);
|
||||
// fix for valgrind false detected errors about uninitialized memory
|
||||
Array_memset(buffer, 0xCC);
|
||||
Defer(free(buffer.data));
|
||||
u32 dec_size = 0;
|
||||
|
||||
// handle unauthorized requests
|
||||
while(true){
|
||||
|
||||
sleepMsec(10);
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user