Skip to content

Commit 128feb6

Browse files
authored
Added support for new PettingZoo API (#751)
1 parent b0c8d28 commit 128feb6

File tree

3 files changed

+35
-9
lines changed

3 files changed

+35
-9
lines changed

test/pettingzoo/test_pistonball.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import pprint
22

3-
import pytest
43
from pistonball import get_args, train_agent, watch
54

65

7-
@pytest.mark.skip(reason="TODO(Markus28): fix later")
86
def test_piston_ball(args=get_args()):
97
if args.watch:
108
watch(args)

test/pettingzoo/test_tic_tac_toe.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import pprint
22

3-
import pytest
43
from tic_tac_toe import get_args, train_agent, watch
54

65

7-
@pytest.mark.skip(reason="TODO(Markus28): fix later")
86
def test_tic_tac_toe(args=get_args()):
97
if args.watch:
108
watch(args)

tianshou/env/pettingzoo_env.py

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
1+
import warnings
12
from abc import ABC
23
from typing import Any, Dict, List, Tuple, Union
34

45
import gym.spaces
6+
import pettingzoo
7+
from packaging import version
58
from pettingzoo.utils.env import AECEnv
69
from pettingzoo.utils.wrappers import BaseWrapper
710

11+
if version.parse(pettingzoo.__version__) < version.parse("1.21.0"):
12+
warnings.warn(
13+
f"You are using PettingZoo {pettingzoo.__version__}. "
14+
f"Future tianshou versions may not support PettingZoo<1.21.0. "
15+
f"Consider upgrading your PettingZoo version.", DeprecationWarning
16+
)
17+
818

919
class PettingZooEnv(AECEnv, ABC):
1020
"""The interface for petting zoo environments.
@@ -57,7 +67,20 @@ def __init__(self, env: BaseWrapper):
5767

5868
def reset(self, *args: Any, **kwargs: Any) -> Union[dict, Tuple[dict, dict]]:
5969
self.env.reset(*args, **kwargs)
60-
observation, _, _, info = self.env.last(self)
70+
71+
# Here, we do not label the return values explicitly to keep compatibility with
72+
# old step API. TODO: Change once PettingZoo>=1.21.0 is required
73+
last_return = self.env.last(self)
74+
75+
if len(last_return) == 4:
76+
warnings.warn(
77+
"The PettingZoo environment is using the old step API. "
78+
"This API may not be supported in future versions of tianshou. "
79+
"We recommend that you update the environment code or apply a "
80+
"compatibility wrapper.", DeprecationWarning
81+
)
82+
83+
observation, info = last_return[0], last_return[-1]
6184
if isinstance(observation, dict) and 'action_mask' in observation:
6285
observation_dict = {
6386
'agent_id': self.env.agent_selection,
@@ -83,9 +106,16 @@ def reset(self, *args: Any, **kwargs: Any) -> Union[dict, Tuple[dict, dict]]:
83106
else:
84107
return observation_dict
85108

86-
def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]:
109+
def step(
110+
self, action: Any
111+
) -> Union[Tuple[Dict, List[int], bool, Dict], Tuple[Dict, List[int], bool, bool,
112+
Dict]]:
87113
self.env.step(action)
88-
observation, rew, done, info = self.env.last()
114+
115+
# Here, we do not label the return values explicitly to keep compatibility with
116+
# old step API. TODO: Change once PettingZoo>=1.21.0 is required
117+
last_return = self.env.last()
118+
observation = last_return[0]
89119
if isinstance(observation, dict) and 'action_mask' in observation:
90120
obs = {
91121
'agent_id': self.env.agent_selection,
@@ -105,15 +135,15 @@ def step(self, action: Any) -> Tuple[Dict, List[int], bool, Dict]:
105135

106136
for agent_id, reward in self.env.rewards.items():
107137
self.rewards[self.agent_idx[agent_id]] = reward
108-
return obs, self.rewards, done, info
138+
return (obs, self.rewards, *last_return[2:]) # type: ignore
109139

110140
def close(self) -> None:
111141
self.env.close()
112142

113143
def seed(self, seed: Any = None) -> None:
114144
try:
115145
self.env.seed(seed)
116-
except NotImplementedError:
146+
except (NotImplementedError, AttributeError):
117147
self.env.reset(seed=seed)
118148

119149
def render(self, mode: str = "human") -> Any:

0 commit comments

Comments
 (0)