-
Notifications
You must be signed in to change notification settings - Fork 920
unifying all reduce memory allocation for single-node and multi-node nvlink #2955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 10 commits
87b8c37
fcf5047
9212097
dd00af3
de04b18
df03bda
e47254d
03371de
1dc9910
5718a32
d118f16
1059a75
2f4309f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,34 @@ | ||
| from typing import Any | ||
|
|
||
| import torch | ||
| import torch.distributed._symmetric_memory as symm_mem | ||
|
|
||
|
|
||
| def _alloc_symm_buffer_bytes( | ||
| size_bytes: int, | ||
| world_size: int, | ||
| dtype: torch.dtype, | ||
| device: torch.device, | ||
| group_name: str, | ||
| ) -> tuple[list[int], torch.Tensor, Any]: | ||
| """Allocate a symmetric memory buffer and return per-peer pointers. | ||
|
|
||
| Args: | ||
| size_bytes: Total buffer size in bytes. | ||
| world_size: Number of peers in the communication group. | ||
| dtype: Element type used to interpret the buffer. | ||
| device: CUDA device for the allocation. | ||
| group_name: Process group name for the rendezvous. | ||
|
|
||
| Returns: | ||
| Tuple of (per-peer data pointers, local tensor, symmetric memory handle). | ||
| """ | ||
| elem_size = torch.empty(0, dtype=dtype).element_size() | ||
| numel = size_bytes // elem_size | ||
| tensor = symm_mem.empty(numel, dtype=dtype, device=device) | ||
|
Comment on lines
+68
to
+70
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be more ergonomic if the API asks for shape or numel instead of size_bytes?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can add another api or option to use this api with shape or numel but I wanted to be least intrusive in the kernels. |
||
| handle = symm_mem.rendezvous(tensor, group=group_name) | ||
| ptrs: list[int] = [ | ||
| handle.get_buffer(peer, (numel,), dtype, storage_offset=0).data_ptr() | ||
| for peer in range(world_size) | ||
| ] | ||
| return ptrs, tensor, handle | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,7 +31,7 @@ | |
|
|
||
| logger = logging.getLogger(__name__) | ||
| from .cuda_ipc import create_shared_buffer, cudart, free_shared_buffer | ||
|
|
||
| from .torch_symmetric_memory import _alloc_symm_buffer_bytes | ||
|
coderabbitai[bot] marked this conversation as resolved.
Outdated
|
||
|
|
||
| class AllReduceStrategyType: | ||
| # NOTE: for trtllm_custom_all_reduce | ||
|
|
@@ -399,6 +399,7 @@ def trtllm_moe_finalize_allreduce_fusion( | |
| MAX_ALL_REDUCE_BLOCKS = 24 | ||
| LamportTokenNumThreshold = 16 | ||
|
|
||
| _symm_workspace_refs: dict[int, list[torch.Tensor]] = {} | ||
|
|
||
| @deprecated( | ||
| "trtllm_create_ipc_workspace_for_all_reduce and trtllm_custom_all_reduce are deprecated and will be removed in the next major bump, use allreduce.py instead." | ||
|
|
@@ -451,27 +452,40 @@ def trtllm_create_ipc_workspace_for_all_reduce( | |
| flag_size = FLAG_SIZE * tp_size * 2 | ||
| lamport_buffer_size = tp_size * LamportTokenNumThreshold * tp_size * hidden_dim * 2 | ||
|
|
||
| device = torch.device(f"cuda:{torch.cuda.current_device()}") | ||
| group_name = group.group_name if group is not None else torch.distributed.group.WORLD.group_name | ||
| symm_refs: list[torch.Tensor] = [] | ||
| ipc_handles = list() | ||
|
|
||
| for size in [ | ||
| buffer_size, | ||
| buffer_size, | ||
| flag_size, | ||
| flag_size, | ||
| lamport_buffer_size, | ||
| lamport_buffer_size, | ||
| lamport_buffer_size, | ||
| for size, dtype in [ | ||
| (buffer_size, torch.float32), | ||
| (buffer_size, torch.float32), | ||
| (flag_size, torch.int32), | ||
| (flag_size, torch.int32), | ||
| (lamport_buffer_size, torch.float16), | ||
| (lamport_buffer_size, torch.float16), | ||
| (lamport_buffer_size, torch.float16), | ||
| ]: | ||
| # all sizes should be aligned to 1LU << 21 bytes (2MB) | ||
| aligned_size = round_up(size, 1 << 21) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why ? torch symm memory will use a mempool under the cover so that you can have smaller requests. Probably good to be make sure it is 16B aligned but not 2MB.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed! |
||
| ipc_handles.append(create_shared_buffer(aligned_size, group)) | ||
| ptrs, tensor, handle = _alloc_symm_buffer_bytes( | ||
| aligned_size, | ||
| tp_size, | ||
| dtype, | ||
| device, | ||
| group_name, | ||
| ) | ||
| symm_refs.append((tensor, handle)) | ||
| ipc_handles.append(ptrs) | ||
|
|
||
| logger.debug( | ||
| "rank %s allocated ipc_handles: %s", | ||
| rank, | ||
| [[hex(handle) for handle in sublist] for sublist in ipc_handles], | ||
| ) | ||
|
|
||
| _symm_workspace_refs[id(ipc_handles)] = symm_refs | ||
|
|
||
| trtllm_lamport_initialize_all( | ||
| ipc_handles[4][rank], | ||
| ipc_handles[5][rank], | ||
|
|
@@ -488,16 +502,16 @@ def trtllm_create_ipc_workspace_for_all_reduce( | |
| def trtllm_destroy_ipc_workspace_for_all_reduce( | ||
| workspace: List[List[int]], group: Optional[ProcessGroup] = None | ||
| ) -> None: | ||
| """ | ||
| Note: | ||
| This function is used to destroy a workspace for all reduce. | ||
| The workspace is a list of IPC handles. | ||
| The workspace should be destroyed after calling trtllm_custom_all_reduce. | ||
| The workspace can be reused for multiple all reduce calls under the same configuration. | ||
| """ | ||
| """Destroy a workspace created by trtllm_create_ipc_workspace_for_all_reduce. | ||
|
|
||
| for ipc_handle in workspace: | ||
| free_shared_buffer(ipc_handle, group) | ||
| Releases the symmetric memory references held internally. The workspace | ||
| list should not be used after this call. | ||
|
|
||
| Args: | ||
| workspace: The ipc_handles list returned by the create function. | ||
| group: Unused, kept for API compatibility. | ||
| """ | ||
| _symm_workspace_refs.pop(id(workspace), None) | ||
|
|
||
|
|
||
| BarrierFlagCount = 256 | ||
|
|
@@ -588,40 +602,44 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( | |
|
|
||
| lamport_buffer_size = lamport_comm_size * 3 | ||
|
|
||
| device = torch.device(f"cuda:{torch.cuda.current_device()}") | ||
| group_name = group.group_name if group is not None else torch.distributed.group.WORLD.group_name | ||
| symm_refs: list[torch.Tensor] = [] | ||
|
|
||
| # we should init 3 buffers for all reduce fusion: | ||
| # [buffer_size, flag_size, lamport_buffer_size] | ||
|
|
||
| ipc_handles: List[List[int]] = list() | ||
| mem_handles: List[SymmDeviceMemory] = list() | ||
| for size in [buffer_size, flag_size, lamport_buffer_size]: | ||
| lamport_buffer_dtype = torch.float16 if not use_fp32_lamport else torch.float32 | ||
| for size, dtype in [ | ||
| (buffer_size, torch.float32), | ||
| (flag_size, torch.int32), | ||
| (lamport_buffer_size, lamport_buffer_dtype), | ||
| ]: | ||
| # todo(review): confirm we need this alignment | ||
| # all sizes should be aligned to 1LU << 21 bytes (2MB) | ||
| aligned_size = round_up(size, 1 << 21) | ||
|
|
||
| if not use_symm_dev_mem: | ||
| ipc_handles.append(create_shared_buffer(aligned_size, group)) | ||
| else: | ||
| # Use torch.cuda.current_device() instead of tp_rank to support | ||
| # base_gpu_id != 0 scenarios where the actual CUDA device index | ||
| # differs from the TP rank. | ||
| symm_mem = SymmDeviceMemory( | ||
| aligned_size, | ||
| tp_size, | ||
| tp_rank, | ||
| torch.cuda.current_device(), | ||
| comm_backend, | ||
| enable_multicast=False, | ||
| allocate_signal_pads=False, | ||
| ) | ||
| ipc_handles.append(symm_mem.uc_ptrs) | ||
| mem_handles.append(symm_mem) | ||
| ptrs, tensor, handle = _alloc_symm_buffer_bytes( | ||
| aligned_size, | ||
| tp_size, | ||
| dtype, | ||
| device, | ||
| group_name, | ||
| ) | ||
| symm_refs.append((tensor, handle)) | ||
| ipc_handles.append(ptrs) | ||
| mem_handles.append(handle) | ||
|
|
||
| logger.debug( | ||
| "rank %s allocated ipc_handles: %s", | ||
| tp_rank, | ||
| [[hex(handle) for handle in sublist] for sublist in ipc_handles], | ||
| ) | ||
|
|
||
| _symm_workspace_refs[id(ipc_handles)] = symm_refs | ||
|
|
||
| # Initialize lamport buffer | ||
| aligned_lamport_buffer_size = round_up(lamport_buffer_size, 1 << 21) | ||
| if use_fp32_lamport: | ||
|
|
@@ -700,20 +718,16 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( | |
| def trtllm_destroy_ipc_workspace_for_all_reduce_fusion( | ||
| workspace: List[List[int]], group: Optional[ProcessGroup] = None | ||
| ) -> None: | ||
| """ | ||
| Parameters: | ||
| - workspace: the workspace to destroy. | ||
| - group: the process group to use. | ||
| """Destroy a workspace created by trtllm_create_ipc_workspace_for_all_reduce_fusion. | ||
|
|
||
| Note: | ||
| This function is used to destroy a workspace for all reduce fusion. | ||
| The workspace is a list of IPC handles. | ||
| The workspace should be destroyed after calling trtllm_custom_all_reduce_fusion. | ||
| The workspace can be reused for multiple all reduce fusion calls under the same configuration. | ||
| """ | ||
| Releases the symmetric memory references held internally. The workspace | ||
| list should not be used after this call. | ||
|
|
||
| for ipc_handle in workspace: | ||
| free_shared_buffer(ipc_handle, group) | ||
| Args: | ||
| workspace: The ipc_handles list returned by the create function. | ||
| group: Unused, kept for API compatibility. | ||
| """ | ||
| _symm_workspace_refs.pop(id(workspace), None) | ||
|
Comment on lines
+728
to
+737
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The create path allocates π€ Prompt for AI Agents
Comment on lines
725
to
+737
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Destroy function incomplete: missing This function only removes symmetric memory references but does not free the π€ Prompt for AI Agents |
||
|
|
||
|
|
||
| # allReduce fused quant utils | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,12 +14,27 @@ | |
| from typing_extensions import deprecated | ||
|
|
||
| from flashinfer.comm.mapping import Mapping | ||
| from flashinfer.comm.mnnvl import TorchDistBackend | ||
|
|
||
| from ..jit import gen_trtllm_mnnvl_comm_module | ||
| from ..utils import register_custom_op | ||
| from .mnnvl import McastGPUBuffer, CommBackend, MPIBackend | ||
| from .mnnvl import CommBackend, MPIBackend | ||
| from .workspace_base import AllReduceFusionWorkspace | ||
|
|
||
| from .torch_symmetric_memory import _alloc_symm_buffer_bytes | ||
| from ..cuda_utils import checkCudaErrors | ||
| try: | ||
| # cuda-python >= 12.9 (has cuda.bindings.driver) | ||
| from cuda.bindings import driver as cuda | ||
| except ImportError: | ||
| try: | ||
| # cuda-python < 12.9 (no cuda.bindings.driver, use cuda as driver) | ||
| # from cuda import cuda is not available in cuda-python >= 13.0 | ||
| from cuda import cuda | ||
| except ImportError as e: | ||
| raise ImportError( | ||
| "Could not import the 'cuda' module. " | ||
| "Please install cuda-python that matches your CUDA version." | ||
| ) from e | ||
|
|
||
| def mpi_barrier(): | ||
| from mpi4py import MPI | ||
|
|
@@ -135,16 +150,23 @@ def __init__( | |
| # Use torch.cuda.current_device() instead of mapping.local_rank to | ||
| # support base_gpu_id != 0 scenarios where the actual CUDA device | ||
| # index differs from the TP rank / local_rank. | ||
| self.mcast_buffer_handle = McastGPUBuffer( | ||
| device = torch.device("cuda", torch.cuda.current_device()) | ||
| if isinstance(comm_backend, TorchDistBackend): | ||
| group = comm_backend._group if comm_backend._group is not None else torch.distributed.group.WORLD | ||
| group_name = group.group_name | ||
| else: | ||
| group_name = torch.distributed.group.WORLD.group_name | ||
|
Comment on lines
+141
to
+149
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't silently rendezvous on Lines 154-158 only derive π€ Prompt for AI Agents |
||
| self.ptrs, self.tensor, self.handle = _alloc_symm_buffer_bytes( | ||
| requested_workspace_size, | ||
| mapping.tp_size, | ||
| mapping.tp_rank, | ||
| torch.device("cuda", torch.cuda.current_device()), | ||
| comm_backend, | ||
| torch.float32, | ||
| device, | ||
| group_name, | ||
| ) | ||
|
|
||
| # Get the actual usable buffer size after allocation (buf_size is updated by McastGPUBuffer) | ||
| allocated_size = self.mcast_buffer_handle.buf_size | ||
| # handle.buffer_size is the usable data size. torch symmetric memory | ||
| # allocator places signal_pad on top of it, not carved from within. | ||
| allocated_size = self.handle.buffer_size | ||
| # We want the buffer size to be aligned to 16B which is the granularity for buffer management. | ||
| self.buffer_size_bytes = ( | ||
| math.floor(allocated_size / self.NUM_LAMPORT_BUFFERS) // 16 * 16 | ||
|
|
@@ -157,7 +179,7 @@ def __init__( | |
| ) | ||
|
|
||
| # We use FP32 for sentinel value regardless of the real dtype | ||
| self.mcast_buffer_handle.lamport_initialize(mapping.tp_rank, torch.float32) | ||
| self.lamport_initialize(mapping.tp_rank, torch.float32, allocated_size) | ||
| # Wait until the initialization is done | ||
| torch.cuda.synchronize() | ||
| comm_backend.barrier() | ||
|
|
@@ -173,9 +195,26 @@ def __init__( | |
| device=torch.device("cuda", torch.cuda.current_device()), | ||
| ) | ||
|
|
||
| self.uc_ptrs_dev = self.mcast_buffer_handle.get_buffer_ptrs_dev() | ||
| self.uc_ptr_local = self.mcast_buffer_handle.get_unicast_ptr(self.rank) | ||
| self.mc_ptr = self.mcast_buffer_handle.get_multicast_ptr() | ||
| self.uc_ptrs_dev = self.handle.buffer_ptrs_dev | ||
| self.uc_ptr_local = self.handle.buffer_ptrs[self.rank] | ||
| self.mc_ptr = self.handle.multicast_ptr | ||
|
|
||
| def lamport_initialize(self, rank: int, dtype: torch.dtype, allocated_size: int): | ||
| if dtype == torch.bfloat16 or dtype == torch.float16: | ||
| neg_zero = 0x8000 | ||
| dsize = 2 | ||
| memset_func = cuda.cuMemsetD16 | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why can't you use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed! thank you! I learned something new. |
||
| elif dtype == torch.float32: | ||
| neg_zero = 0x80000000 | ||
| dsize = 4 | ||
| memset_func = cuda.cuMemsetD32 | ||
| else: | ||
| raise ValueError(f"Unsupported dtype: {dtype}") | ||
|
|
||
| num_elements = (allocated_size) // dsize | ||
| checkCudaErrors( | ||
| memset_func(int(self.ptrs[rank]), neg_zero, num_elements) | ||
| ) | ||
|
|
||
| @functools.cache | ||
| def is_buffer_size_sufficient( | ||
|
|
@@ -235,11 +274,13 @@ def destroy(self) -> None: | |
| if getattr(self, "_destroyed", False): | ||
| return # Already destroyed, nothing to do | ||
|
|
||
| del self.mcast_buffer_handle | ||
| del self.buffer_flags | ||
| del self.uc_ptrs_dev | ||
| del self.uc_ptr_local | ||
| del self.mc_ptr | ||
| del self.tensor | ||
| del self.handle | ||
| del self.ptrs | ||
| self._destroyed = True | ||
|
|
||
|
|
||
|
|
@@ -506,33 +547,22 @@ def get_allreduce_mnnvl_workspace( | |
| dtype: torch.dtype, | ||
| comm_backend_for_handle_transfer: Optional[CommBackend] = None, | ||
| buffer_size_in_bytes: Optional[int] = None, | ||
| ) -> Tuple[McastGPUBuffer, torch.Tensor, int]: | ||
| ) -> Tuple[MNNVLAllReduceFusionWorkspace, torch.Tensor, int]: | ||
| """Get workspace buffers needed for multi-node NVLink all-reduce operation. | ||
|
|
||
| This function allocates and initializes the workspace buffers required for performing | ||
| multi-node NVLink all-reduce operations. It creates: | ||
| 1. A multicast GPU buffer for communication between nodes | ||
| 2. A flags tensor to track buffer state | ||
| 3. Maximum number of elements that can fit in the buffer | ||
|
|
||
| The buffer size is calculated to efficiently handle common hidden dimensions | ||
| (2048, 4096, 5120, 7168, 8192) by using their LCM of 286720. | ||
|
|
||
| Args: | ||
| mapping: Tensor parallel mapping configuration containing rank info | ||
| dtype: Data type of the tensors being reduced | ||
| comm_backend_for_handle_transfer: Communication backend for handle transfer | ||
| buffer_size_in_bytes: Optional buffer size. Practically, assign this to 3 * 2 * dtype.itemsize * hidden_dim * max_tokens | ||
|
|
||
| Returns: | ||
| Tuple containing: | ||
| - McastGPUBuffer: Multicast buffer for inter-node communication | ||
| - MNNVLAllReduceFusionWorkspace: The workspace object backed by torch symmetric memory | ||
| - torch.Tensor: Buffer flags tensor tracking state | ||
| - int: Maximum number of elements that can fit in buffer | ||
| """ | ||
| # buffer shape: [3, 2, buffer_tokens, hidden_dim] | ||
| stride = 3 * 2 * dtype.itemsize | ||
| # LCM for hidden_dim: 2048, 4096, 5120, 7168, 8192 = 286720 | ||
| # max_num_elements must be a multiple of 286720 | ||
| lcm_hidden_dim = 286720 | ||
| TARGET_WORKSPACE_SIZE_BYTES = ( | ||
| buffer_size_in_bytes if buffer_size_in_bytes is not None else 12_000_000 | ||
|
|
@@ -541,21 +571,17 @@ def get_allreduce_mnnvl_workspace( | |
| TARGET_WORKSPACE_SIZE_BYTES / (lcm_hidden_dim * stride) | ||
| ) * (lcm_hidden_dim * stride) | ||
|
|
||
| # Redirect to the new workspace allocation logic. The new kernel needs the new flag buffer layout. | ||
| workspace = MNNVLAllReduceFusionWorkspace( | ||
| mapping, | ||
| buffer_size_in_bytes=buffer_size_in_bytes, | ||
| comm_backend=comm_backend_for_handle_transfer, | ||
| ) | ||
|
|
||
| mcast_buffer = workspace.mcast_buffer_handle | ||
| buffer_flags = workspace.buffer_flags | ||
| # this is calculated using the legacy behavior. We do not use the actual allocated size. | ||
| max_num_elements = workspace.buffer_size_bytes // stride | ||
|
|
||
| return ( | ||
| mcast_buffer, | ||
| buffer_flags, | ||
| workspace, | ||
| workspace.buffer_flags, | ||
| max_num_elements, | ||
| ) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't silently shrink byte-sized allocations.
Line 27 floors
size_bytesto whole elements, so a 6-byte request withtorch.float32allocates only 4 bytes. Because callers treatsize_bytesas the promised usable capacity, that can turn into an undersized symmetric buffer. Please roundnumelup or reject non-divisible sizes here.π Proposed fix
π€ Prompt for AI Agents