@@ -61,6 +61,22 @@ def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> Li
6161 return [single_size ] * blocks
6262
6363
64+ def _get_parameter_tp_plan (parameter_name : str , tp_plan : dict [str , str ]) -> Optional [str ]:
65+ """
66+ Get the TP style for a parameter from the TP plan.
67+
68+ The TP plan is a dictionary that maps parameter names to TP styles.
69+ The parameter name can be a generic name with wildcards (e.g. "*.weight") or a specific name (e.g. "layer_1.weight").
70+ """
71+ generic_param_name = re .sub (r"\d+" , "*" , parameter_name )
72+ if generic_param_name in tp_plan :
73+ return tp_plan [generic_param_name ]
74+ elif "." in generic_param_name and generic_param_name .rsplit ("." , 1 )[0 ] in tp_plan :
75+ return tp_plan [generic_param_name .rsplit ("." , 1 )[0 ]]
76+ else :
77+ return None
78+
79+
6480str_to_torch_dtype = {
6581 "BOOL" : torch .bool ,
6682 "U8" : torch .uint8 ,
@@ -138,6 +154,71 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim):
138154 return tensor .to (str_to_torch_dtype [slice_dtype ])
139155
140156
157+ def repack_weights (
158+ packed_parameter : torch .Tensor ,
159+ sharded_dim : int , # The dimension index in the global tensor that was sharded
160+ world_size : int ,
161+ num_blocks : int = 2 ,
162+ ) -> torch .Tensor :
163+ """
164+ Reorders a tensor that was reconstructed from sharded packed weights into its canonical packed format.
165+
166+ For example, if a weight was packed (e.g., gate_proj and up_proj) and then sharded,
167+ DTensor.full_tensor() might produce an interleaved layout like [G0, U0, G1, U1, ...]
168+ along the sharded dimension. This function reorders it to [G0, G1, ..., U0, U1, ...].
169+ This is an inverse operation to get_packed_weights.
170+
171+ Args:
172+ reconstructed_tensor: The tensor reconstructed from DTensor (e.g., via .full_tensor().contiguous()).
173+ sharded_dim: The dimension index in the reconstructed_tensor that was originally sharded.
174+ world_size: The tensor parallel world size.
175+ num_packed_projs: The number of projections that were packed together (e.g., 2 for gate_up_proj).
176+
177+ Returns:
178+ The reordered tensor in canonical packed format.
179+ """
180+
181+ if num_blocks != 2 :
182+ raise ValueError (
183+ "Num blocks different from 2 is not supported yet. This is most likely a bug in your implementation as we only pack gate and up projections together."
184+ )
185+
186+ actual_sharded_dim = sharded_dim if sharded_dim >= 0 else sharded_dim + packed_parameter .ndim
187+ total_size_on_sharded_dim = packed_parameter .shape [actual_sharded_dim ]
188+ original_block_size_on_dim = total_size_on_sharded_dim // num_blocks
189+ shard_chunk_size = original_block_size_on_dim // world_size
190+
191+ prefix_shape = packed_parameter .shape [:actual_sharded_dim ]
192+ suffix_shape = packed_parameter .shape [actual_sharded_dim + 1 :]
193+
194+ tensor_view = packed_parameter .view (
195+ * prefix_shape ,
196+ world_size ,
197+ num_blocks ,
198+ shard_chunk_size ,
199+ * suffix_shape ,
200+ )
201+
202+ # Permute to bring num_packed_projs first, then world_size, then shard_chunk_size
203+ # This groups all chunks of G together, then all chunks of U together.
204+ # Target order of these middle dimensions: (num_packed_projs, world_size, shard_chunk_size)
205+ # Current order of view's middle dimensions: (world_size, num_packed_projs, shard_chunk_size)
206+ # Absolute indices of the dimensions to be permuted (world_size, num_packed_projs)
207+ axis_ws_abs = len (prefix_shape )
208+ axis_npp_abs = len (prefix_shape ) + 1
209+
210+ permute_order = list (range (tensor_view .ndim ))
211+ permute_order [axis_ws_abs ], permute_order [axis_npp_abs ] = permute_order [axis_npp_abs ], permute_order [axis_ws_abs ]
212+
213+ tensor_permuted = tensor_view .permute (* permute_order )
214+
215+ # Reshape back to the original tensor's ndim, with the sharded dimension now correctly ordered as [G_all, U_all].
216+ # The final shape should be the same as reconstructed_tensor.
217+ final_ordered_tensor = tensor_permuted .reshape_as (packed_parameter )
218+
219+ return final_ordered_tensor
220+
221+
141222def get_tensor_shard (param , empty_param , device_mesh , rank , dim ):
142223 if dim == 0 :
143224 size_ = empty_param .shape [0 ]
@@ -578,6 +659,49 @@ def translate_to_torch_parallel_style(style: str):
578659 raise ValueError (f"Unsupported parallel style value: { style } " )
579660
580661
662+ def convert_local_tensor_to_dtensor (
663+ parameter : torch .Tensor , parameter_name : str , device_mesh , tp_plan : dict [str , str ]
664+ ) -> DTensor :
665+ """
666+ Converts a local variant of weights to a DTensor with corresponding placements. Shouldn't be done ever except of before saving the model.
667+ """
668+ _ , param_type = parameter_name .rsplit ("." , 1 ) if "." in parameter_name else parameter_name
669+ tp_style = _get_parameter_tp_plan (parameter_name , tp_plan )
670+ if not tp_style :
671+ return parameter
672+
673+ if tp_style not in ["local_packed_rowwise" , "local_rowwise" , "local_colwise" ]:
674+ return parameter
675+ # TODO: this logic should be wrapped in a function, this is copied from corresponding tp classes.
676+ if tp_style == "local_packed_rowwise" :
677+ placements = [Shard (- 1 )]
678+ elif tp_style == "local_rowwise" :
679+ if param_type == "bias" :
680+ placements = [Replicate ()]
681+ else :
682+ placements = [Shard (- 1 )]
683+ elif tp_style == "local_colwise" :
684+ if param_type == "bias" :
685+ placements = [Shard (- 1 )]
686+ else :
687+ placements = [Shard (- 2 )]
688+ return DTensor .from_local (parameter , device_mesh , placements , run_check = False )
689+
690+
691+ def replace_state_dict_local_with_dtensor (
692+ state_dict : dict [str , torch .Tensor ],
693+ tp_plan : dict [str , str ],
694+ device_mesh ,
695+ ) -> dict [str , torch .Tensor ]:
696+ """
697+ Replaces all tensors that were sharded with `local_*` strategy with DTensor to make determining their proper size possible.
698+ """
699+ for key , value in state_dict .items ():
700+ if isinstance (value , torch .Tensor ) and not isinstance (value , DTensor ):
701+ state_dict [key ] = convert_local_tensor_to_dtensor (value , key , device_mesh , tp_plan )
702+ return state_dict
703+
704+
581705def add_tensor_parallel_hooks_to_module (model , module , tp_plan , layer_name , current_module_plan , device_mesh ):
582706 """
583707 Add hooks to the module holding the layer. Meaning:
@@ -632,13 +756,9 @@ def shard_and_distribute_module(
632756 param_name , param_type = parameter_name .rsplit ("." , 1 ) if "." in parameter_name else parameter_name
633757 tp_plan = model ._tp_plan
634758 module_to_tp = model .get_submodule (param_name )
635- current_module_plan = None
636759 rank = int (rank )
637- generic_param_name = re .sub (r"\d+" , "*" , parameter_name )
638- if generic_param_name in tp_plan :
639- current_module_plan = tp_plan [generic_param_name ]
640- elif "." in generic_param_name and generic_param_name .rsplit ("." , 1 )[0 ] in tp_plan :
641- current_module_plan = tp_plan [generic_param_name .rsplit ("." , 1 )[0 ]]
760+
761+ current_module_plan = _get_parameter_tp_plan (parameter_name , tp_plan )
642762
643763 # Add hooks to the module if not done yet
644764 # add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh)
0 commit comments