diff --git a/src/client/ServerConnection.c b/src/client/ServerConnection.c index 34c9947..90e81b8 100644 --- a/src/client/ServerConnection.c +++ b/src/client/ServerConnection.c @@ -67,9 +67,9 @@ Result(ServerConnection*) ServerConnection_open(ClientCredential* client_credent EncryptedSocketTCP_construct(&conn->system_socket, _s, conn->session_key); try_void(socket_connect(conn->system_socket.sock, conn->server_end)); - Array(u8) encrypted_buf = Array_alloc_size(64*1024); + Array(u8) encrypted_buf = Array_alloc_size(8*1024); Defer(free(encrypted_buf.data)); - Array(u8) decrypted_buf = Array_alloc_size(64*1024); + Array(u8) decrypted_buf = Array_alloc_size(8*1024); Defer(free(decrypted_buf.data)); u32 encrypted_size = 0, decrypted_size = 0; @@ -84,9 +84,9 @@ Result(ServerConnection*) ServerConnection_open(ClientCredential* client_credent Array_construct_size(encrypted_buf.data, encrypted_size))); // receive server response - encrypted_size = sizeof(PacketHeader); + encrypted_size = AESStreamEncryptor_calcDstSize(sizeof(PacketHeader)); try(decrypted_size, u, EncryptedSocketTCP_recv(&conn->system_socket, - Array_construct_size(encrypted_buf.data, sizeof(PacketHeader)), + Array_construct_size(encrypted_buf.data, encrypted_size), decrypted_buf, SocketRecvFlag_WaitAll)); try_assert(decrypted_size == sizeof(PacketHeader)); @@ -96,18 +96,30 @@ Result(ServerConnection*) ServerConnection_open(ClientCredential* client_credent switch(packet_header->type){ case PacketType_ErrorMessage: Array(u8) err_buf = Array_alloc_size(packet_header->content_size + 1); - bool err_msg_return = false; + bool err_msg_completed = false; Defer( - if(!err_msg_return) + if(!err_msg_completed) free(err_buf.data); ); - ((u8*)err_buf.data)[packet_header->content_size] = 0; - try_void(EncryptedSocketTCP_recv(&conn->system_socket, encrypted_buf, err_buf, SocketRecvFlag_WaitAll)); - err_msg_return = true; + encrypted_size = AESStreamEncryptor_calcDstSize(packet_header->content_size); + if(encrypted_size > encrypted_buf.size) + encrypted_size = encrypted_buf.size; + try_void(EncryptedSocketTCP_recv(&conn->system_socket, + Array_construct_size(encrypted_buf.data, encrypted_size), + err_buf, + SocketRecvFlag_WaitAll)); + ((u8*)err_buf.data)[encrypted_size] = 0; + err_msg_completed = true; Return RESULT_ERROR((char*)err_buf.data, true); - break; - case PacketType_ClientHandshake: - //TODO: receive the rest of the struct + case PacketType_ServerHandshake: + encrypted_size = AESStreamEncryptor_calcDstSize(sizeof(ServerHandshake) - sizeof(PacketHeader)); + try_void(EncryptedSocketTCP_recv(&conn->system_socket, + Array_construct_size(encrypted_buf.data, encrypted_size), + Array_construct_size((u8*)decrypted_buf.data + sizeof(PacketHeader), decrypted_buf.size - sizeof(PacketHeader)), + SocketRecvFlag_WaitAll + )); + ServerHandshake* server_handshake = decrypted_buf.data; + conn->session_id = server_handshake->session_id; break; default: Return RESULT_ERROR_FMT("unexpected response type: %i", packet_header->type); diff --git a/src/client/client.c b/src/client/client.c index b90ff10..beed1a6 100644 --- a/src/client/client.c +++ b/src/client/client.c @@ -56,7 +56,7 @@ Result(void) client_run() { Defer(rl_free(command_input_prev)); str command_input = str_null; bool stop = false; - while((command_input_raw = readline("> ")) && !stop){ + while(!stop && (command_input_raw = readline("> "))){ rl_free(command_input_prev); command_input_prev = command_input_raw; command_input = str_from_cstr(command_input_raw); @@ -104,7 +104,7 @@ static Result(void) commandExec(str command, bool* stop){ else if (is_alias("j") || is_alias("join")){ ServerConnection_close(_server_connection); - puts("Enter server address (ip:port): "); + puts("Enter server address (ip:port:public_key): "); fgets(answer_buf, answer_buf_size, stdin); str new_server_link = str_from_cstr(answer_buf); str_trim(&new_server_link, true); diff --git a/src/server/server.c b/src/server/server.c index 00ccbdd..1b64a19 100644 --- a/src/server/server.c +++ b/src/server/server.c @@ -7,6 +7,7 @@ static void* handle_connection(void* _args); typedef struct ConnectionHandlerArgs { Socket accepted_socket; EndpointIPv4 client_end; + u64 session_id; } ConnectionHandlerArgs; Result(void) server_run(cstr server_endpoint_str){ @@ -18,9 +19,12 @@ Result(void) server_run(cstr server_endpoint_str){ try_void(socket_bind(main_socket, server_end)); try_void(socket_listen(main_socket, 512)); + u64 session_id = 1; + while(true){ ConnectionHandlerArgs* args = (ConnectionHandlerArgs*)malloc(sizeof(ConnectionHandlerArgs)); try(args->accepted_socket, i, socket_accept(main_socket, &args->client_end)); + args->session_id = session_id++; pthread_t conn_thread = {0}; try_stderrcode(pthread_create(&conn_thread, NULL, handle_connection, args)); } @@ -31,6 +35,7 @@ Result(void) server_run(cstr server_endpoint_str){ static void* handle_connection(void* _args){ Deferral(64); //ConnectionHandlerArgs* args = (ConnectionHandlerArgs*)_args; + Defer(free(_args)); // TODO: receive handshake and session key //ClientConnection conn;