|
11 | 11 | <a href="https://github.com/RLE-Foundation/rllte/discussions"> Forum </a> |
|
12 | 12 | <a href="https://hub.rllte.dev/"> Benchmarks </a></h3> -->
|
13 | 13 |
|
14 |
| -<img src="https://img.shields.io/badge/License-MIT-%230677b8"> <img src="https://img.shields.io/badge/GPU-NVIDIA-%2377b900"> <img src="https://img.shields.io/badge/NPU-Ascend-%23c31d20"> <img src="https://img.shields.io/badge/Python-%3E%3D3.8-%2335709F"> <img src="https://img.shields.io/badge/Docs-Passing-%23009485"> <img src="https://img.shields.io/badge/Codestyle-Black-black"> <img src="https://img.shields.io/badge/PyPI-0.0.1-%23006DAD"> <img src="https://img.shields.io/badge/Coverage-97.00%25-green"> |
| 14 | +<img src="https://img.shields.io/badge/License-MIT-%230677b8"> <img src="https://img.shields.io/badge/GPU-NVIDIA-%2377b900"> <img src="https://img.shields.io/badge/NPU-Ascend-%23c31d20"> <img src="https://img.shields.io/badge/Python-%3E%3D3.8-%2335709F"> <img src="https://img.shields.io/badge/Docs-Passing-%23009485"> <img src="https://img.shields.io/badge/Codestyle-Black-black"> <img src="https://img.shields.io/badge/PyPI-0.0.1-%23006DAD"> |
| 15 | + |
| 16 | +<!-- <img src="https://img.shields.io/badge/Coverage-97.00%25-green"> --> |
15 | 17 |
|
16 | 18 | <!-- | [English](README.md) | [中文](docs/README-zh-Hans.md) | -->
|
17 | 19 |
|
|
39 | 41 | # Overview
|
40 | 42 | Inspired by the long-term evolution (LTE) standard project in telecommunications, aiming to provide development components for and standards for advancing RL research and applications. Beyond delivering top-notch algorithm implementations, **RLLTE** also serves as a **toolkit** for developing algorithms.
|
41 | 43 |
|
42 |
| -<div align="center"> |
| 44 | +<!-- <div align="center"> |
43 | 45 | <a href="https://youtu.be/PMF6fa72bmE" rel="nofollow">
|
44 | 46 | <img src='./docs/assets/images/youtube.png' style="width: 70%">
|
45 | 47 | </a>
|
46 | 48 | <br>
|
47 | 49 | An introduction to RLLTE.
|
48 |
| -</div> |
| 50 | +</div> --> |
49 | 51 |
|
50 | 52 | Why **RLLTE**?
|
51 | 53 | - 🧬 Long-term evolution for providing latest algorithms and tricks;
|
@@ -130,83 +132,104 @@ device = "cuda:0" -> device = "npu:0"
|
130 | 132 | ```
|
131 | 133 |
|
132 | 134 | ## Three Steps to Create Your RL Agent
|
| 135 | + |
| 136 | + |
133 | 137 | Developers only need three steps to implement an RL algorithm with **RLLTE**. The following example illustrates how to write an Advantage Actor-Critic (A2C) agent to solve Atari games.
|
134 | 138 | - Firstly, select a prototype:
|
135 | 139 | ``` py
|
136 | 140 | from rllte.common.prototype import OnPolicyAgent
|
137 | 141 | ```
|
138 | 142 | - Secondly, select necessary modules to build the agent:
|
139 |
| -``` py |
140 |
| -from rllte.xploit.encoder import MnihCnnEncoder |
141 |
| -from rllte.xploit.policy import OnPolicySharedActorCritic |
142 |
| -from rllte.xploit.storage import VanillaRolloutStorage |
143 |
| -from rllte.xplore.distribution import Categorical |
144 |
| -``` |
145 |
| -- Run the `.describe` function of the selected policy and you will see the following output: |
146 |
| -``` py |
147 |
| -OnPolicySharedActorCritic.describe() |
148 |
| -# Output: |
149 |
| -# ================================================================================ |
150 |
| -# Name : OnPolicySharedActorCritic |
151 |
| -# Structure : self.encoder (shared by actor and critic), self.actor, self.critic |
152 |
| -# Forward : obs -> self.encoder -> self.actor -> actions |
153 |
| -# : obs -> self.encoder -> self.critic -> values |
154 |
| -# : actions -> log_probs |
155 |
| -# Optimizers : self.optimizers['opt'] -> (self.encoder, self.actor, self.critic) |
156 |
| -# ================================================================================ |
157 |
| -``` |
158 |
| -This will illustrate the structure of the policy and indicate the optimizable parts. Finally, merge these modules and write an `.update` function: |
159 |
| -``` py |
160 |
| -from torch import nn |
161 |
| -import torch as th |
162 |
| - |
163 |
| -class A2C(OnPolicyAgent): |
164 |
| - def __init__(self, env, tag, seed, device, num_steps) -> None: |
165 |
| - super().__init__(env=env, tag=tag, seed=seed, device=device, num_steps=num_steps) |
166 |
| - # create modules |
167 |
| - encoder = MnihCnnEncoder(observation_space=env.observation_space, feature_dim=512) |
168 |
| - policy = OnPolicySharedActorCritic(observation_space=env.observation_space, |
169 |
| - action_space=env.action_space, |
170 |
| - feature_dim=512, |
171 |
| - opt_class=th.optim.Adam, |
172 |
| - opt_kwargs=dict(lr=2.5e-4, eps=1e-5), |
173 |
| - init_fn="xavier_uniform" |
174 |
| - ) |
175 |
| - storage = VanillaRolloutStorage(observation_space=env.observation_space, |
176 |
| - action_space=env.action_space, |
177 |
| - device=device, |
178 |
| - storage_size=self.num_steps, |
179 |
| - num_envs=self.num_envs, |
180 |
| - batch_size=256 |
181 |
| - ) |
182 |
| - dist = Categorical() |
183 |
| - # set all the modules |
184 |
| - self.set(encoder=encoder, policy=policy, storage=storage, distribution=dist) |
185 |
| - |
186 |
| - def update(self): |
187 |
| - for _ in range(4): |
188 |
| - for batch in self.storage.sample(): |
189 |
| - # evaluate the sampled actions |
190 |
| - new_values, new_log_probs, entropy = self.policy.evaluate_actions(obs=batch.observations, actions=batch.actions) |
191 |
| - # policy loss part |
192 |
| - policy_loss = - (batch.adv_targ * new_log_probs).mean() |
193 |
| - # value loss part |
194 |
| - value_loss = 0.5 * (new_values.flatten() - batch.returns).pow(2).mean() |
195 |
| - # update |
196 |
| - self.policy.optimizers['opt'].zero_grad(set_to_none=True) |
197 |
| - (value_loss * 0.5 + policy_loss - entropy * 0.01).backward() |
198 |
| - nn.utils.clip_grad_norm_(self.policy.parameters(), 0.5) |
199 |
| - self.policy.optimizers['opt'].step() |
200 |
| -``` |
201 |
| -Then train the agent by |
202 |
| -``` py |
203 |
| -from rllte.env import make_atari_env |
204 |
| -if __name__ == "__main__": |
205 |
| - device = "cuda" |
206 |
| - env = make_atari_env("PongNoFrameskip-v4", num_envs=8, seed=0, device=device) |
207 |
| - agent = A2C(env=env, tag="a2c_atari", seed=0, device=device, num_steps=128) |
208 |
| - agent.train(num_train_steps=10000000) |
209 |
| -``` |
| 143 | + |
| 144 | + <details> |
| 145 | + <summary>Click to expand code</summary> |
| 146 | + |
| 147 | + ``` py |
| 148 | + from rllte.xploit.encoder import MnihCnnEncoder |
| 149 | + from rllte.xploit.policy import OnPolicySharedActorCritic |
| 150 | + from rllte.xploit.storage import VanillaRolloutStorage |
| 151 | + from rllte.xplore.distribution import Categorical |
| 152 | + ``` |
| 153 | + - Run the `.describe` function of the selected policy and you will see the following output: |
| 154 | + ``` py |
| 155 | + OnPolicySharedActorCritic.describe() |
| 156 | + # Output: |
| 157 | + # ================================================================================ |
| 158 | + # Name : OnPolicySharedActorCritic |
| 159 | + # Structure : self.encoder (shared by actor and critic), self.actor, self.critic |
| 160 | + # Forward : obs -> self.encoder -> self.actor -> actions |
| 161 | + # : obs -> self.encoder -> self.critic -> values |
| 162 | + # : actions -> log_probs |
| 163 | + # Optimizers : self.optimizers['opt'] -> (self.encoder, self.actor, self.critic) |
| 164 | + # ================================================================================ |
| 165 | + ``` |
| 166 | + This illustrates the structure of the policy and indicate the optimizable parts. |
| 167 | + |
| 168 | + </details> |
| 169 | + |
| 170 | +- Thirdly, merge these modules and write an `.update` function: |
| 171 | + |
| 172 | + <details> |
| 173 | + <summary>Click to expand code</summary> |
| 174 | + |
| 175 | + ``` py |
| 176 | + from torch import nn |
| 177 | + import torch as th |
| 178 | + |
| 179 | + class A2C(OnPolicyAgent): |
| 180 | + def __init__(self, env, tag, seed, device, num_steps) -> None: |
| 181 | + super().__init__(env=env, tag=tag, seed=seed, device=device, num_steps=num_steps) |
| 182 | + # create modules |
| 183 | + encoder = MnihCnnEncoder(observation_space=env.observation_space, feature_dim=512) |
| 184 | + policy = OnPolicySharedActorCritic(observation_space=env.observation_space, |
| 185 | + action_space=env.action_space, |
| 186 | + feature_dim=512, |
| 187 | + opt_class=th.optim.Adam, |
| 188 | + opt_kwargs=dict(lr=2.5e-4, eps=1e-5), |
| 189 | + init_fn="xavier_uniform" |
| 190 | + ) |
| 191 | + storage = VanillaRolloutStorage(observation_space=env.observation_space, |
| 192 | + action_space=env.action_space, |
| 193 | + device=device, |
| 194 | + storage_size=self.num_steps, |
| 195 | + num_envs=self.num_envs, |
| 196 | + batch_size=256 |
| 197 | + ) |
| 198 | + dist = Categorical() |
| 199 | + # set all the modules |
| 200 | + self.set(encoder=encoder, policy=policy, storage=storage, distribution=dist) |
| 201 | + |
| 202 | + def update(self): |
| 203 | + for _ in range(4): |
| 204 | + for batch in self.storage.sample(): |
| 205 | + # evaluate the sampled actions |
| 206 | + new_values, new_log_probs, entropy = self.policy.evaluate_actions(obs=batch.observations, actions=batch.actions) |
| 207 | + # policy loss part |
| 208 | + policy_loss = - (batch.adv_targ * new_log_probs).mean() |
| 209 | + # value loss part |
| 210 | + value_loss = 0.5 * (new_values.flatten() - batch.returns).pow(2).mean() |
| 211 | + # update |
| 212 | + self.policy.optimizers['opt'].zero_grad(set_to_none=True) |
| 213 | + (value_loss * 0.5 + policy_loss - entropy * 0.01).backward() |
| 214 | + nn.utils.clip_grad_norm_(self.policy.parameters(), 0.5) |
| 215 | + self.policy.optimizers['opt'].step() |
| 216 | + ``` |
| 217 | + |
| 218 | + </details> |
| 219 | + |
| 220 | +- Finally, train the agent by |
| 221 | + <details> |
| 222 | + <summary>Click to expand code</summary> |
| 223 | + ``` py |
| 224 | + from rllte.env import make_atari_env |
| 225 | + if __name__ == "__main__": |
| 226 | + device = "cuda" |
| 227 | + env = make_atari_env("PongNoFrameskip-v4", num_envs=8, seed=0, device=device) |
| 228 | + agent = A2C(env=env, tag="a2c_atari", seed=0, device=device, num_steps=128) |
| 229 | + agent.train(num_train_steps=10000000) |
| 230 | + ``` |
| 231 | + </details> |
| 232 | + |
210 | 233 | As shown in this example, only a few dozen lines of code are needed to create RL agents with **RLLTE**.
|
211 | 234 |
|
212 | 235 | ## Algorithm Decoupling and Module Replacement
|
|
0 commit comments