diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 5b8f8acb5a60..e788321b49bd 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -61,6 +61,22 @@ def _blocks_to_block_sizes(total_size: int, blocks: Union[int, List[int]]) -> Li return [single_size] * blocks +def _get_parameter_tp_plan(parameter_name: str, tp_plan: dict[str, str]) -> Optional[str]: + """ + Get the TP style for a parameter from the TP plan. + + The TP plan is a dictionary that maps parameter names to TP styles. + The parameter name can be a generic name with wildcards (e.g. "*.weight") or a specific name (e.g. "layer_1.weight"). + """ + generic_param_name = re.sub(r"\d+", "*", parameter_name) + if generic_param_name in tp_plan: + return tp_plan[generic_param_name] + elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan: + return tp_plan[generic_param_name.rsplit(".", 1)[0]] + else: + return None + + str_to_torch_dtype = { "BOOL": torch.bool, "U8": torch.uint8, @@ -138,6 +154,71 @@ def get_packed_weights(param, empty_param, device_mesh, rank, dim): return tensor.to(str_to_torch_dtype[slice_dtype]) +def repack_weights( + packed_parameter: torch.Tensor, + sharded_dim: int, # The dimension index in the global tensor that was sharded + world_size: int, + num_blocks: int = 2, +) -> torch.Tensor: + """ + Reorders a tensor that was reconstructed from sharded packed weights into its canonical packed format. + + For example, if a weight was packed (e.g., gate_proj and up_proj) and then sharded, + DTensor.full_tensor() might produce an interleaved layout like [G0, U0, G1, U1, ...] + along the sharded dimension. This function reorders it to [G0, G1, ..., U0, U1, ...]. + This is an inverse operation to get_packed_weights. + + Args: + reconstructed_tensor: The tensor reconstructed from DTensor (e.g., via .full_tensor().contiguous()). + sharded_dim: The dimension index in the reconstructed_tensor that was originally sharded. + world_size: The tensor parallel world size. + num_packed_projs: The number of projections that were packed together (e.g., 2 for gate_up_proj). + + Returns: + The reordered tensor in canonical packed format. + """ + + if num_blocks != 2: + raise ValueError( + "Num blocks different from 2 is not supported yet. This is most likely a bug in your implementation as we only pack gate and up projections together." + ) + + actual_sharded_dim = sharded_dim if sharded_dim >= 0 else sharded_dim + packed_parameter.ndim + total_size_on_sharded_dim = packed_parameter.shape[actual_sharded_dim] + original_block_size_on_dim = total_size_on_sharded_dim // num_blocks + shard_chunk_size = original_block_size_on_dim // world_size + + prefix_shape = packed_parameter.shape[:actual_sharded_dim] + suffix_shape = packed_parameter.shape[actual_sharded_dim + 1 :] + + tensor_view = packed_parameter.view( + *prefix_shape, + world_size, + num_blocks, + shard_chunk_size, + *suffix_shape, + ) + + # Permute to bring num_packed_projs first, then world_size, then shard_chunk_size + # This groups all chunks of G together, then all chunks of U together. + # Target order of these middle dimensions: (num_packed_projs, world_size, shard_chunk_size) + # Current order of view's middle dimensions: (world_size, num_packed_projs, shard_chunk_size) + # Absolute indices of the dimensions to be permuted (world_size, num_packed_projs) + axis_ws_abs = len(prefix_shape) + axis_npp_abs = len(prefix_shape) + 1 + + permute_order = list(range(tensor_view.ndim)) + permute_order[axis_ws_abs], permute_order[axis_npp_abs] = permute_order[axis_npp_abs], permute_order[axis_ws_abs] + + tensor_permuted = tensor_view.permute(*permute_order) + + # Reshape back to the original tensor's ndim, with the sharded dimension now correctly ordered as [G_all, U_all]. + # The final shape should be the same as reconstructed_tensor. + final_ordered_tensor = tensor_permuted.reshape_as(packed_parameter) + + return final_ordered_tensor + + def get_tensor_shard(param, empty_param, device_mesh, rank, dim): if dim == 0: size_ = empty_param.shape[0] @@ -578,6 +659,49 @@ def translate_to_torch_parallel_style(style: str): raise ValueError(f"Unsupported parallel style value: {style}") +def convert_local_tensor_to_dtensor( + parameter: torch.Tensor, parameter_name: str, device_mesh, tp_plan: dict[str, str] +) -> DTensor: + """ + Converts a local variant of weights to a DTensor with corresponding placements. Shouldn't be done ever except of before saving the model. + """ + _, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name + tp_style = _get_parameter_tp_plan(parameter_name, tp_plan) + if not tp_style: + return parameter + + if tp_style not in ["local_packed_rowwise", "local_rowwise", "local_colwise"]: + return parameter + # TODO: this logic should be wrapped in a function, this is copied from corresponding tp classes. + if tp_style == "local_packed_rowwise": + placements = [Shard(-1)] + elif tp_style == "local_rowwise": + if param_type == "bias": + placements = [Replicate()] + else: + placements = [Shard(-1)] + elif tp_style == "local_colwise": + if param_type == "bias": + placements = [Shard(-1)] + else: + placements = [Shard(-2)] + return DTensor.from_local(parameter, device_mesh, placements, run_check=False) + + +def replace_state_dict_local_with_dtensor( + state_dict: dict[str, torch.Tensor], + tp_plan: dict[str, str], + device_mesh, +) -> dict[str, torch.Tensor]: + """ + Replaces all tensors that were sharded with `local_*` strategy with DTensor to make determining their proper size possible. + """ + for key, value in state_dict.items(): + if isinstance(value, torch.Tensor) and not isinstance(value, DTensor): + state_dict[key] = convert_local_tensor_to_dtensor(value, key, device_mesh, tp_plan) + return state_dict + + def add_tensor_parallel_hooks_to_module(model, module, tp_plan, layer_name, current_module_plan, device_mesh): """ Add hooks to the module holding the layer. Meaning: @@ -632,13 +756,9 @@ def shard_and_distribute_module( param_name, param_type = parameter_name.rsplit(".", 1) if "." in parameter_name else parameter_name tp_plan = model._tp_plan module_to_tp = model.get_submodule(param_name) - current_module_plan = None rank = int(rank) - generic_param_name = re.sub(r"\d+", "*", parameter_name) - if generic_param_name in tp_plan: - current_module_plan = tp_plan[generic_param_name] - elif "." in generic_param_name and generic_param_name.rsplit(".", 1)[0] in tp_plan: - current_module_plan = tp_plan[generic_param_name.rsplit(".", 1)[0]] + + current_module_plan = _get_parameter_tp_plan(parameter_name, tp_plan) # Add hooks to the module if not done yet # add_tensor_parallel_hooks_to_module(model, module_to_tp, tp_plan, param_name, current_module_plan, device_mesh) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 6e4c631d481d..b40b7cd2b388 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -63,6 +63,9 @@ from .integrations.sdpa_attention import sdpa_attention_forward from .integrations.tensor_parallel import ( SUPPORTED_TP_STYLES, + _get_parameter_tp_plan, + repack_weights, + replace_state_dict_local_with_dtensor, shard_and_distribute_module, verify_tp_plan, ) @@ -123,6 +126,7 @@ from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files from .utils.import_utils import ( ENV_VARS_TRUE_VALUES, + is_huggingface_hub_greater_or_equal, is_sagemaker_mp_enabled, is_torch_fx_proxy, is_torchdynamo_compiling, @@ -168,6 +172,9 @@ _is_ds_init_called = False _torch_distributed_available = torch.distributed.is_available() +if _torch_distributed_available and is_torch_greater_or_equal("2.5"): + from torch.distributed.tensor import DTensor + def is_fsdp_enabled(): return ( @@ -3413,6 +3420,12 @@ def save_pretrained( if safe_serialization and not is_safetensors_available(): raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.") + # we need to check against tp_size, not tp_plan, as tp_plan is substituted to the class one + if self._tp_size is not None and not is_huggingface_hub_greater_or_equal("0.31.4"): + raise ImportError( + "Saving a model with tensor parallelism requires `huggingface_hub` version 0.31.4 or higher." + ) + if os.path.isfile(save_directory): logger.error(f"Provided path ({save_directory}) should be a directory, not a file") return @@ -3540,6 +3553,10 @@ def save_pretrained( # Rename state_dict keys before saving to file. Do nothing unless overridden in a particular model. # (initially introduced with TimmWrapperModel to remove prefix and make checkpoints compatible with timm) state_dict = self._fix_state_dict_keys_on_save(state_dict) + # If model was sharded, we cannot properly determine sizes of tensors that `local_*` strategy was used, + # therefore we replace them with DTensors that are equivalently sharded + if self._tp_size is not None: + state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh) if safe_serialization: # Safetensors does not allow tensor aliasing. @@ -3548,7 +3565,7 @@ def save_pretrained( for name, tensor in state_dict.items(): # Sometimes in the state_dict we have non-tensor objects. # e.g. in bitsandbytes we have some `str` objects in the state_dict - if isinstance(tensor, torch.Tensor): + if isinstance(tensor, torch.Tensor) or isinstance(tensor, DTensor): ptrs[id_tensor_storage(tensor)].append(name) else: # In the non-tensor case, fall back to the pointer of the object itself @@ -3658,7 +3675,14 @@ def save_pretrained( for shard_file, tensors in filename_to_tensors: shard = {} for tensor in tensors: - shard[tensor] = state_dict[tensor].contiguous() + if isinstance(state_dict[tensor], DTensor): + full_tensor = state_dict[tensor].full_tensor() + # to get the correctly ordered tensor we need to repack if packed + if _get_parameter_tp_plan(tensor, self._tp_plan) in ("local_packed_rowwise",): + full_tensor = repack_weights(full_tensor, -1, self._tp_size, 2) + shard[tensor] = full_tensor.contiguous() # only do contiguous after it's permuted correctly + else: + shard[tensor] = state_dict[tensor].contiguous() # delete reference, see https://github.com/huggingface/transformers/pull/34890 del state_dict[tensor] @@ -4606,6 +4630,7 @@ def _assign_original_dtype(module): # record tp degree the model sharded to model._tp_size = tp_size + model._device_mesh = device_mesh # make sure token embedding weights are still tied if needed model.tie_weights() diff --git a/src/transformers/pytorch_utils.py b/src/transformers/pytorch_utils.py index b2fe91253576..ca60a05a2655 100644 --- a/src/transformers/pytorch_utils.py +++ b/src/transformers/pytorch_utils.py @@ -296,6 +296,13 @@ def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]: guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with non-overlapping lifetimes may have the same id. """ + if is_torch_greater_or_equal_than_2_0: + from torch.distributed.tensor import DTensor + + if isinstance(tensor, DTensor): + local_tensor = tensor.to_local() + return tensor.device, local_tensor.storage().data_ptr(), tensor.nbytes + if tensor.device.type == "xla" and is_torch_xla_available(): # NOTE: xla tensors dont have storage # use some other unique id to distinguish. diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 7837278f6f3d..dd0a2a02d1f1 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -97,6 +97,7 @@ is_grokadamw_available, is_hadamard_available, is_hqq_available, + is_huggingface_hub_greater_or_equal, is_ipex_available, is_jieba_available, is_jinja_available, @@ -542,6 +543,21 @@ def decorator(test_case): return decorator +def require_huggingface_hub_greater_or_equal(version: str): + """ + Decorator marking a test that requires huggingface_hub version >= `version`. + + These tests are skipped when huggingface_hub version is less than `version`. + """ + + def decorator(test_case): + return unittest.skipUnless( + is_huggingface_hub_greater_or_equal(version), f"test requires huggingface_hub version >= {version}" + )(test_case) + + return decorator + + def require_flash_attn(test_case): """ Decorator marking a test that requires Flash Attention. diff --git a/src/transformers/utils/__init__.py b/src/transformers/utils/__init__.py index 480e74c69e01..957662f92011 100644 --- a/src/transformers/utils/__init__.py +++ b/src/transformers/utils/__init__.py @@ -167,6 +167,7 @@ is_habana_gaudi1, is_hadamard_available, is_hqq_available, + is_huggingface_hub_greater_or_equal, is_in_notebook, is_ipex_available, is_jieba_available, diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 3c0079459b10..58b8f9fad9f2 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -1077,6 +1077,19 @@ def is_torch_greater_or_equal(library_version: str, accept_dev: bool = False): return version.parse(importlib.metadata.version("torch")) >= version.parse(library_version) +@lru_cache() +def is_huggingface_hub_greater_or_equal(library_version: str, accept_dev: bool = False): + if not _is_package_available("huggingface_hub"): + return False + + if accept_dev: + return version.parse( + version.parse(importlib.metadata.version("huggingface_hub")).base_version + ) >= version.parse(library_version) + else: + return version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse(library_version) + + def is_torchdistx_available(): return _torchdistx_available diff --git a/tests/tensor_parallel/test_tensor_parallel.py b/tests/tensor_parallel/test_tensor_parallel.py index 7276869d7642..a8ca73263587 100644 --- a/tests/tensor_parallel/test_tensor_parallel.py +++ b/tests/tensor_parallel/test_tensor_parallel.py @@ -12,14 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import subprocess import tempfile import textwrap from transformers import is_torch_available +from transformers.integrations.tensor_parallel import get_packed_weights, repack_weights from transformers.testing_utils import ( TestCasePlus, get_torch_dist_unique_port, + require_huggingface_hub_greater_or_equal, require_torch_multi_gpu, ) @@ -28,19 +31,51 @@ import torch +class TestTensorParallelUtils(TestCasePlus): + def test_packed_unpacked_conversion(self): + WORLD_SIZE = 2 + PACKED_BLOCK_SIZE = 800 + SHARDING_DIM = 2 + NUM_BLOCKS = 2 + + original_packed_weights = torch.randn(4, 512, 2 * PACKED_BLOCK_SIZE) + original_packed_weights.get_dtype = lambda: "F32" # get_packed_weights expects PySlice object + empty_param = torch.empty(4, 512, 2 * PACKED_BLOCK_SIZE) + + class MockDeviceMesh: + def size(self): + return WORLD_SIZE + + mock_mesh = ( + MockDeviceMesh() + ) # get_packed_weights only calls `.size()`, do this to avoid doing actual distributed run + + packed_weights_0 = get_packed_weights(original_packed_weights, empty_param, mock_mesh, 0, SHARDING_DIM) + packed_weights_1 = get_packed_weights(original_packed_weights, empty_param, mock_mesh, 1, SHARDING_DIM) + + # simulate all gather of sharded weights + packed_weights = torch.cat([packed_weights_0, packed_weights_1], dim=SHARDING_DIM) + unpacked_weights = repack_weights(packed_weights, SHARDING_DIM, WORLD_SIZE, NUM_BLOCKS) + + assert torch.allclose(unpacked_weights, original_packed_weights) + + # RUN_SLOW=1 pytest -sv tests/tensor_parallel/test_tensor_parallel.py class TestTensorParallel(TestCasePlus): nproc_per_node = 2 - def torchrun(self, script: str): + def torchrun(self, script: str, is_torchrun: bool = True): """Run the `script` using `torchrun` command for multi-processing in a subprocess. Captures errors as necessary.""" with tempfile.NamedTemporaryFile(mode="w+", suffix=".py") as tmp: tmp.write(script) tmp.flush() tmp.seek(0) - cmd = ( - f"torchrun --nproc_per_node {self.nproc_per_node} --master_port {get_torch_dist_unique_port()} {tmp.name}" - ).split() + if is_torchrun: + cmd = ( + f"torchrun --nproc_per_node {self.nproc_per_node} --master_port {get_torch_dist_unique_port()} {tmp.name}" + ).split() + else: + cmd = ["python", tmp.name] # Note that the subprocess will be waited for here, and raise an error if not successful try: @@ -88,6 +123,48 @@ def test_model_forward(self): ) self.torchrun(script_to_run) + @require_huggingface_hub_greater_or_equal("0.31.4") + def test_model_save(self): + from safetensors import safe_open + + with tempfile.TemporaryDirectory() as tmp_dir: + for is_torchrun in [True, False]: + script_to_run = textwrap.dedent( + f""" + import torch + import os + from transformers import AutoModelForCausalLM + + model_id = "JackFram/llama-68m" + kwargs = dict() + + if os.environ.get("RANK", None) is not None: + kwargs["tp_plan"] = "auto" + result_dir = "{tmp_dir}/tp" + else: + result_dir = "{tmp_dir}/nontp" + + model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs) + model.save_pretrained(result_dir) + """ + ) + self.torchrun(script_to_run, is_torchrun=is_torchrun) + + non_tp_model_path = os.path.join(tmp_dir, "nontp") + tp_model_path = os.path.join(tmp_dir, "tp") + + for filename in os.listdir(non_tp_model_path): + if not filename.endswith(".safetensors"): + continue + + non_tp_model = safe_open(os.path.join(non_tp_model_path, filename), device="cpu", framework="pt") + tp_model = safe_open(os.path.join(tp_model_path, filename), device="cpu", framework="pt") + for non_tp_key in non_tp_model.keys(): + non_tp_tensor = non_tp_model.get_tensor(non_tp_key) + tp_tensor = tp_model.get_tensor(non_tp_key) + assert torch.allclose(non_tp_tensor, tp_tensor), f"Tensor with key: {non_tp_key} does not match" + del non_tp_tensor, tp_tensor + @require_torch_multi_gpu class TestTensorParallelCuda(TestTensorParallel):