diff options
Diffstat (limited to 'crates/rebel-runner/src/lib.rs')
-rw-r--r-- | crates/rebel-runner/src/lib.rs | 217 |
1 files changed, 217 insertions, 0 deletions
diff --git a/crates/rebel-runner/src/lib.rs b/crates/rebel-runner/src/lib.rs new file mode 100644 index 0000000..7dde05d --- /dev/null +++ b/crates/rebel-runner/src/lib.rs @@ -0,0 +1,217 @@ +mod init; +mod jobserver; +mod ns; +mod paths; +mod tar; +mod task; +mod util; + +use std::{ + collections::HashSet, + fs::File, + net, + os::unix::{net::UnixStream, prelude::*}, + process, slice, +}; + +use capctl::prctl; +use nix::{ + errno::Errno, + fcntl::Flock, + poll, + sched::CloneFlags, + sys::{ + signal, + signalfd::{SfdFlags, SignalFd}, + stat, wait, + }, + unistd::{self, Gid, Pid, Uid}, +}; +use uds::UnixSeqpacketConn; + +use rebel_common::{error::*, types::*}; + +use jobserver::Jobserver; +use util::{checkable::Checkable, clone, steal::Steal, unix}; + +#[derive(Debug, Clone)] +pub struct Options { + pub jobs: Option<usize>, +} + +#[derive(Debug)] +struct RunnerContext { + socket: Steal<UnixSeqpacketConn>, + jobserver: Jobserver, + tasks: HashSet<Pid>, +} + +fn handle_sigchld(ctx: &mut RunnerContext) -> Result<()> { + loop { + let status = match wait::waitpid(Pid::from_raw(-1), Some(wait::WaitPidFlag::WNOHANG)) { + Ok(wait::WaitStatus::StillAlive) | Err(Errno::ECHILD) => return Ok(()), + res => res.expect("waitpid()"), + }; + let pid = status.pid().unwrap(); + if ctx.tasks.remove(&pid) { + status.check()?; + } + } +} + +fn handle_request(ctx: &mut RunnerContext, request_socket: UnixStream) { + let run = || { + ctx.socket.steal(); + + let task: Task = + bincode::deserialize_from(&request_socket).expect("Failed to decode task description"); + + prctl::set_name(&task.label).expect("prctl(PR_SET_NAME)"); + + let result = task::handle(task, &mut ctx.jobserver); + bincode::serialize_into(&request_socket, &result).expect("Failed to send task result"); + drop(request_socket); + }; + + let pid = unsafe { clone::spawn(None, run) }.expect("fork()"); + assert!(ctx.tasks.insert(pid)); +} + +fn handle_socket(ctx: &mut RunnerContext) -> bool { + let mut fd = 0; + + match ctx + .socket + .recv_fds(&mut [0], slice::from_mut(&mut fd)) + .expect("recv_fds()") + { + (1, _, n_fd) => { + assert!(n_fd == 1); + } + _ => return false, + } + + let request_socket = unsafe { UnixStream::from_raw_fd(fd) }; + handle_request(ctx, request_socket); + true +} + +fn borrow_socket_fd(socket: &UnixSeqpacketConn) -> BorrowedFd<'_> { + unsafe { BorrowedFd::borrow_raw(socket.as_raw_fd()) } +} + +fn runner( + uid: Uid, + gid: Gid, + socket: UnixSeqpacketConn, + _lockfile: Flock<File>, + options: &Options, +) -> ! { + unistd::setsid().expect("setsid()"); + ns::mount_proc(); + ns::setup_userns(Uid::from_raw(0), Gid::from_raw(0), uid, gid); + + stat::umask(stat::Mode::from_bits_truncate(0o022)); + + init::init_runner().unwrap(); + + let jobs = options + .jobs + .unwrap_or_else(|| unix::nproc().expect("Failed to get number of available CPUs")); + let jobserver = Jobserver::new(jobs).expect("Failed to initialize jobserver pipe"); + let mut ctx = RunnerContext { + socket: socket.into(), + jobserver, + tasks: HashSet::new(), + }; + + let mut signals = signal::SigSet::empty(); + signals.add(signal::Signal::SIGCHLD); + signal::pthread_sigmask(signal::SigmaskHow::SIG_BLOCK, Some(&signals), None) + .expect("pthread_sigmask()"); + let mut signal_fd = SignalFd::with_flags(&signals, SfdFlags::SFD_CLOEXEC) + .expect("Failed to create signal file descriptor"); + + loop { + let socket_fd = borrow_socket_fd(&ctx.socket); + let mut pollfds = [ + poll::PollFd::new(signal_fd.as_fd(), poll::PollFlags::POLLIN), + poll::PollFd::new(socket_fd.as_fd(), poll::PollFlags::POLLIN), + ]; + poll::poll(&mut pollfds, poll::PollTimeout::NONE).expect("poll()"); + + let signal_events = pollfds[0] + .revents() + .expect("Unknown events in poll() return"); + let socket_events = pollfds[1] + .revents() + .expect("Unknown events in poll() return"); + + if signal_events.contains(poll::PollFlags::POLLIN) { + let _signal = signal_fd.read_signal().expect("read_signal()").unwrap(); + handle_sigchld(&mut ctx).expect("Task process exited abnormally"); + } else if signal_events.intersects(!poll::PollFlags::POLLIN) { + panic!("Unexpected error status for signal file descriptor"); + } + + if socket_events.contains(poll::PollFlags::POLLIN) { + if !handle_socket(&mut ctx) { + break; + } + } else if socket_events.intersects(!poll::PollFlags::POLLIN) { + panic!("Unexpected error status for socket file descriptor"); + } + } + + process::exit(0); +} + +pub struct Runner { + socket: UnixSeqpacketConn, +} + +impl Runner { + /// Creates a new container runner + /// + /// # Safety + /// + /// Do not call in multithreaded processes. + pub unsafe fn new(options: &Options) -> Result<Self> { + let lockfile = unix::lock(paths::LOCKFILE, true, false) + .context("Failed to get lock on build directory, is another instance running?")?; + + let uid = unistd::geteuid(); + let gid = unistd::getegid(); + + let (local, remote) = UnixSeqpacketConn::pair().expect("socketpair()"); + + match clone::clone( + CloneFlags::CLONE_NEWUSER | CloneFlags::CLONE_NEWNS | CloneFlags::CLONE_NEWPID, + ) + .expect("clone()") + { + unistd::ForkResult::Parent { .. } => Ok(Runner { socket: local }), + unistd::ForkResult::Child => { + drop(local); + runner(uid, gid, remote, lockfile, options); + } + } + } + + pub fn spawn(&self, task: &Task) -> UnixStream { + let (local, remote) = UnixStream::pair().expect("socketpair()"); + + self.socket + .send_fds(&[0], &[remote.as_raw_fd()]) + .expect("send()"); + + bincode::serialize_into(&local, task).expect("Task submission failed"); + local.shutdown(net::Shutdown::Write).expect("shutdown()"); + + local + } + + pub fn result(socket: &UnixStream) -> Result<TaskOutput> { + bincode::deserialize_from(socket).expect("Failed to read task result") + } +} |