diff options
Diffstat (limited to 'src/handshake.c')
-rw-r--r-- | src/handshake.c | 62 |
1 files changed, 47 insertions, 15 deletions
diff --git a/src/handshake.c b/src/handshake.c index 9fe62cf..97c47cb 100644 --- a/src/handshake.c +++ b/src/handshake.c @@ -58,10 +58,42 @@ static const char *const RECORD_TYPES[RECORD_MAX] = { /** Reads a TLV record as an 8bit integer */ -#define AS_UINT8(ptr) (*(uint8_t *)(ptr).data) +static inline uint8_t as_uint8(const fastd_handshake_record_t *record) { + return record->data[0]; +} + +/** Reads a TLV record as a 16bit integer (little endian) */ +static inline uint16_t as_uint16(const fastd_handshake_record_t *record) { + return as_uint8(record) | (uint16_t)record->data[1] << 8; +} + +/** Reads a TLV record as a 24bit integer (little endian) */ +static inline uint32_t as_uint24(const fastd_handshake_record_t *record) { + return as_uint16(record) | (uint32_t)record->data[2] << 16; +} -/** Reads a TLV record as a 16bit integer (big endian) */ -#define AS_UINT16(ptr) ((*(uint8_t *)(ptr).data) + (*((uint8_t *)(ptr).data+1) << 8)) +/** Reads a TLV record as a 32bit integer (little endian) */ +static inline uint32_t as_uint32(const fastd_handshake_record_t *record) { + return as_uint24(record) | (uint32_t)record->data[3] << 24; +} + +/** Reads a TLV record as a variable-length integer (little endian) */ +static inline uint32_t as_uint(const fastd_handshake_record_t *record) { + switch(record->length) { + case 0: + return 0; + case 1: + return as_uint8(record); + case 2: + return as_uint16(record); + case 3: + return as_uint24(record); + case 4: + return as_uint32(record); + default: + return 0xffffffffULL; + } +} /** Generates a zero-separated list of supported methods */ @@ -167,7 +199,7 @@ fastd_buffer_t fastd_handshake_new_reply(uint8_t type, const fastd_method_info_t } /** Prints the error corresponding to the given reply code and error detail */ -static void print_error(const char *prefix, const fastd_peer_address_t *remote_addr, uint8_t reply_code, uint8_t error_detail) { +static void print_error(const char *prefix, const fastd_peer_address_t *remote_addr, uint8_t reply_code, uint16_t error_detail) { const char *error_field_str; if (error_detail >= RECORD_MAX) @@ -210,7 +242,7 @@ static void print_error(const char *prefix, const fastd_peer_address_t *remote_a } /** Sends an error reply to a peer */ -static void send_error(fastd_socket_t *sock, const fastd_peer_address_t *local_addr, const fastd_peer_address_t *remote_addr, fastd_peer_t *peer, const fastd_handshake_t *handshake, uint8_t reply_code, uint8_t error_detail) { +static void send_error(fastd_socket_t *sock, const fastd_peer_address_t *local_addr, const fastd_peer_address_t *remote_addr, fastd_peer_t *peer, const fastd_handshake_t *handshake, uint8_t reply_code, uint16_t error_detail) { print_error("sending", remote_addr, reply_code, error_detail); fastd_buffer_t buffer = fastd_buffer_alloc(sizeof(fastd_handshake_packet_t), 0, 3*5 /* enough space for handshake type, reply code and error detail */); @@ -221,7 +253,7 @@ static void send_error(fastd_socket_t *sock, const fastd_peer_address_t *local_a fastd_handshake_add_uint8(&buffer, RECORD_HANDSHAKE_TYPE, handshake->type+1); fastd_handshake_add_uint8(&buffer, RECORD_REPLY_CODE, reply_code); - fastd_handshake_add_uint8(&buffer, RECORD_ERROR_DETAIL, error_detail); + fastd_handshake_add_uint(&buffer, RECORD_ERROR_DETAIL, error_detail); fastd_send_handshake(sock, local_addr, remote_addr, peer, buffer); } @@ -271,11 +303,11 @@ static inline fastd_handshake_t parse_tlvs(const fastd_buffer_t *buffer) { /** Prints the error found in a received handshake */ static inline void print_error_reply(const fastd_peer_address_t *remote_addr, const fastd_handshake_t *handshake) { - uint8_t reply_code = AS_UINT8(handshake->records[RECORD_REPLY_CODE]); - uint8_t error_detail = RECORD_MAX; + uint8_t reply_code = as_uint8(&handshake->records[RECORD_REPLY_CODE]); + uint16_t error_detail = RECORD_MAX; - if (handshake->records[RECORD_ERROR_DETAIL].length == 1) - error_detail = AS_UINT8(handshake->records[RECORD_ERROR_DETAIL]); + if (handshake->records[RECORD_ERROR_DETAIL].length == 1 || handshake->records[RECORD_ERROR_DETAIL].length == 2) + error_detail = as_uint(&handshake->records[RECORD_ERROR_DETAIL]); print_error("received", remote_addr, reply_code, error_detail); } @@ -290,7 +322,7 @@ static inline bool check_records(fastd_socket_t *sock, const fastd_peer_address_ } if (handshake->records[RECORD_MODE].data) { - if (handshake->records[RECORD_MODE].length != 1 || AS_UINT8(handshake->records[RECORD_MODE]) != conf.mode) { + if (handshake->records[RECORD_MODE].length != 1 || as_uint8(&handshake->records[RECORD_MODE]) != conf.mode) { send_error(sock, local_addr, remote_addr, peer, handshake, REPLY_UNACCEPTABLE_VALUE, RECORD_MODE); return false; } @@ -298,9 +330,9 @@ static inline bool check_records(fastd_socket_t *sock, const fastd_peer_address_ if (!conf.secure_handshakes || handshake->type > 1) { if (handshake->records[RECORD_MTU].length == 2) { - if (AS_UINT16(handshake->records[RECORD_MTU]) != conf.mtu) { + if (as_uint16(&handshake->records[RECORD_MTU]) != conf.mtu) { pr_warn("MTU configuration differs with peer %I: local MTU is %u, remote MTU is %u", - remote_addr, conf.mtu, AS_UINT16(handshake->records[RECORD_MTU])); + remote_addr, conf.mtu, as_uint16(&handshake->records[RECORD_MTU])); } } } @@ -311,7 +343,7 @@ static inline bool check_records(fastd_socket_t *sock, const fastd_peer_address_ return false; } - if (AS_UINT8(handshake->records[RECORD_REPLY_CODE]) != REPLY_SUCCESS) { + if (as_uint8(&handshake->records[RECORD_REPLY_CODE]) != REPLY_SUCCESS) { print_error_reply(remote_addr, handshake); return false; } @@ -372,7 +404,7 @@ void fastd_handshake_handle(fastd_socket_t *sock, const fastd_peer_address_t *lo goto end_free; } - handshake.type = AS_UINT8(handshake.records[RECORD_HANDSHAKE_TYPE]); + handshake.type = as_uint8(&handshake.records[RECORD_HANDSHAKE_TYPE]); if (!check_records(sock, local_addr, remote_addr, peer, &handshake)) goto end_free; |