diff options
author | Matthias Schiffer <mschiffer@universe-factory.net> | 2021-10-27 00:20:50 +0200 |
---|---|---|
committer | Matthias Schiffer <mschiffer@universe-factory.net> | 2021-10-27 00:44:28 +0200 |
commit | 95b557b7f3b54a4685660d281bf5fc9b1fba2f70 (patch) | |
tree | bb7703209a516f6ef2e563d30f980472826803ae | |
parent | ddb92b57fb92a43dcd3ac03ba987636bbc5d1843 (diff) | |
download | rebel-95b557b7f3b54a4685660d281bf5fc9b1fba2f70.tar rebel-95b557b7f3b54a4685660d281bf5fc9b1fba2f70.zip |
runner: add Steal wrapper, use for socket
A deref wrapper that allows taking out its contents, which is convenient
after a fork.
-rw-r--r-- | crates/runner/src/lib.rs | 39 | ||||
-rw-r--r-- | crates/runner/src/util/mod.rs | 1 | ||||
-rw-r--r-- | crates/runner/src/util/steal.rs | 40 |
3 files changed, 59 insertions, 21 deletions
diff --git a/crates/runner/src/lib.rs b/crates/runner/src/lib.rs index d0cc531..8066c58 100644 --- a/crates/runner/src/lib.rs +++ b/crates/runner/src/lib.rs @@ -27,7 +27,7 @@ use uds::UnixSeqpacketConn; use common::{error::*, types::*}; use jobserver::Jobserver; -use util::{checkable::Checkable, clone, unix}; +use util::{checkable::Checkable, clone, steal::Steal, unix}; #[derive(Debug, Clone)] pub struct Options { @@ -36,6 +36,7 @@ pub struct Options { #[derive(Debug)] struct RunnerContext { + socket: Steal<UnixSeqpacketConn>, jobserver: Jobserver, tasks: HashSet<Pid>, } @@ -53,13 +54,9 @@ fn handle_sigchld(ctx: &mut RunnerContext) -> Result<()> { } } -fn handle_request( - ctx: &mut RunnerContext, - socket: UnixSeqpacketConn, - request_socket: UnixStream, -) -> UnixSeqpacketConn { - let run = |socket| { - drop(socket); +fn handle_request(ctx: &mut RunnerContext, request_socket: UnixStream) { + let run = |()| { + ctx.socket.steal(); let task: Task = bincode::deserialize_from(&request_socket).expect("Failed to decode task description"); @@ -71,30 +68,30 @@ fn handle_request( drop(request_socket); }; - let (pid, socket) = unsafe { clone::spawn(None, socket, run) }.expect("fork()"); + let pid = unsafe { clone::spawn(None, (), run) }.expect("fork()").0; assert!(ctx.tasks.insert(pid)); - - socket } -fn handle_socket(ctx: &mut RunnerContext, socket: UnixSeqpacketConn) -> Option<UnixSeqpacketConn> { +fn handle_socket(ctx: &mut RunnerContext) -> bool { let mut fd = 0; - match socket + match ctx + .socket .recv_fds(&mut [0], slice::from_mut(&mut fd)) .expect("recv_fds()") { (1, _, n_fd) => { assert!(n_fd == 1); } - _ => return None, + _ => return false, } let request_socket = unsafe { UnixStream::from_raw_fd(fd) }; - Some(handle_request(ctx, socket, request_socket)) + handle_request(ctx, request_socket); + true } -fn runner(uid: Uid, gid: Gid, mut socket: UnixSeqpacketConn, _lockfile: File, options: &Options) { +fn runner(uid: Uid, gid: Gid, socket: UnixSeqpacketConn, _lockfile: File, options: &Options) { ns::mount_proc(); ns::setup_userns(Uid::from_raw(0), Gid::from_raw(0), uid, gid); @@ -107,6 +104,7 @@ fn runner(uid: Uid, gid: Gid, mut socket: UnixSeqpacketConn, _lockfile: File, op .unwrap_or_else(|| unix::nproc().expect("Failed to get number of available CPUs")); let jobserver = Jobserver::new(jobs).expect("Failed to initialize jobserver pipe"); let mut ctx = RunnerContext { + socket: socket.into(), jobserver, tasks: HashSet::new(), }; @@ -119,7 +117,7 @@ fn runner(uid: Uid, gid: Gid, mut socket: UnixSeqpacketConn, _lockfile: File, op let mut pollfds = [ poll::PollFd::new(sfd.as_raw_fd(), poll::PollFlags::POLLIN), - poll::PollFd::new(socket.as_raw_fd(), poll::PollFlags::POLLIN), + poll::PollFd::new(ctx.socket.as_raw_fd(), poll::PollFlags::POLLIN), ]; loop { @@ -139,10 +137,9 @@ fn runner(uid: Uid, gid: Gid, mut socket: UnixSeqpacketConn, _lockfile: File, op .revents() .expect("Unknown events in poll() return"); if events.contains(poll::PollFlags::POLLIN) { - socket = match handle_socket(&mut ctx, socket) { - Some(socket) => socket, - None => break, - }; + if !handle_socket(&mut ctx) { + break; + } } else if events.intersects(!poll::PollFlags::POLLIN) { panic!("Unexpected error status for socket file descriptor"); } diff --git a/crates/runner/src/util/mod.rs b/crates/runner/src/util/mod.rs index 5310ccf..1f89592 100644 --- a/crates/runner/src/util/mod.rs +++ b/crates/runner/src/util/mod.rs @@ -2,4 +2,5 @@ pub mod checkable; pub mod cjson; pub mod clone; pub mod fs; +pub mod steal; pub mod unix; diff --git a/crates/runner/src/util/steal.rs b/crates/runner/src/util/steal.rs new file mode 100644 index 0000000..91b2cdf --- /dev/null +++ b/crates/runner/src/util/steal.rs @@ -0,0 +1,40 @@ +use std::ops::{Deref, DerefMut}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub struct Steal<T>(pub Option<T>); + +impl<T> Steal<T> { + pub fn new(value: T) -> Steal<T> { + Steal(Some(value)) + } + + pub fn steal(&mut self) -> T { + self.0 + .take() + .expect("Attempted to steal already stoken value") + } +} + +impl<T> From<T> for Steal<T> { + fn from(value: T) -> Self { + Steal::new(value) + } +} + +impl<T> Deref for Steal<T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.0 + .as_ref() + .expect("Attempted to dereference stolen value") + } +} + +impl<T> DerefMut for Steal<T> { + fn deref_mut(&mut self) -> &mut Self::Target { + self.0 + .as_mut() + .expect("Attempted to dereference stolen value") + } +} |