diff options
Diffstat (limited to 'src/jrummikub/ai/fdsolver/Solver.java')
-rw-r--r-- | src/jrummikub/ai/fdsolver/Solver.java | 192 |
1 files changed, 181 insertions, 11 deletions
diff --git a/src/jrummikub/ai/fdsolver/Solver.java b/src/jrummikub/ai/fdsolver/Solver.java index 1164c36..6660467 100644 --- a/src/jrummikub/ai/fdsolver/Solver.java +++ b/src/jrummikub/ai/fdsolver/Solver.java @@ -1,25 +1,195 @@ package jrummikub.ai.fdsolver; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.Map; +import java.util.Set; + public class Solver { + Set<Var<?>> vars = new HashSet<Var<?>>(); + + Set<Var<?>> dirtyVars = new HashSet<Var<?>>(); + + Set<Var<?>> unsolvedVars = new HashSet<Var<?>>(); + + Set<Constraint> constraints = new HashSet<Constraint>(); + + ArrayList<StackFrame> stack = new ArrayList<StackFrame>(); + + boolean contradiction = false; - public void add(Constraint constraint) { - // TODO Auto-generated method stub - + static private class StackFrame { + Map<Var<?>, HashSet<?>> invalidatedValues = new HashMap<Var<?>, HashSet<?>>(); + Var<?> branchVar; + Object branchValue; + + public <T> void invalidate(Var<T> var, T invalid) { + HashSet<T> values = (HashSet<T>) invalidatedValues.get(var); + if (values == null) { + invalidatedValues.put(var, + new HashSet<T>(Arrays.asList(invalid))); + } else { + values.add(invalid); + } + } + + @Override + public String toString() { + return "StackItem [invalidatedValues=" + invalidatedValues + "]"; + } + } + + public Solver() { + stack.add(new StackFrame()); } public boolean solve() { - // TODO Auto-generated method stub - return false; + do { + solveStep(); + } while (!(contradiction || (dirtyVars.isEmpty() && unsolvedVars + .isEmpty()))); + return !contradiction; + } + + public void solveStep() { + if (unsolvedVars.isEmpty() && dirtyVars.isEmpty()) { + contradiction = true; + } else { + propagateAll(); + } + if (contradiction) { + if (stack.size() == 1) { + return; + } + backtrack(); + } else if (unsolvedVars.isEmpty()) { + return; + } else { + branch(); + } + } + + public void propagateOnce() { + Iterator<Var<?>> it = dirtyVars.iterator(); + Var<?> dirtyVar = it.next(); + it.remove(); + outerLoop: for (Constraint constraint : constraints) { + if (constraint.getWatchedVars().contains(dirtyVar)) { + for (Propagator propagator : constraint.getPropagators(false)) { + if (propagator.getWatchedVars().contains(dirtyVar)) { + propagator.propagate(); + if (contradiction) { + break outerLoop; + } + } + } + } + } + } + + public void propagateAll() { + while (!(dirtyVars.isEmpty() || contradiction)) { + propagateOnce(); + } + } + + public void addVar(Var<?> var) { + vars.add(var); + if (var.getRange().size() != 1) { + unsolvedVars.add(var); + } + } + + public void addConstraint(Constraint constraint) { + constraints.add(constraint); + + for (Var<?> var : constraint.getWatchedVars()) { + var.makeDirty(); + } + } + + // backtracking and logging + + void branch() { + branchOn(unsolvedVars.iterator().next()); + } + + @SuppressWarnings("unchecked") + void branchOn(Var<?> var) { + Object value = var.getRange().iterator().next(); + branchWith((Var<Object>) var, value); + } + + <T> void branchWith(Var<T> var, T value) { + StackFrame stackFrame = new StackFrame(); + stackFrame.branchVar = var; + stackFrame.branchValue = value; + stack.add(stackFrame); + var.choose(value); + } + + private StackFrame getTopStackFrame() { + return stack.get(stack.size() - 1); + } + + <T> void logInvalidation(Var<T> var, T invalid) { + getTopStackFrame().invalidate(var, invalid); + if (var.getRange().size() == 1) { + unsolvedVars.remove(var); + } + } + + @SuppressWarnings("unchecked") + // This would need rank-2 types which java lacks + private void rollback(StackFrame item) { + for (Map.Entry<Var<?>, HashSet<?>> entry : item.invalidatedValues + .entrySet()) { + Var<Object> var = (Var<Object>) entry.getKey(); + HashSet<Object> values = (HashSet<Object>) entry.getValue(); + var.getRange().addAll(values); + if (var.getRange().size() != 1) { + unsolvedVars.add(var); + } else { + unsolvedVars.remove(var); + } + } + } + + @SuppressWarnings("unchecked") + void backtrack() { + contradiction = false; + StackFrame topFrame = getTopStackFrame(); + rollback(topFrame); + stack.remove(stack.size() - 1); + ((Var<Object>) topFrame.branchVar).invalidate(topFrame.branchValue); + } + + // factory methods for vars + + public Var<Integer> makeRangeVar(int low, int high) { + ArrayList<Integer> range = new ArrayList<Integer>(); + for (int i = low; i <= high; i++) { + range.add(i); + } + return makeVar(range); + } + + public Var<Boolean> makeBoolVar() { + return makeVar(true, false); } - public void push() { - // TODO Auto-generated method stub - + public <T> Var<T> makeVar(Collection<T> range) { + Var<T> var = new Var<T>(this, range); + addVar(var); + return var; } - public void pop() { - // TODO Auto-generated method stub - + public <T> Var<T> makeVar(T... range) { + return makeVar(Arrays.asList(range)); } } |