summaryrefslogtreecommitdiffstats
path: root/crates
diff options
context:
space:
mode:
Diffstat (limited to 'crates')
-rw-r--r--crates/driver/Cargo.toml2
-rw-r--r--crates/driver/src/driver.rs78
2 files changed, 67 insertions, 13 deletions
diff --git a/crates/driver/Cargo.toml b/crates/driver/Cargo.toml
index 673e0ba..df1fb19 100644
--- a/crates/driver/Cargo.toml
+++ b/crates/driver/Cargo.toml
@@ -17,7 +17,7 @@ enum-kinds = "0.5.1"
handlebars = "5.1.2"
indoc = "2.0.4"
lazy_static = "1.4.0"
-nix = { version = "0.28.0", features = ["poll"] }
+nix = { version = "0.28.0", features = ["poll", "signal"] }
scoped-tls-hkt = "0.1.2"
serde = { version = "1", features = ["derive", "rc"] }
serde_yaml = "0.9"
diff --git a/crates/driver/src/driver.rs b/crates/driver/src/driver.rs
index cc4bfbf..5a00882 100644
--- a/crates/driver/src/driver.rs
+++ b/crates/driver/src/driver.rs
@@ -1,10 +1,17 @@
use std::{
collections::{HashMap, HashSet},
+ iter,
os::unix::{net::UnixStream, prelude::*},
};
use indoc::indoc;
-use nix::poll;
+use nix::{
+ poll,
+ sys::{
+ signal,
+ signalfd::{SfdFlags, SignalFd},
+ },
+};
use common::{error::*, string_hash::*, types::*};
use runner::Runner;
@@ -143,11 +150,18 @@ impl<'ctx> CompletionState<'ctx> {
}
}
+#[derive(Debug)]
enum SpawnResult {
Spawned(UnixStream),
Skipped(TaskOutput),
}
+#[derive(Debug, PartialEq, Eq, Hash)]
+enum TaskWaitResult {
+ Failed,
+ Interrupted,
+}
+
#[derive(Debug)]
pub struct Driver<'ctx> {
rdeps: HashMap<TaskRef<'ctx>, Vec<TaskRef<'ctx>>>,
@@ -351,11 +365,13 @@ impl<'ctx> Driver<'ctx> {
Ok(())
}
- fn wait_for_task(&mut self) -> Result<bool> {
+ fn wait_for_task(&mut self, signal_fd: &mut SignalFd) -> Result<Option<TaskWaitResult>> {
let mut pollfds: Vec<_> = self
.tasks_running
.values()
- .map(|(socket, _)| poll::PollFd::new(socket.as_fd(), poll::PollFlags::POLLIN))
+ .map(|(socket, _)| socket.as_fd())
+ .chain(iter::once(signal_fd.as_fd()))
+ .map(|fd| poll::PollFd::new(fd, poll::PollFlags::POLLIN))
.collect();
while poll::poll(&mut pollfds, poll::PollTimeout::NONE).context("poll()")? == 0 {}
@@ -380,6 +396,11 @@ impl<'ctx> Driver<'ctx> {
continue;
}
+ if fd == signal_fd.as_raw_fd() {
+ let _signal = signal_fd.read_signal().expect("read_signal()").unwrap();
+ return Ok(Some(TaskWaitResult::Interrupted));
+ }
+
let (socket, task_ref) = self.tasks_running.remove(&fd).unwrap();
match Runner::result(&socket) {
@@ -388,12 +409,12 @@ impl<'ctx> Driver<'ctx> {
}
Err(error) => {
eprintln!("{}", error);
- return Ok(false);
+ return Ok(Some(TaskWaitResult::Failed));
}
}
}
- Ok(true)
+ Ok(None)
}
fn is_done(&self) -> bool {
@@ -402,25 +423,58 @@ impl<'ctx> Driver<'ctx> {
&& self.tasks_running.is_empty()
}
+ fn setup_signalfd() -> Result<SignalFd> {
+ let mut signals = signal::SigSet::empty();
+ signals.add(signal::Signal::SIGINT);
+ signal::pthread_sigmask(signal::SigmaskHow::SIG_BLOCK, Some(&signals), None)
+ .expect("pthread_sigmask()");
+ SignalFd::with_flags(&signals, SfdFlags::SFD_CLOEXEC)
+ .context("Failed to create signal file descriptor")
+ }
+
+ fn raise_sigint() {
+ let mut signals = signal::SigSet::empty();
+ signals.add(signal::Signal::SIGINT);
+ signal::pthread_sigmask(signal::SigmaskHow::SIG_UNBLOCK, Some(&signals), None)
+ .expect("pthread_sigmask()");
+ signal::raise(signal::Signal::SIGINT).expect("raise()");
+ unreachable!();
+ }
+
pub fn run(&mut self, runner: &Runner, keep_going: bool) -> Result<bool> {
let mut success = true;
+ let mut interrupted = false;
+
+ let mut signal_fd = Self::setup_signalfd()?;
self.run_tasks(runner)?;
while !self.tasks_running.is_empty() {
- if !self.wait_for_task()? {
- success = false;
+ match self.wait_for_task(&mut signal_fd)? {
+ Some(TaskWaitResult::Failed) => {
+ success = false;
+ }
+ Some(TaskWaitResult::Interrupted) => {
+ if interrupted {
+ Self::raise_sigint();
+ }
+ eprintln!("Interrupt received, not spawning new tasks. Interrupt again to stop immediately.");
+ interrupted = true;
+ }
+ None => {}
}
- if success || keep_going {
+ if !interrupted && (success || keep_going) {
self.run_tasks(runner)?;
}
}
- if success {
- assert!(self.is_done(), "No runnable tasks left");
- self.state.print_summary();
+ if interrupted || !success {
+ return Ok(false);
}
- Ok(success)
+ assert!(self.is_done(), "No runnable tasks left");
+ self.state.print_summary();
+
+ Ok(true)
}
}