Skip to content

Commit 8ddcf89

Browse files
authored
[RLlib] Implemented ViewRequirementConnector (ray-project#26998)
1 parent 837ef77 commit 8ddcf89

File tree

11 files changed

+970
-630
lines changed

11 files changed

+970
-630
lines changed

rllib/connectors/agent/view_requirement.py

Lines changed: 40 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,12 @@
77
register_connector,
88
)
99
from ray.rllib.policy.sample_batch import SampleBatch
10-
from ray.rllib.utils.typing import AgentConnectorDataType, AgentConnectorsOutput
10+
from ray.rllib.utils.typing import (
11+
AgentConnectorDataType,
12+
AgentConnectorsOutput,
13+
)
1114
from ray.util.annotations import PublicAPI
15+
from ray.rllib.evaluation.collectors.agent_collector import AgentCollector
1216

1317

1418
@PublicAPI(stability="alpha")
@@ -32,11 +36,24 @@ def __init__(self, ctx: ConnectorContext):
3236
super().__init__(ctx)
3337

3438
self._view_requirements = ctx.view_requirements
35-
self._agent_data = defaultdict(lambda: defaultdict(SampleBatch))
39+
40+
# a dict of env_id to a dict of agent_id to a list of agent_collector objects
41+
env_default = defaultdict(
42+
lambda: AgentCollector(
43+
self._view_requirements,
44+
max_seq_len=ctx.config["model"]["max_seq_len"],
45+
intial_states=ctx.initial_states,
46+
disable_action_flattening=ctx.config.get(
47+
"_disable_action_flattening", False
48+
),
49+
is_policy_recurrent=ctx.is_policy_recurrent,
50+
)
51+
)
52+
self.agent_collectors = defaultdict(lambda: env_default)
3653

3754
def reset(self, env_id: str):
38-
if env_id in self._agent_data:
39-
del self._agent_data[env_id]
55+
if env_id in self.agent_collectors:
56+
del self.agent_collectors[env_id]
4057

4158
def _get_sample_batch_for_action(
4259
self, view_requirements, agent_batch
@@ -61,6 +78,9 @@ def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
6178

6279
env_id = ac_data.env_id
6380
agent_id = ac_data.agent_id
81+
# TODO: we don't keep episode_id around so use env_id as episode_id ?
82+
episode_id = env_id if SampleBatch.EPS_ID not in d else d[SampleBatch.EPS_ID]
83+
6484
assert env_id is not None and agent_id is not None, (
6585
f"ViewRequirementAgentConnector requires env_id({env_id}) "
6686
"and agent_id({agent_id})"
@@ -77,40 +97,22 @@ def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
7797
# iew_requirement.used_for_training is False.
7898
training_dict = d
7999

80-
# Agent batch is our buffer of necessary history for computing
81-
# a SampleBatch for policy forward pass.
82-
# This is used by both training and inference.
83-
agent_batch = self._agent_data[env_id][agent_id]
84-
for col, req in vr.items():
85-
# Not used for action computation.
86-
if not req.used_for_compute_actions:
87-
continue
88-
89-
# Create the batch of data from the different buffers.
90-
if col == SampleBatch.OBS:
91-
# NEXT_OBS from the training sample is the current OBS
92-
# to run Policy with.
93-
data_col = SampleBatch.NEXT_OBS
94-
else:
95-
data_col = req.data_col or col
96-
if data_col not in d:
97-
continue
98-
99-
if col not in agent_batch:
100-
agent_batch[col] = []
101-
# Stack along batch dim.
102-
agent_batch[col].append(d[data_col])
103-
104-
# Only keep the useful part of the history.
105-
h = -1
106-
if req.shift_from is not None:
107-
h = req.shift_from
108-
elif type(req.shift) == int:
109-
h = req.shift
110-
assert h <= 0, "Cannot use future data to compute action"
111-
agent_batch[col] = agent_batch[col][h:]
112-
113-
sample_batch = self._get_sample_batch_for_action(vr, agent_batch)
100+
agent_collector = self.agent_collectors[env_id][agent_id]
101+
102+
if SampleBatch.NEXT_OBS not in d:
103+
raise ValueError(f"connector data {d} should contain next_obs.")
104+
105+
if agent_collector.is_empty():
106+
agent_collector.add_init_obs(
107+
episode_id=episode_id,
108+
agent_index=agent_id,
109+
env_id=env_id,
110+
t=-1,
111+
init_obs=d[SampleBatch.NEXT_OBS],
112+
)
113+
else:
114+
agent_collector.add_action_reward_next_obs(d)
115+
sample_batch = agent_collector.build_for_inference()
114116

115117
return_data = AgentConnectorDataType(
116118
env_id, agent_id, AgentConnectorsOutput(training_dict, sample_batch)

rllib/connectors/connector.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,25 +38,27 @@ class ConnectorContext:
3838
def __init__(
3939
self,
4040
config: AlgorithmConfigDict = None,
41-
model_initial_states: List[TensorType] = None,
41+
initial_states: List[TensorType] = None,
4242
observation_space: gym.Space = None,
4343
action_space: gym.Space = None,
4444
view_requirements: Dict[str, ViewRequirement] = None,
45+
is_policy_recurrent: bool = False,
4546
):
4647
"""Construct a ConnectorContext instance.
4748
4849
Args:
49-
model_initial_states: States that are used for constructing
50+
initial_states: States that are used for constructing
5051
the initial input dict for RNN models. [] if a model is not recurrent.
5152
action_space_struct: a policy's action space, in python
5253
data format. E.g., python dict instead of DictSpace, python tuple
5354
instead of TupleSpace.
5455
"""
5556
self.config = config or {}
56-
self.initial_states = model_initial_states or []
57+
self.initial_states = initial_states or []
5758
self.observation_space = observation_space
5859
self.action_space = action_space
5960
self.view_requirements = view_requirements
61+
self.is_policy_recurrent = is_policy_recurrent
6062

6163
@staticmethod
6264
def from_policy(policy: "Policy") -> "ConnectorContext":
@@ -69,11 +71,12 @@ def from_policy(policy: "Policy") -> "ConnectorContext":
6971
A ConnectorContext instance.
7072
"""
7173
return ConnectorContext(
72-
policy.config,
73-
policy.get_initial_state(),
74-
policy.observation_space,
75-
policy.action_space,
76-
policy.view_requirements,
74+
config=policy.config,
75+
initial_states=policy.get_initial_state(),
76+
observation_space=policy.observation_space,
77+
action_space=policy.action_space,
78+
view_requirements=policy.view_requirements,
79+
is_policy_recurrent=policy.is_recurrent(),
7780
)
7881

7982

0 commit comments

Comments
 (0)