Skip to content

Commit 6a6ba8d

Browse files
committed
Merge pull request #1074 from maichmueller:feat/infostatetree_store_states_option
PiperOrigin-RevId: 859196436 Change-Id: I65ba29402c1cbf1fc38558af5e594908a1fcff7d
2 parents 50389e9 + b969a1a commit 6a6ba8d

File tree

3 files changed

+206
-113
lines changed

3 files changed

+206
-113
lines changed

open_spiel/algorithms/infostate_tree.cc

Lines changed: 134 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,25 @@
1414

1515
#include "open_spiel/algorithms/infostate_tree.h"
1616

17+
#include <algorithm>
18+
#include <cmath>
19+
#include <cstddef>
20+
#include <functional>
1721
#include <limits>
1822
#include <memory>
23+
#include <ostream>
1924
#include <stack>
2025
#include <string>
2126
#include <utility>
2227
#include <vector>
2328

29+
#include "open_spiel/abseil-cpp/absl/strings/str_cat.h"
30+
#include "open_spiel/abseil-cpp/absl/strings/str_join.h"
31+
#include "open_spiel/abseil-cpp/absl/types/optional.h"
2432
#include "open_spiel/action_view.h"
33+
#include "open_spiel/observer.h"
34+
#include "open_spiel/spiel.h"
35+
#include "open_spiel/spiel_utils.h"
2536

