@@ -61,6 +61,8 @@ MCTS::MCTS(Board &_board, int _col, NN *_network, int _playouts):boardhash(_boar
6161 playouts = _playouts;
6262 tr = new Node[(playouts+2 )*BLSIZE];
6363 chlist = new Board[playouts + 2 ];
64+ ravelist = new BoardWeight[playouts + 2 ];
65+ raveclist = new Board[playouts + 2 ];
6466 Prior::setbyBoard (board);
6567 Prior::setPlayer (nowcol);
6668 starttime = clock ();
@@ -126,13 +128,16 @@ int MCTS::selection(int cur)
126128 Val father_val = (-tr[cur].sumv / tr[cur].cnt + 1 .0f ) / 1 .1f - 1 .0f ;
127129 static const Val father_decay = 0 .5f ;
128130 Val frac1 = powf (father_decay, tr[ch].cnt );
129-
130- if (tr[ch].is_end ) frac1 = 0 ;
131+ Val rave_cnt = (Val)(*tr[cur].cnt_rave )[i];
132+ Val rave_win = (*tr[cur].sum_rave )[i] / rave_cnt;
133+ // Val rave_beta = rave_cnt /(rave_cnt + tr[ch].cnt + 2*rave_cnt*tr[ch].cnt);
134+ Val rave_beta = 0 .0f ;
135+ if (tr[ch].is_end ) frac1 = 0 , rave_beta = 0 ;
131136
132137 if (tr[ch].cnt == 0 )
133138 ucb = father_val + var_ele;
134139 else
135- ucb = frac1 * father_val + (1 - frac1)*tr[ch].sumv / tr[ch].cnt + var_ele;
140+ ucb = rave_beta * rave_win+( 1 -rave_beta)*( frac1 * father_val + (1 - frac1)*tr[ch].sumv / tr[ch].cnt ) + var_ele;
136141 if (ucb > maxv)
137142 {
138143 maxv = ucb;
@@ -249,15 +254,21 @@ void MCTS::expand(int cur,RawOutput &output, Board &avail)
249254 // board.debug();
250255 // std::cout<<"netwin:"<<output.v<<'\n';
251256 tr[cur].ch = &chlist[chlistcnt];
257+ tr[cur].sum_rave = &ravelist[chlistcnt];
258+ tr[cur].cnt_rave = &raveclist[chlistcnt];
252259 (*tr[cur].ch ).clear ();
260+ (*tr[cur].sum_rave ).clear ();
261+ (*tr[cur].cnt_rave ).clear ();
253262 chlistcnt++;
254263 for (int i = 0 ; i < BLSIZE; i++)
255264 if (avail[i]) // for valid
256265 {
257266 (*tr[cur].ch )[i]=trcnt;
258- tr[trcnt].sumv = tr[trcnt]. sum_rave = 0 .0f ;
267+ tr[trcnt].sumv = 0 .0f ;
259268 tr[trcnt].ch = nullptr ;
260- tr[trcnt].cnt = tr[trcnt].cnt_rave = 0 ;
269+ tr[trcnt].sum_rave = nullptr ;
270+ tr[trcnt].cnt_rave = nullptr ;
271+ tr[trcnt].cnt = 0 ;
261272 tr[trcnt].policy = output.p [i];
262273 tr[trcnt].move = i;
263274 tr[trcnt].fa = cur;
@@ -290,6 +301,16 @@ void MCTS::simulation_back(int cur)
290301 val = -fabs (val);
291302
292303backprop:
304+ int tcur = cur;
305+ int move = tr[cur].move ;
306+ while (tcur > 0 )
307+ {
308+ tcur = tr[tcur].fa ;
309+ (*tr[tcur].sum_rave )[move] += val;
310+ (*tr[tcur].cnt_rave )[move] ++;
311+ tcur = tr[tcur].fa ;
312+ }
313+
293314 tr[cur].sumv += val;
294315 tr[cur].cnt ++;
295316
0 commit comments