Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions flashinfer/comm/torch_symmetric_memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Any

import torch
import torch.distributed._symmetric_memory as symm_mem


def _alloc_symm_buffer_bytes(
size_bytes: int,
world_size: int,
dtype: torch.dtype,
device: torch.device,
group_name: str,
) -> tuple[list[int], torch.Tensor, Any]:
"""Allocate a symmetric memory buffer and return per-peer pointers.

Args:
size_bytes: Total buffer size in bytes.
world_size: Number of peers in the communication group.
dtype: Element type used to interpret the buffer.
device: CUDA device for the allocation.
group_name: Process group name for the rendezvous.

Returns:
Tuple of (per-peer data pointers, local tensor, symmetric memory handle).
"""
elem_size = torch.empty(0, dtype=dtype).element_size()
numel = size_bytes // elem_size
tensor = symm_mem.empty(numel, dtype=dtype, device=device)
Comment on lines +68 to +70
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟑 Minor

Don't silently shrink byte-sized allocations.

Line 27 floors size_bytes to whole elements, so a 6-byte request with torch.float32 allocates only 4 bytes. Because callers treat size_bytes as the promised usable capacity, that can turn into an undersized symmetric buffer. Please round numel up or reject non-divisible sizes here.

πŸ› Proposed fix
-    numel = size_bytes // elem_size
+    numel = (size_bytes + elem_size - 1) // elem_size
πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/comm/torch_symmetric_memory.py` around lines 26 - 28, The current
allocation computes numel = size_bytes // elem_size which floors the element
count and can under-allocate (e.g., 6 bytes for float32 becomes 4 bytes). Update
the allocation in the block using elem_size/numel/tensor/symm_mem.empty to
either round up numel (e.g., ceil division: numel = (size_bytes + elem_size - 1)
// elem_size) so the buffer has at least size_bytes capacity (note actual
allocated bytes will be numel * elem_size), or explicitly raise an error when
size_bytes % elem_size != 0 to reject non-divisible requests; apply the chosen
behavior consistently where tensor = symm_mem.empty(numel, dtype=dtype,
device=device) is created and ensure any callers expecting exact usable capacity
are adjusted accordingly.

Comment on lines +68 to +70
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be more ergonomic if the API asks for shape or numel instead of size_bytes?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can add another api or option to use this api with shape or numel but I wanted to be least intrusive in the kernels.

handle = symm_mem.rendezvous(tensor, group=group_name)
ptrs: list[int] = [
handle.get_buffer(peer, (numel,), dtype, storage_offset=0).data_ptr()
for peer in range(world_size)
]
return ptrs, tensor, handle
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ptrs and handle are redundant to each other in this return. When user has the handle, they can get the ptrs themselves.

112 changes: 63 additions & 49 deletions flashinfer/comm/trtllm_ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

logger = logging.getLogger(__name__)
from .cuda_ipc import create_shared_buffer, cudart, free_shared_buffer

from .torch_symmetric_memory import _alloc_symm_buffer_bytes
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

class AllReduceStrategyType:
# NOTE: for trtllm_custom_all_reduce
Expand Down Expand Up @@ -399,6 +399,7 @@ def trtllm_moe_finalize_allreduce_fusion(
MAX_ALL_REDUCE_BLOCKS = 24
LamportTokenNumThreshold = 16

_symm_workspace_refs: dict[int, list[torch.Tensor]] = {}

@deprecated(
"trtllm_create_ipc_workspace_for_all_reduce and trtllm_custom_all_reduce are deprecated and will be removed in the next major bump, use allreduce.py instead."
Expand Down Expand Up @@ -451,27 +452,40 @@ def trtllm_create_ipc_workspace_for_all_reduce(
flag_size = FLAG_SIZE * tp_size * 2
lamport_buffer_size = tp_size * LamportTokenNumThreshold * tp_size * hidden_dim * 2

device = torch.device(f"cuda:{torch.cuda.current_device()}")
group_name = group.group_name if group is not None else torch.distributed.group.WORLD.group_name
symm_refs: list[torch.Tensor] = []
ipc_handles = list()

for size in [
buffer_size,
buffer_size,
flag_size,
flag_size,
lamport_buffer_size,
lamport_buffer_size,
lamport_buffer_size,
for size, dtype in [
(buffer_size, torch.float32),
(buffer_size, torch.float32),
(flag_size, torch.int32),
(flag_size, torch.int32),
(lamport_buffer_size, torch.float16),
(lamport_buffer_size, torch.float16),
(lamport_buffer_size, torch.float16),
]:
# all sizes should be aligned to 1LU << 21 bytes (2MB)
aligned_size = round_up(size, 1 << 21)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why ? torch symm memory will use a mempool under the cover so that you can have smaller requests. Probably good to be make sure it is 16B aligned but not 2MB.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed!

ipc_handles.append(create_shared_buffer(aligned_size, group))
ptrs, tensor, handle = _alloc_symm_buffer_bytes(
aligned_size,
tp_size,
dtype,
device,
group_name,
)
symm_refs.append((tensor, handle))
ipc_handles.append(ptrs)

logger.debug(
"rank %s allocated ipc_handles: %s",
rank,
[[hex(handle) for handle in sublist] for sublist in ipc_handles],
)

_symm_workspace_refs[id(ipc_handles)] = symm_refs

trtllm_lamport_initialize_all(
ipc_handles[4][rank],
ipc_handles[5][rank],
Expand All @@ -488,16 +502,16 @@ def trtllm_create_ipc_workspace_for_all_reduce(
def trtllm_destroy_ipc_workspace_for_all_reduce(
workspace: List[List[int]], group: Optional[ProcessGroup] = None
) -> None:
"""
Note:
This function is used to destroy a workspace for all reduce.
The workspace is a list of IPC handles.
The workspace should be destroyed after calling trtllm_custom_all_reduce.
The workspace can be reused for multiple all reduce calls under the same configuration.
"""
"""Destroy a workspace created by trtllm_create_ipc_workspace_for_all_reduce.

for ipc_handle in workspace:
free_shared_buffer(ipc_handle, group)
Releases the symmetric memory references held internally. The workspace
list should not be used after this call.

Args:
workspace: The ipc_handles list returned by the create function.
group: Unused, kept for API compatibility.
"""
_symm_workspace_refs.pop(id(workspace), None)


BarrierFlagCount = 256
Expand Down Expand Up @@ -588,40 +602,44 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion(

lamport_buffer_size = lamport_comm_size * 3

device = torch.device(f"cuda:{torch.cuda.current_device()}")
group_name = group.group_name if group is not None else torch.distributed.group.WORLD.group_name
symm_refs: list[torch.Tensor] = []

# we should init 3 buffers for all reduce fusion:
# [buffer_size, flag_size, lamport_buffer_size]

ipc_handles: List[List[int]] = list()
mem_handles: List[SymmDeviceMemory] = list()
for size in [buffer_size, flag_size, lamport_buffer_size]:
lamport_buffer_dtype = torch.float16 if not use_fp32_lamport else torch.float32
for size, dtype in [
(buffer_size, torch.float32),
(flag_size, torch.int32),
(lamport_buffer_size, lamport_buffer_dtype),
]:
# todo(review): confirm we need this alignment
# all sizes should be aligned to 1LU << 21 bytes (2MB)
aligned_size = round_up(size, 1 << 21)

if not use_symm_dev_mem:
ipc_handles.append(create_shared_buffer(aligned_size, group))
else:
# Use torch.cuda.current_device() instead of tp_rank to support
# base_gpu_id != 0 scenarios where the actual CUDA device index
# differs from the TP rank.
symm_mem = SymmDeviceMemory(
aligned_size,
tp_size,
tp_rank,
torch.cuda.current_device(),
comm_backend,
enable_multicast=False,
allocate_signal_pads=False,
)
ipc_handles.append(symm_mem.uc_ptrs)
mem_handles.append(symm_mem)
ptrs, tensor, handle = _alloc_symm_buffer_bytes(
aligned_size,
tp_size,
dtype,
device,
group_name,
)
symm_refs.append((tensor, handle))
ipc_handles.append(ptrs)
mem_handles.append(handle)

logger.debug(
"rank %s allocated ipc_handles: %s",
tp_rank,
[[hex(handle) for handle in sublist] for sublist in ipc_handles],
)

_symm_workspace_refs[id(ipc_handles)] = symm_refs

# Initialize lamport buffer
aligned_lamport_buffer_size = round_up(lamport_buffer_size, 1 << 21)
if use_fp32_lamport:
Expand Down Expand Up @@ -700,20 +718,16 @@ def trtllm_create_ipc_workspace_for_all_reduce_fusion(
def trtllm_destroy_ipc_workspace_for_all_reduce_fusion(
workspace: List[List[int]], group: Optional[ProcessGroup] = None
) -> None:
"""
Parameters:
- workspace: the workspace to destroy.
- group: the process group to use.
"""Destroy a workspace created by trtllm_create_ipc_workspace_for_all_reduce_fusion.

Note:
This function is used to destroy a workspace for all reduce fusion.
The workspace is a list of IPC handles.
The workspace should be destroyed after calling trtllm_custom_all_reduce_fusion.
The workspace can be reused for multiple all reduce fusion calls under the same configuration.
"""
Releases the symmetric memory references held internally. The workspace
list should not be used after this call.

for ipc_handle in workspace:
free_shared_buffer(ipc_handle, group)
Args:
workspace: The ipc_handles list returned by the create function.
group: Unused, kept for API compatibility.
"""
_symm_workspace_refs.pop(id(workspace), None)
Comment on lines +728 to +737
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

destroy_*_fusion() still leaks the cudaMalloc flag buffer.

The create path allocates flag_ptr = cudart.cudaMalloc(5 * 4) at Line 672, but this destroy function only removes _symm_workspace_refs. The raw device allocation is never freed, so repeated workspace recreation leaks CUDA memory. Please track that pointer alongside the symmetric refs and release it here as well.

πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/comm/trtllm_ar.py` around lines 721 - 730, The destroy function
for the fusion workspace currently only pops _symm_workspace_refs and thus leaks
the device flag buffer allocated in
trtllm_create_ipc_workspace_for_all_reduce_fusion (flag_ptr =
cudart.cudaMalloc(5 * 4)). Modify the create path to store the flag_ptr
alongside the workspace refs (e.g., in a dict like _symm_flag_ptrs keyed by
id(workspace) or by storing a tuple in _symm_workspace_refs), and update the
destroy function (trtllm_destroy_ipc_workspace_for_all_reduce_fusion) to
retrieve and free that device pointer with cudart.cudaFree(flag_ptr) before
removing entries from _symm_workspace_refs (and _symm_flag_ptrs if used); ensure
you handle missing keys safely (pop with default None) to avoid exceptions.

