Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 204 additions & 0 deletions megatron/core/transformer/custom_layers/batch_invariant_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"is_batch_invariant_mode_enabled",
"disable_batch_invariant_mode",
"enable_batch_invariant_mode",
"grouped_gemm_batch_invariant",
]


Expand Down Expand Up @@ -234,6 +235,209 @@ def grid(META):
return c


@triton.jit
def _grouped_gemm_batch_invariant_kernel(
# Pointers
a_ptr,
b_ptr,
c_ptr,
bias_ptr,
batch_sizes_ptr,
a_offsets_ptr,
schedule_ptr,
# Dimensions
K,
N,
# Strides
stride_am,
stride_ak,
stride_be,
stride_bn,
stride_bk,
stride_cm,
stride_cn,
stride_bias_e,
stride_bias_n,
# Meta
num_tiles,
# Constants
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
NUM_SMS: tl.constexpr,
HAS_BIAS: tl.constexpr,
):
"""Batch-invariant grouped GEMM Triton kernel.

Each tile is pre-assigned to an (expert, m-block, n-block) triple via a
CPU-built schedule tensor. Persistent-style: program IDs stride over all
tiles so that a fixed number of SMs can service an arbitrary tile count.
"""
pid = tl.program_id(axis=0)
num_pid_groups = tl.num_programs(axis=0)
idx = pid.to(tl.int64)

while idx < num_tiles:
# 1. Fetch schedule entry: (expert_idx, pid_m, pid_n)
sched_offset = idx * 3
expert_idx = tl.load(schedule_ptr + sched_offset).to(tl.int64)
pid_m = tl.load(schedule_ptr + sched_offset + 1).to(tl.int64)
pid_n = tl.load(schedule_ptr + sched_offset + 2).to(tl.int64)

current_expert_m = tl.load(batch_sizes_ptr + expert_idx)
global_m_start = tl.load(a_offsets_ptr + expert_idx)

# 2. Compute pointers
offs_am = (pid_m * BLOCK_M) + tl.arange(0, BLOCK_M)
offs_bn = (pid_n * BLOCK_N) + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)

a_ptrs = a_ptr + (
(global_m_start + offs_am[:, None]) * stride_am + offs_k[None, :] * stride_ak
)
b_ptrs = b_ptr + (
expert_idx * stride_be + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn
)

# 3. Matmul accumulation loop
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

for k in range(0, tl.cdiv(K, BLOCK_K)):
a_mask = (offs_am[:, None] < current_expert_m) & (offs_k[None, :] < (K - k * BLOCK_K))
b_mask = (offs_k[:, None] < (K - k * BLOCK_K)) & (offs_bn[None, :] < N)

a = tl.load(a_ptrs, mask=a_mask, other=0.0)
b = tl.load(b_ptrs, mask=b_mask, other=0.0)

accumulator = tl.dot(a, b, accumulator)

a_ptrs += BLOCK_K * stride_ak
b_ptrs += BLOCK_K * stride_bk

# 4. Optional bias addition (bias is [Experts, N])
if HAS_BIAS:
bias_offset = expert_idx * stride_bias_e + offs_bn * stride_bias_n
bias = tl.load(bias_ptr + bias_offset, mask=offs_bn < N, other=0.0).to(tl.float32)
accumulator += bias[None, :]

# 5. Store output
c_val = accumulator.to(c_ptr.dtype.element_ty)

offs_cm = (pid_m * BLOCK_M) + tl.arange(0, BLOCK_M)
offs_cn = (pid_n * BLOCK_N) + tl.arange(0, BLOCK_N)

c_ptrs = c_ptr + (
(global_m_start + offs_cm[:, None]) * stride_cm + offs_cn[None, :] * stride_cn
)
c_mask = (offs_cm[:, None] < current_expert_m) & (offs_cn[None, :] < N)

tl.store(c_ptrs, c_val, mask=c_mask)

idx += num_pid_groups


