1+ import warnings
12from abc import ABC
23from typing import Any , Dict , List , Tuple , Union
34
45import gym .spaces
6+ import pettingzoo
7+ from packaging import version
58from pettingzoo .utils .env import AECEnv
69from pettingzoo .utils .wrappers import BaseWrapper
710
11+ if version .parse (pettingzoo .__version__ ) < version .parse ("1.21.0" ):
12+ warnings .warn (
13+ f"You are using PettingZoo { pettingzoo .__version__ } . "
14+ f"Future tianshou versions may not support PettingZoo<1.21.0. "
15+ f"Consider upgrading your PettingZoo version." , DeprecationWarning
16+ )
17+
818
919class PettingZooEnv (AECEnv , ABC ):
1020 """The interface for petting zoo environments.
@@ -57,7 +67,20 @@ def __init__(self, env: BaseWrapper):
5767
5868 def reset (self , * args : Any , ** kwargs : Any ) -> Union [dict , Tuple [dict , dict ]]:
5969 self .env .reset (* args , ** kwargs )
60- observation , _ , _ , info = self .env .last (self )
70+
71+ # Here, we do not label the return values explicitly to keep compatibility with
72+ # old step API. TODO: Change once PettingZoo>=1.21.0 is required
73+ last_return = self .env .last (self )
74+
75+ if len (last_return ) == 4 :
76+ warnings .warn (
77+ "The PettingZoo environment is using the old step API. "
78+ "This API may not be supported in future versions of tianshou. "
79+ "We recommend that you update the environment code or apply a "
80+ "compatibility wrapper." , DeprecationWarning
81+ )
82+
83+ observation , info = last_return [0 ], last_return [- 1 ]
6184 if isinstance (observation , dict ) and 'action_mask' in observation :
6285 observation_dict = {
6386 'agent_id' : self .env .agent_selection ,
@@ -83,9 +106,16 @@ def reset(self, *args: Any, **kwargs: Any) -> Union[dict, Tuple[dict, dict]]:
83106 else :
84107 return observation_dict
85108
86- def step (self , action : Any ) -> Tuple [Dict , List [int ], bool , Dict ]:
109+ def step (
110+ self , action : Any
111+ ) -> Union [Tuple [Dict , List [int ], bool , Dict ], Tuple [Dict , List [int ], bool , bool ,
112+ Dict ]]:
87113 self .env .step (action )
88- observation , rew , done , info = self .env .last ()
114+
115+ # Here, we do not label the return values explicitly to keep compatibility with
116+ # old step API. TODO: Change once PettingZoo>=1.21.0 is required
117+ last_return = self .env .last ()
118+ observation = last_return [0 ]
89119 if isinstance (observation , dict ) and 'action_mask' in observation :
90120 obs = {
91121 'agent_id' : self .env .agent_selection ,
@@ -105,15 +135,15 @@ def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]:
105135
106136 for agent_id , reward in self .env .rewards .items ():
107137 self .rewards [self .agent_idx [agent_id ]] = reward
108- return obs , self .rewards , done , info
138+ return ( obs , self .rewards , * last_return [ 2 :]) # type: ignore
109139
110140 def close (self ) -> None :
111141 self .env .close ()
112142
113143 def seed (self , seed : Any = None ) -> None :
114144 try :
115145 self .env .seed (seed )
116- except NotImplementedError :
146+ except ( NotImplementedError , AttributeError ) :
117147 self .env .reset (seed = seed )
118148
119149 def render (self , mode : str = "human" ) -> Any :
0 commit comments