@@ -518,9 +518,12 @@ def add_algo_arguments(parser):
518518 reset_arg (parser , "--seed" , type = int , default = 1234 )
519519 reset_arg (parser , "--clip-grad" , type = float , default = 1.0 )
520520 reset_arg (parser , "--calculate-per-token-loss" , action = "store_true" )
521+ reset_arg (parser , "--lr" , type = float , default = 1e-6 )
521522
523+ parser .add_argument ("--num-critic-only-steps" , type = int , default = 0 , help = "Number of critic only steps" )
522524 parser .add_argument ("--critic-load" , type = str , default = None , help = "The checkpoint for critic model." )
523525 parser .add_argument ("--critic-save" , type = str , default = None , help = "The checkpoint for critic model." )
526+ parser .add_argument ("--critic-lr" , type = float , default = None , help = "The lr for critic model" )
524527
525528 parser .add_argument ("--eps-clip" , type = float , default = 0.2 , help = "PPO clip range" )
526529 parser .add_argument ("--eps-clip-high" , type = float , default = None , help = "PPO clip upper range" )
@@ -984,9 +987,6 @@ def slime_validate_args(args):
984987 args .ckpt_step = args .ref_ckpt_step
985988 args .start_rollout_id = 0
986989
987- if args .critic_load is None :
988- args .critic_load = args .load
989-
990990 if args .eval_interval is not None :
991991 assert args .eval_prompt_data is not None , "eval_prompt_data must be set when eval_interval is set"
992992 if len (args .eval_prompt_data ) == 1 :
@@ -1032,6 +1032,10 @@ def slime_validate_args(args):
10321032 args .critic_num_gpus_per_node = args .actor_num_gpus_per_node
10331033 if args .critic_num_nodes is None :
10341034 args .critic_num_nodes = args .actor_num_nodes
1035+ if args .critic_load is None :
1036+ args .critic_load = args .load
1037+ if args .critic_lr is None :
1038+ args .critic_lr = args .lr
10351039
10361040 if args .debug_rollout_only :
10371041 if args .colocate and args .rollout_num_gpus is None :
0 commit comments