Skip to content

Commit e0c7392

Browse files
committed
better implement of tree search
1 parent 4504e17 commit e0c7392

2 files changed

Lines changed: 73 additions & 65 deletions

File tree

cppsrc/Search.cpp

Lines changed: 66 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,10 @@ void MCTS::addNoise(int cur, Val epsilon, Val alpha)
4646
sum += dirichlet[i];
4747
}
4848

49-
for (size_t i = 0; i < tr[cur].ch.size(); i++)
49+
for (int i = 0; i < BLSIZE; i++)
5050
{
51-
int ch = tr[cur].ch[i];
52-
tr[ch].policy = (1-epsilon) * tr[ch].policy + epsilon * dirichlet[i] / sum;
51+
int ch = (*tr[cur].ch)[i];
52+
if (ch) tr[ch].policy = (1-epsilon) * tr[ch].policy + epsilon * dirichlet[i] / sum;
5353
}
5454
}
5555

@@ -60,6 +60,7 @@ MCTS::MCTS(Board &_board, int _col, NN *_network, int _playouts):boardhash(_boar
6060
network = _network;
6161
playouts = _playouts;
6262
tr = new Node[(playouts+2)*BLSIZE];
63+
chlist = new Board[playouts + 2];
6364
Prior::setbyBoard(board);
6465
Prior::setPlayer(nowcol);
6566
starttime = clock();
@@ -87,7 +88,7 @@ void MCTS::unmake_move(int move)
8788

8889
void MCTS::createRoot()
8990
{
90-
trcnt = 1;
91+
trcnt = 1; chlistcnt = 0;
9192
stopflag = false;
9293
tr[root].cnt = 1;
9394
tr[root].fa = -1;
@@ -107,31 +108,38 @@ void MCTS::createRoot()
107108
addNoise(root, 0.25f, 0.1f);
108109
}
109110

110-
//LZ's selection, weaker than current
111+
//UCT-RAVE selection, use father hueristic
111112
int MCTS::selection(int cur)
112113
{
113-
if (tr[cur].ch.size() == 0)
114+
if (!tr[cur].ch)
114115
return cur;
115-
auto sum_policy = 0.0f;
116-
for (auto &ch : tr[cur].ch)
117-
if (tr[ch].cnt)
118-
sum_policy += tr[ch].policy;
119-
120-
auto numerator = sqrtf((Val)tr[cur].ch.size());
121-
auto fpu = 0.3f * sqrtf(sum_policy);
122-
int maxc = 0; Val maxv = -FLOAT_INF;
123-
124-
for (auto &ch : tr[cur].ch)
116+
int maxp = 0;
117+
Val maxv = -FLOAT_INF;
118+
for (int i=0;i<BLSIZE;i++)
125119
{
126-
auto winrate = tr[ch].cnt == 0 ? (-tr[cur].sumv / tr[cur].cnt) - fpu : tr[ch].sumv / tr[ch].cnt;
127-
auto ucb = winrate + UCBC * tr[ch].policy * numerator / (1.0f + tr[ch].cnt);
120+
int ch = (*tr[cur].ch)[i];
121+
if (!ch) continue;
122+
Val ucb;
123+
Val var_ele = UCBC*tr[ch].policy*sqrtf((Val)tr[cur].cnt) / (1 + tr[ch].cnt);
124+
125+
//rescale range
126+
Val father_val = (-tr[cur].sumv / tr[cur].cnt + 1.0f) / 1.1f - 1.0f;
127+
static const Val father_decay = 0.5f;
128+
Val frac1 = powf(father_decay, tr[ch].cnt);
129+
130+
if (tr[ch].is_end) frac1 = 0;
131+
132+
if (tr[ch].cnt == 0)
133+
ucb = father_val + var_ele;
134+
else
135+
ucb = frac1 * father_val + (1 - frac1)*tr[ch].sumv / tr[ch].cnt + var_ele;
128136
if (ucb > maxv)
129137
{
130138
maxv = ucb;
131-
maxc = ch;
139+
maxp = ch;
132140
}
133141
}
134-
return maxc;
142+
return maxp;
135143
}
136144

137145
bool MCTS::getTimeLimit(int played)
@@ -141,20 +149,24 @@ bool MCTS::getTimeLimit(int played)
141149
clock_t t1 = clock();
142150
if (timeout_turn && t1 - starttime >= timeout_turn - 500)
143151
return true;
152+
Val maxrate = 0;
153+
for (int i = 0; i < BLSIZE; i++)
154+
{
155+
int ch = (*tr[root].ch)[i];
156+
if (ch) maxrate = std::max(maxrate, (float)tr[ch].cnt / played);
157+
}
158+
144159
if (played > 600)
145-
for (auto it : tr[root].ch)
146-
if ((float)tr[it].cnt / played > 0.95)
160+
if (maxrate > 0.95)
147161
return true;
148162

149163
if (played > 1800)
150-
for (auto it : tr[root].ch)
151-
if ((float)tr[it].cnt / played > 0.90)
152-
return true;
164+
if (maxrate > 0.90)
165+
return true;
153166

154167
if (played > 4000)
155-
for (auto it : tr[root].ch)
156-
if ((float)tr[it].cnt / played > 0.85)
157-
return true;
168+
if (maxrate > 0.85)
169+
return true;
158170

159171
if (timeout_left)
160172
{
@@ -178,29 +190,14 @@ void MCTS::solve(BoardWeight &result)
178190
int cur = root;
179191
while (1)
180192
{
181-
Val maxv = -FLOAT_INF;
182-
//int maxp = selection(cur);
183-
184-
int maxp = cur;
185-
for (auto &ch : tr[cur].ch)
186-
{
187-
Val ucb;
188-
if (tr[ch].cnt == 0)
189-
ucb = (-tr[cur].sumv / tr[cur].cnt + 1.0f)/1.1f -1.0f + UCBC * tr[ch].policy*sqrtf((Val)tr[cur].cnt);
190-
else
191-
ucb = tr[ch].sumv / tr[ch].cnt + UCBC*tr[ch].policy*sqrtf((Val)tr[cur].cnt) / (1 + tr[ch].cnt);
192-
if (ucb > maxv)
193-
{
194-
maxv = ucb;
195-
maxp = ch;
196-
}
197-
}
193+
int maxp = selection(cur);
198194

199195
//leaf node
200196
if (maxp == cur) break;
201197
//forward search
202198
cur = maxp;
203199
make_move(tr[cur].move);
200+
//if (tr[cur].is_end) break;
204201
}
205202
//simulation & backpropagation
206203
simulation_back(cur);
@@ -214,9 +211,12 @@ void MCTS::solve(BoardWeight &result)
214211
{
215212
debug_s << board2showString(board, true);
216213
vector<std::pair<int, int>> pvlist;
217-
for (auto c : tr[root].ch)
218-
if (tr[c].cnt)
219-
pvlist.push_back({tr[c].cnt, c});
214+
for (int i = 0; i < BLSIZE; i++)
215+
{
216+
int ch = (*tr[root].ch)[i];
217+
if (ch && tr[ch].cnt)
218+
pvlist.push_back({ tr[ch].cnt, ch });
219+
}
220220
sort(pvlist.begin(), pvlist.end());
221221
for (int i = 0; i < std::min(10, (int)pvlist.size());i++)
222222
{
@@ -228,9 +228,10 @@ void MCTS::solve(BoardWeight &result)
228228
}
229229
logRefrsh();
230230
result.clear();
231-
for (auto ch : tr[root].ch)
231+
for (int i = 0; i < BLSIZE; i++)
232232
{
233-
result[tr[ch].move] = (Val)tr[ch].cnt;
233+
int ch = (*tr[root].ch)[i];
234+
if (ch) result[i] = (Val)tr[ch].cnt;
234235
//std::cout << tr[ch].move << ' ' << tr[ch].cnt << ' ' << tr[ch].sumv / tr[ch].cnt << '\n';
235236
}
236237
}
@@ -247,17 +248,20 @@ void MCTS::expand(int cur,RawOutput &output, Board &avail)
247248
{
248249
//board.debug();
249250
//std::cout<<"netwin:"<<output.v<<'\n';
250-
if (!tr[cur].ch.empty())
251-
tr[cur].ch.clear();
251+
tr[cur].ch = &chlist[chlistcnt];
252+
(*tr[cur].ch).clear();
253+
chlistcnt++;
252254
for (int i = 0; i < BLSIZE; i++)
253255
if (avail[i]) //for valid
254256
{
255-
tr[cur].ch.push_back(trcnt);
256-
tr[trcnt].sumv = 0;
257-
tr[trcnt].cnt = 0;
257+
(*tr[cur].ch)[i]=trcnt;
258+
tr[trcnt].sumv = tr[trcnt].sum_rave = 0.0f;
259+
tr[trcnt].ch = nullptr;
260+
tr[trcnt].cnt = tr[trcnt].cnt_rave = 0;
258261
tr[trcnt].policy = output.p[i];
259262
tr[trcnt].move = i;
260263
tr[trcnt].fa = cur;
264+
tr[trcnt].is_end = false;
261265
trcnt++;
262266
}
263267
}
@@ -270,24 +274,25 @@ void MCTS::simulation_back(int cur)
270274
auto result = getEvaluation(board, nowcol, network, use_transform, tr[cur].move);
271275
val = -result.first.v;
272276
expand(cur, result.first, result.second);
273-
auto &it = hash_table.find(boardhash());
274-
if (it != hash_table.end())
275-
counter++;
276-
else
277-
hash_table[boardhash()] = cur;
277+
if (val<-0.99 || val>0.99) tr[cur].is_end = true;
278+
//auto &it = hash_table.find(boardhash());
279+
//if (it != hash_table.end()) counter++;
280+
//else hash_table[boardhash()] = cur;
278281
}
279282
else
280283
{
281284
val = 1;
282285
if (board.count() == BLSIZE)
283286
val = 0;
287+
tr[cur].is_end = true;
284288
}
285289
if (cfg_swap3 && board.count() == 3 && board.countv(2) == 1 && tr[cur].fa>0) //if swap, player choice max rate point
286290
val = -fabs(val);
287291

292+
backprop:
288293
tr[cur].sumv += val;
289294
tr[cur].cnt++;
290-
//back
295+
291296
while (cur > 0)
292297
{
293298
unmake_move(tr[cur].move);

cppsrc/Search.h

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,11 @@ class MCTS
2626
NN* network;
2727
struct Node
2828
{
29-
Val sumv,policy;
30-
int cnt; int fa;
31-
vector<int> ch;
29+
Val sumv,policy, sum_rave;
30+
int cnt,cnt_rave; int fa;
3231
int move;
32+
bool is_end;
33+
Board *ch;
3334
void print()
3435
{
3536
//fout << "rate:" << score / cnt << " cnt:" << cnt << " " << Coord(p) << " ch:";
@@ -38,10 +39,11 @@ class MCTS
3839
}
3940
};
4041
Node *tr;
42+
Board *chlist;
4143
std::map<unsigned long long, int> hash_table;
4244
BoardHasher boardhash;
4345
const int root = 0;
44-
int trcnt = 1, viscnt = 0;
46+
int trcnt, chlistcnt, viscnt = 0;
4547
Board board;
4648
int nowcol;
4749
int counter = 0;
@@ -64,6 +66,7 @@ class MCTS
6466
~MCTS()
6567
{
6668
delete[] tr;
69+
delete[] chlist;
6770
}
6871
};
6972

0 commit comments

Comments
 (0)