Skip to content

Commit 10d9190

Browse files
Add Trainers as generators (#559)
The new proposed feature is to have trainers as generators. The usage pattern is: ```python trainer = OnPolicyTrainer(...) for epoch, epoch_stat, info in trainer: print(f"Epoch: {epoch}") print(epoch_stat) print(info) do_something_with_policy() query_something_about_policy() make_a_plot_with(epoch_stat) display(info) ``` - epoch int: the epoch number - epoch_stat dict: a large collection of metrics of the current epoch, including stat - info dict: the usual dict out of the non-generator version of the trainer You can even iterate on several different trainers at the same time: ```python trainer1 = OnPolicyTrainer(...) trainer2 = OnPolicyTrainer(...) for result1, result2, ... in zip(trainer1, trainer2, ...): compare_results(result1, result2, ...) ``` Co-authored-by: Jiayi Weng <trinkle23897@gmail.com>
1 parent 2336a7d commit 10d9190

File tree

14 files changed

+864
-488
lines changed

14 files changed

+864
-488
lines changed

.github/ISSUE_TEMPLATE.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,6 @@
77
- [ ] I have searched through the [issue tracker](https://github.com/thu-ml/tianshou/issues) for duplicates
88
- [ ] I have mentioned version numbers, operating system and environment, where applicable:
99
```python
10-
import tianshou, torch, numpy, sys
11-
print(tianshou.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
10+
import tianshou, gym, torch, numpy, sys
11+
print(tianshou.__version__, gym.__version__, torch.__version__, numpy.__version__, sys.version, sys.platform)
1212
```

Makefile

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,8 @@ lint:
2222
flake8 ${LINT_PATHS} --count --show-source --statistics
2323

2424
format:
25-
# sort imports
2625
$(call check_install, isort)
2726
isort ${LINT_PATHS}
28-
# reformat using yapf
2927
$(call check_install, yapf)
3028
yapf -ir ${LINT_PATHS}
3129

@@ -57,6 +55,6 @@ doc-clean:
5755

5856
clean: doc-clean
5957

60-
commit-checks: format lint mypy check-docstyle spelling
58+
commit-checks: lint check-codestyle mypy check-docstyle spelling
6159

6260
.PHONY: clean spelling doc mypy lint format check-codestyle check-docstyle commit-checks

docs/api/tianshou.trainer.rst

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,49 @@
11
tianshou.trainer
22
================
33

4-
.. automodule:: tianshou.trainer
4+
5+
On-policy
6+
---------
7+
8+
.. autoclass:: tianshou.trainer.OnpolicyTrainer
9+
:members:
10+
:undoc-members:
11+
:show-inheritance:
12+
13+
.. autofunction:: tianshou.trainer.onpolicy_trainer
14+
15+
.. autoclass:: tianshou.trainer.onpolicy_trainer_iter
16+
17+
18+
Off-policy
19+
----------
20+
21+
.. autoclass:: tianshou.trainer.OffpolicyTrainer
22+
:members:
23+
:undoc-members:
24+
:show-inheritance:
25+
26+
.. autofunction:: tianshou.trainer.offpolicy_trainer
27+
28+
.. autoclass:: tianshou.trainer.offpolicy_trainer_iter
29+
30+
31+
Offline
32+
-------
33+
34+
.. autoclass:: tianshou.trainer.OfflineTrainer
535
:members:
636
:undoc-members:
737
:show-inheritance:
38+
39+
.. autofunction:: tianshou.trainer.offline_trainer
40+
41+
.. autoclass:: tianshou.trainer.offline_trainer_iter
42+
43+
44+
utils
45+
-----
46+
47+
.. autofunction:: tianshou.trainer.test_episode
48+
49+
.. autofunction:: tianshou.trainer.gather_info

docs/spelling_wordlist.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,15 @@ fqf
2424
iqn
2525
qrdqn
2626
rl
27+
offpolicy
28+
onpolicy
2729
quantile
2830
quantiles
2931
dqn
3032
param
3133
async
3234
subprocess
35+
deque
3336
nn
3437
equ
3538
cql

docs/tutorials/concepts.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,26 @@ Once you have a collector and a policy, you can start writing the training metho
380380

381381
Tianshou has three types of trainer: :func:`~tianshou.trainer.onpolicy_trainer` for on-policy algorithms such as Policy Gradient, :func:`~tianshou.trainer.offpolicy_trainer` for off-policy algorithms such as DQN, and :func:`~tianshou.trainer.offline_trainer` for offline algorithms such as BCQ. Please check out :doc:`/api/tianshou.trainer` for the usage.
382382

383+
We also provide the corresponding iterator-based trainer classes :class:`~tianshou.trainer.OnpolicyTrainer`, :class:`~tianshou.trainer.OffpolicyTrainer`, :class:`~tianshou.trainer.OfflineTrainer` to facilitate users writing more flexible training logic:
384+
::
385+
386+
trainer = OnpolicyTrainer(...)
387+
for epoch, epoch_stat, info in trainer:
388+
print(f"Epoch: {epoch}")
389+
print(epoch_stat)
390+
print(info)
391+
do_something_with_policy()
392+
query_something_about_policy()
393+
make_a_plot_with(epoch_stat)
394+
display(info)
395+
396+
# or even iterate on several trainers at the same time
397+
398+
trainer1 = OnpolicyTrainer(...)
399+
trainer2 = OnpolicyTrainer(...)
400+
for result1, result2, ... in zip(trainer1, trainer2, ...):
401+
compare_results(result1, result2, ...)
402+
383403

384404
.. _pseudocode:
385405

test/continuous/test_ppo.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from tianshou.data import Collector, VectorReplayBuffer
1212
from tianshou.env import DummyVectorEnv
1313
from tianshou.policy import PPOPolicy
14-
from tianshou.trainer import onpolicy_trainer
14+
from tianshou.trainer import OnpolicyTrainer
1515
from tianshou.utils import TensorboardLogger
1616
from tianshou.utils.net.common import ActorCritic, Net
1717
from tianshou.utils.net.continuous import ActorProb, Critic
@@ -157,7 +157,7 @@ def save_checkpoint_fn(epoch, env_step, gradient_step):
157157
print("Fail to restore policy and optim.")
158158

159159
# trainer
160-
result = onpolicy_trainer(
160+
trainer = OnpolicyTrainer(
161161
policy,
162162
train_collector,
163163
test_collector,
@@ -173,10 +173,16 @@ def save_checkpoint_fn(epoch, env_step, gradient_step):
173173
resume_from_log=args.resume,
174174
save_checkpoint_fn=save_checkpoint_fn
175175
)
176-
assert stop_fn(result['best_reward'])
176+
177+
for epoch, epoch_stat, info in trainer:
178+
print(f"Epoch: {epoch}")
179+
print(epoch_stat)
180+
print(info)
181+
182+
assert stop_fn(info["best_reward"])
177183

178184
if __name__ == '__main__':
179-
pprint.pprint(result)
185+
pprint.pprint(info)
180186
# Let's watch its performance!
181187
env = gym.make(args.task)
182188
policy.eval()

test/continuous/test_sac_with_il.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def get_args():
2424
parser = argparse.ArgumentParser()
2525
parser.add_argument('--task', type=str, default='Pendulum-v0')
2626
parser.add_argument('--reward-threshold', type=float, default=None)
27-
parser.add_argument('--seed', type=int, default=0)
27+
parser.add_argument('--seed', type=int, default=1)
2828
parser.add_argument('--buffer-size', type=int, default=20000)
2929
parser.add_argument('--actor-lr', type=float, default=1e-3)
3030
parser.add_argument('--critic-lr', type=float, default=1e-3)

test/continuous/test_td3.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from tianshou.env import DummyVectorEnv
1212
from tianshou.exploration import GaussianNoise
1313
from tianshou.policy import TD3Policy
14-
from tianshou.trainer import offpolicy_trainer
14+
from tianshou.trainer import OffpolicyTrainer
1515
from tianshou.utils import TensorboardLogger
1616
from tianshou.utils.net.common import Net
1717
from tianshou.utils.net.continuous import Actor, Critic
@@ -135,8 +135,8 @@ def save_fn(policy):
135135
def stop_fn(mean_rewards):
136136
return mean_rewards >= args.reward_threshold
137137

138-
# trainer
139-
result = offpolicy_trainer(
138+
# Iterator trainer
139+
trainer = OffpolicyTrainer(
140140
policy,
141141
train_collector,
142142
test_collector,
@@ -148,12 +148,17 @@ def stop_fn(mean_rewards):
148148
update_per_step=args.update_per_step,
149149
stop_fn=stop_fn,
150150
save_fn=save_fn,
151-
logger=logger
151+
logger=logger,
152152
)
153-
assert stop_fn(result['best_reward'])
153+
for epoch, epoch_stat, info in trainer:
154+
print(f"Epoch: {epoch}")
155+
print(epoch_stat)
156+
print(info)
154157

155-
if __name__ == '__main__':
156-
pprint.pprint(result)
158+
assert stop_fn(info["best_reward"])
159+
160+
if __name__ == "__main__":
161+
pprint.pprint(info)
157162
# Let's watch its performance!
158163
env = gym.make(args.task)
159164
policy.eval()

test/offline/test_cql.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from tianshou.data import Collector, VectorReplayBuffer
1313
from tianshou.env import DummyVectorEnv
1414
from tianshou.policy import CQLPolicy
15-
from tianshou.trainer import offline_trainer
15+
from tianshou.trainer import OfflineTrainer
1616
from tianshou.utils import TensorboardLogger
1717
from tianshou.utils.net.common import Net
1818
from tianshou.utils.net.continuous import ActorProb, Critic
@@ -195,7 +195,7 @@ def watch():
195195
collector.collect(n_episode=1, render=1 / 35)
196196

197197
# trainer
198-
result = offline_trainer(
198+
trainer = OfflineTrainer(
199199
policy,
200200
buffer,
201201
test_collector,
@@ -207,11 +207,17 @@ def watch():
207207
stop_fn=stop_fn,
208208
logger=logger,
209209
)
210-
assert stop_fn(result['best_reward'])
210+
211+
for epoch, epoch_stat, info in trainer:
212+
print(f"Epoch: {epoch}")
213+
print(epoch_stat)
214+
print(info)
215+
216+
assert stop_fn(info["best_reward"])
211217

212218
# Let's watch its performance!
213-
if __name__ == '__main__':
214-
pprint.pprint(result)
219+
if __name__ == "__main__":
220+
pprint.pprint(info)
215221
env = gym.make(args.task)
216222
policy.eval()
217223
collector = Collector(policy, env)

tianshou/trainer/__init__.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,34 @@
11
"""Trainer package."""
22

3-
# isort:skip_file
4-
5-
from tianshou.trainer.utils import test_episode, gather_info
6-
from tianshou.trainer.onpolicy import onpolicy_trainer
7-
from tianshou.trainer.offpolicy import offpolicy_trainer
8-
from tianshou.trainer.offline import offline_trainer
3+
from tianshou.trainer.base import BaseTrainer
4+
from tianshou.trainer.offline import (
5+
OfflineTrainer,
6+
offline_trainer,
7+
offline_trainer_iter,
8+
)
9+
from tianshou.trainer.offpolicy import (
10+
OffpolicyTrainer,
11+
offpolicy_trainer,
12+
offpolicy_trainer_iter,
13+
)
14+
from tianshou.trainer.onpolicy import (
15+
OnpolicyTrainer,
16+
onpolicy_trainer,
17+
onpolicy_trainer_iter,
18+
)
19+
from tianshou.trainer.utils import gather_info, test_episode
920

1021
__all__ = [
22+
"BaseTrainer",
1123
"offpolicy_trainer",
24+
"offpolicy_trainer_iter",
25+
"OffpolicyTrainer",
1226
"onpolicy_trainer",
27+
"onpolicy_trainer_iter",
28+
"OnpolicyTrainer",
1329
"offline_trainer",
30+
"offline_trainer_iter",
31+
"OfflineTrainer",
1432
"test_episode",
1533
"gather_info",
1634
]

0 commit comments

Comments
 (0)