1717from tianshou .utils import TensorboardLogger
1818from tianshou .utils .net .common import Recurrent
1919from tianshou .utils .space_info import SpaceInfo
20+ from tianshou .utils .torch_utils import policy_within_training_step
2021
2122
2223def get_args () -> argparse .Namespace :
@@ -92,6 +93,7 @@ def test_drqn(args: argparse.Namespace = get_args(), enable_assertions: bool = T
9293 n_step_return_horizon = args .n_step ,
9394 target_update_freq = args .target_update_freq ,
9495 )
96+
9597 # collector
9698 buffer = VectorReplayBuffer (
9799 args .buffer_size ,
@@ -102,8 +104,12 @@ def test_drqn(args: argparse.Namespace = get_args(), enable_assertions: bool = T
102104 train_collector = Collector [CollectStats ](algorithm , train_envs , buffer , exploration_noise = True )
103105 # the stack_num is for RNN training: sample framestack obs
104106 test_collector = Collector [CollectStats ](algorithm , test_envs , exploration_noise = True )
105- train_collector .reset ()
106- train_collector .collect (n_step = args .batch_size * args .training_num )
107+
108+ # initial data collection
109+ with policy_within_training_step (policy ):
110+ train_collector .reset ()
111+ train_collector .collect (n_step = args .batch_size * args .training_num )
112+
107113 # log
108114 log_path = os .path .join (args .logdir , args .task , "drqn" )
109115 writer = SummaryWriter (log_path )
0 commit comments