@@ -20,12 +20,14 @@ def get_args():
2020 parser .add_argument ('--task' , type = str , default = 'Pendulum-v0' )
2121 parser .add_argument ('--seed' , type = int , default = 0 )
2222 parser .add_argument ('--buffer-size' , type = int , default = 20000 )
23- parser .add_argument ('--actor-lr' , type = float , default = 3e-4 )
23+ parser .add_argument ('--actor-lr' , type = float , default = 1e-3 )
2424 parser .add_argument ('--critic-lr' , type = float , default = 1e-3 )
2525 parser .add_argument ('--il-lr' , type = float , default = 1e-3 )
2626 parser .add_argument ('--gamma' , type = float , default = 0.99 )
2727 parser .add_argument ('--tau' , type = float , default = 0.005 )
2828 parser .add_argument ('--alpha' , type = float , default = 0.2 )
29+ parser .add_argument ('--auto-alpha' , type = int , default = 1 )
30+ parser .add_argument ('--alpha-lr' , type = float , default = 3e-4 )
2931 parser .add_argument ('--epoch' , type = int , default = 5 )
3032 parser .add_argument ('--step-per-epoch' , type = int , default = 24000 )
3133 parser .add_argument ('--il-step-per-epoch' , type = int , default = 500 )
@@ -41,7 +43,7 @@ def get_args():
4143 parser .add_argument ('--logdir' , type = str , default = 'log' )
4244 parser .add_argument ('--render' , type = float , default = 0. )
4345 parser .add_argument ('--rew-norm' , action = "store_true" , default = False )
44- parser .add_argument ('--n-step' , type = int , default = 4 )
46+ parser .add_argument ('--n-step' , type = int , default = 3 )
4547 parser .add_argument (
4648 '--device' , type = str ,
4749 default = 'cuda' if torch .cuda .is_available () else 'cpu' )
@@ -85,6 +87,13 @@ def test_sac_with_il(args=get_args()):
8587 concat = True , device = args .device )
8688 critic2 = Critic (net_c2 , device = args .device ).to (args .device )
8789 critic2_optim = torch .optim .Adam (critic2 .parameters (), lr = args .critic_lr )
90+
91+ if args .auto_alpha :
92+ target_entropy = - np .prod (env .action_space .shape )
93+ log_alpha = torch .zeros (1 , requires_grad = True , device = args .device )
94+ alpha_optim = torch .optim .Adam ([log_alpha ], lr = args .alpha_lr )
95+ args .alpha = (target_entropy , log_alpha , alpha_optim )
96+
8897 policy = SACPolicy (
8998 actor , actor_optim , critic1 , critic1_optim , critic2 , critic2_optim ,
9099 tau = args .tau , gamma = args .gamma , alpha = args .alpha ,
@@ -135,18 +144,20 @@ def stop_fn(mean_rewards):
135144 args .action_shape , max_action = args .max_action , device = args .device
136145 ).to (args .device )
137146 optim = torch .optim .Adam (net .parameters (), lr = args .il_lr )
138- il_policy = ImitationPolicy (net , optim , mode = 'continuous' )
147+ il_policy = ImitationPolicy (
148+ net , optim , mode = 'continuous' , action_space = env .action_space ,
149+ action_scaling = True , action_bound_method = "clip" )
139150 il_test_collector = Collector (
140151 il_policy ,
141- DummyVectorEnv (
142- [lambda : gym .make (args .task ) for _ in range (args .test_num )])
152+ DummyVectorEnv ([lambda : gym .make (args .task ) for _ in range (args .test_num )])
143153 )
144154 train_collector .reset ()
145155 result = offpolicy_trainer (
146156 il_policy , train_collector , il_test_collector , args .epoch ,
147157 args .il_step_per_epoch , args .step_per_collect , args .test_num ,
148158 args .batch_size , stop_fn = stop_fn , save_fn = save_fn , logger = logger )
149159 assert stop_fn (result ['best_reward' ])
160+
150161 if __name__ == '__main__' :
151162 pprint .pprint (result )
152163 # Let's watch its performance!
0 commit comments