77 register_connector ,
88)
99from ray .rllib .policy .sample_batch import SampleBatch
10- from ray .rllib .utils .typing import AgentConnectorDataType , AgentConnectorsOutput
10+ from ray .rllib .utils .typing import (
11+ AgentConnectorDataType ,
12+ AgentConnectorsOutput ,
13+ )
1114from ray .util .annotations import PublicAPI
15+ from ray .rllib .evaluation .collectors .agent_collector import AgentCollector
1216
1317
1418@PublicAPI (stability = "alpha" )
@@ -32,11 +36,24 @@ def __init__(self, ctx: ConnectorContext):
3236 super ().__init__ (ctx )
3337
3438 self ._view_requirements = ctx .view_requirements
35- self ._agent_data = defaultdict (lambda : defaultdict (SampleBatch ))
39+
40+ # a dict of env_id to a dict of agent_id to a list of agent_collector objects
41+ env_default = defaultdict (
42+ lambda : AgentCollector (
43+ self ._view_requirements ,
44+ max_seq_len = ctx .config ["model" ]["max_seq_len" ],
45+ intial_states = ctx .initial_states ,
46+ disable_action_flattening = ctx .config .get (
47+ "_disable_action_flattening" , False
48+ ),
49+ is_policy_recurrent = ctx .is_policy_recurrent ,
50+ )
51+ )
52+ self .agent_collectors = defaultdict (lambda : env_default )
3653
3754 def reset (self , env_id : str ):
38- if env_id in self ._agent_data :
39- del self ._agent_data [env_id ]
55+ if env_id in self .agent_collectors :
56+ del self .agent_collectors [env_id ]
4057
4158 def _get_sample_batch_for_action (
4259 self , view_requirements , agent_batch
@@ -61,6 +78,9 @@ def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
6178
6279 env_id = ac_data .env_id
6380 agent_id = ac_data .agent_id
81+ # TODO: we don't keep episode_id around so use env_id as episode_id ?
82+ episode_id = env_id if SampleBatch .EPS_ID not in d else d [SampleBatch .EPS_ID ]
83+
6484 assert env_id is not None and agent_id is not None , (
6585 f"ViewRequirementAgentConnector requires env_id({ env_id } ) "
6686 "and agent_id({agent_id})"
@@ -77,40 +97,22 @@ def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
7797 # iew_requirement.used_for_training is False.
7898 training_dict = d
7999
80- # Agent batch is our buffer of necessary history for computing
81- # a SampleBatch for policy forward pass.
82- # This is used by both training and inference.
83- agent_batch = self ._agent_data [env_id ][agent_id ]
84- for col , req in vr .items ():
85- # Not used for action computation.
86- if not req .used_for_compute_actions :
87- continue
88-
89- # Create the batch of data from the different buffers.
90- if col == SampleBatch .OBS :
91- # NEXT_OBS from the training sample is the current OBS
92- # to run Policy with.
93- data_col = SampleBatch .NEXT_OBS
94- else :
95- data_col = req .data_col or col
96- if data_col not in d :
97- continue
98-
99- if col not in agent_batch :
100- agent_batch [col ] = []
101- # Stack along batch dim.
102- agent_batch [col ].append (d [data_col ])
103-
104- # Only keep the useful part of the history.
105- h = - 1
106- if req .shift_from is not None :
107- h = req .shift_from
108- elif type (req .shift ) == int :
109- h = req .shift
110- assert h <= 0 , "Cannot use future data to compute action"
111- agent_batch [col ] = agent_batch [col ][h :]
112-
113- sample_batch = self ._get_sample_batch_for_action (vr , agent_batch )
100+ agent_collector = self .agent_collectors [env_id ][agent_id ]
101+
102+ if SampleBatch .NEXT_OBS not in d :
103+ raise ValueError (f"connector data { d } should contain next_obs." )
104+
105+ if agent_collector .is_empty ():
106+ agent_collector .add_init_obs (
107+ episode_id = episode_id ,
108+ agent_index = agent_id ,
109+ env_id = env_id ,
110+ t = - 1 ,
111+ init_obs = d [SampleBatch .NEXT_OBS ],
112+ )
113+ else :
114+ agent_collector .add_action_reward_next_obs (d )
115+ sample_batch = agent_collector .build_for_inference ()
114116
115117 return_data = AgentConnectorDataType (
116118 env_id , agent_id , AgentConnectorsOutput (training_dict , sample_batch )
0 commit comments