fixed connection bugs
This commit is contained in:
parent
8179609d47
commit
94fcbe5daf
@ -60,8 +60,8 @@ Result(ServerConnection*) ServerConnection_open(ClientCredentials* client_creden
|
|||||||
br_hmac_drbg_generate(&key_rng, conn->session_key.data, conn->session_key.size);
|
br_hmac_drbg_generate(&key_rng, conn->session_key.data, conn->session_key.size);
|
||||||
|
|
||||||
// connect to server address
|
// connect to server address
|
||||||
Socket _s;
|
try(Socket _s, i, socket_open_TCP());
|
||||||
try(_s, i, socket_open_TCP());
|
// TODO: set socket timeout to 5 seconds
|
||||||
try_void(socket_connect(_s, conn->server_end));
|
try_void(socket_connect(_s, conn->server_end));
|
||||||
EncryptedSocketTCP_construct(&conn->sock, _s, conn->session_key);
|
EncryptedSocketTCP_construct(&conn->sock, _s, conn->session_key);
|
||||||
|
|
||||||
@ -73,11 +73,12 @@ Result(ServerConnection*) ServerConnection_open(ClientCredentials* client_creden
|
|||||||
|
|
||||||
// construct ClientHandshake in dec_buf
|
// construct ClientHandshake in dec_buf
|
||||||
try_void(ClientHandshake_tryConstruct((ClientHandshake*)dec_buf.data, conn->session_key));
|
try_void(ClientHandshake_tryConstruct((ClientHandshake*)dec_buf.data, conn->session_key));
|
||||||
|
dec_size = sizeof(ClientHandshake);
|
||||||
// encrypt by server public key
|
// encrypt by server public key
|
||||||
try(enc_size, u,
|
try(enc_size, u,
|
||||||
RSAEncryptor_encrypt(
|
RSAEncryptor_encrypt(
|
||||||
&conn->rsa_enc,
|
&conn->rsa_enc,
|
||||||
Array_sliceBefore(dec_buf, sizeof(ClientHandshake)),
|
Array_sliceBefore(dec_buf, dec_size),
|
||||||
enc_buf
|
enc_buf
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
@ -107,7 +108,7 @@ Result(ServerConnection*) ServerConnection_open(ClientCredentials* client_creden
|
|||||||
);
|
);
|
||||||
|
|
||||||
// receive error message of length packet_header->content_size
|
// receive error message of length packet_header->content_size
|
||||||
enc_size = AESStreamEncryptor_calcDstSize(packet_header->content_size);
|
enc_size = packet_header->content_size;
|
||||||
if(enc_size > enc_buf.size)
|
if(enc_size > enc_buf.size)
|
||||||
enc_size = enc_buf.size;
|
enc_size = enc_buf.size;
|
||||||
try_void(
|
try_void(
|
||||||
@ -124,12 +125,12 @@ Result(ServerConnection*) ServerConnection_open(ClientCredentials* client_creden
|
|||||||
Return RESULT_ERROR((char*)err_buf.data, true);
|
Return RESULT_ERROR((char*)err_buf.data, true);
|
||||||
}
|
}
|
||||||
case PacketType_ServerHandshake: {
|
case PacketType_ServerHandshake: {
|
||||||
enc_size = AESStreamEncryptor_calcDstSize(sizeof(ServerHandshake) - sizeof(PacketHeader));
|
enc_size = sizeof(ServerHandshake) - sizeof(PacketHeader);
|
||||||
try_void(
|
try_void(
|
||||||
EncryptedSocketTCP_recv(
|
EncryptedSocketTCP_recv(
|
||||||
&conn->sock,
|
&conn->sock,
|
||||||
Array_sliceBefore(enc_buf, enc_size),
|
Array_sliceBefore(enc_buf, enc_size),
|
||||||
Array_sliceAfter(dec_buf, sizeof(PacketHeader)),
|
dec_buf,
|
||||||
SocketRecvFlag_WaitAll
|
SocketRecvFlag_WaitAll
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|||||||
@ -23,16 +23,32 @@ static Result(void) commandExec(str command, bool* stop);
|
|||||||
static Result(void) askUserNameAndPassword(ClientCredentials** cred){
|
static Result(void) askUserNameAndPassword(ClientCredentials** cred){
|
||||||
Deferral(8);
|
Deferral(8);
|
||||||
|
|
||||||
|
char username_buf[1024];
|
||||||
|
str usrername = str_null;
|
||||||
|
while(true) {
|
||||||
printf("username: ");
|
printf("username: ");
|
||||||
char username[1024];
|
fgets(username_buf, sizeof(username_buf), stdin);
|
||||||
fgets(username, sizeof(username), stdin);
|
usrername = str_from_cstr(username_buf);
|
||||||
|
if(usrername.size < 4){
|
||||||
|
printf("ERROR: username length must be at least 4\n");
|
||||||
|
}
|
||||||
|
else break;
|
||||||
|
}
|
||||||
|
|
||||||
|
char password_buf[1024];
|
||||||
|
str password = str_null;
|
||||||
|
while(true) {
|
||||||
printf("password: ");
|
printf("password: ");
|
||||||
char password[1024];
|
|
||||||
// TODO: hide password
|
// TODO: hide password
|
||||||
fgets(password, sizeof(password), stdin);
|
fgets(password_buf, sizeof(password_buf), stdin);
|
||||||
|
password = str_from_cstr(password_buf);
|
||||||
|
if(password.size < 8){
|
||||||
|
printf("ERROR: password length must be at least 8\n");
|
||||||
|
}
|
||||||
|
else break;
|
||||||
|
}
|
||||||
|
|
||||||
try(*cred, p, ClientCredentials_create(str_from_cstr(username), str_from_cstr(password)));
|
try(*cred, p, ClientCredentials_create(usrername, password));
|
||||||
Return RESULT_VOID;
|
Return RESULT_VOID;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -106,6 +122,7 @@ static Result(void) commandExec(str command, bool* stop){
|
|||||||
|
|
||||||
printf("connecting to server...\n");
|
printf("connecting to server...\n");
|
||||||
try(_server_connection, p, ServerConnection_open(_client_credentials, new_server_link.data));
|
try(_server_connection, p, ServerConnection_open(_client_credentials, new_server_link.data));
|
||||||
|
printf("connection established\n");
|
||||||
|
|
||||||
// TODO: request server info
|
// TODO: request server info
|
||||||
// show server info
|
// show server info
|
||||||
|
|||||||
@ -176,7 +176,6 @@ void AESStreamDecryptor_construct(AESStreamDecryptor* ptr, Array(u8) key, const
|
|||||||
|
|
||||||
Result(u32) AESStreamDecryptor_decrypt(AESStreamDecryptor* ptr, Array(u8) src, Array(u8) dst){
|
Result(u32) AESStreamDecryptor_decrypt(AESStreamDecryptor* ptr, Array(u8) src, Array(u8) dst){
|
||||||
Deferral(4);
|
Deferral(4);
|
||||||
try_assert(src.size >= AESStreamEncryptor_calcDstSize(0));
|
|
||||||
u32 decrypted_size = __AESStreamDecryptor_calcDstSize(src.size);
|
u32 decrypted_size = __AESStreamDecryptor_calcDstSize(src.size);
|
||||||
try_assert(dst.size >= decrypted_size);
|
try_assert(dst.size >= decrypted_size);
|
||||||
|
|
||||||
|
|||||||
@ -88,6 +88,7 @@ typedef struct AESStreamEncryptor {
|
|||||||
/// @param dec_class &br_aes_XXX_ctr_vtable
|
/// @param dec_class &br_aes_XXX_ctr_vtable
|
||||||
void AESStreamEncryptor_construct(AESStreamEncryptor* ptr, Array(u8) key, const br_block_ctr_class* ctr_class);
|
void AESStreamEncryptor_construct(AESStreamEncryptor* ptr, Array(u8) key, const br_block_ctr_class* ctr_class);
|
||||||
|
|
||||||
|
/// use this only at the beginning of the stream
|
||||||
#define AESStreamEncryptor_calcDstSize(src_size) (src_size + __AES_BLOCK_IV_SIZE)
|
#define AESStreamEncryptor_calcDstSize(src_size) (src_size + __AES_BLOCK_IV_SIZE)
|
||||||
|
|
||||||
/// @brief If ptr->block_counter == 0, writes random IV to `dst`. After that writes encrypted data to dst.
|
/// @brief If ptr->block_counter == 0, writes random IV to `dst`. After that writes encrypted data to dst.
|
||||||
@ -114,7 +115,7 @@ typedef struct AESStreamDecryptor {
|
|||||||
void AESStreamDecryptor_construct(AESStreamDecryptor* ptr, Array(u8) key, const br_block_ctr_class* ctr_class);
|
void AESStreamDecryptor_construct(AESStreamDecryptor* ptr, Array(u8) key, const br_block_ctr_class* ctr_class);
|
||||||
|
|
||||||
/// @brief Reads IV from `src`, then decrypts data and writes it to dst
|
/// @brief Reads IV from `src`, then decrypts data and writes it to dst
|
||||||
/// @param src array of size at least AESStreamEncryptor_calcDstSize(0).
|
/// @param src array of any size
|
||||||
/// @param dst array of size >= src.size
|
/// @param dst array of size >= src.size
|
||||||
/// @return size of decrypted data
|
/// @return size of decrypted data
|
||||||
Result(u32) AESStreamDecryptor_decrypt(AESStreamDecryptor* ptr, Array(u8) src, Array(u8) dst);
|
Result(u32) AESStreamDecryptor_decrypt(AESStreamDecryptor* ptr, Array(u8) src, Array(u8) dst);
|
||||||
|
|||||||
@ -180,14 +180,15 @@ void RSAEncryptor_construct(RSAEncryptor* ptr, const br_rsa_public_key* pk){
|
|||||||
}
|
}
|
||||||
|
|
||||||
Result(u32) RSAEncryptor_encrypt(RSAEncryptor* ptr, Array(u8) src, Array(u8) dst){
|
Result(u32) RSAEncryptor_encrypt(RSAEncryptor* ptr, Array(u8) src, Array(u8) dst){
|
||||||
const u32 max_src_size = RSAEncryptor_calcMaxSrcSize(ptr->pk->nlen * 8, 256);
|
u32 key_size_bytes = ptr->pk->nlen;
|
||||||
|
const u32 max_src_size = RSAEncryptor_calcMaxSrcSize(key_size_bytes * 8, 256);
|
||||||
if(src.size > max_src_size){
|
if(src.size > max_src_size){
|
||||||
return RESULT_ERROR_FMT("src.size (%u) must be <= %u (use RSAEncryptor_calcMaxSrcSize)",
|
return RESULT_ERROR_FMT("src.size (%u) must be <= %u (use RSAEncryptor_calcMaxSrcSize)",
|
||||||
src.size, max_src_size);
|
src.size, max_src_size);
|
||||||
}
|
}
|
||||||
if(dst.size < ptr->pk->nlen){
|
if(dst.size < key_size_bytes){
|
||||||
return RESULT_ERROR_FMT("dst.size (%u) must be >= %u (key length in bytes)",
|
return RESULT_ERROR_FMT("dst.size (%u) must be >= %u (key length in bytes)",
|
||||||
dst.size, (u32)ptr->pk->nlen);
|
dst.size, key_size_bytes);
|
||||||
}
|
}
|
||||||
size_t sz = br_rsa_i31_oaep_encrypt(
|
size_t sz = br_rsa_i31_oaep_encrypt(
|
||||||
&ptr->rng.vtable, &br_sha256_vtable,
|
&ptr->rng.vtable, &br_sha256_vtable,
|
||||||
@ -207,18 +208,26 @@ void RSADecryptor_construct(RSADecryptor* ptr, const br_rsa_private_key* sk){
|
|||||||
ptr->sk = sk;
|
ptr->sk = sk;
|
||||||
}
|
}
|
||||||
|
|
||||||
Result(u32) RSADecryptor_decrypt(RSADecryptor* ptr, Array(u8) buf){
|
Result(u32) RSADecryptor_decrypt(RSADecryptor* ptr, Array(u8) src, Array(u8) dst){
|
||||||
if(buf.size != ptr->sk->n_bitlen/8){
|
u32 key_size_bits = ptr->sk->n_bitlen;
|
||||||
return RESULT_ERROR_FMT("buf.size (%u) must be == %u (key length in bytes)",
|
if(src.size != key_size_bits/8){
|
||||||
buf.size, ptr->sk->n_bitlen/8);
|
return RESULT_ERROR_FMT("src.size (%u) must be == %u (key length in bytes)",
|
||||||
|
src.size, key_size_bits/8);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t sz = buf.size;
|
const u32 max_src_size = RSAEncryptor_calcMaxSrcSize(key_size_bits, 256);
|
||||||
|
if(dst.size < max_src_size){
|
||||||
|
return RESULT_ERROR_FMT("dst.size (%u) must be >= %u (use RSAEncryptor_calcMaxSrcSize)",
|
||||||
|
dst.size, max_src_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
memcpy(dst.data, src.data, src.size);
|
||||||
|
size_t sz = src.size;
|
||||||
size_t r = br_rsa_i31_oaep_decrypt(
|
size_t r = br_rsa_i31_oaep_decrypt(
|
||||||
&br_sha256_vtable,
|
&br_sha256_vtable,
|
||||||
NULL, 0,
|
NULL, 0,
|
||||||
ptr->sk,
|
ptr->sk,
|
||||||
buf.data, &sz);
|
dst.data, &sz);
|
||||||
|
|
||||||
if(r == 0){
|
if(r == 0){
|
||||||
return RESULT_ERROR("RSA encryption failed", false);
|
return RESULT_ERROR("RSA encryption failed", false);
|
||||||
|
|||||||
@ -101,6 +101,7 @@ typedef struct RSADecryptor {
|
|||||||
/// RSA OAEP encryption with SHA256 hashing algorithm
|
/// RSA OAEP encryption with SHA256 hashing algorithm
|
||||||
void RSADecryptor_construct(RSADecryptor* ptr, const br_rsa_private_key* sk);
|
void RSADecryptor_construct(RSADecryptor* ptr, const br_rsa_private_key* sk);
|
||||||
|
|
||||||
/// @param buf buffer with size == key size in bytes
|
/// @param src buffer with size == key size in bytes
|
||||||
|
/// @param dst buffer with size >= `RSAEncryptor_calcMaxSrcSize(key_size_bits, 256)`
|
||||||
/// @return size of decrypted data
|
/// @return size of decrypted data
|
||||||
Result(u32) RSADecryptor_decrypt(RSADecryptor* ptr, Array(u8) buf);
|
Result(u32) RSADecryptor_decrypt(RSADecryptor* ptr, Array(u8) src, Array(u8) dst);
|
||||||
|
|||||||
@ -24,6 +24,7 @@ Result(i32) EncryptedSocketTCP_recv(EncryptedSocketTCP* ptr,
|
|||||||
{
|
{
|
||||||
Deferral(4);
|
Deferral(4);
|
||||||
try(i32 received_size, i, socket_recv(ptr->sock, encrypted_buf, flags));
|
try(i32 received_size, i, socket_recv(ptr->sock, encrypted_buf, flags));
|
||||||
|
//TODO: return error if WaitAll flag was set and socket closed before filling the buffer
|
||||||
//TODO: return something when received_size == 0 (socket has been closed)
|
//TODO: return something when received_size == 0 (socket has been closed)
|
||||||
encrypted_buf.size = received_size;
|
encrypted_buf.size = received_size;
|
||||||
try(i32 decrypted_size, u, AESStreamDecryptor_decrypt(&ptr->dec, encrypted_buf, decrypted_buf));
|
try(i32 decrypted_size, u, AESStreamDecryptor_decrypt(&ptr->dec, encrypted_buf, decrypted_buf));
|
||||||
|
|||||||
@ -9,6 +9,6 @@ Result(void) ClientHandshake_tryConstruct(ClientHandshake* ptr, Array(u8) sessio
|
|||||||
}
|
}
|
||||||
|
|
||||||
void ServerHandshake_construct(ServerHandshake* ptr, u64 session_id){
|
void ServerHandshake_construct(ServerHandshake* ptr, u64 session_id){
|
||||||
PacketHeader_construct(&ptr->header, PROTOCOL_VERSION, PacketType_ClientHandshake, sizeof(session_id));
|
PacketHeader_construct(&ptr->header, PROTOCOL_VERSION, PacketType_ServerHandshake, sizeof(session_id));
|
||||||
ptr->session_id = session_id;
|
ptr->session_id = session_id;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -32,6 +32,8 @@ Result(ClientConnection*) ClientConnection_accept(ServerCredentials* server_cred
|
|||||||
Defer(free(dec_buf.data));
|
Defer(free(dec_buf.data));
|
||||||
u32 enc_size = 0, dec_size = 0;
|
u32 enc_size = 0, dec_size = 0;
|
||||||
|
|
||||||
|
// TODO: set socket timeout to 5 seconds
|
||||||
|
|
||||||
// receive message encrypted by server public key
|
// receive message encrypted by server public key
|
||||||
try(enc_size, u,
|
try(enc_size, u,
|
||||||
socket_recv(
|
socket_recv(
|
||||||
@ -47,7 +49,8 @@ Result(ClientConnection*) ClientConnection_accept(ServerCredentials* server_cred
|
|||||||
try(dec_size, u,
|
try(dec_size, u,
|
||||||
RSADecryptor_decrypt(
|
RSADecryptor_decrypt(
|
||||||
&rsa_dec,
|
&rsa_dec,
|
||||||
Array_sliceBefore(enc_buf, enc_size)
|
Array_sliceBefore(enc_buf, enc_size),
|
||||||
|
dec_buf
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user