Skip to content

Commit a2dc628

Browse files
committed
fix tp only bug
Signed-off-by: Wang, Yi <yi.a.wang@intel.com>
1 parent 36479b8 commit a2dc628

File tree

1 file changed

+23
-17
lines changed

1 file changed

+23
-17
lines changed

src/accelerate/accelerator.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1588,33 +1588,39 @@ def _prepare_tp(self, *args):
15881588

15891589
old_named_params = self._get_named_parameters(*tuple(result), drop_refs=True)
15901590

1591-
for arg in result:
1592-
if not isinstance(arg, torch.nn.Module):
1593-
continue
1591+
from torch.distributed.tensor import DTensor
15941592

1595-
from torch.distributed.tensor import DTensor, Replicate
1596-
from transformers.integrations.tensor_parallel import ReplicateParallel
1593+
if self.is_fsdp2:
1594+
for arg in result:
1595+
if not isinstance(arg, torch.nn.Module):
1596+
continue
15971597

1598-
model: torch.nn.Module = arg
1599-
tp_plan = ReplicateParallel
1598+
from torch.distributed.tensor import Replicate
1599+
from transformers.integrations.tensor_parallel import ReplicateParallel
16001600

1601-
for name, param in model.named_parameters():
1602-
if isinstance(param, DTensor):
1603-
continue
1601+
model: torch.nn.Module = arg
1602+
tp_plan = ReplicateParallel
16041603

1605-
dp = DTensor.from_local(param, device_mesh=device_mesh["tp"], placements=[Replicate()])
1606-
param_name, param_type = name.rsplit(".", 1)
1607-
module_to_tp = model.get_submodule(param_name)
1604+
for name, param in model.named_parameters():
1605+
if isinstance(param, DTensor):
1606+
continue
1607+
1608+
dp = DTensor.from_local(param, device_mesh=device_mesh["tp"], placements=[Replicate()])
1609+
param_name, param_type = name.rsplit(".", 1)
1610+
module_to_tp = model.get_submodule(param_name)
16081611

1609-
tp_plan().prepare_module_tp(module_to_tp, device_mesh["tp"])
1610-
if not isinstance(dp, torch.nn.Parameter):
1611-
dp = torch.nn.Parameter(dp, requires_grad=param.requires_grad)
1612-
setattr(module_to_tp, param_type, dp)
1612+
tp_plan().prepare_module_tp(module_to_tp, device_mesh["tp"])
1613+
if not isinstance(dp, torch.nn.Parameter):
1614+
dp = torch.nn.Parameter(dp, requires_grad=param.requires_grad)
1615+
setattr(module_to_tp, param_type, dp)
16131616

16141617
new_named_params = self._get_named_parameters(*tuple(result), drop_refs=False)
16151618
# Build a map from old to new params
16161619
mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
16171620

1621+
if not mapping:
1622+
return result
1623+
16181624
def _get_tensor_address(p):
16191625
if isinstance(p, DTensor):
16201626
return p._local_tensor.data_ptr()

0 commit comments

Comments
 (0)