From 1d7c40bc111551bda06944b38f8d0c5abb11e478 Mon Sep 17 00:00:00 2001 From: Philippe Vaillancourt Date: Wed, 2 May 2018 00:14:44 -0400 Subject: [PATCH] perf: Add tests for Ultimate-Tic-Tac-Toe and ability to set duration on each call of getAction() --- src/controller.ts | 4 +- src/macao.ts | 4 +- src/mcts.ts | 17 +- src/utils.ts | 3 + test/tic-tac-toe/tic-tac-toe.test.ts | 2 +- .../ultimate-tic-tac-toe.test.ts | 57 ++++ .../ultimate-tic-tac-toe.ts | 322 ++++++++++++++++++ 7 files changed, 401 insertions(+), 8 deletions(-) create mode 100644 test/ultimate-tic-tac-toe/ultimate-tic-tac-toe.test.ts create mode 100644 test/ultimate-tic-tac-toe/ultimate-tic-tac-toe.ts diff --git a/src/controller.ts b/src/controller.ts index f46b83f..ab4d7af 100644 --- a/src/controller.ts +++ b/src/controller.ts @@ -126,7 +126,7 @@ export class Controller { * @returns {Action} * @memberof Controller */ - getAction(state: State): Action { - return this.mcts_.getAction(state) + getAction(state: State, duration?: number): Action { + return this.mcts_.getAction(state, duration) } } diff --git a/src/macao.ts b/src/macao.ts index b4476d6..c4af8fe 100644 --- a/src/macao.ts +++ b/src/macao.ts @@ -62,8 +62,8 @@ export class Macao { * @returns {Action} * @memberof Macao */ - getAction(state: State): Action { - return this.controller_.getAction(state) + getAction(state: State, duration?: number): Action { + return this.controller_.getAction(state, duration) } /** diff --git a/src/mcts.ts b/src/mcts.ts index 6c52edf..72eca3a 100644 --- a/src/mcts.ts +++ b/src/mcts.ts @@ -8,6 +8,7 @@ import { CalculateReward } from './classes' import { spliceRandom, loopFor } from './utils' +import { performance } from 'perf_hooks' /** * @@ -356,7 +357,7 @@ export class DefaultUCB1 implements UCB1 { * @template Action */ export interface MCTSFacade { - getAction: (state: State) => Action + getAction: (state: State, duration?: number) => Action } /** @@ -403,13 +404,23 @@ export class DefaultMCTSFacade * @returns {Action} * @memberof DefaultMCTSFacade */ - getAction(state: State): Action { + getAction(state: State, duration?: number): Action { const rootNode = this.createRootNode_(state) - loopFor(this.duration_).milliseconds(() => { + loopFor(duration || this.duration_).milliseconds(() => { + performance.mark('select start') const node = this.select_.run(rootNode, this.explorationParam_) + performance.mark('select end') + performance.mark('simulate start') const score = this.simulate_.run(node.mctsState.state) + performance.mark('simulate end') + performance.mark('backPropagate start') this.backPropagate_.run(node, score) + performance.mark('backPropagate end') + performance.measure('select', 'select start', 'select end') + performance.measure('simulate', 'simulate start', 'simulate end') + performance.measure('backPropagate', 'backPropagate start', 'backPropagate end') }) + const bestChild = this.bestChild_.run(rootNode, 0) return bestChild.action as Action diff --git a/src/utils.ts b/src/utils.ts index 72772a3..3b09265 100644 --- a/src/utils.ts +++ b/src/utils.ts @@ -91,3 +91,6 @@ export const loopFor = (time: number) => { } } } + +// Function to get the average of a numbers array +export const arrAvg = (arr: number[]) => arr.reduce((a, b) => a + b, 0) / arr.length diff --git a/test/tic-tac-toe/tic-tac-toe.test.ts b/test/tic-tac-toe/tic-tac-toe.test.ts index fdd66e2..45d4b93 100644 --- a/test/tic-tac-toe/tic-tac-toe.test.ts +++ b/test/tic-tac-toe/tic-tac-toe.test.ts @@ -18,7 +18,7 @@ import { MCTSState } from '../../src/classes' import { TicTacToeState, TicTacToeMove, ticTacToeFuncs } from './tic-tac-toe' import { DataStore } from '../../src/data-store' import { loopFor } from '../../src/utils' -import Macao from '../../src/macao' +import { Macao } from '../../src/macao' xdescribe('The DefaultMCTSFacade instance', () => { let dataStore: DataGateway> diff --git a/test/ultimate-tic-tac-toe/ultimate-tic-tac-toe.test.ts b/test/ultimate-tic-tac-toe/ultimate-tic-tac-toe.test.ts new file mode 100644 index 0000000..4d46428 --- /dev/null +++ b/test/ultimate-tic-tac-toe/ultimate-tic-tac-toe.test.ts @@ -0,0 +1,57 @@ +import { UTicTacToeState, UTicTacToeMove, uTicTacToeFuncs } from './ultimate-tic-tac-toe' +import { Macao } from '../../src/macao' +import { loopFor } from '../../src/utils' + +xdescribe('The Macao instance', () => { + let uTicTacToeBoard: number[][][][] + let state: UTicTacToeState + let mcts: Macao + + describe('when used to simulate 100 Ultimate Tic Tac Toe games', () => { + describe('given 85 ms per turn and an exploration param of 1.414', () => { + it('should end in a draw 95% of the time or better', () => { + let results = 0 + loopFor(100).turns(() => { + mcts = new Macao( + { + stateIsTerminal: uTicTacToeFuncs.stateIsTerminal, + generateActions: uTicTacToeFuncs.generateActions, + applyAction: uTicTacToeFuncs.applyAction, + calculateReward: uTicTacToeFuncs.calculateReward + }, + { duration: 85 } + ) + uTicTacToeBoard = [ + [ + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]] + ], + [ + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]] + ], + [ + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]] + ] + ] + state = { + board: uTicTacToeBoard, + player: -1, + previousAction: { bigRow: -1, bigCol: -1, smallRow: -1, smallCol: -1 } + } + while (!uTicTacToeFuncs.stateIsTerminal(state)) { + const action = mcts.getAction(state) + state = uTicTacToeFuncs.applyAction(state, action) + } + + results += uTicTacToeFuncs.calculateReward(state, 1) === 0 ? 1 : 0 + }) + expect(results).toBeGreaterThan(95) + }) + }) + }) +}) diff --git a/test/ultimate-tic-tac-toe/ultimate-tic-tac-toe.ts b/test/ultimate-tic-tac-toe/ultimate-tic-tac-toe.ts new file mode 100644 index 0000000..78a4aed --- /dev/null +++ b/test/ultimate-tic-tac-toe/ultimate-tic-tac-toe.ts @@ -0,0 +1,322 @@ +import { Macao } from '../../src/macao' + +export interface UTicTacToeMove { + bigRow: number + bigCol: number + smallRow: number + smallCol: number +} + +export interface UTicTacToeState { + board: number[][][][] + player: number + previousAction: UTicTacToeMove +} + +export interface TicTacToeState { + board: (number | string)[][] +} + +export function convertToMove(row: number, col: number): UTicTacToeMove { + const bigRow = Math.floor(row / 3) + const bigCol = Math.floor(col / 3) + const smallRow = Math.floor(row % 3) + const smallCol = Math.floor(col % 3) + return { bigRow, bigCol, smallRow, smallCol } +} + +export function convertFromMove(move: UTicTacToeMove): { row: number; col: number } { + const row = Math.floor(move.bigRow * 3) + Math.floor(move.smallRow % 3) + const col = Math.floor(move.bigCol * 3) + Math.floor(move.smallCol % 3) + return { row, col } +} + +export const uTicTacToeBoard = [ + [ + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]] + ], + [ + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]] + ], + [ + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]] + ] +] + +export function possibleMovesUTicTacToe(state: UTicTacToeState): UTicTacToeMove[] { + const result: UTicTacToeMove[] = [] + if (state.previousAction.bigRow !== -1) { + const bigRow = state.previousAction.smallRow + const bigCol = state.previousAction.smallCol + const innerState = { board: state.board[bigRow][bigCol] } + + // Check if the inner board square the previous player played into is not terminal + if (!stateIsTerminalTicTacToe(innerState)) { + // Only check for moves in the big board square + innerState.board.forEach((smallRowArray, smallRow) => { + smallRowArray.forEach((value, smallCol) => { + if (value === 0) result.push({ bigRow, bigCol, smallRow, smallCol }) + }) + }) + return result + } + } + + // If that inner board is Terminal, we have to check all other inner boards + state.board.forEach((bigRowArray, bigRow) => { + bigRowArray.forEach((innerSquare, bigCol) => { + const innerState = { board: innerSquare } + // Check if inner board is not Terminal + if (!stateIsTerminalTicTacToe(innerState)) { + // Push all possible moves to result array + innerState.board.forEach((smallRowArray, smallRow) => { + smallRowArray.forEach((value, smallCol) => { + if (value === 0) { + result.push({ bigRow, bigCol, smallRow, smallCol }) + } + }) + }) + } + }) + }) + return result +} + +// Be careful not to mutate the board but to return a new one +export function playMoveUTicTacToe(state: UTicTacToeState, move: UTicTacToeMove): UTicTacToeState { + const jSONBoard = JSON.stringify(state.board) + const newBoard = JSON.parse(jSONBoard) + + newBoard[move.bigRow][move.bigCol][move.smallRow][move.smallCol] = state.player * -1 + const newState: UTicTacToeState = { + board: newBoard, + player: state.player * -1, + previousAction: move + } + return newState +} + +export function stateIsTerminalTicTacToe(state: TicTacToeState): boolean { + for (let i = 0; i < 3; i++) { + // check rows to see if there is a winner + if ( + state.board[i][0] === state.board[i][1] && + state.board[i][1] === state.board[i][2] && + state.board[i][0] !== 0 && + state.board[i][0] !== 'D' + ) { + return true + } + + // check cols to see if there is a winner + if ( + state.board[0][i] === state.board[1][i] && + state.board[1][i] === state.board[2][i] && + state.board[0][i] !== 0 && + state.board[0][i] !== 'D' + ) { + return true + } + } + + // check diags to see if there is a winner + if ( + state.board[0][0] === state.board[1][1] && + state.board[1][1] === state.board[2][2] && + state.board[0][0] !== 0 && + state.board[0][0] !== 'D' + ) { + return true + } + + if ( + state.board[0][2] === state.board[1][1] && + state.board[1][1] === state.board[2][0] && + state.board[0][2] !== 0 && + state.board[0][2] !== 'D' + ) { + return true + } + + // check to see if the board is full and therefore a draw + const flattenBoard = state.board.reduce((p, c) => p.concat(c)) + if (flattenBoard.every(value => value !== 0)) return true + + return false +} + +export function stateIsTerminalUTicTacToe(state: UTicTacToeState): boolean { + let metaboard: (number | string)[][] = [[0, 0, 0], [0, 0, 0], [0, 0, 0]] + + state.board.forEach((bigRow, bigRowIndex) => { + bigRow.forEach((innerBoard, bigColIndex) => { + const innerState = { board: innerBoard } + if (stateIsTerminalTicTacToe(innerState)) { + const score = calculateRewardTicTacToe(innerState, 1) + metaboard[bigRowIndex][bigColIndex] = score === 0 ? 'D' : score + } + }) + }) + + return stateIsTerminalTicTacToe({ board: metaboard }) +} + +export function calculateRewardTicTacToe(state: TicTacToeState, player: number): number { + for (let i = 0; i < 3; i++) { + // check rows to see if there is a winner + if ( + state.board[i][0] === state.board[i][1] && + state.board[i][1] === state.board[i][2] && + state.board[i][0] !== 0 + ) { + if (state.board[i][0] === player) return 1 + + return -1 + } + + // check cols to see if there is a winner + if ( + state.board[0][i] === state.board[1][i] && + state.board[1][i] === state.board[2][i] && + state.board[0][i] !== 0 + ) { + if (state.board[0][i] === player) return 1 + + return -1 + } + } + + // check diags to see if there is a winner + if ( + state.board[0][0] === state.board[1][1] && + state.board[1][1] === state.board[2][2] && + state.board[0][0] !== 0 + ) { + if (state.board[0][0] === player) return 1 + + return -1 + } + + if ( + state.board[0][2] === state.board[1][1] && + state.board[1][1] === state.board[2][0] && + state.board[0][2] !== 0 + ) { + if (state.board[0][2] === player) return 1 + + return -1 + } + + return 0 +} + +export function calculateRewardUTicTacToe(state: UTicTacToeState, player: number): number { + for (let i = 0; i < 3; i++) { + // check rows to see if there is a winner + if ( + calculateRewardTicTacToe({ board: state.board[i][0] }, player) === + calculateRewardTicTacToe({ board: state.board[i][1] }, player) && + calculateRewardTicTacToe({ board: state.board[i][1] }, player) === + calculateRewardTicTacToe({ board: state.board[i][2] }, player) + ) { + if (calculateRewardTicTacToe({ board: state.board[i][0] }, player) === 1) return 1 + + return -1 + } + + // check cols to see if there is a winner + if ( + calculateRewardTicTacToe({ board: state.board[0][i] }, player) === + calculateRewardTicTacToe({ board: state.board[1][i] }, player) && + calculateRewardTicTacToe({ board: state.board[1][i] }, player) === + calculateRewardTicTacToe({ board: state.board[2][i] }, player) + ) { + if (calculateRewardTicTacToe({ board: state.board[0][i] }, player) === 1) return 1 + + return -1 + } + } + + // check diags to see if there is a winner + if ( + calculateRewardTicTacToe({ board: state.board[0][0] }, player) === + calculateRewardTicTacToe({ board: state.board[1][1] }, player) && + calculateRewardTicTacToe({ board: state.board[1][1] }, player) === + calculateRewardTicTacToe({ board: state.board[2][2] }, player) + ) { + if (calculateRewardTicTacToe({ board: state.board[0][0] }, player) === 1) return 1 + + return -1 + } + + if ( + calculateRewardTicTacToe({ board: state.board[0][2] }, player) === + calculateRewardTicTacToe({ board: state.board[1][1] }, player) && + calculateRewardTicTacToe({ board: state.board[1][1] }, player) === + calculateRewardTicTacToe({ board: state.board[2][0] }, player) + ) { + if (calculateRewardTicTacToe({ board: state.board[0][2] }, player) === 1) return 1 + + return -1 + } + + // If there is no 3 in a row, the winner is whoever has won the most small boards + let result = 0 + + for (const row of state.board) { + for (const col of row) { + result += calculateRewardTicTacToe({ board: col }, player) + } + } + + return result +} + +export const uTicTacToeFuncs = { + generateActions: possibleMovesUTicTacToe, + applyAction: playMoveUTicTacToe, + stateIsTerminal: stateIsTerminalUTicTacToe, + calculateReward: calculateRewardUTicTacToe +} + +const testBoard = [ + [ + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]] + ], + [ + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [-1, 1, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]] + ], + [ + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]], + [[0, 0, 0], [0, 0, 0], [0, 0, 0]] + ] +] +const testState = { + board: testBoard, + player: -1, + previousAction: { + bigRow: 1, + bigCol: 1, + smallRow: 1, + smallCol: 0 + } +} + +// const mcts = new Macao(ticTacToeFuncs, {duration: 2000}); +// mcts.getAction(testState); //? + +// possibleMovesUTicTacToe(testState) //? +// stateIsTerminalUTicTacToe(testState) //? +// calculateRewardUTicTacToe(testState, 1) +// playMoveUTicTacToe(testState, {bigRow:0, bigCol:0, smallCol:0, smallRow:0}).board[0][0]