Skip to content

Commit a2e609c

Browse files
committed
11
1 parent a3c1935 commit a2e609c

File tree

4 files changed

+306
-10
lines changed

4 files changed

+306
-10
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1442,6 +1442,8 @@ events.*
14421442
!/lzero/mcts/**/lib/*.cpp
14431443
!/lzero/mcts/**/lib/*.hpp
14441444
!/lzero/mcts/**/lib/*.h
1445+
!/lzero/mcts/ctree/ctree_alphazero/*.cpp
1446+
!/lzero/mcts/ctree/ctree_alphazero/*.h
14451447
**/tb/*
14461448
**/mcts/ctree/tests_cpp/*
14471449
**/*tmp*

lzero/mcts/ctree/ctree_alphazero/CMakeLists.txt

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,8 @@ project(mcts_alphazero VERSION 1.0)
99
# This is required for embedding Python in the project
1010
find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
1111

12-
# Add pybind11 as a subdirectory,
13-
# so that its build files are generated alongside the current project.
14-
# This is necessary because the current project depends on pybind11
15-
add_subdirectory(pybind11)
12+
# Find pybind11 package installed via pip
13+
find_package(pybind11 CONFIG REQUIRED)
1614

1715
# Add two .cpp files to the mcts_alphazero module
1816
# These files are compiled and linked into the module

lzero/mcts/ctree/ctree_alphazero/make.sh

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,11 @@ mkdir -p build
6363
# Navigate into the "build" directory
6464
cd build || exit
6565

66-
# Run cmake on the parent directory with the specified architecture
67-
cmake .. -DCMAKE_OSX_ARCHITECTURES="arm64"
66+
# Get pybind11 cmake directory
67+
PYBIND11_CMAKE_DIR=$(python -c "import pybind11; print(pybind11.get_cmake_dir())")
68+
69+
# Run cmake on the parent directory with the specified architecture and pybind11 path
70+
cmake .. -DCMAKE_OSX_ARCHITECTURES="arm64" -DCMAKE_PREFIX_PATH="$PYBIND11_CMAKE_DIR"
6871

6972
# Run the "make" command to compile the project
7073
make

lzero/mcts/ctree/ctree_alphazero/mcts_alphazero.cpp

Lines changed: 297 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ class MCTS {
131131
return std::make_pair(action, child);
132132
}
133133

