diff options
-rw-r--r-- | src/jrummikub/ai/fdsolver/Solver.java | 81 | ||||
-rw-r--r-- | src/jrummikub/ai/fdsolver/Var.java | 40 | ||||
-rw-r--r-- | test/jrummikub/ai/fdsolver/SolverTest.java | 10 |
3 files changed, 106 insertions, 25 deletions
diff --git a/src/jrummikub/ai/fdsolver/Solver.java b/src/jrummikub/ai/fdsolver/Solver.java index 6660467..c47b1c1 100644 --- a/src/jrummikub/ai/fdsolver/Solver.java +++ b/src/jrummikub/ai/fdsolver/Solver.java @@ -3,11 +3,13 @@ package jrummikub.ai.fdsolver; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.Map; import java.util.Set; +import static jrummikub.ai.fdsolver.Satisfiability.*; public class Solver { Set<Var<?>> vars = new HashSet<Var<?>>(); @@ -26,6 +28,7 @@ public class Solver { Map<Var<?>, HashSet<?>> invalidatedValues = new HashMap<Var<?>, HashSet<?>>(); Var<?> branchVar; Object branchValue; + HashSet<Constraint> finishedConstraints = new HashSet<Constraint>(); public <T> void invalidate(Var<T> var, T invalid) { HashSet<T> values = (HashSet<T>) invalidatedValues.get(var); @@ -47,16 +50,19 @@ public class Solver { stack.add(new StackFrame()); } + private boolean isSolved() { + return dirtyVars.isEmpty() && unsolvedVars.isEmpty(); + } + public boolean solve() { do { solveStep(); - } while (!(contradiction || (dirtyVars.isEmpty() && unsolvedVars - .isEmpty()))); + } while (!(contradiction || isSolved())); return !contradiction; } public void solveStep() { - if (unsolvedVars.isEmpty() && dirtyVars.isEmpty()) { + if (isSolved()) { contradiction = true; } else { propagateAll(); @@ -74,17 +80,25 @@ public class Solver { } public void propagateOnce() { - Iterator<Var<?>> it = dirtyVars.iterator(); - Var<?> dirtyVar = it.next(); - it.remove(); - outerLoop: for (Constraint constraint : constraints) { - if (constraint.getWatchedVars().contains(dirtyVar)) { - for (Propagator propagator : constraint.getPropagators(false)) { - if (propagator.getWatchedVars().contains(dirtyVar)) { - propagator.propagate(); - if (contradiction) { - break outerLoop; - } + Var<?> dirtyVar = Collections.max(dirtyVars); + dirtyVars.remove(dirtyVar); + outerLoop: for (Iterator<Constraint> i = dirtyVar.getConstraints().iterator(); i.hasNext();) { + Constraint constraint = i.next(); + Satisfiability sat = constraint.getSatisfiability(); + if (sat == UNSAT) { + contradiction = true; + break; + } + if (sat == TAUT) { + i.remove(); + finishedConstraint(constraint, dirtyVar); + continue; + } + for (Propagator propagator : constraint.getPropagators(false)) { + if (propagator.getWatchedVars().contains(dirtyVar)) { + propagator.propagate(); + if (contradiction) { + break outerLoop; } } } @@ -109,6 +123,7 @@ public class Solver { for (Var<?> var : constraint.getWatchedVars()) { var.makeDirty(); + var.getConstraints().add(constraint); } } @@ -120,7 +135,14 @@ public class Solver { @SuppressWarnings("unchecked") void branchOn(Var<?> var) { - Object value = var.getRange().iterator().next(); + + Set<?> range = var.getRange(); + int n = (int)(Math.random() * range.size()); + Iterator<?> it = range.iterator(); + for (int i = 0; i < n; i++) { + it.next(); + } + Object value = it.next(); branchWith((Var<Object>) var, value); } @@ -143,6 +165,16 @@ public class Solver { } } + private void finishedConstraint(Constraint constraint, Var<?> currentVar) { + for (Var<?> var : constraint.getWatchedVars()) { + if (var != currentVar) { + var.getConstraints().remove(constraint); + } + } + constraints.remove(constraint); + getTopStackFrame().finishedConstraints.add(constraint); + } + @SuppressWarnings("unchecked") // This would need rank-2 types which java lacks private void rollback(StackFrame item) { @@ -156,6 +188,13 @@ public class Solver { } else { unsolvedVars.remove(var); } + var.makeDirty(); // TODO think a bit more about this + } + for (Constraint constraint : item.finishedConstraints) { + for (Var<?> var : constraint.getWatchedVars()) { + var.getConstraints().add(constraint); + } + constraints.add(constraint); } } @@ -191,5 +230,17 @@ public class Solver { public <T> Var<T> makeVar(T... range) { return makeVar(Arrays.asList(range)); } + + public void record() { + for (Var<?> var : vars) { + var.record(); + } + } + + public void restore() { + for (Var<?> var : vars) { + var.restore(); + } + } } diff --git a/src/jrummikub/ai/fdsolver/Var.java b/src/jrummikub/ai/fdsolver/Var.java index eb6f432..e023f34 100644 --- a/src/jrummikub/ai/fdsolver/Var.java +++ b/src/jrummikub/ai/fdsolver/Var.java @@ -4,13 +4,16 @@ import java.util.Collection; import java.util.HashSet; import java.util.Iterator; -public class Var<T> { +public class Var<T> implements Comparable<Var<T>> { private Solver solver; private HashSet<T> range; + private HashSet<Constraint> constraints; + private T recorded; public Var(Solver solver, Collection<T> range) { this.solver = solver; this.range = new HashSet<T>(range); + constraints = new HashSet<Constraint>(); } public T getValue() { @@ -23,7 +26,7 @@ public class Var<T> { return range; } - public void choose(T value) { + void choose(T value) { for (Iterator<T> i = this.iterator(); i.hasNext();) { if (i.next() != value) { i.remove(); @@ -31,15 +34,19 @@ public class Var<T> { } } - public void makeDirty() { + void makeDirty() { this.solver.dirtyVars.add(this); } - public void invalidate(T value) { + void invalidate(T value) { range.remove(value); solver.logInvalidation(this, value); makeDirty(); } + + HashSet<Constraint> getConstraints() { + return constraints; + } public Iterator<T> iterator() { final Iterator<T> iterator = range.iterator(); @@ -69,10 +76,35 @@ public class Var<T> { } }; } + + void record() { + recorded = getValue(); + } + + void restore() { + range.clear(); + range.add(recorded); + } @Override public String toString() { return "Var" + range; } + + private int neighborCount() { + int count = 0; + for (Constraint constraint : constraints) { + count += constraint.getWatchedVars().size(); + } + return count; + } + + @Override + public int compareTo(Var<T> other) { + int rangeCompare = ((Integer)range.size()).compareTo(other.range.size()); + if (rangeCompare != 0) + return rangeCompare; + return ((Integer)neighborCount()).compareTo(other.neighborCount()); + } } diff --git a/test/jrummikub/ai/fdsolver/SolverTest.java b/test/jrummikub/ai/fdsolver/SolverTest.java index 535ef57..ae32b07 100644 --- a/test/jrummikub/ai/fdsolver/SolverTest.java +++ b/test/jrummikub/ai/fdsolver/SolverTest.java @@ -16,14 +16,12 @@ public class SolverTest { solver.addConstraint(new LessThan<Integer>(false, y, x)); - int lastx = 0, lasty = 0; while (solver.solve()) { - lastx = x.getValue(); - lasty = y.getValue(); + solver.record(); solver.addConstraint(new LessThanConst<Integer>(false, x, x.getValue())); } - - assertEquals(2, lastx); - assertEquals(1, lasty); + solver.restore(); + assertEquals(2, (int)x.getValue()); + assertEquals(1, (int)y.getValue()); } } |