@@ -84,7 +84,7 @@ def test_dqn(args=get_args()):
8484 # replay buffer: `save_last_obs` and `stack_num` can be removed together
8585 # when you have enough RAM
8686 buffer = ReplayBuffer (args .buffer_size , ignore_obs_next = True ,
87- save_last_obs = True , stack_num = args .frames_stack )
87+ save_only_last_obs = True , stack_num = args .frames_stack )
8888 # collector
8989 train_collector = Collector (policy , train_envs , buffer )
9090 test_collector = Collector (policy , test_envs )
@@ -100,17 +100,19 @@ def stop_fn(x):
100100 return x >= env .spec .reward_threshold
101101 elif 'Pong' in args .task :
102102 return x >= 20
103+ else :
104+ return False
103105
104106 def train_fn (x ):
105107 # nature DQN setting, linear decay in the first 1M steps
106108 now = x * args .collect_per_step * args .step_per_epoch
107109 if now <= 1e6 :
108110 eps = args .eps_train - now / 1e6 * \
109111 (args .eps_train - args .eps_train_final )
110- policy .set_eps (eps )
111112 else :
112- policy .set_eps (args .eps_train_final )
113- print ("set eps =" , policy .eps )
113+ eps = args .eps_train_final
114+ policy .set_eps (eps )
115+ writer .add_scalar ('train/eps' , eps , global_step = now )
114116
115117 def test_fn (x ):
116118 policy .set_eps (args .eps_test )
0 commit comments