Skip to content

Commit 00d6bbc

Browse files
authored
[bugfix] fix optimizer deepspeed (#9173)
1 parent b52b2d4 commit 00d6bbc

2 files changed

Lines changed: 15 additions & 4 deletions

File tree

swift/optimizers/base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from torch.optim import Optimizer
2+
from transformers.trainer import Trainer as HfTrainer
23
from typing import TYPE_CHECKING
34

45
try:
@@ -48,7 +49,7 @@ def create_optimizer_and_scheduler(self, num_training_steps: int) -> None:
4849
trainer.scheduler = self.create_scheduler(num_training_steps, trainer.optimizer)
4950

5051
def create_optimizer(self) -> Optimizer:
51-
return self.trainer.create_optimizer()
52+
return HfTrainer.create_optimizer(self.trainer)
5253

5354
def create_scheduler(self, num_training_steps: int, optimizer: Optimizer) -> LRScheduler:
54-
return self.trainer.create_scheduler(num_training_steps, optimizer)
55+
return HfTrainer.create_scheduler(self.trainer, num_training_steps, optimizer)

swift/trainers/mixin.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ def __init__(self,
7777
self.padding_free = self.template.padding_free
7878
self.task_type = self.template.task_type
7979
self.problem_type = getattr(model.config, 'problem_type', None)
80+
self.optimizer_callback = optimizers_map[args.optimizer or 'default'](args, self)
8081
if args.check_model and hasattr(model, 'model_dir'):
8182
with ms_logger_context(logging.CRITICAL), patch_modelscope_hub_timeout():
8283
config_info = self._collect_config_info()
@@ -983,8 +984,17 @@ def create_loss_and_eval_metric(self, args):
983984
return res
984985

985986
def create_optimizer_and_scheduler(self, num_training_steps: int):
986-
optimizer_callback: OptimizerCallback = optimizers_map[self.args.optimizer or 'default'](self.args, self)
987-
optimizer_callback.create_optimizer_and_scheduler(num_training_steps)
987+
self.optimizer_callback.create_optimizer_and_scheduler(num_training_steps)
988+
989+
def create_optimizer(self):
990+
self.optimizer = self.optimizer_callback.create_optimizer()
991+
if self.optimizer is not None:
992+
self.optimizer.param_groups = [pg for pg in self.optimizer.param_groups if len(pg['params']) > 0]
993+
return self.optimizer
994+
995+
def create_scheduler(self, num_training_steps: int, optimizer=None):
996+
self.lr_scheduler = self.optimizer_callback.create_scheduler(num_training_steps, optimizer)
997+
return self.lr_scheduler
988998

989999
@staticmethod
9901000
def _get_listwise_reranker_preds(logits, labels):

0 commit comments

Comments
 (0)