Skip to content

Commit 99c99bb

Browse files
authored
Fix 2 bugs and refactor RunningMeanStd to support dict obs norm (#695)
* fix #689 * fix #672 * refactor RMS class * fix #688
1 parent 6505484 commit 99c99bb

File tree

9 files changed

+59
-20
lines changed

9 files changed

+59
-20
lines changed

docs/tutorials/cheatsheet.rst

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
Cheat Sheet
22
===========
33

4-
This page shows some code snippets of how to use Tianshou to develop new algorithms / apply algorithms to new scenarios.
4+
This page shows some code snippets of how to use Tianshou to develop new
5+
algorithms / apply algorithms to new scenarios.
56

6-
By the way, some of these issues can be resolved by using a ``gym.Wrapper``. It could be a universal solution in the policy-environment interaction. But you can also use the batch processor :ref:`preprocess_fn`.
7+
By the way, some of these issues can be resolved by using a ``gym.Wrapper``.
8+
It could be a universal solution in the policy-environment interaction. But
9+
you can also use the batch processor :ref:`preprocess_fn` or vectorized
10+
environment wrapper :class:`~tianshou.env.VectorEnvWrapper`.
711

812

913
.. _network_api:
@@ -22,6 +26,18 @@ Build New Policy
2226
See :class:`~tianshou.policy.BasePolicy`.
2327

2428

29+
.. _eval_policy:
30+
31+
Manually Evaluate Policy
32+
------------------------
33+
34+
If you'd like to manually see the action generated by a well-trained agent:
35+
::
36+
37+
# assume obs is a single environment observation
38+
action = policy(Batch(obs=np.array([obs]))).act[0]
39+
40+
2541
.. _customize_training:
2642

2743
Customize Training Process

docs/tutorials/dqn.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,12 @@ Watch the Agent's Performance
256256
collector = ts.data.Collector(policy, env, exploration_noise=True)
257257
collector.collect(n_episode=1, render=1 / 35)
258258

259+
If you'd like to manually see the action generated by a well-trained agent:
260+
::
261+
262+
# assume obs is a single environment observation
263+
action = policy(Batch(obs=np.array([obs]))).act[0]
264+
259265

260266
.. _customized_trainer:
261267

test/base/test_env.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,14 @@ def assert_get(v, expected):
211211
v.close()
212212

213213

214+
def test_attr_unwrapped():
215+
train_envs = DummyVectorEnv([lambda: gym.make("CartPole-v1")])
216+
train_envs.set_env_attr("test_attribute", 1337)
217+
assert train_envs.get_env_attr("test_attribute") == [1337]
218+
assert hasattr(train_envs.workers[0].env, "test_attribute")
219+
assert hasattr(train_envs.workers[0].env.unwrapped, "test_attribute")
220+
221+
214222
def test_env_obs_dtype():
215223
for obs_type in ["array", "object"]:
216224
envs = SubprocVectorEnv(
@@ -349,6 +357,7 @@ def test_venv_wrapper_envpool_gym_reset_return_info():
349357
test_venv_wrapper_envpool()
350358
test_env_obs_dtype()
351359
test_vecenv()
360+
test_attr_unwrapped()
352361
test_async_env()
353362
test_async_check_id()
354363
test_env_reset_optional_kwargs()

tianshou/data/collector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@ def reset_env(self, gym_reset_kwargs: Optional[Dict[str, Any]] = None) -> None:
145145
)
146146
obs = processed_data.get("obs", obs)
147147
info = processed_data.get("info", info)
148-
self.data.info = info
148+
self.data.info = info
149149
else:
150150
obs = rval
151151
if self.preprocess_fn:

tianshou/env/venv_wrappers.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -68,24 +68,17 @@ class VectorEnvNormObs(VectorEnvWrapper):
6868
"""An observation normalization wrapper for vectorized environments.
6969
7070
:param bool update_obs_rms: whether to update obs_rms. Default to True.
71-
:param float clip_obs: the maximum absolute value for observation. Default to
72-
10.0.
73-
:param float epsilon: To avoid division by zero.
7471
"""
7572

7673
def __init__(
7774
self,
7875
venv: BaseVectorEnv,
7976
update_obs_rms: bool = True,
80-
clip_obs: float = 10.0,
81-
epsilon: float = np.finfo(np.float32).eps.item(),
8277
) -> None:
8378
super().__init__(venv)
8479
# initialize observation running mean/std
8580
self.update_obs_rms = update_obs_rms
8681
self.obs_rms = RunningMeanStd()
87-
self.clip_max = clip_obs
88-
self.eps = epsilon
8982

9083
def reset(
9184
self,
@@ -127,8 +120,7 @@ def step(
127120

128121
def _norm_obs(self, obs: np.ndarray) -> np.ndarray:
129122
if self.obs_rms:
130-
obs = (obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.eps)
131-
obs = np.clip(obs, -self.clip_max, self.clip_max)
123+
return self.obs_rms.norm(obs) # type: ignore
132124
return obs
133125

134126
def set_obs_rms(self, obs_rms: RunningMeanStd) -> None:

tianshou/env/worker/dummy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def get_env_attr(self, key: str) -> Any:
1717
return getattr(self.env, key)
1818

1919
def set_env_attr(self, key: str, value: Any) -> None:
20-
setattr(self.env, key, value)
20+
setattr(self.env.unwrapped, key, value)
2121

2222
def reset(self, **kwargs: Any) -> Union[np.ndarray, Tuple[np.ndarray, dict]]:
2323
if "seed" in kwargs:

tianshou/env/worker/ray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
class _SetAttrWrapper(gym.Wrapper):
1515

1616
def set_env_attr(self, key: str, value: Any) -> None:
17-
setattr(self.env, key, value)
17+
setattr(self.env.unwrapped, key, value)
1818

1919
def get_env_attr(self, key: str) -> Any:
2020
return getattr(self.env, key)

tianshou/env/worker/subproc.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ def _setup_buf(space: gym.Space) -> Union[dict, tuple, ShArray]:
4949
if isinstance(space, gym.spaces.Dict):
5050
assert isinstance(space.spaces, OrderedDict)
5151
return {k: _setup_buf(v) for k, v in space.spaces.items()}
52-
elif isinstance(space, gym.spaces.Tuple):
53-
assert isinstance(space.spaces, tuple)
54-
return tuple([_setup_buf(t) for t in space.spaces])
52+
elif isinstance(space, gym.spaces.Tuple): # type: ignore
53+
assert isinstance(space.spaces, tuple) # type: ignore
54+
return tuple([_setup_buf(t) for t in space.spaces]) # type: ignore
5555
else:
5656
return ShArray(space.dtype, space.shape) # type: ignore
5757

@@ -122,7 +122,7 @@ def _encode_obs(
122122
elif cmd == "getattr":
123123
p.send(getattr(env, data) if hasattr(env, data) else None)
124124
elif cmd == "setattr":
125-
setattr(env, data["key"], data["value"])
125+
setattr(env.unwrapped, data["key"], data["value"])
126126
else:
127127
p.close()
128128
raise NotImplementedError

tianshou/utils/statistics.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from numbers import Number
2-
from typing import List, Union
2+
from typing import List, Optional, Union
33

44
import numpy as np
55
import torch
@@ -70,15 +70,31 @@ class RunningMeanStd(object):
7070
"""Calculates the running mean and std of a data stream.
7171
7272
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
73+
74+
:param mean: the initial mean estimation for data array. Default to 0.
75+
:param std: the initial standard error estimation for data array. Default to 1.
76+
:param float clip_max: the maximum absolute value for data array. Default to
77+
10.0.
78+
:param float epsilon: To avoid division by zero.
7379
"""
7480

7581
def __init__(
7682
self,
7783
mean: Union[float, np.ndarray] = 0.0,
78-
std: Union[float, np.ndarray] = 1.0
84+
std: Union[float, np.ndarray] = 1.0,
85+
clip_max: Optional[float] = 10.0,
86+
epsilon: float = np.finfo(np.float32).eps.item(),
7987
) -> None:
8088
self.mean, self.var = mean, std
89+
self.clip_max = clip_max
8190
self.count = 0
91+
self.eps = epsilon
92+
93+
def norm(self, data_array: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
94+
data_array = (data_array - self.mean) / np.sqrt(self.var + self.eps)
95+
if self.clip_max:
96+
data_array = np.clip(data_array, -self.clip_max, self.clip_max)
97+
return data_array
8298

8399
def update(self, data_array: np.ndarray) -> None:
84100
"""Add a batch of item into RMS with the same shape, modify mean/var/count."""

0 commit comments

Comments
 (0)