Skip to content

Commit 4bf7a91

Browse files
authored
Benchmarking (#1276)
Refactoring of the scripts, reducing parametrization of HL scripts to the minimum, restored same default config as in v0.5.0 (except for mujoco task versions, which have been bumped from 3 to 4). Various improvements in the rliable eval code are done. Also, this PR adds the possibility to run and evaluate multiple experiments directly from an ExperimentBuilder. This possibility is used to establish a benchmarking script that will run multiple scripts in parallel in tmux sessions, evaluate them with rliable and aggregate the stats such that they can be displayed in the benchmarking section in the logs.
2 parents c310b5a + 1e9aac0 commit 4bf7a91

File tree

312 files changed

+4923
-4522
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

312 files changed

+4923
-4522
lines changed

CHANGELOG.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,7 @@ Our main test environment remains Python 3.11-based for the time being (see `poe
320320
- `highlevel`:
321321
- Change the way in which seeding is handled: The mechanism introduced in v1.1.0
322322
was completely revised:
323-
- The `train_seed` and `test_seed` attributes were removed from `SamplingConfig`.
323+
- The `training_seed` and `test_seed` attributes were removed from `SamplingConfig`.
324324
Instead, the seeds are derived from the seed defined in `ExperimentConfig`.
325325
- Seed attributes of `EnvFactory` classes were removed.
326326
Instead, seeds are passed to methods of `EnvFactory`.
@@ -555,7 +555,7 @@ A detailed list of changes can be found below.
555555
#1194 #1195
556556
- `env`:
557557
- `EnvFactoryRegistered`: parameter `seed` has been replaced by the pair
558-
of parameters `train_seed` and `test_seed`
558+
of parameters `training_seed` and `test_seed`
559559
Persisted instances will continue to work correctly.
560560
Subclasses such as `AtariEnvFactory` are also affected requires
561561
explicit train and test seeds. #1074

README.md

Lines changed: 47 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -235,53 +235,53 @@ almost exclusively concerned with configuration that controls what to do
235235
```python
236236
from tianshou.highlevel.config import OffPolicyTrainingConfig
237237
from tianshou.highlevel.env import (
238-
EnvFactoryRegistered,
239-
VectorEnvType,
238+
EnvFactoryRegistered,
239+
VectorEnvType,
240240
)
241241
from tianshou.highlevel.experiment import DQNExperimentBuilder, ExperimentConfig
242242
from tianshou.highlevel.params.algorithm_params import DQNParams
243243
from tianshou.highlevel.trainer import (
244-
EpochStopCallbackRewardThreshold,
244+
EpochStopCallbackRewardThreshold,
245245
)
246246

247247
experiment = (
248-
DQNExperimentBuilder(
249-
EnvFactoryRegistered(
250-
task="CartPole-v1",
251-
venv_type=VectorEnvType.DUMMY,
252-
train_seed=0,
253-
test_seed=10,
254-
),
255-
ExperimentConfig(
256-
persistence_enabled=False,
257-
watch=True,
258-
watch_render=1 / 35,
259-
watch_num_episodes=100,
260-
),
261-
OffPolicyTrainingConfig(
262-
max_epochs=10,
263-
epoch_num_steps=10000,
264-
batch_size=64,
265-
num_train_envs=10,
266-
num_test_envs=100,
267-
buffer_size=20000,
268-
collection_step_num_env_steps=10,
269-
update_step_num_gradient_steps_per_sample=1 / 10,
270-
),
271-
)
272-
.with_dqn_params(
273-
DQNParams(
274-
lr=1e-3,
275-
gamma=0.9,
276-
n_step_return_horizon=3,
277-
target_update_freq=320,
278-
eps_training=0.3,
279-
eps_inference=0.0,
280-
),
281-
)
282-
.with_model_factory_default(hidden_sizes=(64, 64))
283-
.with_epoch_stop_callback(EpochStopCallbackRewardThreshold(195))
284-
.build()
248+
DQNExperimentBuilder(
249+
EnvFactoryRegistered(
250+
task="CartPole-v1",
251+
venv_type=VectorEnvType.DUMMY,
252+
training_seed=0,
253+
test_seed=10,
254+
),
255+
ExperimentConfig(
256+
persistence_enabled=False,
257+
watch=True,
258+
watch_render=1 / 35,
259+
watch_num_episodes=100,
260+
),
261+
OffPolicyTrainingConfig(
262+
max_epochs=10,
263+
epoch_num_steps=10000,
264+
batch_size=64,
265+
num_training_envs=10,
266+
num_test_envs=100,
267+
buffer_size=20000,
268+
collection_step_num_env_steps=10,
269+
update_step_num_gradient_steps_per_sample=1 / 10,
270+
),
271+
)
272+
.with_dqn_params(
273+
DQNParams(
274+
lr=1e-3,
275+
gamma=0.9,
276+
n_step_return_horizon=3,
277+
target_update_freq=320,
278+
eps_training=0.3,
279+
eps_inference=0.0,
280+
),
281+
)
282+
.with_model_factory_default(hidden_sizes=(64, 64))
283+
.with_epoch_stop_callback(EpochStopCallbackRewardThreshold(195))
284+
.build()
285285
)
286286
experiment.run()
287287
```
@@ -352,7 +352,7 @@ Define hyper-parameters:
352352
```python
353353
task = 'CartPole-v1'
354354
lr, epoch, batch_size = 1e-3, 10, 64
355-
train_num, test_num = 10, 100
355+
num_training_envs, num_test_envs = 10, 100
356356
gamma, n_step, target_freq = 0.9, 3, 320
357357
buffer_size = 20000
358358
eps_train, eps_test = 0.1, 0.05
@@ -369,8 +369,8 @@ Create the environments:
369369

370370
```python
371371
# You can also try SubprocVectorEnv, which will use parallelization
372-
train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
373-
test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)])
372+
training_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(num_training_envs)])
373+
test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
374374
```
375375

376376
Create the network, policy, and algorithm:
@@ -408,10 +408,10 @@ algorithm = DQN(
408408
Set up the collectors:
409409

410410
```python
411-
train_collector = ts.data.Collector[CollectStats](
411+
training_collector = ts.data.Collector[CollectStats](
412412
algorithm,
413-
train_envs,
414-
ts.data.VectorReplayBuffer(buffer_size, num_train_envs),
413+
training_envs,
414+
ts.data.VectorReplayBuffer(buffer_size, num_training_envs),
415415
exploration_noise=True,
416416
)
417417
test_collector = ts.data.Collector[CollectStats](
@@ -426,7 +426,7 @@ Let's train the model using the algorithm:
426426
```python
427427
result = algorithm.run_training(
428428
OffPolicyTrainerParams(
429-
train_collector=train_collector,
429+
training_collector=training_collector,
430430
test_collector=test_collector,
431431
max_epochs=epoch,
432432
epoch_num_steps=epoch_num_steps,

0 commit comments

Comments
 (0)