Skip to content

Commit 4fc7f18

Browse files
Nemo-RL integration bugfixes for --transformer-impl inference_optimized (#3851)
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
1 parent f8becec commit 4fc7f18

File tree

10 files changed

+259
-231
lines changed

10 files changed

+259
-231
lines changed

megatron/core/inference/contexts/dynamic_context.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -573,7 +573,6 @@ def __init__(self, model_config: TransformerConfig, inference_config: InferenceC
573573

574574
# Allocate GPU state.
575575
self.is_tensor_state_allocated = False
576-
self.is_symmetric_memory_initialized = False
577576
self.initialize_all_tensors()
578577

579578
# Print info.
@@ -2893,11 +2892,3 @@ def get_kvcache_utilization_stats(self) -> dict:
28932892
'total_request_count': int(total_request_count),
28942893
'max_requests': int(self.max_requests),
28952894
}
2896-
2897-
def maybe_initialize_symmetric_memory(self):
2898-
"""
2899-
Initializes symmetric memory for inference, if not already initialized
2900-
"""
2901-
if not self.is_symmetric_memory_initialized:
2902-
parallel_state._set_global_symmetric_memory_buffer()
2903-
self.is_symmetric_memory_initialized = True
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
3+
"""Lazy-initialized symmetric memory manager for inference.
4+
5+
Provides a registry of SymmetricMemoryBuffer instances keyed by a
6+
user-supplied identifier (e.g. "tp", "ep"). Buffers are created on first
7+
access so that callers never need to worry about initialization ordering
8+
relative to the inference context.
9+
"""
10+
11+
from __future__ import annotations
12+
13+
import operator
14+
from functools import reduce
15+
from typing import Optional
16+
17+
import torch
18+
19+
try:
20+
import torch.distributed._symmetric_memory as symm_mem
21+
22+
HAVE_TORCH_SYMM_MEM = True
23+
except ImportError:
24+
HAVE_TORCH_SYMM_MEM = False
25+
26+
try:
27+
import triton # pylint: disable=unused-import
28+
29+
HAVE_TRITON = True
30+
except ImportError:
31+
HAVE_TRITON = False
32+
33+
34+
class SymmetricMemoryBuffer:
35+
"""
36+
symmetric memory buffer used in inference.
37+
This buffer is used by mcore-inference's low-latency
38+
NVLS all-gather and reduce-scatter collectives.
39+
"""
40+
41+
def __init__(self, size_in_mb, process_group):
42+
if not HAVE_TORCH_SYMM_MEM or not HAVE_TRITON:
43+
# This should be hit if the user is running an older
44+
# version of torch, or if they do not have triton
45+
# installed.
46+
self.symm_buffer = None
47+
self.symm_mem_hdl = None
48+
else:
49+
numel = int(size_in_mb * 1024 * 1024) # size in bytes
50+
try:
51+
symm_mem.enable_symm_mem_for_group(process_group.group_name)
52+
self.symm_buffer = symm_mem.empty(numel, dtype=torch.uint8, device='cuda')
53+
self.symm_mem_hdl = symm_mem.rendezvous(self.symm_buffer, process_group)
54+
except RuntimeError as e:
55+
# If symmetric memory initialization fails, set buffer and handle to None
56+
# This should happen if the process group is not contained within NVlink
57+
self.symm_buffer = None
58+
self.symm_mem_hdl = None
59+
60+
def _can_allocate(self, numel, dtype) -> bool:
61+
"""
62+
Returns whether enough symmetric memory is available
63+
for the given tensor shape and dtype.
64+
"""
65+
if self.symm_mem_hdl is None:
66+
return False
67+
size_of_dtype = torch.tensor([], dtype=dtype).element_size()
68+
required_len = numel * size_of_dtype
69+
return required_len <= self.symm_buffer.numel()
70+
71+
def _allocate(self, numel, dtype) -> torch.Tensor:
72+
"""
73+
Allocates a sub-tensor from the self.symm_buffer for the given numel and dtype"""
74+
required_bytes = numel * torch.tensor([], dtype=dtype).element_size()
75+
return self.symm_buffer[0:required_bytes].view(dtype).view(numel)
76+
77+
def maybe_get_tensors(self, tensor_specs, alignment=16):
78+
"""
79+
Pack multiple tensors contiguously in the symmetric buffer with alignment.
80+
81+
Each tensor's starting offset is aligned to `alignment` bytes (default 16
82+
for 128-bit multimem access).
83+
84+
Args:
85+
tensor_specs: list of (numel, dtype) tuples.
86+
alignment: byte alignment for each tensor's start offset (default 16).
87+
88+
Returns:
89+
{"handle": None, "tensors": None} if unavailable or insufficient space.
90+
{"handle": symm_mem_hdl, "tensors": [(raw_byte_view, byte_offset), ...]}
91+
on success, where raw_byte_view is a uint8 slice of the buffer.
92+
"""
93+
_NONE_RESULT = {"handle": None, "tensors": None}
94+
if self.symm_mem_hdl is None:
95+
return _NONE_RESULT
96+
97+
# Compute aligned byte sizes and running offsets
98+
slices = []
99+
current_offset = 0
100+
for numel, dtype in tensor_specs:
101+
nbytes = numel * torch.tensor([], dtype=dtype).element_size()
102+
aligned_nbytes = ((nbytes + alignment - 1) // alignment) * alignment
103+
slices.append((current_offset, nbytes))
104+
current_offset += aligned_nbytes
105+
106+
if not self._can_allocate(current_offset, torch.uint8):
107+
return _NONE_RESULT
108+
109+
tensors = []
110+
for offset, nbytes in slices:
111+
tensors.append((self.symm_buffer[offset : offset + nbytes], offset))
112+
113+
return {"handle": self.symm_mem_hdl, "tensors": tensors}
114+
115+
def maybe_get_tensor(self, tensor_shape, dtype):
116+
"""
117+
Returns (potentially) a sub-tensor from the self.symm_buffer for the given shape.
118+
If enough symmetric memory is not available, returns None.
119+
"""
120+
if self.symm_mem_hdl is None:
121+
return {"tensor": None, "handle": None}
122+
numel = reduce(operator.mul, tensor_shape, 1)
123+
if not self._can_allocate(numel, dtype):
124+
return {"tensor": None, "handle": None}
125+
return {
126+
"tensor": self._allocate(numel, dtype).view(*tensor_shape),
127+
"handle": self.symm_mem_hdl,
128+
}
129+
130+
131+
class SymmetricMemoryManager:
132+
"""Registry of lazily-initialized symmetric memory buffers.
133+
134+
Usage::
135+
136+
buf = SymmetricMemoryManager.get_buffer("tp", process_group=tp_group)
137+
result = buf.maybe_get_tensor(shape, dtype)
138+
"""
139+
140+
_buffers: dict[str, SymmetricMemoryBuffer] = {}
141+
_default_size_mb: int = 256
142+
143+
@classmethod
144+
def get_buffer(
145+
cls,
146+
key: str,
147+
process_group: Optional[torch.distributed.ProcessGroup] = None,
148+
size_mb: Optional[int] = None,
149+
) -> SymmetricMemoryBuffer:
150+
"""Return the buffer for *key*, creating it on first call.
151+
152+
Args:
153+
key: Unique identifier (e.g. "tp", "ep").
154+
process_group: Required on the first call for a given key.
155+
Subsequent calls may omit it.
156+
size_mb: Buffer size in MiB (default 256).
157+
"""
158+
if key not in cls._buffers:
159+
assert (
160+
process_group is not None
161+
), f"SymmetricMemoryManager: process_group is required on first access for key='{key}'"
162+
cls._buffers[key] = SymmetricMemoryBuffer(
163+
size_in_mb=size_mb or cls._default_size_mb, process_group=process_group
164+
)
165+
return cls._buffers[key]
166+
167+
@classmethod
168+
def destroy(cls, key: Optional[str] = None) -> None:
169+
"""Destroy one or all buffers.
170+
171+
Args:
172+
key: If provided, destroy only that buffer. Otherwise destroy all.
173+
"""
174+
if key is not None:
175+
cls._buffers.pop(key, None)
176+
else:
177+
cls._buffers.clear()
178+
179+
@classmethod
180+
def is_initialized(cls, key: str) -> bool:
181+
"""Check whether a buffer has been created for *key*."""
182+
return key in cls._buffers

megatron/core/inference/text_generation_controllers/text_generation_controller.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -580,10 +580,6 @@ def _dynamic_step_context_init(
580580
else:
581581
set_decode_expert_padding(unwrapped_model, False)
582582

583-
# initialize symmetric memory if needed
584-
if model_config.transformer_impl == "inference_optimized":
585-
context.maybe_initialize_symmetric_memory()
586-
587583
if nccl_all_reduce_for_prefill and symmetric_ar_type is not None:
588584
if context.is_decode_only():
589585
# Turn on symmetric all reduce when in decode mode
@@ -1595,11 +1591,6 @@ def dummy_forward(self):
15951591
context = self.inference_wrapped_model.inference_context
15961592
# if no cuda graphs, directly use dummy forward
15971593
if not context.cuda_graph_batch_dimensions_list:
1598-
# initialize symmetric memory if needed
1599-
unwrapped_model = unwrap_model(self.inference_wrapped_model.model)
1600-
model_config = get_model_config(unwrapped_model)
1601-
if model_config.transformer_impl == "inference_optimized":
1602-
context.maybe_initialize_symmetric_memory()
16031594
self.inference_wrapped_model.dummy_forward()
16041595

16051596
# Disable MoE padding for MTP computation

megatron/core/parallel_state.py

Lines changed: 5 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212
import numpy as np
1313
import torch
1414

15-
from .utils import GlobalMemoryBuffer, GlobalSymmetricMemoryBuffer, is_torch_min_version
15+
from megatron.core.inference.symmetric_memory import SymmetricMemoryManager
16+
17+
from .utils import GlobalMemoryBuffer, is_torch_min_version
1618

1719
logger = logging.getLogger(__name__)
1820

@@ -138,9 +140,6 @@
138140
# Memory buffers to avoid dynamic memory allocation
139141
_GLOBAL_MEMORY_BUFFER = None
140142

141-
# Global symmetric memory buffers for inference
142-
_GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP = None
143-
_GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP = None
144143

145144
# List of all process groups
146145
# Used for updating the timeout for all process groups
@@ -2017,62 +2016,18 @@ def _set_global_memory_buffer():
20172016
_GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()
20182017

20192018

2020-
def _set_global_symmetric_memory_buffer():
2021-
"""Initialize global buffer."""
2022-
global _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP, _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP
2023-
assert (
2024-
_GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP is None
2025-
), "global symmetric memory buffer for TP is already initialized"
2026-
assert (
2027-
_GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP is None
2028-
), "global symmetric memory buffer for EP is already initialized"
2029-
2030-
_GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP = GlobalSymmetricMemoryBuffer(
2031-
size_in_mb=256, # todo: set from an argument?
2032-
process_group=get_tensor_model_parallel_group(),
2033-
)
2034-
2035-
_GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP = GlobalSymmetricMemoryBuffer(
2036-
size_in_mb=256, # todo: set from an argument?
2037-
process_group=get_expert_model_parallel_group(),
2038-
)
2039-
2040-
20412019
def get_global_memory_buffer():
20422020
"""Return the global GlobalMemoryBuffer object"""
20432021
assert _GLOBAL_MEMORY_BUFFER is not None, "global memory buffer is not initialized"
20442022
return _GLOBAL_MEMORY_BUFFER
20452023

20462024

2047-
def get_global_symmetric_memory_buffer_tp():
2048-
"""Return the global GlobalSymmetricMemoryBuffer object"""
2049-
assert (
2050-
_GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP is not None
2051-
), "global symmetric memory buffer is not initialized"
2052-
return _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP
2053-
2054-
2055-
def get_global_symmetric_memory_buffer_ep():
2056-
"""Return the global GlobalSymmetricMemoryBuffer object"""
2057-
assert (
2058-
_GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP is not None
2059-
), "global symmetric memory buffer is not initialized"
2060-
return _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP
2061-
2062-
20632025
def destroy_global_memory_buffer():
20642026
"""Sets the global memory buffer to None"""
20652027
global _GLOBAL_MEMORY_BUFFER
20662028
_GLOBAL_MEMORY_BUFFER = None
20672029

20682030

2069-
def destroy_global_symmetric_memory_buffer():
2070-
"""Sets the global symmetric memory buffer to None"""
2071-
global _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP, _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP
2072-
_GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP = None
2073-
_GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP = None
2074-
2075-
20762031
def get_all_ranks():
20772032
"""Get caller's rank in tensor-model-parallel, data-parallel, context-parallel,
20782033
pipeline-model-parallel and expert-model-parallel groups."""
@@ -2151,12 +2106,6 @@ def destroy_model_parallel():
21512106
global _GLOBAL_MEMORY_BUFFER
21522107
_GLOBAL_MEMORY_BUFFER = None
21532108

2154-
global _GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP
2155-
_GLOBAL_SYMMETRIC_MEMORY_BUFFER_TP = None
2156-
2157-
global _GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP
2158-
_GLOBAL_SYMMETRIC_MEMORY_BUFFER_EP = None
2159-
21602109
global _DATA_PARALLEL_GROUP_GLOO
21612110
if (
21622111
_DATA_PARALLEL_GROUP_GLOO is not None
@@ -2239,3 +2188,5 @@ def destroy_model_parallel():
22392188

22402189
global _global_process_group_list
22412190
_global_process_group_list = None
2191+
2192+
SymmetricMemoryManager.destroy()

0 commit comments

Comments
 (0)