Skip to content
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
62e4c15
make this work with colocated
sidsingh-nvidia Mar 13, 2026
7cdd9ac
fixes for non colocated RL
sidsingh-nvidia Mar 13, 2026
342a982
cleanup
sidsingh-nvidia Mar 13, 2026
4358bab
remove old global symmetric buffer references
sidsingh-nvidia Mar 13, 2026
afab913
cleanup
sidsingh-nvidia Mar 13, 2026
0d37a1e
cleanup
sidsingh-nvidia Mar 13, 2026
9b54b5e
add back destroy
sidsingh-nvidia Mar 13, 2026
151471f
lint
sidsingh-nvidia Mar 13, 2026
3845007
update unit test
sidsingh-nvidia Mar 13, 2026
bd10366
Update megatron/core/inference/text_generation_controllers/text_gener…
sidsingh-nvidia Mar 13, 2026
33e617b
use inference column parallel linear in shared experts
sidsingh-nvidia Mar 13, 2026
7525273
config args for grouped gemm backend
sidsingh-nvidia Mar 13, 2026
77e6eb3
move ggemm resolution to megatron/core/inference
sidsingh-nvidia Mar 13, 2026
ba22786
support torch with cuda graphs
sidsingh-nvidia Mar 13, 2026
eb2ba4f
refactor
sidsingh-nvidia Mar 13, 2026
b08036f
checkpoint
sidsingh-nvidia Mar 13, 2026
e3ba13d
lateest
sidsingh-nvidia Mar 13, 2026
32d89cc
delete files
sidsingh-nvidia Mar 13, 2026
ba8ee2e
remove accidentally commited file
sidsingh-nvidia Mar 13, 2026
3e569ed
restore flask server
sidsingh-nvidia Mar 13, 2026
b582a4a
remove
sidsingh-nvidia Mar 13, 2026
e3b776c
remove
sidsingh-nvidia Mar 13, 2026
540f8d4
remove
sidsingh-nvidia Mar 13, 2026
5986c34
remove
sidsingh-nvidia Mar 13, 2026
f8e8d23
remove files
sidsingh-nvidia Mar 13, 2026
c503024
remove more files
sidsingh-nvidia Mar 13, 2026
d456b3d
support torch fused moe without cuda graphs
sidsingh-nvidia Mar 13, 2026
9a7c0ea
yield control in cmq sync all reduce
sidsingh-nvidia Mar 13, 2026
b335959
Merge branch 'main' into siddharth/torch-ggemm-mxfp8
sidsingh-nvidia Mar 16, 2026
8353f34
unpermute and reduce-scatter in fp32
sidsingh-nvidia Mar 16, 2026
9ba5b26
add comments
sidsingh-nvidia Mar 16, 2026
a0b7b73
add unit tests
sidsingh-nvidia Mar 16, 2026
9f0a97d
add mxfp8 unit tests
sidsingh-nvidia Mar 16, 2026
57103d6
fused quantization kernels + multimem fp32 reduce-scatter + unit tests
sidsingh-nvidia Mar 16, 2026
f41e387
linting
sidsingh-nvidia Mar 16, 2026
6990acc
Merge branch 'main' into siddharth/torch-ggemm-mxfp8
sidsingh-nvidia Mar 16, 2026
48ee34c
disable flashinfer unit test on hopper
sidsingh-nvidia Mar 16, 2026
1e74e6f
remove pylint comment
sidsingh-nvidia Mar 16, 2026
de51ebf
add comments on swizzled layout
sidsingh-nvidia Mar 16, 2026
8f2719c
add comment
sidsingh-nvidia Mar 16, 2026
986f1fd
Merge branch 'main' into siddharth/torch-ggemm-mxfp8
sidsingh-nvidia Mar 16, 2026
c6d151c
lint
sidsingh-nvidia Mar 16, 2026
42fc570
remove comment
sidsingh-nvidia Mar 16, 2026
3738e9d
lint
sidsingh-nvidia Mar 16, 2026
c58fa2b
fix mamba moe unit test
sidsingh-nvidia Mar 16, 2026
efa4315
remove arg
sidsingh-nvidia Mar 17, 2026
deb8752
Merge branch 'main' into siddharth/torch-ggemm-mxfp8
sidsingh-nvidia Mar 17, 2026
d3c5f25
changes
sidsingh-nvidia Mar 17, 2026
725cbf8
changes
sidsingh-nvidia Mar 17, 2026
c3fd456
remove dead arg
sidsingh-nvidia Mar 17, 2026
0660fe9
Merge branch 'main' into siddharth/torch-ggemm-mxfp8
sidsingh-nvidia Mar 17, 2026
3add0f3
repair unit tests
sidsingh-nvidia Mar 17, 2026
c907257
heal test
sidsingh-nvidia Mar 17, 2026
a7d4d52
Merge branch 'main' into siddharth/torch-ggemm-mxfp8
sidsingh-nvidia Mar 17, 2026
7b2e512
small change
sidsingh-nvidia Mar 17, 2026
6e314f3
bugfix for non-cudagraphed codepath
sidsingh-nvidia Mar 17, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions megatron/core/inference/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,8 @@ class InferenceConfig:
A list of the per-request metadata types to track. Each entry is a tuple
consisting of the string label, the target dtype, and whether to store the data on GPU.
"""

use_synchronous_zmq_collectives: bool = False
"""Whether to use synchronous ZMQ collectives for inference. If True, the all_reduce_max operation
will be performed synchronously, which can help reduce performance variability for MoEs.
"""
18 changes: 15 additions & 3 deletions megatron/core/inference/engines/async_zmq_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(self, zmq_context: zmq.Context, process_group: dist.ProcessGroup):
self.bcast_sock.connect(bcast_socket_addr)
self.bcast_sock.setsockopt_string(zmq.SUBSCRIBE, "")

async def all_reduce_max(self, *local_vals: int) -> int | tuple[int, ...]:
async def all_reduce_max(self, *local_vals: int, async_op=True) -> int | tuple[int, ...]:
"""Element-wise all-reduce max of one or more integers.

