Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions flashinfer/comm/torch_symmetric_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import functools
from typing import Any

import torch
import torch.distributed._symmetric_memory as symm_mem
import torch.distributed.distributed_c10d as c10d

_compat_patched = False


def _patch_group_count_reset() -> None:
"""Prevent group_count from resetting to 0 on WORLD destruction (2.10 and below)."""
global _compat_patched
if _compat_patched:
return
_compat_patched = True

import torch.distributed as dist

_original_destroy = dist.destroy_process_group

@functools.wraps(_original_destroy)
def _patched_destroy(group=None):
saved_count = c10d._world.group_count
_original_destroy(group)
# WORLD destruction resets group_count to 0 – restore it so the next
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟑 Minor

Use ASCII punctuation in comments.

Ruff flags the EN DASH in this comment; replace it with - to keep pre-commit clean.

🧹 Proposed fix
-        # WORLD destruction resets group_count to 0 – restore it so the next
+        # WORLD destruction resets group_count to 0 - restore it so the next
πŸ“ Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# WORLD destruction resets group_count to 0 – restore it so the next
# WORLD destruction resets group_count to 0 - restore it so the next
🧰 Tools
πŸͺ› Ruff (0.15.10)

[warning] 26-26: Comment contains ambiguous – (EN DASH). Did you mean - (HYPHEN-MINUS)?

(RUF003)

πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/comm/torch_symmetric_memory.py` at line 26, The comment containing
an EN DASH should be changed to use an ASCII hyphen; locate the comment near the
WORLD handling that reads "WORLD destruction resets group_count to 0 – restore
it so the next" and replace the EN DASH with a regular hyphen ("-") so it reads
"WORLD destruction resets group_count to 0 - restore it so the next", ensuring
the comment uses ASCII punctuation to satisfy Ruff/EN DASH linting.

# init_process_group picks a name that is fresh in the C++ map.
if group is None:
c10d._world.group_count = saved_count

dist.destroy_process_group = _patched_destroy


def _enable_symm_mem_for_group(group_name: str) -> None:
"""Enable symmetric memory for a process group (PyTorch 2.11+)."""
torch_version = tuple(int(x) for x in torch.__version__.split(".")[:2])
if torch_version >= (2, 11):
return
from torch.distributed._symmetric_memory import enable_symm_mem_for_group

_patch_group_count_reset()
enable_symm_mem_for_group(group_name)


def _alloc_symm_buffer_bytes(
size_bytes: int,
world_size: int,
dtype: torch.dtype,
device: torch.device,
group_name: str,
) -> tuple[list[int], torch.Tensor, Any]:
"""Allocate a symmetric memory buffer and return per-peer pointers.

Args:
size_bytes: Total buffer size in bytes.
world_size: Number of peers in the communication group.
dtype: Element type used to interpret the buffer.
device: CUDA device for the allocation.
group_name: Process group name for the rendezvous.

Returns:
Tuple of (per-peer data pointers, local tensor, symmetric memory handle).
"""
# Ensure symmetric memory is set up with the correct store before
# rendezvous on PyTorch older than 2.11.
_enable_symm_mem_for_group(group_name)

elem_size = torch.empty(0, dtype=dtype).element_size()
numel = size_bytes // elem_size
tensor = symm_mem.empty(numel, dtype=dtype, device=device)
Comment on lines +68 to +70
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟑 Minor

Don't silently shrink byte-sized allocations.

Line 27 floors size_bytes to whole elements, so a 6-byte request with torch.float32 allocates only 4 bytes. Because callers treat size_bytes as the promised usable capacity, that can turn into an undersized symmetric buffer. Please round numel up or reject non-divisible sizes here.

πŸ› Proposed fix
-    numel = size_bytes // elem_size
+    numel = (size_bytes + elem_size - 1) // elem_size
πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/comm/torch_symmetric_memory.py` around lines 26 - 28, The current
allocation computes numel = size_bytes // elem_size which floors the element
count and can under-allocate (e.g., 6 bytes for float32 becomes 4 bytes). Update
the allocation in the block using elem_size/numel/tensor/symm_mem.empty to
either round up numel (e.g., ceil division: numel = (size_bytes + elem_size - 1)
// elem_size) so the buffer has at least size_bytes capacity (note actual
allocated bytes will be numel * elem_size), or explicitly raise an error when
size_bytes % elem_size != 0 to reject non-divisible requests; apply the chosen
behavior consistently where tensor = symm_mem.empty(numel, dtype=dtype,
device=device) is created and ensure any callers expecting exact usable capacity
are adjusted accordingly.

Comment on lines +68 to +70
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be more ergonomic if the API asks for shape or numel instead of size_bytes?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can add another api or option to use this api with shape or numel but I wanted to be least intrusive in the kernels.

