Skip to content

Commit bce8b12

Browse files
authored
Fix accelerate init (#116)
1 parent c5a6783 commit bce8b12

File tree

1 file changed

+10
-14
lines changed

1 file changed

+10
-14
lines changed

trainer/trainer.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -620,32 +620,28 @@ def init_accelerate(model, optimizer, training_dataloader, scheduler, grad_accum
620620
_precision = "bf16"
621621
accelerator = Accelerator(gradient_accumulation_steps=grad_accum_steps, mixed_precision=_precision)
622622
if isinstance(model, torch.nn.Module):
623-
model = accelerator.prepare(model)
623+
model = accelerator.prepare_model(model)
624624

625-
if isinstance(optimizer, torch.optim.Optimizer):
626-
optimizer = accelerator.prepare(optimizer)
627-
elif isinstance(optimizer, dict):
625+
if isinstance(optimizer, dict):
628626
for key, optim in optimizer.items():
629-
optimizer[key] = accelerator.prepare(optim)
627+
optimizer[key] = accelerator.prepare_optimizer(optim)
630628
elif isinstance(optimizer, list):
631629
for i, optim in enumerate(optimizer):
632-
optimizer[i] = accelerator.prepare(optim)
630+
optimizer[i] = accelerator.prepare_optimizer(optim)
633631
elif optimizer is not None:
634-
raise ValueError("Optimizer must be a dict, list or torch.optim.Optimizer")
632+
optimizer = accelerator.prepare_optimizer(optimizer)
635633

636634
if isinstance(training_dataloader, torch.utils.data.DataLoader):
637-
training_dataloader = accelerator.prepare(training_dataloader)
635+
training_dataloader = accelerator.prepare_data_loader(training_dataloader)
638636

639-
if isinstance(scheduler, torch.optim.lr_scheduler._LRScheduler): # pylint:disable=protected-access
640-
scheduler = accelerator.prepare(scheduler)
641-
elif isinstance(scheduler, dict):
637+
if isinstance(scheduler, dict):
642638
for key, sched in scheduler.items():
643-
scheduler[key] = accelerator.prepare(sched)
639+
scheduler[key] = accelerator.prepare_scheduler(sched)
644640
elif isinstance(scheduler, list):
645641
for i, sched in enumerate(scheduler):
646-
scheduler[i] = accelerator.prepare(sched)
642+
scheduler[i] = accelerator.prepare_scheduler(sched)
647643
elif scheduler is not None:
648-
raise ValueError("Scheduler must be a dict, list or torch.optim.lr_scheduler._LRScheduler")
644+
scheduler = accelerator.prepare_scheduler(scheduler)
649645

650646
return model, optimizer, training_dataloader, scheduler, accelerator
651647

0 commit comments

Comments
 (0)