Commit 93666b21 authored by Marta Różańska's avatar Marta Różańska
Browse files

Merge pull request #591 in MEL/upperware from mcts-tree-trimming to ZPP2019Master

* commit 'ff10fb82':
  minor fixes
  Changed line memoryLimiter to better clarify what method does.
  Removed unnecessary assertion.
  Changed decision on which nodes are kept in queue to clarify what code does. Now all nodes are in queue.
  Removed whitespace.
  Changed the way nodes are removed from queue and the way whichNodeToPrune works (now it doesnt remove node).
  Memory limiter queue control when deleting empty solutions.
  Fixed comment in NodeImpl.
  Code style fixes according to Tomek's review.
  Some more fixes according to pr review.
  Removed one import.
  fixed according to pr review.
  Reversed changes to test.
  Small change in tree pruning.
  Fixed a few things. Added branch trimmer.
  Added test. Cleaned up classes.
  Moved code to memory_limit package. Cleaned up code. Removed bug.
  super raw commit.
  Added experimental trimming to tree.
parents e754554f ff10fb82
......@@ -27,6 +27,7 @@ public class MCTSSolver {
private double minTemperature;
private double maxTemperature;
private int iterations;
private final static int NODE_COUNT_LIMIT = 2000000;
private OneToManyChannel<Message, UtilityMessage> messageChannel;
private SolutionBuffer solutionBuffer = new SolutionBuffer();
private AvailablePolicies policyType;
......@@ -64,7 +65,7 @@ public class MCTSSolver {
private List<Thread> startWorkers(List<MCTSWrapper> mctsWrappers) {
return IntStream.range(0, numThreads).mapToObj(pid -> {
Thread thread = new Thread( () -> {
MCTSSingleTreeSolver mctsSingleTreeSolver = new MCTSSingleTreeSolver(minTemperature , 10, iterations, mctsWrappers.get(pid), policyType);
MCTSSingleTreeSolver mctsSingleTreeSolver = new MCTSSingleTreeSolver(minTemperature , 10, iterations, NODE_COUNT_LIMIT / numThreads, mctsWrappers.get(pid), policyType);
WorkerThread workerThread = new WorkerThread(pid, iterations, solutionBuffer, messageChannel, mctsSingleTreeSolver, SAVE_TREE);
workerThread.workerRun();
});
......
......@@ -2,10 +2,8 @@ package eu.melodic.upperware.mcts_solver.solver.mcts;
import cp_wrapper.solution.CpSolution;
import eu.melodic.upperware.mcts_solver.solver.mcts.cp_wrapper.MCTSWrapper;
import eu.melodic.upperware.mcts_solver.solver.mcts.tree.MoveProvider;
import eu.melodic.upperware.mcts_solver.solver.mcts.tree.Policy;
import eu.melodic.upperware.mcts_solver.solver.mcts.tree.Solution;
import eu.melodic.upperware.mcts_solver.solver.mcts.tree.Tree;
import eu.melodic.upperware.mcts_solver.solver.mcts.tree.*;
import eu.melodic.upperware.mcts_solver.solver.mcts.tree_impl.MemoryLimiterImpl;
import eu.melodic.upperware.mcts_solver.solver.mcts.tree_impl.MoveProviderImpl;
import eu.melodic.upperware.mcts_solver.solver.mcts.tree_impl.NodeStatisticsImpl;
import eu.melodic.upperware.mcts_solver.solver.mcts.tree_impl.TreeImpl;
......@@ -24,19 +22,21 @@ public class MCTSSingleTreeSolver {
private MCTSWrapper mctsWrapper;
private MoveProvider moveProvider;
private Policy policy;
private MemoryLimiter memoryLimiter;
@Getter
private Tree mctsTree;
public MCTSSingleTreeSolver(double selectorCoefficient, double explorationCoefficient, int iterations, MCTSWrapper mctsWrapper, AvailablePolicies policy) {
public MCTSSingleTreeSolver(double selectorCoefficient, double explorationCoefficient, int iterations, int nodeCountLimit, MCTSWrapper mctsWrapper, AvailablePolicies policy) {
this.selectorCoefficient = selectorCoefficient;
this.explorationCoefficient = explorationCoefficient;
this.iterations = iterations;
this.mctsWrapper = mctsWrapper;
moveProvider = new MoveProviderImpl(mctsWrapper);
this.memoryLimiter = new MemoryLimiterImpl(nodeCountLimit);
moveProvider = new MoveProviderImpl(mctsWrapper, memoryLimiter);
this.policy = mctsWrapper.createPolicy(policy);
updateParameters();
mctsTree = new TreeImpl(this.policy, moveProvider);
mctsTree = new TreeImpl(this.policy, moveProvider, memoryLimiter);
}
public CpSolution solve() {
......
package eu.melodic.upperware.mcts_solver.solver.mcts.tree;
public interface MemoryLimiter {
// Tells tree whether it should cut branch.
boolean shouldPruneTree();
// Tells tree which node to prune.
Node whichNodeToPrune();
// Marks nodes as recently accessed. Goes from bottom to root of tree.
void updateRecentlyAccessedNodes(Node startingNode);
// Creates node with certain value.
Node createNode(Node parent, int value);
// Removes node from queue, does nothing if it's not in queue.
void removeNodeFromQueue(Node node);
}
\ No newline at end of file
......@@ -9,9 +9,13 @@ public interface Node extends Comparable<Node> {
List<Node> getChildren();
void linkToTree(Node parent); // Called after creation of a Node in order to add it to a tree.
Node update(Solution solution); // Updates node based on solution.
int getChildrenSize();
void visit(); // Visits node and registers it in node statistics.
int getChildrenSize();
void addChild(Node child);
Node getBestChild();
boolean isExpanded();
void setExpanded();
void setUnexpanded();
void removeChild(Node child);
void removeChildren();
}
......@@ -6,4 +6,7 @@ public interface NodeStatistics {
void update(Solution solution); // Updates statistics after finding some path (solution).
void markNewVisit();
double getEvaluation(NodeStatistics parentStats); // Evaluates node.
boolean isExpanded();
void setExpanded();
void setUnexpanded();
}
......@@ -6,23 +6,25 @@ import org.javatuples.Pair;
import java.util.stream.IntStream;
@Slf4j
public abstract class Tree {
@Getter
protected Node root;
private Policy policy;
private MoveProvider moveProvider; // MoveProvider is responsible for both tree search and expansion.
private MemoryLimiter memoryLimiter; // Responsible for prevention of memory overflow by limiting tree size.
private final int minDepthSubtreeRemoval;
public Tree(Policy policy, MoveProvider moveProvider) {
public Tree(Policy policy, MoveProvider moveProvider, MemoryLimiter memoryLimiter) {
this.policy = policy;
this.moveProvider = moveProvider;
this.memoryLimiter = memoryLimiter;
this.minDepthSubtreeRemoval = policy.minDepthSubtreeRemoval();
}
public Solution run(int iterations) {
return IntStream.range(0, iterations)
.mapToObj(i ->runIteration())
.mapToObj(i -> runIteration())
.max(Solution::compareTo).get();
}
......@@ -50,23 +52,41 @@ public abstract class Tree {
Pair<Node, Path> state = searchAndExpand();
Node leaf = state.getValue0();
Path path = state.getValue1();
Solution solution = rollout(path);
backPropagate(leaf, solution);
memoryLimiter.updateRecentlyAccessedNodes(leaf);
if (solution.isEmpty() && leaf.getNodeStatistics().getDepth() > minDepthSubtreeRemoval) {
removeSubtreeWithNoSolutions((leaf));
}
while (memoryLimiter.shouldPruneTree()) {
removeLeaf(memoryLimiter.whichNodeToPrune());
}
return solution;
}
private void removeSubtreeWithNoSolutions(Node subtreeRoot) {
log.debug("Removing subtree at depth {}", subtreeRoot.getNodeStatistics().getDepth());
removeNode(subtreeRoot);
removeLeaf(subtreeRoot);
}
private void removeNode(Node node) {
private void removeLeaf(Node node) {
assert(node.getChildrenSize() == 0);
// Shouldn't happen if node limit is not very small.
if (node == root) {
root.setUnexpanded();
return;
}
node.getParent().removeChild(node);
if (node.getParent() != root && node.getParent().getChildrenSize() == 0) {
removeNode(node.getParent());
memoryLimiter.removeNodeFromQueue(node);
if (node.getParent().getChildrenSize() == 0) {
removeLeaf(node.getParent());
}
}
}
package eu.melodic.upperware.mcts_solver.solver.mcts.tree_impl;
import eu.melodic.upperware.mcts_solver.solver.mcts.tree.MemoryLimiter;
import eu.melodic.upperware.mcts_solver.solver.mcts.tree.Node;
public class MemoryLimiterImpl implements MemoryLimiter {
private int limit;
private int count = 0;
private Queue accessQueue = new Queue();
public MemoryLimiterImpl(int limit) {
this.limit = limit;
}
@Override
public boolean shouldPruneTree() {
return count > limit && !accessQueue.empty();
}
@Override
public Node whichNodeToPrune() {
return accessQueue.getFront();
}
@Override
public void updateRecentlyAccessedNodes(Node startingNode) {
Node current = startingNode;
while (current != null) {
if (current.isExpanded() && current.getParent() != null) { // If current is not leaf or root.
accessQueue.pushBack(current);
}
current = current.getParent();
}
}
@Override
public Node createNode(Node parent, int value) {
count++;
NodeImpl newNode = new NodeImpl(value);
newNode.linkToTree(parent);
accessQueue.pushBack(newNode);
return newNode;
}
@Override
public void removeNodeFromQueue(Node node) {
count--;
accessQueue.removeNodeFromQueue((NodeImpl) node);
}
}
package eu.melodic.upperware.mcts_solver.solver.mcts.tree_impl;
import eu.melodic.upperware.mcts_solver.solver.mcts.cp_wrapper.MCTSWrapper;
import eu.melodic.upperware.mcts_solver.solver.mcts.tree.MemoryLimiter;
import eu.melodic.upperware.mcts_solver.solver.mcts.tree.MoveProvider;
import eu.melodic.upperware.mcts_solver.solver.mcts.tree.Node;
import eu.melodic.upperware.mcts_solver.solver.mcts.tree.Path;
......@@ -12,6 +13,7 @@ import java.util.stream.IntStream;
@AllArgsConstructor
public class MoveProviderImpl implements MoveProvider {
private MCTSWrapper mctsWrapper;
private MemoryLimiter memoryLimiter;
@Override
public Pair<Node, Path> searchAndExpand(Node root) {
......@@ -23,7 +25,9 @@ public class MoveProviderImpl implements MoveProvider {
current = traversingResult.getValue0();
Path path = traversingResult.getValue1();
Node expanded = expand(current);
if (current != expanded) {
current = expanded;
current.visit();
......@@ -38,8 +42,8 @@ public class MoveProviderImpl implements MoveProvider {
int depth = 0;
Path path = new Path();
// While has all available children.
while (depth < this.mctsWrapper.getSize() && current.getChildrenSize() == this.mctsWrapper.domainSize(depth)) {
// While has been expanded and is not leaf.
while (depth < this.mctsWrapper.getSize() && current.isExpanded()) {
current = current.getBestChild();
depth++;
current.visit();
......@@ -59,10 +63,9 @@ public class MoveProviderImpl implements MoveProvider {
return toExpand;
}
IntStream.range(mctsWrapper.getMinDomainValue(depth), mctsWrapper.getMaxDomainValue(depth) + 1).
forEach(value -> {
Node newNode = new NodeImpl(value);
newNode.linkToTree(toExpand);
});
forEach(value -> memoryLimiter.createNode(toExpand, value));
toExpand.setExpanded();
return toExpand.getChildren().get(mctsWrapper.generateRandomValue(depth));
}
......
......@@ -10,12 +10,16 @@ import java.util.List;
import static java.util.Collections.max;
@Getter
public class NodeImpl implements Node {
@Getter
private Node parent = null;
@Getter
private List<Node> children = new ArrayList<>();
@Getter
private int value;
@Getter
private NodeStatistics nodeStatistics;
private QueueLinker queueLinker = new QueueLinker();
public NodeImpl(Integer value) {
this.value = value;
......@@ -40,13 +44,13 @@ public class NodeImpl implements Node {
}
@Override
public int getChildrenSize() {
return children.size();
public void visit() {
nodeStatistics.markNewVisit();
}
@Override
public void visit() {
nodeStatistics.markNewVisit();
public int getChildrenSize() {
return children.size();
}
@Override
......@@ -60,10 +64,29 @@ public class NodeImpl implements Node {
}
@Override
public boolean isExpanded() {
return nodeStatistics.isExpanded();
}
@Override
public void setExpanded() {
nodeStatistics.setExpanded();
}
@Override
public void setUnexpanded() {
nodeStatistics.setUnexpanded();
}
public void removeChild(Node child) {
this.children.remove(child);
}
@Override
public void removeChildren() {
children.clear();
}
@Override
public int compareTo(Node other) {
NodeStatistics otherStats = other.getNodeStatistics();
......@@ -84,4 +107,9 @@ public class NodeImpl implements Node {
return -Integer.compare(value, other.getValue());
}
}
// Queue linker that is responsible for add, deleting and moving node in queue.
protected QueueLinker getQueueLinker() {
return queueLinker;
}
}
......@@ -17,6 +17,7 @@ public class NodeStatisticsImpl implements NodeStatistics {
private double maximalUtility = 0.0;
private int visitCount;
private int depth;
private boolean isExpanded = false; // True if nodes has been expanded and has children.
public NodeStatisticsImpl(int parentDepth) {
this.visitCount = 0;
......@@ -48,4 +49,14 @@ public class NodeStatisticsImpl implements NodeStatistics {
(1 - selectorCoefficient) * maximalUtility +
explorationCoefficient * Math.sqrt(Math.log((double) parentStats.getVisitCount() / (double) getVisitCount()));
}
@Override
public void setExpanded() {
this.isExpanded = true;
}
@Override
public void setUnexpanded() {
this.isExpanded = false;
}
}
package eu.melodic.upperware.mcts_solver.solver.mcts.tree_impl;
import eu.melodic.upperware.mcts_solver.solver.mcts.tree.Node;
import lombok.Getter;
public class Queue {
@Getter
private NodeImpl front = null;
private NodeImpl back = null;
/*
Moves node to the back of queue.
If node was already in queue then its previous occurrence is forgotten and it's added as a new element.
*/
public void pushBack(Node newNode) {
NodeImpl node = (NodeImpl) newNode;
if (node.getQueueLinker().isInQueue()) { // If is in queue then remove it from queue for now.
removeNodeFromQueue(node);
}
// Current node is not in queue.
if (this.empty()) {
this.front = this.back = node;
node.getQueueLinker().addToQueue(node,null);
} else {
node.getQueueLinker().addToQueue(node, back);
this.back = node;
}
}
public boolean empty() {
return front == null;
}
public void removeNodeFromQueue(NodeImpl node) {
if (front == node) {
this.front = node.getQueueLinker().getNext();
}
if (back == node) {
this.back = node.getQueueLinker().getPrevious();
}
node.getQueueLinker().removeFromQueue();
}
}
package eu.melodic.upperware.mcts_solver.solver.mcts.tree_impl;
import lombok.Getter;
import lombok.Setter;
@Getter
public class QueueLinker {
private boolean isInQueue = false;
@Setter
private NodeImpl next = null; // Next is a node in back direction.
@Setter
private NodeImpl previous = null; // Previous is a node in front direction.
public void addToQueue(NodeImpl current, NodeImpl previous) {
if (previous != null) {
previous.getQueueLinker().setNext(current);
}
this.isInQueue = true;
this.previous = previous;
this.next = null;
}
public void removeFromQueue() {
if (previous != null) {
previous.getQueueLinker().setNext(next);
}
if (next != null) {
next.getQueueLinker().setPrevious(previous);
}
this.isInQueue = false;
this.next = null;
this.previous = null;
}
}
package eu.melodic.upperware.mcts_solver.solver.mcts.tree_impl;
import eu.melodic.upperware.mcts_solver.solver.mcts.tree.MemoryLimiter;
import eu.melodic.upperware.mcts_solver.solver.mcts.tree.MoveProvider;
import eu.melodic.upperware.mcts_solver.solver.mcts.tree.Policy;
import eu.melodic.upperware.mcts_solver.solver.mcts.tree.Tree;
public class TreeImpl extends Tree {
public TreeImpl(Policy policy, MoveProvider moveProvider) {
super(policy, moveProvider);
public TreeImpl(Policy policy, MoveProvider moveProvider, MemoryLimiter memoryLimiter) {
super(policy, moveProvider, memoryLimiter);
this.root = new NodeImpl();
}
}
......@@ -22,7 +22,7 @@ public class MCTSSingleTreeSolverTest {
CPWrapper cpWrapper = new CPWrapper();
cpWrapper.parse(problem.keySet().iterator().next(), problem.values().iterator().next());
MCTSSingleTreeSolver mctsSingleTreeSolver = new MCTSSingleTreeSolver(0.1, 0.5, 150, new MCTSWrapper(cpWrapper, null), AvailablePolicies.RANDOM_POLICY);
MCTSSingleTreeSolver mctsSingleTreeSolver = new MCTSSingleTreeSolver(0.1, 0.5, 150, 200000, new MCTSWrapper(cpWrapper, null), AvailablePolicies.RANDOM_POLICY);
List<Integer> assignment = mctsSingleTreeSolver.search().getAssignment();
List<Double> domain1 = Arrays.asList(1.0,2.0,3.0,4.0,5.0);
......@@ -40,7 +40,7 @@ public class MCTSSingleTreeSolverTest {
CPWrapper cpWrapper = new CPWrapper();
cpWrapper.parse(problem.keySet().iterator().next(), problem.values().iterator().next());
MCTSSingleTreeSolver mctsSingleTreeSolver = new MCTSSingleTreeSolver(0.1, 0.8, 5000, new MCTSWrapper(cpWrapper, null), AvailablePolicies.RANDOM_POLICY);
MCTSSingleTreeSolver mctsSingleTreeSolver = new MCTSSingleTreeSolver(0.1, 0.8, 5000, 200000, new MCTSWrapper(cpWrapper, null), AvailablePolicies.RANDOM_POLICY);
List<Integer> assignment = mctsSingleTreeSolver.search().getAssignment();
List<Double> domain1 = Arrays.asList(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0);
......
package memory_limiter_tests;
import eu.melodic.upperware.mcts_solver.solver.mcts.tree.Node;
import eu.melodic.upperware.mcts_solver.solver.mcts.tree_impl.NodeImpl;
import eu.melodic.upperware.mcts_solver.solver.mcts.tree_impl.Queue;
import org.junit.Test;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import static org.junit.Assert.*;
public class QueueTest {
private NodeImpl node1 = new NodeImpl(1);
private NodeImpl node2 = new NodeImpl(2);
private NodeImpl node3 = new NodeImpl(3);
@Test
public void QueueTest() {
Queue queue = new Queue();
assertTrue(queue.empty());
queue.pushBack(node1);
assertFalse(queue.empty());
assertNotNull(queue.getFront());
queue.removeNodeFromQueue(queue.getFront());
queue.pushBack(node1);
queue.pushBack(node2);
assertFalse(queue.empty());
assertEquals(queue.getFront(), node1);
queue.removeNodeFromQueue(queue.getFront());
assertFalse(queue.empty());
assertEquals(queue.getFront(), node2);
queue.removeNodeFromQueue(queue.getFront());
assertTrue(queue.empty());
queue.pushBack(node1);
queue.pushBack(node2);
queue.pushBack(node1);
assertFalse(queue.empty());
assertEquals(queue.getFront(), node2);
queue.removeNodeFromQueue(queue.getFront());
assertFalse(queue.empty());
assertEquals(queue.getFront(), node1);
queue.removeNodeFromQueue(queue.getFront());
assertTrue(queue.empty());
queue.pushBack(node1);
queue.pushBack(node1);
assertFalse(queue.empty());
assertEquals(queue.getFront(), node1);