From 8cbd59792e7f03de927593994fb85466b7432d39 Mon Sep 17 00:00:00 2001 From: Matthias Schiffer Date: Sun, 20 Oct 2013 02:37:04 +0200 Subject: Refactor handshake code, prevent downgrade attacks --- src/handshake.c | 379 +++++++++++++++++++++++++++----------------------------- 1 file changed, 181 insertions(+), 198 deletions(-) (limited to 'src/handshake.c') diff --git a/src/handshake.c b/src/handshake.c index d0134b9..46a1357 100644 --- a/src/handshake.c +++ b/src/handshake.c @@ -81,123 +81,160 @@ static uint8_t* create_method_list(fastd_context_t *ctx, size_t *len) { return ret; } -fastd_buffer_t fastd_handshake_new_init(fastd_context_t *ctx, size_t tail_space) { +static inline bool string_equal(const char *str, const char *buf, size_t maxlen) { + if (strlen(str) != strnlen(buf, maxlen)) + return false; + + return !strncmp(str, buf, maxlen); +} + +static inline bool record_equal(const char *str, const fastd_handshake_record_t *record) { + return string_equal(str, (const char*)record->data, record->length); +} + +static const fastd_method_t* method_from_name(fastd_context_t *ctx, const char *name, size_t n) { + int i; + for (i = 0; i < MAX_METHODS; i++) { + if (!ctx->conf->methods[i]) + break; + + if (string_equal(ctx->conf->methods[i]->name, name, n)) + return ctx->conf->methods[i]; + } + + return NULL; +} + +static fastd_string_stack_t* parse_string_list(const uint8_t *data, size_t len) { + const uint8_t *end = data+len; + fastd_string_stack_t *ret = NULL; + + while (data < end) { + fastd_string_stack_t *part = fastd_string_stack_dupn((char*)data, end-data); + part->next = ret; + ret = part; + data += strlen(part->str) + 1; + } + + return ret; +} + +static fastd_buffer_t new_handshake(fastd_context_t *ctx, uint8_t type, const fastd_method_t *method, bool with_method_list, size_t tail_space) { size_t version_len = strlen(FASTD_VERSION); size_t protocol_len = strlen(ctx->conf->protocol->name); - size_t method_len = strlen(ctx->conf->method_default->name); + size_t method_len = method ? strlen(method->name) : 0; - size_t method_list_len; - uint8_t *method_list = create_method_list(ctx, &method_list_len); + size_t method_list_len = 0; + uint8_t *method_list = NULL; + + if (with_method_list) + method_list = create_method_list(ctx, &method_list_len); fastd_buffer_t buffer = fastd_buffer_alloc(ctx, sizeof(fastd_handshake_packet_t), 0, - 2*5 + /* handshake type, mode */ + 3*5 + /* handshake type, mode, reply code */ 6 + /* MTU */ 4+version_len + /* version name */ 4+protocol_len + /* protocol name */ 4+method_len + /* method name */ 4+method_list_len + /* supported method name list */ - tail_space - ); + tail_space); fastd_handshake_packet_t *packet = buffer.data; packet->rsv = 0; packet->tlv_len = 0; - fastd_handshake_add_uint8(ctx, &buffer, RECORD_HANDSHAKE_TYPE, 1); + fastd_handshake_add_uint8(ctx, &buffer, RECORD_HANDSHAKE_TYPE, type); fastd_handshake_add_uint8(ctx, &buffer, RECORD_MODE, ctx->conf->mode); fastd_handshake_add_uint16(ctx, &buffer, RECORD_MTU, ctx->conf->mtu); fastd_handshake_add(ctx, &buffer, RECORD_VERSION_NAME, version_len, FASTD_VERSION); fastd_handshake_add(ctx, &buffer, RECORD_PROTOCOL_NAME, protocol_len, ctx->conf->protocol->name); - fastd_handshake_add(ctx, &buffer, RECORD_METHOD_NAME, method_len, ctx->conf->method_default->name); - fastd_handshake_add(ctx, &buffer, RECORD_METHOD_LIST, method_list_len, method_list); - free(method_list); + if (method) + fastd_handshake_add(ctx, &buffer, RECORD_METHOD_NAME, method_len, method->name); + + if (with_method_list) { + fastd_handshake_add(ctx, &buffer, RECORD_METHOD_LIST, method_list_len, method_list); + free(method_list); + } return buffer; } -static const fastd_method_t* method_from_name(fastd_context_t *ctx, const char *name, size_t n) { - int i; - for (i = 0; i < MAX_METHODS; i++) { - if (!ctx->conf->methods[i]) - break; - - if (strncmp(name, ctx->conf->methods[i]->name, n) == 0) - return ctx->conf->methods[i]; - } +fastd_buffer_t fastd_handshake_new_init(fastd_context_t *ctx, size_t tail_space) { + if (ctx->conf->secure_handshakes) + return new_handshake(ctx, 1, NULL, false, tail_space); + else + return new_handshake(ctx, 1, ctx->conf->method_default, true, tail_space); +} - return NULL; +fastd_buffer_t fastd_handshake_new_reply(fastd_context_t *ctx, const fastd_handshake_t *handshake, const fastd_method_t *method, bool with_method_list, size_t tail_space) { + fastd_buffer_t buffer = new_handshake(ctx, handshake->type+1, method, with_method_list, tail_space); + fastd_handshake_add_uint8(ctx, &buffer, RECORD_REPLY_CODE, 0); + return buffer; } -fastd_buffer_t fastd_handshake_new_reply(fastd_context_t *ctx, const fastd_handshake_t *handshake, const fastd_method_t *method, size_t tail_space) { - bool first = (AS_UINT8(handshake->records[RECORD_HANDSHAKE_TYPE]) == 1); - size_t version_len = strlen(FASTD_VERSION); - size_t method_len = strlen(method->name); - size_t extra_size = 0; - - if (first) - extra_size = 6 + /* MTU */ - 4+version_len; /* version name */ - - fastd_buffer_t buffer = fastd_buffer_alloc(ctx, sizeof(fastd_handshake_packet_t), 1, - 2*5 + /* handshake type, reply code */ - 4+method_len + /* method name */ - extra_size + - tail_space - ); - fastd_handshake_packet_t *packet = buffer.data; +static void print_error(fastd_context_t *ctx, const char *prefix, const fastd_peer_address_t *remote_addr, uint8_t reply_code, uint8_t error_detail) { + const char *error_field_str; - packet->rsv = 0; - packet->tlv_len = 0; + if (error_detail >= RECORD_MAX) + error_field_str = ""; + else + error_field_str = RECORD_TYPES[error_detail]; - fastd_handshake_add_uint8(ctx, &buffer, RECORD_HANDSHAKE_TYPE, AS_UINT8(handshake->records[RECORD_HANDSHAKE_TYPE])+1); - fastd_handshake_add_uint8(ctx, &buffer, RECORD_REPLY_CODE, 0); - fastd_handshake_add(ctx, &buffer, RECORD_METHOD_NAME, method_len, method->name); + switch (reply_code) { + case REPLY_SUCCESS: + break; - if (first) { - fastd_handshake_add_uint16(ctx, &buffer, RECORD_MTU, ctx->conf->mtu); - fastd_handshake_add(ctx, &buffer, RECORD_VERSION_NAME, version_len, FASTD_VERSION); - } + case REPLY_MANDATORY_MISSING: + pr_warn(ctx, "Handshake with %I failed: %s error: mandatory field `%s' missing", remote_addr, prefix, error_field_str); + break; - return buffer; + case REPLY_UNACCEPTABLE_VALUE: + pr_warn(ctx, "Handshake with %I failed: %s error: unacceptable value for field `%s'", remote_addr, prefix, error_field_str); + break; + + default: + pr_warn(ctx, "Handshake with %I failed: %s error: unknown code %i", remote_addr, prefix, reply_code); + } } -static fastd_string_stack_t* parse_string_list(const uint8_t *data, size_t len) { - const uint8_t *end = data+len; - fastd_string_stack_t *ret = NULL; +static void send_error(fastd_context_t *ctx, fastd_socket_t *sock, const fastd_peer_address_t *local_addr, const fastd_peer_address_t *remote_addr, fastd_peer_t *peer, const fastd_handshake_t *handshake, uint8_t reply_code, uint8_t error_detail) { + print_error(ctx, "sending", remote_addr, reply_code, error_detail); - while (data < end) { - fastd_string_stack_t *part = fastd_string_stack_dupn((char*)data, end-data); - part->next = ret; - ret = part; - data += strlen(part->str) + 1; - } + fastd_buffer_t buffer = fastd_buffer_alloc(ctx, sizeof(fastd_handshake_packet_t), 0, 3*5 /* enough space for handshake type, reply code and error detail */); + fastd_handshake_packet_t *reply = buffer.data; - return ret; + reply->rsv = 0; + reply->tlv_len = 0; + + fastd_handshake_add_uint8(ctx, &buffer, RECORD_HANDSHAKE_TYPE, handshake->type+1); + fastd_handshake_add_uint8(ctx, &buffer, RECORD_REPLY_CODE, reply_code); + fastd_handshake_add_uint8(ctx, &buffer, RECORD_ERROR_DETAIL, error_detail); + + fastd_send_handshake(ctx, sock, local_addr, remote_addr, peer, buffer); } -void fastd_handshake_handle(fastd_context_t *ctx, fastd_socket_t *sock, const fastd_peer_address_t *local_addr, const fastd_peer_address_t *remote_addr, fastd_peer_t *peer, fastd_buffer_t buffer) { - if (buffer.len < sizeof(fastd_handshake_packet_t)) { - pr_warn(ctx, "received a short handshake from %I", remote_addr); - goto end_free; - } +static inline fastd_handshake_t parse_tlvs(const fastd_buffer_t *buffer) { + fastd_handshake_t handshake = {}; - fastd_handshake_packet_t *packet = buffer.data; + if (buffer->len < sizeof(fastd_handshake_packet_t)) + return handshake; - size_t len = buffer.len - sizeof(fastd_handshake_packet_t); + fastd_handshake_packet_t *packet = buffer->data; + + size_t len = buffer->len - sizeof(fastd_handshake_packet_t); if (packet->tlv_len) { - size_t tlv_len = fastd_handshake_tlv_len(&buffer); - if (tlv_len > len) { - pr_warn(ctx, "received a short handshake from %I", remote_addr); - goto end_free; - } + size_t tlv_len = fastd_handshake_tlv_len(buffer); + if (tlv_len > len) + return handshake; len = tlv_len; } uint8_t *ptr = packet->tlv_data, *end = packet->tlv_data + len; - fastd_handshake_t handshake = { .tlv_len = len, .tlv_data = packet->tlv_data }; + handshake.tlv_len = len; + handshake.tlv_data = packet->tlv_data; while (true) { if (ptr+4 > end) @@ -217,166 +254,112 @@ void fastd_handshake_handle(fastd_context_t *ctx, fastd_socket_t *sock, const fa ptr += 4+len; } - if (handshake.records[RECORD_HANDSHAKE_TYPE].length != 1) { - pr_debug(ctx, "received handshake without handshake type from %I", remote_addr); - goto end_free; - } + return handshake; +} - handshake.type = AS_UINT8(handshake.records[RECORD_HANDSHAKE_TYPE]); +static inline void print_error_reply(fastd_context_t *ctx, const fastd_peer_address_t *remote_addr, const fastd_handshake_t *handshake) { + uint8_t reply_code = AS_UINT8(handshake->records[RECORD_REPLY_CODE]); + uint8_t error_detail = RECORD_MAX; - if (handshake.records[RECORD_MTU].length == 2) { - if (AS_UINT16(handshake.records[RECORD_MTU]) != ctx->conf->mtu) { - pr_warn(ctx, "MTU configuration differs with peer %I: local MTU is %u, remote MTU is %u", - remote_addr, ctx->conf->mtu, AS_UINT16(handshake.records[RECORD_MTU])); - } - } + if (handshake->records[RECORD_ERROR_DETAIL].length == 1) + error_detail = AS_UINT8(handshake->records[RECORD_ERROR_DETAIL]); - if (handshake.type == 1) { - uint8_t reply_code = REPLY_SUCCESS; - uint8_t error_detail = 0; + print_error(ctx, "received", remote_addr, reply_code, error_detail); +} - if (!handshake.records[RECORD_MODE].data) { - reply_code = REPLY_MANDATORY_MISSING; - error_detail = RECORD_MODE; - goto send_reply; +static inline bool check_records(fastd_context_t *ctx, fastd_socket_t *sock, const fastd_peer_address_t *local_addr, const fastd_peer_address_t *remote_addr, fastd_peer_t *peer, const fastd_handshake_t *handshake) { + if (!ctx->conf->secure_handshakes || handshake->type > 1) { + if (handshake->records[RECORD_PROTOCOL_NAME].data) { + if (!record_equal(ctx->conf->protocol->name, &handshake->records[RECORD_PROTOCOL_NAME])) { + send_error(ctx, sock, local_addr, remote_addr, peer, handshake, REPLY_UNACCEPTABLE_VALUE, RECORD_PROTOCOL_NAME); + return false; + } } - if (handshake.records[RECORD_MODE].length != 1 || AS_UINT8(handshake.records[RECORD_MODE]) != ctx->conf->mode) { - reply_code = REPLY_UNACCEPTABLE_VALUE; - error_detail = RECORD_MODE; - goto send_reply; + if (handshake->records[RECORD_MODE].data) { + if (handshake->records[RECORD_MODE].length != 1 || AS_UINT8(handshake->records[RECORD_MODE]) != ctx->conf->mode) { + send_error(ctx, sock, local_addr, remote_addr, peer, handshake, REPLY_UNACCEPTABLE_VALUE, RECORD_MODE); + return false; + } } - if (!handshake.records[RECORD_PROTOCOL_NAME].data) { - reply_code = REPLY_MANDATORY_MISSING; - error_detail = RECORD_PROTOCOL_NAME; - goto send_reply; + if (handshake->records[RECORD_MTU].length == 2) { + if (AS_UINT16(handshake->records[RECORD_MTU]) != ctx->conf->mtu) { + pr_warn(ctx, "MTU configuration differs with peer %I: local MTU is %u, remote MTU is %u", + remote_addr, ctx->conf->mtu, AS_UINT16(handshake->records[RECORD_MTU])); + } } + } - if (handshake.records[RECORD_PROTOCOL_NAME].length != strlen(ctx->conf->protocol->name) - || strncmp((char*)handshake.records[RECORD_PROTOCOL_NAME].data, ctx->conf->protocol->name, handshake.records[RECORD_PROTOCOL_NAME].length)) { - reply_code = REPLY_UNACCEPTABLE_VALUE; - error_detail = RECORD_PROTOCOL_NAME; - goto send_reply; + if (handshake->type > 1) { + if (handshake->records[RECORD_REPLY_CODE].length != 1) { + pr_warn(ctx, "received handshake reply without reply code from %I", remote_addr); + return false; } - const fastd_method_t *method = NULL; + if (AS_UINT8(handshake->records[RECORD_REPLY_CODE]) != REPLY_SUCCESS) { + print_error_reply(ctx, remote_addr, handshake); + return false; + } + } - if (handshake.records[RECORD_METHOD_LIST].data && handshake.records[RECORD_METHOD_LIST].length) { - fastd_string_stack_t *method_list = parse_string_list(handshake.records[RECORD_METHOD_LIST].data, handshake.records[RECORD_METHOD_LIST].length); + return true; +} - fastd_string_stack_t *method_name = method_list; - while (method_name) { - const fastd_method_t *cur_method = method_from_name(ctx, method_name->str, SIZE_MAX); +static inline const fastd_method_t* get_method(fastd_context_t *ctx, const fastd_handshake_t *handshake) { + if (handshake->records[RECORD_METHOD_LIST].data && handshake->records[RECORD_METHOD_LIST].length) { + fastd_string_stack_t *method_list = parse_string_list(handshake->records[RECORD_METHOD_LIST].data, handshake->records[RECORD_METHOD_LIST].length); - if (cur_method) - method = cur_method; + const fastd_method_t *method; + fastd_string_stack_t *method_name = method_list; - method_name = method_name->next; - } + while (method_name) { + const fastd_method_t *cur_method = method_from_name(ctx, method_name->str, SIZE_MAX); - fastd_string_stack_free(method_list); + if (cur_method) + method = cur_method; - if (!method) { - reply_code = REPLY_UNACCEPTABLE_VALUE; - error_detail = RECORD_METHOD_LIST; - goto send_reply; - } + method_name = method_name->next; } - else { - if (!handshake.records[RECORD_METHOD_NAME].data) { - reply_code = REPLY_MANDATORY_MISSING; - error_detail = RECORD_METHOD_NAME; - goto send_reply; - } - if (handshake.records[RECORD_METHOD_NAME].length != strlen(ctx->conf->method_default->name) - || strncmp((char*)handshake.records[RECORD_METHOD_NAME].data, ctx->conf->method_default->name, handshake.records[RECORD_METHOD_NAME].length)) { - reply_code = REPLY_UNACCEPTABLE_VALUE; - error_detail = RECORD_METHOD_NAME; - goto send_reply; - } - - method = ctx->conf->method_default; - } - - send_reply: - if (reply_code) { - fastd_buffer_t reply_buffer = fastd_buffer_alloc(ctx, sizeof(fastd_handshake_packet_t), 0, 3*5 /* enough space for handshake type, reply code and error detail */); - fastd_handshake_packet_t *reply = reply_buffer.data; - reply->rsv = 0; - reply->tlv_len = 0; + fastd_string_stack_free(method_list); - fastd_handshake_add_uint8(ctx, &reply_buffer, RECORD_HANDSHAKE_TYPE, 2); - fastd_handshake_add_uint8(ctx, &reply_buffer, RECORD_REPLY_CODE, reply_code); - fastd_handshake_add_uint8(ctx, &reply_buffer, RECORD_ERROR_DETAIL, error_detail); - - fastd_send_handshake(ctx, sock, local_addr, remote_addr, peer, reply_buffer); - } - else { - ctx->conf->protocol->handshake_handle(ctx, sock, local_addr, remote_addr, peer, &handshake, method); - } + return method; } - else { - if (handshake.records[RECORD_REPLY_CODE].length != 1) { - pr_warn(ctx, "received handshake reply without reply code from %I", remote_addr); - goto end_free; - } - - uint8_t reply_code = AS_UINT8(handshake.records[RECORD_REPLY_CODE]); - if (reply_code == REPLY_SUCCESS) { - const fastd_method_t *method = ctx->conf->method_default; - - if (handshake.records[RECORD_METHOD_NAME].data) { - method = method_from_name(ctx, (const char*)handshake.records[RECORD_METHOD_NAME].data, handshake.records[RECORD_METHOD_NAME].length); - } + if (!handshake->records[RECORD_METHOD_NAME].data) + return NULL; - /* - * If we receive an invalid method here, some went really wrong on the other side. - * It doesn't even make sense to send an error reply here. - */ - if (!method) { - pr_warn(ctx, "Handshake with %I failed because an invalid method name was sent", remote_addr); - goto end_free; - } + return method_from_name(ctx, (const char*)handshake->records[RECORD_METHOD_NAME].data, handshake->records[RECORD_METHOD_NAME].length); +} - ctx->conf->protocol->handshake_handle(ctx, sock, local_addr, remote_addr, peer, &handshake, method); - } - else { - const char *error_field_str; +void fastd_handshake_handle(fastd_context_t *ctx, fastd_socket_t *sock, const fastd_peer_address_t *local_addr, const fastd_peer_address_t *remote_addr, fastd_peer_t *peer, fastd_buffer_t buffer) { + fastd_handshake_t handshake = parse_tlvs(&buffer); - if (reply_code >= REPLY_MAX) { - pr_warn(ctx, "Handshake with %I failed with unknown code %i", remote_addr, reply_code); - goto end_free; - } + if (!handshake.tlv_data) { + pr_warn(ctx, "received a short handshake from %I", remote_addr); + goto end_free; + } - if (handshake.records[RECORD_ERROR_DETAIL].length != 1) { - pr_warn(ctx, "Handshake with %I failed with code %s", remote_addr, REPLY_TYPES[reply_code]); - goto end_free; - } + if (handshake.records[RECORD_HANDSHAKE_TYPE].length != 1) { + pr_debug(ctx, "received handshake without handshake type from %I", remote_addr); + goto end_free; + } - uint8_t error_detail = AS_UINT8(handshake.records[RECORD_ERROR_DETAIL]); - if (error_detail >= RECORD_MAX) - error_field_str = ""; - else - error_field_str = RECORD_TYPES[error_detail]; + handshake.type = AS_UINT8(handshake.records[RECORD_HANDSHAKE_TYPE]); - switch (reply_code) { - case REPLY_MANDATORY_MISSING: - pr_warn(ctx, "Handshake with %I failed: mandatory field `%s' missing", remote_addr, error_field_str); - break; + if (!check_records(ctx, sock, local_addr, remote_addr, peer, &handshake)) + goto end_free; - case REPLY_UNACCEPTABLE_VALUE: - pr_warn(ctx, "Handshake with %I failed: unacceptable value for field `%s'", remote_addr, error_field_str); - break; + const fastd_method_t *method = get_method(ctx, &handshake); - default: /* just to silence the warning */ - break; - } - } + if (handshake.type > 1 && !method) { + send_error(ctx, sock, local_addr, remote_addr, peer, &handshake, REPLY_UNACCEPTABLE_VALUE, RECORD_METHOD_NAME); + goto end_free; } + ctx->conf->protocol->handshake_handle(ctx, sock, local_addr, remote_addr, peer, &handshake, method); + end_free: fastd_buffer_free(buffer); } -- cgit v1.2.3