diff options
Diffstat (limited to 'src/handshake.c')
-rw-r--r-- | src/handshake.c | 52 |
1 files changed, 29 insertions, 23 deletions
diff --git a/src/handshake.c b/src/handshake.c index 2c22655..1bb18a1 100644 --- a/src/handshake.c +++ b/src/handshake.c @@ -105,16 +105,17 @@ static inline uint32_t as_uint(const fastd_handshake_record_t *record) { /** Generates a zero-separated list of supported methods */ -static uint8_t * create_method_list(size_t *len) { - size_t n, i; - for (n = 0; conf.methods[n].name; n++) { - } +static uint8_t * create_method_list(const fastd_string_stack_t *methods, size_t *len) { + size_t n = 0, i; + const fastd_string_stack_t *method; + for (method = methods; method; method = method->next) + n++; *len = 0; size_t lens[n]; - for (i = 0; i < n; i++) { - lens[i] = strlen(conf.methods[i].name) + 1; + for (method = methods, i = 0; method; method = method->next, i++) { + lens[i] = strlen(method->str) + 1; *len += lens[i]; } @@ -123,8 +124,8 @@ static uint8_t * create_method_list(size_t *len) { uint8_t *ptr = ret; - for (i = 0; i < n; i++) { - memcpy(ptr, conf.methods[i].name, lens[i]); + for (method = methods, i = 0; method; method = method->next, i++) { + memcpy(ptr, method->str, lens[i]); ptr += lens[i]; } @@ -160,7 +161,7 @@ static fastd_string_stack_t * parse_string_list(const uint8_t *data, size_t len) } /** Allocates and initializes a new handshake packet */ -static fastd_handshake_buffer_t new_handshake(uint8_t type, bool little_endian, const fastd_method_info_t *method, bool with_method_list, size_t tail_space) { +static fastd_handshake_buffer_t new_handshake(uint8_t type, bool little_endian, const fastd_method_info_t *method, const fastd_string_stack_t *methods, size_t tail_space) { size_t version_len = strlen(FASTD_VERSION); size_t protocol_len = strlen(conf.protocol->name); size_t method_len = method ? strlen(method->name) : 0; @@ -168,8 +169,8 @@ static fastd_handshake_buffer_t new_handshake(uint8_t type, bool little_endian, size_t method_list_len = 0; uint8_t *method_list = NULL; - if (with_method_list) - method_list = create_method_list(&method_list_len); + if (methods) + method_list = create_method_list(methods, &method_list_len); fastd_handshake_buffer_t buffer = { .buffer = fastd_buffer_alloc(sizeof(fastd_handshake_packet_t), 1, @@ -193,10 +194,10 @@ static fastd_handshake_buffer_t new_handshake(uint8_t type, bool little_endian, fastd_handshake_add(&buffer, RECORD_VERSION_NAME, version_len, FASTD_VERSION); fastd_handshake_add(&buffer, RECORD_PROTOCOL_NAME, protocol_len, conf.protocol->name); - if (method && (!with_method_list || !conf.secure_handshakes)) + if (method && (!methods || !conf.secure_handshakes)) fastd_handshake_add(&buffer, RECORD_METHOD_NAME, method_len, method->name); - if (with_method_list) { + if (methods) { fastd_handshake_add(&buffer, RECORD_METHOD_LIST, method_list_len, method_list); free(method_list); } @@ -206,12 +207,12 @@ static fastd_handshake_buffer_t new_handshake(uint8_t type, bool little_endian, /** Allocates and initializes a new initial handshake packet */ fastd_handshake_buffer_t fastd_handshake_new_init(size_t tail_space) { - return new_handshake(1, true, NULL, !conf.secure_handshakes, tail_space); + return new_handshake(1, true, NULL, conf.secure_handshakes ? NULL : conf.peer_group->methods, tail_space); } /** Allocates and initializes a new reply handshake packet */ -fastd_handshake_buffer_t fastd_handshake_new_reply(uint8_t type, bool little_endian, const fastd_method_info_t *method, bool with_method_list, size_t tail_space) { - fastd_handshake_buffer_t buffer = new_handshake(type, little_endian, method, with_method_list, tail_space); +fastd_handshake_buffer_t fastd_handshake_new_reply(uint8_t type, bool little_endian, const fastd_method_info_t *method, const fastd_string_stack_t *methods, size_t tail_space) { + fastd_handshake_buffer_t buffer = new_handshake(type, little_endian, method, methods, tail_space); fastd_handshake_add_uint8(&buffer, RECORD_REPLY_CODE, 0); return buffer; } @@ -390,16 +391,19 @@ static inline bool check_records(fastd_socket_t *sock, const fastd_peer_address_ } /** Returns the method info with a specified name and length */ -static inline const fastd_method_info_t * get_method_by_name(const char *name, size_t n) { +static inline const fastd_method_info_t * get_method_by_name(const fastd_string_stack_t *methods, const char *name, size_t n) { char name0[n+1]; memcpy(name0, name, n); name0[n] = 0; + if (!fastd_string_stack_contains(methods, name0)) + return NULL; + return fastd_method_get_by_name(name0); } /** Returns the most appropriate method to negotiate with a peer a handshake was received from */ -static inline const fastd_method_info_t * get_method(const fastd_handshake_t *handshake) { +static inline const fastd_method_info_t * get_method(const fastd_string_stack_t *methods, const fastd_handshake_t *handshake) { if (handshake->records[RECORD_METHOD_LIST].data && handshake->records[RECORD_METHOD_LIST].length) { fastd_string_stack_t *method_list = parse_string_list(handshake->records[RECORD_METHOD_LIST].data, handshake->records[RECORD_METHOD_LIST].length); @@ -407,10 +411,12 @@ static inline const fastd_method_info_t * get_method(const fastd_handshake_t *ha fastd_string_stack_t *method_name; for (method_name = method_list; method_name; method_name = method_name->next) { - const fastd_method_info_t *cur_method = fastd_method_get_by_name(method_name->str); + if (!fastd_string_stack_contains(methods, method_name->str)) + continue; - if (cur_method) - method = cur_method; + method = fastd_method_get_by_name(method_name->str); + if (!method) + exit_bug("get_method: can't find configured method"); } fastd_string_stack_free(method_list); @@ -421,7 +427,7 @@ static inline const fastd_method_info_t * get_method(const fastd_handshake_t *ha if (!handshake->records[RECORD_METHOD_NAME].data) return NULL; - return get_method_by_name((const char *)handshake->records[RECORD_METHOD_NAME].data, handshake->records[RECORD_METHOD_NAME].length); + return get_method_by_name(methods, (const char *)handshake->records[RECORD_METHOD_NAME].data, handshake->records[RECORD_METHOD_NAME].length); } /** Handles a handshake packet */ @@ -447,7 +453,7 @@ void fastd_handshake_handle(fastd_socket_t *sock, const fastd_peer_address_t *lo goto end_free; if (!conf.secure_handshakes || handshake.type > 1) { - method = get_method(&handshake); + method = get_method(fastd_peer_get_methods(peer), &handshake); if (handshake.records[RECORD_VERSION_NAME].data) handshake.peer_version = peer_version = fastd_strndup((const char *)handshake.records[RECORD_VERSION_NAME].data, handshake.records[RECORD_VERSION_NAME].length); |