134-
// Expand a leaf node by adding its children based on policy probabilities
134+
// Expand a single leaf node for non-batch case.
135135
double _expand_leaf_node(std::shared_ptr<Node> node, py::object simulate_env, py::object policy_value_func) {
136136
std::map<int, double> action_probs_dict;
137137
double leaf_value;
@@ -157,6 +157,263 @@ class MCTS {
157157
return leaf_value;
158158
}
159159

160+
// Batch expand multiple leaf nodes using parallel batch inference.
161+
// Returns a list of leaf values for all nodes.
162+
std::vector<double> _batch_expand_leaf_nodes(
163+
const std::vector<std::shared_ptr<Node>>& leaf_nodes,
164+
const std::vector<py::object>& simulate_envs,
165+
py::object policy_value_func_batch
166+
) {
167+
if (leaf_nodes.empty() || simulate_envs.empty()) {
168+
return std::vector<double>();
169+
}
170+
171+
int batch_size = leaf_nodes.size();
172+
std::vector<double> leaf_values(batch_size);
173+
174+
py::list env_list;
175+
for (int i = 0; i < batch_size; ++i) {
176+
env_list.append(simulate_envs[i]);
177+
}
178+
179+
py::list batch_results = policy_value_func_batch(env_list).cast<py::list>();
180+
181+
for (int i = 0; i < batch_size; ++i) {
182+
py::object env = simulate_envs[i];
183+
std::shared_ptr<Node> node = leaf_nodes[i];
184+
185+
py::tuple result = batch_results[i].cast<py::tuple>();
186+
std::map<int, double> action_probs_dict = result[0].cast<std::map<int, double>>();
187+
double leaf_value = result[1].cast<double>();
188+
189+
leaf_values[i] = leaf_value;
190+
191+
py::list legal_actions_list = env.attr("legal_actions").cast<py::list>();
192+
std::vector<int> legal_actions = legal_actions_list.cast<std::vector<int>>();
193+
194+
for (const auto& kv : action_probs_dict) {
195+
int action = kv.first;
196+
double prior_p = kv.second;
197+
if (std::find(legal_actions.begin(), legal_actions.end(), action) !=
198+
legal_actions.end()) {
199+
node->children[action] = std::make_shared<Node>(node, prior_p);
200+
}
201+
}
202+
}
203+
204+
return leaf_values;
205+
}
206+
207+
// Batch version: Get next actions for multiple environments with batch inference optimization.
208+
std::vector<std::tuple<int, std::vector<double>, std::shared_ptr<Node>>> get_next_actions_batch(
209+
py::list state_configs_list,
210+
py::object policy_value_func_batch,
211+
double temperature,
212+
bool sample,
213+
py::list simulate_env_list
214+
) {
215+
int batch_size = py::len(state_configs_list);
216+
std::vector<std::tuple<int, std::vector<double>, std::shared_ptr<Node>>> results;
217+
results.reserve(batch_size);
218+
219+
std::vector<std::shared_ptr<Node>> roots;
220+
roots.reserve(batch_size);
221+
222+
std::vector<py::object> init_states;
223+
std::vector<py::object> katago_game_states;
224+
init_states.reserve(batch_size);
225+
katago_game_states.reserve(batch_size);
226+
227+
for (int i = 0; i < batch_size; ++i) {
228+
roots.push_back(std::make_shared<Node>());
229+
py::object state_config = state_configs_list[i].cast<py::object>();
230+
231+
py::object init_state = state_config["init_state"];
232+
if (!init_state.is_none()) {
233+
init_state = py::bytes(init_state.attr("tobytes")());
234+
}
235+
init_states.push_back(init_state);
236+
237+
py::object katago_game_state = state_config["katago_game_state"];
238+
if (!katago_game_state.is_none()) {
239+
katago_game_state = py::module::import("pickle").attr("dumps")(katago_game_state);
240+
}
241+
katago_game_states.push_back(katago_game_state);
242+
}
243+
244+
py::list env_list;
245+
for (int i = 0; i < batch_size; ++i) {
246+
py::object state_config = state_configs_list[i].cast<py::object>();
247+
py::object env = simulate_env_list[i];
248+
env.attr("reset")(
249+
state_config["start_player_index"].cast<int>(),
250+
init_states[i],
251+
state_config["katago_policy_init"].cast<bool>(),
252+
katago_game_states[i]
253+
);
254+
env_list.append(env);
255+
}
256+
257+
py::list batch_results = policy_value_func_batch(env_list).cast<py::list>();
258+
259+
for (int i = 0; i < batch_size; ++i) {
260+
py::object env = simulate_env_list[i];
261+
262+
py::tuple result = batch_results[i].cast<py::tuple>();
263+
std::map<int, double> action_probs_dict = result[0].cast<std::map<int, double>>();
264+
265+
py::list legal_actions_list = env.attr("legal_actions").cast<py::list>();
266+
std::vector<int> legal_actions = legal_actions_list.cast<std::vector<int>>();
267+
268+
for (const auto& kv : action_probs_dict) {
269+
if (std::find(legal_actions.begin(), legal_actions.end(), kv.first) !=
270+
legal_actions.end()) {
271+
roots[i]->children[kv.first] = std::make_shared<Node>(roots[i], kv.second);
272+
}
273+
}
274+
275+
if (sample) {
276+
_add_exploration_noise(roots[i]);
277+
}
278+
}
279+
280+
for (int n = 0; n < num_simulations; ++n) {
281+
std::vector<SimulationResult> simulation_results;
282+
simulation_results.reserve(batch_size);
283+
284+
for (int i = 0; i < batch_size; ++i) {
285+
py::object state_config = state_configs_list[i].cast<py::object>();
286+
py::object env = simulate_env_list[i];
287+
288+
env.attr("reset")(
289+
state_config["start_player_index"].cast<int>(),
290+
init_states[i],
291+
state_config["katago_policy_init"].cast<bool>(),
292+
katago_game_states[i]
293+
);
294+
env.attr("battle_mode") = env.attr("battle_mode_in_simulation_env");
295+
296+
SimulationResult sim_result = _simulate_to_leaf(roots[i], env);
297+
simulation_results.push_back(sim_result);
298+
}
299+
300+
std::vector<int> unfinished_indices;
301+
std::vector<std::shared_ptr<Node>> leaf_nodes_to_expand;
302+
std::vector<py::object> envs_to_infer;
303+
304+
for (int i = 0; i < batch_size; ++i) {
305+
if (!simulation_results[i].is_done) {
306+
unfinished_indices.push_back(i);
307+
leaf_nodes_to_expand.push_back(simulation_results[i].leaf_node);
308+
envs_to_infer.push_back(simulation_results[i].simulate_env);
309+
}
310+
}
311+
312+
std::vector<double> leaf_values;
313+
if (!unfinished_indices.empty()) {
314+
leaf_values = _batch_expand_leaf_nodes(
315+
leaf_nodes_to_expand,
316+
envs_to_infer,
317+
policy_value_func_batch
318+
);
319+
}
320+
321+
for (int i = 0; i < batch_size; ++i) {
322+
std::shared_ptr<Node> leaf_node = simulation_results[i].leaf_node;
323+
py::object env = simulation_results[i].simulate_env;
324+
double leaf_value;
325+
326+
if (simulation_results[i].is_done) {
327+
std::string battle_mode =
328+
env.attr("battle_mode_in_simulation_env").cast<std::string>();
329+
int winner = simulation_results[i].winner;
330+
331+
if (battle_mode == "self_play_mode") {
332+
if (winner == -1) {
333+
leaf_value = 0;
334+
} else {
335+
leaf_value =
336+
(env.attr("current_player").cast<int>() == winner) ? 1 : -1;
337+
}
338+
}
339+
else if (battle_mode == "play_with_bot_mode") {
340+
if (winner == -1) {
341+
leaf_value = 0;
342+
}
343+
else if (winner == 1) {
344+
leaf_value = 1;
345+
}
346+
else if (winner == 2) {
347+
leaf_value = -1;
348+
}
349+
}
350+
} else {
351+
auto it = std::find(unfinished_indices.begin(), unfinished_indices.end(), i);
352+
if (it != unfinished_indices.end()) {
353+
int result_idx = std::distance(unfinished_indices.begin(), it);
354+
leaf_value = leaf_values[result_idx];
355+
} else {
356+
leaf_value = 0;
357+
}
358+
}
359+
360+
std::string battle_mode =
361+
env.attr("battle_mode_in_simulation_env").cast<std::string>();
362+
if (battle_mode == "play_with_bot_mode") {
363+
leaf_node->update_recursive(leaf_value, battle_mode);
364+
}
365+
else if (battle_mode == "self_play_mode") {
366+
leaf_node->update_recursive(-leaf_value, battle_mode);
367+
}
368+
}
369+
}
370+
371+
for (int i = 0; i < batch_size; ++i) {
372+
py::object state_config = state_configs_list[i].cast<py::object>();
373+
py::object env = simulate_env_list[i];
374+
375+
env.attr("reset")(
376+
state_config["start_player_index"].cast<int>(),
377+
init_states[i],
378+
state_config["katago_policy_init"].cast<bool>(),
379+
katago_game_states[i]
380+
);
381+
382+
std::vector<std::pair<int, int>> action_visits;
383+
int action_space_n = env.attr("action_space").attr("n").cast<int>();
384+
for (int action = 0; action < action_space_n; ++action) {
385+
if (roots[i]->children.count(action)) {
386+
action_visits.emplace_back(action, roots[i]->children[action]->visit_count);
387+
} else {
388+
action_visits.emplace_back(action, 0);
389+
}
390+
}
391+
392+
std::vector<int> actions;
393+
std::vector<int> visits;
394+
for (const auto& av : action_visits) {
395+
actions.emplace_back(av.first);
396+
visits.emplace_back(av.second);
397+
}
398+
399+
std::vector<double> visits_d(visits.begin(), visits.end());
400+
std::vector<double> action_probs =
401+
visit_count_to_action_distribution(visits_d, temperature);
402+
403+
int action_selected;
404+
if (sample) {
405+
action_selected = random_choice(actions, action_probs);
406+
} else {
407+
auto max_it = std::max_element(action_probs.begin(), action_probs.end());
408+
action_selected = actions[std::distance(action_probs.begin(), max_it)];
409+
}
410+
411+
results.push_back(std::make_tuple(action_selected, action_probs, roots[i]));
412+
}
413+
414+
return results;
415+
}
416+
160417
// Main function to get the next action from MCTS
161418
std::tuple<int, std::vector<double>, std::shared_ptr<Node>> get_next_action(py::object state_config_for_env_reset, py::object policy_value_func, double temperature, bool sample) {
162419
std::shared_ptr<Node> root = std::make_shared<Node>();
@@ -228,7 +485,37 @@ class MCTS {
228485
return std::make_tuple(action_selected, action_probs, root);
229486
}
230487

231-
// Simulate a game starting from a given node
488+
// Structure to store single simulation result for parallel batch inference.
489+
struct SimulationResult {
490+
std::shared_ptr<Node> leaf_node;
491+
bool is_done;
492+
int winner;
493+
py::object simulate_env;
494+
};
495+
496+
// Single simulation from root to leaf node without inference, returns leaf node info.
497+
SimulationResult _simulate_to_leaf(std::shared_ptr<Node> node, py::object simulate_env) {
498+
while (!node->is_leaf()) {
499+
int action;
500+
std::shared_ptr<Node> child;
501+
std::tie(action, child) = _select_child(node, simulate_env);
502+
if (action == -1) {
503+
break;
504+
}
505+
simulate_env.attr("step")(action);
506+
node = child;
507+
}
508+
509+
bool done;
510+
int winner;
511+
py::tuple result = simulate_env.attr("get_done_winner")();
512+
done = result[0].cast<bool>();
513+
winner = result[1].cast<int>();
514+
515+
return SimulationResult{node, done, winner, simulate_env};
516+
}
517+
518+
// Legacy single environment simulation function (kept for backward compatibility).
232519
void _simulate(std::shared_ptr<Node> node, py::object simulate_env, py::object policy_value_func) {
233520
while (!node->is_leaf()) {
234521
int action;
@@ -372,5 +659,11 @@ PYBIND11_MODULE(mcts_alphazero, m) {
372659
py::arg("state_config_for_env_reset"),
373660
py::arg("policy_value_func"),
374661
py::arg("temperature"),
375-
py::arg("sample"));
376-
}
662+
py::arg("sample"))
663+
.def("get_next_actions_batch", &MCTS::get_next_actions_batch,
664+
py::arg("state_configs_list"),
665+
py::arg("policy_value_func_batch"),
666+
py::arg("temperature"),
667+
py::arg("sample"),
668+
py::arg("simulate_env_list"));
669+
}

0 commit comments

Comments
 (0)