-
Notifications
You must be signed in to change notification settings - Fork 920
unifying all reduce memory allocation for single-node and multi-node nvlink #2955
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
87b8c37
fcf5047
9212097
dd00af3
de04b18
df03bda
e47254d
03371de
1dc9910
5718a32
d118f16
1059a75
2f4309f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| import functools | ||
| from typing import Any | ||
|
|
||
| import torch | ||
| import torch.distributed._symmetric_memory as symm_mem | ||
| import torch.distributed.distributed_c10d as c10d | ||
|
|
||
| _compat_patched = False | ||
|
|
||
|
|
||
| def _patch_group_count_reset() -> None: | ||
| """Prevent group_count from resetting to 0 on WORLD destruction (2.10 and below).""" | ||
| global _compat_patched | ||
| if _compat_patched: | ||
| return | ||
| _compat_patched = True | ||
|
|
||
| import torch.distributed as dist | ||
|
|
||
| _original_destroy = dist.destroy_process_group | ||
|
|
||
| @functools.wraps(_original_destroy) | ||
| def _patched_destroy(group=None): | ||
| saved_count = c10d._world.group_count | ||
| _original_destroy(group) | ||
| # WORLD destruction resets group_count to 0 β restore it so the next | ||
| # init_process_group picks a name that is fresh in the C++ map. | ||
| if group is None: | ||
| c10d._world.group_count = saved_count | ||
|
|
||
| dist.destroy_process_group = _patched_destroy | ||
|
|
||
|
|
||
| def _enable_symm_mem_for_group(group_name: str) -> None: | ||
| """Enable symmetric memory for a process group (PyTorch 2.11+).""" | ||
| torch_version = tuple(int(x) for x in torch.__version__.split(".")[:2]) | ||
| if torch_version >= (2, 11): | ||
| return | ||
| from torch.distributed._symmetric_memory import enable_symm_mem_for_group | ||
|
|
||
| _patch_group_count_reset() | ||
| enable_symm_mem_for_group(group_name) | ||
|
|
||
|
|
||
| def _alloc_symm_buffer_bytes( | ||
| size_bytes: int, | ||
| world_size: int, | ||
| dtype: torch.dtype, | ||
| device: torch.device, | ||
| group_name: str, | ||
| ) -> tuple[list[int], torch.Tensor, Any]: | ||
| """Allocate a symmetric memory buffer and return per-peer pointers. | ||
|
|
||
| Args: | ||
| size_bytes: Total buffer size in bytes. | ||
| world_size: Number of peers in the communication group. | ||
| dtype: Element type used to interpret the buffer. | ||
| device: CUDA device for the allocation. | ||
| group_name: Process group name for the rendezvous. | ||
|
|
||
| Returns: | ||
| Tuple of (per-peer data pointers, local tensor, symmetric memory handle). | ||
| """ | ||
| # Ensure symmetric memory is set up with the correct store before | ||
| # rendezvous on PyTorch older than 2.11. | ||
| _enable_symm_mem_for_group(group_name) | ||
|
|
||
| elem_size = torch.empty(0, dtype=dtype).element_size() | ||
| numel = size_bytes // elem_size | ||
| tensor = symm_mem.empty(numel, dtype=dtype, device=device) | ||
|
Comment on lines
+68
to
+70
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't silently shrink byte-sized allocations. Line 27 floors π Proposed fix- numel = size_bytes // elem_size
+ numel = (size_bytes + elem_size - 1) // elem_sizeπ€ Prompt for AI Agents
Comment on lines
+68
to
+70
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be more ergonomic if the API asks for shape or numel instead of size_bytes?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can add another api or option to use this api with shape or numel but I wanted to be least intrusive in the kernels. |
||
| handle = symm_mem.rendezvous(tensor, group=group_name) | ||
| ptrs: list[int] = [ | ||
| handle.get_buffer(peer, (numel,), dtype, storage_offset=0).data_ptr() | ||
| for peer in range(world_size) | ||
| ] | ||
| return ptrs, tensor, handle | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: |
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,7 +30,8 @@ | |
| from ..utils import register_custom_op, round_up | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
| from .cuda_ipc import create_shared_buffer, cudart, free_shared_buffer | ||
| from .cuda_ipc import cudart | ||
| from .torch_symmetric_memory import _alloc_symm_buffer_bytes | ||
|
|
||
|
|
||
| class AllReduceStrategyType: | ||
|
|
@@ -399,6 +400,8 @@ def trtllm_moe_finalize_allreduce_fusion( | |
| MAX_ALL_REDUCE_BLOCKS = 24 | ||
| LamportTokenNumThreshold = 16 | ||
|
|
||
| _symm_workspace_refs: dict[int, list[torch.Tensor]] = {} | ||
|
|
||
|
|
||
| @deprecated( | ||
| "trtllm_create_ipc_workspace_for_all_reduce and trtllm_custom_all_reduce are deprecated and will be removed in the next major bump, use allreduce.py instead." | ||
|
|
@@ -451,27 +454,43 @@ def trtllm_create_ipc_workspace_for_all_reduce( | |
| flag_size = FLAG_SIZE * tp_size * 2 | ||
| lamport_buffer_size = tp_size * LamportTokenNumThreshold * tp_size * hidden_dim * 2 | ||
|
|
||
| device = torch.device(f"cuda:{torch.cuda.current_device()}") | ||
| group_name = ( | ||
| group.group_name | ||
| if group is not None | ||
| else torch.distributed.group.WORLD.group_name | ||
| ) | ||
| symm_refs: list[torch.Tensor] = [] | ||
| ipc_handles = list() | ||
|
|
||
| for size in [ | ||
| buffer_size, | ||
| buffer_size, | ||
| flag_size, | ||
| flag_size, | ||
| lamport_buffer_size, | ||
| lamport_buffer_size, | ||
| lamport_buffer_size, | ||
| for size, dtype in [ | ||
| (buffer_size, torch.float32), | ||
| (buffer_size, torch.float32), | ||
| (flag_size, torch.int32), | ||
| (flag_size, torch.int32), | ||
| (lamport_buffer_size, torch.float16), | ||
| (lamport_buffer_size, torch.float16), | ||
| (lamport_buffer_size, torch.float16), | ||
| ]: | ||
| # all sizes should be aligned to 1LU << 21 bytes (2MB) | ||
| aligned_size = round_up(size, 1 << 21) | ||
| ipc_handles.append(create_shared_buffer(aligned_size, group)) | ||
| aligned_size = round_up(size, 16) | ||
| ptrs, tensor, handle = _alloc_symm_buffer_bytes( | ||
| aligned_size, | ||
| tp_size, | ||
| dtype, | ||
| device, | ||
| group_name, | ||
| ) | ||
| symm_refs.append((tensor, handle)) | ||
| ipc_handles.append(ptrs) | ||
|
|
||
| logger.debug( | ||
| "rank %s allocated ipc_handles: %s", | ||
| rank, | ||
| [[hex(handle) for handle in sublist] for sublist in ipc_handles], | ||
| ) | ||
|
|
||
| _symm_workspace_refs[id(ipc_handles)] = symm_refs | ||
|
|
||
| trtllm_lamport_initialize_all( | ||
| ipc_handles[4][rank], | ||
| ipc_handles[5][rank], | ||
|
|
@@ -488,16 +507,16 @@ def trtllm_create_ipc_workspace_for_all_reduce( | |
| def trtllm_destroy_ipc_workspace_for_all_reduce( | ||
| workspace: List[List[int]], group: Optional[ProcessGroup] = None | ||
| ) -> None: | ||
| """ | ||
| Note: | ||
| This function is used to destroy a workspace for all reduce. | ||
| The workspace is a list of IPC handles. | ||
| The workspace should be destroyed after calling trtllm_custom_all_reduce. | ||
| The workspace can be reused for multiple all reduce calls under the same configuration. | ||
| """ | ||
| """Destroy a workspace created by trtllm_create_ipc_workspace_for_all_reduce. | ||
|
|
||
| for ipc_handle in workspace: | ||
| free_shared_buffer(ipc_handle, group) | ||
| Releases the symmetric memory references held internally. The workspace | ||
| list should not be used after this call. | ||
|
|
||
| Args: | ||
| workspace: The ipc_handles list returned by the create function. | ||
| group: Unused, kept for API compatibility. | ||
| """ | ||
| _symm_workspace_refs.pop(id(workspace), None) | ||
|
|
||
|
|
||
| BarrierFlagCount = 256 | ||
|
|
@@ -588,42 +607,48 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( | |
|
|
||
| lamport_buffer_size = lamport_comm_size * 3 | ||
|
|
||
| device = torch.device(f"cuda:{torch.cuda.current_device()}") | ||
| group_name = ( | ||
| group.group_name | ||
| if group is not None | ||
| else torch.distributed.group.WORLD.group_name | ||
| ) | ||
| symm_refs: list[torch.Tensor] = [] | ||
|
|
||
| # we should init 3 buffers for all reduce fusion: | ||
| # [buffer_size, flag_size, lamport_buffer_size] | ||
|
|
||
| ipc_handles: List[List[int]] = list() | ||
| mem_handles: List[SymmDeviceMemory] = list() | ||
| for size in [buffer_size, flag_size, lamport_buffer_size]: | ||
| # todo(review): confirm we need this alignment | ||
| # all sizes should be aligned to 1LU << 21 bytes (2MB) | ||
| aligned_size = round_up(size, 1 << 21) | ||
| lamport_buffer_dtype = torch.float16 if not use_fp32_lamport else torch.float32 | ||
| for size, dtype in [ | ||
| (buffer_size, torch.float32), | ||
| (flag_size, torch.int32), | ||
| (lamport_buffer_size, lamport_buffer_dtype), | ||
| ]: | ||
| aligned_size = round_up(size, 16) | ||
|
|
||
| if not use_symm_dev_mem: | ||
| ipc_handles.append(create_shared_buffer(aligned_size, group)) | ||
| else: | ||
| # Use torch.cuda.current_device() instead of tp_rank to support | ||
| # base_gpu_id != 0 scenarios where the actual CUDA device index | ||
| # differs from the TP rank. | ||
| symm_mem = SymmDeviceMemory( | ||
| aligned_size, | ||
| tp_size, | ||
| tp_rank, | ||
| torch.cuda.current_device(), | ||
| comm_backend, | ||
| enable_multicast=False, | ||
| allocate_signal_pads=False, | ||
| ) | ||
| ipc_handles.append(symm_mem.uc_ptrs) | ||
| mem_handles.append(symm_mem) | ||
| ptrs, tensor, handle = _alloc_symm_buffer_bytes( | ||
| aligned_size, | ||
| tp_size, | ||
| dtype, | ||
| device, | ||
| group_name, | ||
| ) | ||
| symm_refs.append((tensor, handle)) | ||
| ipc_handles.append(ptrs) | ||
| mem_handles.append(handle) | ||
|
|
||
| logger.debug( | ||
| "rank %s allocated ipc_handles: %s", | ||
| tp_rank, | ||
| [[hex(handle) for handle in sublist] for sublist in ipc_handles], | ||
| ) | ||
|
|
||
| _symm_workspace_refs[id(ipc_handles)] = symm_refs | ||
|
|
||
| # Initialize lamport buffer | ||
| aligned_lamport_buffer_size = round_up(lamport_buffer_size, 1 << 21) | ||
| aligned_lamport_buffer_size = round_up(lamport_buffer_size, 16) | ||
| if use_fp32_lamport: | ||
| trtllm_lamport_initialize( | ||
| ipc_handles[2][tp_rank], aligned_lamport_buffer_size // 4, torch.float32 | ||
|
|
@@ -700,20 +725,16 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion( | |
| def trtllm_destroy_ipc_workspace_for_all_reduce_fusion( | ||
| workspace: List[List[int]], group: Optional[ProcessGroup] = None | ||
| ) -> None: | ||
| """ | ||
| Parameters: | ||
| - workspace: the workspace to destroy. | ||
| - group: the process group to use. | ||
| """Destroy a workspace created by trtllm_create_ipc_workspace_for_all_reduce_fusion. | ||
|
|
||
| Note: | ||
| This function is used to destroy a workspace for all reduce fusion. | ||
| The workspace is a list of IPC handles. | ||
| The workspace should be destroyed after calling trtllm_custom_all_reduce_fusion. | ||
| The workspace can be reused for multiple all reduce fusion calls under the same configuration. | ||
| """ | ||
| Releases the symmetric memory references held internally. The workspace | ||
| list should not be used after this call. | ||
|
|
||
| for ipc_handle in workspace: | ||
| free_shared_buffer(ipc_handle, group) | ||
| Args: | ||
| workspace: The ipc_handles list returned by the create function. | ||
| group: Unused, kept for API compatibility. | ||
| """ | ||
| _symm_workspace_refs.pop(id(workspace), None) | ||
|
Comment on lines
+728
to
+737
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The create path allocates π€ Prompt for AI Agents
Comment on lines
725
to
+737
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Destroy function incomplete: missing This function only removes symmetric memory references but does not free the π€ Prompt for AI Agents |
||
|
|
||
|
|
||
| # allReduce fused quant utils | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use ASCII punctuation in comments.
Ruff flags the EN DASH in this comment; replace it with
-to keep pre-commit clean.π§Ή Proposed fix
π Committable suggestion
π§° Tools
πͺ Ruff (0.15.10)
[warning] 26-26: Comment contains ambiguous
β(EN DASH). Did you mean-(HYPHEN-MINUS)?(RUF003)
π€ Prompt for AI Agents