@@ -1588,33 +1588,39 @@ def _prepare_tp(self, *args):
15881588
15891589 old_named_params = self ._get_named_parameters (* tuple (result ), drop_refs = True )
15901590
1591- for arg in result :
1592- if not isinstance (arg , torch .nn .Module ):
1593- continue
1591+ from torch .distributed .tensor import DTensor
15941592
1595- from torch .distributed .tensor import DTensor , Replicate
1596- from transformers .integrations .tensor_parallel import ReplicateParallel
1593+ if self .is_fsdp2 :
1594+ for arg in result :
1595+ if not isinstance (arg , torch .nn .Module ):
1596+ continue
15971597
1598- model : torch .nn . Module = arg
1599- tp_plan = ReplicateParallel
1598+ from torch .distributed . tensor import Replicate
1599+ from transformers . integrations . tensor_parallel import ReplicateParallel
16001600
1601- for name , param in model .named_parameters ():
1602- if isinstance (param , DTensor ):
1603- continue
1601+ model : torch .nn .Module = arg
1602+ tp_plan = ReplicateParallel
16041603
1605- dp = DTensor .from_local (param , device_mesh = device_mesh ["tp" ], placements = [Replicate ()])
1606- param_name , param_type = name .rsplit ("." , 1 )
1607- module_to_tp = model .get_submodule (param_name )
1604+ for name , param in model .named_parameters ():
1605+ if isinstance (param , DTensor ):
1606+ continue
1607+
1608+ dp = DTensor .from_local (param , device_mesh = device_mesh ["tp" ], placements = [Replicate ()])
1609+ param_name , param_type = name .rsplit ("." , 1 )
1610+ module_to_tp = model .get_submodule (param_name )
16081611
1609- tp_plan ().prepare_module_tp (module_to_tp , device_mesh ["tp" ])
1610- if not isinstance (dp , torch .nn .Parameter ):
1611- dp = torch .nn .Parameter (dp , requires_grad = param .requires_grad )
1612- setattr (module_to_tp , param_type , dp )
1612+ tp_plan ().prepare_module_tp (module_to_tp , device_mesh ["tp" ])
1613+ if not isinstance (dp , torch .nn .Parameter ):
1614+ dp = torch .nn .Parameter (dp , requires_grad = param .requires_grad )
1615+ setattr (module_to_tp , param_type , dp )
16131616
16141617 new_named_params = self ._get_named_parameters (* tuple (result ), drop_refs = False )
16151618 # Build a map from old to new params
16161619 mapping = {p : new_named_params [n ] for n , p in old_named_params .items ()}
16171620
1621+ if not mapping :
1622+ return result
1623+
16181624 def _get_tensor_address (p ):
16191625 if isinstance (p , DTensor ):
16201626 return p ._local_tensor .data_ptr ()
0 commit comments