diff --git a/src/cobo/handshake.c b/src/cobo/handshake.c index f75c08a8..d20956be 100644 --- a/src/cobo/handshake.c +++ b/src/cobo/handshake.c @@ -29,6 +29,7 @@ Place, Suite 330, Boston, MA 02111-1307 USA #include #include #include +#include #include #include "handshake.h" @@ -82,25 +83,37 @@ Place, Suite 330, Boston, MA 02111-1307 USA #define CLIENT_TO_SERVER_SIG 0x9B1CC028 #define HSHAKE_AGAIN -16 +#define HOSTNAME_MAX_LEN 256 typedef struct { + uint64_t random_number; + uint64_t session_id; + uint32_t counter; + char hostname[HOSTNAME_MAX_LEN]; +} random_number_t; + +typedef struct { + random_number_t random_number; uint32_t signature; uint16_t server_port; uint16_t client_port; uid_t uid; gid_t gid; - uint64_t session_id; unsigned char server_addr[MAX_ADDR_LEN]; unsigned char client_addr[MAX_ADDR_LEN]; } handshake_packet_t; +typedef struct { + random_number_t random_number; + uint32_t signature; +} random_number_packet_t; + typedef struct { int i_am_server; struct sockaddr server_addr; struct sockaddr client_addr; } connection_info_t; - static FILE *debug_file = NULL; static char *last_error_message = NULL; static char *last_security_message = NULL; @@ -109,6 +122,7 @@ static char *saved_key_filepath; static unsigned int saved_key_len; static connection_info_t *saved_conninfo; static int timeout_seconds = 0; +static uint32_t unique_number_counter = 0; /** Routines for creating a handshake_packet_t **/ static int encode_addr(struct sockaddr *addr, unsigned char *target_addr, uint16_t *port); @@ -116,31 +130,31 @@ static int encode_packet(handshake_packet_t *packet, uint64_t session_id, struct sockaddr *server_addr, struct sockaddr *client_addr); /** Routines for turning a handshake_packet_t into an encrypted buffer **/ -static int encrypt_packet(handshake_protocol_t *hdata, handshake_packet_t *packet, +static int encrypt_packet(handshake_protocol_t *hdata, void *packet, size_t packet_size, unsigned char **packet_buffer, size_t *packet_buffer_size); -static int none_encrypt_packet(handshake_packet_t *packet, +static int none_encrypt_packet(void *packet, size_t packet_size, unsigned char **packet_buffer, size_t *packet_buffer_size); -static int munge_encrypt_packet(handshake_packet_t *packet, +static int munge_encrypt_packet(void *packet, size_t packet_size, unsigned char **packet_buffer, size_t *packet_buffer_size); static int filekey_encrypt_packet(char *key_filepath, int key_length_bytes, - handshake_packet_t *packet, + void *packet, size_t packet_size, unsigned char **packet_buffer, size_t *packet_buffer_size); static int key_encrypt_packet(unsigned char *key, int key_length_bytes, - handshake_packet_t *packet, + void *packet, size_t packet_size, unsigned char **packet_buffer, size_t *packet_buffer_size); /** Routines for decrypting and validating a handshake_packet_t **/ -static int decrypt_packet(handshake_protocol_t *hdata, handshake_packet_t *expected_packet, +static int decrypt_packet(handshake_protocol_t *hdata, void *recvd_packet, size_t packet_size, unsigned char *recvd_buffer, size_t recvd_buffer_size); -static int none_decrypt_packet(handshake_packet_t *expected_packet, +static int none_decrypt_packet(void *recvd_packet, size_t packet_size, unsigned char *recvd_buffer, size_t recvd_buffer_size); -static int munge_decrypt_packet(handshake_packet_t *expected_packet, +static int munge_decrypt_packet(void *recvd_packet, size_t packet_size, unsigned char *recvd_buffer, size_t recvd_buffer_size); static int key_decrypt_packet(unsigned char *key, unsigned int key_len, - handshake_packet_t *expected_packet, + void *recvd_packet, size_t packet_size, unsigned char *recvd_buffer, size_t recvd_buffer_size); -static int compare_packets(handshake_packet_t *expected_packet, - handshake_packet_t *recvd_packet); +static int compare_packets(handshake_packet_t *expected_packet, handshake_packet_t *recvd_packet); +static int compare_random_number_packets(random_number_packet_t *expected_random_number_packet, random_number_packet_t *recvd_random_number_packet); static int handshake_wrapper(int sockfd, handshake_protocol_t *hdata, uint64_t session_id, @@ -269,9 +283,10 @@ static int handshake_main(int sockfd, handshake_protocol_t *hdata, uint64_t sess int is_server) { int result, return_result, peer_result, socket_error = 0; - handshake_packet_t packet, expected_packet; - unsigned char *packet_buffer = NULL, *recvd_packet_buffer = NULL; - size_t packet_buffer_size = 0, recvd_packet_buffer_size = 0; + handshake_packet_t packet, expected_packet, recvd_packet; + random_number_packet_t random_number_packet, expected_random_number_packet, recvd_random_number_packet; + unsigned char *packet_buffer = NULL, *recvd_long_packet_buffer = NULL, *recvd_packet_buffer = NULL; + size_t packet_buffer_size = 0, recvd_long_packet_buffer_size = 0, recvd_packet_buffer_size = 0; /** * Exchange a public signature as a handshake to make sure @@ -286,8 +301,9 @@ static int handshake_main(int sockfd, handshake_protocol_t *hdata, uint64_t sess } /** - * Encode socket names, session, gid, and uid into a handshake_packet_t + * Encode socket names, session, random number, gid, and uid into a handshake_packet_t **/ + debug_printf("Creating outgoing packet for handshake\n"); result = encode_packet(&packet, session_id, &saved_conninfo->server_addr, &saved_conninfo->client_addr); if (result < 0) { @@ -297,15 +313,15 @@ static int handshake_main(int sockfd, handshake_protocol_t *hdata, uint64_t sess } packet.signature = is_server ? SERVER_TO_CLIENT_SIG : CLIENT_TO_SERVER_SIG; debug_printf("Encoded packet: server_port = %d, client_port = %d, " - "uid = %d, gid = %d, session_id = %llu, signature = %lx\n", - (int) packet.server_port, (int) packet.client_port, (int) packet.uid, (int) packet.gid, - (unsigned long long) packet.session_id, (unsigned long) packet.signature); + "uid = %d, gid = %d, session_id = %llu, signature = %lx, random_number =%lx\n", + (int) packet.server_port, (int) packet.client_port, (int) packet.uid, (int) packet.gid, + (unsigned long long) packet.random_number.session_id, (unsigned long) packet.signature, (unsigned long) packet.random_number.random_number); /** * Encrypt/Sign the handshake_packet_t, producing a packet_buffer **/ debug_printf("Encrypting outgoing packet\n"); - result = encrypt_packet(hdata, &packet, &packet_buffer, &packet_buffer_size); + result = encrypt_packet(hdata, &packet, sizeof(handshake_packet_t), &packet_buffer, &packet_buffer_size); if (result < 0) { debug_printf("Error in server encrypting outgoing packet"); return_result = result; @@ -317,6 +333,8 @@ static int handshake_main(int sockfd, handshake_protocol_t *hdata, uint64_t sess * Send the packet_buffer on the network **/ result = send_packet(sockfd, packet_buffer, packet_buffer_size); + free(packet_buffer); + packet_buffer = NULL; if (result < 0) { debug_printf("Problem sending packet on network: %s\n", strerror(errno)); return_result = result; @@ -327,7 +345,7 @@ static int handshake_main(int sockfd, handshake_protocol_t *hdata, uint64_t sess /** * Recieve a packet_buffer on the network **/ - result = recv_packet(sockfd, &recvd_packet_buffer, &recvd_packet_buffer_size); + result = recv_packet(sockfd, &recvd_long_packet_buffer, &recvd_long_packet_buffer_size); if (result < 0) { debug_printf("Problem receiving packet\n"); return_result = result; @@ -336,7 +354,18 @@ static int handshake_main(int sockfd, handshake_protocol_t *hdata, uint64_t sess } /** - * Produce an expected handshake_packet_t + * Decrypt their initial handshake packet, containing their random number + **/ + debug_printf("Decrypting their initial handshake packet, containing their random number\n"); + result = decrypt_packet(hdata, &recvd_packet, sizeof(handshake_packet_t), recvd_long_packet_buffer, recvd_long_packet_buffer_size); + if (result < 0) { + debug_printf("Error decrypting and checking received packet\n"); + return_result = result; + goto done; + } + + /** + * Produce an expected handshake_packet_t, random_number_packet_t **/ debug_printf("Creating an expected packet\n"); result = encode_packet(&expected_packet, session_id, &saved_conninfo->server_addr, &saved_conninfo->client_addr); @@ -345,22 +374,72 @@ static int handshake_main(int sockfd, handshake_protocol_t *hdata, uint64_t sess return_result = result; goto done; } - expected_packet.signature = is_server ? CLIENT_TO_SERVER_SIG : SERVER_TO_CLIENT_SIG; - - /** - * Decrypt the packet recieved on the network and compare - * it to the expected handshake_packet_t - **/ - debug_printf("Decrypting and checking packet\n"); - result = decrypt_packet(hdata, &expected_packet, recvd_packet_buffer, recvd_packet_buffer_size); + + expected_packet.signature = is_server ? CLIENT_TO_SERVER_SIG : SERVER_TO_CLIENT_SIG; + + debug_printf("Checking random number packets\n"); + + debug_printf("Checking initial packet\n"); + result = compare_packets( &expected_packet, &recvd_packet); + if (result < 0) { + debug_printf("Error checking initial packet\n"); + return_result = result; + goto done; + } + + //encrypt random number packet that they are expecting + random_number_packet.random_number = recvd_packet.random_number; + random_number_packet.signature = packet.signature; + result = encrypt_packet(hdata, &random_number_packet, sizeof(random_number_packet_t), &packet_buffer, &packet_buffer_size); + if (result < 0) { + debug_printf("Error in server encrypting outgoing random_number_packet"); + return_result = result; + goto done; + } + + //send their expected packet random number + debug_printf("Sending their expected random number packet on network\n"); + result = send_packet(sockfd, packet_buffer, packet_buffer_size); + free(packet_buffer); + packet_buffer = NULL; + if (result < 0) { + debug_printf("Problem sending packet on network: %s\n", strerror(errno)); + return_result = result; + socket_error = 1; + return HSHAKE_DROP_CONNECTION; + } + + //get thier random number packet + debug_printf("Receiving their random number packet from network\n"); + result = recv_packet(sockfd, &recvd_packet_buffer, &recvd_packet_buffer_size); + if (result < 0) { + debug_printf("Problem receiving packet\n"); + return_result = result; + socket_error = 1; + goto done; + } + //decrypt their signed random number packet + debug_printf("Decrypting their random number packet\n"); + result = decrypt_packet(hdata, &recvd_random_number_packet, sizeof(random_number_packet_t), recvd_packet_buffer, recvd_packet_buffer_size); if (result < 0) { debug_printf("Error decrypting and checking received packet\n"); return_result = result; goto done; } - debug_printf("Successfully completed initial handshake\n"); + expected_random_number_packet.random_number = packet.random_number; + expected_random_number_packet.signature = is_server ? CLIENT_TO_SERVER_SIG : SERVER_TO_CLIENT_SIG; + debug_printf("Checking random number packets\n"); + result = compare_random_number_packets(&recvd_random_number_packet, &expected_random_number_packet); + if (result < 0) { + debug_printf("Error checking packet\n"); + return_result = result; + goto done; + } + + debug_printf("Successfully completed initial handshake\n"); + return_result = 0; done: @@ -381,7 +460,6 @@ static int handshake_main(int sockfd, handshake_protocol_t *hdata, uint64_t sess return_result = peer_result; } } - if (packet_buffer) free(packet_buffer); if (recvd_packet_buffer) @@ -410,13 +488,67 @@ static int encode_addr(struct sockaddr *addr, unsigned char *target_addr, uint16 return 0; } +int random_number(random_number_t *rand_num, uint64_t session_id) { + int local_errno; + int timeout = 100 * 30; //30 seconds + useconds_t ten_milliseconds = 10000; + ssize_t result, bytes_to_read, bytes_read; + char hostname_buf[HOSTNAME_MAX_LEN]; + + // Get hostname + if (gethostname(hostname_buf, HOSTNAME_MAX_LEN) != 0) { + local_errno = errno; + error_printf("Trouble getting hostname. Error: %s\n",strerror(local_errno)); + return -1; + } + hostname_buf[HOSTNAME_MAX_LEN-1] = '\0'; + + //random number file + bytes_to_read = sizeof(rand_num->random_number); + bytes_read = 0; + do { + result = getrandom(((unsigned char*) &rand_num->random_number) + bytes_read, + bytes_to_read - bytes_read, + GRND_NONBLOCK); + if (result<0) { + local_errno = errno; + if (local_errno == EAGAIN || local_errno == EINTR){ + if (--timeout <= 0) { + error_printf("Random number collection failed with repeated EAGAIN or EINTR\n"); + return -1; + } + usleep(ten_milliseconds); + continue; + } + else{ + error_printf("Random number collection failed. Error: %s\n", strerror(local_errno)); + return -1; + } + } + bytes_read += result; + } while (bytes_read < bytes_to_read); + + rand_num->session_id = session_id; + rand_num->counter = unique_number_counter; + strcpy(rand_num->hostname, hostname_buf); + unique_number_counter++; + + return 0; +} + + static int encode_packet(handshake_packet_t *packet, uint64_t session_id, struct sockaddr *server_addr, struct sockaddr *client_addr) -{ +{ int result; packet->uid = getuid(); packet->gid = getgid(); - packet->session_id = session_id; + + result = random_number(&packet->random_number, session_id); + if (result < 0) { + debug_printf("Error encoding random number\n"); + return result; + } result = encode_addr(server_addr, packet->server_addr, &packet->server_port); if (result < 0) { @@ -433,39 +565,39 @@ static int encode_packet(handshake_packet_t *packet, uint64_t session_id, return 0; } -static int encrypt_packet(handshake_protocol_t *hdata, handshake_packet_t *packet, +static int encrypt_packet(handshake_protocol_t *hdata, void *packet, size_t packet_size, unsigned char **packet_buffer, size_t *packet_buffer_size) { switch (hdata->mechanism) { case hs_none: debug_printf("Server skipping encryption of packet\n"); - return none_encrypt_packet(packet, packet_buffer, packet_buffer_size); + return none_encrypt_packet(packet, packet_size, packet_buffer, packet_buffer_size); case hs_munge: debug_printf("Server encrypting packet with munge\n"); - return munge_encrypt_packet(packet, packet_buffer, packet_buffer_size); + return munge_encrypt_packet(packet, packet_size, packet_buffer, packet_buffer_size); case hs_key_in_file: debug_printf("Server encrypting packet with key of size %d from file %s\n", hdata->data.key_in_file.key_length_bytes, hdata->data.key_in_file.key_filepath); return filekey_encrypt_packet(hdata->data.key_in_file.key_filepath, hdata->data.key_in_file.key_length_bytes, - packet, packet_buffer, packet_buffer_size); + packet, packet_size, packet_buffer, packet_buffer_size); case hs_explicit_key: debug_printf("Server encrypting packet with provided key of size %d\n", hdata->data.explicit_key.key_length_bytes); return key_encrypt_packet(hdata->data.explicit_key.key, hdata->data.explicit_key.key_length_bytes, - packet, packet_buffer, packet_buffer_size); + packet, packet_size, packet_buffer, packet_buffer_size); } abort(); return HSHAKE_INTERNAL_ERROR; } -static int none_encrypt_packet(handshake_packet_t *packet, +static int none_encrypt_packet(void *packet, size_t packet_size, unsigned char **packet_buffer, size_t *packet_buffer_size) { #if defined(ENABLE_NULL_ENCRYPTION) - *packet_buffer_size = sizeof(*packet); + *packet_buffer_size = packet_size; *packet_buffer = malloc(*packet_buffer_size); memcpy(*packet_buffer, packet, *packet_buffer_size); return 0; @@ -504,7 +636,7 @@ static int munge_create_context(munge_ctx_t *output_ctx) } #endif -static int munge_encrypt_packet(handshake_packet_t *packet, +static int munge_encrypt_packet(void *packet, size_t packet_size, unsigned char **packet_buffer, size_t *packet_buffer_size) { #if defined(MUNGE) @@ -519,7 +651,7 @@ static int munge_encrypt_packet(handshake_packet_t *packet, goto done; } - result = munge_encode((char **) packet_buffer, ctx, packet, sizeof(*packet)); + result = munge_encode((char **) packet_buffer, ctx, packet, packet_size); if (result != EMUNGE_SUCCESS) { error_printf("Munge failed to encrypt packet with error: %s\n", munge_ctx_strerror(ctx)); return_result = HSHAKE_INTERNAL_ERROR; @@ -632,7 +764,7 @@ static int read_key(char *key_filepath, int key_length_bytes) } static int filekey_encrypt_packet(char *key_filepath, int key_length_bytes, - handshake_packet_t *packet, + void *packet, size_t packet_size, unsigned char **packet_buffer, size_t *packet_buffer_size) { int result; @@ -643,7 +775,7 @@ static int filekey_encrypt_packet(char *key_filepath, int key_length_bytes, return result; } - result = key_encrypt_packet(saved_key, key_length_bytes, packet, + result = key_encrypt_packet(saved_key, key_length_bytes, packet, packet_size, packet_buffer, packet_buffer_size); if (result < 0) { debug_printf("Error encrypting packet under filekey_encrypt\n"); @@ -698,7 +830,7 @@ static int get_hash_of_buffer(unsigned char *buffer, size_t buffer_size, #endif static int key_encrypt_packet(unsigned char *key, int key_length_bytes, - handshake_packet_t *packet, + void *packet, size_t packet_size, unsigned char **packet_buffer, size_t *packet_buffer_size) { #if defined(GCRYPT) @@ -713,7 +845,7 @@ static int key_encrypt_packet(unsigned char *key, int key_length_bytes, initialized = 1; } - result = get_hash_of_buffer((unsigned char *) packet, sizeof(*packet), + result = get_hash_of_buffer((unsigned char *) packet, packet_size, key, key_length_bytes, &hash_result, &hash_result_size); if (result < 0) { @@ -722,11 +854,11 @@ static int key_encrypt_packet(unsigned char *key, int key_length_bytes, } debug_printf("Adding packet of size %lu and hash of size %u to buffer\n", - (unsigned long) sizeof(*packet), hash_result_size); - *packet_buffer_size = sizeof(*packet) + hash_result_size; + (unsigned long) packet_size, hash_result_size); + *packet_buffer_size = packet_size + hash_result_size; *packet_buffer = malloc(*packet_buffer_size); - memcpy(*packet_buffer, packet, sizeof(*packet)); - memcpy(*packet_buffer + sizeof(*packet), hash_result, hash_result_size); + memcpy(*packet_buffer, packet, packet_size); + memcpy(*packet_buffer + packet_size, hash_result, hash_result_size); free(hash_result); @@ -741,7 +873,6 @@ static int reliable_write(int fd, const void *buf, size_t size) { int result; size_t bytes_written = 0; - while (bytes_written < size) { result = write(fd, ((unsigned char *) buf) + bytes_written, size - bytes_written); if (result == -1 && errno == EINTR) @@ -802,60 +933,57 @@ static int reliable_read(int fd, void *buf, size_t size) return bytes_read; } -static int decrypt_packet(handshake_protocol_t *hdata, handshake_packet_t *expected_packet, +static int decrypt_packet(handshake_protocol_t *hdata, void *recvd_packet, size_t packet_size, unsigned char *recvd_buffer, size_t recvd_buffer_size) { switch (hdata->mechanism) { case hs_none: debug_printf("Checking packet with no encryption\n"); - return none_decrypt_packet(expected_packet, recvd_buffer, recvd_buffer_size); + return none_decrypt_packet(recvd_packet, packet_size, recvd_buffer, recvd_buffer_size); case hs_munge: debug_printf("Decrypting and checking packet with munge\n"); - return munge_decrypt_packet(expected_packet, recvd_buffer, recvd_buffer_size); + return munge_decrypt_packet(recvd_packet, packet_size, recvd_buffer, recvd_buffer_size); case hs_key_in_file: debug_printf("Decrypting packet with key from file\n"); assert(saved_key); return key_decrypt_packet(saved_key, saved_key_len, - expected_packet, recvd_buffer, recvd_buffer_size); + recvd_packet, packet_size, recvd_buffer, recvd_buffer_size); case hs_explicit_key: debug_printf("Decrypting packet with explicit key\n"); return key_decrypt_packet(hdata->data.explicit_key.key, hdata->data.explicit_key.key_length_bytes, - expected_packet, recvd_buffer, recvd_buffer_size); + recvd_packet, packet_size, recvd_buffer, recvd_buffer_size); } abort(); return HSHAKE_INTERNAL_ERROR; } -static int none_decrypt_packet(handshake_packet_t *expected_packet, +static int none_decrypt_packet(void *recvd_packet, size_t packet_size, unsigned char *recvd_buffer, size_t recvd_buffer_size) { #if defined(ENABLE_NULL_ENCRYPTION) - handshake_packet_t recvd_packet; - - if (recvd_buffer_size != sizeof(recvd_packet)) { + if (recvd_buffer_size != packet_size) { error_printf("Received buffer of size %lu, but expected size %lu\n", - (unsigned long) recvd_buffer_size, (unsigned long) sizeof(recvd_packet)); + (unsigned long) recvd_buffer_size, (unsigned long) packet_size); return HSHAKE_DROP_CONNECTION; } - memcpy(&recvd_packet, recvd_buffer, recvd_buffer_size); - return compare_packets(expected_packet, &recvd_packet); + memcpy(recvd_packet, recvd_buffer, recvd_buffer_size); + return 0; #else error_printf("Null encryption must be explicitly enabled\n"); return HSHAKE_INTERNAL_ERROR; #endif } -static int munge_decrypt_packet(handshake_packet_t *expected_packet, +static int munge_decrypt_packet(void *recvd_packet, size_t recvd_packet_size, unsigned char *recvd_buffer, size_t recvd_buffer_size) { #if defined(MUNGE) munge_err_t result; munge_ctx_t ctx = NULL; void *payload = NULL; - int payload_size, return_result, iresult; + int payload_size, return_result = 0, iresult; uid_t uid; gid_t gid; - handshake_packet_t *recvd_packet; iresult = munge_create_context(&ctx); if (iresult < 0) { @@ -863,7 +991,7 @@ static int munge_decrypt_packet(handshake_packet_t *expected_packet, return_result = iresult; goto done; } - + result = munge_decode((char *) recvd_buffer, ctx, &payload, &payload_size, &uid, &gid); switch (result) { case EMUNGE_SUCCESS: @@ -902,31 +1030,14 @@ static int munge_decrypt_packet(handshake_packet_t *expected_packet, security_error_printf("Unknown error return from munge: %s\n", munge_strerror(result)); return_result = HSHAKE_ABORT; goto done; - } - - if (payload_size != sizeof(*recvd_packet)) { - security_error_printf("Recieved munge packet with invalid payload size of %d\n", (int) payload_size); - return_result = HSHAKE_ABORT; - goto done; } - recvd_packet = (handshake_packet_t *) payload; - /* Munge provides a UID and GID. That should match the copy in the payload */ - if (recvd_packet->uid != uid) { - security_error_printf("Packet came from uid %d, but payload claimed uid %d\n", - (int) recvd_packet->uid, (int) uid); - return_result = HSHAKE_ABORT; - goto done; - } - if (recvd_packet->gid != gid) { - security_error_printf("Packet came from gid %d, but payload claimed gid %d\n", - (int) recvd_packet->gid, (int) gid); + if (payload_size != recvd_packet_size) { + security_error_printf("Recieved munge packet with invalid payload size of %d\n", (int) payload_size); return_result = HSHAKE_ABORT; goto done; } - - return_result = compare_packets(expected_packet, recvd_packet); - + memcpy(recvd_packet, payload, recvd_packet_size); done: if (payload) free(payload); @@ -934,7 +1045,6 @@ static int munge_decrypt_packet(handshake_packet_t *expected_packet, munge_ctx_destroy(ctx); return return_result; - #else error_printf("Handshake not compiled with munge support\n"); return HSHAKE_INTERNAL_ERROR; @@ -942,25 +1052,22 @@ static int munge_decrypt_packet(handshake_packet_t *expected_packet, } static int key_decrypt_packet(unsigned char *key, unsigned int key_len, - handshake_packet_t *expected_packet, + void *recvd_packet, size_t packet_size, unsigned char *recvd_buffer, size_t recvd_buffer_size) { #if defined(GCRYPT) - handshake_packet_t *recvd_packet; unsigned char *calcd_hash_val = NULL, *recvd_hash_val; int result, return_result, hash_val_size; int i; - if (recvd_buffer_size < sizeof(*expected_packet)) { + if (recvd_buffer_size < packet_size) { error_printf("Packet was too small. Size was %d, expected at least %d\n", - (int) recvd_buffer_size, (int) sizeof(*expected_packet)); + (int) recvd_buffer_size, (int) packet_size); return_result = HSHAKE_INTERNAL_ERROR; goto done; } - - recvd_packet = (handshake_packet_t *) recvd_buffer; - - result = get_hash_of_buffer((unsigned char *) recvd_packet, sizeof(*recvd_packet), + memcpy(recvd_packet, recvd_buffer, packet_size); + result = get_hash_of_buffer((unsigned char *) recvd_packet, packet_size, key, key_len, &calcd_hash_val, &hash_val_size); if (result < 0) { @@ -969,14 +1076,14 @@ static int key_decrypt_packet(unsigned char *key, unsigned int key_len, goto done; } - if (recvd_buffer_size != sizeof(*recvd_packet) + hash_val_size) { + if (recvd_buffer_size != packet_size + hash_val_size) { error_printf("Packet was too small. Size was %d, expected %d\n", - (int) recvd_buffer_size, (int) sizeof(*recvd_packet) + hash_val_size); + (int) recvd_buffer_size, (int) packet_size + hash_val_size); return_result = HSHAKE_INTERNAL_ERROR; goto done; } - recvd_hash_val = recvd_buffer + sizeof(*recvd_packet); + recvd_hash_val = recvd_buffer + packet_size; for (i = 0; i < hash_val_size; i++) { if (recvd_hash_val[i] != calcd_hash_val[i]) { security_error_printf("Hash signature of packet did not match expected value\n"); @@ -984,8 +1091,6 @@ static int key_decrypt_packet(unsigned char *key, unsigned int key_len, goto done; } } - - return_result = compare_packets(expected_packet, recvd_packet); done: if (calcd_hash_val) @@ -998,22 +1103,19 @@ static int key_decrypt_packet(unsigned char *key, unsigned int key_len, #endif } -static int compare_packets(handshake_packet_t *expected_packet, - handshake_packet_t *recvd_packet) +static int compare_packets(handshake_packet_t *expected_packet, handshake_packet_t *recvd_packet) { int i; - - if (expected_packet->session_id != recvd_packet->session_id) { + if (expected_packet->random_number.session_id != recvd_packet->random_number.session_id) { //If sessions don't match, expect that we've just recv a packet //from another instance of handshake running on the same node. //Drop connection - error_printf("Received mismatching session IDs. Expected %lu, got %lu\n", - expected_packet->session_id, recvd_packet->session_id); + debug_printf("Received mismatching session IDs in initial packed. Expected %lu, got %lu. Two different spindle sessions probably tried to connect.\n", + expected_packet->random_number.session_id, recvd_packet->random_number.session_id); return HSHAKE_DROP_CONNECTION; } - if (expected_packet->signature != recvd_packet->signature) { - security_error_printf("Received handshake with malformed signature. Expected %x, got %x\n", + security_error_printf("Received initial packet with malformed signature. Packet expected %x, got %x\n", expected_packet->signature, recvd_packet->signature); return HSHAKE_ABORT; } @@ -1029,7 +1131,7 @@ static int compare_packets(handshake_packet_t *expected_packet, (int) expected_packet->client_port, (int) recvd_packet->client_port); return HSHAKE_ABORT; } - + if (expected_packet->uid != recvd_packet->uid) { security_error_printf("Received handshake from another uid. Expected %d, got %d\n", (int) expected_packet->uid, (int) recvd_packet->uid); @@ -1060,6 +1162,31 @@ static int compare_packets(handshake_packet_t *expected_packet, return 0; } +static int compare_random_number_packets(random_number_packet_t *expected_random_number_packet, random_number_packet_t *recvd_random_number_packet) +{ + + if (expected_random_number_packet->signature != recvd_random_number_packet->signature) { + security_error_printf("Received random number packet with malformed signature. Packet expected %x, got %x\n", + expected_random_number_packet->signature, recvd_random_number_packet->signature); + return HSHAKE_ABORT; + } + + const random_number_t *a = &expected_random_number_packet->random_number; + const random_number_t *b = &recvd_random_number_packet->random_number; + if (a->random_number != b->random_number || a->session_id != b->session_id || a->counter != b->counter || strcmp(a->hostname, b->hostname) != 0){ + security_error_printf("Received handshake from another random number. Expected random number %ld, got %ld, session id: %lu got %lu, counter: %d got %d, hostname: %s got %s\n", + a->random_number, b->random_number, + a->session_id, b->session_id, + (int) a->counter, (int) b->counter, + + a->hostname, b->hostname); + return HSHAKE_ABORT; + } + + debug_printf("Packets compared equal.\n"); + return 0; +} + static int share_result(int fd, int handshake_result) { int32_t result_to_send, peer_result;