diff --git a/megatron/core/inference/communication/torch_symm_triton/collectives.py b/megatron/core/inference/communication/torch_symm_triton/collectives.py index cf2003c8595..3120f032812 100644 --- a/megatron/core/inference/communication/torch_symm_triton/collectives.py +++ b/megatron/core/inference/communication/torch_symm_triton/collectives.py @@ -165,9 +165,11 @@ def _multimem_reduce_scatter_kernel( NUMEL_PER_THREAD: tl.constexpr, RANK: tl.constexpr, WORLD_SIZE: tl.constexpr, + REDUCE_F32: tl.constexpr = False, ): """ Triton kernel to perform multicast reduce-scatter over nvlink using multimem instructions. + When REDUCE_F32=True, uses fp32 reduction instead of bf16x2 reduction. """ symm_mem_sync( signal_pad_ptrs, @@ -196,7 +198,7 @@ def _multimem_reduce_scatter_kernel( multicast_ptr.to(tl.pointer_type(tl.uint64)) + (RANK * numel_per_rank + offsets) * 2 ) local_ptrs = local_ptr.to(tl.pointer_type(tl.uint64)) + offsets * 2 - (x, y, z, w) = ld_128(multicast_ptrs, mask=mask, multicast_op=True) + (x, y, z, w) = ld_128(multicast_ptrs, mask=mask, multicast_op=True, reduce_f32=REDUCE_F32) st_128(local_ptrs, x, y, z, w, mask=mask, multicast_op=False) block_start += tl.num_programs(axis=0) * BLOCK_SIZE @@ -328,10 +330,16 @@ def multimem_reduce_scatter( Multicast reduce-scatter for a single tensor. Input tensor must be a symmetric memory buffer. Output tensor can be a regular torch tensor. + Supports bfloat16 and float32 dtypes. """ assert HAVE_TRITON, "Triton is required for multimem reduce-scatter." - assert input_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." - assert output_tensor.dtype == torch.bfloat16, "Only bfloat16 is supported for now." + assert input_tensor.dtype in ( + torch.bfloat16, + torch.float32, + ), f"Only bfloat16 and float32 are supported, got {input_tensor.dtype}" + assert ( + input_tensor.dtype == output_tensor.dtype + ), f"Input and output dtypes must match: {input_tensor.dtype} vs {output_tensor.dtype}" assert are_tensors_nvls_eligible( output_tensor ), "Output tensor must be 16-byte divisible on Hopper+ for NVLS." @@ -340,6 +348,7 @@ def multimem_reduce_scatter( and input_tensor.numel() // output_tensor.numel() == symm_mem_hdl.world_size ), "Input numel must be exactly world_size * output numel for reduce-scatter." + reduce_f32 = input_tensor.dtype == torch.float32 numel_per_thread, num_blocks, config = _kernel_launch_config( output_tensor.element_size(), input_tensor.numel(), symm_mem_hdl.world_size, **kwargs ) @@ -353,6 +362,7 @@ def multimem_reduce_scatter( RANK=symm_mem_hdl.rank, WORLD_SIZE=symm_mem_hdl.world_size, num_warps=config["num_warps"], + REDUCE_F32=reduce_f32, ) return output_tensor diff --git a/megatron/core/inference/communication/torch_symm_triton/multimem_asm.py b/megatron/core/inference/communication/torch_symm_triton/multimem_asm.py index 774c3f6d2bf..859b9010aea 100644 --- a/megatron/core/inference/communication/torch_symm_triton/multimem_asm.py +++ b/megatron/core/inference/communication/torch_symm_triton/multimem_asm.py @@ -1,7 +1,9 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. +# pylint: disable=line-too-long # Adapted from https://github.com/yifuwang/symm-mem-recipes.git + from unittest.mock import MagicMock from megatron.core.utils import null_decorator @@ -16,60 +18,76 @@ @triton.jit -def ld_128(ptr, mask, multicast_op: tl.constexpr): +def ld_128(ptr, mask, multicast_op: tl.constexpr, reduce_f32: tl.constexpr = False): """ - Loads 128 bits (8 x bf16) from memory into registers. + Loads 128 bits from memory into registers. This function abstracts two distinct hardware behaviors based on `multicast_op`: 1. **Standard Load (`multicast_op=False`)**: - **Semantics:** Local Global Memory Load. - **Action:** Reads 128 bits from `ptr` in global memory into the local register file. - - **Use Case:** Standard tensor processing. 2. **Multicast Reduce-Load (`multicast_op=True`)**: - **Semantics:** "Pull" Reduction over NVLink. - **Action:** Simultaneously reads 128 bits from the *same* address across all peer GPUs - in the multicast group, sums them (add reduction), and loads the result into the - local register file. + in the multicast group, sums them, and loads the result into the local register file. - **Hardware:** Uses `multimem.ld_reduce` (Hopper+). - - **Use Case:** The "Reduce" step in collective operations. + - When `reduce_f32=False` (default): bf16x2 addition with f32 accumulation + (128 bits = 8 x bf16, 2 per register). + - When `reduce_f32=True`: native f32 addition + (128 bits = 4 x fp32, 1 per register). Args: ptr: Memory pointer to the source buffer. mask: Boolean predicate. If False, the operation is skipped (no-op). multicast_op (tl.constexpr): Toggles between standard load (False) - and multicast-reduce (True). + and multicast-reduce (True). + reduce_f32 (tl.constexpr): When True and multicast_op=True, uses f32 reduction + instead of bf16x2 reduction. Default False. Returns: Four 32-bit registers (tl.uint32), representing 128 bits of loaded data. - Note: When interpreting as bf16, this equates to 8 values (2 per register). """ - # PTX Assembly Logic: - # 1. @$5: Predication. Only execute if argument 5 (mask) is True (1). - # 2. Opcode Selection: - # - 'multimem.ld_reduce...add.v4.bf16x2': Hardware-accelerated reduction across peers. - # - 'ld.global...v4.u32': Standard 128-bit memory read. - # 3. Operands: - # - {$0, $1, $2, $3}: Destination registers (Output). - # - [$4]: Source memory address (Input). if multicast_op: - return tl.inline_asm_elementwise( - """ - { - .reg .pred %p0; - setp.ne.s32 %p0, $5, 1; - @%p0 bra end; - multimem.ld_reduce.relaxed.sys.global.add.acc::f32.v4.bf16x2 {$0, $1, $2, $3}, [$4]; - end: - } - """, - "=r,=r,=r,=r,l,r", - args=[ptr, mask.to(tl.int32)], - dtype=(tl.uint32, tl.uint32, tl.uint32, tl.uint32), - is_pure=True, - pack=1, - ) + if reduce_f32: + # fp32 reduction: multimem.ld_reduce.add.v4.f32 + # Each 128-bit load reduces 4 x fp32 values across peers. + return tl.inline_asm_elementwise( + """ + { + .reg .pred %p0; + setp.ne.s32 %p0, $5, 1; + @%p0 bra end; + multimem.ld_reduce.relaxed.sys.global.add.v4.f32 {$0, $1, $2, $3}, [$4]; + end: + } + """, + "=r,=r,=r,=r,l,r", + args=[ptr, mask.to(tl.int32)], + dtype=(tl.uint32, tl.uint32, tl.uint32, tl.uint32), + is_pure=True, + pack=1, + ) + else: + # bf16x2 reduction with f32 accumulation: multimem.ld_reduce.add.acc::f32.v4.bf16x2 + # Each 128-bit load reduces 8 x bf16 values (packed as 4 x bf16x2) across peers. + return tl.inline_asm_elementwise( + """ + { + .reg .pred %p0; + setp.ne.s32 %p0, $5, 1; + @%p0 bra end; + multimem.ld_reduce.relaxed.sys.global.add.acc::f32.v4.bf16x2 {$0, $1, $2, $3}, [$4]; + end: + } + """, + "=r,=r,=r,=r,l,r", + args=[ptr, mask.to(tl.int32)], + dtype=(tl.uint32, tl.uint32, tl.uint32, tl.uint32), + is_pure=True, + pack=1, + ) else: return tl.inline_asm_elementwise( """ diff --git a/megatron/core/inference/config.py b/megatron/core/inference/config.py index bc0770d450d..63ec551cb5d 100644 --- a/megatron/core/inference/config.py +++ b/megatron/core/inference/config.py @@ -281,3 +281,9 @@ class InferenceConfig: A list of the per-request metadata types to track. Each entry is a tuple consisting of the string label, the target dtype, and whether to store the data on GPU. """ + + use_synchronous_zmq_collectives: bool = False + """Whether to use synchronous ZMQ collectives for inference. If True, the + all_reduce_max operation will be performed synchronously, which can help reduce + performance variability for MoEs. + """ diff --git a/megatron/core/inference/engines/async_zmq_communicator.py b/megatron/core/inference/engines/async_zmq_communicator.py index 124f2d46932..6af35be7fb4 100644 --- a/megatron/core/inference/engines/async_zmq_communicator.py +++ b/megatron/core/inference/engines/async_zmq_communicator.py @@ -65,7 +65,7 @@ def __init__(self, zmq_context: zmq.Context, process_group: dist.ProcessGroup): self.bcast_sock.connect(bcast_socket_addr) self.bcast_sock.setsockopt_string(zmq.SUBSCRIBE, "") - async def all_reduce_max(self, *local_vals: int) -> int | tuple[int, ...]: + async def all_reduce_max(self, *local_vals: int, async_op=True) -> int | tuple[int, ...]: """Element-wise all-reduce max of one or more integers. Packs all values into a single message so the communication cost @@ -88,13 +88,21 @@ async def all_reduce_max(self, *local_vals: int) -> int | tuple[int, ...]: while len(rows) < self.world_size: try: - msg = self.gather_sock.recv(flags=zmq.NOBLOCK) + if async_op: + msg = self.gather_sock.recv(flags=zmq.NOBLOCK) + else: + msg = self.gather_sock.recv() rows.append(struct.unpack(fmt, msg)) except zmq.Again: await asyncio.sleep(0.001) maxes = tuple(max(row[i] for row in rows) for i in range(n)) self.bcast_sock.send(struct.pack(fmt, *maxes)) + if not async_op: + await asyncio.sleep( + 0 + ) # Yield control once to ensure that other coroutines can run. + # This might be needed for colocated RL. return maxes[0] if n == 1 else maxes else: @@ -102,8 +110,16 @@ async def all_reduce_max(self, *local_vals: int) -> int | tuple[int, ...]: while True: try: - msg = self.bcast_sock.recv(flags=zmq.NOBLOCK) + if async_op: + msg = self.bcast_sock.recv(flags=zmq.NOBLOCK) + else: + msg = self.bcast_sock.recv() result = struct.unpack(fmt, msg) + if not async_op: + await asyncio.sleep( + 0 + ) # Yield control once to ensure that other coroutines can run. + # This might be needed for colocated RL. return result[0] if n == 1 else result except zmq.Again: await asyncio.sleep(0.001) diff --git a/megatron/core/inference/engines/dynamic_engine.py b/megatron/core/inference/engines/dynamic_engine.py index 409c726d10c..9c2e15ce7b5 100644 --- a/megatron/core/inference/engines/dynamic_engine.py +++ b/megatron/core/inference/engines/dynamic_engine.py @@ -224,6 +224,7 @@ def __init__(self, controller: TextGenerationController, context: DynamicInferen self.metrics_writer = inference_config.metrics_writer self.logging_step_interval = inference_config.logging_step_interval self.unified_memory_level = inference_config.unified_memory_level + self.use_synchronous_zmq_collectives = inference_config.use_synchronous_zmq_collectives self.cuda_graph_impl = model_config.cuda_graph_impl self.cuda_graph_scope = model_config.cuda_graph_scope # Initialize engine. @@ -2066,7 +2067,7 @@ async def _ep_establish_consensus( # We have tried that and it blocks the event loop in megatron-rl. global_work, global_consensus = ( await self.expert_parallel_zmq_communicator.all_reduce_max( - local_work, consensus_val + local_work, consensus_val, async_op=(not self.use_synchronous_zmq_collectives) ) ) else: @@ -2086,7 +2087,9 @@ async def _world_barrier(self): """ range_push("world_barrier") if hasattr(self, 'world_zmq_communicator'): - await self.world_zmq_communicator.all_reduce_max(1) + await self.world_zmq_communicator.all_reduce_max( + 1, async_op=(not self.use_synchronous_zmq_collectives) + ) range_pop() @trace_async_exceptions diff --git a/megatron/core/inference/moe/__init__.py b/megatron/core/inference/moe/__init__.py new file mode 100644 index 00000000000..ea716b5fbd5 --- /dev/null +++ b/megatron/core/inference/moe/__init__.py @@ -0,0 +1,57 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import enum + +import torch + +from .fused_moe import ActivationType, mcore_fused_moe + + +class InferenceGroupedGemmBackend(enum.Enum): + """Resolved backend for grouped GEMM operations during inference.""" + + FLASHINFER = "flashinfer" + TORCH = "torch" + TE = "te" + + +def resolve_inference_grouped_gemm_backend( + backend: str, is_cuda_graphed: bool, is_mxfp8: bool = False +) -> InferenceGroupedGemmBackend: + """Resolve the grouped GEMM backend to use for the current iteration. + + Prerequisites are validated at init time in MoELayer; this function + simply maps (backend, is_cuda_graphed) to the concrete backend enum. + + Args: + backend: One of 'auto', 'torch', 'te'. + is_cuda_graphed: Whether this is a CUDA-graphed iteration. + is_mxfp8: Whether the model is using MXFP8 quantization (affects auto backend choice). + Returns: + An InferenceGroupedGemmBackend enum value. + """ + if backend == 'auto': + if is_cuda_graphed: + if is_mxfp8: + assert hasattr(torch.nn.functional, 'scaled_grouped_mm'), ( + "Auto backend selection for MXFP8 requires " + "torch.nn.functional.scaled_grouped_mm. " + "Please install PyTorch 2.10+." + ) + return InferenceGroupedGemmBackend.TORCH + else: + return InferenceGroupedGemmBackend.FLASHINFER + else: + if hasattr(torch.nn.functional, 'grouped_mm'): + return InferenceGroupedGemmBackend.TORCH + else: + return InferenceGroupedGemmBackend.TE + elif backend == 'torch': + return InferenceGroupedGemmBackend.TORCH + elif backend == 'te': + return InferenceGroupedGemmBackend.TE + else: + raise ValueError( + f"Unknown inference_grouped_gemm_backend: '{backend}'. " + "Must be 'auto', 'torch', or 'te'." + ) diff --git a/megatron/core/inference/moe/activations.py b/megatron/core/inference/moe/activations.py new file mode 100644 index 00000000000..169d8499116 --- /dev/null +++ b/megatron/core/inference/moe/activations.py @@ -0,0 +1,166 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +"""Padding-aware activation kernels for fused MoE. + +These kernels skip padding rows (where permutation_map == -1) to avoid +wasted computation on aligned-but-empty expert slots. +""" + +from unittest.mock import MagicMock + +import torch + +from megatron.core.utils import null_decorator + +try: + import triton + import triton.language as tl + + HAVE_TRITON = True +except ImportError: + HAVE_TRITON = False + +if not HAVE_TRITON: + triton = MagicMock() + triton.jit = null_decorator + tl = MagicMock() + + +def _ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def _squared_relu_kernel(input_ptr, output_ptr, src_idx_ptr, M, N, BLOCK_N: tl.constexpr): + """Squared ReLU that skips padding rows (permutation_map == -1).""" + row = tl.program_id(0) + if tl.load(src_idx_ptr + row) < 0: + return + for n in tl.range(0, N, BLOCK_N): + o = n + tl.arange(0, BLOCK_N) + m = o < N + x = tl.load(input_ptr + row * N + o, mask=m).to(tl.float32) + r = tl.maximum(x, 0.0) + tl.store(output_ptr + row * N + o, (r * r).to(tl.bfloat16), mask=m) + + +def padded_squared_relu(x: torch.Tensor, permutation_map: torch.Tensor) -> torch.Tensor: + """Squared ReLU activation that skips padding rows.""" + M, N = x.shape + out = torch.zeros(M, N, dtype=x.dtype, device=x.device) + BLOCK_N = min(triton.next_power_of_2(N), 1024) + _squared_relu_kernel[(M,)](x, out, permutation_map, M, N, BLOCK_N=BLOCK_N) + return out + + +@triton.jit +def _squared_relu_quantize_kernel( + input_ptr, + out_fp8_ptr, + out_scale_ptr, + src_idx_ptr, + K, + n_col_blocks, + skip_padding: tl.constexpr, + REAL_GROUPS: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_GROUPS: tl.constexpr, +): + """Fused squared ReLU + MXFP8 quantize + swizzle in one kernel. + + Grid: (M,) — one program per row. + Reads BF16 FC1 output, applies squared ReLU, quantizes to FP8, + writes FP8 data + swizzled scales in place. + """ + row = tl.program_id(0) + if skip_padding: + if tl.load(src_idx_ptr + row) < 0: + return + + offs = tl.arange(0, BLOCK_K) + mask = offs < K + + # Load and apply squared ReLU + x = tl.load(input_ptr + row * K + offs, mask=mask, other=0.0).to(tl.float32) + relu = tl.maximum(x, 0.0) + activated = relu * relu + + # Per-group-of-32 quantization + x_grouped = tl.reshape(activated, [BLOCK_GROUPS, 32]) + abs_grouped = tl.abs(x_grouped) + max_vals = tl.max(abs_grouped, axis=1) + + dequant_scale = max_vals / 448.0 + dequant_exp = (dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF) & 0x7F800000 + dequant_rounded = dequant_exp.to(tl.float32, bitcast=True) + quant_scale = tl.where(dequant_rounded == 0, 0.0, 1.0 / dequant_rounded) + + quantized = x_grouped * quant_scale[:, None] + quantized_flat = tl.reshape(quantized, [BLOCK_K]) + out_fp8 = quantized_flat.to(tl.float8e4nv) + + # Store FP8 data + tl.store(out_fp8_ptr + row * K + offs, out_fp8, mask=mask) + + # Store swizzled scales + scale_exp = (dequant_exp >> 23).to(tl.uint8) + col_offs = tl.arange(0, BLOCK_GROUPS) + col_mask = col_offs < REAL_GROUPS + + macro_row_block = row // 128 + macro_col_block = col_offs // 4 + local_row = row % 128 + local_col = col_offs % 4 + group = local_row // 32 + sub_row = local_row % 32 + tile_idx = macro_row_block * n_col_blocks + macro_col_block + swizzled_offs = tile_idx * 512 + sub_row * 16 + group * 4 + local_col + + tl.store(out_scale_ptr + swizzled_offs, scale_exp, mask=col_mask) + + +def squared_relu_and_quantize_mxfp8( + x: torch.Tensor, permutation_map: torch.Tensor, skip_padding: bool = True +): + """Fused squared ReLU + MXFP8 quantize + swizzle. + + Reads BF16 FC1 output, applies squared ReLU, quantizes to FP8 with + swizzled scales. Single kernel replaces padded_squared_relu + mxfp8_quantize. + + Args: + x: [M, K] BF16 FC1 output. + permutation_map: [M] int32, original token index or -1 for padding. + skip_padding: if True, skip rows where permutation_map == -1. + + Returns: + MXFP8Tensor with .data [M, K] float8_e4m3fn and .scale (swizzled e8m0). + """ + from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor + + M, K = x.shape + assert K % 32 == 0 + + scale_cols = K // 32 + n_row_blocks = _ceil_div(M, 128) + n_col_blocks = _ceil_div(scale_cols, 4) + total_scale_bytes = n_row_blocks * n_col_blocks * 512 + + out_fp8 = torch.empty(M, K, dtype=torch.float8_e4m3fn, device=x.device) + out_scale = torch.zeros(total_scale_bytes, dtype=torch.uint8, device=x.device) + + BLOCK_K = triton.next_power_of_2(K) + BLOCK_GROUPS = BLOCK_K // 32 + + _squared_relu_quantize_kernel[(M,)]( + x, + out_fp8, + out_scale, + permutation_map, + K, + n_col_blocks, + skip_padding, + REAL_GROUPS=scale_cols, + BLOCK_K=BLOCK_K, + BLOCK_GROUPS=BLOCK_GROUPS, + ) + + return MXFP8Tensor(data=out_fp8, scale=out_scale.view(torch.float8_e8m0fnu), backend="triton") diff --git a/megatron/core/inference/moe/fused_moe.py b/megatron/core/inference/moe/fused_moe.py new file mode 100644 index 00000000000..39382eee079 --- /dev/null +++ b/megatron/core/inference/moe/fused_moe.py @@ -0,0 +1,204 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +"""Fused MoE: permute -> FC1 -> activation -> FC2 -> unpermute. + +Supports BF16 weights with torch.nn.functional.grouped_mm. +All permutation logic is handled internally — callers invoke a single function. +""" + +from enum import Enum +from typing import Callable, Optional + +import torch + +from megatron.core.inference.moe.activations import ( + padded_squared_relu, + squared_relu_and_quantize_mxfp8, +) +from megatron.core.inference.moe.pad import pad_to_alignment, unpad_from_alignment +from megatron.core.inference.moe.permute import ( + permute_and_quantize_mxfp8, + permute_tokens, + unpermute_tokens, +) +from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor + +try: + from torch.nn.functional import grouped_mm + + HAVE_GROUPED_MM = True +except ImportError: + HAVE_GROUPED_MM = False + +try: + from torch.nn.functional import ScalingType, SwizzleType, scaled_grouped_mm + + HAVE_SCALED_GMM = True +except ImportError: + HAVE_SCALED_GMM = False + + +class ActivationType(Enum): + """Activation functions supported by mcore_fused_moe.""" + + SQUARED_RELU = "squared_relu" + + +def _bf16_grouped_mm( + x_bf16: torch.Tensor, weight: torch.Tensor, offs: torch.Tensor +) -> torch.Tensor: + """BF16 grouped GEMM using torch.nn.functional.grouped_mm.""" + assert x_bf16.dtype == torch.bfloat16, f"Expected bf16 input, got {x_bf16.dtype}" + return grouped_mm(x_bf16, weight.transpose(1, 2), offs=offs) + + +def _mxfp8_grouped_mm(act: MXFP8Tensor, weight: MXFP8Tensor, offs: torch.Tensor) -> torch.Tensor: + """MXFP8 scaled_grouped_mm with pre-quantized activations and weights.""" + return scaled_grouped_mm( + act.data, + weight.data.transpose(1, 2), + act.scale_2d(), + ScalingType.BlockWise1x32, + weight.scale, + ScalingType.BlockWise1x32, + swizzle_a=SwizzleType.SWIZZLE_32_4_4, + swizzle_b=SwizzleType.SWIZZLE_32_4_4, + offs=offs, + output_dtype=torch.bfloat16, + ) + + +def _get_activation_func(activation_type: ActivationType, fused_quant: bool = False) -> Callable: + """Resolve ActivationType enum to a concrete kernel. + + If fused_quant=True, returns the fused activation + MXFP8 quantize kernel. + """ + if activation_type == ActivationType.SQUARED_RELU: + return squared_relu_and_quantize_mxfp8 if fused_quant else padded_squared_relu + else: + raise ValueError(f"Unsupported activation type: {activation_type}") + + +def mcore_fused_moe( + hidden_states: torch.Tensor, + probs: torch.Tensor, + fc1_weight, + fc2_weight, + activation_type: ActivationType, + num_local_experts: int, + local_expert_start: int, + routing_map: Optional[torch.Tensor] = None, + tokens_per_expert: Optional[torch.Tensor] = None, + skip_permute: bool = False, + disable_fused_quant_kernels: bool = False, +) -> torch.Tensor: + """Fused MoE: [permute ->] pad -> FC1 -> activation -> FC2 -> unpad [-> unpermute]. + + Two modes: + - skip_permute=False (default): tokens are unpermuted. Requires routing_map. + Performs full permute -> compute -> unpermute. + - skip_permute=True: tokens are already permuted by the dispatcher. Requires + tokens_per_expert. Pads to alignment, computes, then unpads. Probs are + applied during unpad. + + Unless disable_fused_quant_kernels=True, when weights are MXFP8, uses fused + kernels that combine permute/activation with MXFP8 quantization into single + kernel launches. + + Args: + hidden_states: [num_tokens, hidden_size] BF16 input. + probs: routing probabilities. Shape is [num_tokens, topk] when + skip_permute=False, or [num_tokens] (already gathered) when + skip_permute=True. + fc1_weight: stacked weight for FC1 (torch.Tensor for BF16, MXFP8Tensor for MXFP8). + fc2_weight: stacked weight for FC2 (same type as fc1_weight). + activation_type: ActivationType enum (SQUARED_RELU). + num_local_experts: number of experts on this rank. + local_expert_start: first global expert index on this rank. + routing_map: [num_tokens, topk] int expert assignments. Required when skip_permute=False. + tokens_per_expert: [num_local_experts] int32 token counts. Required when skip_permute=True. + skip_permute: if True, skip permute/unpermute (tokens already in expert order). + disable_fused_quant_kernels: if True, disable fused permute+quantize and + activation+quantize kernels for MXFP8, using separate launches instead. + Useful for debugging. Ignored when weights are BF16. + + Returns: + [num_tokens, hidden_size] BF16 output. + """ + assert ( + hidden_states.dtype == torch.bfloat16 + ), f"mcore_fused_moe requires bf16 input, got {hidden_states.dtype}" + + num_tokens = hidden_states.shape[0] + use_mxfp8 = isinstance(fc1_weight, MXFP8Tensor) + # Fused quant kernels only apply to MXFP8 path + use_fused_quant = use_mxfp8 and not disable_fused_quant_kernels + + if use_mxfp8: + assert ( + HAVE_SCALED_GMM + ), "torch.nn.functional.scaled_grouped_mm not available. Install PyTorch 2.10+." + mm_fn = _mxfp8_grouped_mm + # scaled_grouped_mm requires each expert's token count aligned to 32, + # but swizzled MXFP8 scales require alignment to 128. Use 128 to + # satisfy both constraints. + expert_alignment = 128 + else: + assert ( + HAVE_GROUPED_MM + ), "torch.nn.functional.grouped_mm not available. Install PyTorch 2.10+." + mm_fn = _bf16_grouped_mm + expert_alignment = 16 + + activation_func = _get_activation_func(activation_type, fused_quant=use_fused_quant) + + # --- Pre-processing: permute or pad --- + if skip_permute: + assert tokens_per_expert is not None, "tokens_per_expert is required when skip_permute=True" + tokens_per_expert = tokens_per_expert.cuda().int() + assert routing_map is None, "routing_map must be None when skip_permute=True" + hidden_states, permutation_map, offs = pad_to_alignment( + hidden_states, tokens_per_expert, expert_alignment + ) + permuted_probs = None + + else: + assert routing_map is not None, "routing_map is required when skip_permute=False" + if use_fused_quant: + # Fused permute + MXFP8 quantize: single kernel produces MXFP8Tensor + hidden_states, permuted_probs, permutation_map, offs = permute_and_quantize_mxfp8( + hidden_states, + probs, + routing_map, + local_expert_start, + num_local_experts, + alignment=expert_alignment, + ) + else: + hidden_states, permuted_probs, permutation_map, offs = permute_tokens( + hidden_states, + probs, + routing_map, + local_expert_start, + num_local_experts, + alignment=expert_alignment, + ) + + # --- FC1 -> activation -> FC2 --- + # Quantize if MXFP8 path and hidden_states not already quantized (fused permute+quant + # produces MXFP8Tensor directly; skip_permute path always needs separate quant). + needs_quant = use_mxfp8 and not isinstance(hidden_states, MXFP8Tensor) + if needs_quant: + hidden_states = MXFP8Tensor.from_bf16(hidden_states, backend="triton") + fc1_output = mm_fn(hidden_states, fc1_weight, offs) + + activation_out = activation_func(fc1_output, permutation_map) + # Fused activation+quant returns MXFP8Tensor; otherwise quantize separately. + if use_mxfp8 and not isinstance(activation_out, MXFP8Tensor): + activation_out = MXFP8Tensor.from_bf16(activation_out, backend="triton") + fc2_output = mm_fn(activation_out, fc2_weight, offs) + # --- Post-processing: unpermute or unpad --- + if skip_permute: + probs_1d = probs.squeeze(-1) if probs.dim() > 1 else probs + return unpad_from_alignment(fc2_output, permutation_map, num_tokens, probs=probs_1d) + else: + return unpermute_tokens(fc2_output, permuted_probs, permutation_map, num_tokens) diff --git a/megatron/core/inference/moe/pad.py b/megatron/core/inference/moe/pad.py new file mode 100644 index 00000000000..656953b691c --- /dev/null +++ b/megatron/core/inference/moe/pad.py @@ -0,0 +1,201 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +"""Pad / unpad utilities for already-permuted expert tokens. + +When the token dispatcher has already permuted tokens into expert-grouped +order, these functions insert/remove alignment padding so that each expert's +token block satisfies the alignment requirements of grouped_mm / +scaled_grouped_mm. +""" + +from unittest.mock import MagicMock + +import torch +from packaging import version + +from megatron.core.utils import null_decorator + +try: + import triton + import triton.language as tl + + if version.parse(triton.__version__) < version.parse("3.4.0") and not torch.cuda.is_available(): + HAVE_TRITON = False + else: + HAVE_TRITON = tl.constexpr(version.parse(triton.__version__) >= version.parse("2.0.0")) +except ImportError: + HAVE_TRITON = False + +if not HAVE_TRITON: + triton = MagicMock() + triton.jit = null_decorator + tl = MagicMock() + +from megatron.core.inference.moe.permute import compute_expert_offsets + + +@triton.jit +def _pad_tokens_kernel( + src_ptr, + dst_ptr, + perm_map_ptr, + tpe_ptr, # tokens_per_expert [num_experts] + hidden_dim, + num_experts: tl.constexpr, + alignment: tl.constexpr, + BLOCK_H: tl.constexpr, +): + """Copy one input row into the padded output buffer. + + Computes unpadded and padded cumulative offsets inline from + tokens_per_expert, avoiding a separate cumsum kernel launch. + """ + row = tl.program_id(0) + + # Walk tokens_per_expert to find which expert this row belongs to + # and compute both unpadded and padded start offsets on the fly. + unpadded_start = tl.zeros([], dtype=tl.int32) + padded_start = tl.zeros([], dtype=tl.int32) + expert_id = -1 + for e in tl.static_range(0, num_experts): + count = tl.load(tpe_ptr + e).to(tl.int32) + if expert_id < 0 and row < unpadded_start + count: + expert_id = e + if expert_id < 0: + unpadded_start += count + aligned = tl.where( + count > 0, + ((count + alignment - 1) // alignment) * alignment, + tl.zeros([], dtype=tl.int32), + ) + padded_start += aligned + + if expert_id < 0: + return + + local_idx = row - unpadded_start + dst_row = padded_start + local_idx + + # Write permutation_map: padded row → original unpadded row + tl.store(perm_map_ptr + dst_row, row) + + # Copy hidden state + for h in tl.range(0, hidden_dim, BLOCK_H): + o = h + tl.arange(0, BLOCK_H) + m = o < hidden_dim + tl.store( + dst_ptr + dst_row * hidden_dim + o, + tl.load(src_ptr + row * hidden_dim + o, mask=m), + mask=m, + ) + + +def pad_to_alignment( + hidden_states: torch.Tensor, tokens_per_expert: torch.Tensor, alignment: int +) -> tuple: + """Pad already-permuted tokens so each expert's block is aligned. + + Args: + hidden_states: [total_tokens, hidden_size] already permuted by dispatcher. + tokens_per_expert: [num_local_experts] int32 token counts. + alignment: per-expert alignment. + + Returns: + (padded_hidden, permutation_map, inclusive_offsets) + - padded_hidden: [padded_total, hidden_size] + - permutation_map: [padded_total] int32, original row index or -1 for padding. + - inclusive_offsets: [num_local_experts] int32 cumulative aligned offsets for grouped_mm. + """ + num_experts = tokens_per_expert.shape[0] + total_tokens = hidden_states.shape[0] + hidden_dim = hidden_states.shape[1] + + # We still need padded_inc for the return value (used as offs by grouped_mm) + _, padded_inc = compute_expert_offsets(tokens_per_expert, alignment=alignment) + padded_total = int(padded_inc[-1].item()) + + padded_hidden = torch.zeros( + padded_total, hidden_dim, dtype=hidden_states.dtype, device=hidden_states.device + ) + permutation_map = torch.full( + (padded_total,), -1, dtype=torch.int32, device=hidden_states.device + ) + + if total_tokens > 0: + BLOCK_H = min(triton.next_power_of_2(hidden_dim), 1024) + _pad_tokens_kernel[(total_tokens,)]( + hidden_states, + padded_hidden, + permutation_map, + tokens_per_expert, + hidden_dim, + num_experts, + alignment, + BLOCK_H=BLOCK_H, + ) + + return padded_hidden, permutation_map, padded_inc + + +@triton.jit +def _unpad_tokens_kernel( + src_ptr, + dst_ptr, + perm_map_ptr, + probs_ptr, + hidden_dim, + has_probs: tl.constexpr, + BLOCK_H: tl.constexpr, +): + """Copy one real (non-padding) row from padded to unpadded layout. + + Optionally multiplies each row by its routing probability. + """ + row = tl.program_id(0) + dst_row = tl.load(perm_map_ptr + row) + if dst_row < 0: + return + if has_probs: + prob = tl.load(probs_ptr + dst_row) + for h in tl.range(0, hidden_dim, BLOCK_H): + o = h + tl.arange(0, BLOCK_H) + m = o < hidden_dim + v = tl.load(src_ptr + row * hidden_dim + o, mask=m) + if has_probs: + v = v * prob + tl.store(dst_ptr + dst_row * hidden_dim + o, v, mask=m) + + +def unpad_from_alignment( + padded_output: torch.Tensor, + permutation_map: torch.Tensor, + original_size: int, + probs: torch.Tensor = None, +) -> torch.Tensor: + """Remove alignment padding, scattering results back to original positions. + + Args: + padded_output: [padded_total, hidden_size] output from expert computation. + permutation_map: [padded_total] int32, original row index or -1 for padding. + original_size: number of rows in the unpadded output. + probs: optional [original_size] routing probabilities to multiply during unpad. + + Returns: + [original_size, hidden_size] unpadded output. + """ + hidden_dim = padded_output.shape[1] + output = torch.zeros( + original_size, hidden_dim, dtype=padded_output.dtype, device=padded_output.device + ) + has_probs = probs is not None + if padded_output.shape[0] > 0: + BLOCK_H = min(triton.next_power_of_2(hidden_dim), 1024) + _unpad_tokens_kernel[(padded_output.shape[0],)]( + padded_output, + output, + permutation_map, + probs if has_probs else padded_output, # dummy pointer when no probs + hidden_dim, + has_probs, + BLOCK_H=BLOCK_H, + ) + return output diff --git a/megatron/core/inference/moe/permute.py b/megatron/core/inference/moe/permute.py new file mode 100644 index 00000000000..b14d0b3dbd0 --- /dev/null +++ b/megatron/core/inference/moe/permute.py @@ -0,0 +1,458 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +"""Triton kernels for token permutation and unpermutation in fused MoE. + +Includes: +- Token counting per expert +- Expert offset computation (aligned prefix sums) +- Permute tokens into expert-grouped order +- Unpermute expert outputs back to original token order +""" + +from unittest.mock import MagicMock + +import torch + +from megatron.core.utils import null_decorator + +try: + import triton + import triton.language as tl + + HAVE_TRITON = True +except ImportError: + HAVE_TRITON = False + +if not HAVE_TRITON: + triton = MagicMock() + triton.jit = null_decorator + tl = MagicMock() + + +def _ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def _count_local_tokens_kernel( + routing_map_ptr, # [num_tokens * topk] flattened expert assignments + tokens_per_expert_ptr, # [num_local_experts] output counters (zeroed by caller) + total_pairs, # num_tokens * topk — total (token, topk) pairs + local_expert_start, # first global expert index owned by this rank + num_local_experts: tl.constexpr, # number of experts on this rank + BLOCK_SIZE: tl.constexpr, # number of pairs processed per program +): + """Count tokens routed to experts on this rank, ignoring tokens routed elsewhere. + + Each program processes BLOCK_SIZE (token, topk) pairs. Tokens assigned to + experts outside [local_expert_start, local_expert_start + num_local_experts) + are silently skipped. + """ + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < total_pairs + expert_ids = tl.load(routing_map_ptr + offsets, mask=mask, other=-1) + # Map global expert IDs to local indices; non-local experts become negative + local_ids = expert_ids - local_expert_start + is_local = (local_ids >= 0) & (local_ids < num_local_experts) & mask + tl.atomic_add(tokens_per_expert_ptr + local_ids, 1, mask=is_local) + + +def compute_local_tokens_per_expert( + routing_map: torch.Tensor, local_expert_start: int, num_local_experts: int +) -> torch.Tensor: + """Count tokens routed to each local expert.""" + total_pairs = routing_map.numel() + tokens_per_expert = torch.zeros(num_local_experts, dtype=torch.int32, device=routing_map.device) + BLOCK = 256 + _count_local_tokens_kernel[(_ceil_div(total_pairs, BLOCK),)]( + routing_map, + tokens_per_expert, + total_pairs, + local_expert_start, + num_local_experts, + BLOCK_SIZE=BLOCK, + ) + return tokens_per_expert + + +@triton.jit +def _prefix_sum_kernel( + tokens_per_expert_ptr, # [num_local_experts] raw token counts + exclusive_offsets_ptr, # [num_local_experts] output: exclusive prefix sum of aligned counts + inclusive_offsets_ptr, # [num_local_experts] output: inclusive prefix sum of aligned counts + num_local_experts, # number of experts on this rank + alignment: tl.constexpr, # per-expert alignment (counts rounded up to this multiple) + BLOCK_SIZE: tl.constexpr, # next_power_of_2(num_local_experts) for tl.cumsum +): + """Exclusive and inclusive prefix sums of aligned token counts. + + Each expert's token count is rounded up to the nearest multiple of + `alignment` (experts with 0 tokens stay at 0). The inclusive offsets + are used as `offs` by grouped_mm / scaled_grouped_mm. + """ + r = tl.arange(0, BLOCK_SIZE) + mask = r < num_local_experts + h = tl.load(tokens_per_expert_ptr + r, mask=mask, other=0) + # Round up non-zero counts to alignment boundary + if alignment > 1: + h = tl.where(h > 0, ((h + alignment - 1) // alignment) * alignment, h) + inc = tl.cumsum(h, axis=0) + tl.store(exclusive_offsets_ptr + r, inc - h, mask=mask) + tl.store(inclusive_offsets_ptr + r, inc, mask=mask) + + +def compute_expert_offsets(tokens_per_expert: torch.Tensor, alignment: int = 1) -> tuple: + """Compute exclusive and inclusive prefix sums of aligned token counts.""" + n = tokens_per_expert.shape[0] + exclusive_cumsum = torch.empty_like(tokens_per_expert) + inclusive_cumsum = torch.empty_like(tokens_per_expert) + _prefix_sum_kernel[(1,)]( + tokens_per_expert, + exclusive_cumsum, + inclusive_cumsum, + n, + alignment, + BLOCK_SIZE=triton.next_power_of_2(n), + ) + return exclusive_cumsum, inclusive_cumsum + + +@triton.jit +def _permute_tokens_kernel( + hidden_ptr, # [num_tokens, hidden_dim] input hidden states + probs_ptr, # [num_tokens, topk] routing probabilities + routing_map_ptr, # [num_tokens, topk] expert assignments (global IDs) + out_hidden_ptr, # [output_size, hidden_dim] output: permuted hidden states + out_probs_ptr, # [output_size] output: permuted probabilities + out_src_idx_ptr, # [output_size] output: permutation_map (original token index, -1 for padding) + counters_ptr, # [num_local_experts] exclusive offsets, + # atomically incremented to assign positions + num_tokens, # number of input tokens + hidden_dim, # hidden dimension + topk: tl.constexpr, # number of expert choices per token + local_expert_start, # first global expert index on this rank + num_local_experts: tl.constexpr, # number of experts on this rank + BLOCK_H: tl.constexpr, # tile size for copying hidden_dim +): + """Permute tokens into expert-grouped order. + + Grid: one program per (token, topk) pair. Each program looks up the assigned + expert, skips non-local experts, then atomically claims a position within + that expert's block and copies the hidden state + prob + source index. + """ + # Each program handles one (token, topk) pair + pair = tl.program_id(0) + tok = pair // topk + k = pair % topk + if tok >= num_tokens: + return + eid = tl.load(routing_map_ptr + tok * topk + k) + lid = eid - local_expert_start + # Skip tokens routed to non-local experts + if lid < 0 or lid >= num_local_experts: + return + # Atomically claim a position within this expert's aligned block + pos = tl.atomic_add(counters_ptr + lid, 1) + # Copy hidden state row + for h in tl.range(0, hidden_dim, BLOCK_H): + o = h + tl.arange(0, BLOCK_H) + m = o < hidden_dim + tl.store( + out_hidden_ptr + pos * hidden_dim + o, + tl.load(hidden_ptr + tok * hidden_dim + o, mask=m), + mask=m, + ) + tl.store(out_probs_ptr + pos, tl.load(probs_ptr + tok * topk + k)) + # Record source token index for unpermute + tl.store(out_src_idx_ptr + pos, tok) + + +def permute_tokens( + hidden_states: torch.Tensor, + probs: torch.Tensor, + routing_map: torch.Tensor, + local_expert_start: int, + num_local_experts: int, + alignment: int = 1, +) -> tuple: + """Permute tokens into expert-grouped order. + + Computes token counts, aligned expert offsets, output sizing, and + permutation in a single call. + + Args: + hidden_states: [num_tokens, hidden_size] input. + probs: [num_tokens, topk] routing probabilities. + routing_map: [num_tokens, topk] expert assignments. + local_expert_start: first global expert index on this rank. + num_local_experts: number of experts on this rank. + alignment: per-expert token alignment (default 1). + + Returns: + (permuted_hidden, permuted_probs, permutation_map, inclusive_offsets) + - permuted_hidden: [output_size, hidden_size] + - permuted_probs: [output_size] + - permutation_map: [output_size] int32, maps each permuted row back to + its original token index. Used by unpermute_tokens to scatter expert + outputs back and by activation kernels to skip padding rows (-1). + - inclusive_offsets: [num_local_experts] int32 cumulative offsets for grouped_mm + """ + num_tokens, hidden_dim = hidden_states.shape + topk = probs.shape[1] + + # Count how many (token, topk) pairs are routed to each local expert. + # Non-local experts are ignored. Result is [num_local_experts] int32. + tokens_per_expert = compute_local_tokens_per_expert( + routing_map, local_expert_start, num_local_experts + ) + + # exclusive_expert_offsets[i] = start of expert i's block in the padded output. + # Used as the initial counter for atomic position assignment in the permute kernel. + # inclusive_expert_offsets[i] = end of expert i's block (= start of expert i+1). + # Passed as `offs` to grouped_mm / scaled_grouped_mm to delimit expert boundaries. + exclusive_expert_offsets, inclusive_expert_offsets = compute_expert_offsets( + tokens_per_expert, alignment=alignment + ) + output_size = num_tokens * min(topk, num_local_experts) + alignment * num_local_experts + + permuted_hidden = torch.empty( + output_size, hidden_dim, dtype=hidden_states.dtype, device=hidden_states.device + ) + permuted_probs = torch.empty(output_size, dtype=probs.dtype, device=probs.device) + permutation_map = torch.full((output_size,), -1, dtype=torch.int32, device=probs.device) + BLOCK_H = min(triton.next_power_of_2(hidden_dim), 1024) + _permute_tokens_kernel[(num_tokens * topk,)]( + hidden_states, + probs, + routing_map, + permuted_hidden, + permuted_probs, + permutation_map, + exclusive_expert_offsets, + num_tokens, + hidden_dim, + topk, + local_expert_start, + num_local_experts, + BLOCK_H=BLOCK_H, + ) + return permuted_hidden, permuted_probs, permutation_map, inclusive_expert_offsets + + +@triton.jit +def _unpermute_tokens_kernel( + expert_out_ptr, # [output_size, hidden_dim] expert outputs in permuted order + probs_ptr, # [output_size] fp32 routing probabilities (permuted) + src_idx_ptr, # [output_size] permutation_map: original token index, or -1 for padding + output_ptr, # [num_tokens, hidden_dim] fp32 output buffer (zeroed by caller) + hidden_dim, # hidden dimension + BLOCK_H: tl.constexpr, # tile size for processing hidden_dim +): + """Scatter weighted expert outputs back to original token positions. + + Grid: one program per row of expert_out. Padding rows (src_idx == -1) are + skipped. Multiple topk selections for the same token are accumulated via + atomic adds. All arithmetic is in fp32 to avoid precision loss. + """ + row = tl.program_id(0) + source_idx = tl.load(src_idx_ptr + row) + # Skip padding rows + if source_idx < 0: + return + prob = tl.load(probs_ptr + row) # fp32 + for h in tl.range(0, hidden_dim, BLOCK_H): + offsets = h + tl.arange(0, BLOCK_H) + m = offsets < hidden_dim + # Upcast bf16 expert output to fp32 before multiply + accumulate + v = tl.load(expert_out_ptr + row * hidden_dim + offsets, mask=m).to(tl.float32) + tl.atomic_add(output_ptr + source_idx * hidden_dim + offsets, v * prob, mask=m) + + +def unpermute_tokens( + expert_output: torch.Tensor, + permuted_probs: torch.Tensor, + permutation_map: torch.Tensor, + num_tokens: int, +) -> torch.Tensor: + """Unpermute expert outputs back to original token order. + + Accumulates in fp32 to avoid precision loss from multiple topk atomic adds. + Returns fp32 output. + """ + assert ( + permuted_probs.dtype == torch.float32 + ), f"permuted_probs must be fp32, got {permuted_probs.dtype}" + output_size, hidden_dim = expert_output.shape + output = torch.zeros(num_tokens, hidden_dim, dtype=torch.float32, device=expert_output.device) + BLOCK_H = min(triton.next_power_of_2(hidden_dim), 1024) + _unpermute_tokens_kernel[(output_size,)]( + expert_output, permuted_probs, permutation_map, output, hidden_dim, BLOCK_H=BLOCK_H + ) + return output + + +@triton.jit +def _permute_quantize_mxfp8_kernel( + hidden_ptr, + probs_ptr, + routing_map_ptr, + out_fp8_ptr, + out_scale_ptr, + out_probs_ptr, + out_src_idx_ptr, + counters_ptr, + num_tokens, + K, + n_col_blocks, + topk: tl.constexpr, + local_expert_start, + num_local_experts: tl.constexpr, + REAL_GROUPS: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_GROUPS: tl.constexpr, +): + """Fused permute + MXFP8 quantize + swizzle in one kernel. + + Grid: (num_tokens * topk,) — one program per (token, k) pair. + Reads BF16 from source token, quantizes to FP8 e4m3, writes FP8 data + + swizzled e8m0 scales to the permuted write position. + """ + pair = tl.program_id(0) + tok = pair // topk + k = pair % topk + if tok >= num_tokens: + return + eid = tl.load(routing_map_ptr + tok * topk + k) + lid = eid - local_expert_start + if lid < 0 or lid >= num_local_experts: + return + + pos = tl.atomic_add(counters_ptr + lid, 1) + + # Load full row from source token + offs = tl.arange(0, BLOCK_K) + mask = offs < K + x = tl.load(hidden_ptr + tok * K + offs, mask=mask, other=0.0).to(tl.float32) + + # Per-group-of-32 quantization + x_grouped = tl.reshape(x, [BLOCK_GROUPS, 32]) + abs_grouped = tl.abs(x_grouped) + max_vals = tl.max(abs_grouped, axis=1) + + dequant_scale = max_vals / 448.0 + dequant_exp = (dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF) & 0x7F800000 + dequant_rounded = dequant_exp.to(tl.float32, bitcast=True) + quant_scale = tl.where(dequant_rounded == 0, 0.0, 1.0 / dequant_rounded) + + quantized = x_grouped * quant_scale[:, None] + quantized_flat = tl.reshape(quantized, [BLOCK_K]) + out_fp8 = quantized_flat.to(tl.float8e4nv) + + # Store FP8 data at permuted position + tl.store(out_fp8_ptr + pos * K + offs, out_fp8, mask=mask) + + # Store swizzled scales at permuted position + scale_exp = (dequant_exp >> 23).to(tl.uint8) + col_offs = tl.arange(0, BLOCK_GROUPS) + col_mask = col_offs < REAL_GROUPS + + macro_row_block = pos // 128 + macro_col_block = col_offs // 4 + local_row = pos % 128 + local_col = col_offs % 4 + group = local_row // 32 + sub_row = local_row % 32 + tile_idx = macro_row_block * n_col_blocks + macro_col_block + swizzled_offs = tile_idx * 512 + sub_row * 16 + group * 4 + local_col + + tl.store(out_scale_ptr + swizzled_offs, scale_exp, mask=col_mask) + + # Store prob and source index + tl.store(out_probs_ptr + pos, tl.load(probs_ptr + tok * topk + k)) + tl.store(out_src_idx_ptr + pos, tok) + + +def permute_and_quantize_mxfp8( + hidden_states: torch.Tensor, + probs: torch.Tensor, + routing_map: torch.Tensor, + local_expert_start: int, + num_local_experts: int, + alignment: int = 128, +) -> tuple: + """Fused permute + MXFP8 quantize + swizzle. + + Self-contained API matching permute_tokens: computes token counts, aligned + expert offsets, output sizing, permutation, and MXFP8 quantization in a + single kernel launch. + + Args: + hidden_states: [num_tokens, hidden_size] BF16 input. + probs: [num_tokens, topk] routing probabilities. + routing_map: [num_tokens, topk] expert assignments. + local_expert_start: first global expert index on this rank. + num_local_experts: number of experts on this rank. + alignment: per-expert token alignment (default 128, required for MXFP8 swizzle). + + Returns: + (permuted_mxfp8, permuted_probs, permutation_map, inclusive_offsets) + - permuted_mxfp8: MXFP8Tensor with .data [output_size, K] and .scale (swizzled) + - permuted_probs: [output_size] routing probs + - permutation_map: [output_size] int32, original token index or -1 for padding + - inclusive_offsets: [num_local_experts] int32 cumulative offsets for scaled_grouped_mm + """ + from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor + + num_tokens, K = hidden_states.shape + topk = probs.shape[1] + assert K % 32 == 0 + + # Count how many (token, topk) pairs are routed to each local expert. + tokens_per_expert = compute_local_tokens_per_expert( + routing_map, local_expert_start, num_local_experts + ) + + # exclusive_expert_offsets[i] = start of expert i's block in the padded output. + # inclusive_expert_offsets[i] = end of expert i's block (= start of expert i+1). + exclusive_expert_offsets, inclusive_expert_offsets = compute_expert_offsets( + tokens_per_expert, alignment=alignment + ) + output_size = num_tokens * min(topk, num_local_experts) + alignment * num_local_experts + + scale_cols = K // 32 + n_row_blocks = _ceil_div(output_size, 128) + n_col_blocks = _ceil_div(scale_cols, 4) + total_scale_bytes = n_row_blocks * n_col_blocks * 512 + + out_fp8 = torch.empty(output_size, K, dtype=torch.float8_e4m3fn, device=hidden_states.device) + out_scale = torch.zeros(total_scale_bytes, dtype=torch.uint8, device=hidden_states.device) + permuted_probs = torch.empty(output_size, dtype=probs.dtype, device=probs.device) + permutation_map = torch.full((output_size,), -1, dtype=torch.int32, device=probs.device) + + BLOCK_K = triton.next_power_of_2(K) + BLOCK_GROUPS = BLOCK_K // 32 + + _permute_quantize_mxfp8_kernel[(num_tokens * topk,)]( + hidden_states, + probs, + routing_map, + out_fp8, + out_scale, + permuted_probs, + permutation_map, + exclusive_expert_offsets, + num_tokens, + K, + n_col_blocks, + topk, + local_expert_start, + num_local_experts, + REAL_GROUPS=scale_cols, + BLOCK_K=BLOCK_K, + BLOCK_GROUPS=BLOCK_GROUPS, + ) + + permuted_mxfp8 = MXFP8Tensor( + data=out_fp8, scale=out_scale.view(torch.float8_e8m0fnu), backend="triton" + ) + return permuted_mxfp8, permuted_probs, permutation_map, inclusive_expert_offsets diff --git a/megatron/core/inference/quantization/mxfp8_quantize.py b/megatron/core/inference/quantization/mxfp8_quantize.py new file mode 100644 index 00000000000..73f2ac974b3 --- /dev/null +++ b/megatron/core/inference/quantization/mxfp8_quantize.py @@ -0,0 +1,185 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""Standalone MXFP8 quantization kernel with fused scale swizzle. + +One block per token. Quantizes BF16 → FP8 e4m3 and writes scales directly +in cuBLAS 2D blocked (swizzled) layout. No FP4, no triton_kernels dependency. + +Usage: + from megatron.core.inference.quantization.mxfp8_quantize import mxfp8_quantize + data, swizzled_scales, total_scale_bytes = mxfp8_quantize(x_bf16) + # data: [M, K] float8_e4m3fn + # swizzled_scales: 1D uint8 in cuBLAS blocked layout +""" + +import torch +import triton +import triton.language as tl + + +def _ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def _mxfp8_quant_swizzle_kernel( + out_ptr, # [M, K] output buffer for float8_e4m3fn quantized data + scale_ptr, # 1D output buffer for swizzled uint8 scales (e8m0 exponents) + src_ptr, # [M, K] input tensor in bf16/fp16/fp32 + K, # number of columns in the input (must be divisible by 32) + n_col_blocks, # ceil(K/32 / 4) — number of macro-tile columns in the swizzle layout + REAL_GROUPS: tl.constexpr, # actual number of scale groups per row (K // 32) + BLOCK_K: tl.constexpr, # next_power_of_2(K) — padded column count for tl.reshape + BLOCK_GROUPS: tl.constexpr, # BLOCK_K // 32 — padded group count (must be power of 2) +): + """Each triton block quantizes one row → FP8 e4m3, write scales directly in swizzled layout. + + We use round up in scale calculation. see: Mishra et al., + Recipes for Pre-training LLMs with MXFP8 (https://arxiv.org/pdf/2506.08027) + + The implementation borrows code from the triton upstream MXFP downcast kernel: + https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py + + Note on swizzled scale layout (torch.nn.functional.SwizzleType.SWIZZLE_32_4_4): + + Background: In MXFP8, every group of 32 elements shares one 1-byte scale + (an e8m0 exponent). For an [M, K] matrix, this gives an [M, K//32] scale + matrix. cuBLAS doesn't read these scales in simple row-major order — it + expects a "swizzled" layout optimized for its internal access patterns. + + Step 1 — Divide into macro-tiles: + The scale matrix is partitioned into 128-row x 4-col macro-tiles. + Each tile is stored as a contiguous 512-byte (128 x 4) block. + + Step 2 — Interleave within each tile: + Within a macro-tile, the 128 rows are NOT stored sequentially. + Instead, they are split into 4 groups of 32 rows: + group 0: rows 0- 31 + group 1: rows 32- 63 + group 2: rows 64- 95 + group 3: rows 96-127 + + Rows with the same position within their group (same "sub_row") + are placed next to each other. So the memory layout is: + + Concretely, for sub_row=0: + byte 0: row 0, col 0 + byte 1: row 0, col 1 + byte 2: row 0, col 2 + byte 3: row 0, col 3 + byte 4: row 32, col 0 + byte 5: row 32, col 1 + byte 6: row 32, col 2 + byte 7: row 32, col 3 + byte 8: row 64, col 0 + ... + byte 15: row 96, col 3 + + The formula to map logical (row, col) → byte offset: + tile_idx = (row // 128) * n_col_blocks + (col // 4) + sub_row = row % 32 + group = (row % 128) // 32 + local_col = col % 4 + offset = tile_idx * 512 + sub_row * 16 + group * 4 + local_col + + """ + row = tl.program_id(0) + src_row = src_ptr + row * K + out_row = out_ptr + row * K + + offs = tl.arange(0, BLOCK_K) + mask = offs < K + + # Load full row + x = tl.load(src_row + offs, mask=mask, other=0.0).to(tl.float32) + + # Per-group-of-32 max + x_grouped = tl.reshape(x, [BLOCK_GROUPS, 32]) + abs_grouped = tl.abs(x_grouped) + max_vals = tl.max(abs_grouped, axis=1) + + # 448 is the max representable value in FP8 e4m3. + # dequant_scale = min scale s.t. max_val / scale <= 448. + dequant_scale = max_vals / 448.0 + # Round up to next power of 2 via integer bit manipulation: + # Adding 0x007FFFFF (mantissa mask) before masking with 0x7F800000 + # (exponent-only mask) bumps the exponent if any mantissa bits are set. + # Result: 2^ceil(log2(max/448)) as a uint32-encoded float. + dequant_exp = (dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF) & 0x7F800000 + # Reinterpret uint32 back as float32 — now a power-of-2 dequantization scale. + dequant_rounded = dequant_exp.to(tl.float32, bitcast=True) + # Quantization scale is the reciprocal; guard against div-by-zero for all-zero groups. + quant_scale = tl.where(dequant_rounded == 0, 0.0, 1.0 / dequant_rounded) + + # Quantize + quantized = x_grouped * quant_scale[:, None] + quantized_flat = tl.reshape(quantized, [BLOCK_K]) + out_fp8 = quantized_flat.to(tl.float8e4nv) + + # Store FP8 data + tl.store(out_row + offs, out_fp8, mask=mask) + + # Store swizzled scales + scale_exp = (dequant_exp >> 23).to(tl.uint8) + col_offs = tl.arange(0, BLOCK_GROUPS) + col_mask = col_offs < REAL_GROUPS + + # Compute swizzled offsets for each scale element. + # + # The scale matrix [M, K//32] is divided into 128×4 macro-tiles. + # Within each tile, rows are split into 4 groups of 32 (group = local_row // 32). + # Rather than flattening row-major, the layout interleaves groups so that + # rows 32 apart are adjacent in memory: + # + # offset = tile_idx * 512 + sub_row * 16 + group * 4 + local_col + macro_row_block = row // 128 + macro_col_block = col_offs // 4 + local_row = row % 128 + local_col = col_offs % 4 + group = local_row // 32 + sub_row = local_row % 32 + tile_idx = macro_row_block * n_col_blocks + macro_col_block + swizzled_offs = tile_idx * 512 + sub_row * 16 + group * 4 + local_col + + tl.store(scale_ptr + swizzled_offs, scale_exp, mask=col_mask) + + +def mxfp8_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize a 2D tensor to MXFP8 with fused scale swizzle. + + Args: + x: [M, K] tensor in bf16/fp16/fp32. K must be divisible by 32. + + Returns: + (data, swizzled_scales): + data: [M, K] float8_e4m3fn + swizzled_scales: 1D tensor in cuBLAS blocked layout (uint8/e8m0) + """ + assert x.is_cuda and x.dim() == 2 + assert x.dtype in (torch.bfloat16, torch.float16, torch.float32) + M, K = x.shape + assert K % 32 == 0, f"K ({K}) must be divisible by 32" + + scale_cols = K // 32 + n_row_blocks = _ceil_div(M, 128) + n_col_blocks = _ceil_div(scale_cols, 4) + total_scale_bytes = n_row_blocks * n_col_blocks * 512 + + out_data = torch.empty(M, K, dtype=torch.float8_e4m3fn, device=x.device) + out_scale = torch.zeros(total_scale_bytes, dtype=torch.uint8, device=x.device) + + BLOCK_K = triton.next_power_of_2(K) + BLOCK_GROUPS = BLOCK_K // 32 + + _mxfp8_quant_swizzle_kernel[(M,)]( + out_data, + out_scale, + x, + K, + n_col_blocks, + REAL_GROUPS=scale_cols, + BLOCK_K=BLOCK_K, + BLOCK_GROUPS=BLOCK_GROUPS, + ) + + return out_data, out_scale.view(torch.float8_e8m0fnu) diff --git a/megatron/core/inference/quantization/mxfp8_tensor.py b/megatron/core/inference/quantization/mxfp8_tensor.py index 7cf7225c2bb..505f4f9ebc5 100644 --- a/megatron/core/inference/quantization/mxfp8_tensor.py +++ b/megatron/core/inference/quantization/mxfp8_tensor.py @@ -6,32 +6,67 @@ import torch try: - from flashinfer import mxfp8_quantize + from flashinfer import mxfp8_quantize as flashinfer_mxfp8_quantize HAVE_FLASHINFER = True except ImportError: HAVE_FLASHINFER = False +from megatron.core.inference.quantization.mxfp8_quantize import ( + mxfp8_quantize as mcore_mxfp8_quantize, +) + + +def _ceil_div(a, b): + return (a + b - 1) // b + @dataclass class MXFP8Tensor: - """MXFP8 tensor wrapper class.""" + """MXFP8 tensor wrapper storing quantized fp8_e4m3 data and swizzled e8m0 scales.""" - data: torch.Tensor - scale: torch.Tensor + data: torch.Tensor # [M, K] fp8_e4m3fn + scale: torch.Tensor # 1D, swizzled cuBLAS blocked layout, e8m0 + backend: Optional[str] = None # quantization backend: 'flashinfer' or 'triton' def size(self, idx: Optional[int] = None): """Wrapper for calling self.data.size()""" return self.data.size(idx) - @classmethod - def from_bf16(cls, x: torch.Tensor, group_size: int = 32): - """Quantize BF16 tensor to MXFP8 format using FlashInfer.""" + def scale_2d(self, K: Optional[int] = None) -> torch.Tensor: + """Reshape 1D swizzled scale to 2D for scaled_grouped_mm / scaled_mm. + + Swizzle pads rows to multiples of 128 and cols to multiples of 4. + Returns (padded_M, padded_cols) where padded_cols = ceil(K//32, 4) * 4. + """ + if self.scale.dim() == 2: + return self.scale + if K is None: + K = self.data.shape[-1] + n_col_blocks = _ceil_div(K // 32, 4) + padded_cols = n_col_blocks * 4 + return self.scale.reshape(-1, padded_cols) - assert HAVE_FLASHINFER, "Need flashinfer for mxfp8 quantization" - assert x.is_cuda, "Input must be on CUDA" - assert x.dim() == 2, "Input must be 2D [M, K]" - M, K = x.shape - assert K % group_size == 0, f"K ({K}) must be divisible by group_size ({group_size})" + @classmethod + def from_bf16(cls, x: torch.Tensor, group_size: int = 32, backend: str = "flashinfer"): + """Quantize BF16 tensor to MXFP8. - return cls(*mxfp8_quantize(x)) + Args: + x: [M, K] BF16 tensor on CUDA. + group_size: MXFP8 group size (default 32). + backend: 'triton' (fused quantize + swizzle Triton kernel) or + 'flashinfer' (single fused FlashInfer CUDA kernel). + """ + assert x.is_cuda and x.dim() == 2 + assert x.shape[-1] % group_size == 0 + if backend == "flashinfer": + assert HAVE_FLASHINFER, "FlashInfer not available" + return cls(*flashinfer_mxfp8_quantize(x), backend=backend) + elif backend == "triton": + xq, xs = mcore_mxfp8_quantize(x) + return cls(data=xq, scale=xs, backend=backend) + else: + raise ValueError( + f"Unknown MXFP8 quantization backend: '{backend}'. " + "Must be 'triton' or 'flashinfer'." + ) diff --git a/megatron/core/inference/quantization/utils.py b/megatron/core/inference/quantization/utils.py index daf46e11ccd..b4a63ae0657 100644 --- a/megatron/core/inference/quantization/utils.py +++ b/megatron/core/inference/quantization/utils.py @@ -15,14 +15,21 @@ try: from flashinfer import mm_mxfp8 as flashinfer_mm_mxfp8 - from flashinfer import mxfp8_quantize HAVE_FLASHINFER = True except ImportError: HAVE_FLASHINFER = False +try: + from torch.nn.functional import ScalingType, SwizzleType + from torch.nn.functional import scaled_mm as torch_scaled_mm + + HAVE_TORCH_SCALED_MM = True +except ImportError: + HAVE_TORCH_SCALED_MM = False -def _verify_te_to_flashinfer_mxfp8_conversion(te_dequantized, fi_quantized: MXFP8Tensor) -> None: + +def _verify_te_to_mcore_mxfp8_conversion(te_dequantized, fi_quantized: MXFP8Tensor) -> None: # Sanity check: compare the first logical block (32 values) # Slice logical dimensions first to naturally handle any data swizzling/strides te_block = te_dequantized[0, :32].float() @@ -43,37 +50,43 @@ def _verify_te_to_flashinfer_mxfp8_conversion(te_dequantized, fi_quantized: MXFP raise ValueError(f"MXFP8 sanity check failed. Diff norm: {diff_norm}") -def quantize_model_to_mxfp8(model: torch.nn.Module) -> None: - """ - Converts a TE MXFP8 model to a FlashInfer MXFP8 model by - recursively translating each layer's weights. +def quantize_model_to_mxfp8(model: torch.nn.Module, backend: str = "flashinfer") -> None: + """Convert TE MXFP8 weights to mcore MXFP8Tensor format. + + Recursively walks the model and replaces each TEMXFP8Tensor parameter + with an MXFP8Tensor re-quantized via the specified backend. + + Args: + model: The model whose TE MXFP8 parameters should be converted. + backend: 'flashinfer' or 'triton' quantization backend. """ assert HAVE_TE - assert HAVE_FLASHINFER + import logging + + rank = torch.distributed.get_rank() + if backend == "flashinfer": + assert HAVE_FLASHINFER, "FlashInfer not available for MXFP8 quantization" - # Recurse through child modules for child in model.children(): - quantize_model_to_mxfp8(child) + quantize_model_to_mxfp8(child, backend=backend) def replace_in_dict(attr_dict): """Helper function to replace TE MXFP8 weights.""" keys = list(attr_dict.keys()) for key in keys: val = attr_dict[key] - if isinstance(val, TEMXFP8Tensor): - # Undo the TE quantization and re-quantize with FlashInfer + is_te_mxfp8 = isinstance(val, TEMXFP8Tensor) or ( + hasattr(val, 'data') and isinstance(val.data, TEMXFP8Tensor) + ) + if is_te_mxfp8: + # Undo the TE quantization and re-quantize # Note that this introduces a one-time overhead but avoids any - # numerical differences between TE and FlashInfer MXFP8 formats + # numerical differences between TE and mcore MXFP8 formats te_dequantized = val.dequantize() - fi_quantized = MXFP8Tensor.from_bf16(te_dequantized) - - # Sanity check the numerical correctness of the TE -> FlashInfer conversion - _verify_te_to_flashinfer_mxfp8_conversion(te_dequantized, fi_quantized) - - # Remove the existing TE parameter and then replace the - # attribute with the re-quantized tensor + mcore_quantized = MXFP8Tensor.from_bf16(te_dequantized, backend=backend) + _verify_te_to_mcore_mxfp8_conversion(te_dequantized, mcore_quantized) del model._parameters[key] - setattr(model, key, fi_quantized) + setattr(model, key, mcore_quantized) if hasattr(model, '_parameters') and model._parameters: replace_in_dict(model._parameters) @@ -87,6 +100,8 @@ def _should_quantize_param(val: torch.Tensor) -> bool: return False if HAVE_TE and isinstance(val, TEMXFP8Tensor): return True + if HAVE_TE and hasattr(val, 'data') and isinstance(val.data, TEMXFP8Tensor): + return True if ( isinstance(val, torch.nn.Parameter) and val.dim() == 2 @@ -100,6 +115,8 @@ def _to_bf16(val: torch.Tensor) -> torch.Tensor: """Convert a parameter value to BF16 for quantization.""" if HAVE_TE and isinstance(val, TEMXFP8Tensor): return val.dequantize() + if HAVE_TE and hasattr(val, 'data') and isinstance(val.data, TEMXFP8Tensor): + return val.data.dequantize() return val.data.to(torch.bfloat16) @@ -126,8 +143,9 @@ def quantize_params_to_mxfp8( model: torch.nn.Module, persistent_buffers: Optional[Dict[str, MXFP8Tensor]] = None, _prefix: str = "", + backend: str = "flashinfer", ) -> Dict[str, MXFP8Tensor]: - """Quantize model parameters to FlashInfer MXFP8Tensor format. + """Quantize model parameters to MXFP8Tensor format. Handles both TEMXFP8Tensor (fp8_param=True) and BF16/FP16 nn.Parameter inputs. When *persistent_buffers* is provided, new quantized values are @@ -140,11 +158,13 @@ def quantize_params_to_mxfp8( parameter names to previously-created ``MXFP8Tensor`` objects. Updated in-place and returned. _prefix: Internal recursion prefix – callers should not set this. + backend: 'flashinfer' or 'triton' quantization backend. Returns: The ``persistent_buffers`` dict (created on first call if ``None``). """ - assert HAVE_FLASHINFER + if backend == "flashinfer": + assert HAVE_FLASHINFER, "FlashInfer not available for MXFP8 quantization" if persistent_buffers is None: persistent_buffers = {} @@ -152,7 +172,9 @@ def quantize_params_to_mxfp8( # Recurse through child modules for child_name, child_module in model.named_children(): child_prefix = f"{_prefix}{child_name}." if _prefix else f"{child_name}." - quantize_params_to_mxfp8(child_module, persistent_buffers, _prefix=child_prefix) + quantize_params_to_mxfp8( + child_module, persistent_buffers, _prefix=child_prefix, backend=backend + ) # Process parameters owned directly by this module if hasattr(model, '_parameters') and model._parameters: @@ -169,36 +191,76 @@ def quantize_params_to_mxfp8( if fqn in persistent_buffers: # Subsequent call: copy into existing tensors to preserve addresses - new_data, new_scale = mxfp8_quantize(bf16_data) - persistent_buffers[fqn].data.copy_(new_data) - persistent_buffers[fqn].scale.copy_(new_scale) - fi_tensor = persistent_buffers[fqn] + new_tensor = MXFP8Tensor.from_bf16(bf16_data, backend=backend) + persistent_buffers[fqn].data.copy_(new_tensor.data) + persistent_buffers[fqn].scale.copy_(new_tensor.scale) + mcore_tensor = persistent_buffers[fqn] else: # First call: create new MXFP8Tensor - fi_tensor = MXFP8Tensor.from_bf16(bf16_data) + mcore_tensor = MXFP8Tensor.from_bf16(bf16_data, backend=backend) # Verify correctness for TEMXFP8Tensor inputs if HAVE_TE and isinstance(val, TEMXFP8Tensor): - _verify_te_to_flashinfer_mxfp8_conversion(bf16_data, fi_tensor) + _verify_te_to_mcore_mxfp8_conversion(bf16_data, mcore_tensor) - persistent_buffers[fqn] = fi_tensor + persistent_buffers[fqn] = mcore_tensor # Replace nn.Parameter with MXFP8Tensor attribute del model._parameters[key] - setattr(model, key, fi_tensor) + setattr(model, key, mcore_tensor) return persistent_buffers +def _mm_mxfp8_flashinfer(x_mxfp8: MXFP8Tensor, weight: MXFP8Tensor, out=None): + """MXFP8 matmul via FlashInfer.""" + return flashinfer_mm_mxfp8( + x_mxfp8.data, weight.data.T, x_mxfp8.scale, weight.scale, out_dtype=torch.bfloat16, out=out + ) + + +def _mm_mxfp8_torch(x_mxfp8: MXFP8Tensor, weight: MXFP8Tensor, out=None): + """MXFP8 matmul via torch.nn.functional.scaled_mm.""" + result = torch_scaled_mm( + x_mxfp8.data, + weight.data.t(), + x_mxfp8.scale_2d(), + ScalingType.BlockWise1x32, + weight.scale, + ScalingType.BlockWise1x32, + swizzle_a=SwizzleType.SWIZZLE_32_4_4, + swizzle_b=SwizzleType.SWIZZLE_32_4_4, + output_dtype=torch.bfloat16, + ) + if out is not None: + out.copy_(result) + return out + return result + + def mm_mxfp8(x: torch.Tensor, weight: MXFP8Tensor, out: torch.Tensor = None): - """ - Computes a matmul in MXFP8 using FlashInfer. + """Compute a matmul in MXFP8. - Quantizes the bf16 input activation tensor. Weight must be pre-quantized. + Quantizes the bf16 input activation tensor on the fly. Weight must be + pre-quantized. Dispatches to FlashInfer or torch based on weight.backend. """ - assert HAVE_FLASHINFER - - x = MXFP8Tensor.from_bf16(x.squeeze(1)) - return flashinfer_mm_mxfp8( - x.data, weight.data.T, x.scale, weight.scale, out_dtype=torch.bfloat16, out=out - ).unsqueeze(1) + backend = weight.backend + assert ( + backend is not None + ), "weight.backend is None — was the weight created via MXFP8Tensor.from_bf16?" + + x_squeezed = x.squeeze(1) + x_mxfp8 = MXFP8Tensor.from_bf16(x_squeezed, backend=backend) + + if backend == "flashinfer": + assert HAVE_FLASHINFER, "FlashInfer not available for MXFP8 matmul" + result = _mm_mxfp8_flashinfer(x_mxfp8, weight, out=out) + elif backend == "triton": + assert ( + HAVE_TORCH_SCALED_MM + ), "torch.nn.functional.scaled_mm with ScalingType/SwizzleType not available" + result = _mm_mxfp8_torch(x_mxfp8, weight, out=out) + else: + raise ValueError(f"Unknown MXFP8 backend: '{backend}'") + + return result.unsqueeze(1) diff --git a/megatron/core/models/backends.py b/megatron/core/models/backends.py index e867e91c003..e50535cc8fe 100644 --- a/megatron/core/models/backends.py +++ b/megatron/core/models/backends.py @@ -36,12 +36,12 @@ from megatron.core.extensions.transformer_engine import ( TEActivationOp, - TEColumnParallelLinear, TEDotProductAttention, TELinear, TENorm, ) from megatron.core.tensor_parallel.inference_layers import ( + InferenceColumnParallelLinear, InferenceLayerNormColumnParallelLinear, InferenceRowParallelLinear, ) @@ -152,7 +152,7 @@ def linear(self) -> type: def column_parallel_linear(self) -> type: """Which column parallel linear module TE backend uses""" - return TEColumnParallelLinear + return InferenceColumnParallelLinear def row_parallel_linear(self) -> type: """Which row parallel linear module TE backend uses""" diff --git a/megatron/core/tensor_parallel/__init__.py b/megatron/core/tensor_parallel/__init__.py index 2140ee54b37..0852014a859 100644 --- a/megatron/core/tensor_parallel/__init__.py +++ b/megatron/core/tensor_parallel/__init__.py @@ -1,7 +1,11 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. from .cross_entropy import vocab_parallel_cross_entropy from .data import broadcast_data -from .inference_layers import InferenceLayerNormColumnParallelLinear, InferenceRowParallelLinear +from .inference_layers import ( + InferenceColumnParallelLinear, + InferenceLayerNormColumnParallelLinear, + InferenceRowParallelLinear, +) from .layers import ( ColumnParallelLinear, RowParallelLinear, diff --git a/megatron/core/tensor_parallel/inference_layers.py b/megatron/core/tensor_parallel/inference_layers.py index 5946cac0ac8..17726a29e53 100644 --- a/megatron/core/tensor_parallel/inference_layers.py +++ b/megatron/core/tensor_parallel/inference_layers.py @@ -5,6 +5,7 @@ import torch.distributed as dist from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, TELayerNormColumnParallelLinear, TERowParallelLinear, ) @@ -186,6 +187,104 @@ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, None]: return x, None +class InferenceColumnParallelLinear(TEColumnParallelLinear): + """ + Inference optimized version of TEColumnParallelLinear. + """ + + def __init__( + self, + input_size: int, + output_size: int, + *, + config: TransformerConfig, + init_method: Callable, + gather_output: bool, + bias: bool, + skip_bias_add: bool, + is_expert: bool, + stride: int = 1, + skip_weight_param_allocation: bool = False, + tp_comm_buffer_name: Optional[str] = None, + tp_group: Optional[torch.distributed.ProcessGroup] = None, + ): + assert HAVE_TE, "--transformer-impl=inference_optimized requires transformer engine" + super().__init__( + input_size, + output_size, + config=config, + init_method=init_method, + gather_output=gather_output, + bias=bias, + skip_bias_add=skip_bias_add, + is_expert=is_expert, + stride=stride, + skip_weight_param_allocation=skip_weight_param_allocation, + tp_comm_buffer_name=tp_comm_buffer_name, + tp_group=tp_group, + ) + self.tp_group = get_tensor_model_parallel_group_if_none(tp_group, is_expert=is_expert) + self.tp_size = dist.get_world_size(self.tp_group) + + assert ( + output_size % self.tp_size == 0 + ), f"output_size ({output_size}) must be divisible by tp_size ({self.tp_size})" + + if self.tp_size > 1: + assert ( + config.sequence_parallel + ), "--transformer-impl=inference_optimized requires --sequence-parallel" + + self.triton_nvls_kernels_allowed = not config.inference_disable_triton_nvls_kernels + + def _maybe_allocate_symmetric_buffer(self, x: torch.Tensor): + """ + Attempt to allocate symmetric memory buffer for all-gather. + """ + symm_mem_buffer_dims = list(x.size()) + symm_mem_buffer_dims[0] *= self.tp_size + buf = SymmetricMemoryManager.get_buffer("tp", process_group=self.tp_group) + symm_mem_buffer = buf.maybe_get_tensor(symm_mem_buffer_dims, dtype=x.dtype) + return symm_mem_buffer + + def _all_gather(self, x: torch.Tensor, symm_mem_buffer: dict) -> None: + """ + Attempt an NVLS all-gather into symmetric memory. If not possible, + revert to torch dist (NCCL) all-gather. + """ + if self.tp_size == 1: + return x + + can_use_nvls = ( + self.triton_nvls_kernels_allowed + and are_tensors_nvls_eligible(x) + and symm_mem_buffer["handle"] is not None + ) + if can_use_nvls: + multimem_all_gather(symm_mem_buffer["tensor"], x, symm_mem_buffer["handle"]) + return symm_mem_buffer["tensor"] + else: + x, _ = gather_along_first_dim(x, process_group=self.tp_group) + return x + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, None]: + """ + Forward pass. + """ + if self.training: + return super().forward(x) + + if self.tp_size == 1: + x = _apply_linear(x, self.weight, self.config) + return x, None + + symm_mem_buffer = self._maybe_allocate_symmetric_buffer(x) + x = self._all_gather(x, symm_mem_buffer) + x = _apply_linear(x, self.weight, self.config) + + return x, None + + class InferenceRowParallelLinear(TERowParallelLinear): """ Inference optimized version of TERowParallelLinear. diff --git a/megatron/core/transformer/moe/experts.py b/megatron/core/transformer/moe/experts.py index 62c2ec734c4..116e47fffb1 100644 --- a/megatron/core/transformer/moe/experts.py +++ b/megatron/core/transformer/moe/experts.py @@ -18,6 +18,7 @@ from megatron.core.fusions.fused_bias_geglu import quick_gelu, weighted_bias_quick_geglu_impl from megatron.core.fusions.fused_bias_swiglu import weighted_bias_swiglu_impl from megatron.core.fusions.fused_weighted_squared_relu import weighted_squared_relu_impl +from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor from megatron.core.pipeline_parallel.fine_grained_activation_offload import ( FineGrainedActivationOffloadingInterface as off_interface, ) @@ -38,7 +39,6 @@ sharded_state_dict_default, ) from megatron.core.typed_torch import apply_module, not_none -from megatron.core.utils import is_torch_min_version try: import transformer_engine as te # pylint: disable=unused-import @@ -59,6 +59,13 @@ except ImportError: HAVE_FLASHINFER = False +from megatron.core.inference.moe import ActivationType as McoreActivationType +from megatron.core.inference.moe import ( + InferenceGroupedGemmBackend, + mcore_fused_moe, + resolve_inference_grouped_gemm_backend, +) + logger = logging.getLogger(__name__) @@ -462,7 +469,7 @@ class InferenceGroupedMLP(TEGroupedMLP): Supports three forward paths: - Training: delegates to parent TEGroupedMLP - Inference + CUDA graphed: FlashInfer cutlass_fused_moe (fused permute + GEMM) - - Inference + eager: torch._grouped_mm with GPU-resident cumsum offsets + - Inference + eager: torch.nn.functional.grouped_mm with GPU-resident cumsum offsets """ def __init__( @@ -486,16 +493,12 @@ def __init__( self.is_inference_cuda_graphed_iteration = False - # torch._grouped_mm requires PyTorch >= 2.10 - self._torch_grouped_mm_available = ( - is_torch_min_version("2.10") - and hasattr(torch, '_grouped_mm') - and not config.inference_disable_torch_grouped_mm - ) - if HAVE_FLASHINFER: self._flashinfer_activation_type = self._resolve_flashinfer_activation_type() + self._mcore_activation_type = self._resolve_mcore_activation_type() + self.inference_grouped_gemm_backend = config.inference_grouped_gemm_backend + def _resolve_flashinfer_activation_type(self): """Map megatron activation config to FlashInfer ActivationType.""" assert ( @@ -512,6 +515,13 @@ def _resolve_flashinfer_activation_type(self): return ActivationType.Relu2 raise ValueError(f"No FlashInfer ActivationType mapping for activation_func={func}") + def _resolve_mcore_activation_type(self): + """Map megatron activation config to mcore_fused_moe ActivationType.""" + func = self.config.activation_func + if func == squared_relu: + return McoreActivationType.SQUARED_RELU + raise ValueError(f"No mcore_fused_moe ActivationType mapping for activation_func={func}") + def set_inference_cuda_graphed_iteration(self): """Enable CUDA-graphed iteration mode.""" self.is_inference_cuda_graphed_iteration = True @@ -520,7 +530,49 @@ def unset_inference_cuda_graphed_iteration(self): """Disable CUDA-graphed iteration mode.""" self.is_inference_cuda_graphed_iteration = False - @torch.inference_mode(False) + def _build_concatenated_mxfp8_weights(self): + """Build stacked MXFP8 weight tensors from per-expert MXFP8Tensor attributes. + + After quantize_model_to_mxfp8, each per-expert weight (weight0, weight1, ...) + has been replaced with an MXFP8Tensor. This method stacks their data and + scales into _fc1_weight / _fc2_weight for scaled_grouped_mm. + + Note: this creates a contiguous copy since per-expert MXFP8Tensor attributes + are not contiguous across experts. This is a one-time cost at first forward. + + Unlike _build_concatenated_weights, this does not create nn.Parameter views + back into the buffer — MXFP8 weights are not nn.Parameters (they are plain + MXFP8Tensor attributes set by quantize_model_to_mxfp8). This path is only + intended for non-colocated inference. + """ + + for linear_name, buf_name in [('linear_fc1', '_fc1_weight'), ('linear_fc2', '_fc2_weight')]: + linear = getattr(self, linear_name) + q_list, s_list = [], [] + for i in range(self.num_local_experts): + w = getattr(linear, f'weight{i}') + if isinstance(w, MXFP8Tensor): + mxfp8 = w + elif hasattr(w, 'data') and isinstance(w.data, MXFP8Tensor): + mxfp8 = w.data + else: + raise RuntimeError( + f"Expected MXFP8Tensor for {linear_name}.weight{i}, " + f"got {type(w).__name__}. Was quantize_model_to_mxfp8 called?" + ) + q_list.append(mxfp8.data) + s_list.append(mxfp8.scale) + + setattr( + self, + buf_name, + MXFP8Tensor( + data=torch.stack(q_list, dim=0).contiguous(), + scale=torch.stack(s_list, dim=0).contiguous(), + ), + ) + + @torch.inference_mode(False) # needed for non-colocated inference. def _build_concatenated_weights(self): """Create big contiguous weight tensors that share storage with TE's per-expert parameters. @@ -533,7 +585,7 @@ def _build_concatenated_weights(self): This allows: - TE's forward to work correctly (same Parameter objects, same internal state) - Training updates to flow through (param.data is a view into the big tensor) - - torch._grouped_mm / FlashInfer to use the big tensor directly + - torch.nn.functional.grouped_mm / FlashInfer to use the big tensor directly """ # Get device/dtype from existing TE weights device = self.linear_fc1.weight0.device @@ -582,45 +634,25 @@ def _flashinfer_forward(self, hidden_states, routing_map, probs): )[0] return output, None - def _torch_grouped_mm_forward( - self, permuted_local_hidden_states, tokens_per_expert, permuted_probs + def _mcore_fused_moe_forward( + self, hidden_states, probs, routing_map=None, tokens_per_expert=None, skip_permute=False ): - permuted_probs = permuted_probs.unsqueeze(-1) - if not tokens_per_expert.is_cuda: - tokens_per_expert = tokens_per_expert.to('cuda') - - if self.config.moe_apply_probs_on_input: - assert ( - self.config.moe_router_topk == 1 - ), "`moe_apply_probs_on_input` only works with `moe_router_topk`=1." - original_dtype = permuted_local_hidden_states.dtype - permuted_local_hidden_states = permuted_probs * permuted_local_hidden_states - permuted_local_hidden_states = permuted_local_hidden_states.to(original_dtype) - permuted_probs = torch.ones_like(permuted_probs) - - if permuted_local_hidden_states.nelement() != 0: - # Use pre-concatenated weights (built during init/load) - # _fc1_weight shape: [num_experts, ffn_hidden * (2 if gated else 1), hidden_size] - # _fc2_weight shape: [num_experts, hidden_size, ffn_hidden] - # Compute cumulative offsets on GPU (no host sync!) - # offs[i] = end index of expert i's tokens - offs = tokens_per_expert.cumsum(0).to(torch.int32) - - fc1_output = torch._grouped_mm( - permuted_local_hidden_states, self._fc1_weight.transpose(1, 2), offs=offs - ) - - # Activation with routing probabilities - bias_act_output = self.bias_act_func(fc1_output, None, permuted_probs) - - fc2_output = torch._grouped_mm( - bias_act_output, self._fc2_weight.transpose(1, 2), offs=offs - ) - else: - # No tokens allocated - return empty tensor with correct shape - fc2_output = permuted_local_hidden_states - - return fc2_output, None + """Torch grouped_mm fused MoE forward via mcore_fused_moe.""" + local_expert_start = self.ep_group.rank() * self.num_local_experts + output = mcore_fused_moe( + hidden_states, + probs, + self._fc1_weight, + self._fc2_weight, + activation_type=self._mcore_activation_type, + num_local_experts=self.num_local_experts, + local_expert_start=local_expert_start, + routing_map=routing_map, + tokens_per_expert=tokens_per_expert, + skip_permute=skip_permute, + disable_fused_quant_kernels=self.config.inference_moe_disable_fused_quant_kernels, + ) + return output, None def forward( self, @@ -635,7 +667,7 @@ def forward( - Inference + CUDA graphed: FlashInfer cutlass_fused_moe. tokens_per_expert is not used in this path; the FlashInfer kernel operates directly on routing_map. - - Inference + eager: torch._grouped_mm with GPU-resident cumsum offsets. + - Inference + eager: torch.nn.functional.grouped_mm with GPU-resident cumsum offsets. Args: permuted_local_hidden_states: [num_tokens, hidden_size] input hidden states. @@ -647,29 +679,42 @@ def forward( """ if self.training: + assert ( + not self.config.fp8_recipe == "mxfp8" + ), "MXFP8 inference optimized is not compatible with training / colocated RL." return super().forward(permuted_local_hidden_states, tokens_per_expert, permuted_probs) # Lazily build concatenated weights on first forward (after checkpoint load) if not self._concatenated_weights_built: - assert ( - not self.training - ), "Concatenated weights must be built before training forward pass." - self._build_concatenated_weights() + if self.config.fp8_recipe == "mxfp8": + self._build_concatenated_mxfp8_weights() + else: + self._build_concatenated_weights() self._concatenated_weights_built = True - if self.is_inference_cuda_graphed_iteration: + resolved_backend = resolve_inference_grouped_gemm_backend( + self.inference_grouped_gemm_backend, + self.is_inference_cuda_graphed_iteration, + is_mxfp8=self.config.fp8_recipe == "mxfp8", + ) + + if resolved_backend == InferenceGroupedGemmBackend.FLASHINFER: assert routing_map is not None, "routing_map is required for FlashInfer forward pass." assert ( - HAVE_FLASHINFER - ), "FlashInfer is not available; cannot use FlashInfer forward pass." + self.is_inference_cuda_graphed_iteration + ), "FlashInfer forward path is only used in CUDA-graphed inference iterations." return self._flashinfer_forward( permuted_local_hidden_states, routing_map, permuted_probs ) - elif self._torch_grouped_mm_available: - return self._torch_grouped_mm_forward( - permuted_local_hidden_states, tokens_per_expert, permuted_probs + elif resolved_backend == InferenceGroupedGemmBackend.TORCH: + return self._mcore_fused_moe_forward( + permuted_local_hidden_states, + permuted_probs, + routing_map=routing_map, + tokens_per_expert=tokens_per_expert, + skip_permute=(not self.is_inference_cuda_graphed_iteration), ) - else: + elif resolved_backend == InferenceGroupedGemmBackend.TE: return super().forward(permuted_local_hidden_states, tokens_per_expert, permuted_probs) diff --git a/megatron/core/transformer/moe/moe_layer.py b/megatron/core/transformer/moe/moe_layer.py index 8277486b03b..0cbf707af30 100644 --- a/megatron/core/transformer/moe/moe_layer.py +++ b/megatron/core/transformer/moe/moe_layer.py @@ -268,14 +268,23 @@ def __init__( # Inference-optimized mode setup if config.transformer_impl == "inference_optimized": - assert ( - HAVE_FLASHINFER - ), "flashinfer-python is required for inference-optimized MoE implementation." - if not HAVE_FLASHINFER_CUBIN_AND_JIT_CACHE: - warnings.warn( - "flashinfer-cubin and/or flashinfer-jit-cache not found. " - "The FlashInfer cutlass kernel will be JIT compiled," - "which may take a long time." + if config.inference_grouped_gemm_backend == 'auto': + assert HAVE_FLASHINFER, ( + "inference_grouped_gemm_backend='auto'" + "requires flashinfer-python. " + "Install flashinfer-python or set " + "inference_grouped_gemm_backend to 'torch' or 'te'." + ) + if not HAVE_FLASHINFER_CUBIN_AND_JIT_CACHE: + warnings.warn( + "flashinfer-cubin and/or flashinfer-jit-cache not found. " + "The FlashInfer cutlass kernel will be JIT compiled," + "which may take a long time." + ) + elif config.inference_grouped_gemm_backend == 'torch': + assert hasattr(torch.nn.functional, 'grouped_mm'), ( + "inference_grouped_gemm_backend='torch' requires " + "torch.nn.functional.grouped_mm (available since PyTorch 2.10)." ) self._setup_inference_mode(pg_collection) diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py index 88a88f11e06..b3b7b06b1b8 100644 --- a/megatron/core/transformer/moe/token_dispatcher_inference.py +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -305,7 +305,11 @@ def token_combine(self, hidden_states): output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device) # Check output only: if output is 16-byte divisible, input (world_size * output) is too. - nvls_eligible = self.triton_nvls_kernels_allowed and are_tensors_nvls_eligible(output) + nvls_eligible = ( + self.triton_nvls_kernels_allowed + and output.dtype in (torch.bfloat16, torch.float32) + and are_tensors_nvls_eligible(output) + ) rs_buffer = None if nvls_eligible: @@ -319,10 +323,10 @@ def token_combine(self, hidden_states): # Use latency-optimized NVLS reduce-scatter multimem_reduce_scatter(output, rs_buffer["tensor"], rs_buffer["handle"]) - return output + return output.to(torch.bfloat16) else: # Fallback to NCCL hidden_states = reduce_scatter_to_sequence_parallel_region( hidden_states, group=self.tp_ep_group ) - return hidden_states + return hidden_states.to(torch.bfloat16) diff --git a/megatron/core/transformer/transformer_config.py b/megatron/core/transformer/transformer_config.py index 11c60742742..c202e6e7800 100644 --- a/megatron/core/transformer/transformer_config.py +++ b/megatron/core/transformer/transformer_config.py @@ -916,9 +916,23 @@ class TransformerConfig(ModelParallelConfig): inference_disable_triton_nvls_kernels: bool = False """ If true, disables the use of Triton NVLS kernels during inference. """ - inference_disable_torch_grouped_mm: bool = False - """ If true, disables torch._grouped_mm in InferenceGroupedMLP, - falling back to TE GroupedGEMM. """ + inference_grouped_gemm_backend: Literal['auto', 'torch', 'te'] = "auto" + """Specifies the backend to use for grouped GEMM operations during inference. + Options: + - 'auto': Uses FlashInfer for CUDA-graphed iterations (requires flashinfer-python), + and torch.nn.functional.grouped_mm for non-CUDA-graphed iterations (falls back to TE + if unavailable). Note: the heuristic for choosing backends in 'auto' mode may change + in future releases. + - 'torch': Uses torch.nn.functional.grouped_mm. For CUDA-graphed iterations, uses + mcore_fused_moe (permute/unpermute + grouped_mm with Triton kernels). + - 'te': Uses TE GroupedGEMM only. Not supported with CUDA graphs. + """ + + inference_moe_disable_fused_quant_kernels: bool = False + """When False (default), use fused kernels that combine permute/activation with + MXFP8 quantization + swizzle into a single kernel launch. Only applies when + fp8_recipe='mxfp8'. Set to True to disable fusion and use separate kernel + launches (useful for debugging).""" mrope_section: Optional[List[int]] = None """ Multimodal rope section is for channel dimension of temporal, height and width @@ -1160,17 +1174,32 @@ def __post_init__(self): ) if self.moe_router_dtype != "fp32": raise ValueError( - "Inference-optimized MoE requires --moe-router-dtype=fp32 " + "--transformer-impl='inference_optimized' requires --moe-router-dtype=fp32 " "to avoid costly dtype conversions during decode." ) - if self.gated_linear_unit and self.cuda_graph_impl != "none": + + if self.gated_linear_unit and self.cuda_graph_impl == "local": raise ValueError( - "Inference-optimized MoE does not yet support CUDA graphs with gated " - "linear units (SwiGLU/GeGLU) due to differences in weight layouts " - "between the FlashInfer kernel and mcore. Either disable CUDA graphs " - "(--cuda-graph-impl=none) or use a non-gated activation (e.g. squared_relu)." + "--transformer-impl='inference_optimized' does not yet support CUDA graphs " + "with gated linear units (SwiGLU/GeGLU) due to differences in weight " + "layouts between the FlashInfer kernel and mcore. Either disable CUDA " + "graphs (--cuda-graph-impl=none) or use a non-gated activation " + "(e.g. squared_relu)." ) + assert self.inference_grouped_gemm_backend in ('auto', 'torch', 'te'), ( + f"inference_grouped_gemm_backend must be 'auto', 'torch', or 'te', " + f"got '{self.inference_grouped_gemm_backend}'" + ) + + if self.cuda_graph_impl == "local": + if self.inference_grouped_gemm_backend == "te": + raise ValueError( + "TE GroupedGEMM is not supported with CUDA graphs. Please set " + "inference_grouped_gemm_backend to 'auto' or 'torch', or disable " + "CUDA graphs (--cuda-graph-impl=none)." + ) + if self.num_moe_experts is not None and self.num_moe_experts <= 0: raise ValueError("num_moe_experts must be non-negative.") @@ -2187,12 +2216,6 @@ def __post_init__(self): "for inference_optimized transformer implementation." ) - if self.inference_disable_torch_grouped_mm: - assert self.transformer_impl == "inference_optimized", ( - "inference_disable_torch_grouped_mm is only supported " - "for inference_optimized transformer implementation." - ) - if self.batch_invariant_mode: assert ( self.attention_backend == AttnBackend.flash diff --git a/megatron/inference/utils.py b/megatron/inference/utils.py index ec8f1088be1..d4c8cc1fca9 100644 --- a/megatron/inference/utils.py +++ b/megatron/inference/utils.py @@ -69,8 +69,17 @@ def get_model_for_inference() -> MegatronModule: model.eval() if args.transformer_impl == "inference_optimized" and args.fp8_recipe == "mxfp8": - quantize_model_to_mxfp8(unwrap_model(model)) - + backend = args.inference_grouped_gemm_backend + if backend == "auto": + quant_backend = "flashinfer" + elif backend == "torch": + quant_backend = "triton" + elif backend == "te": + raise ValueError( + "MXFP8 quantization is not supported with " + "inference_grouped_gemm_backend='te'." + ) + quantize_model_to_mxfp8(unwrap_model(model), backend=quant_backend) return model @@ -351,6 +360,7 @@ def get_inference_config_from_model_and_args(model: MegatronModule, args): metrics_writer=metrics_writer, logging_step_interval=args.inference_logging_step_interval, num_speculative_tokens=args.num_speculative_tokens, + use_synchronous_zmq_collectives=args.inference_use_synchronous_zmq_collectives, ) diff --git a/megatron/training/arguments.py b/megatron/training/arguments.py index 0420873124c..4d9b0d40356 100644 --- a/megatron/training/arguments.py +++ b/megatron/training/arguments.py @@ -960,8 +960,12 @@ def validate_args(args, defaults={}): "--no-check-for-nan-in-loss-and-grad should be set with --cuda-graph-scope=full_iteration for training. Note: If you are trying to use full_iteration CUDA graphs for inference, please use --cuda-graph-scope full_iteration_inference instead" if args.cuda_graph_impl == "local" and CudaGraphScope.full_iteration_inference in args.cuda_graph_scope: - assert args.fp8 is None, \ - "fp8 is not supported with inference dynamic batching and full_iteration_inference CUDA graph" + if args.fp8 is not None: + assert args.transformer_impl == "inference_optimized", \ + "fp8 with full_iteration_inference CUDA graphs is only supported with " \ + "--transformer-impl=inference_optimized" + assert args.fp8_recipe == "mxfp8", \ + "Only --fp8-recipe=mxfp8 is supported with full_iteration_inference CUDA graphs" if args.cuda_graph_impl == 'local': assert args.inference_dynamic_batching_num_cuda_graphs > 0 or args.inference_dynamic_batching_num_cuda_graphs == -1, \ @@ -1856,6 +1860,8 @@ def _add_inference_args(parser): group.add_argument('--mamba-inference-ssm-states-dtype', type=str, choices=['bf16', 'fp16', 'fp32'], default='bf16', help='Dtype for the Mamba inference SSM states tensor') + group.add_argument('--inference-use-synchronous-zmq-collectives', action=argparse.BooleanOptionalAction, + required=False, default=False, help='Use synchronous ZMQ collectives for inference. Helps in reducing performance variability for MoEs.') return parser diff --git a/tests/unit_tests/inference/test_data_parallel_inference_coordinator.py b/tests/unit_tests/inference/test_data_parallel_inference_coordinator.py index 4110d2f69c4..33a42499203 100644 --- a/tests/unit_tests/inference/test_data_parallel_inference_coordinator.py +++ b/tests/unit_tests/inference/test_data_parallel_inference_coordinator.py @@ -16,6 +16,7 @@ from megatron.core.inference.data_parallel_inference_coordinator import ( DataParallelInferenceCoordinator, ) +from megatron.core.inference.engines.async_zmq_communicator import AsyncZMQCommunicator from megatron.core.inference.engines.dynamic_engine import ( DynamicInferenceEngine, EngineState, @@ -118,6 +119,12 @@ def __init__(self): self.step_start_event = unittest.mock.MagicMock() self.step_end_event = unittest.mock.MagicMock() + # ZMQ-based world barrier (async-friendly, no NCCL). + self.zmq_context = zmq.Context() + total_world_size = torch.distributed.get_world_size() + self.world_zmq_communicator = AsyncZMQCommunicator(self.zmq_context, process_group=None) + self.use_synchronous_zmq_collectives = False + async def run_engine_with_coordinator(self, *, loop=None): """Override to bypass @trace_async_exceptions for testability. @@ -261,6 +268,20 @@ def initialize_model_parallel(request, monkeypatch): Utils.destroy_model_parallel() +@pytest.fixture +def test_case_communicator(): + """A separate ZMQ communicator for test sync barriers. + + Use this instead of engine._world_barrier() when the engine loop may be + calling _world_barrier() concurrently (e.g. during state transitions). + """ + ctx = zmq.Context() + comm = AsyncZMQCommunicator(ctx, process_group=None) + yield comm + comm.close() + ctx.term() + + @pytest.fixture(scope="class") def coordinator(): """Launch a single coordinator process for the entire test class. @@ -337,7 +358,9 @@ def build_requests(self, num_requests=NUM_REQUESTS, num_tokens=NUM_TOKENS): ], indirect=["initialize_model_parallel"], ) - async def test_parallel_configs(self, initialize_model_parallel, coordinator): + async def test_parallel_configs( + self, initialize_model_parallel, coordinator, test_case_communicator + ): """Test coordinator with various TP, PP, and EP configurations.""" dp_addr = coordinator port = int(dp_addr.rsplit(":", 1)[-1]) @@ -350,9 +373,7 @@ async def test_parallel_configs(self, initialize_model_parallel, coordinator): ) # Ensure all engines are registered before submitting requests. - await asyncio.wait_for( - asyncio.get_event_loop().run_in_executor(None, torch.distributed.barrier), timeout=30.0 - ) + await asyncio.wait_for(test_case_communicator.all_reduce_max(1), timeout=30.0) client = None try: @@ -370,10 +391,7 @@ async def test_parallel_configs(self, initialize_model_parallel, coordinator): for result in results: assert result["status"] == Status.COMPLETED.name - await asyncio.wait_for( - asyncio.get_event_loop().run_in_executor(None, torch.distributed.barrier), - timeout=30.0, - ) + await asyncio.wait_for(test_case_communicator.all_reduce_max(1), timeout=30.0) finally: await cleanup_engine(engine, client) @@ -381,7 +399,9 @@ async def test_parallel_configs(self, initialize_model_parallel, coordinator): @pytest.mark.skipif(not HAVE_ZMQ, reason="pyzmq is required for this test") @pytest.mark.asyncio @pytest.mark.parametrize("deserialize", [True, False], ids=["deserialize", "raw"]) - async def test_deserialize_flag(self, initialize_model_parallel, coordinator, deserialize): + async def test_deserialize_flag( + self, initialize_model_parallel, coordinator, test_case_communicator, deserialize + ): """Test that the correct response type is returned based on the deserialize flag.""" dp_addr = coordinator port = int(dp_addr.rsplit(":", 1)[-1]) @@ -393,9 +413,7 @@ async def test_deserialize_flag(self, initialize_model_parallel, coordinator, de ) # Ensure all engines are registered before submitting requests. - await asyncio.wait_for( - asyncio.get_event_loop().run_in_executor(None, torch.distributed.barrier), timeout=30.0 - ) + await asyncio.wait_for(test_case_communicator.all_reduce_max(1), timeout=30.0) client = None try: @@ -414,10 +432,7 @@ async def test_deserialize_flag(self, initialize_model_parallel, coordinator, de else: assert isinstance(result, dict) - await asyncio.wait_for( - asyncio.get_event_loop().run_in_executor(None, torch.distributed.barrier), - timeout=30.0, - ) + await asyncio.wait_for(test_case_communicator.all_reduce_max(1), timeout=30.0) finally: await cleanup_engine(engine, client) @@ -429,7 +444,9 @@ async def test_deserialize_flag(self, initialize_model_parallel, coordinator, de [pytest.param((2, 2, 2), id="tp2-pp2-ep2")], indirect=["initialize_model_parallel"], ) - async def test_control_logic_lifecycle(self, initialize_model_parallel, coordinator): + async def test_control_logic_lifecycle( + self, initialize_model_parallel, coordinator, test_case_communicator + ): """Comprehensive lifecycle test for the engine state machine.""" # States where paused stays set: once set during PAUSE, it's only cleared by UNPAUSE. PAUSED_FAMILY = { @@ -472,10 +489,8 @@ def assert_state(eng, expected): ) # Synchronize all ranks so every engine has registered. - await asyncio.wait_for( - asyncio.get_event_loop().run_in_executor(None, torch.distributed.barrier), - timeout=30.0, - ) + # Use test_case_communicator to avoid colliding with engine-internal barriers. + await asyncio.wait_for(test_case_communicator.all_reduce_max(1), timeout=30.0) if rank == 0: client = InferenceClient(dp_addr) @@ -592,10 +607,7 @@ def assert_state(eng, expected): ] # Synchronize all ranks before STOP. - await asyncio.wait_for( - asyncio.get_event_loop().run_in_executor(None, torch.distributed.barrier), - timeout=30.0, - ) + await asyncio.wait_for(test_case_communicator.all_reduce_max(1), timeout=30.0) if rank == 0: # Verify doomed futures are still pending. @@ -617,7 +629,7 @@ def assert_state(eng, expected): @pytest.mark.internal @pytest.mark.skipif(not HAVE_ZMQ, reason="pyzmq is required for this test") @pytest.mark.asyncio - async def test_throughput(self, initialize_model_parallel, coordinator): + async def test_throughput(self, initialize_model_parallel, coordinator, test_case_communicator): """Throughput benchmark: measures ZMQ packet rate.""" _, dp, _, _, _ = initialize_model_parallel num_requests = 10**3 @@ -633,7 +645,7 @@ async def test_throughput(self, initialize_model_parallel, coordinator): ) # Ensure all engines are registered before submitting requests. - await asyncio.get_event_loop().run_in_executor(None, torch.distributed.barrier) + await asyncio.wait_for(test_case_communicator.all_reduce_max(1), timeout=30.0) client = None try: @@ -654,6 +666,6 @@ async def test_throughput(self, initialize_model_parallel, coordinator): f"ZMQ throughput: {total / elapsed_ms:.2f} requests/ms " f"({total} reqs in {elapsed_ms:.0f} ms)" ) - await asyncio.get_event_loop().run_in_executor(None, torch.distributed.barrier) + await asyncio.wait_for(test_case_communicator.all_reduce_max(1), timeout=30.0) finally: await cleanup_engine(engine, client, timeout=60.0) diff --git a/tests/unit_tests/inference/test_moe_permute.py b/tests/unit_tests/inference/test_moe_permute.py new file mode 100644 index 00000000000..4664d0fa2cd --- /dev/null +++ b/tests/unit_tests/inference/test_moe_permute.py @@ -0,0 +1,446 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Unit tests for megatron.core.inference.moe.permute. + +Tests cover: +- compute_local_tokens_per_expert: token counting against PyTorch reference +- compute_expert_offsets: prefix sums with and without alignment +- permute_tokens: expert grouping, data integrity, alignment padding +- unpermute_tokens: weighted scatter-back, fp32 accumulation +- permute -> unpermute roundtrip +""" + +import pytest +import torch + + +def _ref_tokens_per_expert(routing_map, local_expert_start, num_local_experts): + """PyTorch reference for compute_local_tokens_per_expert.""" + counts = torch.zeros(num_local_experts, dtype=torch.int32, device=routing_map.device) + for eid in routing_map.flatten(): + lid = eid.item() - local_expert_start + if 0 <= lid < num_local_experts: + counts[lid] += 1 + return counts + + +def _ref_expert_offsets(tokens_per_expert, alignment): + """PyTorch reference for compute_expert_offsets.""" + aligned = tokens_per_expert.clone().to(torch.int32) + for i in range(len(aligned)): + if aligned[i] > 0 and alignment > 1: + aligned[i] = ((aligned[i] + alignment - 1) // alignment) * alignment + inc = torch.cumsum(aligned, dim=0) + exc = inc - aligned + return exc.to(torch.int32), inc.to(torch.int32) + + +def _make_inputs(num_tokens, hidden_dim, topk, num_experts, seed=42): + """Create random hidden states, probs, and routing_map.""" + torch.manual_seed(seed) + hidden = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.bfloat16) + probs = torch.rand(num_tokens, topk, device="cuda", dtype=torch.float32) + routing_map = torch.randint(0, num_experts, (num_tokens, topk), device="cuda") + return hidden, probs, routing_map + + +@pytest.mark.internal +class TestComputeLocalTokensPerExpert: + + @pytest.mark.parametrize("num_tokens", [1, 4, 16, 64, 128, 256, 512]) + @pytest.mark.parametrize("topk", [1, 2, 4, 6, 8]) + @pytest.mark.parametrize( + "num_experts,num_local,start", + [ + (4, 4, 0), # all local, small expert count + (8, 8, 0), # all local (EP=1) + (8, 4, 0), # first half local (EP=2, rank 0) + (8, 4, 4), # second half local (EP=2, rank 1) + (8, 2, 2), # middle slice (EP=4, rank 1) + (8, 1, 7), # single expert local (EP=8, last rank) + (32, 8, 0), # 32 experts, first 8 local + (32, 8, 24), # 32 experts, last 8 local + (128, 32, 0), # 128 experts, first 32 local (EP=4, rank 0) + (128, 32, 96), # 128 experts, last 32 local (EP=4, rank 3) + ], + ) + def test_matches_reference(self, num_tokens, topk, num_experts, num_local, start): + from megatron.core.inference.moe.permute import compute_local_tokens_per_expert + + routing_map = torch.randint(0, num_experts, (num_tokens, topk), device="cuda") + result = compute_local_tokens_per_expert(routing_map, start, num_local) + expected = _ref_tokens_per_expert(routing_map, start, num_local) + torch.testing.assert_close(result, expected, atol=0, rtol=0) + + def test_no_local_tokens(self): + """All tokens routed to non-local experts -> all zeros.""" + from megatron.core.inference.moe.permute import compute_local_tokens_per_expert + + routing_map = torch.full((16, 4), 99, dtype=torch.int64, device="cuda") + result = compute_local_tokens_per_expert(routing_map, 0, 8) + assert result.sum().item() == 0 + + def test_single_expert_all_tokens(self): + """All token-topk pairs route to a single local expert.""" + from megatron.core.inference.moe.permute import compute_local_tokens_per_expert + + num_tokens, topk, num_local = 32, 4, 8 + routing_map = torch.full((num_tokens, topk), 3, dtype=torch.int64, device="cuda") + result = compute_local_tokens_per_expert(routing_map, 0, num_local) + assert result[3].item() == num_tokens * topk + assert result.sum().item() == num_tokens * topk + + @pytest.mark.parametrize("seed", [0, 7, 42, 123, 999]) + def test_total_count_equals_local_pairs(self, seed): + """Sum of tokens_per_expert equals total routing pairs hitting local experts.""" + from megatron.core.inference.moe.permute import compute_local_tokens_per_expert + + torch.manual_seed(seed) + num_tokens, topk, num_experts = 64, 6, 16 + local_start, num_local = 4, 4 + routing_map = torch.randint(0, num_experts, (num_tokens, topk), device="cuda") + result = compute_local_tokens_per_expert(routing_map, local_start, num_local) + local_mask = (routing_map >= local_start) & (routing_map < local_start + num_local) + assert result.sum().item() == local_mask.sum().item() + + +@pytest.mark.internal +class TestComputeExpertOffsets: + + @pytest.mark.parametrize("alignment", [1, 8, 16, 32, 64, 128]) + @pytest.mark.parametrize( + "tpe_values", + [ + [5, 0, 12, 3, 0, 7, 1, 20], + [1, 1, 1, 1], + [0, 0, 0, 0], + [100, 0, 0, 50], + [1], + [33, 33, 33, 33, 33, 33, 33, 33], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + [127, 0, 129, 0, 1, 0, 255, 0], + ], + ) + def test_matches_reference(self, alignment, tpe_values): + from megatron.core.inference.moe.permute import compute_expert_offsets + + tpe = torch.tensor(tpe_values, dtype=torch.int32, device="cuda") + exc, inc = compute_expert_offsets(tpe, alignment=alignment) + ref_exc, ref_inc = _ref_expert_offsets(tpe, alignment) + torch.testing.assert_close(exc, ref_exc, atol=0, rtol=0) + torch.testing.assert_close(inc, ref_inc, atol=0, rtol=0) + + @pytest.mark.parametrize("n_experts", [1, 2, 4, 8, 16, 32, 64, 128]) + def test_exclusive_starts_at_zero(self, n_experts): + from megatron.core.inference.moe.permute import compute_expert_offsets + + tpe = torch.randint(1, 50, (n_experts,), dtype=torch.int32, device="cuda") + exc, inc = compute_expert_offsets(tpe, alignment=1) + assert exc[0].item() == 0 + assert inc[-1].item() == tpe.sum().item() + + def test_zero_experts_skipped(self): + """Experts with 0 tokens should not consume any aligned space.""" + from megatron.core.inference.moe.permute import compute_expert_offsets + + tpe = torch.tensor([0, 5, 0, 3], dtype=torch.int32, device="cuda") + exc, inc = compute_expert_offsets(tpe, alignment=32) + # Expert 0: 0 tokens -> 0 aligned -> exc=0, inc=0 + assert exc[0].item() == 0 + assert inc[0].item() == 0 + # Expert 1: 5 tokens -> 32 aligned -> exc=0, inc=32 + assert exc[1].item() == 0 + assert inc[1].item() == 32 + # Expert 2: 0 tokens -> exc=32, inc=32 + assert exc[2].item() == 32 + assert inc[2].item() == 32 + + @pytest.mark.parametrize("alignment", [16, 32, 128]) + def test_all_offsets_aligned(self, alignment): + """Every inclusive offset should be a multiple of alignment.""" + from megatron.core.inference.moe.permute import compute_expert_offsets + + tpe = torch.tensor([3, 7, 0, 15, 1, 0, 50, 2], dtype=torch.int32, device="cuda") + exc, inc = compute_expert_offsets(tpe, alignment=alignment) + for i in range(len(tpe)): + assert ( + inc[i].item() % alignment == 0 + ), f"inc[{i}]={inc[i].item()} not aligned to {alignment}" + assert ( + exc[i].item() % alignment == 0 + ), f"exc[{i}]={exc[i].item()} not aligned to {alignment}" + + +class TestPermuteTokens: + + @pytest.mark.parametrize( + "num_tokens,hidden_dim,topk,num_experts", + [ + (1, 64, 1, 4), + (1, 128, 8, 8), + (4, 64, 2, 4), + (16, 128, 2, 8), + (32, 64, 4, 8), + (64, 256, 6, 8), + (128, 128, 8, 128), + (256, 64, 2, 32), + (512, 128, 6, 16), + ], + ) + def test_data_integrity(self, num_tokens, hidden_dim, topk, num_experts): + """Every permuted row matches the original token's hidden state.""" + from megatron.core.inference.moe.permute import permute_tokens + + hidden, probs, routing_map = _make_inputs(num_tokens, hidden_dim, topk, num_experts) + perm_h, perm_p, perm_map, offs = permute_tokens( + hidden, probs, routing_map, 0, num_experts, alignment=1 + ) + + # Check every non-padding row + for i in range(perm_map.shape[0]): + src = perm_map[i].item() + if src < 0: + continue + torch.testing.assert_close( + perm_h[i], hidden[src], msg=f"Row {i} (src={src}) hidden mismatch" + ) + + @pytest.mark.parametrize("alignment", [1, 16, 32, 64, 128]) + @pytest.mark.parametrize("num_tokens,topk,num_experts", [(16, 2, 4), (64, 4, 8), (128, 8, 32)]) + def test_offsets_are_aligned(self, alignment, num_tokens, topk, num_experts): + """Inclusive offsets are multiples of alignment.""" + from megatron.core.inference.moe.permute import permute_tokens + + hidden, probs, routing_map = _make_inputs(num_tokens, 128, topk, num_experts) + _, _, _, offs = permute_tokens( + hidden, probs, routing_map, 0, num_experts, alignment=alignment + ) + if alignment > 1: + for i in range(offs.shape[0]): + assert ( + offs[i].item() % alignment == 0 + ), f"Offset {i}={offs[i].item()} not aligned to {alignment}" + + @pytest.mark.parametrize( + "num_tokens,topk,num_experts,alignment", + [(8, 2, 4, 128), (32, 2, 4, 128), (16, 4, 8, 64), (64, 6, 8, 32)], + ) + def test_padding_rows_have_neg1(self, num_tokens, topk, num_experts, alignment): + """Padding rows in permutation_map are -1.""" + from megatron.core.inference.moe.permute import permute_tokens + + hidden, probs, routing_map = _make_inputs(num_tokens, 64, topk, num_experts) + _, _, perm_map, _ = permute_tokens( + hidden, probs, routing_map, 0, num_experts, alignment=alignment + ) + padding_mask = perm_map == -1 + real_mask = perm_map >= 0 + assert padding_mask.sum() > 0, "Expected some padding rows with large alignment" + assert real_mask.sum() > 0, "Expected some real rows" + + @pytest.mark.parametrize( + "num_tokens,topk,num_experts", [(16, 2, 4), (32, 4, 8), (64, 6, 16), (128, 8, 128)] + ) + @pytest.mark.parametrize("alignment", [1, 32, 128]) + def test_total_real_rows_equals_routed_pairs(self, num_tokens, topk, num_experts, alignment): + """Number of non-padding rows equals total (token, topk) pairs routed locally.""" + from megatron.core.inference.moe.permute import permute_tokens + + hidden, probs, routing_map = _make_inputs(num_tokens, 64, topk, num_experts) + _, _, perm_map, _ = permute_tokens( + hidden, probs, routing_map, 0, num_experts, alignment=alignment + ) + real_count = (perm_map >= 0).sum().item() + # All experts are local, so every pair should appear + assert real_count == num_tokens * topk + + @pytest.mark.parametrize( + "num_tokens,topk,num_experts,local_start,num_local", + [ + (64, 4, 8, 2, 3), # experts 2, 3, 4 + (64, 4, 8, 0, 1), # only expert 0 + (64, 4, 8, 7, 1), # only expert 7 + (128, 6, 16, 4, 8), # experts 4-11 + (32, 2, 32, 16, 16), # second half of 32 + (256, 8, 128, 0, 32), # first 32 of 128 + ], + ) + def test_expert_subset(self, num_tokens, topk, num_experts, local_start, num_local): + """Only tokens routed to local experts appear in output.""" + from megatron.core.inference.moe.permute import permute_tokens + + hidden, probs, routing_map = _make_inputs(num_tokens, 64, topk, num_experts) + _, _, perm_map, _ = permute_tokens( + hidden, probs, routing_map, local_start, num_local, alignment=1 + ) + real_count = (perm_map >= 0).sum().item() + local_mask = (routing_map >= local_start) & (routing_map < local_start + num_local) + expected_count = local_mask.sum().item() + assert real_count == expected_count + + @pytest.mark.parametrize("hidden_dim", [32, 64, 128, 256, 512, 1024, 2688]) + def test_various_hidden_dims(self, hidden_dim): + """Permute works across various hidden dimensions including non-power-of-2.""" + from megatron.core.inference.moe.permute import permute_tokens + + hidden, probs, routing_map = _make_inputs(32, hidden_dim, 4, 8) + perm_h, _, perm_map, _ = permute_tokens(hidden, probs, routing_map, 0, 8, alignment=1) + # Spot-check first real row + for i in range(perm_map.shape[0]): + src = perm_map[i].item() + if src >= 0: + torch.testing.assert_close(perm_h[i], hidden[src]) + break + + +@pytest.mark.internal +class TestUnpermuteTokens: + + def test_weighted_scatter(self): + """Unpermute correctly accumulates prob-weighted expert outputs.""" + from megatron.core.inference.moe.permute import unpermute_tokens + + num_tokens, hidden_dim = 4, 8 + # Two entries map to token 0, one to token 2 + expert_output = torch.ones(3, hidden_dim, device="cuda", dtype=torch.bfloat16) + permuted_probs = torch.tensor([0.5, 0.3, 0.7], device="cuda", dtype=torch.float32) + perm_map = torch.tensor([0, 0, 2], dtype=torch.int32, device="cuda") + + result = unpermute_tokens(expert_output, permuted_probs, perm_map, num_tokens) + + assert result.dtype == torch.float32 + # Token 0: 0.5 * 1.0 + 0.3 * 1.0 = 0.8 + torch.testing.assert_close( + result[0], torch.full((hidden_dim,), 0.8, device="cuda"), atol=1e-5, rtol=1e-5 + ) + # Token 1: untouched -> 0 + torch.testing.assert_close(result[1], torch.zeros(hidden_dim, device="cuda")) + # Token 2: 0.7 * 1.0 = 0.7 + torch.testing.assert_close( + result[2], torch.full((hidden_dim,), 0.7, device="cuda"), atol=1e-5, rtol=1e-5 + ) + + def test_padding_rows_ignored(self): + """Rows with permutation_map == -1 are skipped.""" + from megatron.core.inference.moe.permute import unpermute_tokens + + expert_output = torch.ones(4, 8, device="cuda", dtype=torch.bfloat16) + permuted_probs = torch.ones(4, device="cuda", dtype=torch.float32) + perm_map = torch.tensor([0, -1, -1, 1], dtype=torch.int32, device="cuda") + + result = unpermute_tokens(expert_output, permuted_probs, perm_map, 3) + # Only tokens 0 and 1 get values + assert result[0].sum().item() != 0 + assert result[1].sum().item() != 0 + assert result[2].sum().item() == 0 + + @pytest.mark.parametrize("hidden_dim", [8, 64, 128, 256, 512, 2688]) + def test_various_hidden_dims(self, hidden_dim): + """Unpermute works across various hidden dimensions.""" + from megatron.core.inference.moe.permute import unpermute_tokens + + num_tokens = 8 + expert_output = torch.randn(4, hidden_dim, device="cuda", dtype=torch.bfloat16) + permuted_probs = torch.tensor([1.0, 1.0, 1.0, 1.0], device="cuda", dtype=torch.float32) + perm_map = torch.tensor([0, 1, 2, 3], dtype=torch.int32, device="cuda") + + result = unpermute_tokens(expert_output, permuted_probs, perm_map, num_tokens) + assert result.shape == (num_tokens, hidden_dim) + # First 4 tokens should have values, rest should be zero + for t in range(4): + torch.testing.assert_close(result[t], expert_output[t].float(), atol=1e-5, rtol=1e-5) + for t in range(4, num_tokens): + assert result[t].sum().item() == 0 + + @pytest.mark.parametrize("topk", [1, 2, 4, 6, 8]) + def test_multiple_topk_accumulation(self, topk): + """Multiple topk entries for the same token are summed correctly.""" + from megatron.core.inference.moe.permute import unpermute_tokens + + hidden_dim = 32 + # All topk entries point to token 0 + expert_output = torch.ones(topk, hidden_dim, device="cuda", dtype=torch.bfloat16) + probs = torch.full((topk,), 0.1, device="cuda", dtype=torch.float32) + perm_map = torch.zeros(topk, dtype=torch.int32, device="cuda") + + result = unpermute_tokens(expert_output, probs, perm_map, 1) + expected_val = 0.1 * topk + torch.testing.assert_close( + result[0], torch.full((hidden_dim,), expected_val, device="cuda"), atol=1e-4, rtol=1e-4 + ) + + +@pytest.mark.internal +class TestPermuteUnpermuteRoundtrip: + + @pytest.mark.parametrize( + "num_tokens,hidden_dim,topk,num_experts,alignment", + [ + (1, 64, 1, 4, 1), + (1, 128, 1, 4, 128), + (8, 64, 1, 4, 1), + (16, 64, 2, 4, 1), + (16, 64, 2, 4, 32), + (32, 128, 4, 8, 32), + (32, 128, 4, 8, 128), + (64, 256, 6, 8, 1), + (64, 256, 6, 8, 128), + (128, 128, 8, 32, 1), + (128, 128, 8, 32, 128), + (256, 64, 2, 128, 32), + (64, 2688, 8, 128, 128), # nanov3-like hidden_dim + ], + ) + def test_roundtrip_identity(self, num_tokens, hidden_dim, topk, num_experts, alignment): + """permute -> (identity transform) -> unpermute recovers weighted sum of inputs.""" + from megatron.core.inference.moe.permute import permute_tokens, unpermute_tokens + + torch.manual_seed(42) + hidden = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.bfloat16) + probs = torch.rand(num_tokens, topk, device="cuda", dtype=torch.float32) + routing_map = torch.randint(0, num_experts, (num_tokens, topk), device="cuda") + + perm_h, perm_p, perm_map, _ = permute_tokens( + hidden, probs, routing_map, 0, num_experts, alignment=alignment + ) + # Pass permuted hidden directly through (identity expert) + result = unpermute_tokens(perm_h, perm_p, perm_map, num_tokens) + + # Build reference: for each token, sum prob[k] * hidden[token] over topk + ref = torch.zeros(num_tokens, hidden_dim, device="cuda", dtype=torch.float32) + for t in range(num_tokens): + prob_sum = probs[t].sum() + ref[t] = hidden[t].float() * prob_sum + + torch.testing.assert_close(result, ref, atol=1e-2, rtol=1e-2) + + @pytest.mark.parametrize( + "local_start,num_local,num_experts", + [(0, 4, 8), (4, 4, 8), (0, 1, 8), (0, 8, 8), (0, 32, 128), (96, 32, 128)], + ) + def test_roundtrip_with_expert_subset(self, local_start, num_local, num_experts): + """Roundtrip works when only a subset of experts are local.""" + from megatron.core.inference.moe.permute import permute_tokens, unpermute_tokens + + torch.manual_seed(42) + num_tokens, hidden_dim, topk = 64, 128, 4 + hidden = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.bfloat16) + probs = torch.rand(num_tokens, topk, device="cuda", dtype=torch.float32) + routing_map = torch.randint(0, num_experts, (num_tokens, topk), device="cuda") + + perm_h, perm_p, perm_map, _ = permute_tokens( + hidden, probs, routing_map, local_start, num_local, alignment=32 + ) + result = unpermute_tokens(perm_h, perm_p, perm_map, num_tokens) + + # Reference: only accumulate probs for local experts + ref = torch.zeros(num_tokens, hidden_dim, device="cuda", dtype=torch.float32) + for t in range(num_tokens): + local_prob_sum = 0.0 + for k in range(topk): + eid = routing_map[t, k].item() + if local_start <= eid < local_start + num_local: + local_prob_sum += probs[t, k].item() + ref[t] = hidden[t].float() * local_prob_sum + + torch.testing.assert_close(result, ref, atol=1e-2, rtol=1e-2) diff --git a/tests/unit_tests/inference/test_mxfp8_utils.py b/tests/unit_tests/inference/test_mxfp8_utils.py new file mode 100644 index 00000000000..a137dfbc820 --- /dev/null +++ b/tests/unit_tests/inference/test_mxfp8_utils.py @@ -0,0 +1,645 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Unit tests for MXFP8 quantization. + +Tests cover: +- mxfp8_quantize (Triton kernel): data and swizzled scales vs PyTorch reference +- MXFP8Tensor.from_bf16: both 'triton' and 'flashinfer' backends +- MXFP8Tensor.scale_2d: reshape correctness +""" + +import pytest +import torch + +pytestmark = [ + pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required"), + pytest.mark.internal, +] + + +def ceil_div(a, b): + return (a + b - 1) // b + + +# ────────────────────────────────────────────────────────────────────── +# Reference functions from PyTorch +# https://github.com/pytorch/pytorch/blob/main/torch/testing/_internal/common_quantized.py#L578 +# ────────────────────────────────────────────────────────────────────── + + +def ref_to_mxfp(data_hp: torch.Tensor, block_size: int = 32, format: str = "mxfp8"): + if data_hp.dtype not in (torch.bfloat16, torch.float): + raise AssertionError(f"{data_hp.dtype} is not supported yet") + if data_hp.shape[-1] % block_size != 0: + raise AssertionError( + f"the last dimension of shape {data_hp.shape} must be divisible by block_size {block_size}" + ) + if not data_hp.is_contiguous(): + raise AssertionError("unsupported: data_hp must be contiguous") + + orig_shape = data_hp.shape + data_hp = data_hp.reshape(*orig_shape[:-1], orig_shape[-1] // block_size, block_size) + + max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1) + + data_hp = data_hp.to(torch.float32) + max_abs = max_abs.to(torch.float32) + + if format == "mxfp8": + F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0 + max_pos = F8E4M3_MAX + elif format == "mxfp4": + F4E2M1_MAX = 6.0 + max_pos = F4E2M1_MAX + + # RCEIL + def _to_mx_rceil( + data_hp: torch.Tensor, max_abs: torch.Tensor, max_pos: float + ) -> tuple[torch.Tensor, torch.Tensor]: + E8M0_EXPONENT_BIAS = 127 + descale = max_abs / max_pos + exponent = torch.where( + torch.isnan(descale), + 0xFF, # Handle biased exponent for nan + ( + torch.clamp( + torch.ceil(torch.log2(descale)), min=-E8M0_EXPONENT_BIAS, max=E8M0_EXPONENT_BIAS + ) + + E8M0_EXPONENT_BIAS + ).to(torch.uint8), + ) + + descale_fp = torch.where( + exponent == 0, 1.0, torch.exp2(E8M0_EXPONENT_BIAS - exponent.to(torch.float32)) + ) + + # scale and saturated cast the data elements to max of target dtype + data_lp = torch.clamp(data_hp * descale_fp, min=-1 * max_pos, max=max_pos) + return exponent, data_lp + + scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos) + + # cast to target dtype + data_lp = data_lp.to(torch.float8_e4m3fn) + data_lp = data_lp.reshape(orig_shape) + + scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu) + scale_e8m0_biased = scale_e8m0_biased.squeeze(-1) + return scale_e8m0_biased, data_lp + + +def ref_swizzle(input_matrix) -> torch.Tensor: + """Rearrange a scale matrix into cuBLAS 2D blocked (swizzled) layout. + + See: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + input_matrix: Input tensor of shape (H, W) + + Returns: + Flattened swizzled tensor. + """ + rows, cols = input_matrix.shape + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + padded = input_matrix + if (rows, cols) != (padded_rows, padded_cols): + padded = torch.zeros( + (padded_rows, padded_cols), device=input_matrix.device, dtype=input_matrix.dtype + ) + padded[:rows, :cols] = input_matrix + + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + + return rearranged.flatten() + + +# ────────────────────────────────────────────────────────────────────── +# mxfp8_quantize (Triton kernel) +# ────────────────────────────────────────────────────────────────────── + + +class TestMxfp8Quantize: + + @pytest.mark.parametrize( + "M,K", + [ + (1, 32), + (1, 64), + (1, 128), + (4, 32), + (4, 128), + (16, 64), + (16, 256), + (32, 128), + (64, 256), + (128, 128), + (128, 512), + (128, 2688), # nanov3 hidden_size + (256, 1856), # nanov3 moe_ffn_hidden_size + (512, 2688), + ], + ) + def test_data_matches_reference(self, M, K): + """Quantized FP8 data matches PyTorch reference.""" + from megatron.core.inference.quantization.mxfp8_quantize import mxfp8_quantize + + torch.manual_seed(42) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + + triton_data, _ = mxfp8_quantize(x) + _, ref_data = ref_to_mxfp(x) + + assert triton_data.shape == (M, K) + assert triton_data.dtype == torch.float8_e4m3fn + torch.testing.assert_close( + triton_data.view(torch.uint8), ref_data.view(torch.uint8), atol=0, rtol=0 + ) + + @pytest.mark.parametrize( + "M,K", + [ + (1, 32), + (1, 64), + (4, 128), + (16, 256), + (32, 128), + (128, 128), + (128, 512), + (128, 2688), + (256, 1856), + (512, 2688), + ], + ) + def test_scales_match_reference(self, M, K): + """Swizzled scales match ref_to_mxfp scales passed through ref_swizzle.""" + from megatron.core.inference.quantization.mxfp8_quantize import mxfp8_quantize + + torch.manual_seed(42) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + + _, triton_scales = mxfp8_quantize(x) + ref_scales_2d, _ = ref_to_mxfp(x) # [M, K//32] e8m0 + + # Swizzle the reference scales + ref_swizzled = ref_swizzle(ref_scales_2d) + + # Compare as uint8 since e8m0 is just exponent bytes + torch.testing.assert_close( + triton_scales.view(torch.uint8), ref_swizzled.view(torch.uint8), atol=0, rtol=0 + ) + + @pytest.mark.parametrize("M,K", [(1, 32), (16, 128), (128, 2688)]) + def test_all_zeros_input(self, M, K): + """All-zero input produces all-zero FP8 data and zero scales.""" + from megatron.core.inference.quantization.mxfp8_quantize import mxfp8_quantize + + x = torch.zeros(M, K, device="cuda", dtype=torch.bfloat16) + data, scales = mxfp8_quantize(x) + assert (data.float() == 0).all() + assert (scales.view(torch.uint8) == 0).all() + + @pytest.mark.parametrize("M,K", [(1, 32), (16, 128), (128, 256)]) + def test_constant_input(self, M, K): + """Constant input: all elements in a group have the same value.""" + from megatron.core.inference.quantization.mxfp8_quantize import mxfp8_quantize + + x = torch.full((M, K), 1.0, device="cuda", dtype=torch.bfloat16) + data, _ = mxfp8_quantize(x) + _, ref_data = ref_to_mxfp(x) + torch.testing.assert_close( + data.view(torch.uint8), ref_data.view(torch.uint8), atol=0, rtol=0 + ) + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + def test_input_dtypes(self, dtype): + """Kernel accepts bf16, fp16, and fp32 inputs.""" + from megatron.core.inference.quantization.mxfp8_quantize import mxfp8_quantize + + x = torch.randn(16, 128, device="cuda", dtype=dtype) + data, _ = mxfp8_quantize(x) + assert data.dtype == torch.float8_e4m3fn + assert data.shape == (16, 128) + + @pytest.mark.parametrize("M", [1, 127, 128, 129, 255, 256, 257, 512]) + def test_various_row_counts(self, M): + """Test row counts that are not multiples of 128 (macro tile boundary).""" + from megatron.core.inference.quantization.mxfp8_quantize import mxfp8_quantize + + K = 128 + torch.manual_seed(42) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + data, _ = mxfp8_quantize(x) + _, ref_data = ref_to_mxfp(x) + torch.testing.assert_close( + data.view(torch.uint8), ref_data.view(torch.uint8), atol=0, rtol=0 + ) + + @pytest.mark.parametrize("seed", [0, 7, 42, 123, 999]) + def test_reproducible(self, seed): + """Same input always produces same output.""" + from megatron.core.inference.quantization.mxfp8_quantize import mxfp8_quantize + + torch.manual_seed(seed) + x = torch.randn(64, 256, device="cuda", dtype=torch.bfloat16) + d1, s1 = mxfp8_quantize(x) + d2, s2 = mxfp8_quantize(x) + torch.testing.assert_close(d1.view(torch.uint8), d2.view(torch.uint8), atol=0, rtol=0) + torch.testing.assert_close(s1.view(torch.uint8), s2.view(torch.uint8), atol=0, rtol=0) + + +# ────────────────────────────────────────────────────────────────────── +# MXFP8Tensor +# ────────────────────────────────────────────────────────────────────── + + +class TestMXFP8Tensor: + + @pytest.mark.parametrize("M,K", [(16, 128), (64, 256), (128, 2688)]) + def test_from_bf16_triton(self, M, K): + """from_bf16 with triton backend produces correct data and scales.""" + from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor + + torch.manual_seed(42) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + + tensor = MXFP8Tensor.from_bf16(x, backend="triton") + _, ref_data = ref_to_mxfp(x) + + assert tensor.data.shape == (M, K) + assert tensor.data.dtype == torch.float8_e4m3fn + assert tensor.backend == "triton" + torch.testing.assert_close( + tensor.data.view(torch.uint8), ref_data.view(torch.uint8), atol=0, rtol=0 + ) + + @pytest.mark.parametrize("M,K", [(16, 128), (64, 256), (128, 2688)]) + def test_from_bf16_flashinfer(self, M, K): + """from_bf16 with flashinfer backend produces valid output.""" + from megatron.core.inference.quantization.mxfp8_tensor import HAVE_FLASHINFER, MXFP8Tensor + + if not HAVE_FLASHINFER: + pytest.skip("FlashInfer not available") + + torch.manual_seed(42) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + + tensor = MXFP8Tensor.from_bf16(x, backend="flashinfer") + assert tensor.data.shape == (M, K) + assert tensor.data.dtype == torch.float8_e4m3fn + assert tensor.backend == "flashinfer" + + def test_from_bf16_invalid_backend(self): + """from_bf16 with invalid backend raises ValueError.""" + from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor + + x = torch.randn(16, 128, device="cuda", dtype=torch.bfloat16) + with pytest.raises(ValueError, match="Unknown MXFP8 quantization backend"): + MXFP8Tensor.from_bf16(x, backend="invalid") + + @pytest.mark.parametrize("M,K", [(1, 32), (16, 128), (128, 2688), (256, 1856)]) + def test_scale_2d_shape(self, M, K): + """scale_2d returns correct shape: (-1, ceil(K//32, 4)*4).""" + from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor + + torch.manual_seed(42) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + tensor = MXFP8Tensor.from_bf16(x, backend="triton") + + scale_2d = tensor.scale_2d() + expected_cols = ceil_div(K // 32, 4) * 4 + assert scale_2d.dim() == 2 + assert scale_2d.shape[1] == expected_cols + + @pytest.mark.parametrize("M,K", [(16, 128), (128, 2688)]) + def test_scale_2d_idempotent(self, M, K): + """Calling scale_2d twice returns the same result.""" + from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor + + torch.manual_seed(42) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + tensor = MXFP8Tensor.from_bf16(x, backend="triton") + + s1 = tensor.scale_2d() + s2 = tensor.scale_2d() + torch.testing.assert_close(s1.view(torch.uint8), s2.view(torch.uint8), atol=0, rtol=0) + + def test_size_method(self): + """size() delegates to data.size().""" + from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor + + x = torch.randn(32, 128, device="cuda", dtype=torch.bfloat16) + tensor = MXFP8Tensor.from_bf16(x, backend="triton") + assert tensor.size() == torch.Size([32, 128]) + assert tensor.size(0) == 32 + assert tensor.size(1) == 128 + + +# ────────────────────────────────────────────────────────────────────── +# Triton vs FlashInfer cross-validation +# ────────────────────────────────────────────────────────────────────── + + +@pytest.mark.skipif( + torch.cuda.get_device_capability()[0] < 10, + reason="MXFP8 FlashInfer comparison requires Blackwell (SM 100+)", +) +class TestTritonVsFlashinfer: + + @pytest.mark.parametrize("M,K", [(1, 32), (16, 128), (64, 256), (128, 2688), (256, 1856)]) + def test_data_matches(self, M, K): + """Triton and FlashInfer backends produce identical FP8 data.""" + from megatron.core.inference.quantization.mxfp8_tensor import HAVE_FLASHINFER, MXFP8Tensor + + if not HAVE_FLASHINFER: + pytest.skip("FlashInfer not available") + + torch.manual_seed(42) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + + triton_tensor = MXFP8Tensor.from_bf16(x, backend="triton") + flashinfer_tensor = MXFP8Tensor.from_bf16(x, backend="flashinfer") + + torch.testing.assert_close( + triton_tensor.data.float(), flashinfer_tensor.data.float(), atol=0, rtol=0 + ) + + @pytest.mark.parametrize("M,K", [(1, 32), (16, 128), (64, 256), (128, 2688), (256, 1856)]) + def test_scales_match(self, M, K): + """Triton and FlashInfer backends produce identical swizzled scales.""" + from megatron.core.inference.quantization.mxfp8_tensor import HAVE_FLASHINFER, MXFP8Tensor + + if not HAVE_FLASHINFER: + pytest.skip("FlashInfer not available") + + torch.manual_seed(42) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + + triton_tensor = MXFP8Tensor.from_bf16(x, backend="triton") + flashinfer_tensor = MXFP8Tensor.from_bf16(x, backend="flashinfer") + + torch.testing.assert_close( + triton_tensor.scale.view(torch.uint8), + flashinfer_tensor.scale.view(torch.uint8), + atol=0, + rtol=0, + ) + + +def _make_permutation_map(M, num_padding=0): + """Create a permutation_map with optional padding rows at the end.""" + real = torch.arange(M - num_padding, dtype=torch.int32, device="cuda") + pad = torch.full((num_padding,), -1, dtype=torch.int32, device="cuda") + return torch.cat([real, pad]) + + +# ────────────────────────────────────────────────────────────────────── +# squared_relu_and_quantize_mxfp8 vs PyTorch reference +# ────────────────────────────────────────────────────────────────────── + + +class TestSquaredReluAndQuantizeMxfp8: + """Compare fused squared_relu + mxfp8 quantize against PyTorch reference. + + Reference: torch.relu(x.float()).pow(2).to(bf16) -> ref_to_mxfp -> ref_swizzle. + The fused kernel computes squared ReLU in fp32 and quantizes to MXFP8 in one pass, + so the PyTorch fp32 reference is the correct baseline (not the unfused Triton path + which has an intermediate bf16 roundtrip). + """ + + @pytest.mark.parametrize( + "M,K", + [ + (1, 32), + (4, 64), + (16, 128), + (32, 256), + (64, 128), + (128, 128), + (128, 256), + (128, 2688), + (256, 1856), + (512, 2688), + ], + ) + def test_data_matches_pytorch_ref(self, M, K): + """Fused FP8 data matches PyTorch squared ReLU + ref_to_mxfp.""" + from megatron.core.inference.moe.activations import squared_relu_and_quantize_mxfp8 + + torch.manual_seed(42) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + perm_map = _make_permutation_map(M, num_padding=0) + + # PyTorch reference: squared ReLU in fp32, then downcast to bf16, then quantize + activated_ref = torch.relu(x.float()).pow(2) + _, ref_data = ref_to_mxfp(activated_ref) + + # Fused kernel + fused_result = squared_relu_and_quantize_mxfp8(x, perm_map) + + torch.testing.assert_close( + fused_result.data.view(torch.uint8), ref_data.view(torch.uint8), atol=0, rtol=0 + ) + + @pytest.mark.parametrize("M,K", [(1, 32), (16, 128), (128, 128), (128, 2688), (256, 1856)]) + def test_scales_match_pytorch_ref(self, M, K): + """Fused swizzled scales match PyTorch ref_to_mxfp + ref_swizzle.""" + from megatron.core.inference.moe.activations import squared_relu_and_quantize_mxfp8 + + torch.manual_seed(42) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + perm_map = _make_permutation_map(M, num_padding=0) + + # PyTorch reference + activated_ref = torch.relu(x.float()).pow(2) + ref_scales_2d, _ = ref_to_mxfp(activated_ref) + ref_swizzled = ref_swizzle(ref_scales_2d) + + # Fused kernel + fused_result = squared_relu_and_quantize_mxfp8(x, perm_map) + + torch.testing.assert_close( + fused_result.scale.view(torch.uint8), ref_swizzled.view(torch.uint8), atol=0, rtol=0 + ) + + @pytest.mark.parametrize( + "M,K,num_padding", + [(32, 128, 8), (64, 256, 16), (128, 128, 32), (128, 2688, 64), (256, 1856, 128)], + ) + def test_real_rows_match_pytorch_ref_with_padding(self, M, K, num_padding): + """Real rows match PyTorch reference even when padding rows are present.""" + from megatron.core.inference.moe.activations import squared_relu_and_quantize_mxfp8 + + torch.manual_seed(42) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + perm_map = _make_permutation_map(M, num_padding=num_padding) + + # PyTorch reference (only real rows) + real_rows = M - num_padding + activated_ref = torch.relu(x[:real_rows].float()).pow(2) + _, ref_data = ref_to_mxfp(activated_ref) + + # Fused kernel + fused_result = squared_relu_and_quantize_mxfp8(x, perm_map) + + torch.testing.assert_close( + fused_result.data[:real_rows].view(torch.uint8), + ref_data.view(torch.uint8), + atol=0, + rtol=0, + ) + + +# ────────────────────────────────────────────────────────────────────── +# permute_and_quantize_mxfp8 +# ────────────────────────────────────────────────────────────────────── + + +class TestPermuteAndQuantizeMxfp8: + """Compare fused permute + mxfp8 quantize against PyTorch reference. + + PyTorch reference: + 1. For each real row, quantize the source token with ref_to_mxfp + 2. Compare FP8 data per source token + Structural checks (permutation_map, probs, offsets) verified independently. + """ + + def _make_inputs(self, num_tokens, K, topk, num_experts, seed=42): + torch.manual_seed(seed) + hidden = torch.randn(num_tokens, K, device="cuda", dtype=torch.bfloat16) + probs = torch.rand(num_tokens, topk, device="cuda", dtype=torch.float32) + routing_map = torch.randint(0, num_experts, (num_tokens, topk), device="cuda") + return hidden, probs, routing_map + + @pytest.mark.parametrize( + "num_tokens,K,topk,num_experts", + [ + (4, 128, 2, 4), + (16, 128, 2, 8), + (32, 256, 4, 8), + (64, 128, 6, 8), + (64, 2688, 8, 128), + (128, 1856, 4, 32), + ], + ) + def test_data_matches_pytorch_ref(self, num_tokens, K, topk, num_experts): + """For each real row, fused FP8 data matches ref_to_mxfp of the source token.""" + from megatron.core.inference.moe.permute import permute_and_quantize_mxfp8 + + hidden, probs, routing_map = self._make_inputs(num_tokens, K, topk, num_experts) + + fused_mxfp8, _, fused_perm_map, _ = permute_and_quantize_mxfp8( + hidden, probs, routing_map, 0, num_experts, alignment=128 + ) + + # For each real row, quantize the source token with PyTorch ref and compare + for i in range(fused_perm_map.shape[0]): + src = fused_perm_map[i].item() + if src < 0: + continue + _, ref_data = ref_to_mxfp(hidden[src].unsqueeze(0)) + torch.testing.assert_close( + fused_mxfp8.data[i].view(torch.uint8), + ref_data.squeeze(0).view(torch.uint8), + atol=0, + rtol=0, + msg=f"Row {i} (src={src}) FP8 data mismatch vs PyTorch ref", + ) + + @pytest.mark.parametrize( + "num_tokens,K,topk,num_experts", [(16, 128, 2, 8), (32, 256, 4, 8), (64, 2688, 8, 128)] + ) + def test_batch_data_matches_pytorch_ref(self, num_tokens, K, topk, num_experts): + """Batch comparison: gather all real rows, quantize as batch, compare.""" + from megatron.core.inference.moe.permute import permute_and_quantize_mxfp8 + + hidden, probs, routing_map = self._make_inputs(num_tokens, K, topk, num_experts) + + fused_mxfp8, _, fused_perm_map, _ = permute_and_quantize_mxfp8( + hidden, probs, routing_map, 0, num_experts, alignment=128 + ) + + real_mask = fused_perm_map >= 0 + real_indices = real_mask.nonzero(as_tuple=True)[0] + if len(real_indices) == 0: + return + + src_tokens = fused_perm_map[real_indices].long() + permuted_bf16 = hidden[src_tokens] + + _, ref_data = ref_to_mxfp(permuted_bf16) + + torch.testing.assert_close( + fused_mxfp8.data[real_indices].view(torch.uint8), + ref_data.view(torch.uint8), + atol=0, + rtol=0, + ) + + @pytest.mark.parametrize( + "num_tokens,K,topk,num_experts", [(16, 128, 2, 8), (32, 256, 4, 8), (64, 2688, 8, 128)] + ) + def test_correct_token_count(self, num_tokens, K, topk, num_experts): + """Number of real rows equals total (token, topk) pairs routed to local experts.""" + from megatron.core.inference.moe.permute import permute_and_quantize_mxfp8 + + hidden, probs, routing_map = self._make_inputs(num_tokens, K, topk, num_experts) + + _, _, fused_perm_map, _ = permute_and_quantize_mxfp8( + hidden, probs, routing_map, 0, num_experts, alignment=128 + ) + + real_count = (fused_perm_map >= 0).sum().item() + # All experts are local, so every pair should appear + assert real_count == num_tokens * topk + + @pytest.mark.parametrize( + "num_tokens,K,topk,num_experts,local_start,num_local", + [(64, 128, 4, 8, 2, 3), (64, 256, 4, 8, 0, 4), (128, 128, 8, 128, 96, 32)], + ) + def test_expert_subset(self, num_tokens, K, topk, num_experts, local_start, num_local): + """Fused kernel correctly handles local expert subsets.""" + from megatron.core.inference.moe.permute import permute_and_quantize_mxfp8 + + hidden, probs, routing_map = self._make_inputs(num_tokens, K, topk, num_experts) + + _, _, fused_perm_map, _ = permute_and_quantize_mxfp8( + hidden, probs, routing_map, local_start, num_local, alignment=128 + ) + + real_count = (fused_perm_map >= 0).sum().item() + local_mask = (routing_map >= local_start) & (routing_map < local_start + num_local) + expected_count = local_mask.sum().item() + assert real_count == expected_count + + def test_returns_mxfp8_tensor(self): + """Result is an MXFP8Tensor with correct backend.""" + from megatron.core.inference.moe.permute import permute_and_quantize_mxfp8 + from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor + + hidden, probs, routing_map = self._make_inputs(16, 128, 2, 4) + result, _, _, _ = permute_and_quantize_mxfp8( + hidden, probs, routing_map, 0, 4, alignment=128 + ) + assert isinstance(result, MXFP8Tensor) + assert result.backend == "triton" + assert result.data.dtype == torch.float8_e4m3fn + + @pytest.mark.parametrize("alignment", [128]) + def test_offsets_aligned(self, alignment): + """Inclusive offsets are multiples of alignment.""" + from megatron.core.inference.moe.permute import permute_and_quantize_mxfp8 + + hidden, probs, routing_map = self._make_inputs(64, 128, 4, 8) + _, _, _, offs = permute_and_quantize_mxfp8( + hidden, probs, routing_map, 0, 8, alignment=alignment + ) + for i in range(offs.shape[0]): + assert ( + offs[i].item() % alignment == 0 + ), f"Offset {i}={offs[i].item()} not aligned to {alignment}" diff --git a/tests/unit_tests/models/test_mamba_moe_model.py b/tests/unit_tests/models/test_mamba_moe_model.py index c3894a8cb67..02421114fe2 100644 --- a/tests/unit_tests/models/test_mamba_moe_model.py +++ b/tests/unit_tests/models/test_mamba_moe_model.py @@ -282,8 +282,9 @@ "offload_modules": [], "hybrid_context_parallel": False, "max_seqlen_per_dp_cp_rank": None, - "inference_disable_torch_grouped_mm": False, "inference_disable_triton_nvls_kernels": False, + "inference_grouped_gemm_backend": "auto", + "inference_moe_disable_fused_quant_kernels": False, } # Fields to ignore entirely (ephemeral, environment-specific, very large). SKIP_FIELDS = set()