Skip to content

Commit ef50681

Browse files
author
Yuanmo
committed
2 parents 766004a + 5b984b0 commit ef50681

File tree

9 files changed

+379
-15
lines changed

9 files changed

+379
-15
lines changed

pyproject.toml

+3-9
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ packages = ["rllte"]
1313

1414
[project]
1515
name = "rllte-core"
16-
version = "0.0.1.beta12"
16+
version = "0.0.1.beta13"
1717
authors = [
1818
{ name="Reinforcement Learning Evolution Foundation", email="[email protected]" },
1919
]
@@ -33,7 +33,7 @@ classifiers = [
3333
]
3434

3535
dependencies = [
36-
"gymnasium[accept-rom-license]",
36+
"gymnasium[accept-rom-license, other]",
3737
"torch",
3838
"torchvision",
3939
"termcolor",
@@ -56,13 +56,7 @@ tests = [
5656
"isort>=5.0",
5757
"black"
5858
]
59-
envs = [
60-
"envpool",
61-
"ale-py==0.8.1",
62-
"dm-control",
63-
"procgen",
64-
"minigrid"
65-
]
59+
6660
docs = [
6761
"mkdocs-material",
6862
"mkgendocs"

rllte/common/prototype/base_reward.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from rllte.common.preprocessing import process_action_space, process_observation_space
3535
from rllte.common.utils import TorchRunningMeanStd, RewardForwardFilter
3636

37+
3738
class BaseReward(ABC):
3839
"""Base class of reward module.
3940
@@ -61,8 +62,12 @@ def __init__(
6162
obs_norm_type: str = "rms",
6263
) -> None:
6364
# get environment information
64-
self.observation_space = envs.observation_space
65-
self.action_space = envs.action_space
65+
if isinstance(envs, VectorEnv):
66+
self.observation_space = envs.single_observation_space
67+
self.action_space = envs.single_action_space
68+
else:
69+
self.observation_space = envs.observation_space
70+
self.action_space = envs.action_space
6671
self.n_envs = envs.unwrapped.num_envs
6772
## process the observation and action space
6873
self.obs_shape: Tuple = process_observation_space(self.observation_space) # type: ignore
@@ -138,6 +143,7 @@ def init_normalization(self) -> None:
138143
"""Initialize the normalization parameters for observations if the RMS is used."""
139144
# TODO: better initialization parameters?
140145
num_steps, num_iters = 128, 20
146+
# for the vectorized environments with `Gymnasium2Torch` from rllte
141147
try:
142148
_, _ = self.envs.reset()
143149
if self.obs_norm_type == "rms":
@@ -157,14 +163,19 @@ def init_normalization(self) -> None:
157163
self.obs_norm.update(all_next_obs)
158164
all_next_obs = []
159165
except:
160-
# for the outdated gym version
166+
# for the normal vectorized environments
161167
_ = self.envs.reset()
162168
if self.obs_norm_type == "rms":
163169
all_next_obs = []
164170
for step in range(num_steps * num_iters):
165171
actions = [self.action_space.sample() for _ in range(self.n_envs)]
166172
actions = np.stack(actions)
167-
next_obs, _, _, _ = self.envs.step(actions)
173+
try:
174+
# for the old gym output
175+
next_obs, _, _, _ = self.envs.step(actions)
176+
except:
177+
# for the new gymnaisum output
178+
next_obs, _, _, _, _ = self.envs.step(actions)
168179
all_next_obs += th.as_tensor(next_obs).view(-1, *self.obs_shape)
169180
# update the running mean and std
170181
if len(all_next_obs) % (num_steps * self.n_envs) == 0:

rllte/env/README.md

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
Integrating RL environments in RLLTE is incredibly easy and efficient!
2+
3+
## Menu
4+
1. [Installation](#installation)
5+
2. [Usage](#usage)
6+
7+
## Installation
8+
9+
Assuming you are running inside a conda environment.
10+
11+
### Atari
12+
```
13+
pip install ale-py==0.8.1
14+
```
15+
16+
### Craftax
17+
18+
You will need a Jax GPU-enabled conda environment:
19+
20+
```
21+
conda create -n rllte jaxlib==*cuda jax python=3.11 -c conda-forge
22+
pip install craftax
23+
pip install brax
24+
pip install -e .[envs]
25+
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
26+
```
27+
28+
### DMC
29+
```
30+
pip install dm-control
31+
```
32+
33+
### SuperMarioBros
34+
```
35+
pip install gym-super-mario-bros==7.4.0
36+
```
37+
38+
### Minigrid
39+
```
40+
pip install minigrid
41+
```
42+
43+
### Miniworld
44+
```
45+
pip install miniworld
46+
```
47+
48+
### Procgen
49+
```
50+
pip install procgen
51+
```
52+
53+
### Envpool
54+
```
55+
pip install envpool
56+
```
57+
58+
## Usage
59+
60+
Each environment has a `make_env()` function in `rllte/env/<your_RL_env>/__init__.py` and its necessary wrappers in `rllte/env/<your_RL_env>/wrappers.py`. To add your custom environments, simply follow the same logic as the currently available environments, and the RL training will work flawlessly!
61+
62+
## Example training
63+
64+
```
65+
from rllte.agent import PPO
66+
from rllte.env import (
67+
make_mario_env,
68+
make_envpool_vizdoom_env,
69+
make_envpool_procgen_env,
70+
make_minigrid_env,
71+
make_envpool_atari_env,
72+
make_craftax_env
73+
)
74+
75+
# define params
76+
device = "cuda"
77+
78+
# define environment
79+
env = make_craftax_env(
80+
num_envs=32,
81+
device=device,
82+
)
83+
84+
# define agent
85+
agent = PPO(
86+
env=env,
87+
device=device
88+
)
89+
90+
# start training
91+
agent.train(
92+
num_train_steps=10_000_000,
93+
)
94+
```

rllte/env/__init__.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from .testing import make_multidiscrete_env as make_multidiscrete_env
2929
from .testing import make_box_env as make_box_env
3030
from .testing import make_discrete_env as make_discrete_env
31-
3231
from .utils import make_rllte_env as make_rllte_env
3332

3433
try:
@@ -52,6 +51,11 @@
5251
except Exception:
5352
pass
5453

54+
try:
55+
from .miniworld import make_miniworld_env as make_miniworld_env
56+
except Exception:
57+
pass
58+
5559
try:
5660
from .procgen import make_envpool_procgen_env as make_envpool_procgen_env
5761
from .procgen import make_procgen_env as make_procgen_env
@@ -60,6 +64,15 @@
6064

6165
try:
6266
from .mario import make_mario_env as make_mario_env
63-
from .mario import make_mario_multilevel_env as make_mario_multilevel_env
67+
except Exception:
68+
pass
69+
70+
try:
71+
from .craftax import make_craftax_env as make_craftax_env
72+
except Exception:
73+
pass
74+
75+
try:
76+
from .vizdoom import make_envpool_vizdoom_env as make_envpool_vizdoom_env
6477
except Exception:
6578
pass

rllte/env/craftax/__init__.py

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from craftax.envs.craftax_pixels_env import CraftaxPixelsEnv
2+
from craftax_classic.envs.craftax_pixels_env import CraftaxClassicPixelsEnv
3+
from environment_base.wrappers import (
4+
LogWrapper,
5+
BatchEnvWrapper,
6+
OptimisticResetVecEnvWrapper,
7+
)
8+
9+
from rllte.env.craftax.wrappers import TorchWrapper, ResizeTorchWrapper, RecordEpisodeStatistics4Craftax
10+
11+
def make_craftax_env(
12+
env_id: str = "Craftax-Classic",
13+
num_envs: int = 32,
14+
reset_ratio: int = 16,
15+
device: str = "cpu",
16+
):
17+
18+
if env_id == "Craftax-Classic":
19+
env = CraftaxClassicPixelsEnv()
20+
elif env_id == "Craftax":
21+
env = CraftaxPixelsEnv()
22+
else:
23+
raise ValueError(f"Unknown environment: {env_id}")
24+
25+
env = LogWrapper(env)
26+
env = OptimisticResetVecEnvWrapper(env, num_envs=num_envs, reset_ratio=reset_ratio)
27+
env = TorchWrapper(env, device=device)
28+
env = ResizeTorchWrapper(env, (84, 84))
29+
env = RecordEpisodeStatistics4Craftax(env)
30+
env.num_envs = num_envs
31+
env.single_observation_space = env.observation_space
32+
env.single_action_space = env.action_space
33+
return env
34+

rllte/env/craftax/wrappers.py

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import numpy as np
2+
import jax
3+
import gymnasium as gym
4+
import torch
5+
from dataclasses import asdict
6+
from brax.io import torch as brax_torch
7+
8+
class TorchWrapper(gym.Wrapper):
9+
"""Wrapper that converts Jax tensors to PyTorch tensors."""
10+
11+
def __init__(self, env, device):
12+
super().__init__(env)
13+
self.device = device
14+
self.env = env
15+
self.default_params = env.default_params
16+
self.metadata = {
17+
'render.modes': ['human', 'rgb_array'],
18+
}
19+
20+
# define obs and action space
21+
obs_shape = env.observation_space(self.default_params).shape
22+
self.observation_space = gym.spaces.Box(
23+
low=-1e6, high=1e6, shape=obs_shape)
24+
self.action_space = gym.spaces.Discrete(env.action_space(self.default_params).n)
25+
26+
# jit the reset function
27+
def reset(key):
28+
key1, key2 = jax.random.split(key)
29+
obs, state = self.env.reset(key2)
30+
return state, obs, key1, asdict(state)
31+
self._reset = jax.jit(reset)
32+
33+
# jit the step function
34+
def step(state, action):
35+
obs, env_state, reward, done, info = self.env.step(rng=self._key, state=state, action=action)
36+
return env_state, obs, reward, done, {**asdict(env_state), **info}
37+
self._step = jax.jit(step)
38+
39+
def reset(self, seed=0, options=None):
40+
self.seed(seed)
41+
self._state, obs, self._key, info = self._reset(self._key)
42+
return brax_torch.jax_to_torch(obs, device=self.device), info
43+
44+
def step(self, action):
45+
action = brax_torch.torch_to_jax(action)
46+
self._state, obs, reward, done, info = self._step(self._state, action)
47+
obs = brax_torch.jax_to_torch(obs, device=self.device)
48+
reward = brax_torch.jax_to_torch(reward, device=self.device)
49+
terminateds = brax_torch.jax_to_torch(done, device=self.device)
50+
truncateds = brax_torch.jax_to_torch(done, device=self.device)
51+
info = brax_torch.jax_to_torch(info, device=self.device)
52+
return obs, reward, terminateds, truncateds, info
53+
54+
def seed(self, seed: int = 0):
55+
self._key = jax.random.PRNGKey(seed)
56+
57+
class ResizeTorchWrapper(gym.Wrapper):
58+
"""Wrapper that resizes observations to a given shape."""
59+
60+
def __init__(self, env, shape):
61+
super().__init__(env)
62+
self.env = env
63+
num_channels = env.observation_space.shape[-1]
64+
self.shape = (num_channels, shape[0], shape[1])
65+
66+
# define obs and action space
67+
self.observation_space = gym.spaces.Box(
68+
low=-1e6, high=1e6, shape=self.shape)
69+
70+
def reset(self, seed=0, options=None):
71+
obs, info = self.env.reset(seed, options)
72+
obs = obs.permute(0, 3, 1, 2)
73+
obs = torch.nn.functional.interpolate(obs, size=self.shape[1:], mode='nearest')
74+
return obs, info
75+
76+
def step(self, action):
77+
obs, reward, terminateds, truncateds, info = self.env.step(action)
78+
obs = obs.permute(0, 3, 1, 2)
79+
obs = torch.nn.functional.interpolate(obs, size=self.shape[1:], mode='nearest')
80+
return obs, reward, terminateds, truncateds, info
81+
82+
class RecordEpisodeStatistics4Craftax(gym.Wrapper):
83+
def __init__(self, env: gym.Env, deque_size: int = 100) -> None:
84+
super().__init__(env)
85+
self.num_envs = getattr(env, "num_envs", 1)
86+
self.episode_returns = None
87+
self.episode_lengths = None
88+
89+
def reset(self, **kwargs):
90+
observations, infos = super().reset(**kwargs)
91+
self.episode_returns = np.zeros(self.num_envs, dtype=np.float32)
92+
self.episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
93+
self.returned_episode_returns = np.zeros(self.num_envs, dtype=np.float32)
94+
self.returned_episode_lengths = np.zeros(self.num_envs, dtype=np.int32)
95+
return observations, infos
96+
97+
def step(self, actions):
98+
observations, rewards, terms, truncs, infos = super().step(actions)
99+
self.episode_returns += rewards.cpu().numpy()
100+
self.episode_lengths += 1
101+
self.returned_episode_returns[:] = self.episode_returns
102+
self.returned_episode_lengths[:] = self.episode_lengths
103+
self.episode_returns *= 1 - infos["returned_episode"].cpu().numpy().astype(np.int32)
104+
self.episode_lengths *= 1 - infos["returned_episode"].cpu().numpy().astype(np.int32)
105+
infos["episode"] = {}
106+
infos["episode"]["r"] = self.returned_episode_returns
107+
infos["episode"]["l"] = self.returned_episode_lengths
108+
109+
for idx, d in enumerate(terms):
110+
if not d:
111+
infos["episode"]["r"][idx] = 0
112+
infos["episode"]["l"][idx] = 0
113+
114+
return observations, rewards, terms, truncs, infos

0 commit comments

Comments
 (0)