Compare commits

..

4 Commits

21 changed files with 644 additions and 207 deletions

2
dependencies/tlibc vendored

@ -1 +1 @@
Subproject commit 00a1a29d342c6b3bcb8015a661364e10a3134fe4 Subproject commit d6436d08338a0a762e727f0c816dd5a09782b180

View File

@ -63,49 +63,27 @@ Result(ServerConnection*) ServerConnection_open(ClientCredentials* client_creden
// connect to server address // connect to server address
try(Socket _s, i, socket_open_TCP()); try(Socket _s, i, socket_open_TCP());
// TODO: set socket timeout to 5 seconds try_void(socket_TCP_enableAliveChecks_default(_s));
try_void(socket_connect(_s, conn->server_end)); try_void(socket_connect(_s, conn->server_end));
EncryptedSocketTCP_construct(&conn->sock, _s, NETWORK_BUFFER_SIZE, conn->session_key); EncryptedSocketTCP_construct(&conn->sock, _s, NETWORK_BUFFER_SIZE, conn->session_key);
Array(u8) buffer = Array_alloc_size(NETWORK_BUFFER_SIZE); // send PacketHeader and ClientHandshake
// fix for valgrind false detected errors about uninitialized memory // encryption by server public key
Array_memset(buffer, 0xCC); PacketHeader packet_header = {0};
Defer(free(buffer.data)); ClientHandshake client_handshake = {0};
try_void(ClientHandshake_tryConstruct(&client_handshake, &packet_header,
// construct PacketHeader and ClientHandshake in buffer conn->session_key));
PacketHeader_construct(buffer.data, PROTOCOL_VERSION, try_void(EncryptedSocketTCP_sendStructRSA(&conn->sock, &conn->rsa_enc, &packet_header));
PacketType_ClientHandshake, sizeof(ClientHandshake)); try_void(EncryptedSocketTCP_sendStructRSA(&conn->sock, &conn->rsa_enc, &client_handshake));
ClientHandshake_construct(
Array_sliceAfter(buffer, sizeof(PacketHeader)).data,
conn->session_key);
u32 header_and_message_size = sizeof(PacketHeader) + sizeof(ClientHandshake);
// encrypt message by server public key
Array(u8) bufferPart_encryptedClientHandshake = Array_sliceAfter(buffer, header_and_message_size);
try(u32 rsa_enc_size, u,
RSAEncryptor_encrypt(
&conn->rsa_enc,
Array_sliceBefore(buffer, header_and_message_size),
bufferPart_encryptedClientHandshake
)
);
bufferPart_encryptedClientHandshake.size = rsa_enc_size;
// send encrypted message
try_void(socket_send(conn->sock.sock, bufferPart_encryptedClientHandshake));
// receive server response // receive server response
try_void( try_void(EncryptedSocketTCP_recvStruct(&conn->sock, &packet_header));
EncryptedSocketTCP_recv(&conn->sock, try_void(PacketHeader_validateMagic(&packet_header));
Array_sliceBefore(buffer, sizeof(PacketHeader)),
SocketRecvFlag_WaitAll
)
);
PacketHeader* packet_header = buffer.data;
try_void(PacketHeader_validateMagic(packet_header));
// handle server response // handle server response
switch(packet_header->type){ switch(packet_header.type){
case PacketType_ErrorMessage: { case PacketType_ErrorMessage: {
u32 err_msg_size = packet_header->content_size; u32 err_msg_size = packet_header.content_size;
if(err_msg_size > conn->sock.recv_buf.size) if(err_msg_size > conn->sock.recv_buf.size)
err_msg_size = conn->sock.recv_buf.size; err_msg_size = conn->sock.recv_buf.size;
Array(u8) err_buf = Array_alloc_size(err_msg_size + 1); Array(u8) err_buf = Array_alloc_size(err_msg_size + 1);
@ -119,8 +97,8 @@ Result(ServerConnection*) ServerConnection_open(ClientCredentials* client_creden
try_void( try_void(
EncryptedSocketTCP_recv( EncryptedSocketTCP_recv(
&conn->sock, &conn->sock,
Array_sliceBefore(err_buf, err_msg_size), Array_sliceTo(err_buf, err_msg_size),
SocketRecvFlag_WaitAll SocketRecvFlag_WholeBuffer
) )
); );
@ -129,23 +107,13 @@ Result(ServerConnection*) ServerConnection_open(ClientCredentials* client_creden
Return RESULT_ERROR((char*)err_buf.data, true); Return RESULT_ERROR((char*)err_buf.data, true);
} }
case PacketType_ServerHandshake: { case PacketType_ServerHandshake: {
Array(u8) bufferPart_ServerHandshake = { ServerHandshake server_handshake;
.data = (u8*)buffer.data + sizeof(PacketHeader), try_void(EncryptedSocketTCP_recvStruct(&conn->sock, &server_handshake));
.size = sizeof(ServerHandshake) conn->session_id = server_handshake.session_id;
};
try_void(
EncryptedSocketTCP_recv(
&conn->sock,
bufferPart_ServerHandshake,
SocketRecvFlag_WaitAll
)
);
ServerHandshake* server_handshake = bufferPart_ServerHandshake.data;
conn->session_id = server_handshake->session_id;
break; break;
} }
default: default:
Return RESULT_ERROR_FMT("unexpected response type: %i", packet_header->type); Return RESULT_ERROR_FMT("unexpected response type: %i", packet_header.type);
} }
success = true; success = true;

View File

@ -1,5 +1,7 @@
#include "client.h" #include "client.h"
#include "term.h" #include "term.h"
#include "tlibc/time.h"
#include "network/tcp-chat-protocol/v1.h"
static const str greeting_art = STR( static const str greeting_art = STR(
" ^,,^ |\n" " ^,,^ |\n"
@ -23,21 +25,22 @@ static Result(void) commandExec(str command, bool* stop);
static Result(void) askUserNameAndPassword(ClientCredentials** cred){ static Result(void) askUserNameAndPassword(ClientCredentials** cred){
Deferral(8); Deferral(8);
char username_buf[1024]; char username_buf[128];
str usrername = str_null; str username = str_null;
while(true) { while(true) {
printf("username: "); printf("username: ");
if(fgets(username_buf, sizeof(username_buf), stdin) == NULL){ if(fgets(username_buf, sizeof(username_buf), stdin) == NULL){
Return RESULT_ERROR("STDIN is closed", false); Return RESULT_ERROR("STDIN is closed", false);
} }
usrername = str_from_cstr(username_buf); username = str_from_cstr(username_buf);
if(usrername.size < 4){ if(username.size < USERNAME_SIZE_MIN || username.size > USERNAME_SIZE_MAX){
printf("ERROR: username length must be at least 4\n"); printf("ERROR: username length (in bytes) must be >= %i and <= %i\n",
USERNAME_SIZE_MIN, USERNAME_SIZE_MAX);
} }
else break; else break;
} }
char password_buf[1024]; char password_buf[128];
str password = str_null; str password = str_null;
while(true) { while(true) {
printf("password: "); printf("password: ");
@ -46,13 +49,14 @@ static Result(void) askUserNameAndPassword(ClientCredentials** cred){
Return RESULT_ERROR("STDIN is closed", false); Return RESULT_ERROR("STDIN is closed", false);
} }
password = str_from_cstr(password_buf); password = str_from_cstr(password_buf);
if(password.size < 8){ if(password.size < PASSWORD_SIZE_MIN || password.size > PASSWORD_SIZE_MAX){
printf("ERROR: password length must be at least 8\n"); printf("ERROR: password length (in bytes) must be >= %i and <= %i\n",
PASSWORD_SIZE_MIN, PASSWORD_SIZE_MAX);
} }
else break; else break;
} }
try(*cred, p, ClientCredentials_create(usrername, password)); try(*cred, p, ClientCredentials_create(username, password));
Return RESULT_VOID; Return RESULT_VOID;
} }
@ -86,7 +90,7 @@ Result(void) client_run() {
if(command_input.size == 0) if(command_input.size == 0)
continue; continue;
Result(void) com_result = commandExec(command_input, &stop); ResultVar(void) com_result = commandExec(command_input, &stop);
if(com_result.error){ if(com_result.error){
str e_str = Error_toStr(com_result.error); str e_str = Error_toStr(com_result.error);
printf("%s\n", e_str.data); printf("%s\n", e_str.data);
@ -138,11 +142,21 @@ static Result(void) commandExec(str command, bool* stop){
// TODO: request server info // TODO: request server info
// show server info // show server info
// save server info to user's db // save server info to user's db
// request log in // try log in
// if not registered, request registration and then log in // if not registered, request registration and then log in
// call serverConnection_run():
// function with infinite loop which sends and receives messages
// with navigation across server channels
//
} }
else if(is_alias("c") || is_alias("connect")){ else if(is_alias("c") || is_alias("connect")){
// TODO: read saved servers from database // TODO: read saved servers from database
// show scrollable list of them
// select one
// try log in
// if not registered, ask user if they want to register
// regiser and then log in
} }
else { else {
Return RESULT_ERROR_FMT("unknown kommand: '%s'\n" Return RESULT_ERROR_FMT("unknown kommand: '%s'\n"

View File

@ -4,18 +4,23 @@
// write data from src to array and increment array data pointer // write data from src to array and increment array data pointer
static inline void __Array_writeNext(Array(u8)* dst, u8* src, size_t size){ static inline void __Array_writeNext(Array(u8)* dst, u8* src, size_t size){
memcpy(dst->data, src, size); memcpy(dst->data, src, size);
*dst = Array_sliceAfter(*dst, size); *dst = Array_sliceFrom(*dst, size);
} }
// read data from array to dst and increment array data pointer // read data from array to dst and increment array data pointer
static inline void __Array_readNext(u8* dst, Array(u8)* src, size_t size){ static inline void __Array_readNext(u8* dst, Array(u8)* src, size_t size){
memcpy(dst, src->data, size); memcpy(dst, src->data, size);
*src = Array_sliceAfter(*src, size); *src = Array_sliceFrom(*src, size);
} }
//////////////////////////////////////////////////////////////////////////////
// AESBlockEncryptor //
//////////////////////////////////////////////////////////////////////////////
void AESBlockEncryptor_construct(AESBlockEncryptor* ptr, Array(u8) key, const br_block_cbcenc_class* enc_class){ void AESBlockEncryptor_construct(AESBlockEncryptor* ptr,
Array(u8) key, const br_block_cbcenc_class* enc_class)
{
assert(key.size == 16 || key.size == 24 || key.size == 32); assert(key.size == 16 || key.size == 24 || key.size == 32);
ptr->enc_class = enc_class; ptr->enc_class = enc_class;
@ -25,7 +30,15 @@ void AESBlockEncryptor_construct(AESBlockEncryptor* ptr, Array(u8) key, const br
rng_init_sha256_seedFromSystem(&ptr->rng_ctx.vtable); rng_init_sha256_seedFromSystem(&ptr->rng_ctx.vtable);
} }
Result(u32) AESBlockEncryptor_encrypt(AESBlockEncryptor* ptr, Array(u8) src, Array(u8) dst){ void AESBlockEncryptor_changeKey(AESBlockEncryptor* ptr, Array(u8) key)
{
assert(key.size == 16 || key.size == 24 || key.size == 32);
ptr->enc_class->init((void*)ptr->enc_keys, key.data, key.size);
}
Result(u32) AESBlockEncryptor_encrypt(AESBlockEncryptor* ptr,
Array(u8) src, Array(u8) dst)
{
Deferral(4); Deferral(4);
u32 encrypted_size = AESBlockEncryptor_calcDstSize(src.size); u32 encrypted_size = AESBlockEncryptor_calcDstSize(src.size);
try_assert(dst.size >= encrypted_size); try_assert(dst.size >= encrypted_size);
@ -67,15 +80,28 @@ Result(u32) AESBlockEncryptor_encrypt(AESBlockEncryptor* ptr, Array(u8) src, Arr
} }
//////////////////////////////////////////////////////////////////////////////
// AESBlockDecryptor //
//////////////////////////////////////////////////////////////////////////////
void AESBlockDecryptor_construct(AESBlockDecryptor* ptr, Array(u8) key, const br_block_cbcdec_class* dec_class){ void AESBlockDecryptor_construct(AESBlockDecryptor* ptr,
Array(u8) key, const br_block_cbcdec_class* dec_class)
{
assert(key.size == 16 || key.size == 24 || key.size == 32); assert(key.size == 16 || key.size == 24 || key.size == 32);
ptr->dec_class = dec_class; ptr->dec_class = dec_class;
ptr->dec_class->init((void*)ptr->dec_keys, key.data, key.size); ptr->dec_class->init((void*)ptr->dec_keys, key.data, key.size);
} }
Result(u32) AESBlockDecryptor_decrypt(AESBlockDecryptor* ptr, Array(u8) src, Array(u8) dst){ void AESBlockDecryptor_changeKey(AESBlockDecryptor* ptr, Array(u8) key)
{
assert(key.size == 16 || key.size == 24 || key.size == 32);
ptr->dec_class->init((void*)ptr->dec_keys, key.data, key.size);
}
Result(u32) AESBlockDecryptor_decrypt(AESBlockDecryptor* ptr,
Array(u8) src, Array(u8) dst)
{
Deferral(4); Deferral(4);
try_assert(src.size >= AESBlockEncryptor_calcDstSize(0)); try_assert(src.size >= AESBlockEncryptor_calcDstSize(0));
try_assert(src.size % 16 == 0 && "src must be array of 16-byte blocks"); try_assert(src.size % 16 == 0 && "src must be array of 16-byte blocks");
@ -116,8 +142,13 @@ Result(u32) AESBlockDecryptor_decrypt(AESBlockDecryptor* ptr, Array(u8) src, Arr
} }
//////////////////////////////////////////////////////////////////////////////
// AESStreamEncryptor //
//////////////////////////////////////////////////////////////////////////////
void AESStreamEncryptor_construct(AESStreamEncryptor* ptr, Array(u8) key, const br_block_ctr_class* ctr_class){ void AESStreamEncryptor_construct(AESStreamEncryptor* ptr,
Array(u8) key, const br_block_ctr_class* ctr_class)
{
assert(key.size == 16 || key.size == 24 || key.size == 32); assert(key.size == 16 || key.size == 24 || key.size == 32);
ptr->ctr_class = ctr_class; ptr->ctr_class = ctr_class;
@ -131,7 +162,15 @@ void AESStreamEncryptor_construct(AESStreamEncryptor* ptr, Array(u8) key, const
ptr->block_counter = 0; ptr->block_counter = 0;
} }
Result(u32) AESStreamEncryptor_encrypt(AESStreamEncryptor* ptr, Array(u8) src, Array(u8) dst){ void AESStreamEncryptor_changeKey(AESStreamEncryptor* ptr, Array(u8) key)
{
assert(key.size == 16 || key.size == 24 || key.size == 32);
ptr->ctr_class->init((void*)ptr->ctr_keys, key.data, key.size);
}
Result(u32) AESStreamEncryptor_encrypt(AESStreamEncryptor* ptr,
Array(u8) src, Array(u8) dst)
{
Deferral(4); Deferral(4);
u32 encrypted_size = AESStreamEncryptor_calcDstSize(src.size); u32 encrypted_size = AESStreamEncryptor_calcDstSize(src.size);
try_assert(dst.size >= encrypted_size); try_assert(dst.size >= encrypted_size);
@ -164,8 +203,13 @@ Result(u32) AESStreamEncryptor_encrypt(AESStreamEncryptor* ptr, Array(u8) src, A
} }
//////////////////////////////////////////////////////////////////////////////
// AESStreamDecryptor //
//////////////////////////////////////////////////////////////////////////////
void AESStreamDecryptor_construct(AESStreamDecryptor* ptr, Array(u8) key, const br_block_ctr_class* ctr_class){ void AESStreamDecryptor_construct(AESStreamDecryptor* ptr,
Array(u8) key, const br_block_ctr_class* ctr_class)
{
assert(key.size == 16 || key.size == 24 || key.size == 32); assert(key.size == 16 || key.size == 24 || key.size == 32);
ptr->ctr_class = ctr_class; ptr->ctr_class = ctr_class;
@ -174,7 +218,15 @@ void AESStreamDecryptor_construct(AESStreamDecryptor* ptr, Array(u8) key, const
ptr->block_counter = 0; ptr->block_counter = 0;
} }
Result(u32) AESStreamDecryptor_decrypt(AESStreamDecryptor* ptr, Array(u8) src, Array(u8) dst){ void AESStreamDecryptor_changeKey(AESStreamDecryptor* ptr, Array(u8) key)
{
assert(key.size == 16 || key.size == 24 || key.size == 32);
ptr->ctr_class->init((void*)ptr->ctr_keys, key.data, key.size);
}
Result(u32) AESStreamDecryptor_decrypt(AESStreamDecryptor* ptr,
Array(u8) src, Array(u8) dst)
{
Deferral(4); Deferral(4);
// if it is the beginning of the stream, read IV // if it is the beginning of the stream, read IV

View File

@ -18,7 +18,7 @@
//TODO: use PKS#7 instead of this garbage //TODO: use PKS#7 instead of this garbage
typedef struct EncryptedBlockHeader { typedef struct EncryptedBlockHeader {
u8 padding_size; u8 padding_size;
} __attribute__((aligned(16))) EncryptedBlockHeader; } ATTRIBUTE_ALIGNED(16) EncryptedBlockHeader;
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// AESBlockEncryptor // // AESBlockEncryptor //
@ -40,6 +40,9 @@ typedef struct AESBlockEncryptor {
/// @param enc_class &br_aes_XXX_cbcenc_vtable /// @param enc_class &br_aes_XXX_cbcenc_vtable
void AESBlockEncryptor_construct(AESBlockEncryptor* ptr, Array(u8) key, const br_block_cbcenc_class* enc_class); void AESBlockEncryptor_construct(AESBlockEncryptor* ptr, Array(u8) key, const br_block_cbcenc_class* enc_class);
/// @param key supported sizes: 16, 24, 32
void AESBlockEncryptor_changeKey(AESBlockEncryptor* ptr, Array(u8) key);
/// @brief Encrypts a complete message. For part-by-part encryption use AESStreamEncryptor. /// @brief Encrypts a complete message. For part-by-part encryption use AESStreamEncryptor.
/// @param src array of any size /// @param src array of any size
/// @param dst array of size >= AESBlockEncryptor_calcDstSize(src.size) /// @param dst array of size >= AESBlockEncryptor_calcDstSize(src.size)
@ -63,6 +66,9 @@ typedef struct AESBlockDecryptor {
/// @param dec_class &br_aes_XXX_cbcdec_vtable /// @param dec_class &br_aes_XXX_cbcdec_vtable
void AESBlockDecryptor_construct(AESBlockDecryptor* ptr, Array(u8) key, const br_block_cbcdec_class* dec_class); void AESBlockDecryptor_construct(AESBlockDecryptor* ptr, Array(u8) key, const br_block_cbcdec_class* dec_class);
/// @param key supported sizes: 16, 24, 32
void AESBlockDecryptor_changeKey(AESBlockDecryptor* ptr, Array(u8) key);
/// @brief Decrypts a complete message. For part-by-part decryption use AESStreamEncryptor. /// @brief Decrypts a complete message. For part-by-part decryption use AESStreamEncryptor.
/// @param src array of size at least AESBlockEncryptor_calcDstSize(0). Size must be multiple of 16. /// @param src array of size at least AESBlockEncryptor_calcDstSize(0). Size must be multiple of 16.
/// @param dst array of size >= src.size /// @param dst array of size >= src.size
@ -88,6 +94,9 @@ typedef struct AESStreamEncryptor {
/// @param dec_class &br_aes_XXX_ctr_vtable /// @param dec_class &br_aes_XXX_ctr_vtable
void AESStreamEncryptor_construct(AESStreamEncryptor* ptr, Array(u8) key, const br_block_ctr_class* ctr_class); void AESStreamEncryptor_construct(AESStreamEncryptor* ptr, Array(u8) key, const br_block_ctr_class* ctr_class);
/// @param key supported sizes: 16, 24, 32
void AESStreamEncryptor_changeKey(AESStreamEncryptor* ptr, Array(u8) key);
/// use this only at the beginning of the stream /// use this only at the beginning of the stream
#define AESStreamEncryptor_calcDstSize(src_size) (src_size + __AES_STREAM_IV_SIZE) #define AESStreamEncryptor_calcDstSize(src_size) (src_size + __AES_STREAM_IV_SIZE)
@ -114,6 +123,9 @@ typedef struct AESStreamDecryptor {
/// @param dec_class &br_aes_XXX_ctr_vtable /// @param dec_class &br_aes_XXX_ctr_vtable
void AESStreamDecryptor_construct(AESStreamDecryptor* ptr, Array(u8) key, const br_block_ctr_class* ctr_class); void AESStreamDecryptor_construct(AESStreamDecryptor* ptr, Array(u8) key, const br_block_ctr_class* ctr_class);
/// @param key supported sizes: 16, 24, 32
void AESStreamDecryptor_changeKey(AESStreamDecryptor* ptr, Array(u8) key);
/// @brief Reads IV from `src`, then decrypts data and writes it to dst /// @brief Reads IV from `src`, then decrypts data and writes it to dst
/// @param src array of any size /// @param src array of any size
/// @param dst array of size >= src.size /// @param dst array of size >= src.size

View File

@ -175,6 +175,11 @@ Result(void) RSA_parsePrivateKey_base64(cstr src, br_rsa_private_key* sk){
Return RESULT_VOID; Return RESULT_VOID;
} }
//////////////////////////////////////////////////////////////////////////////
// RSAEncryptor //
//////////////////////////////////////////////////////////////////////////////
void RSAEncryptor_construct(RSAEncryptor* ptr, const br_rsa_public_key* pk){ void RSAEncryptor_construct(RSAEncryptor* ptr, const br_rsa_public_key* pk){
ptr->pk = pk; ptr->pk = pk;
ptr->rng.vtable = &br_hmac_drbg_vtable; ptr->rng.vtable = &br_hmac_drbg_vtable;
@ -185,11 +190,11 @@ Result(u32) RSAEncryptor_encrypt(RSAEncryptor* ptr, Array(u8) src, Array(u8) dst
u32 key_size_bytes = ptr->pk->nlen; u32 key_size_bytes = ptr->pk->nlen;
const u32 max_src_size = RSAEncryptor_calcMaxSrcSize(key_size_bytes * 8, 256); const u32 max_src_size = RSAEncryptor_calcMaxSrcSize(key_size_bytes * 8, 256);
if(src.size > max_src_size){ if(src.size > max_src_size){
return RESULT_ERROR_FMT("src.size (%u) must be <= %u (use RSAEncryptor_calcMaxSrcSize)", return RESULT_ERROR_FMT("src.size (%u) must be <= RSAEncryptor_calcMaxSrcSize() (%u)",
src.size, max_src_size); src.size, max_src_size);
} }
if(dst.size < key_size_bytes){ if(dst.size < key_size_bytes){
return RESULT_ERROR_FMT("dst.size (%u) must be >= %u (key length in bytes)", return RESULT_ERROR_FMT("dst.size (%u) must be >= key length in bytes (%u)",
dst.size, key_size_bytes); dst.size, key_size_bytes);
} }
size_t sz = br_rsa_i31_oaep_encrypt( size_t sz = br_rsa_i31_oaep_encrypt(
@ -206,30 +211,27 @@ Result(u32) RSAEncryptor_encrypt(RSAEncryptor* ptr, Array(u8) src, Array(u8) dst
} }
//////////////////////////////////////////////////////////////////////////////
// RSADecryptor //
//////////////////////////////////////////////////////////////////////////////
void RSADecryptor_construct(RSADecryptor* ptr, const br_rsa_private_key* sk){ void RSADecryptor_construct(RSADecryptor* ptr, const br_rsa_private_key* sk){
ptr->sk = sk; ptr->sk = sk;
} }
Result(u32) RSADecryptor_decrypt(RSADecryptor* ptr, Array(u8) src, Array(u8) dst){ Result(u32) RSADecryptor_decrypt(RSADecryptor* ptr, Array(u8) buffer){
u32 key_size_bits = ptr->sk->n_bitlen; u32 key_size_bits = ptr->sk->n_bitlen;
if(src.size != key_size_bits/8){ if(buffer.size != key_size_bits/8){
return RESULT_ERROR_FMT("src.size (%u) must be == %u (key length in bytes)", return RESULT_ERROR_FMT("buffer.size (%u) must be == key length in bytes (%u)",
src.size, key_size_bits/8); buffer.size, key_size_bits/8);
} }
const u32 max_src_size = RSAEncryptor_calcMaxSrcSize(key_size_bits, 256); size_t sz = buffer.size;
if(dst.size < max_src_size){
return RESULT_ERROR_FMT("dst.size (%u) must be >= %u (use RSAEncryptor_calcMaxSrcSize)",
dst.size, max_src_size);
}
memcpy(dst.data, src.data, src.size);
size_t sz = src.size;
size_t r = br_rsa_i31_oaep_decrypt( size_t r = br_rsa_i31_oaep_decrypt(
&br_sha256_vtable, &br_sha256_vtable,
NULL, 0, NULL, 0,
ptr->sk, ptr->sk,
dst.data, &sz); buffer.data, &sz);
if(r == 0){ if(r == 0){
return RESULT_ERROR("RSA encryption failed", false); return RESULT_ERROR("RSA encryption failed", false);

View File

@ -58,6 +58,7 @@ str RSA_serializePublicKey_base64(const br_rsa_public_key* sk);
/// @param sk out public key. WARNING: .p is allocated on heap /// @param sk out public key. WARNING: .p is allocated on heap
Result(void) RSA_parsePublicKey_base64(cstr src, br_rsa_public_key* sk); Result(void) RSA_parsePublicKey_base64(cstr src, br_rsa_public_key* sk);
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// RSAEncryptor // // RSAEncryptor //
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
@ -90,6 +91,7 @@ https://crypto.stackexchange.com/a/42100
/// @return size of encrypted data /// @return size of encrypted data
Result(u32) RSAEncryptor_encrypt(RSAEncryptor* ptr, Array(u8) src, Array(u8) dst); Result(u32) RSAEncryptor_encrypt(RSAEncryptor* ptr, Array(u8) src, Array(u8) dst);
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// RSADecryptor // // RSADecryptor //
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
@ -102,6 +104,5 @@ typedef struct RSADecryptor {
void RSADecryptor_construct(RSADecryptor* ptr, const br_rsa_private_key* sk); void RSADecryptor_construct(RSADecryptor* ptr, const br_rsa_private_key* sk);
/// @param src buffer with size == key size in bytes /// @param src buffer with size == key size in bytes
/// @param dst buffer with size >= `RSAEncryptor_calcMaxSrcSize(key_size_bits, 256)`
/// @return size of decrypted data /// @return size of decrypted data
Result(u32) RSADecryptor_decrypt(RSADecryptor* ptr, Array(u8) src, Array(u8) dst); Result(u32) RSADecryptor_decrypt(RSADecryptor* ptr, Array(u8) buffer);

View File

@ -9,7 +9,7 @@ typedef struct TableFileHeader {
u16 version; u16 version;
bool _dirty_bit; bool _dirty_bit;
u32 row_size; u32 row_size;
} __attribute__((aligned(256))) TableFileHeader; } ATTRIBUTE_ALIGNED(256) TableFileHeader;
typedef struct Table { typedef struct Table {
TableFileHeader header; TableFileHeader header;
@ -133,7 +133,7 @@ Result(void) Table_validateHeader(Table* t){
Result(void) Table_validateRowSize(Table* t, u32 row_size){ Result(void) Table_validateRowSize(Table* t, u32 row_size){
if(row_size != t->header.row_size){ if(row_size != t->header.row_size){
Result(void) error_result = RESULT_ERROR_FMT( ResultVar(void) error_result = RESULT_ERROR_FMT(
"Requested row size (%u) doesn't match saved row size (%u)", "Requested row size (%u) doesn't match saved row size (%u)",
row_size, t->header.row_size); row_size, t->header.row_size);
return error_result; return error_result;
@ -226,7 +226,7 @@ Result(Table*) idb_getOrCreateTable(IncrementalDB* db, str _table_name, u32 row_
} }
if(!HashMap_tryPush(&db->tables_map, t->name, &t)){ if(!HashMap_tryPush(&db->tables_map, t->name, &t)){
Result(void) error_result = RESULT_ERROR_FMT( ResultVar(void) error_result = RESULT_ERROR_FMT(
"Table '%s' is already open", "Table '%s' is already open",
t->name.data); t->name.data);
Return error_result; Return error_result;
@ -273,7 +273,7 @@ Result(void) idb_updateRows(Table* t, u64 id, const void* src, u64 count){
} }
try_void(Table_setDirtyBit(t, true)); try_void(Table_setDirtyBit(t, true));
Defer(Table_setDirtyBit(t, false)); Defer(IGNORE_RESULT Table_setDirtyBit(t, false));
i64 file_pos = sizeof(t->header) + id * t->header.row_size; i64 file_pos = sizeof(t->header) + id * t->header.row_size;
@ -295,7 +295,7 @@ Result(u64) idb_pushRows(Table* t, const void* src, u64 count){
Defer(pthread_mutex_unlock(&t->mutex)); Defer(pthread_mutex_unlock(&t->mutex));
try_void(Table_setDirtyBit(t, true)); try_void(Table_setDirtyBit(t, true));
Defer(Table_setDirtyBit(t, false)); Defer(IGNORE_RESULT Table_setDirtyBit(t, false));
const u64 new_row_index = t->row_count; const u64 new_row_index = t->row_count;

View File

@ -1,5 +1,9 @@
#include "encrypted_sockets.h" #include "encrypted_sockets.h"
//////////////////////////////////////////////////////////////////////////////
// EncryptedSocketTCP //
//////////////////////////////////////////////////////////////////////////////
void EncryptedSocketTCP_construct(EncryptedSocketTCP* ptr, void EncryptedSocketTCP_construct(EncryptedSocketTCP* ptr,
Socket sock, u32 crypto_buffer_size, Array(u8) aes_key) Socket sock, u32 crypto_buffer_size, Array(u8) aes_key)
{ {
@ -16,6 +20,11 @@ void EncryptedSocketTCP_destroy(EncryptedSocketTCP* ptr){
free(ptr->send_buf.data); free(ptr->send_buf.data);
} }
void EncryptedSocketTCP_changeKey(EncryptedSocketTCP* ptr, Array(u8) aes_key){
AESStreamEncryptor_changeKey(&ptr->enc, aes_key);
AESStreamDecryptor_changeKey(&ptr->dec, aes_key);
}
Result(void) EncryptedSocketTCP_send(EncryptedSocketTCP* ptr, Result(void) EncryptedSocketTCP_send(EncryptedSocketTCP* ptr,
Array(u8) buffer) Array(u8) buffer)
{ {
@ -31,7 +40,7 @@ Result(void) EncryptedSocketTCP_send(EncryptedSocketTCP* ptr,
try_void( try_void(
socket_send( socket_send(
ptr->sock, ptr->sock,
Array_sliceBefore(ptr->send_buf, encrypted_size) Array_sliceTo(ptr->send_buf, encrypted_size)
) )
); );
@ -51,14 +60,14 @@ Result(u32) EncryptedSocketTCP_recv(EncryptedSocketTCP* ptr,
try(i32 received_size, i, try(i32 received_size, i,
socket_recv( socket_recv(
ptr->sock, ptr->sock,
Array_sliceBefore(ptr->recv_buf, size_to_receive), Array_sliceTo(ptr->recv_buf, size_to_receive),
flags flags
) )
); );
try(u32 decrypted_size, u, try(u32 decrypted_size, u,
AESStreamDecryptor_decrypt( AESStreamDecryptor_decrypt(
&ptr->dec, &ptr->dec,
Array_sliceBefore(ptr->recv_buf, received_size), Array_sliceTo(ptr->recv_buf, received_size),
buffer buffer
) )
); );
@ -66,7 +75,80 @@ Result(u32) EncryptedSocketTCP_recv(EncryptedSocketTCP* ptr,
Return RESULT_VALUE(u, decrypted_size); Return RESULT_VALUE(u, decrypted_size);
} }
Result(void) EncryptedSocketTCP_sendRSA(EncryptedSocketTCP* ptr,
RSAEncryptor* rsa_enc, Array(u8) buffer)
{
Deferral(1);
try(u32 encrypted_size, u,
RSAEncryptor_encrypt(
rsa_enc,
buffer,
ptr->send_buf
)
);
try_void(
socket_send(
ptr->sock,
Array_sliceTo(ptr->send_buf, encrypted_size)
)
);
Return RESULT_VOID;
}
Result(u32) EncryptedSocketTCP_recvRSA(EncryptedSocketTCP* ptr,
RSADecryptor* rsa_dec, Array(u8) buffer, SocketRecvFlag flags)
{
Deferral(1);
// RSA encrypts message in block of size KEY_SIZE_BYTES.
// SocketRecvFlag_WholeBuffer should be always enabled to receive such blocks.
// If this flag is set in `flags` by caller, it means decrypted message size
// must be the same as buffer size.
bool fill_whole_buffer = (flags & SocketRecvFlag_WholeBuffer) != 0;
flags |= SocketRecvFlag_WholeBuffer;
u32 size_to_receive = rsa_dec->sk->n_bitlen / 8;
try(i32 received_size, i,
socket_recv(
ptr->sock,
Array_sliceTo(ptr->recv_buf, size_to_receive),
flags
)
);
try(u32 decrypted_size, u,
RSADecryptor_decrypt(
rsa_dec,
Array_sliceTo(ptr->recv_buf, received_size)
)
);
if(fill_whole_buffer){
if(decrypted_size != buffer.size){
Return RESULT_ERROR_FMT(
"SocketRecvFlag_WholeBuffer is set, "
"but decrypted_size (%u) != buffer.size (%u)",
decrypted_size, buffer.size
);
}
}
else if(decrypted_size > buffer.size){
Return RESULT_ERROR_FMT(
"decrypted_size (%u) > buffer.size (%u)",
decrypted_size, buffer.size
);
}
memcpy(buffer.data, ptr->recv_buf.data, decrypted_size);
Return RESULT_VALUE(u, decrypted_size);
}
//////////////////////////////////////////////////////////////////////////////
// EncryptedSocketUDP //
//////////////////////////////////////////////////////////////////////////////
void EncryptedSocketUDP_construct(EncryptedSocketUDP* ptr, void EncryptedSocketUDP_construct(EncryptedSocketUDP* ptr,
Socket sock, u32 crypto_buffer_size, Array(u8) aes_key) Socket sock, u32 crypto_buffer_size, Array(u8) aes_key)
@ -84,6 +166,11 @@ void EncryptedSocketUDP_destroy(EncryptedSocketUDP* ptr){
free(ptr->send_buf.data); free(ptr->send_buf.data);
} }
void EncryptedSocketUDP_changeKey(EncryptedSocketUDP* ptr, Array(u8) aes_key){
AESBlockEncryptor_changeKey(&ptr->enc, aes_key);
AESBlockDecryptor_changeKey(&ptr->dec, aes_key);
}
Result(void) EncryptedSocketUDP_sendto(EncryptedSocketUDP* ptr, Result(void) EncryptedSocketUDP_sendto(EncryptedSocketUDP* ptr,
Array(u8) buffer, EndpointIPv4 remote_end) Array(u8) buffer, EndpointIPv4 remote_end)
{ {
@ -99,7 +186,7 @@ Result(void) EncryptedSocketUDP_sendto(EncryptedSocketUDP* ptr,
try_void( try_void(
socket_sendto( socket_sendto(
ptr->sock, ptr->sock,
Array_sliceBefore(ptr->send_buf, encrypted_size), Array_sliceTo(ptr->send_buf, encrypted_size),
remote_end remote_end
) )
); );
@ -117,7 +204,7 @@ Result(i32) EncryptedSocketUDP_recvfrom(EncryptedSocketUDP* ptr,
try(i32 received_size, i, try(i32 received_size, i,
socket_recvfrom( socket_recvfrom(
ptr->sock, ptr->sock,
Array_sliceBefore(ptr->recv_buf, size_to_receive), Array_sliceTo(ptr->recv_buf, size_to_receive),
flags, flags,
remote_end remote_end
) )
@ -125,7 +212,7 @@ Result(i32) EncryptedSocketUDP_recvfrom(EncryptedSocketUDP* ptr,
try(u32 decrypted_size, u, try(u32 decrypted_size, u,
AESBlockDecryptor_decrypt( AESBlockDecryptor_decrypt(
&ptr->dec, &ptr->dec,
Array_sliceBefore(ptr->recv_buf, received_size), Array_sliceTo(ptr->recv_buf, received_size),
buffer buffer
) )
); );

View File

@ -1,6 +1,7 @@
#pragma once #pragma once
#include "network/socket.h" #include "network/socket.h"
#include "cryptography/AES.h" #include "cryptography/AES.h"
#include "cryptography/RSA.h"
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// EncryptedSocketTCP // // EncryptedSocketTCP //
@ -20,12 +21,37 @@ void EncryptedSocketTCP_construct(EncryptedSocketTCP* ptr,
/// closes the socket /// closes the socket
void EncryptedSocketTCP_destroy(EncryptedSocketTCP* ptr); void EncryptedSocketTCP_destroy(EncryptedSocketTCP* ptr);
void EncryptedSocketTCP_changeKey(EncryptedSocketTCP* ptr, Array(u8) aes_key);
Result(void) EncryptedSocketTCP_send(EncryptedSocketTCP* ptr, Result(void) EncryptedSocketTCP_send(EncryptedSocketTCP* ptr,
Array(u8) buffer); Array(u8) buffer);
#define EncryptedSocketTCP_sendStruct(socket, structPtr)\
EncryptedSocketTCP_send(socket,\
Array_construct_size(structPtr, sizeof(*structPtr)))
Result(u32) EncryptedSocketTCP_recv(EncryptedSocketTCP* ptr, Result(u32) EncryptedSocketTCP_recv(EncryptedSocketTCP* ptr,
Array(u8) buffer, SocketRecvFlag flags); Array(u8) buffer, SocketRecvFlag flags);
#define EncryptedSocketTCP_recvStruct(socket, structPtr)\
EncryptedSocketTCP_recv(socket,\
Array_construct_size(structPtr, sizeof(*structPtr)),\
SocketRecvFlag_WholeBuffer)
Result(void) EncryptedSocketTCP_sendRSA(EncryptedSocketTCP* ptr,
RSAEncryptor* rsa_enc, Array(u8) buffer);
#define EncryptedSocketTCP_sendStructRSA(socket, rsa_enc, structPtr)\
EncryptedSocketTCP_sendRSA(socket, rsa_enc,\
Array_construct_size(structPtr, sizeof(*structPtr)))
Result(u32) EncryptedSocketTCP_recvRSA(EncryptedSocketTCP* ptr,
RSADecryptor* rsa_dec, Array(u8) buffer, SocketRecvFlag flags);
#define EncryptedSocketTCP_recvStructRSA(socket, rsa_dec, structPtr)\
EncryptedSocketTCP_recvRSA(socket, rsa_dec,\
Array_construct_size(structPtr, sizeof(*structPtr)),\
SocketRecvFlag_WholeBuffer)
////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////
// EncryptedSocketUDP // // EncryptedSocketUDP //
@ -45,6 +71,8 @@ void EncryptedSocketUDP_construct(EncryptedSocketUDP* ptr,
/// closes the socket /// closes the socket
void EncryptedSocketUDP_destroy(EncryptedSocketUDP* ptr); void EncryptedSocketUDP_destroy(EncryptedSocketUDP* ptr);
void EncryptedSocketUDP_changeKey(EncryptedSocketUDP* ptr, Array(u8) aes_key);
Result(void) EncryptedSocketUDP_sendto(EncryptedSocketUDP* ptr, Result(void) EncryptedSocketUDP_sendto(EncryptedSocketUDP* ptr,
Array(u8) buffer, EndpointIPv4 remote_end); Array(u8) buffer, EndpointIPv4 remote_end);

View File

@ -10,20 +10,29 @@
#endif #endif
#endif #endif
// include OS-dependent socket headers
#if KN_USE_WINSOCK #if KN_USE_WINSOCK
#include <winsock2.h> #include <winsock2.h>
#include <ws2ipdef.h>
// There you can see what error codes mean. // There you can see what error codes mean.
#include <winerror.h> #include <winerror.h>
#define RESULT_ERROR_SOCKET() RESULT_ERROR(sprintf_malloc(64, "Winsock error %i", WSAGetLastError()), true)
#else #else
#include <sys/types.h> #include <sys/types.h>
#include <sys/socket.h> #include <sys/socket.h>
#include <sys/time.h>
#include <netinet/in.h> #include <netinet/in.h>
#include <arpa/inet.h> #include <arpa/inet.h>
#include <netdb.h> #include <netdb.h>
#include <unistd.h> #include <unistd.h>
#define RESULT_ERROR_SOCKET() RESULT_ERROR(strerror(errno), false) #endif
#if KN_USE_WINSOCK
#define RESULT_ERROR_SOCKET()\
RESULT_ERROR(sprintf_malloc(64, "Winsock error %i (look in <winerror.h>)", WSAGetLastError()), true);
#else
#define RESULT_ERROR_SOCKET()\
RESULT_ERROR(strerror(errno), false);
#endif #endif
struct sockaddr_in EndpointIPv4_toSockaddr(EndpointIPv4 end); struct sockaddr_in EndpointIPv4_toSockaddr(EndpointIPv4 end);

View File

@ -13,13 +13,9 @@ Result(void) network_init(){
return RESULT_VOID; return RESULT_VOID;
} }
Result(void) network_deinit(){ void network_deinit(){
#if _WIN32 #if _WIN32
// Deinitialize Winsock // Deinitialize Winsock
int result = WSACleanup(); (void)WSACleanup();
if (result != 0) {
return RESULT_ERROR_FMT("WSACleanup failed with error code 0x%X", result);
}
#endif #endif
return RESULT_VOID;
} }

View File

@ -2,4 +2,4 @@
#include "tlibc/errors.h" #include "tlibc/errors.h"
Result(void) network_init(); Result(void) network_init();
Result(void) network_deinit(); void network_deinit();

View File

@ -3,7 +3,7 @@
#include <assert.h> #include <assert.h>
Result(Socket) socket_open_TCP(){ Result(Socket) socket_open_TCP(){
Socket s = socket(AF_INET, SOCK_STREAM, 0); Socket s = socket(AF_INET, SOCK_STREAM, IPPROTO_TCP);
if(s == -1){ if(s == -1){
return RESULT_ERROR_SOCKET(); return RESULT_ERROR_SOCKET();
} }
@ -20,7 +20,7 @@ void socket_close(Socket s){
} }
Result(void) socket_shutdown(Socket s, SocketShutdownType direction){ Result(void) socket_shutdown(Socket s, SocketShutdownType direction){
if(shutdown(s, (int)direction) == -1) if(shutdown(s, (int)direction) != 0)
return RESULT_ERROR_SOCKET(); return RESULT_ERROR_SOCKET();
return RESULT_VOID; return RESULT_VOID;
} }
@ -38,18 +38,18 @@ Result(void) socket_listen(Socket s, i32 backlog){
return RESULT_VOID; return RESULT_VOID;
} }
Result(Socket) socket_accept(Socket main_socket, NULLABLE(EndpointIPv4*) remote_end) { Result(Socket) socket_accept(Socket listening_sock, NULLABLE(EndpointIPv4*) remote_end) {
struct sockaddr_in remote_addr = {0}; struct sockaddr_in remote_addr = {0};
i32 sockaddr_size = sizeof(remote_addr); i32 sockaddr_size = sizeof(remote_addr);
Socket user_connection = accept(main_socket, (void*)&remote_addr, (void*)&sockaddr_size); Socket accepted_sock = accept(listening_sock, (void*)&remote_addr, (void*)&sockaddr_size);
if(user_connection == -1) if(accepted_sock == -1)
return RESULT_ERROR_SOCKET(); return RESULT_ERROR_SOCKET();
//TODO: add IPV6 support (struct sockaddr_in6) //TODO: add IPV6 support (struct sockaddr_in6)
assert(sockaddr_size == sizeof(remote_addr)); assert(sockaddr_size == sizeof(remote_addr));
if(remote_end) if(remote_end)
*remote_end = EndpointIPv4_fromSockaddr(remote_addr); *remote_end = EndpointIPv4_fromSockaddr(remote_addr);
return RESULT_VALUE(i, user_connection); return RESULT_VALUE(i, accepted_sock);
} }
Result(void) socket_connect(Socket s, EndpointIPv4 remote_end){ Result(void) socket_connect(Socket s, EndpointIPv4 remote_end){
@ -86,7 +86,7 @@ static inline int SocketRecvFlags_toStd(SocketRecvFlag flags){
int f = 0; int f = 0;
if (flags & SocketRecvFlag_Peek) if (flags & SocketRecvFlag_Peek)
f |= MSG_PEEK; f |= MSG_PEEK;
if (flags & SocketRecvFlag_WaitAll) if (flags & SocketRecvFlag_WholeBuffer)
f |= MSG_WAITALL; f |= MSG_WAITALL;
return f; return f;
} }
@ -96,7 +96,7 @@ Result(i32) socket_recv(Socket s, Array(u8) buffer, SocketRecvFlag flags){
if(r < 0){ if(r < 0){
return RESULT_ERROR_SOCKET(); return RESULT_ERROR_SOCKET();
} }
if(r == 0 || (flags & SocketRecvFlag_WaitAll && (u32)r != buffer.size)) if(r == 0 || (flags & SocketRecvFlag_WholeBuffer && (u32)r != buffer.size))
{ {
return RESULT_ERROR("Socket closed", false); return RESULT_ERROR("Socket closed", false);
} }
@ -111,7 +111,7 @@ Result(i32) socket_recvfrom(Socket s, Array(u8) buffer, SocketRecvFlag flags, NU
if(r < 0){ if(r < 0){
return RESULT_ERROR_SOCKET(); return RESULT_ERROR_SOCKET();
} }
if(r == 0 || (flags & SocketRecvFlag_WaitAll && (u32)r != buffer.size)) if(r == 0 || (flags & SocketRecvFlag_WholeBuffer && (u32)r != buffer.size))
{ {
return RESULT_ERROR("Socket closed", false); return RESULT_ERROR("Socket closed", false);
} }
@ -123,3 +123,60 @@ Result(i32) socket_recvfrom(Socket s, Array(u8) buffer, SocketRecvFlag flags, NU
*remote_end = EndpointIPv4_fromSockaddr(remote_addr); *remote_end = EndpointIPv4_fromSockaddr(remote_addr);
return RESULT_VALUE(i, r); return RESULT_VALUE(i, r);
} }
#define try_setsockopt(socket, level, OPT){ \
if(setsockopt(socket, level, OPT, (void*)&opt_##OPT, sizeof(opt_##OPT)) != 0)\
return RESULT_ERROR_SOCKET();\
}
Result(void) socket_TCP_enableAliveChecks(Socket s,
sec_t first_check_time, u32 checks_count, sec_t checks_interval)
{
#if KN_USE_WINSOCK
BOOL opt_SO_KEEPALIVE = 1; // enable keepalives
DWORD opt_TCP_KEEPIDLE = first_check_time;
DWORD opt_TCP_KEEPCNT = checks_count;
DWORD opt_TCP_KEEPINTVL = checks_interval;
try_setsockopt(s, SOL_SOCKET, SO_KEEPALIVE);
try_setsockopt(s, IPPROTO_TCP, TCP_KEEPIDLE);
try_setsockopt(s, IPPROTO_TCP, TCP_KEEPCNT);
try_setsockopt(s, IPPROTO_TCP, TCP_KEEPINTVL);
// timeout for connect()
DWORD opt_TCP_MAXRT = checks_count * checks_interval;
try_setsockopt(s, IPPROTO_TCP, TCP_MAXRT);
#else
int opt_SO_KEEPALIVE = 1; // enable keepalives
int opt_TCP_KEEPIDLE = first_check_time;
int opt_TCP_KEEPCNT = checks_count;
int opt_TCP_KEEPINTVL = checks_interval;
try_setsockopt(s, SOL_SOCKET, SO_KEEPALIVE);
try_setsockopt(s, IPPROTO_TCP, TCP_KEEPIDLE);
try_setsockopt(s, IPPROTO_TCP, TCP_KEEPCNT);
try_setsockopt(s, IPPROTO_TCP, TCP_KEEPINTVL);
// read more in the article
int opt_TCP_USER_TIMEOUT = checks_count * checks_interval * 1000;
try_setsockopt(s, IPPROTO_TCP, TCP_USER_TIMEOUT);
#endif
return RESULT_VOID;
}
Result(void) socket_setTimeout(Socket s, u32 ms){
#if KN_USE_WINSOCK
DWORD opt_SO_SNDTIMEO = ms;
DWORD opt_SO_RCVTIMEO = opt_SO_SNDTIMEO;
#else
struct timeval opt_SO_SNDTIMEO = {
.tv_sec = ms/1000,
.tv_usec = (ms%1000)*1000
};
struct timeval opt_SO_RCVTIMEO = opt_SO_SNDTIMEO;
#endif
try_setsockopt(s, SOL_SOCKET, SO_SNDTIMEO);
try_setsockopt(s, SOL_SOCKET, SO_RCVTIMEO);
return RESULT_VOID;
}

View File

@ -1,8 +1,8 @@
#pragma once #pragma once
#include "endpoint.h" #include "endpoint.h"
#include "tlibc/errors.h" #include "tlibc/errors.h"
#include "tlibc/collections/Array.h"
#include "tlibc/time.h" #include "tlibc/time.h"
#include "tlibc/collections/Array.h"
typedef enum SocketShutdownType { typedef enum SocketShutdownType {
SocketShutdownType_Receive = 0, SocketShutdownType_Receive = 0,
@ -13,7 +13,7 @@ typedef enum SocketShutdownType {
typedef enum SocketRecvFlag { typedef enum SocketRecvFlag {
SocketRecvFlag_None = 0, SocketRecvFlag_None = 0,
SocketRecvFlag_Peek = 0b1 /* next recv call will read the same data */, SocketRecvFlag_Peek = 0b1 /* next recv call will read the same data */,
SocketRecvFlag_WaitAll = 0b10 /* waits until buffer is full */, SocketRecvFlag_WholeBuffer = 0b10 /* waits until buffer is full */,
} SocketRecvFlag; } SocketRecvFlag;
typedef i64 Socket; typedef i64 Socket;
@ -21,11 +21,28 @@ typedef i64 Socket;
Result(Socket) socket_open_TCP(); Result(Socket) socket_open_TCP();
void socket_close(Socket s); void socket_close(Socket s);
Result(void) socket_shutdown(Socket s, SocketShutdownType direction); Result(void) socket_shutdown(Socket s, SocketShutdownType direction);
Result(void) socket_bind(Socket s, EndpointIPv4 local_end); Result(void) socket_bind(Socket s, EndpointIPv4 local_end);
Result(void) socket_listen(Socket s, i32 backlog); Result(void) socket_listen(Socket s, i32 backlog);
Result(Socket) socket_accept(Socket s, NULLABLE(EndpointIPv4*) remote_end); Result(Socket) socket_accept(Socket listening_sock, NULLABLE(EndpointIPv4*) remote_end);
Result(void) socket_connect(Socket s, EndpointIPv4 remote_end); Result(void) socket_connect(Socket s, EndpointIPv4 remote_end);
Result(void) socket_send(Socket s, Array(u8) buffer); Result(void) socket_send(Socket s, Array(u8) buffer);
Result(void) socket_sendto(Socket s, Array(u8) buffer, EndpointIPv4 dst); Result(void) socket_sendto(Socket s, Array(u8) buffer, EndpointIPv4 dst);
Result(i32) socket_recv(Socket s, Array(u8) buffer, SocketRecvFlag flags); Result(i32) socket_recv(Socket s, Array(u8) buffer, SocketRecvFlag flags);
Result(i32) socket_recvfrom(Socket s, Array(u8) buffer, SocketRecvFlag flags, NULLABLE(EndpointIPv4*) remote_end); Result(i32) socket_recvfrom(Socket s, Array(u8) buffer, SocketRecvFlag flags, NULLABLE(EndpointIPv4*) remote_end);
/// Enables sending SO_KEEPALIVE packets when socket is idling.
/// Also enables TCP_USER_TIMEOUT to handle situations
/// when socket is not sending KEEPALIVE packets.
/// Read more: https://blog.cloudflare.com/when-tcp-sockets-refuse-to-die/
/// RU translaton: https://habr.com/ru/articles/700470/
Result(void) socket_TCP_enableAliveChecks(Socket s,
sec_t first_check_time, u32 checks_count, sec_t checks_interval);
#define socket_TCP_enableAliveChecks_default(socket) \
socket_TCP_enableAliveChecks(socket, 1, 4, 5)
#define SOCKET_TIMEOUT_MS_DEFAULT 5000
#define SOCKET_TIMEOUT_MS_INFINITE 0
/// @brief sets general timeout for send() and recv()
Result(void) socket_setTimeout(Socket s, u32 ms);

View File

@ -9,6 +9,25 @@ Result(void) PacketHeader_validateMagic(PacketHeader* ptr){
return RESULT_VOID; return RESULT_VOID;
} }
Result(void) PacketHeader_validateType(PacketHeader* ptr, u16 expected_type){
if(ptr->type != expected_type){
return RESULT_ERROR_FMT(
"expected message of type %u, but received of type %u",
expected_type, ptr->type);
}
return RESULT_VOID;
}
Result(void) PacketHeader_validateContentSize(PacketHeader* ptr, u64 expected_size){
if(ptr->content_size != expected_size){
return RESULT_ERROR_FMT(
"expected message with content_size " IFWIN("%llu", "%lu")
", but received with content_size " IFWIN("%llu", "%lu"),
expected_size, ptr->content_size);
}
return RESULT_VOID;
}
void PacketHeader_construct(PacketHeader* ptr, u8 protocol_version, u16 type, u64 content_size){ void PacketHeader_construct(PacketHeader* ptr, u8 protocol_version, u16 type, u64 content_size){
ptr->magic.n = PacketHeader_MAGIC.n; ptr->magic.n = PacketHeader_MAGIC.n;
ptr->protocol_version = protocol_version; ptr->protocol_version = protocol_version;

View File

@ -14,7 +14,9 @@ typedef struct PacketHeader {
u16 type; u16 type;
u32 _reserved4; u32 _reserved4;
u64 content_size; u64 content_size;
} __attribute__((aligned(64))) PacketHeader; } ATTRIBUTE_ALIGNED(64) PacketHeader;
void PacketHeader_construct(PacketHeader* ptr, u8 protocol_version, u16 type, u64 content_size); void PacketHeader_construct(PacketHeader* ptr, u8 protocol_version, u16 type, u64 content_size);
Result(void) PacketHeader_validateMagic(PacketHeader* ptr); Result(void) PacketHeader_validateMagic(PacketHeader* ptr);
Result(void) PacketHeader_validateType(PacketHeader* ptr, u16 expected_type);
Result(void) PacketHeader_validateContentSize(PacketHeader* ptr, u64 expected_size);

View File

@ -1,9 +1,73 @@
#include "v1.h" #include "v1.h"
void ClientHandshake_construct(ClientHandshake* ptr, Array(u8) session_key){ #define _PacketHeader_construct(T) \
memcpy(ptr->session_key, session_key.data, sizeof(ptr->session_key)); PacketHeader_construct(header, PROTOCOL_VERSION, PacketType_##T, sizeof(T))
Result(void) ClientHandshake_tryConstruct(ClientHandshake* ptr, PacketHeader* header,
Array(u8) session_key)
{
Deferral(1);
_PacketHeader_construct(ClientHandshake);
try_assert(session_key.size == sizeof(ptr->session_key));
memcpy(ptr->session_key, session_key.data, session_key.size);
Return RESULT_VOID;
} }
void ServerHandshake_construct(ServerHandshake* ptr, u64 session_id){ void ServerHandshake_construct(ServerHandshake* ptr, PacketHeader* header,
u64 session_id)
{
_PacketHeader_construct(ServerHandshake);
ptr->session_id = session_id; ptr->session_id = session_id;
} }
void ServerPublicInfoRequest_construct(ServerPublicInfoRequest *ptr, PacketHeader* header,
ServerPublicInfo property)
{
_PacketHeader_construct(ServerPublicInfoRequest);
ptr->property = property;
}
Result(void) LoginRequest_tryConstruct(LoginRequest *ptr, PacketHeader* header,
Array(u8) token)
{
Deferral(1);
_PacketHeader_construct(LoginRequest);
try_assert(token.size == sizeof(ptr->token));
memcpy(ptr->token, token.data, token.size);
Return RESULT_VOID;
}
void LoginResponse_construct(LoginResponse* ptr, PacketHeader* header,
u64 user_id, u64 landing_channel_id)
{
_PacketHeader_construct(LoginResponse);
ptr->user_id = user_id;
ptr->landing_channel_id = landing_channel_id;
}
Result(void) RegisterRequest_tryConstruct(RegisterRequest *ptr, PacketHeader* header,
str username, Array(u8) token)
{
Deferral(1);
_PacketHeader_construct(RegisterRequest);
try_assert(username.size >= USERNAME_SIZE_MIN && username.size <= USERNAME_SIZE_MAX);
ptr->username_size = username.size;
memcpy(ptr->username, username.data, username.size);
try_assert(token.size == sizeof(ptr->token));
memcpy(ptr->token, token.data, token.size);
Return RESULT_VOID;
}
void RegisterResponse_construct(RegisterResponse *ptr, PacketHeader* header,
u64 user_id)
{
_PacketHeader_construct(RegisterResponse);
ptr->user_id = user_id;
}

View File

@ -1,32 +1,105 @@
#pragma once #pragma once
#include "tlibc/errors.h" #include "tlibc/errors.h"
#include "tlibc/string/str.h"
#include "network/tcp-chat-protocol/constant.h" #include "network/tcp-chat-protocol/constant.h"
#include "cryptography/cryptography.h"
#define PROTOCOL_VERSION 1 /* 1.0.0 */ #define PROTOCOL_VERSION 1 /* 1.0.0 */
#define NETWORK_BUFFER_SIZE 65536 #define NETWORK_BUFFER_SIZE 65536
#define ALIGN_PACKET_STRUCT ATTRIBUTE_ALIGNED(8)
typedef enum PacketType { typedef enum PacketType {
PacketType_Invalid, PacketType_Invalid,
PacketType_ErrorMessage, PacketType_ErrorMessage,
PacketType_ClientHandshake, PacketType_ClientHandshake,
PacketType_ServerHandshake, PacketType_ServerHandshake,
} __attribute__((__packed__)) PacketType; PacketType_ServerPublicInfoRequest,
PacketType_ServerPublicInfoResponse,
PacketType_LoginRequest,
PacketType_LoginResponse,
PacketType_RegisterRequest,
PacketType_RegisterResponse,
} ATTRIBUTE_PACKED PacketType;
typedef struct ErrorMessage { // typedef struct ErrorMessage {
/* content stream of size `header.content_size` */ // /* stream of size header.content_size */
} ErrorMessage; // } ErrorMessage;
typedef struct ClientHandshake { typedef struct ClientHandshake {
u8 session_key[AES_SESSION_KEY_SIZE]; u8 session_key[AES_SESSION_KEY_SIZE];
} ClientHandshake; } ALIGN_PACKET_STRUCT ClientHandshake;
void ClientHandshake_construct(ClientHandshake* ptr, Array(u8) session_key); Result(void) ClientHandshake_tryConstruct(ClientHandshake* ptr, PacketHeader* header,
Array(u8) session_key);
typedef struct ServerHandshake { typedef struct ServerHandshake {
u64 session_id; u64 session_id;
} ServerHandshake; } ALIGN_PACKET_STRUCT ServerHandshake;
void ServerHandshake_construct(ServerHandshake* ptr, PacketHeader* header,
u64 session_id);
typedef enum ServerPublicInfo {
ServerPublicInfo_Name,
ServerPublicInfo_Description,
} ATTRIBUTE_PACKED ServerPublicInfo;
typedef struct ServerPublicInfoRequest {
u32 property;
} ALIGN_PACKET_STRUCT ServerPublicInfoRequest;
void ServerPublicInfoRequest_construct(ServerPublicInfoRequest* ptr, PacketHeader* header,
ServerPublicInfo property);
// typedef struct ServerPublicInfoResponse {
// /* stream of size header.content_size */
// } ServerPublicInfoResponse;
typedef struct LoginRequest {
u8 token[PASSWORD_HASH_SIZE];
} ALIGN_PACKET_STRUCT LoginRequest;
Result(void) LoginRequest_tryConstruct(LoginRequest* ptr, PacketHeader* header,
Array(u8) token);
typedef struct LoginResponse {
u64 user_id;
u64 landing_channel_id;
} ALIGN_PACKET_STRUCT LoginResponse;
void LoginResponse_construct(LoginResponse* ptr, PacketHeader* header,
u64 user_id, u64 landing_channel_id);
#define USERNAME_SIZE_MIN 4
#define USERNAME_SIZE_MAX 64
#define PASSWORD_SIZE_MIN 8
#define PASSWORD_SIZE_MAX 32
typedef struct RegisterRequest {
u32 username_size;
char username[USERNAME_SIZE_MAX];
u8 token[PASSWORD_HASH_SIZE];
} ALIGN_PACKET_STRUCT RegisterRequest;
Result(void) RegisterRequest_tryConstruct(RegisterRequest* ptr, PacketHeader* header,
str username, Array(u8) token);
typedef struct RegisterResponse {
u64 user_id;
} ALIGN_PACKET_STRUCT RegisterResponse;
void RegisterResponse_construct(RegisterResponse* ptr, PacketHeader* header,
u64 user_id);
void ServerHandshake_construct(ServerHandshake* ptr, u64 session_id);

View File

@ -25,74 +25,36 @@ Result(ClientConnection*) ClientConnection_accept(ServerCredentials* server_cred
conn->client_end = client_end; conn->client_end = client_end;
conn->session_id = session_id; conn->session_id = session_id;
conn->session_key = Array_alloc_size(AES_SESSION_KEY_SIZE); conn->session_key = Array_alloc_size(AES_SESSION_KEY_SIZE);
// correct session key will be received from client later
Array_memset(conn->session_key, 0);
EncryptedSocketTCP_construct(&conn->sock, sock_tcp, NETWORK_BUFFER_SIZE, conn->session_key);
try_void(socket_TCP_enableAliveChecks_default(sock_tcp));
Array(u8) buffer = Array_alloc_size(NETWORK_BUFFER_SIZE); // decrypt the rsa messages using server private key
// fix for valgrind false detected errors about uninitialized memory
Array_memset(buffer, 0xCC);
Defer(free(buffer.data));
// TODO: set socket timeout to 5 seconds
// receive message encrypted by server public key
u32 header_and_message_size = sizeof(PacketHeader) + sizeof(ClientHandshake);
Array(u8) bufferPart_encryptedClientHandshake = {
.data = (u8*)buffer.data + header_and_message_size,
.size = server_credentials->rsa_pk.nlen
};
try_void(
socket_recv(
sock_tcp,
bufferPart_encryptedClientHandshake,
SocketRecvFlag_WaitAll
)
);
// decrypt the message using server private key
RSADecryptor rsa_dec; RSADecryptor rsa_dec;
RSADecryptor_construct(&rsa_dec, &server_credentials->rsa_sk); RSADecryptor_construct(&rsa_dec, &server_credentials->rsa_sk);
try(u32 rsa_dec_size, u,
RSADecryptor_decrypt(
&rsa_dec,
bufferPart_encryptedClientHandshake,
buffer
)
);
// validate client handshake // receive PacketHeader
if(rsa_dec_size != header_and_message_size){ PacketHeader packet_header = {0};
Return RESULT_ERROR_FMT( try_void(EncryptedSocketTCP_recvStructRSA(&conn->sock, &rsa_dec, &packet_header));
"decrypted message (size: %u) is not a ClientHandshake (size: %u)", try_void(PacketHeader_validateMagic(&packet_header));
rsa_dec_size, header_and_message_size try_void(PacketHeader_validateType(&packet_header, PacketType_ClientHandshake));
); try_void(PacketHeader_validateContentSize(&packet_header, sizeof(ClientHandshake)));
}
PacketHeader* packet_header = buffer.data; // receive ClientHandshake
ClientHandshake* client_handshake = Array_sliceAfter(buffer, sizeof(PacketHeader)).data; ClientHandshake client_handshake = {0};
try_void(PacketHeader_validateMagic(packet_header)); try_void(EncryptedSocketTCP_recvStructRSA(&conn->sock, &rsa_dec, &client_handshake));
if(packet_header->type != PacketType_ClientHandshake){
Return RESULT_ERROR_FMT(
"received message of unexpected type: %u",
packet_header->type
);
}
// use received session key // use received session key
memcpy(conn->session_key.data, client_handshake->session_key, conn->session_key.size); memcpy(conn->session_key.data, client_handshake.session_key, conn->session_key.size);
EncryptedSocketTCP_construct(&conn->sock, sock_tcp, NETWORK_BUFFER_SIZE, conn->session_key); EncryptedSocketTCP_changeKey(&conn->sock, conn->session_key);
// construct PacketHeader and ServerHandshake in buffer // send PacketHeader and ServerHandshake over encrypted TCP socket
PacketHeader_construct(buffer.data, PROTOCOL_VERSION, ServerHandshake server_handshake = {0};
PacketType_ServerHandshake, sizeof(ServerHandshake)); ServerHandshake_construct(&server_handshake, &packet_header,
ServerHandshake_construct(
Array_sliceAfter(buffer, sizeof(PacketHeader)).data,
session_id); session_id);
// send ServerHandshake over encrypted TCP socket try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &packet_header));
header_and_message_size = sizeof(PacketHeader) + sizeof(ServerHandshake); try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &server_handshake));
try_void(
EncryptedSocketTCP_send(
&conn->sock,
Array_sliceBefore(buffer, header_and_message_size)
)
);
success = true; success = true;
Return RESULT_VALUE(p, conn); Return RESULT_VALUE(p, conn);

View File

@ -1,5 +1,6 @@
#include <pthread.h> #include <pthread.h>
#include "tlibc/filesystem.h" #include "tlibc/filesystem.h"
#include "tlibc/time.h"
#include "db/idb.h" #include "db/idb.h"
#include "server.h" #include "server.h"
#include "config.h" #include "config.h"
@ -57,7 +58,7 @@ Result(void) server_run(cstr server_endpoint_cstr, cstr config_path){
logDebug(log_ctx, "initializing main socket"); logDebug(log_ctx, "initializing main socket");
EndpointIPv4 server_end; EndpointIPv4 server_end;
EndpointIPv4_parse(server_endpoint_cstr, &server_end); try_void(EndpointIPv4_parse(server_endpoint_cstr, &server_end));
try(Socket main_socket, i, socket_open_TCP()); try(Socket main_socket, i, socket_open_TCP());
try_void(socket_bind(main_socket, server_end)); try_void(socket_bind(main_socket, server_end));
try_void(socket_listen(main_socket, 512)); try_void(socket_listen(main_socket, 512));
@ -72,7 +73,7 @@ Result(void) server_run(cstr server_endpoint_cstr, cstr config_path){
//TODO: use async IO instead of threads to not waste system resources //TODO: use async IO instead of threads to not waste system resources
// while waiting for incoming data in 100500 threads // while waiting for incoming data in 100500 threads
try_stderrcode(pthread_create(&conn_thread, NULL, handle_connection, args)); try_stderrcode(pthread_create(&conn_thread, NULL, handle_connection, args));
try_stderrcode(pthread_detach(&conn_thread)); try_stderrcode(pthread_detach(conn_thread));
} }
Return RESULT_VOID; Return RESULT_VOID;
@ -83,7 +84,7 @@ static void* handle_connection(void* _args){
char log_ctx[64]; char log_ctx[64];
sprintf(log_ctx, "Session-" IFWIN("%llx", "%lx"), args->session_id); sprintf(log_ctx, "Session-" IFWIN("%llx", "%lx"), args->session_id);
Result(void) r = try_handle_connection(args, log_ctx); ResultVar(void) r = try_handle_connection(args, log_ctx);
if(r.error){ if(r.error){
str error_s = Error_toStr(r.error); str error_s = Error_toStr(r.error);
logError(log_ctx, "%s", error_s.data); logError(log_ctx, "%s", error_s.data);
@ -113,16 +114,89 @@ static Result(void) try_handle_connection(ConnectionHandlerArgs* args, cstr log_
); );
logInfo(log_ctx, "session accepted"); logInfo(log_ctx, "session accepted");
// handle requests // handle unauthorized requests
bool ahtorized = false;
PacketHeader req_header = {0};
PacketHeader res_header = {0};
while(!ahtorized){
sleepMsec(50);
//TODO: implement some additional check if socket is dead or not
Array(u8) buffer = Array_alloc_size(NETWORK_BUFFER_SIZE); try_void(EncryptedSocketTCP_recvStruct(&conn->sock, &req_header));
// fix for valgrind false detected errors about uninitialized memory try_void(PacketHeader_validateMagic(&req_header));
Array_memset(buffer, 0xCC); //TODO: move request handlers to separate functions
Defer(free(buffer.data)); switch(req_header.type){
u32 dec_size = 0; default:{
Array(u8) err_buf = Array_alloc(u8, 128);
bool err_complete = false;
Defer(if(!err_complete) free(err_buf.data));
sprintf(err_buf.data, "Received unexpected packet of type %u",
req_header.type);
err_buf.size = strlen(err_buf.data);
PacketHeader_construct(&res_header,
PROTOCOL_VERSION, PacketType_ErrorMessage, err_buf.size);
try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &res_header));
//TODO: limit ErrorMessage size to fit into EncryptedSocketTCP.internal_buffer_size
try_void(EncryptedSocketTCP_send(&conn->sock, err_buf));
err_complete = true;
Return RESULT_ERROR(err_buf.data, true);
}
case PacketType_ServerPublicInfoRequest:{
ServerPublicInfoRequest req = {0};
try_void(PacketHeader_validateContentSize(&req_header, sizeof(req)));
try_void(EncryptedSocketTCP_recvStruct(&conn->sock, &req));
//TODO: try find requested info
Array(u8) content;
PacketHeader_construct(&res_header,
PROTOCOL_VERSION, PacketType_ServerPublicInfoResponse, content.size);
try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &res_header));
try_void(EncryptedSocketTCP_send(&conn->sock, content));
break;
}
case PacketType_LoginRequest:{
LoginRequest req = {0};
try_void(PacketHeader_validateContentSize(&req_header, sizeof(req)));
try_void(EncryptedSocketTCP_recvStruct(&conn->sock, &req));
//TODO: try authorize client
u64 user_id;
u64 landing_channel_id;
LoginResponse res = {0};
LoginResponse_construct(&res, &res_header, user_id, landing_channel_id);
try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &res_header));
try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &res));
ahtorized = true;
logInfo(log_ctx, "client authorized");
break;
}
case PacketType_RegisterRequest:{
RegisterRequest req = {0};
try_void(PacketHeader_validateContentSize(&req_header, sizeof(req)));
try_void(EncryptedSocketTCP_recvStruct(&conn->sock, &req));
//TODO: try register client
u64 user_id;
RegisterResponse res = {0};
RegisterResponse_construct(&res, &res_header, user_id);
try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &res_header));
try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &res));
break;
}
}
}
// handle authorized requests
while(true){ while(true){
sleepMsec(10); sleepMsec(50);
} }
Return RESULT_VOID; Return RESULT_VOID;