diff options
Diffstat (limited to 'src/handshake.c')
-rw-r--r-- | src/handshake.c | 68 |
1 files changed, 44 insertions, 24 deletions
diff --git a/src/handshake.c b/src/handshake.c index a55844f..7dd5b44 100644 --- a/src/handshake.c +++ b/src/handshake.c @@ -78,8 +78,11 @@ static inline uint32_t as_uint32(const fastd_handshake_record_t *record) { } /** Reads a TLV record as a 16bit integer (little endian) */ -static inline uint16_t as_uint16_le(const fastd_handshake_record_t *record) { - return as_uint8(record) | (uint16_t)record->data[1] << 8; +static inline uint16_t as_uint16_endian(const fastd_handshake_record_t *record, bool little_endian) { + if (little_endian) + return as_uint8(record) | (uint16_t)record->data[1] << 8; + else + return as_uint16(record); } /** Reads a TLV record as a variable-length integer (little endian) */ @@ -157,7 +160,7 @@ static fastd_string_stack_t * parse_string_list(const uint8_t *data, size_t len) } /** Allocates and initializes a new handshake packet */ -static fastd_buffer_t new_handshake(uint8_t type, const fastd_method_info_t *method, bool with_method_list, size_t tail_space) { +static fastd_handshake_buffer_t new_handshake(uint8_t type, bool little_endian, const fastd_method_info_t *method, bool with_method_list, size_t tail_space) { size_t version_len = strlen(FASTD_VERSION); size_t protocol_len = strlen(conf.protocol->name); size_t method_len = method ? strlen(method->name) : 0; @@ -168,22 +171,24 @@ static fastd_buffer_t new_handshake(uint8_t type, const fastd_method_info_t *met if (with_method_list) method_list = create_method_list(&method_list_len); - fastd_buffer_t buffer = fastd_buffer_alloc(sizeof(fastd_handshake_packet_t), 1, - 3*5 + /* handshake type, mode, reply code */ - 6 + /* MTU */ - 4+version_len + /* version name */ - 4+protocol_len + /* protocol name */ - 4+method_len + /* method name */ - 4+method_list_len + /* supported method name list */ - tail_space); - fastd_handshake_packet_t *packet = buffer.data; + fastd_handshake_buffer_t buffer = { + .buffer = fastd_buffer_alloc(sizeof(fastd_handshake_packet_t), 1, + 3*5 + /* handshake type, mode, reply code */ + 6 + /* MTU */ + 4+version_len + /* version name */ + 4+protocol_len + /* protocol name */ + 4+method_len + /* method name */ + 4+method_list_len + /* supported method name list */ + tail_space), + .little_endian = little_endian}; + fastd_handshake_packet_t *packet = buffer.buffer.data; packet->rsv = 0; packet->tlv_len = 0; fastd_handshake_add_uint8(&buffer, RECORD_HANDSHAKE_TYPE, type); fastd_handshake_add_uint8(&buffer, RECORD_MODE, conf.mode); - fastd_handshake_add_uint16_le(&buffer, RECORD_MTU, conf.mtu); + fastd_handshake_add_uint16_endian(&buffer, RECORD_MTU, conf.mtu); fastd_handshake_add(&buffer, RECORD_VERSION_NAME, version_len, FASTD_VERSION); fastd_handshake_add(&buffer, RECORD_PROTOCOL_NAME, protocol_len, conf.protocol->name); @@ -200,13 +205,13 @@ static fastd_buffer_t new_handshake(uint8_t type, const fastd_method_info_t *met } /** Allocates and initializes a new initial handshake packet */ -fastd_buffer_t fastd_handshake_new_init(size_t tail_space) { - return new_handshake(1, NULL, !conf.secure_handshakes, tail_space); +fastd_handshake_buffer_t fastd_handshake_new_init(size_t tail_space) { + return new_handshake(1, true, NULL, !conf.secure_handshakes, tail_space); } /** Allocates and initializes a new reply handshake packet */ -fastd_buffer_t fastd_handshake_new_reply(uint8_t type, const fastd_method_info_t *method, bool with_method_list, size_t tail_space) { - fastd_buffer_t buffer = new_handshake(type, method, with_method_list, tail_space); +fastd_handshake_buffer_t fastd_handshake_new_reply(uint8_t type, bool little_endian, const fastd_method_info_t *method, bool with_method_list, size_t tail_space) { + fastd_handshake_buffer_t buffer = new_handshake(type, little_endian, method, with_method_list, tail_space); fastd_handshake_add_uint8(&buffer, RECORD_REPLY_CODE, 0); return buffer; } @@ -258,8 +263,11 @@ static void print_error(const char *prefix, const fastd_peer_address_t *remote_a static void send_error(fastd_socket_t *sock, const fastd_peer_address_t *local_addr, const fastd_peer_address_t *remote_addr, fastd_peer_t *peer, const fastd_handshake_t *handshake, uint8_t reply_code, uint16_t error_detail) { print_error("sending", remote_addr, reply_code, error_detail); - fastd_buffer_t buffer = fastd_buffer_alloc(sizeof(fastd_handshake_packet_t), 0, 3*5 /* enough space for handshake type, reply code and error detail */); - fastd_handshake_packet_t *reply = buffer.data; + fastd_handshake_buffer_t buffer = { + .buffer = fastd_buffer_alloc(sizeof(fastd_handshake_packet_t), 0, 3*5 /* enough space for handshake type, reply code and error detail */), + .little_endian = handshake->little_endian + }; + fastd_handshake_packet_t *reply = buffer.buffer.data; reply->rsv = 0; reply->tlv_len = 0; @@ -268,7 +276,7 @@ static void send_error(fastd_socket_t *sock, const fastd_peer_address_t *local_a fastd_handshake_add_uint8(&buffer, RECORD_REPLY_CODE, reply_code); fastd_handshake_add_uint(&buffer, RECORD_ERROR_DETAIL, error_detail); - fastd_send_handshake(sock, local_addr, remote_addr, peer, buffer); + fastd_send_handshake(sock, local_addr, remote_addr, peer, buffer.buffer); } /** Parses the TLV records of a handshake */ @@ -297,8 +305,20 @@ static inline fastd_handshake_t parse_tlvs(const fastd_buffer_t *buffer) { if (ptr+4 > end) break; - uint16_t type = ptr[0] + (ptr[1] << 8); - uint16_t len = ptr[2] + (ptr[3] << 8); + uint16_t type, len; + + if (!handshake.little_endian) { + type = ptr[1] + (ptr[0] << 8); + len = ptr[3] + (ptr[2] << 8); + + if (type > 0xff || (type == 0 && len > 0xff)) + handshake.little_endian = true; + } + + if (handshake.little_endian) { + type = ptr[0] + (ptr[1] << 8); + len = ptr[2] + (ptr[3] << 8); + } if (ptr+4+len > end) break; @@ -343,9 +363,9 @@ static inline bool check_records(fastd_socket_t *sock, const fastd_peer_address_ if (!conf.secure_handshakes || handshake->type > 1) { if (handshake->records[RECORD_MTU].length == 2) { - if (as_uint16_le(&handshake->records[RECORD_MTU]) != conf.mtu) { + if (as_uint16_endian(&handshake->records[RECORD_MTU], handshake->little_endian) != conf.mtu) { pr_warn("MTU configuration differs with peer %I: local MTU is %u, remote MTU is %u", - remote_addr, conf.mtu, as_uint16_le(&handshake->records[RECORD_MTU])); + remote_addr, conf.mtu, as_uint16_endian(&handshake->records[RECORD_MTU], handshake->little_endian)); } } } |