Skip to content

Commit

Permalink
perf: Add tests for Ultimate-Tic-Tac-Toe and ability to set duration …
Browse files Browse the repository at this point in the history
…on each call of getAction()
  • Loading branch information
Philippe Vaillancourt committed May 2, 2018
1 parent 2085c78 commit 1d7c40b
Show file tree
Hide file tree
Showing 7 changed files with 401 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/controller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ export class Controller<State extends Playerwise, Action> {
* @returns {Action}
* @memberof Controller
*/
getAction(state: State): Action {
return this.mcts_.getAction(state)
getAction(state: State, duration?: number): Action {
return this.mcts_.getAction(state, duration)
}
}
4 changes: 2 additions & 2 deletions src/macao.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ export class Macao<State extends Playerwise, Action> {
* @returns {Action}
* @memberof Macao
*/
getAction(state: State): Action {
return this.controller_.getAction(state)
getAction(state: State, duration?: number): Action {
return this.controller_.getAction(state, duration)
}

/**
Expand Down
17 changes: 14 additions & 3 deletions src/mcts.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
CalculateReward
} from './classes'
import { spliceRandom, loopFor } from './utils'
import { performance } from 'perf_hooks'

/**
*
Expand Down Expand Up @@ -356,7 +357,7 @@ export class DefaultUCB1<State, Action> implements UCB1<State, Action> {
* @template Action
*/
export interface MCTSFacade<State, Action> {
getAction: (state: State) => Action
getAction: (state: State, duration?: number) => Action
}

/**
Expand Down Expand Up @@ -403,13 +404,23 @@ export class DefaultMCTSFacade<State extends Playerwise, Action>
* @returns {Action}
* @memberof DefaultMCTSFacade
*/
getAction(state: State): Action {
getAction(state: State, duration?: number): Action {
const rootNode = this.createRootNode_(state)
loopFor(this.duration_).milliseconds(() => {
loopFor(duration || this.duration_).milliseconds(() => {
performance.mark('select start')
const node = this.select_.run(rootNode, this.explorationParam_)
performance.mark('select end')
performance.mark('simulate start')
const score = this.simulate_.run(node.mctsState.state)
performance.mark('simulate end')
performance.mark('backPropagate start')
this.backPropagate_.run(node, score)
performance.mark('backPropagate end')
performance.measure('select', 'select start', 'select end')
performance.measure('simulate', 'simulate start', 'simulate end')
performance.measure('backPropagate', 'backPropagate start', 'backPropagate end')
})

const bestChild = this.bestChild_.run(rootNode, 0)

return bestChild.action as Action
Expand Down
3 changes: 3 additions & 0 deletions src/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,6 @@ export const loopFor = (time: number) => {
}
}
}

// Function to get the average of a numbers array
export const arrAvg = (arr: number[]) => arr.reduce((a, b) => a + b, 0) / arr.length
2 changes: 1 addition & 1 deletion test/tic-tac-toe/tic-tac-toe.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import { MCTSState } from '../../src/classes'
import { TicTacToeState, TicTacToeMove, ticTacToeFuncs } from './tic-tac-toe'
import { DataStore } from '../../src/data-store'
import { loopFor } from '../../src/utils'
import Macao from '../../src/macao'
import { Macao } from '../../src/macao'

xdescribe('The DefaultMCTSFacade instance', () => {
let dataStore: DataGateway<string, MCTSState<TicTacToeState, TicTacToeMove>>
Expand Down
57 changes: 57 additions & 0 deletions test/ultimate-tic-tac-toe/ultimate-tic-tac-toe.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import { UTicTacToeState, UTicTacToeMove, uTicTacToeFuncs } from './ultimate-tic-tac-toe'
import { Macao } from '../../src/macao'
import { loopFor } from '../../src/utils'

xdescribe('The Macao instance', () => {
let uTicTacToeBoard: number[][][][]
let state: UTicTacToeState
let mcts: Macao<UTicTacToeState, UTicTacToeMove>

describe('when used to simulate 100 Ultimate Tic Tac Toe games', () => {
describe('given 85 ms per turn and an exploration param of 1.414', () => {
it('should end in a draw 95% of the time or better', () => {
let results = 0
loopFor(100).turns(() => {
mcts = new Macao(
{
stateIsTerminal: uTicTacToeFuncs.stateIsTerminal,
generateActions: uTicTacToeFuncs.generateActions,
applyAction: uTicTacToeFuncs.applyAction,
calculateReward: uTicTacToeFuncs.calculateReward
},
{ duration: 85 }
)
uTicTacToeBoard = [
[
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
[[0, 0, 0], [0, 0, 0], [0, 0, 0]]
],
[
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
[[0, 0, 0], [0, 0, 0], [0, 0, 0]]
],
[
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
[[0, 0, 0], [0, 0, 0], [0, 0, 0]],
[[0, 0, 0], [0, 0, 0], [0, 0, 0]]
]
]
state = {
board: uTicTacToeBoard,
player: -1,
previousAction: { bigRow: -1, bigCol: -1, smallRow: -1, smallCol: -1 }
}
while (!uTicTacToeFuncs.stateIsTerminal(state)) {
const action = mcts.getAction(state)
state = uTicTacToeFuncs.applyAction(state, action)
}

results += uTicTacToeFuncs.calculateReward(state, 1) === 0 ? 1 : 0
})
expect(results).toBeGreaterThan(95)
})
})
})
})
Loading

0 comments on commit 1d7c40b

Please sign in to comment.