def _build_grouped_gemm_schedule(batch_sizes_cpu, BLOCK_M, BLOCK_N, N, device):
"""Build the (expert, m_block, n_block) tile schedule on CPU.

Returns:
schedule: int32 tensor of shape [num_tiles, 3] on ``device``.
num_tiles: total number of tiles.
"""
m_blocks_per_expert = (batch_sizes_cpu + BLOCK_M - 1) // BLOCK_M
n_blocks = (N + BLOCK_N - 1) // BLOCK_N
num_experts = len(batch_sizes_cpu)

schedule_list = []
for e in range(num_experts):
m_blks = int(m_blocks_per_expert[e])
if m_blks > 0:
ms = torch.arange(m_blks, device='cpu')
ns = torch.arange(n_blocks, device='cpu')
grid_m, grid_n = torch.meshgrid(ms, ns, indexing='ij')
expert_col = torch.full_like(grid_m, e)
schedule_list.append(
torch.stack([expert_col.flatten(), grid_m.flatten(), grid_n.flatten()], dim=1)
)

if not schedule_list:
return None, 0

schedule = torch.cat(schedule_list, dim=0).to(device=device, dtype=torch.int32)
return schedule, schedule.size(0)


def grouped_gemm_batch_invariant(a, b, c, batch_sizes, bias=None, trans_b=False):
"""Launch the batch-invariant grouped GEMM Triton kernel.

Args:
a: Concatenated activations, shape ``[total_tokens, K]``.
b: Stacked expert weights. ``[E, K, N]`` when *trans_b* is False,
``[E, N, K]`` when *trans_b* is True.
c: Pre-allocated output tensor, shape ``[total_tokens, N]``.
batch_sizes: 1-D tensor of length ``E`` with token counts per expert.
bias: Optional bias of shape ``[E, N]``.
trans_b: If True, ``b`` is ``[E, N, K]`` (standard PyTorch Linear layout).
"""
K = a.size(1)
if trans_b:
N = b.size(1)
stride_be, stride_bn, stride_bk = b.stride(0), b.stride(1), b.stride(2)
else:
N = b.size(2)
stride_be, stride_bn, stride_bk = b.stride(0), b.stride(2), b.stride(1)

BLOCK_M, BLOCK_N, BLOCK_K = 128, 128, 64

bs_cpu = batch_sizes.cpu()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will not work with cuda-graphs. Should we disable them in transformer config?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea will add some assertion that it doesn't work with cuda graphs.

schedule, num_tiles = _build_grouped_gemm_schedule(bs_cpu, BLOCK_M, BLOCK_N, N, a.device)
if schedule is None:
return c

num_experts = len(batch_sizes)
a_offsets = torch.zeros(num_experts, device=a.device, dtype=torch.int64)
if num_experts > 1:
a_offsets[1:] = torch.cumsum(batch_sizes[:-1], dim=0)

NUM_SMS = get_compute_units()
grid_size = min(NUM_SMS * 4, num_tiles)

stride_bias_e, stride_bias_n = (0, 0)
if bias is not None:
stride_bias_e, stride_bias_n = bias.stride(0), bias.stride(1)

_grouped_gemm_batch_invariant_kernel[(grid_size,)](
a,
b,
c,
bias,
batch_sizes,
a_offsets,
schedule,
K,
N,
a.stride(0),
a.stride(1),
stride_be,
stride_bn,
stride_bk,
c.stride(0),
c.stride(1),
stride_bias_e,
stride_bias_n,
num_tiles,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
BLOCK_K=BLOCK_K,
GROUP_SIZE_M=8,
NUM_SMS=NUM_SMS,
HAS_BIAS=(bias is not None),
num_warps=4,
num_stages=3,
)
return c


