1414
1515#include " open_spiel/algorithms/infostate_tree.h"
1616
17+ #include < algorithm>
18+ #include < cmath>
19+ #include < cstddef>
20+ #include < functional>
1721#include < limits>
1822#include < memory>
23+ #include < ostream>
1924#include < stack>
2025#include < string>
2126#include < utility>
2227#include < vector>
2328
29+ #include " open_spiel/abseil-cpp/absl/strings/str_cat.h"
30+ #include " open_spiel/abseil-cpp/absl/strings/str_join.h"
31+ #include " open_spiel/abseil-cpp/absl/types/optional.h"
2432#include " open_spiel/action_view.h"
33+ #include " open_spiel/observer.h"
34+ #include " open_spiel/spiel.h"
35+ #include " open_spiel/spiel_utils.h"
2536
2637namespace open_spiel {
2738namespace algorithms {
@@ -161,10 +172,12 @@ void InfostateNode::SwapParent(std::unique_ptr<InfostateNode> self,
161172InfostateTree::InfostateTree (const std::vector<const State*>& start_states,
162173 const std::vector<double >& chance_reach_probs,
163174 std::shared_ptr<Observer> infostate_observer,
164- Player acting_player, int max_move_ahead_limit)
175+ Player acting_player, bool store_world_states,
176+ int max_move_ahead_limit)
165177 : acting_player_(acting_player),
166178 infostate_observer_ (std::move(infostate_observer)),
167- root_(MakeRootNode()) {
179+ root_(MakeRootNode()),
180+ store_all_world_states_(store_world_states) {
168181 SPIEL_CHECK_FALSE (start_states.empty ());
169182 SPIEL_CHECK_EQ (start_states.size (), chance_reach_probs.size ());
170183 SPIEL_CHECK_GE (acting_player_, 0 );
@@ -178,7 +191,8 @@ InfostateTree::InfostateTree(const std::vector<const State*>& start_states,
178191 }
179192
180193 for (int i = 0 ; i < start_states.size (); ++i) {
181- RecursivelyBuildTree (root_.get (), /* depth=*/ 1 , *start_states[i],
194+ RecursivelyBuildTree (root_.get (), /* depth=*/ 1 ,
195+ std::shared_ptr<const State>(start_states[i]->Clone ()),
182196 start_max_move_number + max_move_ahead_limit,
183197 chance_reach_probs[i]);
184198 }
@@ -225,11 +239,11 @@ std::unique_ptr<InfostateNode> InfostateTree::MakeNode(
225239 : std::vector<Action>();
226240 // Instantiate node using new to make sure that we can call
227241 // the private constructor.
228- auto node = std::unique_ptr<InfostateNode>( new InfostateNode (
242+ auto node = new InfostateNode (
229243 *this , parent, parent->num_children (), type, infostate_string,
230244 terminal_utility, terminal_ch_reach_prob, depth, std::move (legal_actions),
231- std::move (terminal_history))) ;
232- return node;
245+ std::move (terminal_history));
246+ return std::unique_ptr<InfostateNode>{ node} ;
233247}
234248
235249std::unique_ptr<InfostateNode> InfostateTree::MakeRootNode () const {
@@ -241,46 +255,53 @@ std::unique_ptr<InfostateNode> InfostateTree::MakeRootNode() const {
241255 /* depth=*/ 0 , /* legal_actions=*/ {}, /* terminal_history=*/ {}));
242256}
243257
244- void InfostateTree::UpdateLeafNode (InfostateNode* node, const State& state ,
245- size_t leaf_depth ,
246- double chance_reach_probs) {
247- tree_height_ = std::max (tree_height_, leaf_depth );
248- node->corresponding_states_ .push_back (state. Clone ( ));
258+ void InfostateTree::UpdateNode (InfostateNode* node,
259+ std::shared_ptr< const State> state ,
260+ size_t node_depth, double chance_reach_probs) {
261+ tree_height_ = std::max (tree_height_, node_depth );
262+ node->corresponding_states_ .push_back (std::move (state ));
249263 node->corresponding_ch_reaches_ .push_back (chance_reach_probs);
250264}
251265
252266void InfostateTree::RecursivelyBuildTree (InfostateNode* parent, size_t depth,
253- const State& state, int move_limit,
267+ std::shared_ptr<const State> state,
268+ int move_limit,
254269 double chance_reach_prob) {
255- if (state.IsTerminal ())
256- return BuildTerminalNode (parent, depth, state, chance_reach_prob);
257- else if (state.IsPlayerActing (acting_player_))
258- return BuildDecisionNode (parent, depth, state, move_limit,
259- chance_reach_prob);
260- else
261- return BuildObservationNode (parent, depth, state, move_limit,
262- chance_reach_prob);
270+ auto [child_node, leaf_update] = std::invoke ([&] {
271+ if (state->IsTerminal ())
272+ return BuildTerminalNode (parent, depth, state, chance_reach_prob);
273+ else if (state->IsPlayerActing (acting_player_))
274+ return BuildDecisionNode (parent, depth, state, move_limit,
275+ chance_reach_prob);
276+ else
277+ return BuildObservationNode (parent, depth, state, move_limit,
278+ chance_reach_prob);
279+ });
280+ if (store_all_world_states_ || leaf_update) {
281+ UpdateNode (child_node, std::move (state), depth, chance_reach_prob);
282+ }
263283}
264284
265- void InfostateTree::BuildTerminalNode ( InfostateNode* parent, size_t depth,
266- const State& state ,
267- double chance_reach_prob) {
268- const double terminal_utility = state. Returns ()[acting_player_];
285+ std::pair< InfostateNode*, bool > InfostateTree::BuildTerminalNode (
286+ InfostateNode* parent, size_t depth ,
287+ const std::shared_ptr< const State>& state, double chance_reach_prob) {
288+ const double terminal_utility = state-> Returns ()[acting_player_];
269289 InfostateNode* terminal_node = parent->AddChild (
270290 MakeNode (parent, kTerminalInfostateNode ,
271- infostate_observer_->StringFrom (state, acting_player_),
272- terminal_utility, chance_reach_prob, depth, & state));
273- UpdateLeafNode ( terminal_node, state, depth, chance_reach_prob) ;
291+ infostate_observer_->StringFrom (* state, acting_player_),
292+ terminal_utility, chance_reach_prob, depth, state. get () ));
293+ return { terminal_node, true } ;
274294}
275295
276- void InfostateTree::BuildDecisionNode (InfostateNode* parent, size_t depth,
277- const State& state, int move_limit,
278- double chance_reach_prob) {
296+ std::pair<InfostateNode*, bool > InfostateTree::BuildDecisionNode (
297+ InfostateNode* parent, size_t depth,
298+ const std::shared_ptr<const State>& state, int move_limit,
299+ double chance_reach_prob) {
279300 SPIEL_DCHECK_EQ (parent->type (), kObservationInfostateNode );
280301 std::string info_state =
281- infostate_observer_->StringFrom (state, acting_player_);
302+ infostate_observer_->StringFrom (* state, acting_player_);
282303 InfostateNode* decision_node = parent->GetChild (info_state);
283- const bool is_leaf_node = state. MoveNumber () >= move_limit;
304+ const bool is_leaf_node = state-> MoveNumber () >= move_limit;
284305
285306 if (decision_node) {
286307 // The decision node has been already constructed along with children
@@ -289,49 +310,62 @@ void InfostateTree::BuildDecisionNode(InfostateNode* parent, size_t depth,
289310 SPIEL_DCHECK_EQ (decision_node->type (), kDecisionInfostateNode );
290311
291312 if (is_leaf_node) { // Do not build deeper.
292- return UpdateLeafNode ( decision_node, state, depth, chance_reach_prob) ;
313+ return { decision_node, true } ;
293314 }
294315
295- if (state. IsSimultaneousNode ()) {
296- const ActionView action_view (state);
316+ if (state-> IsSimultaneousNode ()) {
317+ const ActionView action_view (* state);
297318 for (int i = 0 ; i < action_view.legal_actions [acting_player_].size ();
298319 ++i) {
299320 InfostateNode* observation_node = decision_node->child_at (i);
300321 SPIEL_DCHECK_EQ (observation_node->type (), kObservationInfostateNode );
301322
302323 for (Action flat_actions :
303324 action_view.fixed_action (acting_player_, i)) {
304- std::unique_ptr<State> child = state.Child (flat_actions);
305- RecursivelyBuildTree (observation_node, depth + 2 , *child, move_limit,
306- chance_reach_prob);
325+ auto child_state = std::shared_ptr{state->Child (flat_actions)};
326+ // Only now we can advance the state, when we have all actions.
327+ RecursivelyBuildTree (observation_node, depth + 2 , child_state,
328+ move_limit, chance_reach_prob);
329+ if (store_all_world_states_ &&
330+ !observation_node->is_filler_node ()) {
331+ UpdateNode (observation_node, std::move (child_state), depth + 2 ,
332+ chance_reach_prob);
333+ }
307334 }
308335 }
309336 } else {
310- std::vector<Action> legal_actions = state. LegalActions (acting_player_);
337+ std::vector<Action> legal_actions = state-> LegalActions (acting_player_);
311338 for (int i = 0 ; i < legal_actions.size (); ++i) {
312339 InfostateNode* observation_node = decision_node->child_at (i);
313340 SPIEL_DCHECK_EQ (observation_node->type (), kObservationInfostateNode );
314- std::unique_ptr<State> child = state.Child (legal_actions.at (i));
315- RecursivelyBuildTree (observation_node, depth + 2 , *child, move_limit,
316- chance_reach_prob);
341+ auto child_state = std::shared_ptr{state->Child (legal_actions.at (i))};
342+ // Only now we can advance the state, when we have all actions.
343+ RecursivelyBuildTree (observation_node, depth + 2 , child_state,
344+ move_limit, chance_reach_prob);
345+ if (store_all_world_states_ &&
346+ !observation_node->is_filler_node ()) {
347+ UpdateNode (observation_node, std::move (child_state), depth,
348+ chance_reach_prob);
349+ }
317350 }
318351 }
319352 } else { // The decision node was not found yet.
320- decision_node = parent->AddChild (MakeNode (
321- parent, kDecisionInfostateNode , info_state,
322- /* terminal_utility=*/ NAN, /* chance_reach_prob=*/ NAN, depth, &state));
353+ decision_node = parent->AddChild (
354+ MakeNode (parent, kDecisionInfostateNode , info_state,
355+ /* terminal_utility=*/ NAN,
356+ /* chance_reach_prob=*/ NAN, depth, state.get ()));
323357
324358 if (is_leaf_node) { // Do not build deeper.
325- return UpdateLeafNode ( decision_node, state, depth, chance_reach_prob) ;
359+ return { decision_node, true } ;
326360 }
327361
328362 // Build observation nodes right away after the decision node.
329363 // This is because the player might be acting multiple times in a row:
330364 // each time it might get some observations that branch the infostate
331365 // tree.
332366
333- if (state. IsSimultaneousNode ()) {
334- ActionView action_view (state);
367+ if (state-> IsSimultaneousNode ()) {
368+ ActionView action_view (* state);
335369 for (int i = 0 ; i < action_view.legal_actions [acting_player_].size ();
336370 ++i) {
337371 // We build a dummy observation node.
@@ -344,89 +378,105 @@ void InfostateTree::BuildDecisionNode(InfostateNode* parent, size_t depth,
344378 /* infostate_string=*/ kFillerInfostate ,
345379 /* terminal_utility=*/ NAN, /* chance_reach_prob=*/ NAN, depth,
346380 /* originating_state=*/ nullptr ));
347-
348381 for (Action flat_actions :
349382 action_view.fixed_action (acting_player_, i)) {
383+ auto child_state = std::shared_ptr{state->Child (flat_actions)};
350384 // Only now we can advance the state, when we have all actions.
351- std::unique_ptr<State> child = state.Child (flat_actions);
352- RecursivelyBuildTree (observation_node, depth + 2 , *child, move_limit,
353- chance_reach_prob);
385+ RecursivelyBuildTree (observation_node, depth + 2 , child_state,
386+ move_limit, chance_reach_prob);
387+ if (store_all_world_states_ &&
388+ !observation_node->is_filler_node ()) {
389+ UpdateNode (observation_node, std::move (child_state), depth,
390+ chance_reach_prob);
391+ }
354392 }
355393 }
356394 } else { // Not a sim move node.
357- for (Action a : state. LegalActions ()) {
358- std::unique_ptr<State> child = state. Child (a);
395+ for (Action a : state-> LegalActions ()) {
396+ std::shared_ptr child = state-> Child (a);
359397 InfostateNode* observation_node = decision_node->AddChild (
360398 MakeNode (decision_node, kObservationInfostateNode ,
361399 infostate_observer_->StringFrom (*child, acting_player_),
362400 /* terminal_utility=*/ NAN, /* chance_reach_prob=*/ NAN, depth,
363401 child.get ()));
364- RecursivelyBuildTree (observation_node, depth + 2 , * child, move_limit,
402+ RecursivelyBuildTree (observation_node, depth + 2 , child, move_limit,
365403 chance_reach_prob);
404+ if (store_all_world_states_ &&
405+ !observation_node->is_filler_node ()) {
406+ UpdateNode (observation_node, std::move (child), depth,
407+ chance_reach_prob);
408+ }
366409 }
367410 }
368411 }
412+ return {decision_node, false };
369413}
370414
371- void InfostateTree::BuildObservationNode (InfostateNode* parent, size_t depth,
372- const State& state, int move_limit,
373- double chance_reach_prob) {
374- SPIEL_DCHECK_TRUE (state.IsChanceNode () ||
375- !state.IsPlayerActing (acting_player_));
376- const bool is_leaf_node = state.MoveNumber () >= move_limit;
415+ std::pair<InfostateNode*, bool > InfostateTree::BuildObservationNode (
416+ InfostateNode* parent, size_t depth,
417+ const std::shared_ptr<const State>& state, int move_limit,
418+ double chance_reach_prob) {
419+ SPIEL_DCHECK_TRUE (state->IsChanceNode () ||
420+ !state->IsPlayerActing (acting_player_));
421+ const bool is_leaf_node = state->MoveNumber () >= move_limit;
377422 const std::string info_state =
378- infostate_observer_->StringFrom (state, acting_player_);
423+ infostate_observer_->StringFrom (* state, acting_player_);
379424
380425 InfostateNode* observation_node = parent->GetChild (info_state);
381426 if (!observation_node) {
382- observation_node = parent->AddChild (MakeNode (
383- parent, kObservationInfostateNode , info_state,
384- /* terminal_utility=*/ NAN, /* chance_reach_prob=*/ NAN, depth, &state));
427+ observation_node = parent->AddChild (
428+ MakeNode (parent, kObservationInfostateNode , info_state,
429+ /* terminal_utility=*/ NAN, /* chance_reach_prob=*/ NAN, depth,
430+ state.get ()));
385431 }
386432 SPIEL_DCHECK_EQ (observation_node->type (), kObservationInfostateNode );
387433
388434 if (is_leaf_node) { // Do not build deeper.
389- return UpdateLeafNode ( observation_node, state, depth, chance_reach_prob) ;
435+ return { observation_node, true } ;
390436 }
391437
392- if (state. IsChanceNode ()) {
393- for (std::pair<Action, double > action_prob : state. ChanceOutcomes ()) {
394- std::unique_ptr<State> child = state. Child (action_prob. first );
395- RecursivelyBuildTree (observation_node, depth + 1 , *child , move_limit,
438+ if (state-> IsChanceNode ()) {
439+ for (std::pair<Action, double > action_prob : state-> ChanceOutcomes ()) {
440+ RecursivelyBuildTree (observation_node, depth + 1 ,
441+ state-> Child (action_prob. first ) , move_limit,
396442 chance_reach_prob * action_prob.second );
397443 }
398444 } else {
399- for (Action a : state.LegalActions ()) {
400- std::unique_ptr<State> child = state.Child (a);
401- RecursivelyBuildTree (observation_node, depth + 1 , *child, move_limit,
402- chance_reach_prob);
445+ for (Action a : state->LegalActions ()) {
446+ RecursivelyBuildTree (observation_node, depth + 1 , state->Child (a),
447+ move_limit, chance_reach_prob);
403448 }
404449 }
450+ return {observation_node, false };
405451}
406452int InfostateTree::root_branching_factor () const {
407453 return root_->num_children ();
408454}
409455
410456std::shared_ptr<InfostateTree> MakeInfostateTree (const Game& game,
411457 Player acting_player,
458+ bool store_world_states,
412459 int max_move_limit) {
413460 // Uses new instead of make_shared, because shared_ptr is not a friend and
414461 // can't call private constructors.
415462 return std::shared_ptr<InfostateTree>(new InfostateTree (
416463 {game.NewInitialState ().get ()}, /* chance_reach_probs=*/ {1 .},
417- game.MakeObserver (kInfoStateObsType , {}), acting_player, max_move_limit));
464+ game.MakeObserver (kInfoStateObsType , {}), acting_player,
465+ store_world_states, max_move_limit));
418466}
419467
420468std::shared_ptr<InfostateTree> MakeInfostateTree (
421- const std::vector<InfostateNode*>& start_nodes, int max_move_ahead_limit) {
469+ const std::vector<InfostateNode*>& start_nodes, bool store_world_states,
470+ int max_move_ahead_limit) {
422471 std::vector<const InfostateNode*> const_nodes (start_nodes.begin (),
423472 start_nodes.end ());
424- return MakeInfostateTree (const_nodes, max_move_ahead_limit);
473+ return MakeInfostateTree (const_nodes, store_world_states,
474+ max_move_ahead_limit);
425475}
426476
427477std::shared_ptr<InfostateTree> MakeInfostateTree (
428478 const std::vector<const InfostateNode*>& start_nodes,
429- int max_move_ahead_limit) {
479+ bool store_world_states, int max_move_ahead_limit) {
430480 SPIEL_CHECK_FALSE (start_nodes.empty ());
431481 const InfostateNode* some_node = start_nodes[0 ];
432482 const InfostateTree& originating_tree = some_node->tree ();
@@ -458,17 +508,18 @@ std::shared_ptr<InfostateTree> MakeInfostateTree(
458508 // can't call private constructors.
459509 return std::shared_ptr<InfostateTree>(new InfostateTree (
460510 start_states, chance_reach_probs, originating_tree.infostate_observer_ ,
461- originating_tree.acting_player_ , max_move_ahead_limit));
511+ originating_tree.acting_player_ , store_world_states,
512+ max_move_ahead_limit));
462513}
463514
464515std::shared_ptr<InfostateTree> MakeInfostateTree (
465516 const std::vector<const State*>& start_states,
466517 const std::vector<double >& chance_reach_probs,
467518 std::shared_ptr<Observer> infostate_observer, Player acting_player,
468- int max_move_ahead_limit) {
469- return std::shared_ptr<InfostateTree>(
470- new InfostateTree ( start_states, chance_reach_probs, infostate_observer,
471- acting_player, max_move_ahead_limit));
519+ bool store_world_states, int max_move_ahead_limit) {
520+ return std::shared_ptr<InfostateTree>(new InfostateTree (
521+ start_states, chance_reach_probs, std::move ( infostate_observer) ,
522+ acting_player, store_world_states , max_move_ahead_limit));
472523}
473524SequenceId InfostateTree::empty_sequence () const {
474525 return root ().sequence_id ();
0 commit comments