diff --git a/src/client/ServerConnection.c b/src/client/ServerConnection.c index 42039c4..b540996 100644 --- a/src/client/ServerConnection.c +++ b/src/client/ServerConnection.c @@ -63,16 +63,16 @@ Result(ServerConnection*) ServerConnection_open(ClientCredentials* client_creden // connect to server address try(Socket _s, i, socket_open_TCP()); - // TODO: set socket timeout to 5 seconds + try_void(socket_setTimeout(_s, SOCKET_TIMEOUT_MS_DEFAULT)); try_void(socket_connect(_s, conn->server_end)); EncryptedSocketTCP_construct(&conn->sock, _s, NETWORK_BUFFER_SIZE, conn->session_key); // send PacketHeader and ClientHandshake // encryption by server public key - PacketHeader packet_header; + PacketHeader packet_header = {0}; PacketHeader_construct(&packet_header, PROTOCOL_VERSION, PacketType_ClientHandshake, sizeof(ClientHandshake)); - ClientHandshake client_handshake; + ClientHandshake client_handshake = {0}; ClientHandshake_construct(&client_handshake, conn->session_key); try_void(EncryptedSocketTCP_sendStructRSA(&conn->sock, &conn->rsa_enc, &packet_header)); diff --git a/src/client/client.c b/src/client/client.c index f295540..9f16814 100644 --- a/src/client/client.c +++ b/src/client/client.c @@ -1,5 +1,6 @@ #include "client.h" #include "term.h" +#include "tlibc/time.h" static const str greeting_art = STR( " ^,,^ ╱|\n" diff --git a/src/network/internal.h b/src/network/internal.h index 0b51dba..f2feea1 100644 --- a/src/network/internal.h +++ b/src/network/internal.h @@ -10,20 +10,28 @@ #endif #endif +// include OS-dependent socket headers #if KN_USE_WINSOCK #include // There you can see what error codes mean. #include - #define RESULT_ERROR_SOCKET() RESULT_ERROR(sprintf_malloc(64, "Winsock error %i", WSAGetLastError()), true) #else #include #include + #include #include #include #include #include - #define RESULT_ERROR_SOCKET() RESULT_ERROR(strerror(errno), false) +#endif + +#if KN_USE_WINSOCK + #define RESULT_ERROR_SOCKET()\ + RESULT_ERROR(sprintf_malloc(64, "Winsock error %i (look in )", WSAGetLastError()), true); +#else + #define RESULT_ERROR_SOCKET()\ + RESULT_ERROR(strerror(errno), false); #endif struct sockaddr_in EndpointIPv4_toSockaddr(EndpointIPv4 end); diff --git a/src/network/socket.c b/src/network/socket.c index 9541361..c0546b3 100755 --- a/src/network/socket.c +++ b/src/network/socket.c @@ -20,7 +20,30 @@ void socket_close(Socket s){ } Result(void) socket_shutdown(Socket s, SocketShutdownType direction){ - if(shutdown(s, (int)direction) == -1) + if(shutdown(s, (int)direction) != 0) + return RESULT_ERROR_SOCKET(); + return RESULT_VOID; +} + +Result(void) socket_setTimeout(Socket s, u32 ms){ + void* opt; + u32 optlen; + +#if KN_USE_WINSOCK + opt = &ms; + optlen = sizeof(ms); +#else + struct timeval tv = { + .tv_sec = ms/1000, + .tv_usec = (ms%1000)*1000 + }; + opt = &tv; + optlen = sizeof(tv); +#endif + + if(setsockopt(s, SOL_SOCKET, SO_SNDTIMEO, opt, optlen) != 0) + return RESULT_ERROR_SOCKET(); + if(setsockopt(s, SOL_SOCKET, SO_RCVTIMEO, opt, optlen) != 0) return RESULT_ERROR_SOCKET(); return RESULT_VOID; } @@ -38,18 +61,18 @@ Result(void) socket_listen(Socket s, i32 backlog){ return RESULT_VOID; } -Result(Socket) socket_accept(Socket main_socket, NULLABLE(EndpointIPv4*) remote_end) { +Result(Socket) socket_accept(Socket listening_sock, NULLABLE(EndpointIPv4*) remote_end) { struct sockaddr_in remote_addr = {0}; i32 sockaddr_size = sizeof(remote_addr); - Socket user_connection = accept(main_socket, (void*)&remote_addr, (void*)&sockaddr_size); - if(user_connection == -1) + Socket accepted_sock = accept(listening_sock, (void*)&remote_addr, (void*)&sockaddr_size); + if(accepted_sock == -1) return RESULT_ERROR_SOCKET(); //TODO: add IPV6 support (struct sockaddr_in6) assert(sockaddr_size == sizeof(remote_addr)); if(remote_end) *remote_end = EndpointIPv4_fromSockaddr(remote_addr); - return RESULT_VALUE(i, user_connection); + return RESULT_VALUE(i, accepted_sock); } Result(void) socket_connect(Socket s, EndpointIPv4 remote_end){ diff --git a/src/network/socket.h b/src/network/socket.h index 72ff845..7e8dadb 100755 --- a/src/network/socket.h +++ b/src/network/socket.h @@ -2,7 +2,6 @@ #include "endpoint.h" #include "tlibc/errors.h" #include "tlibc/collections/Array.h" -#include "tlibc/time.h" typedef enum SocketShutdownType { SocketShutdownType_Receive = 0, @@ -18,13 +17,19 @@ typedef enum SocketRecvFlag { typedef i64 Socket; +#define SOCKET_TIMEOUT_MS_DEFAULT 5000 +#define SOCKET_TIMEOUT_MS_INFINITE 0 + Result(Socket) socket_open_TCP(); void socket_close(Socket s); Result(void) socket_shutdown(Socket s, SocketShutdownType direction); +Result(void) socket_setTimeout(Socket s, u32 ms); + Result(void) socket_bind(Socket s, EndpointIPv4 local_end); Result(void) socket_listen(Socket s, i32 backlog); -Result(Socket) socket_accept(Socket s, NULLABLE(EndpointIPv4*) remote_end); +Result(Socket) socket_accept(Socket listening_sock, NULLABLE(EndpointIPv4*) remote_end); Result(void) socket_connect(Socket s, EndpointIPv4 remote_end); + Result(void) socket_send(Socket s, Array(u8) buffer); Result(void) socket_sendto(Socket s, Array(u8) buffer, EndpointIPv4 dst); Result(i32) socket_recv(Socket s, Array(u8) buffer, SocketRecvFlag flags); diff --git a/src/server/ClientConnection.c b/src/server/ClientConnection.c index 0e690d5..f7949e8 100644 --- a/src/server/ClientConnection.c +++ b/src/server/ClientConnection.c @@ -28,14 +28,14 @@ Result(ClientConnection*) ClientConnection_accept(ServerCredentials* server_cred // correct session key will be received from client later Array_memset(conn->session_key, 0); EncryptedSocketTCP_construct(&conn->sock, sock_tcp, NETWORK_BUFFER_SIZE, conn->session_key); - // TODO: set socket timeout to 5 seconds + try_void(socket_setTimeout(sock_tcp, SOCKET_TIMEOUT_MS_DEFAULT)); // decrypt the rsa messages using server private key RSADecryptor rsa_dec; RSADecryptor_construct(&rsa_dec, &server_credentials->rsa_sk); // receive PacketHeader - PacketHeader packet_header; + PacketHeader packet_header = {0}; try_void(EncryptedSocketTCP_recvStructRSA(&conn->sock, &rsa_dec, &packet_header)); try_void(PacketHeader_validateMagic(&packet_header)); if(packet_header.type != PacketType_ClientHandshake){ @@ -46,7 +46,7 @@ Result(ClientConnection*) ClientConnection_accept(ServerCredentials* server_cred } // receive ClientHandshake - ClientHandshake client_handshake; + ClientHandshake client_handshake = {0}; try_void(EncryptedSocketTCP_recvStructRSA(&conn->sock, &rsa_dec, &client_handshake)); // use received session key @@ -56,7 +56,7 @@ Result(ClientConnection*) ClientConnection_accept(ServerCredentials* server_cred // send PacketHeader and ServerHandshake over encrypted TCP socket PacketHeader_construct(&packet_header, PROTOCOL_VERSION, PacketType_ServerHandshake, sizeof(ServerHandshake)); - ServerHandshake server_handshake; + ServerHandshake server_handshake = {0}; ServerHandshake_construct(&server_handshake, session_id); try_void(EncryptedSocketTCP_sendStruct(&conn->sock, &packet_header)); diff --git a/src/server/server.c b/src/server/server.c index dc73526..5922d51 100644 --- a/src/server/server.c +++ b/src/server/server.c @@ -1,5 +1,6 @@ #include #include "tlibc/filesystem.h" +#include "tlibc/time.h" #include "db/idb.h" #include "server.h" #include "config.h"