summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMatthias Schiffer <mschiffer@universe-factory.net>2021-10-27 00:20:50 +0200
committerMatthias Schiffer <mschiffer@universe-factory.net>2021-10-27 00:44:28 +0200
commit95b557b7f3b54a4685660d281bf5fc9b1fba2f70 (patch)
treebb7703209a516f6ef2e563d30f980472826803ae
parentddb92b57fb92a43dcd3ac03ba987636bbc5d1843 (diff)
downloadrebel-95b557b7f3b54a4685660d281bf5fc9b1fba2f70.tar
rebel-95b557b7f3b54a4685660d281bf5fc9b1fba2f70.zip
runner: add Steal wrapper, use for socket
A deref wrapper that allows taking out its contents, which is convenient after a fork.
-rw-r--r--crates/runner/src/lib.rs39
-rw-r--r--crates/runner/src/util/mod.rs1
-rw-r--r--crates/runner/src/util/steal.rs40
3 files changed, 59 insertions, 21 deletions
diff --git a/crates/runner/src/lib.rs b/crates/runner/src/lib.rs
index d0cc531..8066c58 100644
--- a/crates/runner/src/lib.rs
+++ b/crates/runner/src/lib.rs
@@ -27,7 +27,7 @@ use uds::UnixSeqpacketConn;
use common::{error::*, types::*};
use jobserver::Jobserver;
-use util::{checkable::Checkable, clone, unix};
+use util::{checkable::Checkable, clone, steal::Steal, unix};
#[derive(Debug, Clone)]
pub struct Options {
@@ -36,6 +36,7 @@ pub struct Options {
#[derive(Debug)]
struct RunnerContext {
+ socket: Steal<UnixSeqpacketConn>,
jobserver: Jobserver,
tasks: HashSet<Pid>,
}
@@ -53,13 +54,9 @@ fn handle_sigchld(ctx: &mut RunnerContext) -> Result<()> {
}
}
-fn handle_request(
- ctx: &mut RunnerContext,
- socket: UnixSeqpacketConn,
- request_socket: UnixStream,
-) -> UnixSeqpacketConn {
- let run = |socket| {
- drop(socket);
+fn handle_request(ctx: &mut RunnerContext, request_socket: UnixStream) {
+ let run = |()| {
+ ctx.socket.steal();
let task: Task =
bincode::deserialize_from(&request_socket).expect("Failed to decode task description");
@@ -71,30 +68,30 @@ fn handle_request(
drop(request_socket);
};
- let (pid, socket) = unsafe { clone::spawn(None, socket, run) }.expect("fork()");
+ let pid = unsafe { clone::spawn(None, (), run) }.expect("fork()").0;
assert!(ctx.tasks.insert(pid));
-
- socket
}
-fn handle_socket(ctx: &mut RunnerContext, socket: UnixSeqpacketConn) -> Option<UnixSeqpacketConn> {
+fn handle_socket(ctx: &mut RunnerContext) -> bool {
let mut fd = 0;
- match socket
+ match ctx
+ .socket
.recv_fds(&mut [0], slice::from_mut(&mut fd))
.expect("recv_fds()")
{
(1, _, n_fd) => {
assert!(n_fd == 1);
}
- _ => return None,
+ _ => return false,
}
let request_socket = unsafe { UnixStream::from_raw_fd(fd) };
- Some(handle_request(ctx, socket, request_socket))
+ handle_request(ctx, request_socket);
+ true
}
-fn runner(uid: Uid, gid: Gid, mut socket: UnixSeqpacketConn, _lockfile: File, options: &Options) {
+fn runner(uid: Uid, gid: Gid, socket: UnixSeqpacketConn, _lockfile: File, options: &Options) {
ns::mount_proc();
ns::setup_userns(Uid::from_raw(0), Gid::from_raw(0), uid, gid);
@@ -107,6 +104,7 @@ fn runner(uid: Uid, gid: Gid, mut socket: UnixSeqpacketConn, _lockfile: File, op
.unwrap_or_else(|| unix::nproc().expect("Failed to get number of available CPUs"));
let jobserver = Jobserver::new(jobs).expect("Failed to initialize jobserver pipe");
let mut ctx = RunnerContext {
+ socket: socket.into(),
jobserver,
tasks: HashSet::new(),
};
@@ -119,7 +117,7 @@ fn runner(uid: Uid, gid: Gid, mut socket: UnixSeqpacketConn, _lockfile: File, op
let mut pollfds = [
poll::PollFd::new(sfd.as_raw_fd(), poll::PollFlags::POLLIN),
- poll::PollFd::new(socket.as_raw_fd(), poll::PollFlags::POLLIN),
+ poll::PollFd::new(ctx.socket.as_raw_fd(), poll::PollFlags::POLLIN),
];
loop {
@@ -139,10 +137,9 @@ fn runner(uid: Uid, gid: Gid, mut socket: UnixSeqpacketConn, _lockfile: File, op
.revents()
.expect("Unknown events in poll() return");
if events.contains(poll::PollFlags::POLLIN) {
- socket = match handle_socket(&mut ctx, socket) {
- Some(socket) => socket,
- None => break,
- };
+ if !handle_socket(&mut ctx) {
+ break;
+ }
} else if events.intersects(!poll::PollFlags::POLLIN) {
panic!("Unexpected error status for socket file descriptor");
}
diff --git a/crates/runner/src/util/mod.rs b/crates/runner/src/util/mod.rs
index 5310ccf..1f89592 100644
--- a/crates/runner/src/util/mod.rs
+++ b/crates/runner/src/util/mod.rs
@@ -2,4 +2,5 @@ pub mod checkable;
pub mod cjson;
pub mod clone;
pub mod fs;
+pub mod steal;
pub mod unix;
diff --git a/crates/runner/src/util/steal.rs b/crates/runner/src/util/steal.rs
new file mode 100644
index 0000000..91b2cdf
--- /dev/null
+++ b/crates/runner/src/util/steal.rs
@@ -0,0 +1,40 @@
+use std::ops::{Deref, DerefMut};
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
+pub struct Steal<T>(pub Option<T>);
+
+impl<T> Steal<T> {
+ pub fn new(value: T) -> Steal<T> {
+ Steal(Some(value))
+ }
+
+ pub fn steal(&mut self) -> T {
+ self.0
+ .take()
+ .expect("Attempted to steal already stoken value")
+ }
+}
+
+impl<T> From<T> for Steal<T> {
+ fn from(value: T) -> Self {
+ Steal::new(value)
+ }
+}
+
+impl<T> Deref for Steal<T> {
+ type Target = T;
+
+ fn deref(&self) -> &Self::Target {
+ self.0
+ .as_ref()
+ .expect("Attempted to dereference stolen value")
+ }
+}
+
+impl<T> DerefMut for Steal<T> {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ self.0
+ .as_mut()
+ .expect("Attempted to dereference stolen value")
+ }
+}