Skip to content

Commit a44034d

Browse files
author
Yuanmo
committed
update README.md
1 parent e96e6af commit a44034d

File tree

1 file changed

+97
-74
lines changed

1 file changed

+97
-74
lines changed

README.md

+97-74
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@
1111
<a href="https://github.com/RLE-Foundation/rllte/discussions"> Forum </a> |
1212
<a href="https://hub.rllte.dev/"> Benchmarks </a></h3> -->
1313

14-
<img src="https://img.shields.io/badge/License-MIT-%230677b8"> <img src="https://img.shields.io/badge/GPU-NVIDIA-%2377b900"> <img src="https://img.shields.io/badge/NPU-Ascend-%23c31d20"> <img src="https://img.shields.io/badge/Python-%3E%3D3.8-%2335709F"> <img src="https://img.shields.io/badge/Docs-Passing-%23009485"> <img src="https://img.shields.io/badge/Codestyle-Black-black"> <img src="https://img.shields.io/badge/PyPI-0.0.1-%23006DAD"> <img src="https://img.shields.io/badge/Coverage-97.00%25-green">
14+
<img src="https://img.shields.io/badge/License-MIT-%230677b8"> <img src="https://img.shields.io/badge/GPU-NVIDIA-%2377b900"> <img src="https://img.shields.io/badge/NPU-Ascend-%23c31d20"> <img src="https://img.shields.io/badge/Python-%3E%3D3.8-%2335709F"> <img src="https://img.shields.io/badge/Docs-Passing-%23009485"> <img src="https://img.shields.io/badge/Codestyle-Black-black"> <img src="https://img.shields.io/badge/PyPI-0.0.1-%23006DAD">
15+
16+
<!-- <img src="https://img.shields.io/badge/Coverage-97.00%25-green"> -->
1517

1618
<!-- | [English](README.md) | [中文](docs/README-zh-Hans.md) | -->
1719

@@ -39,13 +41,13 @@
3941
# Overview
4042
Inspired by the long-term evolution (LTE) standard project in telecommunications, aiming to provide development components for and standards for advancing RL research and applications. Beyond delivering top-notch algorithm implementations, **RLLTE** also serves as a **toolkit** for developing algorithms.
4143

42-
<div align="center">
44+
<!-- <div align="center">
4345
<a href="https://youtu.be/PMF6fa72bmE" rel="nofollow">
4446
<img src='./docs/assets/images/youtube.png' style="width: 70%">
4547
</a>
4648
<br>
4749
An introduction to RLLTE.
48-
</div>
50+
</div> -->
4951

