@@ -77,6 +77,7 @@ def __init__(self,
7777 self .padding_free = self .template .padding_free
7878 self .task_type = self .template .task_type
7979 self .problem_type = getattr (model .config , 'problem_type' , None )
80+ self .optimizer_callback = optimizers_map [args .optimizer or 'default' ](args , self )
8081 if args .check_model and hasattr (model , 'model_dir' ):
8182 with ms_logger_context (logging .CRITICAL ), patch_modelscope_hub_timeout ():
8283 config_info = self ._collect_config_info ()
@@ -983,8 +984,17 @@ def create_loss_and_eval_metric(self, args):
983984 return res
984985
985986 def create_optimizer_and_scheduler (self , num_training_steps : int ):
986- optimizer_callback : OptimizerCallback = optimizers_map [self .args .optimizer or 'default' ](self .args , self )
987- optimizer_callback .create_optimizer_and_scheduler (num_training_steps )
987+ self .optimizer_callback .create_optimizer_and_scheduler (num_training_steps )
988+
989+ def create_optimizer (self ):
990+ self .optimizer = self .optimizer_callback .create_optimizer ()
991+ if self .optimizer is not None :
992+ self .optimizer .param_groups = [pg for pg in self .optimizer .param_groups if len (pg ['params' ]) > 0 ]
993+ return self .optimizer
994+
995+ def create_scheduler (self , num_training_steps : int , optimizer = None ):
996+ self .lr_scheduler = self .optimizer_callback .create_scheduler (num_training_steps , optimizer )
997+ return self .lr_scheduler
988998
989999 @staticmethod
9901000 def _get_listwise_reranker_preds (logits , labels ):
0 commit comments