@@ -46,10 +46,10 @@ void MCTS::addNoise(int cur, Val epsilon, Val alpha)
4646 sum += dirichlet[i];
4747 }
4848
49- for (size_t i = 0 ; i < tr[cur]. ch . size () ; i++)
49+ for (int i = 0 ; i < BLSIZE ; i++)
5050 {
51- int ch = tr[cur].ch [i];
52- tr[ch].policy = (1 -epsilon) * tr[ch].policy + epsilon * dirichlet[i] / sum;
51+ int ch = (* tr[cur].ch ) [i];
52+ if (ch) tr[ch].policy = (1 -epsilon) * tr[ch].policy + epsilon * dirichlet[i] / sum;
5353 }
5454}
5555
@@ -60,6 +60,7 @@ MCTS::MCTS(Board &_board, int _col, NN *_network, int _playouts):boardhash(_boar
6060 network = _network;
6161 playouts = _playouts;
6262 tr = new Node[(playouts+2 )*BLSIZE];
63+ chlist = new Board[playouts + 2 ];
6364 Prior::setbyBoard (board);
6465 Prior::setPlayer (nowcol);
6566 starttime = clock ();
@@ -87,7 +88,7 @@ void MCTS::unmake_move(int move)
8788
8889void MCTS::createRoot ()
8990{
90- trcnt = 1 ;
91+ trcnt = 1 ; chlistcnt = 0 ;
9192 stopflag = false ;
9293 tr[root].cnt = 1 ;
9394 tr[root].fa = -1 ;
@@ -107,31 +108,38 @@ void MCTS::createRoot()
107108 addNoise (root, 0 .25f , 0 .1f );
108109}
109110
110- // LZ's selection, weaker than current
111+ // UCT-RAVE selection, use father hueristic
111112int MCTS::selection (int cur)
112113{
113- if (tr[cur].ch . size () == 0 )
114+ if (! tr[cur].ch )
114115 return cur;
115- auto sum_policy = 0 .0f ;
116- for (auto &ch : tr[cur].ch )
117- if (tr[ch].cnt )
118- sum_policy += tr[ch].policy ;
119-
120- auto numerator = sqrtf ((Val)tr[cur].ch .size ());
121- auto fpu = 0 .3f * sqrtf (sum_policy);
122- int maxc = 0 ; Val maxv = -FLOAT_INF;
123-
124- for (auto &ch : tr[cur].ch )
116+ int maxp = 0 ;
117+ Val maxv = -FLOAT_INF;
118+ for (int i=0 ;i<BLSIZE;i++)
125119 {
126- auto winrate = tr[ch].cnt == 0 ? (-tr[cur].sumv / tr[cur].cnt ) - fpu : tr[ch].sumv / tr[ch].cnt ;
127- auto ucb = winrate + UCBC * tr[ch].policy * numerator / (1 .0f + tr[ch].cnt );
120+ int ch = (*tr[cur].ch )[i];
121+ if (!ch) continue ;
122+ Val ucb;
123+ Val var_ele = UCBC*tr[ch].policy *sqrtf ((Val)tr[cur].cnt ) / (1 + tr[ch].cnt );
124+
125+ // rescale range
126+ Val father_val = (-tr[cur].sumv / tr[cur].cnt + 1 .0f ) / 1 .1f - 1 .0f ;
127+ static const Val father_decay = 0 .5f ;
128+ Val frac1 = powf (father_decay, tr[ch].cnt );
129+
130+ if (tr[ch].is_end ) frac1 = 0 ;
131+
132+ if (tr[ch].cnt == 0 )
133+ ucb = father_val + var_ele;
134+ else
135+ ucb = frac1 * father_val + (1 - frac1)*tr[ch].sumv / tr[ch].cnt + var_ele;
128136 if (ucb > maxv)
129137 {
130138 maxv = ucb;
131- maxc = ch;
139+ maxp = ch;
132140 }
133141 }
134- return maxc ;
142+ return maxp ;
135143}
136144
137145bool MCTS::getTimeLimit (int played)
@@ -141,20 +149,24 @@ bool MCTS::getTimeLimit(int played)
141149 clock_t t1 = clock ();
142150 if (timeout_turn && t1 - starttime >= timeout_turn - 500 )
143151 return true ;
152+ Val maxrate = 0 ;
153+ for (int i = 0 ; i < BLSIZE; i++)
154+ {
155+ int ch = (*tr[root].ch )[i];
156+ if (ch) maxrate = std::max (maxrate, (float )tr[ch].cnt / played);
157+ }
158+
144159 if (played > 600 )
145- for (auto it : tr[root].ch )
146- if ((float )tr[it].cnt / played > 0.95 )
160+ if (maxrate > 0.95 )
147161 return true ;
148162
149163 if (played > 1800 )
150- for (auto it : tr[root].ch )
151- if ((float )tr[it].cnt / played > 0.90 )
152- return true ;
164+ if (maxrate > 0.90 )
165+ return true ;
153166
154167 if (played > 4000 )
155- for (auto it : tr[root].ch )
156- if ((float )tr[it].cnt / played > 0.85 )
157- return true ;
168+ if (maxrate > 0.85 )
169+ return true ;
158170
159171 if (timeout_left)
160172 {
@@ -178,29 +190,14 @@ void MCTS::solve(BoardWeight &result)
178190 int cur = root;
179191 while (1 )
180192 {
181- Val maxv = -FLOAT_INF;
182- // int maxp = selection(cur);
183-
184- int maxp = cur;
185- for (auto &ch : tr[cur].ch )
186- {
187- Val ucb;
188- if (tr[ch].cnt == 0 )
189- ucb = (-tr[cur].sumv / tr[cur].cnt + 1 .0f )/1 .1f -1 .0f + UCBC * tr[ch].policy *sqrtf ((Val)tr[cur].cnt );
190- else
191- ucb = tr[ch].sumv / tr[ch].cnt + UCBC*tr[ch].policy *sqrtf ((Val)tr[cur].cnt ) / (1 + tr[ch].cnt );
192- if (ucb > maxv)
193- {
194- maxv = ucb;
195- maxp = ch;
196- }
197- }
193+ int maxp = selection (cur);
198194
199195 // leaf node
200196 if (maxp == cur) break ;
201197 // forward search
202198 cur = maxp;
203199 make_move (tr[cur].move );
200+ // if (tr[cur].is_end) break;
204201 }
205202 // simulation & backpropagation
206203 simulation_back (cur);
@@ -214,9 +211,12 @@ void MCTS::solve(BoardWeight &result)
214211 {
215212 debug_s << board2showString (board, true );
216213 vector<std::pair<int , int >> pvlist;
217- for (auto c : tr[root].ch )
218- if (tr[c].cnt )
219- pvlist.push_back ({tr[c].cnt , c});
214+ for (int i = 0 ; i < BLSIZE; i++)
215+ {
216+ int ch = (*tr[root].ch )[i];
217+ if (ch && tr[ch].cnt )
218+ pvlist.push_back ({ tr[ch].cnt , ch });
219+ }
220220 sort (pvlist.begin (), pvlist.end ());
221221 for (int i = 0 ; i < std::min (10 , (int )pvlist.size ());i++)
222222 {
@@ -228,9 +228,10 @@ void MCTS::solve(BoardWeight &result)
228228 }
229229 logRefrsh ();
230230 result.clear ();
231- for (auto ch : tr[root]. ch )
231+ for (int i = 0 ; i < BLSIZE; i++ )
232232 {
233- result[tr[ch].move ] = (Val)tr[ch].cnt ;
233+ int ch = (*tr[root].ch )[i];
234+ if (ch) result[i] = (Val)tr[ch].cnt ;
234235 // std::cout << tr[ch].move << ' ' << tr[ch].cnt << ' ' << tr[ch].sumv / tr[ch].cnt << '\n';
235236 }
236237}
@@ -247,17 +248,20 @@ void MCTS::expand(int cur,RawOutput &output, Board &avail)
247248{
248249 // board.debug();
249250 // std::cout<<"netwin:"<<output.v<<'\n';
250- if (!tr[cur].ch .empty ())
251- tr[cur].ch .clear ();
251+ tr[cur].ch = &chlist[chlistcnt];
252+ (*tr[cur].ch ).clear ();
253+ chlistcnt++;
252254 for (int i = 0 ; i < BLSIZE; i++)
253255 if (avail[i]) // for valid
254256 {
255- tr[cur].ch .push_back (trcnt);
256- tr[trcnt].sumv = 0 ;
257- tr[trcnt].cnt = 0 ;
257+ (*tr[cur].ch )[i]=trcnt;
258+ tr[trcnt].sumv = tr[trcnt].sum_rave = 0 .0f ;
259+ tr[trcnt].ch = nullptr ;
260+ tr[trcnt].cnt = tr[trcnt].cnt_rave = 0 ;
258261 tr[trcnt].policy = output.p [i];
259262 tr[trcnt].move = i;
260263 tr[trcnt].fa = cur;
264+ tr[trcnt].is_end = false ;
261265 trcnt++;
262266 }
263267}
@@ -270,24 +274,25 @@ void MCTS::simulation_back(int cur)
270274 auto result = getEvaluation (board, nowcol, network, use_transform, tr[cur].move );
271275 val = -result.first .v ;
272276 expand (cur, result.first , result.second );
273- auto &it = hash_table.find (boardhash ());
274- if (it != hash_table.end ())
275- counter++;
276- else
277- hash_table[boardhash ()] = cur;
277+ if (val<-0.99 || val>0.99 ) tr[cur].is_end = true ;
278+ // auto &it = hash_table.find(boardhash());
279+ // if (it != hash_table.end()) counter++;
280+ // else hash_table[boardhash()] = cur;
278281 }
279282 else
280283 {
281284 val = 1 ;
282285 if (board.count () == BLSIZE)
283286 val = 0 ;
287+ tr[cur].is_end = true ;
284288 }
285289 if (cfg_swap3 && board.count () == 3 && board.countv (2 ) == 1 && tr[cur].fa >0 ) // if swap, player choice max rate point
286290 val = -fabs (val);
287291
292+ backprop:
288293 tr[cur].sumv += val;
289294 tr[cur].cnt ++;
290- // back
295+
291296 while (cur > 0 )
292297 {
293298 unmake_move (tr[cur].move );
0 commit comments