From b40b766dedc8933971739bd370f24bbec676a290 Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Tue, 10 Dec 2024 20:10:16 +0800 Subject: [PATCH 01/12] model ckpt 3D support --- .../checkpoint/vescale/_collective_utils.py | 457 +++++++++ internlm/checkpoint/vescale/_utils.py | 463 +++++++++ internlm/checkpoint/vescale/api.py | 53 + .../checkpoint/vescale/base_checkpointer.py | 77 ++ internlm/checkpoint/vescale/bfile.py | 129 +++ internlm/checkpoint/vescale/common.py | 132 +++ internlm/checkpoint/vescale/device_mesh.py | 647 +++++++++++++ internlm/checkpoint/vescale/devicemesh_api.py | 475 +++++++++ .../vescale/distributed_optimizer.py | 109 +++ internlm/checkpoint/vescale/filesystem.py | 905 ++++++++++++++++++ .../checkpoint/vescale/load_state_dict.py | 100 ++ internlm/checkpoint/vescale/mem_checkpoint.py | 396 ++++++++ .../vescale/mem_file_service_pb2.py | 66 ++ .../vescale/mem_file_service_pb2_grpc.py | 321 +++++++ internlm/checkpoint/vescale/mem_server_lib.py | 307 ++++++ internlm/checkpoint/vescale/meta_type.py | 38 + .../checkpoint/vescale/placement_types.py | 563 +++++++++++ .../checkpoint/vescale/save_state_dict.py | 203 ++++ .../vescale/vescale_checkpointer.py | 259 +++++ .../checkpoint/vescale/vescale_planner.py | 268 ++++++ .../vescale/vescale_planner_helpers.py | 288 ++++++ internlm/model/modules/embedding.py | 6 +- internlm/model/modules/linear.py | 20 +- internlm/model/modules/mha.py | 7 +- internlm/train/pipeline.py | 60 +- 25 files changed, 6341 insertions(+), 8 deletions(-) create mode 100644 internlm/checkpoint/vescale/_collective_utils.py create mode 100644 internlm/checkpoint/vescale/_utils.py create mode 100644 internlm/checkpoint/vescale/api.py create mode 100644 internlm/checkpoint/vescale/base_checkpointer.py create mode 100644 internlm/checkpoint/vescale/bfile.py create mode 100644 internlm/checkpoint/vescale/common.py create mode 100644 internlm/checkpoint/vescale/device_mesh.py create mode 100644 internlm/checkpoint/vescale/devicemesh_api.py create mode 100644 internlm/checkpoint/vescale/distributed_optimizer.py create mode 100644 internlm/checkpoint/vescale/filesystem.py create mode 100644 internlm/checkpoint/vescale/load_state_dict.py create mode 100644 internlm/checkpoint/vescale/mem_checkpoint.py create mode 100644 internlm/checkpoint/vescale/mem_file_service_pb2.py create mode 100644 internlm/checkpoint/vescale/mem_file_service_pb2_grpc.py create mode 100644 internlm/checkpoint/vescale/mem_server_lib.py create mode 100644 internlm/checkpoint/vescale/meta_type.py create mode 100644 internlm/checkpoint/vescale/placement_types.py create mode 100644 internlm/checkpoint/vescale/save_state_dict.py create mode 100644 internlm/checkpoint/vescale/vescale_checkpointer.py create mode 100644 internlm/checkpoint/vescale/vescale_planner.py create mode 100644 internlm/checkpoint/vescale/vescale_planner_helpers.py diff --git a/internlm/checkpoint/vescale/_collective_utils.py b/internlm/checkpoint/vescale/_collective_utils.py new file mode 100644 index 000000000..98d43e482 --- /dev/null +++ b/internlm/checkpoint/vescale/_collective_utils.py @@ -0,0 +1,457 @@ +################################################################################ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +################################################################################ +# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. +################################################################################ + +import logging +import math +import copy +from typing import List, Optional + +import torch +import torch.distributed._functional_collectives as funcol +import torch.distributed.distributed_c10d as c10d +from torch.distributed.distributed_c10d import ( + GroupMember, + ProcessGroup, + Work, + all_to_all, + get_global_rank, + get_rank, + broadcast, + get_world_size, + scatter, +) + +from .device_mesh import DeviceMesh, mesh_resources +from .placement_types import DTensorSpec +from internlm.utils.logger import get_logger +logger = get_logger(__file__) + +TORCH_VERSION_BIGGER_THAN_2_2 = torch.__version__ >= "2.2" + + +# NOTE: upstream are working to migrate the following three collective +# apis to be functional, pay attention to it. + + +def mesh_scatter( + output: torch.Tensor, + scatter_list: List[torch.Tensor], + mesh: DeviceMesh, + mesh_dim: int = 0, + async_op: bool = False, +) -> Optional[Work]: + """ + scatter a list of tensors to a device mesh dimension. We by default + use the first rank of the mesh dimension as the source of truth, i.e + for a 2d mesh [[0, 1], [2, 3]], if we scatter on mesh_dim = 1, we will + scatter the tensor list on rank 0 to rank 0/1, and tensor list on rank + 2 to rank 2/3. + + Args: + output (torch.Tensor): the tensor to receive the scattered list. + scatter_list (List[torch.Tensor]): the tensor list to be scattered. + mesh_dim (int, optional): indicate which mesh dimension we want + to scatter on, we by default choose the first rank on the + mesh dimension as source of truth. + + Returns: + A :class:`Work` object + """ + if DebugLogger.IS_DEBUG_MODE: + DebugLogger.log_communication(mesh_scatter, output, scatter_list, mesh_dim, async_op) + + # if rank is not part of mesh, simply return output + if mesh.get_coordinate() is None: + return output + + # TODO: Ideally we should use the meta tensor way + # (to register a meta kernel for the collective op) + # so that it would avoid the communication. Need to + # remove the check below once that is done. + if output.is_meta: + return None + + dim_group = mesh.get_dim_groups(mesh_dim) + assert isinstance(dim_group, ProcessGroup) + # src need to be global rank + src_for_dim = 0 + + if dim_group is not GroupMember.WORLD: + src_for_dim = get_global_rank(dim_group, 0) + + if src_for_dim == get_rank(): + fut = scatter( + output, + scatter_list=scatter_list, + src=src_for_dim, + group=dim_group, + async_op=async_op, + ) + else: + fut = scatter( + output, + scatter_list=None, + src=src_for_dim, + group=dim_group, + async_op=async_op, + ) + + return fut + + +# TODO: test uneven split on GLOO and NCCL + + +def mesh_all_to_all( + output_tensor_list: List[torch.Tensor], + input_tensor_list: List[torch.Tensor], + mesh: DeviceMesh, + mesh_dim: int = 0, + async_op: bool = False, +) -> Optional[Work]: + if DebugLogger.IS_DEBUG_MODE: + DebugLogger.log_communication(mesh_all_to_all, output_tensor_list, input_tensor_list, mesh, mesh_dim, async_op) + + # if rank is not part of mesh, simply return None + if mesh.get_coordinate() is None: + return None + + dim_group = mesh.get_dim_groups(mesh_dim) + assert isinstance(dim_group, ProcessGroup) + + work = None + # no direct dist.all_to_all support on 'gloo' so we manually do scatters + if mesh.device_type == "cpu": + logger.warning("ProcessGroupGloo does not support all_to_all, falling back with scatters!") + # TODO: pull the handle of uneven case in #492 + dim_group_size = get_world_size(dim_group) + for i in range(dim_group_size): + # src need to be global rank + src_for_dim = i + if dim_group is not GroupMember.WORLD: + src_for_dim = get_global_rank(dim_group, i) + + work = scatter( + output_tensor_list[i], + input_tensor_list if mesh.get_rank() == src_for_dim else [], + group=dim_group, + src=src_for_dim, + async_op=async_op, + ) + else: + work = all_to_all( + output_tensor_list, + input_tensor_list, + dim_group, + async_op=async_op, + ) + return work + + +def mesh_all_to_all_single( + tensor: torch.Tensor, + mesh: DeviceMesh, + original_shard_dim: int, + target_shard_dim: int, + mesh_dim: int = 0, + async_op: bool = False, +): + """ + transpose the sharded tensor along a device mesh dimension. + + Args: + tensor (torch.Tensor): tensor to all-to-all. + mesh (DeviceMesh): device mesh that communication happens. + original_shard_dim (int): the dim that source tensor is sharded + target_shard_dim (int): the dim that transposed tensor is sharded + mesh_dim (int, optional): indicate which mesh dimension we want + to broadcast on, we by default choose the first rank on the + mesh dimension as source of truth. + async_op (bool, default False): unused arguments. As all-to-all will + always be sync. + + Returns: + A :class:`Tensor` object + """ + if DebugLogger.IS_DEBUG_MODE: + DebugLogger.log_communication( + mesh_all_to_all_single, tensor, mesh, original_shard_dim, target_shard_dim, mesh_dim + ) + + # if rank is not part of mesh, simply return tensor, which should be an empty tensor + if mesh.get_coordinate() is None: + return tensor + mesh_size = mesh.size(mesh_dim) + assert tensor.size(target_shard_dim) % mesh_size == 0, "we don't support unvevn shard on ``target_shard_dim``" + input_rank = tensor.ndim + assert input_rank >= 2, "input must has at least 2 ranks" + + target_shape = copy.deepcopy(list(tensor.shape)) + target_shape[original_shard_dim] *= mesh_size + target_shape[target_shard_dim] //= mesh_size + + dim_group = mesh.get_dim_groups(mesh_dim) + assert isinstance(dim_group, ProcessGroup) + + if target_shard_dim != 0: + k_new_shape = list(tensor.shape) + k_new_shape[target_shard_dim] //= mesh_size + k_new_shape[0] *= mesh_size + new_shape = list(tensor.shape) + new_shape[target_shard_dim] //= mesh_size + new_shape.insert(target_shard_dim, mesh_size) + indices = ( + [target_shard_dim] + list(range(0, target_shard_dim)) + list(range(target_shard_dim + 1, tensor.ndim + 1)) + ) + tensor = tensor.reshape(new_shape).permute(indices).reshape(k_new_shape) + + output = funcol.all_to_all_single(tensor, output_split_sizes=None, input_split_sizes=None, group=dim_group) + if original_shard_dim == 0: + return output + + n, *out_shape = list(output.shape) + + indices = ( + list(range(1, original_shard_dim)) + + [original_shard_dim, 0] + + list(range(original_shard_dim + 1, output.ndim + 1)) + ) + + return output.reshape(mesh_size, n // mesh_size, *out_shape).permute(indices).reshape(target_shape) + + +def mesh_broadcast( + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int = 0, + async_op=False, +) -> torch.Tensor: + """ + broadcast the tensor to a device mesh dimension. We by default + use the first rank of the mesh dimension as the source of truth, i.e + for a 2d mesh [[0, 1], [2, 3]], if we broadcast on mesh_dim = 1, we will + broadcast the tensor on rank 0 to rank 0/1, and tensor on rank 2 + to rank 2/3. + + Args: + tensor (torch.Tensor): tensor to broadcast. + mesh_dim (int, optional): indicate which mesh dimension we want + to broadcast on, we by default choose the first rank on the + mesh dimension as source of truth. + + Returns: + A :class:`Tensor` object + """ + if DebugLogger.IS_DEBUG_MODE: + DebugLogger.log_communication(mesh_broadcast, tensor, mesh, mesh_dim, async_op) + + # if rank is not part of mesh, simply return tensor, which should be an empty tensor + if mesh.get_coordinate() is None: + return tensor + + dim_group = mesh.get_dim_groups(mesh_dim) + assert isinstance(dim_group, ProcessGroup) + # src need to be global rank + src_for_dim = 0 + if dim_group is not GroupMember.WORLD: + src_for_dim = get_global_rank(dim_group, 0) + if TORCH_VERSION_BIGGER_THAN_2_2: + aysnc_tensor = funcol.broadcast(tensor, src=src_for_dim, group=dim_group) + if not async_op: + return funcol.wait_tensor(aysnc_tensor) + return aysnc_tensor + else: + work = broadcast(tensor, src=src_for_dim, group=dim_group, async_op=async_op) + if not async_op: + return tensor + from torch.distributed._functional_collectives_impl import _register_tensor_work + from torch.distributed._functional_collectives import _maybe_wrap_tensor + + _register_tensor_work(tensor, work) + return _maybe_wrap_tensor(tensor) + + +def mesh_reduce_scatter( + tensor: torch.Tensor, mesh: DeviceMesh, reduce_op: c10d.ReduceOp.RedOpType, scatter_dim: int, mesh_dim: int +) -> torch.Tensor: + """ + First peform all_reduce on the tensor, then split the tensor at scatter_dim + and scatter them to a device mesh dimension. + """ + + if DebugLogger.IS_DEBUG_MODE: + DebugLogger.log_communication(mesh_reduce_scatter, tensor, mesh, reduce_op, scatter_dim, mesh_dim) + + # if rank is not part of mesh, simply return tensor, which should be an empty tensor + if mesh.get_coordinate() is None: + return tensor + + # for now, we only support that size at `scatter_dim`` is divisable by + # the mesh size at `mesh_dim` + num_chunks = mesh.size(dim=mesh_dim) + assert ( + tensor.size(scatter_dim) % num_chunks == 0 + ), f"tensor size at {scatter_dim} is not divisable by the mesh size at {mesh_dim}" + output = funcol.reduce_scatter_tensor( + tensor, reduceOp=reduce_op.name, scatter_dim=scatter_dim, group=mesh._dim_group_infos[mesh_dim][1] + ) + return output + + +def mesh_all_gather( + tensor: torch.Tensor, + global_size: torch.Size, + mesh: DeviceMesh, + scatter_dim: int, + mesh_dim: int, +) -> torch.Tensor: + """ + all_gather all shards and return a tensor that is replicated + on the previously sharded mesh dimension + """ + if DebugLogger.IS_DEBUG_MODE: + DebugLogger.log_communication(mesh_all_gather, tensor, global_size, mesh, scatter_dim, mesh_dim) + + # if rank is not part of mesh, simply return tensor, which should be an empty tensor + if mesh.get_coordinate() is None: + return tensor + + # for now, we only support that global size at `scatter_dim` is equal with + # the multuple of mesh size at `mesh_dim` and local_tensor size at `scatter_dim` + num_chunks = mesh.size(dim=mesh_dim) + assert ( + tensor.size(scatter_dim) * num_chunks == global_size[scatter_dim] + ), f"global tensor size at {scatter_dim} is not equal with the multiply of mesh size at {mesh_dim} and local_tensor size at {scatter_dim}" + tensor = tensor.contiguous() + output = funcol.all_gather_tensor(tensor, gather_dim=scatter_dim, group=mesh._dim_group_infos[mesh_dim][1]) + return output + + +def mesh_all_reduce( + tensor: torch.Tensor, + mesh: DeviceMesh, + reduce_op: c10d.ReduceOp.RedOpType, + mesh_dim: int, +) -> torch.Tensor: + # if rank is not part of mesh, simply return tensor, which should be an empty tensor + if mesh.get_coordinate() is None: + return tensor + + return funcol.all_reduce(tensor, reduceOp=reduce_op.name, group=mesh._dim_group_infos[mesh_dim][1]) + + +def wait(tensor: torch.Tensor) -> torch.Tensor: + if isinstance(tensor, funcol.AsyncCollectiveTensor): + return funcol.wait_tensor(tensor) + return tensor + + +def spec_to_bytes(spec: DTensorSpec) -> int: + assert spec.tensor_meta is not None, "spec should have tensor meta defined!" + return spec.tensor_meta.dtype.itemsize * math.prod(spec.shape) + + +def get_bandwidth_factor(mesh: DeviceMesh) -> List[float]: + # generate bandwidth factor for intra-host/inter-host communication pattern + factors = [1.0] * mesh.ndim + num_devices_per_host = mesh_resources.num_devices_per_host(mesh.device_type) + + num_devices = 1 + for mesh_dim in reversed(range(mesh.ndim)): + num_devices *= mesh.size(mesh_dim) + if num_devices <= num_devices_per_host: + # magic number for intra-host communication bandwidth factor + # TODO: see if we need to tweak this or offer a way for user + # to specify the bandwidths + factors[mesh_dim] = 0.2 + + return factors + + +def allgather_cost(num_bytes: float, mesh: DeviceMesh, mesh_dim: int) -> float: + num_devices_on_mesh_dim = mesh.size(mesh_dim) + bandwidth_factor = get_bandwidth_factor(mesh)[mesh_dim] + # constant latency factor + bandwidth cost + return 1 + bandwidth_factor * num_bytes * (num_devices_on_mesh_dim - 1) / num_devices_on_mesh_dim + + +def allreduce_cost(num_bytes: float, mesh: DeviceMesh, mesh_dim: int) -> float: + num_devices_on_mesh_dim = mesh.size(mesh_dim) + bandwidth_factor = get_bandwidth_factor(mesh)[mesh_dim] + # allreduce have 2x comm bytes compare to allgather/reduce_scatter + return 1 + 2 * bandwidth_factor * num_bytes * (num_devices_on_mesh_dim - 1) / num_devices_on_mesh_dim + + +def reduce_scatter_cost( + num_bytes: float, + mesh: DeviceMesh, + mesh_dim: int, +) -> float: + num_devices_on_mesh_dim = mesh.size(mesh_dim) + bandwidth_factor = get_bandwidth_factor(mesh)[mesh_dim] + # constant latency factor + bandwidth cost + return 1 + bandwidth_factor * num_bytes * (num_devices_on_mesh_dim - 1) / num_devices_on_mesh_dim + + +def redistribute_cost( + current_spec: DTensorSpec, + target_spec: DTensorSpec, +) -> float: + """ + This function returns the cost of redistribute from current to target DTensorSpec. + + NOTE: + 1. Only consider communication cost here, since computation costs for redistribute + are quite trival (i.e. we only need to narrow or simple division) + 2. Only consider redistribute cost on same mesh, cross mesh communication cost is + not quite needed for operator strategy estimation/selection. + """ + if current_spec.mesh != target_spec.mesh: + # make infinite cost if meshes are not same + # TODO: see if we want to support this once there's cross mesh communication + return float("inf") + + if current_spec.is_replicated(): + # short-cut: + # comm cost is 0 if current spec is already full replication + return 0.0 + + mesh = current_spec.mesh + cost = 0.0 + comm_bytes = spec_to_bytes(current_spec) / current_spec.num_shards + # Transformation that considered for redistribute cost: + # 1. allgather 2. alltoall + # 3. allreduce 4. reduce_scatter + for i, (current, target) in enumerate(zip(current_spec.placements, target_spec.placements)): + if current == target: + continue + if current.is_shard() and target.is_replicate(): + # allgather gives larger comm bytes + comm_bytes *= mesh.size(i) + # add up allgather comm cost + cost += allgather_cost(comm_bytes, current_spec.mesh, i) + elif current.is_shard() and target.is_shard(): + # should be alltoall comm, since we haven't implement it yet, add penalty + # to favor allgather instead + cost += allgather_cost(comm_bytes, current_spec.mesh, i) + 1.0 + elif current.is_partial() and target.is_replicate(): + # add up allreduce comm cost + cost += allreduce_cost(comm_bytes, current_spec.mesh, i) + elif current.is_partial() and target.is_shard(): + # add up reduce_scatter comm cost + cost += reduce_scatter_cost(comm_bytes, current_spec.mesh, i) + # after reduce_scatter the comm bytes for further collectives halved. + comm_bytes /= mesh.size(i) + elif current.is_shard() and target.is_partial(): + # ban shard/interleaved_shard -> partial as it does not make sense to perform + # this redistribute + return float("inf") + + return cost diff --git a/internlm/checkpoint/vescale/_utils.py b/internlm/checkpoint/vescale/_utils.py new file mode 100644 index 000000000..0d6f7d8c5 --- /dev/null +++ b/internlm/checkpoint/vescale/_utils.py @@ -0,0 +1,463 @@ +################################################################################ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +################################################################################ +# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. +################################################################################ + +import warnings +import copy +from typing import List, Sequence, Tuple, Optional, Dict, Set, Union + +import torch +import torch.distributed._functional_collectives as funcol +from torch._prims_common import ShapeType + +from .device_mesh import DeviceMesh +from .placement_types import InterleavedShard, Partial, Placement, Replicate, Shard +from ._collective_utils import mesh_all_gather + + +def compute_local_shape(global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement]) -> Tuple[int, ...]: + """ + Compute the shape of a local shard of the given DTensor on its current + coordinate of the mesh. + """ + my_coordinate = mesh.get_coordinate() + + if my_coordinate is None: + # if rank not in the mesh, return empty shape + return () + else: + local_shape = list(global_shape) # start with global shape + ndim = len(global_shape) + for idx, placement in enumerate(placements): + mesh_dim_size = mesh.size(idx) + if isinstance(placement, (Shard, InterleavedShard)): + shard_dim = placement.dim + assert shard_dim < ndim, f"Sharding dim {shard_dim} greater than tensor ndim {ndim}" + local_shard_size, _ = placement._local_shard_size_on_dim( + local_shape[shard_dim], mesh_dim_size, my_coordinate[idx] + ) + assert isinstance(local_shard_size, int) + local_shape[shard_dim] = local_shard_size + + return tuple(local_shape) + + +def compute_local_shape_and_global_offset( + global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] +) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: + """ + Compute the local tensor shape and the global offsets into the original tensor + of a DTensor on its current global rank. This is useful for checkpointing purpose. + Example (2 host with 4GPUs each): + # Below is a DeviceMesh with mesh_shape of (2, 4) + mesh = DeviceMesh(device_type="cuda", + mesh=[ + [0, 1, 2, 3], + [4, 5, 6, 7] + ], + ) + Let's say we distribute a global_tensor of shape (8,4) over the above DeviceMesh + with a placements of [Shard(0), Shard(0)]. + The local shape and global offset will be as follows: + rank0 -- local_shape:[1, 4], global_offset:[0, 0] + rank1 -- local_shape:[1, 4], global_offset:[1, 0] + rank2 -- local_shape:[1, 4], global_offset:[2, 0] + rank5 -- local_shape:[1, 4], global_offset:[5, 0] + rank3 -- local_shape:[1, 4], global_offset:[3, 0] + rank4 -- local_shape:[1, 4], global_offset:[4, 0] + rank6 -- local_shape:[1, 4], global_offset:[6, 0] + rank7 -- local_shape:[1, 4], global_offset:[7, 0] + """ + my_coordinate = mesh.get_coordinate() + + if my_coordinate is None: + # if rank not in the mesh, return empty offset + return ((), ()) + else: + local_shape = list(global_shape) + global_offset = [0] * len(global_shape) + + for idx, placement in enumerate(placements): + mesh_dim_size = mesh.size(idx) + if isinstance(placement, Shard): + shard_dim = placement.dim + local_offset = [0] * len(global_shape) + assert shard_dim < len( + local_shape + ), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" + # TODO: what if placement is InterleavedShard + shard_size, shard_offset = placement._local_shard_size_on_dim( + local_shape[shard_dim], + mesh_dim_size, + my_coordinate[idx], + return_offset=True, + ) + + local_shape[shard_dim] = shard_size + local_offset[shard_dim] = shard_offset + + # On a given dimension, if the local_offset[shard_dim] is smaller than global_offset[shard_dim], + # it means that this dimension has been already sharded in previous placement. + # Therefore, we cannot simply replace the global_offset[shard_dim] with local_offset[shard_dim]. + # Instead, for the given shard_dim, we need to add local_offset[shard_dim] to existing global_offset[shard_dim]. + if global_offset[shard_dim] <= local_offset[shard_dim]: + global_offset[shard_dim] = local_offset[shard_dim] + else: + global_offset[shard_dim] += local_offset[shard_dim] + + return tuple(local_shape), tuple(global_offset) + + +def is_same_shape_across_ranks(tensor_shape: ShapeType, device_mesh: DeviceMesh, placements: Sequence[Placement]): + # check if tensor shapes are the same across ranks + self_shape = torch.tensor([tuple(tensor_shape)], dtype=torch.int64, device=device_mesh.device_type) + for mesh_dim, _ in enumerate(placements): # TODO for perf: use a process group for the entire DeviceMesh + all_shapes = mesh_all_gather( + self_shape, + torch.Size([device_mesh.size(mesh_dim), self_shape.size(1)]), + device_mesh, + scatter_dim=0, + mesh_dim=mesh_dim, + ) + if not torch.all(self_shape == all_shapes): + return False + return True + + +def gather_local_tensor_shape( + self_local_tensor: Union[torch.Tensor, torch.Size], + device_mesh: DeviceMesh, + placements: Sequence[Placement], + shard_only: bool = False, +) -> Optional[Dict[int, List[List[int]]]]: + """All gather local tensor shapes per mesh dimension. + When `shard_only is True`, all gather only sharded mesh dim. Otherwise, all gather all mesh dims.""" + if device_mesh.get_coordinate() is None: # if rank is not part of mesh + return None + + _shape: torch.Size = self_local_tensor if isinstance(self_local_tensor, torch.Size) else self_local_tensor.shape + self_local_shape = torch.tensor([list(_shape)], dtype=torch.int64, device="cpu", pin_memory=True) + self_local_shape = self_local_shape.to(device_mesh.device_type, non_blocking=True) + meshdim_localtensor_shape = {} + for mesh_dim, place in enumerate(placements): + if shard_only and not isinstance(place, (Shard, InterleavedShard)): + continue + stacked_local_shape = mesh_all_gather( + self_local_shape, + torch.Size([device_mesh.size(mesh_dim), self_local_shape.size(1)]), + device_mesh, + scatter_dim=0, + mesh_dim=mesh_dim, + ) + if type(stacked_local_shape) is funcol.AsyncCollectiveTensor: + # synchronously wait for any pending collectives to get the result tensor + stacked_local_shape = stacked_local_shape.trigger_wait() + if hasattr(stacked_local_shape, "elem"): + stacked_local_shape = stacked_local_shape.elem # type: ignore[attr-defined] + + meshdim_localtensor_shape[mesh_dim] = stacked_local_shape.detach().cpu().tolist() + return meshdim_localtensor_shape + + +def compute_global_tensor_info( + tensor: torch.Tensor, + mesh: DeviceMesh, + placements: Sequence[Placement], + meshdim_localtensor_shape: Optional[Dict[int, List[List[int]]]] = None, +) -> Tuple[List[int], List[int]]: + """ + Compute the global size and stride of a DTensor from the given local tensor. + + When `meshdim_localtensor_shape` is None (be default): + The local size is multiplited by `world_size` per Sharding dim. + The local stride is multiplited by `world_size` per Sharding dim, as long as the + dimension is outside sharding dim. + + For example, if we have a local tensor with size (4, 8, 2) and stride (16, 1, 8). + If the DTensor placements are [Shard(2)] and world_size is 2; + then the global size is (4, 8, 4) and stride is (16 * 2, 1, 8). + + When `meshdim_localtensor_shape` is provided: + All local sizes are summed togather as global Sharding dim. + The local stride is scaled by global Sharding dim divided by local Sharding dim, + as long as the dimension is outside sharding dim. + + For example, if we have a local tensor with size (4, 8, 2) and stride (16, 1, 8) on rank0, + and a local tensor with size (4, 8, 1) and stride (8, 1, 8) on rank1. + If the DTensor placements are [Shard(2)] and world_size is 2; + then the global size is (4, 8, 3) and stride is (8 * 3, 1, 8). + + Args: + tensor (:class:`torch.Tensor`): + Local tensor which DTensor will be constructed from. + mesh (:class:`DeviceMesh`): + Object which describes the mesh topology of devices for the DTensor. + placements (Sequence[:class:`Placement`]]): + The attribute of the DTensor that describes its layout on the mesh topology. + meshdim_localtensor_shape (:class:`Dict[int, List[List[int]]]`): + Default None. + Otherwise, a given list for local tensor shapes per device mesh dim. + + Return: + tensor_shape: A List of int which specifies the global size of DTensor which build + on top of the local tensor. + tensor_stride: A List of int which specifies the global stride of DTensor. + """ + if meshdim_localtensor_shape is None: # assume even sharding (contiguous or non-contiguous) + tensor_shape = list(tensor.size()) + tensor_stride = list(tensor.stride()) + else: # support uneven sharding (contiguous only) + if not tensor.is_contiguous(): + warnings.warn( + "`from_local` take non-contiguous local tensor, which is not supported in uneven sharding. Treat as contiguous.", + UserWarning, + ) + # a meta empty tensor is created for obtaining correct local stride, + # especially when local tensor is non-contiguous or narrowed from padding or zero dimmed. + # TODO: rethink supporting non-contiguous local tensor which is narrowed from padding or zero dimmed. + tensor = torch.empty(tensor.shape, dtype=tensor.dtype, device="meta") + tensor_shape = list(tensor.size()) + tensor_stride = list(tensor.stride()) + + # record occured shard dim + shard_dim_occured: Set[int] = set() + + for idx, placement in enumerate(placements): + mesh_dim_size: int = mesh.size(idx) + + # TODO: rethink about this InterleavedShard. + if placement.is_shard() or placement.is_interleaved_shard(): + if placement.dim < 0: + placement.dim += len(tensor_shape) + shard_dim = placement.dim + + assert ( + shard_dim < tensor.ndim + ), f"Sharding dim {shard_dim} greater than tensor ndim {tensor.ndim} for placement number {idx}." + + if shard_dim in shard_dim_occured: + warnings.warn( + "Sharding the same tensor dim is not supported for uneven sharding. Treat as even sharding.", + UserWarning, + ) + is_shard_same_dim = True + else: + shard_dim_occured.add(shard_dim) + is_shard_same_dim = False + + # recover global shape + local_dim_size = tensor_shape[shard_dim] + if meshdim_localtensor_shape is None or is_shard_same_dim: + # duplicate local shape at this sharded dim as global shape + tensor_shape[shard_dim] = local_dim_size * mesh_dim_size + else: + # concat local shapes at this sharded dim as global shape + global_dim_size = sum(shape[shard_dim] for shape in meshdim_localtensor_shape[idx]) + tensor_shape[shard_dim] = global_dim_size + + # recover tensor stride by modifying the stride that larger than + # the current stride on the shard_dim + is_contiguous_tensor = all(tensor_stride[i] >= tensor_stride[i + 1] for i in range(len(tensor_stride) - 1)) + for i in range(len(tensor_stride)): + if (i != shard_dim and tensor_stride[i] >= tensor_stride[shard_dim] and not is_contiguous_tensor) or ( + i < shard_dim and is_contiguous_tensor + ): + # rescale the stride by the shard size + if meshdim_localtensor_shape is None or is_shard_same_dim: + tensor_stride[i] = tensor_stride[i] * mesh_dim_size + else: + if local_dim_size == 0: + tensor_stride[i] *= max(global_dim_size, 1) + else: + assert tensor_stride[i] % local_dim_size == 0 + tensor_stride[i] = tensor_stride[i] // local_dim_size * global_dim_size + + elif not isinstance(placement, (Replicate, Partial)): + raise RuntimeError(f"placement type {type(placement)} not supported!") + + return tensor_shape, tensor_stride + + +def is_zero_out_local_shard(mesh: DeviceMesh, placements: Sequence[Placement]) -> bool: + """ + Compute whether we need to zero out the local shard of current rank, for Partial(). + + e.g. we want a bias tensor in [Partial(), Shard(0), Partial()] + [ [[b1, 0.] + [b2, 0.]] + + [[0., 0.] + [0., 0.]] ] + on a 3D-DeviceMesh: + [ [[0, 1] + [2, 3]] + + [[4, 5] + [6, 7]] ] + The computed result should be: + [ [[False, True] + [False, True]] + + [[True, True] + [True, True]] ] + """ + my_coordinate = mesh.get_coordinate() + + if my_coordinate is None: # if rank not in the mesh, nothing to zero out + return False + + for idx, placement in enumerate(placements): + if not placement.is_partial(): + continue + # we zero out all other ranks of the current mesh dim + # and leave only src-of-truth rank 0 have the data, to perform a "zero cost" shard. + if my_coordinate[idx] != 0: + return True + + return False + + +def _equal_meta_data(dt1, dt2, exact_device: bool) -> bool: + if type(dt1).__name__ != "DTensor" or type(dt2).__name__ != "DTensor": + return False + # check itself + if exact_device and (dt1.device.type != dt2.device.type): + return False + if dt1.shape != dt2.shape: + return False + if dt1.dtype != dt2.dtype: + return False + if dt1.layout != dt2.layout: # torch.strided (dense) or torch.sparse_* + return False + if dt1.stride() != dt2.stride(): + return False + if dt1.requires_grad != dt2.requires_grad: + return False + # check global spec + if exact_device: + if dt1._spec.mesh != dt2._spec.mesh: + return False + else: + if not dt1._spec.mesh.mesh.equal(dt2._spec.mesh.mesh): + return False + if dt1._spec.placements != dt2._spec.placements: + return False + if dt1._spec.tensor_meta != dt2._spec.tensor_meta: + return False + # check local tensor (ref: https://github.com/pytorch/pytorch/blob/63ae1051e17b1cf4fe55ac6b6f17c16672d44150/aten/src/ATen/native/cuda/Equal.cpp#L15) + t1, t2 = dt1._local_tensor, dt2._local_tensor + if exact_device and (t1.device.type != t2.device.type): + return False + if t1.shape != t2.shape: + return False + if t1.dtype != t2.dtype: + return False + if t1.layout != t2.layout: # torch.strided (dense) or torch.sparse_* + return False + if t1.is_contiguous() != t2.is_contiguous(): + return False + if t1.stride() != t2.stride(): + return False + if t1.storage_offset() != t2.storage_offset(): + return False + if t1.requires_grad != t2.requires_grad: + return False + return True + + +def equal(dt1, dt2, exact_device: bool = True) -> bool: + """ + check if two DTensors are 'exactly' equal + """ + if not _equal_meta_data(dt1, dt2, exact_device): + return False + if dt1.is_meta and dt2.is_meta: + return True + if exact_device: + return torch.equal(dt1._local_tensor, dt2._local_tensor) # check value only + else: + return torch.equal(dt1._local_tensor.cpu(), dt2._local_tensor.cpu()) # check value only + + +def allclose( + dt1, + dt2, + rtol: float = 1e-05, + atol: float = 1e-08, + equal_nan: bool = False, + exact_device: bool = True, +) -> bool: + """ + check if two DTensors are 'allclose' + """ + if not _equal_meta_data(dt1, dt2, exact_device): + return False + if dt1.is_meta and dt2.is_meta: + return True + if exact_device: + return torch.allclose( + dt1._local_tensor, dt2._local_tensor, rtol=rtol, atol=atol, equal_nan=equal_nan + ) # check value only + else: + return torch.allclose( + dt1._local_tensor.cpu(), dt2._local_tensor.cpu(), rtol=rtol, atol=atol, equal_nan=equal_nan + ) # check value only + + +def compute_local_offset(global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement]) -> Tuple[int, ...]: + """ + Compute the offsets of a local shard of the given "DTensor" on its current + global rank. This is mostly used by distributed checkpointing to know the + exact offsets of the local shard. + """ + my_coordinate = mesh.get_coordinate() + + if my_coordinate is None: + # if rank not in the mesh, return empty offset + return () + else: + local_offsets = [0] * len(global_shape) + local_shape = list(global_shape) + + for idx, placement in enumerate(placements): + mesh_dim_size = mesh.size(idx) + if isinstance(placement, Shard): + shard_dim = placement.dim + assert shard_dim < len( + local_shape + ), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" + shard_size, shard_offset = placement._local_shard_size_on_dim( + local_shape[shard_dim], + mesh_dim_size, + my_coordinate[idx], + return_offset=True, + ) + local_shape[shard_dim] = shard_size + local_offsets[shard_dim] = shard_offset + return tuple(local_offsets) + + +def compute_global_stride( + local_tensor: torch.Tensor, mesh: DeviceMesh, placements: Sequence[Placement] +) -> Tuple[int, ...]: + """ """ + my_coordinate = mesh.get_coordinate() + if my_coordinate is None: + return () + if not local_tensor.is_contiguous(): + raise RuntimeError("local tensor should be contiguous") + global_stride = copy.deepcopy(list(local_tensor.stride())) + for i, p in enumerate(placements): + if not p.is_shard(): + continue + shard_dim = p.dim + shard_size = mesh.size(i) + for j in range(shard_dim): + global_stride[j] *= shard_size + return tuple(global_stride) diff --git a/internlm/checkpoint/vescale/api.py b/internlm/checkpoint/vescale/api.py new file mode 100644 index 000000000..7d59db1db --- /dev/null +++ b/internlm/checkpoint/vescale/api.py @@ -0,0 +1,53 @@ +################################################################################ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +################################################################################ +# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. +################################################################################ +# The "checkpoint" folder is ONLY USED for "open source" version veScale + +from .vescale_checkpointer import VeScaleCheckpointer +from .meta_type import CheckpointState + + +def save(path: str, checkpoint_state: CheckpointState, async_checkpoint=False): + """ + Save a checkpoint to a given path + Args: + path: Defines the storage path for checkpoint. + checkpoint_state: A dictionary contains key-value pairs for model and optimizer. + - Model: Identified by 'model' key, value should be a model instance. + - Optimizer: Identified by 'optimizer' key, value should be an optimizer instance. + async_checkpoint: A boolean value indicating if saving checkpoint asynchronously, + i.e. after dumping tensors from GPU memory to Host memory, + the training program can continue training immediately. + Then vescale.checkpoint will serialize tensors and dumping to the persistent storage asynchronously. + Example: + >>> checkpoint_state = { "model": distributd_model, "optimizer": distributed_optimizer } + >>> vescale.checkpoint.save("/user/vescale/gpt/", checkpoint_state) + """ + VeScaleCheckpointer.save(path, checkpoint_state, async_checkpoint=async_checkpoint) + + +def load(path: str, checkpoint_state: CheckpointState, broadcast_checkpoint=False): + """ + Load a checkpoint from a given path + Args: + path: Defines the storage path for checkpoint. + checkpoint_state: A dictionary contains key-value pairs for model and optimizer. + - Model: Identified by 'model' key, value should be a model instance. + - Optimizer: Identified by 'optimizer' key, value should be an optimizer instance. + broadcast_checkpoint: A boolean value decides if load a model replica from one data parallel process group + then broadcast tensors to other data parallel process group using GPUs + to reduce the file system access + For example, when data parellel size = 2, + processes with data parallel rank = 0 load model from file system + then broadcast it to processes with data parallel rank = 1 + Example: + >>> checkpoint_state = { "model": distributd_model, "optimizer": distributed_optimizer } + >>> vescale.checkpoint.load("/user/vescale/gpt/", checkpoint_state) + """ + VeScaleCheckpointer.load(path, checkpoint_state, broadcast_checkpoint=broadcast_checkpoint) diff --git a/internlm/checkpoint/vescale/base_checkpointer.py b/internlm/checkpoint/vescale/base_checkpointer.py new file mode 100644 index 000000000..293f9c003 --- /dev/null +++ b/internlm/checkpoint/vescale/base_checkpointer.py @@ -0,0 +1,77 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ +from .meta_type import CheckpointState +from typing import Dict, List +from concurrent.futures import Future, ProcessPoolExecutor +from torch.distributed.checkpoint.storage import WriteResult +from .meta_type import MODEL_STR, OPTIMIZER_STR + +SUPPORTED_TYPES = {MODEL_STR, OPTIMIZER_STR} + + +class BaseCheckpointer: + """ + The Checkpointer class offers APIs that enable users to save and load state dictionarie. + It is designed for extension across various training frameworks. + """ + + # Async IO related members. + state_io_workers: Dict[str, ProcessPoolExecutor] = {} + state_write_futures: Dict[str, Future[List[WriteResult]]] = {} + + @classmethod + def save(cls, path: str, checkpoint_state: CheckpointState): + """ + A Method for saving checkpoint + Args: + path: Defines the storage path for checkpoint. + checkpoint_state: A dictionary contains key-value pairs for model and optimizer. + - Model: Identified by 'model' key, value should be a model instance. + - Optimizer: Identified by 'optimizer' key, value should be an optimizer instance. + + """ + raise NotImplementedError() + + @classmethod + def load(cls, path: str, checkpoint_state: CheckpointState): + """ + A Method for loading checkpoint + Args: + path: Defines the storage path for checkpoint. + checkpoint_state: A dictionary contains key-value pairs for model and optimizer. + - Model: Identified by 'model' key, value should be a model instance. + - Optimizer: Identified by 'optimizer' key, value should be an optimizer instance. + + """ + raise NotImplementedError() + + @classmethod + def _cleanup_futures(cls): + """ + Wait for all write futures to finish before exit, then do the cleanup works. + + WARNING: this method cannot be called by the users. + """ + for key in SUPPORTED_TYPES: + if key in cls.state_write_futures: + futures = cls.state_write_futures[key] + for fut in futures: + fut.result() + cls.state_write_futures[key] = [] + if cls.state_io_workers[key] is not None: + cls.state_io_workers[key].shutdown() + cls.state_io_workers[key] = None diff --git a/internlm/checkpoint/vescale/bfile.py b/internlm/checkpoint/vescale/bfile.py new file mode 100644 index 000000000..f3ba4e331 --- /dev/null +++ b/internlm/checkpoint/vescale/bfile.py @@ -0,0 +1,129 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ +# Existing APIs all follow the rule here: +# https://www.tensorflow.org/api_docs/python/tf/io/gfile + + +import os +import enum +import contextlib +import uuid +from internlm.utils.logger import get_logger +import shutil +from . import mem_server_lib + +logger = get_logger(__file__) +BFILE_DEFAULT_TIMEOUT = None + + +class FileType(enum.Enum): + LOCAL = 0 + LOCAL_MEM = 1 + + +def local_list_folder(folder_path: str, recursive: bool = False): + file_paths = [] + if recursive: + for root, _, files in os.walk(folder_path): + for file_name in files: + file_path = os.path.join(root, file_name) + file_paths.append(file_path) + else: + if os.path.isdir(folder_path): + file_paths.extend([os.path.join(folder_path, d) for d in os.listdir(folder_path)]) + elif os.path.isfile(folder_path): + file_paths.append(folder_path) + else: + logger.warning(f"Path {folder_path} is invalid") + + return file_paths + + +def get_schema(path: str): + if path.startswith(mem_server_lib.SCHEMA): + return FileType.LOCAL_MEM + return FileType.LOCAL + + +def rename(src, dst, overwrite=False): + t = get_schema(src) + if t == FileType.LOCAL_MEM: + return mem_server_lib.rename(src, dst, overwrite) + return os.rename(src, dst) + + +def listdir(path): + t = get_schema(path) + if t == FileType.LOCAL_MEM: + return mem_server_lib.listdir(path) + absolute_files = local_list_folder(path) + return [f[f.rfind("/") + 1 :] for f in absolute_files] + + +def remove(path): + t = get_schema(path) + if t == FileType.LOCAL_MEM: + return mem_server_lib.remove(path) + return shutil.rmtree(path, ignore_errors=True) + + +def exists(path): + t = get_schema(path) + if t == FileType.LOCAL_MEM: + return mem_server_lib.exists(path) + return os.path.exists(path) + + +def makedirs(path): + t = get_schema(path) + if t == FileType.LOCAL_MEM: + # Local mem doesn't have empty folder + return + return os.makedirs(path, exist_ok=True) + + +@contextlib.contextmanager +def BFile(name, mode="r"): + t = get_schema(name) + if t == FileType.LOCAL_MEM: + with mem_server_lib.open(name, mode) as f: + yield f + else: + with open(name, mode) as f: + yield f + + +# ---- Below is some useful utilities ----- + + +def atomic_write(path: str, content: bytes, **kwargs): + tmp_path = path + "_tmp_" + str(uuid.uuid4()) + with BFile(tmp_path, "wb", **kwargs) as f: + f.write(content) + rename(tmp_path, path, overwrite=True) + + +def safe_atomic_write(path: str, content: bytes, **kwargs): + makedirs(os.path.dirname(path)) + atomic_write(path, content, **kwargs) + + +def is_local_path(path: str): + t = get_schema(path) + if t == FileType.LOCAL_MEM or t == FileType.LOCAL: + return True + return False diff --git a/internlm/checkpoint/vescale/common.py b/internlm/checkpoint/vescale/common.py new file mode 100644 index 000000000..2d19fef28 --- /dev/null +++ b/internlm/checkpoint/vescale/common.py @@ -0,0 +1,132 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ +import dataclasses +from typing import Any, Dict, List, Tuple, Hashable, Optional +from collections import OrderedDict +from torch.distributed.checkpoint.planner import SavePlan +from torch.distributed.checkpoint.metadata import MetadataIndex, Metadata +import collections +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) + + +@dataclasses.dataclass +class P2PTensorsInfo: + """ + Record data about tesnors which are across dp ranks + recv_tensors: A dictionary + Key: fqn + Value: a dictionary + key is the process rank, + value is a tuple with (tensor, 1d_range) + send_p2p_reqs: a list of p2p send requests to wait + recv_p2p_reqs: a list p2p receive requests to wait + """ + + recv_tensors: Dict[str, Any] + send_p2p_reqs: List[Any] + recv_p2p_reqs: List[Any] + + +def sort_rank_ranges(process_list: List[Tuple]) -> List[Tuple]: + """ + Decide which rank is receiver and writer + Let rank with most parameters receives and writes tensors + for the best communication cost + If two ranks has the same data size, choose the smaller rank + Args: + A process list with tuples, each tuple is (rank, data_size) + Returns: + A sorted list, data size are sorted in descending order, + if two ranks has the same data size, ranks are in the asceonding order + """ + sorted_process_list = sorted(process_list, key=lambda x: (-x[1], x[0])) + return sorted_process_list + + +_MAX_CACHE_SIZE = 8 + + +class PlanLRUCache: + def __init__(self) -> None: + self._cache: OrderedDict[Hashable, Tuple[SavePlan, Metadata]] = OrderedDict() + self._capacity = _MAX_CACHE_SIZE + + def get(self, key: Hashable) -> Optional[Tuple[SavePlan, Metadata]]: + if key in self._cache: + return self._cache[key] + else: + return None + + def put(self, key: Hashable, plan_value: SavePlan, metadata_value: Metadata) -> None: + if key in self._cache: + self._cache.move_to_end(key, last=False) + else: + self._cache[key] = (plan_value, metadata_value) + if len(self._cache) > self._capacity: + self._cache.popitem() + + def clear(self) -> None: + self._cache.clear() + self._capacity = _MAX_CACHE_SIZE + + def __repr__(self) -> str: + return f"PlanLURCache(capacity: {self._capacity}, keys: {tuple(self._cache.keys())})" + + +def custom_dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]: + """ + A function to remove duplicate tensors to write + when creating global writing plan for saving checkpoint + During the deduplication, + we balance the workloads for duplicated tensors + """ + key_to_plan: Dict[MetadataIndex, List[int]] = {} + for plan_idx, plan in enumerate(all_plans): + for write_item in plan.items: + key_to_plan.setdefault(write_item.index, []).append(plan_idx) + + replicated_items = {k: v for k, v in key_to_plan.items() if len(v) > 1} + # Remove duplicates by always keeping the first entry (Not balance). + # Compute the per-rank remove set. + plan_to_keys: Dict[int, List[MetadataIndex]] = {} + # Record the number of non-duplicated tensors assigned to each rank + assigned_work_load = collections.defaultdict(int) + for plan_idx, plan in enumerate(all_plans): + for write_item in plan.items: + if write_item.index not in replicated_items: + assigned_work_load[plan_idx] += 1 + + for key, plans in replicated_items.items(): + # For duplicated tensors, select the rank assigned with minimum number tensors so far + writer_id = min(plans, key=lambda k: assigned_work_load[k]) + assigned_work_load[writer_id] += 1 + for plan_idx in plans: + # If the rank is not writer rank, remove the key in the rank's plan + if plan_idx != writer_id: + plan_to_keys.setdefault(plan_idx, []).append(key) + logger.info("Duplicate keys to remove: %s", plan_to_keys) + + for plan_idx, keys in plan_to_keys.items(): + # Key Set contains keys to remove + key_set = set(keys) + # rewrite items and remove elements + new_items = [write_item for write_item in all_plans[plan_idx].items if write_item.index not in key_set] + all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items) + + return all_plans diff --git a/internlm/checkpoint/vescale/device_mesh.py b/internlm/checkpoint/vescale/device_mesh.py new file mode 100644 index 000000000..e578471d2 --- /dev/null +++ b/internlm/checkpoint/vescale/device_mesh.py @@ -0,0 +1,647 @@ +################################################################################ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +################################################################################ +# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. +################################################################################ + +import logging +import math +import warnings +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + +import torch +import torch.distributed._functional_collectives as funcol +from torch.distributed.distributed_c10d import ( + ProcessGroup, + _find_pg_by_ranks_and_tag, + _get_default_group, + _get_group_size, + _get_group_tag, + get_process_group_ranks, + get_rank, + get_world_size, + init_process_group, + is_initialized, + new_group, +) + +from internlm.utils.logger import get_logger +logger = get_logger(__file__) + +# only import numpy typing when type checking +if TYPE_CHECKING: + try: + from numpy.typing import ArrayLike + except ImportError: + logger.warning("DeviceMesh requires numpy >= 1.21 to be installed for type checking") + + +class _MeshEnv: + def __init__(self) -> None: + self.mesh_stack: List[DeviceMesh] = [] + self.child_to_parent_mapping: Dict[DeviceMesh, DeviceMesh] = {} + + def get_current_mesh(self) -> "DeviceMesh": + if len(self.mesh_stack) == 0: + raise RuntimeError("No device mesh is currently active!") + return self.mesh_stack[-1] + + def create_child_mesh(self, device_mesh: "DeviceMesh", mesh_dim: int, mesh_dim_name: str) -> "DeviceMesh": + # swap the current dim to the last dim then reshape to flatten out other + # dims, so we can just extract the list of ranks which contains cur_rank. + cur_rank = device_mesh.get_rank() + pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape(-1, device_mesh.mesh.size(mesh_dim)) + + for mesh_1d in pg_ranks_by_dim: + sub_mesh = DeviceMesh( + device_mesh.device_type, + mesh_1d, + mesh_dim_names=(mesh_dim_name,), + _init_process_groups=False, + ) + if cur_rank in mesh_1d: + res_sub_mesh = sub_mesh + + res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]] + # Assign the current DeviceMesh as the parent of the child DeviceMesh. + self.child_to_parent_mapping[res_sub_mesh] = device_mesh + return res_sub_mesh + + def create_submesh_along_multi_dims( + self, device_mesh: "DeviceMesh", mesh_dims: List[int], cur_rank: int = None + ) -> "DeviceMesh": + # swap the current dim to the last dim then reshape to flatten out other + # dims, so we can just extract the list of ranks which contains cur_rank. + # check dims + dim_size = [-1] + for dim in mesh_dims: + if dim >= device_mesh.ndim: + raise RuntimeError("Mesh dim in sub groups out of range!") + dim_size.append(device_mesh.mesh.size(dim)) + mesh_tensor = device_mesh.mesh + for dim in mesh_dims: + mesh_tensor = mesh_tensor.swapdims(-1, dim) + if cur_rank is None: + cur_rank = device_mesh.get_rank() + pg_ranks_by_dims = mesh_tensor.reshape(dim_size) + for mesh_nd in pg_ranks_by_dims: + sub_mesh = DeviceMesh( + device_mesh.device_type, + mesh_nd, + _init_process_groups=False, + ) + if cur_rank in mesh_nd: + res_sub_mesh = sub_mesh + res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[dim] for dim in mesh_dims] + self.child_to_parent_mapping[res_sub_mesh] = device_mesh + return res_sub_mesh + + def create_submesh_group(self, device_mesh: "DeviceMesh", mesh_dim: int) -> "DeviceMesh": + # swap the current dim to the last dim then reshape to flatten out other + # dims, so we can just extract the list of ranks which contains cur_rank. + # check dims + pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape(-1, device_mesh.mesh.size(mesh_dim)) + res = [] + for mesh_1d in pg_ranks_by_dim: + sub_mesh = DeviceMesh( + device_mesh.device_type, + mesh_1d, + _init_process_groups=False, + ) + sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]] + # Assign the current DeviceMesh as the parent of the child DeviceMesh. + self.child_to_parent_mapping[sub_mesh] = device_mesh + res.append(sub_mesh) + return res + + def get_parent_mesh(self, device_mesh: "DeviceMesh") -> Optional["DeviceMesh"]: + return self.child_to_parent_mapping.get(device_mesh, None) + + def get_parent_mesh_dim(self, device_mesh: "DeviceMesh") -> Optional[int]: + """ + Return the index of the mesh dim in the parent mesh. + The device_mesh passed in needs to be sliced out from a parent mesh. + """ + parent_mesh = self.get_parent_mesh(device_mesh) + child_mesh_dim_names = device_mesh.mesh_dim_names + if parent_mesh and child_mesh_dim_names: + assert len(child_mesh_dim_names) == 1, "The child mesh can only be a 1D mesh." + child_mesh_dim_name = child_mesh_dim_names[0] + if parent_mesh.mesh_dim_names: + return parent_mesh.mesh_dim_names.index(child_mesh_dim_name) + return None + + @staticmethod + def num_devices_per_host(device_type: str) -> int: + return _get_device_handle(device_type).device_count() + + @staticmethod + def num_hosts(device_type: str) -> int: + # ProcessGroup can't tell us this info so we have to infer it, assume + # homogeneous hardware for now + return get_world_size() // _MeshEnv.num_devices_per_host(device_type) + + +mesh_resources: _MeshEnv = _MeshEnv() + + +def _get_device_handle(device_type: str = "cuda"): + """ + Get the module corresponding to the device_type which is cuda or cuda-like device. + For example, when the device_type is cuda, the module `torch.cuda` is returned. + Return None when there is no corresponding module for device_type, otherwise + return the corresponding module. + """ + return getattr(torch, device_type, None) + + +class DeviceMesh: + """ + DeviceMesh represents a mesh of devices (given by `device_type`), where layout + of devices could be represented as a n-d dimension array `mesh`, and each value + of the `mesh` is the global rank in the default process group. + + DeviceMesh could be used to describe the layout of devices across the cluster + via `mesh_dim_names`, and serves as a proxy for communication among the device lists + within the cluster. + + By default (`pg` is `None`), we use the default ProcessGroup in this DeviceMesh class + to implement proper communications. Note that we also add collective wrappers in this + class. This is used to decouple detailed communication backend with the underlying + DTensor implementation. + + By giving an existing ProcessGroup `pg`, we construct a device mesh from this `pg`, + instead of the default ProcessGroup. + + Here are the expected behaviors: + | `mesh` | `pg` | result | catch + --------------------------------------------------------------------------------------------- + | None | None | raise error! | + | EXIST | None | use `mesh` + default ProcessGroup | + | None | EXIST | use `pg`'s ranks + `pg` ProcessGroup | 1D mesh only + | EXIST | EXIST | use `pg`'s ranks + `pg` ProcessGroup | `mesh` must equal to `pg`'s ranks + + Args: + device_type (str): device type of the mesh. Currently supports: cpu, cuda/cuda-like, meta. + mesh (ndarray): could be a multi-dimension array or an integer tensor that + describes the layout of devices, the ids are global ids of the default process group. + mesh_dim_names (Optional[Tuple[str]]): A tuple of mesh dim names to be assigned to each + dimension of the multi-dimensional array that describes the layout of devices. Its + length must match the length of `mesh_shape`. Each string in mesh_dim_names must be unique. + pg (Optional[ProcessGroup]): the given ProcessGroup. See above for expected behaviors. + + Returns: + A :class:`DeviceMesh` object + + Example (2 host with 4 GPUs each): + ``` + # The following program runs on each process/rank in SPMD manner. + # initialize device mesh as (2, 4) to represent the topology + # of cross-host(dim 0), and within-host (dim 1) + mesh = DeviceMesh(device_type="cuda", + mesh=[ + [0, 1, 2, 3], + [4, 5, 6, 7] + ]) + ``` + A reduction over the first dimension of mesh will reduce across + columns (0, 4), .. and (3, 7), a reduction over the second dimension + of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7) + + Note: + DeviceMesh can be used as a context manager. + """ + + device_type: str + mesh: Optional[Union[torch.Tensor, "ArrayLike"]] + mesh_dim_names: Optional[Tuple[str, ...]] + + def __init__( + self, + device_type: str, + mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None, + *, + mesh_dim_names: Optional[Tuple[str, ...]] = None, + pg: Optional[ProcessGroup] = None, + _validate_mesh: bool = True, + _init_process_groups: bool = True, + ) -> None: + # for performance, update debug env once here + # check args + if mesh is None and pg is None: + raise ValueError("Either `mesh` or `pg` must be provided!") + if mesh is not None and pg is not None: + pg_mesh_tensor = torch.tensor(get_process_group_ranks(pg), dtype=torch.int, device="cpu") + mesh_tensor = ( + mesh.detach().cpu() + if isinstance(mesh, torch.Tensor) + else torch.tensor(mesh, dtype=torch.int, device="cpu") + ) + if not torch.equal(mesh_tensor, pg_mesh_tensor): + raise ValueError(f"mesh({mesh_tensor}) and pg({pg_mesh_tensor}) must have the same content!") + if pg is not None: + self.mesh = torch.tensor(get_process_group_ranks(pg), dtype=torch.int, device="cpu") + warnings.warn("Construction from given ProcessGroup is only supported for 1D mesh currently.") + # TO FIX: use `mesh` to reshape `pg_mesh_tensor` for nD mesh tensor + if mesh is not None: + self.mesh = ( + mesh.detach().cpu() + if isinstance(mesh, torch.Tensor) + else torch.tensor(mesh, dtype=torch.int, device="cpu") + ) + + self.device_type = device_type + self.mesh_dim_names = mesh_dim_names + + # private field to pre-generate DeviceMesh's hash + self._flatten_mesh_list = tuple(self.mesh.flatten().tolist()) + self._hash = hash((self._flatten_mesh_list, self.mesh.shape)) + + # step 1: try to create default world pg. + if pg is None: + pg = self._get_or_create_default_group() + else: + # TODO: this logic only applies when device_type is cuda + pg_world_size = get_world_size(group=pg) + device_handle = _get_device_handle(self.device_type) + num_devices_per_host = device_handle.device_count() + if pg_world_size > num_devices_per_host and pg_world_size % num_devices_per_host != 0: + raise RuntimeError( + f"DeviceMesh only support homogeneous hardware, but found " + f"{pg_world_size} ranks and {num_devices_per_host} {self.device_type} devices!" + ) + if self.device_type == "cuda": + + def _get_current_device(): + try: + if torch.cuda.is_available(): + return torch.cuda.current_device() + else: + return None + except AssertionError as e: + return None + + device_handle = _get_device_handle(self.device_type) + num_devices_per_host = device_handle.device_count() + local_rank = get_rank() % num_devices_per_host + if local_rank != _get_current_device(): + warnings.warn("Remember to set cuda device id to local rank!!!") + device_handle = _get_device_handle(self.device_type) + device_handle.set_device(local_rank) + + # step 2: validate the mesh before following usage. + if _validate_mesh: + self._validate_mesh(pg) + + # step 3: get coordinate of current global rank on the mesh. + # The world pg is used for device mesh identity (rank) on each + # process (we need to know if the current global rank is in the mesh or not) + rank_coords = (self.mesh == get_rank()).nonzero() + assert rank_coords.size(0) in (0, 1) + self._coordinate_on_dim: Optional[List[int]] = rank_coords[0].tolist() if rank_coords.size(0) > 0 else None + + # step 4: init multi subprocess group for the mesh object. + if _init_process_groups: + self._init_process_groups(pg) + + def _get_or_create_default_group(self): + default_initialized = is_initialized() + if not default_initialized: + init_process_group() + + world_size = get_world_size() + if self.mesh.numel() > world_size: + raise RuntimeError( + f"Mesh should not be bigger than default world size, but found {self.mesh.numel()} ranks!" + ) + + device_handle = _get_device_handle(self.device_type) + # TODO: if user want to pass pg_options, offer a way to do it + if not default_initialized and device_handle: + # automatically set the current cuda/cuda-like device base on num of gpu devices available in each host + # NOTE: This device selection would only work for homogeneous hardware. + num_devices_per_host = device_handle.device_count() + if world_size > num_devices_per_host and world_size % num_devices_per_host != 0: + raise RuntimeError( + f"DeviceMesh only support homogeneous hardware, but found " + f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!" + ) + device_handle.set_device(get_rank() % num_devices_per_host) + + return _get_default_group() + + def _validate_mesh(self, pg: ProcessGroup): + # validate rank uniqueness in mesh tensor + unique_mesh_values = self.mesh.unique(sorted=True) + if unique_mesh_values.numel() != self.mesh.numel(): + raise RuntimeError(f"DeviceMesh cannot have duplicate values, but found {self.mesh.tolist()}") + # validate size + if self.mesh.numel() > _get_group_size(pg): + raise RuntimeError( + f"DeviceMesh should not be bigger than world (group) size, but found {self.mesh.numel()} and {_get_group_size(pg)}" + ) + # validate that all calling ranks pass in the same `mesh` argument. + self_mesh = self.mesh.to(self.device_type).contiguous() + mesh_tensor = funcol.all_gather_tensor(self_mesh, gather_dim=0, group=pg) + mesh_tensor_chunked = torch.chunk(mesh_tensor, _get_group_size(pg)) + # aten.equal not supported for meta device + if self.device_type == "meta": + return + for other_rank, other_mesh in enumerate(mesh_tensor_chunked): + if not torch.equal(self_mesh, other_mesh): + raise RuntimeError( + f"DeviceMesh initialization does not allow different mesh argument:" + f"rank {get_rank()} has mesh {self_mesh} while rank {get_process_group_ranks(pg)[other_rank]}" + f"has mesh {other_mesh}!" + ) + + def _init_process_groups(self, pg: ProcessGroup): + # group tag/ranks associated with each mesh dimension, each mesh dimension should + # have one sub-group per rank + dim_group_infos: List[Tuple[str, List[int]]] = [] + + if self.mesh.ndim == 1 and self.mesh.numel() == _get_group_size(pg): + # if the mesh is the same as the given group, we just append the given + # pg to the first dim groups. + dim_group_infos.append((_get_group_tag(pg), get_process_group_ranks(pg))) + else: + # create sub pgs base on the mesh argument specified + for dim in range(self.mesh.ndim): + # swap the current dim to the last dim + # then reshape to flatten out other dims + pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape(-1, self.mesh.size(dim)) + # multi-dim mesh, create subgroups by looping over the pg_ranks + # for each dim and append the groups + for dim_mesh in pg_ranks_by_dim: + subgroup_ranks = dim_mesh.tolist() + # call new_group regardless of the current rank in the + # pg or not, it's required that all ranks participate + # in subgroup construction + dim_group = new_group(ranks=subgroup_ranks) + # only add to dim_groups if the current rank in the subgroup + if self.get_rank() in subgroup_ranks: + if len(dim_group_infos) > dim: + raise RuntimeError( + f"Each device mesh dimension should get only one process group, but got {self.get_rank} " + f"in {subgroup_ranks}!" + ) + dim_group_infos.append((_get_group_tag(dim_group), subgroup_ranks)) + self._dim_group_infos = dim_group_infos + + def __enter__(self) -> "DeviceMesh": + # set this mesh as the current mesh in mesh env + mesh_resources.mesh_stack.append(self) + return self + + # pyre-fixme[2]: Parameter must be annotated. + def __exit__(self, exc_type, exc_value, exc_traceback) -> None: + # pop this mesh from mesh env + mesh_resources.mesh_stack.pop() + + def __repr__(self) -> str: + return f"DeviceMesh:({self.mesh.tolist()})" + + def __hash__(self): + # ideally, we should use object id as hash, because different device mesh objects + # give different subprocess group, so different device meshes. + # in practice of sharding propagation, + # we only care about different mesh tensor (value, shape). + return self._hash + + def __eq__(self, other: object) -> bool: + if not isinstance(other, DeviceMesh): + return False + if id(self.mesh) == id(other.mesh): # short-cut eq + return True + if self.device_type != other.device_type: + return False + return self.mesh.shape == other.mesh.shape and self._flatten_mesh_list == other._flatten_mesh_list + + def __getitem__(self, mesh_dim_name: str) -> "DeviceMesh": + """ + Slice the current DeviceMesh based on the mesh_dim_name given to create a child + DeviceMesh. + + Args: + mesh_dim_name (str): the name of the mesh dimension of the parent DeviceMesh + to create a child DeviceMesh for. + Returns: + A :class:`DeviceMesh` object + + Example (2 host with 4 GPUs each): + ``` + # Below is a DeviceMesh with mesh_shape of (2, 4) and mesh_dim_name of ("dp", "tp") + mesh = DeviceMesh(device_type="cuda", + mesh=[ + [0, 1, 2, 3], + [4, 5, 6, 7] + ], + mesh_dim_names=["dp", "tp"]) + ) + ``` + Calling mesh["tp"] on rank 0, 1, 2, 3 would return a 1D child DeviceMesh:([0, 1, 2, 3]). + Calling mesh["tp"] on rank 4, 5, 6, 7 would return a 1D child DeviceMesh:([4, 5, 6, 7]). + Calling mesh["dp"] on rank 0, 4 would return a 1D child DeviceMesh:([0, 4]). + Calling mesh["dp"] on rank 1, 5 would return a 1D child DeviceMesh:([1, 5]). + Calling mesh["dp"] on rank 2, 6 would return a 1D child DeviceMesh:([2, 6]). + Calling mesh["dp"] on rank 3, 7 would return a 1D child DeviceMesh:([3, 7]). + """ + if self.mesh.ndim <= 1: + raise RuntimeError(f"Cannot slice a DeviceMesh with {self.mesh.ndim} dimension.") + if self.mesh_dim_names is None: + raise KeyError( + "No `mesh_dim_names` found.", + "To slice the device mesh, please call `init_device_mesh` with `mesh_dim_names`.", + ) + if mesh_dim_name not in self.mesh_dim_names: + raise KeyError( + f"Mesh dimension '{mesh_dim_name}' does not exist.", + f"Available mesh dimensions are: {self.mesh_dim_names}", + ) + mesh_dim = self.mesh_dim_names.index(mesh_dim_name) + submesh = mesh_resources.create_child_mesh(self, mesh_dim, mesh_dim_name) + + return submesh + + def get_dim_groups(self, mesh_dim: Optional[int] = None) -> Union[ProcessGroup, List[ProcessGroup]]: + if not hasattr(self, "_dim_group_infos"): + raise RuntimeError("DeviceMesh process groups not initialized!") + if mesh_dim is not None: + return _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim]) + else: + dim_groups = [] + for mesh_dim in range(self.mesh.ndim): + dim_groups.append(_find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim])) + return dim_groups + + def size(self, dim: Optional[int] = None) -> int: + return self.mesh.numel() if dim is None else self.mesh.size(dim) + + @property + def ndim(self) -> int: + return self.mesh.ndim + + @property + def ndevice(self) -> int: + return torch.numel(self.mesh) + + @property + def shape(self) -> Tuple[int, ...]: + return tuple(self.mesh.shape) + + def get_rank(self) -> int: + return get_rank() + + def get_local_rank(self, mesh_dim: Optional[int] = None) -> int: + """ + Returns the local rank of the given mesh_dim of the DeviceMesh. + + Args: + mesh_dim (int, optional): it is the index of the mesh dimension. Default is None. + + Returns: + An integer denotes the local rank. + + The following program runs on each process/rank in an SPMD manner. In this example, we have 2 + hosts with 4 GPUs each. + Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 0, 1, 2, 3 would return 0. + Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 4, 5, 6, 7 would return 1. + Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 0, 4 would return 0. + Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 1, 5 would return 1. + Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 2, 6 would return 2. + Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 3, 7 would return 3. + """ + if self.ndim > 1 and mesh_dim is None: + raise RuntimeError( + f"Found the DeviceMesh have {self.mesh.ndim} dimensions", + "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", + ) + elif mesh_dim is None: + mesh_dim = 0 + + mesh_dim_group = self.get_dim_groups(mesh_dim) + assert isinstance(mesh_dim_group, ProcessGroup), "We expect ProcessGroup before calling `get_rank`!" + + return get_rank(mesh_dim_group) + + def get_coordinate(self) -> Optional[List[int]]: + """ + Return the relative indices of this rank relative to all + dimensions of the mesh. If this rank is not part of the mesh, return None. + """ + return self._coordinate_on_dim if self._coordinate_on_dim else None + + def enforce_cpu_mesh_tensor(self) -> None: + """ + move `mesh` tensor to cpu for deterministic device; + necessary for comparison and checkpoint loading. + """ + with torch.no_grad(): + self.mesh = self.mesh.cpu() + + def get_submesh(self, mesh_dims: Union[List[int], List[str]]) -> "DeviceMesh": + dims = [] + for dim in mesh_dims: + if isinstance(dim, int): + dims.append(dim) + elif isinstance(dim, str): + assert dim in self.mesh_dim_names, f"Mesh dimension '{dim}' does not exist." + dims.append(self.mesh_dim_names.index(dim)) + return mesh_resources.create_submesh_along_multi_dims(self, dims) + + def get_all_submesh(self, dim: int or str) -> List["DeviceMesh"]: + if isinstance(dim, str): + assert dim in self.mesh_dim_names, f"Mesh dimension '{dim}' does not exist." + mesh_dim = self.mesh_dim_names.index(dim) + else: + mesh_dim = dim + return mesh_resources.create_submesh_group(self, mesh_dim) + + def get_mapping_rank(self, other: "DeviceMesh"): + """ + for cross mesh resharding + we assume that the mesh is 1,2,4,8 + the size will have gcd value + """ + mesh_list = self.mesh.view(-1).tolist() + index = mesh_list.index(self.get_rank()) + other_mesh_list = other.mesh.view(-1).tolist() + gcd_value = math.gcd(len(mesh_list), len(other_mesh_list)) + if gcd_value == 1 and len(mesh_list) != 1 and len(other_mesh_list) != 1: + raise RuntimeError(f"mesh resharding the wrong shape of device mesh {mesh_list} vs {other_mesh_list}") + + a = len(mesh_list) + b = len(other_mesh_list) + factor = max(a, b) // min(a, b) + + if a > b: # group down + data = {} + for i in range((index // factor) * factor, factor): + data.update({mesh_list[index]: other_mesh_list[index // factor]}) + return data + elif a < b: # group up + return [other_mesh_list[i] for i in range(index * factor, (index + 1) * factor)] + else: + return other_mesh_list[index] + + +def init_device_mesh( + device_type: str, + mesh_shape: Tuple[int, ...], + *, + mesh_dim_names: Optional[Tuple[str, ...]] = None, +) -> DeviceMesh: + """ + Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters. + This creates a DeviceMesh with a mesh layout of n-d dimensional array, n being the len(mesh_shape) + and ith dimension being in size mesh_shape[i]. If mesh_dim_names is provided, each dimension is + labeled as mesh_dim_names[i]. + + + Args: + device_type (str): device type of the mesh. Currently supports: cpu, cuda/cuda-like. + mesh_shape: Tuple[int]: A tuple describes the dimension of the multi-dimesnion array + that describes the layout of devices. + Kwargs: + mesh_dim_names: Optional[Tuple[str]]: A tuple of mesh dim names to be assigned to each dimension + of the multi-dimensional array that describes the layout of devices. Its length must match the length + of `mesh_shape`. Each string in mesh_dim_names must be unique. + + Returns: + A :class:`DeviceMesh` object + + .. note: If no process group is found, init_device_mesh will initialize distributed process group/groups + behind the scene, which are required for distributed communications. + + Example: + >>> # xdoctest: +SKIP + >>> from torch.distributed._tensor.device_mesh import init_device_mesh + >>> + >>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,)) + >>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp")) + """ + if mesh_dim_names is not None: + if len(set(mesh_dim_names)) != len(mesh_dim_names): + raise RuntimeError( + "Each mesh_dim_name must be uqique.", + f"Found repeated mesh_dim_name in mesh_dim_names {mesh_dim_names}", + ) + + if len(mesh_shape) != len(mesh_dim_names): + raise RuntimeError( + "mesh_shape and mesh_dim_names should have same length!", + f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.", + ) + + mesh = torch.arange(math.prod(mesh_shape)).view(mesh_shape) + device_mesh = DeviceMesh( + device_type=device_type, + mesh=mesh, + mesh_dim_names=mesh_dim_names, + ) + + return device_mesh diff --git a/internlm/checkpoint/vescale/devicemesh_api.py b/internlm/checkpoint/vescale/devicemesh_api.py new file mode 100644 index 000000000..086604e31 --- /dev/null +++ b/internlm/checkpoint/vescale/devicemesh_api.py @@ -0,0 +1,475 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +import torch +import warnings +from torch.distributed import get_rank +from .device_mesh import init_device_mesh, DeviceMesh +from typing import Optional, List, Tuple, Union, Dict +from torch.distributed.distributed_c10d import ProcessGroup + +__all__ = ["VESCALE_DEVICE_MESH"] + + +class VeDeviceMesh: + _MESH_DIM_NAMES_MAPPING: Dict[int, str] = {} + _MESH_DIM_NAMES_LOOKUP: List[str] = None + _TENSOR_PARALLEL_SIZE: int = None + _DATA_PARALLEL_SIZE: int = None + _PIPELINE_PARALLEL_SIZE: int = None + _DATA_PARALLEL_GROUP: ProcessGroup = None + _TENSOR_PARALLEL_GROUP: ProcessGroup = None + _GLOBAL_MESH: DeviceMesh = None + _MESH_GRIDS: torch.Tensor = None + _DATA_PARALLEL_MESH: DeviceMesh = None + _TENSOR_PARALLEL_MESH: DeviceMesh = None + _GLOBAL_PIPELINE_MODEL_PARALLEL_MESHES: List[DeviceMesh] = None + _GLOBAL_TENSOR_PARALLEL_MESHES: List[DeviceMesh] = None + _RANK_COORDINATE: List[int] = None + DEFAULT_DEVICE_COUNT: int = ( + torch.cuda.device_count() if torch.cuda.is_available() else 8 + ) # enables 8 ranks for CPU multi-processing + PP_DIM: int = 0 + + def init_device_mesh( + self, + device_type: str, + mesh_shape: Tuple[int, ...], + *, + mesh_dim_names: Optional[Tuple[str, ...]] = None, + check_uniqueness: bool = False, + ) -> DeviceMesh: + """Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters. + This creates a DeviceMesh with a mesh layout of n-d dimensional array, n being the len(mesh_shape) + and ith dimension being in size mesh_shape[i]. If mesh_dim_names is provided, each dimension is + labeled as mesh_dim_names[i]. Inherit this utility from upstream DeviceMesh. + + Syntax of (global) DeviceMesh created by our API: + Dimensions follow a left-to-right, inter-instance to intra-instance fashion: i.e. + 1. Dimensions of 3-dimensional global DeviceMesh: [PIPELINE_PARALLEL_DIM, DATA_PARALLEL_DIM, TENSOR_PARALLEL_DIM] + - When PIPELINE_PARALLEL_DIM > 1, 1). DATA_PARALLEL_DIM=1, or 2). TENSOR_PARALLEL_DIM=1, or + 3). DATA_PARALLEL_DIM=1, or 2). TENSOR_PARALLEL_DIM=1, DeviceMesh is written in 3-dimensional + 2. Dimensions of 2-dimensional global DeviceMesh: [DATA_PARALLEL_DIM, TENSOR_PARALLEL_DIM] + 3. Dimensions of 1-dimensional global DeviceMesh: [DATA_PARALLEL_DIM or TENSOR_PARALLEL_DIM] + - 1-dimensional DeviceMesh can be used to specify process groups of data parallel and tensor model parallel dimensions + + Args: + device_type (str): device type of the mesh. Currently supports: cpu, cuda/cuda-like. + mesh_shape: Tuple[int]: A tuple describes the dimension of the multi-dimesnion array + that describes the layout of devices. + Kwargs: + mesh_dim_names: Optional[Tuple[str]]: A tuple of mesh dim names to be assigned to each dimension + of the multi-dimensional array that describes the layout of devices. Its length must match the length + of `mesh_shape`. Each string in mesh_dim_names must be unique. Note that if mesh_dim_names is None, + the function will provide a default mesh identifiers. + + check_uniqueness (bool): This advanced argument is used to prevent users from spoiling global + DeviceMesh API by creating multiple copies in a large code repository. + Set to True to allow VESCALE_DEVICE_MESH API to check the "global device mesh" is only initialized once. + Otherwise, users can create as many DeviceMeshes as they want just like with upstream Devicemesh. + + Returns: + A :class:`DeviceMesh` object + + .. note: If no process group is found, init_device_mesh will initialize distributed process group/groups + behind the scene, which are required for distributed communications. + + Example: + >>> # xdoctest: +SKIP + >>> from vescale.devicemesh_api import VESCALE_DEVICE_MESH + >>> + >>> # Example 1: initialize the global DeviceMesh as a one-dimensional DeviceMesh + >>> VESCALE_DEVICE_MESH.init_device_mesh("cuda", mesh_shape=(8,)) + >>> + >>> # Example 2: re-initialize the global DeviceMesh as a two-dimensional DeviceMesh + >>> VESCALE_DEVICE_MESH.init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp")) + + Limitation: we currently only support fixed sized DeviceMesh with 1 to 3 dimensions. We will loosen this constraint in future. + """ + if device_type.startswith("cuda") and device_type != "cuda": + warnings.warn("'cuda:' is invalid ! Convert to pure 'cuda'!") + device_type = "cuda" + assert device_type in ("cuda", "cpu", "meta"), "Supports only three device types: cuda, cpu, meta!" + if self._GLOBAL_MESH is None or not check_uniqueness: + self._TENSOR_PARALLEL_SIZE = self._DATA_PARALLEL_SIZE = self._PIPELINE_PARALLEL_SIZE = None + self._MESH_DIM_NAMES_MAPPING = {} + if mesh_dim_names is None: + # Support two default sets of default mesh dimensions: 2-dim [dp, tp], and 3-dim [pp, dp, tp] + mesh_dim_names = ["PP", "DP", "TP"][-len(mesh_shape) :] + if device_type is None: + device_type = "cuda" + self._GLOBAL_MESH = init_device_mesh(device_type, mesh_shape, mesh_dim_names=mesh_dim_names) + self._MESH_GRIDS = self._GLOBAL_MESH.mesh.clone().detach().cpu() + if len(mesh_shape) == 3: + self._PIPELINE_PARALLEL_SIZE, self._DATA_PARALLEL_SIZE, self._TENSOR_PARALLEL_SIZE = mesh_shape + elif len(mesh_shape) == 2: + self._DATA_PARALLEL_SIZE, self._TENSOR_PARALLEL_SIZE = mesh_shape + else: + self._DATA_PARALLEL_SIZE = self._TENSOR_PARALLEL_SIZE = mesh_shape[0] + for idx, name in enumerate(mesh_dim_names[::-1]): + self._MESH_DIM_NAMES_MAPPING[idx] = name + self._MESH_DIM_NAMES_LOOKUP = list(self._MESH_DIM_NAMES_MAPPING.values())[::-1] + self._RANK_COORDINATE = None + self._GLOBAL_PIPELINE_MODEL_PARALLEL_MESHES = None + self._GLOBAL_TENSOR_PARALLEL_MESHES = None + elif check_uniqueness: + raise ValueError( + "Already initialized the global DeviceMesh! Turn 'check_uniqueness' off to remove the contraint." + ) + return self._GLOBAL_MESH + + def get( + self, + **kwargs, + ) -> Optional[DeviceMesh]: + """ + Retrieves the global device mesh. If it has not been initialized, pass in + arguments to initialize one. + + Args: + **kwargs (dict): arguments to initialize the global device mesh. + + Returns: + A :class:`DeviceMesh` object + """ + if self._GLOBAL_MESH is None and kwargs: + self.init_device_mesh(**kwargs) + return self._GLOBAL_MESH + + def _get_tensor_parallel_mesh(self) -> DeviceMesh: + """ + This function works the same as get_tensor_parallel_mesh(), but + specifies _validate_mesh=False. + """ + assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" + if self._TENSOR_PARALLEL_MESH is None: + assert self._TENSOR_PARALLEL_GROUP is not None, "tensor model parallel group is not initialized" + assert self._MESH_DIM_NAMES_MAPPING + tensor_dim_name = self._MESH_DIM_NAMES_MAPPING[0] + TP_mesh = self.get()[tensor_dim_name] + self._TENSOR_PARALLEL_MESH = DeviceMesh( + device_type=TP_mesh.device_type, + mesh=TP_mesh.mesh, + pg=self._TENSOR_PARALLEL_GROUP, + _validate_mesh=False, + ) + return self._TENSOR_PARALLEL_MESH + + def _get_data_parallel_mesh(self) -> DeviceMesh: + """ + This function works the same as get_data_parallel_mesh(), but + specifies _validate_mesh=False. + """ + assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" + if self._DATA_PARALLEL_MESH is None: + assert self._DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized" + assert len(self._MESH_DIM_NAMES_MAPPING) >= 2 + data_dim_name = self._MESH_DIM_NAMES_MAPPING[1] + DP_mesh = self.get()[data_dim_name] + self._DATA_PARALLEL_MESH = DeviceMesh( + device_type=DP_mesh.device_type, mesh=DP_mesh.mesh, pg=self._DATA_PARALLEL_GROUP, _validate_mesh=False + ) + return self._DATA_PARALLEL_MESH + + def get_strategy_coordinate(self, local_rank=None) -> List[int]: + """ + Translate current local rank to a strategy coordinate of initialized strategy dimensions. + If local_rank is not provided, return coordinate of current rank. + The only difference of this function w.r.t. upstream DeviceMesh's get_coordinate() is that + it enables users query strategy coordinate of arbitrary ranks. + + Args: + local_rank (int): rank id. If local_rank is None, return the coordinate of the local rank. + + Returns: + Coordinate of local rank mapped to the global DeviceMesh's parallel dimensions. + + Example: + >>> from vescale.devicemesh_api import VESCALE_DEVICE_MESH + >>> dp_size, tp_size = 2, 2 + >>> # Initialize global device mesh of (dp_size=2, tp_size=2) + >>> VESCALE_DEVICE_MESH.init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("DP", "TP")) + >>> local_rank = torch.distributed.get_rank() # local_rank is 0 + 0 + >>> VESCALE_DEVICE_MESH.get_strategy_coordinate(local_rank) + [0, 0] + >>> VESCALE_DEVICE_MESH.get_strategy_coordinate(3) + [1, 1] + """ + assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" + if local_rank is None: + if self._RANK_COORDINATE is None: + self._RANK_COORDINATE = self.get_strategy_coordinate(self.get_local_rank()) + return self._RANK_COORDINATE + rank_coordinate = [int(item) for item in (self._MESH_GRIDS == local_rank).nonzero(as_tuple=True)] + return rank_coordinate + + def lookup_rank(self, dim: Union[int, str]) -> int: + """ + Look up the specified 'id' from a particular dimension of the + current rank's strategy coordinate. + + Args: + dim (Union[int, str]): Dimension indicator. + + Returns: + Specified parallel strategy 'rank' of a global rank. + + Example: + >>> from vescale.devicemesh_api import VESCALE_DEVICE_MESH + >>> dp_size, tp_size = 2, 2 + >>> # Initialize global device mesh of (dp_size=2, tp_size=2) + >>> VESCALE_DEVICE_MESH.init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("DP", "TP")) + >>> local_rank = torch.distributed.get_rank() # local_rank = 0 + 0 + >>> VESCALE_DEVICE_MESH.get_strategy_coordinate(local_rank) + [0, 0] + >>> index = 1 + >>> VESCALE_DEVICE_MESH.lookup_rank(index) # local_rank is 0 + 0 + >>> dim_name = "DP" + >>> VESCALE_DEVICE_MESH.lookup_rank(dim_name) # local_rank is 0 + 0 + """ + assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" + if isinstance(dim, int): + assert 0 <= dim < len(self._MESH_DIM_NAMES_MAPPING) + else: + assert dim in self._MESH_DIM_NAMES_MAPPING.values() + if self._RANK_COORDINATE is None: + self.get_strategy_coordinate() + if isinstance(dim, str): + index = self._MESH_DIM_NAMES_LOOKUP.index(dim) + return self._RANK_COORDINATE[index] + else: + return self._RANK_COORDINATE[dim] + + def get_strategy_size(self, dim: Union[int, str]) -> List[int]: + """ + Return the size of a parallel strategy dimension of the global DeviceMesh. + + Args: + dim (Union[int, str]): Dimension indicator. + + Returns: + Size of a strategt dimension. + """ + assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" + if isinstance(dim, int): + assert 0 <= dim < len(self._MESH_DIM_NAMES_MAPPING) + else: + assert dim in self._MESH_DIM_NAMES_MAPPING.values() + if isinstance(dim, str): + index = self._MESH_DIM_NAMES_LOOKUP.index(dim) + return self.size(index) + else: + return self.size(dim) + + def get_local_rank(self) -> int: + """ + Get rank ID based on this machine. + """ + assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" + local_device_count = torch.cuda.device_count() if torch.cuda.is_available() else self.DEFAULT_DEVICE_COUNT + return get_rank() % local_device_count + + def get_pipeline_parallel_rank(self) -> int: + """ + Get pipeline parallel rank (stage id) of local rank id. + """ + assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" + num_dims = len(self._MESH_DIM_NAMES_MAPPING) + assert num_dims <= 3 + if len(self._MESH_DIM_NAMES_MAPPING) == 3: + pipe_dim_name = self._MESH_DIM_NAMES_MAPPING[2] + return self.lookup_rank(pipe_dim_name) + else: + return 0 + + def get_data_parallel_rank(self) -> int: + """ + Get data parallel rank (stage id) of local rank id. + """ + assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" + assert len(self._MESH_DIM_NAMES_MAPPING) >= 2 + if len(self._MESH_DIM_NAMES_MAPPING) > 1: + data_dim_name = self._MESH_DIM_NAMES_MAPPING[1] + else: + data_dim_name = self._MESH_DIM_NAMES_MAPPING[0] + return self.lookup_rank(data_dim_name) + + def get_tensor_parallel_rank(self) -> int: + """ + Get tensor parallel rank (stage id) of local rank id. + """ + assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" + assert self._MESH_DIM_NAMES_MAPPING + tensor_dim_name = self._MESH_DIM_NAMES_MAPPING[0] + return self.lookup_rank(tensor_dim_name) + + def get_pipeline_parallel_mesh(self) -> DeviceMesh: + """ + Return the pipeline parallel view of the global DeviceMesh. + """ + assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" + assert len(self._MESH_DIM_NAMES_MAPPING) == 3 + pipe_dim_name = self._MESH_DIM_NAMES_MAPPING[0] + return self.get()[pipe_dim_name] + + def get_global_pipeline_parallel_meshes(self, device_type="cuda") -> list: + assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" + if self._GLOBAL_PIPELINE_MODEL_PARALLEL_MESHES is None: + meshes = [] + device_mesh = self.get() + for inner_group in device_mesh.mesh.tolist(): + meshes.append(DeviceMesh(device_type, inner_group, _validate_mesh=False)) + self._GLOBAL_PIPELINE_MODEL_PARALLEL_MESHES = meshes + return self._GLOBAL_PIPELINE_MODEL_PARALLEL_MESHES + + def get_data_parallel_mesh(self) -> DeviceMesh: # noqa: F811 + """ + Return the data parallel view of the global DeviceMesh. + """ + assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" + assert self._MESH_DIM_NAMES_MAPPING + dp_name = self._MESH_DIM_NAMES_MAPPING[1] if self.ndim > 1 else self._MESH_DIM_NAMES_MAPPING[0] + return self.get()[dp_name] + + def get_tensor_parallel_mesh(self) -> DeviceMesh: + """ + Return the tensor parallel view of the global DeviceMesh. + """ + assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" + assert self._MESH_DIM_NAMES_MAPPING + tp_name = self._MESH_DIM_NAMES_MAPPING[0] + return self.get()[tp_name] + + def get_global_tensor_parallel_meshes(self) -> list: + assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" + if self._GLOBAL_TENSOR_PARALLEL_MESHES is None: + assert len(self._MESH_DIM_NAMES_LOOKUP) == 3 + tp_meshes = [] + global_dm = self.get() + device_type = self.get_tensor_parallel_mesh().device_type + all_tp_list = global_dm.mesh.view(-1, global_dm.mesh.size(2)) + for tp_group in all_tp_list: + tp_mesh = DeviceMesh( + device_type, + tp_group, + _validate_mesh=False, + _init_process_groups=False, + ) + tp_meshes.append(tp_mesh) + self._GLOBAL_TENSOR_PARALLEL_MESHES = tp_meshes + return self._GLOBAL_TENSOR_PARALLEL_MESHES + + def is_first_stage(self) -> bool: + """ + Return if the current stage is the first stage, if using pipeline parallelism. + """ + assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" + pp_rank = self.get_pipeline_parallel_rank() + return pp_rank == 0 + + def is_last_stage(self) -> bool: + """ + Return if the current stage is the last stage, if using pipeline parallelism. + """ + assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" + assert len(self._MESH_DIM_NAMES_MAPPING) == 3 + device_mesh = self.get() + num_stages = device_mesh.size(self.PP_DIM) + pp_rank = self.get_pipeline_parallel_rank() + return pp_rank == num_stages - 1 + + def __getitem__(self, mesh_dim_name: str) -> DeviceMesh: + """ + Slice the current DeviceMesh based on the mesh_dim_name given to create a child + DeviceMesh. Inherit this utility from upstream DeviceMesh. + + Args: + mesh_dim_name (str): mesh dimension name. + + Returns: + a dimension "view" of the global DeviceMesh. + """ + assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" + device_mesh = self.get() + return device_mesh[mesh_dim_name] + + def get_data_parallel_dim_groups(self) -> ProcessGroup: + """ + Match process groups of data parallel dimension given + sizes of DeviceMesh. + """ + assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" + device_mesh = self.get() + dim_size = len(device_mesh.mesh.shape) + assert 1 <= dim_size <= 3 + if dim_size <= 2: + return device_mesh.get_dim_groups(0) + return device_mesh.get_dim_groups(1) + + def get_tensor_parallel_dim_groups(self) -> ProcessGroup: + """ + Return process group of the lowest dimension as + the dimension of tensor model parallelism. + """ + assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" + device_mesh = self.get() + assert 1 <= len(device_mesh.mesh.shape) <= 3 + return device_mesh.get_dim_groups(0) + + def get_coordinate(self) -> Optional[List[int]]: + """ + Return the relative indices of this rank relative to all + dimensions of the mesh. If this rank is not part of the mesh, return None. + Inherit this utility from upstream DeviceMesh. + """ + assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" + device_mesh = self.get() + return device_mesh.get_coordinate() + + def size(self, dim: Optional[int] = None) -> int: + """ + Returns dimension size of DeviceMesh along 'dim' dimension. If dim is None, + return the total number of ranks in this DeviceMesh. + + Args: + dim (int): dimension index + + Returns: + Dimension size, or total number of ranks if None. + """ + assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" + device_mesh = self.get() + return device_mesh.mesh.numel() if dim is None else device_mesh.mesh.size(dim) + + @property + def ndim(self) -> int: + assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" + device_mesh = self.get() + return device_mesh.mesh.ndim + + @property + def shape(self) -> Tuple[int, ...]: + assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" + device_mesh = self.get() + return tuple(device_mesh.mesh.shape) + + +VESCALE_DEVICE_MESH = VeDeviceMesh() diff --git a/internlm/checkpoint/vescale/distributed_optimizer.py b/internlm/checkpoint/vescale/distributed_optimizer.py new file mode 100644 index 000000000..4f861a7f9 --- /dev/null +++ b/internlm/checkpoint/vescale/distributed_optimizer.py @@ -0,0 +1,109 @@ +import math +import inspect +from dataclasses import dataclass +from typing import Dict, Sequence, Tuple, Optional, Any +import torch +import torch.distributed as dist + + + +class Range: + """ + A range represents a start and end points for indexing a shard + from a full tensor. + """ + + def __init__(self, start, end): + self.start = start + self.end = end + self.size = end - start + + def normalize(self, start=0): + return Range(start, start + self.size) + + def __str__(self): + return "%d,%d [%d]" % (self.start, self.end, self.size) + + def __len__(self): + return self.end - self.start + + def __repr__(self) -> str: + return "Range(%d,%d [%d])" % (self.start, self.end, self.size) + +@dataclass +class OptimizerStateSpec: + """This class represents mapping between local flattened 1D tensor + and global original DTensor in DOptimzier, it is used for + loading or saving optimizer states using vescale.checkpoint (PyTorch DCP) + and load-time checkpoint resharding when changing tp size or dp size. + + For example, a linear layer in Vescale is DTensor(size=[1024, 1024]) + It first divides into two parts along dim=0 with tensor parallel size = 2 + + tensor_part_0 = DTensor(size=[512, 1024]) + tensor_part_1 = DTensor(size=[512, 1024]) + + Then each part's optimizer states are initalized in DOptimizer sepearately + + Assume dp=2 + For process with dp=0 tp=0, the flatten tensor is torch.Tensor(size=[262144]) + global_shape=(1024, 1024), local_shape=(256, 1024), global_offset=(0, 0) local=torch.Tensor(size=[262144]).view(local_shape) + + For process with dp=1 tp=0, the flatten tensor is torch.Tensor(size=[262144]) + global_shape=(1024, 1024), local_shape=(256, 1024), global_offset=(256, 0) local=torch.Tensor(size=[262144]).view(local_shape) + + For process with dp=0 tp=1, the flatten tensor is torch.Tensor(size=[262144]) + mapping to [512:768, 0:1024] in original DTensor + global_shape=(1024, 1024), local_shape=(256, 1024), global_offset=(512, 0) local=torch.Tensor(size=[262144]).view(local_shape) + + For process with dp=1 tp=1, the flatten tensor is torch.Tensor(size=[262144]) + global_shape=(1024, 1024), local_shape=(256, 1024), global_offset=(768, 0) local=torch.Tensor(size=[262144]).view(local_shape) + """ + + # The original DTensor shape + global_shape: Tuple[int] + # The local tensor shape ***before flattened into 1D tensor*** + local_shape: Tuple[int] + # The local tensor's offset with respect to origianl DTensor + global_offset: Tuple[int] + # The unflattened local tensor after create view using local_shape on the flattened 1D Tensor in DOptimizer + # NOTE: In order to support TP resharding and state cross dp ranks, we defer the reshaping from 1D to local_shape + # to generate saving plan using vescale.checkpoint (PyTorch DCP) + local_tensor: torch.Tensor + # If the current optimizer state is sharded by multiple dp ranks, + # we should record all ranks and their ranges + dp_ranks_ranges: Optional[Dict[int, Range]] + +def convert_dict_with_sharded( + param_state: dict, + global_shape: Tuple[int], + local_shape: Tuple[int], + global_offset: Tuple[int], + dp_ranks_ranges: Optional[Dict[int, Range]], +): + new_param_state = {} + for k, v in param_state.items(): + if isinstance(v, torch.Tensor) and v.dim() >= 1: + # Don't unflatten tensor here, see the comments above + if not dp_ranks_ranges: + if math.prod(local_shape) != math.prod(v.shape): + print(f"rank={dist.get_rank()} name={k} global shape={global_shape}\ + local_shape={local_shape} global_offset={global_offset} real shape={v.shape}") + raise AssertionError() + new_param_state[k] = OptimizerStateSpec( + global_shape, local_shape, global_offset, v, dp_ranks_ranges + ) # , process_group) + else: + new_param_state[k] = v + return new_param_state + +def convert_dict_sharded_to_tensor(param_state: dict, range_1d: Optional[Range]): + for k, v in param_state.items(): + if isinstance(v, OptimizerStateSpec): + # If the state is distributed on multiple dp ranks + # Get my parts + if range_1d: + param_state[k] = v.local_tensor.flatten()[range_1d.start : range_1d.end] + else: + param_state[k] = v.local_tensor.flatten() + return param_state \ No newline at end of file diff --git a/internlm/checkpoint/vescale/filesystem.py b/internlm/checkpoint/vescale/filesystem.py new file mode 100644 index 000000000..5fd8030a1 --- /dev/null +++ b/internlm/checkpoint/vescale/filesystem.py @@ -0,0 +1,905 @@ +################################################################################ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +################################################################################ +# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. +################################################################################ +from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor +from .mem_checkpoint import copy_gpu_tensor_to_cpu_pinned_mem_pool, deallocate_cpu_tensor_in_pinned_mem_pool +from abc import ABC, abstractmethod +import collections +from dataclasses import dataclass +import os +import dataclasses +import io +import torch.distributed as dist +import pickle +from typing import List, Tuple, Union, Dict, cast, Any +from internlm.utils.logger import get_logger +import time +import torch +from torch import Tensor +from torch.futures import Future +from pathlib import Path +from internlm.core.context import global_context as gpc +from internlm.core.context import ParallelMode +from internlm.train.pipeline import map_fqn_global_to_local + + +from torch.distributed.checkpoint.metadata import ( + Metadata, + MetadataIndex, +) +from torch.distributed.checkpoint.storage import ( + StorageReader, + StorageWriter, + WriteResult, +) + +from torch.distributed.checkpoint.planner import ( + LoadItemType, + LoadPlanner, + LoadPlan, + SavePlan, + SavePlanner, + WriteItem, + ReadItem, + WriteItemType, +) + +from torch.distributed.checkpoint.utils import _create_file_view + +from torch.distributed._shard._utils import narrow_tensor_by_index +from torch._utils import _get_device_module + +logger = get_logger(__file__) +from .common import P2PTensorsInfo + +__all__ = [ + "FileSystemWriter", + "FileSystemReader", +] + + +@dataclass +class _StorageInfo: + """ + This is the per entry storage info + """ + + relative_path: str + offset: int + length: int + + +@dataclass +class _StoragePrefix: + prefix: str + + +DEFAULT_SUFFIX = ".distcp" + + +def _trim(tensor: torch.Tensor) -> torch.Tensor: + tensor = copy_gpu_tensor_to_cpu_pinned_mem_pool(tensor.detach()) + # Comment the original DCP code + # When dumping to pinned memory, + # the memory layout for tensor has been contiguous + # if tensor._typed_storage()._size() != tensor.numel(): + # tensor = tensor.clone() + return tensor + + +def _result_from_write_item(item: WriteItem, size_in_bytes, storage_data) -> WriteResult: + return WriteResult(index=item.index, size_in_bytes=size_in_bytes, storage_data=storage_data) + + +class _TensorLoader(ABC): + @abstractmethod + def add(self, fqn, size, obj): + pass + + @abstractmethod + def start_loading(self): + pass + + @abstractmethod + def values(self): + pass + + +def collect_optim_state_across_dp_ranks( + tensor: torch.Tensor, rank_ranges: Dict[int, Any], p2p_reqs: Dict[int, Any] +) -> torch.Tensor: + orignal_shape = tensor.shape + tensor = tensor.flatten() + logger.debug("DEBUG: Start receiving p2p tensor") + recv_start = time.time() + for req in p2p_reqs: + req.wait() + recv_end = time.time() - recv_start + logger.debug(f"DEBUG: Finish receiving p2p tensor. Time cost: {recv_end}s") + for v in rank_ranges.values(): + received_tensor, param_range = v + tensor[param_range.start : param_range.end] = received_tensor + tensor = tensor.reshape(orignal_shape) + return tensor + + +class _SerialCpuLoader(_TensorLoader): + def __init__(self, resolve_fun, p2p_tensors_info: P2PTensorsInfo = None): + self.resolve_fun = resolve_fun + self.items = [] + self.p2p_tensors_info = p2p_tensors_info + + def add(self, fqn, size, obj): + self.items.append((fqn, size, obj)) + + def start_loading(self): + pass + + def values(self): + for fqn, _, obj in self.items: + tensor = self.resolve_fun(obj).detach() + if self.p2p_tensors_info and (obj.index.fqn, obj.index.offset) in self.p2p_tensors_info.recv_tensors: + tensor = collect_optim_state_across_dp_ranks( + tensor=tensor, + rank_ranges=self.p2p_tensors_info.recv_tensors[(obj.index.fqn, obj.index.offset)], + p2p_reqs=self.p2p_tensors_info.recv_p2p_reqs[(obj.index.fqn, obj.index.offset)], + ) + elif self.p2p_tensors_info and fqn in self.p2p_tensors_info.recv_tensors: + tensor = collect_optim_state_across_dp_ranks( + tensor=tensor, rank_ranges=self.p2p_tensors_info.recv_tensors[fqn], p2p_reqs=self.recv_p2p_reqs[fqn] + ) + tensor = copy_gpu_tensor_to_cpu_pinned_mem_pool(tensor) + # Comment the original DCP code + # When dumping to pinned memory, + # the memory layout for tensor has been contiguous + # if tensor.storage().size() != tensor.numel(): + # tensor = tensor.clone() + yield ( + tensor, + obj, + ) + + +class _OverlappingCpuLoader(_TensorLoader): + def __init__( + self, + resolve_fun, + p2p_tensors_info: P2PTensorsInfo = None, + stream=None, + inflight_threshhold=1_000_000, + ): + self.resolve_fun = resolve_fun + self.items = [] + self.inflight_threshhold = inflight_threshhold + self.in_flight_data = 0 + self.current_items: collections.deque = collections.deque() + self.idx = 0 + self.started = False + self.device_type = stream.device_type if stream else torch.device("cuda").type + self.device_module = _get_device_module(self.device_type) + self.p2p_tensors_info = p2p_tensors_info + self.stream = stream or self.device_module.current_stream() + if self.stream != self.device_module.current_stream(): + self.stream.wait_stream(self.device_module.current_stream()) + + @property + def _done(self): + return self.idx >= len(self.items) + + def _drain(self): + drained = [] + if self.in_flight_data >= self.inflight_threshhold: + self.stream.synchronize() + while self.in_flight_data >= self.inflight_threshhold: + val = self.current_items.popleft() + self.in_flight_data -= val[0].numel() * val[0].element_size() + drained.append(val) + return drained + + def _refill(self): + with self.device_module.stream(self.stream): + while not self._done and self.in_flight_data < self.inflight_threshhold: + fqn, _, obj = self.items[self.idx] + self.idx += 1 + tensor = self.resolve_fun(obj).detach() + if self.p2p_tensors_info and (obj.index.fqn, obj.index.offset) in self.p2p_tensors_info.recv_tensors: + tensor = collect_optim_state_across_dp_ranks( + tensor=tensor, + rank_ranges=self.p2p_tensors_info.recv_tensors[(obj.index.fqn, obj.index.offset)], + p2p_reqs=self.p2p_tensors_info.recv_p2p_reqs[(obj.index.fqn, obj.index.offset)], + ) + elif self.p2p_tensors_info and fqn in self.p2p_tensors_info.recv_tensors: + tensor = collect_optim_state_across_dp_ranks( + tensor=tensor, + rank_ranges=self.p2p_tensors_info.recv_tensors[fqn], + p2p_reqs=self.p2p_tensors_info.recv_p2p_reqs[fqn], + ) + if tensor.device.type == self.device_type: + tensor = copy_gpu_tensor_to_cpu_pinned_mem_pool(tensor, non_blocking=True) + # Comment the original DCP code + # When dumping to pinned memory, the memory layout for tensor has been contiguous + # elif tensor.device == torch.device("cpu"): + # if tensor.storage().size() != tensor.numel(): + # # this forces the tensor to be both contiguous and with minimal storage + # tensor = tensor.clone() + + self.current_items.append( + ( + tensor, + obj, + ) + ) + self.in_flight_data += tensor.numel() * tensor.element_size() + + def _finish(self): + assert self._done + if len(self.current_items) > 0: + self.stream.synchronize() + return self.current_items + + def add(self, fqn, size, obj): + if self.started: + raise RuntimeError("cannot add items after loading started") + self.items.append((fqn, size, obj)) + + def start_loading(self): + if self.started: + return + self.started = True + self.items.sort(key=lambda x: x[1]) + self._refill() + + def values(self): + self.start_loading() + while not self._done: + drained = self._drain() + self._refill() + yield from drained + + yield from self._finish() + + +def _item_fqn(item: WriteItem) -> str: + return item.index.fqn + + +def _item_size(item: WriteItem) -> int: + size = 1 + assert item.tensor_data is not None + # can't use math.prod as PT needs to support older python + for s in item.tensor_data.size: + size *= s + + dtype = item.tensor_data.properties.dtype + return size * torch._utils._element_size(dtype) + + +def _split_by_size_and_type(bins, items: List[WriteItem]) -> List[List[WriteItem]]: + if bins == 1: + return [items] + + bytes_w = [wi for wi in items if wi.type == WriteItemType.BYTE_IO] + tensor_w = [wi for wi in items if wi.type != WriteItemType.BYTE_IO] + + buckets: List[List[WriteItem]] = [[] for _ in range(bins)] + bucket_sizes = [0 for _ in range(bins)] + + tensor_w.sort(key=_item_size, reverse=True) + + for i, wi in enumerate(bytes_w): + buckets[i % bins].append(wi) + + for wi in tensor_w: + idx = min(enumerate(bucket_sizes), key=lambda x: x[1])[0] + buckets[idx].append(wi) + bucket_sizes[idx] += _item_size(wi) + + return buckets + + +def _write_item(stream, data, write_item, storage_key): + offset = stream.tell() + + if write_item.type == WriteItemType.BYTE_IO: + assert isinstance(data, io.BytesIO) + stream.write(data.getbuffer()) + else: + assert isinstance(data, torch.Tensor) + assert data.device == torch.device("cpu") + torch.save(data, stream) + length = stream.tell() - offset + + return _result_from_write_item(write_item, length, _StorageInfo(storage_key, offset, length)) + + +def _write_files_from_queue( + file_name, + storage_key, + write_items, + planner: SavePlanner, + inflight_threshhold: int, + use_fsync: bool, + p2p_tensors_info: P2PTensorsInfo = None, +): + loader: _TensorLoader + + if torch.cuda.is_available() and inflight_threshhold > 0: + loader = _OverlappingCpuLoader( + lambda x: planner.resolve_data(x), + inflight_threshhold=inflight_threshhold, + p2p_tensors_info=p2p_tensors_info, + ) + else: + loader = _SerialCpuLoader(lambda x: planner.resolve_data(x), p2p_tensors_info=p2p_tensors_info) + + tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO] + for write_item in tensor_w: + loader.add(_item_fqn(write_item), _item_size(write_item), write_item) + loader.start_loading() + + bytes_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO] + if len(bytes_w) != 0: + assert False + write_results = [] + + stream = open(file_name, "wb") + logger.debug("Start writing byte io data.") + byte_io_write_start = time.time() + for write_item in bytes_w: + data = planner.resolve_data(write_item) + write_results.append(_write_item(stream, data, write_item, storage_key)) + byte_io_write_time = time.time() - byte_io_write_start + logger.debug(f"Finish writing byte io data. Time cost: {byte_io_write_time}s") + + logger.debug("Start writing tensor data.") + tensor_write_start = time.time() + for tensor, write_item in loader.values(): + assert tensor.is_cpu + write_results.append(_write_item(stream, tensor, write_item, storage_key)) + # WARNING: Call deallocate_cpu_tensor_in_pinned_mem_pooltensor + # when the reference to CPU tensor goes to zero + # so the memory pool will reuse the memory if possbile + # Othterwise, the memory pool will allocate memory on the used memory range, + # leading to cuda error 712 cudaErrorHostMemoryAlreadyRegistered + deallocate_cpu_tensor_in_pinned_mem_pool(tensor) + tensor_write_time = time.time() - tensor_write_start + logger.debug(f"Finish writing tensor data. Time cost: {tensor_write_time}s") + + if use_fsync: + os.fsync(stream.fileno()) + + file_stream_close_start = time.time() + stream.close() + file_stream_close_time = time.time() - file_stream_close_start + logger.debug(f"Finish closing file stream. Time cost: {file_stream_close_time}s") + return write_results + + +def _write_files_per_proc( + file_path: Path, + storage_key: str, + byte_data_item: List[Tuple[io.BytesIO, WriteItem]], + tensor_data_item: List[Tuple[torch.Tensor, WriteItem]], + use_fsync: bool, +) -> List[WriteResult]: + write_results = [] + stream = open(file_path, "wb") + # First write byte data. + for write_data, write_item in byte_data_item: + write_results.append(_write_item(stream, write_data, write_item, storage_key)) + # Then write tensor data. + # NOTE: the pinned memory occupied by each tensor have been reallocated. + for write_data, write_item in tensor_data_item: + write_results.append(_write_item(stream, write_data, write_item, storage_key)) + + if use_fsync: + os.fsync(stream.fileno()) + + return write_results + + +def _serialize_tensor(tensor: torch.Tensor) -> bytes: + bio = io.BytesIO() + # NOTE: currently use torch.save() to do the serialization. + torch.save(tensor, bio) + return bio.getbuffer() + + +def _write_to_file(stream, content: bytes, write_item: WriteItem, storage_key: str) -> WriteResult: + offset = stream.tell() + stream.write(content) + length = stream.tell() - offset + return _result_from_write_item(write_item, length, _StorageInfo(storage_key, offset, length)) + + +def _write_files_per_proc_pipe( + file_path: Path, + storage_key: str, + byte_data_item: List[Tuple[io.BytesIO, WriteItem]], + tensor_data_item: List[Tuple[torch.Tensor, WriteItem]], + use_fsync: bool, +) -> List[WriteResult]: + write_futures = [] + write_results = [] + stream = open(file_path, "wb") + executor = ThreadPoolExecutor(max_workers=1) + # For byte data, directly write byte data. + for write_data, write_item in byte_data_item: + content = write_data.getbuffer() + write_futures.append( + executor.submit( + _write_to_file, + stream, + content, + write_item, + storage_key, + ) + ) + # write_results.append(_write_to_file(stream, content, write_item, storage_key)) + # For tensor data, perform serialization in process then do saving in threadpool. + for write_data, write_item in tensor_data_item: + content = _serialize_tensor(write_data) + write_futures.append( + executor.submit( + _write_to_file, + stream, + content, + write_item, + storage_key, + ) + ) + # write_results.append(_write_to_file(stream, content, write_item, storage_key)) + + for fut in write_futures: + write_results.append(fut.result()) + if use_fsync: + os.fsync(stream.fileno()) + executor.shutdown(wait=False) + return write_results + + +def stat_analysis(tasks, planner, p2p_tensors_info, use_fsync=True) -> List[WriteResult]: + """ + Analyzing the overhead of D2H transfer, serialization, and save operations. Assume that + all items are written into one file. + """ + # Step1, aysnc D2H, dumping objects to pinned share memory. + assert len(tasks) == 1, "please generate one write task for analysis" + loader = _SerialCpuLoader(lambda x: planner.resolve_data(x), p2p_tensors_info=p2p_tensors_info) + # Add Bytes. + byte_item_to_write = [] + for task in tasks: + _, _, write_items = task + byte_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO] + byte_item_to_write.extend(byte_w) + if len(byte_item_to_write) != 0: + assert False + # Add tenosrs. + tensor_item_to_write = [] + for task in tasks: + _, _, write_items = task + tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO] + tensor_item_to_write.extend(tensor_w) + for write_item in tensor_w: + loader.add(_item_fqn(write_item), _item_size(write_item), write_item) + loader.start_loading() + # Step1: dump to pinned memory pool. + d2h_dump_wait_start = time.time() + tensor_to_serialize: List[torch.Tensor] = [] + for tensor, write_item in loader.values(): + assert tensor.is_cpu + tensor_to_serialize.append(tensor) + deallocate_cpu_tensor_in_pinned_mem_pool(tensor) + d2h_dump_wait_time = torch.tensor(time.time() - d2h_dump_wait_start).cuda() + dist.all_reduce(d2h_dump_wait_time) + d2h_dump_wait_time = d2h_dump_wait_time.item() / dist.get_world_size() + if dist.get_rank() == 0: + logger.critical(f"End waiting for D2H tensors dumping Time: {d2h_dump_wait_time:.4f}s") + # Step2: call serialization workers to serialize objects. + serialize_wait_start = time.time() + tensor_data_to_write = [] + bio = io.BytesIO() + for tensor in tensor_to_serialize: + bio.seek(0) + bio.truncate(0) + torch.save(tensor, bio) + dump_b = bio.getvalue() + assert isinstance(dump_b, bytes) + tensor_data_to_write.append(dump_b) + serialize_wait_time = torch.tensor(time.time() - serialize_wait_start).cuda() + dist.all_reduce(serialize_wait_time) + serialize_wait_time = serialize_wait_time.item() / dist.get_world_size() + if dist.get_rank() == 0: + logger.critical(f"End waiting for serialization Time: {serialize_wait_time:.4f}s") + # Step3: save/upload the objects from memory to disk. + file_path = tasks[0][0] + storage_key = tasks[0][1] + write_results = [] + assert isinstance(file_path, Path) + save_upload_wait_start = time.time() + with open(file_path, "wb") as stream: + for write_item in byte_item_to_write: + offset = stream.tell() + data = planner.resolve_data(write_item) + stream.write(data.getbuffer()) + length = stream.tell() - offset + write_results.append(_result_from_write_item(write_item, length, _StorageInfo(storage_key, offset, length))) + for tensor_data, write_item in zip(tensor_data_to_write, tensor_item_to_write): + offset = stream.tell() + stream.write(tensor_data) + length = stream.tell() - offset + write_results.append(_result_from_write_item(write_item, length, _StorageInfo(storage_key, offset, length))) + if use_fsync: + os.fsync(stream.fileno()) + save_upload_wait_time = torch.tensor(time.time() - save_upload_wait_start).cuda() + dist.all_reduce(save_upload_wait_time) + save_upload_wait_time = save_upload_wait_time.item() / dist.get_world_size() + if dist.get_rank() == 0: + logger.critical(f"End waiting for tensors saving/uploading Time: {save_upload_wait_time:.4f}s") + return write_results + + +class FileSystemWriter(StorageWriter): + """ + Basic implementation of StorageWriter using file IO. + + This implementation makes the following assumptions and simplifications: + + * The checkpoint path is an empty or non-existing directory. + * File creation is atomic + + The checkpoint consist of one file per write request plus + a `.metadata` file with the serialized metadata. + + """ + + def __init__( + self, + path: Union[str, os.PathLike], + single_file_per_rank: bool = True, + sync_files: bool = True, + worker_count: int = 1, + per_process_copy_ahead: int = 10_000_000, + ) -> None: + """ + Initialize the writer pointing to `path` + + Args: + path: directory where the checkpoint will be written to. + single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. + sync_files : force files to be synced to permanent storage. Default to True. + worker_count: Number of IO workers (processes) to use to write. Default to 1. + per_process_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. + + N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. + """ + super().__init__() + self.path = Path(path) + self.single_file_per_rank = single_file_per_rank + self.sync_files = sync_files + self.worker_count = worker_count + self.per_process_copy_ahead = per_process_copy_ahead + + def set_up_storage_writer(self, is_coordinator: bool) -> None: + pass + + def prepare_local_plan(self, plan: SavePlan, p2p_tensors_info: P2PTensorsInfo = None) -> SavePlan: + self.path.mkdir(parents=True, exist_ok=True) + self.p2p_tensors_info = p2p_tensors_info + return plan + + def prepare_global_plan(self, global_plan: List[SavePlan]) -> List[SavePlan]: + new_plans = [ + dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_")) for i, plan in enumerate(global_plan) + ] + return new_plans + + def prepare_write_data(self, tasks: List[Tuple[Path, str, List[WriteItem]]], planner: SavePlanner): + """ + First stage of saving, Perform Copy data to CPU (D2H). + + Args: + tasks: partitoned tasks for workers to conduct serialization and the actual saving. + planner: save planner used to resolve the bytes and tensor data. + async_io: whether do asynchrous D2H. + + NOTE: Currently we do D2H synchronously. + """ + + byte_data_item_writes: List[List[Tuple[io.BytesIO, WriteItem]]] = [] + tensor_data_item_writes: List[List[Tuple[torch.Tensor, WriteItem]]] = [] + file_path_names: List[Tuple[Path, str]] = [] + + # Perform D2H in copy stream. + d2h_dump_start = time.time() + for task in tasks: + file_path, file_name, write_items = task + byte_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO] + tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO] + byte_data_item = [(planner.resolve_data(wi), wi) for wi in byte_w] + if len(byte_data_item) != 0: + assert False + tensor_data_item = [] + # Async copy to pinned CPU memory pool. + for item in tensor_w: + # att + fqn = _item_fqn(item) + if fqn in map_fqn_global_to_local: + print(f"_item_fqn: {fqn}, {map_fqn_global_to_local[fqn]}", flush=True) + fqn = map_fqn_global_to_local[fqn] + + tensor = planner.resolve_data(item, fqn).detach() + + + if self.p2p_tensors_info and fqn in self.p2p_tensors_info.recv_tensors: + tensor = collect_optim_state_across_dp_ranks( + tensor=tensor, + rank_ranges=self.p2p_tensors_info.recv_tensors[fqn], + p2p_reqs=self.p2p_tensors_info.recv_p2p_reqs[fqn], + ) + tensor = copy_gpu_tensor_to_cpu_pinned_mem_pool(tensor, non_blocking=True) + tensor_data_item.append((tensor, item)) + byte_data_item_writes.append(byte_data_item) + tensor_data_item_writes.append(tensor_data_item) + file_path_names.append((file_path, file_name)) + + d2h_dump_time = time.time() - d2h_dump_start + logger.debug(f"End waiting for D2H copy. Time cost: {d2h_dump_time}s") + + # Deallocate pinned memory. + # NOTE: when prepare_write_data() is called next time, make sure the previous save event is completed. + # Otherwise, tensors in pinned memory pool may be overwritten. + for tensor_data_item in tensor_data_item_writes: + for tensor, _ in tensor_data_item: + assert tensor.is_cpu + deallocate_cpu_tensor_in_pinned_mem_pool(tensor) + + return byte_data_item_writes, tensor_data_item_writes, file_path_names + + def write_data( + self, plan: SavePlan, planner: SavePlanner, async_io: bool = False, io_workers=False + ) -> Future[List[WriteResult]]: + storage_plan: _StoragePrefix = plan.storage_data + file_count = 0 + + def gen_file(): + nonlocal file_count + file_name = f"{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}" + file_count += 1 + return file_name + + tasks: List[Tuple[Path, str, List[WriteItem]]] = [] + # Generate K tasks where K is the number of worker_count. + # print(f"self.single_file_per_rank: {self.single_file_per_rank}", flush=True) + if self.single_file_per_rank: + for bucket in _split_by_size_and_type(self.worker_count, plan.items): + file_name = gen_file() + tasks.append((self.path / file_name, file_name, bucket)) + # Generate K tasks where K is the number of write items. + else: + for item in plan.items: + file_name = gen_file() + tasks.append((self.path / file_name, file_name, [item])) + logger.debug(f"Rank {dist.get_rank()} writes its checkpoint into {len(tasks)} files") + # Make sure the optimizer states across dp ranks + # has been sending to other ranks + # So the receiver can get it when writing tensors to local path + + if self.p2p_tensors_info: + logger.debug("Start waiting for sending p2p tensors futures") + p2p_tensor_send_wait_start = time.time() + for req in self.p2p_tensors_info.send_p2p_reqs: + req.wait() + p2p_tensor_send_wait_time = time.time() - p2p_tensor_send_wait_start + logger.debug(f"End waiting for sending p2p tensors futures Time: {p2p_tensor_send_wait_time}s") + + futures = [] + if not io_workers: + executor = ProcessPoolExecutor(max_workers=self.worker_count) + # executor = torch.multiprocessing.get_context("spawn").Pool(self.worker_count) + else: + executor = io_workers + + # ProcessPool VERSION. + if isinstance(executor, ProcessPoolExecutor): + # print(f"executor: ProcessPoolExecutor", flush=True) + byte_data_item_writes, tensor_data_item_writes, file_path_names = self.prepare_write_data(tasks, planner) + # print(f"byte_data_item_writes {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.TENSOR)}: {byte_data_item_writes}", flush=True) + # print(f"tensor_data_item_writes {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.TENSOR)}: {tensor_data_item_writes}", flush=True) + # print(f"file_path_names {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.TENSOR)}: {file_path_names}", flush=True) + for byte_data_item, tensor_data_item, file_path_name in zip( + byte_data_item_writes, tensor_data_item_writes, file_path_names + ): + file_path, storage_key = file_path_name + worker_args = (file_path, storage_key, byte_data_item, tensor_data_item, self.sync_files) + futures.append(executor.submit(_write_files_per_proc_pipe, *worker_args)) + # futures.append(self._serialize_workers.apply_async(_write_files_per_proc, worker_args)) + if async_io: + return futures + else: + logger.debug("Start waiting for writing futures (serilization + save)") + future_wait_start = time.time() + for fut in futures: + fut.result() + # fut.wait() + future_wait_time = time.time() - future_wait_start + logger.debug(f"End waiting for writing futures. Time cost: {future_wait_time}s") + return futures + else: + # print(f"executor: {executor}", flush=True) + # ThreadPool VERSION. + for task in tasks: + # print(f"task {gpc.get_global_rank()}: {task}", flush=True) + futures.append( + executor.submit( + _write_files_from_queue, + *task, + planner, + self.per_process_copy_ahead, + self.sync_files, + self.p2p_tensors_info, + ) + ) + if async_io: + return futures + else: + logger.debug("Start waiting for writing futures") + future_wait_start = time.time() + for fut in futures: + fut.result() + future_wait_time = time.time() - future_wait_start + logger.debug(f"End waiting for writing futures. Time cost: {future_wait_time}s") + return futures + + def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None: + storage_md = dict() + for wr_list in results: + storage_md.update({wr.index: wr.storage_data for wr in wr_list}) + metadata.storage_data = storage_md + with (self.path / ".metadata.tmp").open("wb") as metadata_file: + pickle.dump(metadata, metadata_file) + os.fsync(metadata_file.fileno()) + + (self.path / ".metadata.tmp").rename(self.path / ".metadata") + + +class FileSystemReader(StorageReader): + def __init__( + self, + path: Union[str, os.PathLike], + broadcast_tensors=False, + data_parallel_process_group=None, + ) -> None: + super().__init__() + self.path = path + self.storage_data: Dict[MetadataIndex, _StorageInfo] = dict() + self.broadcast_tensors = broadcast_tensors + self.data_parallel_process_group = data_parallel_process_group + + # If broadcast_tensors is enabled, the data_parallel_process_group is not none + if self.broadcast_tensors: + assert self.data_parallel_process_group + + def _slice_file(self, file, sinfo: _StorageInfo): + return _create_file_view(file, sinfo.offset, sinfo.length) + + def _get_file_path(self, relative_path): + file_path = os.path.join(self.path, relative_path) + return file_path + + def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: + # group requests by file + per_file: Dict[str, List[ReadItem]] = dict() + for read_item in plan.items: + item_md = self.storage_data[read_item.storage_index] + path = item_md.relative_path + per_file.setdefault(path, []).append(read_item) + + # If broadcasting model tensors is enabled, + # let processes with dp_rank=0 load models and broadcast them to other processes + if self.broadcast_tensors: + self.read_data_with_broadcast(per_file=per_file, planner=planner) + else: + # Otherwise, let all ranks load tensors from files + self.read_from_files(per_file=per_file, planner=planner) + + fut: Future = Future() + fut.set_result(None) + + return fut + + def read_from_files(self, per_file: Dict[str, List[ReadItem]], planner: LoadPlanner): + print(f"debugg per_file {gpc.get_global_rank()}, {gpc.get_local_rank(ParallelMode.PIPELINE)}: {len(per_file)}, {per_file}", flush=True) + for relative_path, reqs in per_file.items(): + file_path = self._get_file_path(relative_path) + print(f"debugg file_path {gpc.get_global_rank()}, {gpc.get_local_rank(ParallelMode.PIPELINE)}: {file_path}, {reqs}", flush=True) + with open(file_path, "rb") as file: + reqs = sorted(reqs, key=lambda req: self.storage_data[req.storage_index].offset) + for req in reqs: + item_md = self.storage_data[req.storage_index] + file_slice = self._slice_file(file, item_md) + if req.type == LoadItemType.BYTE_IO: + bytes = io.BytesIO(file_slice.read(item_md.length)) + bytes.seek(0) + planner.load_bytes(req, bytes) + else: + tensor = cast(Tensor, torch.load(file_slice, map_location="cpu")) + tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths) + target_tensor = planner.resolve_tensor(req).detach() + + assert ( + target_tensor.size() == tensor.size() + ), f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}" + target_tensor.copy_(tensor) + planner.commit_tensor(req, target_tensor) + + def read_data_with_broadcast(self, per_file: Dict[str, List[ReadItem]], planner: LoadPlanner): + for relative_path, reqs in per_file.items(): + if dist.get_rank(self.data_parallel_process_group) == 0: + file_path = self._get_file_path(relative_path) + file = open(file_path, "rb") + dist.barrier(self.data_parallel_process_group) + reqs = sorted(reqs, key=lambda req: self.storage_data[req.storage_index].offset) + for req in reqs: + if dist.get_rank(self.data_parallel_process_group) == 0: + item_md = self.storage_data[req.storage_index] + file_slice = self._slice_file(file, item_md) + + if req.type == LoadItemType.BYTE_IO: + if dist.get_rank(self.data_parallel_process_group) == 0: + object_list = [io.BytesIO(file_slice.read(item_md.length))] + else: + object_list = [None] + + dist.broadcast_object_list( + object_list, + src=dist.get_global_rank(self.data_parallel_process_group, 0), + group=self.data_parallel_process_group, + device=f"cuda:{torch.cuda.current_device()}", + ) + bytes = object_list[0] + bytes.seek(0) + planner.load_bytes(req, bytes) + else: + if dist.get_rank(self.data_parallel_process_group) == 0: + object_list = [cast(Tensor, torch.load(file_slice, map_location="cuda"))] + else: + object_list = [None] + dist.broadcast_object_list( + object_list, + src=dist.get_global_rank(self.data_parallel_process_group, 0), + group=self.data_parallel_process_group, + device=f"cuda:{torch.cuda.current_device()}", + ) + tensor = object_list[0].cpu() + tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths) + target_tensor = planner.resolve_tensor(req).detach() + + assert ( + target_tensor.size() == tensor.size() + ), f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}" + target_tensor.copy_(tensor) + planner.commit_tensor(req, target_tensor) + + # Implementing the abstract function in StorageReader + def read_metadata(self) -> Metadata: + metadata_path = self._get_file_path(".metadata") + with open(metadata_path, "rb") as metadata_file: + metadata = pickle.load(metadata_file) + return metadata + + def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None: + self.storage_data = metadata.storage_data + assert self.storage_data is not None + + def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: + return plan + + def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]: + return global_plan diff --git a/internlm/checkpoint/vescale/load_state_dict.py b/internlm/checkpoint/vescale/load_state_dict.py new file mode 100644 index 000000000..186d92983 --- /dev/null +++ b/internlm/checkpoint/vescale/load_state_dict.py @@ -0,0 +1,100 @@ +################################################################################ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +################################################################################ +# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. +################################################################################ + +from typing import Optional +import torch.distributed as dist +from torch.distributed.checkpoint.planner import LoadPlanner +from torch.distributed.checkpoint.utils import _DistWrapper +from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner +from .filesystem import FileSystemReader +from .meta_type import STATE_DICT_TYPE +import time +from internlm.utils.logger import get_logger +from .vescale_planner import VeScaleLoadPlanner +from internlm.core.context import global_context as gpc +from internlm.core.context import ParallelMode + +logger = get_logger(__file__) + +META_DATA_FILE = ".metadata" + + +def load_state_dict( + state_dict: STATE_DICT_TYPE, + path: str, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + planner: Optional[LoadPlanner] = None, + broadcast_tensors=False, +) -> None: + load_start_time = time.time() + """ + [veScale version] Loads a distributed ``state_dict`` in SPMD style. Fix sub-group storage. + """ + storage_reader = FileSystemReader( + path, + broadcast_tensors=broadcast_tensors, + data_parallel_process_group=process_group, + ) + + # Step 0: create distributed world based on process group and coordinator rank + distW = _DistWrapper(process_group, not no_dist, coordinator_rank) + if process_group: + distW.coordinator_rank = dist.get_global_rank(process_group, distW.coordinator_rank) + if planner is None: + planner = DefaultLoadPlanner() + plan_start_time = time.time() + + # Step 1: all processes create local read plan, + # then coordinator gathers all local plans and create global plan. + def local_step(): + assert planner is not None + meta_read_start_time = time.time() + metadata = storage_reader.read_metadata() + meat_read_cost_time = time.time() - meta_read_start_time + logger.info(f"Finish read meta file. Cost time: {meat_read_cost_time}s") + planner.set_up_planner(state_dict, metadata, distW.is_coordinator) + storage_reader.set_up_storage_reader(metadata, distW.is_coordinator) + + local_plan = planner.create_local_plan() + local_plan = storage_reader.prepare_local_plan(local_plan) + return local_plan + + def global_step(all_local_plans): + assert planner is not None + all_local_plans = planner.create_global_plan(all_local_plans) + all_local_plans = storage_reader.prepare_global_plan(all_local_plans) + return all_local_plans + + if isinstance(planner, VeScaleLoadPlanner): + central_plan = distW.reduce_scatter("plan", local_step, global_step) + else: + raise AssertionError("Unsupported planner for saving checkpoint") + load_ckpt_plan_cost_time = time.time() - plan_start_time + logger.info(f"Finish planning. Cost time: {load_ckpt_plan_cost_time}s") + + read_start_time = time.time() + + # Step 2: all processes read data from the given path + def read_data(): + assert planner is not None + # print(f"central_plan {gpc.get_local_rank(ParallelMode.PIPELINE)}: {central_plan}", flush=True) + final_local_plan = planner.finish_plan(central_plan) + all_reads = storage_reader.read_data(final_local_plan, planner) + all_reads.wait() + return None + + _ = distW.all_gather("read", read_data) + read_cost_time = time.time() - read_start_time + logger.info(f"Finish reading. Cost time: {read_cost_time}s") + + load_ckpt_cost_time = time.time() - load_start_time + logger.info(f"Finish loading. Cost time: {load_ckpt_cost_time}s") diff --git a/internlm/checkpoint/vescale/mem_checkpoint.py b/internlm/checkpoint/vescale/mem_checkpoint.py new file mode 100644 index 000000000..9f482550e --- /dev/null +++ b/internlm/checkpoint/vescale/mem_checkpoint.py @@ -0,0 +1,396 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ +import io +import dataclasses +import os +import torch +from torch import multiprocessing +import threading +from typing import Callable, Dict, Any, DefaultDict, List, Optional +import pickle + +from . import bfile +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) + +if hasattr(torch.storage, "TypedStorage"): + TypedStorage = torch.storage.TypedStorage +elif hasattr(torch.storage, "_TypedStorage"): + TypedStorage = torch.storage._TypedStorage + +# TypedStorage changes in pytorch 2. +if torch.__version__ >= "2": + + def untyped_storage(o): + return o.untyped_storage() + + def location_caster(o): + return o +elif torch.__version__ >= "1.11": + + def untyped_storage(o): + return o.storage()._storage + + def location_caster(o): + return o._storage if isinstance(o, TypedStorage) else o + + +try: + lib = torch.cuda.cudart() +except: + lib = None + + +def _bytes_to_tensor(b: bytes): + # Copied from `_object_to_tensor` in + # https://pytorch.org/docs/2.0/_modules/torch/distributed/distributed_c10d.html + byte_storage = torch.ByteStorage.from_buffer(b) + return torch.ByteTensor(byte_storage) + + +class PinnedStoragePool: + def __init__(self): + self._l = threading.Lock() + self._m = DefaultDict(set) + + def allocate(self, nbytes: int): + with self._l: + # We don't really need storage to have the exact size. So in theory we can find a + # bigger storage that may suit here. But so far we keep everything simple here. + s = self._m[nbytes] + if not s: + t = torch.empty([nbytes], dtype=torch.uint8) + t = t.share_memory_() + if lib is not None and nbytes != 0: + err = lib.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0) + assert err == 0, err + storage = untyped_storage(t) + s.add(storage) + return s.pop() + + def deallocate(self, s): + # WARNING: Call deallocate when the reference to CPU tensor goes to zero + # so the memory pool will reuse the memory if possbile + # Othterwise, the memory pool will allocate memory on the used memory range, + # leading to cuda error 712 cudaErrorHostMemoryAlreadyRegistered + with self._l: + self._m[s.nbytes()].add(s) + + +GLOBAL_POOL = PinnedStoragePool() + +TID = threading.get_ident() + + +def copy_gpu_tensor_to_cpu_pinned_mem_pool(tensor: torch.Tensor, non_blocking=False) -> torch.Tensor: + """ + Copy a tensor on GPU to pinned memory pool (host CPU memory). + The input tensor will not be modified + Args: + tensor: a tensor on cuda device + Return: + a tensor on cpu, whose data is the same as input tensor + """ + m = {} + _old_warning = getattr(torch.storage, "_warn_typed_storage_removal", None) + torch.storage._warn_typed_storage_removal = lambda *args, **kwags: None + + def persistent_id(o): + if torch.is_storage(o) or isinstance(o, TypedStorage): + storage = o + if storage._cdata in m: + return storage._cdata + if storage.device.type != "cpu": + copied = GLOBAL_POOL.allocate(storage.nbytes()) + copied.copy_(storage, non_blocking=non_blocking) + if isinstance(storage, TypedStorage): + copied = storage._new_wrapped_storage(copied) + else: + copied = storage.clone() + m[storage._cdata] = copied + return storage._cdata + return + + b = io.BytesIO() + p = pickle.Pickler(b) + p.persistent_id = persistent_id + p.dump(tensor) + b.seek(0) + up = pickle.Unpickler(b) + up.persistent_load = lambda i: m[i] + cpu_tensor = up.load() + """ + assert type(tensor) == torch.Tensor + storage_obj = tensor.storage() + cpu_storage = GLOBAL_POOL.allocate(storage_obj.nbytes()) + + cpu_storage.copy_(storage_obj, non_blocking=non_blocking) + cpu_tensor = torch.tensor(cpu_storage) + """ + torch.storage._warn_typed_storage_removal = _old_warning + return cpu_tensor + + +def deallocate_cpu_tensor_in_pinned_mem_pool(tensor: torch.Tensor): + "Deallocate CPU tensor in the global pinned memory pool" + GLOBAL_POOL.deallocate(tensor.untyped_storage()) + + +class _CalledOnce: + def __init__(self, func): + self._l = threading.Lock() + self._func = func + self._res = None + self._called = False + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + with self._l: + if self._called: + return self._res + self._called = True + self._res = self._func(*args, **kwargs) + return self._res + + +_LOCATION_TAG_LOCK = threading.Lock() + + +@dataclasses.dataclass +class _SaveArgs: + obj: object + storage_tags: list + pickle_module: __module__ + args: list + kwargs: dict + + +def _serialize_obj_with_map(a: _SaveArgs, as_shared_tensor=False): + """Called to serialize an object to a byte stream or a shared tensor. + + Args: + a (_SaveArgs): The save args consist of the original tensor to serialize, + the location tags, the pickle module and other args + as_shared_tensor (bool): Whether to serialize to a shared tensor or a byte stream. + Set False if no inter process communication will happen subsequently + + Returns: + byte stream or shared tensor: The serialized object + + """ + lm = {} + for storage, tag in a.storage_tags: + lm[storage._cdata] = tag + + def location_tag(storage): + loc = lm.get(storage._cdata, None) + if loc is None: + if storage.nbytes() == 0: + # if return None, save will succeed, but load will fail afterwards + return "cpu" + raise ValueError("Unknown storage") + return loc + + with _LOCATION_TAG_LOCK: + old_location_tag = torch.serialization.location_tag + torch.serialization.location_tag = location_tag + + bio = io.BytesIO() + pickle_module = a.pickle_module or pickle + torch.save(a.obj, bio, pickle_module=pickle_module, *a.args, **a.kwargs) + + torch.serialization.location_tag = old_location_tag + b = bio.getvalue() + if not as_shared_tensor: + return b + else: + return _bytes_to_tensor(b).share_memory_() + + +def _write(f, sa): + # Serialize tensor obj directly to a byte stream, no need to convert it + # back to a shared tensor because the whole procedure happens in the same + # process + b = _serialize_obj_with_map(sa) + bfile.safe_atomic_write(f, b) + + +@dataclasses.dataclass +class _PoolArgs: + pinned_pool: PinnedStoragePool + pooled_storages: list + + +class _WriteFunc: + def __init__(self, sa: _SaveArgs, pa: _PoolArgs, async_worker): + self._sa = sa + if self._sa.pickle_module == pickle: + # This makes wa serializable. + self._sa.pickle_module = None + self._pa = pa + self._async_worker = async_worker + + self._enable_mp = async_worker is not None and sa.pickle_module is None + self._des = _CalledOnce(self._des_do_not_call_directly) + self._l = threading.RLock() + self._serialized = None + self._bytes = None + + def _des_do_not_call_directly(self): + for s in self._pa.pooled_storages: + self._pa.pinned_pool.deallocate(s) + + def __del__(self): + self._des() + + @property + def serialized(self): + with self._l: + if self._serialized is None: + if self._enable_mp: + self._serialized = self._async_worker.apply(_serialize_obj_with_map, (self._sa, True)) + else: + self._serialized = _serialize_obj_with_map(self._sa) + self._des() + return self._serialized + + @property + def bytes(self): + if self._bytes is None: + with self._l: + if self._enable_mp: + self._bytes = self.serialized.numpy().tobytes() + else: + self._bytes = self.serialized + return self._bytes + + def __call__(self, file: str = None): + if file is None: + return self.bytes + + if self._async_worker: + self._async_worker.apply(_write, (file, self._sa)) + else: + _write(file, self._sa) + self._des() + + +class TorchCheckpointRecorder: + def __init__( + self, + fast_mode=None, + async_worker: multiprocessing.Pool = None, + pinned_pool=GLOBAL_POOL, + ): + self._thread_id = threading.get_ident() + self._m = {} + + # After 1.11, typed storage is publicly accessible. + condition = torch.__version__ >= "1.11" + self._fast_mode = fast_mode if fast_mode is not None else condition + # Safety check. + assert not self._fast_mode or condition + + self._async_worker = async_worker + self._pinned_pool = pinned_pool + + def __enter__(self): + self._old_save = torch.save + torch.save = self._save_wrapper + if self._fast_mode: + self._old_warning = getattr(torch.storage, "_warn_typed_storage_removal", None) + torch.storage._warn_typed_storage_removal = lambda *args, **kwags: None + return self + + def __exit__(self, *args): + torch.save = self._old_save + if self._fast_mode: + if self._old_warning: + torch.storage._warn_typed_storage_removal = self._old_warning + + def _save_wrapper(self, obj, f, pickle_module=pickle, *args, **kwargs): + if threading.get_ident() != self._thread_id or not isinstance(f, (str, os.PathLike)): + return self._old_save(obj, f, pickle_module, *args, **kwargs) + + if self._fast_mode: + func = self._copy_to_buffer(obj, pickle_module, *args, **kwargs) + else: + func = self._save_to_buffer(obj, pickle_module, *args, **kwargs) + + self._m[str(f)] = func + + def _save_to_buffer(self, obj, *args, **kwags): + b = io.BytesIO() + self._old_save(obj, b, *args, **kwags) + + def gen_func(b): + def func(f: str = None): + if f: + return bfile.safe_atomic_write(f, b.getvalue()) + return b.getvalue() + + return func + + return gen_func(b) + + def _copy_to_buffer(self, obj, pickle_module, *args, **kwargs): + m = {} + storage_tags = [] + pooled_storages = [] + + def persistent_id(o): + if torch.is_storage(o) or isinstance(o, TypedStorage): + storage = o + if storage._cdata in m: + return storage._cdata + if storage.device.type != "cpu": + copied = self._pinned_pool.allocate(storage.nbytes()) + pooled_storages.append(copied) + copied.copy_(storage, non_blocking=False) + if isinstance(storage, TypedStorage): + copied = storage._new_wrapped_storage(copied) + else: + copied = storage.clone() + m[storage._cdata] = copied + tag = torch.serialization.location_tag(location_caster(storage)) + storage_tags.append((copied, tag)) + return storage._cdata + return + + b = io.BytesIO() + p = pickle_module.Pickler(b) + p.persistent_id = persistent_id + p.dump(obj) + b.seek(0) + up = pickle_module.Unpickler(b) + up.persistent_load = lambda i: m[i] + nobj = up.load() + + sa = _SaveArgs( + obj=nobj, + storage_tags=storage_tags, + pickle_module=pickle_module, + args=args, + kwargs=kwargs, + ) + pa = _PoolArgs(pinned_pool=self._pinned_pool, pooled_storages=pooled_storages) + + return _WriteFunc(sa, pa, self._async_worker) + + @property + def files(self) -> Dict[str, Callable[[Optional[List[str]]], Optional[bytes]]]: + return self._m diff --git a/internlm/checkpoint/vescale/mem_file_service_pb2.py b/internlm/checkpoint/vescale/mem_file_service_pb2.py new file mode 100644 index 000000000..5966cd239 --- /dev/null +++ b/internlm/checkpoint/vescale/mem_file_service_pb2.py @@ -0,0 +1,66 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: checkpoint/utilities/server/mem_file_service.proto +# Protobuf Python Version: 4.25.1 +"""Generated protocol buffer code.""" + +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n2checkpoint/utilities/server/mem_file_service.proto">\n\x1dVeScaleCheckpointWriteRequest\x12\x0f\n\x07\x63ontent\x18\x01 \x01(\x0c\x12\x0c\n\x04name\x18\x08 \x01(\t" \n\x1eVeScaleCheckpointWriteResponse",\n\x1cVeScaleCheckpointReadRequest\x12\x0c\n\x04name\x18\x01 \x01(\t"0\n\x1dVeScaleCheckpointReadResponse\x12\x0f\n\x07\x63ontent\x18\x01 \x01(\x0c"M\n\x1eVeScaleCheckpointRenameRequest\x12\x0b\n\x03src\x18\x01 \x01(\t\x12\x0b\n\x03\x64st\x18\x02 \x01(\t\x12\x11\n\toverwrite\x18\x03 \x01(\x08"!\n\x1fVeScaleCheckpointRenameResponse".\n\x1eVeScaleCheckpointRemoveRequest\x12\x0c\n\x04name\x18\x01 \x01(\t"!\n\x1fVeScaleCheckpointRemoveResponse"/\n\x1fVeScaleCheckpointListdirRequest\x12\x0c\n\x04name\x18\x01 \x01(\t"1\n VeScaleCheckpointListdirResponse\x12\r\n\x05names\x18\x01 \x03(\t".\n\x1eVeScaleCheckpointExistsRequest\x12\x0c\n\x04name\x18\x01 \x01(\t"1\n\x1fVeScaleCheckpointExistsResponse\x12\x0e\n\x06\x65xists\x18\x01 \x01(\x08\x32\xf9\x03\n\x1fVeScaleCheckpointMemFileService\x12L\n\x05Write\x12\x1e.VeScaleCheckpointWriteRequest\x1a\x1f.VeScaleCheckpointWriteResponse"\x00(\x01\x12I\n\x04Read\x12\x1d.VeScaleCheckpointReadRequest\x1a\x1e.VeScaleCheckpointReadResponse"\x00\x30\x01\x12M\n\x06Rename\x12\x1f.VeScaleCheckpointRenameRequest\x1a .VeScaleCheckpointRenameResponse"\x00\x12M\n\x06Remove\x12\x1f.VeScaleCheckpointRemoveRequest\x1a .VeScaleCheckpointRemoveResponse"\x00\x12P\n\x07Listdir\x12 .VeScaleCheckpointListdirRequest\x1a!.VeScaleCheckpointListdirResponse"\x00\x12M\n\x06\x45xists\x12\x1f.VeScaleCheckpointExistsRequest\x1a .VeScaleCheckpointExistsResponse"\x00\x62\x06proto3' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "checkpoint.utilities.server.mem_file_service_pb2", _globals) +if _descriptor._USE_C_DESCRIPTORS is False: + DESCRIPTOR._options = None + _globals["_VESCALECHECKPOINTWRITEREQUEST"]._serialized_start = 54 + _globals["_VESCALECHECKPOINTWRITEREQUEST"]._serialized_end = 116 + _globals["_VESCALECHECKPOINTWRITERESPONSE"]._serialized_start = 118 + _globals["_VESCALECHECKPOINTWRITERESPONSE"]._serialized_end = 150 + _globals["_VESCALECHECKPOINTREADREQUEST"]._serialized_start = 152 + _globals["_VESCALECHECKPOINTREADREQUEST"]._serialized_end = 196 + _globals["_VESCALECHECKPOINTREADRESPONSE"]._serialized_start = 198 + _globals["_VESCALECHECKPOINTREADRESPONSE"]._serialized_end = 246 + _globals["_VESCALECHECKPOINTRENAMEREQUEST"]._serialized_start = 248 + _globals["_VESCALECHECKPOINTRENAMEREQUEST"]._serialized_end = 325 + _globals["_VESCALECHECKPOINTRENAMERESPONSE"]._serialized_start = 327 + _globals["_VESCALECHECKPOINTRENAMERESPONSE"]._serialized_end = 360 + _globals["_VESCALECHECKPOINTREMOVEREQUEST"]._serialized_start = 362 + _globals["_VESCALECHECKPOINTREMOVEREQUEST"]._serialized_end = 408 + _globals["_VESCALECHECKPOINTREMOVERESPONSE"]._serialized_start = 410 + _globals["_VESCALECHECKPOINTREMOVERESPONSE"]._serialized_end = 443 + _globals["_VESCALECHECKPOINTLISTDIRREQUEST"]._serialized_start = 445 + _globals["_VESCALECHECKPOINTLISTDIRREQUEST"]._serialized_end = 492 + _globals["_VESCALECHECKPOINTLISTDIRRESPONSE"]._serialized_start = 494 + _globals["_VESCALECHECKPOINTLISTDIRRESPONSE"]._serialized_end = 543 + _globals["_VESCALECHECKPOINTEXISTSREQUEST"]._serialized_start = 545 + _globals["_VESCALECHECKPOINTEXISTSREQUEST"]._serialized_end = 591 + _globals["_VESCALECHECKPOINTEXISTSRESPONSE"]._serialized_start = 593 + _globals["_VESCALECHECKPOINTEXISTSRESPONSE"]._serialized_end = 642 + _globals["_VESCALECHECKPOINTMEMFILESERVICE"]._serialized_start = 645 + _globals["_VESCALECHECKPOINTMEMFILESERVICE"]._serialized_end = 1150 +# @@protoc_insertion_point(module_scope) diff --git a/internlm/checkpoint/vescale/mem_file_service_pb2_grpc.py b/internlm/checkpoint/vescale/mem_file_service_pb2_grpc.py new file mode 100644 index 000000000..4558bfa65 --- /dev/null +++ b/internlm/checkpoint/vescale/mem_file_service_pb2_grpc.py @@ -0,0 +1,321 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" + +import grpc + +from . import ( + mem_file_service_pb2 as checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2, +) + + +class VeScaleCheckpointMemFileServiceStub: + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.Write = channel.stream_unary( + "/VeScaleCheckpointMemFileService/Write", + request_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointWriteRequest.SerializeToString, + response_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointWriteResponse.FromString, + ) + self.Read = channel.unary_stream( + "/VeScaleCheckpointMemFileService/Read", + request_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointReadRequest.SerializeToString, + response_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointReadResponse.FromString, + ) + self.Rename = channel.unary_unary( + "/VeScaleCheckpointMemFileService/Rename", + request_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRenameRequest.SerializeToString, + response_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRenameResponse.FromString, + ) + self.Remove = channel.unary_unary( + "/VeScaleCheckpointMemFileService/Remove", + request_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRemoveRequest.SerializeToString, + response_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRemoveResponse.FromString, + ) + self.Listdir = channel.unary_unary( + "/VeScaleCheckpointMemFileService/Listdir", + request_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointListdirRequest.SerializeToString, + response_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointListdirResponse.FromString, + ) + self.Exists = channel.unary_unary( + "/VeScaleCheckpointMemFileService/Exists", + request_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointExistsRequest.SerializeToString, + response_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointExistsResponse.FromString, + ) + + +class VeScaleCheckpointMemFileServiceServicer: + """Missing associated documentation comment in .proto file.""" + + def Write(self, request_iterator, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def Read(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def Rename(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def Remove(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def Listdir(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + def Exists(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + + +def add_VeScaleCheckpointMemFileServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + "Write": grpc.stream_unary_rpc_method_handler( + servicer.Write, + request_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointWriteRequest.FromString, + response_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointWriteResponse.SerializeToString, + ), + "Read": grpc.unary_stream_rpc_method_handler( + servicer.Read, + request_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointReadRequest.FromString, + response_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointReadResponse.SerializeToString, + ), + "Rename": grpc.unary_unary_rpc_method_handler( + servicer.Rename, + request_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRenameRequest.FromString, + response_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRenameResponse.SerializeToString, + ), + "Remove": grpc.unary_unary_rpc_method_handler( + servicer.Remove, + request_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRemoveRequest.FromString, + response_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRemoveResponse.SerializeToString, + ), + "Listdir": grpc.unary_unary_rpc_method_handler( + servicer.Listdir, + request_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointListdirRequest.FromString, + response_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointListdirResponse.SerializeToString, + ), + "Exists": grpc.unary_unary_rpc_method_handler( + servicer.Exists, + request_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointExistsRequest.FromString, + response_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointExistsResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler("VeScaleCheckpointMemFileService", rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + + +# This class is part of an EXPERIMENTAL API. +class VeScaleCheckpointMemFileService: + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def Write( + request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.stream_unary( + request_iterator, + target, + "/VeScaleCheckpointMemFileService/Write", + checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointWriteRequest.SerializeToString, + checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointWriteResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def Read( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_stream( + request, + target, + "/VeScaleCheckpointMemFileService/Read", + checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointReadRequest.SerializeToString, + checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointReadResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def Rename( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/VeScaleCheckpointMemFileService/Rename", + checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRenameRequest.SerializeToString, + checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRenameResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def Remove( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/VeScaleCheckpointMemFileService/Remove", + checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRemoveRequest.SerializeToString, + checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRemoveResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def Listdir( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/VeScaleCheckpointMemFileService/Listdir", + checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointListdirRequest.SerializeToString, + checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointListdirResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) + + @staticmethod + def Exists( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/VeScaleCheckpointMemFileService/Exists", + checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointExistsRequest.SerializeToString, + checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointExistsResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) diff --git a/internlm/checkpoint/vescale/mem_server_lib.py b/internlm/checkpoint/vescale/mem_server_lib.py new file mode 100644 index 000000000..2bf3af242 --- /dev/null +++ b/internlm/checkpoint/vescale/mem_server_lib.py @@ -0,0 +1,307 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ +import dataclasses +import io +import grpc +from typing import Tuple +import os +import threading +import contextlib +import pathlib +import subprocess +import time +import queue +from concurrent import futures + +from . import mem_file_service_pb2 +from . import mem_file_service_pb2_grpc + + +class _Directory(dict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.lock = threading.RLock() + + +@dataclasses.dataclass +class _File: + content: bytes = b"" + + +_CHUNK_SIZE = 2 * 1024 * 1024 + + +def get_mem_server_sock_file(name: str): + return f"/var/tmp/mem_server_{name}.sock" + + +class MemFileServicer(mem_file_service_pb2_grpc.VeScaleCheckpointMemFileServiceServicer): + def __init__(self): + self._d = _Directory() + + def Write(self, request_iterator, ctx: grpc.ServicerContext): + b = io.BytesIO() + name = None + for req in request_iterator: + if name is None: + if not req.name: + ctx.abort(grpc.StatusCode.INVALID_ARGUMENT, "Name must be specified.") + name = req.name + d, bn = self._iterate_dir(name, ctx, create=True) + b.write(req.content) + if name: + with d.lock: + d[bn] = _File(content=b.getvalue()) + return mem_file_service_pb2.VeScaleCheckpointWriteResponse() + + def Read(self, req, ctx: grpc.ServicerContext): + d, bn = self._iterate_dir(req.name, ctx) + with d.lock: + if bn not in d or not isinstance(d[bn], _File): + ctx.abort(grpc.StatusCode.NOT_FOUND, f"{req.name} not found.") + f: _File = d[bn] + cur = 0 + while cur < len(f.content): + yield mem_file_service_pb2.VeScaleCheckpointReadResponse(content=f.content[cur : cur + _CHUNK_SIZE]) + cur += _CHUNK_SIZE + + def Rename(self, req, ctx: grpc.ServicerContext): + src_dir, src_bn = self._iterate_dir(req.src, ctx) + dst_dir, dst_bn = self._iterate_dir(req.dst, ctx) + if src_dir != dst_dir: + ctx.abort(grpc.StatusCode.UNIMPLEMENTED, "Rename across dir is not supported.") + d = src_dir + with d.lock: + if src_bn not in src_bn: + ctx.abort(grpc.StatusCode.NOT_FOUND, f"{req.src} is not found.") + if not req.overwrite and dst_bn in d: + ctx.abort(grpc.StatusCode.ALREADY_EXISTS, f"{req.dst} already exists.") + d[dst_bn] = d[src_bn] + del d[src_bn] + return mem_file_service_pb2.VeScaleCheckpointRenameResponse() + + def Remove(self, req, ctx: grpc.ServicerContext): + d, bn = self._iterate_dir(req.name, ctx) + if bn not in d: + ctx.abort(grpc.StatusCode.NOT_FOUND, f"{req.name} not found.") + with d.lock: + del d[bn] + return mem_file_service_pb2.VeScaleCheckpointRemoveResponse() + + def Listdir(self, req, ctx: grpc.ServicerContext): + d, _ = self._iterate_dir(os.path.join(req.name, "*")) + if d is None: + return mem_file_service_pb2.VeScaleCheckpointListdirResponse() + + resp = mem_file_service_pb2.VeScaleCheckpointListdirResponse() + with d.lock: + for name in d: + resp.names.append(name) + return resp + + def Exists(self, req, ctx: grpc.ServicerContext): + d, bn = self._iterate_dir(req.name) + if d is None: + return mem_file_service_pb2.VeScaleCheckpointExistsResponse(exists=False) + with d.lock: + return mem_file_service_pb2.VeScaleCheckpointExistsResponse(exists=bn in d) + + def _iterate_dir(self, name: str, ctx: grpc.ServicerContext = None, create=False) -> Tuple[_Directory, str]: + if ctx is None: + + class FakeCtx: + def abort(*args, **kwargs): + return None, None + + ctx = FakeCtx() + name = str(pathlib.Path(name).absolute())[1:] + parts = name.split("/") + cur = self._d + for part in parts[:-1]: + with cur.lock: + if part not in cur: + if not create: + return ctx.abort(grpc.StatusCode.NOT_FOUND, f"{part} doesn't exist.") + else: + cur[part] = _Directory() + cur = cur[part] + if not isinstance(cur, _Directory): + return ctx.abort( + grpc.StatusCode.ALREADY_EXISTS, + f"{part} already exist as a file.", + ) + return cur, parts[-1] + + +def start_server(name: str, force=False): + sock = get_mem_server_sock_file(name) + if os.path.exists(sock) and not force: + raise OSError("Mem server is already running.") + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + mem_file_service_pb2_grpc.add_VeScaleCheckpointMemFileServiceServicer_to_server(MemFileServicer(), server) + server.add_insecure_port(f"unix:{sock}") + server.start() + return server + + +# --- Below is general file interface --- + +_STUB_CACHE = {} +_STUB_CACHE_LOCK = threading.Lock() +SCHEMA = "/local_mem/" + + +def get_prefix(name: str): + return SCHEMA + name + + +def _get_mem_name_and_name(path: str): + path = path[len(SCHEMA) :] + pos = path.find("/") + if pos == -1: + return path, "/" + else: + return path[:pos], path[pos:] + + +def _get_stub_and_name( + path: str, +) -> Tuple[mem_file_service_pb2_grpc.VeScaleCheckpointMemFileServiceStub, str]: + mem_name, name = _get_mem_name_and_name(path) + if mem_name not in _STUB_CACHE: + c = grpc.insecure_channel(f"unix:{get_mem_server_sock_file(mem_name)}") + with _STUB_CACHE_LOCK: + _STUB_CACHE[mem_name] = mem_file_service_pb2_grpc.VeScaleCheckpointMemFileServiceStub(c) + return _STUB_CACHE[mem_name], name + + +class _FileLike: + def __init__(self, name: str, mode: str): + if mode not in ["rb", "wb"]: + raise NotImplementedError(f"{mode} is not implemented.") + self._stub, self._name = _get_stub_and_name(name) + self._mode = mode + self._is_write = "w" in mode + if self._is_write: + self._write_async() + self._read_buf = None + + @property + def read_buf(self): + if self._read_buf is None: + self._read_buf = io.BytesIO() + for resp in self._stub.Read(mem_file_service_pb2.VeScaleCheckpointReadRequest(name=self._name)): + self._read_buf.write(resp.content) + self._read_buf.seek(0) + return self._read_buf + + def __getattr__(self, name): + if not self._is_write: + return getattr(self.read_buf, name) + + def _write_async(self): + self._q = queue.Queue() + + def streaming(): + while True: + content, eof = self._q.get() + if eof: + break + cur = 0 + while cur < len(content): + req = mem_file_service_pb2.VeScaleCheckpointWriteRequest(content=content[cur : cur + _CHUNK_SIZE]) + if cur == 0: + req.name = self._name + yield req + cur += _CHUNK_SIZE + + self._write_future = self._stub.Write.future(streaming()) + + def write(self, content): + self._q.put((content, False)) + + def close(self): + if self._is_write: + self._q.put((None, True)) + self._write_future.result() + + +@contextlib.contextmanager +def open(name, mode) -> io.FileIO: + f = _FileLike(name, mode) + try: + yield f + finally: + f.close() + + +def rename(src, dst, overwrite=False): + stub, src_name = _get_stub_and_name(src) + dst_stub, dst_name = _get_stub_and_name(dst) + if stub != dst_stub: + raise ValueError(f"Rename across mem file system is not supported. {src} {dst}") + stub.Rename(mem_file_service_pb2.VeScaleCheckpointRenameRequest(src=src_name, dst=dst_name, overwrite=overwrite)) + + +def remove(name): + stub, subname = _get_stub_and_name(name) + stub.Remove(mem_file_service_pb2.VeScaleCheckpointRemoveRequest(name=subname)) + + +def listdir(name): + try: + stub, subname = _get_stub_and_name(name) + resp = stub.Listdir(mem_file_service_pb2.VeScaleCheckpointListdirRequest(name=subname)) + return list(resp.names) + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.UNAVAILABLE: + return [] + raise + + +def exists(name): + try: + stub, subname = _get_stub_and_name(name) + resp = stub.Exists(mem_file_service_pb2.VeScaleCheckpointExistsRequest(name=subname)) + return resp.exists + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.UNAVAILABLE: + return False + raise + + +# --- interface done --- + + +def start_server_in_new_process(name: str): + filename = os.path.join(os.path.dirname(os.path.abspath(__file__)), "detached_mem_server.py") + return subprocess.Popen(["python3", filename, f"--name={name}"]) + + +def wait_until_fs_ready(name: str, timeout=120): + stub, _ = _get_stub_and_name(os.path.join(SCHEMA, name)) + t0 = time.time() + while time.time() < t0 + timeout: + try: + stub.Listdir(mem_file_service_pb2.VeScaleCheckpointListdirRequest(name="/")) + return True + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.UNAVAILABLE: + time.sleep(0.1) + continue + raise + return False diff --git a/internlm/checkpoint/vescale/meta_type.py b/internlm/checkpoint/vescale/meta_type.py new file mode 100644 index 000000000..a669efba0 --- /dev/null +++ b/internlm/checkpoint/vescale/meta_type.py @@ -0,0 +1,38 @@ +################################################################################ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +################################################################################ +# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. +################################################################################ +# meta_type.py saves all constants and data types commonly used in vescale.checkpoint + +from enum import Enum +from typing import Dict, Any, TypeVar +from typing_extensions import Protocol, runtime_checkable + + +STATE_DICT_TYPE = Dict[str, Any] + +MODEL_STR = "model" +OPTIMIZER_STR = "optimizer" +STATE_DICT_STR = "state_dict" + + +class SupportedStrategy(Enum): + Megatron = 0 + FSDP = 1 + VeScale = 2 + + +@runtime_checkable +class Stateful(Protocol): + def state_dict(self) -> Dict[str, Any]: ... + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ... + + +T = TypeVar("T", bound=Stateful) +CheckpointState = Dict[str, T] diff --git a/internlm/checkpoint/vescale/placement_types.py b/internlm/checkpoint/vescale/placement_types.py new file mode 100644 index 000000000..d5f612704 --- /dev/null +++ b/internlm/checkpoint/vescale/placement_types.py @@ -0,0 +1,563 @@ +################################################################################ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +################################################################################ +# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. +################################################################################ + +from dataclasses import dataclass +from typing import List, NamedTuple, Optional, Tuple, cast + +import torch +import torch.distributed.distributed_c10d as c10d + +from .device_mesh import DeviceMesh + + +class Placement: + # base class Placement type + + # convenient utils to check for placement types + def is_shard(self, dim: Optional[int] = None) -> bool: + if dim is not None and isinstance(self, Shard): + return self.dim == dim + else: + return isinstance(self, Shard) + + def is_interleaved_shard(self, dim: Optional[int] = None) -> bool: + if dim is not None and isinstance(self, InterleavedShard): + return self.dim == dim + else: + return isinstance(self, InterleavedShard) + + def is_replicate(self) -> bool: + return isinstance(self, Replicate) + + def is_partial(self) -> bool: + return isinstance(self, Partial) + + def serialize_to_tensor(self, device) -> torch.Tensor: + if self.is_replicate(): + return torch.tensor([0, 0, 0], device=device, dtype=torch.int64) + elif self.is_partial(): + return torch.tensor([1, 0, 0], device=device, dtype=torch.int64) + elif self.is_shard(): + return torch.tensor([2, self.dim, 0], device=device, dtype=torch.int64) + elif self.is_interleaved_shard(): + return torch.tensor([3, self.dim, self.interleaved_size], device=device, dtype=torch.int64) + + @staticmethod + def serialize_from_tensor(tensor: torch.Tensor): + if tensor[0] == 0: + return Replicate() + elif tensor[0] == 1: + return Partial() + elif tensor[0] == 2: + return Shard(dim=tensor[1].item()) + elif tensor[0] == 3: + return InterleavedShard(dim=tensor[1].item(), interleaved_size=tensor[2].item()) + + +class Shard(Placement): + # shard placement, shard on a dim + def __init__(self, dim: int): + self.dim = dim + + def _split_tensor( + self, + tensor: torch.Tensor, + num_chunks: int, + *, + with_padding: bool = True, + contiguous: bool = True, + ) -> Tuple[List[torch.Tensor], List[int]]: + """ + This function uses torch.chunk to split a tensor into num_chunks shards along + the Shard placement dimension, and return a list of shards with their pad sizes. + + Keyword args: + with_padding (bool, optional): when True, we pad the tensor on the last + few ranks before calling the collectives (i.e. scatter/all_gather, etc.). + This is because collectives usually require equal size tensor inputs + + Example: + >>> Given a 2D global tensor with Shard(0) + >>> Run this method: + >>> torch.chunk(torch.tensor([[i] * 2 for i in range(13)]), num_chunks=6, dim=0) + + tensor1([[0, 0], + [1, 1], + [2, 2]]) + + tensor2([[3, 3], + [4, 4], + [5, 5]]) + + tensor3([[6, 6], + [7, 7], + [8, 8]]) + + tensor4([[ 9, 9], + [10, 10], + [11, 11]]) + + tensor5([[12, 12], + [, ], + [, ]]) + + ([[, ], + [, ], + [, ]]) + """ + assert self.dim <= tensor.ndim, f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" + assert tensor.size(self.dim) > 0, f"Tensor size along dim{self.dim} is 0. There is nothing to be sharded." + + # chunk tensor over dimension `dim` into n slices with padding if necessary + tensor_list = list(torch.chunk(tensor, num_chunks, dim=self.dim)) + # compute the chunk size inline with ``torch.chunk`` (round up to int) + full_chunk_size = (tensor.size(self.dim) + num_chunks - 1) // num_chunks + + # Compute chunk size for each chunk for ``self.dim`` + chunk_sizes = [tensor_list[idx].size(self.dim) if idx < len(tensor_list) else 0 for idx in range(num_chunks)] + # Compute pad size on each chunk + pad_sizes = [full_chunk_size - chunk_size for chunk_size in chunk_sizes] + + # Reuse tensor to fill empty chunk with empty tensor + num_empty_tensors = num_chunks - len(tensor_list) + if num_empty_tensors > 0: + tensor_size = list(tensor_list[0].size()) + tensor_size = [size if idx != self.dim else 0 for idx, size in enumerate(tensor_size)] + tensor = tensor.new_zeros(tensor_size) # (allocate empty chunk) + for _ in range(num_empty_tensors): + tensor_list.append(tensor) + + if with_padding or contiguous: + shard_list = [] + for shard, pad_size in zip(tensor_list, pad_sizes): + # Fill the empty tensor with zeroes with padding. + if with_padding and pad_size > 0: + shard = self._pad_tensor(shard, pad_size) + shard = shard.contiguous() if contiguous else shard + shard_list.append(shard) + return shard_list, pad_sizes + else: + return tensor_list, pad_sizes + + def _pad_tensor( + self, + tensor: torch.Tensor, + pad_size: int, + ) -> torch.Tensor: + pad = [0, 0] * (tensor.ndim - self.dim) + pad[-1] = pad_size + return torch.nn.functional.pad(tensor, pad) + + def _unpad_tensor( + self, + tensor: torch.Tensor, + pad_size: int, + ) -> torch.Tensor: + return tensor.narrow( + self.dim, + start=0, + length=tensor.size(self.dim) - pad_size, + ) + + def _local_shard_size_on_dim( + self, + size_on_dim: int, + num_chunks: int, + rank: int, + return_offset: bool = False, + ) -> Tuple[int, int]: + """ + returns the local shard size and offset on a given tensor dim + """ + assert ( + size_on_dim >= num_chunks + ), f"Size to be sharded on dim {self.dim} must be at least as large as the number of devices in that dimension {num_chunks}" + + # Compute the chunk size inline with ``torch.chunk`` + full_chunk_size = (size_on_dim + num_chunks - 1) // num_chunks + + # Compute chunk size for each chunk on the dimension. + chunk_sizes = [ + max( + min(size_on_dim, full_chunk_size * (idx + 1)) - full_chunk_size * idx, + 0, + ) + for idx in range(num_chunks) + ] + local_shard_size = chunk_sizes[rank] + + local_offset_on_dim = -1 + if return_offset: + # Return global tensor dim size of current dimension if for empty shard + # to represent the end of the corresponding tensor dim. + local_offset_on_dim = sum(chunk_sizes[:rank]) + + return (local_shard_size, local_offset_on_dim) + + def __hash__(self) -> int: + ret = self.dim + 128 # restrict sharding dim in [-128, +128]; should be sufficient + assert ret >= 0 + return ret + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Shard): + return False + return self.dim == other.dim + + def __repr__(self) -> str: + """ + machine readable representation of the Shard placement + """ + return f"Shard(dim={self.dim})" + + def __str__(self) -> str: + """human readable representation of the Shard placement""" + return f"S({self.dim})" + + +class Replicate(Placement): + # replicate placement + def __hash__(self) -> int: + # every replicate placement is the same + return -1 + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Replicate): + return False + return True + + def __repr__(self) -> str: + """ + machine readable representation of the Replicate placement + """ + return "Replicate()" + + def __str__(self) -> str: + """ + human readable representation of the Replicate placement + """ + return "R" + + +class Partial(Placement): + # This is a default partial placement with element-wise reduce op + # when doing reduction it follows the contract of `_to_replicate` + # and `_to_shard` to do the reduction and convert the local tensor + # to the corresponding state (replicate or shard) + # + # We can implement custom reductions as needed by subclassing this + # class and override those contracts. + + def __init__(self, reduce_op: c10d.ReduceOp.RedOpType = c10d.ReduceOp.SUM): + self.reduce_op: c10d.ReduceOp.RedOpType = reduce_op + + def __hash__(self) -> int: + ret = -3 - hash(self.reduce_op) # hash(reduce_op) gives 0~8 + assert ret <= -3 + return ret + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Partial): + return False + return self.reduce_op == other.reduce_op + + def __repr__(self) -> str: + """ + machine readable representation of the Partial placement + """ + return f"Partial(reduce_op={self.reduce_op})" + + def __str__(self) -> str: + """ + human readable representation of the Partial placement + """ + return "P" + + +class InterleavedShard(Shard): + """ + The major difference between this placement and Shard is that the global + tensor with a `InterleavedShard` placement is not contiguous. But you can + always treat a InterleavedShard(dim=x, interleaved_size=y) as a + Shard(dim=x+1)) on a tensor by reshaping the original one from + ``[..., size(x), ...]`` to ``[..., y, size(x) // y, ...]`` + + NOTE: We currently don't support padding in InterleavedShard, which means + we cannot interleaved shard a tensor when it's size is not divisible by + the multiply of interleaved_size and corresponding mesh size. + """ + + def __init__(self, dim: int, interleaved_size: int): + self.dim = dim + # TODO: make this attribute a list to support multi interleaved shard + self.interleaved_size = interleaved_size + + def _split_tensor( + self, + tensor: torch.Tensor, + num_chunks: int, + *, + contiguous: bool = True, + ) -> Tuple[List[torch.Tensor]]: + assert self.dim <= tensor.ndim, f"Interleaved Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" + assert tensor.size(self.dim) > 0, f"Tensor size along dim {self.dim} is 0. There is nothing to be sharded." + assert ( + tensor.size(self.dim) % self.interleaved_size == 0 + ), f"Tensor size along dim {self.dim} is not a multiple of interleaved size {self.interleaved_size}." + assert ( + tensor.size(self.dim) // self.interleaved_size + ) % num_chunks == 0, "InterleavedShard doesn't allow padding" + + # step 1: reshape tensor + tensor = tensor.view(tensor.shape[: self.dim] + (self.interleaved_size, -1) + tensor.shape[self.dim + 1 :]) + + # step 2: split tensor + # chunk tensor over dimension `dim` into n slices with padding if necessary + tensor_list = list(torch.chunk(tensor, num_chunks, dim=self.dim + 1)) + + # step 3: reshape back + result_list = [] + for t in tensor_list: + if contiguous: + t = t.contiguous() + # NOTE: view op might be not okay here, because tensor returned by chunk op + # is not contiguous. + shard = t.reshape(tensor.shape[: self.dim] + (-1,) + tensor.shape[self.dim + 2 :]) + result_list.append(shard) + + return result_list + + def _local_shard_size_on_dim( + self, + size_on_dim: int, + num_chunks: int, + rank: int, + return_offset: bool = False, + ) -> Tuple[int, int]: + """ + returns the local shard size and offset on a given tensor dim. + NOTE: argument ``rank`` and ``return_offset`` is useless here. The reason for + keeping them is to align this API with the one of ``Shard`` placement. + """ + assert ( + size_on_dim >= num_chunks + ), f"Size to be sharded on dim {self.dim} must be at least as large as the number of devices in that dimension {num_chunks}" + + # Compute the chunk size inline with ``torch.chunk`` + full_chunk_size = size_on_dim // num_chunks + return (full_chunk_size, None) + + def __hash__(self) -> int: + assert self.dim >= 0 and self.interleaved_size >= 0, "negatives (-1 & -2) can result in hash collison" + return hash((self.dim, self.interleaved_size)) + + def __repr__(self) -> str: + return f"InterleavedShard(dim={self.dim}, interleaved_size={self.interleaved_size})" + + def __str__(self) -> str: + return f"IS({self.dim}, {self.interleaved_size})" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, InterleavedShard): + return False + return self.dim == other.dim and self.interleaved_size == other.interleaved_size + + +class TensorMeta(NamedTuple): + # simple named tuple to represent tensor metadata + # intentionally to stay simple only for sharding + # propagation purposes. + shape: torch.Size + stride: Tuple[int, ...] + dtype: torch.dtype + + def __hash__(self) -> int: + assert isinstance(self.stride, Tuple) + return hash((self.shape, self.stride, self.dtype)) + + def __eq__(self, __o: object) -> bool: + if not isinstance(__o, TensorMeta): + return False + return ( + self.shape == __o.shape # type: ignore[union-attr] + and self.stride == __o.stride # type: ignore[union-attr] + and self.dtype == __o.dtype # type: ignore[union-attr] + ) + + +# used internally to propagate the placements + + +@dataclass +class DTensorSpec: + mesh: DeviceMesh + placements: Tuple[Placement, ...] + + # tensor meta will only be set during sharding propagation + tensor_meta: Optional[TensorMeta] = None + + def __hash__(self) -> int: + # hashing and equality check for DTensorSpec are used to cache the sharding + # propagation results. We only need to consider the mesh, placements, tensor_meta. + assert isinstance(self.placements, Tuple) and all(isinstance(p, Placement) for p in self.placements) + + return hash( + ( + self.mesh, + tuple(self.placements), + self.tensor_meta, # None is hashable + ) + ) + + def __eq__(self, __o: object) -> bool: + if not ( + isinstance(__o, DTensorSpec) and self.mesh == __o.mesh and tuple(self.placements) == tuple(__o.placements) + ): + return False + return self.tensor_meta == __o.tensor_meta # None included + + def __str__(self) -> str: + """ + human readable representation of the DTensorSpec + """ + if len(self.placements) == 1: + placement_str = str(self.placements[0]) + else: + placement_str = str(self.placements) + + if self.tensor_meta is not None: + tensor_shape = str(tuple(self.tensor_meta.shape)) + else: + tensor_shape = "unknown shape" + + return f"Spec({placement_str} on {tensor_shape})" + + @property + def shape(self) -> torch.Size: + if self.tensor_meta is None: + raise ValueError("tensor_meta is not set") + return self.tensor_meta.shape + + @property + def ndim(self) -> int: + if self.tensor_meta is None: + raise ValueError("tensor_meta is not set") + return len(self.tensor_meta.shape) + + @property + def num_shards(self) -> int: + num_shards = 1 + for i, placement in enumerate(self.placements): + if placement.is_shard() or placement.is_interleaved_shard(): + num_shards *= self.mesh.size(i) + return num_shards + + @property + def dim_map(self) -> List[int]: + """ + dim_map is a property we derive from `placements` of + the distributed tensor. It simply return a list of ints + where dim_map[i] denotes the sharding mapping to the mesh + dimension, and len(dim_map) == dist_tensor.ndim + dim_map[i] = -1: means tensor dim i replicate on mesh + dim_map[i] = j: means tensor dim i shard on mesh dim j + + For example, we have a dist tensor that have the shape of + [18, 20, 30], and device_mesh([0, 1, 2, 3]), placements: + [Shard(1)], the dim_map of this placement would be: + [-1, 0, -1]. This representation is pretty helpful during + sharding propagation where we could know exactly each + tensor dimension is sharded or not. + + Note that if placements contains `Partial`, we have to + explicitly deal with it, so that when we create a DTensorSpec + with dim_map, we could properly record the pending sums. + """ + # dims mapping of dist tensor sharding + # return size of tensor ndim, -1 represent replicate + # and int >=0 represent shard on that device mesh dim + r = [-1] * self.ndim + for i, placement in enumerate(self.placements): + if placement.is_shard(): + shard_dim = placement.dim + # NOTE: this might lead to other problems, pay attention. + # relax this check, allow shard one tensor dim twice. + # if r[shard_dim] > -1: + # raise ValueError( + # f"Tensor dim {shard_dim} is already sharded on mesh dim {r[shard_dim]}," + # " DTensor operator implementation does not support things like hybrid" + # " sharding strategies yet (i.e. [Shard(0), Shard(0)])" + # ) + r[shard_dim] = i + return r + + @property + def sums(self) -> List[int]: + """ + sums is a property we derive from `placements` of the + distributed tensor. It simply return a list of ints where + sums[i] denotes the pending sum (partial) on mesh dim i + """ + return [idx for idx, placement in enumerate(self.placements) if placement.is_partial()] + + @classmethod + def from_dim_map( + cls, + mesh: DeviceMesh, + dim_map: List[int], + sums: List[int], + tensor_meta: Optional[TensorMeta] = None, + ) -> "DTensorSpec": + """ + Construct a DTensorSpec from dim_map list and pending sum. + + Args: + mesh (class:`DeviceMesh`): device mesh to be used in the DTensorSpec + dim_map (List[int]): a list of integer that represents sharding on each + tensor dimension, see `dim_map` property doc for details + sums (List[int]): a list of integer that represents the dist tensor have + pending sum on which device mesh dimension. + tensor meta (TensorMeta): DTensor metadata + + Return: + a class:`DTensorSpec` object + """ + # by default replicate on device mesh dims + placements: List[Placement] = [Replicate() for _ in range(mesh.ndim)] + + # find all mesh dims that need pending reductions + for s in sums: + placements[s] = Partial() + + for i, m in enumerate(dim_map): + if m >= 0: + placement = placements[m] + if placement.is_shard(): + placement = cast(Shard, placement) + raise RuntimeError( + f"DeviceMesh dimension cann't be mapped to two dimension of the same tensor: {i} and {placement.dim}" + ) + elif placement.is_partial(): + raise RuntimeError(f"DeviceMesh dimension {m} cannot be both shard and partial!") + placements[m] = Shard(i) + + return cls(mesh, tuple(placements), tensor_meta=tensor_meta) + + def is_replicated(self): + """ + return True if the current DTensorSpec replicates on all mesh dims (devices) + """ + return all(placement.is_replicate() for placement in self.placements) + + def is_partial(self): + """ + return True if the current DTensorSpec is partial on all mesh dims (devices) + """ + return len(self.placements) == 1 and self.placements[0].is_partial() diff --git a/internlm/checkpoint/vescale/save_state_dict.py b/internlm/checkpoint/vescale/save_state_dict.py new file mode 100644 index 000000000..7abf79878 --- /dev/null +++ b/internlm/checkpoint/vescale/save_state_dict.py @@ -0,0 +1,203 @@ +################################################################################ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +################################################################################ +# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. +################################################################################ + +import os +import pickle +from typing import Optional, Tuple, List + + +import torch.distributed as dist +from torch.distributed.checkpoint.utils import _DistWrapper +from .filesystem import FileSystemWriter + + +from torch.distributed.checkpoint.planner import SavePlanner +from torch.distributed.checkpoint.metadata import Metadata +from torch.distributed.checkpoint.storage import WriteResult +from torch.distributed.checkpoint.default_planner import DefaultSavePlanner +from .meta_type import STATE_DICT_TYPE +from internlm.utils.logger import get_logger +import time +from concurrent.futures import Future +from internlm.core.context import global_context as gpc +from internlm.core.context import ParallelMode + + + + +logger = get_logger(__file__) + +from .vescale_planner import VeScaleSavePlanner + + +def save_state_dict( + state_dict: STATE_DICT_TYPE, + path: str, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + planner: Optional[SavePlanner] = None, + async_io: bool = True, + last_write_futures: Future[List[WriteResult]] = None, + io_workers=None, +) -> Tuple[Metadata, Future[List[WriteResult]]]: + """ + [veScale version] Saves a distributed model in SPMD style. Fix sub-group storage. + Args and usage is the same as `torch.distributed.checkpoint.save_state_dict`. + """ + + # Step 0: create distributed world based on process group and coordinator rank + distW = _DistWrapper(process_group, not no_dist, coordinator_rank) + if process_group: + distW.coordinator_rank = dist.get_global_rank(process_group, distW.coordinator_rank) + if planner is None: + planner = DefaultSavePlanner() + assert planner is not None + + global_metatadata = None + + storage_writer = FileSystemWriter(path) + + # Step 1: all processes create local write plan, + # then coordinator gathers all local plans and create global plan. + def local_step(): + logger.debug("Start local step of planning") + if isinstance(planner, VeScaleSavePlanner): + local_plan, p2p_tensors_info = planner.create_local_plan() + local_plan = storage_writer.prepare_local_plan(local_plan, p2p_tensors_info) + if gpc.get_local_rank(ParallelMode.PIPELINE) == 1: + print(f"save local_plan: {local_plan}", flush=True) + # if dist.get_rank() in [0, 1, 2, 3]: + # print(f"save local_plan {dist.get_rank()}, {gpc.get_local_rank(ParallelMode.TENSOR)}: {local_plan}", flush=True) + else: + raise AssertionError("Unsupported planner for planning") + logger.debug("Finish local step of planning") + # print(f"local_plan {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.TENSOR)}: {local_plan}", flush=True) + return local_plan + + def global_step(all_local_plans): + logger.debug("Start global step of planning") + nonlocal global_metatadata + assert planner is not None + # print(f"planner.coordinator_rank {gpc.get_global_rank()}: {planner.is_coordinator }", flush=True) + all_local_plans, global_metatadata = planner.create_global_plan(all_local_plans) + # print(f"save global_metatadata {dist.get_rank()}: {global_metatadata}", flush=True) + all_local_plans = storage_writer.prepare_global_plan(all_local_plans) + logger.debug("End global step of planning") + # print(f"global_plan {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.TENSOR)}: {all_local_plans}", flush=True) + return all_local_plans + + # Step 2: all processes write data from GPUs to pinned memory pool, then dump to local path + # then coordinator write meta-data to local path. + def write_data(async_io: bool = False, io_workers=io_workers): + logger.debug("Start writing data") + assert planner is not None + # print(f"central_plan {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.TENSOR)}: {central_plan}", flush=True) + final_local_plan = planner.finish_plan(central_plan) + if isinstance(planner, VeScaleSavePlanner): + # Use pinned memory pool and mult_processing for dumping ckpt to local directory efficiently + all_write_futures = storage_writer.write_data(final_local_plan, planner, async_io, io_workers) + logger.debug("Finish writing data") + if async_io: + return all_write_futures + else: + # Gather write results. + values = [] + for fut in all_write_futures: + # values += fut.get() + values += fut.result() + return values + else: + raise AssertionError("Unsupported planner for writing data") + + def finish_checkpoint(all_results): + logger.debug("Start writing metadata") + assert global_metatadata is not None, f"rank: {distW.get_rank()} has no global_metadata" + storage_writer.finish(metadata=global_metatadata, results=all_results) + logger.debug("Finish writing metadata") + print(f"global_metatadata: {global_metatadata}", flush=True) + return global_metatadata + + assert planner is not None + planner.set_up_planner(state_dict, distW.is_coordinator) + storage_writer.set_up_storage_writer(distW.is_coordinator) + + # Wait for last write futures to finish. + if last_write_futures: + logger.info("Start waiting for last write events.") + last_write_start_time = time.time() + for fut in last_write_futures: + fut.result() + last_write_time = time.time() - last_write_start_time + logger.info(f"Finish waiting for last write events. Time cost: {last_write_time}s") + + # Each worker bypass the `reduce_scatter()` and `all_reduce()` if finding cached central_plan and metadata. + # NOTE: it fails when the plans of partial workers change while others keep the same. + logger.info("Start planning.") + plan_start_time = time.time() + cached_data = None + + + if isinstance(planner, VeScaleSavePlanner): + central_plan = distW.reduce_scatter("plan", local_step, global_step) + else: + raise AssertionError("Unsupported planner for saving checkpoint") + # if isinstance(planner, VeScaleSavePlanner): #attn + # cached_data = planner.lookup_plan_meta() + # if cached_data: + # logger.debug("Plan cache hit. Reuse existing plan") + # central_plan, _ = cached_data + # _ = local_step() + # else: + # logger.debug("Plan cache miss. The model/optimizer appears for the first time.") + + # central_plan = distW.reduce_scatter("plan", local_step, global_step) + # else: + # raise AssertionError("Unsupported planner for saving checkpoint") + + + + plan_cost_time = time.time() - plan_start_time + logger.info(f"Finish planning. Time cost: {plan_cost_time}s") + + logger.info("Start storing") + store_local_start_time = time.time() + write_futures = [] + if isinstance(planner, VeScaleSavePlanner): + if cached_data: + logger.debug("Metdata cache hit. Reuse existing metadata") + _, final_storage_metadata = cached_data + write_results = write_data(async_io=async_io) + # Be sure to write cache metadata to .metadata file + # Otherwises only the first checkpoint has .metadata + # which leads to error when loading other checkpoints + if distW.is_coordinator: + with (storage_writer.path / ".metadata.tmp").open("wb") as metadata_file: + pickle.dump(final_storage_metadata, metadata_file) + os.fsync(metadata_file.fileno()) + + (storage_writer.path / ".metadata.tmp").rename(storage_writer.path / ".metadata") + + if async_io: + write_futures = write_results + else: + logger.debug("Metadata cache miss. The model/optimizer appears for the first time.") + # First time do synchronous storing to get final_storage_metatdata. + # Determine which communication topology to use. + final_storage_metadata = distW.all_reduce("write", write_data, finish_checkpoint) + assert central_plan is not None + assert final_storage_metadata is not None + # planner.cache_plan_meta(central_plan, final_storage_metadata) #attn + else: + raise AssertionError("Unsupported planner for writing data and metadata") + store_local_cost_time = time.time() - store_local_start_time + logger.info(f"Finish storing. Time cost: {store_local_cost_time}s") + + return final_storage_metadata, write_futures diff --git a/internlm/checkpoint/vescale/vescale_checkpointer.py b/internlm/checkpoint/vescale/vescale_checkpointer.py new file mode 100644 index 000000000..db19e284d --- /dev/null +++ b/internlm/checkpoint/vescale/vescale_checkpointer.py @@ -0,0 +1,259 @@ +################################################################################ +# +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +################################################################################ + +from concurrent.futures import ProcessPoolExecutor +from .base_checkpointer import BaseCheckpointer +from .meta_type import CheckpointState, MODEL_STR, OPTIMIZER_STR +from .save_state_dict import save_state_dict +from .load_state_dict import load_state_dict +from .vescale_planner import VeScaleSavePlanner, VeScaleLoadPlanner +from .devicemesh_api import VESCALE_DEVICE_MESH +from . import bfile +import os +# from .distributed_optimizer import initialize_optimizer_state +import torch.distributed as dist +from internlm.utils.logger import get_logger +import atexit +from internlm.core.context import global_context as gpc +from internlm.core.context import ParallelMode + + +logger = get_logger(__file__) + +VESCALE_SUPPORTED_TYPES = {MODEL_STR, OPTIMIZER_STR} +NUM_IO_WORKER = 1 + + +def deduplicate_2d_list(lst): + seen = set() + deduplicated_list = [] + for item in lst: + # Convert the inner list to a tuple for hashing + tuple_item = tuple(sorted(item)) # Sorting to treat [1, 2] and [2, 1] as the same + if tuple_item not in seen: + seen.add(tuple_item) + # Convert back to list to preserve original type + deduplicated_list.append(item) + return deduplicated_list + + +def get_optim_ckpt_process_group(): + # Get the process group based on current rank + # The processes with same pipeline stage ID + # are in the same process group + device_mesh = VESCALE_DEVICE_MESH.get() + sub_mesh = device_mesh.get_submesh(mesh_dims=["TP", "DP"]) + two_dim_list = sub_mesh.mesh.tolist() + flatten_rank_list = [item for sublist in two_dim_list for item in sublist] + all_flatten_lists = [[] for _ in range(dist.get_world_size())] + dist.all_gather_object(all_flatten_lists, flatten_rank_list) + all_flatten_lists = deduplicate_2d_list(all_flatten_lists) + my_rank = dist.get_rank() + pg = None + for rank_list in all_flatten_lists: + new_pg = dist.new_group(ranks=flatten_rank_list) + if my_rank in rank_list: + pg = new_pg + return pg + + +class VeScaleCheckpointer(BaseCheckpointer): + """ + The Checkpointer class for VeScale, A PyTorch Native Auto Parallelism Framework + """ + + save_planner = VeScaleSavePlanner() + load_planner = VeScaleLoadPlanner() + + optim_ckpt_proces_group = None + for key in VESCALE_SUPPORTED_TYPES: + BaseCheckpointer.state_io_workers[key] = ProcessPoolExecutor(max_workers=NUM_IO_WORKER) + BaseCheckpointer.state_write_futures[key] = [] + + @classmethod + def save( + cls, + path: str, + checkpoint_state: CheckpointState, + async_checkpoint: bool = False, + ): + """ + async_checkpoint: A boolean value indicating if saving checkpoint asynchronously, + i.e. after dumping tensors from GPU memory to Host memory, + the training program can continue training immediately. + Then vescale.checkpoint will serialize tensors and dumping to the persistent storage asynchronously. + """ + # Check if we support saving the components + for key in checkpoint_state.keys(): + if key not in VESCALE_SUPPORTED_TYPES: + raise ValueError(f"{key} is not supported by VeScaleCheckpointer") + + # Start saving checkpoint + for key, value in checkpoint_state.items(): + if key == MODEL_STR: + # Get model path + model_path = os.path.join(path, MODEL_STR) + # Create a "model" folder on under root path + if dist.get_rank() == 0: + bfile.makedirs(model_path) + dist.barrier() + # Save model. + _, new_write_futures = save_state_dict( + state_dict=value.state_dict(), + path=model_path, + process_group=None, + coordinator_rank=0, + no_dist=False, + planner=cls.save_planner, + async_io=async_checkpoint, + last_write_futures=cls.state_write_futures[MODEL_STR], + io_workers=cls.state_io_workers[MODEL_STR], + ) + # Record new write futures. + cls.state_write_futures[MODEL_STR] = new_write_futures + dist.barrier() + elif key == OPTIMIZER_STR: + # Create a "optimizer" folder on under root path + # to save different parts of optimizer + optim_root_path = os.path.join(path, OPTIMIZER_STR) + if dist.get_rank() == 0: + bfile.makedirs(optim_root_path) + dist.barrier() + # Get process group for saving optimizer, + # All processes with the same pipeline rank are in the same pg + if not cls.optim_ckpt_proces_group: + cls.optim_ckpt_proces_group = get_optim_ckpt_process_group() + + # Get optimizer path based on PP rank + # pp_rank = VESCALE_DEVICE_MESH.get_pipeline_parallel_rank() #attn + pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + optimizer_path = os.path.join(optim_root_path, f"pp_{pp_rank}") + # Create optimizer folder on under root path + if dist.get_rank(cls.optim_ckpt_proces_group) == 0: + bfile.makedirs(optimizer_path) + dist.barrier() + # Save optimizer + _, new_write_futures = save_state_dict( + state_dict=value.state_dict(), + path=optimizer_path, + process_group=cls.optim_ckpt_proces_group, + coordinator_rank=0, + no_dist=False, + planner=cls.save_planner, + async_io=async_checkpoint, + last_write_futures=cls.state_write_futures[OPTIMIZER_STR], + io_workers=cls.state_io_workers[OPTIMIZER_STR], + ) + # Record new write futures. + cls.state_write_futures[OPTIMIZER_STR] = new_write_futures + + @classmethod + def load( + cls, + path: str, + checkpoint_state: CheckpointState, + broadcast_checkpoint: bool = False, + ): + """ + broadcast_checkpoint: A boolean value decides if load a model replica from one data parallel process group + then broadcast tensors to other data parallel process group using GPUs + to reduce the file system access + For example, when data parellel size = 2, + processes with data parallel rank = 0 load model from file system + then broadcast it to processes with data parallel rank = 1 + """ + # Add warning + if bfile.is_local_path(path): + logger.warning( + "The local path for checkpointing should be accessible to all ranks. It can be a NFS/FUSE path" + ) + # Check if we support loading the component. + for key in checkpoint_state.keys(): + if key not in VESCALE_SUPPORTED_TYPES: + raise ValueError(f"{key} is not supported by VeScaleCheckpointer") + + # Start loading checkpoint + for key, value in checkpoint_state.items(): + if key == MODEL_STR: + # Get model path + model_path = os.path.join(path, MODEL_STR) + # Get model state dictionary + model_state = value.state_dict() + p = ", ".join(map(lambda item: f"Key: {item[0]}, Shape: {item[1].shape}", model_state.items())) + # print(f"model_state {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.PIPELINE)}: {p})", flush=True) + # Set process group + if broadcast_checkpoint: + assert False + model_load_process_group = VESCALE_DEVICE_MESH.get_data_parallel_dim_groups() + else: + model_load_process_group = None + # Load model + load_state_dict( + state_dict=model_state, + path=model_path, + process_group=model_load_process_group, + coordinator_rank=0, + no_dist=False, + planner=cls.load_planner, + broadcast_tensors=broadcast_checkpoint, + ) + # Load back to model + value.load_state_dict(model_state) #att + elif key == OPTIMIZER_STR: + assert False + # Get process group for loading optimizer, + # All processes with the same pipeline rank are in the same pg + if not cls.optim_ckpt_proces_group: + cls.optim_ckpt_proces_group = get_optim_ckpt_process_group() + # Get optimizer path based on TP and PP ranks + # pp_rank = VESCALE_DEVICE_MESH.get_pipeline_parallel_rank() #attn + pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + optimizer_path = os.path.join(path, f"{OPTIMIZER_STR}", f"pp_{pp_rank}") + # Initialize optimizer states + # initialize_optimizer_state(value) + # Get optimizer state + optimizer_state = value.state_dict() + # Load optimizer state dictionary + load_state_dict( + state_dict=optimizer_state, + path=optimizer_path, + process_group=cls.optim_ckpt_proces_group, + coordinator_rank=0, + no_dist=False, + planner=cls.load_planner, + broadcast_tensors=False, + ) + # Load back to optimizer + value.load_state_dict(optimizer_state) + dist.barrier() + + @classmethod + def __cleanup(cls): + """ + Wait for all write futures to finish before exit, then do the cleanup works. + + WARNING: this method cannot be called by the users. + """ + cls.save_planner.clear_cache() + BaseCheckpointer._cleanup_futures() + + @classmethod + def _register_cleanup(cls): + atexit.register(VeScaleCheckpointer.__cleanup) + + +VeScaleCheckpointer._register_cleanup() diff --git a/internlm/checkpoint/vescale/vescale_planner.py b/internlm/checkpoint/vescale/vescale_planner.py new file mode 100644 index 000000000..91fce5e91 --- /dev/null +++ b/internlm/checkpoint/vescale/vescale_planner.py @@ -0,0 +1,268 @@ +################################################################################ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +################################################################################ +# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. +################################################################################ +import io +import dataclasses +import torch +from typing import Any, Dict, Union, List, Tuple, Optional +from torch.distributed.checkpoint.default_planner import ( + DefaultSavePlanner, + DefaultLoadPlanner, +) + +import mmh3 + +from .common import P2PTensorsInfo, sort_rank_ranges, PlanLRUCache, custom_dedup_tensors +import math +import torch.distributed as dist +from torch.distributed.checkpoint.planner import SavePlan, LoadPlan, WriteItem, ReadItem +from torch.distributed.checkpoint.metadata import MetadataIndex, Metadata +from .distributed_optimizer import OptimizerStateSpec +# from vescale.dtensor import DTensor + +from .vescale_planner_helpers import _create_write_items, _create_read_items, find_state_dict_object +from .devicemesh_api import VESCALE_DEVICE_MESH +from .meta_type import STATE_DICT_STR +from internlm.utils.logger import get_logger +from internlm.core.context import global_context as gpc +from internlm.core.context import ParallelMode +from torch.distributed.checkpoint.planner import WriteItemType + + +logger = get_logger(__file__) +__all__ = [ + "VeScaleSavePlanner", + "VeScaleLoadPlanner", + "create_default_local_load_plan", + "create_default_local_save_plan", +] + +class DTensor: + pass + + +class VeScaleLoadPlanner(DefaultLoadPlanner): + """ + A planner class for loading vescale checkpoint using PyTorch DCP + """ + + def __init__(self): + super().__init__() + + def create_local_plan(self) -> LoadPlan: + return create_default_local_load_plan(self.state_dict, self.metadata) + + def resolve_tensor(self, read_item: ReadItem): + tensor = self.lookup_tensor(read_item.dest_index) + return self.transform_tensor(read_item, tensor) + + def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor: + """ + This is an extension from the planner interface to make it easy to extend the default planner + """ + return find_state_dict_object(self.state_dict, index) + +from internlm.train.pipeline import map_fqn_local_to_global + +def create_default_local_load_plan(state_dict: Dict[str, Any], metadata: Metadata) -> LoadPlan: + """ + A function for creating local loading plan for loading checkpoint + """ + # print(f"metadata: {metadata}", flush=True) + requests = [] + for fqn, obj in state_dict.items(): + md_fqn = fqn + if fqn in map_fqn_local_to_global: + md_fqn = map_fqn_local_to_global[fqn] + print(f"map_fqn_local_to_global {gpc.get_local_rank(ParallelMode.PIPELINE)}: {fqn}, {md_fqn}", flush=True) + md = metadata.state_dict_metadata[md_fqn] + if isinstance(obj, DTensor): + assert False + if obj.device_mesh.get_coordinate() is not None: + requests += _create_read_items(fqn, md, obj) + elif isinstance(obj, OptimizerStateSpec): + assert False + # If the state is distributed on multiple dp ranks + # Read with local_shape, then in DOptimizer then + # get flaaten to 1D and get the part belonging to current dp rank + if obj.dp_ranks_ranges: + obj.local_tensor = torch.zeros( + obj.local_shape, dtype=obj.local_tensor.dtype, device=obj.local_tensor.device + ) + requests += _create_read_items(fqn, md, obj) + else: + # If the state is owned by only one dp rank + # Read directly + obj.local_tensor = obj.local_tensor.reshape(obj.local_shape) + requests += _create_read_items(fqn, md, obj) + else: + item = _create_read_items(fqn, md_fqn, md, obj) + # print(f"_create_read_items {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.PIPELINE)}: {item}", flush=True) + requests += item + return LoadPlan(requests) + + +class VeScaleSavePlanner(DefaultSavePlanner): + """ + A planner class for saving vescale checkpoint using PyTorch DCP + """ + + def __init__(self): + super().__init__() + self._plan_cache = PlanLRUCache() + + def resolve_data(self, write_item: WriteItem, fqn=None) -> Union[torch.Tensor, io.BytesIO]: + assert write_item.type != WriteItemType.BYTE_IO + object = self.lookup_object(write_item.index, fqn) + return self.transform_object(write_item, object) + + def create_local_plan(self) -> Tuple[SavePlan, P2PTensorsInfo]: + plan, p2p_tensors_info = create_default_local_save_plan(self.state_dict, self.is_coordinator) + # print(f"save before replace local_plan {dist.get_rank()}: {plan}", flush=True) + if self.flatten_state_dict: + plan = dataclasses.replace(plan, planner_data=self.mappings) + self.plan = plan + return self.plan, p2p_tensors_info + + def lookup_object(self, index: MetadataIndex, fqn=None) -> Any: + return find_state_dict_object(self.state_dict, index, fqn) + + def lookup_plan_meta(self) -> Optional[Tuple[SavePlan, Metadata]]: + if not hasattr(self, STATE_DICT_STR): + return None + else: + device_mesh = VESCALE_DEVICE_MESH.get() + plan_key = hash((frozenset(self.state_dict.keys()), self.is_coordinator, device_mesh)) + return self._plan_cache.get(plan_key) + + def cache_plan_meta(self, new_plan: SavePlan, new_metadata: Metadata) -> None: + device_mesh = VESCALE_DEVICE_MESH.get() + plan_key = hash((frozenset(self.state_dict.keys()), self.is_coordinator, device_mesh)) + self._plan_cache.put(plan_key, new_plan, new_metadata) + + def clear_cache(self) -> None: + self._plan_cache.clear() + + def dedup_plans(self, all_plans: List[SavePlan]) -> List[SavePlan]: + # Use customized deduplicate function for load balance + all_plans = custom_dedup_tensors(all_plans) + return all_plans + + def create_dedup_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]: + # Disable DCP's dedup replicated tensors function + self.dedup_replicated_tensors = False + rst_value = super().create_global_plan(all_plans) + return rst_value + + def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]: + # Disable DCP's dedup replicated tensors function + self.dedup_replicated_tensors = False + # Use customized deduplicate function for load balance + all_plans = custom_dedup_tensors(all_plans) #att + rst_value = super().create_global_plan(all_plans) + return rst_value + + +def create_default_local_save_plan(state_dict: Dict[str, Any], is_coordinator: bool) -> SavePlan: + """ + A function for creating local saving plan for saving checkpoint. + """ + requests = [] + # Key: fqn + # Value: dictionary (Key is the process rank, value is tensor to receive) + recv_tensors = {} + + send_p2p_reqs = [] + recv_p2p_reqs = {} + + for fqn, obj in state_dict.items(): + # Since DTensor supports submesh, adding extra check to ensure _create_write_items() + # gets called only when the current rank is part of the mesh for the corresponding DTensor. + if isinstance(obj, DTensor): + assert False + if obj.device_mesh.get_coordinate() is not None: + requests += _create_write_items(fqn, obj) + elif isinstance(obj, OptimizerStateSpec): + assert False #attn + # Create write requests if the process is the real writer + if obj.dp_ranks_ranges: + process_list = [] + for rank, param_range in obj.dp_ranks_ranges.items(): + process_list.append((rank, len(param_range))) + sorted_list = sort_rank_ranges(process_list) + writer_rank = sorted_list[mmh3.hash(fqn) % len(sorted_list)][0] + send_ops_to_start = [] + recv_ops_to_start = {} + # Case 1: I am writer + # Receive tensors + logger.debug(f"fqn={fqn} is a tensor across dp ranks. writer rank={writer_rank}") + if dist.get_rank() == writer_rank: + recv_tensors[fqn] = {} + for k, param_range in obj.dp_ranks_ranges.items(): + if k != dist.get_rank(): + recv_tensor = torch.zeros( + (len(param_range),), dtype=obj.local_tensor.dtype, device=obj.local_tensor.device + ) + recv_op = dist.P2POp( + op=dist.irecv, + tensor=recv_tensor, + peer=k, + group=gpc.get_group(ParallelMode.DATA), #att + ) + recv_tensors[fqn][k] = (recv_tensor, param_range) + recv_ops_to_start[k] = recv_op + else: + # Case 2: I am not writer + # Send my tensor + send_op = dist.P2POp( + op=dist.isend, + tensor=obj.local_tensor, + peer=writer_rank, + group=gpc.get_group(ParallelMode.DATA), #att + ) + send_ops_to_start.append(send_op) + + send_reqs = [] + recv_reqs = [] + if send_ops_to_start: + send_reqs = dist.batch_isend_irecv(send_ops_to_start) + if recv_ops_to_start: + recv_reqs = dist.batch_isend_irecv(list(recv_ops_to_start.values())) + + if send_reqs: + send_p2p_reqs.extend(send_reqs) + + if recv_reqs: + recv_p2p_reqs[fqn] = recv_reqs + else: + obj.local_tensor = obj.local_tensor.reshape(obj.local_shape) + requests += _create_write_items(fqn, obj) + elif isinstance(obj, (torch.Tensor)) or is_coordinator: + item = _create_write_items(fqn, obj) + # print(f"_create_write_items {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.TENSOR)}: {item}", flush=True) + requests += item + # if dist.get_rank() == 1: + # print(f"requests: {requests}", flush=True) + + # Padding the states across DP ranks + # Merge the tensors later + writer_rank = dist.get_rank() + for fqn in recv_tensors.keys(): + obj = state_dict[fqn] + new_local_tensor = torch.zeros( + (math.prod(obj.local_shape),), dtype=obj.local_tensor.dtype, device=obj.local_tensor.device + ) + new_local_tensor[obj.dp_ranks_ranges[writer_rank].start : obj.dp_ranks_ranges[writer_rank].end] = ( + obj.local_tensor + ) + obj.local_tensor = new_local_tensor + + obj.local_tensor = obj.local_tensor.reshape(obj.local_shape) + requests += _create_write_items(fqn, obj) + return SavePlan(requests), P2PTensorsInfo(recv_tensors, send_p2p_reqs, recv_p2p_reqs) diff --git a/internlm/checkpoint/vescale/vescale_planner_helpers.py b/internlm/checkpoint/vescale/vescale_planner_helpers.py new file mode 100644 index 000000000..9c4e8903e --- /dev/null +++ b/internlm/checkpoint/vescale/vescale_planner_helpers.py @@ -0,0 +1,288 @@ +################################################################################ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +################################################################################ +# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. +################################################################################ + +from typing import Any, List +import torch +from torch.distributed._shard.metadata import ShardMetadata +from torch.distributed.checkpoint.planner import WriteItem, ReadItem, WriteItemType, LoadItemType, TensorWriteData +from torch.distributed.checkpoint.metadata import ( + STATE_DICT_TYPE, + STORAGE_TYPES, + MetadataIndex, + ChunkStorageMetadata, + BytesStorageMetadata, + TensorStorageMetadata, +) +from torch.distributed._shard.sharded_tensor import TensorProperties +from torch.distributed.checkpoint.resharding import ( + _check_shard_metadata_pair_overlap, + _shards_get_overlap_region_wrt_saved_tensor, +) +from internlm.core.context import global_context as gpc +from internlm.core.context import ParallelMode +# from vescale.dtensor import DTensor +from ._utils import compute_local_shape, compute_local_offset +from .distributed_optimizer import OptimizerStateSpec +from internlm.utils.logger import get_logger +logger = get_logger(__file__) + +class DTensor: + pass + + +def _create_write_items_for_dtensor(fqn, tensor: DTensor) -> WriteItem: + assert False + sizes = torch.Size(compute_local_shape(tensor.shape, tensor.device_mesh, tensor.placements)) + offsets = torch.Size(compute_local_offset(tensor.shape, tensor.device_mesh, tensor.placements)) + + return WriteItem( + index=MetadataIndex(fqn=fqn, offset=offsets), + type=WriteItemType.SHARD, + tensor_data=TensorWriteData( + chunk=ChunkStorageMetadata(offsets=offsets, sizes=sizes), + properties=TensorProperties.create_from_tensor(tensor._local_tensor), # keep out of autograd + size=tensor.size(), + ), + ) + + +def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata: + assert False + sizes = torch.Size(compute_local_shape(tensor.shape, tensor.device_mesh, tensor.placements)) + offsets = torch.Size(compute_local_offset(tensor.shape, tensor.device_mesh, tensor.placements)) + return ChunkStorageMetadata(offsets=offsets, sizes=sizes) + +from internlm.train.pipeline import map_layer_attr, map_fqn_local_to_global +def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor) -> WriteItem: + offsets = torch.Size([0] * len(tensor.size())) + size = tensor.size() + if "norm" not in fqn: + assert fqn in map_layer_attr, f"{fqn}" + if fqn in map_layer_attr: + offsets = torch.Size(map_layer_attr[fqn]['offset']) + size = torch.Size(map_layer_attr[fqn]['complete_size']) + if fqn in map_fqn_local_to_global: + fqn = map_fqn_local_to_global[fqn] + + # if offsets[0] != 0: + # k = offsets[0] // gpc.get_local_rank(ParallelMode.TENSOR) + # assert child.weight.complete_size[0] % k == 0 and child.weight.complete_size[0] // k == gpc.get_world_size(ParallelMode.TENSOR), f"{child.weight.complete_size}, {child.weight.offset}" + # if child.weight.offset[1] != 0: + # k = child.weight.offset[1] // gpc.get_local_rank(ParallelMode.TENSOR) + # assert child.weight.complete_size[1] % k == 0 and child.weight.complete_size[1] // k == gpc.get_world_size(ParallelMode.TENSOR), f"{child.weight.complete_size}, {child.weight.offset}" + + return WriteItem( + index=MetadataIndex(fqn, offsets), + type=WriteItemType.SHARD, + tensor_data=TensorWriteData( + chunk=ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()), + properties=TensorProperties.create_from_tensor(tensor), + size=size, + ), + ) + + +def _create_write_item_for_optimizer_state(fqn, object: OptimizerStateSpec) -> WriteItem: + sizes = object.local_shape + offsets = object.global_offset + + return WriteItem( + index=MetadataIndex(fqn=fqn, offset=offsets), + type=WriteItemType.SHARD, + tensor_data=TensorWriteData( + chunk=ChunkStorageMetadata(offsets=offsets, sizes=sizes), + properties=TensorProperties.create_from_tensor(object.local_tensor), + size=object.global_shape, + ), + ) + + +def _create_write_item_for_bytesio(fqn: str, bytes: Any): + return WriteItem( + index=MetadataIndex(fqn), + type=WriteItemType.BYTE_IO, + ) + + +def _create_write_items(fqn: str, object: Any) -> List[WriteItem]: + if isinstance(object, DTensor): + # assert False + return [_create_write_items_for_dtensor(fqn, object)] + elif isinstance(object, torch.Tensor): + return [_create_write_item_for_tensor(fqn, object)] + elif isinstance(object, OptimizerStateSpec): + return [_create_write_item_for_optimizer_state(fqn, object)] + else: + return [_create_write_item_for_bytesio(fqn, object)] + + +def _create_read_item_for_tensor(dest_index, dest_offsets, storage_index, storage_offsets, lengths): + return ReadItem( + type=LoadItemType.TENSOR, + dest_index=dest_index, + dest_offsets=torch.Size(dest_offsets), + storage_index=storage_index, + storage_offsets=torch.Size(storage_offsets), + lengths=torch.Size(lengths), + ) + + +def create_read_items_for_chunk_list( + fqn: str, + md_fqn: str, + checkpoint_md: TensorStorageMetadata, + local_chunks: List[ChunkStorageMetadata], +) -> List[ReadItem]: + """ + Creates a list of ``ReadItem`` based on the checkpoint and local chunks. + + This applies the resharding algorithm and computes the reads needed + to satisfy ``local_chunks`` with a checkpoint described by ``checkpoint_md``. + + Args: + fqn (str) : The state_dict FQN to pass to ``ReadItem``. + checkpoint_md (TensorStorageMetadata): metadata for a given tensor + from a checkpoint. + local_chunks (List[ChunkStorageMetadata]): Local chunks that needs to be + loaded. + + Returns: + A list of ``ReadItem`` that will satisfy all input chunks. + """ + read_items = [] + # this is a naive quadratic algo that can be optimized later + for idx, shard in enumerate(local_chunks): + for storage_idx, storage_md in enumerate(checkpoint_md.chunks): + if not _check_shard_metadata_pair_overlap(shard, storage_md): + continue + + storage_offsets = [] + dest_offsets = [] + lengths = [] + for ( + dim, + offset_for_saved_tensor, + offset_for_current_tensor, + length, + ) in _shards_get_overlap_region_wrt_saved_tensor(saved_shard=storage_md, current_shard=shard): + storage_offsets.append(offset_for_saved_tensor) + dest_offsets.append(offset_for_current_tensor) + lengths.append(length) + + read_items.append( + _create_read_item_for_tensor( + dest_index=MetadataIndex(fqn, shard.offsets, idx), + dest_offsets=dest_offsets, + storage_index=MetadataIndex(md_fqn, storage_md.offsets, storage_idx), + storage_offsets=storage_offsets, + lengths=lengths, + ) + ) + return read_items + + +def _create_chunk_from_tensor(fqn, tensor: torch.Tensor) -> ChunkStorageMetadata: + + # sizes = torch.Size(compute_local_shape(tensor.shape, tensor.device_mesh, tensor.placements)) + # offsets = torch.Size(compute_local_offset(tensor.shape, tensor.device_mesh, tensor.placements)) + # return ChunkStorageMetadata(offsets=offsets, sizes=sizes) + + offsets = torch.Size([0] * len(tensor.size())) + if "norm" not in fqn: + assert fqn in map_layer_attr, f"{fqn}" + offsets = torch.Size(map_layer_attr[fqn]['offset']) + + + # return ChunkStorageMetadata(offsets=torch.Size([0] * len(tensor.size())), sizes=tensor.size()) + return ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()) + + + +def _create_read_item_for_byteio(dest_index, dest_offset, storage_index, storage_offset, length): + return ReadItem( + type=LoadItemType.BYTE_IO, + dest_index=dest_index, + dest_offsets=torch.Size((dest_offset,)), + storage_index=storage_index, + storage_offsets=torch.Size((storage_offset,)), + lengths=torch.Size((length,)), + ) + + +def _create_chunk_from_optimizer_spec(obj: OptimizerStateSpec) -> ChunkStorageMetadata: + return ChunkStorageMetadata(offsets=obj.global_offset, sizes=obj.local_shape) + + +def _create_read_items(fqn: str, md_fqn, md: STORAGE_TYPES, obj: Any) -> List[ReadItem]: + if not isinstance(md, BytesStorageMetadata): + if isinstance(obj, DTensor): + assert False + local_chunks = [_create_chunk_from_dtensor(obj)] + elif isinstance(obj, torch.Tensor): + local_chunks = [_create_chunk_from_tensor(fqn, obj)]#att + # print(f"local_chunks {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.PIPELINE)} {fqn}: {local_chunks}", flush=True) + elif isinstance(obj, OptimizerStateSpec): + local_chunks = [_create_chunk_from_optimizer_spec(obj)] + else: + raise ValueError( + f"Invalid checkpoint metadata for {fqn}, " + f"expected BytesStorageMetadata but found {type(md)}" + ) + return create_read_items_for_chunk_list(fqn, md_fqn, md, local_chunks) + else: + assert False + return [ + _create_read_item_for_byteio( + dest_index=MetadataIndex(fqn), + dest_offset=0, + storage_index=MetadataIndex(fqn), + storage_offset=0, + length=0, + ) + ] + + +def _chunk_for_shard(shard_md: ShardMetadata) -> ChunkStorageMetadata: + return ChunkStorageMetadata( + offsets=torch.Size(shard_md.shard_offsets), + sizes=torch.Size(shard_md.shard_sizes), + ) + + +def find_tensor_shard(tensor: torch.Tensor, index: MetadataIndex) -> torch.Tensor: + if isinstance(tensor, DTensor): + assert False + return tensor.to_local() + if index.offset is not None: + # special case looking up a tensor by origin + if index.offset == torch.Size([0] * len(tensor.size())): + return tensor + raise ValueError(f"FQN: '{index.fqn}' is not a DTensor, can't find by offset: '{index.offset}'") + return tensor + + +def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex, fqn=None) -> Any: + # Called when real writing happened + # The filesystem writer calls resolve_data , then it will + # call find_state_dict_object + if fqn is None: + fqn = index.fqn + if fqn not in state_dict: + raise ValueError(f"Could not find FQN: '{fqn}'") + obj = state_dict[fqn] + + # if isinstance(obj, torch.Tensor): #att + # return find_tensor_shard(obj, index) + if isinstance(obj, OptimizerStateSpec): + return obj.local_tensor + # elif index.offset is not None: + # raise ValueError( + # f"FQN: '{index.fqn}' is not a DTensor, it is a {type(obj)} can't find by offset: '{index.offset}'" + # ) + return obj diff --git a/internlm/model/modules/embedding.py b/internlm/model/modules/embedding.py index 93fcd6b23..3585e2e46 100644 --- a/internlm/model/modules/embedding.py +++ b/internlm/model/modules/embedding.py @@ -55,6 +55,7 @@ def __init__( self.embed_dim_per_partition = embedding_dim self.vocab_start_index = gpc.get_local_rank(ParallelMode.TENSOR) * self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition + self.offset = [self.vocab_start_index, 0] else: assert embedding_dim % parallel_size == 0, f"{embedding_dim} is not divisible by {parallel_size}" @@ -62,12 +63,15 @@ def __init__( self.embed_dim_per_partition = embedding_dim // parallel_size self.vocab_start_index = 0 self.vocab_end_index = self.num_embeddings_per_partition + self.offset = [0, self.embed_dim_per_partition * gpc.get_local_rank(ParallelMode.TENSOR)] self.weight = nn.Parameter( torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), dtype=dtype) ) - + self.complete_size = [num_embeddings, embedding_dim] setattr(self.weight, "is_embedding_param", True) + setattr(self.weight, "offset", self.offset) + setattr(self.weight, "complete_size", [num_embeddings, embedding_dim]) def forward(self, input_: Tensor) -> Tensor: if self.vocab_parallel and not is_using_isp(): diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index 29070b429..db181e05d 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -597,6 +597,7 @@ def __init__( world_size = gpc.get_world_size(parallel_mode) rank = gpc.get_local_rank(parallel_mode) + self.offset = None if split_mode != "none": split_features = out_features if split_mode == "column" else in_features @@ -606,14 +607,29 @@ def __init__( mod = multiple % world_size # The first @mod ranks get @div + 1 copies, the rest get @div copies local_multiple = div + int(rank < mod) - + # if parallel_mode == ParallelMode.TENSOR: + # print(f"ParallelLinearWithCommExt {split_mode}: infe={in_features}, outfe={out_features}, split={local_multiple * multiple_of}, local_multiple={local_multiple}, multiple_of={multiple_of}", flush=True) if split_mode == "column": super().__init__(in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype) + self.offset = [rank * local_multiple * multiple_of, 0] elif split_mode == "row": super().__init__(local_multiple * multiple_of, out_features, bias=bias, device=device, dtype=dtype) + self.offset = [0, rank * local_multiple * multiple_of] else: super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) - + self.complete_size = [out_features, in_features] + setattr(self.weight, "offset", self.offset) + setattr(self.weight, "complete_size", [out_features, in_features]) + if self.weight.offset[0] != 0: + k = self.weight.offset[0] // rank + assert self.weight.complete_size[0] % k == 0 and self.weight.complete_size[0] // k == gpc.get_world_size(parallel_mode), f"{self.weight.complete_size}, {self.weight.offset}" + else: + assert rank == 0 or self.weight.size()[0] == self.weight.complete_size[0], f"{rank}, {self.weight.size()}, {self.weight.complete_size}, {self.weight.offset} \n split_mode={split_mode}, in_features={in_features}, out_features={out_features}, multiple_of={multiple_of}, multiple={multiple}, world_size={world_size}, div={div}, mod={mod}, local_multiple={local_multiple}" + if self.weight.offset[1] != 0: + k = self.weight.offset[1] // rank + assert self.weight.complete_size[1] % k == 0 and self.weight.complete_size[1] // k == gpc.get_world_size(parallel_mode), f"{self.weight.complete_size}, {self.weight.offset}" + else: + assert rank == 0 or self.weight.size()[1] == self.weight.complete_size[1], f"{rank}, {self.weight.size()}, {self.weight.complete_size}, {self.weight.offset}" def forward(self, input: torch.Tensor, batch_sizes: torch.Tensor = None) -> torch.Tensor: # pylint: disable=W0622 _class_name = self.__class__.__name__ assert self._communicator is not None, f"{_class_name} should register with a communicator first." diff --git a/internlm/model/modules/mha.py b/internlm/model/modules/mha.py index a8ef77bc1..a3ad146db 100644 --- a/internlm/model/modules/mha.py +++ b/internlm/model/modules/mha.py @@ -72,6 +72,10 @@ def _qkv_save_convert(module: "GQA", state_dict, prefix: str, *args, **kwargs) - f"{prefix}wv.weight", f"{prefix}wqkv.weight", ) + from internlm.train.pipeline import map_layer_attr + map_layer_attr[wq_name] = map_layer_attr[fused_name] + map_layer_attr[wk_name] = map_layer_attr[fused_name] + map_layer_attr[wv_name] = map_layer_attr[fused_name] if module.enable_qkv_fusion: state_dict[wq_name], state_dict[wk_name], state_dict[wv_name] = split_fused_wqkv_weight( @@ -465,8 +469,9 @@ def __init__( self._register_load_state_dict_pre_hook( partial(_qkv_pre_load_convert, q_dim=q_dim, kv_dim=self.kv_dim), with_module=True ) - self._register_state_dict_hook(partial(_qkv_save_convert, q_dim=q_dim, kv_dim=self.kv_dim)) + # self._register_state_dict_hook(partial(_qkv_save_convert, q_dim=q_dim, kv_dim=self.kv_dim)) else: + assert False self.wq = new_linear("wq", embed_dim, q_dim, bias, **factory_kwargs) self.wk = new_linear("wk", embed_dim, self.kv_dim, bias, **factory_kwargs) self.wv = new_linear("wv", embed_dim, self.kv_dim, bias, **factory_kwargs) diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 5907a4e30..2b85b452c 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -93,6 +93,7 @@ ) from internlm.utils.timeout import llm_timeout from internlm.utils.utils import TensorParallelMode +from internlm.model.ops.norm import RMSNorm try: import torch_npu @@ -116,17 +117,31 @@ logger = get_logger(__file__) internlm_accelerator = get_accelerator() +map_layer_attr = {} +map_fqn_local_to_global = {} +map_fqn_global_to_local = {} + + +def recover_pipeline_idx_for_layers(model, idx): + start_id = model.first_layer def set_param_unique_tracking_name(model): + # print(f"first_layer {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.PIPELINE)}: {model.first_layer}, {model.last_layer}", flush=True) for chunk_id, chunk in enumerate(unwrap_naive_amp(model)): # Important: only works for llama-class models childrens = chunk.named_children() - for _, children in childrens: + for children_name, children in childrens: if isinstance(children, nn.ModuleList): for idx, block in enumerate(children): for name, child in block.named_modules(): + if name == "": + continue + full_name = f"{chunk_id}.{idx}.{name}" + parts = f"{full_name}.weight".split('.', 2) + original_id = model.first_layer + idx + map_fqn = f"{children_name}.{original_id}." + '.'.join(parts[2:]) + result = f"{children_name}." + '.'.join(parts[1:]) if isinstance(child, (ParallelLinearWithCommExt)): - full_name = f"{chunk_id}.{idx}.{name}" setattr( child.weight, "tracking_name", @@ -138,19 +153,56 @@ def set_param_unique_tracking_name(model): "tracking_name", f"{full_name}.bias", ) + + + # print(result, flush=True) + assert hasattr(child, "offset"), f"{child}" + + # print(f"layer_name {gpc.get_local_rank(ParallelMode.PIPELINE)}: {children_name}, {result}", flush=True) + + # recover_pipeline_idx_for_layers + + # print(f"original_id {gpc.get_local_rank(ParallelMode.PIPELINE)}: {model.first_layer}, {idx}, {original_id}, {name, child}, {map_fqn}") + + if "4.attention_norm" in map_fqn: + assert False + + map_fqn_local_to_global[result] = map_fqn + map_fqn_global_to_local[map_fqn] = result + # print(f"map_pp_layer_fqn {gpc.get_local_rank(ParallelMode.PIPELINE)}: {map_pp_layer_fqn}", flush=True) + + assert result not in map_layer_attr, f"{map_layer_attr} exists" + + map_layer_attr[result] = {'offset': getattr(child, "offset", [0] * len(child.weight.size())), 'complete_size': getattr(child, "complete_size", child.weight.size())} + # print(f"child.weight {gpc.get_local_rank(ParallelMode.TENSOR)}: {result}, {child.offset}, {child.weight.shape}", flush=True) + elif isinstance(child, (RMSNorm)): + print(f"map_fqn {gpc.get_local_rank(ParallelMode.PIPELINE)}: {map_fqn}", flush=True) + map_fqn_local_to_global[result] = map_fqn + map_fqn_global_to_local[map_fqn] = result + else: + full_name = f"{chunk_id}.{children_name}" + result = f"{children_name}.weight" + # print(f"result: {result}", flush=True) if isinstance(children, Embedding1D): setattr( children.weight, "tracking_name", - f"{chunk_id}_embedding.weight", + f"{chunk_id}_embeddings.weight", ) + assert result not in map_layer_attr, f"{map_layer_attr} exists" + # map_layer_attr[result] = {'offset': children.offset, 'complete_size': children.weight.complete_size} else: setattr( children.weight, "tracking_name", - f"{chunk_id}_head.weight", + f"{full_name}.weight", ) + assert result not in map_layer_attr, f"{map_layer_attr} exists" + # map_layer_attr[result] = {'offset': getattr(children, "offset", [0] * len(children.weight.size())), 'complete_size': getattr(children, "complete_size", children.weight.size())} + map_layer_attr[result] = {'offset': getattr(children, "offset", [0] * len(children.weight.size())), 'complete_size': getattr(children, "complete_size", children.weight.size())} + + # print(f"map_layer_attr global={gpc.get_global_rank()}, pp={gpc.get_local_rank(ParallelMode.PIPELINE)}, tp={gpc.get_local_rank(ParallelMode.TENSOR)}: {map_layer_attr}", flush=True) def set_fp32_attr_for_model(model: Union[nn.Module, nn.ModuleList]): From 774d32ff5d717b2e61e49879c02cde408119202f Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Fri, 20 Dec 2024 20:28:01 +0800 Subject: [PATCH 02/12] support optm --- internlm/checkpoint/checkpoint_manager.py | 53 ++- internlm/checkpoint/vescale/filesystem.py | 143 +++++-- .../checkpoint/vescale/load_state_dict.py | 4 +- .../checkpoint/vescale/save_state_dict.py | 16 +- .../vescale/vescale_checkpointer.py | 68 ++- .../checkpoint/vescale/vescale_planner.py | 30 +- .../vescale/vescale_planner_helpers.py | 60 ++- internlm/core/trainer_builder.py | 112 ++++- internlm/initialize/launch.py | 3 + internlm/model/modules/embedding.py | 5 +- .../solver/optimizer/hybrid_zero_optim.py | 401 +++++++++++++++--- internlm/train/pipeline.py | 56 ++- 12 files changed, 765 insertions(+), 186 deletions(-) diff --git a/internlm/checkpoint/checkpoint_manager.py b/internlm/checkpoint/checkpoint_manager.py index 2f7f5d4ed..826a8b859 100644 --- a/internlm/checkpoint/checkpoint_manager.py +++ b/internlm/checkpoint/checkpoint_manager.py @@ -43,6 +43,7 @@ ) from .load_funcs import LOAD_FUNC_DICT from .utils import process_load_info +from internlm.checkpoint.vescale.api import save as vescale_save logger = get_logger(__file__) internlm_accelerator = get_accelerator() @@ -60,7 +61,7 @@ class CheckpointLoadContent: SCHEDULAER = "scheduler" -def try_load_internevo_ckpt(ckpt_mm, load_info, train_state: TrainState = None): +def try_load_internevo_ckpt(ckpt_mm, load_info, train_state: TrainState = None, universal_ckpt=False): """Tries to load a checkpoint from the given folder. Args: @@ -82,8 +83,13 @@ def try_load_internevo_ckpt(ckpt_mm, load_info, train_state: TrainState = None): and the checkpoint manager ckpt_mm and train state objects """ load_content_str, load_ckpt_folder, load_content = process_load_info(load_info) + + if universal_ckpt: + from internlm.checkpoint.vescale.api import load as vescale_load + checkpoint_state = {"model": ckpt_mm.model, "optimizer": ckpt_mm.optimizer} + vescale_load(load_ckpt_folder, checkpoint_state, broadcast_checkpoint=False) - if load_content.need_load(CheckpointLoadContent.MODEL): + if not universal_ckpt and load_content.need_load(CheckpointLoadContent.MODEL): load_model_checkpoint(folder=load_ckpt_folder, model=ckpt_mm.model) load_content_str += f"{CheckpointLoadContent.MODEL}, " @@ -93,7 +99,7 @@ def try_load_internevo_ckpt(ckpt_mm, load_info, train_state: TrainState = None): load_context(load_ckpt_folder, train_state) # load optimizer states. - if load_content.need_load(CheckpointLoadContent.OPIMIZER): + if not universal_ckpt and load_content.need_load(CheckpointLoadContent.OPIMIZER): load_optimizer_checkpoint(load_ckpt_folder, ckpt_mm.optimizer) load_content_str += f"{CheckpointLoadContent.OPIMIZER}, " else: @@ -110,6 +116,7 @@ def try_load_internevo_ckpt(ckpt_mm, load_info, train_state: TrainState = None): logger.warning("CheckpointManager has no 'lr_scheduler', skip reload lr_scheduler checkpoint!") if not load_content.need_load(CheckpointLoadContent.OPIMIZER): + assert False if ckpt_mm.lr_scheduler and train_state: gpc.config.only_load_lr = True load_optimizer_checkpoint(load_ckpt_folder, ckpt_mm.optimizer) @@ -419,10 +426,11 @@ def is_now_to_save_ckpt(self, train_state, force=False) -> (bool, CheckpointSave def try_save_checkpoint(self, train_state, force=False): if not self.enable_save_ckpt: return False - + save_ckpts, save_type, now_break = self.is_now_to_save_ckpt(train_state, force=force) if save_ckpts: + begin = time.time() # Wait for the previous round of asynchronous upload storage to complete. self.storage_manager.wait() if save_type == CheckpointSaveType.SNAPSHOT_CHECKPOINT: @@ -440,6 +448,7 @@ def try_save_checkpoint(self, train_state, force=False): train_state=train_state, model_config=self.model_config, model_config_file=self.model_config_file, + universal_ckpt=gpc.config.ckpt.universal_ckpt, ) if ( @@ -460,6 +469,8 @@ def try_save_checkpoint(self, train_state, force=False): f"Finish to convert internevo2hf checkpoint from {save_ckpt_folder} to {save_hf_ckpt_folder}." ) torch.distributed.barrier() + end = time.time() - begin + print(f"finsh save time {gpc.get_global_rank()}: {end}", flush=True) return now_break @@ -576,12 +587,19 @@ def try_resume_training(self, train_state: TrainState, current_time=""): f"dp={gpc.get_local_rank(ParallelMode.DATA)}===========" ) else: + begin = time.time() load_path = self.load_ckpt_info["path"] load_content = self.load_ckpt_info["content"] load_type = self.load_ckpt_info["ckpt_type"] + universal_ckpt = gpc.config.ckpt.universal_ckpt + kwargs = {} + + if universal_ckpt: + assert load_type == "internevo", "Only internevo ckpt support universal ckpt." + kwargs = {"universal_ckpt": universal_ckpt} load_func = CheckpointLoadMethod.get_ckpt_load_type_func(load_type) - load_content_str = load_func(self, self.load_ckpt_info, train_state) + load_content_str = load_func(self, self.load_ckpt_info, train_state, **kwargs) # If we only load model weight, we need rewrite zero optim's fp32 buffer. if ( @@ -598,6 +616,8 @@ def try_resume_training(self, train_state: TrainState, current_time=""): ) if load_content_str: logger.info(f"===========Load contents are: {load_content_str}") + end = time.time() - begin + print(f"finsh load time {gpc.get_global_rank()}: {end}", flush=True) @llm_timeout(func_name="save_checkpoint") def save_checkpoint( @@ -609,6 +629,7 @@ def save_checkpoint( train_state: TrainState, model_config: Dict = None, model_config_file: str = None, + universal_ckpt=False, ): """ Save checkpoint to the given folder path. @@ -621,13 +642,23 @@ def save_checkpoint( if gpc.is_rank_for_log(): logger.info(f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count}...") - timer("save-model").start() - save_model_checkpoint(folder=folder, model=model) - timer("save-model").stop() + if not universal_ckpt: + print(f"save ckpt: base", flush=True) + timer("save-model").start() + save_model_checkpoint(folder=folder, model=model) + timer("save-model").stop() - timer("save-optimizer").start() - save_optimizer_checkpoint(optim=optimizer, state_path=folder) - timer("save-optimizer").stop() + timer("save-optimizer").start() + save_optimizer_checkpoint(optim=optimizer, state_path=folder) + timer("save-optimizer").stop() + else: + print(f"save ckpt: universal", flush=True) + vescale_save( + path=folder, + checkpoint_state={"model": model, "optimizer": optimizer}, + async_checkpoint=False, + ) + if ( hasattr(train_state, "data_state_dict") diff --git a/internlm/checkpoint/vescale/filesystem.py b/internlm/checkpoint/vescale/filesystem.py index 5fd8030a1..ca8dc3ff4 100644 --- a/internlm/checkpoint/vescale/filesystem.py +++ b/internlm/checkpoint/vescale/filesystem.py @@ -26,7 +26,7 @@ from pathlib import Path from internlm.core.context import global_context as gpc from internlm.core.context import ParallelMode -from internlm.train.pipeline import map_fqn_global_to_local +from internlm.train.pipeline import map_fqn_global_to_local, map_layer_attr from torch.distributed.checkpoint.metadata import ( @@ -144,16 +144,16 @@ def start_loading(self): def values(self): for fqn, _, obj in self.items: tensor = self.resolve_fun(obj).detach() - if self.p2p_tensors_info and (obj.index.fqn, obj.index.offset) in self.p2p_tensors_info.recv_tensors: - tensor = collect_optim_state_across_dp_ranks( - tensor=tensor, - rank_ranges=self.p2p_tensors_info.recv_tensors[(obj.index.fqn, obj.index.offset)], - p2p_reqs=self.p2p_tensors_info.recv_p2p_reqs[(obj.index.fqn, obj.index.offset)], - ) - elif self.p2p_tensors_info and fqn in self.p2p_tensors_info.recv_tensors: - tensor = collect_optim_state_across_dp_ranks( - tensor=tensor, rank_ranges=self.p2p_tensors_info.recv_tensors[fqn], p2p_reqs=self.recv_p2p_reqs[fqn] - ) + # if self.p2p_tensors_info and (obj.index.fqn, obj.index.offset) in self.p2p_tensors_info.recv_tensors: + # tensor = collect_optim_state_across_dp_ranks( + # tensor=tensor, + # rank_ranges=self.p2p_tensors_info.recv_tensors[(obj.index.fqn, obj.index.offset)], + # p2p_reqs=self.p2p_tensors_info.recv_p2p_reqs[(obj.index.fqn, obj.index.offset)], + # ) + # elif self.p2p_tensors_info and fqn in self.p2p_tensors_info.recv_tensors: + # tensor = collect_optim_state_across_dp_ranks( + # tensor=tensor, rank_ranges=self.p2p_tensors_info.recv_tensors[fqn], p2p_reqs=self.recv_p2p_reqs[fqn] + # ) tensor = copy_gpu_tensor_to_cpu_pinned_mem_pool(tensor) # Comment the original DCP code # When dumping to pinned memory, @@ -208,18 +208,18 @@ def _refill(self): fqn, _, obj = self.items[self.idx] self.idx += 1 tensor = self.resolve_fun(obj).detach() - if self.p2p_tensors_info and (obj.index.fqn, obj.index.offset) in self.p2p_tensors_info.recv_tensors: - tensor = collect_optim_state_across_dp_ranks( - tensor=tensor, - rank_ranges=self.p2p_tensors_info.recv_tensors[(obj.index.fqn, obj.index.offset)], - p2p_reqs=self.p2p_tensors_info.recv_p2p_reqs[(obj.index.fqn, obj.index.offset)], - ) - elif self.p2p_tensors_info and fqn in self.p2p_tensors_info.recv_tensors: - tensor = collect_optim_state_across_dp_ranks( - tensor=tensor, - rank_ranges=self.p2p_tensors_info.recv_tensors[fqn], - p2p_reqs=self.p2p_tensors_info.recv_p2p_reqs[fqn], - ) + # if self.p2p_tensors_info and (obj.index.fqn, obj.index.offset) in self.p2p_tensors_info.recv_tensors: + # tensor = collect_optim_state_across_dp_ranks( + # tensor=tensor, + # rank_ranges=self.p2p_tensors_info.recv_tensors[(obj.index.fqn, obj.index.offset)], + # p2p_reqs=self.p2p_tensors_info.recv_p2p_reqs[(obj.index.fqn, obj.index.offset)], + # ) + # elif self.p2p_tensors_info and fqn in self.p2p_tensors_info.recv_tensors: + # tensor = collect_optim_state_across_dp_ranks( + # tensor=tensor, + # rank_ranges=self.p2p_tensors_info.recv_tensors[fqn], + # p2p_reqs=self.p2p_tensors_info.recv_p2p_reqs[fqn], + # ) if tensor.device.type == self.device_type: tensor = copy_gpu_tensor_to_cpu_pinned_mem_pool(tensor, non_blocking=True) # Comment the original DCP code @@ -430,6 +430,7 @@ def _write_files_per_proc_pipe( stream = open(file_path, "wb") executor = ThreadPoolExecutor(max_workers=1) # For byte data, directly write byte data. + assert len(byte_data_item) == 0 for write_data, write_item in byte_data_item: content = write_data.getbuffer() write_futures.append( @@ -443,8 +444,16 @@ def _write_files_per_proc_pipe( ) # write_results.append(_write_to_file(stream, content, write_item, storage_key)) # For tensor data, perform serialization in process then do saving in threadpool. + # print(f"tensor_data_item: {tensor_data_item}", flush=True) + data_memory = 0 + content_momory = 0 + print(f"tensor_data_item {gpc.get_global_rank()}: {len(tensor_data_item)}", flush=True) for write_data, write_item in tensor_data_item: + data_memory += (write_data.numel() * write_data.element_size()) + # print(f"write_item: {write_item}", flush=True) + # print(f"write_data: {write_data}", flush=True) content = _serialize_tensor(write_data) + content_momory += len(content) write_futures.append( executor.submit( _write_to_file, @@ -455,6 +464,7 @@ def _write_files_per_proc_pipe( ) ) # write_results.append(_write_to_file(stream, content, write_item, storage_key)) + print(f"data_memory {gpc.get_global_rank()}: {data_memory / (1024 * 1024 * 1024)}, {content_momory / (1024 * 1024 * 1024) }", flush=True) for fut in write_futures: write_results.append(fut.result()) @@ -582,6 +592,7 @@ def __init__( super().__init__() self.path = Path(path) self.single_file_per_rank = single_file_per_rank + # self.single_file_per_rank = False self.sync_files = sync_files self.worker_count = worker_count self.per_process_copy_ahead = per_process_copy_ahead @@ -600,7 +611,7 @@ def prepare_global_plan(self, global_plan: List[SavePlan]) -> List[SavePlan]: ] return new_plans - def prepare_write_data(self, tasks: List[Tuple[Path, str, List[WriteItem]]], planner: SavePlanner): + def prepare_write_data(self, tasks: List[Tuple[Path, str, List[WriteItem]]], planner: SavePlanner, is_optimizer): """ First stage of saving, Perform Copy data to CPU (D2H). @@ -615,8 +626,12 @@ def prepare_write_data(self, tasks: List[Tuple[Path, str, List[WriteItem]]], pla byte_data_item_writes: List[List[Tuple[io.BytesIO, WriteItem]]] = [] tensor_data_item_writes: List[List[Tuple[torch.Tensor, WriteItem]]] = [] file_path_names: List[Tuple[Path, str]] = [] + + item_list_all = [] + fqn_list_all = [] # Perform D2H in copy stream. + flag = 0 d2h_dump_start = time.time() for task in tasks: file_path, file_name, write_items = task @@ -626,28 +641,53 @@ def prepare_write_data(self, tasks: List[Tuple[Path, str, List[WriteItem]]], pla if len(byte_data_item) != 0: assert False tensor_data_item = [] + + item_list = [] + fqn_list = [] # Async copy to pinned CPU memory pool. for item in tensor_w: # att fqn = _item_fqn(item) - if fqn in map_fqn_global_to_local: - print(f"_item_fqn: {fqn}, {map_fqn_global_to_local[fqn]}", flush=True) - fqn = map_fqn_global_to_local[fqn] + + # map_fqn = fqn + # if fqn.endswith("exp_avg") or fqn.endswith("exp_avg_sq"): + # # os exp_avg, exp_avg_sq + # map_fqn = fqn.rsplit('.', 1)[0] + fqn_list.append((fqn, map_fqn_global_to_local[fqn] if fqn in map_fqn_global_to_local else None)) + if not is_optimizer: + if 'layer' in fqn: + assert fqn in map_fqn_global_to_local + if fqn in map_fqn_global_to_local: + fqn = map_fqn_global_to_local[fqn] + + # if fqn in map_fqn_global_to_local: + # print(f"_item_fqn: {fqn}, {map_fqn_global_to_local[fqn]}", flush=True) + # fqn = map_fqn_global_to_local[fqn] - tensor = planner.resolve_data(item, fqn).detach() + tensor = planner.resolve_data(item, fqn).detach().clone() - if self.p2p_tensors_info and fqn in self.p2p_tensors_info.recv_tensors: - tensor = collect_optim_state_across_dp_ranks( - tensor=tensor, - rank_ranges=self.p2p_tensors_info.recv_tensors[fqn], - p2p_reqs=self.p2p_tensors_info.recv_p2p_reqs[fqn], - ) + # if self.p2p_tensors_info and fqn in self.p2p_tensors_info.recv_tensors: + # tensor = collect_optim_state_across_dp_ranks( + # tensor=tensor, + # rank_ranges=self.p2p_tensors_info.recv_tensors[fqn], + # p2p_reqs=self.p2p_tensors_info.recv_p2p_reqs[fqn], + # ) tensor = copy_gpu_tensor_to_cpu_pinned_mem_pool(tensor, non_blocking=True) + # print(f"item: {item.index.fqn}", flush=True) tensor_data_item.append((tensor, item)) + + item_list.append(item.index.fqn) + flag += 1 byte_data_item_writes.append(byte_data_item) tensor_data_item_writes.append(tensor_data_item) file_path_names.append((file_path, file_name)) + + fqn_list_all.append(fqn_list) + item_list_all.append(item_list) + # print(f"fqn_list_all {gpc.get_global_rank()}: {flag}, {fqn_list_all}", flush=True) + # print(f"item_list_all {gpc.get_global_rank()}: {flag}, {item_list_all}", flush=True) + # print(f"tensor_data_item_writes {gpc.get_global_rank()}: {byte_data_item_writes}, {file_path_names}, {tensor_data_item_writes}", flush=True) d2h_dump_time = time.time() - d2h_dump_start logger.debug(f"End waiting for D2H copy. Time cost: {d2h_dump_time}s") @@ -663,7 +703,7 @@ def prepare_write_data(self, tasks: List[Tuple[Path, str, List[WriteItem]]], pla return byte_data_item_writes, tensor_data_item_writes, file_path_names def write_data( - self, plan: SavePlan, planner: SavePlanner, async_io: bool = False, io_workers=False + self, plan: SavePlan, planner: SavePlanner, async_io: bool = False, io_workers=False, is_optimizer=False ) -> Future[List[WriteResult]]: storage_plan: _StoragePrefix = plan.storage_data file_count = 0 @@ -676,10 +716,11 @@ def gen_file(): tasks: List[Tuple[Path, str, List[WriteItem]]] = [] # Generate K tasks where K is the number of worker_count. - # print(f"self.single_file_per_rank: {self.single_file_per_rank}", flush=True) + print(f"self.single_file_per_rank: {self.single_file_per_rank}", flush=True) if self.single_file_per_rank: for bucket in _split_by_size_and_type(self.worker_count, plan.items): file_name = gen_file() + print(f"file_name {gpc.get_global_rank()}: {file_name}, {self.worker_count}, {bucket}", flush=True) tasks.append((self.path / file_name, file_name, bucket)) # Generate K tasks where K is the number of write items. else: @@ -690,14 +731,15 @@ def gen_file(): # Make sure the optimizer states across dp ranks # has been sending to other ranks # So the receiver can get it when writing tensors to local path - - if self.p2p_tensors_info: - logger.debug("Start waiting for sending p2p tensors futures") - p2p_tensor_send_wait_start = time.time() - for req in self.p2p_tensors_info.send_p2p_reqs: - req.wait() - p2p_tensor_send_wait_time = time.time() - p2p_tensor_send_wait_start - logger.debug(f"End waiting for sending p2p tensors futures Time: {p2p_tensor_send_wait_time}s") + # print(f"p2p_tensors_info: {self.p2p_tensors_info}", flush=True) + # if self.p2p_tensors_info: + # assert False + # logger.debug("Start waiting for sending p2p tensors futures") + # p2p_tensor_send_wait_start = time.time() + # for req in self.p2p_tensors_info.send_p2p_reqs: + # req.wait() + # p2p_tensor_send_wait_time = time.time() - p2p_tensor_send_wait_start + # logger.debug(f"End waiting for sending p2p tensors futures Time: {p2p_tensor_send_wait_time}s") futures = [] if not io_workers: @@ -709,7 +751,7 @@ def gen_file(): # ProcessPool VERSION. if isinstance(executor, ProcessPoolExecutor): # print(f"executor: ProcessPoolExecutor", flush=True) - byte_data_item_writes, tensor_data_item_writes, file_path_names = self.prepare_write_data(tasks, planner) + byte_data_item_writes, tensor_data_item_writes, file_path_names = self.prepare_write_data(tasks, planner, is_optimizer) # print(f"byte_data_item_writes {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.TENSOR)}: {byte_data_item_writes}", flush=True) # print(f"tensor_data_item_writes {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.TENSOR)}: {tensor_data_item_writes}", flush=True) # print(f"file_path_names {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.TENSOR)}: {file_path_names}", flush=True) @@ -737,13 +779,20 @@ def gen_file(): for task in tasks: # print(f"task {gpc.get_global_rank()}: {task}", flush=True) futures.append( + # executor.submit( + # _write_files_from_queue, + # *task, + # planner, + # self.per_process_copy_ahead, + # self.sync_files, + # self.p2p_tensors_info, + # ) executor.submit( _write_files_from_queue, *task, planner, self.per_process_copy_ahead, self.sync_files, - self.p2p_tensors_info, ) ) if async_io: @@ -824,7 +873,9 @@ def read_from_files(self, per_file: Dict[str, List[ReadItem]], planner: LoadPlan for req in reqs: item_md = self.storage_data[req.storage_index] file_slice = self._slice_file(file, item_md) + # print(f"debugg file_slice {gpc.get_global_rank()}, {gpc.get_local_rank(ParallelMode.PIPELINE)}: {file_slice}", flush=True) if req.type == LoadItemType.BYTE_IO: + assert False bytes = io.BytesIO(file_slice.read(item_md.length)) bytes.seek(0) planner.load_bytes(req, bytes) diff --git a/internlm/checkpoint/vescale/load_state_dict.py b/internlm/checkpoint/vescale/load_state_dict.py index 186d92983..6ab14868c 100644 --- a/internlm/checkpoint/vescale/load_state_dict.py +++ b/internlm/checkpoint/vescale/load_state_dict.py @@ -34,11 +34,13 @@ def load_state_dict( no_dist: bool = False, planner: Optional[LoadPlanner] = None, broadcast_tensors=False, + is_optimizer=False ) -> None: load_start_time = time.time() """ [veScale version] Loads a distributed ``state_dict`` in SPMD style. Fix sub-group storage. """ + print(f"load_state_dict: {path}", flush=True) storage_reader = FileSystemReader( path, broadcast_tensors=broadcast_tensors, @@ -64,7 +66,7 @@ def local_step(): planner.set_up_planner(state_dict, metadata, distW.is_coordinator) storage_reader.set_up_storage_reader(metadata, distW.is_coordinator) - local_plan = planner.create_local_plan() + local_plan = planner.create_local_plan(is_optimizer=is_optimizer) local_plan = storage_reader.prepare_local_plan(local_plan) return local_plan diff --git a/internlm/checkpoint/vescale/save_state_dict.py b/internlm/checkpoint/vescale/save_state_dict.py index 7abf79878..6772bce0b 100644 --- a/internlm/checkpoint/vescale/save_state_dict.py +++ b/internlm/checkpoint/vescale/save_state_dict.py @@ -47,6 +47,7 @@ def save_state_dict( async_io: bool = True, last_write_futures: Future[List[WriteResult]] = None, io_workers=None, + is_optimizer=False, ) -> Tuple[Metadata, Future[List[WriteResult]]]: """ [veScale version] Saves a distributed model in SPMD style. Fix sub-group storage. @@ -70,10 +71,9 @@ def save_state_dict( def local_step(): logger.debug("Start local step of planning") if isinstance(planner, VeScaleSavePlanner): - local_plan, p2p_tensors_info = planner.create_local_plan() + local_plan, p2p_tensors_info = planner.create_local_plan(is_optimizer=is_optimizer) local_plan = storage_writer.prepare_local_plan(local_plan, p2p_tensors_info) - if gpc.get_local_rank(ParallelMode.PIPELINE) == 1: - print(f"save local_plan: {local_plan}", flush=True) + print(f"save local_plan {dist.get_rank()}, zero={gpc.get_local_rank(ParallelMode.ZERO1)}, tp={gpc.get_local_rank(ParallelMode.TENSOR)}: {local_plan}", flush=True) # if dist.get_rank() in [0, 1, 2, 3]: # print(f"save local_plan {dist.get_rank()}, {gpc.get_local_rank(ParallelMode.TENSOR)}: {local_plan}", flush=True) else: @@ -88,7 +88,7 @@ def global_step(all_local_plans): assert planner is not None # print(f"planner.coordinator_rank {gpc.get_global_rank()}: {planner.is_coordinator }", flush=True) all_local_plans, global_metatadata = planner.create_global_plan(all_local_plans) - # print(f"save global_metatadata {dist.get_rank()}: {global_metatadata}", flush=True) + print(f"save global_metatadata {dist.get_rank()}, zero={gpc.get_local_rank(ParallelMode.ZERO1)}, tp={gpc.get_local_rank(ParallelMode.TENSOR)}: {global_metatadata}", flush=True) all_local_plans = storage_writer.prepare_global_plan(all_local_plans) logger.debug("End global step of planning") # print(f"global_plan {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.TENSOR)}: {all_local_plans}", flush=True) @@ -103,7 +103,8 @@ def write_data(async_io: bool = False, io_workers=io_workers): final_local_plan = planner.finish_plan(central_plan) if isinstance(planner, VeScaleSavePlanner): # Use pinned memory pool and mult_processing for dumping ckpt to local directory efficiently - all_write_futures = storage_writer.write_data(final_local_plan, planner, async_io, io_workers) + print(f"write_data: {write_data}", flush=True) + all_write_futures = storage_writer.write_data(final_local_plan, planner, async_io, io_workers, is_optimizer) logger.debug("Finish writing data") if async_io: return all_write_futures @@ -144,7 +145,6 @@ def finish_checkpoint(all_results): plan_start_time = time.time() cached_data = None - if isinstance(planner, VeScaleSavePlanner): central_plan = distW.reduce_scatter("plan", local_step, global_step) else: @@ -172,7 +172,7 @@ def finish_checkpoint(all_results): write_futures = [] if isinstance(planner, VeScaleSavePlanner): if cached_data: - logger.debug("Metdata cache hit. Reuse existing metadata") + logger.info("Metdata cache hit. Reuse existing metadata") _, final_storage_metadata = cached_data write_results = write_data(async_io=async_io) # Be sure to write cache metadata to .metadata file @@ -188,7 +188,7 @@ def finish_checkpoint(all_results): if async_io: write_futures = write_results else: - logger.debug("Metadata cache miss. The model/optimizer appears for the first time.") + logger.info("Metadata cache miss. The model/optimizer appears for the first time.") # First time do synchronous storing to get final_storage_metatdata. # Determine which communication topology to use. final_storage_metadata = distW.all_reduce("write", write_data, finish_checkpoint) diff --git a/internlm/checkpoint/vescale/vescale_checkpointer.py b/internlm/checkpoint/vescale/vescale_checkpointer.py index db19e284d..c28b23e90 100644 --- a/internlm/checkpoint/vescale/vescale_checkpointer.py +++ b/internlm/checkpoint/vescale/vescale_checkpointer.py @@ -26,10 +26,14 @@ import os # from .distributed_optimizer import initialize_optimizer_state import torch.distributed as dist +import torch from internlm.utils.logger import get_logger import atexit from internlm.core.context import global_context as gpc from internlm.core.context import ParallelMode +from internlm.solver.optimizer import HybridZeroOptimizer +from internlm.train.pipeline import map_fqn_local_to_global + logger = get_logger(__file__) @@ -102,14 +106,20 @@ def save( if key not in VESCALE_SUPPORTED_TYPES: raise ValueError(f"{key} is not supported by VeScaleCheckpointer") + if path.startswith("local:"): + path = path.split(':')[1] + assert ':' not in path, f"{path} is not valid for universal checkpoint!" # Start saving checkpoint for key, value in checkpoint_state.items(): if key == MODEL_STR: # Get model path model_path = os.path.join(path, MODEL_STR) + print(f"model_path: {path}, {model_path}", flush=True) # Create a "model" folder on under root path if dist.get_rank() == 0: bfile.makedirs(model_path) + # if not os.path.exists(path): + # os.makedirs(path, exist_ok=True) dist.barrier() # Save model. _, new_write_futures = save_state_dict( @@ -122,44 +132,59 @@ def save( async_io=async_checkpoint, last_write_futures=cls.state_write_futures[MODEL_STR], io_workers=cls.state_io_workers[MODEL_STR], + is_optimizer=False, ) # Record new write futures. cls.state_write_futures[MODEL_STR] = new_write_futures - dist.barrier() + dist.barrier() #att elif key == OPTIMIZER_STR: + # adamW hybrid zero optim + assert isinstance(value, HybridZeroOptimizer) + optimizer_state = value.state_dict() + # Create a "optimizer" folder on under root path # to save different parts of optimizer optim_root_path = os.path.join(path, OPTIMIZER_STR) + print(f"optim_root_path: {optim_root_path}", flush=True) if dist.get_rank() == 0: bfile.makedirs(optim_root_path) dist.barrier() # Get process group for saving optimizer, # All processes with the same pipeline rank are in the same pg - if not cls.optim_ckpt_proces_group: - cls.optim_ckpt_proces_group = get_optim_ckpt_process_group() + # if not cls.optim_ckpt_proces_group: + # cls.optim_ckpt_proces_group = get_optim_ckpt_process_group() # Get optimizer path based on PP rank # pp_rank = VESCALE_DEVICE_MESH.get_pipeline_parallel_rank() #attn - pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - optimizer_path = os.path.join(optim_root_path, f"pp_{pp_rank}") + # pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + # optimizer_path = os.path.join(optim_root_path, f"pp_{pp_rank}") + optimizer_path = optim_root_path # Create optimizer folder on under root path - if dist.get_rank(cls.optim_ckpt_proces_group) == 0: - bfile.makedirs(optimizer_path) - dist.barrier() + # if dist.get_rank(cls.optim_ckpt_proces_group) == 0: + # bfile.makedirs(optimizer_path) + # dist.barrier() # Save optimizer _, new_write_futures = save_state_dict( - state_dict=value.state_dict(), + state_dict=optimizer_state["sharded_optimizer_state"], path=optimizer_path, - process_group=cls.optim_ckpt_proces_group, + process_group=None, coordinator_rank=0, no_dist=False, planner=cls.save_planner, async_io=async_checkpoint, last_write_futures=cls.state_write_futures[OPTIMIZER_STR], io_workers=cls.state_io_workers[OPTIMIZER_STR], + is_optimizer=True, ) # Record new write futures. cls.state_write_futures[OPTIMIZER_STR] = new_write_futures + + optimizer_state.pop("sharded_optimizer_state") + # print(f"after_pop {gpc.get_global_rank}: {optimizer_state}", flush=True) + if gpc.get_global_rank() == 0: + print(f"global_optimizer_state: {os.path.join(path, 'global_optimizer_state.pt')}", flush=True) + torch.save(optimizer_state, os.path.join(path, "global_optimizer_state.pt")) + dist.barrier() @classmethod def load( @@ -176,6 +201,9 @@ def load( processes with data parallel rank = 0 load model from file system then broadcast it to processes with data parallel rank = 1 """ + if path.startswith("local:"): + path = path.split(':')[1] + assert ':' not in path, f"{path} is not valid for universal checkpoint!" # Add warning if bfile.is_local_path(path): logger.warning( @@ -214,31 +242,37 @@ def load( # Load back to model value.load_state_dict(model_state) #att elif key == OPTIMIZER_STR: - assert False # Get process group for loading optimizer, # All processes with the same pipeline rank are in the same pg - if not cls.optim_ckpt_proces_group: - cls.optim_ckpt_proces_group = get_optim_ckpt_process_group() + # if not cls.optim_ckpt_proces_group: + # cls.optim_ckpt_proces_group = get_optim_ckpt_process_group() # Get optimizer path based on TP and PP ranks # pp_rank = VESCALE_DEVICE_MESH.get_pipeline_parallel_rank() #attn pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - optimizer_path = os.path.join(path, f"{OPTIMIZER_STR}", f"pp_{pp_rank}") + optimizer_path = os.path.join(path, OPTIMIZER_STR) + print(f"optimizer_path: {optimizer_path}", flush=True) # Initialize optimizer states # initialize_optimizer_state(value) # Get optimizer state optimizer_state = value.state_dict() # Load optimizer state dictionary load_state_dict( - state_dict=optimizer_state, + state_dict=optimizer_state["sharded_optimizer_state"], path=optimizer_path, - process_group=cls.optim_ckpt_proces_group, + process_group=None, coordinator_rank=0, no_dist=False, planner=cls.load_planner, broadcast_tensors=False, + is_optimizer=True, ) + #att check len equal + # print(f"optimizer_state {gpc.get_global_rank()}: {len(optimizer_state)}, {optimizer_state.keys()}", flush=True) + # fqn_list = checkpoint_state['optimizer'].fqn_list + # print(f"fqn_list {gpc.get_global_rank()}: {len(fqn_list[0]) + len(fqn_list[1])}, {fqn_list[0], fqn_list[1]}", flush=True) # Load back to optimizer - value.load_state_dict(optimizer_state) + global_optimizer_state = torch.load(os.path.join(path, "global_optimizer_state.pt")) + value.load_state_dict(optimizer_state["sharded_optimizer_state"], global_optimizer_state) dist.barrier() @classmethod diff --git a/internlm/checkpoint/vescale/vescale_planner.py b/internlm/checkpoint/vescale/vescale_planner.py index 91fce5e91..d1b60fbc7 100644 --- a/internlm/checkpoint/vescale/vescale_planner.py +++ b/internlm/checkpoint/vescale/vescale_planner.py @@ -55,8 +55,8 @@ class VeScaleLoadPlanner(DefaultLoadPlanner): def __init__(self): super().__init__() - def create_local_plan(self) -> LoadPlan: - return create_default_local_load_plan(self.state_dict, self.metadata) + def create_local_plan(self, is_optimizer=False) -> LoadPlan: + return create_default_local_load_plan(self.state_dict, self.metadata, is_optimizer) def resolve_tensor(self, read_item: ReadItem): tensor = self.lookup_tensor(read_item.dest_index) @@ -70,18 +70,19 @@ def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor: from internlm.train.pipeline import map_fqn_local_to_global -def create_default_local_load_plan(state_dict: Dict[str, Any], metadata: Metadata) -> LoadPlan: +def create_default_local_load_plan(state_dict: Dict[str, Any], metadata: Metadata, is_optimizer) -> LoadPlan: """ A function for creating local loading plan for loading checkpoint """ # print(f"metadata: {metadata}", flush=True) requests = [] for fqn, obj in state_dict.items(): - md_fqn = fqn - if fqn in map_fqn_local_to_global: - md_fqn = map_fqn_local_to_global[fqn] - print(f"map_fqn_local_to_global {gpc.get_local_rank(ParallelMode.PIPELINE)}: {fqn}, {md_fqn}", flush=True) - md = metadata.state_dict_metadata[md_fqn] + global_fqn = fqn + if not is_optimizer: + if fqn in map_fqn_local_to_global: + global_fqn = map_fqn_local_to_global[fqn] + # print(f"map_fqn_local_to_global {gpc.get_local_rank(ParallelMode.PIPELINE)}: {fqn}, {global_fqn}", flush=True) + md = metadata.state_dict_metadata[global_fqn] if isinstance(obj, DTensor): assert False if obj.device_mesh.get_coordinate() is not None: @@ -102,7 +103,7 @@ def create_default_local_load_plan(state_dict: Dict[str, Any], metadata: Metadat obj.local_tensor = obj.local_tensor.reshape(obj.local_shape) requests += _create_read_items(fqn, md, obj) else: - item = _create_read_items(fqn, md_fqn, md, obj) + item = _create_read_items(fqn, global_fqn, md, obj) # print(f"_create_read_items {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.PIPELINE)}: {item}", flush=True) requests += item return LoadPlan(requests) @@ -122,8 +123,8 @@ def resolve_data(self, write_item: WriteItem, fqn=None) -> Union[torch.Tensor, i object = self.lookup_object(write_item.index, fqn) return self.transform_object(write_item, object) - def create_local_plan(self) -> Tuple[SavePlan, P2PTensorsInfo]: - plan, p2p_tensors_info = create_default_local_save_plan(self.state_dict, self.is_coordinator) + def create_local_plan(self, is_optimizer=False) -> Tuple[SavePlan, P2PTensorsInfo]: + plan, p2p_tensors_info = create_default_local_save_plan(self.state_dict, self.is_coordinator, is_optimizer) # print(f"save before replace local_plan {dist.get_rank()}: {plan}", flush=True) if self.flatten_state_dict: plan = dataclasses.replace(plan, planner_data=self.mappings) @@ -169,7 +170,7 @@ def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], return rst_value -def create_default_local_save_plan(state_dict: Dict[str, Any], is_coordinator: bool) -> SavePlan: +def create_default_local_save_plan(state_dict: Dict[str, Any], is_coordinator: bool, is_optimizer=False) -> SavePlan: """ A function for creating local saving plan for saving checkpoint. """ @@ -180,6 +181,8 @@ def create_default_local_save_plan(state_dict: Dict[str, Any], is_coordinator: b send_p2p_reqs = [] recv_p2p_reqs = {} + # if is_optimizer: + # state_dict = state_dict["unflatten_fp32_weights"] for fqn, obj in state_dict.items(): # Since DTensor supports submesh, adding extra check to ensure _create_write_items() @@ -244,7 +247,7 @@ def create_default_local_save_plan(state_dict: Dict[str, Any], is_coordinator: b obj.local_tensor = obj.local_tensor.reshape(obj.local_shape) requests += _create_write_items(fqn, obj) elif isinstance(obj, (torch.Tensor)) or is_coordinator: - item = _create_write_items(fqn, obj) + item = _create_write_items(fqn, obj, is_optimizer=is_optimizer) # print(f"_create_write_items {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.TENSOR)}: {item}", flush=True) requests += item # if dist.get_rank() == 1: @@ -254,6 +257,7 @@ def create_default_local_save_plan(state_dict: Dict[str, Any], is_coordinator: b # Merge the tensors later writer_rank = dist.get_rank() for fqn in recv_tensors.keys(): + assert False obj = state_dict[fqn] new_local_tensor = torch.zeros( (math.prod(obj.local_shape),), dtype=obj.local_tensor.dtype, device=obj.local_tensor.device diff --git a/internlm/checkpoint/vescale/vescale_planner_helpers.py b/internlm/checkpoint/vescale/vescale_planner_helpers.py index 9c4e8903e..5c8316ada 100644 --- a/internlm/checkpoint/vescale/vescale_planner_helpers.py +++ b/internlm/checkpoint/vescale/vescale_planner_helpers.py @@ -60,16 +60,30 @@ def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata: return ChunkStorageMetadata(offsets=offsets, sizes=sizes) from internlm.train.pipeline import map_layer_attr, map_fqn_local_to_global -def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor) -> WriteItem: +def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor, is_optimizer=False) -> WriteItem: offsets = torch.Size([0] * len(tensor.size())) size = tensor.size() - if "norm" not in fqn: - assert fqn in map_layer_attr, f"{fqn}" - if fqn in map_layer_attr: - offsets = torch.Size(map_layer_attr[fqn]['offset']) - size = torch.Size(map_layer_attr[fqn]['complete_size']) - if fqn in map_fqn_local_to_global: - fqn = map_fqn_local_to_global[fqn] + + if not is_optimizer: + if 'layer' in fqn: + assert fqn in map_fqn_local_to_global, f"{fqn}" + # print(f"map_fqn: {fqn}, {map_fqn_local_to_global[fqn]}", flush=True) + fqn = map_fqn_local_to_global[fqn] + + map_fqn = fqn + if map_fqn not in map_layer_attr: + # os exp_avg, exp_avg_sq + map_fqn = fqn.rsplit('.', 1)[0] + assert map_fqn in map_layer_attr, f"{gpc.get_global_rank()}, {gpc.get_local_rank(ParallelMode.PIPELINE)}, {gpc.get_local_rank(ParallelMode.ZERO1)}, {is_optimizer}, {fqn}" + offsets = torch.Size(map_layer_attr[map_fqn]['offset']) + size = torch.Size(map_layer_attr[map_fqn]['complete_size']) + + # if map_fqn in map_fqn_local_to_global: + # global_fqn = map_fqn_local_to_global[map_fqn] + # if fqn != map_fqn: + # fqn = global_fqn + f".{fqn.rsplit('.', 1)[1]}" + # else: + # fqn = global_fqn # if offsets[0] != 0: # k = offsets[0] // gpc.get_local_rank(ParallelMode.TENSOR) @@ -111,15 +125,17 @@ def _create_write_item_for_bytesio(fqn: str, bytes: Any): ) -def _create_write_items(fqn: str, object: Any) -> List[WriteItem]: +def _create_write_items(fqn: str, object: Any, is_optimizer=False) -> List[WriteItem]: if isinstance(object, DTensor): - # assert False + assert False return [_create_write_items_for_dtensor(fqn, object)] elif isinstance(object, torch.Tensor): - return [_create_write_item_for_tensor(fqn, object)] + return [_create_write_item_for_tensor(fqn, object, is_optimizer=is_optimizer)] elif isinstance(object, OptimizerStateSpec): + assert False return [_create_write_item_for_optimizer_state(fqn, object)] else: + assert False return [_create_write_item_for_bytesio(fqn, object)] @@ -136,7 +152,7 @@ def _create_read_item_for_tensor(dest_index, dest_offsets, storage_index, storag def create_read_items_for_chunk_list( fqn: str, - md_fqn: str, + global_fqn: str, checkpoint_md: TensorStorageMetadata, local_chunks: List[ChunkStorageMetadata], ) -> List[ReadItem]: @@ -180,7 +196,7 @@ def create_read_items_for_chunk_list( _create_read_item_for_tensor( dest_index=MetadataIndex(fqn, shard.offsets, idx), dest_offsets=dest_offsets, - storage_index=MetadataIndex(md_fqn, storage_md.offsets, storage_idx), + storage_index=MetadataIndex(global_fqn, storage_md.offsets, storage_idx), storage_offsets=storage_offsets, lengths=lengths, ) @@ -188,16 +204,17 @@ def create_read_items_for_chunk_list( return read_items -def _create_chunk_from_tensor(fqn, tensor: torch.Tensor) -> ChunkStorageMetadata: +def _create_chunk_from_tensor(global_fqn, tensor: torch.Tensor) -> ChunkStorageMetadata: # sizes = torch.Size(compute_local_shape(tensor.shape, tensor.device_mesh, tensor.placements)) # offsets = torch.Size(compute_local_offset(tensor.shape, tensor.device_mesh, tensor.placements)) # return ChunkStorageMetadata(offsets=offsets, sizes=sizes) - offsets = torch.Size([0] * len(tensor.size())) - if "norm" not in fqn: - assert fqn in map_layer_attr, f"{fqn}" - offsets = torch.Size(map_layer_attr[fqn]['offset']) + if global_fqn not in map_layer_attr: + # os exp_avg, exp_avg_sq + global_fqn = global_fqn.rsplit('.', 1)[0] + assert global_fqn in map_layer_attr, f"{global_fqn}" + offsets = torch.Size(map_layer_attr[global_fqn]['offset']) # return ChunkStorageMetadata(offsets=torch.Size([0] * len(tensor.size())), sizes=tensor.size()) @@ -220,21 +237,22 @@ def _create_chunk_from_optimizer_spec(obj: OptimizerStateSpec) -> ChunkStorageMe return ChunkStorageMetadata(offsets=obj.global_offset, sizes=obj.local_shape) -def _create_read_items(fqn: str, md_fqn, md: STORAGE_TYPES, obj: Any) -> List[ReadItem]: +def _create_read_items(fqn: str, global_fqn, md: STORAGE_TYPES, obj: Any) -> List[ReadItem]: if not isinstance(md, BytesStorageMetadata): if isinstance(obj, DTensor): assert False local_chunks = [_create_chunk_from_dtensor(obj)] elif isinstance(obj, torch.Tensor): - local_chunks = [_create_chunk_from_tensor(fqn, obj)]#att + local_chunks = [_create_chunk_from_tensor(global_fqn, obj)]#att # print(f"local_chunks {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.PIPELINE)} {fqn}: {local_chunks}", flush=True) elif isinstance(obj, OptimizerStateSpec): + assert False local_chunks = [_create_chunk_from_optimizer_spec(obj)] else: raise ValueError( f"Invalid checkpoint metadata for {fqn}, " + f"expected BytesStorageMetadata but found {type(md)}" ) - return create_read_items_for_chunk_list(fqn, md_fqn, md, local_chunks) + return create_read_items_for_chunk_list(fqn, global_fqn, md, local_chunks) else: assert False return [ diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index d0ef284d4..8a7a381b6 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -45,10 +45,13 @@ from internlm.utils.utils import DataType from internlm.utils.writer import Writer + + +import os + # global llm logger logger = logging.getLogger(__file__) - class TrainerBuilder(Trainer): """ A class for building and managing InternEvo training workflow. @@ -126,7 +129,15 @@ def __init__( # initialize checkpoint manager and try resume training self.ckpt_manager = self._initialize_checkpoint_manager(model, optimizer, lr_scheduler, train_dl, config_lines) + self.ckpt_manager.try_resume_training(train_state, self.current_time) + + + + # from internlm.checkpoint.vescale.api import load as vescale_load + # checkpoint_state = {"model": model.model} + # vescale_load("/mnt/petrelfs/lijiaxing/InternEvo/vescale_ckpt_test/iter_10", checkpoint_state, broadcast_checkpoint=False) + # print("finish loading", flush=True) # initialize customed llm writer self.writer = self._initialize_writer(train_state, config_lines) @@ -258,19 +269,39 @@ def fit(self): """ self.train() train_iter = iter(self.train_dl) - + + from internlm.checkpoint.vescale.api import load as vescale_load + checkpoint_state = {"model": self.ckpt_manager.model} + checkpoint_state = {"model": self.ckpt_manager.model, "optimizer": self.ckpt_manager.optimizer} + checkpoint_state = {"model": self.ckpt_manager.model, "optimizer": self.ckpt_manager.optimizer} + # vescale_load("/mnt/petrelfs/lijiaxing/InternEvo/vescale_ckpt_test_dp2_tp2_pp2/20", checkpoint_state, broadcast_checkpoint=False) + + + print("finish loading", flush=True) + print(f"rank_log: rank={gpc.get_global_rank()}, dp={gpc.get_local_rank(ParallelMode.DATA)}, pp={gpc.get_local_rank(ParallelMode.PIPELINE)}, tp={gpc.get_local_rank(ParallelMode.TENSOR)}", flush=True) + print(f"train_state: {self.train_state}", flush=True) with initialize_llm_profile(profiling=self.profiling, start_time=self.current_time) as prof: gc.disable() for batch_count in range(self.train_state.batch_count, gpc.config.data.total_steps): + # print(f"norm_weight {gpc.get_global_rank()}: {self.ckpt_manager.model.norm.weight.shape}, {self.ckpt_manager.model.norm.weight}") + if batch_count == 35: + break if self._process_batch(batch_count, train_iter, prof): break self.ckpt_manager.wait_async_upload_finish() - + def _process_batch(self, batch_count: int, train_iter, prof) -> bool: empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) start_time = time.time() timer("one-batch").start() + + # data_path = "/mnt/petrelfs/lijiaxing/InternEvo/test_data_batch.pt" + # batch = torch.load(data_path) + + # if batch[0].get("type_ids", None) is not None: + # self.metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None)) + batch, train_iter = self._load_and_prepare_batch(batch_count, train_iter) if self.batch_skipper(batch_count): @@ -278,6 +309,14 @@ def _process_batch(self, batch_count: int, train_iter, prof) -> bool: logger.info(f"Skip batch count:`{batch_count}`...") timer("one-batch").stop() return False + + # if batch_count <= 10: + # if gpc.is_rank_for_log(): + # print(f"skip {batch_count}", flush=True) + # return False + + if gpc.is_rank_for_log(): + print(f"start trainging {batch_count}", flush=True) timer("fwd-bwd").start() loss, moe_loss = self._forward_backward(batch) @@ -289,15 +328,80 @@ def _process_batch(self, batch_count: int, train_iter, prof) -> bool: if self._should_evaluate(): self._evaluate() - + # print(f"self.ckpt_manager.hybrid {gpc.get_global_rank()}, {gpc.get_local_rank(ParallelMode.ZERO1)}, {gpc.get_local_rank(ParallelMode.PIPELINE)}, {gpc.get_local_rank(ParallelMode.TENSOR)}: {self.ckpt_manager.optimizer.state_dict()}", flush=True) + # print(f"self.ckpt_manager.optimizer {gpc.get_global_rank()}, {gpc.get_local_rank(ParallelMode.ZERO1)}, {gpc.get_local_rank(ParallelMode.PIPELINE)}, {gpc.get_local_rank(ParallelMode.TENSOR)}: {self.ckpt_manager.optimizer.optim.state_dict()}", flush=True) + # print(f"self.ckpt_manager.grad_scaler {gpc.get_global_rank()}, {gpc.get_local_rank(ParallelMode.ZERO1)}, {gpc.get_local_rank(ParallelMode.PIPELINE)}, {gpc.get_local_rank(ParallelMode.TENSOR)}: {self.ckpt_manager.optimizer.grad_scaler.state_dict()}", flush=True) + + # checkpoint_state = {"model": self.ckpt_manager.model, "optimizer": self.ckpt_manager.optimizer} + # checkpoint_state = {"model": self.ckpt_manager.model} + # checkpoint_state = {"optimizer": self.ckpt_manager.optimizer} + # assert 'output.weight' in self.ckpt_manager.model.state_dict() + fqn = 'layers.14.feed_forward.w2.weight' + # print(f"cp_length {gpc.get_local_rank(ParallelMode.ZERO1)}: {len(self.ckpt_manager.optimizer.state_dict()['unflatten_fp32_weights'])}, {len(self.ckpt_manager.model.state_dict())}") + # if fqn in self.ckpt_manager.optimizer.state_dict()['unflatten_fp32_weights']: + # tensor1 = self.ckpt_manager.optimizer.state_dict()['unflatten_fp32_weights'][fqn] + # tensor2 = self.ckpt_manager.model.state_dict()[fqn] + # print(f"self.ckpt_manager.optimizer.state_dict() {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.TENSOR)} {gpc.get_local_rank(ParallelMode.ZERO1)}: {tensor1.dtype}, {tensor1.shape}, {tensor1}", flush=True) + # print(f"self.ckpt_manager.model.state_dict() {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.TENSOR)} {gpc.get_local_rank(ParallelMode.ZERO1)}: {tensor2.shape}, {tensor2}", flush=True) + + # print(f"model_state {dist.get_rank()}: {self.ckpt_manager.model.state_dict()['layers.0.attention.wo.weight'].shape}", flush=True) + # print(f"optimizer_state {dist.get_rank()}: {self.ckpt_manager.optimizer.state_dict()['master_current_weights'][0][0].shape}", flush=True) + # print(f"self.ckpt_manager.optimizer {gpc.get_global_rank()}, {gpc.get_local_rank(ParallelMode.ZERO1)}, {gpc.get_local_rank(ParallelMode.PIPELINE)}, {gpc.get_local_rank(ParallelMode.TENSOR)}: {self.ckpt_manager.optimizer.optim.state_dict()}", flush=True) + + # if batch_count == 1: + # from internlm.checkpoint.vescale.api import load as vescale_load + # checkpoint_state = {"model": self.ckpt_manager.model, "optimizer": self.ckpt_manager.optimizer} + # vescale_load("/mnt/petrelfs/lijiaxing/InternEvo/vescale_ckpt_test_dp2_tp2_pp2/iter_20", checkpoint_state, broadcast_checkpoint=False) + + if batch_count == 20: + from internlm.checkpoint.vescale.devicemesh_api import VESCALE_DEVICE_MESH + from internlm.checkpoint.vescale.device_mesh import init_device_mesh + from internlm.checkpoint.vescale.api import save as vescale_save + from internlm.checkpoint.vescale.api import load as vescale_load + from internlm.checkpoint.deepspeed.save_checkpoint import save_checkpoint + # device_mesh = init_device_mesh( + # "cuda", + # ( + # 2, + # 4, + # ), + # mesh_dim_names=("DP", "TP"), + # ) + + # VESCALE_DEVICE_MESH._GLOBAL_MESH = device_mesh + # vescale_load("/mnt/petrelfs/lijiaxing/InternEvo/vescale_ckpt_test/iter_1", checkpoint_state, broadcast_checkpoint=False) + + checkpoint_state = {"model": self.ckpt_manager.model, "optimizer": self.ckpt_manager.optimizer} + # checkpoint_state = {"optimizer": self.ckpt_manager.optimizer} + # checkpoint_state = {"model": self.ckpt_manager.model} + # vescale_save( + # os.path.join("/mnt/petrelfs/lijiaxing/InternEvo/vescale_ckpt_test_dp2_tp2_pp2", f"iter_{batch_count}"), + # checkpoint_state, + # async_checkpoint=False, + # ) + # print("finish save", flush=True) + + # save_checkpoint( + # save_dir="/mnt/petrelfs/lijiaxing/InternEvo/deepspeed_ckpt", + # model=self.ckpt_manager.model, + # optimizer=self.ckpt_manager.optimizer, + # lr_scheduler=self.ckpt_manager.lr_scheduler, + # train_state=self.train_state, + # ) + + + if self.ckpt_manager.try_save_checkpoint(self.train_state): return True + self._update_profilers(batch_count, prof) return False def _load_and_prepare_batch(self, batch_count: int, train_iter): batch, train_iter = load_new_batch(train_dl=self.train_dl, train_iter=train_iter, train_state=self.train_state) + # torch.save(batch, "/mnt/petrelfs/lijiaxing/InternEvo/test_data_batch.py") + self.train_state.batch_count = batch_count self.train_state.num_consumed_samples_in_epoch += len(batch[1]) if batch[0].get("type_ids", None) is not None: diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index fc63b8a23..d3c66665c 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -259,6 +259,9 @@ def args_sanity_check(): # If 'auto_resume' is not given, we set it to True, so internlm can have opportunity # to auto-load latest checkpoint. ckpt._add_item("auto_resume", True) + + if "universal_ckpt" not in ckpt: + ckpt._add_item("universal_ckpt", False) if gpc.is_rank_for_log(): logger.info("+" * 15 + " Ckpt Info " + "+" * 15) # pylint: disable=W1201 diff --git a/internlm/model/modules/embedding.py b/internlm/model/modules/embedding.py index 3585e2e46..8e0632f16 100644 --- a/internlm/model/modules/embedding.py +++ b/internlm/model/modules/embedding.py @@ -47,13 +47,14 @@ def __init__( self.vocab_parallel = vocab_parallel parallel_size = gpc.weight_parallel_size if is_using_isp() else gpc.tensor_parallel_size + rank = gpc.get_local_rank(ParallelMode.WEIGHT) if is_using_isp() else gpc.get_local_rank(ParallelMode.TENSOR) if vocab_parallel: assert num_embeddings % parallel_size == 0, f"{num_embeddings} is not divisible by {parallel_size}" self.num_embeddings_per_partition = num_embeddings // parallel_size self.embed_dim_per_partition = embedding_dim - self.vocab_start_index = gpc.get_local_rank(ParallelMode.TENSOR) * self.num_embeddings_per_partition + self.vocab_start_index = rank * self.num_embeddings_per_partition self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition self.offset = [self.vocab_start_index, 0] else: @@ -63,7 +64,7 @@ def __init__( self.embed_dim_per_partition = embedding_dim // parallel_size self.vocab_start_index = 0 self.vocab_end_index = self.num_embeddings_per_partition - self.offset = [0, self.embed_dim_per_partition * gpc.get_local_rank(ParallelMode.TENSOR)] + self.offset = [0, self.embed_dim_per_partition * rank] self.weight = nn.Parameter( torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition), dtype=dtype) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 49f3fbcf0..ac410cd7d 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -34,6 +34,7 @@ from internlm.solver.optimizer.utils import ( DynamicGradScaler, flatten, + unflatten, get_grad_accumulate_object, has_inf_or_nan, reduce_tensor, @@ -50,6 +51,7 @@ from .base_optimizer import BaseOptimizer from .utils import compute_norm + inf = math.inf logger = get_logger(__file__) internlm_accelerator = get_accelerator() @@ -149,12 +151,19 @@ def __init__( assert self._param_bcast_sync_handler is not None self._isp_communicator = isp_communicator - + + self.param_global_shape_info = {} + self.param_local_shape_info = {} + self.param_global_offset_info = {} + self.param_across_dp_ranks_info = {} + self.fqn_list = {} + shape_list = {} # iterate over the param group in the optimizer # partition these param groups for data parallel training # and add buffers to parameter store for future access for group_id, param_group in enumerate(self.optim.param_groups): group_params = param_group["params"] + self.fqn_list[group_id] = [] # set the dtype for each param group param_group["dtype"] = group_params[0].dtype if len(group_params) != 0 else None @@ -194,6 +203,11 @@ def __init__( for param in params: setattr(param, "group_id", group_id) self._param_store.set_param_to_rank(param, rank) + + # for param in params_per_rank[gpc.get_local_rank(ParallelMode.ZERO1)]: + # print(f"fp16_fqn {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.ZERO1)}: {param.fqn}", flush=True) + + # move to cpu to make room to create the flat tensor for param in group_params: @@ -206,12 +220,22 @@ def __init__( # No flat fp16 buffer is allocated if the process has no parameters. if rank not in self.param_group_no_params_ranks[group_id]: tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id) + with torch.no_grad(): flat_tensor = flatten(tensor_list) flat_tensor = flat_tensor.data.to(get_current_device()) self._param_store.add_flat_fp16_param_by_rank_group(rank, group_id, flat_tensor) sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list) - + + if rank == gpc.get_local_rank(ParallelMode.ZERO1): + offset = 0 + for tensor in tensor_list: + shape_list[tensor.fqn] = tensor.shape + self.fqn_list[group_id].append(tensor.fqn) + offset += tensor.numel() + # print(f"fqn_list {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.ZERO1)}: {len(params_per_rank[rank])}, {len(tensor_list)}, {len(self.fqn_list)}, {self.fqn_list}", flush=True) + + # print(f"shape_list {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.TENSOR)} {gpc.get_local_rank(ParallelMode.ZERO1)}: {shape_list}", flush=True) # create a copy of fp32 weights of the parameters for which this rank is responsible # No flat fp32 buffer is allocated if the process has no parameters. if self.param_group_has_params[group_id]: @@ -223,6 +247,14 @@ def __init__( fp32_flat_current_rank = fp32_flat_current_rank.to(device) fp32_flat_current_rank.requires_grad = True self._fp32_flat_param_groups_of_current_rank[group_id] = fp32_flat_current_rank + print(f"fp32_flat_current_rank {gpc.get_global_rank()}: {fp32_flat_current_rank.shape}", flush=True) + + + + self.param_global_shape_info[fp32_flat_current_rank] = fp32_flat_current_rank.shape + self.param_local_shape_info[fp32_flat_current_rank] = fp32_flat_current_rank.shape + self.param_global_offset_info[fp32_flat_current_rank] = [0, 0] + # need to replace the params in the `params` field in the optimizer # so that when the optimizer calls step(), it only updates the tensors @@ -241,6 +273,9 @@ def __init__( self.skip_grad_reduce = False self._attach_reduction_hook() + + # print(f"self.grad_scaler.state_dict() {gpc.get_global_rank()}: {self.grad_scaler.state_dict()}", flush=True) + # print(f"self.optm.state_dict() {gpc.get_global_rank()}: {self.optim.state_dict()}", flush=True) @property def zero_local_rank(self): @@ -266,6 +301,7 @@ def _partition_param_list(self, group_id, param_group): param_list = param_group["params"] sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True) + for i, param in enumerate(sorted_params): if param.requires_grad is False: continue @@ -290,7 +326,7 @@ def _partition_param_list(self, group_id, param_group): logger.info( # pylint: disable=W1203 f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}" ) - + # print(f"sorted_params: {len(sorted_params)}, {len(params_per_rank)}, {params_per_rank}, {no_params_ranks}", flush=True) return params_per_rank, set(no_params_ranks) def _is_moe_group(self, param_group): @@ -961,59 +997,322 @@ def clip_grad_norm(self, model, max_norm): def state_dict(self): states = {} + optim_states = self.optim.state_dict() grad_scaler = self.grad_scaler.state_dict() states["grad_scaler"] = grad_scaler - optim_states = self.optim.state_dict() - states["base_optim_states"] = optim_states - - flat_fp32_weights = {} - for group_id, param in self._fp32_flat_param_groups_of_current_rank.items(): - if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]: - assert param.grad is None - flat_fp32_weights[group_id] = param - states["flat_fp32_weights"] = flat_fp32_weights - states["zero_devide_optim_plan"] = self.params_per_rank_id_dict + if not gpc.config.ckpt.universal_ckpt: + states["base_optim_states"] = optim_states + flat_fp32_weights = {} + for group_id, param in self._fp32_flat_param_groups_of_current_rank.items(): + if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]: + assert param.grad is None + flat_fp32_weights[group_id] = param + states["flat_fp32_weights"] = flat_fp32_weights + states["zero_devide_optim_plan"] = self.params_per_rank_id_dict + else: + import time + if gpc.get_global_rank() == 0: + start = time.time() + print(f"optim_states: {optim_states}", flush=True) + empty_states = False + if len(optim_states['state']) == 0: + empty_states = True + optim_states['state'] = {0:{'step':0}} + + sharded_optimizer_state = {} + temp = [] + for group_id, flatten_fp32_param in self._fp32_flat_param_groups_of_current_rank.items(): + rank = self._zero_local_rank[group_id] + fqn_list = self.fqn_list[group_id] + if rank not in self.param_group_no_params_ranks[group_id]: + # fp32 param + assert len(fqn_list) > 0 + tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id) + if len(tensor_list) > 0: + unflatten_tensor_list = unflatten(flatten_fp32_param, tensor_list) + # base optimizer state + # notice: we assume that one param group corresponds to one flattened tensor. + # we only save unflattened otimizer state. + # print(type(group_id), flush=True) + # print(optim_states['state'], flush=True) + if not empty_states: + flatten_exp_avg = optim_states['state'][group_id]['exp_avg'] + flatten_exp_avg_sq = optim_states['state'][group_id]['exp_avg_sq'] + assert flatten_exp_avg.shape == flatten_fp32_param.shape == flatten_exp_avg_sq.shape + unflatten_exp_avg = unflatten(flatten_exp_avg, tensor_list) + unflatten_exp_avg_sq = unflatten(flatten_exp_avg_sq, tensor_list) + assert len(unflatten_tensor_list) == len(tensor_list) == len(unflatten_exp_avg) == len(unflatten_exp_avg_sq) + from internlm.train.pipeline import map_fqn_local_to_global + for i in range(len(tensor_list)): + assert tensor_list[i].fqn == fqn_list[i] + assert tensor_list[i].fqn not in sharded_optimizer_state + fqn = tensor_list[i].fqn + if fqn in map_fqn_local_to_global: + fqn = map_fqn_local_to_global[fqn] + temp.append(fqn) + sharded_optimizer_state[fqn] = unflatten_tensor_list[i] + + if not empty_states: + sharded_optimizer_state[f"{fqn}.exp_avg"] = unflatten_exp_avg[i] + sharded_optimizer_state[f"{fqn}.exp_avg_sq"] = unflatten_exp_avg_sq[i] + else: + sharded_optimizer_state[f"{fqn}.exp_avg"] = torch.empty_like(unflatten_tensor_list[i]) + sharded_optimizer_state[f"{fqn}.exp_avg_sq"] = torch.empty_like(unflatten_tensor_list[i]) + + else: + assert len(fqn_list) == 0 + + states["step"] = optim_states['state'][0]['step'] + states["param_groups"] = optim_states["param_groups"] + states["sharded_optimizer_state"] = sharded_optimizer_state + if gpc.get_global_rank() == 0: + end = time.time() - start + print(f"create_state_dict: {end}", flush=True) return states + + # def state_dict(self): + # import time + # if gpc.get_global_rank() == 0: + # start = time.time() + # states = {} + # optim_states = self.optim.state_dict() + # print(f"optim_states: {optim_states}", flush=True) + # empty_states = False + # if len(optim_states['state']) == 0: + # empty_states = True + # optim_states['state'] = {0:{'step':0}} + + + # sharded_optimizer_state = {} + # temp = [] + # for group_id, flatten_fp32_param in self._fp32_flat_param_groups_of_current_rank.items(): + # rank = self._zero_local_rank[group_id] + # fqn_list = self.fqn_list[group_id] + # if rank not in self.param_group_no_params_ranks[group_id]: + # # fp32 param + # assert len(fqn_list) > 0 + # tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id) + # if len(tensor_list) > 0: + # unflatten_tensor_list = unflatten(flatten_fp32_param, tensor_list) + # # base optimizer state + # # notice: we assume that one param group corresponds to one flattened tensor. + # # we only save unflattened otimizer state. + # # print(type(group_id), flush=True) + # # print(optim_states['state'], flush=True) + # if not empty_states: + # flatten_exp_avg = optim_states['state'][group_id]['exp_avg'] + # flatten_exp_avg_sq = optim_states['state'][group_id]['exp_avg_sq'] + # assert flatten_exp_avg.shape == flatten_fp32_param.shape == flatten_exp_avg_sq.shape + # unflatten_exp_avg = unflatten(flatten_exp_avg, tensor_list) + # unflatten_exp_avg_sq = unflatten(flatten_exp_avg_sq, tensor_list) + # assert len(unflatten_tensor_list) == len(tensor_list) == len(unflatten_exp_avg) == len(unflatten_exp_avg_sq) + # from internlm.train.pipeline import map_fqn_local_to_global + # for i in range(len(tensor_list)): + # assert tensor_list[i].fqn == fqn_list[i] + # assert tensor_list[i].fqn not in sharded_optimizer_state + # fqn = tensor_list[i].fqn + # if fqn in map_fqn_local_to_global: + # fqn = map_fqn_local_to_global[fqn] + # temp.append(fqn) + # sharded_optimizer_state[fqn] = unflatten_tensor_list[i] + + # if not empty_states: + # sharded_optimizer_state[f"{fqn}.exp_avg"] = unflatten_exp_avg[i] + # sharded_optimizer_state[f"{fqn}.exp_avg_sq"] = unflatten_exp_avg_sq[i] + # else: + # sharded_optimizer_state[f"{fqn}.exp_avg"] = torch.empty_like(unflatten_tensor_list[i]) + # sharded_optimizer_state[f"{fqn}.exp_avg_sq"] = torch.empty_like(unflatten_tensor_list[i]) + + + # # print(f"debug_before: {self.optim.state_dict()['state'][group_id]['exp_avg']}", flush=True) + # # optim_states['state'][group_id]['exp_avg'] = '' + # # optim_states['state'][group_id]['exp_avg_sq'] = '' + # # print(f"debug_after: {self.optim.state_dict()['state'][group_id]['exp_avg']}", flush=True) + + # else: + # assert len(fqn_list) == 0 + + # grad_scaler = self.grad_scaler.state_dict() + # # print(f"optim_states['state'][0]['step']: {optim_states['state'][0]}, {optim_states['state'][0]['step']}, {type(optim_states['state'][0]['step'])}", flush=True) + # states["grad_scaler"] = grad_scaler + # states["step"] = optim_states['state'][0]['step'] + # states["param_groups"] = optim_states["param_groups"] + # states["sharded_optimizer_state"] = sharded_optimizer_state + # if gpc.get_global_rank() == 0: + # end = time.time() - start + # print(f"create_state_dict: {end}", flush=True) + + # return states + + # def load_state_dict(self, states, global_optimizer_state): + # assert "grad_scaler" in global_optimizer_state, "Not found grad_scaler state!" + # assert "step" in global_optimizer_state, "Not found step state!" + # assert "param_groups" in global_optimizer_state, "Not found param_groups state!" + # print(f"states {gpc.get_global_rank()}: {states.keys()}", flush=True) + + # grad_scaler = global_optimizer_state["grad_scaler"] + # print(f"load_state_dict grad_scaler {gpc.get_global_rank()}: {grad_scaler}", flush=True) + # self.grad_scaler.load_state_dict(grad_scaler) + + # step = global_optimizer_state["step"] + # print(f"load_state_dict step {gpc.get_global_rank()}: {step}, {global_optimizer_state}", flush=True) + # param_groups = global_optimizer_state["param_groups"] + # print(f"load_state_dict param_groups {gpc.get_global_rank()}: {param_groups}", flush=True) + + # if gpc.config.get("only_load_lr", False): + # if gpc.is_rank_for_log(): + # logger.info("Only load lr in param_groups, skip loading weights in optimizer...") + # for pg1, pg2 in zip(self.optim.param_groups, param_groups): + # pg1["lr"] = pg2["lr"] + # return + + # optim_states = {'state': {}, 'param_groups': param_groups} + # for group_id, self_flatten_fp32_param in self._fp32_flat_param_groups_of_current_rank.items(): + # print(f"group_id: {type(group_id)}", flush=True) + # rank = self._zero_local_rank[group_id] + # if rank not in self.param_group_no_params_ranks[group_id]: + # # self fp16 unflatten param list + # self_tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id) + + # if len(self_tensor_list) > 0: + # optim_states['state'][group_id] = {"step": step} + # ckpt_fp32_params = [] + # ckpt_exp_avg_list = [] + # ckpt_exp_avg_sq_list = [] + # for tensor in self_tensor_list: + # fqn = tensor.fqn + # from internlm.train.pipeline import map_fqn_local_to_global + # if fqn in map_fqn_local_to_global: + # fqn = map_fqn_local_to_global[fqn] + # assert tensor.shape == states[fqn].shape == states[f"{fqn}.exp_avg"].shape == states[f"{fqn}.exp_avg_sq"].shape + # ckpt_fp32_params.append(states[fqn]) + # ckpt_exp_avg_list.append(states[f"{fqn}.exp_avg"]) + # ckpt_exp_avg_sq_list.append(states[f"{fqn}.exp_avg_sq"]) + + # ckpt_flatten_fp32_param = flatten(ckpt_fp32_params) + # ckpt_flatten_exp_avg = flatten(ckpt_exp_avg_list) + # ckpt_flatten_exp_avg_sq = flatten(ckpt_exp_avg_sq_list) + + # assert ( + # self_flatten_fp32_param.shape == ckpt_flatten_fp32_param.shape + # ), f"The loaded parameter shape is inconsistent, {self_flatten_fp32_param.shape} != {ckpt_flatten_fp32_param.shape}" + # self_flatten_fp32_param.data.copy_(ckpt_flatten_fp32_param.data) + # optim_states['state'][group_id]['exp_avg'] = ckpt_flatten_exp_avg + # optim_states['state'][group_id]['exp_avg_sq'] = ckpt_flatten_exp_avg_sq + + # self.optim.load_state_dict(optim_states) + + + # # Load the fp16 model weights. + # for group_id in range(len(self._fp16_param_groups)): + # if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]: + # fp16_param = self._param_store.get_flat_fp16_param_by_rank_group( + # rank=self._zero_local_rank[group_id], group_id=group_id + # ) + # fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] + # fp16_param.data.copy_(fp32_param) + + + + + def load_state_dict(self, states, global_optimizer_state=None): + if not gpc.config.ckpt.universal_ckpt: + # TODO: Need to take into account the change in the number of DP. + assert "grad_scaler" in states, "Not found grad_scaler state!" + grad_scaler = states["grad_scaler"] + self.grad_scaler.load_state_dict(grad_scaler) + optim_states = states["base_optim_states"] + + if gpc.config.get("only_load_lr", False): + if gpc.is_rank_for_log(): + logger.info("Only load lr in param_groups, skip loading weights in optimizer...") + for pg1, pg2 in zip(self.optim.param_groups, optim_states["param_groups"]): + pg1["lr"] = pg2["lr"] + return + + self.optim.load_state_dict(optim_states) + + # load fp32 model weight. + flat_fp32_weights = states["flat_fp32_weights"] + assert set(flat_fp32_weights.keys()) == set(self._fp32_flat_param_groups_of_current_rank) + for group_id, param in flat_fp32_weights.items(): + if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]: + self_param = self._fp32_flat_param_groups_of_current_rank[group_id] + assert ( + self_param.shape == param.shape + ), f"The loaded parameter shape is inconsistent, {self_param.shape} != {param.shape}" + self_param.data.copy_(param.data) + + # Load the fp16 model weights. + for group_id in range(len(self._fp16_param_groups)): + if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]: + fp16_param = self._param_store.get_flat_fp16_param_by_rank_group( + rank=self._zero_local_rank[group_id], group_id=group_id + ) + fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] + fp16_param.data.copy_(fp32_param) - def load_state_dict(self, states): - # TODO: Need to take into account the change in the number of DP. - assert "grad_scaler" in states, "Not found grad_scaler state!" - grad_scaler = states["grad_scaler"] - self.grad_scaler.load_state_dict(grad_scaler) - optim_states = states["base_optim_states"] - - if gpc.config.get("only_load_lr", False): - if gpc.is_rank_for_log(): - logger.info("Only load lr in param_groups, skip loading weights in optimizer...") - for pg1, pg2 in zip(self.optim.param_groups, optim_states["param_groups"]): - pg1["lr"] = pg2["lr"] - return - - self.optim.load_state_dict(optim_states) - - # load fp32 model weight. - flat_fp32_weights = states["flat_fp32_weights"] - assert set(flat_fp32_weights.keys()) == set(self._fp32_flat_param_groups_of_current_rank) - for group_id, param in flat_fp32_weights.items(): - if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]: - self_param = self._fp32_flat_param_groups_of_current_rank[group_id] - assert ( - self_param.shape == param.shape - ), f"The loaded parameter shape is inconsistent, {self_param.shape} != {param.shape}" - self_param.data.copy_(param.data) - - # Load the fp16 model weights. - for group_id in range(len(self._fp16_param_groups)): - if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]: - fp16_param = self._param_store.get_flat_fp16_param_by_rank_group( - rank=self._zero_local_rank[group_id], group_id=group_id - ) - fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] - fp16_param.data.copy_(fp32_param) - - if "zero_devide_optim_plan" in states: - self.params_per_rank_id_dict = states["zero_devide_optim_plan"] + if "zero_devide_optim_plan" in states: + self.params_per_rank_id_dict = states["zero_devide_optim_plan"] + else: + assert global_optimizer_state is not None + assert "grad_scaler" in global_optimizer_state, "Not found grad_scaler state!" + assert "step" in global_optimizer_state, "Not found step state!" + assert "param_groups" in global_optimizer_state, "Not found param_groups state!" + print(f"states {gpc.get_global_rank()}: {states.keys()}", flush=True) + + grad_scaler = global_optimizer_state["grad_scaler"] + print(f"load_state_dict grad_scaler {gpc.get_global_rank()}: {grad_scaler}", flush=True) + self.grad_scaler.load_state_dict(grad_scaler) + + step = global_optimizer_state["step"] + print(f"load_state_dict step {gpc.get_global_rank()}: {step}, {global_optimizer_state}", flush=True) + param_groups = global_optimizer_state["param_groups"] + print(f"load_state_dict param_groups {gpc.get_global_rank()}: {param_groups}", flush=True) + + if gpc.config.get("only_load_lr", False): + if gpc.is_rank_for_log(): + logger.info("Only load lr in param_groups, skip loading weights in optimizer...") + for pg1, pg2 in zip(self.optim.param_groups, param_groups): + pg1["lr"] = pg2["lr"] + return + + optim_states = {'state': {}, 'param_groups': param_groups} + for group_id, self_flatten_fp32_param in self._fp32_flat_param_groups_of_current_rank.items(): + print(f"group_id: {type(group_id)}", flush=True) + rank = self._zero_local_rank[group_id] + if rank not in self.param_group_no_params_ranks[group_id]: + # self fp16 unflatten param list + self_tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id) + + if len(self_tensor_list) > 0: + optim_states['state'][group_id] = {"step": step} + ckpt_fp32_params = [] + ckpt_exp_avg_list = [] + ckpt_exp_avg_sq_list = [] + for tensor in self_tensor_list: + fqn = tensor.fqn + from internlm.train.pipeline import map_fqn_local_to_global + if fqn in map_fqn_local_to_global: + fqn = map_fqn_local_to_global[fqn] + assert tensor.shape == states[fqn].shape == states[f"{fqn}.exp_avg"].shape == states[f"{fqn}.exp_avg_sq"].shape + ckpt_fp32_params.append(states[fqn]) + ckpt_exp_avg_list.append(states[f"{fqn}.exp_avg"]) + ckpt_exp_avg_sq_list.append(states[f"{fqn}.exp_avg_sq"]) + + ckpt_flatten_fp32_param = flatten(ckpt_fp32_params) + ckpt_flatten_exp_avg = flatten(ckpt_exp_avg_list) + ckpt_flatten_exp_avg_sq = flatten(ckpt_exp_avg_sq_list) + + assert ( + self_flatten_fp32_param.shape == ckpt_flatten_fp32_param.shape + ), f"The loaded parameter shape is inconsistent, {self_flatten_fp32_param.shape} != {ckpt_flatten_fp32_param.shape}" + self_flatten_fp32_param.data.copy_(ckpt_flatten_fp32_param.data) + optim_states['state'][group_id]['exp_avg'] = ckpt_flatten_exp_avg + optim_states['state'][group_id]['exp_avg_sq'] = ckpt_flatten_exp_avg_sq + + self.optim.load_state_dict(optim_states) def reload_zero_fp32_buff(self): # If we use AMP optimizer, we need to update its fp32 buffer as newly loaded weights value. diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 2b85b452c..cb8f6f48b 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -138,9 +138,10 @@ def set_param_unique_tracking_name(model): continue full_name = f"{chunk_id}.{idx}.{name}" parts = f"{full_name}.weight".split('.', 2) - original_id = model.first_layer + idx - map_fqn = f"{children_name}.{original_id}." + '.'.join(parts[2:]) + global_id = model.first_layer + idx result = f"{children_name}." + '.'.join(parts[1:]) + global_fqn = f"{children_name}.{global_id}." + '.'.join(parts[2:]) + if isinstance(child, (ParallelLinearWithCommExt)): setattr( child.weight, @@ -154,6 +155,19 @@ def set_param_unique_tracking_name(model): f"{full_name}.bias", ) + + setattr( + child.weight, + "fqn", + f"{result}", + ) + if child.bias is not None: + setattr( + child.bias, + "fqn", + f"{result}", + ) + # print(result, flush=True) assert hasattr(child, "offset"), f"{child}" @@ -164,25 +178,29 @@ def set_param_unique_tracking_name(model): # print(f"original_id {gpc.get_local_rank(ParallelMode.PIPELINE)}: {model.first_layer}, {idx}, {original_id}, {name, child}, {map_fqn}") - if "4.attention_norm" in map_fqn: - assert False - - map_fqn_local_to_global[result] = map_fqn - map_fqn_global_to_local[map_fqn] = result + map_fqn_local_to_global[result] = global_fqn + map_fqn_global_to_local[global_fqn] = result # print(f"map_pp_layer_fqn {gpc.get_local_rank(ParallelMode.PIPELINE)}: {map_pp_layer_fqn}", flush=True) - assert result not in map_layer_attr, f"{map_layer_attr} exists" + assert global_fqn not in map_layer_attr, f"{map_layer_attr} exists" - map_layer_attr[result] = {'offset': getattr(child, "offset", [0] * len(child.weight.size())), 'complete_size': getattr(child, "complete_size", child.weight.size())} + map_layer_attr[global_fqn] = {'offset': getattr(child, "offset", [0] * len(child.weight.size())), 'complete_size': getattr(child, "complete_size", child.weight.size())} # print(f"child.weight {gpc.get_local_rank(ParallelMode.TENSOR)}: {result}, {child.offset}, {child.weight.shape}", flush=True) elif isinstance(child, (RMSNorm)): - print(f"map_fqn {gpc.get_local_rank(ParallelMode.PIPELINE)}: {map_fqn}", flush=True) - map_fqn_local_to_global[result] = map_fqn - map_fqn_global_to_local[map_fqn] = result + # print(f"global_fqn {gpc.get_local_rank(ParallelMode.PIPELINE)}: {global_fqn}", flush=True) + map_fqn_local_to_global[result] = global_fqn + map_fqn_global_to_local[global_fqn] = result + setattr( + child.weight, + "fqn", + f"{result}", + ) + map_layer_attr[global_fqn] = {'offset': getattr(child, "offset", [0] * len(child.weight.size())), 'complete_size': getattr(child, "complete_size", child.weight.size())} else: full_name = f"{chunk_id}.{children_name}" result = f"{children_name}.weight" + assert getattr(children, "bias", None) is None # print(f"result: {result}", flush=True) if isinstance(children, Embedding1D): setattr( @@ -200,6 +218,20 @@ def set_param_unique_tracking_name(model): ) assert result not in map_layer_attr, f"{map_layer_attr} exists" # map_layer_attr[result] = {'offset': getattr(children, "offset", [0] * len(children.weight.size())), 'complete_size': getattr(children, "complete_size", children.weight.size())} + + setattr( + children.weight, + "fqn", + f"{result}", + ) + if getattr(children, "bias", None) is not None: + if children.bias is not None: + setattr( + children.bias, + "fqn", + f"{result}", + ) + map_layer_attr[result] = {'offset': getattr(children, "offset", [0] * len(children.weight.size())), 'complete_size': getattr(children, "complete_size", children.weight.size())} # print(f"map_layer_attr global={gpc.get_global_rank()}, pp={gpc.get_local_rank(ParallelMode.PIPELINE)}, tp={gpc.get_local_rank(ParallelMode.TENSOR)}: {map_layer_attr}", flush=True) From 364c5949f66620fe26fa09717943e4d98db674ad Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Mon, 23 Dec 2024 16:19:32 +0800 Subject: [PATCH 03/12] support save cache and load broadcast --- internlm/checkpoint/checkpoint_manager.py | 8 ++--- internlm/checkpoint/vescale/common.py | 2 +- internlm/checkpoint/vescale/filesystem.py | 20 +++++++----- .../checkpoint/vescale/save_state_dict.py | 31 ++++++++++--------- .../vescale/vescale_checkpointer.py | 4 +-- .../checkpoint/vescale/vescale_planner.py | 24 +++++++++++--- .../vescale/vescale_planner_helpers.py | 1 + internlm/initialize/launch.py | 2 +- .../solver/optimizer/hybrid_zero_optim.py | 4 +-- 9 files changed, 59 insertions(+), 37 deletions(-) diff --git a/internlm/checkpoint/checkpoint_manager.py b/internlm/checkpoint/checkpoint_manager.py index 826a8b859..5698d3182 100644 --- a/internlm/checkpoint/checkpoint_manager.py +++ b/internlm/checkpoint/checkpoint_manager.py @@ -87,7 +87,7 @@ def try_load_internevo_ckpt(ckpt_mm, load_info, train_state: TrainState = None, if universal_ckpt: from internlm.checkpoint.vescale.api import load as vescale_load checkpoint_state = {"model": ckpt_mm.model, "optimizer": ckpt_mm.optimizer} - vescale_load(load_ckpt_folder, checkpoint_state, broadcast_checkpoint=False) + vescale_load(load_ckpt_folder, checkpoint_state, broadcast_checkpoint=gpc.config.ckpt.universal_ckpt.broadcast_load) if not universal_ckpt and load_content.need_load(CheckpointLoadContent.MODEL): load_model_checkpoint(folder=load_ckpt_folder, model=ckpt_mm.model) @@ -448,7 +448,7 @@ def try_save_checkpoint(self, train_state, force=False): train_state=train_state, model_config=self.model_config, model_config_file=self.model_config_file, - universal_ckpt=gpc.config.ckpt.universal_ckpt, + universal_ckpt=gpc.config.ckpt.universal_ckpt.enable, ) if ( @@ -591,7 +591,7 @@ def try_resume_training(self, train_state: TrainState, current_time=""): load_path = self.load_ckpt_info["path"] load_content = self.load_ckpt_info["content"] load_type = self.load_ckpt_info["ckpt_type"] - universal_ckpt = gpc.config.ckpt.universal_ckpt + universal_ckpt = gpc.config.ckpt.universal_ckpt.enable kwargs = {} if universal_ckpt: @@ -656,7 +656,7 @@ def save_checkpoint( vescale_save( path=folder, checkpoint_state={"model": model, "optimizer": optimizer}, - async_checkpoint=False, + async_checkpoint=gpc.config.ckpt.universal_ckpt.aysnc_save, ) diff --git a/internlm/checkpoint/vescale/common.py b/internlm/checkpoint/vescale/common.py index 2d19fef28..d69f29a95 100644 --- a/internlm/checkpoint/vescale/common.py +++ b/internlm/checkpoint/vescale/common.py @@ -59,7 +59,7 @@ def sort_rank_ranges(process_list: List[Tuple]) -> List[Tuple]: return sorted_process_list -_MAX_CACHE_SIZE = 8 +_MAX_CACHE_SIZE = 2 # model ckpt + optm ckpt class PlanLRUCache: diff --git a/internlm/checkpoint/vescale/filesystem.py b/internlm/checkpoint/vescale/filesystem.py index ca8dc3ff4..71f21024c 100644 --- a/internlm/checkpoint/vescale/filesystem.py +++ b/internlm/checkpoint/vescale/filesystem.py @@ -27,6 +27,7 @@ from internlm.core.context import global_context as gpc from internlm.core.context import ParallelMode from internlm.train.pipeline import map_fqn_global_to_local, map_layer_attr +from internlm.utils.common import get_current_device from torch.distributed.checkpoint.metadata import ( @@ -880,8 +881,9 @@ def read_from_files(self, per_file: Dict[str, List[ReadItem]], planner: LoadPlan bytes.seek(0) planner.load_bytes(req, bytes) else: - tensor = cast(Tensor, torch.load(file_slice, map_location="cpu")) + tensor = cast(Tensor, torch.load(file_slice, map_location="cpu")) #att tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths) + print(f"req: {req.dest_index.fqn}, {req}", flush=True) target_tensor = planner.resolve_tensor(req).detach() assert ( @@ -892,18 +894,20 @@ def read_from_files(self, per_file: Dict[str, List[ReadItem]], planner: LoadPlan def read_data_with_broadcast(self, per_file: Dict[str, List[ReadItem]], planner: LoadPlanner): for relative_path, reqs in per_file.items(): - if dist.get_rank(self.data_parallel_process_group) == 0: + # if dist.get_rank(self.data_parallel_process_group) == 0: + if gpc.get_local_rank(ParallelMode.DATA) == 0: file_path = self._get_file_path(relative_path) file = open(file_path, "rb") dist.barrier(self.data_parallel_process_group) reqs = sorted(reqs, key=lambda req: self.storage_data[req.storage_index].offset) for req in reqs: - if dist.get_rank(self.data_parallel_process_group) == 0: + if gpc.get_local_rank(ParallelMode.DATA)== 0: item_md = self.storage_data[req.storage_index] file_slice = self._slice_file(file, item_md) if req.type == LoadItemType.BYTE_IO: - if dist.get_rank(self.data_parallel_process_group) == 0: + assert False + if gpc.get_local_rank(ParallelMode.DATA) == 0: object_list = [io.BytesIO(file_slice.read(item_md.length))] else: object_list = [None] @@ -912,13 +916,13 @@ def read_data_with_broadcast(self, per_file: Dict[str, List[ReadItem]], planner: object_list, src=dist.get_global_rank(self.data_parallel_process_group, 0), group=self.data_parallel_process_group, - device=f"cuda:{torch.cuda.current_device()}", + device=get_current_device(), ) bytes = object_list[0] bytes.seek(0) planner.load_bytes(req, bytes) else: - if dist.get_rank(self.data_parallel_process_group) == 0: + if gpc.get_local_rank(ParallelMode.DATA) == 0: object_list = [cast(Tensor, torch.load(file_slice, map_location="cuda"))] else: object_list = [None] @@ -926,9 +930,9 @@ def read_data_with_broadcast(self, per_file: Dict[str, List[ReadItem]], planner: object_list, src=dist.get_global_rank(self.data_parallel_process_group, 0), group=self.data_parallel_process_group, - device=f"cuda:{torch.cuda.current_device()}", + device=get_current_device(), ) - tensor = object_list[0].cpu() + tensor = object_list[0].cpu() #att tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths) target_tensor = planner.resolve_tensor(req).detach() diff --git a/internlm/checkpoint/vescale/save_state_dict.py b/internlm/checkpoint/vescale/save_state_dict.py index 6772bce0b..20d5883d9 100644 --- a/internlm/checkpoint/vescale/save_state_dict.py +++ b/internlm/checkpoint/vescale/save_state_dict.py @@ -53,7 +53,6 @@ def save_state_dict( [veScale version] Saves a distributed model in SPMD style. Fix sub-group storage. Args and usage is the same as `torch.distributed.checkpoint.save_state_dict`. """ - # Step 0: create distributed world based on process group and coordinator rank distW = _DistWrapper(process_group, not no_dist, coordinator_rank) if process_group: @@ -132,6 +131,7 @@ def finish_checkpoint(all_results): # Wait for last write futures to finish. if last_write_futures: + print(f"last_write_futures: {last_write_futures}", flush=True) logger.info("Start waiting for last write events.") last_write_start_time = time.time() for fut in last_write_futures: @@ -145,22 +145,23 @@ def finish_checkpoint(all_results): plan_start_time = time.time() cached_data = None + # if isinstance(planner, VeScaleSavePlanner): + # central_plan = distW.reduce_scatter("plan", local_step, global_step) + # else: + # raise AssertionError("Unsupported planner for saving checkpoint") + if isinstance(planner, VeScaleSavePlanner): - central_plan = distW.reduce_scatter("plan", local_step, global_step) + cached_data = planner.lookup_plan_meta() + if cached_data: + logger.info("Plan cache hit. Reuse existing plan") + central_plan, _ = cached_data + # _ = local_step() #attn + else: + logger.info("Plan cache miss. The model/optimizer appears for the first time.") + + central_plan = distW.reduce_scatter("plan", local_step, global_step) else: raise AssertionError("Unsupported planner for saving checkpoint") - # if isinstance(planner, VeScaleSavePlanner): #attn - # cached_data = planner.lookup_plan_meta() - # if cached_data: - # logger.debug("Plan cache hit. Reuse existing plan") - # central_plan, _ = cached_data - # _ = local_step() - # else: - # logger.debug("Plan cache miss. The model/optimizer appears for the first time.") - - # central_plan = distW.reduce_scatter("plan", local_step, global_step) - # else: - # raise AssertionError("Unsupported planner for saving checkpoint") @@ -194,7 +195,7 @@ def finish_checkpoint(all_results): final_storage_metadata = distW.all_reduce("write", write_data, finish_checkpoint) assert central_plan is not None assert final_storage_metadata is not None - # planner.cache_plan_meta(central_plan, final_storage_metadata) #attn + planner.cache_plan_meta(central_plan, final_storage_metadata) #attn else: raise AssertionError("Unsupported planner for writing data and metadata") store_local_cost_time = time.time() - store_local_start_time diff --git a/internlm/checkpoint/vescale/vescale_checkpointer.py b/internlm/checkpoint/vescale/vescale_checkpointer.py index c28b23e90..0327aee27 100644 --- a/internlm/checkpoint/vescale/vescale_checkpointer.py +++ b/internlm/checkpoint/vescale/vescale_checkpointer.py @@ -225,8 +225,8 @@ def load( # print(f"model_state {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.PIPELINE)}: {p})", flush=True) # Set process group if broadcast_checkpoint: - assert False - model_load_process_group = VESCALE_DEVICE_MESH.get_data_parallel_dim_groups() + # model_load_process_group = VESCALE_DEVICE_MESH.get_data_parallel_dim_groups() + model_load_process_group = gpc.get_group(ParallelMode.DATA) else: model_load_process_group = None # Load model diff --git a/internlm/checkpoint/vescale/vescale_planner.py b/internlm/checkpoint/vescale/vescale_planner.py index d1b60fbc7..598440d0d 100644 --- a/internlm/checkpoint/vescale/vescale_planner.py +++ b/internlm/checkpoint/vescale/vescale_planner.py @@ -135,17 +135,33 @@ def lookup_object(self, index: MetadataIndex, fqn=None) -> Any: return find_state_dict_object(self.state_dict, index, fqn) def lookup_plan_meta(self) -> Optional[Tuple[SavePlan, Metadata]]: + # if not hasattr(self, STATE_DICT_STR): + # return None + # else: + # device_mesh = VESCALE_DEVICE_MESH.get() + # plan_key = hash((frozenset(self.state_dict.keys()), self.is_coordinator, device_mesh)) + # return self._plan_cache.get(plan_key) + if not hasattr(self, STATE_DICT_STR): return None else: - device_mesh = VESCALE_DEVICE_MESH.get() - plan_key = hash((frozenset(self.state_dict.keys()), self.is_coordinator, device_mesh)) + plan_key = hash((frozenset(self.state_dict.keys()), self.is_coordinator)) return self._plan_cache.get(plan_key) def cache_plan_meta(self, new_plan: SavePlan, new_metadata: Metadata) -> None: - device_mesh = VESCALE_DEVICE_MESH.get() - plan_key = hash((frozenset(self.state_dict.keys()), self.is_coordinator, device_mesh)) + # device_mesh = VESCALE_DEVICE_MESH.get() + # plan_key = hash((frozenset(self.state_dict.keys()), self.is_coordinator, device_mesh)) + # self._plan_cache.put(plan_key, new_plan, new_metadata) + + print(f"new_plan {gpc.get_global_rank()}: {new_plan}", flush=True) + print(f"new_metadata {gpc.get_global_rank()}: {new_metadata}", flush=True) + + plan_key = hash((frozenset(self.state_dict.keys()), self.is_coordinator)) + print(f"Before GPU Memory Allocated {gpc.get_global_rank()}: {torch.cuda.memory_allocated() /1024/1024} bytes", flush=True) + print(f"Before GPU Memory Cached {gpc.get_global_rank()}: {torch.cuda.memory_reserved() /1024/1024} bytes", flush=True) self._plan_cache.put(plan_key, new_plan, new_metadata) + print(f"After GPU Memory Allocated {gpc.get_global_rank()}: {torch.cuda.memory_allocated() /1024/1024} bytes", flush=True) + print(f"After GPU Memory Cached {gpc.get_global_rank()}: {torch.cuda.memory_reserved() /1024/1024} bytes", flush=True) def clear_cache(self) -> None: self._plan_cache.clear() diff --git a/internlm/checkpoint/vescale/vescale_planner_helpers.py b/internlm/checkpoint/vescale/vescale_planner_helpers.py index 5c8316ada..513123634 100644 --- a/internlm/checkpoint/vescale/vescale_planner_helpers.py +++ b/internlm/checkpoint/vescale/vescale_planner_helpers.py @@ -298,6 +298,7 @@ def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex, fq # if isinstance(obj, torch.Tensor): #att # return find_tensor_shard(obj, index) if isinstance(obj, OptimizerStateSpec): + assert False return obj.local_tensor # elif index.offset is not None: # raise ValueError( diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index d3c66665c..b593b0788 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -261,7 +261,7 @@ def args_sanity_check(): ckpt._add_item("auto_resume", True) if "universal_ckpt" not in ckpt: - ckpt._add_item("universal_ckpt", False) + ckpt._add_item("universal_ckpt", dict(enable=False, aysnc_save=False, broadcast_load=False)) if gpc.is_rank_for_log(): logger.info("+" * 15 + " Ckpt Info " + "+" * 15) # pylint: disable=W1201 diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index ac410cd7d..c3a6acc10 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -1000,7 +1000,7 @@ def state_dict(self): optim_states = self.optim.state_dict() grad_scaler = self.grad_scaler.state_dict() states["grad_scaler"] = grad_scaler - if not gpc.config.ckpt.universal_ckpt: + if not gpc.config.ckpt.universal_ckpt.enable: states["base_optim_states"] = optim_states flat_fp32_weights = {} for group_id, param in self._fp32_flat_param_groups_of_current_rank.items(): @@ -1217,7 +1217,7 @@ def state_dict(self): def load_state_dict(self, states, global_optimizer_state=None): - if not gpc.config.ckpt.universal_ckpt: + if not gpc.config.ckpt.universal_ckpt.enable: # TODO: Need to take into account the change in the number of DP. assert "grad_scaler" in states, "Not found grad_scaler state!" grad_scaler = states["grad_scaler"] From 05bbd8f1560e3bf560a662567c188cdce8b80913 Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Wed, 25 Dec 2024 13:54:54 +0800 Subject: [PATCH 04/12] univseral ckpt --- configs/7B_internlm2.py | 4 + internlm/checkpoint/checkpoint_manager.py | 33 +- .../universal_checkpoint/__init__.py | 6 + .../checkpoint/universal_checkpoint/api.py | 48 ++ .../universal_checkpoint/checkpointer.py | 269 +++++++ .../checkpoint/universal_checkpoint/common.py | 110 +++ .../universal_checkpoint/filesystem.py | 691 ++++++++++++++++++ .../universal_checkpoint/load_state_dict.py | 85 +++ .../universal_checkpoint/mem_checkpoint.py | 131 ++++ .../universal_checkpoint/planner.py | 150 ++++ .../universal_checkpoint/planner_helpers.py | 157 ++++ .../universal_checkpoint/save_state_dict.py | 163 +++++ internlm/core/trainer_builder.py | 107 +-- internlm/model/modules/linear.py | 15 +- internlm/model/modules/mha.py | 29 +- .../solver/optimizer/hybrid_zero_optim.py | 292 ++------ internlm/train/pipeline.py | 129 ++-- 17 files changed, 1956 insertions(+), 463 deletions(-) create mode 100644 internlm/checkpoint/universal_checkpoint/__init__.py create mode 100644 internlm/checkpoint/universal_checkpoint/api.py create mode 100644 internlm/checkpoint/universal_checkpoint/checkpointer.py create mode 100644 internlm/checkpoint/universal_checkpoint/common.py create mode 100644 internlm/checkpoint/universal_checkpoint/filesystem.py create mode 100644 internlm/checkpoint/universal_checkpoint/load_state_dict.py create mode 100644 internlm/checkpoint/universal_checkpoint/mem_checkpoint.py create mode 100644 internlm/checkpoint/universal_checkpoint/planner.py create mode 100644 internlm/checkpoint/universal_checkpoint/planner_helpers.py create mode 100644 internlm/checkpoint/universal_checkpoint/save_state_dict.py diff --git a/configs/7B_internlm2.py b/configs/7B_internlm2.py index 97758bba4..2ac67bdce 100644 --- a/configs/7B_internlm2.py +++ b/configs/7B_internlm2.py @@ -38,6 +38,10 @@ async_upload=True, # async ckpt upload. (only work for boto3 ckpt) async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. + # control universal ckpt. INFO: Not compatible with the original ckpt + # Default to use async_save and not use broadcast_load + # as broadcast_load may cause loading performance degradation + universal_ckpt=dict(enable=False, aysnc_save=True, broadcast_load=False), ) TRAIN_FOLDER = None diff --git a/internlm/checkpoint/checkpoint_manager.py b/internlm/checkpoint/checkpoint_manager.py index 5698d3182..dd5f7b4e5 100644 --- a/internlm/checkpoint/checkpoint_manager.py +++ b/internlm/checkpoint/checkpoint_manager.py @@ -8,6 +8,7 @@ import torch from internlm.accelerator import get_accelerator +from internlm.checkpoint.universal_checkpoint.api import universal_load, universal_save from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.trainer import TrainState @@ -43,7 +44,6 @@ ) from .load_funcs import LOAD_FUNC_DICT from .utils import process_load_info -from internlm.checkpoint.vescale.api import save as vescale_save logger = get_logger(__file__) internlm_accelerator = get_accelerator() @@ -83,11 +83,16 @@ def try_load_internevo_ckpt(ckpt_mm, load_info, train_state: TrainState = None, and the checkpoint manager ckpt_mm and train state objects """ load_content_str, load_ckpt_folder, load_content = process_load_info(load_info) - + if universal_ckpt: - from internlm.checkpoint.vescale.api import load as vescale_load - checkpoint_state = {"model": ckpt_mm.model, "optimizer": ckpt_mm.optimizer} - vescale_load(load_ckpt_folder, checkpoint_state, broadcast_checkpoint=gpc.config.ckpt.universal_ckpt.broadcast_load) + checkpoint_state = {} + if load_content.need_load(CheckpointLoadContent.MODEL): + checkpoint_state["model"] = ckpt_mm.model + if load_content.need_load(CheckpointLoadContent.OPIMIZER): + checkpoint_state["optimizer"] = ckpt_mm.optimizer + universal_load( + load_ckpt_folder, checkpoint_state, broadcast_checkpoint=gpc.config.ckpt.universal_ckpt.broadcast_load + ) if not universal_ckpt and load_content.need_load(CheckpointLoadContent.MODEL): load_model_checkpoint(folder=load_ckpt_folder, model=ckpt_mm.model) @@ -115,8 +120,7 @@ def try_load_internevo_ckpt(ckpt_mm, load_info, train_state: TrainState = None, if gpc.is_rank_for_log(): logger.warning("CheckpointManager has no 'lr_scheduler', skip reload lr_scheduler checkpoint!") - if not load_content.need_load(CheckpointLoadContent.OPIMIZER): - assert False + if not universal_ckpt and not load_content.need_load(CheckpointLoadContent.OPIMIZER): if ckpt_mm.lr_scheduler and train_state: gpc.config.only_load_lr = True load_optimizer_checkpoint(load_ckpt_folder, ckpt_mm.optimizer) @@ -426,11 +430,10 @@ def is_now_to_save_ckpt(self, train_state, force=False) -> (bool, CheckpointSave def try_save_checkpoint(self, train_state, force=False): if not self.enable_save_ckpt: return False - + save_ckpts, save_type, now_break = self.is_now_to_save_ckpt(train_state, force=force) if save_ckpts: - begin = time.time() # Wait for the previous round of asynchronous upload storage to complete. self.storage_manager.wait() if save_type == CheckpointSaveType.SNAPSHOT_CHECKPOINT: @@ -469,8 +472,6 @@ def try_save_checkpoint(self, train_state, force=False): f"Finish to convert internevo2hf checkpoint from {save_ckpt_folder} to {save_hf_ckpt_folder}." ) torch.distributed.barrier() - end = time.time() - begin - print(f"finsh save time {gpc.get_global_rank()}: {end}", flush=True) return now_break @@ -587,13 +588,12 @@ def try_resume_training(self, train_state: TrainState, current_time=""): f"dp={gpc.get_local_rank(ParallelMode.DATA)}===========" ) else: - begin = time.time() load_path = self.load_ckpt_info["path"] load_content = self.load_ckpt_info["content"] load_type = self.load_ckpt_info["ckpt_type"] universal_ckpt = gpc.config.ckpt.universal_ckpt.enable kwargs = {} - + if universal_ckpt: assert load_type == "internevo", "Only internevo ckpt support universal ckpt." kwargs = {"universal_ckpt": universal_ckpt} @@ -616,8 +616,6 @@ def try_resume_training(self, train_state: TrainState, current_time=""): ) if load_content_str: logger.info(f"===========Load contents are: {load_content_str}") - end = time.time() - begin - print(f"finsh load time {gpc.get_global_rank()}: {end}", flush=True) @llm_timeout(func_name="save_checkpoint") def save_checkpoint( @@ -643,7 +641,6 @@ def save_checkpoint( logger.info(f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count}...") if not universal_ckpt: - print(f"save ckpt: base", flush=True) timer("save-model").start() save_model_checkpoint(folder=folder, model=model) timer("save-model").stop() @@ -652,13 +649,11 @@ def save_checkpoint( save_optimizer_checkpoint(optim=optimizer, state_path=folder) timer("save-optimizer").stop() else: - print(f"save ckpt: universal", flush=True) - vescale_save( + universal_save( path=folder, checkpoint_state={"model": model, "optimizer": optimizer}, async_checkpoint=gpc.config.ckpt.universal_ckpt.aysnc_save, ) - if ( hasattr(train_state, "data_state_dict") diff --git a/internlm/checkpoint/universal_checkpoint/__init__.py b/internlm/checkpoint/universal_checkpoint/__init__.py new file mode 100644 index 000000000..6d144c821 --- /dev/null +++ b/internlm/checkpoint/universal_checkpoint/__init__.py @@ -0,0 +1,6 @@ +from .api import universal_load, universal_save + +__all__ = [ + "universal_save", + "universal_load", +] diff --git a/internlm/checkpoint/universal_checkpoint/api.py b/internlm/checkpoint/universal_checkpoint/api.py new file mode 100644 index 000000000..bffa75dd4 --- /dev/null +++ b/internlm/checkpoint/universal_checkpoint/api.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +# adopted from https://github.com/volcengine/veScale/blob/main/vescale/checkpoint + +from .checkpointer import UniversalCheckpointer +from .common import CheckpointState + + +def universal_save(path: str, checkpoint_state: CheckpointState, async_checkpoint=False): + """ + Save a checkpoint to a given path + Args: + path: Defines the storage path for checkpoint. + checkpoint_state: A dictionary contains key-value pairs for model and optimizer. + - Model: Identified by 'model' key, value should be a model instance. + - Optimizer: Identified by 'optimizer' key, value should be an optimizer instance. + async_checkpoint: A boolean value indicating if saving checkpoint asynchronously, + i.e. after dumping tensors from GPU memory to Host memory, + the training program can continue training immediately. + Then universal_checkpoint will serialize tensors and dumping to + the persistent storage asynchronously. + Example: + >>> checkpoint_state = { "model": nn.Module, "optimizer": HybridZeroOptimizer } + >>> UniversalCheckpointer.save("/ckpt", checkpoint_state) + """ + UniversalCheckpointer.save(path, checkpoint_state, async_checkpoint=async_checkpoint) + + +def universal_load(path: str, checkpoint_state: CheckpointState, broadcast_checkpoint=False): + """ + Load a checkpoint from a given path + Args: + path: Defines the storage path for checkpoint. + checkpoint_state: A dictionary contains key-value pairs for model and optimizer. + - Model: Identified by 'model' key, value should be a model instance. + - Optimizer: Identified by 'optimizer' key, value should be an optimizer instance. + broadcast_checkpoint: A boolean value decides if load a model replica from one data parallel process group + then broadcast tensors to other data parallel process group using GPUs + to reduce the file system access + For example, when data parellel size = 2, + processes with data parallel rank = 0 load model from file system + then broadcast it to processes with data parallel rank = 1 + Example: + >>> checkpoint_state = { "model": nn.Module, "optimizer": HybridZeroOptimizer } + >>> UniversalCheckpointer.load("/ckpt", checkpoint_state) + """ + UniversalCheckpointer.load(path, checkpoint_state, broadcast_checkpoint=broadcast_checkpoint) diff --git a/internlm/checkpoint/universal_checkpoint/checkpointer.py b/internlm/checkpoint/universal_checkpoint/checkpointer.py new file mode 100644 index 000000000..16838ff61 --- /dev/null +++ b/internlm/checkpoint/universal_checkpoint/checkpointer.py @@ -0,0 +1,269 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +# adopted from https://github.com/volcengine/veScale/blob/main/vescale/checkpoint/api + +import atexit +import os +from concurrent.futures import Future, ProcessPoolExecutor +from typing import Dict, List + +import torch +import torch.distributed as dist +from torch.distributed.checkpoint.storage import WriteResult + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.solver.optimizer import HybridZeroOptimizer +from internlm.utils.common import get_current_device +from internlm.utils.logger import get_logger +from internlm.utils.megatron_timers import megatron_timer as timer + +from .common import MODEL_STR, OPTIMIZER_STR, CheckpointState +from .load_state_dict import load_state_dict +from .planner import UniversalLoadPlanner, UniversalSavePlanner +from .save_state_dict import save_state_dict + +logger = get_logger(__file__) + +NUM_IO_WORKER = 1 +SUPPORTED_TYPES = {MODEL_STR, OPTIMIZER_STR} + + +class BaseCheckpointer: + """ + The Checkpointer class offers APIs that enable users to save and load state dictionarie. + It is designed for extension across various training frameworks. + """ + + # Async IO related members. + state_io_workers: Dict[str, ProcessPoolExecutor] = {} + state_write_futures: Dict[str, Future[List[WriteResult]]] = {} + + @classmethod + def save(cls, path: str, checkpoint_state: CheckpointState): + """ + A Method for saving checkpoint + Args: + path: Defines the storage path for checkpoint. + checkpoint_state: A dictionary contains key-value pairs for model and optimizer. + - Model: Identified by 'model' key, value should be a model instance. + - Optimizer: Identified by 'optimizer' key, value should be an optimizer instance. + + """ + raise NotImplementedError() + + @classmethod + def load(cls, path: str, checkpoint_state: CheckpointState): + """ + A Method for loading checkpoint + Args: + path: Defines the storage path for checkpoint. + checkpoint_state: A dictionary contains key-value pairs for model and optimizer. + - Model: Identified by 'model' key, value should be a model instance. + - Optimizer: Identified by 'optimizer' key, value should be an optimizer instance. + + """ + raise NotImplementedError() + + @classmethod + def _cleanup_futures(cls): + """ + Wait for all write futures to finish before exit, then do the cleanup works. + + WARNING: this method cannot be called by the users. + """ + for key in SUPPORTED_TYPES: + if key in cls.state_write_futures: + futures = cls.state_write_futures[key] + for fut in futures: + fut.result() + cls.state_write_futures[key] = [] + if cls.state_io_workers[key] is not None: + cls.state_io_workers[key].shutdown() + cls.state_io_workers[key] = None + + +class UniversalCheckpointer(BaseCheckpointer): + """ + The Checkpointer class for universal checkpoint, A PyTorch Native Auto Parallelism Framework + """ + + save_planner = UniversalSavePlanner() + load_planner = UniversalLoadPlanner() + + optim_ckpt_proces_group = None + for key in SUPPORTED_TYPES: + BaseCheckpointer.state_io_workers[key] = ProcessPoolExecutor(max_workers=NUM_IO_WORKER) + BaseCheckpointer.state_write_futures[key] = [] + + @classmethod + def save( + cls, + path: str, + checkpoint_state: CheckpointState, + async_checkpoint: bool = False, + ): + """ + async_checkpoint: A boolean value indicating if saving checkpoint asynchronously, + i.e. after dumping tensors from GPU memory to Host memory, + the training program can continue training immediately. + Then checkpoint will serialize tensors and dumping to + the persistent storage asynchronously. + """ + # Check if we support saving the components + for key in checkpoint_state.keys(): + if key not in SUPPORTED_TYPES: + raise ValueError(f"{key} is not supported by UniversalCheckpointer") + + # Preprocess saving path + if path.startswith("local:"): + path = path.split(":")[1] + assert ":" not in path, f"{path} is not valid for universal checkpoint!" + + # Start saving checkpoint + for key, value in checkpoint_state.items(): + if key == MODEL_STR: + # Get model path + model_path = os.path.join(path, MODEL_STR) + # Create a "model" folder on under root path + if gpc.get_global_rank() == 0: + os.makedirs(model_path, exist_ok=True) + dist.barrier() + # Save model. + timer("save-model").start() + _, new_write_futures = save_state_dict( + state_dict=value.state_dict(), + path=model_path, + process_group=None, + coordinator_rank=0, + no_dist=False, + planner=cls.save_planner, + async_io=async_checkpoint, + last_write_futures=cls.state_write_futures[MODEL_STR], + io_workers=cls.state_io_workers[MODEL_STR], + is_optimizer=False, + ) + # Record new write futures. + cls.state_write_futures[MODEL_STR] = new_write_futures + dist.barrier() + timer("save-model").stop() + elif key == OPTIMIZER_STR: + # adamW hybrid zero optim + assert isinstance(value, HybridZeroOptimizer), "unsupported optimizer for universal ckpt" + optimizer_state = value.state_dict() + # Create a "optimizer" folder on under root path + # to save different parts of optimizer + optimizer_path = os.path.join(path, OPTIMIZER_STR) + if gpc.get_global_rank() == 0: + os.makedirs(optimizer_path, exist_ok=True) + dist.barrier() + # Save optimizer + timer("save-optimizer").start() + _, new_write_futures = save_state_dict( + state_dict=optimizer_state["sharded_optimizer_state"], + path=optimizer_path, + process_group=None, + coordinator_rank=0, + no_dist=False, + planner=cls.save_planner, + async_io=async_checkpoint, + last_write_futures=cls.state_write_futures[OPTIMIZER_STR], + io_workers=cls.state_io_workers[OPTIMIZER_STR], + is_optimizer=True, + ) + # Record new write futures. + cls.state_write_futures[OPTIMIZER_STR] = new_write_futures + # Save the global part of optimizer state + optimizer_state.pop("sharded_optimizer_state") + if gpc.get_global_rank() == 0: + torch.save(optimizer_state, os.path.join(path, "global_optimizer_state.pt")) + dist.barrier() + timer("save-optimizer").stop() + + @classmethod + def load( + cls, + path: str, + checkpoint_state: CheckpointState, + broadcast_checkpoint: bool = False, + ): + """ + broadcast_checkpoint: A boolean value decides if load a model replica from one data parallel process group + then broadcast tensors to other data parallel process group using GPUs + to reduce the file system access + For example, when data parellel size = 2, + processes with data parallel rank = 0 load model from file system + then broadcast it to processes with data parallel rank = 1 + """ + # Check if we support loading the component. + for key in checkpoint_state.keys(): + if key not in SUPPORTED_TYPES: + raise ValueError(f"{key} is not supported by UniversalCheckpointer") + + # Preprocess loading path + if path.startswith("local:"): + path = path.split(":")[1] + assert ":" not in path, f"{path} is not valid for universal checkpoint!" + + # Start loading checkpoint + for key, value in checkpoint_state.items(): + if key == MODEL_STR: + # Get model path and state dictionary + model_path = os.path.join(path, MODEL_STR) + model_state = value.state_dict() + # Set process group + if broadcast_checkpoint: + model_load_process_group = gpc.get_group(ParallelMode.DATA) + else: + model_load_process_group = None + # Load model + load_state_dict( + state_dict=model_state, + path=model_path, + process_group=model_load_process_group, + coordinator_rank=0, + no_dist=False, + planner=cls.load_planner, + broadcast_tensors=broadcast_checkpoint, + ) + # Load back to model + value.load_state_dict(model_state) + elif key == OPTIMIZER_STR: + # Get optimizer path and state dictionary + optimizer_path = os.path.join(path, OPTIMIZER_STR) + optimizer_state = value.state_dict() + # Load optimizer state dictionary + load_state_dict( + state_dict=optimizer_state["sharded_optimizer_state"], + path=optimizer_path, + process_group=None, + coordinator_rank=0, + no_dist=False, + planner=cls.load_planner, + broadcast_tensors=False, + is_optimizer=True, + ) + # Load back to optimizer + global_optimizer_state = torch.load( + os.path.join(path, "global_optimizer_state.pt"), map_location=get_current_device() + ) + value.load_state_dict(optimizer_state["sharded_optimizer_state"], global_optimizer_state) + dist.barrier() + + @classmethod + def __cleanup(cls): + """ + Wait for all write futures to finish before exit, then do the cleanup works. + + WARNING: this method cannot be called by the users. + """ + cls.save_planner.clear_cache() + BaseCheckpointer._cleanup_futures() + + @classmethod + def _register_cleanup(cls): + atexit.register(UniversalCheckpointer.__cleanup) + + +UniversalCheckpointer._register_cleanup() diff --git a/internlm/checkpoint/universal_checkpoint/common.py b/internlm/checkpoint/universal_checkpoint/common.py new file mode 100644 index 000000000..011077a4f --- /dev/null +++ b/internlm/checkpoint/universal_checkpoint/common.py @@ -0,0 +1,110 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +# adopted from https://github.com/volcengine/veScale/blob/main/vescale/checkpoint/planner + +import collections +import dataclasses +from collections import OrderedDict +from typing import Any, Dict, Hashable, List, Optional, Tuple, TypeVar + +from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex +from torch.distributed.checkpoint.planner import SavePlan +from typing_extensions import Protocol, runtime_checkable + +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) + + +MODEL_STR = "model" +OPTIMIZER_STR = "optimizer" +STATE_DICT_STR = "state_dict" +_MAX_CACHE_SIZE = 2 # model ckpt + optm ckpt + + +@runtime_checkable +class Stateful(Protocol): + def state_dict(self) -> Dict[str, Any]: + ... + + def load_state_dict(self, state_dict: Dict[str, Any], *args) -> None: + ... + + +T = TypeVar("T", bound=Stateful) +CheckpointState = Dict[str, T] + + +class PlanLRUCache: + """ + For saving cache. + """ + + def __init__(self) -> None: + self._cache: OrderedDict[Hashable, Tuple[SavePlan, Metadata]] = OrderedDict() + self._capacity = _MAX_CACHE_SIZE + + def get(self, key: Hashable) -> Optional[Tuple[SavePlan, Metadata]]: + if key in self._cache: + return self._cache[key] + else: + return None + + def put(self, key: Hashable, plan_value: SavePlan, metadata_value: Metadata) -> None: + if key in self._cache: + self._cache.move_to_end(key, last=False) + else: + self._cache[key] = (plan_value, metadata_value) + if len(self._cache) > self._capacity: + self._cache.popitem() + + def clear(self) -> None: + self._cache.clear() + self._capacity = _MAX_CACHE_SIZE + + def __repr__(self) -> str: + return f"PlanLURCache(capacity: {self._capacity}, keys: {tuple(self._cache.keys())})" + + +def custom_dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]: + """ + A function to remove duplicate tensors to write + when creating global writing plan for saving checkpoint + During the deduplication, + we balance the workloads for duplicated tensors + """ + key_to_plan: Dict[MetadataIndex, List[int]] = {} + for plan_idx, plan in enumerate(all_plans): + for write_item in plan.items: + key_to_plan.setdefault(write_item.index, []).append(plan_idx) + + replicated_items = {k: v for k, v in key_to_plan.items() if len(v) > 1} + # Remove duplicates by always keeping the first entry (Not balance). + # Compute the per-rank remove set. + plan_to_keys: Dict[int, List[MetadataIndex]] = {} + # Record the number of non-duplicated tensors assigned to each rank + assigned_work_load = collections.defaultdict(int) + for plan_idx, plan in enumerate(all_plans): + for write_item in plan.items: + if write_item.index not in replicated_items: + assigned_work_load[plan_idx] += 1 + + for key, plans in replicated_items.items(): + # For duplicated tensors, select the rank assigned with minimum number tensors so far + writer_id = min(plans, key=lambda k: assigned_work_load[k]) + assigned_work_load[writer_id] += 1 + for plan_idx in plans: + # If the rank is not writer rank, remove the key in the rank's plan + if plan_idx != writer_id: + plan_to_keys.setdefault(plan_idx, []).append(key) + # logger.info("Duplicate keys to remove: %s", plan_to_keys) + + for plan_idx, keys in plan_to_keys.items(): + # Key Set contains keys to remove + key_set = set(keys) + # rewrite items and remove elements + new_items = [write_item for write_item in all_plans[plan_idx].items if write_item.index not in key_set] + all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items) + + return all_plans diff --git a/internlm/checkpoint/universal_checkpoint/filesystem.py b/internlm/checkpoint/universal_checkpoint/filesystem.py new file mode 100644 index 000000000..1129dd302 --- /dev/null +++ b/internlm/checkpoint/universal_checkpoint/filesystem.py @@ -0,0 +1,691 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +# adopted from https://github.com/volcengine/veScale/blob/main/vescale/checkpoint/storage + +import collections +import dataclasses +import io +import os +import pickle +from abc import ABC, abstractmethod +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Tuple, Union, cast + +import torch +import torch.distributed as dist +from torch import Tensor +from torch._utils import _get_device_module +from torch.distributed._shard._utils import narrow_tensor_by_index +from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex +from torch.distributed.checkpoint.planner import ( + LoadPlan, + LoadPlanner, + ReadItem, + SavePlan, + SavePlanner, + WriteItem, + WriteItemType, +) +from torch.distributed.checkpoint.storage import ( + StorageReader, + StorageWriter, + WriteResult, +) +from torch.distributed.checkpoint.utils import _create_file_view +from torch.futures import Future + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.train.pipeline import map_fqn_global_to_local +from internlm.utils.common import get_current_device +from internlm.utils.logger import get_logger + +from .mem_checkpoint import ( + copy_gpu_tensor_to_cpu_pinned_mem_pool, + deallocate_cpu_tensor_in_pinned_mem_pool, +) + +logger = get_logger(__file__) + +__all__ = [ + "FileSystemWriter", + "FileSystemReader", +] + + +@dataclass +class _StorageInfo: + """ + This is the per entry storage info + """ + + relative_path: str + offset: int + length: int + + +@dataclass +class _StoragePrefix: + prefix: str + + +DEFAULT_SUFFIX = ".distcp" + + +def _result_from_write_item(item: WriteItem, size_in_bytes, storage_data) -> WriteResult: + return WriteResult(index=item.index, size_in_bytes=size_in_bytes, storage_data=storage_data) + + +class _TensorLoader(ABC): + """ + Abstract class + """ + + @abstractmethod + def add(self, fqn, size, obj): + pass + + @abstractmethod + def start_loading(self): + pass + + @abstractmethod + def values(self): + pass + + +class _SerialCpuLoader(_TensorLoader): + """ + Abstract class + currently no use + """ + + def __init__(self, resolve_fun, p2p_tensors_info=None): + self.resolve_fun = resolve_fun + self.items = [] + # For HybridZeroOptimizer, p2p_tensors_info is always none. + self.p2p_tensors_info = p2p_tensors_info + + def add(self, fqn, size, obj): + self.items.append((fqn, size, obj)) + + def start_loading(self): + pass + + def values(self): + for _, _, obj in self.items: + tensor = self.resolve_fun(obj).detach() + tensor = copy_gpu_tensor_to_cpu_pinned_mem_pool(tensor) + # Comment the original DCP code + # When dumping to pinned memory, + # the memory layout for tensor has been contiguous + # if tensor.storage().size() != tensor.numel(): + # tensor = tensor.clone() + yield ( + tensor, + obj, + ) + + +class _OverlappingCpuLoader(_TensorLoader): + """ + currently no use + """ + + def __init__( + self, + resolve_fun, + p2p_tensors_info=None, + stream=None, + inflight_threshhold=1_000_000, + ): + self.resolve_fun = resolve_fun + self.items = [] + self.inflight_threshhold = inflight_threshhold + self.in_flight_data = 0 + self.current_items: collections.deque = collections.deque() + self.idx = 0 + self.started = False + self.device_type = stream.device_type if stream else torch.device("cuda").type + self.device_module = _get_device_module(self.device_type) + # For HybridZeroOptimizer, p2p_tensors_info is always none. + self.p2p_tensors_info = p2p_tensors_info + self.stream = stream or self.device_module.current_stream() + if self.stream != self.device_module.current_stream(): + self.stream.wait_stream(self.device_module.current_stream()) + + @property + def _done(self): + return self.idx >= len(self.items) + + def _drain(self): + drained = [] + if self.in_flight_data >= self.inflight_threshhold: + self.stream.synchronize() + while self.in_flight_data >= self.inflight_threshhold: + val = self.current_items.popleft() + self.in_flight_data -= val[0].numel() * val[0].element_size() + drained.append(val) + return drained + + def _refill(self): + with self.device_module.stream(self.stream): + while not self._done and self.in_flight_data < self.inflight_threshhold: + _, _, obj = self.items[self.idx] + self.idx += 1 + tensor = self.resolve_fun(obj).detach() + if tensor.device.type == self.device_type: + tensor = copy_gpu_tensor_to_cpu_pinned_mem_pool(tensor, non_blocking=True) + + self.current_items.append( + ( + tensor, + obj, + ) + ) + self.in_flight_data += tensor.numel() * tensor.element_size() + + def _finish(self): + assert self._done + if len(self.current_items) > 0: + self.stream.synchronize() + return self.current_items + + def add(self, fqn, size, obj): + if self.started: + raise RuntimeError("cannot add items after loading started") + self.items.append((fqn, size, obj)) + + def start_loading(self): + if self.started: + return + self.started = True + self.items.sort(key=lambda x: x[1]) + self._refill() + + def values(self): + self.start_loading() + while not self._done: + drained = self._drain() + self._refill() + yield from drained + + yield from self._finish() + + +def _item_fqn(item: WriteItem) -> str: + return item.index.fqn + + +def _item_size(item: WriteItem) -> int: + size = 1 + assert item.tensor_data is not None + # can't use math.prod as PT needs to support older python + for s in item.tensor_data.size: + size *= s + + dtype = item.tensor_data.properties.dtype + return size * torch._utils._element_size(dtype) + + +def _split_by_size_and_type(bins, items: List[WriteItem]) -> List[List[WriteItem]]: + if bins == 1: + return [items] + + bytes_w = [wi for wi in items if wi.type == WriteItemType.BYTE_IO] + tensor_w = [wi for wi in items if wi.type != WriteItemType.BYTE_IO] + assert len(bytes_w) == 0, "currently no WriteItemType.BYTE_IO" + + buckets: List[List[WriteItem]] = [[] for _ in range(bins)] + bucket_sizes = [0 for _ in range(bins)] + + tensor_w.sort(key=_item_size, reverse=True) + + for i, wi in enumerate(bytes_w): + buckets[i % bins].append(wi) + + for wi in tensor_w: + idx = min(enumerate(bucket_sizes), key=lambda x: x[1])[0] + buckets[idx].append(wi) + bucket_sizes[idx] += _item_size(wi) + + return buckets + + +def _write_item(stream, data, write_item, storage_key): + offset = stream.tell() + assert isinstance(data, torch.Tensor) + assert data.device == torch.device("cpu") + torch.save(data, stream) + length = stream.tell() - offset + + return _result_from_write_item(write_item, length, _StorageInfo(storage_key, offset, length)) + + +def _write_files_from_queue( + file_name, + storage_key, + write_items, + planner: SavePlanner, + inflight_threshhold: int, + use_fsync: bool, + p2p_tensors_info=None, # currently no use +): + loader: _TensorLoader + + if torch.cuda.is_available() and inflight_threshhold > 0: + loader = _OverlappingCpuLoader( + lambda x: planner.resolve_data(x), # pylint: disable=W0108 + inflight_threshhold=inflight_threshhold, + p2p_tensors_info=p2p_tensors_info, + ) + else: + loader = _SerialCpuLoader( + lambda x: planner.resolve_data(x), p2p_tensors_info=p2p_tensors_info # pylint: disable=W0108 + ) + + tensor_w = list(write_items) + for write_item in tensor_w: + loader.add(_item_fqn(write_item), _item_size(write_item), write_item) + loader.start_loading() + + write_results = [] + stream = open(file_name, "wb") # pylint: disable=R1732 + + for tensor, write_item in loader.values(): + assert tensor.is_cpu + write_results.append(_write_item(stream, tensor, write_item, storage_key)) + # WARNING: Call deallocate_cpu_tensor_in_pinned_mem_pooltensor + # when the reference to CPU tensor goes to zero + # so the memory pool will reuse the memory if possbile + # Othterwise, the memory pool will allocate memory on the used memory range, + # leading to cuda error 712 cudaErrorHostMemoryAlreadyRegistered + deallocate_cpu_tensor_in_pinned_mem_pool(tensor) + + if use_fsync: + os.fsync(stream.fileno()) + + stream.close() + return write_results + + +def _serialize_tensor(tensor: torch.Tensor) -> bytes: + bio = io.BytesIO() + # NOTE: currently use torch.save() to do the serialization. + torch.save(tensor, bio) + return bio.getbuffer() + + +def _write_to_file(stream, content: bytes, write_item: WriteItem, storage_key: str) -> WriteResult: + offset = stream.tell() + stream.write(content) + length = stream.tell() - offset + return _result_from_write_item(write_item, length, _StorageInfo(storage_key, offset, length)) + + +def _write_files_per_proc_pipe( + file_path: Path, + storage_key: str, + byte_data_item: List[Tuple[io.BytesIO, WriteItem]], + tensor_data_item: List[Tuple[torch.Tensor, WriteItem]], + use_fsync: bool, +) -> List[WriteResult]: + write_futures = [] + write_results = [] + stream = open(file_path, "wb") # pylint: disable=R1732 + executor = ThreadPoolExecutor(max_workers=1) + # For byte data, directly write byte data. + assert len(byte_data_item) == 0 + for write_data, write_item in byte_data_item: + content = write_data.getbuffer() + write_futures.append( + executor.submit( + _write_to_file, + stream, + content, + write_item, + storage_key, + ) + ) + + # For tensor data, perform serialization in process then do saving in threadpool. + for write_data, write_item in tensor_data_item: + content = _serialize_tensor(write_data) + write_futures.append( + executor.submit( + _write_to_file, + stream, + content, + write_item, + storage_key, + ) + ) + + for fut in write_futures: + write_results.append(fut.result()) + if use_fsync: + os.fsync(stream.fileno()) + executor.shutdown(wait=False) + return write_results + + +class FileSystemWriter(StorageWriter): + """ + Basic implementation of StorageWriter using file IO. + + This implementation makes the following assumptions and simplifications: + + * The checkpoint path is an empty or non-existing directory. + * File creation is atomic + + The checkpoint consist of one file per write request plus + a `.metadata` file with the serialized metadata. + + """ + + def __init__( + self, + path: Union[str, os.PathLike], + single_file_per_rank: bool = True, + sync_files: bool = True, + worker_count: int = 1, + per_process_copy_ahead: int = 10_000_000, + ) -> None: + """ + Initialize the writer pointing to `path` + + Args: + path: directory where the checkpoint will be written to. + single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. + sync_files : force files to be synced to permanent storage. Default to True. + worker_count: Number of IO workers (processes) to use to write. Default to 1. + per_process_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. + + N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be + consistent in the case of a failure. + """ + super().__init__() + self.path = Path(path) + self.single_file_per_rank = single_file_per_rank + # self.single_file_per_rank = False + self.sync_files = sync_files + self.worker_count = worker_count + self.per_process_copy_ahead = per_process_copy_ahead + + def set_up_storage_writer(self, is_coordinator: bool) -> None: + pass + + def prepare_local_plan(self, plan: SavePlan, p2p_tensors_info=None) -> SavePlan: + self.path.mkdir(parents=True, exist_ok=True) + # For HybridZeroOptimizer, p2p_tensors_info is always none. + self.p2p_tensors_info = p2p_tensors_info + return plan + + def prepare_global_plan(self, global_plan: List[SavePlan]) -> List[SavePlan]: # pylint: disable=W0237 + new_plans = [ + dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_")) for i, plan in enumerate(global_plan) + ] + return new_plans + + def prepare_write_data(self, tasks: List[Tuple[Path, str, List[WriteItem]]], planner: SavePlanner, is_optimizer): + """ + First stage of saving, Perform Copy data to CPU (D2H). + + Args: + tasks: partitoned tasks for workers to conduct serialization and the actual saving. + planner: save planner used to resolve the bytes and tensor data. + + NOTE: Currently we do D2H synchronously. + """ + + byte_data_item_writes: List[List[Tuple[io.BytesIO, WriteItem]]] = [] + tensor_data_item_writes: List[List[Tuple[torch.Tensor, WriteItem]]] = [] + file_path_names: List[Tuple[Path, str]] = [] + + # Perform D2H in copy stream. + for task in tasks: + file_path, file_name, write_items = task + byte_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO] + tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO] + byte_data_item = [(planner.resolve_data(wi), wi) for wi in byte_w] + assert len(byte_data_item) == 0, "currentlu no WriteItemType.BYTE_IO" + tensor_data_item = [] + + item_list = [] + # Map global fqn to local fqn to search local model state data. + for item in tensor_w: + fqn = _item_fqn(item) + if not is_optimizer: + if fqn in map_fqn_global_to_local: # pylint: disable=R1715 + fqn = map_fqn_global_to_local[fqn] + + # Use tensor.clone() when saving slices to avoid unexpected memory issues. + tensor = planner.resolve_data(item, fqn).detach().clone() + tensor = copy_gpu_tensor_to_cpu_pinned_mem_pool(tensor, non_blocking=True) + tensor_data_item.append((tensor, item)) + item_list.append(item.index.fqn) + byte_data_item_writes.append(byte_data_item) + tensor_data_item_writes.append(tensor_data_item) + file_path_names.append((file_path, file_name)) + + # Deallocate pinned memory. + # NOTE: when prepare_write_data() is called next time, make sure the previous save event is completed. + # Otherwise, tensors in pinned memory pool may be overwritten. + for tensor_data_item in tensor_data_item_writes: + for tensor, _ in tensor_data_item: + assert tensor.is_cpu + deallocate_cpu_tensor_in_pinned_mem_pool(tensor) + + return byte_data_item_writes, tensor_data_item_writes, file_path_names + + def write_data( + self, plan: SavePlan, planner: SavePlanner, async_io: bool = False, io_workers=False, is_optimizer=False + ) -> Future[List[WriteResult]]: + storage_plan: _StoragePrefix = plan.storage_data + file_count = 0 + + def gen_file(): + nonlocal file_count + file_name = f"{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}" + file_count += 1 + return file_name + + tasks: List[Tuple[Path, str, List[WriteItem]]] = [] + # Generate K tasks where K is the number of worker_count. + if self.single_file_per_rank: + for bucket in _split_by_size_and_type(self.worker_count, plan.items): + file_name = gen_file() + tasks.append((self.path / file_name, file_name, bucket)) + # Generate K tasks where K is the number of write items. + else: + for item in plan.items: + file_name = gen_file() + tasks.append((self.path / file_name, file_name, [item])) + + futures = [] + if not io_workers: + executor = ProcessPoolExecutor(max_workers=self.worker_count) + else: + executor = io_workers + + # ProcessPool VERSION. + if isinstance(executor, ProcessPoolExecutor): + byte_data_item_writes, tensor_data_item_writes, file_path_names = self.prepare_write_data( + tasks, planner, is_optimizer + ) + for byte_data_item, tensor_data_item, file_path_name in zip( + byte_data_item_writes, tensor_data_item_writes, file_path_names + ): + file_path, storage_key = file_path_name + worker_args = (file_path, storage_key, byte_data_item, tensor_data_item, self.sync_files) + futures.append(executor.submit(_write_files_per_proc_pipe, *worker_args)) + if async_io: + return futures + else: + for fut in futures: + fut.result() + # fut.wait() + return futures + else: + # ThreadPool VERSION. + assert False, "unavailable version" + for task in tasks: + futures.append( + executor.submit( + _write_files_from_queue, + *task, + planner, + self.per_process_copy_ahead, + self.sync_files, + ) + ) + if async_io: + return futures + else: + for fut in futures: + fut.result() + return futures + + def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None: + storage_md = dict() + for wr_list in results: + storage_md.update({wr.index: wr.storage_data for wr in wr_list}) + metadata.storage_data = storage_md + with (self.path / ".metadata.tmp").open("wb") as metadata_file: + pickle.dump(metadata, metadata_file) + os.fsync(metadata_file.fileno()) + + (self.path / ".metadata.tmp").rename(self.path / ".metadata") + + def reset(self, checkpoint_id): + pass + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id): + pass + + +class FileSystemReader(StorageReader): + """ + Basic implementation of StorageReader using file IO. + """ + + def __init__( + self, + path: Union[str, os.PathLike], + broadcast_tensors=False, + data_parallel_process_group=None, + ) -> None: + super().__init__() + self.path = path + self.storage_data: Dict[MetadataIndex, _StorageInfo] = dict() + self.broadcast_tensors = broadcast_tensors + self.data_parallel_process_group = data_parallel_process_group + + # If broadcast_tensors is enabled, the data_parallel_process_group is not none + if self.broadcast_tensors: + assert self.data_parallel_process_group + + def _slice_file(self, file, sinfo: _StorageInfo): + return _create_file_view(file, sinfo.offset, sinfo.length) + + def _get_file_path(self, relative_path): + file_path = os.path.join(self.path, relative_path) + return file_path + + def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: + # group requests by file + per_file: Dict[str, List[ReadItem]] = dict() + for read_item in plan.items: + item_md = self.storage_data[read_item.storage_index] + path = item_md.relative_path + per_file.setdefault(path, []).append(read_item) + + # If broadcasting model tensors is enabled, + # let processes with dp_rank=0 load models and broadcast them to other processes + if self.broadcast_tensors: + self.read_data_with_broadcast(per_file=per_file, planner=planner) + else: + # Otherwise, let all ranks load tensors from files + self.read_from_files(per_file=per_file, planner=planner) + + fut: Future = Future() + fut.set_result(None) + + return fut + + def read_from_files(self, per_file: Dict[str, List[ReadItem]], planner: LoadPlanner): + for relative_path, reqs in per_file.items(): + file_path = self._get_file_path(relative_path) + with open(file_path, "rb") as file: + reqs = sorted(reqs, key=lambda req: self.storage_data[req.storage_index].offset) + for req in reqs: + item_md = self.storage_data[req.storage_index] + file_slice = self._slice_file(file, item_md) + tensor = cast(Tensor, torch.load(file_slice, map_location="cpu")) # att + tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths) + target_tensor = planner.resolve_tensor(req).detach() + + assert ( + target_tensor.size() == tensor.size() + ), f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}" + + target_tensor.copy_(tensor) + planner.commit_tensor(req, target_tensor) + + def read_data_with_broadcast(self, per_file: Dict[str, List[ReadItem]], planner: LoadPlanner): + for relative_path, reqs in per_file.items(): + if gpc.get_local_rank(ParallelMode.DATA) == 0: + file_path = self._get_file_path(relative_path) + file = open(file_path, "rb") # pylint: disable=R1732 + dist.barrier(self.data_parallel_process_group) + reqs = sorted(reqs, key=lambda req: self.storage_data[req.storage_index].offset) + for req in reqs: + if gpc.get_local_rank(ParallelMode.DATA) == 0: + item_md = self.storage_data[req.storage_index] + file_slice = self._slice_file(file, item_md) + object_list = [cast(Tensor, torch.load(file_slice, map_location="cuda"))] + else: + object_list = [None] + dist.broadcast_object_list( + object_list, + src=dist.get_global_rank(self.data_parallel_process_group, 0), + group=self.data_parallel_process_group, + device=get_current_device(), + ) + tensor = object_list[0].cpu() + tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths) + target_tensor = planner.resolve_tensor(req).detach() + + assert ( + target_tensor.size() == tensor.size() + ), f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}" + target_tensor.copy_(tensor) + planner.commit_tensor(req, target_tensor) + + # Implementing the abstract function in StorageReader + def read_metadata(self) -> Metadata: + metadata_path = self._get_file_path(".metadata") + with open(metadata_path, "rb") as metadata_file: + metadata = pickle.load(metadata_file) + return metadata + + def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None: + self.storage_data = metadata.storage_data + assert self.storage_data is not None + + def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: + return plan + + def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]: # pylint: disable=W0237 + return global_plan + + def reset(self, checkpoint_id): + pass + + @classmethod + def validate_checkpoint_id(cls, checkpoint_id): + pass diff --git a/internlm/checkpoint/universal_checkpoint/load_state_dict.py b/internlm/checkpoint/universal_checkpoint/load_state_dict.py new file mode 100644 index 000000000..65c9f413d --- /dev/null +++ b/internlm/checkpoint/universal_checkpoint/load_state_dict.py @@ -0,0 +1,85 @@ +################################################################################ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +################################################################################ +# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. +################################################################################ + +from typing import Optional + +import torch.distributed as dist +from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner +from torch.distributed.checkpoint.planner import LoadPlanner +from torch.distributed.checkpoint.utils import _DistWrapper + +from internlm.utils.logger import get_logger + +from .filesystem import FileSystemReader +from .planner import UniversalLoadPlanner + +logger = get_logger(__file__) + +META_DATA_FILE = ".metadata" + + +def load_state_dict( + state_dict: dict, + path: str, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + planner: Optional[LoadPlanner] = None, + broadcast_tensors=False, + is_optimizer=False, +) -> None: + """ + Loads a distributed ``state_dict`` in SPMD style. Fix sub-group storage. + """ + storage_reader = FileSystemReader( + path, + broadcast_tensors=broadcast_tensors, + data_parallel_process_group=process_group, + ) + + # Step 0: create distributed world based on process group and coordinator rank + distW = _DistWrapper(process_group, not no_dist, coordinator_rank) + if process_group: + distW.coordinator_rank = dist.get_global_rank(process_group, distW.coordinator_rank) + if planner is None: + planner = DefaultLoadPlanner() + + # Step 1: all processes create local read plan, + # then coordinator gathers all local plans and create global plan. + def local_step(): + assert planner is not None + metadata = storage_reader.read_metadata() + planner.set_up_planner(state_dict, metadata, distW.is_coordinator) + storage_reader.set_up_storage_reader(metadata, distW.is_coordinator) + + local_plan = planner.create_local_plan(is_optimizer=is_optimizer) + local_plan = storage_reader.prepare_local_plan(local_plan) + return local_plan + + def global_step(all_local_plans): + assert planner is not None + all_local_plans = planner.create_global_plan(all_local_plans) + all_local_plans = storage_reader.prepare_global_plan(all_local_plans) + return all_local_plans + + if isinstance(planner, UniversalLoadPlanner): + central_plan = distW.reduce_scatter("plan", local_step, global_step) + else: + raise AssertionError("Unsupported planner for saving checkpoint") + + # Step 2: all processes read data from the given path + def read_data(): # pylint: disable=R1711 + assert planner is not None + final_local_plan = planner.finish_plan(central_plan) + all_reads = storage_reader.read_data(final_local_plan, planner) + all_reads.wait() + return None + + _ = distW.all_gather("read", read_data) diff --git a/internlm/checkpoint/universal_checkpoint/mem_checkpoint.py b/internlm/checkpoint/universal_checkpoint/mem_checkpoint.py new file mode 100644 index 000000000..a5c4923eb --- /dev/null +++ b/internlm/checkpoint/universal_checkpoint/mem_checkpoint.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +# adopted from https://github.com/volcengine/veScale/tree/main/vescale/checkpoint/utilities + +import io +import pickle +import threading +from typing import DefaultDict + +import torch + +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) + +if hasattr(torch.storage, "TypedStorage"): + TypedStorage = torch.storage.TypedStorage +elif hasattr(torch.storage, "_TypedStorage"): + TypedStorage = torch.storage._TypedStorage + +# TypedStorage changes in pytorch 2. +if torch.__version__ >= "2": + + def untyped_storage(o): + return o.untyped_storage() + + def location_caster(o): + return o + +elif torch.__version__ >= "1.11": + + def untyped_storage(o): + return o.storage()._storage + + def location_caster(o): + return o._storage if isinstance(o, TypedStorage) else o + + +try: + lib = torch.cuda.cudart() +except Exception: + lib = None + + +class PinnedStoragePool: # pylint: disable=C0115 + def __init__(self): + self._l = threading.Lock() + self._m = DefaultDict(set) + + def allocate(self, nbytes: int): + with self._l: + # We don't really need storage to have the exact size. So in theory we can find a + # bigger storage that may suit here. But so far we keep everything simple here. + s = self._m[nbytes] + if not s: + t = torch.empty([nbytes], dtype=torch.uint8) + t = t.share_memory_() + if lib is not None and nbytes != 0: + err = lib.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0) + assert err == 0, err + storage = untyped_storage(t) + s.add(storage) + return s.pop() + + def deallocate(self, s): + # WARNING: Call deallocate when the reference to CPU tensor goes to zero + # so the memory pool will reuse the memory if possbile + # Othterwise, the memory pool will allocate memory on the used memory range, + # leading to cuda error 712 cudaErrorHostMemoryAlreadyRegistered + with self._l: + self._m[s.nbytes()].add(s) + + +GLOBAL_POOL = PinnedStoragePool() + +TID = threading.get_ident() + + +def copy_gpu_tensor_to_cpu_pinned_mem_pool(tensor: torch.Tensor, non_blocking=False) -> torch.Tensor: + """ + Copy a tensor on GPU to pinned memory pool (host CPU memory). + The input tensor will not be modified + Args: + tensor: a tensor on cuda device + Return: + a tensor on cpu, whose data is the same as input tensor + """ + m = {} + _old_warning = getattr(torch.storage, "_warn_typed_storage_removal", None) + torch.storage._warn_typed_storage_removal = lambda *args, **kwags: None + + def persistent_id(o): + if torch.is_storage(o) or isinstance(o, TypedStorage): + storage = o + if storage._cdata in m: + return storage._cdata + if storage.device.type != "cpu": + copied = GLOBAL_POOL.allocate(storage.nbytes()) + copied.copy_(storage, non_blocking=non_blocking) + if isinstance(storage, TypedStorage): + copied = storage._new_wrapped_storage(copied) + else: + copied = storage.clone() + m[storage._cdata] = copied + return storage._cdata + return + + b = io.BytesIO() + p = pickle.Pickler(b) + p.persistent_id = persistent_id + p.dump(tensor) + b.seek(0) + up = pickle.Unpickler(b) + up.persistent_load = lambda i: m[i] + cpu_tensor = up.load() + """ + assert type(tensor) == torch.Tensor + storage_obj = tensor.storage() + cpu_storage = GLOBAL_POOL.allocate(storage_obj.nbytes()) + + cpu_storage.copy_(storage_obj, non_blocking=non_blocking) + cpu_tensor = torch.tensor(cpu_storage) + """ + torch.storage._warn_typed_storage_removal = _old_warning + return cpu_tensor + + +def deallocate_cpu_tensor_in_pinned_mem_pool(tensor: torch.Tensor): + "Deallocate CPU tensor in the global pinned memory pool" + GLOBAL_POOL.deallocate(tensor.untyped_storage()) diff --git a/internlm/checkpoint/universal_checkpoint/planner.py b/internlm/checkpoint/universal_checkpoint/planner.py new file mode 100644 index 000000000..b9dc5bece --- /dev/null +++ b/internlm/checkpoint/universal_checkpoint/planner.py @@ -0,0 +1,150 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +# adopted from https://github.com/volcengine/veScale/tree/main/vescale/checkpoint/planner/vescale + +import dataclasses +import io +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch.distributed.checkpoint.default_planner import ( + DefaultLoadPlanner, + DefaultSavePlanner, +) +from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex +from torch.distributed.checkpoint.planner import ( + LoadPlan, + ReadItem, + SavePlan, + WriteItem, + WriteItemType, +) + +from internlm.train.pipeline import map_fqn_local_to_global +from internlm.utils.logger import get_logger + +from .common import STATE_DICT_STR, PlanLRUCache, custom_dedup_tensors +from .planner_helpers import ( + _create_read_items, + _create_write_items, + find_state_dict_object, +) + +logger = get_logger(__file__) +__all__ = [ + "UniversalSavePlanner", + "UniversalLoadPlanner", + "create_default_local_load_plan", + "create_default_local_save_plan", +] + + +class UniversalLoadPlanner(DefaultLoadPlanner): + """ + A planner class for loading checkpoint using PyTorch DCP + """ + + def __init__(self): + super().__init__() + + def create_local_plan(self, is_optimizer=False) -> LoadPlan: + return create_default_local_load_plan(self.state_dict, self.metadata, is_optimizer) + + def resolve_tensor(self, read_item: ReadItem): + tensor = self.lookup_tensor(read_item.dest_index) + return self.transform_tensor(read_item, tensor) + + def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor: + """ + This is an extension from the planner interface to make it easy to extend the default planner + """ + return find_state_dict_object(self.state_dict, index) + + +def create_default_local_load_plan(state_dict: Dict[str, Any], metadata: Metadata, is_optimizer) -> LoadPlan: + """ + A function for creating local loading plan for loading checkpoint + """ + requests = [] + for fqn, obj in state_dict.items(): + global_fqn = fqn + # For local model state_dict, the default is to use local fqn. + # We need to map it to global fqn for pipeline parallel. + # As saving ckpt is always using global fqn. + if not is_optimizer: + if fqn in map_fqn_local_to_global: # pylint: disable=R1715 + global_fqn = map_fqn_local_to_global[fqn] + + md = metadata.state_dict_metadata[global_fqn] + item = _create_read_items(fqn, global_fqn, md, obj) + requests += item + return LoadPlan(requests) + + +class UniversalSavePlanner(DefaultSavePlanner): + """ + A planner class for saving checkpoint using PyTorch DCP + """ + + def __init__(self): + super().__init__() + self._plan_cache = PlanLRUCache() + + def resolve_data(self, write_item: WriteItem, fqn=None) -> Union[torch.Tensor, io.BytesIO]: + assert write_item.type != WriteItemType.BYTE_IO + local_object = self.lookup_object(write_item.index, fqn) + return self.transform_object(write_item, local_object) + + def create_local_plan(self, is_optimizer=False) -> Tuple[SavePlan, None]: + plan, p2p_tensors_info = create_default_local_save_plan(self.state_dict, self.is_coordinator, is_optimizer) + if self.flatten_state_dict: + plan = dataclasses.replace(plan, planner_data=self.mappings) + self.plan = plan + return self.plan, p2p_tensors_info + + def lookup_object(self, index: MetadataIndex, fqn=None) -> Any: + return find_state_dict_object(self.state_dict, index, fqn) + + def lookup_plan_meta(self) -> Optional[Tuple[SavePlan, Metadata]]: + if not hasattr(self, STATE_DICT_STR): + return None + else: + plan_key = hash((frozenset(self.state_dict.keys()), self.is_coordinator)) + return self._plan_cache.get(plan_key) + + def cache_plan_meta(self, new_plan: SavePlan, new_metadata: Metadata) -> None: + plan_key = hash((frozenset(self.state_dict.keys()), self.is_coordinator)) + self._plan_cache.put(plan_key, new_plan, new_metadata) + + def clear_cache(self) -> None: + self._plan_cache.clear() + + def create_dedup_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]: + # Disable DCP's dedup replicated tensors function + self.dedup_replicated_tensors = False + rst_value = super().create_global_plan(all_plans) + return rst_value + + def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]: + # Disable DCP's dedup replicated tensors function + self.dedup_replicated_tensors = False + # Use customized deduplicate function for load balance + all_plans = custom_dedup_tensors(all_plans) + rst_value = super().create_global_plan(all_plans) + return rst_value + + +def create_default_local_save_plan( + state_dict: Dict[str, Any], is_coordinator: bool, is_optimizer=False # pylint: disable=W0613 +) -> SavePlan: + """ + A function for creating local saving plan for saving checkpoint. + """ + requests = [] + for fqn, obj in state_dict.items(): + assert isinstance(obj, (torch.Tensor)) + item = _create_write_items(fqn, obj, is_optimizer=is_optimizer) + requests += item + + return SavePlan(requests), None diff --git a/internlm/checkpoint/universal_checkpoint/planner_helpers.py b/internlm/checkpoint/universal_checkpoint/planner_helpers.py new file mode 100644 index 000000000..076d21bce --- /dev/null +++ b/internlm/checkpoint/universal_checkpoint/planner_helpers.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +# adopted from https://github.com/volcengine/veScale/tree/main/vescale/checkpoint/planner/vescale + +from typing import Any, List + +import torch +from torch.distributed._shard.sharded_tensor import TensorProperties +from torch.distributed.checkpoint.metadata import ( + STORAGE_TYPES, + ChunkStorageMetadata, + MetadataIndex, + TensorStorageMetadata, +) +from torch.distributed.checkpoint.planner import ( + LoadItemType, + ReadItem, + TensorWriteData, + WriteItem, + WriteItemType, +) +from torch.distributed.checkpoint.resharding import ( + _check_shard_metadata_pair_overlap, + _shards_get_overlap_region_wrt_saved_tensor, +) + +from internlm.train.pipeline import map_fqn_local_to_global, map_layer_attr + + +def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor, is_optimizer=False) -> WriteItem: + offsets = torch.Size([0] * len(tensor.size())) + size = tensor.size() + + # We always save ckpt using global fqn. + # For optim state_dict, it originally uses global fqn. + if not is_optimizer: + if fqn in map_fqn_local_to_global: # pylint: disable=R1715 + fqn = map_fqn_local_to_global[fqn] + + map_fqn = fqn + if map_fqn not in map_layer_attr: + # Deal with exp_avg and exp_avg_sq in base optim + map_fqn = fqn.rsplit(".", 1)[0] + assert map_fqn in map_layer_attr + offsets = torch.Size(map_layer_attr[map_fqn]["offset"]) + size = torch.Size(map_layer_attr[map_fqn]["complete_size"]) + + return WriteItem( + index=MetadataIndex(fqn, offsets), + type=WriteItemType.SHARD, + tensor_data=TensorWriteData( + chunk=ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()), + properties=TensorProperties.create_from_tensor(tensor), + size=size, + ), + ) + + +def _create_write_items(fqn: str, object: Any, is_optimizer=False) -> List[WriteItem]: # pylint: disable=W0622 + assert isinstance(object, torch.Tensor) + return [_create_write_item_for_tensor(fqn, object, is_optimizer=is_optimizer)] + + +def _create_read_item_for_tensor(dest_index, dest_offsets, storage_index, storage_offsets, lengths): + return ReadItem( + type=LoadItemType.TENSOR, + dest_index=dest_index, + dest_offsets=torch.Size(dest_offsets), + storage_index=storage_index, + storage_offsets=torch.Size(storage_offsets), + lengths=torch.Size(lengths), + ) + + +def create_read_items_for_chunk_list( + fqn: str, + global_fqn: str, + checkpoint_md: TensorStorageMetadata, + local_chunks: List[ChunkStorageMetadata], +) -> List[ReadItem]: + """ + Creates a list of ``ReadItem`` based on the checkpoint and local chunks. + + This applies the resharding algorithm and computes the reads needed + to satisfy ``local_chunks`` with a checkpoint described by ``checkpoint_md``. + + Args: + fqn (str): The local state_dict FQN to pass to ``ReadItem``. + global_fqn (str): The global FQN in the checkpoint. + checkpoint_md (TensorStorageMetadata): metadata for a given tensor + from a checkpoint. + local_chunks (List[ChunkStorageMetadata]): Local chunks that needs to be + loaded. + + Returns: + A list of ``ReadItem`` that will satisfy all input chunks. + """ + read_items = [] + # this is a naive quadratic algo that can be optimized later + for idx, shard in enumerate(local_chunks): + for storage_idx, storage_md in enumerate(checkpoint_md.chunks): + if not _check_shard_metadata_pair_overlap(shard, storage_md): + continue + + storage_offsets = [] + dest_offsets = [] + lengths = [] + for ( + _, + offset_for_saved_tensor, + offset_for_current_tensor, + length, + ) in _shards_get_overlap_region_wrt_saved_tensor(saved_shard=storage_md, current_shard=shard): + storage_offsets.append(offset_for_saved_tensor) + dest_offsets.append(offset_for_current_tensor) + lengths.append(length) + + read_items.append( + _create_read_item_for_tensor( + dest_index=MetadataIndex(fqn, shard.offsets, idx), + dest_offsets=dest_offsets, + storage_index=MetadataIndex(global_fqn, storage_md.offsets, storage_idx), + storage_offsets=storage_offsets, + lengths=lengths, + ) + ) + return read_items + + +def _create_chunk_from_tensor(global_fqn, tensor: torch.Tensor) -> ChunkStorageMetadata: + if global_fqn not in map_layer_attr: + # Deal with exp_avg and exp_avg_sq in base optim + global_fqn = global_fqn.rsplit(".", 1)[0] + assert global_fqn in map_layer_attr, f"{global_fqn}" + offsets = torch.Size(map_layer_attr[global_fqn]["offset"]) + + return ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()) + + +def _create_read_items(fqn: str, global_fqn, md: STORAGE_TYPES, obj: Any) -> List[ReadItem]: + assert isinstance(obj, torch.Tensor) + local_chunks = [_create_chunk_from_tensor(global_fqn, obj)] + return create_read_items_for_chunk_list(fqn, global_fqn, md, local_chunks) + + +def find_state_dict_object(state_dict: dict, index: MetadataIndex, fqn=None) -> Any: + # Called when real writing happened + # The filesystem writer calls resolve_data , then it will + # call find_state_dict_object + if fqn is None: + fqn = index.fqn + if fqn not in state_dict: + raise ValueError(f"Could not find FQN: '{fqn}'") + obj = state_dict[fqn] + + return obj diff --git a/internlm/checkpoint/universal_checkpoint/save_state_dict.py b/internlm/checkpoint/universal_checkpoint/save_state_dict.py new file mode 100644 index 000000000..2646a9281 --- /dev/null +++ b/internlm/checkpoint/universal_checkpoint/save_state_dict.py @@ -0,0 +1,163 @@ +################################################################################ +# Copyright (c) Meta Platforms, Inc. and affiliates +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +################################################################################ +# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. +################################################################################ + +import os +import pickle +import time +from concurrent.futures import Future +from typing import List, Optional, Tuple + +import torch.distributed as dist +from torch.distributed.checkpoint.default_planner import DefaultSavePlanner +from torch.distributed.checkpoint.metadata import Metadata +from torch.distributed.checkpoint.planner import SavePlanner +from torch.distributed.checkpoint.storage import WriteResult +from torch.distributed.checkpoint.utils import _DistWrapper + +from internlm.core.context import global_context as gpc +from internlm.utils.logger import get_logger + +from .filesystem import FileSystemWriter +from .planner import UniversalSavePlanner + +logger = get_logger(__file__) + + +def save_state_dict( + state_dict: dict, + path: str, + process_group: Optional[dist.ProcessGroup] = None, + coordinator_rank: int = 0, + no_dist: bool = False, + planner: Optional[SavePlanner] = None, + async_io: bool = True, + last_write_futures: Future[List[WriteResult]] = None, + io_workers=None, + is_optimizer=False, +) -> Tuple[Metadata, Future[List[WriteResult]]]: + """ + Saves a model in SPMD style. Fix sub-group storage. + """ + # Step 0: create distributed world based on process group and coordinator rank + distW = _DistWrapper(process_group, not no_dist, coordinator_rank) + if process_group: + distW.coordinator_rank = dist.get_global_rank(process_group, distW.coordinator_rank) + if planner is None: + planner = DefaultSavePlanner() + assert planner is not None + + global_metatadata = None + + storage_writer = FileSystemWriter(path) + + # Step 1: all processes create local write plan, + # then coordinator gathers all local plans and create global plan. + def local_step(): + if isinstance(planner, UniversalSavePlanner): + local_plan, p2p_tensors_info = planner.create_local_plan(is_optimizer=is_optimizer) + local_plan = storage_writer.prepare_local_plan(local_plan, p2p_tensors_info) + else: + raise AssertionError("Unsupported planner for planning") + return local_plan + + def global_step(all_local_plans): + nonlocal global_metatadata + all_local_plans, global_metatadata = planner.create_global_plan(all_local_plans) + all_local_plans = storage_writer.prepare_global_plan(all_local_plans) + return all_local_plans + + # Step 2: all processes write data from GPUs to pinned memory pool, then dump to local path + # then coordinator write meta-data to local path. + def write_data(async_io: bool = False, io_workers=io_workers): + final_local_plan = planner.finish_plan(central_plan) + if isinstance(planner, UniversalSavePlanner): + # Use pinned memory pool and mult_processing for dumping ckpt to local directory efficiently + all_write_futures = storage_writer.write_data(final_local_plan, planner, async_io, io_workers, is_optimizer) + if async_io: + return all_write_futures + else: + # Gather write results. + values = [] + for fut in all_write_futures: + # values += fut.get() + values += fut.result() + return values + else: + raise AssertionError("Unsupported planner for writing data") + + def finish_checkpoint(all_results): + assert global_metatadata is not None, f"rank: {distW.get_rank()} has no global_metadata" + storage_writer.finish(metadata=global_metatadata, results=all_results) + return global_metatadata + + assert planner is not None + planner.set_up_planner(state_dict, distW.is_coordinator) + storage_writer.set_up_storage_writer(distW.is_coordinator) + + # Wait for last write futures to finish. + if last_write_futures: + if gpc.is_rank_for_log(): + logger.info("Start waiting for last write events.") + last_write_start_time = time.time() + for fut in last_write_futures: + fut.result() + last_write_time = time.time() - last_write_start_time + if gpc.is_rank_for_log(): + logger.info(f"Finish waiting for last write events. Time cost: {last_write_time}s") + + # Each worker bypass the `reduce_scatter()` and `all_reduce()` if finding cached central_plan and metadata. + # NOTE: it fails when the plans of partial workers change while others keep the same. + cached_data = None + if isinstance(planner, UniversalSavePlanner): + cached_data = planner.lookup_plan_meta() + if cached_data: + if gpc.is_rank_for_log(): + logger.info("Plan cache hit. Reuse existing plan") + central_plan, _ = cached_data + # _ = local_step() # backup for origin + else: + if gpc.is_rank_for_log(): + logger.info("Plan cache miss. The model/optimizer appears for the first time.") + central_plan = distW.reduce_scatter("plan", local_step, global_step) + else: + raise AssertionError("Unsupported planner for saving checkpoint") + + write_futures = [] + if isinstance(planner, UniversalSavePlanner): + if cached_data: + if gpc.is_rank_for_log(): + logger.info("Metdata cache hit. Reuse existing metadata") + _, final_storage_metadata = cached_data + write_results = write_data(async_io=async_io) + # Be sure to write cache metadata to .metadata file + # Otherwises only the first checkpoint has .metadata + # which leads to error when loading other checkpoints + if distW.is_coordinator: + with (storage_writer.path / ".metadata.tmp").open("wb") as metadata_file: + pickle.dump(final_storage_metadata, metadata_file) + os.fsync(metadata_file.fileno()) + + (storage_writer.path / ".metadata.tmp").rename(storage_writer.path / ".metadata") + + if async_io: + write_futures = write_results + else: + if gpc.is_rank_for_log(): + logger.info("Metadata cache miss. The model/optimizer appears for the first time.") + # First time do synchronous storing to get final_storage_metatdata. + # Determine which communication topology to use. + final_storage_metadata = distW.all_reduce("write", write_data, finish_checkpoint) + assert central_plan is not None + assert final_storage_metadata is not None + planner.cache_plan_meta(central_plan, final_storage_metadata) # attn + else: + raise AssertionError("Unsupported planner for writing data and metadata") + + return final_storage_metadata, write_futures diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index 8a7a381b6..c9e9d4ab8 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -45,13 +45,10 @@ from internlm.utils.utils import DataType from internlm.utils.writer import Writer - - -import os - # global llm logger logger = logging.getLogger(__file__) + class TrainerBuilder(Trainer): """ A class for building and managing InternEvo training workflow. @@ -129,15 +126,7 @@ def __init__( # initialize checkpoint manager and try resume training self.ckpt_manager = self._initialize_checkpoint_manager(model, optimizer, lr_scheduler, train_dl, config_lines) - self.ckpt_manager.try_resume_training(train_state, self.current_time) - - - - # from internlm.checkpoint.vescale.api import load as vescale_load - # checkpoint_state = {"model": model.model} - # vescale_load("/mnt/petrelfs/lijiaxing/InternEvo/vescale_ckpt_test/iter_10", checkpoint_state, broadcast_checkpoint=False) - # print("finish loading", flush=True) # initialize customed llm writer self.writer = self._initialize_writer(train_state, config_lines) @@ -165,7 +154,6 @@ def __init__( self._set_attributes( kwargs["profiling"], train_dl, val_dls, train_state, optimizer, beta2_scheduler, isp_communicator ) - super().__init__(engine, scheduler) def _setup_time_and_logging(self) -> str: @@ -269,39 +257,18 @@ def fit(self): """ self.train() train_iter = iter(self.train_dl) - - from internlm.checkpoint.vescale.api import load as vescale_load - checkpoint_state = {"model": self.ckpt_manager.model} - checkpoint_state = {"model": self.ckpt_manager.model, "optimizer": self.ckpt_manager.optimizer} - checkpoint_state = {"model": self.ckpt_manager.model, "optimizer": self.ckpt_manager.optimizer} - # vescale_load("/mnt/petrelfs/lijiaxing/InternEvo/vescale_ckpt_test_dp2_tp2_pp2/20", checkpoint_state, broadcast_checkpoint=False) - - - print("finish loading", flush=True) - print(f"rank_log: rank={gpc.get_global_rank()}, dp={gpc.get_local_rank(ParallelMode.DATA)}, pp={gpc.get_local_rank(ParallelMode.PIPELINE)}, tp={gpc.get_local_rank(ParallelMode.TENSOR)}", flush=True) - print(f"train_state: {self.train_state}", flush=True) with initialize_llm_profile(profiling=self.profiling, start_time=self.current_time) as prof: gc.disable() for batch_count in range(self.train_state.batch_count, gpc.config.data.total_steps): - # print(f"norm_weight {gpc.get_global_rank()}: {self.ckpt_manager.model.norm.weight.shape}, {self.ckpt_manager.model.norm.weight}") - if batch_count == 35: - break if self._process_batch(batch_count, train_iter, prof): break self.ckpt_manager.wait_async_upload_finish() - + def _process_batch(self, batch_count: int, train_iter, prof) -> bool: empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) start_time = time.time() timer("one-batch").start() - - # data_path = "/mnt/petrelfs/lijiaxing/InternEvo/test_data_batch.pt" - # batch = torch.load(data_path) - - # if batch[0].get("type_ids", None) is not None: - # self.metric.set_current_type_ids(type_ids=batch[0].pop("type_ids", None)) - batch, train_iter = self._load_and_prepare_batch(batch_count, train_iter) if self.batch_skipper(batch_count): @@ -309,11 +276,6 @@ def _process_batch(self, batch_count: int, train_iter, prof) -> bool: logger.info(f"Skip batch count:`{batch_count}`...") timer("one-batch").stop() return False - - # if batch_count <= 10: - # if gpc.is_rank_for_log(): - # print(f"skip {batch_count}", flush=True) - # return False if gpc.is_rank_for_log(): print(f"start trainging {batch_count}", flush=True) @@ -328,72 +290,9 @@ def _process_batch(self, batch_count: int, train_iter, prof) -> bool: if self._should_evaluate(): self._evaluate() - # print(f"self.ckpt_manager.hybrid {gpc.get_global_rank()}, {gpc.get_local_rank(ParallelMode.ZERO1)}, {gpc.get_local_rank(ParallelMode.PIPELINE)}, {gpc.get_local_rank(ParallelMode.TENSOR)}: {self.ckpt_manager.optimizer.state_dict()}", flush=True) - # print(f"self.ckpt_manager.optimizer {gpc.get_global_rank()}, {gpc.get_local_rank(ParallelMode.ZERO1)}, {gpc.get_local_rank(ParallelMode.PIPELINE)}, {gpc.get_local_rank(ParallelMode.TENSOR)}: {self.ckpt_manager.optimizer.optim.state_dict()}", flush=True) - # print(f"self.ckpt_manager.grad_scaler {gpc.get_global_rank()}, {gpc.get_local_rank(ParallelMode.ZERO1)}, {gpc.get_local_rank(ParallelMode.PIPELINE)}, {gpc.get_local_rank(ParallelMode.TENSOR)}: {self.ckpt_manager.optimizer.grad_scaler.state_dict()}", flush=True) - - # checkpoint_state = {"model": self.ckpt_manager.model, "optimizer": self.ckpt_manager.optimizer} - # checkpoint_state = {"model": self.ckpt_manager.model} - # checkpoint_state = {"optimizer": self.ckpt_manager.optimizer} - # assert 'output.weight' in self.ckpt_manager.model.state_dict() - fqn = 'layers.14.feed_forward.w2.weight' - # print(f"cp_length {gpc.get_local_rank(ParallelMode.ZERO1)}: {len(self.ckpt_manager.optimizer.state_dict()['unflatten_fp32_weights'])}, {len(self.ckpt_manager.model.state_dict())}") - # if fqn in self.ckpt_manager.optimizer.state_dict()['unflatten_fp32_weights']: - # tensor1 = self.ckpt_manager.optimizer.state_dict()['unflatten_fp32_weights'][fqn] - # tensor2 = self.ckpt_manager.model.state_dict()[fqn] - # print(f"self.ckpt_manager.optimizer.state_dict() {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.TENSOR)} {gpc.get_local_rank(ParallelMode.ZERO1)}: {tensor1.dtype}, {tensor1.shape}, {tensor1}", flush=True) - # print(f"self.ckpt_manager.model.state_dict() {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.TENSOR)} {gpc.get_local_rank(ParallelMode.ZERO1)}: {tensor2.shape}, {tensor2}", flush=True) - - # print(f"model_state {dist.get_rank()}: {self.ckpt_manager.model.state_dict()['layers.0.attention.wo.weight'].shape}", flush=True) - # print(f"optimizer_state {dist.get_rank()}: {self.ckpt_manager.optimizer.state_dict()['master_current_weights'][0][0].shape}", flush=True) - # print(f"self.ckpt_manager.optimizer {gpc.get_global_rank()}, {gpc.get_local_rank(ParallelMode.ZERO1)}, {gpc.get_local_rank(ParallelMode.PIPELINE)}, {gpc.get_local_rank(ParallelMode.TENSOR)}: {self.ckpt_manager.optimizer.optim.state_dict()}", flush=True) - - # if batch_count == 1: - # from internlm.checkpoint.vescale.api import load as vescale_load - # checkpoint_state = {"model": self.ckpt_manager.model, "optimizer": self.ckpt_manager.optimizer} - # vescale_load("/mnt/petrelfs/lijiaxing/InternEvo/vescale_ckpt_test_dp2_tp2_pp2/iter_20", checkpoint_state, broadcast_checkpoint=False) - - if batch_count == 20: - from internlm.checkpoint.vescale.devicemesh_api import VESCALE_DEVICE_MESH - from internlm.checkpoint.vescale.device_mesh import init_device_mesh - from internlm.checkpoint.vescale.api import save as vescale_save - from internlm.checkpoint.vescale.api import load as vescale_load - from internlm.checkpoint.deepspeed.save_checkpoint import save_checkpoint - # device_mesh = init_device_mesh( - # "cuda", - # ( - # 2, - # 4, - # ), - # mesh_dim_names=("DP", "TP"), - # ) - - # VESCALE_DEVICE_MESH._GLOBAL_MESH = device_mesh - # vescale_load("/mnt/petrelfs/lijiaxing/InternEvo/vescale_ckpt_test/iter_1", checkpoint_state, broadcast_checkpoint=False) - - checkpoint_state = {"model": self.ckpt_manager.model, "optimizer": self.ckpt_manager.optimizer} - # checkpoint_state = {"optimizer": self.ckpt_manager.optimizer} - # checkpoint_state = {"model": self.ckpt_manager.model} - # vescale_save( - # os.path.join("/mnt/petrelfs/lijiaxing/InternEvo/vescale_ckpt_test_dp2_tp2_pp2", f"iter_{batch_count}"), - # checkpoint_state, - # async_checkpoint=False, - # ) - # print("finish save", flush=True) - - # save_checkpoint( - # save_dir="/mnt/petrelfs/lijiaxing/InternEvo/deepspeed_ckpt", - # model=self.ckpt_manager.model, - # optimizer=self.ckpt_manager.optimizer, - # lr_scheduler=self.ckpt_manager.lr_scheduler, - # train_state=self.train_state, - # ) - - - + if self.ckpt_manager.try_save_checkpoint(self.train_state): return True - self._update_profilers(batch_count, prof) return False diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index db181e05d..7759757bb 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -607,8 +607,7 @@ def __init__( mod = multiple % world_size # The first @mod ranks get @div + 1 copies, the rest get @div copies local_multiple = div + int(rank < mod) - # if parallel_mode == ParallelMode.TENSOR: - # print(f"ParallelLinearWithCommExt {split_mode}: infe={in_features}, outfe={out_features}, split={local_multiple * multiple_of}, local_multiple={local_multiple}, multiple_of={multiple_of}", flush=True) + if split_mode == "column": super().__init__(in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype) self.offset = [rank * local_multiple * multiple_of, 0] @@ -617,19 +616,11 @@ def __init__( self.offset = [0, rank * local_multiple * multiple_of] else: super().__init__(in_features, out_features, bias=bias, device=device, dtype=dtype) + self.complete_size = [out_features, in_features] setattr(self.weight, "offset", self.offset) setattr(self.weight, "complete_size", [out_features, in_features]) - if self.weight.offset[0] != 0: - k = self.weight.offset[0] // rank - assert self.weight.complete_size[0] % k == 0 and self.weight.complete_size[0] // k == gpc.get_world_size(parallel_mode), f"{self.weight.complete_size}, {self.weight.offset}" - else: - assert rank == 0 or self.weight.size()[0] == self.weight.complete_size[0], f"{rank}, {self.weight.size()}, {self.weight.complete_size}, {self.weight.offset} \n split_mode={split_mode}, in_features={in_features}, out_features={out_features}, multiple_of={multiple_of}, multiple={multiple}, world_size={world_size}, div={div}, mod={mod}, local_multiple={local_multiple}" - if self.weight.offset[1] != 0: - k = self.weight.offset[1] // rank - assert self.weight.complete_size[1] % k == 0 and self.weight.complete_size[1] // k == gpc.get_world_size(parallel_mode), f"{self.weight.complete_size}, {self.weight.offset}" - else: - assert rank == 0 or self.weight.size()[1] == self.weight.complete_size[1], f"{rank}, {self.weight.size()}, {self.weight.complete_size}, {self.weight.offset}" + def forward(self, input: torch.Tensor, batch_sizes: torch.Tensor = None) -> torch.Tensor: # pylint: disable=W0622 _class_name = self.__class__.__name__ assert self._communicator is not None, f"{_class_name} should register with a communicator first." diff --git a/internlm/model/modules/mha.py b/internlm/model/modules/mha.py index a3ad146db..704da3072 100644 --- a/internlm/model/modules/mha.py +++ b/internlm/model/modules/mha.py @@ -65,26 +65,6 @@ def _qkv_pre_load_convert(module: "GQA", state_dict, prefix: str, *args, **kwarg ) -def _qkv_save_convert(module: "GQA", state_dict, prefix: str, *args, **kwargs) -> Dict: # pylint: disable=W0613 - wq_name, wk_name, wv_name, fused_name = ( - f"{prefix}wq.weight", - f"{prefix}wk.weight", - f"{prefix}wv.weight", - f"{prefix}wqkv.weight", - ) - from internlm.train.pipeline import map_layer_attr - map_layer_attr[wq_name] = map_layer_attr[fused_name] - map_layer_attr[wk_name] = map_layer_attr[fused_name] - map_layer_attr[wv_name] = map_layer_attr[fused_name] - - if module.enable_qkv_fusion: - state_dict[wq_name], state_dict[wk_name], state_dict[wv_name] = split_fused_wqkv_weight( - state_dict.pop(fused_name), *args, **kwargs - ) - - return state_dict - - class MHA(nn.Module): """ Multi-head self-attention and cross-attention. @@ -466,16 +446,15 @@ def __init__( if enable_qkv_fusion: assert bias is False, "Fuesd wqkv only support bias is False." self.wqkv = new_linear("wqkv", embed_dim, q_dim + 2 * self.kv_dim, bias, **factory_kwargs) - self._register_load_state_dict_pre_hook( - partial(_qkv_pre_load_convert, q_dim=q_dim, kv_dim=self.kv_dim), with_module=True - ) - # self._register_state_dict_hook(partial(_qkv_save_convert, q_dim=q_dim, kv_dim=self.kv_dim)) else: - assert False self.wq = new_linear("wq", embed_dim, q_dim, bias, **factory_kwargs) self.wk = new_linear("wk", embed_dim, self.kv_dim, bias, **factory_kwargs) self.wv = new_linear("wv", embed_dim, self.kv_dim, bias, **factory_kwargs) + self._register_load_state_dict_pre_hook( + partial(_qkv_pre_load_convert, q_dim=q_dim, kv_dim=self.kv_dim), with_module=True + ) + self.inner_attn = SelfAttention( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout, layer_idx=layer_idx ) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index c3a6acc10..31b927979 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -34,13 +34,13 @@ from internlm.solver.optimizer.utils import ( DynamicGradScaler, flatten, - unflatten, get_grad_accumulate_object, has_inf_or_nan, reduce_tensor, release_param_grad, split_half_float_double, sync_param, + unflatten, ) from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger @@ -51,7 +51,6 @@ from .base_optimizer import BaseOptimizer from .utils import compute_norm - inf = math.inf logger = get_logger(__file__) internlm_accelerator = get_accelerator() @@ -151,19 +150,12 @@ def __init__( assert self._param_bcast_sync_handler is not None self._isp_communicator = isp_communicator - - self.param_global_shape_info = {} - self.param_local_shape_info = {} - self.param_global_offset_info = {} - self.param_across_dp_ranks_info = {} - self.fqn_list = {} - shape_list = {} + # iterate over the param group in the optimizer # partition these param groups for data parallel training # and add buffers to parameter store for future access for group_id, param_group in enumerate(self.optim.param_groups): group_params = param_group["params"] - self.fqn_list[group_id] = [] # set the dtype for each param group param_group["dtype"] = group_params[0].dtype if len(group_params) != 0 else None @@ -203,11 +195,6 @@ def __init__( for param in params: setattr(param, "group_id", group_id) self._param_store.set_param_to_rank(param, rank) - - # for param in params_per_rank[gpc.get_local_rank(ParallelMode.ZERO1)]: - # print(f"fp16_fqn {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.ZERO1)}: {param.fqn}", flush=True) - - # move to cpu to make room to create the flat tensor for param in group_params: @@ -220,22 +207,13 @@ def __init__( # No flat fp16 buffer is allocated if the process has no parameters. if rank not in self.param_group_no_params_ranks[group_id]: tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id) - + with torch.no_grad(): flat_tensor = flatten(tensor_list) flat_tensor = flat_tensor.data.to(get_current_device()) self._param_store.add_flat_fp16_param_by_rank_group(rank, group_id, flat_tensor) sync_param(flat_tensor=flat_tensor, tensor_list=tensor_list) - - if rank == gpc.get_local_rank(ParallelMode.ZERO1): - offset = 0 - for tensor in tensor_list: - shape_list[tensor.fqn] = tensor.shape - self.fqn_list[group_id].append(tensor.fqn) - offset += tensor.numel() - # print(f"fqn_list {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.ZERO1)}: {len(params_per_rank[rank])}, {len(tensor_list)}, {len(self.fqn_list)}, {self.fqn_list}", flush=True) - - # print(f"shape_list {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.TENSOR)} {gpc.get_local_rank(ParallelMode.ZERO1)}: {shape_list}", flush=True) + # create a copy of fp32 weights of the parameters for which this rank is responsible # No flat fp32 buffer is allocated if the process has no parameters. if self.param_group_has_params[group_id]: @@ -247,14 +225,6 @@ def __init__( fp32_flat_current_rank = fp32_flat_current_rank.to(device) fp32_flat_current_rank.requires_grad = True self._fp32_flat_param_groups_of_current_rank[group_id] = fp32_flat_current_rank - print(f"fp32_flat_current_rank {gpc.get_global_rank()}: {fp32_flat_current_rank.shape}", flush=True) - - - - self.param_global_shape_info[fp32_flat_current_rank] = fp32_flat_current_rank.shape - self.param_local_shape_info[fp32_flat_current_rank] = fp32_flat_current_rank.shape - self.param_global_offset_info[fp32_flat_current_rank] = [0, 0] - # need to replace the params in the `params` field in the optimizer # so that when the optimizer calls step(), it only updates the tensors @@ -273,9 +243,6 @@ def __init__( self.skip_grad_reduce = False self._attach_reduction_hook() - - # print(f"self.grad_scaler.state_dict() {gpc.get_global_rank()}: {self.grad_scaler.state_dict()}", flush=True) - # print(f"self.optm.state_dict() {gpc.get_global_rank()}: {self.optim.state_dict()}", flush=True) @property def zero_local_rank(self): @@ -301,7 +268,7 @@ def _partition_param_list(self, group_id, param_group): param_list = param_group["params"] sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True) - + for i, param in enumerate(sorted_params): if param.requires_grad is False: continue @@ -326,7 +293,6 @@ def _partition_param_list(self, group_id, param_group): logger.info( # pylint: disable=W1203 f"Number of elements on ranks: {numel_per_rank}, rank:{gpc.get_global_rank()}" ) - # print(f"sorted_params: {len(sorted_params)}, {len(params_per_rank)}, {params_per_rank}, {no_params_ranks}", flush=True) return params_per_rank, set(no_params_ranks) def _is_moe_group(self, param_group): @@ -1001,6 +967,7 @@ def state_dict(self): grad_scaler = self.grad_scaler.state_dict() states["grad_scaler"] = grad_scaler if not gpc.config.ckpt.universal_ckpt.enable: + # original ckpt states["base_optim_states"] = optim_states flat_fp32_weights = {} for group_id, param in self._fp32_flat_param_groups_of_current_rank.items(): @@ -1010,214 +977,64 @@ def state_dict(self): states["flat_fp32_weights"] = flat_fp32_weights states["zero_devide_optim_plan"] = self.params_per_rank_id_dict else: - import time - if gpc.get_global_rank() == 0: - start = time.time() - print(f"optim_states: {optim_states}", flush=True) + # universal ckpt. INFO: currently only adapt to AdamW + # empty_states is used as the initialization of state_dict when loading ckpt empty_states = False - if len(optim_states['state']) == 0: + if len(optim_states["state"]) == 0: empty_states = True - optim_states['state'] = {0:{'step':0}} - + optim_states["state"] = {0: {"step": 0}} + + # To save tensor that needs to be sharded sharded_optimizer_state = {} - temp = [] for group_id, flatten_fp32_param in self._fp32_flat_param_groups_of_current_rank.items(): rank = self._zero_local_rank[group_id] - fqn_list = self.fqn_list[group_id] if rank not in self.param_group_no_params_ranks[group_id]: - # fp32 param - assert len(fqn_list) > 0 tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id) if len(tensor_list) > 0: unflatten_tensor_list = unflatten(flatten_fp32_param, tensor_list) - # base optimizer state # notice: we assume that one param group corresponds to one flattened tensor. - # we only save unflattened otimizer state. - # print(type(group_id), flush=True) - # print(optim_states['state'], flush=True) + # we will save unflattened otimizer state for universal ckpt. if not empty_states: - flatten_exp_avg = optim_states['state'][group_id]['exp_avg'] - flatten_exp_avg_sq = optim_states['state'][group_id]['exp_avg_sq'] + flatten_exp_avg = optim_states["state"][group_id]["exp_avg"] + flatten_exp_avg_sq = optim_states["state"][group_id]["exp_avg_sq"] assert flatten_exp_avg.shape == flatten_fp32_param.shape == flatten_exp_avg_sq.shape unflatten_exp_avg = unflatten(flatten_exp_avg, tensor_list) unflatten_exp_avg_sq = unflatten(flatten_exp_avg_sq, tensor_list) - assert len(unflatten_tensor_list) == len(tensor_list) == len(unflatten_exp_avg) == len(unflatten_exp_avg_sq) + assert ( + len(unflatten_tensor_list) + == len(tensor_list) + == len(unflatten_exp_avg) + == len(unflatten_exp_avg_sq) + ) + from internlm.train.pipeline import map_fqn_local_to_global + for i in range(len(tensor_list)): - assert tensor_list[i].fqn == fqn_list[i] assert tensor_list[i].fqn not in sharded_optimizer_state fqn = tensor_list[i].fqn - if fqn in map_fqn_local_to_global: + # For optim ckpt, we directly save global_fqn + if fqn in map_fqn_local_to_global: # pylint: disable=consider-using-get fqn = map_fqn_local_to_global[fqn] - temp.append(fqn) + sharded_optimizer_state[fqn] = unflatten_tensor_list[i] - if not empty_states: sharded_optimizer_state[f"{fqn}.exp_avg"] = unflatten_exp_avg[i] sharded_optimizer_state[f"{fqn}.exp_avg_sq"] = unflatten_exp_avg_sq[i] else: sharded_optimizer_state[f"{fqn}.exp_avg"] = torch.empty_like(unflatten_tensor_list[i]) - sharded_optimizer_state[f"{fqn}.exp_avg_sq"] = torch.empty_like(unflatten_tensor_list[i]) - - else: - assert len(fqn_list) == 0 - - states["step"] = optim_states['state'][0]['step'] + sharded_optimizer_state[f"{fqn}.exp_avg_sq"] = torch.empty_like( + unflatten_tensor_list[i] + ) + + states["step"] = optim_states["state"][0]["step"] states["param_groups"] = optim_states["param_groups"] states["sharded_optimizer_state"] = sharded_optimizer_state - if gpc.get_global_rank() == 0: - end = time.time() - start - print(f"create_state_dict: {end}", flush=True) return states - - # def state_dict(self): - # import time - # if gpc.get_global_rank() == 0: - # start = time.time() - # states = {} - # optim_states = self.optim.state_dict() - # print(f"optim_states: {optim_states}", flush=True) - # empty_states = False - # if len(optim_states['state']) == 0: - # empty_states = True - # optim_states['state'] = {0:{'step':0}} - - - # sharded_optimizer_state = {} - # temp = [] - # for group_id, flatten_fp32_param in self._fp32_flat_param_groups_of_current_rank.items(): - # rank = self._zero_local_rank[group_id] - # fqn_list = self.fqn_list[group_id] - # if rank not in self.param_group_no_params_ranks[group_id]: - # # fp32 param - # assert len(fqn_list) > 0 - # tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id) - # if len(tensor_list) > 0: - # unflatten_tensor_list = unflatten(flatten_fp32_param, tensor_list) - # # base optimizer state - # # notice: we assume that one param group corresponds to one flattened tensor. - # # we only save unflattened otimizer state. - # # print(type(group_id), flush=True) - # # print(optim_states['state'], flush=True) - # if not empty_states: - # flatten_exp_avg = optim_states['state'][group_id]['exp_avg'] - # flatten_exp_avg_sq = optim_states['state'][group_id]['exp_avg_sq'] - # assert flatten_exp_avg.shape == flatten_fp32_param.shape == flatten_exp_avg_sq.shape - # unflatten_exp_avg = unflatten(flatten_exp_avg, tensor_list) - # unflatten_exp_avg_sq = unflatten(flatten_exp_avg_sq, tensor_list) - # assert len(unflatten_tensor_list) == len(tensor_list) == len(unflatten_exp_avg) == len(unflatten_exp_avg_sq) - # from internlm.train.pipeline import map_fqn_local_to_global - # for i in range(len(tensor_list)): - # assert tensor_list[i].fqn == fqn_list[i] - # assert tensor_list[i].fqn not in sharded_optimizer_state - # fqn = tensor_list[i].fqn - # if fqn in map_fqn_local_to_global: - # fqn = map_fqn_local_to_global[fqn] - # temp.append(fqn) - # sharded_optimizer_state[fqn] = unflatten_tensor_list[i] - - # if not empty_states: - # sharded_optimizer_state[f"{fqn}.exp_avg"] = unflatten_exp_avg[i] - # sharded_optimizer_state[f"{fqn}.exp_avg_sq"] = unflatten_exp_avg_sq[i] - # else: - # sharded_optimizer_state[f"{fqn}.exp_avg"] = torch.empty_like(unflatten_tensor_list[i]) - # sharded_optimizer_state[f"{fqn}.exp_avg_sq"] = torch.empty_like(unflatten_tensor_list[i]) - - - # # print(f"debug_before: {self.optim.state_dict()['state'][group_id]['exp_avg']}", flush=True) - # # optim_states['state'][group_id]['exp_avg'] = '' - # # optim_states['state'][group_id]['exp_avg_sq'] = '' - # # print(f"debug_after: {self.optim.state_dict()['state'][group_id]['exp_avg']}", flush=True) - - # else: - # assert len(fqn_list) == 0 - - # grad_scaler = self.grad_scaler.state_dict() - # # print(f"optim_states['state'][0]['step']: {optim_states['state'][0]}, {optim_states['state'][0]['step']}, {type(optim_states['state'][0]['step'])}", flush=True) - # states["grad_scaler"] = grad_scaler - # states["step"] = optim_states['state'][0]['step'] - # states["param_groups"] = optim_states["param_groups"] - # states["sharded_optimizer_state"] = sharded_optimizer_state - # if gpc.get_global_rank() == 0: - # end = time.time() - start - # print(f"create_state_dict: {end}", flush=True) - - # return states - - # def load_state_dict(self, states, global_optimizer_state): - # assert "grad_scaler" in global_optimizer_state, "Not found grad_scaler state!" - # assert "step" in global_optimizer_state, "Not found step state!" - # assert "param_groups" in global_optimizer_state, "Not found param_groups state!" - # print(f"states {gpc.get_global_rank()}: {states.keys()}", flush=True) - - # grad_scaler = global_optimizer_state["grad_scaler"] - # print(f"load_state_dict grad_scaler {gpc.get_global_rank()}: {grad_scaler}", flush=True) - # self.grad_scaler.load_state_dict(grad_scaler) - - # step = global_optimizer_state["step"] - # print(f"load_state_dict step {gpc.get_global_rank()}: {step}, {global_optimizer_state}", flush=True) - # param_groups = global_optimizer_state["param_groups"] - # print(f"load_state_dict param_groups {gpc.get_global_rank()}: {param_groups}", flush=True) - - # if gpc.config.get("only_load_lr", False): - # if gpc.is_rank_for_log(): - # logger.info("Only load lr in param_groups, skip loading weights in optimizer...") - # for pg1, pg2 in zip(self.optim.param_groups, param_groups): - # pg1["lr"] = pg2["lr"] - # return - - # optim_states = {'state': {}, 'param_groups': param_groups} - # for group_id, self_flatten_fp32_param in self._fp32_flat_param_groups_of_current_rank.items(): - # print(f"group_id: {type(group_id)}", flush=True) - # rank = self._zero_local_rank[group_id] - # if rank not in self.param_group_no_params_ranks[group_id]: - # # self fp16 unflatten param list - # self_tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id) - - # if len(self_tensor_list) > 0: - # optim_states['state'][group_id] = {"step": step} - # ckpt_fp32_params = [] - # ckpt_exp_avg_list = [] - # ckpt_exp_avg_sq_list = [] - # for tensor in self_tensor_list: - # fqn = tensor.fqn - # from internlm.train.pipeline import map_fqn_local_to_global - # if fqn in map_fqn_local_to_global: - # fqn = map_fqn_local_to_global[fqn] - # assert tensor.shape == states[fqn].shape == states[f"{fqn}.exp_avg"].shape == states[f"{fqn}.exp_avg_sq"].shape - # ckpt_fp32_params.append(states[fqn]) - # ckpt_exp_avg_list.append(states[f"{fqn}.exp_avg"]) - # ckpt_exp_avg_sq_list.append(states[f"{fqn}.exp_avg_sq"]) - - # ckpt_flatten_fp32_param = flatten(ckpt_fp32_params) - # ckpt_flatten_exp_avg = flatten(ckpt_exp_avg_list) - # ckpt_flatten_exp_avg_sq = flatten(ckpt_exp_avg_sq_list) - - # assert ( - # self_flatten_fp32_param.shape == ckpt_flatten_fp32_param.shape - # ), f"The loaded parameter shape is inconsistent, {self_flatten_fp32_param.shape} != {ckpt_flatten_fp32_param.shape}" - # self_flatten_fp32_param.data.copy_(ckpt_flatten_fp32_param.data) - # optim_states['state'][group_id]['exp_avg'] = ckpt_flatten_exp_avg - # optim_states['state'][group_id]['exp_avg_sq'] = ckpt_flatten_exp_avg_sq - - # self.optim.load_state_dict(optim_states) - - - # # Load the fp16 model weights. - # for group_id in range(len(self._fp16_param_groups)): - # if self._zero_local_rank[group_id] not in self.param_group_no_params_ranks[group_id]: - # fp16_param = self._param_store.get_flat_fp16_param_by_rank_group( - # rank=self._zero_local_rank[group_id], group_id=group_id - # ) - # fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id] - # fp16_param.data.copy_(fp32_param) - - - def load_state_dict(self, states, global_optimizer_state=None): if not gpc.config.ckpt.universal_ckpt.enable: + # original ckpt # TODO: Need to take into account the change in the number of DP. assert "grad_scaler" in states, "Not found grad_scaler state!" grad_scaler = states["grad_scaler"] @@ -1256,20 +1073,17 @@ def load_state_dict(self, states, global_optimizer_state=None): if "zero_devide_optim_plan" in states: self.params_per_rank_id_dict = states["zero_devide_optim_plan"] else: + # universal ckpt assert global_optimizer_state is not None assert "grad_scaler" in global_optimizer_state, "Not found grad_scaler state!" assert "step" in global_optimizer_state, "Not found step state!" assert "param_groups" in global_optimizer_state, "Not found param_groups state!" - print(f"states {gpc.get_global_rank()}: {states.keys()}", flush=True) - + grad_scaler = global_optimizer_state["grad_scaler"] - print(f"load_state_dict grad_scaler {gpc.get_global_rank()}: {grad_scaler}", flush=True) self.grad_scaler.load_state_dict(grad_scaler) - + step = global_optimizer_state["step"] - print(f"load_state_dict step {gpc.get_global_rank()}: {step}, {global_optimizer_state}", flush=True) param_groups = global_optimizer_state["param_groups"] - print(f"load_state_dict param_groups {gpc.get_global_rank()}: {param_groups}", flush=True) if gpc.config.get("only_load_lr", False): if gpc.is_rank_for_log(): @@ -1278,40 +1092,48 @@ def load_state_dict(self, states, global_optimizer_state=None): pg1["lr"] = pg2["lr"] return - optim_states = {'state': {}, 'param_groups': param_groups} + optim_states = {"state": {}, "param_groups": param_groups} for group_id, self_flatten_fp32_param in self._fp32_flat_param_groups_of_current_rank.items(): - print(f"group_id: {type(group_id)}", flush=True) rank = self._zero_local_rank[group_id] if rank not in self.param_group_no_params_ranks[group_id]: # self fp16 unflatten param list self_tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id) - + if len(self_tensor_list) > 0: - optim_states['state'][group_id] = {"step": step} + optim_states["state"][group_id] = {"step": step} ckpt_fp32_params = [] ckpt_exp_avg_list = [] ckpt_exp_avg_sq_list = [] + for tensor in self_tensor_list: fqn = tensor.fqn from internlm.train.pipeline import map_fqn_local_to_global - if fqn in map_fqn_local_to_global: + + if fqn in map_fqn_local_to_global: # pylint: disable=consider-using-get fqn = map_fqn_local_to_global[fqn] - assert tensor.shape == states[fqn].shape == states[f"{fqn}.exp_avg"].shape == states[f"{fqn}.exp_avg_sq"].shape + assert ( + tensor.shape + == states[fqn].shape + == states[f"{fqn}.exp_avg"].shape + == states[f"{fqn}.exp_avg_sq"].shape + ) ckpt_fp32_params.append(states[fqn]) ckpt_exp_avg_list.append(states[f"{fqn}.exp_avg"]) ckpt_exp_avg_sq_list.append(states[f"{fqn}.exp_avg_sq"]) - + ckpt_flatten_fp32_param = flatten(ckpt_fp32_params) ckpt_flatten_exp_avg = flatten(ckpt_exp_avg_list) ckpt_flatten_exp_avg_sq = flatten(ckpt_exp_avg_sq_list) - - assert ( - self_flatten_fp32_param.shape == ckpt_flatten_fp32_param.shape - ), f"The loaded parameter shape is inconsistent, {self_flatten_fp32_param.shape} != {ckpt_flatten_fp32_param.shape}" + + assert self_flatten_fp32_param.shape == ckpt_flatten_fp32_param.shape, ( + "The loaded parameter shape is inconsistent," + f"{self_flatten_fp32_param.shape} != {ckpt_flatten_fp32_param.shape}" + ) + self_flatten_fp32_param.data.copy_(ckpt_flatten_fp32_param.data) - optim_states['state'][group_id]['exp_avg'] = ckpt_flatten_exp_avg - optim_states['state'][group_id]['exp_avg_sq'] = ckpt_flatten_exp_avg_sq - + optim_states["state"][group_id]["exp_avg"] = ckpt_flatten_exp_avg + optim_states["state"][group_id]["exp_avg_sq"] = ckpt_flatten_exp_avg_sq + self.optim.load_state_dict(optim_states) def reload_zero_fp32_buff(self): diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index cb8f6f48b..1bf78c646 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -93,7 +93,6 @@ ) from internlm.utils.timeout import llm_timeout from internlm.utils.utils import TensorParallelMode -from internlm.model.ops.norm import RMSNorm try: import torch_npu @@ -117,16 +116,15 @@ logger = get_logger(__file__) internlm_accelerator = get_accelerator() +# For universal checkpoint +# record offset and complete_size of param in each layer map_layer_attr = {} map_fqn_local_to_global = {} map_fqn_global_to_local = {} -def recover_pipeline_idx_for_layers(model, idx): - start_id = model.first_layer - def set_param_unique_tracking_name(model): - # print(f"first_layer {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.PIPELINE)}: {model.first_layer}, {model.last_layer}", flush=True) + uc_enable = gpc.config.ckpt.universal_ckpt.enable for chunk_id, chunk in enumerate(unwrap_naive_amp(model)): # Important: only works for llama-class models childrens = chunk.named_children() @@ -136,12 +134,14 @@ def set_param_unique_tracking_name(model): for name, child in block.named_modules(): if name == "": continue + full_name = f"{chunk_id}.{idx}.{name}" - parts = f"{full_name}.weight".split('.', 2) + name_parts = f"{full_name}.weight".split(".", 2) + # global_id for pipeline parallel case global_id = model.first_layer + idx - result = f"{children_name}." + '.'.join(parts[1:]) - global_fqn = f"{children_name}.{global_id}." + '.'.join(parts[2:]) - + local_fqn = f"{children_name}." + ".".join(name_parts[1:]) + global_fqn = f"{children_name}.{global_id}." + ".".join(name_parts[2:]) + if isinstance(child, (ParallelLinearWithCommExt)): setattr( child.weight, @@ -154,87 +154,80 @@ def set_param_unique_tracking_name(model): "tracking_name", f"{full_name}.bias", ) - - - setattr( - child.weight, - "fqn", - f"{result}", - ) - if child.bias is not None: + + if uc_enable: setattr( - child.bias, + child.weight, "fqn", - f"{result}", + f"{local_fqn}", ) - - - # print(result, flush=True) - assert hasattr(child, "offset"), f"{child}" - - # print(f"layer_name {gpc.get_local_rank(ParallelMode.PIPELINE)}: {children_name}, {result}", flush=True) - - # recover_pipeline_idx_for_layers - - # print(f"original_id {gpc.get_local_rank(ParallelMode.PIPELINE)}: {model.first_layer}, {idx}, {original_id}, {name, child}, {map_fqn}") - - map_fqn_local_to_global[result] = global_fqn - map_fqn_global_to_local[global_fqn] = result - # print(f"map_pp_layer_fqn {gpc.get_local_rank(ParallelMode.PIPELINE)}: {map_pp_layer_fqn}", flush=True) - - assert global_fqn not in map_layer_attr, f"{map_layer_attr} exists" - - map_layer_attr[global_fqn] = {'offset': getattr(child, "offset", [0] * len(child.weight.size())), 'complete_size': getattr(child, "complete_size", child.weight.size())} - # print(f"child.weight {gpc.get_local_rank(ParallelMode.TENSOR)}: {result}, {child.offset}, {child.weight.shape}", flush=True) - elif isinstance(child, (RMSNorm)): - # print(f"global_fqn {gpc.get_local_rank(ParallelMode.PIPELINE)}: {global_fqn}", flush=True) - map_fqn_local_to_global[result] = global_fqn - map_fqn_global_to_local[global_fqn] = result + if child.bias is not None: + setattr( + child.bias, + "fqn", + f"{local_fqn}", + ) + + assert hasattr(child, "offset"), f"{child}" + map_fqn_local_to_global[local_fqn] = global_fqn + map_fqn_global_to_local[global_fqn] = local_fqn + + assert global_fqn not in map_layer_attr, f"{map_layer_attr} exists" + map_layer_attr[global_fqn] = { + "offset": getattr(child, "offset", [0] * len(child.weight.size())), + "complete_size": getattr(child, "complete_size", child.weight.size()), + } + + elif isinstance(child, (RMSNorm)) and uc_enable: + map_fqn_local_to_global[local_fqn] = global_fqn + map_fqn_global_to_local[global_fqn] = local_fqn setattr( child.weight, "fqn", - f"{result}", + f"{local_fqn}", ) - map_layer_attr[global_fqn] = {'offset': getattr(child, "offset", [0] * len(child.weight.size())), 'complete_size': getattr(child, "complete_size", child.weight.size())} - + map_layer_attr[global_fqn] = { + "offset": getattr(child, "offset", [0] * len(child.weight.size())), + "complete_size": getattr(child, "complete_size", child.weight.size()), + } + else: full_name = f"{chunk_id}.{children_name}" - result = f"{children_name}.weight" + local_fqn = f"{children_name}.weight" assert getattr(children, "bias", None) is None - # print(f"result: {result}", flush=True) if isinstance(children, Embedding1D): setattr( children.weight, "tracking_name", f"{chunk_id}_embeddings.weight", ) - assert result not in map_layer_attr, f"{map_layer_attr} exists" - # map_layer_attr[result] = {'offset': children.offset, 'complete_size': children.weight.complete_size} + assert local_fqn not in map_layer_attr, f"{map_layer_attr} exists" else: setattr( children.weight, "tracking_name", f"{full_name}.weight", ) - assert result not in map_layer_attr, f"{map_layer_attr} exists" - # map_layer_attr[result] = {'offset': getattr(children, "offset", [0] * len(children.weight.size())), 'complete_size': getattr(children, "complete_size", children.weight.size())} - - setattr( - children.weight, - "fqn", - f"{result}", - ) - if getattr(children, "bias", None) is not None: - if children.bias is not None: - setattr( - children.bias, - "fqn", - f"{result}", - ) - - map_layer_attr[result] = {'offset': getattr(children, "offset", [0] * len(children.weight.size())), 'complete_size': getattr(children, "complete_size", children.weight.size())} - - # print(f"map_layer_attr global={gpc.get_global_rank()}, pp={gpc.get_local_rank(ParallelMode.PIPELINE)}, tp={gpc.get_local_rank(ParallelMode.TENSOR)}: {map_layer_attr}", flush=True) + assert local_fqn not in map_layer_attr, f"{map_layer_attr} exists" + + if uc_enable: + setattr( + children.weight, + "fqn", + f"{local_fqn}", + ) + if getattr(children, "bias", None) is not None: + if children.bias is not None: + setattr( + children.bias, + "fqn", + f"{local_fqn}", + ) + + map_layer_attr[local_fqn] = { + "offset": getattr(children, "offset", [0] * len(children.weight.size())), + "complete_size": getattr(children, "complete_size", children.weight.size()), + } def set_fp32_attr_for_model(model: Union[nn.Module, nn.ModuleList]): From 2a0bb72f02bb1300a1d768ddc94229a574ebc3c1 Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Wed, 25 Dec 2024 13:56:17 +0800 Subject: [PATCH 05/12] remove old --- .../checkpoint/vescale/_collective_utils.py | 457 --------- internlm/checkpoint/vescale/_utils.py | 463 --------- internlm/checkpoint/vescale/api.py | 53 - .../checkpoint/vescale/base_checkpointer.py | 77 -- internlm/checkpoint/vescale/bfile.py | 129 --- internlm/checkpoint/vescale/common.py | 132 --- internlm/checkpoint/vescale/device_mesh.py | 647 ------------ internlm/checkpoint/vescale/devicemesh_api.py | 475 --------- .../vescale/distributed_optimizer.py | 109 -- internlm/checkpoint/vescale/filesystem.py | 960 ------------------ .../checkpoint/vescale/load_state_dict.py | 102 -- internlm/checkpoint/vescale/mem_checkpoint.py | 396 -------- .../vescale/mem_file_service_pb2.py | 66 -- .../vescale/mem_file_service_pb2_grpc.py | 321 ------ internlm/checkpoint/vescale/mem_server_lib.py | 307 ------ internlm/checkpoint/vescale/meta_type.py | 38 - .../checkpoint/vescale/placement_types.py | 563 ---------- .../checkpoint/vescale/save_state_dict.py | 204 ---- .../vescale/vescale_checkpointer.py | 293 ------ .../checkpoint/vescale/vescale_planner.py | 288 ------ .../vescale/vescale_planner_helpers.py | 307 ------ 21 files changed, 6387 deletions(-) delete mode 100644 internlm/checkpoint/vescale/_collective_utils.py delete mode 100644 internlm/checkpoint/vescale/_utils.py delete mode 100644 internlm/checkpoint/vescale/api.py delete mode 100644 internlm/checkpoint/vescale/base_checkpointer.py delete mode 100644 internlm/checkpoint/vescale/bfile.py delete mode 100644 internlm/checkpoint/vescale/common.py delete mode 100644 internlm/checkpoint/vescale/device_mesh.py delete mode 100644 internlm/checkpoint/vescale/devicemesh_api.py delete mode 100644 internlm/checkpoint/vescale/distributed_optimizer.py delete mode 100644 internlm/checkpoint/vescale/filesystem.py delete mode 100644 internlm/checkpoint/vescale/load_state_dict.py delete mode 100644 internlm/checkpoint/vescale/mem_checkpoint.py delete mode 100644 internlm/checkpoint/vescale/mem_file_service_pb2.py delete mode 100644 internlm/checkpoint/vescale/mem_file_service_pb2_grpc.py delete mode 100644 internlm/checkpoint/vescale/mem_server_lib.py delete mode 100644 internlm/checkpoint/vescale/meta_type.py delete mode 100644 internlm/checkpoint/vescale/placement_types.py delete mode 100644 internlm/checkpoint/vescale/save_state_dict.py delete mode 100644 internlm/checkpoint/vescale/vescale_checkpointer.py delete mode 100644 internlm/checkpoint/vescale/vescale_planner.py delete mode 100644 internlm/checkpoint/vescale/vescale_planner_helpers.py diff --git a/internlm/checkpoint/vescale/_collective_utils.py b/internlm/checkpoint/vescale/_collective_utils.py deleted file mode 100644 index 98d43e482..000000000 --- a/internlm/checkpoint/vescale/_collective_utils.py +++ /dev/null @@ -1,457 +0,0 @@ -################################################################################ -# Copyright (c) Meta Platforms, Inc. and affiliates -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -################################################################################ -# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. -################################################################################ - -import logging -import math -import copy -from typing import List, Optional - -import torch -import torch.distributed._functional_collectives as funcol -import torch.distributed.distributed_c10d as c10d -from torch.distributed.distributed_c10d import ( - GroupMember, - ProcessGroup, - Work, - all_to_all, - get_global_rank, - get_rank, - broadcast, - get_world_size, - scatter, -) - -from .device_mesh import DeviceMesh, mesh_resources -from .placement_types import DTensorSpec -from internlm.utils.logger import get_logger -logger = get_logger(__file__) - -TORCH_VERSION_BIGGER_THAN_2_2 = torch.__version__ >= "2.2" - - -# NOTE: upstream are working to migrate the following three collective -# apis to be functional, pay attention to it. - - -def mesh_scatter( - output: torch.Tensor, - scatter_list: List[torch.Tensor], - mesh: DeviceMesh, - mesh_dim: int = 0, - async_op: bool = False, -) -> Optional[Work]: - """ - scatter a list of tensors to a device mesh dimension. We by default - use the first rank of the mesh dimension as the source of truth, i.e - for a 2d mesh [[0, 1], [2, 3]], if we scatter on mesh_dim = 1, we will - scatter the tensor list on rank 0 to rank 0/1, and tensor list on rank - 2 to rank 2/3. - - Args: - output (torch.Tensor): the tensor to receive the scattered list. - scatter_list (List[torch.Tensor]): the tensor list to be scattered. - mesh_dim (int, optional): indicate which mesh dimension we want - to scatter on, we by default choose the first rank on the - mesh dimension as source of truth. - - Returns: - A :class:`Work` object - """ - if DebugLogger.IS_DEBUG_MODE: - DebugLogger.log_communication(mesh_scatter, output, scatter_list, mesh_dim, async_op) - - # if rank is not part of mesh, simply return output - if mesh.get_coordinate() is None: - return output - - # TODO: Ideally we should use the meta tensor way - # (to register a meta kernel for the collective op) - # so that it would avoid the communication. Need to - # remove the check below once that is done. - if output.is_meta: - return None - - dim_group = mesh.get_dim_groups(mesh_dim) - assert isinstance(dim_group, ProcessGroup) - # src need to be global rank - src_for_dim = 0 - - if dim_group is not GroupMember.WORLD: - src_for_dim = get_global_rank(dim_group, 0) - - if src_for_dim == get_rank(): - fut = scatter( - output, - scatter_list=scatter_list, - src=src_for_dim, - group=dim_group, - async_op=async_op, - ) - else: - fut = scatter( - output, - scatter_list=None, - src=src_for_dim, - group=dim_group, - async_op=async_op, - ) - - return fut - - -# TODO: test uneven split on GLOO and NCCL - - -def mesh_all_to_all( - output_tensor_list: List[torch.Tensor], - input_tensor_list: List[torch.Tensor], - mesh: DeviceMesh, - mesh_dim: int = 0, - async_op: bool = False, -) -> Optional[Work]: - if DebugLogger.IS_DEBUG_MODE: - DebugLogger.log_communication(mesh_all_to_all, output_tensor_list, input_tensor_list, mesh, mesh_dim, async_op) - - # if rank is not part of mesh, simply return None - if mesh.get_coordinate() is None: - return None - - dim_group = mesh.get_dim_groups(mesh_dim) - assert isinstance(dim_group, ProcessGroup) - - work = None - # no direct dist.all_to_all support on 'gloo' so we manually do scatters - if mesh.device_type == "cpu": - logger.warning("ProcessGroupGloo does not support all_to_all, falling back with scatters!") - # TODO: pull the handle of uneven case in #492 - dim_group_size = get_world_size(dim_group) - for i in range(dim_group_size): - # src need to be global rank - src_for_dim = i - if dim_group is not GroupMember.WORLD: - src_for_dim = get_global_rank(dim_group, i) - - work = scatter( - output_tensor_list[i], - input_tensor_list if mesh.get_rank() == src_for_dim else [], - group=dim_group, - src=src_for_dim, - async_op=async_op, - ) - else: - work = all_to_all( - output_tensor_list, - input_tensor_list, - dim_group, - async_op=async_op, - ) - return work - - -def mesh_all_to_all_single( - tensor: torch.Tensor, - mesh: DeviceMesh, - original_shard_dim: int, - target_shard_dim: int, - mesh_dim: int = 0, - async_op: bool = False, -): - """ - transpose the sharded tensor along a device mesh dimension. - - Args: - tensor (torch.Tensor): tensor to all-to-all. - mesh (DeviceMesh): device mesh that communication happens. - original_shard_dim (int): the dim that source tensor is sharded - target_shard_dim (int): the dim that transposed tensor is sharded - mesh_dim (int, optional): indicate which mesh dimension we want - to broadcast on, we by default choose the first rank on the - mesh dimension as source of truth. - async_op (bool, default False): unused arguments. As all-to-all will - always be sync. - - Returns: - A :class:`Tensor` object - """ - if DebugLogger.IS_DEBUG_MODE: - DebugLogger.log_communication( - mesh_all_to_all_single, tensor, mesh, original_shard_dim, target_shard_dim, mesh_dim - ) - - # if rank is not part of mesh, simply return tensor, which should be an empty tensor - if mesh.get_coordinate() is None: - return tensor - mesh_size = mesh.size(mesh_dim) - assert tensor.size(target_shard_dim) % mesh_size == 0, "we don't support unvevn shard on ``target_shard_dim``" - input_rank = tensor.ndim - assert input_rank >= 2, "input must has at least 2 ranks" - - target_shape = copy.deepcopy(list(tensor.shape)) - target_shape[original_shard_dim] *= mesh_size - target_shape[target_shard_dim] //= mesh_size - - dim_group = mesh.get_dim_groups(mesh_dim) - assert isinstance(dim_group, ProcessGroup) - - if target_shard_dim != 0: - k_new_shape = list(tensor.shape) - k_new_shape[target_shard_dim] //= mesh_size - k_new_shape[0] *= mesh_size - new_shape = list(tensor.shape) - new_shape[target_shard_dim] //= mesh_size - new_shape.insert(target_shard_dim, mesh_size) - indices = ( - [target_shard_dim] + list(range(0, target_shard_dim)) + list(range(target_shard_dim + 1, tensor.ndim + 1)) - ) - tensor = tensor.reshape(new_shape).permute(indices).reshape(k_new_shape) - - output = funcol.all_to_all_single(tensor, output_split_sizes=None, input_split_sizes=None, group=dim_group) - if original_shard_dim == 0: - return output - - n, *out_shape = list(output.shape) - - indices = ( - list(range(1, original_shard_dim)) - + [original_shard_dim, 0] - + list(range(original_shard_dim + 1, output.ndim + 1)) - ) - - return output.reshape(mesh_size, n // mesh_size, *out_shape).permute(indices).reshape(target_shape) - - -def mesh_broadcast( - tensor: torch.Tensor, - mesh: DeviceMesh, - mesh_dim: int = 0, - async_op=False, -) -> torch.Tensor: - """ - broadcast the tensor to a device mesh dimension. We by default - use the first rank of the mesh dimension as the source of truth, i.e - for a 2d mesh [[0, 1], [2, 3]], if we broadcast on mesh_dim = 1, we will - broadcast the tensor on rank 0 to rank 0/1, and tensor on rank 2 - to rank 2/3. - - Args: - tensor (torch.Tensor): tensor to broadcast. - mesh_dim (int, optional): indicate which mesh dimension we want - to broadcast on, we by default choose the first rank on the - mesh dimension as source of truth. - - Returns: - A :class:`Tensor` object - """ - if DebugLogger.IS_DEBUG_MODE: - DebugLogger.log_communication(mesh_broadcast, tensor, mesh, mesh_dim, async_op) - - # if rank is not part of mesh, simply return tensor, which should be an empty tensor - if mesh.get_coordinate() is None: - return tensor - - dim_group = mesh.get_dim_groups(mesh_dim) - assert isinstance(dim_group, ProcessGroup) - # src need to be global rank - src_for_dim = 0 - if dim_group is not GroupMember.WORLD: - src_for_dim = get_global_rank(dim_group, 0) - if TORCH_VERSION_BIGGER_THAN_2_2: - aysnc_tensor = funcol.broadcast(tensor, src=src_for_dim, group=dim_group) - if not async_op: - return funcol.wait_tensor(aysnc_tensor) - return aysnc_tensor - else: - work = broadcast(tensor, src=src_for_dim, group=dim_group, async_op=async_op) - if not async_op: - return tensor - from torch.distributed._functional_collectives_impl import _register_tensor_work - from torch.distributed._functional_collectives import _maybe_wrap_tensor - - _register_tensor_work(tensor, work) - return _maybe_wrap_tensor(tensor) - - -def mesh_reduce_scatter( - tensor: torch.Tensor, mesh: DeviceMesh, reduce_op: c10d.ReduceOp.RedOpType, scatter_dim: int, mesh_dim: int -) -> torch.Tensor: - """ - First peform all_reduce on the tensor, then split the tensor at scatter_dim - and scatter them to a device mesh dimension. - """ - - if DebugLogger.IS_DEBUG_MODE: - DebugLogger.log_communication(mesh_reduce_scatter, tensor, mesh, reduce_op, scatter_dim, mesh_dim) - - # if rank is not part of mesh, simply return tensor, which should be an empty tensor - if mesh.get_coordinate() is None: - return tensor - - # for now, we only support that size at `scatter_dim`` is divisable by - # the mesh size at `mesh_dim` - num_chunks = mesh.size(dim=mesh_dim) - assert ( - tensor.size(scatter_dim) % num_chunks == 0 - ), f"tensor size at {scatter_dim} is not divisable by the mesh size at {mesh_dim}" - output = funcol.reduce_scatter_tensor( - tensor, reduceOp=reduce_op.name, scatter_dim=scatter_dim, group=mesh._dim_group_infos[mesh_dim][1] - ) - return output - - -def mesh_all_gather( - tensor: torch.Tensor, - global_size: torch.Size, - mesh: DeviceMesh, - scatter_dim: int, - mesh_dim: int, -) -> torch.Tensor: - """ - all_gather all shards and return a tensor that is replicated - on the previously sharded mesh dimension - """ - if DebugLogger.IS_DEBUG_MODE: - DebugLogger.log_communication(mesh_all_gather, tensor, global_size, mesh, scatter_dim, mesh_dim) - - # if rank is not part of mesh, simply return tensor, which should be an empty tensor - if mesh.get_coordinate() is None: - return tensor - - # for now, we only support that global size at `scatter_dim` is equal with - # the multuple of mesh size at `mesh_dim` and local_tensor size at `scatter_dim` - num_chunks = mesh.size(dim=mesh_dim) - assert ( - tensor.size(scatter_dim) * num_chunks == global_size[scatter_dim] - ), f"global tensor size at {scatter_dim} is not equal with the multiply of mesh size at {mesh_dim} and local_tensor size at {scatter_dim}" - tensor = tensor.contiguous() - output = funcol.all_gather_tensor(tensor, gather_dim=scatter_dim, group=mesh._dim_group_infos[mesh_dim][1]) - return output - - -def mesh_all_reduce( - tensor: torch.Tensor, - mesh: DeviceMesh, - reduce_op: c10d.ReduceOp.RedOpType, - mesh_dim: int, -) -> torch.Tensor: - # if rank is not part of mesh, simply return tensor, which should be an empty tensor - if mesh.get_coordinate() is None: - return tensor - - return funcol.all_reduce(tensor, reduceOp=reduce_op.name, group=mesh._dim_group_infos[mesh_dim][1]) - - -def wait(tensor: torch.Tensor) -> torch.Tensor: - if isinstance(tensor, funcol.AsyncCollectiveTensor): - return funcol.wait_tensor(tensor) - return tensor - - -def spec_to_bytes(spec: DTensorSpec) -> int: - assert spec.tensor_meta is not None, "spec should have tensor meta defined!" - return spec.tensor_meta.dtype.itemsize * math.prod(spec.shape) - - -def get_bandwidth_factor(mesh: DeviceMesh) -> List[float]: - # generate bandwidth factor for intra-host/inter-host communication pattern - factors = [1.0] * mesh.ndim - num_devices_per_host = mesh_resources.num_devices_per_host(mesh.device_type) - - num_devices = 1 - for mesh_dim in reversed(range(mesh.ndim)): - num_devices *= mesh.size(mesh_dim) - if num_devices <= num_devices_per_host: - # magic number for intra-host communication bandwidth factor - # TODO: see if we need to tweak this or offer a way for user - # to specify the bandwidths - factors[mesh_dim] = 0.2 - - return factors - - -def allgather_cost(num_bytes: float, mesh: DeviceMesh, mesh_dim: int) -> float: - num_devices_on_mesh_dim = mesh.size(mesh_dim) - bandwidth_factor = get_bandwidth_factor(mesh)[mesh_dim] - # constant latency factor + bandwidth cost - return 1 + bandwidth_factor * num_bytes * (num_devices_on_mesh_dim - 1) / num_devices_on_mesh_dim - - -def allreduce_cost(num_bytes: float, mesh: DeviceMesh, mesh_dim: int) -> float: - num_devices_on_mesh_dim = mesh.size(mesh_dim) - bandwidth_factor = get_bandwidth_factor(mesh)[mesh_dim] - # allreduce have 2x comm bytes compare to allgather/reduce_scatter - return 1 + 2 * bandwidth_factor * num_bytes * (num_devices_on_mesh_dim - 1) / num_devices_on_mesh_dim - - -def reduce_scatter_cost( - num_bytes: float, - mesh: DeviceMesh, - mesh_dim: int, -) -> float: - num_devices_on_mesh_dim = mesh.size(mesh_dim) - bandwidth_factor = get_bandwidth_factor(mesh)[mesh_dim] - # constant latency factor + bandwidth cost - return 1 + bandwidth_factor * num_bytes * (num_devices_on_mesh_dim - 1) / num_devices_on_mesh_dim - - -def redistribute_cost( - current_spec: DTensorSpec, - target_spec: DTensorSpec, -) -> float: - """ - This function returns the cost of redistribute from current to target DTensorSpec. - - NOTE: - 1. Only consider communication cost here, since computation costs for redistribute - are quite trival (i.e. we only need to narrow or simple division) - 2. Only consider redistribute cost on same mesh, cross mesh communication cost is - not quite needed for operator strategy estimation/selection. - """ - if current_spec.mesh != target_spec.mesh: - # make infinite cost if meshes are not same - # TODO: see if we want to support this once there's cross mesh communication - return float("inf") - - if current_spec.is_replicated(): - # short-cut: - # comm cost is 0 if current spec is already full replication - return 0.0 - - mesh = current_spec.mesh - cost = 0.0 - comm_bytes = spec_to_bytes(current_spec) / current_spec.num_shards - # Transformation that considered for redistribute cost: - # 1. allgather 2. alltoall - # 3. allreduce 4. reduce_scatter - for i, (current, target) in enumerate(zip(current_spec.placements, target_spec.placements)): - if current == target: - continue - if current.is_shard() and target.is_replicate(): - # allgather gives larger comm bytes - comm_bytes *= mesh.size(i) - # add up allgather comm cost - cost += allgather_cost(comm_bytes, current_spec.mesh, i) - elif current.is_shard() and target.is_shard(): - # should be alltoall comm, since we haven't implement it yet, add penalty - # to favor allgather instead - cost += allgather_cost(comm_bytes, current_spec.mesh, i) + 1.0 - elif current.is_partial() and target.is_replicate(): - # add up allreduce comm cost - cost += allreduce_cost(comm_bytes, current_spec.mesh, i) - elif current.is_partial() and target.is_shard(): - # add up reduce_scatter comm cost - cost += reduce_scatter_cost(comm_bytes, current_spec.mesh, i) - # after reduce_scatter the comm bytes for further collectives halved. - comm_bytes /= mesh.size(i) - elif current.is_shard() and target.is_partial(): - # ban shard/interleaved_shard -> partial as it does not make sense to perform - # this redistribute - return float("inf") - - return cost diff --git a/internlm/checkpoint/vescale/_utils.py b/internlm/checkpoint/vescale/_utils.py deleted file mode 100644 index 0d6f7d8c5..000000000 --- a/internlm/checkpoint/vescale/_utils.py +++ /dev/null @@ -1,463 +0,0 @@ -################################################################################ -# Copyright (c) Meta Platforms, Inc. and affiliates -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -################################################################################ -# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. -################################################################################ - -import warnings -import copy -from typing import List, Sequence, Tuple, Optional, Dict, Set, Union - -import torch -import torch.distributed._functional_collectives as funcol -from torch._prims_common import ShapeType - -from .device_mesh import DeviceMesh -from .placement_types import InterleavedShard, Partial, Placement, Replicate, Shard -from ._collective_utils import mesh_all_gather - - -def compute_local_shape(global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement]) -> Tuple[int, ...]: - """ - Compute the shape of a local shard of the given DTensor on its current - coordinate of the mesh. - """ - my_coordinate = mesh.get_coordinate() - - if my_coordinate is None: - # if rank not in the mesh, return empty shape - return () - else: - local_shape = list(global_shape) # start with global shape - ndim = len(global_shape) - for idx, placement in enumerate(placements): - mesh_dim_size = mesh.size(idx) - if isinstance(placement, (Shard, InterleavedShard)): - shard_dim = placement.dim - assert shard_dim < ndim, f"Sharding dim {shard_dim} greater than tensor ndim {ndim}" - local_shard_size, _ = placement._local_shard_size_on_dim( - local_shape[shard_dim], mesh_dim_size, my_coordinate[idx] - ) - assert isinstance(local_shard_size, int) - local_shape[shard_dim] = local_shard_size - - return tuple(local_shape) - - -def compute_local_shape_and_global_offset( - global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] -) -> Tuple[Tuple[int, ...], Tuple[int, ...]]: - """ - Compute the local tensor shape and the global offsets into the original tensor - of a DTensor on its current global rank. This is useful for checkpointing purpose. - Example (2 host with 4GPUs each): - # Below is a DeviceMesh with mesh_shape of (2, 4) - mesh = DeviceMesh(device_type="cuda", - mesh=[ - [0, 1, 2, 3], - [4, 5, 6, 7] - ], - ) - Let's say we distribute a global_tensor of shape (8,4) over the above DeviceMesh - with a placements of [Shard(0), Shard(0)]. - The local shape and global offset will be as follows: - rank0 -- local_shape:[1, 4], global_offset:[0, 0] - rank1 -- local_shape:[1, 4], global_offset:[1, 0] - rank2 -- local_shape:[1, 4], global_offset:[2, 0] - rank5 -- local_shape:[1, 4], global_offset:[5, 0] - rank3 -- local_shape:[1, 4], global_offset:[3, 0] - rank4 -- local_shape:[1, 4], global_offset:[4, 0] - rank6 -- local_shape:[1, 4], global_offset:[6, 0] - rank7 -- local_shape:[1, 4], global_offset:[7, 0] - """ - my_coordinate = mesh.get_coordinate() - - if my_coordinate is None: - # if rank not in the mesh, return empty offset - return ((), ()) - else: - local_shape = list(global_shape) - global_offset = [0] * len(global_shape) - - for idx, placement in enumerate(placements): - mesh_dim_size = mesh.size(idx) - if isinstance(placement, Shard): - shard_dim = placement.dim - local_offset = [0] * len(global_shape) - assert shard_dim < len( - local_shape - ), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" - # TODO: what if placement is InterleavedShard - shard_size, shard_offset = placement._local_shard_size_on_dim( - local_shape[shard_dim], - mesh_dim_size, - my_coordinate[idx], - return_offset=True, - ) - - local_shape[shard_dim] = shard_size - local_offset[shard_dim] = shard_offset - - # On a given dimension, if the local_offset[shard_dim] is smaller than global_offset[shard_dim], - # it means that this dimension has been already sharded in previous placement. - # Therefore, we cannot simply replace the global_offset[shard_dim] with local_offset[shard_dim]. - # Instead, for the given shard_dim, we need to add local_offset[shard_dim] to existing global_offset[shard_dim]. - if global_offset[shard_dim] <= local_offset[shard_dim]: - global_offset[shard_dim] = local_offset[shard_dim] - else: - global_offset[shard_dim] += local_offset[shard_dim] - - return tuple(local_shape), tuple(global_offset) - - -def is_same_shape_across_ranks(tensor_shape: ShapeType, device_mesh: DeviceMesh, placements: Sequence[Placement]): - # check if tensor shapes are the same across ranks - self_shape = torch.tensor([tuple(tensor_shape)], dtype=torch.int64, device=device_mesh.device_type) - for mesh_dim, _ in enumerate(placements): # TODO for perf: use a process group for the entire DeviceMesh - all_shapes = mesh_all_gather( - self_shape, - torch.Size([device_mesh.size(mesh_dim), self_shape.size(1)]), - device_mesh, - scatter_dim=0, - mesh_dim=mesh_dim, - ) - if not torch.all(self_shape == all_shapes): - return False - return True - - -def gather_local_tensor_shape( - self_local_tensor: Union[torch.Tensor, torch.Size], - device_mesh: DeviceMesh, - placements: Sequence[Placement], - shard_only: bool = False, -) -> Optional[Dict[int, List[List[int]]]]: - """All gather local tensor shapes per mesh dimension. - When `shard_only is True`, all gather only sharded mesh dim. Otherwise, all gather all mesh dims.""" - if device_mesh.get_coordinate() is None: # if rank is not part of mesh - return None - - _shape: torch.Size = self_local_tensor if isinstance(self_local_tensor, torch.Size) else self_local_tensor.shape - self_local_shape = torch.tensor([list(_shape)], dtype=torch.int64, device="cpu", pin_memory=True) - self_local_shape = self_local_shape.to(device_mesh.device_type, non_blocking=True) - meshdim_localtensor_shape = {} - for mesh_dim, place in enumerate(placements): - if shard_only and not isinstance(place, (Shard, InterleavedShard)): - continue - stacked_local_shape = mesh_all_gather( - self_local_shape, - torch.Size([device_mesh.size(mesh_dim), self_local_shape.size(1)]), - device_mesh, - scatter_dim=0, - mesh_dim=mesh_dim, - ) - if type(stacked_local_shape) is funcol.AsyncCollectiveTensor: - # synchronously wait for any pending collectives to get the result tensor - stacked_local_shape = stacked_local_shape.trigger_wait() - if hasattr(stacked_local_shape, "elem"): - stacked_local_shape = stacked_local_shape.elem # type: ignore[attr-defined] - - meshdim_localtensor_shape[mesh_dim] = stacked_local_shape.detach().cpu().tolist() - return meshdim_localtensor_shape - - -def compute_global_tensor_info( - tensor: torch.Tensor, - mesh: DeviceMesh, - placements: Sequence[Placement], - meshdim_localtensor_shape: Optional[Dict[int, List[List[int]]]] = None, -) -> Tuple[List[int], List[int]]: - """ - Compute the global size and stride of a DTensor from the given local tensor. - - When `meshdim_localtensor_shape` is None (be default): - The local size is multiplited by `world_size` per Sharding dim. - The local stride is multiplited by `world_size` per Sharding dim, as long as the - dimension is outside sharding dim. - - For example, if we have a local tensor with size (4, 8, 2) and stride (16, 1, 8). - If the DTensor placements are [Shard(2)] and world_size is 2; - then the global size is (4, 8, 4) and stride is (16 * 2, 1, 8). - - When `meshdim_localtensor_shape` is provided: - All local sizes are summed togather as global Sharding dim. - The local stride is scaled by global Sharding dim divided by local Sharding dim, - as long as the dimension is outside sharding dim. - - For example, if we have a local tensor with size (4, 8, 2) and stride (16, 1, 8) on rank0, - and a local tensor with size (4, 8, 1) and stride (8, 1, 8) on rank1. - If the DTensor placements are [Shard(2)] and world_size is 2; - then the global size is (4, 8, 3) and stride is (8 * 3, 1, 8). - - Args: - tensor (:class:`torch.Tensor`): - Local tensor which DTensor will be constructed from. - mesh (:class:`DeviceMesh`): - Object which describes the mesh topology of devices for the DTensor. - placements (Sequence[:class:`Placement`]]): - The attribute of the DTensor that describes its layout on the mesh topology. - meshdim_localtensor_shape (:class:`Dict[int, List[List[int]]]`): - Default None. - Otherwise, a given list for local tensor shapes per device mesh dim. - - Return: - tensor_shape: A List of int which specifies the global size of DTensor which build - on top of the local tensor. - tensor_stride: A List of int which specifies the global stride of DTensor. - """ - if meshdim_localtensor_shape is None: # assume even sharding (contiguous or non-contiguous) - tensor_shape = list(tensor.size()) - tensor_stride = list(tensor.stride()) - else: # support uneven sharding (contiguous only) - if not tensor.is_contiguous(): - warnings.warn( - "`from_local` take non-contiguous local tensor, which is not supported in uneven sharding. Treat as contiguous.", - UserWarning, - ) - # a meta empty tensor is created for obtaining correct local stride, - # especially when local tensor is non-contiguous or narrowed from padding or zero dimmed. - # TODO: rethink supporting non-contiguous local tensor which is narrowed from padding or zero dimmed. - tensor = torch.empty(tensor.shape, dtype=tensor.dtype, device="meta") - tensor_shape = list(tensor.size()) - tensor_stride = list(tensor.stride()) - - # record occured shard dim - shard_dim_occured: Set[int] = set() - - for idx, placement in enumerate(placements): - mesh_dim_size: int = mesh.size(idx) - - # TODO: rethink about this InterleavedShard. - if placement.is_shard() or placement.is_interleaved_shard(): - if placement.dim < 0: - placement.dim += len(tensor_shape) - shard_dim = placement.dim - - assert ( - shard_dim < tensor.ndim - ), f"Sharding dim {shard_dim} greater than tensor ndim {tensor.ndim} for placement number {idx}." - - if shard_dim in shard_dim_occured: - warnings.warn( - "Sharding the same tensor dim is not supported for uneven sharding. Treat as even sharding.", - UserWarning, - ) - is_shard_same_dim = True - else: - shard_dim_occured.add(shard_dim) - is_shard_same_dim = False - - # recover global shape - local_dim_size = tensor_shape[shard_dim] - if meshdim_localtensor_shape is None or is_shard_same_dim: - # duplicate local shape at this sharded dim as global shape - tensor_shape[shard_dim] = local_dim_size * mesh_dim_size - else: - # concat local shapes at this sharded dim as global shape - global_dim_size = sum(shape[shard_dim] for shape in meshdim_localtensor_shape[idx]) - tensor_shape[shard_dim] = global_dim_size - - # recover tensor stride by modifying the stride that larger than - # the current stride on the shard_dim - is_contiguous_tensor = all(tensor_stride[i] >= tensor_stride[i + 1] for i in range(len(tensor_stride) - 1)) - for i in range(len(tensor_stride)): - if (i != shard_dim and tensor_stride[i] >= tensor_stride[shard_dim] and not is_contiguous_tensor) or ( - i < shard_dim and is_contiguous_tensor - ): - # rescale the stride by the shard size - if meshdim_localtensor_shape is None or is_shard_same_dim: - tensor_stride[i] = tensor_stride[i] * mesh_dim_size - else: - if local_dim_size == 0: - tensor_stride[i] *= max(global_dim_size, 1) - else: - assert tensor_stride[i] % local_dim_size == 0 - tensor_stride[i] = tensor_stride[i] // local_dim_size * global_dim_size - - elif not isinstance(placement, (Replicate, Partial)): - raise RuntimeError(f"placement type {type(placement)} not supported!") - - return tensor_shape, tensor_stride - - -def is_zero_out_local_shard(mesh: DeviceMesh, placements: Sequence[Placement]) -> bool: - """ - Compute whether we need to zero out the local shard of current rank, for Partial(). - - e.g. we want a bias tensor in [Partial(), Shard(0), Partial()] - [ [[b1, 0.] - [b2, 0.]] - - [[0., 0.] - [0., 0.]] ] - on a 3D-DeviceMesh: - [ [[0, 1] - [2, 3]] - - [[4, 5] - [6, 7]] ] - The computed result should be: - [ [[False, True] - [False, True]] - - [[True, True] - [True, True]] ] - """ - my_coordinate = mesh.get_coordinate() - - if my_coordinate is None: # if rank not in the mesh, nothing to zero out - return False - - for idx, placement in enumerate(placements): - if not placement.is_partial(): - continue - # we zero out all other ranks of the current mesh dim - # and leave only src-of-truth rank 0 have the data, to perform a "zero cost" shard. - if my_coordinate[idx] != 0: - return True - - return False - - -def _equal_meta_data(dt1, dt2, exact_device: bool) -> bool: - if type(dt1).__name__ != "DTensor" or type(dt2).__name__ != "DTensor": - return False - # check itself - if exact_device and (dt1.device.type != dt2.device.type): - return False - if dt1.shape != dt2.shape: - return False - if dt1.dtype != dt2.dtype: - return False - if dt1.layout != dt2.layout: # torch.strided (dense) or torch.sparse_* - return False - if dt1.stride() != dt2.stride(): - return False - if dt1.requires_grad != dt2.requires_grad: - return False - # check global spec - if exact_device: - if dt1._spec.mesh != dt2._spec.mesh: - return False - else: - if not dt1._spec.mesh.mesh.equal(dt2._spec.mesh.mesh): - return False - if dt1._spec.placements != dt2._spec.placements: - return False - if dt1._spec.tensor_meta != dt2._spec.tensor_meta: - return False - # check local tensor (ref: https://github.com/pytorch/pytorch/blob/63ae1051e17b1cf4fe55ac6b6f17c16672d44150/aten/src/ATen/native/cuda/Equal.cpp#L15) - t1, t2 = dt1._local_tensor, dt2._local_tensor - if exact_device and (t1.device.type != t2.device.type): - return False - if t1.shape != t2.shape: - return False - if t1.dtype != t2.dtype: - return False - if t1.layout != t2.layout: # torch.strided (dense) or torch.sparse_* - return False - if t1.is_contiguous() != t2.is_contiguous(): - return False - if t1.stride() != t2.stride(): - return False - if t1.storage_offset() != t2.storage_offset(): - return False - if t1.requires_grad != t2.requires_grad: - return False - return True - - -def equal(dt1, dt2, exact_device: bool = True) -> bool: - """ - check if two DTensors are 'exactly' equal - """ - if not _equal_meta_data(dt1, dt2, exact_device): - return False - if dt1.is_meta and dt2.is_meta: - return True - if exact_device: - return torch.equal(dt1._local_tensor, dt2._local_tensor) # check value only - else: - return torch.equal(dt1._local_tensor.cpu(), dt2._local_tensor.cpu()) # check value only - - -def allclose( - dt1, - dt2, - rtol: float = 1e-05, - atol: float = 1e-08, - equal_nan: bool = False, - exact_device: bool = True, -) -> bool: - """ - check if two DTensors are 'allclose' - """ - if not _equal_meta_data(dt1, dt2, exact_device): - return False - if dt1.is_meta and dt2.is_meta: - return True - if exact_device: - return torch.allclose( - dt1._local_tensor, dt2._local_tensor, rtol=rtol, atol=atol, equal_nan=equal_nan - ) # check value only - else: - return torch.allclose( - dt1._local_tensor.cpu(), dt2._local_tensor.cpu(), rtol=rtol, atol=atol, equal_nan=equal_nan - ) # check value only - - -def compute_local_offset(global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement]) -> Tuple[int, ...]: - """ - Compute the offsets of a local shard of the given "DTensor" on its current - global rank. This is mostly used by distributed checkpointing to know the - exact offsets of the local shard. - """ - my_coordinate = mesh.get_coordinate() - - if my_coordinate is None: - # if rank not in the mesh, return empty offset - return () - else: - local_offsets = [0] * len(global_shape) - local_shape = list(global_shape) - - for idx, placement in enumerate(placements): - mesh_dim_size = mesh.size(idx) - if isinstance(placement, Shard): - shard_dim = placement.dim - assert shard_dim < len( - local_shape - ), f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" - shard_size, shard_offset = placement._local_shard_size_on_dim( - local_shape[shard_dim], - mesh_dim_size, - my_coordinate[idx], - return_offset=True, - ) - local_shape[shard_dim] = shard_size - local_offsets[shard_dim] = shard_offset - return tuple(local_offsets) - - -def compute_global_stride( - local_tensor: torch.Tensor, mesh: DeviceMesh, placements: Sequence[Placement] -) -> Tuple[int, ...]: - """ """ - my_coordinate = mesh.get_coordinate() - if my_coordinate is None: - return () - if not local_tensor.is_contiguous(): - raise RuntimeError("local tensor should be contiguous") - global_stride = copy.deepcopy(list(local_tensor.stride())) - for i, p in enumerate(placements): - if not p.is_shard(): - continue - shard_dim = p.dim - shard_size = mesh.size(i) - for j in range(shard_dim): - global_stride[j] *= shard_size - return tuple(global_stride) diff --git a/internlm/checkpoint/vescale/api.py b/internlm/checkpoint/vescale/api.py deleted file mode 100644 index 7d59db1db..000000000 --- a/internlm/checkpoint/vescale/api.py +++ /dev/null @@ -1,53 +0,0 @@ -################################################################################ -# Copyright (c) Meta Platforms, Inc. and affiliates -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -################################################################################ -# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. -################################################################################ -# The "checkpoint" folder is ONLY USED for "open source" version veScale - -from .vescale_checkpointer import VeScaleCheckpointer -from .meta_type import CheckpointState - - -def save(path: str, checkpoint_state: CheckpointState, async_checkpoint=False): - """ - Save a checkpoint to a given path - Args: - path: Defines the storage path for checkpoint. - checkpoint_state: A dictionary contains key-value pairs for model and optimizer. - - Model: Identified by 'model' key, value should be a model instance. - - Optimizer: Identified by 'optimizer' key, value should be an optimizer instance. - async_checkpoint: A boolean value indicating if saving checkpoint asynchronously, - i.e. after dumping tensors from GPU memory to Host memory, - the training program can continue training immediately. - Then vescale.checkpoint will serialize tensors and dumping to the persistent storage asynchronously. - Example: - >>> checkpoint_state = { "model": distributd_model, "optimizer": distributed_optimizer } - >>> vescale.checkpoint.save("/user/vescale/gpt/", checkpoint_state) - """ - VeScaleCheckpointer.save(path, checkpoint_state, async_checkpoint=async_checkpoint) - - -def load(path: str, checkpoint_state: CheckpointState, broadcast_checkpoint=False): - """ - Load a checkpoint from a given path - Args: - path: Defines the storage path for checkpoint. - checkpoint_state: A dictionary contains key-value pairs for model and optimizer. - - Model: Identified by 'model' key, value should be a model instance. - - Optimizer: Identified by 'optimizer' key, value should be an optimizer instance. - broadcast_checkpoint: A boolean value decides if load a model replica from one data parallel process group - then broadcast tensors to other data parallel process group using GPUs - to reduce the file system access - For example, when data parellel size = 2, - processes with data parallel rank = 0 load model from file system - then broadcast it to processes with data parallel rank = 1 - Example: - >>> checkpoint_state = { "model": distributd_model, "optimizer": distributed_optimizer } - >>> vescale.checkpoint.load("/user/vescale/gpt/", checkpoint_state) - """ - VeScaleCheckpointer.load(path, checkpoint_state, broadcast_checkpoint=broadcast_checkpoint) diff --git a/internlm/checkpoint/vescale/base_checkpointer.py b/internlm/checkpoint/vescale/base_checkpointer.py deleted file mode 100644 index 293f9c003..000000000 --- a/internlm/checkpoint/vescale/base_checkpointer.py +++ /dev/null @@ -1,77 +0,0 @@ -################################################################################ -# -# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -################################################################################ -from .meta_type import CheckpointState -from typing import Dict, List -from concurrent.futures import Future, ProcessPoolExecutor -from torch.distributed.checkpoint.storage import WriteResult -from .meta_type import MODEL_STR, OPTIMIZER_STR - -SUPPORTED_TYPES = {MODEL_STR, OPTIMIZER_STR} - - -class BaseCheckpointer: - """ - The Checkpointer class offers APIs that enable users to save and load state dictionarie. - It is designed for extension across various training frameworks. - """ - - # Async IO related members. - state_io_workers: Dict[str, ProcessPoolExecutor] = {} - state_write_futures: Dict[str, Future[List[WriteResult]]] = {} - - @classmethod - def save(cls, path: str, checkpoint_state: CheckpointState): - """ - A Method for saving checkpoint - Args: - path: Defines the storage path for checkpoint. - checkpoint_state: A dictionary contains key-value pairs for model and optimizer. - - Model: Identified by 'model' key, value should be a model instance. - - Optimizer: Identified by 'optimizer' key, value should be an optimizer instance. - - """ - raise NotImplementedError() - - @classmethod - def load(cls, path: str, checkpoint_state: CheckpointState): - """ - A Method for loading checkpoint - Args: - path: Defines the storage path for checkpoint. - checkpoint_state: A dictionary contains key-value pairs for model and optimizer. - - Model: Identified by 'model' key, value should be a model instance. - - Optimizer: Identified by 'optimizer' key, value should be an optimizer instance. - - """ - raise NotImplementedError() - - @classmethod - def _cleanup_futures(cls): - """ - Wait for all write futures to finish before exit, then do the cleanup works. - - WARNING: this method cannot be called by the users. - """ - for key in SUPPORTED_TYPES: - if key in cls.state_write_futures: - futures = cls.state_write_futures[key] - for fut in futures: - fut.result() - cls.state_write_futures[key] = [] - if cls.state_io_workers[key] is not None: - cls.state_io_workers[key].shutdown() - cls.state_io_workers[key] = None diff --git a/internlm/checkpoint/vescale/bfile.py b/internlm/checkpoint/vescale/bfile.py deleted file mode 100644 index f3ba4e331..000000000 --- a/internlm/checkpoint/vescale/bfile.py +++ /dev/null @@ -1,129 +0,0 @@ -################################################################################ -# -# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -################################################################################ -# Existing APIs all follow the rule here: -# https://www.tensorflow.org/api_docs/python/tf/io/gfile - - -import os -import enum -import contextlib -import uuid -from internlm.utils.logger import get_logger -import shutil -from . import mem_server_lib - -logger = get_logger(__file__) -BFILE_DEFAULT_TIMEOUT = None - - -class FileType(enum.Enum): - LOCAL = 0 - LOCAL_MEM = 1 - - -def local_list_folder(folder_path: str, recursive: bool = False): - file_paths = [] - if recursive: - for root, _, files in os.walk(folder_path): - for file_name in files: - file_path = os.path.join(root, file_name) - file_paths.append(file_path) - else: - if os.path.isdir(folder_path): - file_paths.extend([os.path.join(folder_path, d) for d in os.listdir(folder_path)]) - elif os.path.isfile(folder_path): - file_paths.append(folder_path) - else: - logger.warning(f"Path {folder_path} is invalid") - - return file_paths - - -def get_schema(path: str): - if path.startswith(mem_server_lib.SCHEMA): - return FileType.LOCAL_MEM - return FileType.LOCAL - - -def rename(src, dst, overwrite=False): - t = get_schema(src) - if t == FileType.LOCAL_MEM: - return mem_server_lib.rename(src, dst, overwrite) - return os.rename(src, dst) - - -def listdir(path): - t = get_schema(path) - if t == FileType.LOCAL_MEM: - return mem_server_lib.listdir(path) - absolute_files = local_list_folder(path) - return [f[f.rfind("/") + 1 :] for f in absolute_files] - - -def remove(path): - t = get_schema(path) - if t == FileType.LOCAL_MEM: - return mem_server_lib.remove(path) - return shutil.rmtree(path, ignore_errors=True) - - -def exists(path): - t = get_schema(path) - if t == FileType.LOCAL_MEM: - return mem_server_lib.exists(path) - return os.path.exists(path) - - -def makedirs(path): - t = get_schema(path) - if t == FileType.LOCAL_MEM: - # Local mem doesn't have empty folder - return - return os.makedirs(path, exist_ok=True) - - -@contextlib.contextmanager -def BFile(name, mode="r"): - t = get_schema(name) - if t == FileType.LOCAL_MEM: - with mem_server_lib.open(name, mode) as f: - yield f - else: - with open(name, mode) as f: - yield f - - -# ---- Below is some useful utilities ----- - - -def atomic_write(path: str, content: bytes, **kwargs): - tmp_path = path + "_tmp_" + str(uuid.uuid4()) - with BFile(tmp_path, "wb", **kwargs) as f: - f.write(content) - rename(tmp_path, path, overwrite=True) - - -def safe_atomic_write(path: str, content: bytes, **kwargs): - makedirs(os.path.dirname(path)) - atomic_write(path, content, **kwargs) - - -def is_local_path(path: str): - t = get_schema(path) - if t == FileType.LOCAL_MEM or t == FileType.LOCAL: - return True - return False diff --git a/internlm/checkpoint/vescale/common.py b/internlm/checkpoint/vescale/common.py deleted file mode 100644 index d69f29a95..000000000 --- a/internlm/checkpoint/vescale/common.py +++ /dev/null @@ -1,132 +0,0 @@ -################################################################################ -# -# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -################################################################################ -import dataclasses -from typing import Any, Dict, List, Tuple, Hashable, Optional -from collections import OrderedDict -from torch.distributed.checkpoint.planner import SavePlan -from torch.distributed.checkpoint.metadata import MetadataIndex, Metadata -import collections -from internlm.utils.logger import get_logger - -logger = get_logger(__file__) - - -@dataclasses.dataclass -class P2PTensorsInfo: - """ - Record data about tesnors which are across dp ranks - recv_tensors: A dictionary - Key: fqn - Value: a dictionary - key is the process rank, - value is a tuple with (tensor, 1d_range) - send_p2p_reqs: a list of p2p send requests to wait - recv_p2p_reqs: a list p2p receive requests to wait - """ - - recv_tensors: Dict[str, Any] - send_p2p_reqs: List[Any] - recv_p2p_reqs: List[Any] - - -def sort_rank_ranges(process_list: List[Tuple]) -> List[Tuple]: - """ - Decide which rank is receiver and writer - Let rank with most parameters receives and writes tensors - for the best communication cost - If two ranks has the same data size, choose the smaller rank - Args: - A process list with tuples, each tuple is (rank, data_size) - Returns: - A sorted list, data size are sorted in descending order, - if two ranks has the same data size, ranks are in the asceonding order - """ - sorted_process_list = sorted(process_list, key=lambda x: (-x[1], x[0])) - return sorted_process_list - - -_MAX_CACHE_SIZE = 2 # model ckpt + optm ckpt - - -class PlanLRUCache: - def __init__(self) -> None: - self._cache: OrderedDict[Hashable, Tuple[SavePlan, Metadata]] = OrderedDict() - self._capacity = _MAX_CACHE_SIZE - - def get(self, key: Hashable) -> Optional[Tuple[SavePlan, Metadata]]: - if key in self._cache: - return self._cache[key] - else: - return None - - def put(self, key: Hashable, plan_value: SavePlan, metadata_value: Metadata) -> None: - if key in self._cache: - self._cache.move_to_end(key, last=False) - else: - self._cache[key] = (plan_value, metadata_value) - if len(self._cache) > self._capacity: - self._cache.popitem() - - def clear(self) -> None: - self._cache.clear() - self._capacity = _MAX_CACHE_SIZE - - def __repr__(self) -> str: - return f"PlanLURCache(capacity: {self._capacity}, keys: {tuple(self._cache.keys())})" - - -def custom_dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]: - """ - A function to remove duplicate tensors to write - when creating global writing plan for saving checkpoint - During the deduplication, - we balance the workloads for duplicated tensors - """ - key_to_plan: Dict[MetadataIndex, List[int]] = {} - for plan_idx, plan in enumerate(all_plans): - for write_item in plan.items: - key_to_plan.setdefault(write_item.index, []).append(plan_idx) - - replicated_items = {k: v for k, v in key_to_plan.items() if len(v) > 1} - # Remove duplicates by always keeping the first entry (Not balance). - # Compute the per-rank remove set. - plan_to_keys: Dict[int, List[MetadataIndex]] = {} - # Record the number of non-duplicated tensors assigned to each rank - assigned_work_load = collections.defaultdict(int) - for plan_idx, plan in enumerate(all_plans): - for write_item in plan.items: - if write_item.index not in replicated_items: - assigned_work_load[plan_idx] += 1 - - for key, plans in replicated_items.items(): - # For duplicated tensors, select the rank assigned with minimum number tensors so far - writer_id = min(plans, key=lambda k: assigned_work_load[k]) - assigned_work_load[writer_id] += 1 - for plan_idx in plans: - # If the rank is not writer rank, remove the key in the rank's plan - if plan_idx != writer_id: - plan_to_keys.setdefault(plan_idx, []).append(key) - logger.info("Duplicate keys to remove: %s", plan_to_keys) - - for plan_idx, keys in plan_to_keys.items(): - # Key Set contains keys to remove - key_set = set(keys) - # rewrite items and remove elements - new_items = [write_item for write_item in all_plans[plan_idx].items if write_item.index not in key_set] - all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items) - - return all_plans diff --git a/internlm/checkpoint/vescale/device_mesh.py b/internlm/checkpoint/vescale/device_mesh.py deleted file mode 100644 index e578471d2..000000000 --- a/internlm/checkpoint/vescale/device_mesh.py +++ /dev/null @@ -1,647 +0,0 @@ -################################################################################ -# Copyright (c) Meta Platforms, Inc. and affiliates -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -################################################################################ -# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. -################################################################################ - -import logging -import math -import warnings -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union - -import torch -import torch.distributed._functional_collectives as funcol -from torch.distributed.distributed_c10d import ( - ProcessGroup, - _find_pg_by_ranks_and_tag, - _get_default_group, - _get_group_size, - _get_group_tag, - get_process_group_ranks, - get_rank, - get_world_size, - init_process_group, - is_initialized, - new_group, -) - -from internlm.utils.logger import get_logger -logger = get_logger(__file__) - -# only import numpy typing when type checking -if TYPE_CHECKING: - try: - from numpy.typing import ArrayLike - except ImportError: - logger.warning("DeviceMesh requires numpy >= 1.21 to be installed for type checking") - - -class _MeshEnv: - def __init__(self) -> None: - self.mesh_stack: List[DeviceMesh] = [] - self.child_to_parent_mapping: Dict[DeviceMesh, DeviceMesh] = {} - - def get_current_mesh(self) -> "DeviceMesh": - if len(self.mesh_stack) == 0: - raise RuntimeError("No device mesh is currently active!") - return self.mesh_stack[-1] - - def create_child_mesh(self, device_mesh: "DeviceMesh", mesh_dim: int, mesh_dim_name: str) -> "DeviceMesh": - # swap the current dim to the last dim then reshape to flatten out other - # dims, so we can just extract the list of ranks which contains cur_rank. - cur_rank = device_mesh.get_rank() - pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape(-1, device_mesh.mesh.size(mesh_dim)) - - for mesh_1d in pg_ranks_by_dim: - sub_mesh = DeviceMesh( - device_mesh.device_type, - mesh_1d, - mesh_dim_names=(mesh_dim_name,), - _init_process_groups=False, - ) - if cur_rank in mesh_1d: - res_sub_mesh = sub_mesh - - res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]] - # Assign the current DeviceMesh as the parent of the child DeviceMesh. - self.child_to_parent_mapping[res_sub_mesh] = device_mesh - return res_sub_mesh - - def create_submesh_along_multi_dims( - self, device_mesh: "DeviceMesh", mesh_dims: List[int], cur_rank: int = None - ) -> "DeviceMesh": - # swap the current dim to the last dim then reshape to flatten out other - # dims, so we can just extract the list of ranks which contains cur_rank. - # check dims - dim_size = [-1] - for dim in mesh_dims: - if dim >= device_mesh.ndim: - raise RuntimeError("Mesh dim in sub groups out of range!") - dim_size.append(device_mesh.mesh.size(dim)) - mesh_tensor = device_mesh.mesh - for dim in mesh_dims: - mesh_tensor = mesh_tensor.swapdims(-1, dim) - if cur_rank is None: - cur_rank = device_mesh.get_rank() - pg_ranks_by_dims = mesh_tensor.reshape(dim_size) - for mesh_nd in pg_ranks_by_dims: - sub_mesh = DeviceMesh( - device_mesh.device_type, - mesh_nd, - _init_process_groups=False, - ) - if cur_rank in mesh_nd: - res_sub_mesh = sub_mesh - res_sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[dim] for dim in mesh_dims] - self.child_to_parent_mapping[res_sub_mesh] = device_mesh - return res_sub_mesh - - def create_submesh_group(self, device_mesh: "DeviceMesh", mesh_dim: int) -> "DeviceMesh": - # swap the current dim to the last dim then reshape to flatten out other - # dims, so we can just extract the list of ranks which contains cur_rank. - # check dims - pg_ranks_by_dim = device_mesh.mesh.swapdims(-1, mesh_dim).reshape(-1, device_mesh.mesh.size(mesh_dim)) - res = [] - for mesh_1d in pg_ranks_by_dim: - sub_mesh = DeviceMesh( - device_mesh.device_type, - mesh_1d, - _init_process_groups=False, - ) - sub_mesh._dim_group_infos = [device_mesh._dim_group_infos[mesh_dim]] - # Assign the current DeviceMesh as the parent of the child DeviceMesh. - self.child_to_parent_mapping[sub_mesh] = device_mesh - res.append(sub_mesh) - return res - - def get_parent_mesh(self, device_mesh: "DeviceMesh") -> Optional["DeviceMesh"]: - return self.child_to_parent_mapping.get(device_mesh, None) - - def get_parent_mesh_dim(self, device_mesh: "DeviceMesh") -> Optional[int]: - """ - Return the index of the mesh dim in the parent mesh. - The device_mesh passed in needs to be sliced out from a parent mesh. - """ - parent_mesh = self.get_parent_mesh(device_mesh) - child_mesh_dim_names = device_mesh.mesh_dim_names - if parent_mesh and child_mesh_dim_names: - assert len(child_mesh_dim_names) == 1, "The child mesh can only be a 1D mesh." - child_mesh_dim_name = child_mesh_dim_names[0] - if parent_mesh.mesh_dim_names: - return parent_mesh.mesh_dim_names.index(child_mesh_dim_name) - return None - - @staticmethod - def num_devices_per_host(device_type: str) -> int: - return _get_device_handle(device_type).device_count() - - @staticmethod - def num_hosts(device_type: str) -> int: - # ProcessGroup can't tell us this info so we have to infer it, assume - # homogeneous hardware for now - return get_world_size() // _MeshEnv.num_devices_per_host(device_type) - - -mesh_resources: _MeshEnv = _MeshEnv() - - -def _get_device_handle(device_type: str = "cuda"): - """ - Get the module corresponding to the device_type which is cuda or cuda-like device. - For example, when the device_type is cuda, the module `torch.cuda` is returned. - Return None when there is no corresponding module for device_type, otherwise - return the corresponding module. - """ - return getattr(torch, device_type, None) - - -class DeviceMesh: - """ - DeviceMesh represents a mesh of devices (given by `device_type`), where layout - of devices could be represented as a n-d dimension array `mesh`, and each value - of the `mesh` is the global rank in the default process group. - - DeviceMesh could be used to describe the layout of devices across the cluster - via `mesh_dim_names`, and serves as a proxy for communication among the device lists - within the cluster. - - By default (`pg` is `None`), we use the default ProcessGroup in this DeviceMesh class - to implement proper communications. Note that we also add collective wrappers in this - class. This is used to decouple detailed communication backend with the underlying - DTensor implementation. - - By giving an existing ProcessGroup `pg`, we construct a device mesh from this `pg`, - instead of the default ProcessGroup. - - Here are the expected behaviors: - | `mesh` | `pg` | result | catch - --------------------------------------------------------------------------------------------- - | None | None | raise error! | - | EXIST | None | use `mesh` + default ProcessGroup | - | None | EXIST | use `pg`'s ranks + `pg` ProcessGroup | 1D mesh only - | EXIST | EXIST | use `pg`'s ranks + `pg` ProcessGroup | `mesh` must equal to `pg`'s ranks - - Args: - device_type (str): device type of the mesh. Currently supports: cpu, cuda/cuda-like, meta. - mesh (ndarray): could be a multi-dimension array or an integer tensor that - describes the layout of devices, the ids are global ids of the default process group. - mesh_dim_names (Optional[Tuple[str]]): A tuple of mesh dim names to be assigned to each - dimension of the multi-dimensional array that describes the layout of devices. Its - length must match the length of `mesh_shape`. Each string in mesh_dim_names must be unique. - pg (Optional[ProcessGroup]): the given ProcessGroup. See above for expected behaviors. - - Returns: - A :class:`DeviceMesh` object - - Example (2 host with 4 GPUs each): - ``` - # The following program runs on each process/rank in SPMD manner. - # initialize device mesh as (2, 4) to represent the topology - # of cross-host(dim 0), and within-host (dim 1) - mesh = DeviceMesh(device_type="cuda", - mesh=[ - [0, 1, 2, 3], - [4, 5, 6, 7] - ]) - ``` - A reduction over the first dimension of mesh will reduce across - columns (0, 4), .. and (3, 7), a reduction over the second dimension - of mesh reduces across rows (0, 1, 2, 3) and (4, 5, 6, 7) - - Note: - DeviceMesh can be used as a context manager. - """ - - device_type: str - mesh: Optional[Union[torch.Tensor, "ArrayLike"]] - mesh_dim_names: Optional[Tuple[str, ...]] - - def __init__( - self, - device_type: str, - mesh: Optional[Union[torch.Tensor, "ArrayLike"]] = None, - *, - mesh_dim_names: Optional[Tuple[str, ...]] = None, - pg: Optional[ProcessGroup] = None, - _validate_mesh: bool = True, - _init_process_groups: bool = True, - ) -> None: - # for performance, update debug env once here - # check args - if mesh is None and pg is None: - raise ValueError("Either `mesh` or `pg` must be provided!") - if mesh is not None and pg is not None: - pg_mesh_tensor = torch.tensor(get_process_group_ranks(pg), dtype=torch.int, device="cpu") - mesh_tensor = ( - mesh.detach().cpu() - if isinstance(mesh, torch.Tensor) - else torch.tensor(mesh, dtype=torch.int, device="cpu") - ) - if not torch.equal(mesh_tensor, pg_mesh_tensor): - raise ValueError(f"mesh({mesh_tensor}) and pg({pg_mesh_tensor}) must have the same content!") - if pg is not None: - self.mesh = torch.tensor(get_process_group_ranks(pg), dtype=torch.int, device="cpu") - warnings.warn("Construction from given ProcessGroup is only supported for 1D mesh currently.") - # TO FIX: use `mesh` to reshape `pg_mesh_tensor` for nD mesh tensor - if mesh is not None: - self.mesh = ( - mesh.detach().cpu() - if isinstance(mesh, torch.Tensor) - else torch.tensor(mesh, dtype=torch.int, device="cpu") - ) - - self.device_type = device_type - self.mesh_dim_names = mesh_dim_names - - # private field to pre-generate DeviceMesh's hash - self._flatten_mesh_list = tuple(self.mesh.flatten().tolist()) - self._hash = hash((self._flatten_mesh_list, self.mesh.shape)) - - # step 1: try to create default world pg. - if pg is None: - pg = self._get_or_create_default_group() - else: - # TODO: this logic only applies when device_type is cuda - pg_world_size = get_world_size(group=pg) - device_handle = _get_device_handle(self.device_type) - num_devices_per_host = device_handle.device_count() - if pg_world_size > num_devices_per_host and pg_world_size % num_devices_per_host != 0: - raise RuntimeError( - f"DeviceMesh only support homogeneous hardware, but found " - f"{pg_world_size} ranks and {num_devices_per_host} {self.device_type} devices!" - ) - if self.device_type == "cuda": - - def _get_current_device(): - try: - if torch.cuda.is_available(): - return torch.cuda.current_device() - else: - return None - except AssertionError as e: - return None - - device_handle = _get_device_handle(self.device_type) - num_devices_per_host = device_handle.device_count() - local_rank = get_rank() % num_devices_per_host - if local_rank != _get_current_device(): - warnings.warn("Remember to set cuda device id to local rank!!!") - device_handle = _get_device_handle(self.device_type) - device_handle.set_device(local_rank) - - # step 2: validate the mesh before following usage. - if _validate_mesh: - self._validate_mesh(pg) - - # step 3: get coordinate of current global rank on the mesh. - # The world pg is used for device mesh identity (rank) on each - # process (we need to know if the current global rank is in the mesh or not) - rank_coords = (self.mesh == get_rank()).nonzero() - assert rank_coords.size(0) in (0, 1) - self._coordinate_on_dim: Optional[List[int]] = rank_coords[0].tolist() if rank_coords.size(0) > 0 else None - - # step 4: init multi subprocess group for the mesh object. - if _init_process_groups: - self._init_process_groups(pg) - - def _get_or_create_default_group(self): - default_initialized = is_initialized() - if not default_initialized: - init_process_group() - - world_size = get_world_size() - if self.mesh.numel() > world_size: - raise RuntimeError( - f"Mesh should not be bigger than default world size, but found {self.mesh.numel()} ranks!" - ) - - device_handle = _get_device_handle(self.device_type) - # TODO: if user want to pass pg_options, offer a way to do it - if not default_initialized and device_handle: - # automatically set the current cuda/cuda-like device base on num of gpu devices available in each host - # NOTE: This device selection would only work for homogeneous hardware. - num_devices_per_host = device_handle.device_count() - if world_size > num_devices_per_host and world_size % num_devices_per_host != 0: - raise RuntimeError( - f"DeviceMesh only support homogeneous hardware, but found " - f"{world_size} ranks and {num_devices_per_host} {self.device_type} devices!" - ) - device_handle.set_device(get_rank() % num_devices_per_host) - - return _get_default_group() - - def _validate_mesh(self, pg: ProcessGroup): - # validate rank uniqueness in mesh tensor - unique_mesh_values = self.mesh.unique(sorted=True) - if unique_mesh_values.numel() != self.mesh.numel(): - raise RuntimeError(f"DeviceMesh cannot have duplicate values, but found {self.mesh.tolist()}") - # validate size - if self.mesh.numel() > _get_group_size(pg): - raise RuntimeError( - f"DeviceMesh should not be bigger than world (group) size, but found {self.mesh.numel()} and {_get_group_size(pg)}" - ) - # validate that all calling ranks pass in the same `mesh` argument. - self_mesh = self.mesh.to(self.device_type).contiguous() - mesh_tensor = funcol.all_gather_tensor(self_mesh, gather_dim=0, group=pg) - mesh_tensor_chunked = torch.chunk(mesh_tensor, _get_group_size(pg)) - # aten.equal not supported for meta device - if self.device_type == "meta": - return - for other_rank, other_mesh in enumerate(mesh_tensor_chunked): - if not torch.equal(self_mesh, other_mesh): - raise RuntimeError( - f"DeviceMesh initialization does not allow different mesh argument:" - f"rank {get_rank()} has mesh {self_mesh} while rank {get_process_group_ranks(pg)[other_rank]}" - f"has mesh {other_mesh}!" - ) - - def _init_process_groups(self, pg: ProcessGroup): - # group tag/ranks associated with each mesh dimension, each mesh dimension should - # have one sub-group per rank - dim_group_infos: List[Tuple[str, List[int]]] = [] - - if self.mesh.ndim == 1 and self.mesh.numel() == _get_group_size(pg): - # if the mesh is the same as the given group, we just append the given - # pg to the first dim groups. - dim_group_infos.append((_get_group_tag(pg), get_process_group_ranks(pg))) - else: - # create sub pgs base on the mesh argument specified - for dim in range(self.mesh.ndim): - # swap the current dim to the last dim - # then reshape to flatten out other dims - pg_ranks_by_dim = self.mesh.swapdims(-1, dim).reshape(-1, self.mesh.size(dim)) - # multi-dim mesh, create subgroups by looping over the pg_ranks - # for each dim and append the groups - for dim_mesh in pg_ranks_by_dim: - subgroup_ranks = dim_mesh.tolist() - # call new_group regardless of the current rank in the - # pg or not, it's required that all ranks participate - # in subgroup construction - dim_group = new_group(ranks=subgroup_ranks) - # only add to dim_groups if the current rank in the subgroup - if self.get_rank() in subgroup_ranks: - if len(dim_group_infos) > dim: - raise RuntimeError( - f"Each device mesh dimension should get only one process group, but got {self.get_rank} " - f"in {subgroup_ranks}!" - ) - dim_group_infos.append((_get_group_tag(dim_group), subgroup_ranks)) - self._dim_group_infos = dim_group_infos - - def __enter__(self) -> "DeviceMesh": - # set this mesh as the current mesh in mesh env - mesh_resources.mesh_stack.append(self) - return self - - # pyre-fixme[2]: Parameter must be annotated. - def __exit__(self, exc_type, exc_value, exc_traceback) -> None: - # pop this mesh from mesh env - mesh_resources.mesh_stack.pop() - - def __repr__(self) -> str: - return f"DeviceMesh:({self.mesh.tolist()})" - - def __hash__(self): - # ideally, we should use object id as hash, because different device mesh objects - # give different subprocess group, so different device meshes. - # in practice of sharding propagation, - # we only care about different mesh tensor (value, shape). - return self._hash - - def __eq__(self, other: object) -> bool: - if not isinstance(other, DeviceMesh): - return False - if id(self.mesh) == id(other.mesh): # short-cut eq - return True - if self.device_type != other.device_type: - return False - return self.mesh.shape == other.mesh.shape and self._flatten_mesh_list == other._flatten_mesh_list - - def __getitem__(self, mesh_dim_name: str) -> "DeviceMesh": - """ - Slice the current DeviceMesh based on the mesh_dim_name given to create a child - DeviceMesh. - - Args: - mesh_dim_name (str): the name of the mesh dimension of the parent DeviceMesh - to create a child DeviceMesh for. - Returns: - A :class:`DeviceMesh` object - - Example (2 host with 4 GPUs each): - ``` - # Below is a DeviceMesh with mesh_shape of (2, 4) and mesh_dim_name of ("dp", "tp") - mesh = DeviceMesh(device_type="cuda", - mesh=[ - [0, 1, 2, 3], - [4, 5, 6, 7] - ], - mesh_dim_names=["dp", "tp"]) - ) - ``` - Calling mesh["tp"] on rank 0, 1, 2, 3 would return a 1D child DeviceMesh:([0, 1, 2, 3]). - Calling mesh["tp"] on rank 4, 5, 6, 7 would return a 1D child DeviceMesh:([4, 5, 6, 7]). - Calling mesh["dp"] on rank 0, 4 would return a 1D child DeviceMesh:([0, 4]). - Calling mesh["dp"] on rank 1, 5 would return a 1D child DeviceMesh:([1, 5]). - Calling mesh["dp"] on rank 2, 6 would return a 1D child DeviceMesh:([2, 6]). - Calling mesh["dp"] on rank 3, 7 would return a 1D child DeviceMesh:([3, 7]). - """ - if self.mesh.ndim <= 1: - raise RuntimeError(f"Cannot slice a DeviceMesh with {self.mesh.ndim} dimension.") - if self.mesh_dim_names is None: - raise KeyError( - "No `mesh_dim_names` found.", - "To slice the device mesh, please call `init_device_mesh` with `mesh_dim_names`.", - ) - if mesh_dim_name not in self.mesh_dim_names: - raise KeyError( - f"Mesh dimension '{mesh_dim_name}' does not exist.", - f"Available mesh dimensions are: {self.mesh_dim_names}", - ) - mesh_dim = self.mesh_dim_names.index(mesh_dim_name) - submesh = mesh_resources.create_child_mesh(self, mesh_dim, mesh_dim_name) - - return submesh - - def get_dim_groups(self, mesh_dim: Optional[int] = None) -> Union[ProcessGroup, List[ProcessGroup]]: - if not hasattr(self, "_dim_group_infos"): - raise RuntimeError("DeviceMesh process groups not initialized!") - if mesh_dim is not None: - return _find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim]) - else: - dim_groups = [] - for mesh_dim in range(self.mesh.ndim): - dim_groups.append(_find_pg_by_ranks_and_tag(*self._dim_group_infos[mesh_dim])) - return dim_groups - - def size(self, dim: Optional[int] = None) -> int: - return self.mesh.numel() if dim is None else self.mesh.size(dim) - - @property - def ndim(self) -> int: - return self.mesh.ndim - - @property - def ndevice(self) -> int: - return torch.numel(self.mesh) - - @property - def shape(self) -> Tuple[int, ...]: - return tuple(self.mesh.shape) - - def get_rank(self) -> int: - return get_rank() - - def get_local_rank(self, mesh_dim: Optional[int] = None) -> int: - """ - Returns the local rank of the given mesh_dim of the DeviceMesh. - - Args: - mesh_dim (int, optional): it is the index of the mesh dimension. Default is None. - - Returns: - An integer denotes the local rank. - - The following program runs on each process/rank in an SPMD manner. In this example, we have 2 - hosts with 4 GPUs each. - Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 0, 1, 2, 3 would return 0. - Calling mesh_2d.get_local_rank(mesh_dim=0) on rank 4, 5, 6, 7 would return 1. - Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 0, 4 would return 0. - Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 1, 5 would return 1. - Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 2, 6 would return 2. - Calling mesh_2d.get_local_rank(mesh_dim=1) on rank 3, 7 would return 3. - """ - if self.ndim > 1 and mesh_dim is None: - raise RuntimeError( - f"Found the DeviceMesh have {self.mesh.ndim} dimensions", - "Optional kwarg `mesh_dim` needs to be specified when device_mesh.ndim > 1.", - ) - elif mesh_dim is None: - mesh_dim = 0 - - mesh_dim_group = self.get_dim_groups(mesh_dim) - assert isinstance(mesh_dim_group, ProcessGroup), "We expect ProcessGroup before calling `get_rank`!" - - return get_rank(mesh_dim_group) - - def get_coordinate(self) -> Optional[List[int]]: - """ - Return the relative indices of this rank relative to all - dimensions of the mesh. If this rank is not part of the mesh, return None. - """ - return self._coordinate_on_dim if self._coordinate_on_dim else None - - def enforce_cpu_mesh_tensor(self) -> None: - """ - move `mesh` tensor to cpu for deterministic device; - necessary for comparison and checkpoint loading. - """ - with torch.no_grad(): - self.mesh = self.mesh.cpu() - - def get_submesh(self, mesh_dims: Union[List[int], List[str]]) -> "DeviceMesh": - dims = [] - for dim in mesh_dims: - if isinstance(dim, int): - dims.append(dim) - elif isinstance(dim, str): - assert dim in self.mesh_dim_names, f"Mesh dimension '{dim}' does not exist." - dims.append(self.mesh_dim_names.index(dim)) - return mesh_resources.create_submesh_along_multi_dims(self, dims) - - def get_all_submesh(self, dim: int or str) -> List["DeviceMesh"]: - if isinstance(dim, str): - assert dim in self.mesh_dim_names, f"Mesh dimension '{dim}' does not exist." - mesh_dim = self.mesh_dim_names.index(dim) - else: - mesh_dim = dim - return mesh_resources.create_submesh_group(self, mesh_dim) - - def get_mapping_rank(self, other: "DeviceMesh"): - """ - for cross mesh resharding - we assume that the mesh is 1,2,4,8 - the size will have gcd value - """ - mesh_list = self.mesh.view(-1).tolist() - index = mesh_list.index(self.get_rank()) - other_mesh_list = other.mesh.view(-1).tolist() - gcd_value = math.gcd(len(mesh_list), len(other_mesh_list)) - if gcd_value == 1 and len(mesh_list) != 1 and len(other_mesh_list) != 1: - raise RuntimeError(f"mesh resharding the wrong shape of device mesh {mesh_list} vs {other_mesh_list}") - - a = len(mesh_list) - b = len(other_mesh_list) - factor = max(a, b) // min(a, b) - - if a > b: # group down - data = {} - for i in range((index // factor) * factor, factor): - data.update({mesh_list[index]: other_mesh_list[index // factor]}) - return data - elif a < b: # group up - return [other_mesh_list[i] for i in range(index * factor, (index + 1) * factor)] - else: - return other_mesh_list[index] - - -def init_device_mesh( - device_type: str, - mesh_shape: Tuple[int, ...], - *, - mesh_dim_names: Optional[Tuple[str, ...]] = None, -) -> DeviceMesh: - """ - Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters. - This creates a DeviceMesh with a mesh layout of n-d dimensional array, n being the len(mesh_shape) - and ith dimension being in size mesh_shape[i]. If mesh_dim_names is provided, each dimension is - labeled as mesh_dim_names[i]. - - - Args: - device_type (str): device type of the mesh. Currently supports: cpu, cuda/cuda-like. - mesh_shape: Tuple[int]: A tuple describes the dimension of the multi-dimesnion array - that describes the layout of devices. - Kwargs: - mesh_dim_names: Optional[Tuple[str]]: A tuple of mesh dim names to be assigned to each dimension - of the multi-dimensional array that describes the layout of devices. Its length must match the length - of `mesh_shape`. Each string in mesh_dim_names must be unique. - - Returns: - A :class:`DeviceMesh` object - - .. note: If no process group is found, init_device_mesh will initialize distributed process group/groups - behind the scene, which are required for distributed communications. - - Example: - >>> # xdoctest: +SKIP - >>> from torch.distributed._tensor.device_mesh import init_device_mesh - >>> - >>> mesh_1d = init_device_mesh("cuda", mesh_shape=(8,)) - >>> mesh_2d = init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp")) - """ - if mesh_dim_names is not None: - if len(set(mesh_dim_names)) != len(mesh_dim_names): - raise RuntimeError( - "Each mesh_dim_name must be uqique.", - f"Found repeated mesh_dim_name in mesh_dim_names {mesh_dim_names}", - ) - - if len(mesh_shape) != len(mesh_dim_names): - raise RuntimeError( - "mesh_shape and mesh_dim_names should have same length!", - f"Found len(mesh_dim_names): {len(mesh_dim_names)} and len(mesh_shape):{len(mesh_shape)}.", - ) - - mesh = torch.arange(math.prod(mesh_shape)).view(mesh_shape) - device_mesh = DeviceMesh( - device_type=device_type, - mesh=mesh, - mesh_dim_names=mesh_dim_names, - ) - - return device_mesh diff --git a/internlm/checkpoint/vescale/devicemesh_api.py b/internlm/checkpoint/vescale/devicemesh_api.py deleted file mode 100644 index 086604e31..000000000 --- a/internlm/checkpoint/vescale/devicemesh_api.py +++ /dev/null @@ -1,475 +0,0 @@ -################################################################################ -# -# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -################################################################################ - -import torch -import warnings -from torch.distributed import get_rank -from .device_mesh import init_device_mesh, DeviceMesh -from typing import Optional, List, Tuple, Union, Dict -from torch.distributed.distributed_c10d import ProcessGroup - -__all__ = ["VESCALE_DEVICE_MESH"] - - -class VeDeviceMesh: - _MESH_DIM_NAMES_MAPPING: Dict[int, str] = {} - _MESH_DIM_NAMES_LOOKUP: List[str] = None - _TENSOR_PARALLEL_SIZE: int = None - _DATA_PARALLEL_SIZE: int = None - _PIPELINE_PARALLEL_SIZE: int = None - _DATA_PARALLEL_GROUP: ProcessGroup = None - _TENSOR_PARALLEL_GROUP: ProcessGroup = None - _GLOBAL_MESH: DeviceMesh = None - _MESH_GRIDS: torch.Tensor = None - _DATA_PARALLEL_MESH: DeviceMesh = None - _TENSOR_PARALLEL_MESH: DeviceMesh = None - _GLOBAL_PIPELINE_MODEL_PARALLEL_MESHES: List[DeviceMesh] = None - _GLOBAL_TENSOR_PARALLEL_MESHES: List[DeviceMesh] = None - _RANK_COORDINATE: List[int] = None - DEFAULT_DEVICE_COUNT: int = ( - torch.cuda.device_count() if torch.cuda.is_available() else 8 - ) # enables 8 ranks for CPU multi-processing - PP_DIM: int = 0 - - def init_device_mesh( - self, - device_type: str, - mesh_shape: Tuple[int, ...], - *, - mesh_dim_names: Optional[Tuple[str, ...]] = None, - check_uniqueness: bool = False, - ) -> DeviceMesh: - """Initializes a `DeviceMesh` based on `device_type`, `mesh_shape`, and `mesh_dim_names` parameters. - This creates a DeviceMesh with a mesh layout of n-d dimensional array, n being the len(mesh_shape) - and ith dimension being in size mesh_shape[i]. If mesh_dim_names is provided, each dimension is - labeled as mesh_dim_names[i]. Inherit this utility from upstream DeviceMesh. - - Syntax of (global) DeviceMesh created by our API: - Dimensions follow a left-to-right, inter-instance to intra-instance fashion: i.e. - 1. Dimensions of 3-dimensional global DeviceMesh: [PIPELINE_PARALLEL_DIM, DATA_PARALLEL_DIM, TENSOR_PARALLEL_DIM] - - When PIPELINE_PARALLEL_DIM > 1, 1). DATA_PARALLEL_DIM=1, or 2). TENSOR_PARALLEL_DIM=1, or - 3). DATA_PARALLEL_DIM=1, or 2). TENSOR_PARALLEL_DIM=1, DeviceMesh is written in 3-dimensional - 2. Dimensions of 2-dimensional global DeviceMesh: [DATA_PARALLEL_DIM, TENSOR_PARALLEL_DIM] - 3. Dimensions of 1-dimensional global DeviceMesh: [DATA_PARALLEL_DIM or TENSOR_PARALLEL_DIM] - - 1-dimensional DeviceMesh can be used to specify process groups of data parallel and tensor model parallel dimensions - - Args: - device_type (str): device type of the mesh. Currently supports: cpu, cuda/cuda-like. - mesh_shape: Tuple[int]: A tuple describes the dimension of the multi-dimesnion array - that describes the layout of devices. - Kwargs: - mesh_dim_names: Optional[Tuple[str]]: A tuple of mesh dim names to be assigned to each dimension - of the multi-dimensional array that describes the layout of devices. Its length must match the length - of `mesh_shape`. Each string in mesh_dim_names must be unique. Note that if mesh_dim_names is None, - the function will provide a default mesh identifiers. - - check_uniqueness (bool): This advanced argument is used to prevent users from spoiling global - DeviceMesh API by creating multiple copies in a large code repository. - Set to True to allow VESCALE_DEVICE_MESH API to check the "global device mesh" is only initialized once. - Otherwise, users can create as many DeviceMeshes as they want just like with upstream Devicemesh. - - Returns: - A :class:`DeviceMesh` object - - .. note: If no process group is found, init_device_mesh will initialize distributed process group/groups - behind the scene, which are required for distributed communications. - - Example: - >>> # xdoctest: +SKIP - >>> from vescale.devicemesh_api import VESCALE_DEVICE_MESH - >>> - >>> # Example 1: initialize the global DeviceMesh as a one-dimensional DeviceMesh - >>> VESCALE_DEVICE_MESH.init_device_mesh("cuda", mesh_shape=(8,)) - >>> - >>> # Example 2: re-initialize the global DeviceMesh as a two-dimensional DeviceMesh - >>> VESCALE_DEVICE_MESH.init_device_mesh("cuda", mesh_shape=(2, 8), mesh_dim_names=("dp", "tp")) - - Limitation: we currently only support fixed sized DeviceMesh with 1 to 3 dimensions. We will loosen this constraint in future. - """ - if device_type.startswith("cuda") and device_type != "cuda": - warnings.warn("'cuda:' is invalid ! Convert to pure 'cuda'!") - device_type = "cuda" - assert device_type in ("cuda", "cpu", "meta"), "Supports only three device types: cuda, cpu, meta!" - if self._GLOBAL_MESH is None or not check_uniqueness: - self._TENSOR_PARALLEL_SIZE = self._DATA_PARALLEL_SIZE = self._PIPELINE_PARALLEL_SIZE = None - self._MESH_DIM_NAMES_MAPPING = {} - if mesh_dim_names is None: - # Support two default sets of default mesh dimensions: 2-dim [dp, tp], and 3-dim [pp, dp, tp] - mesh_dim_names = ["PP", "DP", "TP"][-len(mesh_shape) :] - if device_type is None: - device_type = "cuda" - self._GLOBAL_MESH = init_device_mesh(device_type, mesh_shape, mesh_dim_names=mesh_dim_names) - self._MESH_GRIDS = self._GLOBAL_MESH.mesh.clone().detach().cpu() - if len(mesh_shape) == 3: - self._PIPELINE_PARALLEL_SIZE, self._DATA_PARALLEL_SIZE, self._TENSOR_PARALLEL_SIZE = mesh_shape - elif len(mesh_shape) == 2: - self._DATA_PARALLEL_SIZE, self._TENSOR_PARALLEL_SIZE = mesh_shape - else: - self._DATA_PARALLEL_SIZE = self._TENSOR_PARALLEL_SIZE = mesh_shape[0] - for idx, name in enumerate(mesh_dim_names[::-1]): - self._MESH_DIM_NAMES_MAPPING[idx] = name - self._MESH_DIM_NAMES_LOOKUP = list(self._MESH_DIM_NAMES_MAPPING.values())[::-1] - self._RANK_COORDINATE = None - self._GLOBAL_PIPELINE_MODEL_PARALLEL_MESHES = None - self._GLOBAL_TENSOR_PARALLEL_MESHES = None - elif check_uniqueness: - raise ValueError( - "Already initialized the global DeviceMesh! Turn 'check_uniqueness' off to remove the contraint." - ) - return self._GLOBAL_MESH - - def get( - self, - **kwargs, - ) -> Optional[DeviceMesh]: - """ - Retrieves the global device mesh. If it has not been initialized, pass in - arguments to initialize one. - - Args: - **kwargs (dict): arguments to initialize the global device mesh. - - Returns: - A :class:`DeviceMesh` object - """ - if self._GLOBAL_MESH is None and kwargs: - self.init_device_mesh(**kwargs) - return self._GLOBAL_MESH - - def _get_tensor_parallel_mesh(self) -> DeviceMesh: - """ - This function works the same as get_tensor_parallel_mesh(), but - specifies _validate_mesh=False. - """ - assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" - if self._TENSOR_PARALLEL_MESH is None: - assert self._TENSOR_PARALLEL_GROUP is not None, "tensor model parallel group is not initialized" - assert self._MESH_DIM_NAMES_MAPPING - tensor_dim_name = self._MESH_DIM_NAMES_MAPPING[0] - TP_mesh = self.get()[tensor_dim_name] - self._TENSOR_PARALLEL_MESH = DeviceMesh( - device_type=TP_mesh.device_type, - mesh=TP_mesh.mesh, - pg=self._TENSOR_PARALLEL_GROUP, - _validate_mesh=False, - ) - return self._TENSOR_PARALLEL_MESH - - def _get_data_parallel_mesh(self) -> DeviceMesh: - """ - This function works the same as get_data_parallel_mesh(), but - specifies _validate_mesh=False. - """ - assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" - if self._DATA_PARALLEL_MESH is None: - assert self._DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized" - assert len(self._MESH_DIM_NAMES_MAPPING) >= 2 - data_dim_name = self._MESH_DIM_NAMES_MAPPING[1] - DP_mesh = self.get()[data_dim_name] - self._DATA_PARALLEL_MESH = DeviceMesh( - device_type=DP_mesh.device_type, mesh=DP_mesh.mesh, pg=self._DATA_PARALLEL_GROUP, _validate_mesh=False - ) - return self._DATA_PARALLEL_MESH - - def get_strategy_coordinate(self, local_rank=None) -> List[int]: - """ - Translate current local rank to a strategy coordinate of initialized strategy dimensions. - If local_rank is not provided, return coordinate of current rank. - The only difference of this function w.r.t. upstream DeviceMesh's get_coordinate() is that - it enables users query strategy coordinate of arbitrary ranks. - - Args: - local_rank (int): rank id. If local_rank is None, return the coordinate of the local rank. - - Returns: - Coordinate of local rank mapped to the global DeviceMesh's parallel dimensions. - - Example: - >>> from vescale.devicemesh_api import VESCALE_DEVICE_MESH - >>> dp_size, tp_size = 2, 2 - >>> # Initialize global device mesh of (dp_size=2, tp_size=2) - >>> VESCALE_DEVICE_MESH.init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("DP", "TP")) - >>> local_rank = torch.distributed.get_rank() # local_rank is 0 - 0 - >>> VESCALE_DEVICE_MESH.get_strategy_coordinate(local_rank) - [0, 0] - >>> VESCALE_DEVICE_MESH.get_strategy_coordinate(3) - [1, 1] - """ - assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" - if local_rank is None: - if self._RANK_COORDINATE is None: - self._RANK_COORDINATE = self.get_strategy_coordinate(self.get_local_rank()) - return self._RANK_COORDINATE - rank_coordinate = [int(item) for item in (self._MESH_GRIDS == local_rank).nonzero(as_tuple=True)] - return rank_coordinate - - def lookup_rank(self, dim: Union[int, str]) -> int: - """ - Look up the specified 'id' from a particular dimension of the - current rank's strategy coordinate. - - Args: - dim (Union[int, str]): Dimension indicator. - - Returns: - Specified parallel strategy 'rank' of a global rank. - - Example: - >>> from vescale.devicemesh_api import VESCALE_DEVICE_MESH - >>> dp_size, tp_size = 2, 2 - >>> # Initialize global device mesh of (dp_size=2, tp_size=2) - >>> VESCALE_DEVICE_MESH.init_device_mesh("cuda", (dp_size, tp_size), mesh_dim_names=("DP", "TP")) - >>> local_rank = torch.distributed.get_rank() # local_rank = 0 - 0 - >>> VESCALE_DEVICE_MESH.get_strategy_coordinate(local_rank) - [0, 0] - >>> index = 1 - >>> VESCALE_DEVICE_MESH.lookup_rank(index) # local_rank is 0 - 0 - >>> dim_name = "DP" - >>> VESCALE_DEVICE_MESH.lookup_rank(dim_name) # local_rank is 0 - 0 - """ - assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" - if isinstance(dim, int): - assert 0 <= dim < len(self._MESH_DIM_NAMES_MAPPING) - else: - assert dim in self._MESH_DIM_NAMES_MAPPING.values() - if self._RANK_COORDINATE is None: - self.get_strategy_coordinate() - if isinstance(dim, str): - index = self._MESH_DIM_NAMES_LOOKUP.index(dim) - return self._RANK_COORDINATE[index] - else: - return self._RANK_COORDINATE[dim] - - def get_strategy_size(self, dim: Union[int, str]) -> List[int]: - """ - Return the size of a parallel strategy dimension of the global DeviceMesh. - - Args: - dim (Union[int, str]): Dimension indicator. - - Returns: - Size of a strategt dimension. - """ - assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" - if isinstance(dim, int): - assert 0 <= dim < len(self._MESH_DIM_NAMES_MAPPING) - else: - assert dim in self._MESH_DIM_NAMES_MAPPING.values() - if isinstance(dim, str): - index = self._MESH_DIM_NAMES_LOOKUP.index(dim) - return self.size(index) - else: - return self.size(dim) - - def get_local_rank(self) -> int: - """ - Get rank ID based on this machine. - """ - assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" - local_device_count = torch.cuda.device_count() if torch.cuda.is_available() else self.DEFAULT_DEVICE_COUNT - return get_rank() % local_device_count - - def get_pipeline_parallel_rank(self) -> int: - """ - Get pipeline parallel rank (stage id) of local rank id. - """ - assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" - num_dims = len(self._MESH_DIM_NAMES_MAPPING) - assert num_dims <= 3 - if len(self._MESH_DIM_NAMES_MAPPING) == 3: - pipe_dim_name = self._MESH_DIM_NAMES_MAPPING[2] - return self.lookup_rank(pipe_dim_name) - else: - return 0 - - def get_data_parallel_rank(self) -> int: - """ - Get data parallel rank (stage id) of local rank id. - """ - assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" - assert len(self._MESH_DIM_NAMES_MAPPING) >= 2 - if len(self._MESH_DIM_NAMES_MAPPING) > 1: - data_dim_name = self._MESH_DIM_NAMES_MAPPING[1] - else: - data_dim_name = self._MESH_DIM_NAMES_MAPPING[0] - return self.lookup_rank(data_dim_name) - - def get_tensor_parallel_rank(self) -> int: - """ - Get tensor parallel rank (stage id) of local rank id. - """ - assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" - assert self._MESH_DIM_NAMES_MAPPING - tensor_dim_name = self._MESH_DIM_NAMES_MAPPING[0] - return self.lookup_rank(tensor_dim_name) - - def get_pipeline_parallel_mesh(self) -> DeviceMesh: - """ - Return the pipeline parallel view of the global DeviceMesh. - """ - assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" - assert len(self._MESH_DIM_NAMES_MAPPING) == 3 - pipe_dim_name = self._MESH_DIM_NAMES_MAPPING[0] - return self.get()[pipe_dim_name] - - def get_global_pipeline_parallel_meshes(self, device_type="cuda") -> list: - assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" - if self._GLOBAL_PIPELINE_MODEL_PARALLEL_MESHES is None: - meshes = [] - device_mesh = self.get() - for inner_group in device_mesh.mesh.tolist(): - meshes.append(DeviceMesh(device_type, inner_group, _validate_mesh=False)) - self._GLOBAL_PIPELINE_MODEL_PARALLEL_MESHES = meshes - return self._GLOBAL_PIPELINE_MODEL_PARALLEL_MESHES - - def get_data_parallel_mesh(self) -> DeviceMesh: # noqa: F811 - """ - Return the data parallel view of the global DeviceMesh. - """ - assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" - assert self._MESH_DIM_NAMES_MAPPING - dp_name = self._MESH_DIM_NAMES_MAPPING[1] if self.ndim > 1 else self._MESH_DIM_NAMES_MAPPING[0] - return self.get()[dp_name] - - def get_tensor_parallel_mesh(self) -> DeviceMesh: - """ - Return the tensor parallel view of the global DeviceMesh. - """ - assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" - assert self._MESH_DIM_NAMES_MAPPING - tp_name = self._MESH_DIM_NAMES_MAPPING[0] - return self.get()[tp_name] - - def get_global_tensor_parallel_meshes(self) -> list: - assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" - if self._GLOBAL_TENSOR_PARALLEL_MESHES is None: - assert len(self._MESH_DIM_NAMES_LOOKUP) == 3 - tp_meshes = [] - global_dm = self.get() - device_type = self.get_tensor_parallel_mesh().device_type - all_tp_list = global_dm.mesh.view(-1, global_dm.mesh.size(2)) - for tp_group in all_tp_list: - tp_mesh = DeviceMesh( - device_type, - tp_group, - _validate_mesh=False, - _init_process_groups=False, - ) - tp_meshes.append(tp_mesh) - self._GLOBAL_TENSOR_PARALLEL_MESHES = tp_meshes - return self._GLOBAL_TENSOR_PARALLEL_MESHES - - def is_first_stage(self) -> bool: - """ - Return if the current stage is the first stage, if using pipeline parallelism. - """ - assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" - pp_rank = self.get_pipeline_parallel_rank() - return pp_rank == 0 - - def is_last_stage(self) -> bool: - """ - Return if the current stage is the last stage, if using pipeline parallelism. - """ - assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" - assert len(self._MESH_DIM_NAMES_MAPPING) == 3 - device_mesh = self.get() - num_stages = device_mesh.size(self.PP_DIM) - pp_rank = self.get_pipeline_parallel_rank() - return pp_rank == num_stages - 1 - - def __getitem__(self, mesh_dim_name: str) -> DeviceMesh: - """ - Slice the current DeviceMesh based on the mesh_dim_name given to create a child - DeviceMesh. Inherit this utility from upstream DeviceMesh. - - Args: - mesh_dim_name (str): mesh dimension name. - - Returns: - a dimension "view" of the global DeviceMesh. - """ - assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" - device_mesh = self.get() - return device_mesh[mesh_dim_name] - - def get_data_parallel_dim_groups(self) -> ProcessGroup: - """ - Match process groups of data parallel dimension given - sizes of DeviceMesh. - """ - assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" - device_mesh = self.get() - dim_size = len(device_mesh.mesh.shape) - assert 1 <= dim_size <= 3 - if dim_size <= 2: - return device_mesh.get_dim_groups(0) - return device_mesh.get_dim_groups(1) - - def get_tensor_parallel_dim_groups(self) -> ProcessGroup: - """ - Return process group of the lowest dimension as - the dimension of tensor model parallelism. - """ - assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" - device_mesh = self.get() - assert 1 <= len(device_mesh.mesh.shape) <= 3 - return device_mesh.get_dim_groups(0) - - def get_coordinate(self) -> Optional[List[int]]: - """ - Return the relative indices of this rank relative to all - dimensions of the mesh. If this rank is not part of the mesh, return None. - Inherit this utility from upstream DeviceMesh. - """ - assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" - device_mesh = self.get() - return device_mesh.get_coordinate() - - def size(self, dim: Optional[int] = None) -> int: - """ - Returns dimension size of DeviceMesh along 'dim' dimension. If dim is None, - return the total number of ranks in this DeviceMesh. - - Args: - dim (int): dimension index - - Returns: - Dimension size, or total number of ranks if None. - """ - assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" - device_mesh = self.get() - return device_mesh.mesh.numel() if dim is None else device_mesh.mesh.size(dim) - - @property - def ndim(self) -> int: - assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" - device_mesh = self.get() - return device_mesh.mesh.ndim - - @property - def shape(self) -> Tuple[int, ...]: - assert self._GLOBAL_MESH, "Must initialize global DeviceMesh first!" - device_mesh = self.get() - return tuple(device_mesh.mesh.shape) - - -VESCALE_DEVICE_MESH = VeDeviceMesh() diff --git a/internlm/checkpoint/vescale/distributed_optimizer.py b/internlm/checkpoint/vescale/distributed_optimizer.py deleted file mode 100644 index 4f861a7f9..000000000 --- a/internlm/checkpoint/vescale/distributed_optimizer.py +++ /dev/null @@ -1,109 +0,0 @@ -import math -import inspect -from dataclasses import dataclass -from typing import Dict, Sequence, Tuple, Optional, Any -import torch -import torch.distributed as dist - - - -class Range: - """ - A range represents a start and end points for indexing a shard - from a full tensor. - """ - - def __init__(self, start, end): - self.start = start - self.end = end - self.size = end - start - - def normalize(self, start=0): - return Range(start, start + self.size) - - def __str__(self): - return "%d,%d [%d]" % (self.start, self.end, self.size) - - def __len__(self): - return self.end - self.start - - def __repr__(self) -> str: - return "Range(%d,%d [%d])" % (self.start, self.end, self.size) - -@dataclass -class OptimizerStateSpec: - """This class represents mapping between local flattened 1D tensor - and global original DTensor in DOptimzier, it is used for - loading or saving optimizer states using vescale.checkpoint (PyTorch DCP) - and load-time checkpoint resharding when changing tp size or dp size. - - For example, a linear layer in Vescale is DTensor(size=[1024, 1024]) - It first divides into two parts along dim=0 with tensor parallel size = 2 - - tensor_part_0 = DTensor(size=[512, 1024]) - tensor_part_1 = DTensor(size=[512, 1024]) - - Then each part's optimizer states are initalized in DOptimizer sepearately - - Assume dp=2 - For process with dp=0 tp=0, the flatten tensor is torch.Tensor(size=[262144]) - global_shape=(1024, 1024), local_shape=(256, 1024), global_offset=(0, 0) local=torch.Tensor(size=[262144]).view(local_shape) - - For process with dp=1 tp=0, the flatten tensor is torch.Tensor(size=[262144]) - global_shape=(1024, 1024), local_shape=(256, 1024), global_offset=(256, 0) local=torch.Tensor(size=[262144]).view(local_shape) - - For process with dp=0 tp=1, the flatten tensor is torch.Tensor(size=[262144]) - mapping to [512:768, 0:1024] in original DTensor - global_shape=(1024, 1024), local_shape=(256, 1024), global_offset=(512, 0) local=torch.Tensor(size=[262144]).view(local_shape) - - For process with dp=1 tp=1, the flatten tensor is torch.Tensor(size=[262144]) - global_shape=(1024, 1024), local_shape=(256, 1024), global_offset=(768, 0) local=torch.Tensor(size=[262144]).view(local_shape) - """ - - # The original DTensor shape - global_shape: Tuple[int] - # The local tensor shape ***before flattened into 1D tensor*** - local_shape: Tuple[int] - # The local tensor's offset with respect to origianl DTensor - global_offset: Tuple[int] - # The unflattened local tensor after create view using local_shape on the flattened 1D Tensor in DOptimizer - # NOTE: In order to support TP resharding and state cross dp ranks, we defer the reshaping from 1D to local_shape - # to generate saving plan using vescale.checkpoint (PyTorch DCP) - local_tensor: torch.Tensor - # If the current optimizer state is sharded by multiple dp ranks, - # we should record all ranks and their ranges - dp_ranks_ranges: Optional[Dict[int, Range]] - -def convert_dict_with_sharded( - param_state: dict, - global_shape: Tuple[int], - local_shape: Tuple[int], - global_offset: Tuple[int], - dp_ranks_ranges: Optional[Dict[int, Range]], -): - new_param_state = {} - for k, v in param_state.items(): - if isinstance(v, torch.Tensor) and v.dim() >= 1: - # Don't unflatten tensor here, see the comments above - if not dp_ranks_ranges: - if math.prod(local_shape) != math.prod(v.shape): - print(f"rank={dist.get_rank()} name={k} global shape={global_shape}\ - local_shape={local_shape} global_offset={global_offset} real shape={v.shape}") - raise AssertionError() - new_param_state[k] = OptimizerStateSpec( - global_shape, local_shape, global_offset, v, dp_ranks_ranges - ) # , process_group) - else: - new_param_state[k] = v - return new_param_state - -def convert_dict_sharded_to_tensor(param_state: dict, range_1d: Optional[Range]): - for k, v in param_state.items(): - if isinstance(v, OptimizerStateSpec): - # If the state is distributed on multiple dp ranks - # Get my parts - if range_1d: - param_state[k] = v.local_tensor.flatten()[range_1d.start : range_1d.end] - else: - param_state[k] = v.local_tensor.flatten() - return param_state \ No newline at end of file diff --git a/internlm/checkpoint/vescale/filesystem.py b/internlm/checkpoint/vescale/filesystem.py deleted file mode 100644 index 71f21024c..000000000 --- a/internlm/checkpoint/vescale/filesystem.py +++ /dev/null @@ -1,960 +0,0 @@ -################################################################################ -# Copyright (c) Meta Platforms, Inc. and affiliates -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -################################################################################ -# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. -################################################################################ -from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor -from .mem_checkpoint import copy_gpu_tensor_to_cpu_pinned_mem_pool, deallocate_cpu_tensor_in_pinned_mem_pool -from abc import ABC, abstractmethod -import collections -from dataclasses import dataclass -import os -import dataclasses -import io -import torch.distributed as dist -import pickle -from typing import List, Tuple, Union, Dict, cast, Any -from internlm.utils.logger import get_logger -import time -import torch -from torch import Tensor -from torch.futures import Future -from pathlib import Path -from internlm.core.context import global_context as gpc -from internlm.core.context import ParallelMode -from internlm.train.pipeline import map_fqn_global_to_local, map_layer_attr -from internlm.utils.common import get_current_device - - -from torch.distributed.checkpoint.metadata import ( - Metadata, - MetadataIndex, -) -from torch.distributed.checkpoint.storage import ( - StorageReader, - StorageWriter, - WriteResult, -) - -from torch.distributed.checkpoint.planner import ( - LoadItemType, - LoadPlanner, - LoadPlan, - SavePlan, - SavePlanner, - WriteItem, - ReadItem, - WriteItemType, -) - -from torch.distributed.checkpoint.utils import _create_file_view - -from torch.distributed._shard._utils import narrow_tensor_by_index -from torch._utils import _get_device_module - -logger = get_logger(__file__) -from .common import P2PTensorsInfo - -__all__ = [ - "FileSystemWriter", - "FileSystemReader", -] - - -@dataclass -class _StorageInfo: - """ - This is the per entry storage info - """ - - relative_path: str - offset: int - length: int - - -@dataclass -class _StoragePrefix: - prefix: str - - -DEFAULT_SUFFIX = ".distcp" - - -def _trim(tensor: torch.Tensor) -> torch.Tensor: - tensor = copy_gpu_tensor_to_cpu_pinned_mem_pool(tensor.detach()) - # Comment the original DCP code - # When dumping to pinned memory, - # the memory layout for tensor has been contiguous - # if tensor._typed_storage()._size() != tensor.numel(): - # tensor = tensor.clone() - return tensor - - -def _result_from_write_item(item: WriteItem, size_in_bytes, storage_data) -> WriteResult: - return WriteResult(index=item.index, size_in_bytes=size_in_bytes, storage_data=storage_data) - - -class _TensorLoader(ABC): - @abstractmethod - def add(self, fqn, size, obj): - pass - - @abstractmethod - def start_loading(self): - pass - - @abstractmethod - def values(self): - pass - - -def collect_optim_state_across_dp_ranks( - tensor: torch.Tensor, rank_ranges: Dict[int, Any], p2p_reqs: Dict[int, Any] -) -> torch.Tensor: - orignal_shape = tensor.shape - tensor = tensor.flatten() - logger.debug("DEBUG: Start receiving p2p tensor") - recv_start = time.time() - for req in p2p_reqs: - req.wait() - recv_end = time.time() - recv_start - logger.debug(f"DEBUG: Finish receiving p2p tensor. Time cost: {recv_end}s") - for v in rank_ranges.values(): - received_tensor, param_range = v - tensor[param_range.start : param_range.end] = received_tensor - tensor = tensor.reshape(orignal_shape) - return tensor - - -class _SerialCpuLoader(_TensorLoader): - def __init__(self, resolve_fun, p2p_tensors_info: P2PTensorsInfo = None): - self.resolve_fun = resolve_fun - self.items = [] - self.p2p_tensors_info = p2p_tensors_info - - def add(self, fqn, size, obj): - self.items.append((fqn, size, obj)) - - def start_loading(self): - pass - - def values(self): - for fqn, _, obj in self.items: - tensor = self.resolve_fun(obj).detach() - # if self.p2p_tensors_info and (obj.index.fqn, obj.index.offset) in self.p2p_tensors_info.recv_tensors: - # tensor = collect_optim_state_across_dp_ranks( - # tensor=tensor, - # rank_ranges=self.p2p_tensors_info.recv_tensors[(obj.index.fqn, obj.index.offset)], - # p2p_reqs=self.p2p_tensors_info.recv_p2p_reqs[(obj.index.fqn, obj.index.offset)], - # ) - # elif self.p2p_tensors_info and fqn in self.p2p_tensors_info.recv_tensors: - # tensor = collect_optim_state_across_dp_ranks( - # tensor=tensor, rank_ranges=self.p2p_tensors_info.recv_tensors[fqn], p2p_reqs=self.recv_p2p_reqs[fqn] - # ) - tensor = copy_gpu_tensor_to_cpu_pinned_mem_pool(tensor) - # Comment the original DCP code - # When dumping to pinned memory, - # the memory layout for tensor has been contiguous - # if tensor.storage().size() != tensor.numel(): - # tensor = tensor.clone() - yield ( - tensor, - obj, - ) - - -class _OverlappingCpuLoader(_TensorLoader): - def __init__( - self, - resolve_fun, - p2p_tensors_info: P2PTensorsInfo = None, - stream=None, - inflight_threshhold=1_000_000, - ): - self.resolve_fun = resolve_fun - self.items = [] - self.inflight_threshhold = inflight_threshhold - self.in_flight_data = 0 - self.current_items: collections.deque = collections.deque() - self.idx = 0 - self.started = False - self.device_type = stream.device_type if stream else torch.device("cuda").type - self.device_module = _get_device_module(self.device_type) - self.p2p_tensors_info = p2p_tensors_info - self.stream = stream or self.device_module.current_stream() - if self.stream != self.device_module.current_stream(): - self.stream.wait_stream(self.device_module.current_stream()) - - @property - def _done(self): - return self.idx >= len(self.items) - - def _drain(self): - drained = [] - if self.in_flight_data >= self.inflight_threshhold: - self.stream.synchronize() - while self.in_flight_data >= self.inflight_threshhold: - val = self.current_items.popleft() - self.in_flight_data -= val[0].numel() * val[0].element_size() - drained.append(val) - return drained - - def _refill(self): - with self.device_module.stream(self.stream): - while not self._done and self.in_flight_data < self.inflight_threshhold: - fqn, _, obj = self.items[self.idx] - self.idx += 1 - tensor = self.resolve_fun(obj).detach() - # if self.p2p_tensors_info and (obj.index.fqn, obj.index.offset) in self.p2p_tensors_info.recv_tensors: - # tensor = collect_optim_state_across_dp_ranks( - # tensor=tensor, - # rank_ranges=self.p2p_tensors_info.recv_tensors[(obj.index.fqn, obj.index.offset)], - # p2p_reqs=self.p2p_tensors_info.recv_p2p_reqs[(obj.index.fqn, obj.index.offset)], - # ) - # elif self.p2p_tensors_info and fqn in self.p2p_tensors_info.recv_tensors: - # tensor = collect_optim_state_across_dp_ranks( - # tensor=tensor, - # rank_ranges=self.p2p_tensors_info.recv_tensors[fqn], - # p2p_reqs=self.p2p_tensors_info.recv_p2p_reqs[fqn], - # ) - if tensor.device.type == self.device_type: - tensor = copy_gpu_tensor_to_cpu_pinned_mem_pool(tensor, non_blocking=True) - # Comment the original DCP code - # When dumping to pinned memory, the memory layout for tensor has been contiguous - # elif tensor.device == torch.device("cpu"): - # if tensor.storage().size() != tensor.numel(): - # # this forces the tensor to be both contiguous and with minimal storage - # tensor = tensor.clone() - - self.current_items.append( - ( - tensor, - obj, - ) - ) - self.in_flight_data += tensor.numel() * tensor.element_size() - - def _finish(self): - assert self._done - if len(self.current_items) > 0: - self.stream.synchronize() - return self.current_items - - def add(self, fqn, size, obj): - if self.started: - raise RuntimeError("cannot add items after loading started") - self.items.append((fqn, size, obj)) - - def start_loading(self): - if self.started: - return - self.started = True - self.items.sort(key=lambda x: x[1]) - self._refill() - - def values(self): - self.start_loading() - while not self._done: - drained = self._drain() - self._refill() - yield from drained - - yield from self._finish() - - -def _item_fqn(item: WriteItem) -> str: - return item.index.fqn - - -def _item_size(item: WriteItem) -> int: - size = 1 - assert item.tensor_data is not None - # can't use math.prod as PT needs to support older python - for s in item.tensor_data.size: - size *= s - - dtype = item.tensor_data.properties.dtype - return size * torch._utils._element_size(dtype) - - -def _split_by_size_and_type(bins, items: List[WriteItem]) -> List[List[WriteItem]]: - if bins == 1: - return [items] - - bytes_w = [wi for wi in items if wi.type == WriteItemType.BYTE_IO] - tensor_w = [wi for wi in items if wi.type != WriteItemType.BYTE_IO] - - buckets: List[List[WriteItem]] = [[] for _ in range(bins)] - bucket_sizes = [0 for _ in range(bins)] - - tensor_w.sort(key=_item_size, reverse=True) - - for i, wi in enumerate(bytes_w): - buckets[i % bins].append(wi) - - for wi in tensor_w: - idx = min(enumerate(bucket_sizes), key=lambda x: x[1])[0] - buckets[idx].append(wi) - bucket_sizes[idx] += _item_size(wi) - - return buckets - - -def _write_item(stream, data, write_item, storage_key): - offset = stream.tell() - - if write_item.type == WriteItemType.BYTE_IO: - assert isinstance(data, io.BytesIO) - stream.write(data.getbuffer()) - else: - assert isinstance(data, torch.Tensor) - assert data.device == torch.device("cpu") - torch.save(data, stream) - length = stream.tell() - offset - - return _result_from_write_item(write_item, length, _StorageInfo(storage_key, offset, length)) - - -def _write_files_from_queue( - file_name, - storage_key, - write_items, - planner: SavePlanner, - inflight_threshhold: int, - use_fsync: bool, - p2p_tensors_info: P2PTensorsInfo = None, -): - loader: _TensorLoader - - if torch.cuda.is_available() and inflight_threshhold > 0: - loader = _OverlappingCpuLoader( - lambda x: planner.resolve_data(x), - inflight_threshhold=inflight_threshhold, - p2p_tensors_info=p2p_tensors_info, - ) - else: - loader = _SerialCpuLoader(lambda x: planner.resolve_data(x), p2p_tensors_info=p2p_tensors_info) - - tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO] - for write_item in tensor_w: - loader.add(_item_fqn(write_item), _item_size(write_item), write_item) - loader.start_loading() - - bytes_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO] - if len(bytes_w) != 0: - assert False - write_results = [] - - stream = open(file_name, "wb") - logger.debug("Start writing byte io data.") - byte_io_write_start = time.time() - for write_item in bytes_w: - data = planner.resolve_data(write_item) - write_results.append(_write_item(stream, data, write_item, storage_key)) - byte_io_write_time = time.time() - byte_io_write_start - logger.debug(f"Finish writing byte io data. Time cost: {byte_io_write_time}s") - - logger.debug("Start writing tensor data.") - tensor_write_start = time.time() - for tensor, write_item in loader.values(): - assert tensor.is_cpu - write_results.append(_write_item(stream, tensor, write_item, storage_key)) - # WARNING: Call deallocate_cpu_tensor_in_pinned_mem_pooltensor - # when the reference to CPU tensor goes to zero - # so the memory pool will reuse the memory if possbile - # Othterwise, the memory pool will allocate memory on the used memory range, - # leading to cuda error 712 cudaErrorHostMemoryAlreadyRegistered - deallocate_cpu_tensor_in_pinned_mem_pool(tensor) - tensor_write_time = time.time() - tensor_write_start - logger.debug(f"Finish writing tensor data. Time cost: {tensor_write_time}s") - - if use_fsync: - os.fsync(stream.fileno()) - - file_stream_close_start = time.time() - stream.close() - file_stream_close_time = time.time() - file_stream_close_start - logger.debug(f"Finish closing file stream. Time cost: {file_stream_close_time}s") - return write_results - - -def _write_files_per_proc( - file_path: Path, - storage_key: str, - byte_data_item: List[Tuple[io.BytesIO, WriteItem]], - tensor_data_item: List[Tuple[torch.Tensor, WriteItem]], - use_fsync: bool, -) -> List[WriteResult]: - write_results = [] - stream = open(file_path, "wb") - # First write byte data. - for write_data, write_item in byte_data_item: - write_results.append(_write_item(stream, write_data, write_item, storage_key)) - # Then write tensor data. - # NOTE: the pinned memory occupied by each tensor have been reallocated. - for write_data, write_item in tensor_data_item: - write_results.append(_write_item(stream, write_data, write_item, storage_key)) - - if use_fsync: - os.fsync(stream.fileno()) - - return write_results - - -def _serialize_tensor(tensor: torch.Tensor) -> bytes: - bio = io.BytesIO() - # NOTE: currently use torch.save() to do the serialization. - torch.save(tensor, bio) - return bio.getbuffer() - - -def _write_to_file(stream, content: bytes, write_item: WriteItem, storage_key: str) -> WriteResult: - offset = stream.tell() - stream.write(content) - length = stream.tell() - offset - return _result_from_write_item(write_item, length, _StorageInfo(storage_key, offset, length)) - - -def _write_files_per_proc_pipe( - file_path: Path, - storage_key: str, - byte_data_item: List[Tuple[io.BytesIO, WriteItem]], - tensor_data_item: List[Tuple[torch.Tensor, WriteItem]], - use_fsync: bool, -) -> List[WriteResult]: - write_futures = [] - write_results = [] - stream = open(file_path, "wb") - executor = ThreadPoolExecutor(max_workers=1) - # For byte data, directly write byte data. - assert len(byte_data_item) == 0 - for write_data, write_item in byte_data_item: - content = write_data.getbuffer() - write_futures.append( - executor.submit( - _write_to_file, - stream, - content, - write_item, - storage_key, - ) - ) - # write_results.append(_write_to_file(stream, content, write_item, storage_key)) - # For tensor data, perform serialization in process then do saving in threadpool. - # print(f"tensor_data_item: {tensor_data_item}", flush=True) - data_memory = 0 - content_momory = 0 - print(f"tensor_data_item {gpc.get_global_rank()}: {len(tensor_data_item)}", flush=True) - for write_data, write_item in tensor_data_item: - data_memory += (write_data.numel() * write_data.element_size()) - # print(f"write_item: {write_item}", flush=True) - # print(f"write_data: {write_data}", flush=True) - content = _serialize_tensor(write_data) - content_momory += len(content) - write_futures.append( - executor.submit( - _write_to_file, - stream, - content, - write_item, - storage_key, - ) - ) - # write_results.append(_write_to_file(stream, content, write_item, storage_key)) - print(f"data_memory {gpc.get_global_rank()}: {data_memory / (1024 * 1024 * 1024)}, {content_momory / (1024 * 1024 * 1024) }", flush=True) - - for fut in write_futures: - write_results.append(fut.result()) - if use_fsync: - os.fsync(stream.fileno()) - executor.shutdown(wait=False) - return write_results - - -def stat_analysis(tasks, planner, p2p_tensors_info, use_fsync=True) -> List[WriteResult]: - """ - Analyzing the overhead of D2H transfer, serialization, and save operations. Assume that - all items are written into one file. - """ - # Step1, aysnc D2H, dumping objects to pinned share memory. - assert len(tasks) == 1, "please generate one write task for analysis" - loader = _SerialCpuLoader(lambda x: planner.resolve_data(x), p2p_tensors_info=p2p_tensors_info) - # Add Bytes. - byte_item_to_write = [] - for task in tasks: - _, _, write_items = task - byte_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO] - byte_item_to_write.extend(byte_w) - if len(byte_item_to_write) != 0: - assert False - # Add tenosrs. - tensor_item_to_write = [] - for task in tasks: - _, _, write_items = task - tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO] - tensor_item_to_write.extend(tensor_w) - for write_item in tensor_w: - loader.add(_item_fqn(write_item), _item_size(write_item), write_item) - loader.start_loading() - # Step1: dump to pinned memory pool. - d2h_dump_wait_start = time.time() - tensor_to_serialize: List[torch.Tensor] = [] - for tensor, write_item in loader.values(): - assert tensor.is_cpu - tensor_to_serialize.append(tensor) - deallocate_cpu_tensor_in_pinned_mem_pool(tensor) - d2h_dump_wait_time = torch.tensor(time.time() - d2h_dump_wait_start).cuda() - dist.all_reduce(d2h_dump_wait_time) - d2h_dump_wait_time = d2h_dump_wait_time.item() / dist.get_world_size() - if dist.get_rank() == 0: - logger.critical(f"End waiting for D2H tensors dumping Time: {d2h_dump_wait_time:.4f}s") - # Step2: call serialization workers to serialize objects. - serialize_wait_start = time.time() - tensor_data_to_write = [] - bio = io.BytesIO() - for tensor in tensor_to_serialize: - bio.seek(0) - bio.truncate(0) - torch.save(tensor, bio) - dump_b = bio.getvalue() - assert isinstance(dump_b, bytes) - tensor_data_to_write.append(dump_b) - serialize_wait_time = torch.tensor(time.time() - serialize_wait_start).cuda() - dist.all_reduce(serialize_wait_time) - serialize_wait_time = serialize_wait_time.item() / dist.get_world_size() - if dist.get_rank() == 0: - logger.critical(f"End waiting for serialization Time: {serialize_wait_time:.4f}s") - # Step3: save/upload the objects from memory to disk. - file_path = tasks[0][0] - storage_key = tasks[0][1] - write_results = [] - assert isinstance(file_path, Path) - save_upload_wait_start = time.time() - with open(file_path, "wb") as stream: - for write_item in byte_item_to_write: - offset = stream.tell() - data = planner.resolve_data(write_item) - stream.write(data.getbuffer()) - length = stream.tell() - offset - write_results.append(_result_from_write_item(write_item, length, _StorageInfo(storage_key, offset, length))) - for tensor_data, write_item in zip(tensor_data_to_write, tensor_item_to_write): - offset = stream.tell() - stream.write(tensor_data) - length = stream.tell() - offset - write_results.append(_result_from_write_item(write_item, length, _StorageInfo(storage_key, offset, length))) - if use_fsync: - os.fsync(stream.fileno()) - save_upload_wait_time = torch.tensor(time.time() - save_upload_wait_start).cuda() - dist.all_reduce(save_upload_wait_time) - save_upload_wait_time = save_upload_wait_time.item() / dist.get_world_size() - if dist.get_rank() == 0: - logger.critical(f"End waiting for tensors saving/uploading Time: {save_upload_wait_time:.4f}s") - return write_results - - -class FileSystemWriter(StorageWriter): - """ - Basic implementation of StorageWriter using file IO. - - This implementation makes the following assumptions and simplifications: - - * The checkpoint path is an empty or non-existing directory. - * File creation is atomic - - The checkpoint consist of one file per write request plus - a `.metadata` file with the serialized metadata. - - """ - - def __init__( - self, - path: Union[str, os.PathLike], - single_file_per_rank: bool = True, - sync_files: bool = True, - worker_count: int = 1, - per_process_copy_ahead: int = 10_000_000, - ) -> None: - """ - Initialize the writer pointing to `path` - - Args: - path: directory where the checkpoint will be written to. - single_file_per_rank: Produce one file per rank instead of one file per tensor/blob. Default to True. - sync_files : force files to be synced to permanent storage. Default to True. - worker_count: Number of IO workers (processes) to use to write. Default to 1. - per_process_copy_ahead: How many bytes to copy from the GPU ahead of saving then. Default 10Mb. - - N. B. If sync_files is disabled, there's no guarantee that the checkpoint will be consistent in the case of a failure. - """ - super().__init__() - self.path = Path(path) - self.single_file_per_rank = single_file_per_rank - # self.single_file_per_rank = False - self.sync_files = sync_files - self.worker_count = worker_count - self.per_process_copy_ahead = per_process_copy_ahead - - def set_up_storage_writer(self, is_coordinator: bool) -> None: - pass - - def prepare_local_plan(self, plan: SavePlan, p2p_tensors_info: P2PTensorsInfo = None) -> SavePlan: - self.path.mkdir(parents=True, exist_ok=True) - self.p2p_tensors_info = p2p_tensors_info - return plan - - def prepare_global_plan(self, global_plan: List[SavePlan]) -> List[SavePlan]: - new_plans = [ - dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_")) for i, plan in enumerate(global_plan) - ] - return new_plans - - def prepare_write_data(self, tasks: List[Tuple[Path, str, List[WriteItem]]], planner: SavePlanner, is_optimizer): - """ - First stage of saving, Perform Copy data to CPU (D2H). - - Args: - tasks: partitoned tasks for workers to conduct serialization and the actual saving. - planner: save planner used to resolve the bytes and tensor data. - async_io: whether do asynchrous D2H. - - NOTE: Currently we do D2H synchronously. - """ - - byte_data_item_writes: List[List[Tuple[io.BytesIO, WriteItem]]] = [] - tensor_data_item_writes: List[List[Tuple[torch.Tensor, WriteItem]]] = [] - file_path_names: List[Tuple[Path, str]] = [] - - item_list_all = [] - fqn_list_all = [] - - # Perform D2H in copy stream. - flag = 0 - d2h_dump_start = time.time() - for task in tasks: - file_path, file_name, write_items = task - byte_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO] - tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO] - byte_data_item = [(planner.resolve_data(wi), wi) for wi in byte_w] - if len(byte_data_item) != 0: - assert False - tensor_data_item = [] - - item_list = [] - fqn_list = [] - # Async copy to pinned CPU memory pool. - for item in tensor_w: - # att - fqn = _item_fqn(item) - - # map_fqn = fqn - # if fqn.endswith("exp_avg") or fqn.endswith("exp_avg_sq"): - # # os exp_avg, exp_avg_sq - # map_fqn = fqn.rsplit('.', 1)[0] - fqn_list.append((fqn, map_fqn_global_to_local[fqn] if fqn in map_fqn_global_to_local else None)) - if not is_optimizer: - if 'layer' in fqn: - assert fqn in map_fqn_global_to_local - if fqn in map_fqn_global_to_local: - fqn = map_fqn_global_to_local[fqn] - - # if fqn in map_fqn_global_to_local: - # print(f"_item_fqn: {fqn}, {map_fqn_global_to_local[fqn]}", flush=True) - # fqn = map_fqn_global_to_local[fqn] - - tensor = planner.resolve_data(item, fqn).detach().clone() - - - # if self.p2p_tensors_info and fqn in self.p2p_tensors_info.recv_tensors: - # tensor = collect_optim_state_across_dp_ranks( - # tensor=tensor, - # rank_ranges=self.p2p_tensors_info.recv_tensors[fqn], - # p2p_reqs=self.p2p_tensors_info.recv_p2p_reqs[fqn], - # ) - tensor = copy_gpu_tensor_to_cpu_pinned_mem_pool(tensor, non_blocking=True) - # print(f"item: {item.index.fqn}", flush=True) - tensor_data_item.append((tensor, item)) - - item_list.append(item.index.fqn) - flag += 1 - byte_data_item_writes.append(byte_data_item) - tensor_data_item_writes.append(tensor_data_item) - file_path_names.append((file_path, file_name)) - - fqn_list_all.append(fqn_list) - item_list_all.append(item_list) - # print(f"fqn_list_all {gpc.get_global_rank()}: {flag}, {fqn_list_all}", flush=True) - # print(f"item_list_all {gpc.get_global_rank()}: {flag}, {item_list_all}", flush=True) - # print(f"tensor_data_item_writes {gpc.get_global_rank()}: {byte_data_item_writes}, {file_path_names}, {tensor_data_item_writes}", flush=True) - - d2h_dump_time = time.time() - d2h_dump_start - logger.debug(f"End waiting for D2H copy. Time cost: {d2h_dump_time}s") - - # Deallocate pinned memory. - # NOTE: when prepare_write_data() is called next time, make sure the previous save event is completed. - # Otherwise, tensors in pinned memory pool may be overwritten. - for tensor_data_item in tensor_data_item_writes: - for tensor, _ in tensor_data_item: - assert tensor.is_cpu - deallocate_cpu_tensor_in_pinned_mem_pool(tensor) - - return byte_data_item_writes, tensor_data_item_writes, file_path_names - - def write_data( - self, plan: SavePlan, planner: SavePlanner, async_io: bool = False, io_workers=False, is_optimizer=False - ) -> Future[List[WriteResult]]: - storage_plan: _StoragePrefix = plan.storage_data - file_count = 0 - - def gen_file(): - nonlocal file_count - file_name = f"{storage_plan.prefix}{file_count}{DEFAULT_SUFFIX}" - file_count += 1 - return file_name - - tasks: List[Tuple[Path, str, List[WriteItem]]] = [] - # Generate K tasks where K is the number of worker_count. - print(f"self.single_file_per_rank: {self.single_file_per_rank}", flush=True) - if self.single_file_per_rank: - for bucket in _split_by_size_and_type(self.worker_count, plan.items): - file_name = gen_file() - print(f"file_name {gpc.get_global_rank()}: {file_name}, {self.worker_count}, {bucket}", flush=True) - tasks.append((self.path / file_name, file_name, bucket)) - # Generate K tasks where K is the number of write items. - else: - for item in plan.items: - file_name = gen_file() - tasks.append((self.path / file_name, file_name, [item])) - logger.debug(f"Rank {dist.get_rank()} writes its checkpoint into {len(tasks)} files") - # Make sure the optimizer states across dp ranks - # has been sending to other ranks - # So the receiver can get it when writing tensors to local path - # print(f"p2p_tensors_info: {self.p2p_tensors_info}", flush=True) - # if self.p2p_tensors_info: - # assert False - # logger.debug("Start waiting for sending p2p tensors futures") - # p2p_tensor_send_wait_start = time.time() - # for req in self.p2p_tensors_info.send_p2p_reqs: - # req.wait() - # p2p_tensor_send_wait_time = time.time() - p2p_tensor_send_wait_start - # logger.debug(f"End waiting for sending p2p tensors futures Time: {p2p_tensor_send_wait_time}s") - - futures = [] - if not io_workers: - executor = ProcessPoolExecutor(max_workers=self.worker_count) - # executor = torch.multiprocessing.get_context("spawn").Pool(self.worker_count) - else: - executor = io_workers - - # ProcessPool VERSION. - if isinstance(executor, ProcessPoolExecutor): - # print(f"executor: ProcessPoolExecutor", flush=True) - byte_data_item_writes, tensor_data_item_writes, file_path_names = self.prepare_write_data(tasks, planner, is_optimizer) - # print(f"byte_data_item_writes {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.TENSOR)}: {byte_data_item_writes}", flush=True) - # print(f"tensor_data_item_writes {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.TENSOR)}: {tensor_data_item_writes}", flush=True) - # print(f"file_path_names {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.TENSOR)}: {file_path_names}", flush=True) - for byte_data_item, tensor_data_item, file_path_name in zip( - byte_data_item_writes, tensor_data_item_writes, file_path_names - ): - file_path, storage_key = file_path_name - worker_args = (file_path, storage_key, byte_data_item, tensor_data_item, self.sync_files) - futures.append(executor.submit(_write_files_per_proc_pipe, *worker_args)) - # futures.append(self._serialize_workers.apply_async(_write_files_per_proc, worker_args)) - if async_io: - return futures - else: - logger.debug("Start waiting for writing futures (serilization + save)") - future_wait_start = time.time() - for fut in futures: - fut.result() - # fut.wait() - future_wait_time = time.time() - future_wait_start - logger.debug(f"End waiting for writing futures. Time cost: {future_wait_time}s") - return futures - else: - # print(f"executor: {executor}", flush=True) - # ThreadPool VERSION. - for task in tasks: - # print(f"task {gpc.get_global_rank()}: {task}", flush=True) - futures.append( - # executor.submit( - # _write_files_from_queue, - # *task, - # planner, - # self.per_process_copy_ahead, - # self.sync_files, - # self.p2p_tensors_info, - # ) - executor.submit( - _write_files_from_queue, - *task, - planner, - self.per_process_copy_ahead, - self.sync_files, - ) - ) - if async_io: - return futures - else: - logger.debug("Start waiting for writing futures") - future_wait_start = time.time() - for fut in futures: - fut.result() - future_wait_time = time.time() - future_wait_start - logger.debug(f"End waiting for writing futures. Time cost: {future_wait_time}s") - return futures - - def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None: - storage_md = dict() - for wr_list in results: - storage_md.update({wr.index: wr.storage_data for wr in wr_list}) - metadata.storage_data = storage_md - with (self.path / ".metadata.tmp").open("wb") as metadata_file: - pickle.dump(metadata, metadata_file) - os.fsync(metadata_file.fileno()) - - (self.path / ".metadata.tmp").rename(self.path / ".metadata") - - -class FileSystemReader(StorageReader): - def __init__( - self, - path: Union[str, os.PathLike], - broadcast_tensors=False, - data_parallel_process_group=None, - ) -> None: - super().__init__() - self.path = path - self.storage_data: Dict[MetadataIndex, _StorageInfo] = dict() - self.broadcast_tensors = broadcast_tensors - self.data_parallel_process_group = data_parallel_process_group - - # If broadcast_tensors is enabled, the data_parallel_process_group is not none - if self.broadcast_tensors: - assert self.data_parallel_process_group - - def _slice_file(self, file, sinfo: _StorageInfo): - return _create_file_view(file, sinfo.offset, sinfo.length) - - def _get_file_path(self, relative_path): - file_path = os.path.join(self.path, relative_path) - return file_path - - def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: - # group requests by file - per_file: Dict[str, List[ReadItem]] = dict() - for read_item in plan.items: - item_md = self.storage_data[read_item.storage_index] - path = item_md.relative_path - per_file.setdefault(path, []).append(read_item) - - # If broadcasting model tensors is enabled, - # let processes with dp_rank=0 load models and broadcast them to other processes - if self.broadcast_tensors: - self.read_data_with_broadcast(per_file=per_file, planner=planner) - else: - # Otherwise, let all ranks load tensors from files - self.read_from_files(per_file=per_file, planner=planner) - - fut: Future = Future() - fut.set_result(None) - - return fut - - def read_from_files(self, per_file: Dict[str, List[ReadItem]], planner: LoadPlanner): - print(f"debugg per_file {gpc.get_global_rank()}, {gpc.get_local_rank(ParallelMode.PIPELINE)}: {len(per_file)}, {per_file}", flush=True) - for relative_path, reqs in per_file.items(): - file_path = self._get_file_path(relative_path) - print(f"debugg file_path {gpc.get_global_rank()}, {gpc.get_local_rank(ParallelMode.PIPELINE)}: {file_path}, {reqs}", flush=True) - with open(file_path, "rb") as file: - reqs = sorted(reqs, key=lambda req: self.storage_data[req.storage_index].offset) - for req in reqs: - item_md = self.storage_data[req.storage_index] - file_slice = self._slice_file(file, item_md) - # print(f"debugg file_slice {gpc.get_global_rank()}, {gpc.get_local_rank(ParallelMode.PIPELINE)}: {file_slice}", flush=True) - if req.type == LoadItemType.BYTE_IO: - assert False - bytes = io.BytesIO(file_slice.read(item_md.length)) - bytes.seek(0) - planner.load_bytes(req, bytes) - else: - tensor = cast(Tensor, torch.load(file_slice, map_location="cpu")) #att - tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths) - print(f"req: {req.dest_index.fqn}, {req}", flush=True) - target_tensor = planner.resolve_tensor(req).detach() - - assert ( - target_tensor.size() == tensor.size() - ), f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}" - target_tensor.copy_(tensor) - planner.commit_tensor(req, target_tensor) - - def read_data_with_broadcast(self, per_file: Dict[str, List[ReadItem]], planner: LoadPlanner): - for relative_path, reqs in per_file.items(): - # if dist.get_rank(self.data_parallel_process_group) == 0: - if gpc.get_local_rank(ParallelMode.DATA) == 0: - file_path = self._get_file_path(relative_path) - file = open(file_path, "rb") - dist.barrier(self.data_parallel_process_group) - reqs = sorted(reqs, key=lambda req: self.storage_data[req.storage_index].offset) - for req in reqs: - if gpc.get_local_rank(ParallelMode.DATA)== 0: - item_md = self.storage_data[req.storage_index] - file_slice = self._slice_file(file, item_md) - - if req.type == LoadItemType.BYTE_IO: - assert False - if gpc.get_local_rank(ParallelMode.DATA) == 0: - object_list = [io.BytesIO(file_slice.read(item_md.length))] - else: - object_list = [None] - - dist.broadcast_object_list( - object_list, - src=dist.get_global_rank(self.data_parallel_process_group, 0), - group=self.data_parallel_process_group, - device=get_current_device(), - ) - bytes = object_list[0] - bytes.seek(0) - planner.load_bytes(req, bytes) - else: - if gpc.get_local_rank(ParallelMode.DATA) == 0: - object_list = [cast(Tensor, torch.load(file_slice, map_location="cuda"))] - else: - object_list = [None] - dist.broadcast_object_list( - object_list, - src=dist.get_global_rank(self.data_parallel_process_group, 0), - group=self.data_parallel_process_group, - device=get_current_device(), - ) - tensor = object_list[0].cpu() #att - tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths) - target_tensor = planner.resolve_tensor(req).detach() - - assert ( - target_tensor.size() == tensor.size() - ), f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}" - target_tensor.copy_(tensor) - planner.commit_tensor(req, target_tensor) - - # Implementing the abstract function in StorageReader - def read_metadata(self) -> Metadata: - metadata_path = self._get_file_path(".metadata") - with open(metadata_path, "rb") as metadata_file: - metadata = pickle.load(metadata_file) - return metadata - - def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None: - self.storage_data = metadata.storage_data - assert self.storage_data is not None - - def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan: - return plan - - def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]: - return global_plan diff --git a/internlm/checkpoint/vescale/load_state_dict.py b/internlm/checkpoint/vescale/load_state_dict.py deleted file mode 100644 index 6ab14868c..000000000 --- a/internlm/checkpoint/vescale/load_state_dict.py +++ /dev/null @@ -1,102 +0,0 @@ -################################################################################ -# Copyright (c) Meta Platforms, Inc. and affiliates -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -################################################################################ -# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. -################################################################################ - -from typing import Optional -import torch.distributed as dist -from torch.distributed.checkpoint.planner import LoadPlanner -from torch.distributed.checkpoint.utils import _DistWrapper -from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner -from .filesystem import FileSystemReader -from .meta_type import STATE_DICT_TYPE -import time -from internlm.utils.logger import get_logger -from .vescale_planner import VeScaleLoadPlanner -from internlm.core.context import global_context as gpc -from internlm.core.context import ParallelMode - -logger = get_logger(__file__) - -META_DATA_FILE = ".metadata" - - -def load_state_dict( - state_dict: STATE_DICT_TYPE, - path: str, - process_group: Optional[dist.ProcessGroup] = None, - coordinator_rank: int = 0, - no_dist: bool = False, - planner: Optional[LoadPlanner] = None, - broadcast_tensors=False, - is_optimizer=False -) -> None: - load_start_time = time.time() - """ - [veScale version] Loads a distributed ``state_dict`` in SPMD style. Fix sub-group storage. - """ - print(f"load_state_dict: {path}", flush=True) - storage_reader = FileSystemReader( - path, - broadcast_tensors=broadcast_tensors, - data_parallel_process_group=process_group, - ) - - # Step 0: create distributed world based on process group and coordinator rank - distW = _DistWrapper(process_group, not no_dist, coordinator_rank) - if process_group: - distW.coordinator_rank = dist.get_global_rank(process_group, distW.coordinator_rank) - if planner is None: - planner = DefaultLoadPlanner() - plan_start_time = time.time() - - # Step 1: all processes create local read plan, - # then coordinator gathers all local plans and create global plan. - def local_step(): - assert planner is not None - meta_read_start_time = time.time() - metadata = storage_reader.read_metadata() - meat_read_cost_time = time.time() - meta_read_start_time - logger.info(f"Finish read meta file. Cost time: {meat_read_cost_time}s") - planner.set_up_planner(state_dict, metadata, distW.is_coordinator) - storage_reader.set_up_storage_reader(metadata, distW.is_coordinator) - - local_plan = planner.create_local_plan(is_optimizer=is_optimizer) - local_plan = storage_reader.prepare_local_plan(local_plan) - return local_plan - - def global_step(all_local_plans): - assert planner is not None - all_local_plans = planner.create_global_plan(all_local_plans) - all_local_plans = storage_reader.prepare_global_plan(all_local_plans) - return all_local_plans - - if isinstance(planner, VeScaleLoadPlanner): - central_plan = distW.reduce_scatter("plan", local_step, global_step) - else: - raise AssertionError("Unsupported planner for saving checkpoint") - load_ckpt_plan_cost_time = time.time() - plan_start_time - logger.info(f"Finish planning. Cost time: {load_ckpt_plan_cost_time}s") - - read_start_time = time.time() - - # Step 2: all processes read data from the given path - def read_data(): - assert planner is not None - # print(f"central_plan {gpc.get_local_rank(ParallelMode.PIPELINE)}: {central_plan}", flush=True) - final_local_plan = planner.finish_plan(central_plan) - all_reads = storage_reader.read_data(final_local_plan, planner) - all_reads.wait() - return None - - _ = distW.all_gather("read", read_data) - read_cost_time = time.time() - read_start_time - logger.info(f"Finish reading. Cost time: {read_cost_time}s") - - load_ckpt_cost_time = time.time() - load_start_time - logger.info(f"Finish loading. Cost time: {load_ckpt_cost_time}s") diff --git a/internlm/checkpoint/vescale/mem_checkpoint.py b/internlm/checkpoint/vescale/mem_checkpoint.py deleted file mode 100644 index 9f482550e..000000000 --- a/internlm/checkpoint/vescale/mem_checkpoint.py +++ /dev/null @@ -1,396 +0,0 @@ -################################################################################ -# -# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -################################################################################ -import io -import dataclasses -import os -import torch -from torch import multiprocessing -import threading -from typing import Callable, Dict, Any, DefaultDict, List, Optional -import pickle - -from . import bfile -from internlm.utils.logger import get_logger - -logger = get_logger(__file__) - -if hasattr(torch.storage, "TypedStorage"): - TypedStorage = torch.storage.TypedStorage -elif hasattr(torch.storage, "_TypedStorage"): - TypedStorage = torch.storage._TypedStorage - -# TypedStorage changes in pytorch 2. -if torch.__version__ >= "2": - - def untyped_storage(o): - return o.untyped_storage() - - def location_caster(o): - return o -elif torch.__version__ >= "1.11": - - def untyped_storage(o): - return o.storage()._storage - - def location_caster(o): - return o._storage if isinstance(o, TypedStorage) else o - - -try: - lib = torch.cuda.cudart() -except: - lib = None - - -def _bytes_to_tensor(b: bytes): - # Copied from `_object_to_tensor` in - # https://pytorch.org/docs/2.0/_modules/torch/distributed/distributed_c10d.html - byte_storage = torch.ByteStorage.from_buffer(b) - return torch.ByteTensor(byte_storage) - - -class PinnedStoragePool: - def __init__(self): - self._l = threading.Lock() - self._m = DefaultDict(set) - - def allocate(self, nbytes: int): - with self._l: - # We don't really need storage to have the exact size. So in theory we can find a - # bigger storage that may suit here. But so far we keep everything simple here. - s = self._m[nbytes] - if not s: - t = torch.empty([nbytes], dtype=torch.uint8) - t = t.share_memory_() - if lib is not None and nbytes != 0: - err = lib.cudaHostRegister(t.data_ptr(), t.numel() * t.element_size(), 0) - assert err == 0, err - storage = untyped_storage(t) - s.add(storage) - return s.pop() - - def deallocate(self, s): - # WARNING: Call deallocate when the reference to CPU tensor goes to zero - # so the memory pool will reuse the memory if possbile - # Othterwise, the memory pool will allocate memory on the used memory range, - # leading to cuda error 712 cudaErrorHostMemoryAlreadyRegistered - with self._l: - self._m[s.nbytes()].add(s) - - -GLOBAL_POOL = PinnedStoragePool() - -TID = threading.get_ident() - - -def copy_gpu_tensor_to_cpu_pinned_mem_pool(tensor: torch.Tensor, non_blocking=False) -> torch.Tensor: - """ - Copy a tensor on GPU to pinned memory pool (host CPU memory). - The input tensor will not be modified - Args: - tensor: a tensor on cuda device - Return: - a tensor on cpu, whose data is the same as input tensor - """ - m = {} - _old_warning = getattr(torch.storage, "_warn_typed_storage_removal", None) - torch.storage._warn_typed_storage_removal = lambda *args, **kwags: None - - def persistent_id(o): - if torch.is_storage(o) or isinstance(o, TypedStorage): - storage = o - if storage._cdata in m: - return storage._cdata - if storage.device.type != "cpu": - copied = GLOBAL_POOL.allocate(storage.nbytes()) - copied.copy_(storage, non_blocking=non_blocking) - if isinstance(storage, TypedStorage): - copied = storage._new_wrapped_storage(copied) - else: - copied = storage.clone() - m[storage._cdata] = copied - return storage._cdata - return - - b = io.BytesIO() - p = pickle.Pickler(b) - p.persistent_id = persistent_id - p.dump(tensor) - b.seek(0) - up = pickle.Unpickler(b) - up.persistent_load = lambda i: m[i] - cpu_tensor = up.load() - """ - assert type(tensor) == torch.Tensor - storage_obj = tensor.storage() - cpu_storage = GLOBAL_POOL.allocate(storage_obj.nbytes()) - - cpu_storage.copy_(storage_obj, non_blocking=non_blocking) - cpu_tensor = torch.tensor(cpu_storage) - """ - torch.storage._warn_typed_storage_removal = _old_warning - return cpu_tensor - - -def deallocate_cpu_tensor_in_pinned_mem_pool(tensor: torch.Tensor): - "Deallocate CPU tensor in the global pinned memory pool" - GLOBAL_POOL.deallocate(tensor.untyped_storage()) - - -class _CalledOnce: - def __init__(self, func): - self._l = threading.Lock() - self._func = func - self._res = None - self._called = False - - def __call__(self, *args: Any, **kwargs: Any) -> Any: - with self._l: - if self._called: - return self._res - self._called = True - self._res = self._func(*args, **kwargs) - return self._res - - -_LOCATION_TAG_LOCK = threading.Lock() - - -@dataclasses.dataclass -class _SaveArgs: - obj: object - storage_tags: list - pickle_module: __module__ - args: list - kwargs: dict - - -def _serialize_obj_with_map(a: _SaveArgs, as_shared_tensor=False): - """Called to serialize an object to a byte stream or a shared tensor. - - Args: - a (_SaveArgs): The save args consist of the original tensor to serialize, - the location tags, the pickle module and other args - as_shared_tensor (bool): Whether to serialize to a shared tensor or a byte stream. - Set False if no inter process communication will happen subsequently - - Returns: - byte stream or shared tensor: The serialized object - - """ - lm = {} - for storage, tag in a.storage_tags: - lm[storage._cdata] = tag - - def location_tag(storage): - loc = lm.get(storage._cdata, None) - if loc is None: - if storage.nbytes() == 0: - # if return None, save will succeed, but load will fail afterwards - return "cpu" - raise ValueError("Unknown storage") - return loc - - with _LOCATION_TAG_LOCK: - old_location_tag = torch.serialization.location_tag - torch.serialization.location_tag = location_tag - - bio = io.BytesIO() - pickle_module = a.pickle_module or pickle - torch.save(a.obj, bio, pickle_module=pickle_module, *a.args, **a.kwargs) - - torch.serialization.location_tag = old_location_tag - b = bio.getvalue() - if not as_shared_tensor: - return b - else: - return _bytes_to_tensor(b).share_memory_() - - -def _write(f, sa): - # Serialize tensor obj directly to a byte stream, no need to convert it - # back to a shared tensor because the whole procedure happens in the same - # process - b = _serialize_obj_with_map(sa) - bfile.safe_atomic_write(f, b) - - -@dataclasses.dataclass -class _PoolArgs: - pinned_pool: PinnedStoragePool - pooled_storages: list - - -class _WriteFunc: - def __init__(self, sa: _SaveArgs, pa: _PoolArgs, async_worker): - self._sa = sa - if self._sa.pickle_module == pickle: - # This makes wa serializable. - self._sa.pickle_module = None - self._pa = pa - self._async_worker = async_worker - - self._enable_mp = async_worker is not None and sa.pickle_module is None - self._des = _CalledOnce(self._des_do_not_call_directly) - self._l = threading.RLock() - self._serialized = None - self._bytes = None - - def _des_do_not_call_directly(self): - for s in self._pa.pooled_storages: - self._pa.pinned_pool.deallocate(s) - - def __del__(self): - self._des() - - @property - def serialized(self): - with self._l: - if self._serialized is None: - if self._enable_mp: - self._serialized = self._async_worker.apply(_serialize_obj_with_map, (self._sa, True)) - else: - self._serialized = _serialize_obj_with_map(self._sa) - self._des() - return self._serialized - - @property - def bytes(self): - if self._bytes is None: - with self._l: - if self._enable_mp: - self._bytes = self.serialized.numpy().tobytes() - else: - self._bytes = self.serialized - return self._bytes - - def __call__(self, file: str = None): - if file is None: - return self.bytes - - if self._async_worker: - self._async_worker.apply(_write, (file, self._sa)) - else: - _write(file, self._sa) - self._des() - - -class TorchCheckpointRecorder: - def __init__( - self, - fast_mode=None, - async_worker: multiprocessing.Pool = None, - pinned_pool=GLOBAL_POOL, - ): - self._thread_id = threading.get_ident() - self._m = {} - - # After 1.11, typed storage is publicly accessible. - condition = torch.__version__ >= "1.11" - self._fast_mode = fast_mode if fast_mode is not None else condition - # Safety check. - assert not self._fast_mode or condition - - self._async_worker = async_worker - self._pinned_pool = pinned_pool - - def __enter__(self): - self._old_save = torch.save - torch.save = self._save_wrapper - if self._fast_mode: - self._old_warning = getattr(torch.storage, "_warn_typed_storage_removal", None) - torch.storage._warn_typed_storage_removal = lambda *args, **kwags: None - return self - - def __exit__(self, *args): - torch.save = self._old_save - if self._fast_mode: - if self._old_warning: - torch.storage._warn_typed_storage_removal = self._old_warning - - def _save_wrapper(self, obj, f, pickle_module=pickle, *args, **kwargs): - if threading.get_ident() != self._thread_id or not isinstance(f, (str, os.PathLike)): - return self._old_save(obj, f, pickle_module, *args, **kwargs) - - if self._fast_mode: - func = self._copy_to_buffer(obj, pickle_module, *args, **kwargs) - else: - func = self._save_to_buffer(obj, pickle_module, *args, **kwargs) - - self._m[str(f)] = func - - def _save_to_buffer(self, obj, *args, **kwags): - b = io.BytesIO() - self._old_save(obj, b, *args, **kwags) - - def gen_func(b): - def func(f: str = None): - if f: - return bfile.safe_atomic_write(f, b.getvalue()) - return b.getvalue() - - return func - - return gen_func(b) - - def _copy_to_buffer(self, obj, pickle_module, *args, **kwargs): - m = {} - storage_tags = [] - pooled_storages = [] - - def persistent_id(o): - if torch.is_storage(o) or isinstance(o, TypedStorage): - storage = o - if storage._cdata in m: - return storage._cdata - if storage.device.type != "cpu": - copied = self._pinned_pool.allocate(storage.nbytes()) - pooled_storages.append(copied) - copied.copy_(storage, non_blocking=False) - if isinstance(storage, TypedStorage): - copied = storage._new_wrapped_storage(copied) - else: - copied = storage.clone() - m[storage._cdata] = copied - tag = torch.serialization.location_tag(location_caster(storage)) - storage_tags.append((copied, tag)) - return storage._cdata - return - - b = io.BytesIO() - p = pickle_module.Pickler(b) - p.persistent_id = persistent_id - p.dump(obj) - b.seek(0) - up = pickle_module.Unpickler(b) - up.persistent_load = lambda i: m[i] - nobj = up.load() - - sa = _SaveArgs( - obj=nobj, - storage_tags=storage_tags, - pickle_module=pickle_module, - args=args, - kwargs=kwargs, - ) - pa = _PoolArgs(pinned_pool=self._pinned_pool, pooled_storages=pooled_storages) - - return _WriteFunc(sa, pa, self._async_worker) - - @property - def files(self) -> Dict[str, Callable[[Optional[List[str]]], Optional[bytes]]]: - return self._m diff --git a/internlm/checkpoint/vescale/mem_file_service_pb2.py b/internlm/checkpoint/vescale/mem_file_service_pb2.py deleted file mode 100644 index 5966cd239..000000000 --- a/internlm/checkpoint/vescale/mem_file_service_pb2.py +++ /dev/null @@ -1,66 +0,0 @@ -################################################################################ -# -# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -################################################################################ -# Generated by the protocol buffer compiler. DO NOT EDIT! -# source: checkpoint/utilities/server/mem_file_service.proto -# Protobuf Python Version: 4.25.1 -"""Generated protocol buffer code.""" - -from google.protobuf import descriptor as _descriptor -from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import symbol_database as _symbol_database -from google.protobuf.internal import builder as _builder -# @@protoc_insertion_point(imports) - -_sym_db = _symbol_database.Default() - - -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n2checkpoint/utilities/server/mem_file_service.proto">\n\x1dVeScaleCheckpointWriteRequest\x12\x0f\n\x07\x63ontent\x18\x01 \x01(\x0c\x12\x0c\n\x04name\x18\x08 \x01(\t" \n\x1eVeScaleCheckpointWriteResponse",\n\x1cVeScaleCheckpointReadRequest\x12\x0c\n\x04name\x18\x01 \x01(\t"0\n\x1dVeScaleCheckpointReadResponse\x12\x0f\n\x07\x63ontent\x18\x01 \x01(\x0c"M\n\x1eVeScaleCheckpointRenameRequest\x12\x0b\n\x03src\x18\x01 \x01(\t\x12\x0b\n\x03\x64st\x18\x02 \x01(\t\x12\x11\n\toverwrite\x18\x03 \x01(\x08"!\n\x1fVeScaleCheckpointRenameResponse".\n\x1eVeScaleCheckpointRemoveRequest\x12\x0c\n\x04name\x18\x01 \x01(\t"!\n\x1fVeScaleCheckpointRemoveResponse"/\n\x1fVeScaleCheckpointListdirRequest\x12\x0c\n\x04name\x18\x01 \x01(\t"1\n VeScaleCheckpointListdirResponse\x12\r\n\x05names\x18\x01 \x03(\t".\n\x1eVeScaleCheckpointExistsRequest\x12\x0c\n\x04name\x18\x01 \x01(\t"1\n\x1fVeScaleCheckpointExistsResponse\x12\x0e\n\x06\x65xists\x18\x01 \x01(\x08\x32\xf9\x03\n\x1fVeScaleCheckpointMemFileService\x12L\n\x05Write\x12\x1e.VeScaleCheckpointWriteRequest\x1a\x1f.VeScaleCheckpointWriteResponse"\x00(\x01\x12I\n\x04Read\x12\x1d.VeScaleCheckpointReadRequest\x1a\x1e.VeScaleCheckpointReadResponse"\x00\x30\x01\x12M\n\x06Rename\x12\x1f.VeScaleCheckpointRenameRequest\x1a .VeScaleCheckpointRenameResponse"\x00\x12M\n\x06Remove\x12\x1f.VeScaleCheckpointRemoveRequest\x1a .VeScaleCheckpointRemoveResponse"\x00\x12P\n\x07Listdir\x12 .VeScaleCheckpointListdirRequest\x1a!.VeScaleCheckpointListdirResponse"\x00\x12M\n\x06\x45xists\x12\x1f.VeScaleCheckpointExistsRequest\x1a .VeScaleCheckpointExistsResponse"\x00\x62\x06proto3' -) - -_globals = globals() -_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, "checkpoint.utilities.server.mem_file_service_pb2", _globals) -if _descriptor._USE_C_DESCRIPTORS is False: - DESCRIPTOR._options = None - _globals["_VESCALECHECKPOINTWRITEREQUEST"]._serialized_start = 54 - _globals["_VESCALECHECKPOINTWRITEREQUEST"]._serialized_end = 116 - _globals["_VESCALECHECKPOINTWRITERESPONSE"]._serialized_start = 118 - _globals["_VESCALECHECKPOINTWRITERESPONSE"]._serialized_end = 150 - _globals["_VESCALECHECKPOINTREADREQUEST"]._serialized_start = 152 - _globals["_VESCALECHECKPOINTREADREQUEST"]._serialized_end = 196 - _globals["_VESCALECHECKPOINTREADRESPONSE"]._serialized_start = 198 - _globals["_VESCALECHECKPOINTREADRESPONSE"]._serialized_end = 246 - _globals["_VESCALECHECKPOINTRENAMEREQUEST"]._serialized_start = 248 - _globals["_VESCALECHECKPOINTRENAMEREQUEST"]._serialized_end = 325 - _globals["_VESCALECHECKPOINTRENAMERESPONSE"]._serialized_start = 327 - _globals["_VESCALECHECKPOINTRENAMERESPONSE"]._serialized_end = 360 - _globals["_VESCALECHECKPOINTREMOVEREQUEST"]._serialized_start = 362 - _globals["_VESCALECHECKPOINTREMOVEREQUEST"]._serialized_end = 408 - _globals["_VESCALECHECKPOINTREMOVERESPONSE"]._serialized_start = 410 - _globals["_VESCALECHECKPOINTREMOVERESPONSE"]._serialized_end = 443 - _globals["_VESCALECHECKPOINTLISTDIRREQUEST"]._serialized_start = 445 - _globals["_VESCALECHECKPOINTLISTDIRREQUEST"]._serialized_end = 492 - _globals["_VESCALECHECKPOINTLISTDIRRESPONSE"]._serialized_start = 494 - _globals["_VESCALECHECKPOINTLISTDIRRESPONSE"]._serialized_end = 543 - _globals["_VESCALECHECKPOINTEXISTSREQUEST"]._serialized_start = 545 - _globals["_VESCALECHECKPOINTEXISTSREQUEST"]._serialized_end = 591 - _globals["_VESCALECHECKPOINTEXISTSRESPONSE"]._serialized_start = 593 - _globals["_VESCALECHECKPOINTEXISTSRESPONSE"]._serialized_end = 642 - _globals["_VESCALECHECKPOINTMEMFILESERVICE"]._serialized_start = 645 - _globals["_VESCALECHECKPOINTMEMFILESERVICE"]._serialized_end = 1150 -# @@protoc_insertion_point(module_scope) diff --git a/internlm/checkpoint/vescale/mem_file_service_pb2_grpc.py b/internlm/checkpoint/vescale/mem_file_service_pb2_grpc.py deleted file mode 100644 index 4558bfa65..000000000 --- a/internlm/checkpoint/vescale/mem_file_service_pb2_grpc.py +++ /dev/null @@ -1,321 +0,0 @@ -################################################################################ -# -# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -################################################################################ -# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! -"""Client and server classes corresponding to protobuf-defined services.""" - -import grpc - -from . import ( - mem_file_service_pb2 as checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2, -) - - -class VeScaleCheckpointMemFileServiceStub: - """Missing associated documentation comment in .proto file.""" - - def __init__(self, channel): - """Constructor. - - Args: - channel: A grpc.Channel. - """ - self.Write = channel.stream_unary( - "/VeScaleCheckpointMemFileService/Write", - request_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointWriteRequest.SerializeToString, - response_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointWriteResponse.FromString, - ) - self.Read = channel.unary_stream( - "/VeScaleCheckpointMemFileService/Read", - request_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointReadRequest.SerializeToString, - response_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointReadResponse.FromString, - ) - self.Rename = channel.unary_unary( - "/VeScaleCheckpointMemFileService/Rename", - request_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRenameRequest.SerializeToString, - response_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRenameResponse.FromString, - ) - self.Remove = channel.unary_unary( - "/VeScaleCheckpointMemFileService/Remove", - request_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRemoveRequest.SerializeToString, - response_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRemoveResponse.FromString, - ) - self.Listdir = channel.unary_unary( - "/VeScaleCheckpointMemFileService/Listdir", - request_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointListdirRequest.SerializeToString, - response_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointListdirResponse.FromString, - ) - self.Exists = channel.unary_unary( - "/VeScaleCheckpointMemFileService/Exists", - request_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointExistsRequest.SerializeToString, - response_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointExistsResponse.FromString, - ) - - -class VeScaleCheckpointMemFileServiceServicer: - """Missing associated documentation comment in .proto file.""" - - def Write(self, request_iterator, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") - - def Read(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") - - def Rename(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") - - def Remove(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") - - def Listdir(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") - - def Exists(self, request, context): - """Missing associated documentation comment in .proto file.""" - context.set_code(grpc.StatusCode.UNIMPLEMENTED) - context.set_details("Method not implemented!") - raise NotImplementedError("Method not implemented!") - - -def add_VeScaleCheckpointMemFileServiceServicer_to_server(servicer, server): - rpc_method_handlers = { - "Write": grpc.stream_unary_rpc_method_handler( - servicer.Write, - request_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointWriteRequest.FromString, - response_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointWriteResponse.SerializeToString, - ), - "Read": grpc.unary_stream_rpc_method_handler( - servicer.Read, - request_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointReadRequest.FromString, - response_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointReadResponse.SerializeToString, - ), - "Rename": grpc.unary_unary_rpc_method_handler( - servicer.Rename, - request_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRenameRequest.FromString, - response_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRenameResponse.SerializeToString, - ), - "Remove": grpc.unary_unary_rpc_method_handler( - servicer.Remove, - request_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRemoveRequest.FromString, - response_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRemoveResponse.SerializeToString, - ), - "Listdir": grpc.unary_unary_rpc_method_handler( - servicer.Listdir, - request_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointListdirRequest.FromString, - response_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointListdirResponse.SerializeToString, - ), - "Exists": grpc.unary_unary_rpc_method_handler( - servicer.Exists, - request_deserializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointExistsRequest.FromString, - response_serializer=checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointExistsResponse.SerializeToString, - ), - } - generic_handler = grpc.method_handlers_generic_handler("VeScaleCheckpointMemFileService", rpc_method_handlers) - server.add_generic_rpc_handlers((generic_handler,)) - - -# This class is part of an EXPERIMENTAL API. -class VeScaleCheckpointMemFileService: - """Missing associated documentation comment in .proto file.""" - - @staticmethod - def Write( - request_iterator, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.stream_unary( - request_iterator, - target, - "/VeScaleCheckpointMemFileService/Write", - checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointWriteRequest.SerializeToString, - checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointWriteResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - ) - - @staticmethod - def Read( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_stream( - request, - target, - "/VeScaleCheckpointMemFileService/Read", - checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointReadRequest.SerializeToString, - checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointReadResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - ) - - @staticmethod - def Rename( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_unary( - request, - target, - "/VeScaleCheckpointMemFileService/Rename", - checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRenameRequest.SerializeToString, - checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRenameResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - ) - - @staticmethod - def Remove( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_unary( - request, - target, - "/VeScaleCheckpointMemFileService/Remove", - checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRemoveRequest.SerializeToString, - checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointRemoveResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - ) - - @staticmethod - def Listdir( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_unary( - request, - target, - "/VeScaleCheckpointMemFileService/Listdir", - checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointListdirRequest.SerializeToString, - checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointListdirResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - ) - - @staticmethod - def Exists( - request, - target, - options=(), - channel_credentials=None, - call_credentials=None, - insecure=False, - compression=None, - wait_for_ready=None, - timeout=None, - metadata=None, - ): - return grpc.experimental.unary_unary( - request, - target, - "/VeScaleCheckpointMemFileService/Exists", - checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointExistsRequest.SerializeToString, - checkpoint_dot_utilities_dot_server_dot_mem__file__service__pb2.VeScaleCheckpointExistsResponse.FromString, - options, - channel_credentials, - insecure, - call_credentials, - compression, - wait_for_ready, - timeout, - metadata, - ) diff --git a/internlm/checkpoint/vescale/mem_server_lib.py b/internlm/checkpoint/vescale/mem_server_lib.py deleted file mode 100644 index 2bf3af242..000000000 --- a/internlm/checkpoint/vescale/mem_server_lib.py +++ /dev/null @@ -1,307 +0,0 @@ -################################################################################ -# -# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -################################################################################ -import dataclasses -import io -import grpc -from typing import Tuple -import os -import threading -import contextlib -import pathlib -import subprocess -import time -import queue -from concurrent import futures - -from . import mem_file_service_pb2 -from . import mem_file_service_pb2_grpc - - -class _Directory(dict): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.lock = threading.RLock() - - -@dataclasses.dataclass -class _File: - content: bytes = b"" - - -_CHUNK_SIZE = 2 * 1024 * 1024 - - -def get_mem_server_sock_file(name: str): - return f"/var/tmp/mem_server_{name}.sock" - - -class MemFileServicer(mem_file_service_pb2_grpc.VeScaleCheckpointMemFileServiceServicer): - def __init__(self): - self._d = _Directory() - - def Write(self, request_iterator, ctx: grpc.ServicerContext): - b = io.BytesIO() - name = None - for req in request_iterator: - if name is None: - if not req.name: - ctx.abort(grpc.StatusCode.INVALID_ARGUMENT, "Name must be specified.") - name = req.name - d, bn = self._iterate_dir(name, ctx, create=True) - b.write(req.content) - if name: - with d.lock: - d[bn] = _File(content=b.getvalue()) - return mem_file_service_pb2.VeScaleCheckpointWriteResponse() - - def Read(self, req, ctx: grpc.ServicerContext): - d, bn = self._iterate_dir(req.name, ctx) - with d.lock: - if bn not in d or not isinstance(d[bn], _File): - ctx.abort(grpc.StatusCode.NOT_FOUND, f"{req.name} not found.") - f: _File = d[bn] - cur = 0 - while cur < len(f.content): - yield mem_file_service_pb2.VeScaleCheckpointReadResponse(content=f.content[cur : cur + _CHUNK_SIZE]) - cur += _CHUNK_SIZE - - def Rename(self, req, ctx: grpc.ServicerContext): - src_dir, src_bn = self._iterate_dir(req.src, ctx) - dst_dir, dst_bn = self._iterate_dir(req.dst, ctx) - if src_dir != dst_dir: - ctx.abort(grpc.StatusCode.UNIMPLEMENTED, "Rename across dir is not supported.") - d = src_dir - with d.lock: - if src_bn not in src_bn: - ctx.abort(grpc.StatusCode.NOT_FOUND, f"{req.src} is not found.") - if not req.overwrite and dst_bn in d: - ctx.abort(grpc.StatusCode.ALREADY_EXISTS, f"{req.dst} already exists.") - d[dst_bn] = d[src_bn] - del d[src_bn] - return mem_file_service_pb2.VeScaleCheckpointRenameResponse() - - def Remove(self, req, ctx: grpc.ServicerContext): - d, bn = self._iterate_dir(req.name, ctx) - if bn not in d: - ctx.abort(grpc.StatusCode.NOT_FOUND, f"{req.name} not found.") - with d.lock: - del d[bn] - return mem_file_service_pb2.VeScaleCheckpointRemoveResponse() - - def Listdir(self, req, ctx: grpc.ServicerContext): - d, _ = self._iterate_dir(os.path.join(req.name, "*")) - if d is None: - return mem_file_service_pb2.VeScaleCheckpointListdirResponse() - - resp = mem_file_service_pb2.VeScaleCheckpointListdirResponse() - with d.lock: - for name in d: - resp.names.append(name) - return resp - - def Exists(self, req, ctx: grpc.ServicerContext): - d, bn = self._iterate_dir(req.name) - if d is None: - return mem_file_service_pb2.VeScaleCheckpointExistsResponse(exists=False) - with d.lock: - return mem_file_service_pb2.VeScaleCheckpointExistsResponse(exists=bn in d) - - def _iterate_dir(self, name: str, ctx: grpc.ServicerContext = None, create=False) -> Tuple[_Directory, str]: - if ctx is None: - - class FakeCtx: - def abort(*args, **kwargs): - return None, None - - ctx = FakeCtx() - name = str(pathlib.Path(name).absolute())[1:] - parts = name.split("/") - cur = self._d - for part in parts[:-1]: - with cur.lock: - if part not in cur: - if not create: - return ctx.abort(grpc.StatusCode.NOT_FOUND, f"{part} doesn't exist.") - else: - cur[part] = _Directory() - cur = cur[part] - if not isinstance(cur, _Directory): - return ctx.abort( - grpc.StatusCode.ALREADY_EXISTS, - f"{part} already exist as a file.", - ) - return cur, parts[-1] - - -def start_server(name: str, force=False): - sock = get_mem_server_sock_file(name) - if os.path.exists(sock) and not force: - raise OSError("Mem server is already running.") - server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) - mem_file_service_pb2_grpc.add_VeScaleCheckpointMemFileServiceServicer_to_server(MemFileServicer(), server) - server.add_insecure_port(f"unix:{sock}") - server.start() - return server - - -# --- Below is general file interface --- - -_STUB_CACHE = {} -_STUB_CACHE_LOCK = threading.Lock() -SCHEMA = "/local_mem/" - - -def get_prefix(name: str): - return SCHEMA + name - - -def _get_mem_name_and_name(path: str): - path = path[len(SCHEMA) :] - pos = path.find("/") - if pos == -1: - return path, "/" - else: - return path[:pos], path[pos:] - - -def _get_stub_and_name( - path: str, -) -> Tuple[mem_file_service_pb2_grpc.VeScaleCheckpointMemFileServiceStub, str]: - mem_name, name = _get_mem_name_and_name(path) - if mem_name not in _STUB_CACHE: - c = grpc.insecure_channel(f"unix:{get_mem_server_sock_file(mem_name)}") - with _STUB_CACHE_LOCK: - _STUB_CACHE[mem_name] = mem_file_service_pb2_grpc.VeScaleCheckpointMemFileServiceStub(c) - return _STUB_CACHE[mem_name], name - - -class _FileLike: - def __init__(self, name: str, mode: str): - if mode not in ["rb", "wb"]: - raise NotImplementedError(f"{mode} is not implemented.") - self._stub, self._name = _get_stub_and_name(name) - self._mode = mode - self._is_write = "w" in mode - if self._is_write: - self._write_async() - self._read_buf = None - - @property - def read_buf(self): - if self._read_buf is None: - self._read_buf = io.BytesIO() - for resp in self._stub.Read(mem_file_service_pb2.VeScaleCheckpointReadRequest(name=self._name)): - self._read_buf.write(resp.content) - self._read_buf.seek(0) - return self._read_buf - - def __getattr__(self, name): - if not self._is_write: - return getattr(self.read_buf, name) - - def _write_async(self): - self._q = queue.Queue() - - def streaming(): - while True: - content, eof = self._q.get() - if eof: - break - cur = 0 - while cur < len(content): - req = mem_file_service_pb2.VeScaleCheckpointWriteRequest(content=content[cur : cur + _CHUNK_SIZE]) - if cur == 0: - req.name = self._name - yield req - cur += _CHUNK_SIZE - - self._write_future = self._stub.Write.future(streaming()) - - def write(self, content): - self._q.put((content, False)) - - def close(self): - if self._is_write: - self._q.put((None, True)) - self._write_future.result() - - -@contextlib.contextmanager -def open(name, mode) -> io.FileIO: - f = _FileLike(name, mode) - try: - yield f - finally: - f.close() - - -def rename(src, dst, overwrite=False): - stub, src_name = _get_stub_and_name(src) - dst_stub, dst_name = _get_stub_and_name(dst) - if stub != dst_stub: - raise ValueError(f"Rename across mem file system is not supported. {src} {dst}") - stub.Rename(mem_file_service_pb2.VeScaleCheckpointRenameRequest(src=src_name, dst=dst_name, overwrite=overwrite)) - - -def remove(name): - stub, subname = _get_stub_and_name(name) - stub.Remove(mem_file_service_pb2.VeScaleCheckpointRemoveRequest(name=subname)) - - -def listdir(name): - try: - stub, subname = _get_stub_and_name(name) - resp = stub.Listdir(mem_file_service_pb2.VeScaleCheckpointListdirRequest(name=subname)) - return list(resp.names) - except grpc.RpcError as e: - if e.code() == grpc.StatusCode.UNAVAILABLE: - return [] - raise - - -def exists(name): - try: - stub, subname = _get_stub_and_name(name) - resp = stub.Exists(mem_file_service_pb2.VeScaleCheckpointExistsRequest(name=subname)) - return resp.exists - except grpc.RpcError as e: - if e.code() == grpc.StatusCode.UNAVAILABLE: - return False - raise - - -# --- interface done --- - - -def start_server_in_new_process(name: str): - filename = os.path.join(os.path.dirname(os.path.abspath(__file__)), "detached_mem_server.py") - return subprocess.Popen(["python3", filename, f"--name={name}"]) - - -def wait_until_fs_ready(name: str, timeout=120): - stub, _ = _get_stub_and_name(os.path.join(SCHEMA, name)) - t0 = time.time() - while time.time() < t0 + timeout: - try: - stub.Listdir(mem_file_service_pb2.VeScaleCheckpointListdirRequest(name="/")) - return True - except grpc.RpcError as e: - if e.code() == grpc.StatusCode.UNAVAILABLE: - time.sleep(0.1) - continue - raise - return False diff --git a/internlm/checkpoint/vescale/meta_type.py b/internlm/checkpoint/vescale/meta_type.py deleted file mode 100644 index a669efba0..000000000 --- a/internlm/checkpoint/vescale/meta_type.py +++ /dev/null @@ -1,38 +0,0 @@ -################################################################################ -# Copyright (c) Meta Platforms, Inc. and affiliates -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -################################################################################ -# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. -################################################################################ -# meta_type.py saves all constants and data types commonly used in vescale.checkpoint - -from enum import Enum -from typing import Dict, Any, TypeVar -from typing_extensions import Protocol, runtime_checkable - - -STATE_DICT_TYPE = Dict[str, Any] - -MODEL_STR = "model" -OPTIMIZER_STR = "optimizer" -STATE_DICT_STR = "state_dict" - - -class SupportedStrategy(Enum): - Megatron = 0 - FSDP = 1 - VeScale = 2 - - -@runtime_checkable -class Stateful(Protocol): - def state_dict(self) -> Dict[str, Any]: ... - - def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ... - - -T = TypeVar("T", bound=Stateful) -CheckpointState = Dict[str, T] diff --git a/internlm/checkpoint/vescale/placement_types.py b/internlm/checkpoint/vescale/placement_types.py deleted file mode 100644 index d5f612704..000000000 --- a/internlm/checkpoint/vescale/placement_types.py +++ /dev/null @@ -1,563 +0,0 @@ -################################################################################ -# Copyright (c) Meta Platforms, Inc. and affiliates -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -################################################################################ -# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. -################################################################################ - -from dataclasses import dataclass -from typing import List, NamedTuple, Optional, Tuple, cast - -import torch -import torch.distributed.distributed_c10d as c10d - -from .device_mesh import DeviceMesh - - -class Placement: - # base class Placement type - - # convenient utils to check for placement types - def is_shard(self, dim: Optional[int] = None) -> bool: - if dim is not None and isinstance(self, Shard): - return self.dim == dim - else: - return isinstance(self, Shard) - - def is_interleaved_shard(self, dim: Optional[int] = None) -> bool: - if dim is not None and isinstance(self, InterleavedShard): - return self.dim == dim - else: - return isinstance(self, InterleavedShard) - - def is_replicate(self) -> bool: - return isinstance(self, Replicate) - - def is_partial(self) -> bool: - return isinstance(self, Partial) - - def serialize_to_tensor(self, device) -> torch.Tensor: - if self.is_replicate(): - return torch.tensor([0, 0, 0], device=device, dtype=torch.int64) - elif self.is_partial(): - return torch.tensor([1, 0, 0], device=device, dtype=torch.int64) - elif self.is_shard(): - return torch.tensor([2, self.dim, 0], device=device, dtype=torch.int64) - elif self.is_interleaved_shard(): - return torch.tensor([3, self.dim, self.interleaved_size], device=device, dtype=torch.int64) - - @staticmethod - def serialize_from_tensor(tensor: torch.Tensor): - if tensor[0] == 0: - return Replicate() - elif tensor[0] == 1: - return Partial() - elif tensor[0] == 2: - return Shard(dim=tensor[1].item()) - elif tensor[0] == 3: - return InterleavedShard(dim=tensor[1].item(), interleaved_size=tensor[2].item()) - - -class Shard(Placement): - # shard placement, shard on a dim - def __init__(self, dim: int): - self.dim = dim - - def _split_tensor( - self, - tensor: torch.Tensor, - num_chunks: int, - *, - with_padding: bool = True, - contiguous: bool = True, - ) -> Tuple[List[torch.Tensor], List[int]]: - """ - This function uses torch.chunk to split a tensor into num_chunks shards along - the Shard placement dimension, and return a list of shards with their pad sizes. - - Keyword args: - with_padding (bool, optional): when True, we pad the tensor on the last - few ranks before calling the collectives (i.e. scatter/all_gather, etc.). - This is because collectives usually require equal size tensor inputs - - Example: - >>> Given a 2D global tensor with Shard(0) - >>> Run this method: - >>> torch.chunk(torch.tensor([[i] * 2 for i in range(13)]), num_chunks=6, dim=0) - - tensor1([[0, 0], - [1, 1], - [2, 2]]) - - tensor2([[3, 3], - [4, 4], - [5, 5]]) - - tensor3([[6, 6], - [7, 7], - [8, 8]]) - - tensor4([[ 9, 9], - [10, 10], - [11, 11]]) - - tensor5([[12, 12], - [, ], - [, ]]) - - ([[, ], - [, ], - [, ]]) - """ - assert self.dim <= tensor.ndim, f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" - assert tensor.size(self.dim) > 0, f"Tensor size along dim{self.dim} is 0. There is nothing to be sharded." - - # chunk tensor over dimension `dim` into n slices with padding if necessary - tensor_list = list(torch.chunk(tensor, num_chunks, dim=self.dim)) - # compute the chunk size inline with ``torch.chunk`` (round up to int) - full_chunk_size = (tensor.size(self.dim) + num_chunks - 1) // num_chunks - - # Compute chunk size for each chunk for ``self.dim`` - chunk_sizes = [tensor_list[idx].size(self.dim) if idx < len(tensor_list) else 0 for idx in range(num_chunks)] - # Compute pad size on each chunk - pad_sizes = [full_chunk_size - chunk_size for chunk_size in chunk_sizes] - - # Reuse tensor to fill empty chunk with empty tensor - num_empty_tensors = num_chunks - len(tensor_list) - if num_empty_tensors > 0: - tensor_size = list(tensor_list[0].size()) - tensor_size = [size if idx != self.dim else 0 for idx, size in enumerate(tensor_size)] - tensor = tensor.new_zeros(tensor_size) # (allocate empty chunk) - for _ in range(num_empty_tensors): - tensor_list.append(tensor) - - if with_padding or contiguous: - shard_list = [] - for shard, pad_size in zip(tensor_list, pad_sizes): - # Fill the empty tensor with zeroes with padding. - if with_padding and pad_size > 0: - shard = self._pad_tensor(shard, pad_size) - shard = shard.contiguous() if contiguous else shard - shard_list.append(shard) - return shard_list, pad_sizes - else: - return tensor_list, pad_sizes - - def _pad_tensor( - self, - tensor: torch.Tensor, - pad_size: int, - ) -> torch.Tensor: - pad = [0, 0] * (tensor.ndim - self.dim) - pad[-1] = pad_size - return torch.nn.functional.pad(tensor, pad) - - def _unpad_tensor( - self, - tensor: torch.Tensor, - pad_size: int, - ) -> torch.Tensor: - return tensor.narrow( - self.dim, - start=0, - length=tensor.size(self.dim) - pad_size, - ) - - def _local_shard_size_on_dim( - self, - size_on_dim: int, - num_chunks: int, - rank: int, - return_offset: bool = False, - ) -> Tuple[int, int]: - """ - returns the local shard size and offset on a given tensor dim - """ - assert ( - size_on_dim >= num_chunks - ), f"Size to be sharded on dim {self.dim} must be at least as large as the number of devices in that dimension {num_chunks}" - - # Compute the chunk size inline with ``torch.chunk`` - full_chunk_size = (size_on_dim + num_chunks - 1) // num_chunks - - # Compute chunk size for each chunk on the dimension. - chunk_sizes = [ - max( - min(size_on_dim, full_chunk_size * (idx + 1)) - full_chunk_size * idx, - 0, - ) - for idx in range(num_chunks) - ] - local_shard_size = chunk_sizes[rank] - - local_offset_on_dim = -1 - if return_offset: - # Return global tensor dim size of current dimension if for empty shard - # to represent the end of the corresponding tensor dim. - local_offset_on_dim = sum(chunk_sizes[:rank]) - - return (local_shard_size, local_offset_on_dim) - - def __hash__(self) -> int: - ret = self.dim + 128 # restrict sharding dim in [-128, +128]; should be sufficient - assert ret >= 0 - return ret - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Shard): - return False - return self.dim == other.dim - - def __repr__(self) -> str: - """ - machine readable representation of the Shard placement - """ - return f"Shard(dim={self.dim})" - - def __str__(self) -> str: - """human readable representation of the Shard placement""" - return f"S({self.dim})" - - -class Replicate(Placement): - # replicate placement - def __hash__(self) -> int: - # every replicate placement is the same - return -1 - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Replicate): - return False - return True - - def __repr__(self) -> str: - """ - machine readable representation of the Replicate placement - """ - return "Replicate()" - - def __str__(self) -> str: - """ - human readable representation of the Replicate placement - """ - return "R" - - -class Partial(Placement): - # This is a default partial placement with element-wise reduce op - # when doing reduction it follows the contract of `_to_replicate` - # and `_to_shard` to do the reduction and convert the local tensor - # to the corresponding state (replicate or shard) - # - # We can implement custom reductions as needed by subclassing this - # class and override those contracts. - - def __init__(self, reduce_op: c10d.ReduceOp.RedOpType = c10d.ReduceOp.SUM): - self.reduce_op: c10d.ReduceOp.RedOpType = reduce_op - - def __hash__(self) -> int: - ret = -3 - hash(self.reduce_op) # hash(reduce_op) gives 0~8 - assert ret <= -3 - return ret - - def __eq__(self, other: object) -> bool: - if not isinstance(other, Partial): - return False - return self.reduce_op == other.reduce_op - - def __repr__(self) -> str: - """ - machine readable representation of the Partial placement - """ - return f"Partial(reduce_op={self.reduce_op})" - - def __str__(self) -> str: - """ - human readable representation of the Partial placement - """ - return "P" - - -class InterleavedShard(Shard): - """ - The major difference between this placement and Shard is that the global - tensor with a `InterleavedShard` placement is not contiguous. But you can - always treat a InterleavedShard(dim=x, interleaved_size=y) as a - Shard(dim=x+1)) on a tensor by reshaping the original one from - ``[..., size(x), ...]`` to ``[..., y, size(x) // y, ...]`` - - NOTE: We currently don't support padding in InterleavedShard, which means - we cannot interleaved shard a tensor when it's size is not divisible by - the multiply of interleaved_size and corresponding mesh size. - """ - - def __init__(self, dim: int, interleaved_size: int): - self.dim = dim - # TODO: make this attribute a list to support multi interleaved shard - self.interleaved_size = interleaved_size - - def _split_tensor( - self, - tensor: torch.Tensor, - num_chunks: int, - *, - contiguous: bool = True, - ) -> Tuple[List[torch.Tensor]]: - assert self.dim <= tensor.ndim, f"Interleaved Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" - assert tensor.size(self.dim) > 0, f"Tensor size along dim {self.dim} is 0. There is nothing to be sharded." - assert ( - tensor.size(self.dim) % self.interleaved_size == 0 - ), f"Tensor size along dim {self.dim} is not a multiple of interleaved size {self.interleaved_size}." - assert ( - tensor.size(self.dim) // self.interleaved_size - ) % num_chunks == 0, "InterleavedShard doesn't allow padding" - - # step 1: reshape tensor - tensor = tensor.view(tensor.shape[: self.dim] + (self.interleaved_size, -1) + tensor.shape[self.dim + 1 :]) - - # step 2: split tensor - # chunk tensor over dimension `dim` into n slices with padding if necessary - tensor_list = list(torch.chunk(tensor, num_chunks, dim=self.dim + 1)) - - # step 3: reshape back - result_list = [] - for t in tensor_list: - if contiguous: - t = t.contiguous() - # NOTE: view op might be not okay here, because tensor returned by chunk op - # is not contiguous. - shard = t.reshape(tensor.shape[: self.dim] + (-1,) + tensor.shape[self.dim + 2 :]) - result_list.append(shard) - - return result_list - - def _local_shard_size_on_dim( - self, - size_on_dim: int, - num_chunks: int, - rank: int, - return_offset: bool = False, - ) -> Tuple[int, int]: - """ - returns the local shard size and offset on a given tensor dim. - NOTE: argument ``rank`` and ``return_offset`` is useless here. The reason for - keeping them is to align this API with the one of ``Shard`` placement. - """ - assert ( - size_on_dim >= num_chunks - ), f"Size to be sharded on dim {self.dim} must be at least as large as the number of devices in that dimension {num_chunks}" - - # Compute the chunk size inline with ``torch.chunk`` - full_chunk_size = size_on_dim // num_chunks - return (full_chunk_size, None) - - def __hash__(self) -> int: - assert self.dim >= 0 and self.interleaved_size >= 0, "negatives (-1 & -2) can result in hash collison" - return hash((self.dim, self.interleaved_size)) - - def __repr__(self) -> str: - return f"InterleavedShard(dim={self.dim}, interleaved_size={self.interleaved_size})" - - def __str__(self) -> str: - return f"IS({self.dim}, {self.interleaved_size})" - - def __eq__(self, other: object) -> bool: - if not isinstance(other, InterleavedShard): - return False - return self.dim == other.dim and self.interleaved_size == other.interleaved_size - - -class TensorMeta(NamedTuple): - # simple named tuple to represent tensor metadata - # intentionally to stay simple only for sharding - # propagation purposes. - shape: torch.Size - stride: Tuple[int, ...] - dtype: torch.dtype - - def __hash__(self) -> int: - assert isinstance(self.stride, Tuple) - return hash((self.shape, self.stride, self.dtype)) - - def __eq__(self, __o: object) -> bool: - if not isinstance(__o, TensorMeta): - return False - return ( - self.shape == __o.shape # type: ignore[union-attr] - and self.stride == __o.stride # type: ignore[union-attr] - and self.dtype == __o.dtype # type: ignore[union-attr] - ) - - -# used internally to propagate the placements - - -@dataclass -class DTensorSpec: - mesh: DeviceMesh - placements: Tuple[Placement, ...] - - # tensor meta will only be set during sharding propagation - tensor_meta: Optional[TensorMeta] = None - - def __hash__(self) -> int: - # hashing and equality check for DTensorSpec are used to cache the sharding - # propagation results. We only need to consider the mesh, placements, tensor_meta. - assert isinstance(self.placements, Tuple) and all(isinstance(p, Placement) for p in self.placements) - - return hash( - ( - self.mesh, - tuple(self.placements), - self.tensor_meta, # None is hashable - ) - ) - - def __eq__(self, __o: object) -> bool: - if not ( - isinstance(__o, DTensorSpec) and self.mesh == __o.mesh and tuple(self.placements) == tuple(__o.placements) - ): - return False - return self.tensor_meta == __o.tensor_meta # None included - - def __str__(self) -> str: - """ - human readable representation of the DTensorSpec - """ - if len(self.placements) == 1: - placement_str = str(self.placements[0]) - else: - placement_str = str(self.placements) - - if self.tensor_meta is not None: - tensor_shape = str(tuple(self.tensor_meta.shape)) - else: - tensor_shape = "unknown shape" - - return f"Spec({placement_str} on {tensor_shape})" - - @property - def shape(self) -> torch.Size: - if self.tensor_meta is None: - raise ValueError("tensor_meta is not set") - return self.tensor_meta.shape - - @property - def ndim(self) -> int: - if self.tensor_meta is None: - raise ValueError("tensor_meta is not set") - return len(self.tensor_meta.shape) - - @property - def num_shards(self) -> int: - num_shards = 1 - for i, placement in enumerate(self.placements): - if placement.is_shard() or placement.is_interleaved_shard(): - num_shards *= self.mesh.size(i) - return num_shards - - @property - def dim_map(self) -> List[int]: - """ - dim_map is a property we derive from `placements` of - the distributed tensor. It simply return a list of ints - where dim_map[i] denotes the sharding mapping to the mesh - dimension, and len(dim_map) == dist_tensor.ndim - dim_map[i] = -1: means tensor dim i replicate on mesh - dim_map[i] = j: means tensor dim i shard on mesh dim j - - For example, we have a dist tensor that have the shape of - [18, 20, 30], and device_mesh([0, 1, 2, 3]), placements: - [Shard(1)], the dim_map of this placement would be: - [-1, 0, -1]. This representation is pretty helpful during - sharding propagation where we could know exactly each - tensor dimension is sharded or not. - - Note that if placements contains `Partial`, we have to - explicitly deal with it, so that when we create a DTensorSpec - with dim_map, we could properly record the pending sums. - """ - # dims mapping of dist tensor sharding - # return size of tensor ndim, -1 represent replicate - # and int >=0 represent shard on that device mesh dim - r = [-1] * self.ndim - for i, placement in enumerate(self.placements): - if placement.is_shard(): - shard_dim = placement.dim - # NOTE: this might lead to other problems, pay attention. - # relax this check, allow shard one tensor dim twice. - # if r[shard_dim] > -1: - # raise ValueError( - # f"Tensor dim {shard_dim} is already sharded on mesh dim {r[shard_dim]}," - # " DTensor operator implementation does not support things like hybrid" - # " sharding strategies yet (i.e. [Shard(0), Shard(0)])" - # ) - r[shard_dim] = i - return r - - @property - def sums(self) -> List[int]: - """ - sums is a property we derive from `placements` of the - distributed tensor. It simply return a list of ints where - sums[i] denotes the pending sum (partial) on mesh dim i - """ - return [idx for idx, placement in enumerate(self.placements) if placement.is_partial()] - - @classmethod - def from_dim_map( - cls, - mesh: DeviceMesh, - dim_map: List[int], - sums: List[int], - tensor_meta: Optional[TensorMeta] = None, - ) -> "DTensorSpec": - """ - Construct a DTensorSpec from dim_map list and pending sum. - - Args: - mesh (class:`DeviceMesh`): device mesh to be used in the DTensorSpec - dim_map (List[int]): a list of integer that represents sharding on each - tensor dimension, see `dim_map` property doc for details - sums (List[int]): a list of integer that represents the dist tensor have - pending sum on which device mesh dimension. - tensor meta (TensorMeta): DTensor metadata - - Return: - a class:`DTensorSpec` object - """ - # by default replicate on device mesh dims - placements: List[Placement] = [Replicate() for _ in range(mesh.ndim)] - - # find all mesh dims that need pending reductions - for s in sums: - placements[s] = Partial() - - for i, m in enumerate(dim_map): - if m >= 0: - placement = placements[m] - if placement.is_shard(): - placement = cast(Shard, placement) - raise RuntimeError( - f"DeviceMesh dimension cann't be mapped to two dimension of the same tensor: {i} and {placement.dim}" - ) - elif placement.is_partial(): - raise RuntimeError(f"DeviceMesh dimension {m} cannot be both shard and partial!") - placements[m] = Shard(i) - - return cls(mesh, tuple(placements), tensor_meta=tensor_meta) - - def is_replicated(self): - """ - return True if the current DTensorSpec replicates on all mesh dims (devices) - """ - return all(placement.is_replicate() for placement in self.placements) - - def is_partial(self): - """ - return True if the current DTensorSpec is partial on all mesh dims (devices) - """ - return len(self.placements) == 1 and self.placements[0].is_partial() diff --git a/internlm/checkpoint/vescale/save_state_dict.py b/internlm/checkpoint/vescale/save_state_dict.py deleted file mode 100644 index 20d5883d9..000000000 --- a/internlm/checkpoint/vescale/save_state_dict.py +++ /dev/null @@ -1,204 +0,0 @@ -################################################################################ -# Copyright (c) Meta Platforms, Inc. and affiliates -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -################################################################################ -# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. -################################################################################ - -import os -import pickle -from typing import Optional, Tuple, List - - -import torch.distributed as dist -from torch.distributed.checkpoint.utils import _DistWrapper -from .filesystem import FileSystemWriter - - -from torch.distributed.checkpoint.planner import SavePlanner -from torch.distributed.checkpoint.metadata import Metadata -from torch.distributed.checkpoint.storage import WriteResult -from torch.distributed.checkpoint.default_planner import DefaultSavePlanner -from .meta_type import STATE_DICT_TYPE -from internlm.utils.logger import get_logger -import time -from concurrent.futures import Future -from internlm.core.context import global_context as gpc -from internlm.core.context import ParallelMode - - - - -logger = get_logger(__file__) - -from .vescale_planner import VeScaleSavePlanner - - -def save_state_dict( - state_dict: STATE_DICT_TYPE, - path: str, - process_group: Optional[dist.ProcessGroup] = None, - coordinator_rank: int = 0, - no_dist: bool = False, - planner: Optional[SavePlanner] = None, - async_io: bool = True, - last_write_futures: Future[List[WriteResult]] = None, - io_workers=None, - is_optimizer=False, -) -> Tuple[Metadata, Future[List[WriteResult]]]: - """ - [veScale version] Saves a distributed model in SPMD style. Fix sub-group storage. - Args and usage is the same as `torch.distributed.checkpoint.save_state_dict`. - """ - # Step 0: create distributed world based on process group and coordinator rank - distW = _DistWrapper(process_group, not no_dist, coordinator_rank) - if process_group: - distW.coordinator_rank = dist.get_global_rank(process_group, distW.coordinator_rank) - if planner is None: - planner = DefaultSavePlanner() - assert planner is not None - - global_metatadata = None - - storage_writer = FileSystemWriter(path) - - # Step 1: all processes create local write plan, - # then coordinator gathers all local plans and create global plan. - def local_step(): - logger.debug("Start local step of planning") - if isinstance(planner, VeScaleSavePlanner): - local_plan, p2p_tensors_info = planner.create_local_plan(is_optimizer=is_optimizer) - local_plan = storage_writer.prepare_local_plan(local_plan, p2p_tensors_info) - print(f"save local_plan {dist.get_rank()}, zero={gpc.get_local_rank(ParallelMode.ZERO1)}, tp={gpc.get_local_rank(ParallelMode.TENSOR)}: {local_plan}", flush=True) - # if dist.get_rank() in [0, 1, 2, 3]: - # print(f"save local_plan {dist.get_rank()}, {gpc.get_local_rank(ParallelMode.TENSOR)}: {local_plan}", flush=True) - else: - raise AssertionError("Unsupported planner for planning") - logger.debug("Finish local step of planning") - # print(f"local_plan {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.TENSOR)}: {local_plan}", flush=True) - return local_plan - - def global_step(all_local_plans): - logger.debug("Start global step of planning") - nonlocal global_metatadata - assert planner is not None - # print(f"planner.coordinator_rank {gpc.get_global_rank()}: {planner.is_coordinator }", flush=True) - all_local_plans, global_metatadata = planner.create_global_plan(all_local_plans) - print(f"save global_metatadata {dist.get_rank()}, zero={gpc.get_local_rank(ParallelMode.ZERO1)}, tp={gpc.get_local_rank(ParallelMode.TENSOR)}: {global_metatadata}", flush=True) - all_local_plans = storage_writer.prepare_global_plan(all_local_plans) - logger.debug("End global step of planning") - # print(f"global_plan {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.TENSOR)}: {all_local_plans}", flush=True) - return all_local_plans - - # Step 2: all processes write data from GPUs to pinned memory pool, then dump to local path - # then coordinator write meta-data to local path. - def write_data(async_io: bool = False, io_workers=io_workers): - logger.debug("Start writing data") - assert planner is not None - # print(f"central_plan {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.TENSOR)}: {central_plan}", flush=True) - final_local_plan = planner.finish_plan(central_plan) - if isinstance(planner, VeScaleSavePlanner): - # Use pinned memory pool and mult_processing for dumping ckpt to local directory efficiently - print(f"write_data: {write_data}", flush=True) - all_write_futures = storage_writer.write_data(final_local_plan, planner, async_io, io_workers, is_optimizer) - logger.debug("Finish writing data") - if async_io: - return all_write_futures - else: - # Gather write results. - values = [] - for fut in all_write_futures: - # values += fut.get() - values += fut.result() - return values - else: - raise AssertionError("Unsupported planner for writing data") - - def finish_checkpoint(all_results): - logger.debug("Start writing metadata") - assert global_metatadata is not None, f"rank: {distW.get_rank()} has no global_metadata" - storage_writer.finish(metadata=global_metatadata, results=all_results) - logger.debug("Finish writing metadata") - print(f"global_metatadata: {global_metatadata}", flush=True) - return global_metatadata - - assert planner is not None - planner.set_up_planner(state_dict, distW.is_coordinator) - storage_writer.set_up_storage_writer(distW.is_coordinator) - - # Wait for last write futures to finish. - if last_write_futures: - print(f"last_write_futures: {last_write_futures}", flush=True) - logger.info("Start waiting for last write events.") - last_write_start_time = time.time() - for fut in last_write_futures: - fut.result() - last_write_time = time.time() - last_write_start_time - logger.info(f"Finish waiting for last write events. Time cost: {last_write_time}s") - - # Each worker bypass the `reduce_scatter()` and `all_reduce()` if finding cached central_plan and metadata. - # NOTE: it fails when the plans of partial workers change while others keep the same. - logger.info("Start planning.") - plan_start_time = time.time() - cached_data = None - - # if isinstance(planner, VeScaleSavePlanner): - # central_plan = distW.reduce_scatter("plan", local_step, global_step) - # else: - # raise AssertionError("Unsupported planner for saving checkpoint") - - if isinstance(planner, VeScaleSavePlanner): - cached_data = planner.lookup_plan_meta() - if cached_data: - logger.info("Plan cache hit. Reuse existing plan") - central_plan, _ = cached_data - # _ = local_step() #attn - else: - logger.info("Plan cache miss. The model/optimizer appears for the first time.") - - central_plan = distW.reduce_scatter("plan", local_step, global_step) - else: - raise AssertionError("Unsupported planner for saving checkpoint") - - - - plan_cost_time = time.time() - plan_start_time - logger.info(f"Finish planning. Time cost: {plan_cost_time}s") - - logger.info("Start storing") - store_local_start_time = time.time() - write_futures = [] - if isinstance(planner, VeScaleSavePlanner): - if cached_data: - logger.info("Metdata cache hit. Reuse existing metadata") - _, final_storage_metadata = cached_data - write_results = write_data(async_io=async_io) - # Be sure to write cache metadata to .metadata file - # Otherwises only the first checkpoint has .metadata - # which leads to error when loading other checkpoints - if distW.is_coordinator: - with (storage_writer.path / ".metadata.tmp").open("wb") as metadata_file: - pickle.dump(final_storage_metadata, metadata_file) - os.fsync(metadata_file.fileno()) - - (storage_writer.path / ".metadata.tmp").rename(storage_writer.path / ".metadata") - - if async_io: - write_futures = write_results - else: - logger.info("Metadata cache miss. The model/optimizer appears for the first time.") - # First time do synchronous storing to get final_storage_metatdata. - # Determine which communication topology to use. - final_storage_metadata = distW.all_reduce("write", write_data, finish_checkpoint) - assert central_plan is not None - assert final_storage_metadata is not None - planner.cache_plan_meta(central_plan, final_storage_metadata) #attn - else: - raise AssertionError("Unsupported planner for writing data and metadata") - store_local_cost_time = time.time() - store_local_start_time - logger.info(f"Finish storing. Time cost: {store_local_cost_time}s") - - return final_storage_metadata, write_futures diff --git a/internlm/checkpoint/vescale/vescale_checkpointer.py b/internlm/checkpoint/vescale/vescale_checkpointer.py deleted file mode 100644 index 0327aee27..000000000 --- a/internlm/checkpoint/vescale/vescale_checkpointer.py +++ /dev/null @@ -1,293 +0,0 @@ -################################################################################ -# -# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -################################################################################ - -from concurrent.futures import ProcessPoolExecutor -from .base_checkpointer import BaseCheckpointer -from .meta_type import CheckpointState, MODEL_STR, OPTIMIZER_STR -from .save_state_dict import save_state_dict -from .load_state_dict import load_state_dict -from .vescale_planner import VeScaleSavePlanner, VeScaleLoadPlanner -from .devicemesh_api import VESCALE_DEVICE_MESH -from . import bfile -import os -# from .distributed_optimizer import initialize_optimizer_state -import torch.distributed as dist -import torch -from internlm.utils.logger import get_logger -import atexit -from internlm.core.context import global_context as gpc -from internlm.core.context import ParallelMode -from internlm.solver.optimizer import HybridZeroOptimizer -from internlm.train.pipeline import map_fqn_local_to_global - - - -logger = get_logger(__file__) - -VESCALE_SUPPORTED_TYPES = {MODEL_STR, OPTIMIZER_STR} -NUM_IO_WORKER = 1 - - -def deduplicate_2d_list(lst): - seen = set() - deduplicated_list = [] - for item in lst: - # Convert the inner list to a tuple for hashing - tuple_item = tuple(sorted(item)) # Sorting to treat [1, 2] and [2, 1] as the same - if tuple_item not in seen: - seen.add(tuple_item) - # Convert back to list to preserve original type - deduplicated_list.append(item) - return deduplicated_list - - -def get_optim_ckpt_process_group(): - # Get the process group based on current rank - # The processes with same pipeline stage ID - # are in the same process group - device_mesh = VESCALE_DEVICE_MESH.get() - sub_mesh = device_mesh.get_submesh(mesh_dims=["TP", "DP"]) - two_dim_list = sub_mesh.mesh.tolist() - flatten_rank_list = [item for sublist in two_dim_list for item in sublist] - all_flatten_lists = [[] for _ in range(dist.get_world_size())] - dist.all_gather_object(all_flatten_lists, flatten_rank_list) - all_flatten_lists = deduplicate_2d_list(all_flatten_lists) - my_rank = dist.get_rank() - pg = None - for rank_list in all_flatten_lists: - new_pg = dist.new_group(ranks=flatten_rank_list) - if my_rank in rank_list: - pg = new_pg - return pg - - -class VeScaleCheckpointer(BaseCheckpointer): - """ - The Checkpointer class for VeScale, A PyTorch Native Auto Parallelism Framework - """ - - save_planner = VeScaleSavePlanner() - load_planner = VeScaleLoadPlanner() - - optim_ckpt_proces_group = None - for key in VESCALE_SUPPORTED_TYPES: - BaseCheckpointer.state_io_workers[key] = ProcessPoolExecutor(max_workers=NUM_IO_WORKER) - BaseCheckpointer.state_write_futures[key] = [] - - @classmethod - def save( - cls, - path: str, - checkpoint_state: CheckpointState, - async_checkpoint: bool = False, - ): - """ - async_checkpoint: A boolean value indicating if saving checkpoint asynchronously, - i.e. after dumping tensors from GPU memory to Host memory, - the training program can continue training immediately. - Then vescale.checkpoint will serialize tensors and dumping to the persistent storage asynchronously. - """ - # Check if we support saving the components - for key in checkpoint_state.keys(): - if key not in VESCALE_SUPPORTED_TYPES: - raise ValueError(f"{key} is not supported by VeScaleCheckpointer") - - if path.startswith("local:"): - path = path.split(':')[1] - assert ':' not in path, f"{path} is not valid for universal checkpoint!" - # Start saving checkpoint - for key, value in checkpoint_state.items(): - if key == MODEL_STR: - # Get model path - model_path = os.path.join(path, MODEL_STR) - print(f"model_path: {path}, {model_path}", flush=True) - # Create a "model" folder on under root path - if dist.get_rank() == 0: - bfile.makedirs(model_path) - # if not os.path.exists(path): - # os.makedirs(path, exist_ok=True) - dist.barrier() - # Save model. - _, new_write_futures = save_state_dict( - state_dict=value.state_dict(), - path=model_path, - process_group=None, - coordinator_rank=0, - no_dist=False, - planner=cls.save_planner, - async_io=async_checkpoint, - last_write_futures=cls.state_write_futures[MODEL_STR], - io_workers=cls.state_io_workers[MODEL_STR], - is_optimizer=False, - ) - # Record new write futures. - cls.state_write_futures[MODEL_STR] = new_write_futures - dist.barrier() #att - elif key == OPTIMIZER_STR: - # adamW hybrid zero optim - assert isinstance(value, HybridZeroOptimizer) - optimizer_state = value.state_dict() - - # Create a "optimizer" folder on under root path - # to save different parts of optimizer - optim_root_path = os.path.join(path, OPTIMIZER_STR) - print(f"optim_root_path: {optim_root_path}", flush=True) - if dist.get_rank() == 0: - bfile.makedirs(optim_root_path) - dist.barrier() - # Get process group for saving optimizer, - # All processes with the same pipeline rank are in the same pg - # if not cls.optim_ckpt_proces_group: - # cls.optim_ckpt_proces_group = get_optim_ckpt_process_group() - - # Get optimizer path based on PP rank - # pp_rank = VESCALE_DEVICE_MESH.get_pipeline_parallel_rank() #attn - # pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - # optimizer_path = os.path.join(optim_root_path, f"pp_{pp_rank}") - optimizer_path = optim_root_path - # Create optimizer folder on under root path - # if dist.get_rank(cls.optim_ckpt_proces_group) == 0: - # bfile.makedirs(optimizer_path) - # dist.barrier() - # Save optimizer - _, new_write_futures = save_state_dict( - state_dict=optimizer_state["sharded_optimizer_state"], - path=optimizer_path, - process_group=None, - coordinator_rank=0, - no_dist=False, - planner=cls.save_planner, - async_io=async_checkpoint, - last_write_futures=cls.state_write_futures[OPTIMIZER_STR], - io_workers=cls.state_io_workers[OPTIMIZER_STR], - is_optimizer=True, - ) - # Record new write futures. - cls.state_write_futures[OPTIMIZER_STR] = new_write_futures - - optimizer_state.pop("sharded_optimizer_state") - # print(f"after_pop {gpc.get_global_rank}: {optimizer_state}", flush=True) - if gpc.get_global_rank() == 0: - print(f"global_optimizer_state: {os.path.join(path, 'global_optimizer_state.pt')}", flush=True) - torch.save(optimizer_state, os.path.join(path, "global_optimizer_state.pt")) - dist.barrier() - - @classmethod - def load( - cls, - path: str, - checkpoint_state: CheckpointState, - broadcast_checkpoint: bool = False, - ): - """ - broadcast_checkpoint: A boolean value decides if load a model replica from one data parallel process group - then broadcast tensors to other data parallel process group using GPUs - to reduce the file system access - For example, when data parellel size = 2, - processes with data parallel rank = 0 load model from file system - then broadcast it to processes with data parallel rank = 1 - """ - if path.startswith("local:"): - path = path.split(':')[1] - assert ':' not in path, f"{path} is not valid for universal checkpoint!" - # Add warning - if bfile.is_local_path(path): - logger.warning( - "The local path for checkpointing should be accessible to all ranks. It can be a NFS/FUSE path" - ) - # Check if we support loading the component. - for key in checkpoint_state.keys(): - if key not in VESCALE_SUPPORTED_TYPES: - raise ValueError(f"{key} is not supported by VeScaleCheckpointer") - - # Start loading checkpoint - for key, value in checkpoint_state.items(): - if key == MODEL_STR: - # Get model path - model_path = os.path.join(path, MODEL_STR) - # Get model state dictionary - model_state = value.state_dict() - p = ", ".join(map(lambda item: f"Key: {item[0]}, Shape: {item[1].shape}", model_state.items())) - # print(f"model_state {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.PIPELINE)}: {p})", flush=True) - # Set process group - if broadcast_checkpoint: - # model_load_process_group = VESCALE_DEVICE_MESH.get_data_parallel_dim_groups() - model_load_process_group = gpc.get_group(ParallelMode.DATA) - else: - model_load_process_group = None - # Load model - load_state_dict( - state_dict=model_state, - path=model_path, - process_group=model_load_process_group, - coordinator_rank=0, - no_dist=False, - planner=cls.load_planner, - broadcast_tensors=broadcast_checkpoint, - ) - # Load back to model - value.load_state_dict(model_state) #att - elif key == OPTIMIZER_STR: - # Get process group for loading optimizer, - # All processes with the same pipeline rank are in the same pg - # if not cls.optim_ckpt_proces_group: - # cls.optim_ckpt_proces_group = get_optim_ckpt_process_group() - # Get optimizer path based on TP and PP ranks - # pp_rank = VESCALE_DEVICE_MESH.get_pipeline_parallel_rank() #attn - pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - optimizer_path = os.path.join(path, OPTIMIZER_STR) - print(f"optimizer_path: {optimizer_path}", flush=True) - # Initialize optimizer states - # initialize_optimizer_state(value) - # Get optimizer state - optimizer_state = value.state_dict() - # Load optimizer state dictionary - load_state_dict( - state_dict=optimizer_state["sharded_optimizer_state"], - path=optimizer_path, - process_group=None, - coordinator_rank=0, - no_dist=False, - planner=cls.load_planner, - broadcast_tensors=False, - is_optimizer=True, - ) - #att check len equal - # print(f"optimizer_state {gpc.get_global_rank()}: {len(optimizer_state)}, {optimizer_state.keys()}", flush=True) - # fqn_list = checkpoint_state['optimizer'].fqn_list - # print(f"fqn_list {gpc.get_global_rank()}: {len(fqn_list[0]) + len(fqn_list[1])}, {fqn_list[0], fqn_list[1]}", flush=True) - # Load back to optimizer - global_optimizer_state = torch.load(os.path.join(path, "global_optimizer_state.pt")) - value.load_state_dict(optimizer_state["sharded_optimizer_state"], global_optimizer_state) - dist.barrier() - - @classmethod - def __cleanup(cls): - """ - Wait for all write futures to finish before exit, then do the cleanup works. - - WARNING: this method cannot be called by the users. - """ - cls.save_planner.clear_cache() - BaseCheckpointer._cleanup_futures() - - @classmethod - def _register_cleanup(cls): - atexit.register(VeScaleCheckpointer.__cleanup) - - -VeScaleCheckpointer._register_cleanup() diff --git a/internlm/checkpoint/vescale/vescale_planner.py b/internlm/checkpoint/vescale/vescale_planner.py deleted file mode 100644 index 598440d0d..000000000 --- a/internlm/checkpoint/vescale/vescale_planner.py +++ /dev/null @@ -1,288 +0,0 @@ -################################################################################ -# Copyright (c) Meta Platforms, Inc. and affiliates -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -################################################################################ -# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. -################################################################################ -import io -import dataclasses -import torch -from typing import Any, Dict, Union, List, Tuple, Optional -from torch.distributed.checkpoint.default_planner import ( - DefaultSavePlanner, - DefaultLoadPlanner, -) - -import mmh3 - -from .common import P2PTensorsInfo, sort_rank_ranges, PlanLRUCache, custom_dedup_tensors -import math -import torch.distributed as dist -from torch.distributed.checkpoint.planner import SavePlan, LoadPlan, WriteItem, ReadItem -from torch.distributed.checkpoint.metadata import MetadataIndex, Metadata -from .distributed_optimizer import OptimizerStateSpec -# from vescale.dtensor import DTensor - -from .vescale_planner_helpers import _create_write_items, _create_read_items, find_state_dict_object -from .devicemesh_api import VESCALE_DEVICE_MESH -from .meta_type import STATE_DICT_STR -from internlm.utils.logger import get_logger -from internlm.core.context import global_context as gpc -from internlm.core.context import ParallelMode -from torch.distributed.checkpoint.planner import WriteItemType - - -logger = get_logger(__file__) -__all__ = [ - "VeScaleSavePlanner", - "VeScaleLoadPlanner", - "create_default_local_load_plan", - "create_default_local_save_plan", -] - -class DTensor: - pass - - -class VeScaleLoadPlanner(DefaultLoadPlanner): - """ - A planner class for loading vescale checkpoint using PyTorch DCP - """ - - def __init__(self): - super().__init__() - - def create_local_plan(self, is_optimizer=False) -> LoadPlan: - return create_default_local_load_plan(self.state_dict, self.metadata, is_optimizer) - - def resolve_tensor(self, read_item: ReadItem): - tensor = self.lookup_tensor(read_item.dest_index) - return self.transform_tensor(read_item, tensor) - - def lookup_tensor(self, index: MetadataIndex) -> torch.Tensor: - """ - This is an extension from the planner interface to make it easy to extend the default planner - """ - return find_state_dict_object(self.state_dict, index) - -from internlm.train.pipeline import map_fqn_local_to_global - -def create_default_local_load_plan(state_dict: Dict[str, Any], metadata: Metadata, is_optimizer) -> LoadPlan: - """ - A function for creating local loading plan for loading checkpoint - """ - # print(f"metadata: {metadata}", flush=True) - requests = [] - for fqn, obj in state_dict.items(): - global_fqn = fqn - if not is_optimizer: - if fqn in map_fqn_local_to_global: - global_fqn = map_fqn_local_to_global[fqn] - # print(f"map_fqn_local_to_global {gpc.get_local_rank(ParallelMode.PIPELINE)}: {fqn}, {global_fqn}", flush=True) - md = metadata.state_dict_metadata[global_fqn] - if isinstance(obj, DTensor): - assert False - if obj.device_mesh.get_coordinate() is not None: - requests += _create_read_items(fqn, md, obj) - elif isinstance(obj, OptimizerStateSpec): - assert False - # If the state is distributed on multiple dp ranks - # Read with local_shape, then in DOptimizer then - # get flaaten to 1D and get the part belonging to current dp rank - if obj.dp_ranks_ranges: - obj.local_tensor = torch.zeros( - obj.local_shape, dtype=obj.local_tensor.dtype, device=obj.local_tensor.device - ) - requests += _create_read_items(fqn, md, obj) - else: - # If the state is owned by only one dp rank - # Read directly - obj.local_tensor = obj.local_tensor.reshape(obj.local_shape) - requests += _create_read_items(fqn, md, obj) - else: - item = _create_read_items(fqn, global_fqn, md, obj) - # print(f"_create_read_items {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.PIPELINE)}: {item}", flush=True) - requests += item - return LoadPlan(requests) - - -class VeScaleSavePlanner(DefaultSavePlanner): - """ - A planner class for saving vescale checkpoint using PyTorch DCP - """ - - def __init__(self): - super().__init__() - self._plan_cache = PlanLRUCache() - - def resolve_data(self, write_item: WriteItem, fqn=None) -> Union[torch.Tensor, io.BytesIO]: - assert write_item.type != WriteItemType.BYTE_IO - object = self.lookup_object(write_item.index, fqn) - return self.transform_object(write_item, object) - - def create_local_plan(self, is_optimizer=False) -> Tuple[SavePlan, P2PTensorsInfo]: - plan, p2p_tensors_info = create_default_local_save_plan(self.state_dict, self.is_coordinator, is_optimizer) - # print(f"save before replace local_plan {dist.get_rank()}: {plan}", flush=True) - if self.flatten_state_dict: - plan = dataclasses.replace(plan, planner_data=self.mappings) - self.plan = plan - return self.plan, p2p_tensors_info - - def lookup_object(self, index: MetadataIndex, fqn=None) -> Any: - return find_state_dict_object(self.state_dict, index, fqn) - - def lookup_plan_meta(self) -> Optional[Tuple[SavePlan, Metadata]]: - # if not hasattr(self, STATE_DICT_STR): - # return None - # else: - # device_mesh = VESCALE_DEVICE_MESH.get() - # plan_key = hash((frozenset(self.state_dict.keys()), self.is_coordinator, device_mesh)) - # return self._plan_cache.get(plan_key) - - if not hasattr(self, STATE_DICT_STR): - return None - else: - plan_key = hash((frozenset(self.state_dict.keys()), self.is_coordinator)) - return self._plan_cache.get(plan_key) - - def cache_plan_meta(self, new_plan: SavePlan, new_metadata: Metadata) -> None: - # device_mesh = VESCALE_DEVICE_MESH.get() - # plan_key = hash((frozenset(self.state_dict.keys()), self.is_coordinator, device_mesh)) - # self._plan_cache.put(plan_key, new_plan, new_metadata) - - print(f"new_plan {gpc.get_global_rank()}: {new_plan}", flush=True) - print(f"new_metadata {gpc.get_global_rank()}: {new_metadata}", flush=True) - - plan_key = hash((frozenset(self.state_dict.keys()), self.is_coordinator)) - print(f"Before GPU Memory Allocated {gpc.get_global_rank()}: {torch.cuda.memory_allocated() /1024/1024} bytes", flush=True) - print(f"Before GPU Memory Cached {gpc.get_global_rank()}: {torch.cuda.memory_reserved() /1024/1024} bytes", flush=True) - self._plan_cache.put(plan_key, new_plan, new_metadata) - print(f"After GPU Memory Allocated {gpc.get_global_rank()}: {torch.cuda.memory_allocated() /1024/1024} bytes", flush=True) - print(f"After GPU Memory Cached {gpc.get_global_rank()}: {torch.cuda.memory_reserved() /1024/1024} bytes", flush=True) - - def clear_cache(self) -> None: - self._plan_cache.clear() - - def dedup_plans(self, all_plans: List[SavePlan]) -> List[SavePlan]: - # Use customized deduplicate function for load balance - all_plans = custom_dedup_tensors(all_plans) - return all_plans - - def create_dedup_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]: - # Disable DCP's dedup replicated tensors function - self.dedup_replicated_tensors = False - rst_value = super().create_global_plan(all_plans) - return rst_value - - def create_global_plan(self, all_plans: List[SavePlan]) -> Tuple[List[SavePlan], Metadata]: - # Disable DCP's dedup replicated tensors function - self.dedup_replicated_tensors = False - # Use customized deduplicate function for load balance - all_plans = custom_dedup_tensors(all_plans) #att - rst_value = super().create_global_plan(all_plans) - return rst_value - - -def create_default_local_save_plan(state_dict: Dict[str, Any], is_coordinator: bool, is_optimizer=False) -> SavePlan: - """ - A function for creating local saving plan for saving checkpoint. - """ - requests = [] - # Key: fqn - # Value: dictionary (Key is the process rank, value is tensor to receive) - recv_tensors = {} - - send_p2p_reqs = [] - recv_p2p_reqs = {} - # if is_optimizer: - # state_dict = state_dict["unflatten_fp32_weights"] - - for fqn, obj in state_dict.items(): - # Since DTensor supports submesh, adding extra check to ensure _create_write_items() - # gets called only when the current rank is part of the mesh for the corresponding DTensor. - if isinstance(obj, DTensor): - assert False - if obj.device_mesh.get_coordinate() is not None: - requests += _create_write_items(fqn, obj) - elif isinstance(obj, OptimizerStateSpec): - assert False #attn - # Create write requests if the process is the real writer - if obj.dp_ranks_ranges: - process_list = [] - for rank, param_range in obj.dp_ranks_ranges.items(): - process_list.append((rank, len(param_range))) - sorted_list = sort_rank_ranges(process_list) - writer_rank = sorted_list[mmh3.hash(fqn) % len(sorted_list)][0] - send_ops_to_start = [] - recv_ops_to_start = {} - # Case 1: I am writer - # Receive tensors - logger.debug(f"fqn={fqn} is a tensor across dp ranks. writer rank={writer_rank}") - if dist.get_rank() == writer_rank: - recv_tensors[fqn] = {} - for k, param_range in obj.dp_ranks_ranges.items(): - if k != dist.get_rank(): - recv_tensor = torch.zeros( - (len(param_range),), dtype=obj.local_tensor.dtype, device=obj.local_tensor.device - ) - recv_op = dist.P2POp( - op=dist.irecv, - tensor=recv_tensor, - peer=k, - group=gpc.get_group(ParallelMode.DATA), #att - ) - recv_tensors[fqn][k] = (recv_tensor, param_range) - recv_ops_to_start[k] = recv_op - else: - # Case 2: I am not writer - # Send my tensor - send_op = dist.P2POp( - op=dist.isend, - tensor=obj.local_tensor, - peer=writer_rank, - group=gpc.get_group(ParallelMode.DATA), #att - ) - send_ops_to_start.append(send_op) - - send_reqs = [] - recv_reqs = [] - if send_ops_to_start: - send_reqs = dist.batch_isend_irecv(send_ops_to_start) - if recv_ops_to_start: - recv_reqs = dist.batch_isend_irecv(list(recv_ops_to_start.values())) - - if send_reqs: - send_p2p_reqs.extend(send_reqs) - - if recv_reqs: - recv_p2p_reqs[fqn] = recv_reqs - else: - obj.local_tensor = obj.local_tensor.reshape(obj.local_shape) - requests += _create_write_items(fqn, obj) - elif isinstance(obj, (torch.Tensor)) or is_coordinator: - item = _create_write_items(fqn, obj, is_optimizer=is_optimizer) - # print(f"_create_write_items {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.TENSOR)}: {item}", flush=True) - requests += item - # if dist.get_rank() == 1: - # print(f"requests: {requests}", flush=True) - - # Padding the states across DP ranks - # Merge the tensors later - writer_rank = dist.get_rank() - for fqn in recv_tensors.keys(): - assert False - obj = state_dict[fqn] - new_local_tensor = torch.zeros( - (math.prod(obj.local_shape),), dtype=obj.local_tensor.dtype, device=obj.local_tensor.device - ) - new_local_tensor[obj.dp_ranks_ranges[writer_rank].start : obj.dp_ranks_ranges[writer_rank].end] = ( - obj.local_tensor - ) - obj.local_tensor = new_local_tensor - - obj.local_tensor = obj.local_tensor.reshape(obj.local_shape) - requests += _create_write_items(fqn, obj) - return SavePlan(requests), P2PTensorsInfo(recv_tensors, send_p2p_reqs, recv_p2p_reqs) diff --git a/internlm/checkpoint/vescale/vescale_planner_helpers.py b/internlm/checkpoint/vescale/vescale_planner_helpers.py deleted file mode 100644 index 513123634..000000000 --- a/internlm/checkpoint/vescale/vescale_planner_helpers.py +++ /dev/null @@ -1,307 +0,0 @@ -################################################################################ -# Copyright (c) Meta Platforms, Inc. and affiliates -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -################################################################################ -# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. -################################################################################ - -from typing import Any, List -import torch -from torch.distributed._shard.metadata import ShardMetadata -from torch.distributed.checkpoint.planner import WriteItem, ReadItem, WriteItemType, LoadItemType, TensorWriteData -from torch.distributed.checkpoint.metadata import ( - STATE_DICT_TYPE, - STORAGE_TYPES, - MetadataIndex, - ChunkStorageMetadata, - BytesStorageMetadata, - TensorStorageMetadata, -) -from torch.distributed._shard.sharded_tensor import TensorProperties -from torch.distributed.checkpoint.resharding import ( - _check_shard_metadata_pair_overlap, - _shards_get_overlap_region_wrt_saved_tensor, -) -from internlm.core.context import global_context as gpc -from internlm.core.context import ParallelMode -# from vescale.dtensor import DTensor -from ._utils import compute_local_shape, compute_local_offset -from .distributed_optimizer import OptimizerStateSpec -from internlm.utils.logger import get_logger -logger = get_logger(__file__) - -class DTensor: - pass - - -def _create_write_items_for_dtensor(fqn, tensor: DTensor) -> WriteItem: - assert False - sizes = torch.Size(compute_local_shape(tensor.shape, tensor.device_mesh, tensor.placements)) - offsets = torch.Size(compute_local_offset(tensor.shape, tensor.device_mesh, tensor.placements)) - - return WriteItem( - index=MetadataIndex(fqn=fqn, offset=offsets), - type=WriteItemType.SHARD, - tensor_data=TensorWriteData( - chunk=ChunkStorageMetadata(offsets=offsets, sizes=sizes), - properties=TensorProperties.create_from_tensor(tensor._local_tensor), # keep out of autograd - size=tensor.size(), - ), - ) - - -def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata: - assert False - sizes = torch.Size(compute_local_shape(tensor.shape, tensor.device_mesh, tensor.placements)) - offsets = torch.Size(compute_local_offset(tensor.shape, tensor.device_mesh, tensor.placements)) - return ChunkStorageMetadata(offsets=offsets, sizes=sizes) - -from internlm.train.pipeline import map_layer_attr, map_fqn_local_to_global -def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor, is_optimizer=False) -> WriteItem: - offsets = torch.Size([0] * len(tensor.size())) - size = tensor.size() - - if not is_optimizer: - if 'layer' in fqn: - assert fqn in map_fqn_local_to_global, f"{fqn}" - # print(f"map_fqn: {fqn}, {map_fqn_local_to_global[fqn]}", flush=True) - fqn = map_fqn_local_to_global[fqn] - - map_fqn = fqn - if map_fqn not in map_layer_attr: - # os exp_avg, exp_avg_sq - map_fqn = fqn.rsplit('.', 1)[0] - assert map_fqn in map_layer_attr, f"{gpc.get_global_rank()}, {gpc.get_local_rank(ParallelMode.PIPELINE)}, {gpc.get_local_rank(ParallelMode.ZERO1)}, {is_optimizer}, {fqn}" - offsets = torch.Size(map_layer_attr[map_fqn]['offset']) - size = torch.Size(map_layer_attr[map_fqn]['complete_size']) - - # if map_fqn in map_fqn_local_to_global: - # global_fqn = map_fqn_local_to_global[map_fqn] - # if fqn != map_fqn: - # fqn = global_fqn + f".{fqn.rsplit('.', 1)[1]}" - # else: - # fqn = global_fqn - - # if offsets[0] != 0: - # k = offsets[0] // gpc.get_local_rank(ParallelMode.TENSOR) - # assert child.weight.complete_size[0] % k == 0 and child.weight.complete_size[0] // k == gpc.get_world_size(ParallelMode.TENSOR), f"{child.weight.complete_size}, {child.weight.offset}" - # if child.weight.offset[1] != 0: - # k = child.weight.offset[1] // gpc.get_local_rank(ParallelMode.TENSOR) - # assert child.weight.complete_size[1] % k == 0 and child.weight.complete_size[1] // k == gpc.get_world_size(ParallelMode.TENSOR), f"{child.weight.complete_size}, {child.weight.offset}" - - return WriteItem( - index=MetadataIndex(fqn, offsets), - type=WriteItemType.SHARD, - tensor_data=TensorWriteData( - chunk=ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()), - properties=TensorProperties.create_from_tensor(tensor), - size=size, - ), - ) - - -def _create_write_item_for_optimizer_state(fqn, object: OptimizerStateSpec) -> WriteItem: - sizes = object.local_shape - offsets = object.global_offset - - return WriteItem( - index=MetadataIndex(fqn=fqn, offset=offsets), - type=WriteItemType.SHARD, - tensor_data=TensorWriteData( - chunk=ChunkStorageMetadata(offsets=offsets, sizes=sizes), - properties=TensorProperties.create_from_tensor(object.local_tensor), - size=object.global_shape, - ), - ) - - -def _create_write_item_for_bytesio(fqn: str, bytes: Any): - return WriteItem( - index=MetadataIndex(fqn), - type=WriteItemType.BYTE_IO, - ) - - -def _create_write_items(fqn: str, object: Any, is_optimizer=False) -> List[WriteItem]: - if isinstance(object, DTensor): - assert False - return [_create_write_items_for_dtensor(fqn, object)] - elif isinstance(object, torch.Tensor): - return [_create_write_item_for_tensor(fqn, object, is_optimizer=is_optimizer)] - elif isinstance(object, OptimizerStateSpec): - assert False - return [_create_write_item_for_optimizer_state(fqn, object)] - else: - assert False - return [_create_write_item_for_bytesio(fqn, object)] - - -def _create_read_item_for_tensor(dest_index, dest_offsets, storage_index, storage_offsets, lengths): - return ReadItem( - type=LoadItemType.TENSOR, - dest_index=dest_index, - dest_offsets=torch.Size(dest_offsets), - storage_index=storage_index, - storage_offsets=torch.Size(storage_offsets), - lengths=torch.Size(lengths), - ) - - -def create_read_items_for_chunk_list( - fqn: str, - global_fqn: str, - checkpoint_md: TensorStorageMetadata, - local_chunks: List[ChunkStorageMetadata], -) -> List[ReadItem]: - """ - Creates a list of ``ReadItem`` based on the checkpoint and local chunks. - - This applies the resharding algorithm and computes the reads needed - to satisfy ``local_chunks`` with a checkpoint described by ``checkpoint_md``. - - Args: - fqn (str) : The state_dict FQN to pass to ``ReadItem``. - checkpoint_md (TensorStorageMetadata): metadata for a given tensor - from a checkpoint. - local_chunks (List[ChunkStorageMetadata]): Local chunks that needs to be - loaded. - - Returns: - A list of ``ReadItem`` that will satisfy all input chunks. - """ - read_items = [] - # this is a naive quadratic algo that can be optimized later - for idx, shard in enumerate(local_chunks): - for storage_idx, storage_md in enumerate(checkpoint_md.chunks): - if not _check_shard_metadata_pair_overlap(shard, storage_md): - continue - - storage_offsets = [] - dest_offsets = [] - lengths = [] - for ( - dim, - offset_for_saved_tensor, - offset_for_current_tensor, - length, - ) in _shards_get_overlap_region_wrt_saved_tensor(saved_shard=storage_md, current_shard=shard): - storage_offsets.append(offset_for_saved_tensor) - dest_offsets.append(offset_for_current_tensor) - lengths.append(length) - - read_items.append( - _create_read_item_for_tensor( - dest_index=MetadataIndex(fqn, shard.offsets, idx), - dest_offsets=dest_offsets, - storage_index=MetadataIndex(global_fqn, storage_md.offsets, storage_idx), - storage_offsets=storage_offsets, - lengths=lengths, - ) - ) - return read_items - - -def _create_chunk_from_tensor(global_fqn, tensor: torch.Tensor) -> ChunkStorageMetadata: - - # sizes = torch.Size(compute_local_shape(tensor.shape, tensor.device_mesh, tensor.placements)) - # offsets = torch.Size(compute_local_offset(tensor.shape, tensor.device_mesh, tensor.placements)) - # return ChunkStorageMetadata(offsets=offsets, sizes=sizes) - - if global_fqn not in map_layer_attr: - # os exp_avg, exp_avg_sq - global_fqn = global_fqn.rsplit('.', 1)[0] - assert global_fqn in map_layer_attr, f"{global_fqn}" - offsets = torch.Size(map_layer_attr[global_fqn]['offset']) - - - # return ChunkStorageMetadata(offsets=torch.Size([0] * len(tensor.size())), sizes=tensor.size()) - return ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()) - - - -def _create_read_item_for_byteio(dest_index, dest_offset, storage_index, storage_offset, length): - return ReadItem( - type=LoadItemType.BYTE_IO, - dest_index=dest_index, - dest_offsets=torch.Size((dest_offset,)), - storage_index=storage_index, - storage_offsets=torch.Size((storage_offset,)), - lengths=torch.Size((length,)), - ) - - -def _create_chunk_from_optimizer_spec(obj: OptimizerStateSpec) -> ChunkStorageMetadata: - return ChunkStorageMetadata(offsets=obj.global_offset, sizes=obj.local_shape) - - -def _create_read_items(fqn: str, global_fqn, md: STORAGE_TYPES, obj: Any) -> List[ReadItem]: - if not isinstance(md, BytesStorageMetadata): - if isinstance(obj, DTensor): - assert False - local_chunks = [_create_chunk_from_dtensor(obj)] - elif isinstance(obj, torch.Tensor): - local_chunks = [_create_chunk_from_tensor(global_fqn, obj)]#att - # print(f"local_chunks {gpc.get_global_rank()} {gpc.get_local_rank(ParallelMode.PIPELINE)} {fqn}: {local_chunks}", flush=True) - elif isinstance(obj, OptimizerStateSpec): - assert False - local_chunks = [_create_chunk_from_optimizer_spec(obj)] - else: - raise ValueError( - f"Invalid checkpoint metadata for {fqn}, " + f"expected BytesStorageMetadata but found {type(md)}" - ) - return create_read_items_for_chunk_list(fqn, global_fqn, md, local_chunks) - else: - assert False - return [ - _create_read_item_for_byteio( - dest_index=MetadataIndex(fqn), - dest_offset=0, - storage_index=MetadataIndex(fqn), - storage_offset=0, - length=0, - ) - ] - - -def _chunk_for_shard(shard_md: ShardMetadata) -> ChunkStorageMetadata: - return ChunkStorageMetadata( - offsets=torch.Size(shard_md.shard_offsets), - sizes=torch.Size(shard_md.shard_sizes), - ) - - -def find_tensor_shard(tensor: torch.Tensor, index: MetadataIndex) -> torch.Tensor: - if isinstance(tensor, DTensor): - assert False - return tensor.to_local() - if index.offset is not None: - # special case looking up a tensor by origin - if index.offset == torch.Size([0] * len(tensor.size())): - return tensor - raise ValueError(f"FQN: '{index.fqn}' is not a DTensor, can't find by offset: '{index.offset}'") - return tensor - - -def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex, fqn=None) -> Any: - # Called when real writing happened - # The filesystem writer calls resolve_data , then it will - # call find_state_dict_object - if fqn is None: - fqn = index.fqn - if fqn not in state_dict: - raise ValueError(f"Could not find FQN: '{fqn}'") - obj = state_dict[fqn] - - # if isinstance(obj, torch.Tensor): #att - # return find_tensor_shard(obj, index) - if isinstance(obj, OptimizerStateSpec): - assert False - return obj.local_tensor - # elif index.offset is not None: - # raise ValueError( - # f"FQN: '{index.fqn}' is not a DTensor, it is a {type(obj)} can't find by offset: '{index.offset}'" - # ) - return obj From 7ed87ab7490f0697413b10ecd58e39ee0072cc86 Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Wed, 25 Dec 2024 14:04:56 +0800 Subject: [PATCH 06/12] fix --- internlm/core/trainer_builder.py | 7 ++----- internlm/solver/optimizer/hybrid_zero_optim.py | 2 -- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/internlm/core/trainer_builder.py b/internlm/core/trainer_builder.py index c9e9d4ab8..d0ef284d4 100644 --- a/internlm/core/trainer_builder.py +++ b/internlm/core/trainer_builder.py @@ -154,6 +154,7 @@ def __init__( self._set_attributes( kwargs["profiling"], train_dl, val_dls, train_state, optimizer, beta2_scheduler, isp_communicator ) + super().__init__(engine, scheduler) def _setup_time_and_logging(self) -> str: @@ -257,6 +258,7 @@ def fit(self): """ self.train() train_iter = iter(self.train_dl) + with initialize_llm_profile(profiling=self.profiling, start_time=self.current_time) as prof: gc.disable() for batch_count in range(self.train_state.batch_count, gpc.config.data.total_steps): @@ -277,9 +279,6 @@ def _process_batch(self, batch_count: int, train_iter, prof) -> bool: timer("one-batch").stop() return False - if gpc.is_rank_for_log(): - print(f"start trainging {batch_count}", flush=True) - timer("fwd-bwd").start() loss, moe_loss = self._forward_backward(batch) timer("fwd-bwd").stop() @@ -299,8 +298,6 @@ def _process_batch(self, batch_count: int, train_iter, prof) -> bool: def _load_and_prepare_batch(self, batch_count: int, train_iter): batch, train_iter = load_new_batch(train_dl=self.train_dl, train_iter=train_iter, train_state=self.train_state) - # torch.save(batch, "/mnt/petrelfs/lijiaxing/InternEvo/test_data_batch.py") - self.train_state.batch_count = batch_count self.train_state.num_consumed_samples_in_epoch += len(batch[1]) if batch[0].get("type_ids", None) is not None: diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 31b927979..dc9ae7c57 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -207,7 +207,6 @@ def __init__( # No flat fp16 buffer is allocated if the process has no parameters. if rank not in self.param_group_no_params_ranks[group_id]: tensor_list = self._param_store.get_fp16_params_by_rank_group(rank, group_id) - with torch.no_grad(): flat_tensor = flatten(tensor_list) flat_tensor = flat_tensor.data.to(get_current_device()) @@ -268,7 +267,6 @@ def _partition_param_list(self, group_id, param_group): param_list = param_group["params"] sorted_params = sorted(param_list, key=lambda x: x.numel(), reverse=True) - for i, param in enumerate(sorted_params): if param.requires_grad is False: continue From 92e19594d0e9bff0370a2d53e2d260a47806304c Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Wed, 25 Dec 2024 15:36:04 +0800 Subject: [PATCH 07/12] fix ci --- .github/workflows/lint_check.yaml | 2 +- internlm/checkpoint/universal_checkpoint/filesystem.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/lint_check.yaml b/.github/workflows/lint_check.yaml index ab1a532e2..fe86bd05a 100644 --- a/.github/workflows/lint_check.yaml +++ b/.github/workflows/lint_check.yaml @@ -10,7 +10,7 @@ on: jobs: # lint check can be auto-executed by the workflow lint-check: - runs-on: ubuntu-latest + runs-on: ubuntu-20.04 steps: - uses: actions/checkout@v3 diff --git a/internlm/checkpoint/universal_checkpoint/filesystem.py b/internlm/checkpoint/universal_checkpoint/filesystem.py index 1129dd302..6efdc454c 100644 --- a/internlm/checkpoint/universal_checkpoint/filesystem.py +++ b/internlm/checkpoint/universal_checkpoint/filesystem.py @@ -410,7 +410,6 @@ def __init__( super().__init__() self.path = Path(path) self.single_file_per_rank = single_file_per_rank - # self.single_file_per_rank = False self.sync_files = sync_files self.worker_count = worker_count self.per_process_copy_ahead = per_process_copy_ahead From 7b29800f772b9baddd5aaf6d4a82cd728ab842ed Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Wed, 25 Dec 2024 15:37:43 +0800 Subject: [PATCH 08/12] fix ci --- internlm/initialize/launch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index b593b0788..8578eb514 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -259,7 +259,7 @@ def args_sanity_check(): # If 'auto_resume' is not given, we set it to True, so internlm can have opportunity # to auto-load latest checkpoint. ckpt._add_item("auto_resume", True) - + if "universal_ckpt" not in ckpt: ckpt._add_item("universal_ckpt", dict(enable=False, aysnc_save=False, broadcast_load=False)) From 8aaddaffd43aa5f46bf3c4c298b23697ec5f73a9 Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Wed, 25 Dec 2024 15:49:12 +0800 Subject: [PATCH 09/12] fix ci --- internlm/checkpoint/universal_checkpoint/common.py | 2 +- internlm/checkpoint/universal_checkpoint/filesystem.py | 2 +- internlm/checkpoint/universal_checkpoint/save_state_dict.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/internlm/checkpoint/universal_checkpoint/common.py b/internlm/checkpoint/universal_checkpoint/common.py index 011077a4f..a2e6d6bb3 100644 --- a/internlm/checkpoint/universal_checkpoint/common.py +++ b/internlm/checkpoint/universal_checkpoint/common.py @@ -42,7 +42,7 @@ class PlanLRUCache: """ def __init__(self) -> None: - self._cache: OrderedDict[Hashable, Tuple[SavePlan, Metadata]] = OrderedDict() + self._cache: OrderedDict[Hashable, Tuple[SavePlan, Metadata]] = OrderedDict() # pylint: disable=E1136 self._capacity = _MAX_CACHE_SIZE def get(self, key: Hashable) -> Optional[Tuple[SavePlan, Metadata]]: diff --git a/internlm/checkpoint/universal_checkpoint/filesystem.py b/internlm/checkpoint/universal_checkpoint/filesystem.py index 6efdc454c..2db95137a 100644 --- a/internlm/checkpoint/universal_checkpoint/filesystem.py +++ b/internlm/checkpoint/universal_checkpoint/filesystem.py @@ -672,7 +672,7 @@ def read_metadata(self) -> Metadata: metadata = pickle.load(metadata_file) return metadata - def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None: + def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None: # pylint: disable=W0613 self.storage_data = metadata.storage_data assert self.storage_data is not None diff --git a/internlm/checkpoint/universal_checkpoint/save_state_dict.py b/internlm/checkpoint/universal_checkpoint/save_state_dict.py index 2646a9281..8552c4acf 100644 --- a/internlm/checkpoint/universal_checkpoint/save_state_dict.py +++ b/internlm/checkpoint/universal_checkpoint/save_state_dict.py @@ -38,10 +38,10 @@ def save_state_dict( no_dist: bool = False, planner: Optional[SavePlanner] = None, async_io: bool = True, - last_write_futures: Future[List[WriteResult]] = None, + last_write_futures: Future[List[WriteResult]] = None, # pylint: disable=E1136 io_workers=None, is_optimizer=False, -) -> Tuple[Metadata, Future[List[WriteResult]]]: +) -> Tuple[Metadata, Future[List[WriteResult]]]: # pylint: disable=E1136 """ Saves a model in SPMD style. Fix sub-group storage. """ From 4a0429829bf2709475c8663069600f5c6e0b6655 Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Thu, 26 Dec 2024 10:41:09 +0800 Subject: [PATCH 10/12] little fix --- configs/7B_internlm2.py | 10 ++++++---- internlm/checkpoint/checkpoint_manager.py | 8 +++++--- internlm/train/pipeline.py | 6 +++--- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/configs/7B_internlm2.py b/configs/7B_internlm2.py index 2ac67bdce..7993479f1 100644 --- a/configs/7B_internlm2.py +++ b/configs/7B_internlm2.py @@ -38,10 +38,12 @@ async_upload=True, # async ckpt upload. (only work for boto3 ckpt) async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. - # control universal ckpt. INFO: Not compatible with the original ckpt - # Default to use async_save and not use broadcast_load - # as broadcast_load may cause loading performance degradation - universal_ckpt=dict(enable=False, aysnc_save=True, broadcast_load=False), + # INFO: Universal ckpt is not compatible with the original ckpt. + # Default is to use async_save and not use broadcast_load + # as broadcast_load may cause loading performance degradation. + # NOTE: If using aysnc_save, there is a risk of losing the latest ckpt + # when there is a sudden training interruption. + universal_ckpt=dict(enable=True, aysnc_save=True, broadcast_load=False), ) TRAIN_FOLDER = None diff --git a/internlm/checkpoint/checkpoint_manager.py b/internlm/checkpoint/checkpoint_manager.py index dd5f7b4e5..a935d05ec 100644 --- a/internlm/checkpoint/checkpoint_manager.py +++ b/internlm/checkpoint/checkpoint_manager.py @@ -93,6 +93,8 @@ def try_load_internevo_ckpt(ckpt_mm, load_info, train_state: TrainState = None, universal_load( load_ckpt_folder, checkpoint_state, broadcast_checkpoint=gpc.config.ckpt.universal_ckpt.broadcast_load ) + if gpc.is_rank_for_log(): + logger.warning("Finsh loading universal model checkpoint and optimizer checkpoint.") if not universal_ckpt and load_content.need_load(CheckpointLoadContent.MODEL): load_model_checkpoint(folder=load_ckpt_folder, model=ckpt_mm.model) @@ -107,9 +109,9 @@ def try_load_internevo_ckpt(ckpt_mm, load_info, train_state: TrainState = None, if not universal_ckpt and load_content.need_load(CheckpointLoadContent.OPIMIZER): load_optimizer_checkpoint(load_ckpt_folder, ckpt_mm.optimizer) load_content_str += f"{CheckpointLoadContent.OPIMIZER}, " - else: - if gpc.is_rank_for_log(): - logger.warning("CheckpointManager has no 'optimizer', skip reload optim checkpoint!") + + if not load_content.need_load(CheckpointLoadContent.OPIMIZER) and gpc.is_rank_for_log(): + logger.warning("CheckpointManager has no 'optimizer', skip reload optim checkpoint!") # load lr scheduler states. if load_content.need_load(CheckpointLoadContent.SCHEDULAER): diff --git a/internlm/train/pipeline.py b/internlm/train/pipeline.py index 1bf78c646..f34ced17b 100644 --- a/internlm/train/pipeline.py +++ b/internlm/train/pipeline.py @@ -175,7 +175,7 @@ def set_param_unique_tracking_name(model): assert global_fqn not in map_layer_attr, f"{map_layer_attr} exists" map_layer_attr[global_fqn] = { "offset": getattr(child, "offset", [0] * len(child.weight.size())), - "complete_size": getattr(child, "complete_size", child.weight.size()), + "complete_size": getattr(child, "complete_size", list(child.weight.size())), } elif isinstance(child, (RMSNorm)) and uc_enable: @@ -188,7 +188,7 @@ def set_param_unique_tracking_name(model): ) map_layer_attr[global_fqn] = { "offset": getattr(child, "offset", [0] * len(child.weight.size())), - "complete_size": getattr(child, "complete_size", child.weight.size()), + "complete_size": getattr(child, "complete_size", list(child.weight.size())), } else: @@ -226,7 +226,7 @@ def set_param_unique_tracking_name(model): map_layer_attr[local_fqn] = { "offset": getattr(children, "offset", [0] * len(children.weight.size())), - "complete_size": getattr(children, "complete_size", children.weight.size()), + "complete_size": getattr(children, "complete_size", list(children.weight.size())), } From 7336f4873615542480e8e569fd58d8c97b492bf6 Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Thu, 26 Dec 2024 21:45:03 +0800 Subject: [PATCH 11/12] fix bug --- configs/7B_internlm2.py | 2 +- tests/test_training/train_CI.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/configs/7B_internlm2.py b/configs/7B_internlm2.py index 7993479f1..cedb881aa 100644 --- a/configs/7B_internlm2.py +++ b/configs/7B_internlm2.py @@ -43,7 +43,7 @@ # as broadcast_load may cause loading performance degradation. # NOTE: If using aysnc_save, there is a risk of losing the latest ckpt # when there is a sudden training interruption. - universal_ckpt=dict(enable=True, aysnc_save=True, broadcast_load=False), + universal_ckpt=dict(enable=False, aysnc_save=True, broadcast_load=False), ) TRAIN_FOLDER = None diff --git a/tests/test_training/train_CI.py b/tests/test_training/train_CI.py index b33cf4c38..be8e9847a 100644 --- a/tests/test_training/train_CI.py +++ b/tests/test_training/train_CI.py @@ -20,7 +20,7 @@ from internlm.checkpoint import CheckpointManager # noqa: E402 from internlm.core.context import ParallelMode # noqa: E402 from internlm.core.context import global_context as gpc # noqa: E402 -from internlm.core.trainer import TrainState, Trainer # noqa: E402 +from internlm.core.trainer import Trainer, TrainState # noqa: E402 from internlm.data import ( # noqa: E402 build_train_loader_with_data_type, build_valid_loader_with_data_type, @@ -70,7 +70,8 @@ def check_model_weights(model, ckpt_path, total_equal=False): model2_dict[key.replace("wqkv", "Wqkv")] = model2_dict.pop(key) key = key.replace("wqkv", "Wqkv") if key not in model1_dict: - assert False, f"Error: The key {key} for current model dose not exist in standard ckpt!" + if "Wqkv" not in key: + assert False, f"Error: The key {key} for current model dose not exist in standard ckpt!" for key in model1_dict.keys(): if key in model2_dict: From 00f4c6d9ca0c2548f652517f69561aad1cc686ee Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Fri, 27 Dec 2024 10:30:58 +0800 Subject: [PATCH 12/12] fix ci --- tests/test_utils/test_model_checkpoint.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_utils/test_model_checkpoint.py b/tests/test_utils/test_model_checkpoint.py index 5fe8b3c49..06e364fff 100644 --- a/tests/test_utils/test_model_checkpoint.py +++ b/tests/test_utils/test_model_checkpoint.py @@ -1,7 +1,4 @@ import multiprocessing - -backup_ForkingPickler = multiprocessing.reduction.ForkingPickler -backup_dump = multiprocessing.reduction.dump import os from functools import partial @@ -25,6 +22,9 @@ reset_singletons, ) +backup_ForkingPickler = multiprocessing.reduction.ForkingPickler +backup_dump = multiprocessing.reduction.dump + # (TOTAL_STEP, CKPT_EVERY, SNPASHOT_EVERY) step_info_list = [(8, 4, 2), (3, 4, 2), (1, 6, 3)] ckpt_config_list = [ @@ -201,8 +201,8 @@ def return_latest_save_path(save_ckpt_folder, total_step, snapshot_freq, ckpt_fr @pytest.mark.parametrize("step_info", step_info_list) @pytest.mark.parametrize("ckpt_config", ckpt_config_list) def test_ckpt_mm(step_info, ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-import - from internlm.core.context import global_context as gpc from internlm.checkpoint.checkpoint_manager import CheckpointLoadMask + from internlm.core.context import global_context as gpc ckpt_config = Config(ckpt_config) total_step, checkpoint_every, oss_snapshot_freq = step_info @@ -222,6 +222,8 @@ def test_ckpt_mm(step_info, ckpt_config, init_dist_and_model): # noqa # pylint: ) model, opim = init_dist_and_model + gpc.config._add_item("ckpt", dict()) + gpc.config.ckpt._add_item("universal_ckpt", dict(enable=False, aysnc_save=True, broadcast_load=False)) train_state = TrainState(gpc.config, None) if isinstance(opim, HybridZeroOptimizer): print("Is HybridZeroOptimizer!", flush=True) @@ -297,9 +299,9 @@ def test_ckpt_mm(step_info, ckpt_config, init_dist_and_model): # noqa # pylint: def query_quit_file(rank, world_size=2): + from internlm.checkpoint.checkpoint_manager import CheckpointSaveType from internlm.core.context import global_context as gpc from internlm.initialize import initialize_distributed_env - from internlm.checkpoint.checkpoint_manager import CheckpointSaveType ckpt_config = Config( dict( @@ -348,8 +350,6 @@ def query_quit_file(rank, world_size=2): def test_quit_siganl_handler(): # noqa # pylint: disable=unused-import - import multiprocessing - # we do hack here to workaround the bug of 3rd party library dill, which only occurs in this unittest: # https://github.com/uqfoundation/dill/issues/380 multiprocessing.reduction.ForkingPickler = backup_ForkingPickler