Pentagon-Swap AI PlayerCode Implementation in Java1. Mont Carlo Tree Class1.1 Get the Next Move1.2 Select the Best Node1.3 Expand the Given Node1.4 Rollout Step1.5 Backpropagation1.6 Find the Best Child Node2. UCB (Upper Confidence Bound) Class
xpackage student_player;
import java.util.ArrayList;
import java.util.Random;
import pentago_swap.PentagoBoardState;
import pentago_swap.PentagoMove;
import pentago_swap.PentagoBoardState.Piece;
import pentago_swap.PentagoBoardState.Quadrant;
public class MCTree
{
Node root;
int playerID;
// Constructor
public MCTree(PentagoBoardState pbState, int playerID)
{
this.root = new Node(null, pbState, null, 0, 0);
this.playerID = playerID;
}
}
xxxxxxxxxx
// Method to get the next move
public PentagoMove getNextMove(Node root)
{
// When we are the first to play, we optimize the first move
// Board size 6*6
int element = 0;
// Iterate through positions on board
for (int p = 0; p < 6; p++)
{
for (int q = 0; q < 6; q++)
{
if (root.pbState.getPieceAt(p, q) != Piece.EMPTY)
element++;
}
}
// When we checked that no piece is on the board, we create the first move
if (element == 0)
{
PentagoMove firstMove = new PentagoMove(1, 1, Quadrant.TR, Quadrant.BR, playerID);
return firstMove;
}
if (element == 1)
{
if (root.pbState.getPieceAt(1, 1) == Piece.EMPTY)
{
PentagoMove secondMove = new PentagoMove(1, 1, Quadrant.TR, Quadrant.BR, playerID);
return secondMove;
}
if (root.pbState.getPieceAt(4, 1) == Piece.EMPTY)
{
PentagoMove secondMove = new PentagoMove(4, 1, Quadrant.TL, Quadrant.TR, playerID);
return secondMove;
}
if (root.pbState.getPieceAt(1, 4) == Piece.EMPTY)
{
PentagoMove secondMove = new PentagoMove(1, 4, Quadrant.TR, Quadrant.BR, playerID);
return secondMove;
}
if (root.pbState.getPieceAt(4, 4) == Piece.EMPTY)
{
PentagoMove secondMove = new PentagoMove(4, 4, Quadrant.TL, Quadrant.BR, playerID);
return secondMove;
}
}
long time = System.currentTimeMillis();
while (System.currentTimeMillis() < time + 1500)
{
// Step 1, Selection: From root, move to greatest UCB child until we reach the
// leaf
Node bestNode = selectBestNode(root);
// Step 2, Expansion:
int choice_of_rollout = expandNode(bestNode);
Node nodeToExplore = bestNode;
// Step 3, Rollout: 2 cases, Never visited before or visited
if (choice_of_rollout == 0) {
}
else
{
if (nodeToExplore.children.size() < 1)
continue;
if (nodeToExplore.children.size() > 0)
nodeToExplore = bestNode.children.get(0);
}
double win = rollout(nodeToExplore);
// Step 4, Backpropagation:
backPropagation(nodeToExplore, win);
}
// Optimization: Check if we can win at this move by checking all the legal
// moves
ArrayList<PentagoMove> all_moves = root.pbState.getAllLegalMoves();
for (PentagoMove move : all_moves)
{
PentagoBoardState clone_state_1 = (PentagoBoardState) root.pbState.clone();
clone_state_1.processMove(move);
if (clone_state_1.getWinner() == playerID)
// System.out.println("I am gonna win");
return move;
}
boolean isBest = true;
int bestChild = -1;
do {
isBest = true;
bestChild = findBestChildAt(root);
PentagoMove BestMove = root.children.get(bestChild).move;
PentagoBoardState clone_state_2 = (PentagoBoardState) root.pbState.clone();
clone_state_2.processMove(BestMove);
ArrayList<PentagoMove> all_moves_other = clone_state_2.getAllLegalMoves();
for (PentagoMove move_other : all_moves_other)
{
PentagoBoardState clone_state_3 = (PentagoBoardState) clone_state_2.clone();
clone_state_3.processMove(move_other);
if (clone_state_3.getWinner() == clone_state_3.getTurnPlayer())
{
// This move will make me lose, this bestmoev is not good
// Discard this bestmove, look for a new best move
// root.children.get(bestChild).win = -99999;
root.children.get(bestChild).win = 0;
isBest = false;
System.out.println("I am gonna lose");
}
}
} while (isBest == false);
return root.children.get(bestChild).move;
}
xxxxxxxxxx
public Node selectBestNode(Node root)
{
Node bestNode = root;
while (!bestNode.isLeaf())
{
bestNode = UCB.findBestNodeWithUCB(bestNode);
}
return bestNode;
}
xxxxxxxxxx
public int expandNode(Node bestNode)
{
/**
* @return 0 - if its never visited
* @return 1 - It has been visited before, then add all possible outcome
*/
if (bestNode.visit == 0)
{
return 0;
}
else
{
PentagoBoardState clone_state = (PentagoBoardState) bestNode.pbState.clone();
ArrayList<PentagoMove> moves = clone_state.getAllLegalMoves();
for (PentagoMove m : moves)
{
PentagoBoardState newState = (PentagoBoardState) bestNode.pbState.clone();
newState.processMove(m);
Node child = new Node(m, newState, bestNode, 0, 0);
child.parent = bestNode;
bestNode.children.add(child);
}
return 1;
}
}
xxxxxxxxxx
public double rollout(Node node)
{
PentagoBoardState newState = (PentagoBoardState) node.pbState.clone();
while (!newState.gameOver())
{
newState.processMove((PentagoMove) newState.getRandomMove());
}
if (newState.getWinner() == playerID)
{
return 10;
}
else if (newState.getWinner() == newState.getTurnPlayer())
{
return -10;
}
else
{
return 0;
}
}
xxxxxxxxxx
public void backPropagation(Node nodeToExplore, double win)
{
// Update the node itself
nodeToExplore.visit++;
nodeToExplore.win += win;
// Backpropogation upwards
Node temp = nodeToExplore.parent;
while (temp != null)
{
if (temp.pbState.getTurnPlayer() == playerID)
temp.win += nodeToExplore.win;
temp.visit += nodeToExplore.visit;
temp = temp.parent;
}
}
xxxxxxxxxx
public int findBestChildAt(Node root)
{
double maxValue = -1;
int bestChild = -1;
for (int i = 0; i < root.children.size(); i++)
{
double tempWinLost = root.children.get(i).win / root.children.get(i).visit + 2 * Math.sqrt(Math.log(root.visit / root.children.get(i).visit));
if (tempWinLost >= maxValue)
{
maxValue = tempWinLost;
bestChild = i;
}
}
return bestChild;
}
UCB formula:
xxxxxxxxxx
class UCB
{
public static double ucbValue(double totalVisit, double nodeWin, double nodeVisit)
{
return (nodeWin / nodeVisit) + 2 * Math.sqrt(Math.log(totalVisit) / nodeVisit);
}
public static Node findBestNodeWithUCB(Node node)
{
double parent_visit = node.visit;
double maxValue = -1;
int bestChildAt = 0;
for (int i = 0; i < node.children.size(); i++)
{
if (node.children.get(i).visit == 0)
{
bestChildAt = i;
break;
}
double ucb = ucbValue(parent_visit, node.children.get(i).win, node.children.get(i).visit);
if (ucb > maxValue)
{
maxValue = ucb;
bestChildAt = i;
}
}
return node.children.get(bestChildAt);
}
}