2637
namespace open_spiel {
2738
namespace algorithms {
@@ -161,10 +172,12 @@ void InfostateNode::SwapParent(std::unique_ptr<InfostateNode> self,
161172
InfostateTree::InfostateTree(const std::vector<const State*>& start_states,
162173
const std::vector<double>& chance_reach_probs,
163174
std::shared_ptr<Observer> infostate_observer,
164-
Player acting_player, int max_move_ahead_limit)
175+
Player acting_player, bool store_world_states,
176+
int max_move_ahead_limit)
165177
: acting_player_(acting_player),
166178
infostate_observer_(std::move(infostate_observer)),
167-
root_(MakeRootNode()) {
179+
root_(MakeRootNode()),
180+
store_all_world_states_(store_world_states) {
168181
SPIEL_CHECK_FALSE(start_states.empty());
169182
SPIEL_CHECK_EQ(start_states.size(), chance_reach_probs.size());
170183
SPIEL_CHECK_GE(acting_player_, 0);
@@ -178,7 +191,8 @@ InfostateTree::InfostateTree(const std::vector<const State*>& start_states,
178191
}
179192

180193
for (int i = 0; i < start_states.size(); ++i) {
181-
RecursivelyBuildTree(root_.get(), /*depth=*/1, *start_states[i],
194+
RecursivelyBuildTree(root_.get(), /*depth=*/1,
195+
std::shared_ptr<const State>(start_states[i]->Clone()),
182196
start_max_move_number + max_move_ahead_limit,
183197
chance_reach_probs[i]);
184198
}
@@ -225,11 +239,11 @@ std::unique_ptr<InfostateNode> InfostateTree::MakeNode(
225239
: std::vector<Action>();
226240
// Instantiate node using new to make sure that we can call
227241
// the private constructor.
228-
auto node = std::unique_ptr<InfostateNode>(new InfostateNode(
242+
auto node = new InfostateNode(
229243
*this, parent, parent->num_children(), type, infostate_string,
230244
terminal_utility, terminal_ch_reach_prob, depth, std::move(legal_actions),
231-
std::move(terminal_history)));
232-
return node;
245+
std::move(terminal_history));
246+
return std::unique_ptr<InfostateNode>{node};
233247
}
234248

235249
std::unique_ptr<InfostateNode> InfostateTree::MakeRootNode() const {
@@ -241,46 +255,53 @@ std::unique_ptr<InfostateNode> InfostateTree::MakeRootNode() const {
241255
/*depth=*/0, /*legal_actions=*/{}, /*terminal_history=*/{}));
242256
}
243257

244-
void InfostateTree::UpdateLeafNode(InfostateNode* node, const State& state,
245-
size_t leaf_depth,
246-
double chance_reach_probs) {
247-
tree_height_ = std::max(tree_height_, leaf_depth);
248-
node->corresponding_states_.push_back(state.Clone());
258+
void InfostateTree::UpdateNode(InfostateNode* node,
259+
std::shared_ptr<const State> state,
260+
size_t node_depth, double chance_reach_probs) {
261+
tree_height_ = std::max(tree_height_, node_depth);
262+
node->corresponding_states_.push_back(std::move(state));
249263
node->corresponding_ch_reaches_.push_back(chance_reach_probs);
250264
}
251265

252266
void InfostateTree::RecursivelyBuildTree(InfostateNode* parent, size_t depth,
253-
const State& state, int move_limit,
267+
std::shared_ptr<const State> state,
268+
int move_limit,
254269
double chance_reach_prob) {
255-
if (state.IsTerminal())
256-
return BuildTerminalNode(parent, depth, state, chance_reach_prob);
257-
else if (state.IsPlayerActing(acting_player_))
258-
return BuildDecisionNode(parent, depth, state, move_limit,
259-
chance_reach_prob);
260-
else
261-
return BuildObservationNode(parent, depth, state, move_limit,
262-
chance_reach_prob);
270+
auto [child_node, leaf_update] = std::invoke([&] {
271+
if (state->IsTerminal())
272+
return BuildTerminalNode(parent, depth, state, chance_reach_prob);
273+
else if (state->IsPlayerActing(acting_player_))
274+
return BuildDecisionNode(parent, depth, state, move_limit,
275+
chance_reach_prob);
276+
else
277+
return BuildObservationNode(parent, depth, state, move_limit,
278+
chance_reach_prob);
279+
});
280+
if (store_all_world_states_ || leaf_update) {
281+
UpdateNode(child_node, std::move(state), depth, chance_reach_prob);
282+
}
263283
}
264284

265-
void InfostateTree::BuildTerminalNode(InfostateNode* parent, size_t depth,
266-
const State& state,
267-
double chance_reach_prob) {
268-
const double terminal_utility = state.Returns()[acting_player_];
285+
std::pair<InfostateNode*, bool> InfostateTree::BuildTerminalNode(
286+
InfostateNode* parent, size_t depth,
287+
const std::shared_ptr<const State>& state, double chance_reach_prob) {
288+
const double terminal_utility = state->Returns()[acting_player_];
269289
InfostateNode* terminal_node = parent->AddChild(
270290
MakeNode(parent, kTerminalInfostateNode,
271-
infostate_observer_->StringFrom(state, acting_player_),
272-
terminal_utility, chance_reach_prob, depth, &state));
273-
UpdateLeafNode(terminal_node, state, depth, chance_reach_prob);
291+
infostate_observer_->StringFrom(*state, acting_player_),
292+
terminal_utility, chance_reach_prob, depth, state.get()));
293+
return {terminal_node, true};
274294
}
275295

276-
void InfostateTree::BuildDecisionNode(InfostateNode* parent, size_t depth,
277-
const State& state, int move_limit,
278-
double chance_reach_prob) {
296+
std::pair<InfostateNode*, bool> InfostateTree::BuildDecisionNode(
297+
InfostateNode* parent, size_t depth,
298+
const std::shared_ptr<const State>& state, int move_limit,
299+
double chance_reach_prob) {
279300
SPIEL_DCHECK_EQ(parent->type(), kObservationInfostateNode);
280301
std::string info_state =
281-
infostate_observer_->StringFrom(state, acting_player_);
302+
infostate_observer_->StringFrom(*state, acting_player_);
282303
InfostateNode* decision_node = parent->GetChild(info_state);
283-
const bool is_leaf_node = state.MoveNumber() >= move_limit;
304+
const bool is_leaf_node = state->MoveNumber() >= move_limit;
284305

285306
if (decision_node) {
286307
// The decision node has been already constructed along with children
@@ -289,49 +310,62 @@ void InfostateTree::BuildDecisionNode(InfostateNode* parent, size_t depth,
289310
SPIEL_DCHECK_EQ(decision_node->type(), kDecisionInfostateNode);
290311

291312
if (is_leaf_node) { // Do not build deeper.
292-
return UpdateLeafNode(decision_node, state, depth, chance_reach_prob);
313+
return {decision_node, true};
293314
}
294315

295-
if (state.IsSimultaneousNode()) {
296-
const ActionView action_view(state);
316+
if (state->IsSimultaneousNode()) {
317+
const ActionView action_view(*state);
297318
for (int i = 0; i < action_view.legal_actions[acting_player_].size();
298319
++i) {
299320
InfostateNode* observation_node = decision_node->child_at(i);
300321
SPIEL_DCHECK_EQ(observation_node->type(), kObservationInfostateNode);
301322

302323
for (Action flat_actions :
303324
action_view.fixed_action(acting_player_, i)) {
304-
std::unique_ptr<State> child = state.Child(flat_actions);
305-
RecursivelyBuildTree(observation_node, depth + 2, *child, move_limit,
306-
chance_reach_prob);
325+
auto child_state = std::shared_ptr{state->Child(flat_actions)};
326+
// Only now we can advance the state, when we have all actions.
327+
RecursivelyBuildTree(observation_node, depth + 2, child_state,
328+
move_limit, chance_reach_prob);
329+
if (store_all_world_states_ &&
330+
!observation_node->is_filler_node()) {
331+
UpdateNode(observation_node, std::move(child_state), depth + 2,
332+
chance_reach_prob);
333+
}
307334
}
308335
}
309336
} else {
310-
std::vector<Action> legal_actions = state.LegalActions(acting_player_);
337+
std::vector<Action> legal_actions = state->LegalActions(acting_player_);
311338
for (int i = 0; i < legal_actions.size(); ++i) {
312339
InfostateNode* observation_node = decision_node->child_at(i);
313340
SPIEL_DCHECK_EQ(observation_node->type(), kObservationInfostateNode);
314-
std::unique_ptr<State> child = state.Child(legal_actions.at(i));
315-
RecursivelyBuildTree(observation_node, depth + 2, *child, move_limit,
316-
chance_reach_prob);
341+
auto child_state = std::shared_ptr{state->Child(legal_actions.at(i))};
342+
// Only now we can advance the state, when we have all actions.
343+
RecursivelyBuildTree(observation_node, depth + 2, child_state,
344+
move_limit, chance_reach_prob);
345+
if (store_all_world_states_ &&
346+
!observation_node->is_filler_node()) {
347+
UpdateNode(observation_node, std::move(child_state), depth,
348+
chance_reach_prob);
349+
}
317350
}
318351
}
319352
} else { // The decision node was not found yet.
320-
decision_node = parent->AddChild(MakeNode(
321-
parent, kDecisionInfostateNode, info_state,
322-
/*terminal_utility=*/NAN, /*chance_reach_prob=*/NAN, depth, &state));
353+
decision_node = parent->AddChild(
354+
MakeNode(parent, kDecisionInfostateNode, info_state,
355+
/*terminal_utility=*/NAN,
356+
/*chance_reach_prob=*/NAN, depth, state.get()));
323357

324358
if (is_leaf_node) { // Do not build deeper.
325-
return UpdateLeafNode(decision_node, state, depth, chance_reach_prob);
359+
return {decision_node, true};
326360
}
327361

328362
// Build observation nodes right away after the decision node.
329363
// This is because the player might be acting multiple times in a row:
330364
// each time it might get some observations that branch the infostate
331365
// tree.
332366

333-
if (state.IsSimultaneousNode()) {
334-
ActionView action_view(state);
367+
if (state->IsSimultaneousNode()) {
368+
ActionView action_view(*state);
335369
for (int i = 0; i < action_view.legal_actions[acting_player_].size();
336370
++i) {
337371
// We build a dummy observation node.
@@ -344,89 +378,105 @@ void InfostateTree::BuildDecisionNode(InfostateNode* parent, size_t depth,
344378
/*infostate_string=*/kFillerInfostate,
345379
/*terminal_utility=*/NAN, /*chance_reach_prob=*/NAN, depth,
346380
/*originating_state=*/nullptr));
347-
348381
for (Action flat_actions :
349382
action_view.fixed_action(acting_player_, i)) {
383+
auto child_state = std::shared_ptr{state->Child(flat_actions)};
350384
// Only now we can advance the state, when we have all actions.
351-
std::unique_ptr<State> child = state.Child(flat_actions);
352-
RecursivelyBuildTree(observation_node, depth + 2, *child, move_limit,
353-
chance_reach_prob);
385+
RecursivelyBuildTree(observation_node, depth + 2, child_state,
386+
move_limit, chance_reach_prob);
387+
if (store_all_world_states_ &&
388+
!observation_node->is_filler_node()) {
389+
UpdateNode(observation_node, std::move(child_state), depth,
390+
chance_reach_prob);
391+
}
354392
}
355393
}
356394
} else { // Not a sim move node.
357-
for (Action a : state.LegalActions()) {
358-
std::unique_ptr<State> child = state.Child(a);
395+
for (Action a : state->LegalActions()) {
396+
std::shared_ptr child = state->Child(a);
359397
InfostateNode* observation_node = decision_node->AddChild(
360398
MakeNode(decision_node, kObservationInfostateNode,
361399
infostate_observer_->StringFrom(*child, acting_player_),
362400
/*terminal_utility=*/NAN, /*chance_reach_prob=*/NAN, depth,
363401
child.get()));
364-
RecursivelyBuildTree(observation_node, depth + 2, *child, move_limit,
402+
RecursivelyBuildTree(observation_node, depth + 2, child, move_limit,
365403
chance_reach_prob);
404+
if (store_all_world_states_ &&
405+
!observation_node->is_filler_node()) {
406+
UpdateNode(observation_node, std::move(child), depth,
407+
chance_reach_prob);
408+
}
366409
}
367410
}
368411
}
412+
return {decision_node, false};
369413
}
370414

371-
void InfostateTree::BuildObservationNode(InfostateNode* parent, size_t depth,
372-
const State& state, int move_limit,
373-
double chance_reach_prob) {
374-
SPIEL_DCHECK_TRUE(state.IsChanceNode() ||
375-
!state.IsPlayerActing(acting_player_));
376-
const bool is_leaf_node = state.MoveNumber() >= move_limit;
415+
std::pair<InfostateNode*, bool> InfostateTree::BuildObservationNode(
416+
InfostateNode* parent, size_t depth,
417+
const std::shared_ptr<const State>& state, int move_limit,
418+
double chance_reach_prob) {
419+
SPIEL_DCHECK_TRUE(state->IsChanceNode() ||
420+
!state->IsPlayerActing(acting_player_));
421+
const bool is_leaf_node = state->MoveNumber() >= move_limit;
377422
const std::string info_state =
378-
infostate_observer_->StringFrom(state, acting_player_);
423+
infostate_observer_->StringFrom(*state, acting_player_);
379424

380425
InfostateNode* observation_node = parent->GetChild(info_state);
381426
if (!observation_node) {
382-
observation_node = parent->AddChild(MakeNode(
383-
parent, kObservationInfostateNode, info_state,
384-
/*terminal_utility=*/NAN, /*chance_reach_prob=*/NAN, depth, &state));
427+
observation_node = parent->AddChild(
428+
MakeNode(parent, kObservationInfostateNode, info_state,
429+
/*terminal_utility=*/NAN, /*chance_reach_prob=*/NAN, depth,
430+
state.get()));
385431
}
386432
SPIEL_DCHECK_EQ(observation_node->type(), kObservationInfostateNode);
387433

388434
if (is_leaf_node) { // Do not build deeper.
389-
return UpdateLeafNode(observation_node, state, depth, chance_reach_prob);
435+
return {observation_node, true};
390436
}
391437

392-
if (state.IsChanceNode()) {
393-
for (std::pair<Action, double> action_prob : state.ChanceOutcomes()) {
394-
std::unique_ptr<State> child = state.Child(action_prob.first);
395-
RecursivelyBuildTree(observation_node, depth + 1, *child, move_limit,
438+
if (state->IsChanceNode()) {
439+
for (std::pair<Action, double> action_prob : state->ChanceOutcomes()) {
440+
RecursivelyBuildTree(observation_node, depth + 1,
441+
state->Child(action_prob.first), move_limit,
396442
chance_reach_prob * action_prob.second);
397443
}
398444
} else {
399-
for (Action a : state.LegalActions()) {
400-
std::unique_ptr<State> child = state.Child(a);
401-
RecursivelyBuildTree(observation_node, depth + 1, *child, move_limit,
402-
chance_reach_prob);
445+
for (Action a : state->LegalActions()) {
446+
RecursivelyBuildTree(observation_node, depth + 1, state->Child(a),
447+
move_limit, chance_reach_prob);
403448
}
404449
}
450+
return {observation_node, false};
405451
}
406452
int InfostateTree::root_branching_factor() const {
407453
return root_->num_children();
408454
}
409455

410456
std::shared_ptr<InfostateTree> MakeInfostateTree(const Game& game,
411457
Player acting_player,
458+
bool store_world_states,
412459
int max_move_limit) {
413460
// Uses new instead of make_shared, because shared_ptr is not a friend and
414461
// can't call private constructors.
415462
return std::shared_ptr<InfostateTree>(new InfostateTree(
416463
{game.NewInitialState().get()}, /*chance_reach_probs=*/{1.},
417-
game.MakeObserver(kInfoStateObsType, {}), acting_player, max_move_limit));
464+
game.MakeObserver(kInfoStateObsType, {}), acting_player,
465+
store_world_states, max_move_limit));
418466
}
419467

420468
std::shared_ptr<InfostateTree> MakeInfostateTree(
421-
const std::vector<InfostateNode*>& start_nodes, int max_move_ahead_limit) {
469+
const std::vector<InfostateNode*>& start_nodes, bool store_world_states,
470+
int max_move_ahead_limit) {
422471
std::vector<const InfostateNode*> const_nodes(start_nodes.begin(),
423472
start_nodes.end());
424-
return MakeInfostateTree(const_nodes, max_move_ahead_limit);
473+
return MakeInfostateTree(const_nodes, store_world_states,
474+
max_move_ahead_limit);
425475
}
426476

427477
std::shared_ptr<InfostateTree> MakeInfostateTree(
428478
const std::vector<const InfostateNode*>& start_nodes,
429-
int max_move_ahead_limit) {
479+
bool store_world_states, int max_move_ahead_limit) {
430480
SPIEL_CHECK_FALSE(start_nodes.empty());
431481
const InfostateNode* some_node = start_nodes[0];
432482
const InfostateTree& originating_tree = some_node->tree();
@@ -458,17 +508,18 @@ std::shared_ptr<InfostateTree> MakeInfostateTree(
458508
// can't call private constructors.
459509
return std::shared_ptr<InfostateTree>(new InfostateTree(
460510
start_states, chance_reach_probs, originating_tree.infostate_observer_,
461-
originating_tree.acting_player_, max_move_ahead_limit));
511+
originating_tree.acting_player_, store_world_states,
512+
max_move_ahead_limit));
462513
}
463514

464515
std::shared_ptr<InfostateTree> MakeInfostateTree(
465516
const std::vector<const State*>& start_states,
466517
const std::vector<double>& chance_reach_probs,
467518
std::shared_ptr<Observer> infostate_observer, Player acting_player,
468-
int max_move_ahead_limit) {
469-
return std::shared_ptr<InfostateTree>(
470-
new InfostateTree(start_states, chance_reach_probs, infostate_observer,
471-
acting_player, max_move_ahead_limit));
519+
bool store_world_states, int max_move_ahead_limit) {
520+
return std::shared_ptr<InfostateTree>(new InfostateTree(
521+
start_states, chance_reach_probs, std::move(infostate_observer),
522+
acting_player, store_world_states, max_move_ahead_limit));
472523
}
473524
SequenceId InfostateTree::empty_sequence() const {
474525
return root().sequence_id();

0 commit comments

Comments
 (0)