summaryrefslogtreecommitdiffstats
path: root/src/jrummikub/ai/fdsolver/Solver.java
diff options
context:
space:
mode:
Diffstat (limited to 'src/jrummikub/ai/fdsolver/Solver.java')
-rw-r--r--src/jrummikub/ai/fdsolver/Solver.java192
1 files changed, 181 insertions, 11 deletions
diff --git a/src/jrummikub/ai/fdsolver/Solver.java b/src/jrummikub/ai/fdsolver/Solver.java
index 1164c36..6660467 100644
--- a/src/jrummikub/ai/fdsolver/Solver.java
+++ b/src/jrummikub/ai/fdsolver/Solver.java
@@ -1,25 +1,195 @@
package jrummikub.ai.fdsolver;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.Map;
+import java.util.Set;
+
public class Solver {
+ Set<Var<?>> vars = new HashSet<Var<?>>();
+
+ Set<Var<?>> dirtyVars = new HashSet<Var<?>>();
+
+ Set<Var<?>> unsolvedVars = new HashSet<Var<?>>();
+
+ Set<Constraint> constraints = new HashSet<Constraint>();
+
+ ArrayList<StackFrame> stack = new ArrayList<StackFrame>();
+
+ boolean contradiction = false;
- public void add(Constraint constraint) {
- // TODO Auto-generated method stub
-
+ static private class StackFrame {
+ Map<Var<?>, HashSet<?>> invalidatedValues = new HashMap<Var<?>, HashSet<?>>();
+ Var<?> branchVar;
+ Object branchValue;
+
+ public <T> void invalidate(Var<T> var, T invalid) {
+ HashSet<T> values = (HashSet<T>) invalidatedValues.get(var);
+ if (values == null) {
+ invalidatedValues.put(var,
+ new HashSet<T>(Arrays.asList(invalid)));
+ } else {
+ values.add(invalid);
+ }
+ }
+
+ @Override
+ public String toString() {
+ return "StackItem [invalidatedValues=" + invalidatedValues + "]";
+ }
+ }
+
+ public Solver() {
+ stack.add(new StackFrame());
}
public boolean solve() {
- // TODO Auto-generated method stub
- return false;
+ do {
+ solveStep();
+ } while (!(contradiction || (dirtyVars.isEmpty() && unsolvedVars
+ .isEmpty())));
+ return !contradiction;
+ }
+
+ public void solveStep() {
+ if (unsolvedVars.isEmpty() && dirtyVars.isEmpty()) {
+ contradiction = true;
+ } else {
+ propagateAll();
+ }
+ if (contradiction) {
+ if (stack.size() == 1) {
+ return;
+ }
+ backtrack();
+ } else if (unsolvedVars.isEmpty()) {
+ return;
+ } else {
+ branch();
+ }
+ }
+
+ public void propagateOnce() {
+ Iterator<Var<?>> it = dirtyVars.iterator();
+ Var<?> dirtyVar = it.next();
+ it.remove();
+ outerLoop: for (Constraint constraint : constraints) {
+ if (constraint.getWatchedVars().contains(dirtyVar)) {
+ for (Propagator propagator : constraint.getPropagators(false)) {
+ if (propagator.getWatchedVars().contains(dirtyVar)) {
+ propagator.propagate();
+ if (contradiction) {
+ break outerLoop;
+ }
+ }
+ }
+ }
+ }
+ }
+
+ public void propagateAll() {
+ while (!(dirtyVars.isEmpty() || contradiction)) {
+ propagateOnce();
+ }
+ }
+
+ public void addVar(Var<?> var) {
+ vars.add(var);
+ if (var.getRange().size() != 1) {
+ unsolvedVars.add(var);
+ }
+ }
+
+ public void addConstraint(Constraint constraint) {
+ constraints.add(constraint);
+
+ for (Var<?> var : constraint.getWatchedVars()) {
+ var.makeDirty();
+ }
+ }
+
+ // backtracking and logging
+
+ void branch() {
+ branchOn(unsolvedVars.iterator().next());
+ }
+
+ @SuppressWarnings("unchecked")
+ void branchOn(Var<?> var) {
+ Object value = var.getRange().iterator().next();
+ branchWith((Var<Object>) var, value);
+ }
+
+ <T> void branchWith(Var<T> var, T value) {
+ StackFrame stackFrame = new StackFrame();
+ stackFrame.branchVar = var;
+ stackFrame.branchValue = value;
+ stack.add(stackFrame);
+ var.choose(value);
+ }
+
+ private StackFrame getTopStackFrame() {
+ return stack.get(stack.size() - 1);
+ }
+
+ <T> void logInvalidation(Var<T> var, T invalid) {
+ getTopStackFrame().invalidate(var, invalid);
+ if (var.getRange().size() == 1) {
+ unsolvedVars.remove(var);
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ // This would need rank-2 types which java lacks
+ private void rollback(StackFrame item) {
+ for (Map.Entry<Var<?>, HashSet<?>> entry : item.invalidatedValues
+ .entrySet()) {
+ Var<Object> var = (Var<Object>) entry.getKey();
+ HashSet<Object> values = (HashSet<Object>) entry.getValue();
+ var.getRange().addAll(values);
+ if (var.getRange().size() != 1) {
+ unsolvedVars.add(var);
+ } else {
+ unsolvedVars.remove(var);
+ }
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ void backtrack() {
+ contradiction = false;
+ StackFrame topFrame = getTopStackFrame();
+ rollback(topFrame);
+ stack.remove(stack.size() - 1);
+ ((Var<Object>) topFrame.branchVar).invalidate(topFrame.branchValue);
+ }
+
+ // factory methods for vars
+
+ public Var<Integer> makeRangeVar(int low, int high) {
+ ArrayList<Integer> range = new ArrayList<Integer>();
+ for (int i = low; i <= high; i++) {
+ range.add(i);
+ }
+ return makeVar(range);
+ }
+
+ public Var<Boolean> makeBoolVar() {
+ return makeVar(true, false);
}
- public void push() {
- // TODO Auto-generated method stub
-
+ public <T> Var<T> makeVar(Collection<T> range) {
+ Var<T> var = new Var<T>(this, range);
+ addVar(var);
+ return var;
}
- public void pop() {
- // TODO Auto-generated method stub
-
+ public <T> Var<T> makeVar(T... range) {
+ return makeVar(Arrays.asList(range));
}
}