Skip to content

Commit f414915

Browse files
committed
modify search
add UCT_RAVE, but disabled
1 parent e0c7392 commit f414915

3 files changed

Lines changed: 35 additions & 9 deletions

File tree

cppsrc/Board.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,7 @@ int BoardArray<int>::countv(int col) const
4545
template <>
4646
void BoardArray<float>::clear()
4747
{
48-
for (int i = 0; i < BLSIZE; i++)
49-
m[i] = 0;
48+
memset(m, 0, sizeof(m));
5049
}
5150

5251
template <>

cppsrc/Search.cpp

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ MCTS::MCTS(Board &_board, int _col, NN *_network, int _playouts):boardhash(_boar
6161
playouts = _playouts;
6262
tr = new Node[(playouts+2)*BLSIZE];
6363
chlist = new Board[playouts + 2];
64+
ravelist = new BoardWeight[playouts + 2];
65+
raveclist = new Board[playouts + 2];
6466
Prior::setbyBoard(board);
6567
Prior::setPlayer(nowcol);
6668
starttime = clock();
@@ -126,13 +128,16 @@ int MCTS::selection(int cur)
126128
Val father_val = (-tr[cur].sumv / tr[cur].cnt + 1.0f) / 1.1f - 1.0f;
127129
static const Val father_decay = 0.5f;
128130
Val frac1 = powf(father_decay, tr[ch].cnt);
129-
130-
if (tr[ch].is_end) frac1 = 0;
131+
Val rave_cnt = (Val)(*tr[cur].cnt_rave)[i];
132+
Val rave_win = (*tr[cur].sum_rave)[i] / rave_cnt;
133+
//Val rave_beta = rave_cnt /(rave_cnt + tr[ch].cnt + 2*rave_cnt*tr[ch].cnt);
134+
Val rave_beta = 0.0f;
135+
if (tr[ch].is_end) frac1 = 0, rave_beta = 0;
131136

132137
if (tr[ch].cnt == 0)
133138
ucb = father_val + var_ele;
134139
else
135-
ucb = frac1 * father_val + (1 - frac1)*tr[ch].sumv / tr[ch].cnt + var_ele;
140+
ucb = rave_beta * rave_win+(1-rave_beta)*(frac1 * father_val + (1 - frac1)*tr[ch].sumv / tr[ch].cnt) + var_ele;
136141
if (ucb > maxv)
137142
{
138143
maxv = ucb;
@@ -249,15 +254,21 @@ void MCTS::expand(int cur,RawOutput &output, Board &avail)
249254
//board.debug();
250255
//std::cout<<"netwin:"<<output.v<<'\n';
251256
tr[cur].ch = &chlist[chlistcnt];
257+
tr[cur].sum_rave = &ravelist[chlistcnt];
258+
tr[cur].cnt_rave = &raveclist[chlistcnt];
252259
(*tr[cur].ch).clear();
260+
(*tr[cur].sum_rave).clear();
261+
(*tr[cur].cnt_rave).clear();
253262
chlistcnt++;
254263
for (int i = 0; i < BLSIZE; i++)
255264
if (avail[i]) //for valid
256265
{
257266
(*tr[cur].ch)[i]=trcnt;
258-
tr[trcnt].sumv = tr[trcnt].sum_rave = 0.0f;
267+
tr[trcnt].sumv = 0.0f;
259268
tr[trcnt].ch = nullptr;
260-
tr[trcnt].cnt = tr[trcnt].cnt_rave = 0;
269+
tr[trcnt].sum_rave = nullptr;
270+
tr[trcnt].cnt_rave = nullptr;
271+
tr[trcnt].cnt = 0;
261272
tr[trcnt].policy = output.p[i];
262273
tr[trcnt].move = i;
263274
tr[trcnt].fa = cur;
@@ -290,6 +301,16 @@ void MCTS::simulation_back(int cur)
290301
val = -fabs(val);
291302

292303
backprop:
304+
int tcur = cur;
305+
int move = tr[cur].move;
306+
while (tcur > 0)
307+
{
308+
tcur = tr[tcur].fa;
309+
(*tr[tcur].sum_rave)[move] += val;
310+
(*tr[tcur].cnt_rave)[move] ++;
311+
tcur = tr[tcur].fa;
312+
}
313+
293314
tr[cur].sumv += val;
294315
tr[cur].cnt++;
295316

cppsrc/Search.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@ class MCTS
2626
NN* network;
2727
struct Node
2828
{
29-
Val sumv,policy, sum_rave;
30-
int cnt,cnt_rave; int fa;
29+
Val sumv,policy;
30+
int cnt; int fa;
3131
int move;
3232
bool is_end;
3333
Board *ch;
34+
Board *cnt_rave;
35+
BoardWeight *sum_rave;
3436
void print()
3537
{
3638
//fout << "rate:" << score / cnt << " cnt:" << cnt << " " << Coord(p) << " ch:";
@@ -40,6 +42,8 @@ class MCTS
4042
};
4143
Node *tr;
4244
Board *chlist;
45+
Board *raveclist;
46+
BoardWeight *ravelist;
4347
std::map<unsigned long long, int> hash_table;
4448
BoardHasher boardhash;
4549
const int root = 0;
@@ -67,6 +71,8 @@ class MCTS
6771
{
6872
delete[] tr;
6973
delete[] chlist;
74+
delete[] ravelist;
75+
delete[] raveclist;
7076
}
7177
};
7278

0 commit comments

Comments
 (0)