Skip to content

Commit 000b055

Browse files
committed
Reset threads before search + exit search early
1 parent 452f58c commit 000b055

File tree

6 files changed

+70
-23
lines changed

6 files changed

+70
-23
lines changed

.github/workflows/release-pipeline.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ jobs:
2525
working-directory: src/Sapling
2626
id: get_version
2727
run: |
28-
VERSION=1.2.4
28+
VERSION=1.2.5
2929
echo "Application version: $VERSION"
3030
echo "::set-output name=version::$VERSION"
3131

src/Sapling.Engine/DataGen/Bench.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ public static unsafe void Run(int depth = 14)
9999
Console.WriteLine(fen);
100100
gameState.ResetToFen(fen);
101101
stopwatch.Restart();
102-
var result = searcher.Search(gameState, depthLimit: depth, writeInfo: true);
102+
searcher.Reset(gameState);
103+
var result = searcher.Search(gameState, searcher.Stop, depthLimit: depth, threadId: 0);
103104

104105
totalTime += stopwatch.ElapsedMilliseconds;
105106
totalNodes += result.nodes;

src/Sapling.Engine/DataGen/DataGenerator.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,8 @@ private unsafe void RunWorker(DataGeneratorStats stats)
179179
}
180180
else
181181
{
182-
var (pv, _, s, _) = searcher.Search(gameState, nodeLimit: 6500, depthLimit: 60, writeInfo: false);
182+
searcher.Reset(gameState);
183+
var (pv, _, s, _) = searcher.Search(gameState, searcher.Stop, nodeLimit: 6500, depthLimit: 60);
183184
move = pv[0];
184185
score = s;
185186

src/Sapling.Engine/Search/ParallelSearcher.cs

+42-8
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,20 @@ public void Stop()
4646
}
4747
}
4848

49+
public void CancelSearch()
50+
{
51+
foreach (var searcher in Searchers)
52+
{
53+
searcher.Stop();
54+
}
55+
}
56+
4957
public (List<uint> pv, int depthSearched, int score, long nodes, TimeSpan duration) NodeBoundSearch(
5058
GameState state, int nodeLimit = 0, int maxDepth = 0)
5159
{
5260
var start = DateTime.Now;
53-
var searchResult = Searchers[0].Search(state, nodeLimit: nodeLimit, depthLimit: maxDepth);
61+
Searchers[0].Reset(state);
62+
var searchResult = Searchers[0].Search(state, CancelSearch, nodeLimit: nodeLimit, depthLimit: maxDepth);
5463
return (searchResult.pv, searchResult.depthSearched, searchResult.score,
5564
searchResult.nodes, DateTime.Now - start);
5665
}
@@ -69,7 +78,6 @@ public void Stop()
6978
// Prevent a previous searches timeout cancelling a new search
7079
return;
7180
}
72-
7381
// Stop all searchers once think time has been reached
7482
foreach (var searcher in Searchers)
7583
{
@@ -79,10 +87,12 @@ public void Stop()
7987
}
8088

8189
var start = DateTime.Now;
90+
DateTime? end = thinkTime > 0 ? DateTime.Now.AddMilliseconds(thinkTime) : null;
8291

8392
if (Searchers.Count == 1)
8493
{
85-
var searchResult = Searchers[0].Search(state, writeInfo: true);
94+
Searchers[0].Reset(state);
95+
var searchResult = Searchers[0].Search(state, CancelSearch, timeLimit: end, threadId: 0);
8696
return (searchResult.pv, searchResult.depthSearched, searchResult.score,
8797
searchResult.nodes, DateTime.Now - start);
8898
}
@@ -92,10 +102,28 @@ public void Stop()
92102
new ThreadLocal<(List<uint> move, int depthSearched, int score, long nodes)>(
93103
() => (new List<uint>(), 0, int.MinValue, 0), true);
94104

105+
foreach (var searcher in Searchers)
106+
{
107+
searcher.Reset(state);
108+
}
95109

96110
// Parallel search, with thread-local best move
97-
Parallel.For(0, Searchers.Count,
98-
i => { results.Value = Searchers[i].Search(state, searchers: Searchers, writeInfo: i == 0); });
111+
var threads = new Thread[Searchers.Count];
112+
for (int i = 0; i < Searchers.Count; i++)
113+
{
114+
int threadId = i;
115+
threads[i] = new Thread(() =>
116+
{
117+
results.Value = Searchers[threadId].Search(state, CancelSearch, timeLimit: end, threadId: threadId);
118+
});
119+
threads[i].Start();
120+
}
121+
122+
// Wait for all to complete
123+
foreach (var thread in threads)
124+
{
125+
thread.Join();
126+
}
99127

100128
var dt = DateTime.Now - start;
101129

@@ -120,7 +148,8 @@ public void Stop()
120148

121149
if (resultList.Count == 0)
122150
{
123-
var searchResult = Searchers[0].Search(state, depthLimit:0, writeInfo: true);
151+
Searchers[0].Reset(state);
152+
var searchResult = Searchers[0].Search(state, CancelSearch, depthLimit:0, threadId: 0);
124153
return (searchResult.pv, searchResult.depthSearched, searchResult.score,
125154
searchResult.nodes, DateTime.Now - start);
126155
}
@@ -161,17 +190,22 @@ public void Stop()
161190
var searchId = Guid.NewGuid();
162191
_prevSearchId = searchId;
163192