5052
Why **RLLTE**?
5153
- 🧬 Long-term evolution for providing latest algorithms and tricks;
@@ -130,83 +132,104 @@ device = "cuda:0" -> device = "npu:0"
130132
```
131133

132134
## Three Steps to Create Your RL Agent
135+
136+
133137
Developers only need three steps to implement an RL algorithm with **RLLTE**. The following example illustrates how to write an Advantage Actor-Critic (A2C) agent to solve Atari games.
134138
- Firstly, select a prototype:
135139
``` py
136140
from rllte.common.prototype import OnPolicyAgent
137141
```
138142
- Secondly, select necessary modules to build the agent:
139-
``` py
140-
from rllte.xploit.encoder import MnihCnnEncoder
141-
from rllte.xploit.policy import OnPolicySharedActorCritic
142-
from rllte.xploit.storage import VanillaRolloutStorage
143-
from rllte.xplore.distribution import Categorical
144-
```
145-
- Run the `.describe` function of the selected policy and you will see the following output:
146-
``` py
147-
OnPolicySharedActorCritic.describe()
148-
# Output:
149-
# ================================================================================
150-
# Name : OnPolicySharedActorCritic
151-
# Structure : self.encoder (shared by actor and critic), self.actor, self.critic
152-
# Forward : obs -> self.encoder -> self.actor -> actions
153-
# : obs -> self.encoder -> self.critic -> values
154-
# : actions -> log_probs
155-
# Optimizers : self.optimizers['opt'] -> (self.encoder, self.actor, self.critic)
156-
# ================================================================================
157-
```
158-
This will illustrate the structure of the policy and indicate the optimizable parts. Finally, merge these modules and write an `.update` function:
159-
``` py
160-
from torch import nn
161-
import torch as th
162-
163-
class A2C(OnPolicyAgent):
164-
def __init__(self, env, tag, seed, device, num_steps) -> None:
165-
super().__init__(env=env, tag=tag, seed=seed, device=device, num_steps=num_steps)
166-
# create modules
167-
encoder = MnihCnnEncoder(observation_space=env.observation_space, feature_dim=512)
168-
policy = OnPolicySharedActorCritic(observation_space=env.observation_space,
169-
action_space=env.action_space,
170-
feature_dim=512,
171-
opt_class=th.optim.Adam,
172-
opt_kwargs=dict(lr=2.5e-4, eps=1e-5),
173-
init_fn="xavier_uniform"
174-
)
175-
storage = VanillaRolloutStorage(observation_space=env.observation_space,
176-
action_space=env.action_space,
177-
device=device,
178-
storage_size=self.num_steps,
179-
num_envs=self.num_envs,
180-
batch_size=256
181-
)
182-
dist = Categorical()
183-
# set all the modules
184-
self.set(encoder=encoder, policy=policy, storage=storage, distribution=dist)
185-
186-
def update(self):
187-
for _ in range(4):
188-
for batch in self.storage.sample():
189-
# evaluate the sampled actions
190-
new_values, new_log_probs, entropy = self.policy.evaluate_actions(obs=batch.observations, actions=batch.actions)
191-
# policy loss part
192-
policy_loss = - (batch.adv_targ * new_log_probs).mean()
193-
# value loss part
194-
value_loss = 0.5 * (new_values.flatten() - batch.returns).pow(2).mean()
195-
# update
196-
self.policy.optimizers['opt'].zero_grad(set_to_none=True)
197-
(value_loss * 0.5 + policy_loss - entropy * 0.01).backward()
198-
nn.utils.clip_grad_norm_(self.policy.parameters(), 0.5)
199-
self.policy.optimizers['opt'].step()
200-
```
201-
Then train the agent by
202-
``` py
203-
from rllte.env import make_atari_env
204-
if __name__ == "__main__":
205-
device = "cuda"
206-
env = make_atari_env("PongNoFrameskip-v4", num_envs=8, seed=0, device=device)
207-
agent = A2C(env=env, tag="a2c_atari", seed=0, device=device, num_steps=128)
208-
agent.train(num_train_steps=10000000)
209-
```
143+
144+
<details>
145+
<summary>Click to expand code</summary>
146+
147+
``` py
148+
from rllte.xploit.encoder import MnihCnnEncoder
149+
from rllte.xploit.policy import OnPolicySharedActorCritic
150+
from rllte.xploit.storage import VanillaRolloutStorage
151+
from rllte.xplore.distribution import Categorical
152+
```
153+
- Run the `.describe` function of the selected policy and you will see the following output:
154+
``` py
155+
OnPolicySharedActorCritic.describe()
156+
# Output:
157+
# ================================================================================
158+
# Name : OnPolicySharedActorCritic
159+
# Structure : self.encoder (shared by actor and critic), self.actor, self.critic
160+
# Forward : obs -> self.encoder -> self.actor -> actions
161+
# : obs -> self.encoder -> self.critic -> values
162+
# : actions -> log_probs
163+
# Optimizers : self.optimizers['opt'] -> (self.encoder, self.actor, self.critic)
164+
# ================================================================================
165+
```
166+
This illustrates the structure of the policy and indicate the optimizable parts.
167+
168+
</details>
169+
170+
- Thirdly, merge these modules and write an `.update` function:
171+
172+
<details>
173+
<summary>Click to expand code</summary>
174+
175+
``` py
176+
from torch import nn
177+
import torch as th
178+
179+
class A2C(OnPolicyAgent):
180+
def __init__(self, env, tag, seed, device, num_steps) -> None:
181+
super().__init__(env=env, tag=tag, seed=seed, device=device, num_steps=num_steps)
182+
# create modules
183+
encoder = MnihCnnEncoder(observation_space=env.observation_space, feature_dim=512)
184+
policy = OnPolicySharedActorCritic(observation_space=env.observation_space,
185+
action_space=env.action_space,
186+
feature_dim=512,
187+
opt_class=th.optim.Adam,
188+
opt_kwargs=dict(lr=2.5e-4, eps=1e-5),
189+
init_fn="xavier_uniform"
190+
)
191+
storage = VanillaRolloutStorage(observation_space=env.observation_space,
192+
action_space=env.action_space,
193+
device=device,
194+
storage_size=self.num_steps,
195+
num_envs=self.num_envs,
196+
batch_size=256
197+
)
198+
dist = Categorical()
199+
# set all the modules
200+
self.set(encoder=encoder, policy=policy, storage=storage, distribution=dist)
201+
202+
def update(self):
203+
for _ in range(4):
204+
for batch in self.storage.sample():
205+
# evaluate the sampled actions
206+
new_values, new_log_probs, entropy = self.policy.evaluate_actions(obs=batch.observations, actions=batch.actions)
207+
# policy loss part
208+
policy_loss = - (batch.adv_targ * new_log_probs).mean()
209+
# value loss part
210+
value_loss = 0.5 * (new_values.flatten() - batch.returns).pow(2).mean()
211+
# update
212+
self.policy.optimizers['opt'].zero_grad(set_to_none=True)
213+
(value_loss * 0.5 + policy_loss - entropy * 0.01).backward()
214+
nn.utils.clip_grad_norm_(self.policy.parameters(), 0.5)
215+
self.policy.optimizers['opt'].step()
216+
```
217+
218+
</details>
219+
220+
- Finally, train the agent by
221+
<details>
222+
<summary>Click to expand code</summary>
223+
``` py
224+
from rllte.env import make_atari_env
225+
if __name__ == "__main__":
226+
device = "cuda"
227+
env = make_atari_env("PongNoFrameskip-v4", num_envs=8, seed=0, device=device)
228+
agent = A2C(env=env, tag="a2c_atari", seed=0, device=device, num_steps=128)
229+
agent.train(num_train_steps=10000000)
230+
```
231+
</details>
232+
210233
As shown in this example, only a few dozen lines of code are needed to create RL agents with **RLLTE**.
211234

212235
## Algorithm Decoupling and Module Replacement

0 commit comments

Comments
 (0)