package jrummikub.ai.fdsolver; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.Map; import java.util.Set; import static jrummikub.ai.fdsolver.Satisfiability.*; public class Solver { Set> vars = new HashSet>(); Set> dirtyVars = new HashSet>(); Set> unsolvedVars = new HashSet>(); Set constraints = new HashSet(); ArrayList stack = new ArrayList(); boolean contradiction = false; static private class StackFrame { Map, HashSet> invalidatedValues = new HashMap, HashSet>(); Var branchVar; Object branchValue; HashSet finishedConstraints = new HashSet(); public void invalidate(Var var, T invalid) { HashSet values = (HashSet) invalidatedValues.get(var); if (values == null) { invalidatedValues.put(var, new HashSet(Arrays.asList(invalid))); } else { values.add(invalid); } } @Override public String toString() { return "StackItem [invalidatedValues=" + invalidatedValues + "]"; } } public Solver() { stack.add(new StackFrame()); } private boolean isSolved() { return dirtyVars.isEmpty() && unsolvedVars.isEmpty(); } public boolean solve() { do { solveStep(); } while (!(contradiction || isSolved())); return !contradiction; } public void solveStep() { if (isSolved()) { contradiction = true; } else { propagateAll(); } if (contradiction) { if (stack.size() == 1) { return; } backtrack(); } else if (unsolvedVars.isEmpty()) { return; } else { branch(); } } public void propagateOnce() { Var dirtyVar = Collections.max(dirtyVars); dirtyVars.remove(dirtyVar); outerLoop: for (Iterator i = dirtyVar.getConstraints().iterator(); i.hasNext();) { Constraint constraint = i.next(); Satisfiability sat = constraint.getSatisfiability(); if (sat == UNSAT) { contradiction = true; break; } if (sat == TAUT) { i.remove(); finishedConstraint(constraint, dirtyVar); continue; } for (Propagator propagator : constraint.getPropagators(false)) { if (propagator.getWatchedVars().contains(dirtyVar)) { propagator.propagate(); if (contradiction) { break outerLoop; } } } } } public void propagateAll() { while (!(dirtyVars.isEmpty() || contradiction)) { propagateOnce(); } } public void addVar(Var var) { vars.add(var); if (var.getRange().size() != 1) { unsolvedVars.add(var); } } public void addConstraint(Constraint constraint) { constraints.add(constraint); for (Var var : constraint.getWatchedVars()) { var.makeDirty(); var.getConstraints().add(constraint); } } // backtracking and logging void branch() { branchOn(unsolvedVars.iterator().next()); } @SuppressWarnings("unchecked") void branchOn(Var var) { Set range = var.getRange(); int n = (int)(Math.random() * range.size()); Iterator it = range.iterator(); for (int i = 0; i < n; i++) { it.next(); } Object value = it.next(); branchWith((Var) var, value); } void branchWith(Var var, T value) { StackFrame stackFrame = new StackFrame(); stackFrame.branchVar = var; stackFrame.branchValue = value; stack.add(stackFrame); var.choose(value); } private StackFrame getTopStackFrame() { return stack.get(stack.size() - 1); } void logInvalidation(Var var, T invalid) { getTopStackFrame().invalidate(var, invalid); if (var.getRange().size() == 1) { unsolvedVars.remove(var); } } private void finishedConstraint(Constraint constraint, Var currentVar) { for (Var var : constraint.getWatchedVars()) { if (var != currentVar) { var.getConstraints().remove(constraint); } } constraints.remove(constraint); getTopStackFrame().finishedConstraints.add(constraint); } @SuppressWarnings("unchecked") // This would need rank-2 types which java lacks private void rollback(StackFrame item) { for (Map.Entry, HashSet> entry : item.invalidatedValues .entrySet()) { Var var = (Var) entry.getKey(); HashSet values = (HashSet) entry.getValue(); var.getRange().addAll(values); if (var.getRange().size() != 1) { unsolvedVars.add(var); } else { unsolvedVars.remove(var); } var.makeDirty(); // TODO think a bit more about this } for (Constraint constraint : item.finishedConstraints) { for (Var var : constraint.getWatchedVars()) { var.getConstraints().add(constraint); } constraints.add(constraint); } } @SuppressWarnings("unchecked") void backtrack() { contradiction = false; StackFrame topFrame = getTopStackFrame(); rollback(topFrame); stack.remove(stack.size() - 1); ((Var) topFrame.branchVar).invalidate(topFrame.branchValue); } // factory methods for vars public Var makeRangeVar(int low, int high) { ArrayList range = new ArrayList(); for (int i = low; i <= high; i++) { range.add(i); } return makeVar(range); } public Var makeBoolVar() { return makeVar(true, false); } public Var makeVar(Collection range) { Var var = new Var(this, range); addVar(var); return var; } public Var makeVar(T... range) { return makeVar(Arrays.asList(range)); } public void record() { for (Var var : vars) { var.record(); } } public void restore() { for (Var var : vars) { var.restore(); } } }