Packs all values into a single message so the communication cost
Expand All @@ -88,25 +88,37 @@ async def all_reduce_max(self, *local_vals: int) -> int | tuple[int, ...]:

while len(rows) < self.world_size:
try:
msg = self.gather_sock.recv(flags=zmq.NOBLOCK)
if async_op:
msg = self.gather_sock.recv(flags=zmq.NOBLOCK)
else:
msg = self.gather_sock.recv()
rows.append(struct.unpack(fmt, msg))
except zmq.Again:
await asyncio.sleep(0.001)

maxes = tuple(max(row[i] for row in rows) for i in range(n))
self.bcast_sock.send(struct.pack(fmt, *maxes))
if not async_op:
await asyncio.sleep(0) # Yield control once to ensure that other coroutines can run. This might be needed for colocated RL.
return maxes[0] if n == 1 else maxes

else:
self.gather_sock.send(payload)

while True:
try:
msg = self.bcast_sock.recv(flags=zmq.NOBLOCK)
if async_op:
msg = self.bcast_sock.recv(flags=zmq.NOBLOCK)
else:
msg = self.bcast_sock.recv()
result = struct.unpack(fmt, msg)
if not async_op:
await asyncio.sleep(0) # Yield control once to ensure that other coroutines can run. This might be needed for colocated RL.
return result[0] if n == 1 else result
except zmq.Again:
await asyncio.sleep(0.001)



