Skip to content

Commit b6c6ed9

Browse files
committed
Add Dirichlet Noise
Copy paste from suragnair/alpha-zero-general#186
1 parent 865270f commit b6c6ed9

3 files changed

Lines changed: 25 additions & 9 deletions

File tree

Coach.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, game, nnet, args):
2424
self.nnet = nnet
2525
self.pnet = self.nnet.__class__(self.game, self.nnet.args) # the competitor network
2626
self.args = args
27-
self.mcts = MCTS(self.game, self.nnet, self.args)
27+
self.mcts = MCTS(self.game, self.nnet, self.args, dirichlet_noise=(self.args.dirichletAlpha>0))
2828
self.trainExamplesHistory = [] # history of examples from args.numItersForTrainExamplesHistory latest iterations
2929
self.skipFirstSelfPlay = False # can be overriden in loadTrainExamples()
3030

@@ -86,7 +86,7 @@ def learn(self):
8686
iterationTrainExamples = deque([], maxlen=self.args.maxlenOfQueue)
8787

8888
for _ in tqdm(range(self.args.numEps), desc="Self Play", ncols=100):
89-
self.mcts = MCTS(self.game, self.nnet, self.args) # reset search tree
89+
self.mcts = MCTS(self.game, self.nnet, self.args, dirichlet_noise=(self.args.dirichletAlpha>0)) # reset search tree
9090
iterationTrainExamples += self.executeEpisode()
9191

9292
# save the iteration examples to the history

MCTS.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@ class MCTS():
1313
This class handles the MCTS tree.
1414
"""
1515

16-
def __init__(self, game, nnet, args):
16+
def __init__(self, game, nnet, args, dirichlet_noise=False):
1717
self.game = game
1818
self.nnet = nnet
1919
self.args = args
20+
self.dirichlet_noise = dirichlet_noise
2021
self.Qsa = {} # stores Q values for s,a (as defined in the paper)
2122
self.Nsa = {} # stores #times edge s,a was visited
2223
self.Ns = {} # stores #times board s was visited
@@ -35,7 +36,8 @@ def getActionProb(self, canonicalBoard, temp=1):
3536
proportional to Nsa[(s,a)]**(1./temp)
3637
"""
3738
for i in range(self.args.numMCTSSims):
38-
self.search(canonicalBoard)
39+
dir_noise = (i == 0 and self.dirichlet_noise)
40+
self.search(canonicalBoard, dirichlet_noise=dir_noise)
3941

4042
s = self.game.stringRepresentation(canonicalBoard)
4143
counts = [self.Nsa[(s, a)] if (s, a) in self.Nsa else 0 for a in range(self.game.getActionSize())]
@@ -52,7 +54,7 @@ def getActionProb(self, canonicalBoard, temp=1):
5254
probs = [x / counts_sum for x in counts]
5355
return probs
5456

55-
def search(self, canonicalBoard):
57+
def search(self, canonicalBoard, dirichlet_noise=False):
5658
"""
5759
This function performs one iteration of MCTS. It is recursively called
5860
till a leaf node is found. The action chosen at each node is one that
@@ -81,10 +83,11 @@ def search(self, canonicalBoard):
8183
return -self.Es[s]
8284

8385
if s not in self.Ps:
84-
# leaf node
85-
self.Ps[s], v = self.nnet.predict(canonicalBoard)
8686
valids = self.game.getValidMoves(canonicalBoard, 1)
87-
self.Ps[s] = self.Ps[s] * valids # masking invalid moves
87+
# leaf node
88+
self.Ps[s], v = self.nnet.predict(canonicalBoard, valids)
89+
if dirichlet_noise:
90+
self.applyDirNoise(s, valids)
8891
sum_Ps_s = np.sum(self.Ps[s])
8992
if sum_Ps_s > 0:
9093
self.Ps[s] /= sum_Ps_s # renormalize
@@ -102,6 +105,10 @@ def search(self, canonicalBoard):
102105
return -v
103106

104107
valids = self.Vs[s]
108+
if dirichlet_noise:
109+
self.applyDirNoise(s, valids)
110+
sum_Ps_s = np.sum(self.Ps[s])
111+
self.Ps[s] /= sum_Ps_s # renormalize
105112
cur_best = -float('inf')
106113
best_act = -1
107114

@@ -134,3 +141,12 @@ def search(self, canonicalBoard):
134141

135142
self.Ns[s] += 1
136143
return -v
144+
145+
146+
def applyDirNoise(self, s, valids):
147+
dir_values = np.random.dirichlet([self.args.dirichletAlpha] * np.count_nonzero(valids))
148+
dir_idx = 0
149+
for idx in range(len(self.Ps[s])):
150+
if valids[idx]:
151+
self.Ps[s][idx] = (0.75 * self.Ps[s][idx]) + (0.25 * dir_values[dir_idx])
152+
dir_idx += 1

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def main():
7373
parser.add_argument('--maxlenOfQueue' , '-q' , action='store', default=200000, type=int , help='Number of game examples to train the neural networks')
7474
parser.add_argument('--numMCTSSims' , '-m' , action='store', default=25 , type=int , help='Number of games moves for MCTS to simulate.')
7575
parser.add_argument('--cpuct' , '-c' , action='store', default=1.0 , type=float, help='')
76-
# parser.add_argument('--dirichletAlpha' , '-a' , action='store', default=0.1 , type=float, help='α=0.3 for chess, scaled in inverse proportion to the approximate number of legal moves in a typical position')
76+
parser.add_argument('--dirichletAlpha' , '-a' , action='store', default=0.1 , type=float, help='α=0.3 for chess, scaled in inverse proportion to the approximate number of legal moves in a typical position')
7777
parser.add_argument('--numItersForTrainExamplesHistory', '-n', action='store', default=5, type=int, help='')
7878

7979
parser.add_argument('--learn-rate' , '-l' , action='store', default=0.001, type=float, help='')

0 commit comments

Comments
 (0)