@@ -131,7 +131,7 @@ class MCTS {
131131 return std::make_pair (action, child);
132132 }
133133
134- // Expand a leaf node by adding its children based on policy probabilities
134+ // Expand a single leaf node for non-batch case.
135135 double _expand_leaf_node (std::shared_ptr<Node> node, py::object simulate_env, py::object policy_value_func) {
136136 std::map<int , double > action_probs_dict;
137137 double leaf_value;
@@ -157,6 +157,263 @@ class MCTS {
157157 return leaf_value;
158158 }
159159
160+ // Batch expand multiple leaf nodes using parallel batch inference.
161+ // Returns a list of leaf values for all nodes.
162+ std::vector<double > _batch_expand_leaf_nodes (
163+ const std::vector<std::shared_ptr<Node>>& leaf_nodes,
164+ const std::vector<py::object>& simulate_envs,
165+ py::object policy_value_func_batch
166+ ) {
167+ if (leaf_nodes.empty () || simulate_envs.empty ()) {
168+ return std::vector<double >();
169+ }
170+
171+ int batch_size = leaf_nodes.size ();
172+ std::vector<double > leaf_values (batch_size);
173+
174+ py::list env_list;
175+ for (int i = 0 ; i < batch_size; ++i) {
176+ env_list.append (simulate_envs[i]);
177+ }
178+
179+ py::list batch_results = policy_value_func_batch (env_list).cast <py::list>();
180+
181+ for (int i = 0 ; i < batch_size; ++i) {
182+ py::object env = simulate_envs[i];
183+ std::shared_ptr<Node> node = leaf_nodes[i];
184+
185+ py::tuple result = batch_results[i].cast <py::tuple>();
186+ std::map<int , double > action_probs_dict = result[0 ].cast <std::map<int , double >>();
187+ double leaf_value = result[1 ].cast <double >();
188+
189+ leaf_values[i] = leaf_value;
190+
191+ py::list legal_actions_list = env.attr (" legal_actions" ).cast <py::list>();
192+ std::vector<int > legal_actions = legal_actions_list.cast <std::vector<int >>();
193+
194+ for (const auto & kv : action_probs_dict) {
195+ int action = kv.first ;
196+ double prior_p = kv.second ;
197+ if (std::find (legal_actions.begin (), legal_actions.end (), action) !=
198+ legal_actions.end ()) {
199+ node->children [action] = std::make_shared<Node>(node, prior_p);
200+ }
201+ }
202+ }
203+
204+ return leaf_values;
205+ }
206+
207+ // Batch version: Get next actions for multiple environments with batch inference optimization.
208+ std::vector<std::tuple<int , std::vector<double >, std::shared_ptr<Node>>> get_next_actions_batch (
209+ py::list state_configs_list,
210+ py::object policy_value_func_batch,
211+ double temperature,
212+ bool sample,
213+ py::list simulate_env_list
214+ ) {
215+ int batch_size = py::len (state_configs_list);
216+ std::vector<std::tuple<int , std::vector<double >, std::shared_ptr<Node>>> results;
217+ results.reserve (batch_size);
218+
219+ std::vector<std::shared_ptr<Node>> roots;
220+ roots.reserve (batch_size);
221+
222+ std::vector<py::object> init_states;
223+ std::vector<py::object> katago_game_states;
224+ init_states.reserve (batch_size);
225+ katago_game_states.reserve (batch_size);
226+
227+ for (int i = 0 ; i < batch_size; ++i) {
228+ roots.push_back (std::make_shared<Node>());
229+ py::object state_config = state_configs_list[i].cast <py::object>();
230+
231+ py::object init_state = state_config[" init_state" ];
232+ if (!init_state.is_none ()) {
233+ init_state = py::bytes (init_state.attr (" tobytes" )());
234+ }
235+ init_states.push_back (init_state);
236+
237+ py::object katago_game_state = state_config[" katago_game_state" ];
238+ if (!katago_game_state.is_none ()) {
239+ katago_game_state = py::module::import (" pickle" ).attr (" dumps" )(katago_game_state);
240+ }
241+ katago_game_states.push_back (katago_game_state);
242+ }
243+
244+ py::list env_list;
245+ for (int i = 0 ; i < batch_size; ++i) {
246+ py::object state_config = state_configs_list[i].cast <py::object>();
247+ py::object env = simulate_env_list[i];
248+ env.attr (" reset" )(
249+ state_config[" start_player_index" ].cast <int >(),
250+ init_states[i],
251+ state_config[" katago_policy_init" ].cast <bool >(),
252+ katago_game_states[i]
253+ );
254+ env_list.append (env);
255+ }
256+
257+ py::list batch_results = policy_value_func_batch (env_list).cast <py::list>();
258+
259+ for (int i = 0 ; i < batch_size; ++i) {
260+ py::object env = simulate_env_list[i];
261+
262+ py::tuple result = batch_results[i].cast <py::tuple>();
263+ std::map<int , double > action_probs_dict = result[0 ].cast <std::map<int , double >>();
264+
265+ py::list legal_actions_list = env.attr (" legal_actions" ).cast <py::list>();
266+ std::vector<int > legal_actions = legal_actions_list.cast <std::vector<int >>();
267+
268+ for (const auto & kv : action_probs_dict) {
269+ if (std::find (legal_actions.begin (), legal_actions.end (), kv.first ) !=
270+ legal_actions.end ()) {
271+ roots[i]->children [kv.first ] = std::make_shared<Node>(roots[i], kv.second );
272+ }
273+ }
274+
275+ if (sample) {
276+ _add_exploration_noise (roots[i]);
277+ }
278+ }
279+
280+ for (int n = 0 ; n < num_simulations; ++n) {
281+ std::vector<SimulationResult> simulation_results;
282+ simulation_results.reserve (batch_size);
283+
284+ for (int i = 0 ; i < batch_size; ++i) {
285+ py::object state_config = state_configs_list[i].cast <py::object>();
286+ py::object env = simulate_env_list[i];
287+
288+ env.attr (" reset" )(
289+ state_config[" start_player_index" ].cast <int >(),
290+ init_states[i],
291+ state_config[" katago_policy_init" ].cast <bool >(),
292+ katago_game_states[i]
293+ );
294+ env.attr (" battle_mode" ) = env.attr (" battle_mode_in_simulation_env" );
295+
296+ SimulationResult sim_result = _simulate_to_leaf (roots[i], env);
297+ simulation_results.push_back (sim_result);
298+ }
299+
300+ std::vector<int > unfinished_indices;
301+ std::vector<std::shared_ptr<Node>> leaf_nodes_to_expand;
302+ std::vector<py::object> envs_to_infer;
303+
304+ for (int i = 0 ; i < batch_size; ++i) {
305+ if (!simulation_results[i].is_done ) {
306+ unfinished_indices.push_back (i);
307+ leaf_nodes_to_expand.push_back (simulation_results[i].leaf_node );
308+ envs_to_infer.push_back (simulation_results[i].simulate_env );
309+ }
310+ }
311+
312+ std::vector<double > leaf_values;
313+ if (!unfinished_indices.empty ()) {
314+ leaf_values = _batch_expand_leaf_nodes (
315+ leaf_nodes_to_expand,
316+ envs_to_infer,
317+ policy_value_func_batch
318+ );
319+ }
320+
321+ for (int i = 0 ; i < batch_size; ++i) {
322+ std::shared_ptr<Node> leaf_node = simulation_results[i].leaf_node ;
323+ py::object env = simulation_results[i].simulate_env ;
324+ double leaf_value;
325+
326+ if (simulation_results[i].is_done ) {
327+ std::string battle_mode =
328+ env.attr (" battle_mode_in_simulation_env" ).cast <std::string>();
329+ int winner = simulation_results[i].winner ;
330+
331+ if (battle_mode == " self_play_mode" ) {
332+ if (winner == -1 ) {
333+ leaf_value = 0 ;
334+ } else {
335+ leaf_value =
336+ (env.attr (" current_player" ).cast <int >() == winner) ? 1 : -1 ;
337+ }
338+ }
339+ else if (battle_mode == " play_with_bot_mode" ) {
340+ if (winner == -1 ) {
341+ leaf_value = 0 ;
342+ }
343+ else if (winner == 1 ) {
344+ leaf_value = 1 ;
345+ }
346+ else if (winner == 2 ) {
347+ leaf_value = -1 ;
348+ }
349+ }
350+ } else {
351+ auto it = std::find (unfinished_indices.begin (), unfinished_indices.end (), i);
352+ if (it != unfinished_indices.end ()) {
353+ int result_idx = std::distance (unfinished_indices.begin (), it);
354+ leaf_value = leaf_values[result_idx];
355+ } else {
356+ leaf_value = 0 ;
357+ }
358+ }
359+
360+ std::string battle_mode =
361+ env.attr (" battle_mode_in_simulation_env" ).cast <std::string>();
362+ if (battle_mode == " play_with_bot_mode" ) {
363+ leaf_node->update_recursive (leaf_value, battle_mode);
364+ }
365+ else if (battle_mode == " self_play_mode" ) {
366+ leaf_node->update_recursive (-leaf_value, battle_mode);
367+ }
368+ }
369+ }
370+
371+ for (int i = 0 ; i < batch_size; ++i) {
372+ py::object state_config = state_configs_list[i].cast <py::object>();
373+ py::object env = simulate_env_list[i];
374+
375+ env.attr (" reset" )(
376+ state_config[" start_player_index" ].cast <int >(),
377+ init_states[i],
378+ state_config[" katago_policy_init" ].cast <bool >(),
379+ katago_game_states[i]
380+ );
381+
382+ std::vector<std::pair<int , int >> action_visits;
383+ int action_space_n = env.attr (" action_space" ).attr (" n" ).cast <int >();
384+ for (int action = 0 ; action < action_space_n; ++action) {
385+ if (roots[i]->children .count (action)) {
386+ action_visits.emplace_back (action, roots[i]->children [action]->visit_count );
387+ } else {
388+ action_visits.emplace_back (action, 0 );
389+ }
390+ }
391+
392+ std::vector<int > actions;
393+ std::vector<int > visits;
394+ for (const auto & av : action_visits) {
395+ actions.emplace_back (av.first );
396+ visits.emplace_back (av.second );
397+ }
398+
399+ std::vector<double > visits_d (visits.begin (), visits.end ());
400+ std::vector<double > action_probs =
401+ visit_count_to_action_distribution (visits_d, temperature);
402+
403+ int action_selected;
404+ if (sample) {
405+ action_selected = random_choice (actions, action_probs);
406+ } else {
407+ auto max_it = std::max_element (action_probs.begin (), action_probs.end ());
408+ action_selected = actions[std::distance (action_probs.begin (), max_it)];
409+ }
410+
411+ results.push_back (std::make_tuple (action_selected, action_probs, roots[i]));
412+ }
413+
414+ return results;
415+ }
416+
160417 // Main function to get the next action from MCTS
161418 std::tuple<int , std::vector<double >, std::shared_ptr<Node>> get_next_action (py::object state_config_for_env_reset, py::object policy_value_func, double temperature, bool sample) {
162419 std::shared_ptr<Node> root = std::make_shared<Node>();
@@ -228,7 +485,37 @@ class MCTS {
228485 return std::make_tuple (action_selected, action_probs, root);
229486 }
230487
231- // Simulate a game starting from a given node
488+ // Structure to store single simulation result for parallel batch inference.
489+ struct SimulationResult {
490+ std::shared_ptr<Node> leaf_node;
491+ bool is_done;
492+ int winner;
493+ py::object simulate_env;
494+ };
495+
496+ // Single simulation from root to leaf node without inference, returns leaf node info.
497+ SimulationResult _simulate_to_leaf (std::shared_ptr<Node> node, py::object simulate_env) {
498+ while (!node->is_leaf ()) {
499+ int action;
500+ std::shared_ptr<Node> child;
501+ std::tie (action, child) = _select_child (node, simulate_env);
502+ if (action == -1 ) {
503+ break ;
504+ }
505+ simulate_env.attr (" step" )(action);
506+ node = child;
507+ }
508+
509+ bool done;
510+ int winner;
511+ py::tuple result = simulate_env.attr (" get_done_winner" )();
512+ done = result[0 ].cast <bool >();
513+ winner = result[1 ].cast <int >();
514+
515+ return SimulationResult{node, done, winner, simulate_env};
516+ }
517+
518+ // Legacy single environment simulation function (kept for backward compatibility).
232519 void _simulate (std::shared_ptr<Node> node, py::object simulate_env, py::object policy_value_func) {
233520 while (!node->is_leaf ()) {
234521 int action;
@@ -372,5 +659,11 @@ PYBIND11_MODULE(mcts_alphazero, m) {
372659 py::arg (" state_config_for_env_reset" ),
373660 py::arg (" policy_value_func" ),
374661 py::arg (" temperature" ),
375- py::arg (" sample" ));
376- }
662+ py::arg (" sample" ))
663+ .def (" get_next_actions_batch" , &MCTS::get_next_actions_batch,
664+ py::arg (" state_configs_list" ),
665+ py::arg (" policy_value_func_batch" ),
666+ py::arg (" temperature" ),
667+ py::arg (" sample" ),
668+ py::arg (" simulate_env_list" ));
669+ }
0 commit comments