Comment on lines 725 to +737
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Destroy function incomplete: missing flag_ptr cleanup.

This function only removes symmetric memory references but does not free the flag_ptr CUDA allocation created in trtllm_create_ipc_workspace_for_all_reduce_fusion. See the related comment above for the suggested fix.

πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/comm/trtllm_ar.py` around lines 728 - 740, The destroy function
trtllm_destroy_ipc_workspace_for_all_reduce_fusion currently only pops
_symm_workspace_refs but fails to free the CUDA allocation pointed to by
flag_ptr created in trtllm_create_ipc_workspace_for_all_reduce_fusion; update
the function to lookup the stored record in _symm_workspace_refs (using
id(workspace)), if present free the CUDA allocation referenced by its flag_ptr
(using the same CUDA/free API used when allocating it), then remove the entry
from _symm_workspace_refs and handle missing entries gracefully so no dangling
GPU memory remains.



# allReduce fused quant utils
Expand Down
92 changes: 59 additions & 33 deletions flashinfer/comm/trtllm_mnnvl_ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,27 @@
from typing_extensions import deprecated

from flashinfer.comm.mapping import Mapping
from flashinfer.comm.mnnvl import TorchDistBackend

from ..jit import gen_trtllm_mnnvl_comm_module
from ..utils import register_custom_op
from .mnnvl import McastGPUBuffer, CommBackend, MPIBackend
from .mnnvl import CommBackend, MPIBackend
from .workspace_base import AllReduceFusionWorkspace

from .torch_symmetric_memory import _alloc_symm_buffer_bytes
from ..cuda_utils import checkCudaErrors
try:
# cuda-python >= 12.9 (has cuda.bindings.driver)
from cuda.bindings import driver as cuda
except ImportError:
try:
# cuda-python < 12.9 (no cuda.bindings.driver, use cuda as driver)
# from cuda import cuda is not available in cuda-python >= 13.0
from cuda import cuda
except ImportError as e:
raise ImportError(
"Could not import the 'cuda' module. "
"Please install cuda-python that matches your CUDA version."
) from e

def mpi_barrier():
from mpi4py import MPI
Expand Down Expand Up @@ -135,16 +150,23 @@ def __init__(
# Use torch.cuda.current_device() instead of mapping.local_rank to
# support base_gpu_id != 0 scenarios where the actual CUDA device
# index differs from the TP rank / local_rank.
self.mcast_buffer_handle = McastGPUBuffer(
device = torch.device("cuda", torch.cuda.current_device())
if isinstance(comm_backend, TorchDistBackend):
group = comm_backend._group if comm_backend._group is not None else torch.distributed.group.WORLD
group_name = group.group_name
else:
group_name = torch.distributed.group.WORLD.group_name
Comment on lines +141 to +149
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Don't silently rendezvous on WORLD for non-TorchDistBackend backends.

Lines 154-158 only derive group_name from TorchDistBackend; every other CommBackend falls back to torch.distributed.group.WORLD. That breaks subgroup communicators and also makes the documented default MPIBackend() path depend on an initialized torch process group. Please require a backend that can surface the matching process-group identity, or fail explicitly instead of using the wrong peer set.

πŸ€– Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@flashinfer/comm/trtllm_mnnvl_ar.py` around lines 154 - 158, The current
branch silently uses torch.distributed.group.WORLD when comm_backend is not a
TorchDistBackend, which incorrectly rendezvous non-Torch backends; update the
logic around comm_backend / TorchDistBackend and group_name so that only
backends that can surface a matching torch process-group identity are allowed:
if comm_backend is a TorchDistBackend use its _group.group_name, otherwise
require the CommBackend to provide a group identifier (e.g., a new
method/property on the backend interface) and, if it cannot, raise an explicit
error (RuntimeError) explaining that the backend does not expose a process-group
and rendezvous on WORLD is not permitted. Ensure references to comm_backend,
TorchDistBackend, and group_name are updated accordingly.

