Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1570,7 +1570,13 @@ def prepare(self, *args, device_placement=None):
return result if len(result) > 1 else result[0]

def _prepare_tp(self, *args):
result = list(args)
# First pass: prepare everything except schedulers (and model, which is prepared separately below)
result = [
self._prepare_one(obj, first_pass=True) if not isinstance(obj, torch.nn.Module) else obj for obj in args
]

# Second pass: prepare schedulers
result = [self._prepare_one(obj) if not isinstance(obj, torch.nn.Module) else obj for obj in result]

device_mesh = self.torch_device_mesh

Expand Down Expand Up @@ -1617,7 +1623,14 @@ def _get_tensor_address(p):
# so that the optimizer can correctly update the model parameters.
param_group["params"] = [mapping[_get_tensor_address(p)] for p in param_group["params"]]

return args
for item in result:
if any(
item in container
for container in (self._dataloaders, self._models, self._optimizers, self._schedulers)
):
item._is_accelerate_prepared = True

return result

def _prepare_cp(self, *args):
from torch.distributed.tensor.experimental import context_parallel
Expand Down