summaryrefslogtreecommitdiffstats
path: root/src/handshake.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/handshake.c')
-rw-r--r--src/handshake.c62
1 files changed, 47 insertions, 15 deletions
diff --git a/src/handshake.c b/src/handshake.c
index 9fe62cf..97c47cb 100644
--- a/src/handshake.c
+++ b/src/handshake.c
@@ -58,10 +58,42 @@ static const char *const RECORD_TYPES[RECORD_MAX] = {
/** Reads a TLV record as an 8bit integer */
-#define AS_UINT8(ptr) (*(uint8_t *)(ptr).data)
+static inline uint8_t as_uint8(const fastd_handshake_record_t *record) {
+ return record->data[0];
+}
+
+/** Reads a TLV record as a 16bit integer (little endian) */
+static inline uint16_t as_uint16(const fastd_handshake_record_t *record) {
+ return as_uint8(record) | (uint16_t)record->data[1] << 8;
+}
+
+/** Reads a TLV record as a 24bit integer (little endian) */
+static inline uint32_t as_uint24(const fastd_handshake_record_t *record) {
+ return as_uint16(record) | (uint32_t)record->data[2] << 16;
+}
-/** Reads a TLV record as a 16bit integer (big endian) */
-#define AS_UINT16(ptr) ((*(uint8_t *)(ptr).data) + (*((uint8_t *)(ptr).data+1) << 8))
+/** Reads a TLV record as a 32bit integer (little endian) */
+static inline uint32_t as_uint32(const fastd_handshake_record_t *record) {
+ return as_uint24(record) | (uint32_t)record->data[3] << 24;
+}
+
+/** Reads a TLV record as a variable-length integer (little endian) */
+static inline uint32_t as_uint(const fastd_handshake_record_t *record) {
+ switch(record->length) {
+ case 0:
+ return 0;
+ case 1:
+ return as_uint8(record);
+ case 2:
+ return as_uint16(record);
+ case 3:
+ return as_uint24(record);
+ case 4:
+ return as_uint32(record);
+ default:
+ return 0xffffffffULL;
+ }
+}
/** Generates a zero-separated list of supported methods */
@@ -167,7 +199,7 @@ fastd_buffer_t fastd_handshake_new_reply(uint8_t type, const fastd_method_info_t
}
/** Prints the error corresponding to the given reply code and error detail */
-static void print_error(const char *prefix, const fastd_peer_address_t *remote_addr, uint8_t reply_code, uint8_t error_detail) {
+static void print_error(const char *prefix, const fastd_peer_address_t *remote_addr, uint8_t reply_code, uint16_t error_detail) {
const char *error_field_str;
if (error_detail >= RECORD_MAX)
@@ -210,7 +242,7 @@ static void print_error(const char *prefix, const fastd_peer_address_t *remote_a
}
/** Sends an error reply to a peer */
-static void send_error(fastd_socket_t *sock, const fastd_peer_address_t *local_addr, const fastd_peer_address_t *remote_addr, fastd_peer_t *peer, const fastd_handshake_t *handshake, uint8_t reply_code, uint8_t error_detail) {
+static void send_error(fastd_socket_t *sock, const fastd_peer_address_t *local_addr, const fastd_peer_address_t *remote_addr, fastd_peer_t *peer, const fastd_handshake_t *handshake, uint8_t reply_code, uint16_t error_detail) {
print_error("sending", remote_addr, reply_code, error_detail);
fastd_buffer_t buffer = fastd_buffer_alloc(sizeof(fastd_handshake_packet_t), 0, 3*5 /* enough space for handshake type, reply code and error detail */);
@@ -221,7 +253,7 @@ static void send_error(fastd_socket_t *sock, const fastd_peer_address_t *local_a
fastd_handshake_add_uint8(&buffer, RECORD_HANDSHAKE_TYPE, handshake->type+1);
fastd_handshake_add_uint8(&buffer, RECORD_REPLY_CODE, reply_code);
- fastd_handshake_add_uint8(&buffer, RECORD_ERROR_DETAIL, error_detail);
+ fastd_handshake_add_uint(&buffer, RECORD_ERROR_DETAIL, error_detail);
fastd_send_handshake(sock, local_addr, remote_addr, peer, buffer);
}
@@ -271,11 +303,11 @@ static inline fastd_handshake_t parse_tlvs(const fastd_buffer_t *buffer) {
/** Prints the error found in a received handshake */
static inline void print_error_reply(const fastd_peer_address_t *remote_addr, const fastd_handshake_t *handshake) {
- uint8_t reply_code = AS_UINT8(handshake->records[RECORD_REPLY_CODE]);
- uint8_t error_detail = RECORD_MAX;
+ uint8_t reply_code = as_uint8(&handshake->records[RECORD_REPLY_CODE]);
+ uint16_t error_detail = RECORD_MAX;
- if (handshake->records[RECORD_ERROR_DETAIL].length == 1)
- error_detail = AS_UINT8(handshake->records[RECORD_ERROR_DETAIL]);
+ if (handshake->records[RECORD_ERROR_DETAIL].length == 1 || handshake->records[RECORD_ERROR_DETAIL].length == 2)
+ error_detail = as_uint(&handshake->records[RECORD_ERROR_DETAIL]);
print_error("received", remote_addr, reply_code, error_detail);
}
@@ -290,7 +322,7 @@ static inline bool check_records(fastd_socket_t *sock, const fastd_peer_address_
}
if (handshake->records[RECORD_MODE].data) {
- if (handshake->records[RECORD_MODE].length != 1 || AS_UINT8(handshake->records[RECORD_MODE]) != conf.mode) {
+ if (handshake->records[RECORD_MODE].length != 1 || as_uint8(&handshake->records[RECORD_MODE]) != conf.mode) {
send_error(sock, local_addr, remote_addr, peer, handshake, REPLY_UNACCEPTABLE_VALUE, RECORD_MODE);
return false;
}
@@ -298,9 +330,9 @@ static inline bool check_records(fastd_socket_t *sock, const fastd_peer_address_
if (!conf.secure_handshakes || handshake->type > 1) {
if (handshake->records[RECORD_MTU].length == 2) {
- if (AS_UINT16(handshake->records[RECORD_MTU]) != conf.mtu) {
+ if (as_uint16(&handshake->records[RECORD_MTU]) != conf.mtu) {
pr_warn("MTU configuration differs with peer %I: local MTU is %u, remote MTU is %u",
- remote_addr, conf.mtu, AS_UINT16(handshake->records[RECORD_MTU]));
+ remote_addr, conf.mtu, as_uint16(&handshake->records[RECORD_MTU]));
}
}
}
@@ -311,7 +343,7 @@ static inline bool check_records(fastd_socket_t *sock, const fastd_peer_address_
return false;
}
- if (AS_UINT8(handshake->records[RECORD_REPLY_CODE]) != REPLY_SUCCESS) {
+ if (as_uint8(&handshake->records[RECORD_REPLY_CODE]) != REPLY_SUCCESS) {
print_error_reply(remote_addr, handshake);
return false;
}
@@ -372,7 +404,7 @@ void fastd_handshake_handle(fastd_socket_t *sock, const fastd_peer_address_t *lo
goto end_free;
}
- handshake.type = AS_UINT8(handshake.records[RECORD_HANDSHAKE_TYPE]);
+ handshake.type = as_uint8(&handshake.records[RECORD_HANDSHAKE_TYPE]);
if (!check_records(sock, local_addr, remote_addr, peer, &handshake))
goto end_free;