Skip to content

Commit b73093f

Browse files
committed
v2: Update parameter names (mainly test_num -> num_test_envs)
1 parent c264917 commit b73093f

17 files changed

+42
-42
lines changed

examples/atari/atari_dqn_hl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def main(
4343
update_per_step: float = 0.1,
4444
batch_size: int = 32,
4545
num_train_envs: int = 10,
46-
test_num: int = 10,
46+
num_test_envs: int = 10,
4747
frames_stack: int = 4,
4848
icm_lr_scale: float = 0.0,
4949
icm_reward_scale: float = 0.01,
@@ -56,7 +56,7 @@ def main(
5656
epoch_num_steps=epoch_num_steps,
5757
batch_size=batch_size,
5858
num_train_envs=num_train_envs,
59-
num_test_envs=test_num,
59+
num_test_envs=num_test_envs,
6060
buffer_size=buffer_size,
6161
collection_step_num_env_steps=collection_step_num_env_steps,
6262
update_step_num_gradient_steps_per_sample=update_per_step,

examples/atari/atari_iqn_hl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def main(
4444
update_per_step: float = 0.1,
4545
batch_size: int = 32,
4646
num_train_envs: int = 10,
47-
test_num: int = 10,
47+
num_test_envs: int = 10,
4848
frames_stack: int = 4,
4949
) -> None:
5050
log_name = os.path.join(task, "iqn", str(experiment_config.seed), datetime_tag())
@@ -54,7 +54,7 @@ def main(
5454
epoch_num_steps=epoch_num_steps,
5555
batch_size=batch_size,
5656
num_train_envs=num_train_envs,
57-
num_test_envs=test_num,
57+
num_test_envs=num_test_envs,
5858
buffer_size=buffer_size,
5959
collection_step_num_env_steps=collection_step_num_env_steps,
6060
update_step_num_gradient_steps_per_sample=update_per_step,

examples/atari/atari_ppo_hl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def main(
3737
batch_size: int = 256,
3838
hidden_sizes: Sequence[int] = (512,),
3939
num_train_envs: int = 10,
40-
test_num: int = 10,
40+
num_test_envs: int = 10,
4141
return_scaling: bool = False,
4242
vf_coef: float = 0.25,
4343
ent_coef: float = 0.01,
@@ -62,7 +62,7 @@ def main(
6262
epoch_num_steps=epoch_num_steps,
6363
batch_size=batch_size,
6464
num_train_envs=num_train_envs,
65-
num_test_envs=test_num,
65+
num_test_envs=num_test_envs,
6666
buffer_size=buffer_size,
6767
collection_step_num_env_steps=collection_step_num_env_steps,
6868
update_step_num_repetitions=update_step_num_repetitions,

examples/atari/atari_sac_hl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def main(
4343
batch_size: int = 64,
4444
hidden_sizes: Sequence[int] = (512,),
4545
num_train_envs: int = 10,
46-
test_num: int = 10,
46+
num_test_envs: int = 10,
4747
frames_stack: int = 4,
4848
icm_lr_scale: float = 0.0,
4949
icm_reward_scale: float = 0.01,
@@ -57,7 +57,7 @@ def main(
5757
update_step_num_gradient_steps_per_sample=update_per_step,
5858
batch_size=batch_size,
5959
num_train_envs=num_train_envs,
60-
num_test_envs=test_num,
60+
num_test_envs=num_test_envs,
6161
buffer_size=buffer_size,
6262
collection_step_num_env_steps=collection_step_num_env_steps,
6363
replay_buffer_stack_num=frames_stack,

examples/discrete/discrete_dqn.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
def main() -> None:
1313
task = "CartPole-v1"
1414
lr, epoch, batch_size = 1e-3, 10, 64
15-
train_num, test_num = 10, 100
15+
num_train_envs, num_test_envs = 10, 100
1616
gamma, n_step, target_freq = 0.9, 3, 320
1717
buffer_size = 20000
1818
eps_train, eps_test = 0.1, 0.05
@@ -22,8 +22,8 @@ def main() -> None:
2222
# For other loggers, see https://tianshou.readthedocs.io/en/master/tutorials/logger.html
2323

2424
# You can also try SubprocVectorEnv, which will use parallelization
25-
train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(train_num)])
26-
test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(test_num)])
25+
train_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(num_train_envs)])
26+
test_envs = ts.env.DummyVectorEnv([lambda: gym.make(task) for _ in range(num_test_envs)])
2727

2828
from tianshou.utils.net.common import Net
2929

@@ -50,7 +50,7 @@ def main() -> None:
5050
train_collector = ts.data.Collector[CollectStats](
5151
algorithm,
5252
train_envs,
53-
ts.data.VectorReplayBuffer(buffer_size, train_num),
53+
ts.data.VectorReplayBuffer(buffer_size, num_train_envs),
5454
exploration_noise=True,
5555
)
5656
test_collector = ts.data.Collector[CollectStats](
@@ -74,7 +74,7 @@ def stop_fn(mean_rewards: float) -> bool:
7474
max_epochs=epoch,
7575
epoch_num_steps=epoch_num_steps,
7676
collection_step_num_env_steps=collection_step_num_env_steps,
77-
test_step_num_episodes=test_num,
77+
test_step_num_episodes=num_test_envs,
7878
batch_size=batch_size,
7979
update_step_num_gradient_steps_per_sample=1 / collection_step_num_env_steps,
8080
stop_fn=stop_fn,

examples/mujoco/fetch_her_ddpg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,14 @@ def get_args() -> argparse.Namespace:
8383
def make_fetch_env(
8484
task: str,
8585
num_train_envs: int,
86-
test_num: int,
86+
num_test_envs: int,
8787
) -> tuple[gym.Env, BaseVectorEnv, BaseVectorEnv]:
8888
env = TruncatedAsTerminated(gym.make(task))
8989
train_envs = ShmemVectorEnv(
9090
[lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(num_train_envs)],
9191
)
9292
test_envs = ShmemVectorEnv(
93-
[lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(test_num)],
93+
[lambda: TruncatedAsTerminated(gym.make(task)) for _ in range(num_test_envs)],
9494
)
9595
return env, train_envs, test_envs
9696

examples/mujoco/mujoco_a2c_hl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def main(
3232
update_step_num_repetitions: int = 1,
3333
batch_size: int = 16,
3434
num_train_envs: int = 16,
35-
test_num: int = 10,
35+
num_test_envs: int = 10,
3636
return_scaling: bool = True,
3737
vf_coef: float = 0.5,
3838
ent_coef: float = 0.01,
@@ -48,7 +48,7 @@ def main(
4848
epoch_num_steps=epoch_num_steps,
4949
batch_size=batch_size,
5050
num_train_envs=num_train_envs,
51-
num_test_envs=test_num,
51+
num_test_envs=num_test_envs,
5252
buffer_size=buffer_size,
5353
collection_step_num_env_steps=collection_step_num_env_steps,
5454
update_step_num_repetitions=update_step_num_repetitions,

examples/mujoco/mujoco_ddpg_hl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def main(
3434
n_step: int = 1,
3535
batch_size: int = 256,
3636
num_train_envs: int = 1,
37-
test_num: int = 10,
37+
num_test_envs: int = 10,
3838
) -> None:
3939
log_name = os.path.join(task, "ddpg", str(experiment_config.seed), datetime_tag())
4040

@@ -43,7 +43,7 @@ def main(
4343
epoch_num_steps=epoch_num_steps,
4444
batch_size=batch_size,
4545
num_train_envs=num_train_envs,
46-
num_test_envs=test_num,
46+
num_test_envs=num_test_envs,
4747
buffer_size=buffer_size,
4848
collection_step_num_env_steps=collection_step_num_env_steps,
4949
update_step_num_gradient_steps_per_sample=update_per_step,

examples/mujoco/mujoco_ppo_hl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ def main(
3131
update_step_num_repetitions: int = 10,
3232
batch_size: int = 64,
3333
num_train_envs: int = 10,
34-
test_num: int = 10,
34+
num_test_envs: int = 10,
3535
return_scaling: bool = True,
3636
vf_coef: float = 0.25,
3737
ent_coef: float = 0.0,
@@ -52,7 +52,7 @@ def main(
5252
epoch_num_steps=epoch_num_steps,
5353
batch_size=batch_size,
5454
num_train_envs=num_train_envs,
55-
num_test_envs=test_num,
55+
num_test_envs=num_test_envs,
5656
buffer_size=buffer_size,
5757
collection_step_num_env_steps=collection_step_num_env_steps,
5858
update_step_num_repetitions=update_step_num_repetitions,

examples/mujoco/mujoco_redq_hl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def main(
4040
batch_size: int = 256,
4141
target_mode: Literal["mean", "min"] = "min",
4242
num_train_envs: int = 1,
43-
test_num: int = 10,
43+
num_test_envs: int = 10,
4444
) -> None:
4545
log_name = os.path.join(task, "redq", str(experiment_config.seed), datetime_tag())
4646

@@ -49,7 +49,7 @@ def main(
4949
epoch_num_steps=epoch_num_steps,
5050
batch_size=batch_size,
5151
num_train_envs=num_train_envs,
52-
num_test_envs=test_num,
52+
num_test_envs=num_test_envs,
5353
buffer_size=buffer_size,
5454
collection_step_num_env_steps=collection_step_num_env_steps,
5555
update_step_num_gradient_steps_per_sample=update_per_step,

0 commit comments

Comments
 (0)