11import gym
2- from typing import Callable , Dict , List , Tuple , Type , Optional , Union
2+ import logging
3+ from typing import Callable , Dict , List , Tuple , Type , Optional , Union , Set
34
45from ray .rllib .env .base_env import BaseEnv
56from ray .rllib .env .env_context import EnvContext
6- from ray .rllib .utils .annotations import ExperimentalAPI , override , PublicAPI
7+ from ray .rllib .utils .annotations import ExperimentalAPI , override , PublicAPI , \
8+ DeveloperAPI
79from ray .rllib .utils .typing import AgentID , EnvID , EnvType , MultiAgentDict , \
810 MultiEnvDict
911
1012# If the obs space is Dict type, look for the global state under this key.
1113ENV_STATE = "state"
1214
15+ logger = logging .getLogger (__name__ )
16+
1317
1418@PublicAPI
1519class MultiAgentEnv (gym .Env ):
@@ -20,6 +24,15 @@ class MultiAgentEnv(gym.Env):
2024 referred to as "agents" or "RL agents".
2125 """
2226
27+ def __init__ (self ):
28+ self .observation_space = None
29+ self .action_space = None
30+ self ._agent_ids = {}
31+
32+ # do the action and observation spaces map from agent ids to spaces
33+ # for the individual agents?
34+ self ._spaces_in_preferred_format = None
35+
2336 @PublicAPI
2437 def reset (self ) -> MultiAgentDict :
2538 """Resets the env and returns observations from ready agents.
@@ -81,20 +94,127 @@ def step(
8194 """
8295 raise NotImplementedError
8396
97+ @ExperimentalAPI
98+ def observation_space_contains (self , x : MultiAgentDict ) -> bool :
99+ """Checks if the observation space contains the given key.
100+
101+ Args:
102+ x: Observations to check.
103+
104+ Returns:
105+ True if the observation space contains the given all observations
106+ in x.
107+ """
108+ if not hasattr (self , "_spaces_in_preferred_format" ) or \
109+ self ._spaces_in_preferred_format is None :
110+ self ._spaces_in_preferred_format = \
111+ self ._check_if_space_maps_agent_id_to_sub_space ()
112+ if self ._spaces_in_preferred_format :
113+ return self .observation_space .contains (x )
114+
115+ logger .warning ("observation_space_contains() has not been implemented" )
116+ return True
117+
118+ @ExperimentalAPI
119+ def action_space_contains (self , x : MultiAgentDict ) -> bool :
120+ """Checks if the action space contains the given action.
121+
122+ Args:
123+ x: Actions to check.
124+
125+ Returns:
126+ True if the action space contains all actions in x.
127+ """
128+ if not hasattr (self , "_spaces_in_preferred_format" ) or \
129+ self ._spaces_in_preferred_format is None :
130+ self ._spaces_in_preferred_format = \
131+ self ._check_if_space_maps_agent_id_to_sub_space ()
132+ if self ._spaces_in_preferred_format :
133+ return self .action_space .contains (x )
134+
135+ logger .warning ("action_space_contains() has not been implemented" )
136+ return True
137+
138+ @ExperimentalAPI
139+ def action_space_sample (self , agent_ids : list = None ) -> MultiAgentDict :
140+ """Returns a random action for each environment, and potentially each
141+ agent in that environment.
142+
143+ Args:
144+ agent_ids: List of agent ids to sample actions for. If None or
145+ empty list, sample actions for all agents in the
146+ environment.
147+
148+ Returns:
149+ A random action for each environment.
150+ """
151+ if not hasattr (self , "_spaces_in_preferred_format" ) or \
152+ self ._spaces_in_preferred_format is None :
153+ self ._spaces_in_preferred_format = \
154+ self ._check_if_space_maps_agent_id_to_sub_space ()
155+ if self ._spaces_in_preferred_format :
156+ if agent_ids is None :
157+ agent_ids = self .get_agent_ids ()
158+ samples = self .action_space .sample ()
159+ return {agent_id : samples [agent_id ] for agent_id in agent_ids }
160+ logger .warning ("action_space_sample() has not been implemented" )
161+ del agent_ids
162+ return {}
163+
164+ @ExperimentalAPI
165+ def observation_space_sample (self , agent_ids : list = None ) -> MultiEnvDict :
166+ """Returns a random observation from the observation space for each
167+ agent if agent_ids is None, otherwise returns a random observation for
168+ the agents in agent_ids.
169+
170+ Args:
171+ agent_ids: List of agent ids to sample actions for. If None or
172+ empty list, sample actions for all agents in the
173+ environment.
174+
175+ Returns:
176+ A random action for each environment.
177+ """
178+
179+ if not hasattr (self , "_spaces_in_preferred_format" ) or \
180+ self ._spaces_in_preferred_format is None :
181+ self ._spaces_in_preferred_format = \
182+ self ._check_if_space_maps_agent_id_to_sub_space ()
183+ if self ._spaces_in_preferred_format :
184+ if agent_ids is None :
185+ agent_ids = self .get_agent_ids ()
186+ samples = self .observation_space .sample ()
187+ samples = {agent_id : samples [agent_id ] for agent_id in agent_ids }
188+ return samples
189+ logger .warning ("observation_space_sample() has not been implemented" )
190+ del agent_ids
191+ return {}
192+
193+ @PublicAPI
194+ def get_agent_ids (self ) -> Set [AgentID ]:
195+ """Returns a set of agent ids in the environment.
196+
197+ Returns:
198+ set of agent ids.
199+ """
200+ if not isinstance (self ._agent_ids , set ):
201+ self ._agent_ids = set (self ._agent_ids )
202+ return self ._agent_ids
203+
84204 @PublicAPI
85205 def render (self , mode = None ) -> None :
86206 """Tries to render the environment."""
87207
88208 # By default, do nothing.
89209 pass
90210
91- # yapf: disable
92- # __grouping_doc_begin__
211+ # yapf: disable
212+ # __grouping_doc_begin__
93213 @ExperimentalAPI
94214 def with_agent_groups (
95- self ,
96- groups : Dict [str , List [AgentID ]],
97- obs_space : gym .Space = None ,
215+ self ,
216+ groups : Dict [str , List [AgentID ]],
217+ obs_space : gym .Space = None ,
98218 act_space : gym .Space = None ) -> "MultiAgentEnv" :
99219 """Convenience method for grouping together agents in this env.
100220
@@ -132,8 +252,9 @@ def with_agent_groups(
132252 from ray .rllib .env .wrappers .group_agents_wrapper import \
133253 GroupAgentsWrapper
134254 return GroupAgentsWrapper (self , groups , obs_space , act_space )
135- # __grouping_doc_end__
136- # yapf: enable
255+
256+ # __grouping_doc_end__
257+ # yapf: enable
137258
138259 @PublicAPI
139260 def to_base_env (
@@ -182,6 +303,20 @@ def to_base_env(
182303
183304 return env
184305
306+ @DeveloperAPI
307+ def _check_if_space_maps_agent_id_to_sub_space (self ) -> bool :
308+ # do the action and observation spaces map from agent ids to spaces
309+ # for the individual agents?
310+ obs_space_check = (
311+ hasattr (self , "observation_space" )
312+ and isinstance (self .observation_space , gym .spaces .Dict )
313+ and set (self .observation_space .keys ()) == self .get_agent_ids ())
314+ action_space_check = (
315+ hasattr (self , "action_space" )
316+ and isinstance (self .action_space , gym .spaces .Dict )
317+ and set (self .action_space .keys ()) == self .get_agent_ids ())
318+ return obs_space_check and action_space_check
319+
185320
186321def make_multi_agent (
187322 env_name_or_creator : Union [str , Callable [[EnvContext ], EnvType ]],
@@ -242,6 +377,40 @@ def __init__(self, config=None):
242377 self .dones = set ()
243378 self .observation_space = self .agents [0 ].observation_space
244379 self .action_space = self .agents [0 ].action_space
380+ self ._agent_ids = set (range (num ))
381+
382+ @override (MultiAgentEnv )
383+ def observation_space_sample (self ,
384+ agent_ids : list = None ) -> MultiAgentDict :
385+ if agent_ids is None :
386+ agent_ids = list (range (len (self .agents )))
387+ obs = {
388+ agent_id : self .observation_space .sample ()
389+ for agent_id in agent_ids
390+ }
391+
392+ return obs
393+
394+ @override (MultiAgentEnv )
395+ def action_space_sample (self ,
396+ agent_ids : list = None ) -> MultiAgentDict :
397+ if agent_ids is None :
398+ agent_ids = list (range (len (self .agents )))
399+ actions = {
400+ agent_id : self .action_space .sample ()
401+ for agent_id in agent_ids
402+ }
403+
404+ return actions
405+
406+ @override (MultiAgentEnv )
407+ def action_space_contains (self , x : MultiAgentDict ) -> bool :
408+ return all (self .action_space .contains (val ) for val in x .values ())
409+
410+ @override (MultiAgentEnv )
411+ def observation_space_contains (self , x : MultiAgentDict ) -> bool :
412+ return all (
413+ self .observation_space .contains (val ) for val in x .values ())
245414
246415 @override (MultiAgentEnv )
247416 def reset (self ):
@@ -277,7 +446,7 @@ def __init__(self, make_env: Callable[[int], EnvType],
277446
278447 Args:
279448 make_env (Callable[[int], EnvType]): Factory that produces a new
280- MultiAgentEnv intance . Must be defined, if the number of
449+ MultiAgentEnv instance . Must be defined, if the number of
281450 existing envs is less than num_envs.
282451 existing_envs (List[MultiAgentEnv]): List of already existing
283452 multi-agent envs.
@@ -355,18 +524,31 @@ def try_render(self, env_id: Optional[EnvID] = None) -> None:
355524 @override (BaseEnv )
356525 @PublicAPI
357526 def observation_space (self ) -> gym .spaces .Dict :
358- space = {
359- _id : env .observation_space
360- for _id , env in enumerate (self .envs )
361- }
362- return gym .spaces .Dict (space )
527+ self .envs [0 ].observation_space
363528
364529 @property
365530 @override (BaseEnv )
366531 @PublicAPI
367532 def action_space (self ) -> gym .Space :
368- space = {_id : env .action_space for _id , env in enumerate (self .envs )}
369- return gym .spaces .Dict (space )
533+ return self .envs [0 ].action_space
534+
535+ @override (BaseEnv )
536+ def observation_space_contains (self , x : MultiEnvDict ) -> bool :
537+ return all (
538+ self .envs [0 ].observation_space_contains (val ) for val in x .values ())
539+
540+ @override (BaseEnv )
541+ def action_space_contains (self , x : MultiEnvDict ) -> bool :
542+ return all (
543+ self .envs [0 ].action_space_contains (val ) for val in x .values ())
544+
545+ @override (BaseEnv )
546+ def observation_space_sample (self , agent_ids : list = None ) -> MultiEnvDict :
547+ return self .envs [0 ].observation_space_sample (agent_ids )
548+
549+ @override (BaseEnv )
550+ def action_space_sample (self , agent_ids : list = None ) -> MultiEnvDict :
551+ return self .envs [0 ].action_space_sample (agent_ids )
370552
371553
372554class _MultiAgentEnvState :
0 commit comments