@@ -620,32 +620,28 @@ def init_accelerate(model, optimizer, training_dataloader, scheduler, grad_accum
620620 _precision = "bf16"
621621 accelerator = Accelerator (gradient_accumulation_steps = grad_accum_steps , mixed_precision = _precision )
622622 if isinstance (model , torch .nn .Module ):
623- model = accelerator .prepare (model )
623+ model = accelerator .prepare_model (model )
624624
625- if isinstance (optimizer , torch .optim .Optimizer ):
626- optimizer = accelerator .prepare (optimizer )
627- elif isinstance (optimizer , dict ):
625+ if isinstance (optimizer , dict ):
628626 for key , optim in optimizer .items ():
629- optimizer [key ] = accelerator .prepare (optim )
627+ optimizer [key ] = accelerator .prepare_optimizer (optim )
630628 elif isinstance (optimizer , list ):
631629 for i , optim in enumerate (optimizer ):
632- optimizer [i ] = accelerator .prepare (optim )
630+ optimizer [i ] = accelerator .prepare_optimizer (optim )
633631 elif optimizer is not None :
634- raise ValueError ( "Optimizer must be a dict, list or torch.optim.Optimizer" )
632+ optimizer = accelerator . prepare_optimizer ( optimizer )
635633
636634 if isinstance (training_dataloader , torch .utils .data .DataLoader ):
637- training_dataloader = accelerator .prepare (training_dataloader )
635+ training_dataloader = accelerator .prepare_data_loader (training_dataloader )
638636
639- if isinstance (scheduler , torch .optim .lr_scheduler ._LRScheduler ): # pylint:disable=protected-access
640- scheduler = accelerator .prepare (scheduler )
641- elif isinstance (scheduler , dict ):
637+ if isinstance (scheduler , dict ):
642638 for key , sched in scheduler .items ():
643- scheduler [key ] = accelerator .prepare (sched )
639+ scheduler [key ] = accelerator .prepare_scheduler (sched )
644640 elif isinstance (scheduler , list ):
645641 for i , sched in enumerate (scheduler ):
646- scheduler [i ] = accelerator .prepare (sched )
642+ scheduler [i ] = accelerator .prepare_scheduler (sched )
647643 elif scheduler is not None :
648- raise ValueError ( "Scheduler must be a dict, list or torch.optim.lr_scheduler._LRScheduler" )
644+ scheduler = accelerator . prepare_scheduler ( scheduler )
649645
650646 return model , optimizer , training_dataloader , scheduler , accelerator
651647
0 commit comments