def close(self):
"""
Expand Down
5 changes: 3 additions & 2 deletions megatron/core/inference/engines/dynamic_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def __init__(self, controller: TextGenerationController, context: DynamicInferen
self.metrics_writer = inference_config.metrics_writer
self.logging_step_interval = inference_config.logging_step_interval
self.unified_memory_level = inference_config.unified_memory_level
self.use_synchronous_zmq_collectives = inference_config.use_synchronous_zmq_collectives
self.cuda_graph_impl = model_config.cuda_graph_impl
self.cuda_graph_scope = model_config.cuda_graph_scope
# Initialize engine.
Expand Down Expand Up @@ -2066,7 +2067,7 @@ async def _ep_establish_consensus(
# We have tried that and it blocks the event loop in megatron-rl.
global_work, global_consensus = (
await self.expert_parallel_zmq_communicator.all_reduce_max(
local_work, consensus_val
local_work, consensus_val, async_op=(not self.use_synchronous_zmq_collectives)
)
)
else:
Expand All @@ -2086,7 +2087,7 @@ async def _world_barrier(self):
"""
range_push("world_barrier")
if hasattr(self, 'world_zmq_communicator'):
await self.world_zmq_communicator.all_reduce_max(1)
await self.world_zmq_communicator.all_reduce_max(1, async_op=(not self.use_synchronous_zmq_collectives))
range_pop()

@trace_async_exceptions
Expand Down
48 changes: 48 additions & 0 deletions megatron/core/inference/moe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.

import enum
import torch
from .fused_moe import ActivationType, mcore_fused_moe


class InferenceGroupedGemmBackend(enum.Enum):
"""Resolved backend for grouped GEMM operations during inference."""

FLASHINFER = "flashinfer"
TORCH = "torch"
TE = "te"


def resolve_inference_grouped_gemm_backend(
backend: str,
is_cuda_graphed: bool,
) -> InferenceGroupedGemmBackend:
"""Resolve the grouped GEMM backend to use for the current iteration.

Prerequisites are validated at init time in MoELayer; this function
simply maps (backend, is_cuda_graphed) to the concrete backend enum.

Args:
backend: One of 'auto', 'torch', 'te'.
is_cuda_graphed: Whether this is a CUDA-graphed iteration.

Returns:
An InferenceGroupedGemmBackend enum value.
"""
if backend == 'auto':
if is_cuda_graphed:
return InferenceGroupedGemmBackend.FLASHINFER
else:
if hasattr(torch.nn.functional, 'grouped_mm'):
return InferenceGroupedGemmBackend.TORCH
else:
return InferenceGroupedGemmBackend.TE
elif backend == 'torch':
return InferenceGroupedGemmBackend.TORCH
elif backend == 'te':
return InferenceGroupedGemmBackend.TE
else:
raise ValueError(
f"Unknown inference_grouped_gemm_backend: '{backend}'. "
"Must be 'auto', 'torch', or 'te'."
)
91 changes: 91 additions & 0 deletions megatron/core/inference/moe/activations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
"""Padding-aware activation kernels for fused MoE.

These kernels skip padding rows (where permutation_map == -1) to avoid
wasted computation on aligned-but-empty expert slots.
"""

from unittest.mock import MagicMock

import torch
from packaging import version

from megatron.core.utils import null_decorator

try:
import triton
import triton.language as tl

if version.parse(triton.__version__) < version.parse("3.4.0") and not torch.cuda.is_available():
HAVE_TRITON = False
else:
HAVE_TRITON = tl.constexpr(version.parse(triton.__version__) >= version.parse("2.0.0"))
except ImportError:
HAVE_TRITON = False

if not HAVE_TRITON:
triton = MagicMock()
triton.jit = null_decorator
tl = MagicMock()


@triton.jit
def _squared_relu_kernel(
input_ptr, output_ptr, src_idx_ptr, M, N,
BLOCK_N: tl.constexpr,
):
"""Squared ReLU that skips padding rows (permutation_map == -1)."""
row = tl.program_id(0)
if tl.load(src_idx_ptr + row) < 0:
return
for n in tl.range(0, N, BLOCK_N):
o = n + tl.arange(0, BLOCK_N)
m = o < N
x = tl.load(input_ptr + row * N + o, mask=m)
r = tl.maximum(x, 0.0)
tl.store(output_ptr + row * N + o, r * r, mask=m)


def padded_squared_relu(
x: torch.Tensor, permutation_map: torch.Tensor
) -> torch.Tensor:
"""Squared ReLU activation that skips padding rows."""
M, N = x.shape
out = torch.zeros(M, N, dtype=x.dtype, device=x.device)
BLOCK_N = min(triton.next_power_of_2(N), 1024)
_squared_relu_kernel[(M,)](
x, out, permutation_map, M, N, BLOCK_N=BLOCK_N,
)
return out