164-
165193
// Thread-local storage for best move in each thread
166194
var results =
167195
new ThreadLocal<(List<uint> move, int depthSearched, int score, long nodes)>(
168196
() => (new List<uint>(), 0, int.MinValue, 0), true);
169197

170198
var start = DateTime.Now;
171199

200+
201+
foreach (var searcher in Searchers)
202+
{
203+
searcher.Reset(state);
204+
}
205+
172206
// Parallel search, with thread-local best move
173207
Parallel.For(0, Searchers.Count,
174-
i => { results.Value = Searchers[i].Search(state, depthLimit: depth, writeInfo: i == 0); });
208+
i => { results.Value = Searchers[i].Search(state, CancelSearch, depthLimit: depth, threadId: i); });
175209
var dt = DateTime.Now - start;
176210

177211
Span<int> voteMap = stackalloc int[64 * 64];

src/Sapling.Engine/Search/Searcher.cs

+21-10
Original file line numberDiff line numberDiff line change
@@ -208,15 +208,12 @@ public static int GetAsperationWindow(int index)
208208
};
209209
}
210210

211-
public (List<uint> pv, int depthSearched, int score, long nodes) Search(GameState inputBoard, List<Searcher>? searchers = null, int nodeLimit = 0,
212-
int depthLimit = 0, bool writeInfo = false)
211+
public void Reset(GameState inputBoard)
213212
{
213+
_searchCancelled = false;
214214
NodesVisited = 0;
215215
BestSoFar = 0;
216216

217-
var depthSearched = 0;
218-
_searchCancelled = false;
219-
220217
NativeMemory.Clear(History, (nuint)HistoryLength * sizeof(int));
221218
NativeMemory.Clear(Counters, (nuint)CountersLength * sizeof(uint));
222219
NativeMemory.Clear(killers, (nuint)KillersLength * sizeof(uint));
@@ -230,6 +227,12 @@ public static int GetAsperationWindow(int index)
230227
NativeMemory.Clear(BucketCacheBlackBoards, (nuint)sizeof(BoardStateData) * NnueWeights.InputBuckets * 2);
231228

232229
Unsafe.CopyBlock(HashHistory, inputBoard.HashHistory, sizeof(ulong) * (uint)inputBoard.Board.TurnCount);
230+
}
231+
232+
public (List<uint> pv, int depthSearched, int score, long nodes) Search(GameState inputBoard, Action cancellSearch, List<Searcher>? searchers = null, int nodeLimit = 0,
233+
int depthLimit = 0, int threadId = -1, DateTime? timeLimit = null)
234+
{
235+
var depthSearched = 0;
233236

234237
var alpha = Constants.MinScore;
235238
var beta = Constants.MaxScore;
@@ -252,6 +255,8 @@ public static int GetAsperationWindow(int index)
252255
var startTime = DateTime.Now;
253256
for (var j = 1; j < maxDepth; j++)
254257
{
258+
var iterationStart = DateTime.Now;
259+
255260
if (_searchCancelled || (nodeLimit > 0 && NodesVisited > nodeLimit))
256261
{
257262
break;
@@ -297,7 +302,7 @@ public static int GetAsperationWindow(int index)
297302
depthSearched = j;
298303
bestEval = lastIterationEval;
299304

300-
if (writeInfo)
305+
if (threadId == 0)
301306
{
302307
var nodes = searchers?.Sum(s =>s.NodesVisited) ?? NodesVisited;
303308

@@ -316,15 +321,21 @@ public static int GetAsperationWindow(int index)
316321
}
317322

318323
Console.WriteLine(
319-
$"info depth {depthSearched} score {ScoreToString(bestEval)} nodes {nodes} nps {nps} time {(int)dt.TotalMilliseconds} pv{sb}");
324+
$"info depth {depthSearched} score {ScoreToString(bestEval)} nodes {nodes} nps {nps} time {(int)dt.TotalMilliseconds} pv{sb}");
325+
326+
var iterationDuration = DateTime.Now - iterationStart;
327+
var timeRemaining = timeLimit - DateTime.Now;
328+
if (timeRemaining.HasValue && timeRemaining.Value.TotalMilliseconds < iterationDuration.TotalMilliseconds * 1.5)
329+
{
330+
cancellSearch();
331+
break;
332+
}
320333
}
321334

322335
if (_searchCancelled || (nodeLimit > 0 && NodesVisited > nodeLimit))
323-
{
324336
break;
325-
}
326337
}
327-
338+
328339
return (GetPvMoveList(pvMoves), depthSearched, bestEval, NodesVisited);
329340
}
330341

src/Sapling/UciEngine.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ private void SetOption(string[] tokens)
6161
case "threads":
6262
if (tokens[3] == "value" && int.TryParse(tokens[4], out var searchThreads))
6363
{
64-
_threadCount = searchThreads;
65-
_parallelSearcher.SetThreads(searchThreads);
64+
_threadCount = Math.Clamp(searchThreads, 1, Environment.ProcessorCount);
65+
_parallelSearcher.SetThreads(_threadCount);
6666
LogToFile($"[Debug] Set Threads '{searchThreads}'");
6767
}
6868

0 commit comments

Comments
 (0)