@triton.jit
def _log_softmax_kernel(
input_ptr, output_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr
Expand Down
84 changes: 83 additions & 1 deletion megatron/core/transformer/moe/experts.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,20 @@
except ImportError:
HAVE_FLASHINFER = False

try:
from megatron.core.transformer.custom_layers.batch_invariant_kernels import (
grouped_gemm_batch_invariant,
is_batch_invariant_mode_enabled,
)

HAVE_BATCH_INVARIANT = True
except ImportError:
HAVE_BATCH_INVARIANT = False

def is_batch_invariant_mode_enabled():
return False


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -621,19 +635,82 @@ def _torch_grouped_mm_forward(

return fc2_output, None

def _triton_batch_invariant_forward(
self, permuted_local_hidden_states, tokens_per_expert, permuted_probs
):
"""Batch-invariant grouped GEMM forward using Triton kernel.

Provides deterministic results regardless of batch composition by using
a pre-scheduled tile assignment that is independent of dynamic batching.
"""
permuted_probs = permuted_probs.unsqueeze(-1)
if not tokens_per_expert.is_cuda:
tokens_per_expert = tokens_per_expert.to(permuted_local_hidden_states.device)

if self.config.moe_apply_probs_on_input:
assert (
self.config.moe_router_topk == 1
), "`moe_apply_probs_on_input` only works with `moe_router_topk`=1."
original_dtype = permuted_local_hidden_states.dtype
permuted_local_hidden_states = permuted_probs * permuted_local_hidden_states
permuted_local_hidden_states = permuted_local_hidden_states.to(original_dtype)
permuted_probs = torch.ones_like(permuted_probs)

if permuted_local_hidden_states.nelement() != 0:
batch_sizes = tokens_per_expert.to(torch.int64)

# fc1: _fc1_weight is [E, out_features, in_features] (TE layout [N, K])
# Use trans_b=True since weights are [E, N, K]
total_tokens = permuted_local_hidden_states.size(0)
fc1_out_features = self._fc1_weight.size(1)
fc1_output = torch.empty(
total_tokens,
fc1_out_features,
device=permuted_local_hidden_states.device,
dtype=permuted_local_hidden_states.dtype,
)
grouped_gemm_batch_invariant(
permuted_local_hidden_states,
self._fc1_weight,
fc1_output,
batch_sizes,
trans_b=True,
)

# Activation with routing probabilities
bias_act_output = self.bias_act_func(fc1_output, None, permuted_probs)

# fc2: _fc2_weight is [E, out_features, in_features] (TE layout [N, K])
fc2_out_features = self._fc2_weight.size(1)
fc2_output = torch.empty(
total_tokens,
fc2_out_features,
device=bias_act_output.device,
dtype=bias_act_output.dtype,
)
grouped_gemm_batch_invariant(
bias_act_output, self._fc2_weight, fc2_output, batch_sizes, trans_b=True
)
else:
fc2_output = permuted_local_hidden_states

return fc2_output, None

def forward(
self,
permuted_local_hidden_states: torch.Tensor,
tokens_per_expert: Optional[torch.Tensor],
permuted_probs: torch.Tensor,
routing_map: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Forward pass with three modes:
"""Forward pass with four modes:

- Training: delegates to parent TEGroupedMLP.
- Inference + CUDA graphed: FlashInfer cutlass_fused_moe. tokens_per_expert
is not used in this path; the FlashInfer kernel operates directly on
routing_map.
- Inference + batch invariant: Triton grouped GEMM with deterministic
tile scheduling for bitwise-reproducible results.
- Inference + eager: torch._grouped_mm with GPU-resident cumsum offsets.

Args:
Expand All @@ -656,6 +733,11 @@ def forward(
permuted_local_hidden_states, routing_map, permuted_probs
)

elif is_batch_invariant_mode_enabled():
return self._triton_batch_invariant_forward(
permuted_local_hidden_states, tokens_per_expert, permuted_probs
)

elif self._torch_grouped_mm_available:
return self._torch_grouped_mm_forward(
permuted_local_hidden_states, tokens_per_expert, permuted_probs
Expand Down
Loading