@triton.jit
def _swiglu_kernel(
input_ptr, output_ptr, src_idx_ptr, M, N,
BLOCK_N: tl.constexpr,
):
"""SwiGLU activation that skips padding rows (permutation_map == -1)."""
row = tl.program_id(0)
if tl.load(src_idx_ptr + row) < 0:
return
for n in tl.range(0, N, BLOCK_N):
o = n + tl.arange(0, BLOCK_N)
m = o < N
gate = tl.load(input_ptr + row * 2 * N + o, mask=m)
up = tl.load(input_ptr + row * 2 * N + N + o, mask=m)
tl.store(output_ptr + row * N + o,
tl.sigmoid(gate.to(tl.float32)).to(gate.dtype) * gate * up, mask=m)


def padded_swiglu(
x: torch.Tensor, permutation_map: torch.Tensor
) -> torch.Tensor:
"""SwiGLU activation that skips padding rows."""
M = x.shape[0]
N = x.shape[1] // 2
out = torch.zeros(M, N, dtype=x.dtype, device=x.device)
BLOCK_N = min(triton.next_power_of_2(N), 1024)
_swiglu_kernel[(M,)](
x, out, permutation_map, M, N, BLOCK_N=BLOCK_N,
)
return out
168 changes: 168 additions & 0 deletions megatron/core/inference/moe/fused_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
"""Fused MoE: permute -> FC1 -> activation -> FC2 -> unpermute.

Supports BF16 weights with torch.nn.functional.grouped_mm.
All permutation logic is handled internally — callers invoke a single function.
"""

from enum import Enum
from typing import Callable

import torch

from typing import Optional

from megatron.core.inference.moe.activations import padded_squared_relu, padded_swiglu
from megatron.core.inference.moe.pad import pad_to_alignment, unpad_from_alignment
from megatron.core.inference.moe.permute import permute_tokens, unpermute_tokens
from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor

try:
from torch.nn.functional import grouped_mm

HAVE_GROUPED_MM = True
except ImportError:
HAVE_GROUPED_MM = False

try:
from torch.nn.functional import scaled_grouped_mm, ScalingType, SwizzleType
HAVE_SCALED_GMM = True
except ImportError:
HAVE_SCALED_GMM = False


class ActivationType(Enum):
"""Activation functions supported by mcore_fused_moe."""
SQUARED_RELU = "squared_relu"
SWIGLU = "swiglu"


def _bf16_grouped_mm(
x_bf16: torch.Tensor, weight: torch.Tensor, offs: torch.Tensor,
) -> torch.Tensor:
"""BF16 grouped GEMM using torch.nn.functional.grouped_mm."""
assert x_bf16.dtype == torch.bfloat16, f"Expected bf16 input, got {x_bf16.dtype}"
return grouped_mm(x_bf16, weight.transpose(1, 2), offs=offs)


def _mxfp8_grouped_mm(
act: MXFP8Tensor, weight: MXFP8Tensor, offs: torch.Tensor,
) -> torch.Tensor:
"""MXFP8 scaled_grouped_mm with pre-quantized activations and weights."""
return scaled_grouped_mm(
act.data, weight.data.transpose(1, 2),
act.scale_2d(), ScalingType.BlockWise1x32,
weight.scale, ScalingType.BlockWise1x32,
swizzle_a=SwizzleType.SWIZZLE_32_4_4, swizzle_b=SwizzleType.SWIZZLE_32_4_4,
offs=offs, output_dtype=torch.bfloat16,
)

def _get_activation_func(
activation_type: ActivationType
) -> Callable:
"""Resolve ActivationType enum to a concrete kernel."""
if activation_type == ActivationType.SWIGLU:
return padded_swiglu
elif activation_type == ActivationType.SQUARED_RELU:
return padded_squared_relu
else:
raise ValueError(f"Unsupported activation type: {activation_type}")


