diff options
-rw-r--r-- | src/handshake.c | 34 | ||||
-rw-r--r-- | src/handshake.h | 50 |
2 files changed, 44 insertions, 40 deletions
diff --git a/src/handshake.c b/src/handshake.c index 47da0e4..82d47b8 100644 --- a/src/handshake.c +++ b/src/handshake.c @@ -98,10 +98,10 @@ fastd_buffer_t fastd_handshake_new_init(fastd_context_t *ctx, size_t tail_space) 4+method_list_len + /* supported method name list */ tail_space ); - fastd_handshake_packet_t *request = buffer.data; + fastd_handshake_packet_t *packet = buffer.data; - request->rsv1 = 0; - request->rsv2 = 0; + packet->rsv = 0; + packet->tlv_len = 0; fastd_handshake_add_uint8(ctx, &buffer, RECORD_HANDSHAKE_TYPE, 1); fastd_handshake_add_uint8(ctx, &buffer, RECORD_MODE, ctx->conf->mode); @@ -146,10 +146,10 @@ fastd_buffer_t fastd_handshake_new_reply(fastd_context_t *ctx, const fastd_hands extra_size + tail_space ); - fastd_handshake_packet_t *request = buffer.data; + fastd_handshake_packet_t *packet = buffer.data; - request->rsv1 = 0; - request->rsv2 = 0; + packet->rsv = 0; + packet->tlv_len = 0; fastd_handshake_add_uint8(ctx, &buffer, RECORD_HANDSHAKE_TYPE, AS_UINT8(handshake->records[RECORD_HANDSHAKE_TYPE])+1); fastd_handshake_add_uint8(ctx, &buffer, RECORD_REPLY_CODE, 0); @@ -186,15 +186,27 @@ void fastd_handshake_handle(fastd_context_t *ctx, fastd_socket_t *sock, const fa fastd_handshake_t handshake = { .buffer = buffer }; fastd_handshake_packet_t *packet = buffer.data; - uint8_t *ptr = packet->tlv_data; + size_t len = buffer.len - sizeof(fastd_handshake_packet_t); + if (packet->tlv_len) { + size_t tlv_len = ntohs(packet->tlv_len); + if (tlv_len > len) { + pr_warn(ctx, "received a short handshake from %I", remote_addr); + goto end_free; + } + + len = tlv_len; + } + + uint8_t *ptr = packet->tlv_data, *end = packet->tlv_data + len; + while (true) { - if (ptr+4 > (uint8_t*)buffer.data + buffer.len) + if (ptr+4 > end) break; uint16_t type = ptr[0] + (ptr[1] << 8); uint16_t len = ptr[2] + (ptr[3] << 8); - if (ptr+4+len > (uint8_t*)buffer.data + buffer.len) + if (ptr+4+len > end) break; if (type < RECORD_MAX) { @@ -292,8 +304,8 @@ void fastd_handshake_handle(fastd_context_t *ctx, fastd_socket_t *sock, const fa fastd_buffer_t reply_buffer = fastd_buffer_alloc(ctx, sizeof(fastd_handshake_packet_t), 0, 3*5 /* enough space for handshake type, reply code and error detail */); fastd_handshake_packet_t *reply = reply_buffer.data; - reply->rsv1 = 0; - reply->rsv2 = 0; + reply->rsv = 0; + reply->tlv_len = 0; fastd_handshake_add_uint8(ctx, &reply_buffer, RECORD_HANDSHAKE_TYPE, 2); fastd_handshake_add_uint8(ctx, &reply_buffer, RECORD_REPLY_CODE, reply_code); diff --git a/src/handshake.h b/src/handshake.h index 2fae1ad..e548be3 100644 --- a/src/handshake.h +++ b/src/handshake.h @@ -59,8 +59,8 @@ typedef enum fastd_reply_code { typedef struct __attribute__((__packed__)) fastd_handshake_packet { - uint8_t rsv1; - uint16_t rsv2; + uint8_t rsv; + uint16_t tlv_len; uint8_t tlv_data[]; } fastd_handshake_packet_t; @@ -82,50 +82,42 @@ fastd_buffer_t fastd_handshake_new_reply(fastd_context_t *ctx, const fastd_hands void fastd_handshake_handle(fastd_context_t *ctx, fastd_socket_t *sock, const fastd_peer_address_t *local_addr, const fastd_peer_address_t *remote_addr, fastd_peer_t *peer, fastd_buffer_t buffer); -static inline void fastd_handshake_add(fastd_context_t *ctx, fastd_buffer_t *buffer, fastd_handshake_record_type_t type, size_t len, const void *data) { - if ((uint8_t*)buffer->data + buffer->len + 4 + len > (uint8_t*)buffer->base + buffer->base_len) +static inline uint8_t* fastd_handshake_extend(fastd_context_t *ctx, fastd_buffer_t *buffer, fastd_handshake_record_type_t type, size_t len) { + uint8_t *dst = buffer->data + buffer->len; + + if (buffer->data + buffer->len + 4 + len > buffer->base + buffer->base_len) exit_bug(ctx, "not enough buffer allocated for handshake"); - uint8_t *dst = (uint8_t*)buffer->data + buffer->len; + buffer->len += 4 + len; + + fastd_handshake_packet_t *packet = buffer->data; + packet->tlv_len = htons(ntohs(packet->tlv_len) + 4 + len); dst[0] = type; dst[1] = type >> 8; dst[2] = len; dst[3] = len >> 8; - memcpy(dst+4, data, len); - buffer->len += 4 + len; + return dst+4; } -static inline void fastd_handshake_add_uint8(fastd_context_t *ctx, fastd_buffer_t *buffer, fastd_handshake_record_type_t type, uint8_t value) { - if ((uint8_t*)buffer->data + buffer->len + 5 > (uint8_t*)buffer->base + buffer->base_len) - exit_bug(ctx, "not enough buffer allocated for handshake"); +static inline void fastd_handshake_add(fastd_context_t *ctx, fastd_buffer_t *buffer, fastd_handshake_record_type_t type, size_t len, const void *data) { + uint8_t *dst = fastd_handshake_extend(ctx, buffer, type, len); - uint8_t *dst = (uint8_t*)buffer->data + buffer->len; + memcpy(dst, data, len); +} - dst[0] = type; - dst[1] = type >> 8; - dst[2] = 1; - dst[3] = 0; - dst[4] = value; +static inline void fastd_handshake_add_uint8(fastd_context_t *ctx, fastd_buffer_t *buffer, fastd_handshake_record_type_t type, uint8_t value) { + uint8_t *dst = fastd_handshake_extend(ctx, buffer, type, 1); - buffer->len += 5; + dst[0] = value; } static inline void fastd_handshake_add_uint16(fastd_context_t *ctx, fastd_buffer_t *buffer, fastd_handshake_record_type_t type, uint16_t value) { - if ((uint8_t*)buffer->data + buffer->len + 6 > (uint8_t*)buffer->base + buffer->base_len) - exit_bug(ctx, "not enough buffer allocated for handshake"); - - uint8_t *dst = (uint8_t*)buffer->data + buffer->len; - - dst[0] = type; - dst[1] = type >> 8; - dst[2] = 2; - dst[3] = 0; - dst[4] = value; - dst[5] = value >> 8; + uint8_t *dst = fastd_handshake_extend(ctx, buffer, type, 2); - buffer->len += 6; + dst[0] = value; + dst[1] = value >> 8; } |