From 95b557b7f3b54a4685660d281bf5fc9b1fba2f70 Mon Sep 17 00:00:00 2001 From: Matthias Schiffer Date: Wed, 27 Oct 2021 00:20:50 +0200 Subject: runner: add Steal wrapper, use for socket A deref wrapper that allows taking out its contents, which is convenient after a fork. --- crates/runner/src/lib.rs | 39 ++++++++++++++++++--------------------- crates/runner/src/util/mod.rs | 1 + crates/runner/src/util/steal.rs | 40 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 21 deletions(-) create mode 100644 crates/runner/src/util/steal.rs diff --git a/crates/runner/src/lib.rs b/crates/runner/src/lib.rs index d0cc531..8066c58 100644 --- a/crates/runner/src/lib.rs +++ b/crates/runner/src/lib.rs @@ -27,7 +27,7 @@ use uds::UnixSeqpacketConn; use common::{error::*, types::*}; use jobserver::Jobserver; -use util::{checkable::Checkable, clone, unix}; +use util::{checkable::Checkable, clone, steal::Steal, unix}; #[derive(Debug, Clone)] pub struct Options { @@ -36,6 +36,7 @@ pub struct Options { #[derive(Debug)] struct RunnerContext { + socket: Steal, jobserver: Jobserver, tasks: HashSet, } @@ -53,13 +54,9 @@ fn handle_sigchld(ctx: &mut RunnerContext) -> Result<()> { } } -fn handle_request( - ctx: &mut RunnerContext, - socket: UnixSeqpacketConn, - request_socket: UnixStream, -) -> UnixSeqpacketConn { - let run = |socket| { - drop(socket); +fn handle_request(ctx: &mut RunnerContext, request_socket: UnixStream) { + let run = |()| { + ctx.socket.steal(); let task: Task = bincode::deserialize_from(&request_socket).expect("Failed to decode task description"); @@ -71,30 +68,30 @@ fn handle_request( drop(request_socket); }; - let (pid, socket) = unsafe { clone::spawn(None, socket, run) }.expect("fork()"); + let pid = unsafe { clone::spawn(None, (), run) }.expect("fork()").0; assert!(ctx.tasks.insert(pid)); - - socket } -fn handle_socket(ctx: &mut RunnerContext, socket: UnixSeqpacketConn) -> Option { +fn handle_socket(ctx: &mut RunnerContext) -> bool { let mut fd = 0; - match socket + match ctx + .socket .recv_fds(&mut [0], slice::from_mut(&mut fd)) .expect("recv_fds()") { (1, _, n_fd) => { assert!(n_fd == 1); } - _ => return None, + _ => return false, } let request_socket = unsafe { UnixStream::from_raw_fd(fd) }; - Some(handle_request(ctx, socket, request_socket)) + handle_request(ctx, request_socket); + true } -fn runner(uid: Uid, gid: Gid, mut socket: UnixSeqpacketConn, _lockfile: File, options: &Options) { +fn runner(uid: Uid, gid: Gid, socket: UnixSeqpacketConn, _lockfile: File, options: &Options) { ns::mount_proc(); ns::setup_userns(Uid::from_raw(0), Gid::from_raw(0), uid, gid); @@ -107,6 +104,7 @@ fn runner(uid: Uid, gid: Gid, mut socket: UnixSeqpacketConn, _lockfile: File, op .unwrap_or_else(|| unix::nproc().expect("Failed to get number of available CPUs")); let jobserver = Jobserver::new(jobs).expect("Failed to initialize jobserver pipe"); let mut ctx = RunnerContext { + socket: socket.into(), jobserver, tasks: HashSet::new(), }; @@ -119,7 +117,7 @@ fn runner(uid: Uid, gid: Gid, mut socket: UnixSeqpacketConn, _lockfile: File, op let mut pollfds = [ poll::PollFd::new(sfd.as_raw_fd(), poll::PollFlags::POLLIN), - poll::PollFd::new(socket.as_raw_fd(), poll::PollFlags::POLLIN), + poll::PollFd::new(ctx.socket.as_raw_fd(), poll::PollFlags::POLLIN), ]; loop { @@ -139,10 +137,9 @@ fn runner(uid: Uid, gid: Gid, mut socket: UnixSeqpacketConn, _lockfile: File, op .revents() .expect("Unknown events in poll() return"); if events.contains(poll::PollFlags::POLLIN) { - socket = match handle_socket(&mut ctx, socket) { - Some(socket) => socket, - None => break, - }; + if !handle_socket(&mut ctx) { + break; + } } else if events.intersects(!poll::PollFlags::POLLIN) { panic!("Unexpected error status for socket file descriptor"); } diff --git a/crates/runner/src/util/mod.rs b/crates/runner/src/util/mod.rs index 5310ccf..1f89592 100644 --- a/crates/runner/src/util/mod.rs +++ b/crates/runner/src/util/mod.rs @@ -2,4 +2,5 @@ pub mod checkable; pub mod cjson; pub mod clone; pub mod fs; +pub mod steal; pub mod unix; diff --git a/crates/runner/src/util/steal.rs b/crates/runner/src/util/steal.rs new file mode 100644 index 0000000..91b2cdf --- /dev/null +++ b/crates/runner/src/util/steal.rs @@ -0,0 +1,40 @@ +use std::ops::{Deref, DerefMut}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub struct Steal(pub Option); + +impl Steal { + pub fn new(value: T) -> Steal { + Steal(Some(value)) + } + + pub fn steal(&mut self) -> T { + self.0 + .take() + .expect("Attempted to steal already stoken value") + } +} + +impl From for Steal { + fn from(value: T) -> Self { + Steal::new(value) + } +} + +impl Deref for Steal { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.0 + .as_ref() + .expect("Attempted to dereference stolen value") + } +} + +impl DerefMut for Steal { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0 + .as_mut() + .expect("Attempted to dereference stolen value") + } +} -- cgit v1.2.3