From a0be6d31b4da42e854ae1d05cffc193e8072223a Mon Sep 17 00:00:00 2001 From: Matthias Schiffer Date: Sun, 1 Jul 2012 17:01:13 +0200 Subject: Add support for multiple crypto methods without reconfiguration --- src/config.c | 36 ++++++++-- src/fastd.c | 60 +++++++++++++++-- src/fastd.h | 19 +++++- src/handshake.c | 153 ++++++++++++++++++++++++++++++++++-------- src/handshake.h | 3 +- src/protocol_ec25519_fhmqvc.c | 101 ++++++++++++++++------------ 6 files changed, 285 insertions(+), 87 deletions(-) diff --git a/src/config.c b/src/config.c index e69abcc..5c7fd89 100644 --- a/src/config.c +++ b/src/config.c @@ -80,7 +80,8 @@ static void default_config(fastd_config *conf) { conf->forward = false; conf->protocol = &fastd_protocol_ec25519_fhmqvc; - conf->method = &fastd_method_null; + conf->method_default = &fastd_method_null; + memset(conf->methods, 0, sizeof(conf->methods)); conf->secret = NULL; conf->key_valid = 3600; /* 60 minutes */ conf->key_refresh = 3300; /* 55 minutes */ @@ -136,21 +137,41 @@ bool fastd_config_protocol(fastd_context *ctx, fastd_config *conf, const char *n return true; } -bool fastd_config_method(fastd_context *ctx, fastd_config *conf, const char *name) { +static inline const fastd_method* parse_method_name(const char *name) { if (!strcmp(name, "null")) - conf->method = &fastd_method_null; + return &fastd_method_null; #ifdef WITH_METHOD_XSALSA20_POLY1305 else if (!strcmp(name, "xsalsa20-poly1305")) - conf->method = &fastd_method_xsalsa20_poly1305; + return &fastd_method_xsalsa20_poly1305; #endif #ifdef WITH_METHOD_AES128_GCM else if (!strcmp(name, "aes128-gcm")) - conf->method = &fastd_method_aes128_gcm; + return &fastd_method_aes128_gcm; #endif else + return NULL; +} + +bool fastd_config_method(fastd_context *ctx, fastd_config *conf, const char *name) { + const fastd_method *method = parse_method_name(name); + + if (!method) return false; - return true; + conf->method_default = method; + + int i; + for (i = 0; i < MAX_METHODS; i++) { + if (conf->methods[i] == method) + return true; + + if (conf->methods[i] == NULL) { + conf->methods[i] = method; + return true; + } + } + + exit_bug(ctx, "MAX_METHODS too low"); } bool fastd_config_add_log_file(fastd_context *ctx, fastd_config *conf, const char *name, int level) { @@ -680,6 +701,9 @@ void fastd_configure(fastd_context *ctx, fastd_config *conf, int argc, char *con if (conf->log_stderr_level < 0 && conf->log_syslog_level < 0 && !conf->log_files) conf->log_stderr_level = FASTD_DEFAULT_LOG_LEVEL; + if (!conf->methods[0]) + conf->methods[0] = conf->method_default; + if (conf->generate_key || conf->show_key) return; diff --git a/src/fastd.c b/src/fastd.c index 6a326fb..bd9c13c 100644 --- a/src/fastd.c +++ b/src/fastd.c @@ -235,6 +235,54 @@ static void close_sockets(fastd_context *ctx) { } } +static size_t methods_max_packet_size(fastd_context *ctx) { + size_t ret = ctx->conf->methods[0]->max_packet_size(ctx); + + int i; + for (i = 0; i < MAX_METHODS; i++) { + if (!ctx->conf->methods[i]) + break; + + size_t s = ctx->conf->methods[i]->max_packet_size(ctx); + if (s > ret) + ret = s; + } + + return ret; +} + +static size_t methods_min_encrypt_head_space(fastd_context *ctx) { + size_t ret = ctx->conf->methods[0]->min_encrypt_head_space(ctx); + + int i; + for (i = 0; i < MAX_METHODS; i++) { + if (!ctx->conf->methods[i]) + break; + + size_t s = ctx->conf->methods[i]->min_encrypt_head_space(ctx); + if (s > ret) + ret = s; + } + + return ret; +} + +static size_t methods_min_decrypt_head_space(fastd_context *ctx) { + size_t ret = ctx->conf->methods[0]->min_decrypt_head_space(ctx); + + int i; + for (i = 0; i < MAX_METHODS; i++) { + if (!ctx->conf->methods[i]) + break; + + size_t s = ctx->conf->methods[i]->min_decrypt_head_space(ctx); + if (s > ret) + ret = s; + } + + return ret; +} + static void fastd_send_type(fastd_context *ctx, const fastd_peer_address *address, uint8_t packet_type, fastd_buffer buffer) { int sockfd; struct msghdr msg; @@ -312,7 +360,7 @@ void fastd_handle_receive(fastd_context *ctx, fastd_peer *peer, fastd_buffer buf fastd_peer *dest_peer; for (dest_peer = ctx->peers; dest_peer; dest_peer = dest_peer->next) { if (dest_peer != peer && fastd_peer_is_established(dest_peer)) { - fastd_buffer send_buffer = fastd_buffer_alloc(buffer.len, ctx->conf->method->min_encrypt_head_space(ctx), 0); + fastd_buffer send_buffer = fastd_buffer_alloc(buffer.len, methods_min_encrypt_head_space(ctx), 0); memcpy(send_buffer.data, buffer.data, buffer.len); ctx->conf->protocol->send(ctx, dest_peer, send_buffer); } @@ -431,7 +479,7 @@ static void handle_tasks(fastd_context *ctx) { case TASK_KEEPALIVE: pr_debug(ctx, "sending keepalive to %P", task->peer); - ctx->conf->protocol->send(ctx, task->peer, fastd_buffer_alloc(0, ctx->conf->method->min_encrypt_head_space(ctx), 0)); + ctx->conf->protocol->send(ctx, task->peer, fastd_buffer_alloc(0, methods_min_encrypt_head_space(ctx), 0)); break; default: @@ -444,7 +492,7 @@ static void handle_tasks(fastd_context *ctx) { static void handle_tun(fastd_context *ctx) { size_t max_len = fastd_max_packet_size(ctx); - fastd_buffer buffer = fastd_buffer_alloc(max_len, ctx->conf->method->min_encrypt_head_space(ctx), 0); + fastd_buffer buffer = fastd_buffer_alloc(max_len, methods_min_encrypt_head_space(ctx), 0); ssize_t len = read(ctx->tunfd, buffer.data, max_len); if (len < 0) { @@ -481,7 +529,7 @@ static void handle_tun(fastd_context *ctx) { if (peer == NULL) { for (peer = ctx->peers; peer; peer = peer->next) { if (fastd_peer_is_established(peer)) { - fastd_buffer send_buffer = fastd_buffer_alloc(len, ctx->conf->method->min_encrypt_head_space(ctx), 0); + fastd_buffer send_buffer = fastd_buffer_alloc(len, methods_min_encrypt_head_space(ctx), 0); memcpy(send_buffer.data, buffer.data, len); ctx->conf->protocol->send(ctx, peer, send_buffer); } @@ -492,8 +540,8 @@ static void handle_tun(fastd_context *ctx) { } static void handle_socket(fastd_context *ctx, int sockfd) { - size_t max_len = ctx->conf->method->max_packet_size(ctx); - fastd_buffer buffer = fastd_buffer_alloc(max_len, ctx->conf->method->min_decrypt_head_space(ctx), 0); + size_t max_len = methods_max_packet_size(ctx); + fastd_buffer buffer = fastd_buffer_alloc(max_len, methods_min_decrypt_head_space(ctx), 0); uint8_t packet_type; struct iovec iov[2] = { diff --git a/src/fastd.h b/src/fastd.h index e92f598..4d3d2a4 100644 --- a/src/fastd.h +++ b/src/fastd.h @@ -49,6 +49,10 @@ #define FASTD_VERSION "0.4" +/* This must be adjusted when new methods are added */ +#define MAX_METHODS 3 + + struct _fastd_buffer { void *base; size_t base_len; @@ -68,7 +72,7 @@ struct _fastd_protocol { void (*peer_configure)(fastd_context *ctx, fastd_peer_config *peer_conf); void (*handshake_init)(fastd_context *ctx, const fastd_peer_address *address, const fastd_peer_config *peer_conf); - void (*handshake_handle)(fastd_context *ctx, const fastd_peer_address *address, const fastd_peer_config *peer_conf, const fastd_handshake *handshake); + void (*handshake_handle)(fastd_context *ctx, const fastd_peer_address *address, const fastd_peer_config *peer_conf, const fastd_handshake *handshake, const fastd_method *method); void (*handle_recv)(fastd_context *ctx, fastd_peer *peer, fastd_buffer buffer); void (*send)(fastd_context *ctx, fastd_peer *peer, fastd_buffer buffer); @@ -152,7 +156,8 @@ struct _fastd_config { bool forward; const fastd_protocol *protocol; - const fastd_method *method; + const fastd_method *methods[MAX_METHODS]; + const fastd_method *method_default; char *secret; unsigned key_valid; unsigned key_refresh; @@ -318,6 +323,16 @@ static inline fastd_string_stack* fastd_string_stack_dup(const char *str) { return ret; } +static inline fastd_string_stack* fastd_string_stack_dupn(const char *str, size_t len) { + size_t str_len = strnlen(str, len); + fastd_string_stack *ret = malloc(sizeof(fastd_string_stack) + str_len + 1); + ret->next = NULL; + strncpy(ret->str, str, str_len); + ret->str[str_len] = 0; + + return ret; +} + static inline fastd_string_stack* fastd_string_stack_push(fastd_string_stack *stack, const char *str) { fastd_string_stack *ret = malloc(sizeof(fastd_string_stack) + strlen(str) + 1); ret->next = stack; diff --git a/src/handshake.c b/src/handshake.c index c464896..01a582f 100644 --- a/src/handshake.c +++ b/src/handshake.c @@ -47,6 +47,7 @@ static const char const *RECORD_TYPES[RECORD_MAX] = { "MTU", "method name", "version name", + "method list", }; static const char const *REPLY_TYPES[REPLY_MAX] = { @@ -59,16 +60,45 @@ static const char const *REPLY_TYPES[REPLY_MAX] = { #define AS_UINT16(ptr) ((*(uint8_t*)(ptr).data) + (*((uint8_t*)(ptr).data+1) << 8)) +static uint8_t* create_method_list(fastd_context *ctx, size_t *len) { + *len = strlen(ctx->conf->methods[0]->name); + + int i; + for (i = 1; i < MAX_METHODS; i++) { + if (!ctx->conf->methods[i]) + break; + + *len += strlen(ctx->conf->methods[i]->name) + 1; + } + + uint8_t *ret = malloc(*len+1); + char *ptr = (char*)ret; + + for (i = 0; i < MAX_METHODS; i++) { + if (!ctx->conf->methods[i]) + break; + + ptr = stpcpy(ptr, ctx->conf->methods[i]->name) + 1; + } + + return ret; +} + fastd_buffer fastd_handshake_new_init(fastd_context *ctx, size_t tail_space) { - size_t protocol_len = strlen(ctx->conf->protocol->name); - size_t method_len = strlen(ctx->conf->method->name); size_t version_len = strlen(FASTD_VERSION); + size_t protocol_len = strlen(ctx->conf->protocol->name); + size_t method_len = strlen(ctx->conf->method_default->name); + + size_t method_list_len; + uint8_t *method_list = create_method_list(ctx, &method_list_len); + fastd_buffer buffer = fastd_buffer_alloc(sizeof(fastd_packet), 0, - 2*5 + /* handshake type, mode */ - 6 + /* MTU */ - 4+protocol_len + /* protocol name */ - 4+method_len + /* method name */ - 4+version_len + /* version name */ + 2*5 + /* handshake type, mode */ + 6 + /* MTU */ + 4+version_len + /* version name */ + 4+protocol_len + /* protocol name */ + 4+method_len + /* method name */ + 4+method_list_len + /* supported method name list */ tail_space ); fastd_packet *request = buffer.data; @@ -80,16 +110,33 @@ fastd_buffer fastd_handshake_new_init(fastd_context *ctx, size_t tail_space) { fastd_handshake_add_uint8(ctx, &buffer, RECORD_MODE, ctx->conf->mode); fastd_handshake_add_uint16(ctx, &buffer, RECORD_MTU, ctx->conf->mtu); - fastd_handshake_add(ctx, &buffer, RECORD_PROTOCOL_NAME, protocol_len, ctx->conf->protocol->name); - fastd_handshake_add(ctx, &buffer, RECORD_METHOD_NAME, method_len, ctx->conf->method->name); fastd_handshake_add(ctx, &buffer, RECORD_VERSION_NAME, version_len, FASTD_VERSION); + fastd_handshake_add(ctx, &buffer, RECORD_PROTOCOL_NAME, protocol_len, ctx->conf->protocol->name); + fastd_handshake_add(ctx, &buffer, RECORD_METHOD_NAME, method_len, ctx->conf->method_default->name); + fastd_handshake_add(ctx, &buffer, RECORD_METHOD_LIST, method_list_len, method_list); + + free(method_list); return buffer; } -fastd_buffer fastd_handshake_new_reply(fastd_context *ctx, const fastd_handshake *handshake, size_t tail_space) { +static const fastd_method* method_from_name(fastd_context *ctx, const char *name, size_t n) { + int i; + for (i = 0; i < MAX_METHODS; i++) { + if (!ctx->conf->methods[i]) + break; + + if (strncmp(name, ctx->conf->methods[i]->name, n) == 0) + return ctx->conf->methods[i]; + } + + return NULL; +} + +fastd_buffer fastd_handshake_new_reply(fastd_context *ctx, const fastd_handshake *handshake, const fastd_method *method, size_t tail_space) { bool first = (AS_UINT8(handshake->records[RECORD_HANDSHAKE_TYPE]) == 1); size_t version_len = strlen(FASTD_VERSION); + size_t method_len = strlen(method->name); size_t extra_size = 0; if (first) @@ -98,6 +145,7 @@ fastd_buffer fastd_handshake_new_reply(fastd_context *ctx, const fastd_handshake fastd_buffer buffer = fastd_buffer_alloc(sizeof(fastd_packet), 0, 2*5 + /* handshake type, reply code */ + 4+method_len + /* method name */ extra_size + tail_space ); @@ -108,6 +156,7 @@ fastd_buffer fastd_handshake_new_reply(fastd_context *ctx, const fastd_handshake fastd_handshake_add_uint8(ctx, &buffer, RECORD_HANDSHAKE_TYPE, AS_UINT8(handshake->records[RECORD_HANDSHAKE_TYPE])+1); fastd_handshake_add_uint8(ctx, &buffer, RECORD_REPLY_CODE, 0); + fastd_handshake_add(ctx, &buffer, RECORD_METHOD_NAME, method_len, method->name); if (first) { fastd_handshake_add_uint16(ctx, &buffer, RECORD_MTU, ctx->conf->mtu); @@ -117,6 +166,20 @@ fastd_buffer fastd_handshake_new_reply(fastd_context *ctx, const fastd_handshake return buffer; } +static fastd_string_stack* parse_string_list(uint8_t *data, size_t len) { + uint8_t *end = data+len; + fastd_string_stack *ret = NULL; + + while (data < end) { + fastd_string_stack *part = fastd_string_stack_dupn((char*)data, end-data); + part->next = ret; + ret = part; + data += strlen(part->str) + 1; + } + + return ret; +} + void fastd_handshake_handle(fastd_context *ctx, const fastd_peer_address *address, const fastd_peer_config *peer_conf, fastd_buffer buffer) { if (buffer.len < sizeof(fastd_packet)) { pr_warn(ctx, "received a short handshake from %I", address); @@ -190,17 +253,43 @@ void fastd_handshake_handle(fastd_context *ctx, const fastd_peer_address *addres goto send_reply; } - if (!handshake.records[RECORD_METHOD_NAME].data) { - reply_code = REPLY_MANDATORY_MISSING; - error_detail = RECORD_METHOD_NAME; - goto send_reply; + const fastd_method *method = NULL; + + if (handshake.records[RECORD_METHOD_LIST].data && handshake.records[RECORD_METHOD_LIST].length) { + fastd_string_stack *method_list = parse_string_list(handshake.records[RECORD_METHOD_LIST].data, handshake.records[RECORD_METHOD_LIST].length); + + fastd_string_stack *method_name = method_list; + while (method_name) { + const fastd_method *cur_method = method_from_name(ctx, method_name->str, SIZE_MAX); + + if (cur_method) + method = cur_method; + + method_name = method_name->next; + } + + fastd_string_stack_free(method_list); + + if (!method) { + reply_code = REPLY_UNACCEPTABLE_VALUE; + error_detail = RECORD_METHOD_LIST; + goto send_reply; + } } + else { + if (!handshake.records[RECORD_METHOD_NAME].data) { + reply_code = REPLY_MANDATORY_MISSING; + error_detail = RECORD_METHOD_NAME; + goto send_reply; + } + if (handshake.records[RECORD_METHOD_NAME].length != strlen(ctx->conf->method_default->name) + || strncmp((char*)handshake.records[RECORD_METHOD_NAME].data, ctx->conf->method_default->name, handshake.records[RECORD_METHOD_NAME].length)) { + reply_code = REPLY_UNACCEPTABLE_VALUE; + error_detail = RECORD_METHOD_NAME; + goto send_reply; + } - if (handshake.records[RECORD_METHOD_NAME].length != strlen(ctx->conf->method->name) - || strncmp((char*)handshake.records[RECORD_METHOD_NAME].data, ctx->conf->method->name, handshake.records[RECORD_METHOD_NAME].length)) { - reply_code = REPLY_UNACCEPTABLE_VALUE; - error_detail = RECORD_METHOD_NAME; - goto send_reply; + method = ctx->conf->method_default; } send_reply: @@ -218,17 +307,10 @@ void fastd_handshake_handle(fastd_context *ctx, const fastd_peer_address *addres fastd_send_handshake(ctx, address, reply_buffer); } else { - ctx->conf->protocol->handshake_handle(ctx, address, peer_conf, &handshake); + ctx->conf->protocol->handshake_handle(ctx, address, peer_conf, &handshake, method); } } else { - if ((handshake.type & 1) == 0) { - /*if (packet->req_id != peer->last_req_id) { - pr_warn(ctx, "received handshake reply with request ID %u from %P while %u was expected", packet->req_id, peer, peer->last_req_id); - goto end_free; - }*/ - } - if (handshake.records[RECORD_REPLY_CODE].length != 1) { pr_warn(ctx, "received handshake reply without reply code from %I", address); goto end_free; @@ -237,7 +319,22 @@ void fastd_handshake_handle(fastd_context *ctx, const fastd_peer_address *addres uint8_t reply_code = AS_UINT8(handshake.records[RECORD_REPLY_CODE]); if (reply_code == REPLY_SUCCESS) { - ctx->conf->protocol->handshake_handle(ctx, address, peer_conf, &handshake); + const fastd_method *method = ctx->conf->method_default; + + if (handshake.records[RECORD_METHOD_NAME].data) { + method = method_from_name(ctx, handshake.records[RECORD_METHOD_NAME].data, handshake.records[RECORD_METHOD_NAME].length); + } + + /* + * If we receive an invalid method here, some went really wrong on the other side. + * It doesn't even make sense to send an error reply here. + */ + if (!method) { + pr_warn(ctx, "Handshake with %I failed because an invalid method name was sent", address); + goto end_free; + } + + ctx->conf->protocol->handshake_handle(ctx, address, peer_conf, &handshake, method); } else { const char *error_field_str; diff --git a/src/handshake.h b/src/handshake.h index 2e26194..0db4e2f 100644 --- a/src/handshake.h +++ b/src/handshake.h @@ -45,6 +45,7 @@ typedef enum _fastd_handshake_record_type { RECORD_MTU, RECORD_METHOD_NAME, RECORD_VERSION_NAME, + RECORD_METHOD_LIST, RECORD_MAX, } fastd_handshake_record_type; @@ -68,7 +69,7 @@ struct _fastd_handshake { fastd_buffer fastd_handshake_new_init(fastd_context *ctx, size_t tail_space); -fastd_buffer fastd_handshake_new_reply(fastd_context *ctx, const fastd_handshake *handshake, size_t tail_space); +fastd_buffer fastd_handshake_new_reply(fastd_context *ctx, const fastd_handshake *handshake, const fastd_method *method, size_t tail_space); void fastd_handshake_handle(fastd_context *ctx, const fastd_peer_address *address, const fastd_peer_config *peer_conf, fastd_buffer buffer); diff --git a/src/protocol_ec25519_fhmqvc.c b/src/protocol_ec25519_fhmqvc.c index 8599e29..b27f4d6 100644 --- a/src/protocol_ec25519_fhmqvc.c +++ b/src/protocol_ec25519_fhmqvc.c @@ -82,6 +82,7 @@ typedef struct _protocol_session { bool handshakes_cleaned; bool refreshing; + const fastd_method *method; fastd_method_session_state *method_state; } protocol_session; @@ -100,7 +101,7 @@ struct _fastd_protocol_peer_state { #define RECORD_T RECORD_PROTOCOL5 -static void protocol_send(fastd_context *ctx, fastd_peer *peer, fastd_buffer buffer); +static void send_empty(fastd_context *ctx, fastd_peer *peer, protocol_session *session); static inline bool read_key(uint8_t key[32], const char *hexkey) { @@ -123,7 +124,7 @@ static inline bool is_handshake_key_preferred(fastd_context *ctx, const handshak } static inline bool is_session_valid(fastd_context *ctx, const protocol_session *session) { - return ctx->conf->method->session_is_valid(ctx, session->method_state); + return (session->method && session->method->session_is_valid(ctx, session->method_state)); } static fastd_peer* get_peer(fastd_context *ctx, const fastd_peer_config *peer_conf) { @@ -146,7 +147,7 @@ static bool backoff(fastd_context *ctx, const fastd_peer *peer) { static inline void check_session_refresh(fastd_context *ctx, fastd_peer *peer) { protocol_session *session = &peer->protocol_state->session; - if (!session->refreshing && ctx->conf->method->session_is_initiator(ctx, session->method_state) && ctx->conf->method->session_want_refresh(ctx, session->method_state)) { + if (!session->refreshing && session->method->session_is_initiator(ctx, session->method_state) && session->method->session_want_refresh(ctx, session->method_state)) { pr_verbose(ctx, "refreshing session with %P", peer); session->handshakes_cleaned = true; session->refreshing = true; @@ -244,7 +245,8 @@ static void protocol_handshake_init(fastd_context *ctx, const fastd_peer_address fastd_send_handshake(ctx, address, buffer); } -static void respond_handshake(fastd_context *ctx, const fastd_peer_address *address, const fastd_peer *peer, const handshake_key *handshake_key, const ecc_public_key_256 *peer_handshake_key, const fastd_handshake *handshake) { +static void respond_handshake(fastd_context *ctx, const fastd_peer_address *address, const fastd_peer *peer, const handshake_key *handshake_key, const ecc_public_key_256 *peer_handshake_key, + const fastd_handshake *handshake, const fastd_method *method) { pr_debug(ctx, "responding handshake with %P[%I]...", peer, address); uint8_t hashinput[5*PUBLICKEYBYTES]; @@ -292,7 +294,7 @@ static void respond_handshake(fastd_context *ctx, const fastd_peer_address *addr crypto_auth_hmacsha256(hmacbuf, hashinput, 2*PUBLICKEYBYTES, shared_handshake_key); - fastd_buffer buffer = fastd_handshake_new_reply(ctx, handshake, 4*(4+PUBLICKEYBYTES) + 4+HMACBYTES); + fastd_buffer buffer = fastd_handshake_new_reply(ctx, handshake, method, 4*(4+PUBLICKEYBYTES) + 4+HMACBYTES); fastd_handshake_add(ctx, &buffer, RECORD_SENDER_KEY, PUBLICKEYBYTES, ctx->conf->protocol_config->public_key.p); fastd_handshake_add(ctx, &buffer, RECORD_RECEIPIENT_KEY, PUBLICKEYBYTES, peer->config->protocol_config->public_key.p); @@ -303,7 +305,7 @@ static void respond_handshake(fastd_context *ctx, const fastd_peer_address *addr fastd_send_handshake(ctx, address, buffer); } -static bool establish(fastd_context *ctx, fastd_peer *peer, const fastd_peer_address *address, bool initiator, +static bool establish(fastd_context *ctx, fastd_peer *peer, const fastd_method *method, const fastd_peer_address *address, bool initiator, const ecc_public_key_256 *A, const ecc_public_key_256 *B, const ecc_public_key_256 *X, const ecc_public_key_256 *Y, const ecc_public_key_256 *sigma, uint64_t serial) { uint8_t hashinput[5*PUBLICKEYBYTES]; @@ -317,11 +319,13 @@ static bool establish(fastd_context *ctx, fastd_peer *peer, const fastd_peer_add pr_verbose(ctx, "%I authorized as %P", address, peer); if (is_session_valid(ctx, &peer->protocol_state->session) && !is_session_valid(ctx, &peer->protocol_state->old_session)) { - ctx->conf->method->session_free(ctx, peer->protocol_state->old_session.method_state); + if (peer->protocol_state->old_session.method) + peer->protocol_state->old_session.method->session_free(ctx, peer->protocol_state->old_session.method_state); peer->protocol_state->old_session = peer->protocol_state->session; } else { - ctx->conf->method->session_free(ctx, peer->protocol_state->session.method_state); + if (peer->protocol_state->session.method) + peer->protocol_state->session.method->session_free(ctx, peer->protocol_state->session.method_state); } memcpy(hashinput, X->p, PUBLICKEYBYTES); @@ -334,7 +338,8 @@ static bool establish(fastd_context *ctx, fastd_peer *peer, const fastd_peer_add peer->protocol_state->session.established = ctx->now; peer->protocol_state->session.handshakes_cleaned = false; peer->protocol_state->session.refreshing = false; - peer->protocol_state->session.method_state = ctx->conf->method->session_init(ctx, hash, HASHBYTES, initiator); + peer->protocol_state->session.method = method; + peer->protocol_state->session.method_state = method->session_init(ctx, hash, HASHBYTES, initiator); peer->protocol_state->last_serial = serial; fastd_peer_seen(ctx, peer); @@ -347,17 +352,18 @@ static bool establish(fastd_context *ctx, fastd_peer *peer, const fastd_peer_add fastd_peer_set_established(ctx, peer); - pr_verbose(ctx, "new session with %P established.", peer); + pr_verbose(ctx, "new session with %P established using method `%s'.", peer, method->name); fastd_task_schedule_keepalive(ctx, peer, ctx->conf->keepalive_interval*1000); if (!initiator) - protocol_send(ctx, peer, fastd_buffer_alloc(0, ctx->conf->method->min_encrypt_head_space(ctx), 0)); + send_empty(ctx, peer, &peer->protocol_state->session); return true; } -static void finish_handshake(fastd_context *ctx, const fastd_peer_address *address, fastd_peer *peer, const handshake_key *handshake_key, const ecc_public_key_256 *peer_handshake_key, const fastd_handshake *handshake) { +static void finish_handshake(fastd_context *ctx, const fastd_peer_address *address, fastd_peer *peer, const handshake_key *handshake_key, const ecc_public_key_256 *peer_handshake_key, + const fastd_handshake *handshake, const fastd_method *method) { pr_debug(ctx, "finishing handshake with %P[%I]...", peer, address); uint8_t hashinput[5*PUBLICKEYBYTES]; @@ -412,11 +418,11 @@ static void finish_handshake(fastd_context *ctx, const fastd_peer_address *addre memcpy(hashinput+PUBLICKEYBYTES, handshake_key->public_key.p, PUBLICKEYBYTES); crypto_auth_hmacsha256(hmacbuf, hashinput, 2*PUBLICKEYBYTES, shared_handshake_key); - if (!establish(ctx, peer, address, true, &handshake_key->public_key, peer_handshake_key, &ctx->conf->protocol_config->public_key, + if (!establish(ctx, peer, method, address, true, &handshake_key->public_key, peer_handshake_key, &ctx->conf->protocol_config->public_key, &peer->config->protocol_config->public_key, &sigma, handshake_key->serial)) return; - fastd_buffer buffer = fastd_handshake_new_reply(ctx, handshake, 4*(4+PUBLICKEYBYTES) + 4+HMACBYTES); + fastd_buffer buffer = fastd_handshake_new_reply(ctx, handshake, method, 4*(4+PUBLICKEYBYTES) + 4+HMACBYTES); fastd_handshake_add(ctx, &buffer, RECORD_SENDER_KEY, PUBLICKEYBYTES, ctx->conf->protocol_config->public_key.p); fastd_handshake_add(ctx, &buffer, RECORD_RECEIPIENT_KEY, PUBLICKEYBYTES, peer->config->protocol_config->public_key.p); @@ -428,7 +434,8 @@ static void finish_handshake(fastd_context *ctx, const fastd_peer_address *addre } -static void handle_finish_handshake(fastd_context *ctx, const fastd_peer_address *address, fastd_peer *peer, const handshake_key *handshake_key, const ecc_public_key_256 *peer_handshake_key, const fastd_handshake *handshake) { +static void handle_finish_handshake(fastd_context *ctx, const fastd_peer_address *address, fastd_peer *peer, const handshake_key *handshake_key, const ecc_public_key_256 *peer_handshake_key, + const fastd_handshake *handshake, const fastd_method *method) { pr_debug(ctx, "handling handshake finish with %P[%I]...", peer, address); uint8_t hashinput[5*PUBLICKEYBYTES]; @@ -478,7 +485,7 @@ static void handle_finish_handshake(fastd_context *ctx, const fastd_peer_address return; } - establish(ctx, peer, address, false, peer_handshake_key, &handshake_key->public_key, &peer->config->protocol_config->public_key, + establish(ctx, peer, method, address, false, peer_handshake_key, &handshake_key->public_key, &peer->config->protocol_config->public_key, &ctx->conf->protocol_config->public_key, &sigma, handshake_key->serial); } @@ -516,7 +523,7 @@ static inline bool has_field(const fastd_handshake *handshake, uint8_t type, siz return (handshake->records[type].length == length); } -static void protocol_handshake_handle(fastd_context *ctx, const fastd_peer_address *address, const fastd_peer_config *peer_conf, const fastd_handshake *handshake) { +static void protocol_handshake_handle(fastd_context *ctx, const fastd_peer_address *address, const fastd_peer_config *peer_conf, const fastd_handshake *handshake, const fastd_method *method) { handshake_key *handshake_key; char *peer_version_name = NULL; @@ -582,7 +589,7 @@ static void protocol_handshake_handle(fastd_context *ctx, const fastd_peer_addre peer->last_handshake_response = ctx->now; peer->last_handshake_response_address = *address; - respond_handshake(ctx, address, peer, &ctx->protocol_state->handshake_key, handshake->records[RECORD_SENDER_HANDSHAKE_KEY].data, handshake); + respond_handshake(ctx, address, peer, &ctx->protocol_state->handshake_key, handshake->records[RECORD_SENDER_HANDSHAKE_KEY].data, handshake, method); break; case 2: @@ -603,7 +610,7 @@ static void protocol_handshake_handle(fastd_context *ctx, const fastd_peer_addre pr_verbose(ctx, "received handshake response from %P[%I] using fastd %s", peer, address, peer_version_name); free(peer_version_name); - finish_handshake(ctx, address, peer, handshake_key, handshake->records[RECORD_SENDER_HANDSHAKE_KEY].data, handshake); + finish_handshake(ctx, address, peer, handshake_key, handshake->records[RECORD_SENDER_HANDSHAKE_KEY].data, handshake, method); break; case 3: @@ -620,7 +627,7 @@ static void protocol_handshake_handle(fastd_context *ctx, const fastd_peer_addre pr_debug(ctx, "received handshake finish from %P[%I]", peer, address); - handle_finish_handshake(ctx, address, peer, handshake_key, handshake->records[RECORD_SENDER_HANDSHAKE_KEY].data, handshake); + handle_finish_handshake(ctx, address, peer, handshake_key, handshake->records[RECORD_SENDER_HANDSHAKE_KEY].data, handshake, method); break; default: @@ -643,17 +650,17 @@ static void protocol_handle_recv(fastd_context *ctx, fastd_peer *peer, fastd_buf bool ok = false; if (is_session_valid(ctx, &peer->protocol_state->old_session)) { - if (ctx->conf->method->decrypt(ctx, peer, peer->protocol_state->old_session.method_state, &recv_buffer, buffer)) + if (peer->protocol_state->old_session.method->decrypt(ctx, peer, peer->protocol_state->old_session.method_state, &recv_buffer, buffer)) ok = true; } if (!ok) { - if (ctx->conf->method->decrypt(ctx, peer, peer->protocol_state->session.method_state, &recv_buffer, buffer)) { + if (peer->protocol_state->session.method->decrypt(ctx, peer, peer->protocol_state->session.method_state, &recv_buffer, buffer)) { ok = true; if (peer->protocol_state->old_session.method_state) { pr_debug(ctx, "invalidating old session with %P", peer); - ctx->conf->method->session_free(ctx, peer->protocol_state->old_session.method_state); + peer->protocol_state->old_session.method->session_free(ctx, peer->protocol_state->old_session.method_state); peer->protocol_state->old_session.method_state = NULL; } @@ -662,8 +669,8 @@ static void protocol_handle_recv(fastd_context *ctx, fastd_peer *peer, fastd_buf fastd_task_delete_peer_handshakes(ctx, peer); peer->protocol_state->session.handshakes_cleaned = true; - if (ctx->conf->method->session_is_initiator(ctx, peer->protocol_state->session.method_state)) - protocol_send(ctx, peer, fastd_buffer_alloc(0, ctx->conf->method->min_encrypt_head_space(ctx), 0)); + if (peer->protocol_state->session.method->session_is_initiator(ctx, peer->protocol_state->session.method_state)) + send_empty(ctx, peer, &peer->protocol_state->session); } check_session_refresh(ctx, peer); @@ -688,33 +695,38 @@ static void protocol_handle_recv(fastd_context *ctx, fastd_peer *peer, fastd_buf fastd_buffer_free(buffer); } +static void session_send(fastd_context *ctx, fastd_peer *peer, fastd_buffer buffer, protocol_session *session) { + fastd_buffer send_buffer; + if (!session->method->encrypt(ctx, peer, session->method_state, &send_buffer, buffer)) { + fastd_buffer_free(buffer); + return; + } + + fastd_send(ctx, &peer->address, send_buffer); + + fastd_task_delete_peer_keepalives(ctx, peer); + fastd_task_schedule_keepalive(ctx, peer, ctx->conf->keepalive_interval*1000); +} + static void protocol_send(fastd_context *ctx, fastd_peer *peer, fastd_buffer buffer) { - if (!peer->protocol_state || !is_session_valid(ctx, &peer->protocol_state->session)) - goto fail; + if (!peer->protocol_state || !is_session_valid(ctx, &peer->protocol_state->session)) { + fastd_buffer_free(buffer); + return; + } check_session_refresh(ctx, peer); - protocol_session *session; - if (ctx->conf->method->session_is_initiator(ctx, peer->protocol_state->session.method_state) && is_session_valid(ctx, &peer->protocol_state->old_session)) { + if (peer->protocol_state->session.method->session_is_initiator(ctx, peer->protocol_state->session.method_state) && is_session_valid(ctx, &peer->protocol_state->old_session)) { pr_debug(ctx, "sending packet for old session to %P", peer); - session = &peer->protocol_state->old_session; + session_send(ctx, peer, buffer, &peer->protocol_state->old_session); } else { - session = &peer->protocol_state->session; + session_send(ctx, peer, buffer, &peer->protocol_state->session); } +} - fastd_buffer send_buffer; - if (!ctx->conf->method->encrypt(ctx, peer, session->method_state, &send_buffer, buffer)) - goto fail; - - fastd_send(ctx, &peer->address, send_buffer); - - fastd_task_delete_peer_keepalives(ctx, peer); - fastd_task_schedule_keepalive(ctx, peer, ctx->conf->keepalive_interval*1000); - return; - - fail: - fastd_buffer_free(buffer); +static void send_empty(fastd_context *ctx, fastd_peer *peer, protocol_session *session) { + session_send(ctx, peer, fastd_buffer_alloc(0, session->method->min_encrypt_head_space(ctx), 0), session); } static void protocol_init_peer_state(fastd_context *ctx, fastd_peer *peer) { @@ -729,7 +741,8 @@ static void protocol_init_peer_state(fastd_context *ctx, fastd_peer *peer) { } static void reset_session(fastd_context *ctx, protocol_session *session) { - ctx->conf->method->session_free(ctx, session->method_state); + if (session->method) + session->method->session_free(ctx, session->method_state); memset(session, 0, sizeof(protocol_session)); } -- cgit v1.2.3