diff options
Diffstat (limited to 'src/handshake.c')
-rw-r--r-- | src/handshake.c | 46 |
1 files changed, 25 insertions, 21 deletions
diff --git a/src/handshake.c b/src/handshake.c index 726c117..000d685 100644 --- a/src/handshake.c +++ b/src/handshake.c @@ -176,7 +176,7 @@ static fastd_string_stack_t * parse_string_list(const uint8_t *data, size_t len) } /** Allocates and initializes a new handshake packet */ -static fastd_handshake_buffer_t new_handshake(uint8_t type, bool little_endian, const fastd_method_info_t *method, const fastd_string_stack_t *methods, size_t tail_space) { +static fastd_handshake_buffer_t new_handshake(uint8_t type, bool little_endian, uint16_t mtu, const fastd_method_info_t *method, const fastd_string_stack_t *methods, size_t tail_space) { size_t version_len = strlen(FASTD_VERSION); size_t protocol_len = strlen(conf.protocol->name); size_t method_len = method ? strlen(method->name) : 0; @@ -190,7 +190,7 @@ static fastd_handshake_buffer_t new_handshake(uint8_t type, bool little_endian, fastd_handshake_buffer_t buffer = { .buffer = fastd_buffer_alloc(sizeof(fastd_handshake_packet_t), 1, 3*5 + /* handshake type, mode, reply code */ - 6 + /* MTU */ + (mtu ? 6 : 0) + /* MTU */ 4+version_len + /* version name */ 4+protocol_len + /* protocol name */ 4+method_len + /* method name */ @@ -204,7 +204,9 @@ static fastd_handshake_buffer_t new_handshake(uint8_t type, bool little_endian, fastd_handshake_add_uint8(&buffer, RECORD_HANDSHAKE_TYPE, type); fastd_handshake_add_uint8(&buffer, RECORD_MODE, get_mode_id()); - fastd_handshake_add_uint16_endian(&buffer, RECORD_MTU, conf.mtu); + + if (mtu) + fastd_handshake_add_uint16_endian(&buffer, RECORD_MTU, mtu); fastd_handshake_add(&buffer, RECORD_VERSION_NAME, version_len, FASTD_VERSION); fastd_handshake_add(&buffer, RECORD_PROTOCOL_NAME, protocol_len, conf.protocol->name); @@ -222,18 +224,18 @@ static fastd_handshake_buffer_t new_handshake(uint8_t type, bool little_endian, /** Allocates and initializes a new initial handshake packet */ fastd_handshake_buffer_t fastd_handshake_new_init(size_t tail_space) { - return new_handshake(1, true, NULL, conf.secure_handshakes ? NULL : conf.peer_group->methods, tail_space); + return new_handshake(1, true, 0, NULL, conf.secure_handshakes ? NULL : conf.peer_group->methods, tail_space); } /** Allocates and initializes a new reply handshake packet */ -fastd_handshake_buffer_t fastd_handshake_new_reply(uint8_t type, bool little_endian, const fastd_method_info_t *method, const fastd_string_stack_t *methods, size_t tail_space) { - fastd_handshake_buffer_t buffer = new_handshake(type, little_endian, method, methods, tail_space); +fastd_handshake_buffer_t fastd_handshake_new_reply(uint8_t type, bool little_endian, uint16_t mtu, const fastd_method_info_t *method, const fastd_string_stack_t *methods, size_t tail_space) { + fastd_handshake_buffer_t buffer = new_handshake(type, little_endian, mtu, method, methods, tail_space); fastd_handshake_add_uint8(&buffer, RECORD_REPLY_CODE, 0); return buffer; } /** Prints the error corresponding to the given reply code and error detail */ -static void print_error(const char *prefix, const fastd_peer_address_t *remote_addr, uint8_t reply_code, uint16_t error_detail) { +static void print_error(const char *prefix, const fastd_peer_t *peer, const fastd_peer_address_t *remote_addr, uint8_t reply_code, uint16_t error_detail) { const char *error_field_str; if (error_detail >= RECORD_MAX) @@ -260,7 +262,7 @@ static void print_error(const char *prefix, const fastd_peer_address_t *remote_a break; case RECORD_MTU: - pr_warn("Handshake with %I failed: %s error: MTU configuration differs with peer (local MTU is %u)", remote_addr, prefix, conf.mtu); + pr_warn("Handshake with %I failed: %s error: MTU configuration differs with peer (local MTU is %u)", remote_addr, prefix, fastd_peer_get_mtu(peer)); break; case RECORD_METHOD_NAME: @@ -281,7 +283,7 @@ static void print_error(const char *prefix, const fastd_peer_address_t *remote_a /** Sends an error reply to a peer */ void fastd_handshake_send_error(fastd_socket_t *sock, const fastd_peer_address_t *local_addr, const fastd_peer_address_t *remote_addr, fastd_peer_t *peer, const fastd_handshake_t *handshake, uint8_t reply_code, uint16_t error_detail) { - print_error("sending", remote_addr, reply_code, error_detail); + print_error("sending", peer, remote_addr, reply_code, error_detail); fastd_handshake_buffer_t buffer = { .buffer = fastd_buffer_alloc(sizeof(fastd_handshake_packet_t), 0, 3*5 /* enough space for handshake type, reply code and error detail */), @@ -355,14 +357,14 @@ static inline fastd_handshake_t parse_tlvs(const fastd_buffer_t *buffer) { } /** Prints the error found in a received handshake */ -static inline void print_error_reply(const fastd_peer_address_t *remote_addr, const fastd_handshake_t *handshake) { +static inline void print_error_reply(const fastd_peer_t *peer, const fastd_peer_address_t *remote_addr, const fastd_handshake_t *handshake) { uint8_t reply_code = as_uint8(&handshake->records[RECORD_REPLY_CODE]); uint16_t error_detail = RECORD_MAX; if (handshake->records[RECORD_ERROR_DETAIL].length == 1 || handshake->records[RECORD_ERROR_DETAIL].length == 2) error_detail = as_uint(&handshake->records[RECORD_ERROR_DETAIL]); - print_error("received", remote_addr, reply_code, error_detail); + print_error("received", peer, remote_addr, reply_code, error_detail); } /** Does some basic validity checks on a received handshake */ @@ -381,15 +383,6 @@ static inline bool check_records(fastd_socket_t *sock, const fastd_peer_address_ } } - if (!conf.secure_handshakes || handshake->type > 1) { - if (handshake->records[RECORD_MTU].length == 2) { - if (as_uint16_endian(&handshake->records[RECORD_MTU], handshake->little_endian) != conf.mtu) { - fastd_handshake_send_error(sock, local_addr, remote_addr, peer, handshake, REPLY_UNACCEPTABLE_VALUE, RECORD_MTU); - return false; - } - } - } - if (handshake->type > 1) { if (handshake->records[RECORD_REPLY_CODE].length != 1) { pr_warn("received handshake reply without reply code from %I", remote_addr); @@ -397,7 +390,18 @@ static inline bool check_records(fastd_socket_t *sock, const fastd_peer_address_ } if (as_uint8(&handshake->records[RECORD_REPLY_CODE]) != REPLY_SUCCESS) { - print_error_reply(remote_addr, handshake); + print_error_reply(peer, remote_addr, handshake); + return false; + } + } + + return true; +} + +bool fastd_handshake_check_mtu(fastd_socket_t *sock, const fastd_peer_address_t *local_addr, const fastd_peer_address_t *remote_addr, fastd_peer_t *peer, const fastd_handshake_t *handshake) { + if (handshake->records[RECORD_MTU].length == 2) { + if (as_uint16_endian(&handshake->records[RECORD_MTU], handshake->little_endian) != fastd_peer_get_mtu(peer)) { + fastd_handshake_send_error(sock, local_addr, remote_addr, peer, handshake, REPLY_UNACCEPTABLE_VALUE, RECORD_MTU); return false; } } |