Skip to content

Commit 58e1632

Browse files
authored
Tianshou v2 (#1259)
See change log. Resolves #1091 Resolves #810 Resolves #959 Resolves #898 Resolves #919 Resolves #913 Resolves #948 Resolves #949 Resolves #1204
2 parents 5e15adb + c74cc17 commit 58e1632

File tree

237 files changed

+17997
-16257
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

237 files changed

+17997
-16257
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,5 +160,8 @@ docs/conf.py
160160
/temp
161161
/temp*.py
162162

163+
# Serena
164+
/.serena
165+
163166
# determinism test snapshots
164167
/test/resources/determinism/

CHANGELOG.md

Lines changed: 263 additions & 2 deletions
Large diffs are not rendered by default.

README.md

Lines changed: 131 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
1. Convenient high-level interfaces for applications of RL (training an implemented algorithm on a custom environment).
1313
1. Large scope: online (on- and off-policy) and offline RL, experimental support for multi-agent RL (MARL), experimental support for model-based RL, and more
1414

15-
1615
Unlike other reinforcement learning libraries, which may have complex codebases,
1716
unfriendly high-level APIs, or are not optimized for speed, Tianshou provides a high-performance, modularized framework
1817
and user-friendly interfaces for building deep reinforcement learning agents. One more aspect that sets Tianshou apart is its
@@ -149,9 +148,11 @@ If no errors are reported, you have successfully installed Tianshou.
149148

150149
## Documentation
151150

152-
Tutorials and API documentation are hosted on [tianshou.readthedocs.io](https://tianshou.readthedocs.io/).
151+
Find example scripts in the [test/]( https://github.com/thu-ml/tianshou/blob/master/test) and [examples/](https://github.com/thu-ml/tianshou/blob/master/examples) folders.
153152

154-
Find example scripts in the [test/](https://github.com/thu-ml/tianshou/blob/master/test) and [examples/](https://github.com/thu-ml/tianshou/blob/master/examples) folders.
153+
Tutorials and API documentation are hosted on [tianshou.readthedocs.io](https://tianshou.readthedocs.io/).
154+
**Important**: The documentation is currently being updated to reflect the changes in Tianshou v2.0.0. Not all features are documented yet, and some parts are outdated (they are marked as such). The documentation will be fully updated when
155+
the v2.0.0 release is finalized.
155156

156157
## Why Tianshou?
157158

@@ -180,20 +181,23 @@ Check out the [GitHub Actions](https://github.com/thu-ml/tianshou/actions) page
180181

181182
Atari and MuJoCo benchmark results can be found in the [examples/atari/](examples/atari/) and [examples/mujoco/](examples/mujoco/) folders respectively. **Our MuJoCo results reach or exceed the level of performance of most existing benchmarks.**
182183

183-
### Policy Interface
184+
### Algorithm Abstraction
185+
186+
Reinforcement learning algorithms are build on abstractions for
187+
188+
- on-policy algorithms (`OnPolicyAlgorithm`),
189+
- off-policy algorithms (`OffPolicyAlgorithm`), and
190+
- offline algorithms (`OfflineAlgorithm`),
191+
192+
all of which clearly separate the core algorithm from the training process and the respective environment interactions.
184193

185-
All algorithms implement the following, highly general API:
194+
In each case, the implementation of an algorithm necessarily involves only the implementation of methods for
186195

187-
- `__init__`: initialize the policy;
188-
- `forward`: compute actions based on given observations;
189-
- `process_buffer`: process initial buffer, which is useful for some offline learning algorithms
190-
- `process_fn`: preprocess data from the replay buffer (since we have reformulated _all_ algorithms to replay buffer-based algorithms);
191-
- `learn`: learn from a given batch of data;
192-
- `post_process_fn`: update the replay buffer from the learning process (e.g., prioritized replay buffer needs to update the weight);
193-
- `update`: the main interface for training, i.e., `process_fn -> learn -> post_process_fn`.
196+
- pre-processing a batch of data, augmenting it with necessary information/sufficient statistics for learning (`_preprocess_batch`),
197+
- updating model parameters based on an augmented batch of data (`_update_with_batch`).
194198

195-
The implementation of this API suffices for a new algorithm to be applicable within Tianshou,
196-
making experimenation with new approaches particularly straightforward.
199+
The implementation of these methods suffices for a new algorithm to be applicable within Tianshou,
200+
making experimentation with new approaches particularly straightforward.
197201

198202
## Quick Start
199203

@@ -203,70 +207,68 @@ Tianshou provides two API levels:
203207
- the procedural interface, which provides a maximum of control, especially for very advanced users and developers of reinforcement learning algorithms.
204208

205209
In the following, let us consider an example application using the _CartPole_ gymnasium environment.
206-
We shall apply the deep Q network (DQN) learning algorithm using both APIs.
210+
We shall apply the deep Q-network (DQN) learning algorithm using both APIs.
207211

208212
### High-Level API
209213

210-
To get started, we need some imports.
211-
212-
```python
213-
from tianshou.highlevel.config import SamplingConfig
214-
from tianshou.highlevel.env import (
215-
EnvFactoryRegistered,
216-
VectorEnvType,
217-
)
218-
from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig
219-
from tianshou.highlevel.params.policy_params import DQNParams
220-
from tianshou.highlevel.trainer import (
221-
EpochTestCallbackDQNSetEps,
222-
EpochTrainCallbackDQNSetEps,
223-
EpochStopCallbackRewardThreshold
224-
)
225-
```
226-
227214
In the high-level API, the basis for an RL experiment is an `ExperimentBuilder`
228215
with which we can build the experiment we then seek to run.
229216
Since we want to use DQN, we use the specialization `DQNExperimentBuilder`.
230-
The other imports serve to provide configuration options for our experiment.
231217

232218
The high-level API provides largely declarative semantics, i.e. the code is
233219
almost exclusively concerned with configuration that controls what to do
234220
(rather than how to do it).
235221

236222
```python
223+
from tianshou.highlevel.config import OffPolicyTrainingConfig
224+
from tianshou.highlevel.env import (
225+
EnvFactoryRegistered,
226+
VectorEnvType,
227+
)
228+
from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig
229+
from tianshou.highlevel.params.algorithm_params import DQNParams
230+
from tianshou.highlevel.trainer import (
231+
EpochStopCallbackRewardThreshold,
232+
)
233+
237234
experiment = (
238-
DQNExperimentBuilder(
239-
EnvFactoryRegistered(task="CartPole-v1", train_seed=0, test_seed=0, venv_type=VectorEnvType.DUMMY),
240-
ExperimentConfig(
241-
persistence_enabled=False,
242-
watch=True,
243-
watch_render=1 / 35,
244-
watch_num_episodes=100,
245-
),
246-
SamplingConfig(
247-
num_epochs=10,
248-
step_per_epoch=10000,
249-
batch_size=64,
250-
num_train_envs=10,
251-
num_test_envs=100,
252-
buffer_size=20000,
253-
step_per_collect=10,
254-
update_per_step=1 / 10,
255-
),
256-
)
257-
.with_dqn_params(
258-
DQNParams(
259-
lr=1e-3,
260-
discount_factor=0.9,
261-
estimation_step=3,
262-
target_update_freq=320,
263-
),
264-
)
265-
.with_model_factory_default(hidden_sizes=(64, 64))
266-
.with_epoch_train_callback(EpochTrainCallbackDQNSetEps(0.3))
267-
.with_epoch_test_callback(EpochTestCallbackDQNSetEps(0.0))
268-
.with_epoch_stop_callback(EpochStopCallbackRewardThreshold(195))
269-
.build()
235+
DQNExperimentBuilder(
236+
EnvFactoryRegistered(
237+
task="CartPole-v1",
238+
venv_type=VectorEnvType.DUMMY,
239+
train_seed=0,
240+
test_seed=10,
241+
),
242+
ExperimentConfig(
243+
persistence_enabled=False,
244+
watch=True,
245+
watch_render=1 / 35,
246+
watch_num_episodes=100,
247+
),
248+
OffPolicyTrainingConfig(
249+
max_epochs=10,
250+
epoch_num_steps=10000,
251+
batch_size=64,
252+
num_train_envs=10,
253+
num_test_envs=100,
254+
buffer_size=20000,
255+
collection_step_num_env_steps=10,
256+
update_step_num_gradient_steps_per_sample=1 / 10,
257+
),
258+
)
259+
.with_dqn_params(
260+
DQNParams(
261+
lr=1e-3,
262+
gamma=0.9,
263+
n_step_return_horizon=3,
264+
target_update_freq=320,
265+
eps_training=0.3,
266+
eps_inference=0.0,
267+
),
268+
)
269+
.with_model_factory_default(hidden_sizes=(64, 64))
270+
.with_epoch_stop_callback(EpochStopCallbackRewardThreshold(195))
271+
.build()
270272
)
271273
experiment.run()
272274
```
@@ -281,24 +283,25 @@ The experiment builder takes three arguments:
281283
episodes (`watch_num_episodes=100`). We have disabled persistence, because
282284
we do not want to save training logs, the agent or its configuration for
283285
future use.
284-
- the sampling configuration, which controls fundamental training parameters,
286+
- the training configuration, which controls fundamental training parameters,
285287
such as the total number of epochs we run the experiment for (`num_epochs=10`)
286288
and the number of environment steps each epoch shall consist of
287-
(`step_per_epoch=10000`).
289+
(`epoch_num_steps=10000`).
288290
Every epoch consists of a series of data collection (rollout) steps and
289291
training steps.
290-
The parameter `step_per_collect` controls the amount of data that is
292+
The parameter `collection_step_num_env_steps` controls the amount of data that is
291293
collected in each collection step and after each collection step, we
292294
perform a training step, applying a gradient-based update based on a sample
293295
of data (`batch_size=64`) taken from the buffer of data that has been
294-
collected. For further details, see the documentation of `SamplingConfig`.
296+
collected. For further details, see the documentation of configuration class.
295297

296-
We then proceed to configure some of the parameters of the DQN algorithm itself
297-
and of the neural network model we want to use.
298-
A DQN-specific detail is the use of callbacks to configure the algorithm's
299-
epsilon parameter for exploration. We want to use random exploration during rollouts
300-
(train callback), but we don't when evaluating the agent's performance in the test
301-
environments (test callback).
298+
We then proceed to configure some of the parameters of the DQN algorithm itself:
299+
For instance, we control the epsilon parameter for exploration.
300+
We want to use random exploration during rollouts for training (`eps_training`),
301+
but we don't when evaluating the agent's performance in the test environments
302+
(`eps_inference`).
303+
Furthermore, we configure model parameters of the network for the Q function,
304+
parametrising the number of hidden layers of the default MLP factory.
302305

303306
Find the script in [examples/discrete/discrete_dqn_hl.py](examples/discrete/discrete_dqn_hl.py).
304307
Here's a run (with the training time cut short):
@@ -309,15 +312,15 @@ Here's a run (with the training time cut short):
309312

310313
Find many further applications of the high-level API in the `examples/` folder;
311314
look for scripts ending with `_hl.py`.
312-
Note that most of these examples require the extra package `argparse`
315+
Note that most of these examples require the extra `argparse`
313316
(install it by adding `--extras argparse` when invoking poetry).
314317

315318
### Procedural API
316319

317320
Let us now consider an analogous example in the procedural API.
318321
Find the full script in [examples/discrete/discrete_dqn.py](https://github.com/thu-ml/tianshou/blob/master/examples/discrete/discrete_dqn.py).
319322

320-
First, import some relevant packages:
323+
First, import the relevant packages:
321324

322325
```python
323326
import gymnasium as gym
@@ -326,7 +329,7 @@ from torch.utils.tensorboard import SummaryWriter
326329
import tianshou as ts
327330
```
328331

329-
Define some hyper-parameters:
332+
Define hyper-parameters:
330333

331334
```python
332335
task = 'CartPole-v1'
@@ -335,14 +338,13 @@ train_num, test_num = 10, 100
335338
gamma, n_step, target_freq = 0.9, 3, 320
336339
buffer_size = 20000
337340
eps_train, eps_test = 0.1, 0.05
338-
step_per_epoch, step_per_collect = 10000, 10
341+
epoch_num_steps, collection_step_num_env_steps = 10000, 10
339342
```
340343

341344
Initialize the logger:
342345

343346
```python
344347
logger = ts.utils.TensorboardLogger(SummaryWriter('log/dqn'))
345-
# For other loggers, see https://tianshou.readthedocs.io/en/master/01_tutorials/05_logger.html
346348
```
347349

348350
Make environments:
@@ -353,53 +355,78 @@ train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_
353355
test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)])
354356
```
355357

356-
Create the network as well as its optimizer:
358+
Create the network, policy, and algorithm:
357359

358360
```python
359361
from tianshou.utils.net.common import Net
362+
from tianshou.algorithm import DQN
363+
from tianshou.algorithm.modelfree.dqn import DiscreteQLearningPolicy
364+
from tianshou.algorithm.optim import AdamOptimizerFactory
360365

361366
# Note: You can easily define other networks.
362367
# See https://tianshou.readthedocs.io/en/master/01_tutorials/00_dqn.html#build-the-network
363368
env = gym.make(task, render_mode="human")
364369
state_shape = env.observation_space.shape or env.observation_space.n
365370
action_shape = env.action_space.shape or env.action_space.n
366-
net = Net(state_shape=state_shape, action_shape=action_shape, hidden_sizes=[128, 128, 128])
367-
optim = torch.optim.Adam(net.parameters(), lr=lr)
368-
```
369-
370-
Set up the policy and collectors:
371+
net = Net(
372+
state_shape=state_shape, action_shape=action_shape,
373+
hidden_sizes=[128, 128, 128]
374+
)
371375

372-
```python
373-
policy = ts.policy.DQNPolicy(
376+
policy = DiscreteQLearningPolicy(
374377
model=net,
375-
optim=optim,
376-
discount_factor=gamma,
377378
action_space=env.action_space,
378-
estimation_step=n_step,
379+
eps_training=eps_train,
380+
eps_inference=eps_test
381+
)
382+
383+
# Create the algorithm with the policy and optimizer factory
384+
algorithm = DQN(
385+
policy=policy,
386+
optim=AdamOptimizerFactory(lr=lr),
387+
gamma=gamma,
388+
n_step_return_horizon=n_step,
379389
target_update_freq=target_freq
380390
)
381-
train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(buffer_size, train_num), exploration_noise=True)
382-
test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True) # because DQN uses epsilon-greedy method
383391
```
384392

385-
Let's train it:
393+
Set up the collectors:
386394

387395
```python
388-
result = ts.trainer.OffpolicyTrainer(
389-
policy=policy,
396+
train_collector = ts.data.Collector(policy, train_envs,
397+
ts.data.VectorReplayBuffer(buffer_size, train_num), exploration_noise=True)
398+
test_collector = ts.data.Collector(policy, test_envs,
399+
exploration_noise=True) # because DQN uses epsilon-greedy method
400+
```
401+
402+
Let's train it using the algorithm:
403+
404+
```python
405+
from tianshou.highlevel.config import OffPolicyTrainingConfig
406+
407+
# Create training configuration
408+
training_config = OffPolicyTrainingConfig(
409+
max_epochs=epoch,
410+
epoch_num_steps=epoch_num_steps,
411+
batch_size=batch_size,
412+
num_train_envs=train_num,
413+
num_test_envs=test_num,
414+
buffer_size=buffer_size,
415+
collection_step_num_env_steps=collection_step_num_env_steps,
416+
update_step_num_gradient_steps_per_sample=1 / collection_step_num_env_steps,
417+
test_step_num_episodes=test_num,
418+
)
419+
420+
# Run training (trainer is created automatically by the algorithm)
421+
result = algorithm.run_training(
422+
training_config=training_config,
390423
train_collector=train_collector,
391424
test_collector=test_collector,
392-
max_epoch=epoch,
393-
step_per_epoch=step_per_epoch,
394-
step_per_collect=step_per_collect,
395-
episode_per_test=test_num,
396-
batch_size=batch_size,
397-
update_per_step=1 / step_per_collect,
425+
logger=logger,
398426
train_fn=lambda epoch, env_step: policy.set_eps(eps_train),
399427
test_fn=lambda epoch, env_step: policy.set_eps(eps_test),
400428
stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
401-
logger=logger,
402-
).run()
429+
)
403430
print(f"Finished training in {result.timing.total_time} seconds")
404431
```
405432

0 commit comments

Comments
 (0)