diff options
Diffstat (limited to 'src/linux_alg.c')
-rw-r--r-- | src/linux_alg.c | 117 |
1 files changed, 89 insertions, 28 deletions
diff --git a/src/linux_alg.c b/src/linux_alg.c index 70f4752..9591ce8 100644 --- a/src/linux_alg.c +++ b/src/linux_alg.c @@ -26,6 +26,7 @@ #include "linux_alg.h" +#include <alloca.h> #include <linux/if_alg.h> #include <unistd.h> @@ -38,27 +39,43 @@ void fastd_linux_alg_init(fastd_context *ctx) { ctx->algfd_ghash = socket(AF_ALG, SOCK_SEQPACKET, 0); if (ctx->algfd_ghash < 0) - goto error_ghash; + goto ghash_done; - struct sockaddr_alg sa_ghash = {}; - sa_ghash.salg_family = AF_ALG; - strcpy((char*)sa_ghash.salg_type, "hash"); - strcpy((char*)sa_ghash.salg_name, "ghash"); - if (bind(ctx->algfd_ghash, (struct sockaddr*)&sa_ghash, sizeof(sa_ghash)) < 0) { + struct sockaddr_alg sa = {}; + sa.salg_family = AF_ALG; + strcpy((char*)sa.salg_type, "hash"); + strcpy((char*)sa.salg_name, "ghash"); + if (bind(ctx->algfd_ghash, (struct sockaddr*)&sa, sizeof(sa)) < 0) { close(ctx->algfd_ghash); - goto error_ghash; + ctx->algfd_ghash = -1; } - return; + ghash_done: + if (ctx->algfd_ghash < 0) + pr_info(ctx, "no kernel support for GHASH was found, falling back to userspace implementation"); + + ctx->algfd_aesctr = socket(AF_ALG, SOCK_SEQPACKET, 0); + if (ctx->algfd_aesctr < 0) + goto aesctr_done; - error_ghash: - pr_info(ctx, "no kernel support for GHASH was found, falling back to userspace implementation"); - ctx->algfd_ghash = -1; + strcpy((char*)sa.salg_type, "skcipher"); + strcpy((char*)sa.salg_name, "ctr(aes)"); + if (bind(ctx->algfd_aesctr, (struct sockaddr*)&sa, sizeof(sa)) < 0) { + close(ctx->algfd_aesctr); + ctx->algfd_aesctr = -1; + } + + aesctr_done: + if (ctx->algfd_aesctr < 0) + pr_info(ctx, "no kernel support for AES-CTR was found, falling back to userspace implementation"); } void fastd_linux_alg_close(fastd_context *ctx) { if (ctx->algfd_ghash >= 0) close(ctx->algfd_ghash); + + if (ctx->algfd_aesctr >= 0) + close(ctx->algfd_aesctr); } @@ -85,29 +102,73 @@ bool fastd_linux_alg_ghash(fastd_context *ctx, int fd, uint8_t out[16], const vo if (!len) return false; - const uint8_t *in = data; - const uint8_t *end = in+len; + if (write(fd, data, len) < 0) { + pr_error_errno(ctx, "fastd_linux_alg_ghash: write"); + return false; + } - while (in < end) { - int bytes = write(fd, in, end-in); - if (bytes < 0) { - pr_error_errno(ctx, "fastd_linux_alg_ghash: write"); - return false; - } + if (read(fd, out, 16) < 16) { + pr_error_errno(ctx, "fastd_linux_alg_ghash: read"); + return false; + } - in += bytes; + return true; +} + +int fastd_linux_alg_aesctr_init(fastd_context *ctx, uint8_t *key, size_t keylen) { + if (ctx->algfd_aesctr < 0) + return -1; + + if (setsockopt(ctx->algfd_aesctr, SOL_ALG, ALG_SET_KEY, key, keylen) < 0) { + pr_error_errno(ctx, "fastd_linux_alg_aesctr_init: setsockopt"); + return -1; } - end = out+16; + int ret = accept(ctx->algfd_aesctr, NULL, NULL); - while (out < end) { - int bytes = read(fd, out, end-out); - if (bytes < 0) { - pr_error_errno(ctx, "fastd_linux_alg_ghash: read"); - return false; - } + if (ret < 0) { + pr_error_errno(ctx, "fastd_linux_alg_aesctr_init: accept"); + return -1; + } + + return ret; +} + +bool fastd_linux_alg_aesctr(fastd_context *ctx, int fd, void *out, const void *in, size_t len, const uint8_t iv[16]) { + if (!len) + return false; + + struct iovec vec = { .iov_base = (void*)in, .iov_len = len }; - out += bytes; + static const size_t cmsglen = sizeof(struct cmsghdr)+sizeof(struct af_alg_iv)+16; + struct cmsghdr *cmsg = alloca(cmsglen); + cmsg->cmsg_len = cmsglen; + cmsg->cmsg_level = SOL_ALG; + cmsg->cmsg_type = ALG_SET_IV; + + struct af_alg_iv *alg_iv = (void*)CMSG_DATA(cmsg); + alg_iv->ivlen = 16; + memcpy(alg_iv->iv, iv, 16); + + struct msghdr msg = { + .msg_iov = &vec, + .msg_iovlen = 1, + .msg_control = cmsg, + .msg_controllen = cmsglen + }; + + if (sendmsg(fd, &msg, 0) < 0) { + pr_error_errno(ctx, "fastd_linux_alg_aesctr: sendmsg"); + return false; + } + + msg.msg_control = NULL; + msg.msg_controllen = 0; + vec.iov_base = out; + + if (recvmsg(fd, &msg, 0) < 0) { + pr_error_errno(ctx, "fastd_linux_alg_aesctr: recvmsg"); + return false; } return true; |