summaryrefslogtreecommitdiffstats
path: root/src
diff options
context:
space:
mode:
authorMatthias Schiffer <mschiffer@universe-factory.net>2021-01-26 22:04:18 +0100
committerMatthias Schiffer <mschiffer@universe-factory.net>2021-01-26 22:04:18 +0100
commite48108cdef9555565d0715f4b6f39228cfc45376 (patch)
treecc4d1bf88a2a7d407dda841108ea30778bece53d /src
parent9e60d16555765a6cfc053798f56e5b914ea1834e (diff)
downloadrebel-e48108cdef9555565d0715f4b6f39228cfc45376.tar
rebel-e48108cdef9555565d0715f4b6f39228cfc45376.zip
Rewrite dependency resolution to reuse solutions
Diffstat (limited to 'src')
-rw-r--r--src/main.rs84
-rw-r--r--src/recipe.rs10
-rw-r--r--src/resolve.rs91
-rw-r--r--src/types.rs20
4 files changed, 130 insertions, 75 deletions
diff --git a/src/main.rs b/src/main.rs
index 8c5708f..b572afc 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -1,88 +1,38 @@
-use im::{HashMap, HashSet};
-use std::{fmt, path::Path, rc::Rc};
+use std::collections::HashSet;
+use std::path::Path;
mod recipe;
+mod resolve;
+mod types;
-#[derive(Debug)]
-enum Error {
- TaskNotFound(String),
- DependencyCycle(String),
-}
-
-impl fmt::Display for Error {
- fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- match self {
- Error::TaskNotFound(id) => write!(f, "Task not found: {}", id),
- Error::DependencyCycle(id) => write!(f, "DependencyCycle: {}", id),
- }
- }
-}
-
-impl std::error::Error for Error {}
-
-type Result<T> = std::result::Result<T, Error>;
-
-#[derive(Default)]
-struct Context {
- tasks: HashMap<String, Rc<recipe::Task>>,
-}
-
-impl Context {
- fn collect_subtasks<'a>(
- &'a self,
- queued: &HashSet<&'a str>,
- id: &str,
- ) -> Result<HashSet<&'a str>> {
- let (task_id, task) = match self.tasks.get_key_value(id) {
- Some(t) => t,
- None => {
- return Err(Error::TaskNotFound(id.to_string()));
- }
- };
-
- if queued.contains(id) {
- return Err(Error::DependencyCycle(id.to_string()));
- }
-
- let queued_sub = queued.update(task_id);
-
- let mut subtasks: HashSet<&'a str> = HashSet::new();
- subtasks.insert(task_id);
-
- for dep in &task.depends {
- let deptasks = self.collect_subtasks(&queued_sub, dep)?;
- subtasks = subtasks.union(deptasks);
- }
-
- Ok(subtasks)
- }
-
- fn collect_tasks(&self, id: &str) -> Result<HashSet<&str>> {
- self.collect_subtasks(&HashSet::new(), id)
- }
-}
+use resolve::Result;
+use types::*;
fn main() -> Result<()> {
let recipes = recipe::read_recipes(Path::new("examples")).unwrap();
- let mut ctx = Context::default();
+ let mut tasks: TaskMap = TaskMap::default();
for (recipe_name, recipe) in recipes {
for (task_name, task) in recipe.tasks {
let full_name = format!("{}:{}", recipe_name, task_name);
- ctx.tasks.insert(full_name, Rc::new(task));
+ tasks.0.insert(full_name, task);
}
}
- let queue = ctx.collect_tasks("ls:build")?;
- let (runnable, queued): (HashSet<&str>, HashSet<&str>) = queue
+ let mut rsv = resolve::Resolver::new(&tasks);
+
+ rsv.add_goal(&"ls:build".to_string())?;
+ let taskset = rsv.to_taskset();
+
+ let (runnable, queued): (HashSet<TaskRef>, HashSet<TaskRef>) = taskset
.into_iter()
- .partition(|id| ctx.tasks.get(*id).unwrap().depends.is_empty());
+ .partition(|id| tasks.get(id).unwrap().depends.is_empty());
for t in &runnable {
- println!("Runnable: {} ({:?})", t, ctx.tasks.get(*t).unwrap().run);
+ println!("Runnable: {} ({:?})", t, tasks.get(t).unwrap().run);
}
for t in &queued {
- println!("Queued: {} ({:?})", t, ctx.tasks.get(*t).unwrap().run);
+ println!("Queued: {} ({:?})", t, tasks.get(t).unwrap().run);
}
Ok(())
diff --git a/src/recipe.rs b/src/recipe.rs
index d06b62b..20bca6d 100644
--- a/src/recipe.rs
+++ b/src/recipe.rs
@@ -1,14 +1,8 @@
-use im::HashMap;
use serde::Deserialize;
-use std::{fmt, fs::File, io, path::Path};
+use std::{collections::HashMap, fmt, fs::File, io, path::Path};
use walkdir::WalkDir;
-#[derive(Clone, Debug, Deserialize)]
-pub struct Task {
- #[serde(default)]
- pub depends: Vec<String>,
- pub run: String,
-}
+use crate::types::*;
#[derive(Clone, Debug, Deserialize)]
pub struct Recipe {
diff --git a/src/resolve.rs b/src/resolve.rs
new file mode 100644
index 0000000..0758016
--- /dev/null
+++ b/src/resolve.rs
@@ -0,0 +1,91 @@
+use std::collections::{HashMap, HashSet};
+use std::fmt;
+
+use crate::types::*;
+
+#[derive(Debug)]
+pub enum Error {
+ TaskNotFound(TaskRef),
+ DependencyCycle(TaskRef),
+}
+
+impl fmt::Display for Error {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match self {
+ Error::TaskNotFound(id) => write!(f, "Task not found: {}", id),
+ Error::DependencyCycle(id) => write!(f, "DependencyCycle: {}", id),
+ }
+ }
+}
+
+impl std::error::Error for Error {}
+
+pub type Result<T> = std::result::Result<T, Error>;
+
+#[derive(PartialEq)]
+enum ResolveState {
+ Resolving,
+ Resolved,
+}
+
+pub struct Resolver<'a> {
+ tasks: &'a TaskMap,
+ resolve_state: HashMap<TaskRef, ResolveState>,
+}
+
+impl<'a> Resolver<'a> {
+ pub fn new(tasks: &'a TaskMap) -> Self {
+ Resolver {
+ tasks: tasks,
+ resolve_state: HashMap::new(),
+ }
+ }
+
+ pub fn add_goal(&mut self, task: &TaskRef) -> Result<()> {
+ match self.resolve_state.get(task) {
+ Some(ResolveState::Resolving) => return Err(Error::DependencyCycle(task.clone())),
+ Some(ResolveState::Resolved) => return Ok(()),
+ None => (),
+ }
+
+ let task_def = match self.tasks.get(task) {
+ None => return Err(Error::TaskNotFound(task.clone())),
+ Some(task_def) => task_def,
+ };
+
+ self.resolve_state
+ .insert(task.clone(), ResolveState::Resolving);
+
+ for dep in &task_def.depends {
+ let res = self.add_goal(dep);
+ if res.is_err() {
+ self.resolve_state.remove(task);
+ return res;
+ }
+ }
+
+ *self
+ .resolve_state
+ .get_mut(task)
+ .expect("Missing resolve_state") = ResolveState::Resolved;
+
+ Ok(())
+ }
+
+ pub fn to_taskset(self) -> HashSet<TaskRef> {
+ fn tasks_resolved(this: &Resolver) -> bool {
+ for (_, resolved) in &this.resolve_state {
+ if *resolved != ResolveState::Resolved {
+ return false;
+ }
+ }
+ true
+ }
+ debug_assert!(tasks_resolved(&self));
+
+ self.resolve_state
+ .into_iter()
+ .map(|entry| entry.0)
+ .collect()
+ }
+}
diff --git a/src/types.rs b/src/types.rs
new file mode 100644
index 0000000..cb777f7
--- /dev/null
+++ b/src/types.rs
@@ -0,0 +1,20 @@
+use serde::Deserialize;
+use std::collections::HashMap;
+
+pub type TaskRef = String;
+
+#[derive(Clone, Debug, Deserialize)]
+pub struct Task {
+ #[serde(default)]
+ pub depends: Vec<TaskRef>,
+ pub run: String,
+}
+
+#[derive(Default)]
+pub struct TaskMap(pub HashMap<String, Task>);
+
+impl TaskMap {
+ pub fn get(&self, task: &TaskRef) -> Option<&Task> {
+ self.0.get(task)
+ }
+}