summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/jrummikub/ai/fdsolver/Solver.java81
-rw-r--r--src/jrummikub/ai/fdsolver/Var.java40
-rw-r--r--test/jrummikub/ai/fdsolver/SolverTest.java10
3 files changed, 106 insertions, 25 deletions
diff --git a/src/jrummikub/ai/fdsolver/Solver.java b/src/jrummikub/ai/fdsolver/Solver.java
index 6660467..c47b1c1 100644
--- a/src/jrummikub/ai/fdsolver/Solver.java
+++ b/src/jrummikub/ai/fdsolver/Solver.java
@@ -3,11 +3,13 @@ package jrummikub.ai.fdsolver;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
+import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
+import static jrummikub.ai.fdsolver.Satisfiability.*;
public class Solver {
Set<Var<?>> vars = new HashSet<Var<?>>();
@@ -26,6 +28,7 @@ public class Solver {
Map<Var<?>, HashSet<?>> invalidatedValues = new HashMap<Var<?>, HashSet<?>>();
Var<?> branchVar;
Object branchValue;
+ HashSet<Constraint> finishedConstraints = new HashSet<Constraint>();
public <T> void invalidate(Var<T> var, T invalid) {
HashSet<T> values = (HashSet<T>) invalidatedValues.get(var);
@@ -47,16 +50,19 @@ public class Solver {
stack.add(new StackFrame());
}
+ private boolean isSolved() {
+ return dirtyVars.isEmpty() && unsolvedVars.isEmpty();
+ }
+
public boolean solve() {
do {
solveStep();
- } while (!(contradiction || (dirtyVars.isEmpty() && unsolvedVars
- .isEmpty())));
+ } while (!(contradiction || isSolved()));
return !contradiction;
}
public void solveStep() {
- if (unsolvedVars.isEmpty() && dirtyVars.isEmpty()) {
+ if (isSolved()) {
contradiction = true;
} else {
propagateAll();
@@ -74,17 +80,25 @@ public class Solver {
}
public void propagateOnce() {
- Iterator<Var<?>> it = dirtyVars.iterator();
- Var<?> dirtyVar = it.next();
- it.remove();
- outerLoop: for (Constraint constraint : constraints) {
- if (constraint.getWatchedVars().contains(dirtyVar)) {
- for (Propagator propagator : constraint.getPropagators(false)) {
- if (propagator.getWatchedVars().contains(dirtyVar)) {
- propagator.propagate();
- if (contradiction) {
- break outerLoop;
- }
+ Var<?> dirtyVar = Collections.max(dirtyVars);
+ dirtyVars.remove(dirtyVar);
+ outerLoop: for (Iterator<Constraint> i = dirtyVar.getConstraints().iterator(); i.hasNext();) {
+ Constraint constraint = i.next();
+ Satisfiability sat = constraint.getSatisfiability();
+ if (sat == UNSAT) {
+ contradiction = true;
+ break;
+ }
+ if (sat == TAUT) {
+ i.remove();
+ finishedConstraint(constraint, dirtyVar);
+ continue;
+ }
+ for (Propagator propagator : constraint.getPropagators(false)) {
+ if (propagator.getWatchedVars().contains(dirtyVar)) {
+ propagator.propagate();
+ if (contradiction) {
+ break outerLoop;
}
}
}
@@ -109,6 +123,7 @@ public class Solver {
for (Var<?> var : constraint.getWatchedVars()) {
var.makeDirty();
+ var.getConstraints().add(constraint);
}
}
@@ -120,7 +135,14 @@ public class Solver {
@SuppressWarnings("unchecked")
void branchOn(Var<?> var) {
- Object value = var.getRange().iterator().next();
+
+ Set<?> range = var.getRange();
+ int n = (int)(Math.random() * range.size());
+ Iterator<?> it = range.iterator();
+ for (int i = 0; i < n; i++) {
+ it.next();
+ }
+ Object value = it.next();
branchWith((Var<Object>) var, value);
}
@@ -143,6 +165,16 @@ public class Solver {
}
}
+ private void finishedConstraint(Constraint constraint, Var<?> currentVar) {
+ for (Var<?> var : constraint.getWatchedVars()) {
+ if (var != currentVar) {
+ var.getConstraints().remove(constraint);
+ }
+ }
+ constraints.remove(constraint);
+ getTopStackFrame().finishedConstraints.add(constraint);
+ }
+
@SuppressWarnings("unchecked")
// This would need rank-2 types which java lacks
private void rollback(StackFrame item) {
@@ -156,6 +188,13 @@ public class Solver {
} else {
unsolvedVars.remove(var);
}
+ var.makeDirty(); // TODO think a bit more about this
+ }
+ for (Constraint constraint : item.finishedConstraints) {
+ for (Var<?> var : constraint.getWatchedVars()) {
+ var.getConstraints().add(constraint);
+ }
+ constraints.add(constraint);
}
}
@@ -191,5 +230,17 @@ public class Solver {
public <T> Var<T> makeVar(T... range) {
return makeVar(Arrays.asList(range));
}
+
+ public void record() {
+ for (Var<?> var : vars) {
+ var.record();
+ }
+ }
+
+ public void restore() {
+ for (Var<?> var : vars) {
+ var.restore();
+ }
+ }
}
diff --git a/src/jrummikub/ai/fdsolver/Var.java b/src/jrummikub/ai/fdsolver/Var.java
index eb6f432..e023f34 100644
--- a/src/jrummikub/ai/fdsolver/Var.java
+++ b/src/jrummikub/ai/fdsolver/Var.java
@@ -4,13 +4,16 @@ import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
-public class Var<T> {
+public class Var<T> implements Comparable<Var<T>> {
private Solver solver;
private HashSet<T> range;
+ private HashSet<Constraint> constraints;
+ private T recorded;
public Var(Solver solver, Collection<T> range) {
this.solver = solver;
this.range = new HashSet<T>(range);
+ constraints = new HashSet<Constraint>();
}
public T getValue() {
@@ -23,7 +26,7 @@ public class Var<T> {
return range;
}
- public void choose(T value) {
+ void choose(T value) {
for (Iterator<T> i = this.iterator(); i.hasNext();) {
if (i.next() != value) {
i.remove();
@@ -31,15 +34,19 @@ public class Var<T> {
}
}
- public void makeDirty() {
+ void makeDirty() {
this.solver.dirtyVars.add(this);
}
- public void invalidate(T value) {
+ void invalidate(T value) {
range.remove(value);
solver.logInvalidation(this, value);
makeDirty();
}
+
+ HashSet<Constraint> getConstraints() {
+ return constraints;
+ }
public Iterator<T> iterator() {
final Iterator<T> iterator = range.iterator();
@@ -69,10 +76,35 @@ public class Var<T> {
}
};
}
+
+ void record() {
+ recorded = getValue();
+ }
+
+ void restore() {
+ range.clear();
+ range.add(recorded);
+ }
@Override
public String toString() {
return "Var" + range;
}
+
+ private int neighborCount() {
+ int count = 0;
+ for (Constraint constraint : constraints) {
+ count += constraint.getWatchedVars().size();
+ }
+ return count;
+ }
+
+ @Override
+ public int compareTo(Var<T> other) {
+ int rangeCompare = ((Integer)range.size()).compareTo(other.range.size());
+ if (rangeCompare != 0)
+ return rangeCompare;
+ return ((Integer)neighborCount()).compareTo(other.neighborCount());
+ }
}
diff --git a/test/jrummikub/ai/fdsolver/SolverTest.java b/test/jrummikub/ai/fdsolver/SolverTest.java
index 535ef57..ae32b07 100644
--- a/test/jrummikub/ai/fdsolver/SolverTest.java
+++ b/test/jrummikub/ai/fdsolver/SolverTest.java
@@ -16,14 +16,12 @@ public class SolverTest {
solver.addConstraint(new LessThan<Integer>(false, y, x));
- int lastx = 0, lasty = 0;
while (solver.solve()) {
- lastx = x.getValue();
- lasty = y.getValue();
+ solver.record();
solver.addConstraint(new LessThanConst<Integer>(false, x, x.getValue()));
}
-
- assertEquals(2, lastx);
- assertEquals(1, lasty);
+ solver.restore();
+ assertEquals(2, (int)x.getValue());
+ assertEquals(1, (int)y.getValue());
}
}