package jrummikub.ai.fdsolver; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.Iterator; import java.util.Map; import java.util.Set; public class Solver { Set> vars = new HashSet>(); Set> dirtyVars = new HashSet>(); Set> unsolvedVars = new HashSet>(); Set constraints = new HashSet(); ArrayList stack = new ArrayList(); boolean contradiction = false; static private class StackFrame { Map, HashSet> invalidatedValues = new HashMap, HashSet>(); Var branchVar; Object branchValue; public void invalidate(Var var, T invalid) { HashSet values = (HashSet) invalidatedValues.get(var); if (values == null) { invalidatedValues.put(var, new HashSet(Arrays.asList(invalid))); } else { values.add(invalid); } } @Override public String toString() { return "StackItem [invalidatedValues=" + invalidatedValues + "]"; } } public Solver() { stack.add(new StackFrame()); } public boolean solve() { do { solveStep(); } while (!(contradiction || (dirtyVars.isEmpty() && unsolvedVars .isEmpty()))); return !contradiction; } public void solveStep() { if (unsolvedVars.isEmpty() && dirtyVars.isEmpty()) { contradiction = true; } else { propagateAll(); } if (contradiction) { if (stack.size() == 1) { return; } backtrack(); } else if (unsolvedVars.isEmpty()) { return; } else { branch(); } } public void propagateOnce() { Iterator> it = dirtyVars.iterator(); Var dirtyVar = it.next(); it.remove(); outerLoop: for (Constraint constraint : constraints) { if (constraint.getWatchedVars().contains(dirtyVar)) { for (Propagator propagator : constraint.getPropagators(false)) { if (propagator.getWatchedVars().contains(dirtyVar)) { propagator.propagate(); if (contradiction) { break outerLoop; } } } } } } public void propagateAll() { while (!(dirtyVars.isEmpty() || contradiction)) { propagateOnce(); } } public void addVar(Var var) { vars.add(var); if (var.getRange().size() != 1) { unsolvedVars.add(var); } } public void addConstraint(Constraint constraint) { constraints.add(constraint); for (Var var : constraint.getWatchedVars()) { var.makeDirty(); } } // backtracking and logging void branch() { branchOn(unsolvedVars.iterator().next()); } @SuppressWarnings("unchecked") void branchOn(Var var) { Object value = var.getRange().iterator().next(); branchWith((Var) var, value); } void branchWith(Var var, T value) { StackFrame stackFrame = new StackFrame(); stackFrame.branchVar = var; stackFrame.branchValue = value; stack.add(stackFrame); var.choose(value); } private StackFrame getTopStackFrame() { return stack.get(stack.size() - 1); } void logInvalidation(Var var, T invalid) { getTopStackFrame().invalidate(var, invalid); if (var.getRange().size() == 1) { unsolvedVars.remove(var); } } @SuppressWarnings("unchecked") // This would need rank-2 types which java lacks private void rollback(StackFrame item) { for (Map.Entry, HashSet> entry : item.invalidatedValues .entrySet()) { Var var = (Var) entry.getKey(); HashSet values = (HashSet) entry.getValue(); var.getRange().addAll(values); if (var.getRange().size() != 1) { unsolvedVars.add(var); } else { unsolvedVars.remove(var); } } } @SuppressWarnings("unchecked") void backtrack() { contradiction = false; StackFrame topFrame = getTopStackFrame(); rollback(topFrame); stack.remove(stack.size() - 1); ((Var) topFrame.branchVar).invalidate(topFrame.branchValue); } // factory methods for vars public Var makeRangeVar(int low, int high) { ArrayList range = new ArrayList(); for (int i = low; i <= high; i++) { range.add(i); } return makeVar(range); } public Var makeBoolVar() { return makeVar(true, false); } public Var makeVar(Collection range) { Var var = new Var(this, range); addVar(var); return var; } public Var makeVar(T... range) { return makeVar(Arrays.asList(range)); } }