def mcore_fused_moe(
hidden_states: torch.Tensor,
probs: torch.Tensor,
fc1_weight,
fc2_weight,
activation_type: ActivationType,
num_local_experts: int,
local_expert_start: int,
routing_map: Optional[torch.Tensor] = None,
tokens_per_expert: Optional[torch.Tensor] = None,
skip_permute: bool = False,
) -> torch.Tensor:
"""Fused MoE: [permute ->] pad -> FC1 -> activation -> FC2 -> unpad [-> unpermute].

Two modes:
- skip_permute=False (default): tokens are unpermuted. Requires routing_map.
Performs full permute -> compute -> unpermute.
- skip_permute=True: tokens are already permuted by the dispatcher. Requires
tokens_per_expert. Pads to alignment, computes, then unpads. Probs are
applied during unpad.

Args:
hidden_states: [num_tokens, hidden_size] BF16 input.
probs: routing probabilities. Shape is [num_tokens, topk] when
skip_permute=False, or [num_tokens] (already gathered) when
skip_permute=True.
fc1_weight: stacked weight for FC1 (torch.Tensor for BF16, MXFP8Tensor for MXFP8).
fc2_weight: stacked weight for FC2 (same type as fc1_weight).
activation_type: ActivationType enum (SWIGLU or SQUARED_RELU).
num_local_experts: number of experts on this rank.
local_expert_start: first global expert index on this rank.
routing_map: [num_tokens, topk] int expert assignments. Required when skip_permute=False.
tokens_per_expert: [num_local_experts] int32 token counts. Required when skip_permute=True.
skip_permute: if True, skip permute/unpermute (tokens already in expert order).

Returns:
[num_tokens, hidden_size] BF16 output.
"""
assert hidden_states.dtype == torch.bfloat16, (
f"mcore_fused_moe requires bf16 input, got {hidden_states.dtype}"
)

num_tokens = hidden_states.shape[0]
use_mxfp8 = isinstance(fc1_weight, MXFP8Tensor)
activation_func = _get_activation_func(activation_type)

if use_mxfp8:
assert HAVE_SCALED_GMM, "torch.nn.functional.scaled_grouped_mm not available"
mm_fn = _mxfp8_grouped_mm
# scaled_grouped_mm requires each expert's token count aligned to 32,
# but swizzled MXFP8 scales require alignment to 128. Use 128 to
# satisfy both constraints.
expert_alignment = 128
else:
assert HAVE_GROUPED_MM, "torch.nn.functional.grouped_mm not available"
mm_fn = _bf16_grouped_mm
expert_alignment = 16

# --- Pre-processing: permute or pad ---
if skip_permute:
assert tokens_per_expert is not None, (
"tokens_per_expert is required when skip_permute=True"
)
tokens_per_expert = tokens_per_expert.cuda().int()
assert routing_map is None, (
"routing_map must be None when skip_permute=True"
)
work_hidden, permutation_map, offs = pad_to_alignment(
hidden_states, tokens_per_expert, expert_alignment,
)
permuted_probs = None

else:
assert routing_map is not None, (
"routing_map is required when skip_permute=False"
)
work_hidden, permuted_probs, permutation_map, offs = permute_tokens(
hidden_states, probs, routing_map,
local_expert_start, num_local_experts,
alignment=expert_alignment,
)

# --- FC1 -> activation -> FC2 ---
if use_mxfp8:
work_hidden = MXFP8Tensor.from_bf16(work_hidden, backend="triton")
fc1_output = mm_fn(work_hidden, fc1_weight, offs)
activated = activation_func(fc1_output, permutation_map)
if use_mxfp8:
activated = MXFP8Tensor.from_bf16(activated, backend="triton")
fc2_output = mm_fn(activated, fc2_weight, offs)

# --- Post-processing: unpermute or unpad ---
if skip_permute:
probs_1d = probs.squeeze(-1) if probs.dim() > 1 else probs
return unpad_from_alignment(fc2_output, permutation_map, num_tokens, probs=probs_1d)
else:
return unpermute_tokens(fc2_output, permuted_probs, permutation_map, num_tokens)
Loading