1212def main () -> None :
1313 task = "CartPole-v1"
1414 lr , epoch , batch_size = 1e-3 , 10 , 64
15- train_num , test_num = 10 , 100
15+ num_train_envs , num_test_envs = 10 , 100
1616 gamma , n_step , target_freq = 0.9 , 3 , 320
1717 buffer_size = 20000
1818 eps_train , eps_test = 0.1 , 0.05
@@ -22,8 +22,8 @@ def main() -> None:
2222 # For other loggers, see https://tianshou.readthedocs.io/en/master/tutorials/logger.html
2323
2424 # You can also try SubprocVectorEnv, which will use parallelization
25- train_envs = ts .env .DummyVectorEnv ([lambda : gym .make (task ) for _ in range (train_num )])
26- test_envs = ts .env .DummyVectorEnv ([lambda : gym .make (task ) for _ in range (test_num )])
25+ train_envs = ts .env .DummyVectorEnv ([lambda : gym .make (task ) for _ in range (num_train_envs )])
26+ test_envs = ts .env .DummyVectorEnv ([lambda : gym .make (task ) for _ in range (num_test_envs )])
2727
2828 from tianshou .utils .net .common import Net
2929
@@ -50,7 +50,7 @@ def main() -> None:
5050 train_collector = ts .data .Collector [CollectStats ](
5151 algorithm ,
5252 train_envs ,
53- ts .data .VectorReplayBuffer (buffer_size , train_num ),
53+ ts .data .VectorReplayBuffer (buffer_size , num_train_envs ),
5454 exploration_noise = True ,
5555 )
5656 test_collector = ts .data .Collector [CollectStats ](
@@ -74,7 +74,7 @@ def stop_fn(mean_rewards: float) -> bool:
7474 max_epochs = epoch ,
7575 epoch_num_steps = epoch_num_steps ,
7676 collection_step_num_env_steps = collection_step_num_env_steps ,
77- test_step_num_episodes = test_num ,
77+ test_step_num_episodes = num_test_envs ,
7878 batch_size = batch_size ,
7979 update_step_num_gradient_steps_per_sample = 1 / collection_step_num_env_steps ,
8080 stop_fn = stop_fn ,
0 commit comments