Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 20 additions & 15 deletions Coach.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import logging
import os
import sys
from collections import deque
from multiprocessing import Manager
from multiprocessing import Process
from pickle import Pickler, Unpickler
from random import shuffle

import numpy as np
from tqdm import tqdm

from Arena import Arena
from MCTS import MCTS
Expand All @@ -25,11 +25,10 @@ def __init__(self, game, nnet, args):
self.nnet = nnet
self.pnet = self.nnet.__class__(self.game) # the competitor network
self.args = args
self.mcts = MCTS(self.game, self.nnet, self.args)
self.trainExamplesHistory = [] # history of examples from args.numItersForTrainExamplesHistory latest iterations
self.skipFirstSelfPlay = False # can be overriden in loadTrainExamples()

def executeEpisode(self):
def executeEpisode(self, return_examples=None):
"""
This function executes one episode of self-play, starting with player 1.
As the game is played, each turn is added as a training example to
Expand All @@ -50,12 +49,13 @@ def executeEpisode(self):
self.curPlayer = 1
episodeStep = 0

mcts = MCTS(self.game, self.nnet, self.args)
while True:
episodeStep += 1
canonicalBoard = self.game.getCanonicalForm(board, self.curPlayer)
temp = int(episodeStep < self.args.tempThreshold)

pi = self.mcts.getActionProb(canonicalBoard, temp=temp)
pi = mcts.getActionProb(canonicalBoard, temp=temp)
sym = self.game.getSymmetries(canonicalBoard, pi)
for b, p in sym:
trainExamples.append([b, self.curPlayer, p, None])
Expand All @@ -66,7 +66,8 @@ def executeEpisode(self):
r = self.game.getGameEnded(board, self.curPlayer)

if r != 0:
return [(x[0], x[2], r * ((-1) ** (x[1] != self.curPlayer))) for x in trainExamples]
return_examples += [(x[0], x[2], r * ((-1) ** (x[1] != self.curPlayer))) for x in trainExamples]
return

def learn(self):
"""
Expand All @@ -80,23 +81,27 @@ def learn(self):
for i in range(1, self.args.numIters + 1):
# bookkeeping
log.info(f'Starting Iter #{i} ...')

# examples of the iteration
if not self.skipFirstSelfPlay or i > 1:
iterationTrainExamples = deque([], maxlen=self.args.maxlenOfQueue)

for _ in tqdm(range(self.args.numEps), desc="Self Play"):
self.mcts = MCTS(self.game, self.nnet, self.args) # reset search tree
iterationTrainExamples += self.executeEpisode()

# save the iteration examples to the history
self.trainExamplesHistory.append(iterationTrainExamples)
manager = Manager()
examples = manager.list()
processes = []
for _ in range(self.args.numEps):
p = Process(target=self.executeEpisode, args=(examples,))
p.start()
processes.append(p)
[p.join() for p in processes]

# save the iteration examples to the history
self.trainExamplesHistory.append(examples[-self.args.maxlenOfQueue:])

if len(self.trainExamplesHistory) > self.args.numItersForTrainExamplesHistory:
log.warning(
f"Removing the oldest entry in trainExamples. len(trainExamplesHistory) = {len(self.trainExamplesHistory)}")
self.trainExamplesHistory.pop(0)
# backup history to a file
# NB! the examples were collected using the model from the previous iteration, so (i-1)
# NB! the examples were collected using the model from the previous iteration, so (i-1)
self.saveTrainExamples(i - 1)

# shuffle examples before training
Expand Down
2 changes: 1 addition & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ def update(self, val, n=1):

class dotdict(dict):
def __getattr__(self, name):
return self[name]
return self.get(name)