From b2d02587fcd86f0c3910441d58c94dd0c9fea5b5 Mon Sep 17 00:00:00 2001 From: Matthias Schiffer Date: Sat, 15 Sep 2012 08:55:50 +0200 Subject: Add support for kernel AES implementations This gives AES128 a slight boost on my system, but it is still slower than XSalsa20... I should probably write userspace code that can make use of AES-NI and CLMUL. Or directly jump to the kernel space with the whole forwarding code. Nevertheless, this might run nicely on Geode CPUs and similar hardware with AES acceleration, at least if the context switches aren't too expensive... --- src/fastd.h | 1 + src/linux_alg.c | 117 ++++++++++++++++++++++++++++++++++++------------ src/linux_alg.h | 3 ++ src/method_aes128_gcm.c | 109 ++++++++++++++++++++++++++------------------ 4 files changed, 158 insertions(+), 72 deletions(-) (limited to 'src') diff --git a/src/fastd.h b/src/fastd.h index 80ec0d2..3981f46 100644 --- a/src/fastd.h +++ b/src/fastd.h @@ -216,6 +216,7 @@ struct _fastd_context { int sock6fd; int algfd_ghash; + int algfd_aesctr; size_t eth_addr_size; size_t n_eth_addr; diff --git a/src/linux_alg.c b/src/linux_alg.c index 70f4752..9591ce8 100644 --- a/src/linux_alg.c +++ b/src/linux_alg.c @@ -26,6 +26,7 @@ #include "linux_alg.h" +#include #include #include @@ -38,27 +39,43 @@ void fastd_linux_alg_init(fastd_context *ctx) { ctx->algfd_ghash = socket(AF_ALG, SOCK_SEQPACKET, 0); if (ctx->algfd_ghash < 0) - goto error_ghash; + goto ghash_done; - struct sockaddr_alg sa_ghash = {}; - sa_ghash.salg_family = AF_ALG; - strcpy((char*)sa_ghash.salg_type, "hash"); - strcpy((char*)sa_ghash.salg_name, "ghash"); - if (bind(ctx->algfd_ghash, (struct sockaddr*)&sa_ghash, sizeof(sa_ghash)) < 0) { + struct sockaddr_alg sa = {}; + sa.salg_family = AF_ALG; + strcpy((char*)sa.salg_type, "hash"); + strcpy((char*)sa.salg_name, "ghash"); + if (bind(ctx->algfd_ghash, (struct sockaddr*)&sa, sizeof(sa)) < 0) { close(ctx->algfd_ghash); - goto error_ghash; + ctx->algfd_ghash = -1; } - return; + ghash_done: + if (ctx->algfd_ghash < 0) + pr_info(ctx, "no kernel support for GHASH was found, falling back to userspace implementation"); + + ctx->algfd_aesctr = socket(AF_ALG, SOCK_SEQPACKET, 0); + if (ctx->algfd_aesctr < 0) + goto aesctr_done; - error_ghash: - pr_info(ctx, "no kernel support for GHASH was found, falling back to userspace implementation"); - ctx->algfd_ghash = -1; + strcpy((char*)sa.salg_type, "skcipher"); + strcpy((char*)sa.salg_name, "ctr(aes)"); + if (bind(ctx->algfd_aesctr, (struct sockaddr*)&sa, sizeof(sa)) < 0) { + close(ctx->algfd_aesctr); + ctx->algfd_aesctr = -1; + } + + aesctr_done: + if (ctx->algfd_aesctr < 0) + pr_info(ctx, "no kernel support for AES-CTR was found, falling back to userspace implementation"); } void fastd_linux_alg_close(fastd_context *ctx) { if (ctx->algfd_ghash >= 0) close(ctx->algfd_ghash); + + if (ctx->algfd_aesctr >= 0) + close(ctx->algfd_aesctr); } @@ -85,29 +102,73 @@ bool fastd_linux_alg_ghash(fastd_context *ctx, int fd, uint8_t out[16], const vo if (!len) return false; - const uint8_t *in = data; - const uint8_t *end = in+len; + if (write(fd, data, len) < 0) { + pr_error_errno(ctx, "fastd_linux_alg_ghash: write"); + return false; + } - while (in < end) { - int bytes = write(fd, in, end-in); - if (bytes < 0) { - pr_error_errno(ctx, "fastd_linux_alg_ghash: write"); - return false; - } + if (read(fd, out, 16) < 16) { + pr_error_errno(ctx, "fastd_linux_alg_ghash: read"); + return false; + } - in += bytes; + return true; +} + +int fastd_linux_alg_aesctr_init(fastd_context *ctx, uint8_t *key, size_t keylen) { + if (ctx->algfd_aesctr < 0) + return -1; + + if (setsockopt(ctx->algfd_aesctr, SOL_ALG, ALG_SET_KEY, key, keylen) < 0) { + pr_error_errno(ctx, "fastd_linux_alg_aesctr_init: setsockopt"); + return -1; } - end = out+16; + int ret = accept(ctx->algfd_aesctr, NULL, NULL); - while (out < end) { - int bytes = read(fd, out, end-out); - if (bytes < 0) { - pr_error_errno(ctx, "fastd_linux_alg_ghash: read"); - return false; - } + if (ret < 0) { + pr_error_errno(ctx, "fastd_linux_alg_aesctr_init: accept"); + return -1; + } + + return ret; +} + +bool fastd_linux_alg_aesctr(fastd_context *ctx, int fd, void *out, const void *in, size_t len, const uint8_t iv[16]) { + if (!len) + return false; + + struct iovec vec = { .iov_base = (void*)in, .iov_len = len }; - out += bytes; + static const size_t cmsglen = sizeof(struct cmsghdr)+sizeof(struct af_alg_iv)+16; + struct cmsghdr *cmsg = alloca(cmsglen); + cmsg->cmsg_len = cmsglen; + cmsg->cmsg_level = SOL_ALG; + cmsg->cmsg_type = ALG_SET_IV; + + struct af_alg_iv *alg_iv = (void*)CMSG_DATA(cmsg); + alg_iv->ivlen = 16; + memcpy(alg_iv->iv, iv, 16); + + struct msghdr msg = { + .msg_iov = &vec, + .msg_iovlen = 1, + .msg_control = cmsg, + .msg_controllen = cmsglen + }; + + if (sendmsg(fd, &msg, 0) < 0) { + pr_error_errno(ctx, "fastd_linux_alg_aesctr: sendmsg"); + return false; + } + + msg.msg_control = NULL; + msg.msg_controllen = 0; + vec.iov_base = out; + + if (recvmsg(fd, &msg, 0) < 0) { + pr_error_errno(ctx, "fastd_linux_alg_aesctr: recvmsg"); + return false; } return true; diff --git a/src/linux_alg.h b/src/linux_alg.h index e10f23a..0099474 100644 --- a/src/linux_alg.h +++ b/src/linux_alg.h @@ -36,4 +36,7 @@ void fastd_linux_alg_close(fastd_context *ctx); int fastd_linux_alg_ghash_init(fastd_context *ctx, uint8_t key[16]); bool fastd_linux_alg_ghash(fastd_context *ctx, int fd, uint8_t out[16], const void *data, size_t len); +int fastd_linux_alg_aesctr_init(fastd_context *ctx, uint8_t *key, size_t keylen); +bool fastd_linux_alg_aesctr(fastd_context *ctx, int fd, void *out, const void *in, size_t len, const uint8_t iv[16]); + #endif /* _FASTD_LINUX_ALG_H_ */ diff --git a/src/method_aes128_gcm.c b/src/method_aes128_gcm.c index 86203a2..ed8f026 100644 --- a/src/method_aes128_gcm.c +++ b/src/method_aes128_gcm.c @@ -31,6 +31,7 @@ #include +#define KEYBYTES 16 #define NONCEBYTES 7 #define BLOCKBYTES 16 #define BLOCKQWORDS (BLOCKBYTES/8) @@ -54,6 +55,7 @@ struct _fastd_method_session_state { struct timespec receive_last; uint64_t receive_reorder_seen; + int algfd_aesctr; int algfd_ghash; }; @@ -92,7 +94,11 @@ static size_t method_max_packet_size(fastd_context *ctx) { } -static size_t method_min_head_space(fastd_context *ctx) { +static size_t method_min_encrypt_head_space(fastd_context *ctx) { + return BLOCKBYTES; +} + +static size_t method_min_decrypt_head_space(fastd_context *ctx) { return 0; } @@ -129,10 +135,32 @@ static inline void xor_a(block_t *x, const block_t *a) { xor(x, x, a); } +static inline void xor_blocks(block_t *out, const block_t *in1, const block_t *in2, size_t n_blocks) { + int i; + for (i = 0; i < n_blocks; i++) + xor(&out[i], &in1[i], &in2[i]); +} + +static void aes128ctr(fastd_context *ctx, block_t *out, const block_t *in, size_t n_blocks, const block_t *iv, fastd_method_session_state *session) { + if (session->algfd_aesctr >= 0) { + if (fastd_linux_alg_aesctr(ctx, session->algfd_aesctr, out, in, n_blocks*BLOCKBYTES, iv->b)) + return; + + /* on error */ + close(session->algfd_aesctr); + session->algfd_aesctr = -1; + } + + block_t stream[n_blocks]; + crypto_stream_aes128ctr_afternm((uint8_t*)stream, sizeof(stream), iv->b, session->d); + + xor_blocks(out, in, stream, n_blocks); +} + static fastd_method_session_state* method_session_init(fastd_context *ctx, uint8_t *secret, size_t length, bool initiator) { int i; - if (length < crypto_stream_aes128ctr_KEYBYTES) + if (length < KEYBYTES) exit_bug(ctx, "aes128-gcm: tried to init with short secret"); fastd_method_session_state *session = malloc(sizeof(fastd_method_session_state)); @@ -144,11 +172,12 @@ static fastd_method_session_state* method_session_init(fastd_context *ctx, uint8 session->refresh_after.tv_sec += ctx->conf->key_refresh; crypto_stream_aes128ctr_beforenm(session->d, secret); + session->algfd_aesctr = fastd_linux_alg_aesctr_init(ctx, secret, KEYBYTES); - static const uint8_t zerononce[crypto_stream_aes128ctr_NONCEBYTES] = {}; + static const block_t zeroblock = {}; block_t Hbase[4]; - crypto_stream_aes128ctr_afternm(Hbase[0].b, BLOCKBYTES, zerononce, session->d); + aes128ctr(ctx, &Hbase[0], &zeroblock, 1, &zeroblock, session); block_t Rbase[4]; Rbase[0] = r; @@ -230,12 +259,6 @@ static void mulH_a(block_t *x, fastd_method_session_state *session) { *x = out; } -static inline void xor_blocks(block_t *out, const block_t *in1, const block_t *in2, size_t n_blocks) { - int i; - for (i = 0; i < n_blocks; i++) - xor(&out[i], &in1[i], &in2[i]); -} - static inline void put_size(block_t *out, size_t len) { memset(out, 0, BLOCKBYTES-5); out->b[BLOCKBYTES-5] = len >> 29; @@ -245,7 +268,7 @@ static inline void put_size(block_t *out, size_t len) { out->b[BLOCKBYTES-1] = len << 3; } -static inline void ghash(fastd_context *ctx, block_t *out, const block_t *blocks, size_t n_blocks, fastd_method_session_state *session) { +static void ghash(fastd_context *ctx, block_t *out, const block_t *blocks, size_t n_blocks, fastd_method_session_state *session) { if (session->algfd_ghash >= 0) { if (fastd_linux_alg_ghash(ctx, session->algfd_ghash, out->b, blocks, n_blocks*BLOCKBYTES)) return; @@ -265,39 +288,39 @@ static inline void ghash(fastd_context *ctx, block_t *out, const block_t *blocks } static bool method_encrypt(fastd_context *ctx, fastd_peer *peer, fastd_method_session_state *session, fastd_buffer *out, fastd_buffer in) { + fastd_buffer_pull_head(&in, BLOCKBYTES); + memset(in.data, 0, BLOCKBYTES); + size_t tail_len = ALIGN(in.len, BLOCKBYTES)-in.len; - *out = fastd_buffer_alloc(in.len, ALIGN(NONCEBYTES+BLOCKBYTES, 8), BLOCKBYTES+tail_len); + *out = fastd_buffer_alloc(in.len, ALIGN(NONCEBYTES, 8), BLOCKBYTES+tail_len); if (tail_len) memset(in.data+in.len, 0, tail_len); - uint8_t nonce[crypto_stream_aes128ctr_NONCEBYTES]; - memcpy(nonce, session->send_nonce, NONCEBYTES); - memset(nonce+NONCEBYTES, 0, crypto_stream_aes128ctr_NONCEBYTES-NONCEBYTES-1); - nonce[crypto_stream_aes128ctr_NONCEBYTES-1] = 1; + block_t nonce; + memcpy(nonce.b, session->send_nonce, NONCEBYTES); + memset(nonce.b+NONCEBYTES, 0, crypto_stream_aes128ctr_NONCEBYTES-NONCEBYTES-1); + nonce.b[crypto_stream_aes128ctr_NONCEBYTES-1] = 1; int n_blocks = (in.len+BLOCKBYTES-1)/BLOCKBYTES; - block_t stream[n_blocks+1]; - crypto_stream_aes128ctr_afternm((uint8_t*)stream, sizeof(stream), nonce, session->d); - block_t *inblocks = in.data; block_t *outblocks = out->data; - xor_blocks(outblocks, inblocks, stream+1, n_blocks); + aes128ctr(ctx, outblocks, inblocks, n_blocks, &nonce, session); if (tail_len) memset(out->data+out->len, 0, tail_len); - put_size(&outblocks[n_blocks], in.len); + put_size(&outblocks[n_blocks], in.len-BLOCKBYTES); - block_t *sig = outblocks-1; - ghash(ctx, sig, outblocks, n_blocks+1, session); - xor_a(sig, &stream[0]); + block_t sig; + ghash(ctx, &sig, outblocks+1, n_blocks, session); + xor_a(&outblocks[0], &sig); fastd_buffer_free(in); - fastd_buffer_pull_head(out, NONCEBYTES+BLOCKBYTES); + fastd_buffer_pull_head(out, NONCEBYTES); memcpy(out->data, session->send_nonce, NONCEBYTES); increment_nonce(session->send_nonce); @@ -311,13 +334,13 @@ static bool method_decrypt(fastd_context *ctx, fastd_peer *peer, fastd_method_se if (!method_session_is_valid(ctx, session)) return false; - uint8_t nonce[crypto_stream_aes128ctr_NONCEBYTES]; - memcpy(nonce, in.data, NONCEBYTES); - memset(nonce+NONCEBYTES, 0, crypto_stream_aes128ctr_NONCEBYTES-NONCEBYTES-1); - nonce[crypto_stream_aes128ctr_NONCEBYTES-1] = 1; + block_t nonce; + memcpy(nonce.b, in.data, NONCEBYTES); + memset(nonce.b+NONCEBYTES, 0, crypto_stream_aes128ctr_NONCEBYTES-NONCEBYTES-1); + nonce.b[crypto_stream_aes128ctr_NONCEBYTES-1] = 1; int64_t age; - if (!is_nonce_valid(nonce, session->receive_nonce, &age)) + if (!is_nonce_valid(nonce.b, session->receive_nonce, &age)) return false; if (age >= 0) { @@ -328,45 +351,43 @@ static bool method_decrypt(fastd_context *ctx, fastd_peer *peer, fastd_method_se return false; } - fastd_buffer_push_head(&in, NONCEBYTES+BLOCKBYTES); + fastd_buffer_push_head(&in, NONCEBYTES); size_t tail_len = ALIGN(in.len, BLOCKBYTES)-in.len; *out = fastd_buffer_alloc(in.len, 0, tail_len); int n_blocks = (in.len+BLOCKBYTES-1)/BLOCKBYTES; - block_t stream[n_blocks+1]; - crypto_stream_aes128ctr_afternm((uint8_t*)stream, sizeof(stream), nonce, session->d); - block_t *inblocks = in.data; block_t *outblocks = out->data; + aes128ctr(ctx, outblocks, inblocks, n_blocks, &nonce, session); + if (tail_len) memset(in.data+in.len, 0, tail_len); - put_size(&inblocks[n_blocks], in.len); + put_size(&inblocks[n_blocks], in.len-BLOCKBYTES); block_t sig; - ghash(ctx, &sig, inblocks, n_blocks+1, session); - xor_a(&sig, &stream[0]); + ghash(ctx, &sig, inblocks+1, n_blocks, session); - if (memcmp(&sig, inblocks-1, BLOCKBYTES) != 0) { + if (memcmp(&sig, &outblocks[0], BLOCKBYTES) != 0) { fastd_buffer_free(*out); /* restore input buffer */ - fastd_buffer_pull_head(&in, NONCEBYTES+BLOCKBYTES); + fastd_buffer_pull_head(&in, NONCEBYTES); return false; } - xor_blocks(outblocks, inblocks, stream+1, n_blocks); - fastd_buffer_free(in); + fastd_buffer_push_head(out, BLOCKBYTES); + if (age < 0) { session->receive_reorder_seen >>= age; session->receive_reorder_seen |= (1 >> (age+1)); - memcpy(session->receive_nonce, nonce, NONCEBYTES); + memcpy(session->receive_nonce, nonce.b, NONCEBYTES); session->receive_last = ctx->now; } else if (age == 0 || session->receive_reorder_seen & (1 << (age-1))) { @@ -386,8 +407,8 @@ const fastd_method fastd_method_aes128_gcm = { .name = "aes128-gcm", .max_packet_size = method_max_packet_size, - .min_encrypt_head_space = method_min_head_space, - .min_decrypt_head_space = method_min_head_space, + .min_encrypt_head_space = method_min_encrypt_head_space, + .min_decrypt_head_space = method_min_decrypt_head_space, .min_encrypt_tail_space = method_min_encrypt_tail_space, .min_decrypt_tail_space = method_min_decrypt_tail_space, -- cgit v1.2.3