unifying all reduce memory allocation for single-node and multi-node nvlink#2955
unifying all reduce memory allocation for single-node and multi-node nvlink#2955Amir-19 wants to merge 13 commits intoflashinfer-ai:mainfrom
Conversation
📝 WalkthroughWalkthroughAdds a new PyTorch symmetric-CUDA-memory helper and migrates TRT-LLM and MNNVL IPC/all-reduce workspace allocation from legacy shared-buffer/McastGPUBuffer APIs to rendezvous-backed symmetric tensors, updating allocation, tracking, and destruction flows plus related tests and logging. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant TRTLLM as TRT-LLM code
participant Torch as torch.distributed
participant CUDA as CUDA symmetric memory
TRTLLM->>Torch: _enable_symm_mem_for_group(group_name)
Torch-->>Torch: patch destroy_process_group (if needed)
TRTLLM->>CUDA: request symmetric allocation (size,dtype,device)
CUDA-->>Torch: rendezvous enable_symm_mem_for_group + empty tensor
Torch-->>CUDA: rendezvous barrier / handle
CUDA-->>TRTLLM: return (ptrs list, tensor, handle)
TRTLLM->>TRTLLM: store refs in _symm_workspace_refs[id(ipc_handles)]
TRTLLM->>CUDA: on destroy -> teardown references (delete tensor, handle, ptrs)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request transitions the All-Reduce workspace management to use torch.distributed._symmetric_memory, replacing previous custom IPC and multicast buffer implementations. Key changes include the introduction of a symmetric buffer allocation utility and updates to the creation and destruction logic for both standard and fused All-Reduce workspaces. Review feedback highlights the need for safer handling of optional process groups to prevent AttributeError, the importance of using torch.cuda.current_device() for device consistency, and the correction of return type hints and hardcoded data types.
Signed-off-by: Amir Samani <asamani@nvidia.com>
Signed-off-by: Amir Samani <asamani@nvidia.com>
There was a problem hiding this comment.
Actionable comments posted: 5
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/comm/torch_symmetric_memory.py`:
- Around line 26-28: The current allocation computes numel = size_bytes //
elem_size which floors the element count and can under-allocate (e.g., 6 bytes
for float32 becomes 4 bytes). Update the allocation in the block using
elem_size/numel/tensor/symm_mem.empty to either round up numel (e.g., ceil
division: numel = (size_bytes + elem_size - 1) // elem_size) so the buffer has
at least size_bytes capacity (note actual allocated bytes will be numel *
elem_size), or explicitly raise an error when size_bytes % elem_size != 0 to
reject non-divisible requests; apply the chosen behavior consistently where
tensor = symm_mem.empty(numel, dtype=dtype, device=device) is created and ensure
any callers expecting exact usable capacity are adjusted accordingly.
In `@flashinfer/comm/trtllm_ar.py`:
- Around line 33-34: The imports create_shared_buffer and free_shared_buffer
from .cuda_ipc are dead and should be removed; update the import statement in
trtllm_ar.py to only import symbols that are actually used (e.g., keep cudart if
referenced, otherwise remove the entire .cuda_ipc import), ensuring that
references to create_shared_buffer and free_shared_buffer are not left
elsewhere; verify _alloc_symm_buffer_bytes from .torch_symmetric_memory remains
imported if used.
- Around line 721-730: The destroy function for the fusion workspace currently
only pops _symm_workspace_refs and thus leaks the device flag buffer allocated
in trtllm_create_ipc_workspace_for_all_reduce_fusion (flag_ptr =
cudart.cudaMalloc(5 * 4)). Modify the create path to store the flag_ptr
alongside the workspace refs (e.g., in a dict like _symm_flag_ptrs keyed by
id(workspace) or by storing a tuple in _symm_workspace_refs), and update the
destroy function (trtllm_destroy_ipc_workspace_for_all_reduce_fusion) to
retrieve and free that device pointer with cudart.cudaFree(flag_ptr) before
removing entries from _symm_workspace_refs (and _symm_flag_ptrs if used); ensure
you handle missing keys safely (pop with default None) to avoid exceptions.
In `@flashinfer/comm/trtllm_mnnvl_ar.py`:
- Around line 154-158: The current branch silently uses
torch.distributed.group.WORLD when comm_backend is not a TorchDistBackend, which
incorrectly rendezvous non-Torch backends; update the logic around comm_backend
/ TorchDistBackend and group_name so that only backends that can surface a
matching torch process-group identity are allowed: if comm_backend is a
TorchDistBackend use its _group.group_name, otherwise require the CommBackend to
provide a group identifier (e.g., a new method/property on the backend
interface) and, if it cannot, raise an explicit error (RuntimeError) explaining
that the backend does not expose a process-group and rendezvous on WORLD is not
permitted. Ensure references to comm_backend, TorchDistBackend, and group_name
are updated accordingly.
In `@tests/comm/test_trtllm_mnnvl_allreduce.py`:
- Around line 478-486: The test presently sets legacy_explicit_workspace_bytes
with a hardcoded factor (3 * 2) before world size is known, which can undersize
the workspace; instead, compute the legacy override inside run_mnnvl_ar_full
after querying dist.get_world_size() (or after the test instantiates the process
group) and call
MNNVLAllReduceFusionWorkspace.get_required_buffer_size_bytes(...) (pass dtype,
hidden_size, seq_lens/max_seq_len, and the discovered world_size/tp_size) to
derive the correct explicit_workspace_bytes, then pass that value into the
legacy_explicit_workspace_bytes argument so sizing uses the actual world_size
rather than the hardcoded factor.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 2fc6d2f2-38d9-469e-9ce3-2137a91440e4
📒 Files selected for processing (6)
flashinfer/comm/torch_symmetric_memory.pyflashinfer/comm/trtllm_ar.pyflashinfer/comm/trtllm_mnnvl_ar.pytests/comm/test_trtllm_allreduce_fusion.pytests/comm/test_trtllm_mnnvl_allreduce.pytests/comm/test_trtllm_mnnvl_allreduce_custom_comm.py
| elem_size = torch.empty(0, dtype=dtype).element_size() | ||
| numel = size_bytes // elem_size | ||
| tensor = symm_mem.empty(numel, dtype=dtype, device=device) |
There was a problem hiding this comment.
Don't silently shrink byte-sized allocations.
Line 27 floors size_bytes to whole elements, so a 6-byte request with torch.float32 allocates only 4 bytes. Because callers treat size_bytes as the promised usable capacity, that can turn into an undersized symmetric buffer. Please round numel up or reject non-divisible sizes here.
🐛 Proposed fix
- numel = size_bytes // elem_size
+ numel = (size_bytes + elem_size - 1) // elem_size🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/comm/torch_symmetric_memory.py` around lines 26 - 28, The current
allocation computes numel = size_bytes // elem_size which floors the element
count and can under-allocate (e.g., 6 bytes for float32 becomes 4 bytes). Update
the allocation in the block using elem_size/numel/tensor/symm_mem.empty to
either round up numel (e.g., ceil division: numel = (size_bytes + elem_size - 1)
// elem_size) so the buffer has at least size_bytes capacity (note actual
allocated bytes will be numel * elem_size), or explicitly raise an error when
size_bytes % elem_size != 0 to reject non-divisible requests; apply the chosen
behavior consistently where tensor = symm_mem.empty(numel, dtype=dtype,
device=device) is created and ensure any callers expecting exact usable capacity
are adjusted accordingly.
| """Destroy a workspace created by trtllm_create_ipc_workspace_for_all_reduce_fusion. | ||
|
|
||
| Note: | ||
| This function is used to destroy a workspace for all reduce fusion. | ||
| The workspace is a list of IPC handles. | ||
| The workspace should be destroyed after calling trtllm_custom_all_reduce_fusion. | ||
| The workspace can be reused for multiple all reduce fusion calls under the same configuration. | ||
| """ | ||
| Releases the symmetric memory references held internally. The workspace | ||
| list should not be used after this call. | ||
|
|
||
| for ipc_handle in workspace: | ||
| free_shared_buffer(ipc_handle, group) | ||
| Args: | ||
| workspace: The ipc_handles list returned by the create function. | ||
| group: Unused, kept for API compatibility. | ||
| """ | ||
| _symm_workspace_refs.pop(id(workspace), None) |
There was a problem hiding this comment.
destroy_*_fusion() still leaks the cudaMalloc flag buffer.
The create path allocates flag_ptr = cudart.cudaMalloc(5 * 4) at Line 672, but this destroy function only removes _symm_workspace_refs. The raw device allocation is never freed, so repeated workspace recreation leaks CUDA memory. Please track that pointer alongside the symmetric refs and release it here as well.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/comm/trtllm_ar.py` around lines 721 - 730, The destroy function
for the fusion workspace currently only pops _symm_workspace_refs and thus leaks
the device flag buffer allocated in
trtllm_create_ipc_workspace_for_all_reduce_fusion (flag_ptr =
cudart.cudaMalloc(5 * 4)). Modify the create path to store the flag_ptr
alongside the workspace refs (e.g., in a dict like _symm_flag_ptrs keyed by
id(workspace) or by storing a tuple in _symm_workspace_refs), and update the
destroy function (trtllm_destroy_ipc_workspace_for_all_reduce_fusion) to
retrieve and free that device pointer with cudart.cudaFree(flag_ptr) before
removing entries from _symm_workspace_refs (and _symm_flag_ptrs if used); ensure
you handle missing keys safely (pop with default None) to avoid exceptions.
| if isinstance(comm_backend, TorchDistBackend): | ||
| group = comm_backend._group if comm_backend._group is not None else torch.distributed.group.WORLD | ||
| group_name = group.group_name | ||
| else: | ||
| group_name = torch.distributed.group.WORLD.group_name |
There was a problem hiding this comment.
Don't silently rendezvous on WORLD for non-TorchDistBackend backends.
Lines 154-158 only derive group_name from TorchDistBackend; every other CommBackend falls back to torch.distributed.group.WORLD. That breaks subgroup communicators and also makes the documented default MPIBackend() path depend on an initialized torch process group. Please require a backend that can surface the matching process-group identity, or fail explicitly instead of using the wrong peer set.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/comm/trtllm_mnnvl_ar.py` around lines 154 - 158, The current
branch silently uses torch.distributed.group.WORLD when comm_backend is not a
TorchDistBackend, which incorrectly rendezvous non-Torch backends; update the
logic around comm_backend / TorchDistBackend and group_name so that only
backends that can surface a matching torch process-group identity are allowed:
if comm_backend is a TorchDistBackend use its _group.group_name, otherwise
require the CommBackend to provide a group identifier (e.g., a new
method/property on the backend interface) and, if it cannot, raise an explicit
error (RuntimeError) explaining that the backend does not expose a process-group
and rendezvous on WORLD is not permitted. Ensure references to comm_backend,
TorchDistBackend, and group_name are updated accordingly.
| explicit_workspace_bytes = 3 * 2 * dtype.itemsize * hidden_size * max(seq_lens) | ||
| run_mnnvl_ar_full( | ||
| monkeypatch, seq_lens, fusion, dtype, hidden_size, legacy_api=True | ||
| monkeypatch, | ||
| seq_lens, | ||
| fusion, | ||
| dtype, | ||
| hidden_size, | ||
| legacy_explicit_workspace_bytes=explicit_workspace_bytes, | ||
| legacy_api=True, |
There was a problem hiding this comment.
Derive the explicit legacy workspace size after world_size is known.
Line 478 hardcodes a 3 * 2 factor and ignores tp_size, but the actual required buffer size scales with world_size. On larger MPI jobs this override can undersize the workspace and make the legacy path fail for sizing reasons instead of kernel correctness. Please compute the override inside run_mnnvl_ar_full() after dist.get_world_size() is available, ideally via MNNVLAllReduceFusionWorkspace.get_required_buffer_size_bytes(...).
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@tests/comm/test_trtllm_mnnvl_allreduce.py` around lines 478 - 486, The test
presently sets legacy_explicit_workspace_bytes with a hardcoded factor (3 * 2)
before world size is known, which can undersize the workspace; instead, compute
the legacy override inside run_mnnvl_ar_full after querying
dist.get_world_size() (or after the test instantiates the process group) and
call MNNVLAllReduceFusionWorkspace.get_required_buffer_size_bytes(...) (pass
dtype, hidden_size, seq_lens/max_seq_len, and the discovered world_size/tp_size)
to derive the correct explicit_workspace_bytes, then pass that value into the
legacy_explicit_workspace_bytes argument so sizing uses the actual world_size
rather than the hardcoded factor.
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 1
♻️ Duplicate comments (1)
flashinfer/comm/trtllm_ar.py (1)
680-692:⚠️ Potential issue | 🟠 MajorCUDA memory leak:
flag_ptris never freed.The
flag_ptrallocated at line 682 viacudart.cudaMalloc(5 * 4)is added to the workspace at line 692, buttrtllm_destroy_ipc_workspace_for_all_reduce_fusion(lines 728-740) only removes_symm_workspace_refsentries. The raw CUDA allocation is never freed, causing memory leaks on repeated workspace creation/destruction cycles.🛠️ Suggested fix approach
Track
flag_ptralongside the symmetric refs:+_symm_flag_ptrs: dict[int, int] = {} # id(ipc_handles) -> flag_ptr value # In trtllm_create_ipc_workspace_for_all_reduce_fusion, after line 692: + _symm_flag_ptrs[id(ipc_handles)] = flag_ptr.value # In trtllm_destroy_ipc_workspace_for_all_reduce_fusion: def trtllm_destroy_ipc_workspace_for_all_reduce_fusion( workspace: List[List[int]], group: Optional[ProcessGroup] = None ) -> None: _symm_workspace_refs.pop(id(workspace), None) + flag_ptr = _symm_flag_ptrs.pop(id(workspace), None) + if flag_ptr is not None: + cudart.cudaFree(flag_ptr)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/comm/trtllm_ar.py` around lines 680 - 692, The allocated CUDA pointer flag_ptr created in the workspace setup (see flag_ptr = cudart.cudaMalloc(...) and workspace.append(flag_ptr.value)) is never freed, causing memory leaks; modify the workspace tracking so flag_ptr is stored alongside the workspace/symmetric refs (e.g., append a tuple or push into a dedicated list such as _flag_ptrs) when created in the routine that allocates it, and update trtllm_destroy_ipc_workspace_for_all_reduce_fusion to iterate over those stored flag pointers and call cudart.cudaFree(flag_ptr) (or cudart.cudaFree(c_void_p(flag_ptr))) before removing entries from _symm_workspace_refs/workspace to ensure proper CUDA memory deallocation.
🧹 Nitpick comments (1)
flashinfer/comm/trtllm_ar.py (1)
403-404: Type annotation mismatch: stores tuples but annotated aslist[torch.Tensor].The dict stores
(tensor, handle)tuples (see lines 484, 641), but the type annotation sayslist[torch.Tensor]. This should be updated for accuracy.🔧 Suggested type fix
-_symm_workspace_refs: dict[int, list[torch.Tensor]] = {} +_symm_workspace_refs: dict[int, list[tuple[torch.Tensor, Any]]] = {}You'll also need to add
Anyto the imports fromtyping.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/comm/trtllm_ar.py` around lines 403 - 404, The _symm_workspace_refs dictionary is annotated as dict[int, list[torch.Tensor]] but actually stores (tensor, handle) tuples; update the annotation to reflect list[tuple[torch.Tensor, Any]] (or list[tuple[torch.Tensor, HandleType]] if a concrete handle type exists) and add Any to the typing imports so the annotation is valid; ensure any other occurrences or type checks that reference _symm_workspace_refs are adjusted to expect tuples rather than bare tensors.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/comm/trtllm_ar.py`:
- Around line 728-740: The destroy function
trtllm_destroy_ipc_workspace_for_all_reduce_fusion currently only pops
_symm_workspace_refs but fails to free the CUDA allocation pointed to by
flag_ptr created in trtllm_create_ipc_workspace_for_all_reduce_fusion; update
the function to lookup the stored record in _symm_workspace_refs (using
id(workspace)), if present free the CUDA allocation referenced by its flag_ptr
(using the same CUDA/free API used when allocating it), then remove the entry
from _symm_workspace_refs and handle missing entries gracefully so no dangling
GPU memory remains.
---
Duplicate comments:
In `@flashinfer/comm/trtllm_ar.py`:
- Around line 680-692: The allocated CUDA pointer flag_ptr created in the
workspace setup (see flag_ptr = cudart.cudaMalloc(...) and
workspace.append(flag_ptr.value)) is never freed, causing memory leaks; modify
the workspace tracking so flag_ptr is stored alongside the workspace/symmetric
refs (e.g., append a tuple or push into a dedicated list such as _flag_ptrs)
when created in the routine that allocates it, and update
trtllm_destroy_ipc_workspace_for_all_reduce_fusion to iterate over those stored
flag pointers and call cudart.cudaFree(flag_ptr) (or
cudart.cudaFree(c_void_p(flag_ptr))) before removing entries from
_symm_workspace_refs/workspace to ensure proper CUDA memory deallocation.
---
Nitpick comments:
In `@flashinfer/comm/trtllm_ar.py`:
- Around line 403-404: The _symm_workspace_refs dictionary is annotated as
dict[int, list[torch.Tensor]] but actually stores (tensor, handle) tuples;
update the annotation to reflect list[tuple[torch.Tensor, Any]] (or
list[tuple[torch.Tensor, HandleType]] if a concrete handle type exists) and add
Any to the typing imports so the annotation is valid; ensure any other
occurrences or type checks that reference _symm_workspace_refs are adjusted to
expect tuples rather than bare tensors.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: f7f709fc-7f42-4acf-9099-457780b2932c
📒 Files selected for processing (2)
flashinfer/comm/trtllm_ar.pyflashinfer/comm/trtllm_mnnvl_ar.py
| def trtllm_destroy_ipc_workspace_for_all_reduce_fusion( | ||
| workspace: List[List[int]], group: Optional[ProcessGroup] = None | ||
| ) -> None: | ||
| """ | ||
| Parameters: | ||
| - workspace: the workspace to destroy. | ||
| - group: the process group to use. | ||
| """Destroy a workspace created by trtllm_create_ipc_workspace_for_all_reduce_fusion. | ||
|
|
||
| Note: | ||
| This function is used to destroy a workspace for all reduce fusion. | ||
| The workspace is a list of IPC handles. | ||
| The workspace should be destroyed after calling trtllm_custom_all_reduce_fusion. | ||
| The workspace can be reused for multiple all reduce fusion calls under the same configuration. | ||
| """ | ||
| Releases the symmetric memory references held internally. The workspace | ||
| list should not be used after this call. | ||
|
|
||
| for ipc_handle in workspace: | ||
| free_shared_buffer(ipc_handle, group) | ||
| Args: | ||
| workspace: The ipc_handles list returned by the create function. | ||
| group: Unused, kept for API compatibility. | ||
| """ | ||
| _symm_workspace_refs.pop(id(workspace), None) |
There was a problem hiding this comment.
Destroy function incomplete: missing flag_ptr cleanup.
This function only removes symmetric memory references but does not free the flag_ptr CUDA allocation created in trtllm_create_ipc_workspace_for_all_reduce_fusion. See the related comment above for the suggested fix.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/comm/trtllm_ar.py` around lines 728 - 740, The destroy function
trtllm_destroy_ipc_workspace_for_all_reduce_fusion currently only pops
_symm_workspace_refs but fails to free the CUDA allocation pointed to by
flag_ptr created in trtllm_create_ipc_workspace_for_all_reduce_fusion; update
the function to lookup the stored record in _symm_workspace_refs (using
id(workspace)), if present free the CUDA allocation referenced by its flag_ptr
(using the same CUDA/free API used when allocating it), then remove the entry
from _symm_workspace_refs and handle missing entries gracefully so no dangling
GPU memory remains.
|
[FAILED] Pipeline #47672454: 12/20 passed |
Signed-off-by: Amir Samani <asamani@nvidia.com>
|
/bot run |
|
[FAILED] Pipeline #47877295: 10/20 passed |
|
|
||
|
|
||
| def _patch_group_count_reset() -> None: | ||
| """Prevent group_count from resetting to 0 on WORLD destruction (2.10 only). |
There was a problem hiding this comment.
@kwen2501 Is this hack necessary for pytorch 2.10?
There was a problem hiding this comment.
@kwen2501 Could you have a look at this PR?
There was a problem hiding this comment.
Hmm, do we need to support the case of in-process restart (hence calling init_process_group twice)?
| # all sizes should be aligned to 1LU << 21 bytes (2MB) | ||
| aligned_size = round_up(size, 1 << 21) |
There was a problem hiding this comment.
Why ? torch symm memory will use a mempool under the cover so that you can have smaller requests. Probably good to be make sure it is 16B aligned but not 2MB.
| if dtype == torch.bfloat16 or dtype == torch.float16: | ||
| neg_zero = 0x8000 | ||
| dsize = 2 | ||
| memset_func = cuda.cuMemsetD16 |
There was a problem hiding this comment.
why can't you use tensor.fill_(-0.0) ?
There was a problem hiding this comment.
fixed! thank you! I learned something new.
|
|
||
|
|
||
| def _patch_group_count_reset() -> None: | ||
| """Prevent group_count from resetting to 0 on WORLD destruction (2.10 only). |
There was a problem hiding this comment.
Hmm, do we need to support the case of in-process restart (hence calling init_process_group twice)?
| This helper mimics the 2.11 behaviour: it calls ``set_group_info`` with the | ||
| group's native store (no extra prefix) and populates the Python-side guard | ||
| dict so that ``enable_symm_mem_for_group`` becomes a no-op for this group. |
There was a problem hiding this comment.
Sorry I am a bit confused.
torch 2.11 purposely deprecates the enable_symm_mem_for_group API.
Should user just check the torch version, and call enable_symm_mem_for_group when version is lower than 2.11? That is it.
There was a problem hiding this comment.
I haven't investigate deeply to what exactly is happening but if I do
torch_version = tuple(int(x) for x in torch.__version__.split(".")[:2])
if torch_version >= (2, 11):
return
from torch.distributed._symmetric_memory import enable_symm_mem_for_group
enable_symm_mem_for_group(group_name)
mpirun -np 2 pytest tests/comm/test_allreduce_unified_api.py -vv -s hangs after the first test case. adding _patch_group_count_reset() fixes the issue.
| elem_size = torch.empty(0, dtype=dtype).element_size() | ||
| numel = size_bytes // elem_size | ||
| tensor = symm_mem.empty(numel, dtype=dtype, device=device) |
There was a problem hiding this comment.
Would it be more ergonomic if the API asks for shape or numel instead of size_bytes?
There was a problem hiding this comment.
we can add another api or option to use this api with shape or numel but I wanted to be least intrusive in the kernels.
| handle.get_buffer(peer, (numel,), dtype, storage_offset=0).data_ptr() | ||
| for peer in range(world_size) | ||
| ] | ||
| return ptrs, tensor, handle |
There was a problem hiding this comment.
nit: ptrs and handle are redundant to each other in this return. When user has the handle, they can get the ptrs themselves.
|
One general thought I have is that: This call uses internal pooling of torch symm_mem, thus reusable. |
|
Correction: |
Signed-off-by: Amir Samani <asamani@nvidia.com>
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/comm/trtllm_mnnvl_ar.py (1)
150-173:⚠️ Potential issue | 🔴 CriticalFix dtype-dependent Lamport initialization in symmetric buffer allocation.
Line 153 allocates the symmetric buffer with hardcoded
torch.float32, ignoring the actual dtype parameter. For fp16/bf16 reductions, the Lamport sentinel must be the 16-bit negative-zero pattern (0xBC00 for float16, 0xBF80 for bfloat16), not the float32 pattern (0xBF800000). This causes incorrect synchronization values and silent data corruption.Additionally, line 552 calls
MNNVLAllReduceFusionWorkspacewithout passing thedtypeparameter whenbuffer_size_in_bytesis provided, allowing the workspace to be initialized with float32 regardless of the actual data type in use.Fix:
- Add
dtyperequirement whenbuffer_size_in_bytesis provided- Pass the actual
dtypeto_alloc_symm_buffer_bytesinstead of hardcodingtorch.float32- Pass
dtypeparameter in the workspace constructor call at line 552Proposed diff
else: logging.debug( f"[MNNVL Allreduce] Using provided buffer size override in bytes: {buffer_size_in_bytes} bytes." ) + if dtype is None: + raise ValueError( + "dtype must be provided when buffer_size_in_bytes is provided; " + "Lamport initialization is dtype-dependent." + ) @@ self.ptrs, self.tensor, self.handle = _alloc_symm_buffer_bytes( requested_workspace_size, mapping.tp_size, - torch.float32, + dtype, device, group_name, ) @@ workspace = MNNVLAllReduceFusionWorkspace( mapping, + dtype=dtype, buffer_size_in_bytes=buffer_size_in_bytes, comm_backend=comm_backend_for_handle_transfer, )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/comm/trtllm_mnnvl_ar.py` around lines 150 - 173, The symmetric buffer allocation currently hardcodes torch.float32; update the allocation and workspace construction to use the actual dtype: change the callsite that invokes _alloc_symm_buffer_bytes(...) to pass the real dtype (not torch.float32) so the Lamport sentinel matches fp16/bf16 bit patterns, ensure MNNVLAllReduceFusionWorkspace requires/accepts a dtype when buffer_size_in_bytes is provided and propagate that dtype into its constructor call (the call at the other location that currently omits dtype), and ensure the Lamport initialization (self.tensor.fill_(-0.0)) semantics remain correct for the supplied dtype; reference _alloc_symm_buffer_bytes, MNNVLAllReduceFusionWorkspace, and self.tensor.fill_(-0.0) when making the changes.
♻️ Duplicate comments (3)
flashinfer/comm/trtllm_ar.py (1)
403-403:⚠️ Potential issue | 🟠 MajorFree the fusion
flag_ptrallocation on destroy.Line 679 still allocates a raw CUDA flag buffer, but the destroy path only drops symmetric-memory refs. Track this pointer with the workspace and release it in
trtllm_destroy_ipc_workspace_for_all_reduce_fusion.🐛 Proposed fix
-_symm_workspace_refs: dict[int, list[torch.Tensor]] = {} +_symm_workspace_refs: dict[int, list[tuple[torch.Tensor, object]]] = {} +_symm_flag_ptrs: dict[int, int] = {} @@ # add flag_ptr to workspace workspace.append(flag_ptr.value) + _symm_flag_ptrs[id(ipc_handles)] = flag_ptr.value @@ def trtllm_destroy_ipc_workspace_for_all_reduce_fusion( workspace: List[List[int]], group: Optional[ProcessGroup] = None ) -> None: @@ - _symm_workspace_refs.pop(id(workspace), None) + flag_ptr = _symm_flag_ptrs.pop(id(workspace), None) + if flag_ptr is not None: + cudart.cudaFree(c_void_p(flag_ptr)) + _symm_workspace_refs.pop(id(workspace), None)Also applies to: 678-689, 725-737
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/comm/trtllm_ar.py` at line 403, The code currently stores symmetric-memory refs in _symm_workspace_refs but does not free the raw CUDA allocation used for the fusion flag (flag_ptr); update the workspace bookkeeping to track the flag_ptr alongside the symmetric refs (e.g., add it into the workspace struct or map entry created where flag_ptr is allocated) and ensure trtllm_destroy_ipc_workspace_for_all_reduce_fusion frees the CUDA buffer (cudaFree or the equivalent used elsewhere) when destroying the workspace; make symmetric references and flag_ptr lifetime tied so both are released in the same destroy path (references: _symm_workspace_refs and trtllm_destroy_ipc_workspace_for_all_reduce_fusion).flashinfer/comm/torch_symmetric_memory.py (1)
68-70:⚠️ Potential issue | 🟡 MinorDon’t floor byte-sized allocations.
size_bytes // elem_sizecan allocate fewer bytes than requested when the size is not divisible by the dtype size. Since callers treatsize_bytesas capacity, round up or reject non-divisible requests.🐛 Proposed fix
elem_size = torch.empty(0, dtype=dtype).element_size() - numel = size_bytes // elem_size + numel = (size_bytes + elem_size - 1) // elem_size tensor = symm_mem.empty(numel, dtype=dtype, device=device)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/comm/torch_symmetric_memory.py` around lines 68 - 70, The current calculation numel = size_bytes // elem_size can under-allocate when size_bytes isn't divisible by elem_size; update the logic in torch_symmetric_memory.py (around elem_size, numel, and tensor = symm_mem.empty...) to either (a) validate divisibility and raise a clear error if size_bytes % elem_size != 0, or (b) round up using ceiling division (numel = (size_bytes + elem_size - 1) // elem_size) so the returned tensor has at least the requested capacity; ensure the chosen behavior is documented in the function's contract and used consistently where symm_mem.empty is called.flashinfer/comm/trtllm_mnnvl_ar.py (1)
141-149:⚠️ Potential issue | 🟠 MajorDon’t rendezvous non-Torch backends on
WORLD.
MPIBackendis still the default, but this path derives the symmetric-memory rendezvous group from torchWORLD, which can be uninitialized or the wrong peer set for non-TorchDistBackendcommunicators. Please require a backend-provided process-group identity or fail explicitly instead of silently usingWORLD.#!/bin/bash # Verify whether non-Torch CommBackend implementations expose a torch process-group/group_name # that can be used instead of falling back to torch.distributed.group.WORLD. rg -n -C3 'class .*Backend|def .*group|group_name|TorchDistBackend|CommBackend' --iglob '*.py'🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@flashinfer/comm/trtllm_mnnvl_ar.py` around lines 141 - 149, The code currently falls back to torch.distributed.group.WORLD for non-TorchDistBackend instances, which can be uninitialized or incorrect for other backends; update the logic around comm_backend, TorchDistBackend, comm_backend._group and group_name so that for non-TorchDistBackend you query a backend-provided process-group identity (e.g., a method or attribute on the CommBackend interface) and use that value, and if the backend does not expose a valid group/group_name then raise an explicit error instead of silently using torch.distributed.group.WORLD; ensure you reference and validate comm_backend._group and group_name and fail fast with a clear message when no backend-provided group is available.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@flashinfer/comm/torch_symmetric_memory.py`:
- Line 26: The comment containing an EN DASH should be changed to use an ASCII
hyphen; locate the comment near the WORLD handling that reads "WORLD destruction
resets group_count to 0 – restore it so the next" and replace the EN DASH with a
regular hyphen ("-") so it reads "WORLD destruction resets group_count to 0 -
restore it so the next", ensuring the comment uses ASCII punctuation to satisfy
Ruff/EN DASH linting.
---
Outside diff comments:
In `@flashinfer/comm/trtllm_mnnvl_ar.py`:
- Around line 150-173: The symmetric buffer allocation currently hardcodes
torch.float32; update the allocation and workspace construction to use the
actual dtype: change the callsite that invokes _alloc_symm_buffer_bytes(...) to
pass the real dtype (not torch.float32) so the Lamport sentinel matches
fp16/bf16 bit patterns, ensure MNNVLAllReduceFusionWorkspace requires/accepts a
dtype when buffer_size_in_bytes is provided and propagate that dtype into its
constructor call (the call at the other location that currently omits dtype),
and ensure the Lamport initialization (self.tensor.fill_(-0.0)) semantics remain
correct for the supplied dtype; reference _alloc_symm_buffer_bytes,
MNNVLAllReduceFusionWorkspace, and self.tensor.fill_(-0.0) when making the
changes.
---
Duplicate comments:
In `@flashinfer/comm/torch_symmetric_memory.py`:
- Around line 68-70: The current calculation numel = size_bytes // elem_size can
under-allocate when size_bytes isn't divisible by elem_size; update the logic in
torch_symmetric_memory.py (around elem_size, numel, and tensor =
symm_mem.empty...) to either (a) validate divisibility and raise a clear error
if size_bytes % elem_size != 0, or (b) round up using ceiling division (numel =
(size_bytes + elem_size - 1) // elem_size) so the returned tensor has at least
the requested capacity; ensure the chosen behavior is documented in the
function's contract and used consistently where symm_mem.empty is called.
In `@flashinfer/comm/trtllm_ar.py`:
- Line 403: The code currently stores symmetric-memory refs in
_symm_workspace_refs but does not free the raw CUDA allocation used for the
fusion flag (flag_ptr); update the workspace bookkeeping to track the flag_ptr
alongside the symmetric refs (e.g., add it into the workspace struct or map
entry created where flag_ptr is allocated) and ensure
trtllm_destroy_ipc_workspace_for_all_reduce_fusion frees the CUDA buffer
(cudaFree or the equivalent used elsewhere) when destroying the workspace; make
symmetric references and flag_ptr lifetime tied so both are released in the same
destroy path (references: _symm_workspace_refs and
trtllm_destroy_ipc_workspace_for_all_reduce_fusion).
In `@flashinfer/comm/trtllm_mnnvl_ar.py`:
- Around line 141-149: The code currently falls back to
torch.distributed.group.WORLD for non-TorchDistBackend instances, which can be
uninitialized or incorrect for other backends; update the logic around
comm_backend, TorchDistBackend, comm_backend._group and group_name so that for
non-TorchDistBackend you query a backend-provided process-group identity (e.g.,
a method or attribute on the CommBackend interface) and use that value, and if
the backend does not expose a valid group/group_name then raise an explicit
error instead of silently using torch.distributed.group.WORLD; ensure you
reference and validate comm_backend._group and group_name and fail fast with a
clear message when no backend-provided group is available.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 12ab60c1-7581-4266-8096-e091d82ff379
📒 Files selected for processing (3)
flashinfer/comm/torch_symmetric_memory.pyflashinfer/comm/trtllm_ar.pyflashinfer/comm/trtllm_mnnvl_ar.py
| def _patched_destroy(group=None): | ||
| saved_count = c10d._world.group_count | ||
| _original_destroy(group) | ||
| # WORLD destruction resets group_count to 0 – restore it so the next |
There was a problem hiding this comment.
Use ASCII punctuation in comments.
Ruff flags the EN DASH in this comment; replace it with - to keep pre-commit clean.
🧹 Proposed fix
- # WORLD destruction resets group_count to 0 – restore it so the next
+ # WORLD destruction resets group_count to 0 - restore it so the next📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # WORLD destruction resets group_count to 0 – restore it so the next | |
| # WORLD destruction resets group_count to 0 - restore it so the next |
🧰 Tools
🪛 Ruff (0.15.10)
[warning] 26-26: Comment contains ambiguous – (EN DASH). Did you mean - (HYPHEN-MINUS)?
(RUF003)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@flashinfer/comm/torch_symmetric_memory.py` at line 26, The comment containing
an EN DASH should be changed to use an ASCII hyphen; locate the comment near the
WORLD handling that reads "WORLD destruction resets group_count to 0 – restore
it so the next" and replace the EN DASH with a regular hyphen ("-") so it reads
"WORLD destruction resets group_count to 0 - restore it so the next", ensuring
the comment uses ASCII punctuation to satisfy Ruff/EN DASH linting.
📌 Description
The goal of this PR is to unify memory allocation for all reduce to use torch symmetric memory instead of custom allocators in in flashinfer.
🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Improvements
Bug Fixes
Tests