Skip to content

Commit 8bb8ecb

Browse files
authored
set policy.eval() before collector.collect (#204)
* fix #203 * no_grad argument in collector.collect
1 parent 34f714a commit 8bb8ecb

File tree

6 files changed

+23
-25
lines changed

6 files changed

+23
-25
lines changed

README.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,8 @@ $ pip install tianshou
5656
You can also install with the newest version through GitHub:
5757

5858
```bash
59-
# latest release
59+
# latest version
6060
$ pip install git+https://github.com/thu-ml/tianshou.git@master
61-
# develop version
62-
$ pip install git+https://github.com/thu-ml/tianshou.git@dev
6361
```
6462

6563
If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:

docs/contributing.rst

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ To install Tianshou in an "editable" mode, run
88

99
.. code-block:: bash
1010
11-
$ git checkout dev
1211
$ pip install -e ".[dev]"
1312
1413
in the main directory. This installation is removable by
@@ -70,9 +69,4 @@ To compile documentation into webpages, run
7069
7170
under the ``docs/`` directory. The generated webpages are in ``docs/_build`` and can be viewed with browsers.
7271

73-
Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/, and the develop version of documentation is in https://tianshou.readthedocs.io/en/dev/.
74-
75-
Pull Request
76-
------------
77-
78-
All of the commits should merge through the pull request to the ``dev`` branch. The pull request must have 2 approvals before merging.
72+
Chinese documentation is in https://tianshou.readthedocs.io/zh/latest/.

docs/index.rst

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,10 +46,8 @@ You can also install with the newest version through GitHub:
4646

4747
.. code-block:: bash
4848
49-
# latest release
49+
# latest version
5050
$ pip install git+https://github.com/thu-ml/tianshou.git@master
51-
# develop version
52-
$ pip install git+https://github.com/thu-ml/tianshou.git@dev
5351
5452
If you use Anaconda or Miniconda, you can install Tianshou through the following command lines:
5553

@@ -70,7 +68,7 @@ After installation, open your python console and type
7068

7169
If no error occurs, you have successfully installed Tianshou.
7270

73-
Tianshou is still under development, you can also check out the documents in stable version through `tianshou.readthedocs.io/en/stable/ <https://tianshou.readthedocs.io/en/stable/>`_ and the develop version through `tianshou.readthedocs.io/en/dev/ <https://tianshou.readthedocs.io/en/dev/>`_.
71+
Tianshou is still under development, you can also check out the documents in stable version through `tianshou.readthedocs.io/en/stable/ <https://tianshou.readthedocs.io/en/stable/>`_.
7472

7573
.. toctree::
7674
:maxdepth: 1

tianshou/data/collector.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,7 @@ def collect(self,
173173
n_episode: Optional[Union[int, List[int]]] = None,
174174
random: bool = False,
175175
render: Optional[float] = None,
176+
no_grad: bool = True,
176177
) -> Dict[str, float]:
177178
"""Collect a specified number of step or episode.
178179
@@ -185,6 +186,8 @@ def collect(self,
185186
defaults to ``False``.
186187
:param float render: the sleep time between rendering consecutive
187188
frames, defaults to ``None`` (no rendering).
189+
:param bool no_grad: whether to retain gradient in policy.forward,
190+
defaults to ``True`` (no gradient retaining).
188191
189192
.. note::
190193
@@ -252,7 +255,10 @@ def collect(self,
252255
result = Batch(
253256
act=[spaces[i].sample() for i in self._ready_env_ids])
254257
else:
255-
with torch.no_grad():
258+
if no_grad:
259+
with torch.no_grad(): # faster than retain_grad version
260+
result = self.policy(self.data, last_state)
261+
else:
256262
result = self.policy(self.data, last_state)
257263

258264
state = result.get('state', Batch())

tianshou/trainer/offpolicy.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,13 @@ def offpolicy_trainer(
7676
start_time = time.time()
7777
test_in_train = test_in_train and train_collector.policy == policy
7878
for epoch in range(1, 1 + max_epoch):
79-
# train
80-
policy.train()
81-
if train_fn:
82-
train_fn(epoch)
8379
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
8480
**tqdm_config) as t:
8581
while t.n < t.total:
82+
# collect
83+
if train_fn:
84+
train_fn(epoch)
85+
policy.eval()
8686
result = train_collector.collect(n_step=collect_per_step)
8787
data = {}
8888
if test_in_train and stop_fn and stop_fn(result['rew']):
@@ -99,9 +99,10 @@ def offpolicy_trainer(
9999
start_time, train_collector, test_collector,
100100
test_result['rew'])
101101
else:
102-
policy.train()
103102
if train_fn:
104103
train_fn(epoch)
104+
# train
105+
policy.train()
105106
for i in range(update_per_step * min(
106107
result['n/st'] // collect_per_step, t.total - t.n)):
107108
global_step += collect_per_step

tianshou/trainer/onpolicy.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -76,13 +76,13 @@ def onpolicy_trainer(
7676
start_time = time.time()
7777
test_in_train = test_in_train and train_collector.policy == policy
7878
for epoch in range(1, 1 + max_epoch):
79-
# train
80-
policy.train()
81-
if train_fn:
82-
train_fn(epoch)
8379
with tqdm.tqdm(total=step_per_epoch, desc=f'Epoch #{epoch}',
8480
**tqdm_config) as t:
8581
while t.n < t.total:
82+
# collect
83+
if train_fn:
84+
train_fn(epoch)
85+
policy.eval()
8686
result = train_collector.collect(n_episode=collect_per_step)
8787
data = {}
8888
if test_in_train and stop_fn and stop_fn(result['rew']):
@@ -99,9 +99,10 @@ def onpolicy_trainer(
9999
start_time, train_collector, test_collector,
100100
test_result['rew'])
101101
else:
102-
policy.train()
103102
if train_fn:
104103
train_fn(epoch)
104+
# train
105+
policy.train()
105106
losses = policy.update(
106107
0, train_collector.buffer, batch_size, repeat_per_collect)
107108
train_collector.reset_buffer()

0 commit comments

Comments
 (0)