self.ptrs, self.tensor, self.handle = _alloc_symm_buffer_bytes(
requested_workspace_size,
mapping.tp_size,
mapping.tp_rank,
torch.device("cuda", torch.cuda.current_device()),
comm_backend,
torch.float32,
device,
group_name,
)

# Get the actual usable buffer size after allocation (buf_size is updated by McastGPUBuffer)
allocated_size = self.mcast_buffer_handle.buf_size
# handle.buffer_size is the usable data size. torch symmetric memory
# allocator places signal_pad on top of it, not carved from within.
allocated_size = self.handle.buffer_size
# We want the buffer size to be aligned to 16B which is the granularity for buffer management.
self.buffer_size_bytes = (
math.floor(allocated_size / self.NUM_LAMPORT_BUFFERS) // 16 * 16
Expand All @@ -157,7 +179,7 @@ def __init__(
)

# We use FP32 for sentinel value regardless of the real dtype
self.mcast_buffer_handle.lamport_initialize(mapping.tp_rank, torch.float32)
self.lamport_initialize(mapping.tp_rank, torch.float32, allocated_size)
# Wait until the initialization is done
torch.cuda.synchronize()
comm_backend.barrier()
Expand All @@ -173,9 +195,26 @@ def __init__(
device=torch.device("cuda", torch.cuda.current_device()),
)

self.uc_ptrs_dev = self.mcast_buffer_handle.get_buffer_ptrs_dev()
self.uc_ptr_local = self.mcast_buffer_handle.get_unicast_ptr(self.rank)
self.mc_ptr = self.mcast_buffer_handle.get_multicast_ptr()
self.uc_ptrs_dev = self.handle.buffer_ptrs_dev
self.uc_ptr_local = self.handle.buffer_ptrs[self.rank]
self.mc_ptr = self.handle.multicast_ptr

def lamport_initialize(self, rank: int, dtype: torch.dtype, allocated_size: int):
if dtype == torch.bfloat16 or dtype == torch.float16:
neg_zero = 0x8000
dsize = 2
memset_func = cuda.cuMemsetD16
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can't you use tensor.fill_(-0.0) ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed! thank you! I learned something new.

elif dtype == torch.float32:
neg_zero = 0x80000000
dsize = 4
memset_func = cuda.cuMemsetD32
else:
raise ValueError(f"Unsupported dtype: {dtype}")

num_elements = (allocated_size) // dsize
checkCudaErrors(
memset_func(int(self.ptrs[rank]), neg_zero, num_elements)
)

@functools.cache
def is_buffer_size_sufficient(
Expand Down Expand Up @@ -235,11 +274,13 @@ def destroy(self) -> None:
if getattr(self, "_destroyed", False):
return # Already destroyed, nothing to do

del self.mcast_buffer_handle
del self.buffer_flags
del self.uc_ptrs_dev
del self.uc_ptr_local
del self.mc_ptr
del self.tensor
del self.handle
del self.ptrs
self._destroyed = True


Expand Down Expand Up @@ -506,33 +547,22 @@ def get_allreduce_mnnvl_workspace(
dtype: torch.dtype,
comm_backend_for_handle_transfer: Optional[CommBackend] = None,
buffer_size_in_bytes: Optional[int] = None,
) -> Tuple[McastGPUBuffer, torch.Tensor, int]:
) -> Tuple[MNNVLAllReduceFusionWorkspace, torch.Tensor, int]:
"""Get workspace buffers needed for multi-node NVLink all-reduce operation.

This function allocates and initializes the workspace buffers required for performing
multi-node NVLink all-reduce operations. It creates:
1. A multicast GPU buffer for communication between nodes
2. A flags tensor to track buffer state
3. Maximum number of elements that can fit in the buffer

The buffer size is calculated to efficiently handle common hidden dimensions
(2048, 4096, 5120, 7168, 8192) by using their LCM of 286720.

Args:
mapping: Tensor parallel mapping configuration containing rank info
dtype: Data type of the tensors being reduced
comm_backend_for_handle_transfer: Communication backend for handle transfer
buffer_size_in_bytes: Optional buffer size. Practically, assign this to 3 * 2 * dtype.itemsize * hidden_dim * max_tokens

Returns:
Tuple containing:
- McastGPUBuffer: Multicast buffer for inter-node communication
- MNNVLAllReduceFusionWorkspace: The workspace object backed by torch symmetric memory
- torch.Tensor: Buffer flags tensor tracking state
- int: Maximum number of elements that can fit in buffer
"""
# buffer shape: [3, 2, buffer_tokens, hidden_dim]
stride = 3 * 2 * dtype.itemsize
# LCM for hidden_dim: 2048, 4096, 5120, 7168, 8192 = 286720
# max_num_elements must be a multiple of 286720
lcm_hidden_dim = 286720
TARGET_WORKSPACE_SIZE_BYTES = (
buffer_size_in_bytes if buffer_size_in_bytes is not None else 12_000_000
Expand All @@ -541,21 +571,17 @@ def get_allreduce_mnnvl_workspace(
TARGET_WORKSPACE_SIZE_BYTES / (lcm_hidden_dim * stride)
) * (lcm_hidden_dim * stride)

# Redirect to the new workspace allocation logic. The new kernel needs the new flag buffer layout.
workspace = MNNVLAllReduceFusionWorkspace(
mapping,
buffer_size_in_bytes=buffer_size_in_bytes,
comm_backend=comm_backend_for_handle_transfer,
)

mcast_buffer = workspace.mcast_buffer_handle
buffer_flags = workspace.buffer_flags
# this is calculated using the legacy behavior. We do not use the actual allocated size.
max_num_elements = workspace.buffer_size_bytes // stride

return (
mcast_buffer,
buffer_flags,
workspace,
workspace.buffer_flags,
max_num_elements,
)

Expand Down
Loading
Loading