summaryrefslogtreecommitdiffstats
path: root/crates/rebel-runner/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'crates/rebel-runner/src/lib.rs')
-rw-r--r--crates/rebel-runner/src/lib.rs217
1 files changed, 217 insertions, 0 deletions
diff --git a/crates/rebel-runner/src/lib.rs b/crates/rebel-runner/src/lib.rs
new file mode 100644
index 0000000..7dde05d
--- /dev/null
+++ b/crates/rebel-runner/src/lib.rs
@@ -0,0 +1,217 @@
+mod init;
+mod jobserver;
+mod ns;
+mod paths;
+mod tar;
+mod task;
+mod util;
+
+use std::{
+ collections::HashSet,
+ fs::File,
+ net,
+ os::unix::{net::UnixStream, prelude::*},
+ process, slice,
+};
+
+use capctl::prctl;
+use nix::{
+ errno::Errno,
+ fcntl::Flock,
+ poll,
+ sched::CloneFlags,
+ sys::{
+ signal,
+ signalfd::{SfdFlags, SignalFd},
+ stat, wait,
+ },
+ unistd::{self, Gid, Pid, Uid},
+};
+use uds::UnixSeqpacketConn;
+
+use rebel_common::{error::*, types::*};
+
+use jobserver::Jobserver;
+use util::{checkable::Checkable, clone, steal::Steal, unix};
+
+#[derive(Debug, Clone)]
+pub struct Options {
+ pub jobs: Option<usize>,
+}
+
+#[derive(Debug)]
+struct RunnerContext {
+ socket: Steal<UnixSeqpacketConn>,
+ jobserver: Jobserver,
+ tasks: HashSet<Pid>,
+}
+
+fn handle_sigchld(ctx: &mut RunnerContext) -> Result<()> {
+ loop {
+ let status = match wait::waitpid(Pid::from_raw(-1), Some(wait::WaitPidFlag::WNOHANG)) {
+ Ok(wait::WaitStatus::StillAlive) | Err(Errno::ECHILD) => return Ok(()),
+ res => res.expect("waitpid()"),
+ };
+ let pid = status.pid().unwrap();
+ if ctx.tasks.remove(&pid) {
+ status.check()?;
+ }
+ }
+}
+
+fn handle_request(ctx: &mut RunnerContext, request_socket: UnixStream) {
+ let run = || {
+ ctx.socket.steal();
+
+ let task: Task =
+ bincode::deserialize_from(&request_socket).expect("Failed to decode task description");
+
+ prctl::set_name(&task.label).expect("prctl(PR_SET_NAME)");
+
+ let result = task::handle(task, &mut ctx.jobserver);
+ bincode::serialize_into(&request_socket, &result).expect("Failed to send task result");
+ drop(request_socket);
+ };
+
+ let pid = unsafe { clone::spawn(None, run) }.expect("fork()");
+ assert!(ctx.tasks.insert(pid));
+}
+
+fn handle_socket(ctx: &mut RunnerContext) -> bool {
+ let mut fd = 0;
+
+ match ctx
+ .socket
+ .recv_fds(&mut [0], slice::from_mut(&mut fd))
+ .expect("recv_fds()")
+ {
+ (1, _, n_fd) => {
+ assert!(n_fd == 1);
+ }
+ _ => return false,
+ }
+
+ let request_socket = unsafe { UnixStream::from_raw_fd(fd) };
+ handle_request(ctx, request_socket);
+ true
+}
+
+fn borrow_socket_fd(socket: &UnixSeqpacketConn) -> BorrowedFd<'_> {
+ unsafe { BorrowedFd::borrow_raw(socket.as_raw_fd()) }
+}
+
+fn runner(
+ uid: Uid,
+ gid: Gid,
+ socket: UnixSeqpacketConn,
+ _lockfile: Flock<File>,
+ options: &Options,
+) -> ! {
+ unistd::setsid().expect("setsid()");
+ ns::mount_proc();
+ ns::setup_userns(Uid::from_raw(0), Gid::from_raw(0), uid, gid);
+
+ stat::umask(stat::Mode::from_bits_truncate(0o022));
+
+ init::init_runner().unwrap();
+
+ let jobs = options
+ .jobs
+ .unwrap_or_else(|| unix::nproc().expect("Failed to get number of available CPUs"));
+ let jobserver = Jobserver::new(jobs).expect("Failed to initialize jobserver pipe");
+ let mut ctx = RunnerContext {
+ socket: socket.into(),
+ jobserver,
+ tasks: HashSet::new(),
+ };
+
+ let mut signals = signal::SigSet::empty();
+ signals.add(signal::Signal::SIGCHLD);
+ signal::pthread_sigmask(signal::SigmaskHow::SIG_BLOCK, Some(&signals), None)
+ .expect("pthread_sigmask()");
+ let mut signal_fd = SignalFd::with_flags(&signals, SfdFlags::SFD_CLOEXEC)
+ .expect("Failed to create signal file descriptor");
+
+ loop {
+ let socket_fd = borrow_socket_fd(&ctx.socket);
+ let mut pollfds = [
+ poll::PollFd::new(signal_fd.as_fd(), poll::PollFlags::POLLIN),
+ poll::PollFd::new(socket_fd.as_fd(), poll::PollFlags::POLLIN),
+ ];
+ poll::poll(&mut pollfds, poll::PollTimeout::NONE).expect("poll()");
+
+ let signal_events = pollfds[0]
+ .revents()
+ .expect("Unknown events in poll() return");
+ let socket_events = pollfds[1]
+ .revents()
+ .expect("Unknown events in poll() return");
+
+ if signal_events.contains(poll::PollFlags::POLLIN) {
+ let _signal = signal_fd.read_signal().expect("read_signal()").unwrap();
+ handle_sigchld(&mut ctx).expect("Task process exited abnormally");
+ } else if signal_events.intersects(!poll::PollFlags::POLLIN) {
+ panic!("Unexpected error status for signal file descriptor");
+ }
+
+ if socket_events.contains(poll::PollFlags::POLLIN) {
+ if !handle_socket(&mut ctx) {
+ break;
+ }
+ } else if socket_events.intersects(!poll::PollFlags::POLLIN) {
+ panic!("Unexpected error status for socket file descriptor");
+ }
+ }
+
+ process::exit(0);
+}
+
+pub struct Runner {
+ socket: UnixSeqpacketConn,
+}
+
+impl Runner {
+ /// Creates a new container runner
+ ///
+ /// # Safety
+ ///
+ /// Do not call in multithreaded processes.
+ pub unsafe fn new(options: &Options) -> Result<Self> {
+ let lockfile = unix::lock(paths::LOCKFILE, true, false)
+ .context("Failed to get lock on build directory, is another instance running?")?;
+
+ let uid = unistd::geteuid();
+ let gid = unistd::getegid();
+
+ let (local, remote) = UnixSeqpacketConn::pair().expect("socketpair()");
+
+ match clone::clone(
+ CloneFlags::CLONE_NEWUSER | CloneFlags::CLONE_NEWNS | CloneFlags::CLONE_NEWPID,
+ )
+ .expect("clone()")
+ {
+ unistd::ForkResult::Parent { .. } => Ok(Runner { socket: local }),
+ unistd::ForkResult::Child => {
+ drop(local);
+ runner(uid, gid, remote, lockfile, options);
+ }
+ }
+ }
+
+ pub fn spawn(&self, task: &Task) -> UnixStream {
+ let (local, remote) = UnixStream::pair().expect("socketpair()");
+
+ self.socket
+ .send_fds(&[0], &[remote.as_raw_fd()])
+ .expect("send()");
+
+ bincode::serialize_into(&local, task).expect("Task submission failed");
+ local.shutdown(net::Shutdown::Write).expect("shutdown()");
+
+ local
+ }
+
+ pub fn result(socket: &UnixStream) -> Result<TaskOutput> {
+ bincode::deserialize_from(socket).expect("Failed to read task result")
+ }
+}