Finished solver core

git-svn-id: svn://sunsvr01.isp.uni-luebeck.de/swproj13/trunk@428 72836036-5685-4462-b002-a69064685172
This commit is contained in:
Jannis Harder 2011-06-14 00:10:38 +02:00
parent d223eaaa98
commit fc07d3bca6
3 changed files with 106 additions and 25 deletions

View file

@ -3,11 +3,13 @@ package jrummikub.ai.fdsolver;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import static jrummikub.ai.fdsolver.Satisfiability.*;
public class Solver {
Set<Var<?>> vars = new HashSet<Var<?>>();
@ -26,6 +28,7 @@ public class Solver {
Map<Var<?>, HashSet<?>> invalidatedValues = new HashMap<Var<?>, HashSet<?>>();
Var<?> branchVar;
Object branchValue;
HashSet<Constraint> finishedConstraints = new HashSet<Constraint>();
public <T> void invalidate(Var<T> var, T invalid) {
HashSet<T> values = (HashSet<T>) invalidatedValues.get(var);
@ -47,16 +50,19 @@ public class Solver {
stack.add(new StackFrame());
}
private boolean isSolved() {
return dirtyVars.isEmpty() && unsolvedVars.isEmpty();
}
public boolean solve() {
do {
solveStep();
} while (!(contradiction || (dirtyVars.isEmpty() && unsolvedVars
.isEmpty())));
} while (!(contradiction || isSolved()));
return !contradiction;
}
public void solveStep() {
if (unsolvedVars.isEmpty() && dirtyVars.isEmpty()) {
if (isSolved()) {
contradiction = true;
} else {
propagateAll();
@ -74,17 +80,25 @@ public class Solver {
}
public void propagateOnce() {
Iterator<Var<?>> it = dirtyVars.iterator();
Var<?> dirtyVar = it.next();
it.remove();
outerLoop: for (Constraint constraint : constraints) {
if (constraint.getWatchedVars().contains(dirtyVar)) {
for (Propagator propagator : constraint.getPropagators(false)) {
if (propagator.getWatchedVars().contains(dirtyVar)) {
propagator.propagate();
if (contradiction) {
break outerLoop;
}
Var<?> dirtyVar = Collections.max(dirtyVars);
dirtyVars.remove(dirtyVar);
outerLoop: for (Iterator<Constraint> i = dirtyVar.getConstraints().iterator(); i.hasNext();) {
Constraint constraint = i.next();
Satisfiability sat = constraint.getSatisfiability();
if (sat == UNSAT) {
contradiction = true;
break;
}
if (sat == TAUT) {
i.remove();
finishedConstraint(constraint, dirtyVar);
continue;
}
for (Propagator propagator : constraint.getPropagators(false)) {
if (propagator.getWatchedVars().contains(dirtyVar)) {
propagator.propagate();
if (contradiction) {
break outerLoop;
}
}
}
@ -109,6 +123,7 @@ public class Solver {
for (Var<?> var : constraint.getWatchedVars()) {
var.makeDirty();
var.getConstraints().add(constraint);
}
}
@ -120,7 +135,14 @@ public class Solver {
@SuppressWarnings("unchecked")
void branchOn(Var<?> var) {
Object value = var.getRange().iterator().next();
Set<?> range = var.getRange();
int n = (int)(Math.random() * range.size());
Iterator<?> it = range.iterator();
for (int i = 0; i < n; i++) {
it.next();
}
Object value = it.next();
branchWith((Var<Object>) var, value);
}
@ -143,6 +165,16 @@ public class Solver {
}
}
private void finishedConstraint(Constraint constraint, Var<?> currentVar) {
for (Var<?> var : constraint.getWatchedVars()) {
if (var != currentVar) {
var.getConstraints().remove(constraint);
}
}
constraints.remove(constraint);
getTopStackFrame().finishedConstraints.add(constraint);
}
@SuppressWarnings("unchecked")
// This would need rank-2 types which java lacks
private void rollback(StackFrame item) {
@ -156,6 +188,13 @@ public class Solver {
} else {
unsolvedVars.remove(var);
}
var.makeDirty(); // TODO think a bit more about this
}
for (Constraint constraint : item.finishedConstraints) {
for (Var<?> var : constraint.getWatchedVars()) {
var.getConstraints().add(constraint);
}
constraints.add(constraint);
}
}
@ -191,5 +230,17 @@ public class Solver {
public <T> Var<T> makeVar(T... range) {
return makeVar(Arrays.asList(range));
}
public void record() {
for (Var<?> var : vars) {
var.record();
}
}
public void restore() {
for (Var<?> var : vars) {
var.restore();
}
}
}

View file

@ -4,13 +4,16 @@ import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
public class Var<T> {
public class Var<T> implements Comparable<Var<T>> {
private Solver solver;
private HashSet<T> range;
private HashSet<Constraint> constraints;
private T recorded;
public Var(Solver solver, Collection<T> range) {
this.solver = solver;
this.range = new HashSet<T>(range);
constraints = new HashSet<Constraint>();
}
public T getValue() {
@ -23,7 +26,7 @@ public class Var<T> {
return range;
}
public void choose(T value) {
void choose(T value) {
for (Iterator<T> i = this.iterator(); i.hasNext();) {
if (i.next() != value) {
i.remove();
@ -31,15 +34,19 @@ public class Var<T> {
}
}
public void makeDirty() {
void makeDirty() {
this.solver.dirtyVars.add(this);
}
public void invalidate(T value) {
void invalidate(T value) {
range.remove(value);
solver.logInvalidation(this, value);
makeDirty();
}
HashSet<Constraint> getConstraints() {
return constraints;
}
public Iterator<T> iterator() {
final Iterator<T> iterator = range.iterator();
@ -69,10 +76,35 @@ public class Var<T> {
}
};
}
void record() {
recorded = getValue();
}
void restore() {
range.clear();
range.add(recorded);
}
@Override
public String toString() {
return "Var" + range;
}
private int neighborCount() {
int count = 0;
for (Constraint constraint : constraints) {
count += constraint.getWatchedVars().size();
}
return count;
}
@Override
public int compareTo(Var<T> other) {
int rangeCompare = ((Integer)range.size()).compareTo(other.range.size());
if (rangeCompare != 0)
return rangeCompare;
return ((Integer)neighborCount()).compareTo(other.neighborCount());
}
}

View file

@ -16,14 +16,12 @@ public class SolverTest {
solver.addConstraint(new LessThan<Integer>(false, y, x));
int lastx = 0, lasty = 0;
while (solver.solve()) {
lastx = x.getValue();
lasty = y.getValue();
solver.record();
solver.addConstraint(new LessThanConst<Integer>(false, x, x.getValue()));
}
assertEquals(2, lastx);
assertEquals(1, lasty);
solver.restore();
assertEquals(2, (int)x.getValue());
assertEquals(1, (int)y.getValue());
}
}