handle = symm_mem.rendezvous(tensor, group=group_name)
ptrs: list[int] = [
handle.get_buffer(peer, (numel,), dtype, storage_offset=0).data_ptr()
for peer in range(world_size)
]
return ptrs, tensor, handle
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ptrs and handle are redundant to each other in this return. When user has the handle, they can get the ptrs themselves.

131 changes: 76 additions & 55 deletions flashinfer/comm/trtllm_ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
from ..utils import register_custom_op, round_up

logger = logging.getLogger(__name__)
from .cuda_ipc import create_shared_buffer, cudart, free_shared_buffer
from .cuda_ipc import cudart
from .torch_symmetric_memory import _alloc_symm_buffer_bytes


class AllReduceStrategyType:
Expand Down Expand Up @@ -399,6 +400,8 @@ def trtllm_moe_finalize_allreduce_fusion(
MAX_ALL_REDUCE_BLOCKS = 24
LamportTokenNumThreshold = 16

_symm_workspace_refs: dict[int, list[torch.Tensor]] = {}


@deprecated(
"trtllm_create_ipc_workspace_for_all_reduce and trtllm_custom_all_reduce are deprecated and will be removed in the next major bump, use allreduce.py instead."
Expand Down Expand Up @@ -451,27 +454,43 @@ def trtllm_create_ipc_workspace_for_all_reduce(
flag_size = FLAG_SIZE * tp_size * 2
lamport_buffer_size = tp_size * LamportTokenNumThreshold * tp_size * hidden_dim * 2

device = torch.device(f"cuda:{torch.cuda.current_device()}")
group_name = (
group.group_name
if group is not None
else torch.distributed.group.WORLD.group_name
)
symm_refs: list[torch.Tensor] = []
ipc_handles = list()

for size in [
buffer_size,
buffer_size,
flag_size,
flag_size,
lamport_buffer_size,
lamport_buffer_size,
lamport_buffer_size,
for size, dtype in [
(buffer_size, torch.float32),
(buffer_size, torch.float32),
(flag_size, torch.int32),
(flag_size, torch.int32),
(lamport_buffer_size, torch.float16),
(lamport_buffer_size, torch.float16),
(lamport_buffer_size, torch.float16),
]:
# all sizes should be aligned to 1LU << 21 bytes (2MB)
aligned_size = round_up(size, 1 << 21)
ipc_handles.append(create_shared_buffer(aligned_size, group))
aligned_size = round_up(size, 16)
ptrs, tensor, handle = _alloc_symm_buffer_bytes(
aligned_size,
tp_size,
dtype,
device,
group_name,
)
symm_refs.append((tensor, handle))
ipc_handles.append(ptrs)

logger.debug(
"rank %s allocated ipc_handles: %s",
rank,
[[hex(handle) for handle in sublist] for sublist in ipc_handles],
)

_symm_workspace_refs[id(ipc_handles)] = symm_refs

trtllm_lamport_initialize_all(
ipc_handles[4][rank],
ipc_handles[5][rank],
Expand All @@ -488,16 +507,16 @@ def trtllm_create_ipc_workspace_for_all_reduce(
def trtllm_destroy_ipc_workspace_for_all_reduce(
workspace: List[List[int]], group: Optional[ProcessGroup] = None
) -> None:
"""
Note:
This function is used to destroy a workspace for all reduce.
The workspace is a list of IPC handles.
The workspace should be destroyed after calling trtllm_custom_all_reduce.
The workspace can be reused for multiple all reduce calls under the same configuration.
"""
"""Destroy a workspace created by trtllm_create_ipc_workspace_for_all_reduce.

for ipc_handle in workspace:
free_shared_buffer(ipc_handle, group)
Releases the symmetric memory references held internally. The workspace
list should not be used after this call.

Args:
workspace: The ipc_handles list returned by the create function.
group: Unused, kept for API compatibility.
"""
_symm_workspace_refs.pop(id(workspace), None)


BarrierFlagCount = 256
Expand Down Expand Up @@ -588,42 +607,48 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion(

lamport_buffer_size = lamport_comm_size * 3

device = torch.device(f"cuda:{torch.cuda.current_device()}")
group_name = (
group.group_name
if group is not None
else torch.distributed.group.WORLD.group_name
)
symm_refs: list[torch.Tensor] = []

# we should init 3 buffers for all reduce fusion:
# [buffer_size, flag_size, lamport_buffer_size]

ipc_handles: List[List[int]] = list()
mem_handles: List[SymmDeviceMemory] = list()
for size in [buffer_size, flag_size, lamport_buffer_size]:
# todo(review): confirm we need this alignment
# all sizes should be aligned to 1LU << 21 bytes (2MB)
aligned_size = round_up(size, 1 << 21)
lamport_buffer_dtype = torch.float16 if not use_fp32_lamport else torch.float32
for size, dtype in [
(buffer_size, torch.float32),
(flag_size, torch.int32),
(lamport_buffer_size, lamport_buffer_dtype),
]:
aligned_size = round_up(size, 16)

if not use_symm_dev_mem:
ipc_handles.append(create_shared_buffer(aligned_size, group))
else:
# Use torch.cuda.current_device() instead of tp_rank to support
# base_gpu_id != 0 scenarios where the actual CUDA device index
# differs from the TP rank.
symm_mem = SymmDeviceMemory(
aligned_size,
tp_size,
tp_rank,
torch.cuda.current_device(),
comm_backend,
enable_multicast=False,
allocate_signal_pads=False,
)
ipc_handles.append(symm_mem.uc_ptrs)
mem_handles.append(symm_mem)
ptrs, tensor, handle = _alloc_symm_buffer_bytes(
aligned_size,
tp_size,
dtype,
device,
group_name,
)
symm_refs.append((tensor, handle))
ipc_handles.append(ptrs)
mem_handles.append(handle)

logger.debug(
"rank %s allocated ipc_handles: %s",
tp_rank,
[[hex(handle) for handle in sublist] for sublist in ipc_handles],
)

_symm_workspace_refs[id(ipc_handles)] = symm_refs

# Initialize lamport buffer
aligned_lamport_buffer_size = round_up(lamport_buffer_size, 1 << 21)
aligned_lamport_buffer_size = round_up(lamport_buffer_size, 16)
if use_fp32_lamport:
trtllm_lamport_initialize(
ipc_handles[2][tp_rank], aligned_lamport_buffer_size // 4, torch.float32
Expand Down Expand Up @@ -700,20 +725,16 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion(
def trtllm_destroy_ipc_workspace_for_all_reduce_fusion(
workspace: List[List[int]], group: Optional[ProcessGroup] = None
) -> None:
"""
Parameters:
- workspace: the workspace to destroy.
- group: the process group to use.
"""Destroy a workspace created by trtllm_create_ipc_workspace_for_all_reduce_fusion.

Note:
This function is used to destroy a workspace for all reduce fusion.
The workspace is a list of IPC handles.
The workspace should be destroyed after calling trtllm_custom_all_reduce_fusion.
The workspace can be reused for multiple all reduce fusion calls under the same configuration.
"""
Releases the symmetric memory references held internally. The workspace
list should not be used after this call.

for ipc_handle in workspace:
free_shared_buffer(ipc_handle, group)
Args:
workspace: The ipc_handles list returned by the create function.
group: Unused, kept for API compatibility.
"""
_symm_workspace_refs.pop(id(workspace), None)
Comment on lines +728 to +737
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

destroy_*_fusion() still leaks the cudaMalloc flag buffer.

The create path allocates flag_ptr = cudart.cudaMalloc(5 * 4) at Line 672, but this destroy function only removes _symm_workspace_refs. The raw device allocation is never freed, so repeated workspace recreation leaks CUDA memory. Please track that pointer alongside the symmetric refs and release it here as well.

πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/comm/trtllm_ar.py` around lines 721 - 730, The destroy function
for the fusion workspace currently only pops _symm_workspace_refs and thus leaks
the device flag buffer allocated in
trtllm_create_ipc_workspace_for_all_reduce_fusion (flag_ptr =
cudart.cudaMalloc(5 * 4)). Modify the create path to store the flag_ptr
alongside the workspace refs (e.g., in a dict like _symm_flag_ptrs keyed by
id(workspace) or by storing a tuple in _symm_workspace_refs), and update the
destroy function (trtllm_destroy_ipc_workspace_for_all_reduce_fusion) to
retrieve and free that device pointer with cudart.cudaFree(flag_ptr) before
removing entries from _symm_workspace_refs (and _symm_flag_ptrs if used); ensure
you handle missing keys safely (pop with default None) to avoid exceptions.

Comment on lines 725 to +737
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Destroy function incomplete: missing flag_ptr cleanup.

This function only removes symmetric memory references but does not free the flag_ptr CUDA allocation created in trtllm_create_ipc_workspace_for_all_reduce_fusion. See the related comment above for the suggested fix.

πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/comm/trtllm_ar.py` around lines 728 - 740, The destroy function
trtllm_destroy_ipc_workspace_for_all_reduce_fusion currently only pops
_symm_workspace_refs but fails to free the CUDA allocation pointed to by
flag_ptr created in trtllm_create_ipc_workspace_for_all_reduce_fusion; update
the function to lookup the stored record in _symm_workspace_refs (using
id(workspace)), if present free the CUDA allocation referenced by its flag_ptr
(using the same CUDA/free API used when allocating it), then remove the entry
from _symm_workspace_refs and handle missing entries gracefully so no dangling
GPU memory remains.



# allReduce fused quant utils
Expand Down
Loading
Loading