diff --git a/examples/gptoss/01_convert_from_hf.py b/examples/gptoss/01_convert_from_hf.py new file mode 100644 index 00000000000..adee3358ec3 --- /dev/null +++ b/examples/gptoss/01_convert_from_hf.py @@ -0,0 +1,55 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""Convert HuggingFace checkpoints to Megatron format.""" + +import os +import argparse + +from megatron.bridge import AutoBridge + +def _parse_args(): + parser = argparse.ArgumentParser(description="Convert HF LLMs to Megatron format") + parser.add_argument( + "--hf-model", + type=str, + required=True, + help="HuggingFace model identifier or path", + ) + parser.add_argument( + "--save-path", + type=str, + default=None, + help="Path to save the converted Megatron checkpoint", + ) + parser.add_argument('--local-rank', '--local_rank', type=int, default=0) + return parser.parse_args() + +if __name__ == "__main__": + args = _parse_args() + HF_MODEL = args.hf_model + SAVE_PATH = args.save_path + WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1)) + + if SAVE_PATH is None: + SAVE_PATH = f"./megatron_checkpoints/{HF_MODEL.replace('/', '_')}" + + print(f"Converting {HF_MODEL} to Megatron format...") + print(f"Save path: {SAVE_PATH}") + + bridge = AutoBridge.from_hf_pretrained(HF_MODEL, trust_remote_code=True) + provider = bridge.to_megatron_provider() + # Update these configs as needed + provider.expert_tensor_parallel_size = 1 + provider.tensor_model_parallel_size = 1 + provider.pipeline_model_parallel_size = WORLD_SIZE + provider.finalize() + + model = provider.provide_distributed_model(wrap_with_ddp=False) + + bridge.save_megatron_model( + model, + SAVE_PATH, + hf_tokenizer_path=HF_MODEL + ) + + print(f"Saved Megatron checkpoint to {SAVE_PATH}") diff --git a/examples/gptoss/02_train.sh b/examples/gptoss/02_train.sh new file mode 100644 index 00000000000..51369d8005d --- /dev/null +++ b/examples/gptoss/02_train.sh @@ -0,0 +1,259 @@ +#!/bin/bash + +export CUDA_DEVICE_MAX_CONNECTIONS=${CUDA_DEVICE_MAX_CONNECTIONS:-1} + + +# Setup arguments with defaults +CHECKPOINT_PATH="NO_VALUE_PROVIDED" +TENSORBOARD_LOGS_PATH="./tensorboard_logs/" +TOKENIZER_ARG="MOCK" +DATA_ARG="MOCK" +DISTRIBUTED_CONFIG_FILE="" + +# Parse command line arguments +while [[ $# -gt 0 ]]; do + case $1 in + --checkpoint-path) + CHECKPOINT_PATH="$2" + shift 2 + ;; + --tensorboard-logs-path) + TENSORBOARD_LOGS_PATH="$2" + shift 2 + ;; + --tokenizer) + TOKENIZER_ARG="$2" + shift 2 + ;; + --data) + DATA_ARG="$2" + shift 2 + ;; + --distributed-config-file) + DISTRIBUTED_CONFIG_FILE="$2" + shift 2 + ;; + -h|--help) + echo "Usage: $0 [OPTIONS]" + echo "Options:" + echo " --checkpoint-path PATH Path to Megatron checkpoint" + echo " --tensorboard-logs-path PATH Path to TensorBoard logs" + echo " --tokenizer PATH|MOCK Path to tokenizer model, or 'MOCK' (default: MOCK)" + echo " --data PATH|MOCK Data prefix, or 'MOCK' (default: MOCK)" + echo " --distributed-config-file FILE Path to distributed training config file" + echo " -h, --help Show this help message" + exit 0 + ;; + *) + echo "Unknown option: $1" + echo "Use --help for usage information" + exit 1 + ;; + esac +done + +# Check if checkpoint path exists +if [ ! -d "$CHECKPOINT_PATH" ]; then + echo "Error: Checkpoint path does not exist: $CHECKPOINT_PATH" + exit 1 +fi +echo "Checkpoint path exists: $CHECKPOINT_PATH" + +# Check if tensorboard logs path exists +if [ ! -d "$TENSORBOARD_LOGS_PATH" ]; then + echo "Warning: TensorBoard logs path does not exist. Creating: $TENSORBOARD_LOGS_PATH" + mkdir -p "$TENSORBOARD_LOGS_PATH" +fi +echo "TensorBoard logs path exists: $TENSORBOARD_LOGS_PATH" + +# NOTE: by default we use 8 GPUs +# These values will be over-written below with environmental variables +GPUS_PER_NODE=8 +NUM_NODES=1 +MASTER_ADDR="localhost" +MASTER_PORT=6000 +NODE_RANK=0 + +# Load distributed config from file if provided +if [ -n "$DISTRIBUTED_CONFIG_FILE" ]; then + if [ ! -f "$DISTRIBUTED_CONFIG_FILE" ]; then + echo "Warning: Distributed config file does not exist: $DISTRIBUTED_CONFIG_FILE" + echo "Continuing with default distributed training settings." + else + echo "Loading distributed config from: $DISTRIBUTED_CONFIG_FILE" + source "$DISTRIBUTED_CONFIG_FILE" + fi +fi + +# Override with environment variables if set +GPUS_PER_NODE=${GPUS_PER_NODE:-8} +NUM_NODES=${NUM_NODES:-1} +MASTER_ADDR=${MASTER_ADDR:-localhost} +MASTER_PORT=${MASTER_PORT:-6000} +NODE_RANK=${NODE_RANK:-0} +WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES)) + +# Path to the pretrain_gpt.py script, assuming this script is run from the root of the Megatron-LM repository +PRETRAIN_SCRIPT_PATH="pretrain_gpt.py" + +# Data cache path (useful for both mock and real data) +DATA_CACHE_PATH="${PWD}/benchmark_cache_gpt_oss_20b" +mkdir -p "$DATA_CACHE_PATH" + +DISTRIBUTED_ARGS=( + --nproc_per_node $GPUS_PER_NODE + --nnodes $NUM_NODES + --master_addr $MASTER_ADDR + --master_port $MASTER_PORT + --node_rank $NODE_RANK +) + +# NOTE: we only set pipeline parallelism to be the number of GPUs +# Adjust each value based on your setup. +TP_SIZE=1 +EP_SIZE=1 +PP_SIZE=${WORLD_SIZE} +MICRO_BATCH_SIZE=1 +GLOBAL_BATCH_SIZE=128 +NUM_LAYERS=12 +DTYPE="fp8" +SEQ_LENGTH=8192 +MAX_POSITION_EMBEDDINGS=8192 +TRAIN_SAMPLES=1953125000 +LR_DECAY_SAMPLES=1949218748 + +MODEL_ARGS=( + --no-masked-softmax-fusion + --transformer-impl transformer_engine + --disable-bias-linear + --untie-embeddings-and-output-weights + --no-rope-fusion + --normalization RMSNorm + --num-layers ${NUM_LAYERS} + --hidden-size 512 + --ffn-hidden-size 2048 + --num-attention-heads 64 + --group-query-attention + --num-query-groups 8 + --seq-length ${SEQ_LENGTH} + --max-position-embeddings ${MAX_POSITION_EMBEDDINGS} + --use-mcore-models + --rotary-percent 1.0 + --rope-type rope + --position-embedding-type rope + --rotary-base 10000 + --no-bias-gelu-fusion + --export-force-local-attention + --no-bias-dropout-fusion + --quick-geglu + --glu-linear-offset 1.0 + --softmax-type learnable + --window-attn-skip-freq 2 + --activation-func-clamp-value 7.0 + --window-size 128,0 + --enable-gpt-oss +) + +MOE_ARGS=( + --num-experts 4 + --moe-router-topk 2 + --moe-router-load-balancing-type aux_loss + --moe-aux-loss-coeff 1e-3 + --moe-grouped-gemm + --moe-token-dispatcher-type alltoall + --overlap-param-gather + --overlap-grad-reduce + --moe-ffn-hidden-size 2048 + --moe-router-dtype fp32 + --moe-z-loss-coeff 1e-3 + --moe-permute-fusion +) + +DATA_ARGS_LIST=() +if [[ "$TOKENIZER_ARG" == "MOCK" ]] || [[ "$DATA_ARG" == "MOCK" ]] || [[ -z "$TOKENIZER_ARG" ]]; then + DATA_ARGS_LIST+=( + "--mock-data" + "--tokenizer-type NullTokenizer" + "--vocab-size 128256" + "--data-cache-path ${DATA_CACHE_PATH}" + "--tiktoken-pattern v2" + "--split '99,1,0'" + "--no-create-attention-mask-in-dataloader" + "--no-mmap-bin-files" + "--num-workers 1" + ) +else + # Settings for real data + DATA_ARGS_LIST+=( + "--data-path $DATA_ARG" + "--tokenizer-type HuggingFaceTokenizer" + "--tokenizer-model $TOKENIZER_ARG" + "--data-cache-path ${DATA_CACHE_PATH}" + "--split '99,1,0'" + "--no-create-attention-mask-in-dataloader" + "--no-mmap-bin-files" + "--num-workers 1" + # Note: --vocab-size might be inferred by HuggingFaceTokenizer or might need to be explicit. + "--vocab-size 128256" + ) +fi + +TRAINING_ARGS=( + --micro-batch-size ${MICRO_BATCH_SIZE} + --global-batch-size ${GLOBAL_BATCH_SIZE} + --lr 1.0e-5 + --train-samples ${TRAIN_SAMPLES} + --lr-decay-samples ${LR_DECAY_SAMPLES} + --lr-decay-style cosine + --min-lr 1.0e-6 + --weight-decay 0.1 + --lr-warmup-fraction 0.05 + --clip-grad 1.0 + --bf16 + --use-flash-attn + --attention-softmax-in-fp32 + --accumulate-allreduce-grads-in-fp32 + --disable-bf16-reduced-precision-matmul + --recompute-activations +) + +MODEL_PARALLEL_ARGS=( + --tensor-model-parallel-size ${TP_SIZE} + --pipeline-model-parallel-size ${PP_SIZE} + --expert-model-parallel-size ${EP_SIZE} + --sequence-parallel + --context-parallel-size 1 + --use-distributed-optimizer + --fp8-format hybrid + --fp8-param-gather + --fp8-amax-compute-algo max + --fp8-amax-history-len 1024 +) + +LOGGING_ARGS=( + --log-interval 1 + --save-interval 10000 + --eval-interval 50000000 + --eval-iters 0 + --save $CHECKPOINT_PATH + --tensorboard-dir "${CHECKPOINT_PATH}/tensorboard" + --moe-per-layer-logging + --no-load-optim + --no-load-rng + --log-throughput +) + +# Ensure pretrain_gpt.py is found +if [ ! -f "$PRETRAIN_SCRIPT_PATH" ]; then + echo "Error: pretrain_gpt.py not found at $PRETRAIN_SCRIPT_PATH" + echo "Please ensure you are running this script from the root of the Megatron-LM repository, and pretrain_gpt.py is present." + exit 1 +fi + +python -m torch.distributed.run ${DISTRIBUTED_ARGS[@]} ${PRETRAIN_SCRIPT_PATH} \ + ${MODEL_ARGS[@]} \ + ${MOE_ARGS[@]} \ + ${DATA_ARGS_LIST[@]} \ + ${TRAINING_ARGS[@]} \ + ${MODEL_PARALLEL_ARGS[@]} \ + ${LOGGING_ARGS[@]} \ No newline at end of file diff --git a/examples/gptoss/03_convert_to_hf.py b/examples/gptoss/03_convert_to_hf.py new file mode 100644 index 00000000000..8089afec854 --- /dev/null +++ b/examples/gptoss/03_convert_to_hf.py @@ -0,0 +1,52 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""Convert HuggingFace checkpoints to Megatron format.""" + +import os +import argparse + +from megatron.bridge import AutoBridge + +def _parse_args(): + parser = argparse.ArgumentParser(description="Convert Megatron LLMs to HuggingFace format") + parser.add_argument( + "--hf-model", + type=str, + required=True, + help="HuggingFace model identifier or path to load config from", + ) + parser.add_argument( + "--megatron-model", + type=str, + required=True, + help="Megatron model identifier or path", + ) + parser.add_argument( + "--save-path", + type=str, + default=None, + help="Path to save the converted HuggingFace checkpoint", + ) + parser.add_argument('--local-rank', '--local_rank', type=int, default=0) + return parser.parse_args() + +if __name__ == "__main__": + args = _parse_args() + HF_MODEL = args.hf_model + MEGATRON_MODEL = args.megatron_model + SAVE_PATH = args.save_path + WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1)) + + if SAVE_PATH is None: + SAVE_PATH = f"./huggingface_checkpoints/{MEGATRON_MODEL.replace('/', '_')}" + + print(f"Converting {MEGATRON_MODEL} to HuggingFace {HF_MODEL} format...") + print(f"Save path: {SAVE_PATH}") + + bridge = AutoBridge.from_hf_pretrained(HF_MODEL, trust_remote_code=True) + bridge.export_ckpt( + MEGATRON_MODEL, + SAVE_PATH, + ) + + print(f"Saved HuggingFace checkpoint to {SAVE_PATH}") diff --git a/examples/post_training/modelopt/conf/nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16.sh b/examples/post_training/modelopt/conf/nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16.sh new file mode 100644 index 00000000000..977be033df0 --- /dev/null +++ b/examples/post_training/modelopt/conf/nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16.sh @@ -0,0 +1,62 @@ +#!/bin/bash + +if [ -z ${HF_MODEL_CKPT} ]; then + HF_MODEL_CKPT=nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16 + TOKENIZER_MODEL=nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-BF16 +else + TOKENIZER_MODEL=${HF_MODEL_CKPT} +fi + + + +MODEL_ARGS=" \ + --trust-remote-code \ + --save-interval 100000 \ + --micro-batch-size 1 \ + --enable-experimental \ + --use-fused-weighted-squared-relu \ + --cross-entropy-loss-fusion \ + --cross-entropy-fusion-impl native \ + --num-experts 512 \ + --moe-router-score-function sigmoid \ + --moe-grouped-gemm \ + --moe-aux-loss-coeff 1e-4 \ + --moe-router-topk 22 \ + --moe-permute-fusion \ + --moe-router-topk-scaling-factor 5.0 \ + --moe-router-enable-expert-bias \ + --moe-router-dtype fp32 \ + --moe-router-load-balancing-type seq_aux_loss \ + --moe-shared-expert-intermediate-size 5376 \ + --moe-token-dispatcher-type allgather \ + --moe-latent-size 1024 \ + \ + --attention-backend flash \ + --disable-gloo-process-groups \ + --is-hybrid-model \ + --mamba-num-heads 128 \ + --mamba-head-dim 64 \ + --hybrid-layer-pattern MEMEMEM*EMEMEMEM*EMEMEMEM*EMEMEMEMEM*EMEMEMEMEM*EMEMEMEMEM*EMEMEMEMEM*EMEMEMEM*EMEMEMEME \ + \ + --use-mcore-models \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --init-method-std 0.014 \ + --position-embedding-type none \ + --squared-relu \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --group-query-attention \ + --num-query-groups 2 \ + --ffn-hidden-size 2688 \ + --kv-channels 128 \ + --normalization RMSNorm \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + \ + --tokenizer-type HuggingFaceTokenizer \ + --bf16 \ + --seq-length 8192 \ + --max-position-embeddings 8192 \ + --export-model-type MambaModel \ + " diff --git a/megatron/core/inference/contexts/kv_block_allocator.py b/megatron/core/inference/contexts/kv_block_allocator.py new file mode 100644 index 00000000000..87039835c7f --- /dev/null +++ b/megatron/core/inference/contexts/kv_block_allocator.py @@ -0,0 +1,360 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +from collections import deque +from typing import Callable, Dict, Optional + +import torch +from torch import Tensor + +from megatron.core.inference.config import PrefixCachingEvictionPolicy + + +class KVBlockAllocator: + """Allocator that manages blocks of memory for the KV cache. + + This allocator is responsible for: + - Initializing a pool of block IDs + - Allocating blocks from the pool + - Releasing blocks back to the pool + + Args: + context (DynamicInferenceContext): Dynamic inference context. + total_count (int): Total number of blocks in the buffer. + paused_count (int): Number of paused blocks in the buffer. Must be less + than `total_count`. + """ + + def __init__( + self, + context: "DynamicInferenceContext", + total_count: int, + paused_count: int, + enable_prefix_caching: bool = False, + prefix_caching_eviction_policy: PrefixCachingEvictionPolicy = ( + PrefixCachingEvictionPolicy.REF_ZERO + ), + ): + + self.context = context + self.enable_prefix_caching = enable_prefix_caching + self.prefix_caching_eviction_policy = prefix_caching_eviction_policy + self.on_blocks_deregistered: Optional[Callable] = None + + self.total_count = total_count + self.total_avail = total_count - 1 # -1 for dummy_block_idx (see below) + self.paused_count = paused_count + self.active_count = total_count - paused_count - 1 # -1 for dummy_block_idx + assert self.active_count >= 1 # ensures paused_count < total_count - 1 + self.dummy_block_idx = self.total_count - 1 + + # Initialize block pool as a "stack" data structure + self.block_bag = torch.arange( + self.total_count, dtype=torch.int32, device=torch.cuda.current_device() + ) + + if self.enable_prefix_caching: + # Block hash tracking for prefix caching: -1 = uncomputed, positive = valid hash + self.block_hashes = torch.full( + (self.total_count,), -1, dtype=torch.int64, device=torch.cuda.current_device() + ) + + # Hash-to-block mapping for O(1) prefix lookup + self.kv_hash_to_block_id: Dict[int, int] = {} + + # Reference count per block: 0 = cached (evictable), >0 = actively used + self.block_ref_counts = torch.zeros( + (self.total_count,), dtype=torch.int32, device=torch.cuda.current_device() + ) + + # LRU timestamps for eviction ordering (higher = more recently used) + # Only needed in LRU mode; RZ mode evicts immediately on ref_count==0 + if self.prefix_caching_eviction_policy == PrefixCachingEvictionPolicy.LRU: + self.block_timestamps = torch.zeros( + (self.total_count,), dtype=torch.int64, device=torch.cuda.current_device() + ) + + def __str__(self): + return ( + f"using: total {self.get_total_used()}/{self.total_count - 1}" + f"; active {self.get_active_used()}/{self.active_count}" + f"; paused {self.get_paused_used()}/{self.paused_count}" + ) + + def get_total_used(self): + """Compute number of total blocks used.""" + return self.total_count - self.total_avail - 1 + + def get_active_used(self): + """Compute number of active blocks used.""" + if not self.enable_prefix_caching: + return ( + self.context.request_kv_block_counts[ + self.context.paused_request_count : self.context.total_request_count + ] + .sum() + .item() + ) + + active_start = self.context.paused_request_count + active_end = self.context.total_request_count + if active_end > active_start: + active_rows = self.context.request_to_kv_block_ids[active_start:active_end] + valid_ids = active_rows[active_rows >= 0] + if valid_ids.numel() > 0: + return int(torch.unique(valid_ids).numel()) + return 0 + + def get_paused_used(self): + """Compute number of paused blocks used.""" + if not self.enable_prefix_caching: + return ( + self.context.request_kv_block_counts[: self.context.paused_request_count] + .sum() + .item() + ) + + if self.context.paused_request_count > 0: + paused_rows = self.context.request_to_kv_block_ids[: self.context.paused_request_count] + valid_ids = paused_rows[paused_rows >= 0] + if valid_ids.numel() > 0: + return int(torch.unique(valid_ids).numel()) + return 0 + + def get_active_avail(self): + """Compute number of active blocks available.""" + return self.active_count - self.get_active_used() + + def get_paused_avail(self): + """Compute number of paused blocks available.""" + return self.paused_count - self.get_paused_used() + + def is_memory_available(self, num_blocks: int) -> bool: + """Check if memory blocks are available. + + Includes both free pool blocks and evictable cached blocks (ref_count == 0). + + Args: + num_blocks (int): Number of blocks to check. + + Return: + (bool) Is memory available? + """ + # Fast path: avoid expensive evictable count computation when free pool suffices + if self.total_avail >= num_blocks: + return True + if not self.enable_prefix_caching: + return False + if self.prefix_caching_eviction_policy == PrefixCachingEvictionPolicy.REF_ZERO: + return False # RZ: no cached blocks to evict + # Also count evictable cached blocks + evictable_count = self.get_evictable_block_count() + return (self.total_avail + evictable_count) >= num_blocks + + def allocate_memory_blocks(self, num_blocks: int) -> Optional[Tensor]: + """Allocate memory blocks if available, else return None. + + Will attempt LRU eviction of cached blocks if the free pool is insufficient. + + Args: + num_blocks (int): Number of blocks to allocate. + + Return: + (Optional[Tensor]) Allocated block IDs. + """ + # Try to evict cached blocks if free pool is insufficient + if self.total_avail < num_blocks: + if ( + not self.enable_prefix_caching + or self.prefix_caching_eviction_policy == PrefixCachingEvictionPolicy.REF_ZERO + ): + return None # RZ: no eviction path; disabled: no cached blocks + blocks_needed_from_eviction = num_blocks - self.total_avail + if not self.evict_lru_blocks(blocks_needed_from_eviction): + return None # Not enough blocks even after eviction + + # Now allocate from the free pool + self.total_avail -= num_blocks + block_ids = self.block_bag[self.total_avail : (self.total_avail + num_blocks)] + assert num_blocks == block_ids.numel() + + if self.enable_prefix_caching: + # Initialize ref counts for newly allocated blocks + self.block_ref_counts[block_ids] = 1 + if self.prefix_caching_eviction_policy == PrefixCachingEvictionPolicy.LRU: + self.update_timestamps(block_ids) + + return block_ids + + def release_memory_blocks(self, blocks: Tensor) -> None: + """Release memory blocks by decrementing reference counts. + + Blocks with ref_count == 0 remain cached (in hash map) for potential reuse. + They will be evicted via LRU when space is needed. + + Args: + blocks (Tensor): Block IDs to release. + + Return: + None + """ + if blocks.numel() == 0: + return + + if self.enable_prefix_caching: + self.block_ref_counts[blocks] -= 1 + if self.prefix_caching_eviction_policy == PrefixCachingEvictionPolicy.REF_ZERO: + zero_mask = self.block_ref_counts[blocks] == 0 + if zero_mask.any(): + self._deregister_blocks(blocks[zero_mask]) + elif self.prefix_caching_eviction_policy == PrefixCachingEvictionPolicy.LRU: + # Unregistered blocks (hash == -1, ref_count == 0) have no hash + # entry to preserve for reuse (e.g., partial blocks at the end of + # a request). Return them directly to the free pool so they are not + # leaked. + unreg_mask = (self.block_ref_counts[blocks] == 0) & ( + self.block_hashes[blocks] == -1 + ) + if unreg_mask.any(): + unreg_blocks = blocks[unreg_mask] + num_unreg = unreg_blocks.numel() + self.block_bag[self.total_avail : self.total_avail + num_unreg] = unreg_blocks + self.total_avail += num_unreg + else: + num_blocks = blocks.numel() + self.block_bag[self.total_avail : self.total_avail + num_blocks] = blocks + self.total_avail += num_blocks + + def reset(self) -> None: + """Reset the allocator to initial state. + + This resets the available block count to the entire memory pool + (except for the dummy block). + """ + + # Reset block bag to so we start consuming from the beginning of the pool + # for UVM performance. + # *Note*: Resetting the block bag is essential because if engine has been + # suspended, then the block bag contains non-unique IDs since the + # right-most IDs have been 'popped' off and are owned by the context. + # Without resetting the block bag, context request memory will clash and + # requests will point to each other's memory blocks, resulting in faulty + # generations. + self.block_bag = torch.arange( + self.total_count, dtype=torch.int32, device=torch.cuda.current_device() + ) + + self.total_avail = self.total_count - 1 + + if self.enable_prefix_caching: + # Reset all block hashes + self.block_hashes.fill_(-1) + + # Reset prefix caching state + self.kv_hash_to_block_id.clear() + self.block_ref_counts.fill_(0) + if self.prefix_caching_eviction_policy == PrefixCachingEvictionPolicy.LRU: + self.block_timestamps.fill_(0) + + # ========================================================================= + # Prefix caching methods + # ========================================================================= + + def register_kv_block_hashes(self, block_ids: list[int], block_hashes: list[int]) -> None: + """Register blocks in the hash-to-block mapping for discovery (batch). + + Args: + block_ids: List of block IDs. + block_hashes: List of computed hash values (same length as block_ids). + """ + if not block_ids: + return + id_tensor = torch.tensor(block_ids, dtype=torch.int64, device=self.block_hashes.device) + hash_tensor = torch.tensor(block_hashes, dtype=torch.int64, device=self.block_hashes.device) + self.block_hashes[id_tensor] = hash_tensor + self.kv_hash_to_block_id.update(zip(block_hashes, block_ids)) + + def _deregister_blocks(self, block_ids: Tensor) -> None: + """Remove blocks from prefix caching state and return to free pool. + + Shared cleanup logic for both LRU eviction and RZ proactive eviction. + + Args: + block_ids: Tensor of block IDs to deregister. + """ + num_blocks = block_ids.numel() + if num_blocks == 0: + return + + # Gather hashes via batched tensor indexing + block_ids_i64 = block_ids.to(torch.int64) + hashes = self.block_hashes[block_ids_i64].tolist() + + # Remove from kv_hash_to_block_id dict (set ops + C-level map, no Python loop) + keys_to_delete = set(hashes) - {-1} + deque( + map(self.kv_hash_to_block_id.pop, keys_to_delete & self.kv_hash_to_block_id.keys()), + maxlen=0, + ) + + # Notify Mamba slot allocator (if wired) to clean up its state + if self.on_blocks_deregistered is not None: + self.on_blocks_deregistered(block_ids.tolist(), keys_to_delete) + + # Reset block state (batched tensor ops) + self.block_hashes[block_ids] = -1 + self.block_ref_counts[block_ids] = 0 + if self.prefix_caching_eviction_policy == PrefixCachingEvictionPolicy.LRU: + self.block_timestamps[block_ids] = 0 + + # Return blocks to free pool + self.block_bag[self.total_avail : self.total_avail + num_blocks] = block_ids + self.total_avail += num_blocks + + def update_timestamps(self, block_ids: Tensor) -> None: + """Update LRU timestamps for accessed blocks. No-op in RZ mode. + + Args: + block_ids: Tensor of block IDs that were accessed. + """ + if ( + self.prefix_caching_eviction_policy != PrefixCachingEvictionPolicy.LRU + or block_ids.numel() == 0 + ): + return + self.block_timestamps[block_ids] = self.context.prefix_cache_lru_clock + + def get_evictable_block_count(self) -> Tensor: + """Get count of cached blocks that can be evicted (ref_count == 0, hash set). + + Returns: + Scalar tensor with the number of evictable cached blocks. + """ + cached_mask = (self.block_ref_counts == 0) & (self.block_hashes != -1) + return cached_mask.sum() + + def evict_lru_blocks(self, num_blocks_needed: int) -> bool: + """Evict LRU cached blocks to free up space in the pool. + + Evicts blocks with ref_count == 0, starting with oldest timestamps. + + Args: + num_blocks_needed: Number of blocks to evict. + + Returns: + True if enough blocks were evicted, False otherwise. + """ + # Find all cached blocks (ref_count == 0, hash != -1) + cached_mask = (self.block_ref_counts == 0) & (self.block_hashes != -1) + cached_block_ids = torch.nonzero(cached_mask, as_tuple=True)[0] + + if cached_block_ids.numel() < num_blocks_needed: + return False # Not enough cached blocks to evict + + # Sort by timestamp (ascending = oldest first) + cached_timestamps = self.block_timestamps[cached_block_ids] + sorted_indices = torch.argsort(cached_timestamps) + blocks_to_evict = cached_block_ids[sorted_indices[:num_blocks_needed]] + + self._deregister_blocks(blocks_to_evict) + + return True diff --git a/megatron/core/inference/contexts/mamba_slot_allocator.py b/megatron/core/inference/contexts/mamba_slot_allocator.py new file mode 100644 index 00000000000..538bd760523 --- /dev/null +++ b/megatron/core/inference/contexts/mamba_slot_allocator.py @@ -0,0 +1,457 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +from typing import TYPE_CHECKING, Dict, List, Optional + +import torch +from torch import Tensor + +from megatron.core.inference.config import PrefixCachingEvictionPolicy + +if TYPE_CHECKING: + from .dynamic_context import DynamicInferenceContext + + +class MambaSlotAllocator: + """Manages Mamba state caching for prefix caching in hybrid models. + + Owns the Mamba cache slot pool, block-to-slot mappings, hash-to-block + mapping, and intermediate state tracking. Accesses KV allocator state + (ref counts, timestamps, block hashes) via the parent context. + + Args: + context: The DynamicInferenceContext that owns this allocator. + max_slots: Maximum number of cache slots. + num_mamba_layers: Number of Mamba layers in the model. + conv_states_shape: Shape of per-slot conv state (excluding layer/slot dims). + ssm_states_shape: Shape of per-slot SSM state (excluding layer/slot dims). + conv_states_dtype: Dtype for conv state tensors. + ssm_states_dtype: Dtype for SSM state tensors. + """ + + def __init__( + self, + context: "DynamicInferenceContext", + max_slots: int, + num_mamba_layers: int, + conv_states_shape: tuple, + ssm_states_shape: tuple, + conv_states_dtype: torch.dtype, + ssm_states_dtype: torch.dtype, + ): + self.context = context + self.max_slots = max_slots + self.num_mamba_layers = num_mamba_layers + + device = torch.cuda.current_device() + num_blocks = context.kv_block_allocator.total_count + + # Block <-> slot mappings + self.block_to_slot = torch.full((num_blocks,), -1, dtype=torch.int32, device=device) + self.slot_to_block = torch.full((max_slots,), -1, dtype=torch.int32, device=device) + + # Free slot pool (stack) + self.free_slots = torch.arange(max_slots, dtype=torch.int32, device=device) + self.free_count = max_slots + + # State tensors + self.conv_states = torch.zeros( + (num_mamba_layers, max_slots) + conv_states_shape, + dtype=conv_states_dtype, + device=device, + ) + self.ssm_states = torch.zeros( + (num_mamba_layers, max_slots) + ssm_states_shape, dtype=ssm_states_dtype, device=device + ) + + # Hash-to-block mapping: only blocks with cached Mamba state + self.hash_to_block_id: Dict[int, int] = {} + + # Per-request intermediate state storage + self._intermediate_offsets: list = [None] * context.max_requests + self._intermediate_block_ids: list = [None] * context.max_requests + self._eos_cache_block_id: list = [None] * context.max_requests + self._intermediate_buffer: dict = {} + + # ========================================================================= + # Slot management + # ========================================================================= + + def allocate_slot(self, block_id: int) -> int: + """Get a free Mamba cache slot for a block, evicting if necessary. + + Args: + block_id: The KV block ID to associate with this slot. + + Returns: + The allocated slot index. + """ + # Check if block already has a slot + existing = self.block_to_slot[block_id].item() + if existing >= 0: + return existing + + # Try free pool + if self.free_count > 0: + self.free_count -= 1 + slot = self.free_slots[self.free_count].item() + else: + slot = self._evict_lru_slot() + + self.block_to_slot[block_id] = slot + self.slot_to_block[slot] = block_id + return slot + + def _evict_lru_slot(self) -> int: + """Evict the least recently used Mamba cache slot. + + Returns: + The freed slot index. + """ + kv_alloc = self.context.kv_block_allocator + # Find blocks that have mamba slots and ref_count == 0 + has_slot_mask = self.block_to_slot[: kv_alloc.total_count] >= 0 + ref_zero_mask = kv_alloc.block_ref_counts[: kv_alloc.total_count] == 0 + candidates = has_slot_mask & ref_zero_mask + candidate_ids = torch.nonzero(candidates, as_tuple=True)[0] + + if candidate_ids.numel() == 0: + raise RuntimeError("No evictable Mamba cache slots available") + + # Pick block with oldest timestamp if LRU, otherwise just pick first + if self.context.prefix_caching_eviction_policy == PrefixCachingEvictionPolicy.LRU: + timestamps = kv_alloc.block_timestamps[candidate_ids] + evict_idx = candidate_ids[torch.argmin(timestamps)].item() + else: + evict_idx = candidate_ids[0].item() + + slot = self.block_to_slot[evict_idx].item() + block_hash = kv_alloc.block_hashes[evict_idx].item() + + # Clean up mappings + self.block_to_slot[evict_idx] = -1 + self.slot_to_block[slot] = -1 + if block_hash > 0 and block_hash in self.hash_to_block_id: + del self.hash_to_block_id[block_hash] + + return slot + + def get_slot(self, block_id: int) -> int: + """Return the cache slot for a block, or -1 if none. + + Args: + block_id: The KV block ID. + + Returns: + Slot index or -1. + """ + return self.block_to_slot[block_id].item() + + def has_state(self, block_id: int) -> bool: + """Check if a block has cached Mamba state.""" + return self.block_to_slot[block_id].item() >= 0 + + def invalidate_block(self, block_id: int) -> None: + """Free cache slot and clear mappings for a block. + + Called when KV blocks are evicted/deregistered. + + Args: + block_id: The KV block ID. + """ + slot = self.block_to_slot[block_id].item() + if slot < 0: + return + self.block_to_slot[block_id] = -1 + self.slot_to_block[slot] = -1 + # Return slot to free pool + self.free_slots[self.free_count] = slot + self.free_count += 1 + + # ========================================================================= + # State store/restore + # ========================================================================= + + def store_from_tensors( + self, block_id: int, layer_idx: int, ssm_state: Tensor, conv_state: Tensor + ) -> None: + """Write provided state tensors to a cache slot for a specific layer. + + Args: + block_id: The KV block ID. + layer_idx: The Mamba layer index. + ssm_state: SSM state tensor to store. + conv_state: Conv state tensor to store. + """ + slot = self.block_to_slot[block_id].item() + assert slot >= 0, f"Block {block_id} has no Mamba cache slot" + self.ssm_states[layer_idx, slot].copy_(ssm_state) + self.conv_states[layer_idx, slot].copy_(conv_state) + + def store_from_live(self, block_id: int, request_idx: int) -> None: + """Copy all layers from live per-request buffer to cache slot. + + Used for block-aligned EOS case where the final kernel state + is in the live buffer. + + Args: + block_id: The KV block ID. + request_idx: The context request index. + """ + slot = self.block_to_slot[block_id].item() + assert slot >= 0, f"Block {block_id} has no Mamba cache slot" + mamba_idx = self.context.mamba_metadata.request_to_mamba_state_idx[request_idx].item() + self.conv_states[:, slot].copy_(self.context.mamba_conv_states[:, mamba_idx]) + self.ssm_states[:, slot].copy_(self.context.mamba_ssm_states[:, mamba_idx]) + + def restore_to_live(self, request_idx: int, block_id: int) -> bool: + """Copy all layers from cache slot to live request state. + + Args: + request_idx: The context request index. + block_id: The KV block ID. + + Returns: + True if state was restored, False if block has no cached state. + """ + slot = self.block_to_slot[block_id].item() + if slot < 0: + return False + mamba_idx = self.context.mamba_metadata.request_to_mamba_state_idx[request_idx].item() + self.context.mamba_conv_states[:, mamba_idx].copy_(self.conv_states[:, slot]) + self.context.mamba_ssm_states[:, mamba_idx].copy_(self.ssm_states[:, slot]) + return True + + # ========================================================================= + # Hash registration + # ========================================================================= + + def register_block_hash(self, block_id: int, block_hash: int) -> None: + """Register a block as having cached Mamba state. + + Args: + block_id: The block ID. + block_hash: The block's hash value. + """ + self.hash_to_block_id[block_hash] = block_id + + # ========================================================================= + # Deregistration callback + # ========================================================================= + + def on_kv_blocks_deregistered(self, block_ids_list: list, hashes_to_delete: set) -> None: + """Handle KV block deregistration by cleaning up Mamba state. + + Called by KVBlockAllocator._deregister_blocks via callback. + + Args: + block_ids_list: List of deregistered block IDs. + hashes_to_delete: Set of hashes being deregistered (excludes -1). + """ + if self.hash_to_block_id: + mamba_keys = hashes_to_delete & self.hash_to_block_id.keys() + if mamba_keys: + from collections import deque + + deque(map(self.hash_to_block_id.pop, mamba_keys), maxlen=0) + for bid in block_ids_list: + self.invalidate_block(bid) + + # ========================================================================= + # Intermediate offset tracking + # ========================================================================= + + def compute_and_store_offsets( + self, + req, + current_id: int, + skip_tokens: int, + prefill_chunk_length: int, + num_matched_blocks: int, + matched_block_ids: list, + overall_required_blocks: int, + ) -> None: + """Compute intermediate state extraction offsets and store per-request. + + Args: + req: The inference request. + current_id: Context request index. + skip_tokens: Number of tokens being skipped (mamba match). + prefill_chunk_length: Total prefill chunk length before skipping. + num_matched_blocks: Number of KV-matched blocks. + matched_block_ids: List of matched KV block IDs. + overall_required_blocks: Total blocks needed for this request. + """ + ctx = self.context + prompt_len = len(req.prompt_tokens) + num_kv_matched = num_matched_blocks + kv_div_abs = num_kv_matched * ctx.block_size_tokens + last_aligned_abs = (prompt_len // ctx.block_size_tokens) * ctx.block_size_tokens + seq_len = prefill_chunk_length - skip_tokens # effective prefill length + + # Compute relative offsets (relative to prefill start after skip) + kv_div_rel = kv_div_abs - skip_tokens + last_aligned_rel = last_aligned_abs - skip_tokens + penultimate_abs = (overall_required_blocks - 1) * ctx.block_size_tokens + penultimate_rel = penultimate_abs - skip_tokens + + # Determine mamba_chunk_size from mamba config (128 is the standard SSM kernel chunk size) + mamba_chunk_size = 128 + + # Build offset list: include if > 0, < seq_len, and % mamba_chunk_size == 0 + offsets_set = set() + for offset in [kv_div_rel, last_aligned_rel, penultimate_rel]: + if offset > 0 and offset < seq_len and offset % mamba_chunk_size == 0: + offsets_set.add(offset) + + offsets = sorted(offsets_set) + + # Map each offset back to block index and block ID + block_ids_for_offsets = [] + for offset in offsets: + abs_token = skip_tokens + offset + block_idx = abs_token // ctx.block_size_tokens - 1 + bid = ctx.request_to_kv_block_ids[current_id][block_idx].item() + block_ids_for_offsets.append(bid) + + self._intermediate_offsets[current_id] = offsets if offsets else None + self._intermediate_block_ids[current_id] = ( + block_ids_for_offsets if block_ids_for_offsets else None + ) + + # Block-aligned EOS: prompt_len is exactly block-aligned + if last_aligned_abs == prompt_len and prompt_len > 0: + last_block_idx = prompt_len // ctx.block_size_tokens - 1 + if last_block_idx >= 0: + eos_bid = ctx.request_to_kv_block_ids[current_id][last_block_idx].item() + self._eos_cache_block_id[current_id] = eos_bid + else: + self._eos_cache_block_id[current_id] = None + else: + self._eos_cache_block_id[current_id] = None + + def get_intermediate_offsets(self) -> Optional[List[List[int]]]: + """Get intermediate token offsets for all prefill requests in the current batch. + + Returns: + List of offset lists (one per prefill request), or None if no + request has intermediate offsets. + """ + ctx = self.context + prefill_count = ctx.batch_dimensions.prefill_req_count + if prefill_count == 0: + return None + + # Prefill requests are the last `prefill_count` active requests + active_start = ctx.paused_request_count + decode_count = ctx.batch_dimensions.decode_req_count + prefill_start = active_start + decode_count + + result = [] + has_any = False + for i in range(prefill_start, prefill_start + prefill_count): + offsets = self._intermediate_offsets[i] + if offsets is not None: + has_any = True + result.append(offsets) + else: + result.append([]) + + return result if has_any else None + + def buffer_intermediate_states( + self, mamba_layer_idx: int, intermediate_states_per_request: list + ) -> None: + """Buffer intermediate states from a single Mamba layer's forward pass. + + Args: + mamba_layer_idx: The Mamba layer index. + intermediate_states_per_request: Per-request list of + (ssm_states, conv_states) tuples or None. + """ + self._intermediate_buffer[mamba_layer_idx] = intermediate_states_per_request + + def commit_intermediate_states(self) -> None: + """Commit buffered intermediate states to the Mamba cache. + + Called after the forward pass completes. For each prefill request: + - Intermediate states at kv_divergence/last_aligned: allocate cache slot, + write state, register hash in hash_to_block_id. + - Block-aligned EOS: copy final state from live buffer to cache slot. + """ + ctx = self.context + prefill_count = ctx.batch_dimensions.prefill_req_count + if prefill_count == 0: + self._clear_intermediate_state() + return + + active_start = ctx.paused_request_count + decode_count = ctx.batch_dimensions.decode_req_count + prefill_start = active_start + decode_count + has_buffer = bool(self._intermediate_buffer) + + for req_batch_idx in range(prefill_count): + ctx_idx = prefill_start + req_batch_idx + offsets = self._intermediate_offsets[ctx_idx] + block_ids = self._intermediate_block_ids[ctx_idx] + + # Commit intermediate states from forward pass + if offsets is not None and block_ids is not None and has_buffer: + for offset_idx in range(len(offsets)): + bid = block_ids[offset_idx] + slot = self.allocate_slot(bid) + + # Write states from each mamba layer + for layer_idx, states_list in self._intermediate_buffer.items(): + if states_list[req_batch_idx] is not None: + ssm_states, conv_states = states_list[req_batch_idx] + self.ssm_states[layer_idx, slot].copy_(ssm_states[offset_idx]) + self.conv_states[layer_idx, slot].copy_(conv_states[offset_idx]) + + # Register in mamba hash map + block_hash = ctx.kv_block_allocator.block_hashes[bid].item() + if block_hash > 0: + self.register_block_hash(bid, block_hash) + + # Handle block-aligned EOS: copy final state from live buffer + eos_bid = self._eos_cache_block_id[ctx_idx] + if eos_bid is not None: + slot = self.allocate_slot(eos_bid) + self.store_from_live(eos_bid, ctx_idx) + block_hash = ctx.kv_block_allocator.block_hashes[eos_bid].item() + if block_hash > 0: + self.register_block_hash(eos_bid, block_hash) + + self._clear_intermediate_state() + + def _clear_intermediate_state(self) -> None: + """Clear all per-request intermediate state tracking.""" + self._intermediate_buffer.clear() + ctx = self.context + prefill_count = ctx.batch_dimensions.prefill_req_count + if prefill_count > 0: + active_start = ctx.paused_request_count + decode_count = ctx.batch_dimensions.decode_req_count + prefill_start = active_start + decode_count + for i in range(prefill_start, prefill_start + prefill_count): + self._intermediate_offsets[i] = None + self._intermediate_block_ids[i] = None + self._eos_cache_block_id[i] = None + + # ========================================================================= + # Reset + # ========================================================================= + + def reset(self) -> None: + """Reset all state (mappings, free pool, cache, intermediate tracking).""" + self.block_to_slot.fill_(-1) + self.slot_to_block.fill_(-1) + self.free_slots = torch.arange( + self.max_slots, dtype=torch.int32, device=torch.cuda.current_device() + ) + self.free_count = self.max_slots + self.hash_to_block_id.clear() + self._intermediate_buffer.clear() + for i in range(self.context.max_requests): + self._intermediate_offsets[i] = None + self._intermediate_block_ids[i] = None + self._eos_cache_block_id[i] = None diff --git a/megatron/core/inference/moe/__init__.py b/megatron/core/inference/moe/__init__.py new file mode 100644 index 00000000000..ea716b5fbd5 --- /dev/null +++ b/megatron/core/inference/moe/__init__.py @@ -0,0 +1,57 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import enum + +import torch + +from .fused_moe import ActivationType, mcore_fused_moe + + +class InferenceGroupedGemmBackend(enum.Enum): + """Resolved backend for grouped GEMM operations during inference.""" + + FLASHINFER = "flashinfer" + TORCH = "torch" + TE = "te" + + +def resolve_inference_grouped_gemm_backend( + backend: str, is_cuda_graphed: bool, is_mxfp8: bool = False +) -> InferenceGroupedGemmBackend: + """Resolve the grouped GEMM backend to use for the current iteration. + + Prerequisites are validated at init time in MoELayer; this function + simply maps (backend, is_cuda_graphed) to the concrete backend enum. + + Args: + backend: One of 'auto', 'torch', 'te'. + is_cuda_graphed: Whether this is a CUDA-graphed iteration. + is_mxfp8: Whether the model is using MXFP8 quantization (affects auto backend choice). + Returns: + An InferenceGroupedGemmBackend enum value. + """ + if backend == 'auto': + if is_cuda_graphed: + if is_mxfp8: + assert hasattr(torch.nn.functional, 'scaled_grouped_mm'), ( + "Auto backend selection for MXFP8 requires " + "torch.nn.functional.scaled_grouped_mm. " + "Please install PyTorch 2.10+." + ) + return InferenceGroupedGemmBackend.TORCH + else: + return InferenceGroupedGemmBackend.FLASHINFER + else: + if hasattr(torch.nn.functional, 'grouped_mm'): + return InferenceGroupedGemmBackend.TORCH + else: + return InferenceGroupedGemmBackend.TE + elif backend == 'torch': + return InferenceGroupedGemmBackend.TORCH + elif backend == 'te': + return InferenceGroupedGemmBackend.TE + else: + raise ValueError( + f"Unknown inference_grouped_gemm_backend: '{backend}'. " + "Must be 'auto', 'torch', or 'te'." + ) diff --git a/megatron/core/inference/moe/activations.py b/megatron/core/inference/moe/activations.py new file mode 100644 index 00000000000..169d8499116 --- /dev/null +++ b/megatron/core/inference/moe/activations.py @@ -0,0 +1,166 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +"""Padding-aware activation kernels for fused MoE. + +These kernels skip padding rows (where permutation_map == -1) to avoid +wasted computation on aligned-but-empty expert slots. +""" + +from unittest.mock import MagicMock + +import torch + +from megatron.core.utils import null_decorator + +try: + import triton + import triton.language as tl + + HAVE_TRITON = True +except ImportError: + HAVE_TRITON = False + +if not HAVE_TRITON: + triton = MagicMock() + triton.jit = null_decorator + tl = MagicMock() + + +def _ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def _squared_relu_kernel(input_ptr, output_ptr, src_idx_ptr, M, N, BLOCK_N: tl.constexpr): + """Squared ReLU that skips padding rows (permutation_map == -1).""" + row = tl.program_id(0) + if tl.load(src_idx_ptr + row) < 0: + return + for n in tl.range(0, N, BLOCK_N): + o = n + tl.arange(0, BLOCK_N) + m = o < N + x = tl.load(input_ptr + row * N + o, mask=m).to(tl.float32) + r = tl.maximum(x, 0.0) + tl.store(output_ptr + row * N + o, (r * r).to(tl.bfloat16), mask=m) + + +def padded_squared_relu(x: torch.Tensor, permutation_map: torch.Tensor) -> torch.Tensor: + """Squared ReLU activation that skips padding rows.""" + M, N = x.shape + out = torch.zeros(M, N, dtype=x.dtype, device=x.device) + BLOCK_N = min(triton.next_power_of_2(N), 1024) + _squared_relu_kernel[(M,)](x, out, permutation_map, M, N, BLOCK_N=BLOCK_N) + return out + + +@triton.jit +def _squared_relu_quantize_kernel( + input_ptr, + out_fp8_ptr, + out_scale_ptr, + src_idx_ptr, + K, + n_col_blocks, + skip_padding: tl.constexpr, + REAL_GROUPS: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_GROUPS: tl.constexpr, +): + """Fused squared ReLU + MXFP8 quantize + swizzle in one kernel. + + Grid: (M,) — one program per row. + Reads BF16 FC1 output, applies squared ReLU, quantizes to FP8, + writes FP8 data + swizzled scales in place. + """ + row = tl.program_id(0) + if skip_padding: + if tl.load(src_idx_ptr + row) < 0: + return + + offs = tl.arange(0, BLOCK_K) + mask = offs < K + + # Load and apply squared ReLU + x = tl.load(input_ptr + row * K + offs, mask=mask, other=0.0).to(tl.float32) + relu = tl.maximum(x, 0.0) + activated = relu * relu + + # Per-group-of-32 quantization + x_grouped = tl.reshape(activated, [BLOCK_GROUPS, 32]) + abs_grouped = tl.abs(x_grouped) + max_vals = tl.max(abs_grouped, axis=1) + + dequant_scale = max_vals / 448.0 + dequant_exp = (dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF) & 0x7F800000 + dequant_rounded = dequant_exp.to(tl.float32, bitcast=True) + quant_scale = tl.where(dequant_rounded == 0, 0.0, 1.0 / dequant_rounded) + + quantized = x_grouped * quant_scale[:, None] + quantized_flat = tl.reshape(quantized, [BLOCK_K]) + out_fp8 = quantized_flat.to(tl.float8e4nv) + + # Store FP8 data + tl.store(out_fp8_ptr + row * K + offs, out_fp8, mask=mask) + + # Store swizzled scales + scale_exp = (dequant_exp >> 23).to(tl.uint8) + col_offs = tl.arange(0, BLOCK_GROUPS) + col_mask = col_offs < REAL_GROUPS + + macro_row_block = row // 128 + macro_col_block = col_offs // 4 + local_row = row % 128 + local_col = col_offs % 4 + group = local_row // 32 + sub_row = local_row % 32 + tile_idx = macro_row_block * n_col_blocks + macro_col_block + swizzled_offs = tile_idx * 512 + sub_row * 16 + group * 4 + local_col + + tl.store(out_scale_ptr + swizzled_offs, scale_exp, mask=col_mask) + + +def squared_relu_and_quantize_mxfp8( + x: torch.Tensor, permutation_map: torch.Tensor, skip_padding: bool = True +): + """Fused squared ReLU + MXFP8 quantize + swizzle. + + Reads BF16 FC1 output, applies squared ReLU, quantizes to FP8 with + swizzled scales. Single kernel replaces padded_squared_relu + mxfp8_quantize. + + Args: + x: [M, K] BF16 FC1 output. + permutation_map: [M] int32, original token index or -1 for padding. + skip_padding: if True, skip rows where permutation_map == -1. + + Returns: + MXFP8Tensor with .data [M, K] float8_e4m3fn and .scale (swizzled e8m0). + """ + from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor + + M, K = x.shape + assert K % 32 == 0 + + scale_cols = K // 32 + n_row_blocks = _ceil_div(M, 128) + n_col_blocks = _ceil_div(scale_cols, 4) + total_scale_bytes = n_row_blocks * n_col_blocks * 512 + + out_fp8 = torch.empty(M, K, dtype=torch.float8_e4m3fn, device=x.device) + out_scale = torch.zeros(total_scale_bytes, dtype=torch.uint8, device=x.device) + + BLOCK_K = triton.next_power_of_2(K) + BLOCK_GROUPS = BLOCK_K // 32 + + _squared_relu_quantize_kernel[(M,)]( + x, + out_fp8, + out_scale, + permutation_map, + K, + n_col_blocks, + skip_padding, + REAL_GROUPS=scale_cols, + BLOCK_K=BLOCK_K, + BLOCK_GROUPS=BLOCK_GROUPS, + ) + + return MXFP8Tensor(data=out_fp8, scale=out_scale.view(torch.float8_e8m0fnu), backend="triton") diff --git a/megatron/core/inference/moe/fused_moe.py b/megatron/core/inference/moe/fused_moe.py new file mode 100644 index 00000000000..39382eee079 --- /dev/null +++ b/megatron/core/inference/moe/fused_moe.py @@ -0,0 +1,204 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +"""Fused MoE: permute -> FC1 -> activation -> FC2 -> unpermute. + +Supports BF16 weights with torch.nn.functional.grouped_mm. +All permutation logic is handled internally — callers invoke a single function. +""" + +from enum import Enum +from typing import Callable, Optional + +import torch + +from megatron.core.inference.moe.activations import ( + padded_squared_relu, + squared_relu_and_quantize_mxfp8, +) +from megatron.core.inference.moe.pad import pad_to_alignment, unpad_from_alignment +from megatron.core.inference.moe.permute import ( + permute_and_quantize_mxfp8, + permute_tokens, + unpermute_tokens, +) +from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor + +try: + from torch.nn.functional import grouped_mm + + HAVE_GROUPED_MM = True +except ImportError: + HAVE_GROUPED_MM = False + +try: + from torch.nn.functional import ScalingType, SwizzleType, scaled_grouped_mm + + HAVE_SCALED_GMM = True +except ImportError: + HAVE_SCALED_GMM = False + + +class ActivationType(Enum): + """Activation functions supported by mcore_fused_moe.""" + + SQUARED_RELU = "squared_relu" + + +def _bf16_grouped_mm( + x_bf16: torch.Tensor, weight: torch.Tensor, offs: torch.Tensor +) -> torch.Tensor: + """BF16 grouped GEMM using torch.nn.functional.grouped_mm.""" + assert x_bf16.dtype == torch.bfloat16, f"Expected bf16 input, got {x_bf16.dtype}" + return grouped_mm(x_bf16, weight.transpose(1, 2), offs=offs) + + +def _mxfp8_grouped_mm(act: MXFP8Tensor, weight: MXFP8Tensor, offs: torch.Tensor) -> torch.Tensor: + """MXFP8 scaled_grouped_mm with pre-quantized activations and weights.""" + return scaled_grouped_mm( + act.data, + weight.data.transpose(1, 2), + act.scale_2d(), + ScalingType.BlockWise1x32, + weight.scale, + ScalingType.BlockWise1x32, + swizzle_a=SwizzleType.SWIZZLE_32_4_4, + swizzle_b=SwizzleType.SWIZZLE_32_4_4, + offs=offs, + output_dtype=torch.bfloat16, + ) + + +def _get_activation_func(activation_type: ActivationType, fused_quant: bool = False) -> Callable: + """Resolve ActivationType enum to a concrete kernel. + + If fused_quant=True, returns the fused activation + MXFP8 quantize kernel. + """ + if activation_type == ActivationType.SQUARED_RELU: + return squared_relu_and_quantize_mxfp8 if fused_quant else padded_squared_relu + else: + raise ValueError(f"Unsupported activation type: {activation_type}") + + +def mcore_fused_moe( + hidden_states: torch.Tensor, + probs: torch.Tensor, + fc1_weight, + fc2_weight, + activation_type: ActivationType, + num_local_experts: int, + local_expert_start: int, + routing_map: Optional[torch.Tensor] = None, + tokens_per_expert: Optional[torch.Tensor] = None, + skip_permute: bool = False, + disable_fused_quant_kernels: bool = False, +) -> torch.Tensor: + """Fused MoE: [permute ->] pad -> FC1 -> activation -> FC2 -> unpad [-> unpermute]. + + Two modes: + - skip_permute=False (default): tokens are unpermuted. Requires routing_map. + Performs full permute -> compute -> unpermute. + - skip_permute=True: tokens are already permuted by the dispatcher. Requires + tokens_per_expert. Pads to alignment, computes, then unpads. Probs are + applied during unpad. + + Unless disable_fused_quant_kernels=True, when weights are MXFP8, uses fused + kernels that combine permute/activation with MXFP8 quantization into single + kernel launches. + + Args: + hidden_states: [num_tokens, hidden_size] BF16 input. + probs: routing probabilities. Shape is [num_tokens, topk] when + skip_permute=False, or [num_tokens] (already gathered) when + skip_permute=True. + fc1_weight: stacked weight for FC1 (torch.Tensor for BF16, MXFP8Tensor for MXFP8). + fc2_weight: stacked weight for FC2 (same type as fc1_weight). + activation_type: ActivationType enum (SQUARED_RELU). + num_local_experts: number of experts on this rank. + local_expert_start: first global expert index on this rank. + routing_map: [num_tokens, topk] int expert assignments. Required when skip_permute=False. + tokens_per_expert: [num_local_experts] int32 token counts. Required when skip_permute=True. + skip_permute: if True, skip permute/unpermute (tokens already in expert order). + disable_fused_quant_kernels: if True, disable fused permute+quantize and + activation+quantize kernels for MXFP8, using separate launches instead. + Useful for debugging. Ignored when weights are BF16. + + Returns: + [num_tokens, hidden_size] BF16 output. + """ + assert ( + hidden_states.dtype == torch.bfloat16 + ), f"mcore_fused_moe requires bf16 input, got {hidden_states.dtype}" + + num_tokens = hidden_states.shape[0] + use_mxfp8 = isinstance(fc1_weight, MXFP8Tensor) + # Fused quant kernels only apply to MXFP8 path + use_fused_quant = use_mxfp8 and not disable_fused_quant_kernels + + if use_mxfp8: + assert ( + HAVE_SCALED_GMM + ), "torch.nn.functional.scaled_grouped_mm not available. Install PyTorch 2.10+." + mm_fn = _mxfp8_grouped_mm + # scaled_grouped_mm requires each expert's token count aligned to 32, + # but swizzled MXFP8 scales require alignment to 128. Use 128 to + # satisfy both constraints. + expert_alignment = 128 + else: + assert ( + HAVE_GROUPED_MM + ), "torch.nn.functional.grouped_mm not available. Install PyTorch 2.10+." + mm_fn = _bf16_grouped_mm + expert_alignment = 16 + + activation_func = _get_activation_func(activation_type, fused_quant=use_fused_quant) + + # --- Pre-processing: permute or pad --- + if skip_permute: + assert tokens_per_expert is not None, "tokens_per_expert is required when skip_permute=True" + tokens_per_expert = tokens_per_expert.cuda().int() + assert routing_map is None, "routing_map must be None when skip_permute=True" + hidden_states, permutation_map, offs = pad_to_alignment( + hidden_states, tokens_per_expert, expert_alignment + ) + permuted_probs = None + + else: + assert routing_map is not None, "routing_map is required when skip_permute=False" + if use_fused_quant: + # Fused permute + MXFP8 quantize: single kernel produces MXFP8Tensor + hidden_states, permuted_probs, permutation_map, offs = permute_and_quantize_mxfp8( + hidden_states, + probs, + routing_map, + local_expert_start, + num_local_experts, + alignment=expert_alignment, + ) + else: + hidden_states, permuted_probs, permutation_map, offs = permute_tokens( + hidden_states, + probs, + routing_map, + local_expert_start, + num_local_experts, + alignment=expert_alignment, + ) + + # --- FC1 -> activation -> FC2 --- + # Quantize if MXFP8 path and hidden_states not already quantized (fused permute+quant + # produces MXFP8Tensor directly; skip_permute path always needs separate quant). + needs_quant = use_mxfp8 and not isinstance(hidden_states, MXFP8Tensor) + if needs_quant: + hidden_states = MXFP8Tensor.from_bf16(hidden_states, backend="triton") + fc1_output = mm_fn(hidden_states, fc1_weight, offs) + + activation_out = activation_func(fc1_output, permutation_map) + # Fused activation+quant returns MXFP8Tensor; otherwise quantize separately. + if use_mxfp8 and not isinstance(activation_out, MXFP8Tensor): + activation_out = MXFP8Tensor.from_bf16(activation_out, backend="triton") + fc2_output = mm_fn(activation_out, fc2_weight, offs) + # --- Post-processing: unpermute or unpad --- + if skip_permute: + probs_1d = probs.squeeze(-1) if probs.dim() > 1 else probs + return unpad_from_alignment(fc2_output, permutation_map, num_tokens, probs=probs_1d) + else: + return unpermute_tokens(fc2_output, permuted_probs, permutation_map, num_tokens) diff --git a/megatron/core/inference/moe/pad.py b/megatron/core/inference/moe/pad.py new file mode 100644 index 00000000000..656953b691c --- /dev/null +++ b/megatron/core/inference/moe/pad.py @@ -0,0 +1,201 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +"""Pad / unpad utilities for already-permuted expert tokens. + +When the token dispatcher has already permuted tokens into expert-grouped +order, these functions insert/remove alignment padding so that each expert's +token block satisfies the alignment requirements of grouped_mm / +scaled_grouped_mm. +""" + +from unittest.mock import MagicMock + +import torch +from packaging import version + +from megatron.core.utils import null_decorator + +try: + import triton + import triton.language as tl + + if version.parse(triton.__version__) < version.parse("3.4.0") and not torch.cuda.is_available(): + HAVE_TRITON = False + else: + HAVE_TRITON = tl.constexpr(version.parse(triton.__version__) >= version.parse("2.0.0")) +except ImportError: + HAVE_TRITON = False + +if not HAVE_TRITON: + triton = MagicMock() + triton.jit = null_decorator + tl = MagicMock() + +from megatron.core.inference.moe.permute import compute_expert_offsets + + +@triton.jit +def _pad_tokens_kernel( + src_ptr, + dst_ptr, + perm_map_ptr, + tpe_ptr, # tokens_per_expert [num_experts] + hidden_dim, + num_experts: tl.constexpr, + alignment: tl.constexpr, + BLOCK_H: tl.constexpr, +): + """Copy one input row into the padded output buffer. + + Computes unpadded and padded cumulative offsets inline from + tokens_per_expert, avoiding a separate cumsum kernel launch. + """ + row = tl.program_id(0) + + # Walk tokens_per_expert to find which expert this row belongs to + # and compute both unpadded and padded start offsets on the fly. + unpadded_start = tl.zeros([], dtype=tl.int32) + padded_start = tl.zeros([], dtype=tl.int32) + expert_id = -1 + for e in tl.static_range(0, num_experts): + count = tl.load(tpe_ptr + e).to(tl.int32) + if expert_id < 0 and row < unpadded_start + count: + expert_id = e + if expert_id < 0: + unpadded_start += count + aligned = tl.where( + count > 0, + ((count + alignment - 1) // alignment) * alignment, + tl.zeros([], dtype=tl.int32), + ) + padded_start += aligned + + if expert_id < 0: + return + + local_idx = row - unpadded_start + dst_row = padded_start + local_idx + + # Write permutation_map: padded row → original unpadded row + tl.store(perm_map_ptr + dst_row, row) + + # Copy hidden state + for h in tl.range(0, hidden_dim, BLOCK_H): + o = h + tl.arange(0, BLOCK_H) + m = o < hidden_dim + tl.store( + dst_ptr + dst_row * hidden_dim + o, + tl.load(src_ptr + row * hidden_dim + o, mask=m), + mask=m, + ) + + +def pad_to_alignment( + hidden_states: torch.Tensor, tokens_per_expert: torch.Tensor, alignment: int +) -> tuple: + """Pad already-permuted tokens so each expert's block is aligned. + + Args: + hidden_states: [total_tokens, hidden_size] already permuted by dispatcher. + tokens_per_expert: [num_local_experts] int32 token counts. + alignment: per-expert alignment. + + Returns: + (padded_hidden, permutation_map, inclusive_offsets) + - padded_hidden: [padded_total, hidden_size] + - permutation_map: [padded_total] int32, original row index or -1 for padding. + - inclusive_offsets: [num_local_experts] int32 cumulative aligned offsets for grouped_mm. + """ + num_experts = tokens_per_expert.shape[0] + total_tokens = hidden_states.shape[0] + hidden_dim = hidden_states.shape[1] + + # We still need padded_inc for the return value (used as offs by grouped_mm) + _, padded_inc = compute_expert_offsets(tokens_per_expert, alignment=alignment) + padded_total = int(padded_inc[-1].item()) + + padded_hidden = torch.zeros( + padded_total, hidden_dim, dtype=hidden_states.dtype, device=hidden_states.device + ) + permutation_map = torch.full( + (padded_total,), -1, dtype=torch.int32, device=hidden_states.device + ) + + if total_tokens > 0: + BLOCK_H = min(triton.next_power_of_2(hidden_dim), 1024) + _pad_tokens_kernel[(total_tokens,)]( + hidden_states, + padded_hidden, + permutation_map, + tokens_per_expert, + hidden_dim, + num_experts, + alignment, + BLOCK_H=BLOCK_H, + ) + + return padded_hidden, permutation_map, padded_inc + + +@triton.jit +def _unpad_tokens_kernel( + src_ptr, + dst_ptr, + perm_map_ptr, + probs_ptr, + hidden_dim, + has_probs: tl.constexpr, + BLOCK_H: tl.constexpr, +): + """Copy one real (non-padding) row from padded to unpadded layout. + + Optionally multiplies each row by its routing probability. + """ + row = tl.program_id(0) + dst_row = tl.load(perm_map_ptr + row) + if dst_row < 0: + return + if has_probs: + prob = tl.load(probs_ptr + dst_row) + for h in tl.range(0, hidden_dim, BLOCK_H): + o = h + tl.arange(0, BLOCK_H) + m = o < hidden_dim + v = tl.load(src_ptr + row * hidden_dim + o, mask=m) + if has_probs: + v = v * prob + tl.store(dst_ptr + dst_row * hidden_dim + o, v, mask=m) + + +def unpad_from_alignment( + padded_output: torch.Tensor, + permutation_map: torch.Tensor, + original_size: int, + probs: torch.Tensor = None, +) -> torch.Tensor: + """Remove alignment padding, scattering results back to original positions. + + Args: + padded_output: [padded_total, hidden_size] output from expert computation. + permutation_map: [padded_total] int32, original row index or -1 for padding. + original_size: number of rows in the unpadded output. + probs: optional [original_size] routing probabilities to multiply during unpad. + + Returns: + [original_size, hidden_size] unpadded output. + """ + hidden_dim = padded_output.shape[1] + output = torch.zeros( + original_size, hidden_dim, dtype=padded_output.dtype, device=padded_output.device + ) + has_probs = probs is not None + if padded_output.shape[0] > 0: + BLOCK_H = min(triton.next_power_of_2(hidden_dim), 1024) + _unpad_tokens_kernel[(padded_output.shape[0],)]( + padded_output, + output, + permutation_map, + probs if has_probs else padded_output, # dummy pointer when no probs + hidden_dim, + has_probs, + BLOCK_H=BLOCK_H, + ) + return output diff --git a/megatron/core/inference/moe/permute.py b/megatron/core/inference/moe/permute.py new file mode 100644 index 00000000000..b14d0b3dbd0 --- /dev/null +++ b/megatron/core/inference/moe/permute.py @@ -0,0 +1,458 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +"""Triton kernels for token permutation and unpermutation in fused MoE. + +Includes: +- Token counting per expert +- Expert offset computation (aligned prefix sums) +- Permute tokens into expert-grouped order +- Unpermute expert outputs back to original token order +""" + +from unittest.mock import MagicMock + +import torch + +from megatron.core.utils import null_decorator + +try: + import triton + import triton.language as tl + + HAVE_TRITON = True +except ImportError: + HAVE_TRITON = False + +if not HAVE_TRITON: + triton = MagicMock() + triton.jit = null_decorator + tl = MagicMock() + + +def _ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def _count_local_tokens_kernel( + routing_map_ptr, # [num_tokens * topk] flattened expert assignments + tokens_per_expert_ptr, # [num_local_experts] output counters (zeroed by caller) + total_pairs, # num_tokens * topk — total (token, topk) pairs + local_expert_start, # first global expert index owned by this rank + num_local_experts: tl.constexpr, # number of experts on this rank + BLOCK_SIZE: tl.constexpr, # number of pairs processed per program +): + """Count tokens routed to experts on this rank, ignoring tokens routed elsewhere. + + Each program processes BLOCK_SIZE (token, topk) pairs. Tokens assigned to + experts outside [local_expert_start, local_expert_start + num_local_experts) + are silently skipped. + """ + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < total_pairs + expert_ids = tl.load(routing_map_ptr + offsets, mask=mask, other=-1) + # Map global expert IDs to local indices; non-local experts become negative + local_ids = expert_ids - local_expert_start + is_local = (local_ids >= 0) & (local_ids < num_local_experts) & mask + tl.atomic_add(tokens_per_expert_ptr + local_ids, 1, mask=is_local) + + +def compute_local_tokens_per_expert( + routing_map: torch.Tensor, local_expert_start: int, num_local_experts: int +) -> torch.Tensor: + """Count tokens routed to each local expert.""" + total_pairs = routing_map.numel() + tokens_per_expert = torch.zeros(num_local_experts, dtype=torch.int32, device=routing_map.device) + BLOCK = 256 + _count_local_tokens_kernel[(_ceil_div(total_pairs, BLOCK),)]( + routing_map, + tokens_per_expert, + total_pairs, + local_expert_start, + num_local_experts, + BLOCK_SIZE=BLOCK, + ) + return tokens_per_expert + + +@triton.jit +def _prefix_sum_kernel( + tokens_per_expert_ptr, # [num_local_experts] raw token counts + exclusive_offsets_ptr, # [num_local_experts] output: exclusive prefix sum of aligned counts + inclusive_offsets_ptr, # [num_local_experts] output: inclusive prefix sum of aligned counts + num_local_experts, # number of experts on this rank + alignment: tl.constexpr, # per-expert alignment (counts rounded up to this multiple) + BLOCK_SIZE: tl.constexpr, # next_power_of_2(num_local_experts) for tl.cumsum +): + """Exclusive and inclusive prefix sums of aligned token counts. + + Each expert's token count is rounded up to the nearest multiple of + `alignment` (experts with 0 tokens stay at 0). The inclusive offsets + are used as `offs` by grouped_mm / scaled_grouped_mm. + """ + r = tl.arange(0, BLOCK_SIZE) + mask = r < num_local_experts + h = tl.load(tokens_per_expert_ptr + r, mask=mask, other=0) + # Round up non-zero counts to alignment boundary + if alignment > 1: + h = tl.where(h > 0, ((h + alignment - 1) // alignment) * alignment, h) + inc = tl.cumsum(h, axis=0) + tl.store(exclusive_offsets_ptr + r, inc - h, mask=mask) + tl.store(inclusive_offsets_ptr + r, inc, mask=mask) + + +def compute_expert_offsets(tokens_per_expert: torch.Tensor, alignment: int = 1) -> tuple: + """Compute exclusive and inclusive prefix sums of aligned token counts.""" + n = tokens_per_expert.shape[0] + exclusive_cumsum = torch.empty_like(tokens_per_expert) + inclusive_cumsum = torch.empty_like(tokens_per_expert) + _prefix_sum_kernel[(1,)]( + tokens_per_expert, + exclusive_cumsum, + inclusive_cumsum, + n, + alignment, + BLOCK_SIZE=triton.next_power_of_2(n), + ) + return exclusive_cumsum, inclusive_cumsum + + +@triton.jit +def _permute_tokens_kernel( + hidden_ptr, # [num_tokens, hidden_dim] input hidden states + probs_ptr, # [num_tokens, topk] routing probabilities + routing_map_ptr, # [num_tokens, topk] expert assignments (global IDs) + out_hidden_ptr, # [output_size, hidden_dim] output: permuted hidden states + out_probs_ptr, # [output_size] output: permuted probabilities + out_src_idx_ptr, # [output_size] output: permutation_map (original token index, -1 for padding) + counters_ptr, # [num_local_experts] exclusive offsets, + # atomically incremented to assign positions + num_tokens, # number of input tokens + hidden_dim, # hidden dimension + topk: tl.constexpr, # number of expert choices per token + local_expert_start, # first global expert index on this rank + num_local_experts: tl.constexpr, # number of experts on this rank + BLOCK_H: tl.constexpr, # tile size for copying hidden_dim +): + """Permute tokens into expert-grouped order. + + Grid: one program per (token, topk) pair. Each program looks up the assigned + expert, skips non-local experts, then atomically claims a position within + that expert's block and copies the hidden state + prob + source index. + """ + # Each program handles one (token, topk) pair + pair = tl.program_id(0) + tok = pair // topk + k = pair % topk + if tok >= num_tokens: + return + eid = tl.load(routing_map_ptr + tok * topk + k) + lid = eid - local_expert_start + # Skip tokens routed to non-local experts + if lid < 0 or lid >= num_local_experts: + return + # Atomically claim a position within this expert's aligned block + pos = tl.atomic_add(counters_ptr + lid, 1) + # Copy hidden state row + for h in tl.range(0, hidden_dim, BLOCK_H): + o = h + tl.arange(0, BLOCK_H) + m = o < hidden_dim + tl.store( + out_hidden_ptr + pos * hidden_dim + o, + tl.load(hidden_ptr + tok * hidden_dim + o, mask=m), + mask=m, + ) + tl.store(out_probs_ptr + pos, tl.load(probs_ptr + tok * topk + k)) + # Record source token index for unpermute + tl.store(out_src_idx_ptr + pos, tok) + + +def permute_tokens( + hidden_states: torch.Tensor, + probs: torch.Tensor, + routing_map: torch.Tensor, + local_expert_start: int, + num_local_experts: int, + alignment: int = 1, +) -> tuple: + """Permute tokens into expert-grouped order. + + Computes token counts, aligned expert offsets, output sizing, and + permutation in a single call. + + Args: + hidden_states: [num_tokens, hidden_size] input. + probs: [num_tokens, topk] routing probabilities. + routing_map: [num_tokens, topk] expert assignments. + local_expert_start: first global expert index on this rank. + num_local_experts: number of experts on this rank. + alignment: per-expert token alignment (default 1). + + Returns: + (permuted_hidden, permuted_probs, permutation_map, inclusive_offsets) + - permuted_hidden: [output_size, hidden_size] + - permuted_probs: [output_size] + - permutation_map: [output_size] int32, maps each permuted row back to + its original token index. Used by unpermute_tokens to scatter expert + outputs back and by activation kernels to skip padding rows (-1). + - inclusive_offsets: [num_local_experts] int32 cumulative offsets for grouped_mm + """ + num_tokens, hidden_dim = hidden_states.shape + topk = probs.shape[1] + + # Count how many (token, topk) pairs are routed to each local expert. + # Non-local experts are ignored. Result is [num_local_experts] int32. + tokens_per_expert = compute_local_tokens_per_expert( + routing_map, local_expert_start, num_local_experts + ) + + # exclusive_expert_offsets[i] = start of expert i's block in the padded output. + # Used as the initial counter for atomic position assignment in the permute kernel. + # inclusive_expert_offsets[i] = end of expert i's block (= start of expert i+1). + # Passed as `offs` to grouped_mm / scaled_grouped_mm to delimit expert boundaries. + exclusive_expert_offsets, inclusive_expert_offsets = compute_expert_offsets( + tokens_per_expert, alignment=alignment + ) + output_size = num_tokens * min(topk, num_local_experts) + alignment * num_local_experts + + permuted_hidden = torch.empty( + output_size, hidden_dim, dtype=hidden_states.dtype, device=hidden_states.device + ) + permuted_probs = torch.empty(output_size, dtype=probs.dtype, device=probs.device) + permutation_map = torch.full((output_size,), -1, dtype=torch.int32, device=probs.device) + BLOCK_H = min(triton.next_power_of_2(hidden_dim), 1024) + _permute_tokens_kernel[(num_tokens * topk,)]( + hidden_states, + probs, + routing_map, + permuted_hidden, + permuted_probs, + permutation_map, + exclusive_expert_offsets, + num_tokens, + hidden_dim, + topk, + local_expert_start, + num_local_experts, + BLOCK_H=BLOCK_H, + ) + return permuted_hidden, permuted_probs, permutation_map, inclusive_expert_offsets + + +@triton.jit +def _unpermute_tokens_kernel( + expert_out_ptr, # [output_size, hidden_dim] expert outputs in permuted order + probs_ptr, # [output_size] fp32 routing probabilities (permuted) + src_idx_ptr, # [output_size] permutation_map: original token index, or -1 for padding + output_ptr, # [num_tokens, hidden_dim] fp32 output buffer (zeroed by caller) + hidden_dim, # hidden dimension + BLOCK_H: tl.constexpr, # tile size for processing hidden_dim +): + """Scatter weighted expert outputs back to original token positions. + + Grid: one program per row of expert_out. Padding rows (src_idx == -1) are + skipped. Multiple topk selections for the same token are accumulated via + atomic adds. All arithmetic is in fp32 to avoid precision loss. + """ + row = tl.program_id(0) + source_idx = tl.load(src_idx_ptr + row) + # Skip padding rows + if source_idx < 0: + return + prob = tl.load(probs_ptr + row) # fp32 + for h in tl.range(0, hidden_dim, BLOCK_H): + offsets = h + tl.arange(0, BLOCK_H) + m = offsets < hidden_dim + # Upcast bf16 expert output to fp32 before multiply + accumulate + v = tl.load(expert_out_ptr + row * hidden_dim + offsets, mask=m).to(tl.float32) + tl.atomic_add(output_ptr + source_idx * hidden_dim + offsets, v * prob, mask=m) + + +def unpermute_tokens( + expert_output: torch.Tensor, + permuted_probs: torch.Tensor, + permutation_map: torch.Tensor, + num_tokens: int, +) -> torch.Tensor: + """Unpermute expert outputs back to original token order. + + Accumulates in fp32 to avoid precision loss from multiple topk atomic adds. + Returns fp32 output. + """ + assert ( + permuted_probs.dtype == torch.float32 + ), f"permuted_probs must be fp32, got {permuted_probs.dtype}" + output_size, hidden_dim = expert_output.shape + output = torch.zeros(num_tokens, hidden_dim, dtype=torch.float32, device=expert_output.device) + BLOCK_H = min(triton.next_power_of_2(hidden_dim), 1024) + _unpermute_tokens_kernel[(output_size,)]( + expert_output, permuted_probs, permutation_map, output, hidden_dim, BLOCK_H=BLOCK_H + ) + return output + + +@triton.jit +def _permute_quantize_mxfp8_kernel( + hidden_ptr, + probs_ptr, + routing_map_ptr, + out_fp8_ptr, + out_scale_ptr, + out_probs_ptr, + out_src_idx_ptr, + counters_ptr, + num_tokens, + K, + n_col_blocks, + topk: tl.constexpr, + local_expert_start, + num_local_experts: tl.constexpr, + REAL_GROUPS: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_GROUPS: tl.constexpr, +): + """Fused permute + MXFP8 quantize + swizzle in one kernel. + + Grid: (num_tokens * topk,) — one program per (token, k) pair. + Reads BF16 from source token, quantizes to FP8 e4m3, writes FP8 data + + swizzled e8m0 scales to the permuted write position. + """ + pair = tl.program_id(0) + tok = pair // topk + k = pair % topk + if tok >= num_tokens: + return + eid = tl.load(routing_map_ptr + tok * topk + k) + lid = eid - local_expert_start + if lid < 0 or lid >= num_local_experts: + return + + pos = tl.atomic_add(counters_ptr + lid, 1) + + # Load full row from source token + offs = tl.arange(0, BLOCK_K) + mask = offs < K + x = tl.load(hidden_ptr + tok * K + offs, mask=mask, other=0.0).to(tl.float32) + + # Per-group-of-32 quantization + x_grouped = tl.reshape(x, [BLOCK_GROUPS, 32]) + abs_grouped = tl.abs(x_grouped) + max_vals = tl.max(abs_grouped, axis=1) + + dequant_scale = max_vals / 448.0 + dequant_exp = (dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF) & 0x7F800000 + dequant_rounded = dequant_exp.to(tl.float32, bitcast=True) + quant_scale = tl.where(dequant_rounded == 0, 0.0, 1.0 / dequant_rounded) + + quantized = x_grouped * quant_scale[:, None] + quantized_flat = tl.reshape(quantized, [BLOCK_K]) + out_fp8 = quantized_flat.to(tl.float8e4nv) + + # Store FP8 data at permuted position + tl.store(out_fp8_ptr + pos * K + offs, out_fp8, mask=mask) + + # Store swizzled scales at permuted position + scale_exp = (dequant_exp >> 23).to(tl.uint8) + col_offs = tl.arange(0, BLOCK_GROUPS) + col_mask = col_offs < REAL_GROUPS + + macro_row_block = pos // 128 + macro_col_block = col_offs // 4 + local_row = pos % 128 + local_col = col_offs % 4 + group = local_row // 32 + sub_row = local_row % 32 + tile_idx = macro_row_block * n_col_blocks + macro_col_block + swizzled_offs = tile_idx * 512 + sub_row * 16 + group * 4 + local_col + + tl.store(out_scale_ptr + swizzled_offs, scale_exp, mask=col_mask) + + # Store prob and source index + tl.store(out_probs_ptr + pos, tl.load(probs_ptr + tok * topk + k)) + tl.store(out_src_idx_ptr + pos, tok) + + +def permute_and_quantize_mxfp8( + hidden_states: torch.Tensor, + probs: torch.Tensor, + routing_map: torch.Tensor, + local_expert_start: int, + num_local_experts: int, + alignment: int = 128, +) -> tuple: + """Fused permute + MXFP8 quantize + swizzle. + + Self-contained API matching permute_tokens: computes token counts, aligned + expert offsets, output sizing, permutation, and MXFP8 quantization in a + single kernel launch. + + Args: + hidden_states: [num_tokens, hidden_size] BF16 input. + probs: [num_tokens, topk] routing probabilities. + routing_map: [num_tokens, topk] expert assignments. + local_expert_start: first global expert index on this rank. + num_local_experts: number of experts on this rank. + alignment: per-expert token alignment (default 128, required for MXFP8 swizzle). + + Returns: + (permuted_mxfp8, permuted_probs, permutation_map, inclusive_offsets) + - permuted_mxfp8: MXFP8Tensor with .data [output_size, K] and .scale (swizzled) + - permuted_probs: [output_size] routing probs + - permutation_map: [output_size] int32, original token index or -1 for padding + - inclusive_offsets: [num_local_experts] int32 cumulative offsets for scaled_grouped_mm + """ + from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor + + num_tokens, K = hidden_states.shape + topk = probs.shape[1] + assert K % 32 == 0 + + # Count how many (token, topk) pairs are routed to each local expert. + tokens_per_expert = compute_local_tokens_per_expert( + routing_map, local_expert_start, num_local_experts + ) + + # exclusive_expert_offsets[i] = start of expert i's block in the padded output. + # inclusive_expert_offsets[i] = end of expert i's block (= start of expert i+1). + exclusive_expert_offsets, inclusive_expert_offsets = compute_expert_offsets( + tokens_per_expert, alignment=alignment + ) + output_size = num_tokens * min(topk, num_local_experts) + alignment * num_local_experts + + scale_cols = K // 32 + n_row_blocks = _ceil_div(output_size, 128) + n_col_blocks = _ceil_div(scale_cols, 4) + total_scale_bytes = n_row_blocks * n_col_blocks * 512 + + out_fp8 = torch.empty(output_size, K, dtype=torch.float8_e4m3fn, device=hidden_states.device) + out_scale = torch.zeros(total_scale_bytes, dtype=torch.uint8, device=hidden_states.device) + permuted_probs = torch.empty(output_size, dtype=probs.dtype, device=probs.device) + permutation_map = torch.full((output_size,), -1, dtype=torch.int32, device=probs.device) + + BLOCK_K = triton.next_power_of_2(K) + BLOCK_GROUPS = BLOCK_K // 32 + + _permute_quantize_mxfp8_kernel[(num_tokens * topk,)]( + hidden_states, + probs, + routing_map, + out_fp8, + out_scale, + permuted_probs, + permutation_map, + exclusive_expert_offsets, + num_tokens, + K, + n_col_blocks, + topk, + local_expert_start, + num_local_experts, + REAL_GROUPS=scale_cols, + BLOCK_K=BLOCK_K, + BLOCK_GROUPS=BLOCK_GROUPS, + ) + + permuted_mxfp8 = MXFP8Tensor( + data=out_fp8, scale=out_scale.view(torch.float8_e8m0fnu), backend="triton" + ) + return permuted_mxfp8, permuted_probs, permutation_map, inclusive_expert_offsets diff --git a/megatron/core/inference/quantization/mxfp8_quantize.py b/megatron/core/inference/quantization/mxfp8_quantize.py new file mode 100644 index 00000000000..73f2ac974b3 --- /dev/null +++ b/megatron/core/inference/quantization/mxfp8_quantize.py @@ -0,0 +1,185 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""Standalone MXFP8 quantization kernel with fused scale swizzle. + +One block per token. Quantizes BF16 → FP8 e4m3 and writes scales directly +in cuBLAS 2D blocked (swizzled) layout. No FP4, no triton_kernels dependency. + +Usage: + from megatron.core.inference.quantization.mxfp8_quantize import mxfp8_quantize + data, swizzled_scales, total_scale_bytes = mxfp8_quantize(x_bf16) + # data: [M, K] float8_e4m3fn + # swizzled_scales: 1D uint8 in cuBLAS blocked layout +""" + +import torch +import triton +import triton.language as tl + + +def _ceil_div(a, b): + return (a + b - 1) // b + + +@triton.jit +def _mxfp8_quant_swizzle_kernel( + out_ptr, # [M, K] output buffer for float8_e4m3fn quantized data + scale_ptr, # 1D output buffer for swizzled uint8 scales (e8m0 exponents) + src_ptr, # [M, K] input tensor in bf16/fp16/fp32 + K, # number of columns in the input (must be divisible by 32) + n_col_blocks, # ceil(K/32 / 4) — number of macro-tile columns in the swizzle layout + REAL_GROUPS: tl.constexpr, # actual number of scale groups per row (K // 32) + BLOCK_K: tl.constexpr, # next_power_of_2(K) — padded column count for tl.reshape + BLOCK_GROUPS: tl.constexpr, # BLOCK_K // 32 — padded group count (must be power of 2) +): + """Each triton block quantizes one row → FP8 e4m3, write scales directly in swizzled layout. + + We use round up in scale calculation. see: Mishra et al., + Recipes for Pre-training LLMs with MXFP8 (https://arxiv.org/pdf/2506.08027) + + The implementation borrows code from the triton upstream MXFP downcast kernel: + https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/numerics_details/mxfp_details/_downcast_to_mxfp.py + + Note on swizzled scale layout (torch.nn.functional.SwizzleType.SWIZZLE_32_4_4): + + Background: In MXFP8, every group of 32 elements shares one 1-byte scale + (an e8m0 exponent). For an [M, K] matrix, this gives an [M, K//32] scale + matrix. cuBLAS doesn't read these scales in simple row-major order — it + expects a "swizzled" layout optimized for its internal access patterns. + + Step 1 — Divide into macro-tiles: + The scale matrix is partitioned into 128-row x 4-col macro-tiles. + Each tile is stored as a contiguous 512-byte (128 x 4) block. + + Step 2 — Interleave within each tile: + Within a macro-tile, the 128 rows are NOT stored sequentially. + Instead, they are split into 4 groups of 32 rows: + group 0: rows 0- 31 + group 1: rows 32- 63 + group 2: rows 64- 95 + group 3: rows 96-127 + + Rows with the same position within their group (same "sub_row") + are placed next to each other. So the memory layout is: + + Concretely, for sub_row=0: + byte 0: row 0, col 0 + byte 1: row 0, col 1 + byte 2: row 0, col 2 + byte 3: row 0, col 3 + byte 4: row 32, col 0 + byte 5: row 32, col 1 + byte 6: row 32, col 2 + byte 7: row 32, col 3 + byte 8: row 64, col 0 + ... + byte 15: row 96, col 3 + + The formula to map logical (row, col) → byte offset: + tile_idx = (row // 128) * n_col_blocks + (col // 4) + sub_row = row % 32 + group = (row % 128) // 32 + local_col = col % 4 + offset = tile_idx * 512 + sub_row * 16 + group * 4 + local_col + + """ + row = tl.program_id(0) + src_row = src_ptr + row * K + out_row = out_ptr + row * K + + offs = tl.arange(0, BLOCK_K) + mask = offs < K + + # Load full row + x = tl.load(src_row + offs, mask=mask, other=0.0).to(tl.float32) + + # Per-group-of-32 max + x_grouped = tl.reshape(x, [BLOCK_GROUPS, 32]) + abs_grouped = tl.abs(x_grouped) + max_vals = tl.max(abs_grouped, axis=1) + + # 448 is the max representable value in FP8 e4m3. + # dequant_scale = min scale s.t. max_val / scale <= 448. + dequant_scale = max_vals / 448.0 + # Round up to next power of 2 via integer bit manipulation: + # Adding 0x007FFFFF (mantissa mask) before masking with 0x7F800000 + # (exponent-only mask) bumps the exponent if any mantissa bits are set. + # Result: 2^ceil(log2(max/448)) as a uint32-encoded float. + dequant_exp = (dequant_scale.to(tl.uint32, bitcast=True) + 0x007FFFFF) & 0x7F800000 + # Reinterpret uint32 back as float32 — now a power-of-2 dequantization scale. + dequant_rounded = dequant_exp.to(tl.float32, bitcast=True) + # Quantization scale is the reciprocal; guard against div-by-zero for all-zero groups. + quant_scale = tl.where(dequant_rounded == 0, 0.0, 1.0 / dequant_rounded) + + # Quantize + quantized = x_grouped * quant_scale[:, None] + quantized_flat = tl.reshape(quantized, [BLOCK_K]) + out_fp8 = quantized_flat.to(tl.float8e4nv) + + # Store FP8 data + tl.store(out_row + offs, out_fp8, mask=mask) + + # Store swizzled scales + scale_exp = (dequant_exp >> 23).to(tl.uint8) + col_offs = tl.arange(0, BLOCK_GROUPS) + col_mask = col_offs < REAL_GROUPS + + # Compute swizzled offsets for each scale element. + # + # The scale matrix [M, K//32] is divided into 128×4 macro-tiles. + # Within each tile, rows are split into 4 groups of 32 (group = local_row // 32). + # Rather than flattening row-major, the layout interleaves groups so that + # rows 32 apart are adjacent in memory: + # + # offset = tile_idx * 512 + sub_row * 16 + group * 4 + local_col + macro_row_block = row // 128 + macro_col_block = col_offs // 4 + local_row = row % 128 + local_col = col_offs % 4 + group = local_row // 32 + sub_row = local_row % 32 + tile_idx = macro_row_block * n_col_blocks + macro_col_block + swizzled_offs = tile_idx * 512 + sub_row * 16 + group * 4 + local_col + + tl.store(scale_ptr + swizzled_offs, scale_exp, mask=col_mask) + + +def mxfp8_quantize(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize a 2D tensor to MXFP8 with fused scale swizzle. + + Args: + x: [M, K] tensor in bf16/fp16/fp32. K must be divisible by 32. + + Returns: + (data, swizzled_scales): + data: [M, K] float8_e4m3fn + swizzled_scales: 1D tensor in cuBLAS blocked layout (uint8/e8m0) + """ + assert x.is_cuda and x.dim() == 2 + assert x.dtype in (torch.bfloat16, torch.float16, torch.float32) + M, K = x.shape + assert K % 32 == 0, f"K ({K}) must be divisible by 32" + + scale_cols = K // 32 + n_row_blocks = _ceil_div(M, 128) + n_col_blocks = _ceil_div(scale_cols, 4) + total_scale_bytes = n_row_blocks * n_col_blocks * 512 + + out_data = torch.empty(M, K, dtype=torch.float8_e4m3fn, device=x.device) + out_scale = torch.zeros(total_scale_bytes, dtype=torch.uint8, device=x.device) + + BLOCK_K = triton.next_power_of_2(K) + BLOCK_GROUPS = BLOCK_K // 32 + + _mxfp8_quant_swizzle_kernel[(M,)]( + out_data, + out_scale, + x, + K, + n_col_blocks, + REAL_GROUPS=scale_cols, + BLOCK_K=BLOCK_K, + BLOCK_GROUPS=BLOCK_GROUPS, + ) + + return out_data, out_scale.view(torch.float8_e8m0fnu) diff --git a/megatron/core/inference/symmetric_memory.py b/megatron/core/inference/symmetric_memory.py new file mode 100644 index 00000000000..254d41ce294 --- /dev/null +++ b/megatron/core/inference/symmetric_memory.py @@ -0,0 +1,182 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""Lazy-initialized symmetric memory manager for inference. + +Provides a registry of SymmetricMemoryBuffer instances keyed by a +user-supplied identifier (e.g. "tp", "ep"). Buffers are created on first +access so that callers never need to worry about initialization ordering +relative to the inference context. +""" + +from __future__ import annotations + +import operator +from functools import reduce +from typing import Optional + +import torch + +try: + import torch.distributed._symmetric_memory as symm_mem + + HAVE_TORCH_SYMM_MEM = True +except ImportError: + HAVE_TORCH_SYMM_MEM = False + +try: + import triton # pylint: disable=unused-import + + HAVE_TRITON = True +except ImportError: + HAVE_TRITON = False + + +class SymmetricMemoryBuffer: + """ + symmetric memory buffer used in inference. + This buffer is used by mcore-inference's low-latency + NVLS all-gather and reduce-scatter collectives. + """ + + def __init__(self, size_in_mb, process_group): + if not HAVE_TORCH_SYMM_MEM or not HAVE_TRITON: + # This should be hit if the user is running an older + # version of torch, or if they do not have triton + # installed. + self.symm_buffer = None + self.symm_mem_hdl = None + else: + numel = int(size_in_mb * 1024 * 1024) # size in bytes + try: + symm_mem.enable_symm_mem_for_group(process_group.group_name) + self.symm_buffer = symm_mem.empty(numel, dtype=torch.uint8, device='cuda') + self.symm_mem_hdl = symm_mem.rendezvous(self.symm_buffer, process_group) + except RuntimeError as e: + # If symmetric memory initialization fails, set buffer and handle to None + # This should happen if the process group is not contained within NVlink + self.symm_buffer = None + self.symm_mem_hdl = None + + def _can_allocate(self, numel, dtype) -> bool: + """ + Returns whether enough symmetric memory is available + for the given tensor shape and dtype. + """ + if self.symm_mem_hdl is None: + return False + size_of_dtype = torch.tensor([], dtype=dtype).element_size() + required_len = numel * size_of_dtype + return required_len <= self.symm_buffer.numel() + + def _allocate(self, numel, dtype) -> torch.Tensor: + """ + Allocates a sub-tensor from the self.symm_buffer for the given numel and dtype""" + required_bytes = numel * torch.tensor([], dtype=dtype).element_size() + return self.symm_buffer[0:required_bytes].view(dtype).view(numel) + + def maybe_get_tensors(self, tensor_specs, alignment=16): + """ + Pack multiple tensors contiguously in the symmetric buffer with alignment. + + Each tensor's starting offset is aligned to `alignment` bytes (default 16 + for 128-bit multimem access). + + Args: + tensor_specs: list of (numel, dtype) tuples. + alignment: byte alignment for each tensor's start offset (default 16). + + Returns: + {"handle": None, "tensors": None} if unavailable or insufficient space. + {"handle": symm_mem_hdl, "tensors": [(raw_byte_view, byte_offset), ...]} + on success, where raw_byte_view is a uint8 slice of the buffer. + """ + _NONE_RESULT = {"handle": None, "tensors": None} + if self.symm_mem_hdl is None: + return _NONE_RESULT + + # Compute aligned byte sizes and running offsets + slices = [] + current_offset = 0 + for numel, dtype in tensor_specs: + nbytes = numel * torch.tensor([], dtype=dtype).element_size() + aligned_nbytes = ((nbytes + alignment - 1) // alignment) * alignment + slices.append((current_offset, nbytes)) + current_offset += aligned_nbytes + + if not self._can_allocate(current_offset, torch.uint8): + return _NONE_RESULT + + tensors = [] + for offset, nbytes in slices: + tensors.append((self.symm_buffer[offset : offset + nbytes], offset)) + + return {"handle": self.symm_mem_hdl, "tensors": tensors} + + def maybe_get_tensor(self, tensor_shape, dtype): + """ + Returns (potentially) a sub-tensor from the self.symm_buffer for the given shape. + If enough symmetric memory is not available, returns None. + """ + if self.symm_mem_hdl is None: + return {"tensor": None, "handle": None} + numel = reduce(operator.mul, tensor_shape, 1) + if not self._can_allocate(numel, dtype): + return {"tensor": None, "handle": None} + return { + "tensor": self._allocate(numel, dtype).view(*tensor_shape), + "handle": self.symm_mem_hdl, + } + + +class SymmetricMemoryManager: + """Registry of lazily-initialized symmetric memory buffers. + + Usage:: + + buf = SymmetricMemoryManager.get_buffer("tp", process_group=tp_group) + result = buf.maybe_get_tensor(shape, dtype) + """ + + _buffers: dict[str, SymmetricMemoryBuffer] = {} + _default_size_mb: int = 256 + + @classmethod + def get_buffer( + cls, + key: str, + process_group: Optional[torch.distributed.ProcessGroup] = None, + size_mb: Optional[int] = None, + ) -> SymmetricMemoryBuffer: + """Return the buffer for *key*, creating it on first call. + + Args: + key: Unique identifier (e.g. "tp", "ep"). + process_group: Required on the first call for a given key. + Subsequent calls may omit it. + size_mb: Buffer size in MiB (default 256). + """ + if key not in cls._buffers: + assert ( + process_group is not None + ), f"SymmetricMemoryManager: process_group is required on first access for key='{key}'" + cls._buffers[key] = SymmetricMemoryBuffer( + size_in_mb=size_mb or cls._default_size_mb, process_group=process_group + ) + return cls._buffers[key] + + @classmethod + def destroy(cls, key: Optional[str] = None) -> None: + """Destroy one or all buffers. + + Args: + key: If provided, destroy only that buffer. Otherwise destroy all. + """ + if key is not None: + cls._buffers.pop(key, None) + else: + cls._buffers.clear() + + @classmethod + def is_initialized(cls, key: str) -> bool: + """Check whether a buffer has been created for *key*.""" + return key in cls._buffers diff --git a/megatron/core/inference/text_generation_server/dynamic_text_gen_server/text_generation_server.py b/megatron/core/inference/text_generation_server/dynamic_text_gen_server/text_generation_server.py new file mode 100644 index 00000000000..adf5c39a9b4 --- /dev/null +++ b/megatron/core/inference/text_generation_server/dynamic_text_gen_server/text_generation_server.py @@ -0,0 +1,210 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +import asyncio +import logging +import multiprocessing as mp +import socket +from contextlib import contextmanager +from typing import List, Optional + +try: + from hypercorn.asyncio import serve + from hypercorn.config import Config + from quart import Quart + + HAS_BACKEND = True +except ImportError as e: + HAS_BACKEND = False + +import megatron.core.inference.text_generation_server.dynamic_text_gen_server.endpoints as endpoints +from megatron.core.inference.inference_client import InferenceClient +from megatron.core.utils import trace_async_exceptions + +logger = logging.getLogger(__name__) + +# Global reference to manage the background server processes +_SERVER_PROCESSES: List[mp.Process] = [] +_SHARED_SOCKET = None + + +@contextmanager +def temp_log_level(level, logger=None): + """Enables temporarily overriding the logging level.""" + logger = logger or logging.getLogger() + old_level = logger.level + logger.setLevel(level) + try: + yield + finally: + logger.setLevel(old_level) + + +@trace_async_exceptions +async def _run_text_gen_server( + coordinator_addr: str, + tokenizer, + rank: int, + server_port: int, + parsers: Optional[List[str]] = None, + verbose: bool = False, + fd: Optional[int] = None, +): + """ + Initializes and runs the async web server. Automatically starts and + manages its own InferenceClient connected to the provided coordinator address. + """ + if not HAS_BACKEND: + raise RuntimeError(f"Web backend framework (Quart) not available") + + # Create and start the client locally inside this process + inference_client = InferenceClient(coordinator_addr) + inference_client.start() + logger.info(f"Rank {rank}: InferenceClient connected.") + + try: + try: + hostname = socket.gethostname() + except Exception as e: + logger.warning(f"Could not get hostname: {e}") + hostname = "0.0.0.0" + + app = Quart(__name__) + + # Quart native way to handle max body size (1 GB; needed for large prompts) + app.config['MAX_CONTENT_LENGTH'] = 2**30 + + # Store client and tokenizer in app config for Blueprints to use + app.config['client'] = inference_client + app.config['tokenizer'] = tokenizer + app.config['parsers'] = parsers + app.config['verbose'] = verbose + + # Register all blueprints from the 'endpoints' package + for endpoint in endpoints.__all__: + app.register_blueprint(endpoint) + + config = Config() + config.keep_alive_timeout = 30.0 # Keep connection alive between long-running requests. + config.backlog = 2**14 # Expect high load; ensure we do not drop connections. + config.h2_max_concurrent_streams = ( + 2**14 + ) # Allow many concurrent streams for HTTP/2 clients. + + if fd is not None: + config.bind = [f"fd://{fd}"] + else: + config.bind = [f"0.0.0.0:{server_port}"] + + with temp_log_level(logging.INFO, logger): + logger.info(f"Starting text generation server on http://{hostname}:{server_port}") + logger.info(f"Using tokenizer: {type(tokenizer)}") + logger.info(f"Using parsers: {parsers}") + + # Quart is natively ASGI, so we can serve the app directly + await serve(app, config) + + finally: + # Gracefully shut down the client when the server stops + inference_client.stop() + logger.info(f"Rank {rank}: Web server and client shut down.") + + +def _server_process_worker( + coordinator_addr: str, + tokenizer, + rank: int, + server_port: int, + parsers: Optional[List[str]] = None, + verbose: bool = False, + fd: Optional[int] = None, +): + """Synchronous worker function that sets up a new event loop for the separate process.""" + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + try: + loop.run_until_complete( + _run_text_gen_server( + coordinator_addr, tokenizer, rank, server_port, parsers, verbose, fd + ) + ) + except KeyboardInterrupt: + logger.info(f"Rank {rank}: text gen server process interrupted.") + finally: + pending = asyncio.all_tasks(loop) + for task in pending: + task.cancel() + if pending: + loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) + loop.close() + + +def start_text_gen_server( + coordinator_addr: str, + tokenizer, + rank: int, + server_port: int, + parsers: Optional[List[str]] = None, + verbose: bool = False, + num_replicas: int = 4, +): + """Start the text generation server.""" + global _SERVER_PROCESSES + global _SHARED_SOCKET + + if _SERVER_PROCESSES: + logger.warning("Text gen server processes are already running.") + return + + _SHARED_SOCKET = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + _SHARED_SOCKET.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + + if hasattr(socket, 'SO_REUSEPORT'): + try: + _SHARED_SOCKET.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + except OSError: + pass + + _SHARED_SOCKET.bind(("0.0.0.0", server_port)) + _SHARED_SOCKET.setblocking(False) + + _SHARED_SOCKET.set_inheritable(True) + fd = _SHARED_SOCKET.fileno() + + for i in range(num_replicas): + p = mp.Process( + target=_server_process_worker, + args=(coordinator_addr, tokenizer, rank, server_port, parsers, verbose, fd), + daemon=True, + ) + p.start() + _SERVER_PROCESSES.append(p) + logger.info(f"Started text gen frontend replica {i+1}/{num_replicas} (PID: {p.pid})") + + +def stop_text_gen_server(): + """Stop the text generation server.""" + global _SERVER_PROCESSES + global _SHARED_SOCKET + + if not _SERVER_PROCESSES: + return + + logger.info(f"Terminating {len(_SERVER_PROCESSES)} Text Gen frontend processes...") + + for p in _SERVER_PROCESSES: + if p.is_alive(): + p.terminate() + + for p in _SERVER_PROCESSES: + p.join(timeout=3) + if p.is_alive(): + p.kill() + p.join() + + # Clean up the master socket + if _SHARED_SOCKET is not None: + _SHARED_SOCKET.close() + _SHARED_SOCKET = None + + _SERVER_PROCESSES = [] + logger.info("All text gen frontend processes terminated.") diff --git a/megatron/core/models/mimo/partition/utils.py b/megatron/core/models/mimo/partition/utils.py new file mode 100644 index 00000000000..0b43e5548ff --- /dev/null +++ b/megatron/core/models/mimo/partition/utils.py @@ -0,0 +1,260 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +"""Token and weight partitioning helper (CP, TP, SP). + +The adapter slices sequences across *context-parallel* ranks and can further +scatter them across *sequence-parallel* ranks when sequence-parallelism is +enabled. +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch # type: ignore[import-not-found] +from torch.distributed import ProcessGroup # type: ignore[import-not-found] + +from megatron.core import tensor_parallel +from megatron.core.model_parallel_config import ModelParallelConfig +from megatron.core.packed_seq_params import PackedSeqParams +from megatron.core.parallel_state import get_context_parallel_group, get_tensor_model_parallel_group +from megatron.core.utils import ( + get_batch_on_this_cp_rank, + get_pg_rank, + get_pg_size, + is_te_min_version, +) + +try: + import transformer_engine_torch as tex # type: ignore + + _HAVE_TEX = True +except ModuleNotFoundError: # pragma: no cover + tex = None # type: ignore + _HAVE_TEX = False + + +@dataclass(frozen=True) +class PartitionConfig: + """Minimal runtime information needed to shard inputs. + + NOTE: Always construct PartitionConfig using the provided classmethod + (from_mp_config) to ensure all fields, including cp_group and tp_group, + are set correctly. + """ + + seq_parallel: bool + use_cp: bool + tp_comm_overlap: bool + max_seq_len: int + kv_format: str = "sbhd" # "sbhd" | "thd" + cp_group: Optional[ProcessGroup] = None + tp_group: Optional[ProcessGroup] = None + + @property + def is_partitioning_enabled(self) -> bool: + """Returns True if context parallelism or sequence parallelism is active.""" + return self.use_cp or self.seq_parallel + + @classmethod + def from_mp_config( + cls, + mp: ModelParallelConfig, + *, + max_seq_len: int, + kv_format: str = "sbhd", + cp_group: Optional[ProcessGroup] = None, + tp_group: Optional[ProcessGroup] = None, + ) -> "PartitionConfig": + """ + Creates a PartitionConfig from a ModelParallelConfig. + """ + if not isinstance(mp, ModelParallelConfig): + raise TypeError("mp must be a ModelParallelConfig instance") + + if mp.context_parallel_size > 1 and cp_group is None: + cp_group = get_context_parallel_group() + + if mp.sequence_parallel and tp_group is None: + tp_group = get_tensor_model_parallel_group() + + return cls( + seq_parallel=mp.sequence_parallel, + use_cp=get_pg_size(cp_group) > 1, + tp_comm_overlap=mp.tp_comm_overlap, + max_seq_len=max_seq_len, + kv_format=kv_format, + cp_group=cp_group, + tp_group=tp_group, + ) + + +class PartitionAdapter: + """Shard batch-first embeddings & label tensors for Context and Sequence Parallelism.""" + + def __init__(self, cfg: PartitionConfig): + """Initialize the partition adapter. + Args: + cfg: PartitionConfig, the configuration for the partition adapter. + """ + self.cfg = cfg + + def shard( + self, + embeddings: torch.Tensor, + labels: torch.Tensor, + loss_mask: torch.Tensor, + attention_mask: torch.Tensor, + packed_seq_params: Optional[PackedSeqParams] = None, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Optional[PackedSeqParams]]: + """ + Apply context parallel (CP) and sequence parallel (SP) sharding to input tensors. + + All input tensors must be in batch-first layout: + - embeddings: (B, S, H) + - labels / loss_mask / attention_mask: (B, S) + + After this call embeddings are still in (B, S/cp, H) batch-first layout. + The caller is responsible for transposing to (S/cp, B, H) if the language model + requires sequence-first tensors. + + Args: + embeddings (torch.Tensor): + Input embeddings tensor. Shape: (B, S, H) + labels (torch.Tensor): + Labels tensor. Shape: (B, S) + loss_mask (torch.Tensor): + Loss mask tensor. Shape: (B, S) + attention_mask (torch.Tensor): + Attention mask tensor. Shape: (B, S) + packed_seq_params (PackedSeqParams, optional): + Packed sequence parameters. Defaults to None. + + Returns: + Tuple containing: + - embeddings (torch.Tensor): Sharded embeddings. Shape: (B, S/cp, H) + - labels (torch.Tensor): Possibly sharded labels. Shape: (B, S/cp) + - loss_mask (torch.Tensor): Possibly sharded loss mask. Shape: (B, S/cp) + - attention_mask (torch.Tensor): Possibly sharded attention mask. Shape: (B, S/cp) + - packed_seq_params (PackedSeqParams, optional): Updated packed sequence parameters. + """ + if not (self.cfg.use_cp or self.cfg.seq_parallel): + return embeddings, labels, loss_mask, attention_mask, packed_seq_params + + # Sanity-check the sequence length before any sharding happens. + if embeddings is not None: + shard_factor = None + seq_dim = None # which dimension holds the token sequence + + if self.cfg.use_cp and self.cfg.seq_parallel: + shard_factor = get_pg_size(self.cfg.tp_group) * get_pg_size(self.cfg.cp_group) * 2 + seq_dim = 1 # embeddings shape: [B, S, H] + elif self.cfg.use_cp: + shard_factor = get_pg_size(self.cfg.cp_group) * 2 + seq_dim = 1 + elif self.cfg.seq_parallel: + shard_factor = get_pg_size(self.cfg.tp_group) + seq_dim = 0 # embeddings shape: [S, B, H] + + if shard_factor is not None and ( + packed_seq_params is None + or getattr(packed_seq_params, 'qkv_format', 'sbhd') == 'sbhd' + ): + assert embeddings.shape[seq_dim] % shard_factor == 0, ( + f"Sequence length should be divisible by {shard_factor} " + "for Sequence/Context parallelism" + ) + + if self.cfg.seq_parallel and self.cfg.tp_comm_overlap: + assert embeddings.shape[seq_dim] == self.cfg.max_seq_len, ( + "TP Comm overlap requires Vision+Text token length " + "== language_max_sequence_length" + ) + + if self.cfg.use_cp: + embeddings, labels, loss_mask, attention_mask, packed_seq_params = ( + self._apply_context_parallel( + embeddings, labels, loss_mask, attention_mask, packed_seq_params + ) + ) + + if self.cfg.seq_parallel and embeddings is not None: + embeddings = tensor_parallel.scatter_to_sequence_parallel_region(embeddings) + + return embeddings, labels, loss_mask, attention_mask, packed_seq_params + + def _apply_context_parallel( + self, + embeddings: Optional[torch.Tensor], + labels: Optional[torch.Tensor], + loss_mask: Optional[torch.Tensor], + attention_mask: Optional[torch.Tensor], + packed_seq_params: Optional[PackedSeqParams], + ) -> Tuple[ + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[torch.Tensor], + Optional[PackedSeqParams], + ]: + """ + Apply context parallel (CP) sharding to input tensors. + + Args: + embeddings (Optional[torch.Tensor]): + Input embeddings tensor. Shape: (B, S, H) + labels (Optional[torch.Tensor]): + Labels tensor. Shape: (B, S) + loss_mask (Optional[torch.Tensor]): + Loss mask tensor. Shape: (B, S) + attention_mask (Optional[torch.Tensor]): + Attention mask tensor. Shape: (B, S) + packed_seq_params (PackedSeqParams, optional): + Packed sequence parameters. Defaults to None. + + Returns: + Tuple containing: + - embeddings (Optional[torch.Tensor]): Sharded embeddings. Shape: (B, S/cp, H) + - labels (Optional[torch.Tensor]): Possibly sharded labels. Shape: (B, S/cp) + - loss_mask (Optional[torch.Tensor]): Possibly sharded loss mask. Shape: (B, S/cp) + - attention_mask (Optional[torch.Tensor]): Possibly sharded attention mask. + Shape: (B, S/cp) + - packed_seq_params (PackedSeqParams, optional): Updated packed sequence parameters. + """ + if not self.cfg.use_cp: + return embeddings, labels, loss_mask, attention_mask, packed_seq_params + + # Distribute sequence across CP ranks + batch = dict() + if embeddings is not None: + batch["embeddings"] = embeddings + if labels is not None: + batch["labels"] = labels + if loss_mask is not None: + batch["loss_mask"] = loss_mask + if attention_mask is not None: + batch["attention_mask"] = attention_mask + + if packed_seq_params is None or getattr(packed_seq_params, 'qkv_format', 'sbhd') == 'sbhd': + batch = get_batch_on_this_cp_rank(batch) + else: + assert _HAVE_TEX and is_te_min_version("1.10.0"), ( + "Please update Transformer Engine to >= 1.10 " + "to use Context Parallel with THD format data" + ) + assert self.cfg.cp_group is not None + cp_size = get_pg_size(self.cfg.cp_group) + cp_rank = get_pg_rank(self.cfg.cp_group) + for key, data in batch.items(): + index = tex.thd_get_partitioned_indices( + packed_seq_params.cu_seqlens_q_padded, data.size(1), cp_size, cp_rank + ) + batch[key] = data.index_select(1, index) + + # Extract sharded tensors; embeddings remain in [B, S/cp, H] — the caller + # is responsible for transposing to [S/cp, B, H] for the language model. + embeddings = batch.get("embeddings", None) + labels = batch.get("labels", None) + loss_mask = batch.get("loss_mask", None) + attention_mask = batch.get("attention_mask", None) + + return embeddings, labels, loss_mask, attention_mask, packed_seq_params diff --git a/megatron/core/resharding/nvshmem_copy_service/compat.py b/megatron/core/resharding/nvshmem_copy_service/compat.py new file mode 100644 index 00000000000..be624c3ba26 --- /dev/null +++ b/megatron/core/resharding/nvshmem_copy_service/compat.py @@ -0,0 +1,49 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +""" +Compatibility layer for cuda-core version differences. + +cuda-core >=0.5 removed the ``cuda.core.experimental._memory`` and +``cuda.core.experimental._stream`` private submodules, but nvshmem4py +still imports from them. We register ``sys.modules`` shims so those +imports resolve to the new ``cuda.core._memory`` / ``cuda.core._stream`` +paths. + +This module should be imported before any nvshmem.core usage. +""" + +import importlib +import sys + + +def _patch_cuda_core_experimental(): + """Register cuda.core._memory / _stream as cuda.core.experimental._memory / _stream.""" + for submod in ("_memory", "_stream"): + exp_key = f"cuda.core.experimental.{submod}" + new_key = f"cuda.core.{submod}" + if exp_key not in sys.modules: + try: + sys.modules[exp_key] = importlib.import_module(new_key) + except ImportError: + pass # old cuda-core that still has experimental._memory + + +def get_cuda_core_device_class(): + """Return the ``Device`` class from whichever cuda-core location is available. + + cuda-core <0.5: ``cuda.core.experimental.Device`` + cuda-core >=0.5: ``cuda.core.Device`` + """ + try: + from cuda.core import Device + + return Device + except ImportError: + from cuda.core.experimental import Device + + return Device + + +def ensure_nvshmem_compat(): + """Apply all compatibility patches. Safe to call multiple times.""" + _patch_cuda_core_experimental() diff --git a/megatron/core/resharding/transforms.py b/megatron/core/resharding/transforms.py new file mode 100644 index 00000000000..45a250d3541 --- /dev/null +++ b/megatron/core/resharding/transforms.py @@ -0,0 +1,261 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +from __future__ import annotations + +""" +Reshard transforms for custom send/recv/writeback during weight transfer. + +- ReshardTransform: base class for pluggable format conversion hooks. +- MXFP8ReshardTransform: writes received BF16 data into persistent FlashInfer + MXFP8Tensor buffers so CUDA-graph device-pointer captures remain valid. +""" + +import torch + +from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor + + +class ReshardTransform: + """Hook for custom send/recv/writeback during reshard execution. + + Implementations override the four methods below. When an instance is + passed to ``execute_reshard_plan``, each ``TransferOp`` is checked via + ``should_transform(param_name)``; if True the transform methods are used + instead of the default send/recv/writeback logic. + + The transform may change the wire format (e.g. send MXFP8 data+scale + instead of BF16) **or** keep the same wire format and only post-process + on the receive side (e.g. receive BF16, convert to MXFP8 in + ``finalize_recv``). The only constraint is that ``prepare_send`` and + ``prepare_recv`` must return the same number of tensors for a given + parameter so that send/recv pairs match. + """ + + def should_transform(self, param_name: str) -> bool: + """Return True if *param_name* should use the transform path.""" + return False + + def prepare_send( + self, param_name: str, src_slice: tuple[slice, ...], src_param: torch.nn.Parameter + ) -> list[torch.Tensor]: + """Produce tensor(s) to send for *param_name*. + + May return multiple tensors (e.g. data + scale when converting to + MXFP8 on the sender side). The default implementation sends the + BF16 slice unchanged (single tensor). + """ + raise NotImplementedError + + def prepare_recv(self, param_name: str, dst_slice: tuple[slice, ...]) -> list[torch.Tensor]: + """Allocate receive buffer(s). Count must match ``prepare_send`` output.""" + raise NotImplementedError + + def finalize_recv( + self, param_name: str, dst_slice: tuple[slice, ...], recv_buffers: list[torch.Tensor] + ) -> None: + """Write received data into final destination (e.g. persistent buffers). + + This is where receiver-side format conversion can happen (e.g. + converting a BF16 recv buffer to MXFP8 before writing into + persistent storage). + """ + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# MXFP8 transform helpers +# --------------------------------------------------------------------------- + + +def _scale_slice_from_data_slice( + data_slice: tuple[slice, ...], block_size: int = 32 +) -> tuple[slice, ...]: + """Convert an MXFP8 data slice to the corresponding scale slice. + + In MXFP8, each group of ``block_size`` elements along the last (K) + dimension shares a single scale value. All dimensions except the last + are passed through unchanged; the last ``slice`` has its start/stop + divided by ``block_size``. Integer index on the last dim is converted + to scale index as idx // block_size. + """ + adjusted = list(data_slice) + last = adjusted[-1] + if isinstance(last, slice): + if last.start is not None and last.start % block_size != 0: + raise AssertionError( + f"MXFP8 data slice last dim ({last}) must be aligned to block_size={block_size}" + ) + if last.stop is not None and last.stop % block_size != 0: + raise AssertionError( + f"MXFP8 data slice last dim ({last}) must be aligned to block_size={block_size}" + ) + scale_start = (last.start // block_size) if last.start is not None else None + scale_stop = (last.stop // block_size) if last.stop is not None else None + # Scale has one value per block; do not use last.step (would index scale wrong). + adjusted[-1] = slice(scale_start, scale_stop) + elif isinstance(last, int): + adjusted[-1] = last // block_size + return tuple(adjusted) + + +def _ensure_sendable(param: torch.Tensor) -> torch.Tensor: + """Return a standard-dtype tensor suitable for wire transmission. + + Quantized parameter types (e.g., Transformer Engine MXFP8Tensor) are + dequantized to their original precision (usually BF16). Standard + parameters are returned via ``.data`` (unwrapped from autograd). + """ + try: + from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor as _TEMXFP8 + + if isinstance(param, _TEMXFP8): + return param.dequantize() + except ImportError: + pass + return param.data + + +class MXFP8ReshardTransform(ReshardTransform): + """MXFP8 format-conversion transform for reshard. + + Writes received weight data directly into persistent ``MXFP8Tensor`` + buffers so that CUDA-graph device-pointer captures remain valid across + refits. + + Two modes are supported, controlled by *convert_on_send*: + + ``convert_on_send=False`` (default — **receiver-side conversion**): + The sender transmits plain BF16 (one tensor per op, identical to the + default reshard path). The receiver allocates a BF16 receive buffer, + then ``finalize_recv`` converts BF16 → MXFP8 and writes into the + persistent buffers. Because the wire format is unchanged the sender + does **not** need a transform — only the receiver creates one. This + is the simplest mode and avoids any sender/receiver coordination. + + ``convert_on_send=True`` (**sender-side conversion**): + The sender converts each BF16 slice to MXFP8 and sends **two** + tensors (data + scale) per op. The receiver allocates matching + MXFP8 buffers and ``finalize_recv`` copies them directly. Both + sender and receiver must use the transform so that tensor counts + match. This mode halves wire bandwidth (~1 byte/elem vs 2). + + **Caveat**: CopyService backends that match local (same-rank) + transfers by ``task_id`` (Gloo, NVSHMEM) will break if multiple + tensors share the same ``task_id``. This mode is therefore only + safe for non-colocated setups where sender and receiver are on + different ranks. A future fix could generate unique sub-IDs. + + Args: + convertible_params: set of fully-qualified parameter names that + should use this transform. + persistent_buffers: dict mapping parameter names (without + *buffer_key_prefix*) to ``MXFP8Tensor`` objects that hold the + receiver's persistent data/scale storage. Empty on the sender + when using ``convert_on_send=True``. + buffer_key_prefix: prefix to strip from ``param_name`` when looking + up entries in *persistent_buffers* (e.g. ``"decoder."``). + convert_on_send: if True, convert BF16 → MXFP8 on the sender and + transmit two tensors (data + scale). If False (default), + transmit BF16 and convert on the receiver in ``finalize_recv``. + """ + + def __init__( + self, + convertible_params: set[str], + persistent_buffers: dict, + buffer_key_prefix: str = "", + convert_on_send: bool = False, + ): + self.convertible_params = convertible_params + self.persistent_buffers = persistent_buffers + self.buffer_key_prefix = buffer_key_prefix + self.convert_on_send = convert_on_send + # Accumulation buffers for 1D-scale params that arrive in partial slices. + # The 1D swizzled FlashInfer scale can't be updated partially; we collect + # all BF16 slices here and quantize the full weight once it's assembled. + # Maps buf_key -> (full-size BF16 accumulation tensor, elements written so far). + self._pending_1d: dict = {} + + def should_transform(self, param_name: str) -> bool: + return param_name in self.convertible_params + + # -- send ---------------------------------------------------------------- + + def prepare_send(self, param_name, src_slice, src_param): + src_data = _ensure_sendable(src_param) + if self.convert_on_send: + + bf16_data = src_data[src_slice].contiguous().to(torch.bfloat16) + mxfp8 = MXFP8Tensor.from_bf16(bf16_data) + return [mxfp8.data.contiguous(), mxfp8.scale.contiguous()] + else: + # BF16 on the wire — same as the default reshard path. + return [src_data[src_slice].contiguous()] + + # -- recv ---------------------------------------------------------------- + + def prepare_recv(self, param_name, dst_slice): + buf_key = param_name.removeprefix(self.buffer_key_prefix) + buf = self.persistent_buffers[buf_key] + + if self.convert_on_send: + # Receive MXFP8 data + scale (2 buffers). + if buf.scale.ndim == 1: + # 1D swizzled scale can't be partially reconstructed from sender-quantized + # slices. Use convert_on_send=False for models with 1D-scale params. + raise NotImplementedError( + f"convert_on_send=True is not supported for parameters with 1D swizzled " + f"scale (param={param_name!r}). Use convert_on_send=False instead, which " + f"receives BF16 and quantizes the full weight on the receiver." + ) + scale_slice = _scale_slice_from_data_slice(dst_slice) + return [ + torch.empty_like(buf.data[dst_slice].contiguous()), + torch.empty_like(buf.scale[scale_slice].contiguous()), + ] + else: + # Receive BF16 (1 buffer, same shape as the MXFP8 data slice). + shape = buf.data[dst_slice].shape + return [torch.empty(shape, dtype=torch.bfloat16, device=buf.data.device)] + + def finalize_recv(self, param_name, dst_slice, recv_buffers): + buf_key = param_name.removeprefix(self.buffer_key_prefix) + buf = self.persistent_buffers[buf_key] + + if self.convert_on_send: + # Already MXFP8 on the wire — copy data and 2D scale slices directly. + # (1D scale is rejected at prepare_recv time, so only 2D reaches here.) + buf.data[dst_slice].copy_(recv_buffers[0]) + scale_slice = _scale_slice_from_data_slice(dst_slice) + buf.scale[scale_slice].copy_(recv_buffers[1]) + elif buf.scale.ndim == 1: + # 1D swizzled scale (FlashInfer format) encodes scale values across the + # full weight tensor; partial updates would corrupt the swizzle layout. + # Accumulate BF16 slices and quantize once all slices are assembled. + if buf_key not in self._pending_1d: + # Use zeros so that any un-filled slice produces zeros rather than garbage. + self._pending_1d[buf_key] = [ + torch.zeros_like(buf.data, dtype=torch.bfloat16), + 0, # elements written so far + ] + accum, written = self._pending_1d[buf_key] + accum[dst_slice].copy_(recv_buffers[0]) + written += recv_buffers[0].numel() + if written >= buf.data.numel(): + if written != buf.data.numel(): + raise AssertionError( + f"1D-scale param {param_name!r}: received {written} elements, " + f"expected {buf.data.numel()} (duplicate or missing slices?)" + ) + mxfp8 = MXFP8Tensor.from_bf16(accum) + buf.data.copy_(mxfp8.data) + buf.scale.copy_(mxfp8.scale) + del self._pending_1d[buf_key] + else: + self._pending_1d[buf_key][1] = written + else: + # 2D scale: each scale row covers exactly one data row, so partial + # row-wise updates are independent and can be applied immediately. + mxfp8 = MXFP8Tensor.from_bf16(recv_buffers[0]) + buf.data[dst_slice].copy_(mxfp8.data) + scale_slice = _scale_slice_from_data_slice(dst_slice) + buf.scale[scale_slice].copy_(mxfp8.scale) diff --git a/megatron/core/ssm/ops/__init__.py b/megatron/core/ssm/ops/__init__.py new file mode 100644 index 00000000000..3e4afde2e29 --- /dev/null +++ b/megatron/core/ssm/ops/__init__.py @@ -0,0 +1 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. diff --git a/megatron/core/ssm/ops/causal_conv1d_triton.py b/megatron/core/ssm/ops/causal_conv1d_triton.py new file mode 100644 index 00000000000..36d14a1d91b --- /dev/null +++ b/megatron/core/ssm/ops/causal_conv1d_triton.py @@ -0,0 +1,274 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +# Some of this code was adopted from https://github.com/Dao-AILab/causal-conv1d/ +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import triton +import triton.language as tl + + +@triton.jit +def causal_conv1d_update_kernel( + x_ptr, + x_b_stride, + x_s_stride, + x_c_stride, + conv_state_ptr, + conv_state_b_stride, + conv_state_c_stride, + conv_state_l_stride, + int_state_ptr, + int_state_b_stride, + int_state_s_stride, + int_state_c_stride, + int_state_l_stride, + weight_ptr, + weight_c_stride, + weight_width_stride, + bias_ptr, + bias_stride, + out_ptr, + out_b_stride, + out_s_stride, + out_c_stride, + conv_state_indices_ptr, + batch, + seq_len, + dim, + state_len, + WIDTH: tl.constexpr, + BLOCK_DIM: tl.constexpr, + HAS_BIAS: tl.constexpr, + HAS_STATE_INDICES: tl.constexpr, + HAS_INT_STATE: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, +): + """Triton implementation of causal_conv1d_update (kernel).""" + batch_id = tl.program_id(0) + channel_block_id = tl.program_id(1) + + channel_offsets = channel_block_id * BLOCK_DIM + tl.arange(0, BLOCK_DIM) + mask = channel_offsets < dim + + # State batch coordinate mapping + if HAS_STATE_INDICES: + state_batch_coord = tl.load(conv_state_indices_ptr + batch_id) + else: + state_batch_coord = batch_id + + # Base Pointers + conv_state_ptrs = ( + conv_state_ptr + + state_batch_coord * conv_state_b_stride + + channel_offsets * conv_state_c_stride + ) + weight_ptrs = weight_ptr + channel_offsets * weight_c_stride + + # Skip padding tokens (block-level uniform condition) + if state_batch_coord < 0: + for s in range(seq_len): + out_ptrs = ( + out_ptr + + batch_id * out_b_stride + + s * out_s_stride + + channel_offsets * out_c_stride + ) + tl.store(out_ptrs, 0.0, mask=mask) + return + + # Load Bias + if HAS_BIAS: + bias_val = tl.load(bias_ptr + channel_offsets * bias_stride, mask=mask).to(tl.float32) + else: + bias_val = tl.zeros([BLOCK_DIM], dtype=tl.float32) + + # Load Weights + if WIDTH == 2: + w0 = tl.load(weight_ptrs + 0 * weight_width_stride, mask=mask).to(tl.float32) + w1 = tl.load(weight_ptrs + 1 * weight_width_stride, mask=mask).to(tl.float32) + elif WIDTH == 3: + w0 = tl.load(weight_ptrs + 0 * weight_width_stride, mask=mask).to(tl.float32) + w1 = tl.load(weight_ptrs + 1 * weight_width_stride, mask=mask).to(tl.float32) + w2 = tl.load(weight_ptrs + 2 * weight_width_stride, mask=mask).to(tl.float32) + elif WIDTH == 4: + w0 = tl.load(weight_ptrs + 0 * weight_width_stride, mask=mask).to(tl.float32) + w1 = tl.load(weight_ptrs + 1 * weight_width_stride, mask=mask).to(tl.float32) + w2 = tl.load(weight_ptrs + 2 * weight_width_stride, mask=mask).to(tl.float32) + w3 = tl.load(weight_ptrs + 3 * weight_width_stride, mask=mask).to(tl.float32) + + # Initialize independent x_vals to match unrolled float array + x_val_0 = tl.zeros([BLOCK_DIM], dtype=tl.float32) + x_val_1 = tl.zeros([BLOCK_DIM], dtype=tl.float32) + x_val_2 = tl.zeros([BLOCK_DIM], dtype=tl.float32) + x_val_3 = tl.zeros([BLOCK_DIM], dtype=tl.float32) + + # Loop over the sequence dimension (e.g., speculative tokens) + for s in range(seq_len): + x_ptrs = x_ptr + batch_id * x_b_stride + s * x_s_stride + channel_offsets * x_c_stride + out_ptrs = ( + out_ptr + batch_id * out_b_stride + s * out_s_stride + channel_offsets * out_c_stride + ) + + # Load the last (WIDTH - 1) elements to use them BEFORE they are overwritten + # by the shift + if WIDTH >= 2: + x_val_0 = tl.load( + conv_state_ptrs + (state_len - WIDTH + 1) * conv_state_l_stride, mask=mask + ).to(tl.float32) + if WIDTH >= 3: + x_val_1 = tl.load( + conv_state_ptrs + (state_len - WIDTH + 2) * conv_state_l_stride, mask=mask + ).to(tl.float32) + if WIDTH >= 4: + x_val_2 = tl.load( + conv_state_ptrs + (state_len - WIDTH + 3) * conv_state_l_stride, mask=mask + ).to(tl.float32) + + # Shift the linear state buffer left by 1 + i = 0 + while i < state_len - 1: + val = tl.load(conv_state_ptrs + (i + 1) * conv_state_l_stride, mask=mask) + tl.store(conv_state_ptrs + i * conv_state_l_stride, val, mask=mask) + i += 1 + + # Process the single token for the current sequence step + x_val = tl.load(x_ptrs, mask=mask) + + # Store the new token at the end of the linear state buffer + tl.store(conv_state_ptrs + (state_len - 1) * conv_state_l_stride, x_val, mask=mask) + + # Write out to the intermediate state buffer if requested + if HAS_INT_STATE: + i = 0 + while i < state_len: + val = tl.load(conv_state_ptrs + i * conv_state_l_stride, mask=mask) + int_ptr = ( + int_state_ptr + + state_batch_coord * int_state_b_stride + + s * int_state_s_stride + + channel_offsets * int_state_c_stride + + i * int_state_l_stride + ) + tl.store(int_ptr, val, mask=mask) + i += 1 + + # Advance registers for calculation + x_val_f32 = x_val.to(tl.float32) + if WIDTH == 2: + x_val_1 = x_val_f32 + elif WIDTH == 3: + x_val_2 = x_val_f32 + elif WIDTH == 4: + x_val_3 = x_val_f32 + + # Compute output + out_val = bias_val + if WIDTH == 2: + out_val += w0 * x_val_0 + w1 * x_val_1 + elif WIDTH == 3: + out_val += w0 * x_val_0 + w1 * x_val_1 + w2 * x_val_2 + elif WIDTH == 4: + out_val += w0 * x_val_0 + w1 * x_val_1 + w2 * x_val_2 + w3 * x_val_3 + + if SILU_ACTIVATION: + out_val = out_val * tl.sigmoid(out_val) + + tl.store(out_ptrs, out_val.to(out_ptrs.dtype.element_ty), mask=mask) + + +def causal_conv1d_update( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + conv_state_indices: torch.Tensor | None, + intermediate_conv_states: torch.Tensor | None = None, +) -> torch.Tensor: + """Triton implementation of causal_conv1d_update (entrypoint).""" + + # Check if input is 2D, temporarily treat as 3D for uniform processing + is_2d = x.dim() == 2 + if is_2d: + x = x.unsqueeze(1) + + batch, seq_len, dim = x.shape + out = torch.empty_like(x) + state_len = conv_state.shape[-1] + width = weight.shape[-1] + + if bias is not None: + bias_stride = bias.stride(0) + has_bias = True + else: + bias = x # Dummy pointer + bias_stride = 0 + has_bias = False + + if conv_state_indices is not None: + has_state_indices = True + else: + conv_state_indices = x # Dummy pointer + has_state_indices = False + + # Extract intermediate state strides if provided + if intermediate_conv_states is not None: + has_int_state = True + int_state_ptr = intermediate_conv_states + int_state_b_stride = intermediate_conv_states.stride(0) + int_state_s_stride = intermediate_conv_states.stride(1) + int_state_c_stride = intermediate_conv_states.stride(2) + int_state_l_stride = intermediate_conv_states.stride(3) + else: + has_int_state = False + int_state_ptr = x # Dummy pointer + int_state_b_stride = 0 + int_state_s_stride = 0 + int_state_c_stride = 0 + int_state_l_stride = 0 + + BLOCK_DIM = 64 + grid = (batch, triton.cdiv(dim, BLOCK_DIM)) + + causal_conv1d_update_kernel[grid]( + x, + x.stride(0), + x.stride(1), + x.stride(2), + conv_state, + conv_state.stride(0), + conv_state.stride(1), + conv_state.stride(2), + int_state_ptr, + int_state_b_stride, + int_state_s_stride, + int_state_c_stride, + int_state_l_stride, + weight, + weight.stride(0), + weight.stride(1), + bias, + bias_stride, + out, + out.stride(0), + out.stride(1), + out.stride(2), + conv_state_indices, + batch, + seq_len, + dim, + state_len, + WIDTH=width, + BLOCK_DIM=BLOCK_DIM, + HAS_BIAS=has_bias, + HAS_STATE_INDICES=has_state_indices, + HAS_INT_STATE=has_int_state, + SILU_ACTIVATION=silu_activation == "silu", + ) + + if is_2d: + out = out.squeeze(1) + + return out diff --git a/megatron/core/ssm/ops/causal_conv1d_varlen.py b/megatron/core/ssm/ops/causal_conv1d_varlen.py new file mode 100644 index 00000000000..327ce1bfc48 --- /dev/null +++ b/megatron/core/ssm/ops/causal_conv1d_varlen.py @@ -0,0 +1,259 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""Triton varlen depthwise causal 1D convolution with per-sequence initial states and fused SiLU. + +Supports packed variable-length sequences where `causal_conv1d_fn` cannot accept +both `seq_idx` and `initial_states` simultaneously. +""" + +import torch +import triton +import triton.language as tl + +from megatron.core.ssm.ops.determinism import autotune_configs + + +@triton.autotune( + configs=autotune_configs( + [ + triton.Config({"BLOCK_T": 128, "BLOCK_C": 64}, num_warps=4), + triton.Config({"BLOCK_T": 128, "BLOCK_C": 128}, num_warps=4), + triton.Config({"BLOCK_T": 256, "BLOCK_C": 64}, num_warps=4), + triton.Config({"BLOCK_T": 256, "BLOCK_C": 128}, num_warps=8), + ] + ), + key=["conv_dim"], +) +@triton.jit +def _causal_conv1d_varlen_kernel( + x_ptr, + weight_ptr, + bias_ptr, + seq_idx_ptr, + seq_start_ptr, + initial_states_ptr, + out_ptr, + total_tokens, + conv_dim: tl.constexpr, + initial_states_stride_req, + initial_states_stride_dim, + WIDTH: tl.constexpr, + BLOCK_T: tl.constexpr, + BLOCK_C: tl.constexpr, + HAS_INITIAL_STATES: tl.constexpr, +): + """Depthwise causal conv1d over packed varlen sequences with initial states and SiLU. + + Fully vectorized over BLOCK_T tokens x BLOCK_C channels per thread block. + """ + pid_c = tl.program_id(0) + pid_t = tl.program_id(1) + + c_off = pid_c * BLOCK_C + tl.arange(0, BLOCK_C) # (BLOCK_C,) + c_mask = c_off < conv_dim + t_off = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # (BLOCK_T,) + t_mask = t_off < total_tokens + + # Load bias: (BLOCK_C,) broadcast to (BLOCK_T, BLOCK_C) + bias = tl.load(bias_ptr + c_off, mask=c_mask, other=0.0).to(tl.float32) + acc = tl.zeros((BLOCK_T, BLOCK_C), dtype=tl.float32) + bias[None, :] + + # Load per-token request ID and request start position + req_id = tl.load(seq_idx_ptr + t_off, mask=t_mask, other=0) # (BLOCK_T,) + req_start = tl.load(seq_start_ptr + t_off, mask=t_mask, other=0) # (BLOCK_T,) + + # Unrolled convolution over WIDTH taps (typically 4) + for j in tl.static_range(WIDTH): + # Load weight column j: (BLOCK_C,) + w_j = tl.load(weight_ptr + c_off * WIDTH + j, mask=c_mask, other=0.0).to(tl.float32) + + # Source position for this tap + src = t_off - (WIDTH - 1) + j # (BLOCK_T,) + in_seq = src >= req_start # (BLOCK_T,) — True if source is within the sequence + + # Load from x for in-sequence positions (mask out out-of-bounds) + src_safe = tl.maximum(src, 0) + x_val = tl.load( + x_ptr + src_safe[:, None] * conv_dim + c_off[None, :], + mask=t_mask[:, None] & c_mask[None, :] & in_seq[:, None], + other=0.0, + ).to( + tl.float32 + ) # (BLOCK_T, BLOCK_C) + + if HAS_INITIAL_STATES: + # For tokens where src < req_start, load from initial_states + state_col = (WIDTH - 1) - (req_start - src) # (BLOCK_T,) + valid_state = (~in_seq) & (state_col >= 0) # (BLOCK_T,) + state_col_safe = tl.maximum(state_col, 0) + + state_val = tl.load( + initial_states_ptr + + req_id[:, None] * initial_states_stride_req + + c_off[None, :] * initial_states_stride_dim + + state_col_safe[:, None], + mask=t_mask[:, None] & c_mask[None, :] & valid_state[:, None], + other=0.0, + ).to( + tl.float32 + ) # (BLOCK_T, BLOCK_C) + + tap = tl.where(in_seq[:, None], x_val, state_val) + else: + tap = x_val + + acc += tap * w_j[None, :] + + # SiLU activation: x * sigmoid(x) + sigmoid_acc = 1.0 / (1.0 + tl.exp(-acc)) + result = acc * sigmoid_acc + + # Store output (cast back to input dtype) + tl.store( + out_ptr + t_off[:, None] * conv_dim + c_off[None, :], + result, + mask=t_mask[:, None] & c_mask[None, :], + ) + + +def causal_conv1d_varlen_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + cu_seqlens: torch.Tensor, + initial_states: torch.Tensor = None, + activation: str = "silu", + precomputed_seq_idx: torch.Tensor = None, + precomputed_seq_start: torch.Tensor = None, +) -> torch.Tensor: + """Depthwise causal 1D convolution over packed variable-length sequences. + + Supports both `cu_seqlens` (sequence boundaries) and `initial_states` + simultaneously, unlike `causal_conv1d_fn` which requires mutual exclusivity + between `seq_idx` and `initial_states`. + + Args: + x: Input tensor of shape (total_tokens, conv_dim), channels-last packed. + weight: Convolution weights of shape (conv_dim, d_conv). + bias: Bias of shape (conv_dim,). + cu_seqlens: Cumulative sequence lengths of shape (num_requests + 1,), int32. + initial_states: Per-request initial conv states of shape + (num_requests, conv_dim, d_conv - 1). If None, uses zeros. + activation: Activation function, must be "silu". + precomputed_seq_idx: Precomputed per-token request ID of shape + (total_tokens,). If provided, skips repeat_interleave (CUDA + graph compatible). Padding tokens should use 0 as sentinel. + precomputed_seq_start: Precomputed per-token request start position + of shape (total_tokens,). Must be provided together with + precomputed_seq_idx. + + Returns: + Output tensor of shape (total_tokens, conv_dim). + """ + assert activation == "silu", f"Only silu activation is supported, got {activation}" + assert x.is_contiguous(), "x must be contiguous" + assert weight.is_contiguous(), "weight must be contiguous" + + total_tokens, conv_dim = x.shape + d_conv = weight.shape[1] + num_requests = cu_seqlens.shape[0] - 1 + + out = torch.empty_like(x) + + # Use precomputed per-token metadata if provided (CUDA graph compatible), + # otherwise compute from cu_seqlens via repeat_interleave. + if precomputed_seq_idx is not None: + assert precomputed_seq_start is not None + seq_idx = precomputed_seq_idx + seq_start = precomputed_seq_start + else: + seq_lengths = cu_seqlens[1:] - cu_seqlens[:-1] + seq_idx = torch.repeat_interleave( + torch.arange(num_requests, device=x.device, dtype=torch.int32), seq_lengths + ) + seq_start = torch.repeat_interleave(cu_seqlens[:-1], seq_lengths).to(torch.int32) + + has_initial_states = initial_states is not None + if not has_initial_states: + initial_states = torch.empty((1, 1, 1), dtype=x.dtype, device=x.device) + is_stride_req = 1 + is_stride_dim = 1 + else: + if precomputed_seq_idx is None: + assert initial_states.shape == (num_requests, conv_dim, d_conv - 1) + is_stride_req = initial_states.stride(0) + is_stride_dim = initial_states.stride(1) + + grid = lambda meta: ( + triton.cdiv(conv_dim, meta["BLOCK_C"]), + triton.cdiv(total_tokens, meta["BLOCK_T"]), + ) + + _causal_conv1d_varlen_kernel[grid]( + x, + weight, + bias, + seq_idx, + seq_start, + initial_states, + out, + total_tokens, + conv_dim, + is_stride_req, + is_stride_dim, + WIDTH=d_conv, + HAS_INITIAL_STATES=has_initial_states, + ) + + return out + + +def _causal_conv1d_varlen_simple( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + cu_seqlens: torch.Tensor, + initial_states: torch.Tensor, + out: torch.Tensor, +) -> None: + """Simple PyTorch implementation of varlen causal conv1d with initial states and SiLU. + + This is a reference implementation for testing. Processes each request and token + sequentially. + """ + total_tokens, conv_dim = x.shape + d_conv = weight.shape[1] + num_requests = cu_seqlens.shape[0] - 1 + + for r in range(num_requests): + start = cu_seqlens[r].item() + end = cu_seqlens[r + 1].item() + seq_len = end - start + + if seq_len == 0: + continue + + if initial_states is not None: + init_state = initial_states[r] # (conv_dim, d_conv - 1) + else: + init_state = torch.zeros((conv_dim, d_conv - 1), dtype=x.dtype, device=x.device) + + x_seq = x[start:end] # (seq_len, conv_dim) + + for t in range(seq_len): + acc = bias.float() # (conv_dim,) + for j in range(d_conv): + src_pos = t - (d_conv - 1) + j + if src_pos < 0: + state_col = (d_conv - 1) + src_pos + if state_col >= 0 and state_col < d_conv - 1: + tap = init_state[:, state_col].float() + else: + tap = torch.zeros(conv_dim, dtype=torch.float32, device=x.device) + else: + tap = x_seq[src_pos].float() + + acc = acc + tap * weight[:, j].float() + + result = acc * torch.sigmoid(acc) + out[start + t] = result.to(out.dtype) diff --git a/megatron/core/ssm/ops/determinism.py b/megatron/core/ssm/ops/determinism.py new file mode 100644 index 00000000000..d642e8bba01 --- /dev/null +++ b/megatron/core/ssm/ops/determinism.py @@ -0,0 +1,123 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, Tri Dao, Albert Gu. + +# Some of this code was adopted from https://github.com/state-spaces/mamba/ +# This source code is licensed under the Apache license found in the +# LICENSE file in the root directory of this source tree. + +import os +import warnings + +import torch +from packaging import version + +try: + import triton + + TRITON_VERSION = version.parse(triton.__version__) +except ImportError: + TRITON_VERSION = version.parse("0.0.0") + +TRITON_HAS_CACHE_RESULTS = TRITON_VERSION >= version.parse("3.4.0") +_autotune_warning_issued = False + +_deterministic_override = None + + +def use_deterministic_mode(): + """Use torch deterministic mode.""" + if _deterministic_override is not None: + return _deterministic_override + env = os.environ.get('MAMBA_DETERMINISTIC') + if env: + return env[0] == '1' + return torch.are_deterministic_algorithms_enabled() + + +def set_deterministic_mode(value): + """Set torch deterministic mode.""" + global _deterministic_override + _deterministic_override = value + + +def _estimate_config_cost(cfg): + """Estimate shared memory cost of a config. Lower is cheaper. + + Returns a tuple (block_cost, num_warps) so that ties in block cost + are broken deterministically by warp count (fewer warps = cheaper). + """ + block_product = 1 + for key, val in cfg.kwargs.items(): + if key.startswith('BLOCK') and isinstance(val, int): + block_product *= val + stages = getattr(cfg, 'num_stages', 1) or 1 + warps = getattr(cfg, 'num_warps', 1) or 1 + return (block_product * stages, warps) + + +def _filter_configs_by_block_sizes(configs): + """Filter configs by TRITON_AUTOTUNE_BLOCK_* env vars. + + Scans environment for any variable matching TRITON_AUTOTUNE_BLOCK_* + (e.g. TRITON_AUTOTUNE_BLOCK_SIZE_M, TRITON_AUTOTUNE_BLOCK_SIZE_H, + TRITON_AUTOTUNE_BLOCK_T, TRITON_AUTOTUNE_BLOCK_C, TRITON_AUTOTUNE_BLOCK_SIZE) + and maps them to the corresponding kernel kwarg (BLOCK_SIZE_M, BLOCK_SIZE_H, + BLOCK_T, BLOCK_C, BLOCK_SIZE). + """ + prefix = "TRITON_AUTOTUNE_" + env_filters = {} + for env_key, env_val in os.environ.items(): + if env_key.startswith(prefix + "BLOCK") and env_val: + kwarg_name = env_key[len(prefix) :] + env_filters[kwarg_name] = int(env_val) + if not env_filters: + return None + matching = configs + for key, target in sorted(env_filters.items()): + matching = [c for c in matching if c.kwargs.get(key) == target] + return matching[:1] if matching else None + + +def autotune_configs(configs): + """Select autotune configs for deterministic mode. + + Uses cached autotuning (TRITON_CACHE_AUTOTUNING=1) if Triton >= 3.4.0, + otherwise auto-selects the cheapest config by block size * stages. + """ + if not configs or not use_deterministic_mode(): + return configs + if TRITON_HAS_CACHE_RESULTS and os.environ.get("TRITON_CACHE_AUTOTUNING") == "1": + return configs + global _autotune_warning_issued + if not _autotune_warning_issued: + _autotune_warning_issued = True + msg = ( + "Deterministic mode: set TRITON_CACHE_AUTOTUNING=1 for cached autotuning." + if TRITON_HAS_CACHE_RESULTS + else "Deterministic mode: upgrade to Triton >= 3.4.0 for cached autotuning." + ) + warnings.warn(msg) + filtered = _filter_configs_by_block_sizes(configs) + if filtered: + return filtered + return [min(configs, key=_estimate_config_cost)] + + +def alloc_tile_workspace(base_shape, tile_dim, dtype, device, deterministic, *, zero_init=True): + """Allocate buffer for deterministic per-program reductions.""" + if base_shape is None: + return None, 0 + if deterministic: + factory = torch.zeros if zero_init else torch.empty + tensor = factory(*base_shape, tile_dim, device=device, dtype=dtype) + return tensor, tensor.stride(-1) + return torch.empty(*base_shape, device=device, dtype=dtype), 0 + + +def finalize_tile_workspace(tensor, deterministic): + """Finalize tile workspace.""" + if tensor is None: + return None + if deterministic: + tensor = tensor.sum(dim=-1) + return tensor diff --git a/megatron/core/ssm/ops/mamba_ssm.py b/megatron/core/ssm/ops/mamba_ssm.py new file mode 100644 index 00000000000..cd2041eb084 --- /dev/null +++ b/megatron/core/ssm/ops/mamba_ssm.py @@ -0,0 +1,441 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2024, Tri Dao, Albert Gu. + +# Some of this code was adopted from https://github.com/state-spaces/mamba/ +# This source code is licensed under the Apache license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import triton +import triton.language as tl +from packaging import version + +TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") + + +if TRITON3: + + @triton.jit + def softplus(dt): + """Optimized softplus.""" + return tl.math.log(tl.math.exp(dt) + 1) + +else: + + @triton.jit + def softplus(dt): + """Optimized softplus.""" + return tl.math.log1p(tl.exp(dt)) + + +@triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) +@triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) +@triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) +@triton.heuristics( + {"HAS_STATE_BATCH_INDICES": lambda args: args["state_batch_indices_ptr"] is not None} +) +@triton.heuristics({"HAS_INT_STATE": lambda args: args["int_state_ptr"] is not None}) +@triton.heuristics({"BLOCK_SIZE_DSTATE": lambda args: triton.next_power_of_2(args["dstate"])}) +@triton.jit +def _selective_scan_update_kernel( + # Pointers to matrices + state_ptr, + x_ptr, + dt_ptr, + dt_bias_ptr, + A_ptr, + B_ptr, + C_ptr, + D_ptr, + z_ptr, + out_ptr, + state_batch_indices_ptr, + int_state_ptr, + # Matrix dimensions + batch, + seq_len, + nheads, + dim, + dstate, + nheads_ngroups_ratio, + # Strides + stride_state_batch, + stride_state_head, + stride_state_dim, + stride_state_dstate, + stride_x_batch, + stride_x_seq, + stride_x_head, + stride_x_dim, + stride_dt_batch, + stride_dt_seq, + stride_dt_head, + stride_dt_dim, + stride_dt_bias_head, + stride_dt_bias_dim, + stride_A_head, + stride_A_dim, + stride_A_dstate, + stride_B_batch, + stride_B_seq, + stride_B_group, + stride_B_dstate, + stride_C_batch, + stride_C_seq, + stride_C_group, + stride_C_dstate, + stride_D_head, + stride_D_dim, + stride_z_batch, + stride_z_seq, + stride_z_head, + stride_z_dim, + stride_out_batch, + stride_out_seq, + stride_out_head, + stride_out_dim, + stride_int_batch, + stride_int_seq, + stride_int_head, + stride_int_dim, + stride_int_dstate, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + TIE_HDIM: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + HAS_D: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_STATE_BATCH_INDICES: tl.constexpr, + HAS_INT_STATE: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = tl.arange(0, BLOCK_SIZE_DSTATE) + + out_ptr += pid_b * stride_out_batch + pid_h * stride_out_head + out_ptrs = out_ptr + offs_m * stride_out_dim + + # 1. State Mapping (handles dynamic batching slot allocation) + if HAS_STATE_BATCH_INDICES: + state_batch_indices_ptr += pid_b + state_batch_idx = tl.load(state_batch_indices_ptr) + # Skip padding tokens (e.g. from graph capture or inactive slots) + if state_batch_idx < 0: + for s in range(seq_len): + out_s_ptrs = out_ptrs + s * stride_out_seq + tl.store(out_s_ptrs, 0.0, mask=offs_m < dim) + return + state_ptr += state_batch_idx * stride_state_batch + pid_h * stride_state_head + if HAS_INT_STATE: + int_state_ptr += state_batch_idx * stride_int_batch + pid_h * stride_int_head + else: + state_ptr += pid_b * stride_state_batch + pid_h * stride_state_head + if HAS_INT_STATE: + int_state_ptr += pid_b * stride_int_batch + pid_h * stride_int_head + + # Base Pointers for Sequence iteration + x_ptr += pid_b * stride_x_batch + pid_h * stride_x_head + dt_ptr += pid_b * stride_dt_batch + pid_h * stride_dt_head + if HAS_DT_BIAS: + dt_bias_ptr += pid_h * stride_dt_bias_head + + A_ptr += pid_h * stride_A_head + B_ptr += pid_b * stride_B_batch + (pid_h // nheads_ngroups_ratio) * stride_B_group + C_ptr += pid_b * stride_C_batch + (pid_h // nheads_ngroups_ratio) * stride_C_group + if HAS_Z: + z_ptr += pid_b * stride_z_batch + pid_h * stride_z_head + + # Constant offsets (A, D, and bias do not have a sequence dimension) + state_ptrs = state_ptr + ( + offs_m[:, None] * stride_state_dim + offs_n[None, :] * stride_state_dstate + ) + if HAS_INT_STATE: + int_state_ptrs = int_state_ptr + ( + offs_m[:, None] * stride_int_dim + offs_n[None, :] * stride_int_dstate + ) + + x_ptrs = x_ptr + offs_m * stride_x_dim + dt_ptrs = dt_ptr + offs_m * stride_dt_dim + + if HAS_DT_BIAS: + dt_bias_ptrs = dt_bias_ptr + offs_m * stride_dt_bias_dim + if HAS_D: + D_ptr += pid_h * stride_D_head + D_ptrs = D_ptr + offs_m * stride_D_dim + + A_ptrs = A_ptr + (offs_m[:, None] * stride_A_dim + offs_n[None, :] * stride_A_dstate) + B_ptrs = B_ptr + offs_n * stride_B_dstate + C_ptrs = C_ptr + offs_n * stride_C_dstate + + if HAS_Z: + z_ptrs = z_ptr + offs_m * stride_z_dim + + # Load initial historical state and constant parameters + state = tl.load( + state_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0 + ).to(tl.float32) + + if not TIE_HDIM: + A = tl.load( + A_ptrs, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate), other=0.0 + ).to(tl.float32) + else: + A = tl.load(A_ptr).to(tl.float32) + + if HAS_D: + D = tl.load(D_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + + # ---------------------------------------------------- + # Sequence Loop (Processes Main Token + Speculative Drafts) + # ---------------------------------------------------- + for s in range(seq_len): + x_s_ptrs = x_ptrs + s * stride_x_seq + dt_s_ptrs = dt_ptrs + s * stride_dt_seq + B_s_ptrs = B_ptrs + s * stride_B_seq + C_s_ptrs = C_ptrs + s * stride_C_seq + if HAS_Z: + z_s_ptrs = z_ptrs + s * stride_z_seq + + x = tl.load(x_s_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + + # Calculate dt and dA + if not TIE_HDIM: + dt = tl.load(dt_s_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + if DT_SOFTPLUS: + dt = tl.where(dt <= 20.0, softplus(dt), dt) + dA = tl.exp(A * dt[:, None]) + else: + dt = tl.load(dt_ptr + s * stride_dt_seq).to(tl.float32) + if HAS_DT_BIAS: + dt += tl.load(dt_bias_ptr).to(tl.float32) + if DT_SOFTPLUS: + dt = tl.where(dt <= 20.0, softplus(dt), dt) + dA = tl.exp(A * dt) + + # Load B and C + B = tl.load(B_s_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + C = tl.load(C_s_ptrs, mask=offs_n < dstate, other=0.0).to(tl.float32) + if HAS_Z: + z = tl.load(z_s_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + + if not TIE_HDIM: + dB = B[None, :] * dt[:, None] + else: + dB = B * dt + + # ---------------------------------------------------- + # The Core State Recurrence (h_t = dA * h_{t-1} + dB * x_t) + # ---------------------------------------------------- + state = state * dA + dB * x[:, None] + + # ---------------------------------------------------- + # Dump Intermediate Speculative State Snapshot + # ---------------------------------------------------- + if HAS_INT_STATE: + int_state_s_ptrs = int_state_ptrs + s * stride_int_seq + tl.store( + int_state_s_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate) + ) + + # Calculate Output + out = tl.sum(state * C[None, :], axis=1) + if HAS_D: + out += x * D + if HAS_Z: + out *= z * tl.sigmoid(z) + + out_s_ptrs = out_ptrs + s * stride_out_seq + tl.store(out_s_ptrs, out, mask=offs_m < dim) + + # After processing all sequence steps, persist the final state back to HBM + tl.store(state_ptrs, state, mask=(offs_m[:, None] < dim) & (offs_n[None, :] < dstate)) + + +def selective_state_update( + state, + x, + dt, + A, + B, + C, + D=None, + z=None, + dt_bias=None, + dt_softplus=False, + state_batch_indices=None, + intermediate_ssm_states=None, +): + """ + Argument: + state: (batch, dim, dstate) or (batch, nheads, dim, dstate) + x: (batch, dim), (batch, seqlen, dim), (batch, nheads, dim) or (batch, seqlen, nheads, dim) + dt: Matches x + A: (dim, dstate) or (nheads, dim, dstate) + B: (batch, dstate), (batch, seqlen, dstate), (batch, ngroups, dstate) or + (batch, seqlen, ngroups, dstate) + C: Matches B + D: (dim,) or (nheads, dim) + z: Matches x + dt_bias: (dim,) or (nheads, dim) + intermediate_ssm_states: Optional buffer of shape (batch, seqlen, nheads, dim, dstate) + or (batch, seqlen, dim, dstate) + Return: + out: shape matches x + """ + has_heads = state.dim() > 3 + if not has_heads: + state = state.unsqueeze(1) + + # Standardize inputs to explicit sequence and head dimensions: (batch, seq_len, nheads, dim) + is_seq_unsq = False + if has_heads: + if x.dim() == 3: # (batch, nheads, dim) -> (batch, 1, nheads, dim) + x = x.unsqueeze(1) + dt = dt.unsqueeze(1) + B = B.unsqueeze(1) + C = C.unsqueeze(1) + if z is not None: + z = z.unsqueeze(1) + is_seq_unsq = True + else: + if x.dim() == 2: # (batch, dim) -> (batch, 1, 1, dim) + x = x.unsqueeze(1).unsqueeze(2) + dt = dt.unsqueeze(1).unsqueeze(2) + B = B.unsqueeze(1).unsqueeze(2) + C = C.unsqueeze(1).unsqueeze(2) + if z is not None: + z = z.unsqueeze(1).unsqueeze(2) + is_seq_unsq = True + elif x.dim() == 3: # (batch, seqlen, dim) -> (batch, seqlen, 1, dim) + x = x.unsqueeze(2) + dt = dt.unsqueeze(2) + B = B.unsqueeze(2) + C = C.unsqueeze(2) + if z is not None: + z = z.unsqueeze(2) + + if A.dim() == 2: + A = A.unsqueeze(0) + if D is not None and D.dim() == 1: + D = D.unsqueeze(0) + if dt_bias is not None and dt_bias.dim() == 1: + dt_bias = dt_bias.unsqueeze(0) + + # Set up Intermediate State standardization + if intermediate_ssm_states is not None: + if not has_heads and intermediate_ssm_states.dim() == 4: + intermediate_ssm_states = intermediate_ssm_states.unsqueeze( + 2 + ) # (batch, seqlen, 1, dim, dstate) + int_state_strides = ( + intermediate_ssm_states.stride(0), + intermediate_ssm_states.stride(1), + intermediate_ssm_states.stride(2), + intermediate_ssm_states.stride(3), + intermediate_ssm_states.stride(4), + ) + else: + intermediate_ssm_states = x # Dummy pointer + int_state_strides = (0, 0, 0, 0, 0) + + batch, seq_len, nheads, dim = x.shape + dstate = state.shape[-1] + ngroups = B.shape[-2] + + out = torch.empty_like(x) + grid = lambda META: (triton.cdiv(dim, META['BLOCK_SIZE_M']), batch, nheads) + z_strides = ( + (z.stride(0), z.stride(1), z.stride(2), z.stride(3)) if z is not None else (0, 0, 0, 0) + ) + + BLOCK_SIZE_M, num_warps = ( + (32, 4) + if dstate <= 16 + else ( + (16, 4) + if dstate <= 32 + else ((8, 4) if dstate <= 64 else ((4, 4) if dstate <= 128 else ((4, 8)))) + ) + ) + + tie_hdim = ( + A.stride(-1) == 0 + and A.stride(-2) == 0 + and dt.stride(-1) == 0 + and (dt_bias is None or dt_bias.stride(-1) == 0) + ) + + with torch.cuda.device(x.device.index): + _selective_scan_update_kernel[grid]( + state, + x, + dt, + dt_bias, + A, + B, + C, + D, + z, + out, + state_batch_indices, + intermediate_ssm_states, + batch, + seq_len, + nheads, + dim, + dstate, + nheads // ngroups, + state.stride(0), + state.stride(1), + state.stride(2), + state.stride(3), + x.stride(0), + x.stride(1), + x.stride(2), + x.stride(3), + dt.stride(0), + dt.stride(1), + dt.stride(2), + dt.stride(3), + *(dt_bias.stride(0), dt_bias.stride(1)) if dt_bias is not None else (0, 0), + A.stride(0), + A.stride(1), + A.stride(2), + B.stride(0), + B.stride(1), + B.stride(2), + B.stride(3), + C.stride(0), + C.stride(1), + C.stride(2), + C.stride(3), + *(D.stride(0), D.stride(1)) if D is not None else (0, 0), + z_strides[0], + z_strides[1], + z_strides[2], + z_strides[3], + out.stride(0), + out.stride(1), + out.stride(2), + out.stride(3), + *int_state_strides, + dt_softplus, + tie_hdim, + BLOCK_SIZE_M, + num_warps=num_warps, + ) + + # Revert dimensions back to match original x format + if not has_heads: + out = out.squeeze(2) + if is_seq_unsq: + out = out.squeeze(1) + + return out diff --git a/megatron/core/ssm/ops/ssd_bmm.py b/megatron/core/ssm/ops/ssd_bmm.py new file mode 100644 index 00000000000..0cbb07fdbf5 --- /dev/null +++ b/megatron/core/ssm/ops/ssd_bmm.py @@ -0,0 +1,199 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_bmm.py +# Adapted from vLLM project (Apache-2.0). + +import torch +import triton +import triton.language as tl + +from megatron.core.ssm.ops.determinism import autotune_configs + + +@triton.autotune( + configs=autotune_configs( + [ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=2, + ), + ] + ), + key=["chunk_size", "K", "IS_CAUSAL"], +) +@triton.jit +def _bmm_chunk_fwd_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + out_ptr, + cu_chunk_seqlens_ptr, + # Matrix dimensions + seqlen, + chunk_size: tl.constexpr, + K: tl.constexpr, + ngroups: tl.constexpr, + stride_a_seqlen: tl.int64, + stride_a_head: tl.int64, + stride_ak: tl.constexpr, + stride_b_seqlen: tl.int64, + stride_b_head: tl.int64, + stride_bk: tl.constexpr, + stride_out_chunk: tl.int64, + stride_out_head: tl.int64, + stride_outm: tl.int64, + stride_outn: tl.constexpr, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + dot_dtype: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_ch = tl.program_id(axis=1).to(tl.int64) + pid_c = pid_ch // ngroups + pid_h = pid_ch - pid_c * ngroups + num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + if IS_CAUSAL: + if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M: + return + + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) + + a_ptr += chunk_seqlen_start * stride_a_seqlen + pid_h * stride_a_head + b_ptr += chunk_seqlen_start * stride_b_seqlen + pid_h * stride_b_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak) + b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen) + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # compute a * b.T + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + a = tl.load( + a_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0, + ).to(dot_dtype) + b = tl.load( + b_ptrs, + mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), + other=0.0, + ).to(dot_dtype) + acc += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + out = acc.to(out_ptr.dtype.element_ty) + out_ptr += pid_c * stride_out_chunk + pid_h * stride_out_head + out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn) + tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size)) + + +def _bmm_chunk_fwd(a, b, chunk_size, cu_chunk_seqlens, causal=False, output_dtype=None): + """ + Argument: + a: (seqlen, ngroups, k) + b: (seqlen, ngroups, k) + chunk_size: int + cu_chunk_seq_lens: (nchunks+1,) + causal: if True, then out[i, j] for i > j will be arbitrary, only out[i, j] for i <= j are + guaranteed to be correct. + Return: + out: (nchunks, ngroups, chunk_size, chunk_size) + """ + seqlen, ngroups, k = a.shape + assert b.shape == a.shape + if a.stride(-1) != 1 and a.stride(0) != 1: + a = a.contiguous() + if b.stride(-1) != 1 and b.stride(0) != 1: + b = b.contiguous() + + nchunks = len(cu_chunk_seqlens) - 1 + # Allocates output. + out_dtype = a.dtype if output_dtype is None else output_dtype + out = torch.empty((nchunks, ngroups, chunk_size, chunk_size), device=a.device, dtype=out_dtype) + dot_dtype = ( + tl.bfloat16 + if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 + else (tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32) + ) + grid = lambda META: ( + triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]) + * triton.cdiv(chunk_size, META["BLOCK_SIZE_N"]), + nchunks * ngroups, + ) + with torch.cuda.device(a.device.index): + _bmm_chunk_fwd_kernel[grid]( + a_ptr=a, + b_ptr=b, + out_ptr=out, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, + seqlen=seqlen, + chunk_size=chunk_size, + K=k, + ngroups=ngroups, + stride_a_seqlen=a.stride(0), + stride_a_head=a.stride(1), + stride_ak=a.stride(2), + stride_b_seqlen=b.stride(0), + stride_b_head=b.stride(1), + stride_bk=b.stride(2), + stride_out_chunk=out.stride(0), + stride_out_head=out.stride(1), + stride_outm=out.stride(-2), + stride_outn=out.stride(-1), + IS_CAUSAL=causal, + dot_dtype=dot_dtype, + ) + return out diff --git a/megatron/core/ssm/ops/ssd_chunk_scan.py b/megatron/core/ssm/ops/ssd_chunk_scan.py new file mode 100644 index 00000000000..521a294db5d --- /dev/null +++ b/megatron/core/ssm/ops/ssd_chunk_scan.py @@ -0,0 +1,424 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from: +# https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_scan.py +# Adapted from vLLM project (Apache-2.0). + +import triton +import triton.language as tl +from packaging import version + +from megatron.core.ssm.ops.determinism import autotune_configs + +TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0") + + +@triton.autotune( + configs=autotune_configs( + [ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 64}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 64}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=2, + ), + ] + ), + key=["chunk_size", "hdim", "dstate", "IS_CAUSAL"], +) +@triton.jit +def _chunk_scan_fwd_kernel( + # Pointers to matrices + cb_ptr, + x_ptr, + z_ptr, + out_ptr, + dt_ptr, + dA_cumsum_ptr, + seq_idx_ptr, + C_ptr, + states_ptr, + D_ptr, + initstates_ptr, + cu_chunk_seqlens_ptr, + # Matrix dimensions + chunk_size: tl.constexpr, + hdim: tl.constexpr, + dstate: tl.constexpr, + seqlen, + nheads_ngroups_ratio: tl.constexpr, + # Strides + stride_cb_chunk: tl.int64, + stride_cb_head: tl.int64, + stride_cb_csize_m: tl.int64, + stride_cb_csize_k: tl.constexpr, + stride_x_seqlen: tl.int64, + stride_x_head: tl.int64, + stride_x_hdim: tl.constexpr, + stride_z_seqlen: tl.int64, + stride_z_head: tl.int64, + stride_z_hdim: tl.constexpr, + stride_out_seqlen: tl.int64, + stride_out_head: tl.int64, + stride_out_hdim: tl.constexpr, + stride_dt_chunk: tl.int64, + stride_dt_head: tl.int64, + stride_dt_csize: tl.constexpr, + stride_dA_cs_chunk: tl.int64, + stride_dA_cs_head: tl.int64, + stride_dA_cs_csize: tl.constexpr, + stride_seq_idx_chunk: tl.constexpr, + stride_C_seqlen: tl.int64, + stride_C_head: tl.int64, + stride_C_dstate: tl.constexpr, + stride_states_chunk: tl.int64, + stride_states_head: tl.int64, + stride_states_hdim: tl.int64, + stride_states_dstate: tl.constexpr, + stride_init_states_batch: tl.int64, + stride_init_states_head: tl.int64, + stride_init_states_hdim: tl.int64, + stride_init_states_dstate: tl.constexpr, + stride_D_head: tl.constexpr, + # Meta-parameters + IS_CAUSAL: tl.constexpr, + HAS_D: tl.constexpr, + D_HAS_HDIM: tl.constexpr, + HAS_Z: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_DSTATE: tl.constexpr, + IS_TRITON_22: tl.constexpr, + HAS_INITSTATES: tl.constexpr, +): + pid_c = tl.program_id(axis=1).to(tl.int64) + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(hdim, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + cb_ptr += pid_c * stride_cb_chunk + (pid_h // nheads_ngroups_ratio) * stride_cb_head + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) + x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + C_ptr += chunk_seqlen_start * stride_C_seqlen + (pid_h // nheads_ngroups_ratio) * stride_C_head + + # M-block offsets and prev states + # - logic in next block may override these if there is an active offset + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + seq_idx_ptr += pid_c * stride_seq_idx_chunk + seq_idx = tl.load(seq_idx_ptr) + seq_idx_prev = tl.load(seq_idx_ptr - stride_seq_idx_chunk, mask=pid_c >= 1, other=-1) + + if HAS_INITSTATES and (seq_idx != seq_idx_prev): + prev_states_ptr = ( + initstates_ptr + seq_idx * stride_init_states_batch + pid_h * stride_init_states_head + ) + prev_states_hdim = stride_init_states_hdim + prev_states_dstate = stride_init_states_dstate + else: + prev_states_ptr = ( + states_ptr + (pid_c - 1) * stride_states_chunk + pid_h * stride_states_head + ) + prev_states_hdim = stride_states_hdim + prev_states_dstate = stride_states_dstate + + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start + + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + dA_cs_m = tl.load( + dA_cumsum_ptr + offs_m * stride_dA_cs_csize, mask=offs_m < chunk_size, other=0.0 + ).to(tl.float32) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + # Faster to just do 1 iteration with larger BLOCK_SIZE_K, up to block size 128 + offs_k_dstate = tl.arange(0, BLOCK_SIZE_DSTATE if BLOCK_SIZE_DSTATE <= 128 else BLOCK_SIZE_K) + C_ptrs = C_ptr + (offs_m[:, None] * stride_C_seqlen + offs_k_dstate[None, :] * stride_C_dstate) + + scale_m = tl.exp(dA_cs_m) + if BLOCK_SIZE_DSTATE <= 128: + C = tl.load( + C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate), + other=0.0, + ) + + if not HAS_INITSTATES and (seq_idx != seq_idx_prev): + # if no init states AND starting a new sequence, we need zeros + prev_states = tl.zeros((BLOCK_SIZE_DSTATE, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty) + else: + # otherwise read the previous state + prev_states_ptrs = ( + prev_states_ptr + + offs_n[None, :] * prev_states_hdim + + offs_k_dstate[:, None] * prev_states_dstate + ) + prev_states = tl.load( + prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate) & (offs_n[None, :] < hdim), + other=0.0, + ) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + + acc = tl.dot(C, prev_states) * scale_m[:, None] + + else: + prev_states_ptrs = ( + prev_states_ptr + + offs_n[None, :] * prev_states_hdim + + offs_k_dstate[:, None] * prev_states_dstate + ) + for k in range(0, dstate, BLOCK_SIZE_K): + C = tl.load( + C_ptrs, + mask=(offs_m[:, None] < chunk_size_limit) & (offs_k_dstate[None, :] < dstate - k), + other=0.0, + ) + if not HAS_INITSTATES and (seq_idx != seq_idx_prev): + prev_states = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_N), dtype=C_ptr.dtype.element_ty) + else: + prev_states = tl.load( + prev_states_ptrs, + mask=(offs_k_dstate[:, None] < dstate - k) & (offs_n[None, :] < hdim), + other=0.0, + ) + prev_states = prev_states.to(C_ptr.dtype.element_ty) + acc += tl.dot(C, prev_states) + C_ptrs += BLOCK_SIZE_K + prev_states_ptrs += BLOCK_SIZE_K + acc *= scale_m[:, None] + + offs_k = tl.arange(0, BLOCK_SIZE_K) + cb_ptrs = cb_ptr + (offs_m[:, None] * stride_cb_csize_m + offs_k[None, :] * stride_cb_csize_k) + x_ptrs = x_ptr + (offs_k[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + K_MAX = chunk_size_limit if not IS_CAUSAL else min((pid_m + 1) * BLOCK_SIZE_M, chunk_size_limit) + for k in range(0, K_MAX, BLOCK_SIZE_K): + cb = tl.load( + cb_ptrs, + mask=(offs_m[:, None] < chunk_size) & (offs_k[None, :] < chunk_size - k), + other=0.0, + ).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) + # If there's seq_idx, we already set cb[i, j] = 0 for seq_idx[i] != seq_idx[j]. + # So we don't need masking wrt seq_idx here. + cb *= tl.exp(tl.minimum(dA_cs_m[:, None] - dA_cs_k[None, :], 0.0)) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size - k, other=0.0).to(tl.float32) + cb *= dt_k + if IS_CAUSAL: + mask = offs_m[:, None] >= k + offs_k[None, :] + cb = tl.where(mask, cb, 0.0) + cb = cb.to(x_ptr.dtype.element_ty) + x = tl.load( + x_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < hdim), + other=0.0, + ) + acc += tl.dot(cb, x) + cb_ptrs += BLOCK_SIZE_K * stride_cb_csize_k + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + offs_out_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_out_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + if HAS_D: + if D_HAS_HDIM: + D = tl.load(D_ptr + pid_h * stride_D_head + offs_n, mask=offs_n < hdim, other=0.0).to( + tl.float32 + ) + else: + D = tl.load(D_ptr + pid_h * stride_D_head).to(tl.float32) + x_residual = tl.load( + x_ptr + (offs_m[:, None] * stride_x_seqlen + offs_n[None, :] * stride_x_hdim), + mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < hdim), + other=0.0, + ).to(tl.float32) + acc += x_residual * D + + if HAS_Z: + z_ptr += chunk_seqlen_start * stride_z_seqlen + pid_h * stride_z_head + z_ptrs = z_ptr + ( + stride_z_seqlen * offs_out_m[:, None] + stride_z_hdim * offs_out_n[None, :] + ) + z = tl.load( + z_ptrs, + mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim), + other=0.0, + ).to(tl.float32) + acc *= z * tl.sigmoid(z) + + out_ptr += chunk_seqlen_start * stride_out_seqlen + pid_h * stride_out_head + out_ptrs = out_ptr + ( + stride_out_seqlen * offs_out_m[:, None] + offs_out_n[None, :] * stride_out_hdim + ) + tl.store( + out_ptrs, acc, mask=(offs_out_m[:, None] < chunk_size_limit) & (offs_out_n[None, :] < hdim) + ) + + +def _chunk_scan_fwd( + cb, + x, + dt, + dA_cumsum, + C, + states, + cu_chunk_seqlens, + out, + seq_idx, + D=None, + z=None, + initial_states=None, +): + assert seq_idx is not None, "this implementation requires seq_idx" + + seqlen, nheads, headdim = x.shape + _, nchunks, chunk_size = dt.shape + _, ngroups, dstate = C.shape + assert nheads % ngroups == 0 + assert C.shape == (seqlen, ngroups, dstate) + assert cb.shape == (nchunks, ngroups, chunk_size, chunk_size) + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + if z is not None: + assert z.shape == x.shape + assert dt.shape == (nheads, nchunks, chunk_size) + assert dA_cumsum.shape == (nheads, nchunks, chunk_size) + assert states.shape == (nchunks, nheads, headdim, dstate) + assert seq_idx.shape == (nchunks,) + + grid = lambda META: ( + triton.cdiv(chunk_size, META["BLOCK_SIZE_M"]) * triton.cdiv(headdim, META["BLOCK_SIZE_N"]), + nchunks, + nheads, + ) + + z_strides = (z.stride(0), z.stride(1), z.stride(2)) if z is not None else (0, 0, 0) + initial_states_strides = ( + ( + initial_states.stride(0), + initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3), + ) + if initial_states is not None + else (0, 0, 0, 0) + ) + + _chunk_scan_fwd_kernel[grid]( + cb_ptr=cb, + x_ptr=x, + z_ptr=z, + out_ptr=out, + dt_ptr=dt, + dA_cumsum_ptr=dA_cumsum, + seq_idx_ptr=seq_idx, + C_ptr=C, + states_ptr=states, + D_ptr=D, + initstates_ptr=initial_states, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, + chunk_size=chunk_size, + hdim=headdim, + dstate=dstate, + seqlen=seqlen, + nheads_ngroups_ratio=nheads // ngroups, + stride_cb_chunk=cb.stride(0), + stride_cb_head=cb.stride(1), + stride_cb_csize_m=cb.stride(2), + stride_cb_csize_k=cb.stride(3), + stride_x_seqlen=x.stride(0), + stride_x_head=x.stride(1), + stride_x_hdim=x.stride(2), + stride_z_seqlen=z_strides[0], + stride_z_head=z_strides[1], + stride_z_hdim=z_strides[2], + stride_out_seqlen=out.stride(0), + stride_out_head=out.stride(1), + stride_out_hdim=out.stride(2), + stride_dt_chunk=dt.stride(1), + stride_dt_head=dt.stride(0), + stride_dt_csize=dt.stride(2), + stride_dA_cs_chunk=dA_cumsum.stride(1), + stride_dA_cs_head=dA_cumsum.stride(0), + stride_dA_cs_csize=dA_cumsum.stride(2), + stride_seq_idx_chunk=seq_idx.stride(0), + stride_C_seqlen=C.stride(0), + stride_C_head=C.stride(1), + stride_C_dstate=C.stride(2), + stride_states_chunk=states.stride(0), + stride_states_head=states.stride(1), + stride_states_hdim=states.stride(2), + stride_states_dstate=states.stride(3), + stride_init_states_batch=initial_states_strides[0], + stride_init_states_head=initial_states_strides[1], + stride_init_states_hdim=initial_states_strides[2], + stride_init_states_dstate=initial_states_strides[3], + stride_D_head=D.stride(0) if D is not None else 0, + IS_CAUSAL=True, + HAS_D=D is not None, + D_HAS_HDIM=D.dim() == 2 if D is not None else True, + HAS_Z=z is not None, + BLOCK_SIZE_DSTATE=max(triton.next_power_of_2(dstate), 16), + IS_TRITON_22=TRITON_22, + HAS_INITSTATES=initial_states is not None, + ) + return diff --git a/megatron/core/ssm/ops/ssd_chunk_state.py b/megatron/core/ssm/ops/ssd_chunk_state.py new file mode 100644 index 00000000000..473af1491aa --- /dev/null +++ b/megatron/core/ssm/ops/ssd_chunk_state.py @@ -0,0 +1,679 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from: +# https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_chunk_state.py +# Adapted from vLLM project (Apache-2.0). + +import torch +import triton +import triton.language as tl +from packaging import version + +from megatron.core.ssm.ops.determinism import autotune_configs + +try: + TRITON3 = version.parse(triton.__version__) >= version.parse("3.0.0") +except: + raise ImportError("Triton version 3.0.0 or higher is required") + +if TRITON3: + + @triton.jit + def softplus(dt): # pylint: disable=C0116 + dt = tl.where(dt <= 20.0, tl.math.log(tl.math.exp(dt) + 1), dt) + return dt + +else: + + @triton.jit + def softplus(dt): # pylint: disable=C0116 + dt = tl.where(dt <= 20.0, tl.math.log1p(tl.exp(dt)), dt) + return dt + + +@triton.autotune( + configs=autotune_configs( + [ + triton.Config({"BLOCK_SIZE_H": 2}), + triton.Config({"BLOCK_SIZE_H": 4}), + triton.Config({"BLOCK_SIZE_H": 8}), + triton.Config({"BLOCK_SIZE_H": 16}), + triton.Config({"BLOCK_SIZE_H": 32}), + triton.Config({"BLOCK_SIZE_H": 64}), + ] + ), + key=["chunk_size", "nheads"], +) +@triton.jit +def _chunk_cumsum_fwd_kernel( + # Pointers to matrices + dt_ptr, + A_ptr, + dt_bias_ptr, + dt_out_ptr, + dA_cumsum_ptr, + cu_chunk_seqlens_ptr, + # Matrix dimension + seqlen, + nheads: tl.constexpr, + chunk_size: tl.constexpr, + dt_min: tl.constexpr, + dt_max: tl.constexpr, + # Strides + stride_dt_seqlen: tl.int64, + stride_dt_head: tl.constexpr, + stride_A_head: tl.constexpr, + stride_dt_bias_head: tl.constexpr, + stride_dt_out_head: tl.int64, + stride_dt_out_chunk: tl.int64, + stride_dt_out_csize: tl.constexpr, + stride_dA_cs_head: tl.int64, + stride_dA_cs_chunk: tl.int64, + stride_dA_cs_csize: tl.constexpr, + # Meta-parameters + DT_SOFTPLUS: tl.constexpr, + HAS_DT_BIAS: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_CHUNK: tl.constexpr, +): + # if dt is long, may cause problems, so use 64 bit + # https://github.com/triton-lang/triton/issues/1058 + pid_c = tl.program_id(axis=0).to(tl.int64) + pid_h = tl.program_id(axis=1) + + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) + + dt_ptr += chunk_seqlen_start * stride_dt_seqlen + dt_out_ptr += pid_c * stride_dt_out_chunk + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + + offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H) + offs_c = tl.arange(0, BLOCK_SIZE_CHUNK) + dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen) + A_ptrs = A_ptr + offs_h * stride_A_head + dt_out_ptrs = dt_out_ptr + ( + offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize + ) + dA_cs_ptrs = dA_cumsum_ptr + ( + offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize + ) + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start + + dt = tl.load( + dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0 + ).to(tl.float32) + if HAS_DT_BIAS: + dt_bias = tl.load( + dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0 + ).to(tl.float32) + dt += dt_bias[:, None] + if DT_SOFTPLUS: + dt = tl.where(dt <= 20.0, softplus(dt), dt) + + dt = tl.clamp(dt, dt_min, dt_max) + dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0) + tl.store(dt_out_ptrs, dt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32) + dA = dt * A[:, None] + dA_cs = tl.cumsum(dA, axis=1) + tl.store(dA_cs_ptrs, dA_cs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size)) + + +@triton.autotune( + configs=autotune_configs( + [ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=2, + ), + ] + ), + key=["hdim", "dstate", "chunk_size"], +) +@triton.jit +def _chunk_state_fwd_kernel( + # Pointers to matrices + x_ptr, + b_ptr, + states_ptr, + dt_ptr, + dA_cumsum_ptr, + cu_chunk_seqlens_ptr, + # Matrix dimensions + hdim: tl.constexpr, + dstate: tl.constexpr, + chunk_size: tl.constexpr, + seqlen, + nheads_ngroups_ratio: tl.constexpr, + # Strides + stride_x_seqlen: tl.int64, + stride_x_head: tl.int64, + stride_x_hdim: tl.constexpr, + stride_b_seqlen: tl.int64, + stride_b_head: tl.int64, + stride_b_dstate: tl.constexpr, + stride_states_chunk: tl.int64, + stride_states_head: tl.int64, + stride_states_hdim: tl.int64, + stride_states_dstate: tl.constexpr, + stride_dt_head: tl.int64, + stride_dt_chunk: tl.int64, + stride_dt_csize: tl.constexpr, + stride_dA_cs_head: tl.int64, + stride_dA_cs_chunk: tl.int64, + stride_dA_cs_csize: tl.constexpr, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): + pid_c = tl.program_id(axis=1).to(tl.int64) + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + chunk_seqlen_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_seqlen_end = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) + b_ptr += chunk_seqlen_start * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head + x_ptr += chunk_seqlen_start * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen) + b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load(dA_cumsum_ptr + (chunk_size - 1) * stride_dA_cs_csize).to(tl.float32) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + + chunk_size_limit = chunk_seqlen_end - chunk_seqlen_start + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + x = tl.load( + x_ptrs, + mask=(offs_m[:, None] < hdim) & (offs_k[None, :] < chunk_size_limit - k), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) & (offs_n[None, :] < dstate), + other=0.0, + ).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to( + tl.float32 + ) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) + scale = tl.exp(tl.minimum(dA_cs_last - dA_cs_k, 0.0)) * dt_k + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + states = acc.to(states_ptr.dtype.element_ty) + + states_ptr += pid_c * stride_states_chunk + pid_h * stride_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + ( + offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate + ) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + +def _chunk_cumsum_fwd( + dt, + A, + chunk_size, + cu_chunk_seqlens, + dt_bias=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), +): + seqlen, nheads = dt.shape + assert A.shape == (nheads,) + if dt_bias is not None: + assert dt_bias.shape == (nheads,) + nchunks = cu_chunk_seqlens.shape[0] - 1 + dt_out = torch.empty(nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) + dA_cumsum = torch.empty(nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32) + grid_chunk_cs = lambda META: (nchunks, triton.cdiv(nheads, META["BLOCK_SIZE_H"])) + with torch.cuda.device(dt.device.index): + _chunk_cumsum_fwd_kernel[grid_chunk_cs]( + dt_ptr=dt, + A_ptr=A, + dt_bias_ptr=dt_bias, + dt_out_ptr=dt_out, + dA_cumsum_ptr=dA_cumsum, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, + seqlen=seqlen, + nheads=nheads, + chunk_size=chunk_size, + dt_min=dt_limit[0], + dt_max=dt_limit[1], + stride_dt_seqlen=dt.stride(0), + stride_dt_head=dt.stride(1), + stride_A_head=A.stride(0), + stride_dt_bias_head=dt_bias.stride(0) if dt_bias is not None else 0, + stride_dt_out_head=dt_out.stride(0), + stride_dt_out_chunk=dt_out.stride(1), + stride_dt_out_csize=dt_out.stride(2), + stride_dA_cs_head=dA_cumsum.stride(0), + stride_dA_cs_chunk=dA_cumsum.stride(1), + stride_dA_cs_csize=dA_cumsum.stride(2), + DT_SOFTPLUS=dt_softplus, + HAS_DT_BIAS=dt_bias is not None, + BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size), + ) + return dA_cumsum, dt_out + + +def _chunk_state_fwd(B, x, dt, dA_cumsum, cu_chunk_seqlens, states=None, states_in_fp32=True): + seqlen, nheads, headdim = x.shape + _, nchunks, chunk_size = dt.shape + _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == (seqlen, ngroups, dstate) + assert dt.shape == (nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + + if states is not None: + assert states.shape == (nchunks, nheads, headdim, dstate) + else: + states_dtype = torch.float32 if states_in_fp32 else B.dtype + states = torch.empty( + (nchunks, nheads, headdim, dstate), device=x.device, dtype=states_dtype + ) + + grid = lambda META: ( + triton.cdiv(headdim, META["BLOCK_SIZE_M"]) * triton.cdiv(dstate, META["BLOCK_SIZE_N"]), + nchunks, + nheads, + ) + with torch.cuda.device(x.device.index): + _chunk_state_fwd_kernel[grid]( + x_ptr=x, + b_ptr=B, + states_ptr=states, + dt_ptr=dt, + dA_cumsum_ptr=dA_cumsum, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, + hdim=headdim, + dstate=dstate, + chunk_size=chunk_size, + seqlen=seqlen, + nheads_ngroups_ratio=nheads // ngroups, + stride_x_seqlen=x.stride(0), + stride_x_head=x.stride(1), + stride_x_hdim=x.stride(2), + stride_b_seqlen=B.stride(0), + stride_b_head=B.stride(1), + stride_b_dstate=B.stride(2), + stride_states_chunk=states.stride(0), + stride_states_head=states.stride(1), + stride_states_hdim=states.stride(2), + stride_states_dstate=states.stride(3), + stride_dt_head=dt.stride(0), + stride_dt_chunk=dt.stride(1), + stride_dt_csize=dt.stride(2), + stride_dA_cs_head=dA_cumsum.stride(0), + stride_dA_cs_chunk=dA_cumsum.stride(1), + stride_dA_cs_csize=dA_cumsum.stride(2), + ) + return states + + +@triton.autotune( + configs=autotune_configs( + [ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 64}, + num_stages=3, + num_warps=8, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=4, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 32, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=5, + num_warps=2, + ), + triton.Config( + {"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 64, "BLOCK_SIZE_K": 32}, + num_stages=4, + num_warps=2, + ), + ] + ), + key=["hdim", "dstate", "chunk_size"], +) +@triton.jit +def _chunk_state_varlen_kernel( + x_ptr, + b_ptr, + dt_ptr, + dA_cumsum_ptr, + chunk_states_ptr, + cu_seqlens_ptr, + last_chunk_indices_ptr, + cu_chunk_seqlens_ptr, + states_ptr, + initstates_ptr, + hdim: tl.constexpr, + dstate: tl.constexpr, + chunk_size: tl.constexpr, + nheads_ngroups_ratio: tl.constexpr, + stride_x_seqlen: tl.int64, + stride_x_head: tl.int64, + stride_x_hdim: tl.constexpr, + stride_b_seqlen: tl.int64, + stride_b_head: tl.int64, + stride_b_dstate: tl.constexpr, + stride_dt_head: tl.int64, + stride_dt_chunk: tl.int64, + stride_dt_csize: tl.constexpr, + stride_dA_cs_head: tl.int64, + stride_dA_cs_chunk: tl.int64, + stride_dA_cs_csize: tl.constexpr, + stride_chunk_states_chunk: tl.int64, + stride_chunk_states_head: tl.int64, + stride_chunk_states_hdim: tl.int64, + stride_chunk_states_dstate: tl.constexpr, + stride_states_batch: tl.int64, + stride_states_head: tl.int64, + stride_states_hdim: tl.int64, + stride_states_dstate: tl.constexpr, + stride_init_states_batch: tl.int64, + stride_init_states_head: tl.int64, + stride_init_states_hdim: tl.int64, + stride_init_states_dstate: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + HAS_INITSTATES: tl.constexpr, + USE_LAST_CHUNK_INDICES: tl.constexpr, +): + pid_b = tl.program_id(axis=1) + pid_h = tl.program_id(axis=2) + num_pid_n = tl.cdiv(dstate, BLOCK_SIZE_N) + pid_m = tl.program_id(axis=0) // num_pid_n + pid_n = tl.program_id(axis=0) % num_pid_n + end_idx = tl.load(cu_seqlens_ptr + pid_b + 1) + start_idx = tl.load(cu_seqlens_ptr + pid_b) + if USE_LAST_CHUNK_INDICES: + pid_c = tl.load(last_chunk_indices_ptr + pid_b).to(tl.int64) + chunk_start = tl.load(cu_chunk_seqlens_ptr + pid_c) + chunk_size_limit = tl.load(cu_chunk_seqlens_ptr + pid_c + 1) - chunk_start + else: + pid_c = (end_idx - 1) // chunk_size + chunk_start = pid_c * chunk_size + chunk_size_limit = end_idx - chunk_start + b_ptr += chunk_start * stride_b_seqlen + (pid_h // nheads_ngroups_ratio) * stride_b_head + x_ptr += chunk_start * stride_x_seqlen + pid_h * stride_x_head + dt_ptr += pid_c * stride_dt_chunk + pid_h * stride_dt_head + dA_cumsum_ptr += pid_c * stride_dA_cs_chunk + pid_h * stride_dA_cs_head + chunk_states_ptr += pid_c * stride_chunk_states_chunk + pid_h * stride_chunk_states_head + + if HAS_INITSTATES: + initstates_ptr += pid_h * stride_init_states_head + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + x_ptrs = x_ptr + (offs_m[:, None] * stride_x_hdim + offs_k[None, :] * stride_x_seqlen) + b_ptrs = b_ptr + (offs_n[None, :] * stride_b_dstate + offs_k[:, None] * stride_b_seqlen) + dt_ptrs = dt_ptr + offs_k * stride_dt_csize + dA_cs_last = tl.load(dA_cumsum_ptr + (end_idx - 1 - chunk_start) * stride_dA_cs_csize).to( + tl.float32 + ) + dA_cumsum_ptrs = dA_cumsum_ptr + offs_k * stride_dA_cs_csize + + start_idx_cur = tl.maximum(start_idx - chunk_start, 0) + + acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, chunk_size_limit, BLOCK_SIZE_K): + x = tl.load( + x_ptrs, + mask=(offs_m[:, None] < hdim) + & (offs_k[None, :] < chunk_size_limit - k) + & (offs_k[None, :] >= start_idx_cur - k), + other=0.0, + ) + b = tl.load( + b_ptrs, + mask=(offs_k[:, None] < chunk_size_limit - k) + & (offs_n[None, :] < dstate) + & (offs_k[:, None] >= start_idx_cur - k), + other=0.0, + ).to(tl.float32) + dA_cs_k = tl.load(dA_cumsum_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to( + tl.float32 + ) + dt_k = tl.load(dt_ptrs, mask=offs_k < chunk_size_limit - k, other=0.0).to(tl.float32) + scale = tl.where( + (offs_k >= start_idx_cur - k) & (offs_k < chunk_size_limit - k), + tl.exp(tl.minimum(dA_cs_last - dA_cs_k, 0.0)) * dt_k, + 0.0, + ) + b *= scale[:, None] + b = b.to(x_ptr.dtype.element_ty) + acc += tl.dot(x, b) + x_ptrs += BLOCK_SIZE_K * stride_x_seqlen + b_ptrs += BLOCK_SIZE_K * stride_b_seqlen + dt_ptrs += BLOCK_SIZE_K * stride_dt_csize + dA_cumsum_ptrs += BLOCK_SIZE_K * stride_dA_cs_csize + + if (start_idx < chunk_start) or (HAS_INITSTATES): + dA_cs_boundary = 0.0 + if not HAS_INITSTATES: + past_states_ptrs = chunk_states_ptr + ( + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate + ) + else: + if start_idx < chunk_start: + past_states_ptrs = chunk_states_ptr + ( + offs_m[:, None] * stride_chunk_states_hdim + + offs_n[None, :] * stride_chunk_states_dstate + ) + else: + past_states_ptrs = initstates_ptr + ( + pid_b * stride_init_states_batch + + offs_m[:, None] * stride_init_states_hdim + + offs_n[None, :] * stride_init_states_dstate + ) + if start_idx > chunk_start: + dA_cs_boundary = tl.load( + dA_cumsum_ptr + (start_idx - chunk_start - 1) * stride_dA_cs_csize + ).to(tl.float32) + + past_states = tl.load( + past_states_ptrs, mask=(offs_m[:, None] < hdim) & (offs_n[None, :] < dstate), other=0.0 + ).to(tl.float32) + scale = tl.exp(tl.minimum(dA_cs_last - dA_cs_boundary, 0.0)) + acc += past_states * scale + + states = acc.to(states_ptr.dtype.element_ty) + states_ptr += pid_b * stride_states_batch + pid_h * stride_states_head + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + states_ptrs = states_ptr + ( + offs_m[:, None] * stride_states_hdim + offs_n[None, :] * stride_states_dstate + ) + c_mask = (offs_m[:, None] < hdim) & (offs_n[None, :] < dstate) + tl.store(states_ptrs, states, mask=c_mask) + + +def chunk_state_varlen( + B, + x, + dt, + dA_cumsum, + cu_seqlens, + chunk_states, + initial_states=None, + last_chunk_indices=None, + cu_chunk_seqlens=None, +): + """Compute per-sequence final SSM state from chunk states. + + Correct when sequences share chunks. + """ + total_seqlen, nheads, headdim = x.shape + _, nchunks, chunk_size = dt.shape + _, ngroups, dstate = B.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + assert nheads % ngroups == 0 + assert B.shape == (total_seqlen, ngroups, dstate) + assert dt.shape == (nheads, nchunks, chunk_size) + assert dA_cumsum.shape == dt.shape + assert chunk_states.shape == (nchunks, nheads, headdim, dstate) + if initial_states is not None: + assert initial_states.shape == (batch, nheads, headdim, dstate) + use_last_chunk = last_chunk_indices is not None and cu_chunk_seqlens is not None + if use_last_chunk: + last_chunk_indices = last_chunk_indices.contiguous().to(x.device) + cu_chunk_seqlens = cu_chunk_seqlens.contiguous().to(x.device) + else: + last_chunk_indices = torch.zeros(1, dtype=torch.int64, device=x.device) + cu_chunk_seqlens = cu_seqlens + + states = torch.empty( + batch, nheads, headdim, dstate, dtype=chunk_states.dtype, device=chunk_states.device + ) + initial_states_strides = ( + ( + initial_states.stride(0), + initial_states.stride(1), + initial_states.stride(2), + initial_states.stride(3), + ) + if initial_states is not None + else (0, 0, 0, 0) + ) + grid = lambda META: ( + triton.cdiv(headdim, META["BLOCK_SIZE_M"]) * triton.cdiv(dstate, META["BLOCK_SIZE_N"]), + batch, + nheads, + ) + with torch.cuda.device(x.device.index): + _chunk_state_varlen_kernel[grid]( + x_ptr=x, + b_ptr=B, + dt_ptr=dt, + dA_cumsum_ptr=dA_cumsum, + chunk_states_ptr=chunk_states, + cu_seqlens_ptr=cu_seqlens, + last_chunk_indices_ptr=last_chunk_indices, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, + states_ptr=states, + initstates_ptr=initial_states, + hdim=headdim, + dstate=dstate, + chunk_size=chunk_size, + nheads_ngroups_ratio=nheads // ngroups, + stride_x_seqlen=x.stride(0), + stride_x_head=x.stride(1), + stride_x_hdim=x.stride(2), + stride_b_seqlen=B.stride(0), + stride_b_head=B.stride(1), + stride_b_dstate=B.stride(2), + stride_dt_head=dt.stride(0), + stride_dt_chunk=dt.stride(1), + stride_dt_csize=dt.stride(2), + stride_dA_cs_head=dA_cumsum.stride(0), + stride_dA_cs_chunk=dA_cumsum.stride(1), + stride_dA_cs_csize=dA_cumsum.stride(2), + stride_chunk_states_chunk=chunk_states.stride(0), + stride_chunk_states_head=chunk_states.stride(1), + stride_chunk_states_hdim=chunk_states.stride(2), + stride_chunk_states_dstate=chunk_states.stride(3), + stride_states_batch=states.stride(0), + stride_states_head=states.stride(1), + stride_states_hdim=states.stride(2), + stride_states_dstate=states.stride(3), + stride_init_states_batch=initial_states_strides[0], + stride_init_states_head=initial_states_strides[1], + stride_init_states_hdim=initial_states_strides[2], + stride_init_states_dstate=initial_states_strides[3], + HAS_INITSTATES=initial_states is not None, + USE_LAST_CHUNK_INDICES=use_last_chunk, + ) + return states diff --git a/megatron/core/ssm/ops/ssd_combined.py b/megatron/core/ssm/ops/ssd_combined.py new file mode 100644 index 00000000000..4fcee98b13e --- /dev/null +++ b/megatron/core/ssm/ops/ssd_combined.py @@ -0,0 +1,234 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from: +# https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_combined.py +# Adapted from vLLM project (Apache-2.0). + +import torch +import triton +from packaging import version + +from .ssd_bmm import _bmm_chunk_fwd +from .ssd_chunk_scan import _chunk_scan_fwd +from .ssd_chunk_state import _chunk_cumsum_fwd, _chunk_state_fwd +from .ssd_state_passing import _state_passing_fwd + +TRITON_22 = version.parse(triton.__version__) >= version.parse("2.2.0") + + +def is_int_pow_2(n): + """Return True if n is a positive integer power of 2.""" + return isinstance(n, int) and n > 0 and (n & (n - 1)) == 0 + + +def _mamba_chunk_scan_combined_fwd( + x, + dt, + A, + B, + C, + chunk_size, + out, + D=None, + z=None, + dt_bias=None, + initial_states=None, + return_intermediate_states=False, + seq_idx=None, + cu_chunk_seqlens=None, + last_chunk_indices=None, + intermediate_chunk_indices=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + state_dtype=None, +): + assert is_int_pow_2(chunk_size), "chunk_size must be integer power of 2" + seqlen, nheads, headdim = x.shape + _, ngroups, dstate = B.shape + assert nheads % ngroups == 0 + assert B.shape == ( + seqlen, + ngroups, + dstate, + ), f"B.shape={B.shape} != ({seqlen}, {ngroups}, {dstate})" + assert dt.shape == (seqlen, nheads) + assert A.shape == (nheads,) + assert C.shape == B.shape + if z is not None: + assert z.shape == x.shape + if D is not None: + assert D.shape == (nheads, headdim) or D.shape == (nheads,) + if seq_idx is not None: + assert seq_idx.shape == (cu_chunk_seqlens.shape[0] - 1,) + if B.stride(-1) != 1: + B = B.contiguous() + if C.stride(-1) != 1: + C = C.contiguous() + if x.stride(-1) != 1 and x.stride(0) != 1: # Either M or K dimension should be contiguous + x = x.contiguous() + if ( + z is not None and z.stride(-1) != 1 and z.stride(0) != 1 + ): # Either M or K dimension should be contiguous + z = z.contiguous() + if D is not None and D.stride(-1) != 1: + D = D.contiguous() + assert cu_chunk_seqlens is not None, "Assuming varlen input - must supply cu_chunk_seqlens" + assert last_chunk_indices is not None, "last_chunk_indices must be provided" + + if initial_states is not None: + num_seqs = last_chunk_indices.shape[0] + assert initial_states.shape == (num_seqs, nheads, headdim, dstate) + + # This function executes 5 sub-functions for computing mamba + # - a good resource is the blog https://goombalab.github.io/blog/2024/mamba2-part3-algorithm/ + # which has a minimal implementation to understand the below operations + # - as explained by the blog, mamba is a special case of causal attention + # - the idea is to chunk the attention matrix and compute each + # submatrix separately using different optimizations. + # - see the blog and paper for a visualization of the submatrices + # which we refer to in the comments below + + # 1. Compute chunked cumsum of A * dt + # - here dt may go through a softplus activation + dA_cumsum, dt = _chunk_cumsum_fwd( + dt, + A, + chunk_size, + cu_chunk_seqlens, + dt_bias=dt_bias, + dt_softplus=dt_softplus, + dt_limit=dt_limit, + ) + + # 2. Compute the state for each intra-chunk + # (right term of low-rank factorization of off-diagonal blocks; B terms) + states = _chunk_state_fwd(B, x, dt, dA_cumsum, cu_chunk_seqlens, states_in_fp32=True) + + # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries + # (middle term of factorization of off-diag blocks; A terms) + # - for handling chunked prefill, this requires i) initial_states and + # ii) seq_idx to be all specified. + # - When a new seq_idx is detected, we will stop passing the prev_state + # and switch accordingly to the init_state corresponding to the new seq_idx. + states = _state_passing_fwd( + states.flatten(-2), # ... p n -> ... (p n) + dA_cumsum, # (nheads, nchunks, chunk_size) + cu_chunk_seqlens, + initial_states=( + initial_states.flatten(-2) if initial_states is not None else None + ), # (batch, nheads, headdim*dstate) + seq_idx=seq_idx, + out_dtype=state_dtype if state_dtype is not None else C.dtype, + ) + states = states.unflatten(-1, (-1, dstate)) + + # 4. Compute batched matrix multiply for C_j^T B_i terms + CB = _bmm_chunk_fwd(C, B, chunk_size, cu_chunk_seqlens, output_dtype=torch.float32) + + # 5. Scan and compute the diagonal blocks, taking into + # account past causal states. + # - if initial states are provided, then states information will be + # augmented with initial_states. + # - to do this properly, we need to account for example changes in + # the continuous batch, therefore we introduce pseudo chunks, which is + # a chunk that is split up each time an example changes. + # - in each (pseudo) chunk, we detect if the previous (pseudo) chunk had + # a seq_idx change, in which case we take states information from + # init_states. + _chunk_scan_fwd( + CB, + x, + dt, + dA_cumsum, + C, + states, + cu_chunk_seqlens, + out, # in-place update + seq_idx, + D=D, + z=z, + initial_states=initial_states, + ) + + if return_intermediate_states: + return states + + final_states = states[last_chunk_indices] + if intermediate_chunk_indices is not None: + intermediate_states = states[intermediate_chunk_indices] + return final_states, intermediate_states + else: + return final_states + + +def mamba_chunk_scan_combined_varlen( + x, + dt, + A, + B, + C, + chunk_size, + cu_chunk_seqlens, + last_chunk_indices, + seq_idx, + out, + D=None, + z=None, + dt_bias=None, + initial_states=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + return_intermediate_states=False, + intermediate_chunk_indices=None, + state_dtype=None, +): + """ + Argument: + x: (seqlen, nheads, headdim) + dt: (seqlen, nheads) + A: (nheads) + B: (seqlen, ngroups, dstate) + C: (seqlen, ngroups, dstate) + chunk_size: int + cu_chunk_seqlens: (nchunks + 1,) + last_chunk_indices: (batch,) + seq_idx: (nchunks,) + out: (seqlen, nheads, headdim) preallocated output tensor + D: (nheads, headdim) or (nheads,) + z: (seqlen, nheads, headdim) + dt_bias: (nheads,) + initial_states: (batch, nheads, headdim, dstate) + dt_softplus: Whether to apply softplus to dt + intermediate_chunk_indices: (N,) optional int64 tensor of chunk indices at which to + extract intermediate SSM states. When provided, returns (final_states, + intermediate_states) instead of just final_states. + state_dtype: The data type of the ssm state + Return: + varlen_states: (batch, nheads, headdim, dstate), or + (varlen_states, intermediate_states) if intermediate_chunk_indices is provided + """ + + assert seq_idx is not None + + varlen_states = _mamba_chunk_scan_combined_fwd( + x, + dt, + A, + B, + C, + chunk_size, + out, + D=D, + z=z, + dt_bias=dt_bias, + initial_states=initial_states, + return_intermediate_states=return_intermediate_states, + seq_idx=seq_idx, + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, + intermediate_chunk_indices=intermediate_chunk_indices, + dt_softplus=dt_softplus, + dt_limit=dt_limit, + state_dtype=state_dtype, + ) + + return varlen_states diff --git a/megatron/core/ssm/ops/ssd_state_passing.py b/megatron/core/ssm/ops/ssd_state_passing.py new file mode 100644 index 00000000000..65b81a0ec31 --- /dev/null +++ b/megatron/core/ssm/ops/ssd_state_passing.py @@ -0,0 +1,149 @@ +# Copyright (c) 2024, Tri Dao, Albert Gu. +# Adapted from: +# https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/ops/triton/ssd_state_passing.py +# Adapted from vLLM project (Apache-2.0). + +import torch +import triton +import triton.language as tl + +from megatron.core.ssm.ops.determinism import autotune_configs + + +@triton.autotune( + configs=autotune_configs( + [ + triton.Config({"BLOCK_SIZE": 64}), + triton.Config({"BLOCK_SIZE": 128}), + triton.Config({"BLOCK_SIZE": 256}), + triton.Config({"BLOCK_SIZE": 512}), + triton.Config({"BLOCK_SIZE": 1024}), + triton.Config({"BLOCK_SIZE": 2048}), + ] + ), + key=["dim"], +) +@triton.jit +def _state_passing_fwd_kernel( + # Pointers to matrices + states_ptr, + out_ptr, + dA_cs_ptr, + initstates_ptr, + seq_idx_ptr, + cu_chunk_seqlens_ptr, + # Matrix dimensions + dim: tl.constexpr, + nchunks, + seqlen, + chunk_size: tl.constexpr, + # Strides + stride_states_chunk: tl.int64, + stride_states_head: tl.int64, + stride_states_dim: tl.constexpr, + stride_out_chunk: tl.int64, + stride_out_head: tl.int64, + stride_out_dim: tl.constexpr, + stride_dA_cs_head: tl.int64, + stride_dA_cs_chunk: tl.int64, + stride_dA_cs_csize: tl.constexpr, + stride_initstates_batch: tl.int64, + stride_initstates_head: tl.int64, + stride_initstates_dim: tl.constexpr, + stride_seq_idx_chunk: tl.constexpr, + # Meta-parameters + HAS_INITSTATES: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + pid_h = tl.program_id(axis=1) + pid_m = tl.program_id(axis=0) + + states_ptr += pid_h * stride_states_head + dA_cs_ptr += pid_h * stride_dA_cs_head + (chunk_size - 1) * stride_dA_cs_csize + out_ptr += pid_h * stride_out_head + + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + states_ptrs = states_ptr + offs_m * stride_states_dim + out_ptrs = out_ptr + offs_m * stride_out_dim + + if HAS_INITSTATES: + initstates_ptrs = ( + initstates_ptr + pid_h * stride_initstates_head + offs_m * stride_initstates_dim + ) + + states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + else: + states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + + prev_seq_idx = 0 + for c in range(nchunks): + new_states = tl.load(states_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + dA_cs = tl.load(dA_cs_ptr).to(tl.float32) + seq_idx = tl.load(seq_idx_ptr + c * stride_seq_idx_chunk) + # we have started a new sequence + if prev_seq_idx != seq_idx: + if HAS_INITSTATES: + initstates_ptrs = ( + initstates_ptr + + seq_idx * stride_initstates_batch + + pid_h * stride_initstates_head + + offs_m * stride_initstates_dim + ) + states = tl.load(initstates_ptrs, mask=offs_m < dim, other=0.0).to(tl.float32) + else: + states = tl.zeros((BLOCK_SIZE,), dtype=tl.float32) + + prev_seq_idx = seq_idx + states = tl.exp(dA_cs) * states + new_states + tl.store(out_ptrs, states, mask=offs_m < dim) + + states_ptrs += stride_states_chunk + dA_cs_ptr += stride_dA_cs_chunk + out_ptrs += stride_out_chunk + + +def _state_passing_fwd( + states, dA_cumsum, cu_chunk_seqlens, seq_idx, initial_states=None, out_dtype=None +): + nchunks, nheads, dim = states.shape + chunk_size = dA_cumsum.shape[-1] + assert dA_cumsum.shape == (nheads, nchunks, chunk_size) + seqlen = seq_idx.shape[-1] + out_dtype = states.dtype if out_dtype is None else out_dtype + out = torch.empty((nchunks, nheads, dim), device=states.device, dtype=out_dtype) + + initial_states_strides = ( + (initial_states.stride(0), initial_states.stride(1), initial_states.stride(2)) + if initial_states is not None + else (0, 0, 0) + ) + + grid = lambda META: (triton.cdiv(dim, META["BLOCK_SIZE"]), nheads) + with torch.cuda.device(states.device.index): + _state_passing_fwd_kernel[grid]( + states_ptr=states, + out_ptr=out, + dA_cs_ptr=dA_cumsum, + initstates_ptr=initial_states, + seq_idx_ptr=seq_idx, + cu_chunk_seqlens_ptr=cu_chunk_seqlens, + dim=dim, + nchunks=nchunks, + seqlen=seqlen if seq_idx is not None else 0, + chunk_size=chunk_size if seq_idx is not None else 0, + stride_states_chunk=states.stride(0), + stride_states_head=states.stride(1), + stride_states_dim=states.stride(2), + stride_out_chunk=out.stride(0), + stride_out_head=out.stride(1), + stride_out_dim=out.stride(2), + stride_dA_cs_head=dA_cumsum.stride(0), + stride_dA_cs_chunk=dA_cumsum.stride(1), + stride_dA_cs_csize=dA_cumsum.stride(2), + stride_initstates_batch=initial_states_strides[0], + stride_initstates_head=initial_states_strides[1], + stride_initstates_dim=initial_states_strides[2], + stride_seq_idx_chunk=seq_idx.stride(0), + HAS_INITSTATES=initial_states is not None, + ) + return out diff --git a/megatron/core/transformer/moe/token_dispatcher_inference.py b/megatron/core/transformer/moe/token_dispatcher_inference.py new file mode 100644 index 00000000000..6b851c252c5 --- /dev/null +++ b/megatron/core/transformer/moe/token_dispatcher_inference.py @@ -0,0 +1,326 @@ +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +""" +CUDA-graph-compatible token dispatcher for inference. + +This dispatcher is only used during CUDA-graphed inference iterations. It replaces +AlltoAll with AllGather/ReduceScatter for token exchange, keeping all metadata +GPU-resident to avoid host synchronizations that would break CUDA graph capture. + +Supports latency-optimized NVLS collectives (multimem all-gather/reduce-scatter) +on Hopper+ GPUs with BF16, with automatic fallback to NCCL. +""" + +from typing import List, Optional + +import torch + +from megatron.core.inference.communication.torch_symm_triton import ( + are_tensors_nvls_eligible, + multimem_all_gather_fused, + multimem_reduce_scatter, +) +from megatron.core.parallel_state import get_global_symmetric_memory_buffer_ep +from megatron.core.process_groups_config import ProcessGroupCollection +from megatron.core.tensor_parallel import ( + gather_from_sequence_parallel_region, + reduce_scatter_to_sequence_parallel_region, +) +from megatron.core.transformer.moe.token_dispatcher import MoEAllGatherTokenDispatcher +from megatron.core.transformer.transformer_config import TransformerConfig + + +class InferenceCUDAGraphTokenDispatcher(MoEAllGatherTokenDispatcher): + """ + CUDA-graph-compatible AllGather token dispatcher for inference. + + Only used during CUDA-graphed inference iterations. Swapped in by + MoELayer.set_inference_cuda_graphed_iteration() before graph capture + and swapped out by MoELayer.unset_inference_cuda_graphed_iteration() after. + + Key features: + - AllGather/ReduceScatter instead of AlltoAll for CUDA graph compatibility + - GPU-resident metadata (no host synchronization) + - NVLS collectives on Hopper+ with automatic NCCL fallback + """ + + def __init__( + self, + num_local_experts: int, + local_expert_indices: List[int], + config: TransformerConfig, + pg_collection: Optional[ProcessGroupCollection] = None, + ) -> None: + """ + Initialize the InferenceCUDAGraphTokenDispatcher. + + Args: + num_local_experts: Number of experts on this rank. + local_expert_indices: Global indices of experts on this rank. + config: Transformer configuration. + pg_collection: Process group collection for distributed ops. + """ + super().__init__( + num_local_experts=num_local_experts, + local_expert_indices=local_expert_indices, + config=config, + pg_collection=pg_collection, + ) + self.topk = config.moe_router_topk + + self.triton_nvls_kernels_allowed = not self.config.inference_disable_triton_nvls_kernels + + def _maybe_allocate_ag_buffers( + self, routing_map: torch.Tensor, probs: torch.Tensor, hidden_states: torch.Tensor + ) -> dict: + """Allocate a single symmetric memory output buffer for fused all-gather. + + Creates one contiguous symmetric memory buffer sized for the gathered + (global) routing_map, probs, and hidden_states, then returns sliced views + into it. This allows a single fused NVLS all-gather kernel to write all + three outputs in one launch. + + Args: + routing_map (torch.Tensor): Local routing map, shape [local_tokens, topk]. + Boolean or integer tensor mapping each token to its selected experts. + probs (torch.Tensor): Local routing probabilities, shape [local_tokens, topk]. + Normalized weights for each token's selected experts. + hidden_states (torch.Tensor): Local hidden states, shape [local_tokens, hidden_dim]. + + Returns: + dict: A dictionary with the following keys: + - "handle": Symmetric memory handle for NVLS ops, or None if + symmetric memory is unavailable. + - "routing_map": Raw byte view for the gathered routing map output. + - "routing_map_offset": Byte offset of routing_map within the buffer. + - "probs": Raw byte view for the gathered probs output. + - "probs_offset": Byte offset of probs within the buffer. + - "hidden_states": Raw byte view for the gathered hidden states output. + - "hidden_states_offset": Byte offset of hidden_states within the buffer. + When allocation fails, all tensor views are None and offsets are 0. + """ + _NONE = { + "handle": None, + "routing_map": None, + "routing_map_offset": 0, + "probs": None, + "probs_offset": 0, + "hidden_states": None, + "hidden_states_offset": 0, + } + + local_tokens = probs.size(0) + global_tokens = local_tokens * self.ep_size + topk = probs.size(-1) + hidden_dim = hidden_states.size(-1) + + result = get_global_symmetric_memory_buffer_ep().maybe_get_tensors( + [ + (global_tokens * topk, routing_map.dtype), + (global_tokens * topk, probs.dtype), + (global_tokens * hidden_dim, hidden_states.dtype), + ] + ) + + if result["handle"] is None: + return _NONE + + (rm_buf, rm_off), (p_buf, p_off), (hs_buf, hs_off) = result["tensors"] + return { + "handle": result["handle"], + "routing_map": rm_buf, + "routing_map_offset": rm_off, + "probs": p_buf, + "probs_offset": p_off, + "hidden_states": hs_buf, + "hidden_states_offset": hs_off, + } + + def _maybe_allocate_rs_buffer(self, x: torch.Tensor) -> dict: + """Allocate a symmetric memory buffer for reduce-scatter input. + + The buffer has the same shape and dtype as x so that x can be copied + into it before the NVLS reduce-scatter kernel. + + Args: + x (torch.Tensor): The global hidden states to be reduce-scattered, + shape [global_tokens, hidden_dim]. + + Returns: + dict: A dictionary with keys "handle" (symmetric memory handle, or + None if unavailable) and "tensor" (the allocated buffer, or None). + """ + symm_mem_buffer = get_global_symmetric_memory_buffer_ep().maybe_get_tensor( + list(x.size()), dtype=x.dtype + ) + return symm_mem_buffer + + def token_dispatch(self, hidden_states, probs): + """Gathers tokens from all EP ranks using AllGather. + + Performs all-gather on routing_map (stored in self.routing_map), probs, + and hidden_states so that every rank holds the full global view. + Uses latency-optimized fused NVLS multimem_all_gather on Hopper+ GPUs + with BF16 when symmetric memory is available. Falls back to NCCL otherwise. + + Args: + hidden_states (torch.Tensor): Local hidden states, + shape [local_tokens, hidden_dim]. + probs (torch.Tensor): Local routing probabilities, + shape [local_tokens, topk]. Normalized weights for each token's + selected experts. + + Returns: + tuple: (hidden_states, probs) gathered across all EP ranks. + - hidden_states (torch.Tensor): Shape [global_tokens, hidden_dim]. + - probs (torch.Tensor): Shape [global_tokens, topk]. + Also updates self.routing_map in-place to the gathered + shape [global_tokens, topk]. + """ + if self.ep_size == 1: + return hidden_states, probs + + # 1. Check inputs only: if inputs are 16-byte divisible, + # outputs (world_size * input) are too. + nvls_eligible = self.triton_nvls_kernels_allowed and are_tensors_nvls_eligible( + hidden_states, probs, self.routing_map + ) + ag_buffers = None + + if nvls_eligible: + # 2. Now attempt to allocate symmetric memory buffers for + # all-gather outputs. If allocation fails, fallback to NCCL. + ag_buffers = self._maybe_allocate_ag_buffers(self.routing_map, probs, hidden_states) + + # 3. Can use NVLS if eligible and buffers allocated successfully (handle is not None) + can_use_nvls = nvls_eligible and ag_buffers["handle"] is not None + + if can_use_nvls: + # Capture shapes for reshaping after all-gather + # Output shape: [local_tokens * ep_size, dim] + local_tokens = probs.size(0) + global_tokens = local_tokens * self.ep_size + topk = probs.size(1) + hidden_dim = hidden_states.size(1) + routing_map_dtype = self.routing_map.dtype + probs_dtype = probs.dtype + hidden_dtype = hidden_states.dtype + + # Fused NVLS all-gather: single kernel launch + single barrier for all 3 tensors + multimem_all_gather_fused( + ag_buffers["routing_map"].view( + torch.bfloat16 + ), # .view does not change the underlying data + self.routing_map.view(torch.bfloat16), + ag_buffers["routing_map_offset"], + ag_buffers["probs"].view(torch.bfloat16), + probs.view(torch.bfloat16), + ag_buffers["probs_offset"], + ag_buffers["hidden_states"].view(torch.bfloat16), + hidden_states.view(torch.bfloat16), + ag_buffers["hidden_states_offset"], + ag_buffers["handle"], + ) + self.routing_map = ( + ag_buffers["routing_map"].view(routing_map_dtype).view(global_tokens, topk) + ) + probs = ag_buffers["probs"].view(probs_dtype).view(global_tokens, topk) + hidden_states = ( + ag_buffers["hidden_states"].view(hidden_dtype).view(global_tokens, hidden_dim) + ) + else: + # Fallback to NCCL for all tensors + with torch.no_grad(): + self.routing_map = gather_from_sequence_parallel_region( + self.routing_map, group=self.tp_ep_group + ) + probs = gather_from_sequence_parallel_region(probs, group=self.tp_ep_group) + hidden_states = gather_from_sequence_parallel_region( + hidden_states, group=self.tp_ep_group + ) + + return hidden_states, probs + + def dispatch_postprocess(self, hidden_states, probs): + """Pass-through: returns inputs directly without permutation. + + Unlike the training dispatcher, this does not permute tokens or compute + tokens_per_expert. The downstream InferenceGroupedMLP (FlashInfer / + CUTLASS fused MoE kernel) operates directly on the routing map stored + in self.routing_map. + + Args: + hidden_states (torch.Tensor): Gathered hidden states, + shape [global_tokens, hidden_dim]. + probs (torch.Tensor): Gathered routing probabilities, + shape [global_tokens, topk]. + + Returns: + tuple: (hidden_states, tokens_per_expert, probs) where + tokens_per_expert is always None. + """ + return hidden_states, None, probs + + def combine_preprocess(self, expert_output): + """Pass-through: InferenceGroupedMLP already produces unpermuted output. + + No unpermutation is needed because dispatch_postprocess did not permute + the tokens in the first place. + + Args: + expert_output (torch.Tensor): Output from InferenceGroupedMLP, + shape [global_tokens, hidden_dim]. + + Returns: + torch.Tensor: The input tensor unchanged. + """ + return expert_output + + def token_combine(self, hidden_states): + """Combines expert outputs across EP ranks using Reduce-Scatter. + + Reduces the global expert output (summing contributions from each rank) + and scatters the result so each rank receives its local token slice. + Uses latency-optimized NVLS multimem_reduce_scatter on Hopper+ GPUs + with BF16 when symmetric memory is available. Falls back to NCCL otherwise. + + Args: + hidden_states (torch.Tensor): Combined expert output after routing + weights have been applied, shape [global_tokens, hidden_dim]. + + Returns: + torch.Tensor: Local slice of the reduced output, + shape [local_tokens, hidden_dim] where + local_tokens = global_tokens // ep_size. + """ + if self.ep_size == 1: + return hidden_states + + # Compute output shape first — check NVLS eligibility on the output, + # since if the smaller output is 16-byte divisible, the input is too. + output_shape = list(hidden_states.size()) + output_shape[0] = hidden_states.size(0) // self.ep_size + output = torch.empty(output_shape, dtype=hidden_states.dtype, device=hidden_states.device) + + # Check output only: if output is 16-byte divisible, input (world_size * output) is too. + nvls_eligible = self.triton_nvls_kernels_allowed and are_tensors_nvls_eligible(output) + rs_buffer = None + + if nvls_eligible: + rs_buffer = self._maybe_allocate_rs_buffer(hidden_states) + + can_use_nvls = nvls_eligible and rs_buffer["handle"] is not None + + if can_use_nvls: + # Copy input to symmetric memory for reduce-scatter + rs_buffer["tensor"].copy_(hidden_states) + + # Use latency-optimized NVLS reduce-scatter + multimem_reduce_scatter(output, rs_buffer["tensor"], rs_buffer["handle"]) + return output + else: + # Fallback to NCCL + hidden_states = reduce_scatter_to_sequence_parallel_region( + hidden_states, group=self.tp_ep_group + ) + return hidden_states diff --git a/megatron/training/config/__init__.py b/megatron/training/config/__init__.py new file mode 100644 index 00000000000..3d346ddd8fe --- /dev/null +++ b/megatron/training/config/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +from megatron.training.config.common_config import ( + RNGConfig, + ProfilingConfig, + DistributedInitConfig, +) +from megatron.training.config.training_config import ( + TrainingConfig, + ValidationConfig, + SchedulerConfig, + LoggerConfig, + CheckpointConfig, +) +from megatron.training.config.resilience_config import ( + RerunStateMachineConfig, + StragglerDetectionConfig, +) diff --git a/megatron/training/config/common_config.py b/megatron/training/config/common_config.py new file mode 100644 index 00000000000..eb7e6313dec --- /dev/null +++ b/megatron/training/config/common_config.py @@ -0,0 +1,135 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +from dataclasses import dataclass, field +from typing import Literal +import os + +@dataclass(kw_only=True) +class RNGConfig: + """Configuration settings for random number generation.""" + + seed: int = 1234 + """Random seed used for python, numpy, pytorch, and cuda.""" + + te_rng_tracker: bool = False + """Use the Transformer Engine version of the random number generator. + Required for CUDA graphs support.""" + + inference_rng_tracker: bool = False + """Use a random number generator configured for inference.""" + + data_parallel_random_init: bool = False + """Enable random initialization of params across data parallel ranks""" + + +@dataclass(kw_only=True) +class ProfilingConfig: + """Configuration settings for profiling the training process.""" + + use_nsys_profiler: bool = field(default=False, metadata={"argparse_meta": {"arg_names": ["--profile"], "dest": "profile"}}) + """Enable nsys profiling. When using this option, nsys options should be specified in + commandline. An example nsys commandline is + `nsys profile -s none -t nvtx,cuda -o --force-overwrite true + --capture-range=cudaProfilerApi --capture-range-end=stop`. + """ + + profile_step_start: int = 10 + """Global step to start profiling.""" + + profile_step_end: int = 12 + """Global step to stop profiling.""" + + use_pytorch_profiler: bool = False + """Use the built-in pytorch profiler. Useful if you wish to view profiles in tensorboard.""" + + pytorch_profiler_collect_shapes: bool = False + """Collect tensor shape in pytorch profiler.""" + + pytorch_profiler_collect_callstack: bool = False + """Collect callstack in pytorch profiler.""" + + pytorch_profiler_collect_chakra: bool = False + """Collect chakra trace in pytorch profiler.""" + + profile_ranks: list[int] = field(default_factory=lambda: []) + """Global ranks to profile.""" + + record_memory_history: bool = False + """Record memory history in last rank.""" + + memory_snapshot_path: str = "snapshot.pickle" + """Specifies where to dump the memory history pickle.""" + + record_shapes: bool = False + """Record shapes of tensors.""" + + nvtx_ranges: bool = False + """Enable NVTX range annotations for profiling. When enabled, inserts NVTX markers + to categorize execution in profiler output.""" + + +@dataclass(kw_only=True) +class DistributedInitConfig: + """Configuration settings for distributed training initialization.""" + + distributed_backend: Literal["nccl", "gloo"] = "nccl" + """Which backend to use for distributed training.""" + + distributed_timeout_minutes: int = 10 + """Timeout minutes for torch.distributed.""" + + align_grad_reduce: bool = True + """If not set, all PP stages will launch gradient reduces simultaneously. + Otherwise, each PP stage will independently launch as needed. + """ + + local_rank: int = field(default_factory=lambda: int(os.getenv("LOCAL_RANK", "0"))) + """local rank passed from distributed launcher.""" + + lazy_mpu_init: bool = False + """If set to True, initialize_megatron() skips DDP initialization and returns function to complete it instead. + Also turns on --use-cpu-initialization flag. This is for external DDP manager.""" + + use_megatron_fsdp: bool = False + """Use Megatron's Fully Sharded Data Parallel. Cannot be used together with use_torch_fsdp2.""" + + use_torch_fsdp2: bool = False + """Use the torch FSDP2 implementation. FSDP2 is not currently working with Pipeline Parallel. + It is still not in a stable release stage, and may therefore contain bugs or other + potential issues.""" + + nccl_communicator_config_path: str | None = None + """Path to the yaml file with NCCL communicator configurations. The number of min/max thread + groups and thread group cluster size of each communicator can be configured by setting + `min_ctas`, `max_ctas`, and `cga_cluster_size`.""" + + use_tp_pp_dp_mapping: bool = False + """If set, distributed ranks initialize order is changed from tp-cp-ep-dp-pp to tp-cp-ep-pp-dp. + """ + + enable_gloo_process_groups: bool = field(default=True, metadata={"argparse_meta": {"arg_names": ["--disable-gloo-process-groups"]}}) + """If enabled, create Gloo process groups for communications.""" + + use_sharp: bool = False + """Set the use of SHARP for the collective communications of data-parallel process groups. + When `True`, run barrier within each data-parallel process group, + which specifies the SHARP application target groups. + """ + + sharp_enabled_group: Literal["dp", "dp_replica"] | None = None + """IB SHARP can be enabled from only one communication group. + By default, it is enabled from dp group if not specified and use_sharp=True. + Available options: [dp, dp_replica] + """ + + high_priority_stream_groups: list[str] | None = field(default_factory=list) + """Specify which communicator groups should use high priority streams during creation. + Assigning high priority to communication streams ensures that communication kernels + are scheduled with higher priority, minimizing the exposed communication when it is + overlapped with other computation kernels. + """ + + distributed_timeout_seconds_after_init: int | None = None + """Timeout in seconds for process groups after initialization. This timeout is applied to all process groups after initialization and the first iteration completes.""" + + disable_jit_fuser: bool = False + """Disable the JIT fuser.""" diff --git a/megatron/training/config/resilience_config.py b/megatron/training/config/resilience_config.py new file mode 100644 index 00000000000..dd0bd716521 --- /dev/null +++ b/megatron/training/config/resilience_config.py @@ -0,0 +1,42 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +from dataclasses import dataclass +from typing import Literal + +@dataclass(kw_only=True) +class RerunStateMachineConfig: + """Configuration for the rerun state machine used for result validation or stats.""" + + error_injection_rate: int = 0 + """Rate at which to inject unexpected results, e.g. 1000 means + once every 1000 result validations""" + + error_injection_type: Literal["correct_result", "transient_error", "persistent_error"] = "transient_error" + """Type of error to inject. """ + + rerun_mode: Literal["disabled", "validate_results", "report_stats"] = "validate_results" + """Use re-run engine to validate results (default) or to emit stats + on variability of computations due to non-deterministic algorithms.""" + + check_for_nan_in_loss: bool = True + """Check for NaN in the loss.""" + + check_for_spiky_loss: bool = False + """Check for spiky loss.""" + + +@dataclass(kw_only=True) +class StragglerDetectionConfig: + """Configuration settings for detecting and logging GPU stragglers.""" + + log_straggler: bool = False + """If set, tracks and logs straggler per GPU.""" + + straggler_ctrlr_port: int = 65535 + """Port number to toggle StragglerDetector on/off at runtime""" + + straggler_minmax_count: int = 1 + """Number of ranks to report with high/low estimated throughput""" + + disable_straggler_on_startup: bool = False + """If set, StragglerDetector is disabled on startup.""" + diff --git a/megatron/training/config/training_config.py b/megatron/training/config/training_config.py new file mode 100644 index 00000000000..526a2e7ee59 --- /dev/null +++ b/megatron/training/config/training_config.py @@ -0,0 +1,517 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +from dataclasses import dataclass, field +import signal +from typing import Literal + +@dataclass(kw_only=True) +class TrainingConfig: + """Configuration settings related to the training loop.""" + + micro_batch_size: int | None = None + """Batch size per model instance (local batch size). Global batch size is local batch size times + data parallel size times number of micro batches.""" + + global_batch_size: int | None = None + """Training batch size. If set, it should be a multiple of micro-batch-size times + data-parallel-size. If this value is None, then use micro-batch-size * data-parallel-size + as the global batch size. This choice will result in 1 for number of micro-batches.""" + + rampup_batch_size: list[int] | None = field(default=None, metadata={"argparse_meta": {"nargs": 3}}) + """Batch size ramp up with the following values: , , + + For example: + rampup-batch-size = [16, 8, 300000] + global-batch-size 1024 + will start with global batch size 16 and over (1024 - 16) / 8 = 126 intervals will increase + the batch size linearly to 1024. In each interval we will use approximately + 300000 / 126 = 2380 samples. + """ + + decrease_batch_size_if_needed: bool = False + """If set, decrease batch size if microbatch_size * dp_size does not + divide batch_size. Old batch_size will be restored if training is re-started + with dp_size that divides batch_size // microbatch_size.""" + + empty_unused_memory_level: Literal[0, 1, 2] = 0 + """Call torch.cuda.empty_cache() each iteration (training and eval), to reduce fragmentation. + 0=off, 1=moderate, 2=aggressive. + """ + + check_weight_hash_across_dp_replicas_interval: int | None = None + """Interval to check weight hashes are same across DP replicas. If not specified, weight hashes not checked.""" + + train_sync_interval: int | None = None + """Training CPU-GPU synchronization interval, to ensure that CPU is not running too far ahead of GPU.""" + + train_iters: int | None = None + """Total number of iterations to train over all training runs. + Note that either train_iters or train_samples should be provided. + """ + + train_samples: int | None = None + """Total number of samples to train over all training runs. + Note that either train_iters or train_samples should be provided.""" + + exit_interval: int | None = None + """Exit the program after the iteration is divisible by this value.""" + + exit_duration_in_mins: int | None = None + """Exit the program after this many minutes.""" + + exit_signal_handler: bool = False + """Dynamically save the checkpoint and shutdown the training if SIGTERM is received""" + + exit_signal: signal.Signals = signal.SIGTERM + """Signal for the signal handler to detect.""" + + exit_signal_handler_for_dataloader: bool = False + """Use signal handler for dataloader workers""" + + manual_gc: bool = False + """Disable the threshold-based default garbage collector and trigger the garbage collection + manually. Manual garbage collection helps to align the timing of the collection across ranks + which mitigates the impact of CPU-associated jitters. When the manual gc is enabled, garbage + collection is performed only at the start and the end of the validation routine by default.""" + + manual_gc_interval: int = 0 + """Training step interval to trigger manual garbage collection. Values > 0 will trigger garbage + collections between training steps. + """ + + manual_gc_eval: bool = True + """When using manual garbage collection, this controls garbage collection at the start and the + end of each evaluation run. + """ + + iterations_to_skip: list[int] = field(default_factory=list) + """List of iterations to skip during training, empty by default.""" + + +@dataclass(kw_only=True) +class ValidationConfig: + """Configuration settings related to validation during or after model training.""" + + eval_iters: int | None = 100 + """Number of iterations to run for evaluation. Used for both validation and test. If not set, + evaluation will not run.""" + + eval_interval: int | None = None + """Interval between running evaluation on validation set. If not set, evaluation will not run + during training. + """ + + skip_train: bool = False + """If set, bypass the training loop, perform evaluation for validation/test, and exit.""" + + test_mode: bool = False + """Run all real-time test alongside the experiment.""" + + full_validation: bool = False + """If set, each time validation occurs it uses the full validation dataset(s). This currently only works for GPT datasets!""" + + multiple_validation_sets: bool = False + """If set, multiple datasets listed in the validation split are evaluated independently with a + separate loss for each dataset in the list. This argument requires that no weights are + included in the list. + """ + + +@dataclass(kw_only=True) +class SchedulerConfig: + """Configuration settings for the learning rate scheduler and weight decay.""" + + # ---------------- Learning rate config. ---------------- + lr_decay_style: Literal["constant", "linear", "cosine", "inverse-square-root", "WSD"] = "linear" + """Learning rate decay function.""" + + lr_wsd_decay_style: Literal["exponential", "linear", "cosine", "minus_sqrt"] = "exponential" + """Decay style for the annealing phase of WSD""" + + lr_decay_iters: int | None = None + """number of iterations to decay learning rate over, If None defaults to train iters""" + + lr_decay_samples: int | None = None + """number of samples to decay learning rate over, If None defaults to train samples""" + + lr_wsd_decay_iters: int | None = None + """number of iterations for the annealing phase in the wsd schedule""" + + lr_wsd_decay_samples: int | None = None + """number of samples for the annealing phase in the wsd schedule""" + + lr_warmup_fraction: float | None = None + """fraction of lr-warmup-(iters/samples) to use for warmup (as a float)""" + + lr_warmup_iters: int = 0 + """number of iterations to linearly warmup learning rate over.""" + + lr_warmup_samples: int = 0 + """number of samples to linearly warmup learning rate over.""" + + lr_warmup_init: float = 0.0 + """Initial value for learning rate warmup. The scheduler starts warmup from this value.""" + + lr_decay_steps: int | None = field(init=False, default=None) + """number of samples to decay learning rate over. Calculated at runtime from + lr_decay_iters or lr_decay_samples. + """ + + lr_warmup_steps: int | None = field(init=False, default=None) + """number of samples to warmup learning rate over. Calculated at runtime from + lr_warmup_fraction, lr_warmup_iters, or lr_warmup_samples. + """ + + override_opt_param_scheduler: bool = field(default=False, metadata={"argparse_meta": {"arg_names": ["--override-opt_param-scheduler", "--override-opt-param-scheduler"]}}) + """Reset the values of the scheduler (learning rate, warmup iterations, minimum learning rate, + maximum number of iterations, and decay style) from input arguments and ignore values from + checkpoints. Note that all the above values will be reset.""" + + use_checkpoint_opt_param_scheduler: bool = field(default=False, metadata={"argparse_meta": {"arg_names": ["--use-checkpoint-opt_param-scheduler", "--use-checkpoint-opt-param-scheduler"]}}) + """Use checkpoint to set the values of the scheduler (learning rate, warmup iterations, + minimum learning rate, maximum number of iterations, and decay style) from checkpoint + and ignore input arguments.""" + + # ---------------- Regularization config. ---------------- + + start_weight_decay: float | None = None + """Initial weight decay coefficient for L2 regularization.""" + + end_weight_decay: float | None = None + """End of run weight decay coefficient for L2 regularization.""" + + weight_decay_incr_style: Literal["constant", "linear", "cosine"] = "constant" + """Weight decay increment function.""" + + no_weight_decay_cond_type: Literal["qwen3_next"] | None = None + """Type of no weight decay condition. Choices: + None (default): param no weight decay if and only if it is 1D; or it is bias; + or it is embedding and embedding_init_method_std is not None. + "qwen3_next": In addition to the default rules, apply weight decay to qk layernorm as a special case.""" + + wd_incr_steps: int | None = field(init=False, default=None) + """Number of samples to increment weight decay over. Calculated at runtime.""" + + wsd_decay_steps: int | None = field(init=False, default=None) + """Number of samples to decay WSD weight decay. Calculated at runtime.""" + + +@dataclass(kw_only=True) +class LoggerConfig: + """Configuration settings for logging, including TensorBoard and WandB.""" + + log_interval: int = 100 + """Report loss and timing interval.""" + + log_params_norm: bool = False + """If set, calculate and log parameters norm.""" + + log_throughput: bool = False + """If set, calculate and log throughput per GPU.""" + + log_throughput_to_tensorboard: bool = False + """Enable throughput logging to tensorboard.""" + + throughput_window_size: int = 100 + """Number of batches to use for a rolling average of throughput.""" + + log_progress: bool = False + """If set, log progress (in terms of number of processed tokens and number of floating-point operations) + to progress.txt file in checkpoint directory. + """ + + timing_log_level: Literal[0, 1, 2] = 0 + """Granularity level to measure and report timing. + 0: report only iteration time and make sure timing does not introduce extra overhead. + 1: report timing for operations that are executed very limited times (basically once) during each iteration + (such as gradient all-reduce) + 2: report timing for operations that migh be executed numerous times during each iteration. + Note that setting the level to 1 or 2 might cause increase in iteration time. + """ + + timing_log_option: Literal["max", "minmax", "all"] = "minmax" + """Options for logging timing: + max: report the max timing across all ranks + minmax: report min and max timings across all ranks + all: report timings of all ranks. + """ + + tensorboard_dir: str | None = None + """Write TensorBoard logs to this directory.""" + + tensorboard_log_interval: int = 1 + """Report to tensorboard interval.""" + + tensorboard_queue_size: int = 1000 + """Size of the tensorboard queue for pending events and summaries + before one of the 'add' calls forces a flush to disk. + """ + + log_timers_to_tensorboard: bool = False + """If set, write timers to tensorboard.""" + + log_loss_scale_to_tensorboard: bool = True + """Disable loss-scale logging to tensorboard.""" + + log_validation_ppl_to_tensorboard: bool = False + """If set, write validation perplexity to tensorboard.""" + + log_memory_to_tensorboard: bool = False + """Enable memory logging to tensorboard.""" + + memory_keys: dict[str, str] | None = None + """Names of memory statistics to log from `torch.cuda.memory_stats()`""" + + log_memory_interval: int | None = None + """Report memory interval.""" + + log_device_memory_used: bool = False + """Log device memory used (as reported by nvidia-smi).""" + + log_l2_norm_grad_to_tensorboard: bool = False + """Enable gradients logging to tensorboard.""" + + log_num_zeros_in_grad: bool = False + """If set, calculate and log the number of zeros in gradient.""" + + log_max_attention_logit: bool = False + """Enable max attention logit logging to tensorboard.""" + + log_runtime_to_tensorboard: bool = False + """Enable runtime metrics logging to tensorboard.""" + + runtime_time_unit: str = "hours" + """Time unit to use for time logging. """ + + barrier_with_L1_time: bool = field(default=True, metadata={"argparse_meta": {"arg_names": ["--no-barrier-with-level-1-timing"]}}) + """If not disabled, use barrier with level 1 time measurements. Note that this is up to the user to + make sure calling barrier with their timers will not result in hangs. This can happen if for + example the user adds a level 1 timer that is not called by all ranks. + """ + + log_world_size_to_tensorboard: bool = False + """Enable world size logging to tensorboard.""" + + wandb_project: str | None = None + """The wandb project name. Ignore wandb by default.""" + + wandb_exp_name: str | None = None + """The wandb experiment name.""" + + wandb_save_dir: str | None = None + """Path to save the wandb results locally.""" + + wandb_entity: str | None = None + """The wandb entity name. It is useful when there are multiple sub-projects in a project.""" + + logging_level: int | None = None + """Set default logging level""" + + filter_warnings: bool = True + """Filter out warning messages""" + + modules_to_filter: list[str] | None = None + """List of modules to filter out from the logs""" + + set_level_for_all_loggers: bool = False + """Set the logging level for all loggers. If False, only level for NeMo loggers will be set.""" + + log_energy: bool = False + """If set, log energy consumption (in Joules).""" + + save_config_filepath: str | None = None + """If set, save the task configuration (ConfigContainer) to this file.""" + + +@dataclass(kw_only=True) +class CheckpointConfig: + """Configuration settings for model checkpointing (saving and loading).""" + + save: str | None = None + """Output directory to save checkpoints to.""" + + save_interval: int | None = field(default=None, metadata={"argparse_meta": {"arg_names": ["--save-interval", "--persistent-save-interval"]}}) + """Number of iterations between persistent checkpoint saves.""" + + save_wgrads_interval: int | None = None + """Number of iterations between wgrad (main_grad) saves.""" + + save_dgrads_interval: int | None = None + """Number of iterations between dgrad saves.""" + + save_retain_interval: int | None = None + """Number of iterations between retained checkpoints + (other checkpoints except the last checkpoint are automatically deleted). + """ + + most_recent_k: int | None = -1 + """Number of latest checkpoint to be saved.""" + + save_optim: bool = True + """Do not save current optimizer.""" + + save_rng: bool = True + """Do not save current rng state.""" + + load: str | None = None + """Directory containing a model checkpoint.""" + + load_optim: bool = True + """Do not load optimizer when loading checkpoint.""" + + load_main_params_from_ckpt: bool = False + """Load main parameters from checkpoint. When loading a model from a checkpoint without loading + the optimizer, the model parameters are updated but for fp16 optimizer with main parameters, + the main parameters need to also be updated. + """ + + load_rng: bool = True + """Do not load rng state when loading checkpoint.""" + + non_persistent_save_interval: int | None = None + """Number of iterations between non-persistent saves.""" + + non_persistent_ckpt_type: Literal["global", "local", "in_memory"] | None = None + """Type of non-persistent model checkpoints. + "global" - Saved as a standard checkpoint (e.g., on Lustre) with old checkpoints being removed. + "local" - [TBD] Each rank saves a portion of the checkpoint locally (e.g., on SSD/ramdisk). + "in_memory" - [TBD] A special kind of local checkpoint that avoids serialization. + None - No non-persistent checkpointing (default option).""" + + non_persistent_global_ckpt_dir: str | None = None + """Directory containing global non-persistent model checkpoints.""" + + non_persistent_local_ckpt_dir: str | None = None + """Directory containing local non-persistent model checkpoints.""" + + non_persistent_local_ckpt_algo: Literal["fully_parallel", "atomic"] = "fully_parallel" + """Algorithm for local non-persistent checkpointing.""" + + finetune: bool = False + """Load model for finetuning. Do not load optimizer or rng state from checkpoint and set iteration to 0. + Assumed when loading a release checkpoint.""" + + pretrained_checkpoint: str | None = None + """Directory containing a pretrained model checkpoint for finetuning.""" + + ckpt_step: int | None = None + """Checkpoint step to load model from.""" + + use_checkpoint_args: bool = False + """Override model-related command-line arguments with arguments from checkpoint""" + + use_mp_args_from_checkpoint_args: bool = False + """Copy model parallelism command-line arguments from checkpoint""" + + use_tokenizer_model_from_checkpoint_args: bool = True + """If set, do not use tokenizer model path from checkpoint""" + + exit_on_missing_checkpoint: bool = False + """If 'load' is set, but checkpoint is not found (e.g., path typo), then exit instead of random initialization.""" + + ckpt_format: Literal["torch", "torch_dist", "torch_dcp", "fsdp_dtensor"] = "torch_dist" + """ Checkpoint format to use. torch is the format used by torch.save/load. + torch_dist is a megatron built-in distributed checkpointing format. + torch_dcp is the torch.distributed.checkpoint format. + fsdp_dtensor is a torch DCP native, Megatron FSDP training-specific checkpoint format. + """ + + auto_detect_ckpt_format: bool = False + """Determine if the checkpoint format is in legacy or distributed format. If False, + expects distributed checkpoint iff args.ckpt_format != "torch". Might slow down + loading a bit (double rank0 ckpt load). + """ + + ckpt_convert_format: Literal["torch", "torch_dist"] | None = None + """Checkpoint format for conversion.""" + + ckpt_convert_save: str | None = None + """Save directory for converted checkpoint.""" + + ckpt_convert_update_legacy_dist_opt_format: bool = False + """When loading a checkpoint, update the legacy format for the distributed optimizer, + which previously used a merged param/grad buffer and a different bucket mapping. + The legacy format was deprecated on Feb 13, 2024. + """ + + ckpt_fully_parallel_save: bool = True + """Disable applying full save parallelization across DP for distributed checkpoints. + Depending on ckpt format might decrease the number of files in the checkpoint. + Makes DistributedOptimizer checkpoint non-reshardable.""" + + async_save: bool = False + """Apply async checkpointing save. Currently works only with `torch_dist` distributed checkpoint format.""" + + use_persistent_ckpt_worker: bool = False + """Use a persistent background worker for async checkpoint saves. When enabled, creates a dedicated + worker thread/process for handling async saves. When disabled, uses temporal workers that are + created and destroyed for each save operation.""" + + ckpt_fully_parallel_load: bool = False + """Apply full load parallelization across DP for distributed checkpoints.""" + + ckpt_fully_parallel_load_exchange_algo: Literal["broadcast", "gather_rounds", "gather_object"] = "broadcast" + """Algorithm for fully parallel load of distributed checkpoints. + "broadcast"(default): Broadcast the checkpoint from rank 0 to all other ranks. + "gather_rounds": Gather the checkpoint from all ranks in rounds. + "gather_object": Gather the checkpoint from all ranks in a single operation. + """ + + ckpt_fully_parallel_save_process_group: Literal["dp", "ep_dp"] = "dp" + """Process group for fully parallel save of distributed checkpoints. + "dp"(default): Data parallel process group. + "ep_dp": Expert data parallel process group. + """ + + ckpt_fully_parallel_load_process_group: Literal["dp", "ep_dp"] = "dp" + """Process group for fully parallel load of distributed checkpoints. + "dp"(default): Data parallel process group. + "ep_dp": Expert data parallel process group. + """ + + ckpt_assume_constant_structure: bool = False + """Assume the checkpoint structure is constant across saves to enable optimizations.""" + + strict_fsdp_dtensor_load: bool = True + """Whether to enforce strict loading for FSDP DTensor checkpoints. When False, allows partial loading.""" + + dist_ckpt_strictness: Literal[ + "assume_ok_unexpected", + "log_unexpected", + "log_all", + "raise_unexpected", + "raise_all", + "return_unexpected", + "return_all", + "ignore_all", + ] = "assume_ok_unexpected" + """Determine handling of key mismatch during checkpoint load. Check StrictHandling docs for flags meaning. + NOTE: This flag controls only distributed checkpoint load from storage, not loading state dict into the model.""" + + dist_ckpt_save_pre_mcore_014: bool = False + """Revert checkpointing simplifications introduced in Megatron-Core v0.14. + This option affects only checkpoint saving format and will be removed soon + (checkpoint load format is determined based on checkpoint metadata).""" + + dist_ckpt_optim_fully_reshardable: bool = False + """Make optimizer distributed checkpoint fully reshardable (TP/PP/EP/DP) as opposed to plain DP reshardability.""" + + distrib_optim_fully_reshardable_mem_efficient: bool = False + """During distributed optimizer checkpoint save and load tries to use as little memory as possible + by using Gloo (instead of NCCL) and only one rank for saving. Turn on only if experiencing host or device memory + issues. Has affect only with `dist_ckpt_optim_fully_reshardable` flag.""" + + save_tokenizer_assets: bool = True + """Save tokenizer files to checkpoint directory. When enabled, saves all tokenizer artifacts + (vocab files, special tokens, tokenizer config) to make checkpoints self-contained and portable. + Set to False for performance-sensitive scenarios where tokenizer files are not needed.""" + + replication: bool = False + """If set, replication of local checkpoints is enabled. Needs to be enabled on all ranks.""" + + replication_jump: int | None = None + """Specifies `J`, the spacing between ranks storing replicas of a given rank's data. Replicas + for rank `n` may be on ranks `n+J`, `n+2J`, ..., or `n-J`, `n-2J`, etc. This flag has an + effect only if --replication is used. and must be consistent across all ranks.""" + + replication_factor: int = 2 + """Number of machines storing the replica of a given rank's data.""" diff --git a/tests/unit_tests/inference/test_batch_dimension_utils.py b/tests/unit_tests/inference/test_batch_dimension_utils.py new file mode 100644 index 00000000000..d67c390068a --- /dev/null +++ b/tests/unit_tests/inference/test_batch_dimension_utils.py @@ -0,0 +1,372 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +""" +Unit tests for CUDAGraphBatchDimensionBuilder.match_graph_config with expert parallelism. +""" + +import pytest +import torch +import torch.distributed as dist + +from megatron.core import parallel_state as ps +from megatron.core.inference.batch_dimensions_utils import ( + CUDAGraphBatchDimensionBuilder, + InferenceBatchDimensions, +) +from tests.unit_tests.test_utilities import Utils + +BD = InferenceBatchDimensions + +# Common config shared across tests +MAX_REQUESTS = 256 +MAX_TOKENS = 2048 +MAX_SEQ_LEN = 4096 +TP_SIZE = 1 +MIXED_PREFILL_COUNT = 16 + + +def _generate_graphs(num_cuda_graphs, use_non_decode=True): + """Generate cuda graph batch dimensions using the builder.""" + graph_list, _ = CUDAGraphBatchDimensionBuilder.generate_cuda_graph_batch_dimensions_list( + tp_size=TP_SIZE, + num_cuda_graphs=num_cuda_graphs, + cuda_graph_max_tokens=MAX_REQUESTS, + cuda_graph_mixed_prefill_count=MIXED_PREFILL_COUNT, + max_requests=MAX_REQUESTS, + max_tokens=MAX_TOKENS, + max_sequence_length=MAX_SEQ_LEN, + use_cuda_graphs_for_non_decode_steps=use_non_decode, + ) + return graph_list + + +def _match( + real, graph_list, ep_group, strict=False, decode_only=False, explicit_chunked_prefill=False +): + return CUDAGraphBatchDimensionBuilder.match_graph_config( + real_batch_dim=real, + cuda_graph_batch_dimensions_list=graph_list, + strict=strict, + decode_only_cuda_graphs=decode_only, + explicit_chunked_prefill=explicit_chunked_prefill, + ep_group=ep_group, + cuda_graph_mixed_prefill_count=MIXED_PREFILL_COUNT, + ) + + +def _assert_consistent_across_ranks(result, ep_group): + """Assert that the match result is the same on every EP rank. + + Either all ranks return None, or all ranks return a config with the + same token_count (which is what the all-reduce synchronises). + """ + if result is None: + flag = torch.zeros(1, dtype=torch.int32, device="cuda") + else: + flag = torch.ones(1, dtype=torch.int32, device="cuda") + + # If any rank got None, all must get None; if any rank got a match, all must. + flag_sum = flag.clone() + dist.all_reduce(flag_sum, op=dist.ReduceOp.SUM, group=ep_group) + ep_size = dist.get_world_size(ep_group) + assert ( + flag_sum.item() == 0 or flag_sum.item() == ep_size + ), f"Inconsistent match: {flag_sum.item()}/{ep_size} ranks got a match" + + if result is not None: + tc = torch.tensor([result.token_count], dtype=torch.int32, device="cuda") + tc_max = tc.clone() + tc_min = tc.clone() + dist.all_reduce(tc_max, op=dist.ReduceOp.MAX, group=ep_group) + dist.all_reduce(tc_min, op=dist.ReduceOp.MIN, group=ep_group) + assert ( + tc_max.item() == tc_min.item() + ), f"Token count mismatch across EP ranks: min={tc_min.item()}, max={tc_max.item()}" + + +class TestCUDAGraphTokenCountAlignment: + """Verify that mixed/prefill graph token counts are a subset of decode graph token counts.""" + + @pytest.mark.parametrize("num_cuda_graphs", [1, 16, 32, -1]) + def test_mixed_token_counts_subset_of_decode(self, num_cuda_graphs): + """Every token count in the mixed/prefill graph pool must also appear + in the decode-only pool. Otherwise, when EP syncs token counts across + ranks, decode-only ranks cannot find a graph at the same token count + as prefill ranks, causing inconsistent matching.""" + graph_list = _generate_graphs(num_cuda_graphs) + + decode_token_counts = {bd.token_count for bd in graph_list if bd.prefill_req_count == 0} + mixed_token_counts = {bd.token_count for bd in graph_list if bd.prefill_req_count > 0} + + mixed_only = mixed_token_counts - decode_token_counts + assert not mixed_only, ( + f"Mixed/prefill token counts with no decode graph: {sorted(mixed_only)}. " + f"This will cause EP rank mismatch when some ranks are decode-only " + f"and others have prefill." + ) + + # Decode-only token counts not in the mixed pool are allowed, but only + # below MIXED_PREFILL_COUNT. The EP adjustment elevates token counts to + # at least MIXED_PREFILL_COUNT when any rank has prefill, so any decode + # token count >= MIXED_PREFILL_COUNT must have a mixed counterpart. + decode_only = decode_token_counts - mixed_token_counts + large_decode_only = {tc for tc in decode_only if tc >= MIXED_PREFILL_COUNT} + assert not large_decode_only, ( + f"Decode-only token counts >= MIXED_PREFILL_COUNT ({MIXED_PREFILL_COUNT}) " + f"with no mixed/prefill graph: {sorted(large_decode_only)}. " + f"The EP token count elevation cannot guarantee alignment for these." + ) + + +class TestMatchGraphConfigWithEP: + """Tests for match_graph_config with expert parallelism. + + Uses the world group as the EP group (all 8 GPUs form one EP group). + """ + + def setup_method(self, method): + Utils.initialize_model_parallel( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + expert_model_parallel_size=Utils.world_size, + ) + + def teardown_method(self, method): + Utils.destroy_model_parallel() + + @staticmethod + def _get_ep_group(): + """Return the EP group created by initialize_model_parallel.""" + return ps.get_expert_model_parallel_group() + + # ------------------------------------------------------------------ # + # 1. All ranks same decode batch → consistent match + # ------------------------------------------------------------------ # + @pytest.mark.internal + @pytest.mark.parametrize("num_cuda_graphs", [1, 16, 32, -1]) + def test_uniform_decode_batch(self, num_cuda_graphs): + """All EP ranks have the same decode-only batch → should all match the same graph.""" + ep_group = self._get_ep_group() + graph_list = _generate_graphs(num_cuda_graphs) + real = BD(token_count=32, prefill_req_count=0, decode_req_count=32) + + result = _match(real, graph_list, ep_group=ep_group) + _assert_consistent_across_ranks(result, ep_group) + assert result is not None, "Should find a matching graph for uniform decode batch" + + # ------------------------------------------------------------------ # + # 2. Different token counts across EP ranks → all-reduce takes max + # ------------------------------------------------------------------ # + @pytest.mark.internal + @pytest.mark.parametrize("num_cuda_graphs", [1, 16, 32, -1]) + def test_varying_decode_token_counts(self, num_cuda_graphs): + """EP ranks have different decode token counts. The all-reduce + should take the max, and all ranks should match the same graph.""" + ep_group = self._get_ep_group() + graph_list = _generate_graphs(num_cuda_graphs) + rank = dist.get_rank() + + # Each rank gets a different token count: 8, 16, 24, ... + token_count = (rank + 1) * 8 + real = BD(token_count=token_count, prefill_req_count=0, decode_req_count=token_count) + + result = _match(real, graph_list, ep_group=ep_group) + _assert_consistent_across_ranks(result, ep_group) + assert result is not None + + # ------------------------------------------------------------------ # + # 3. decode_only_cuda_graphs=True, some ranks have prefill → all None + # ------------------------------------------------------------------ # + @pytest.mark.internal + @pytest.mark.parametrize("num_cuda_graphs", [1, 16, 32, -1]) + def test_decode_only_graphs_with_mixed_ranks(self, num_cuda_graphs): + """When decode_only_cuda_graphs=True and at least one EP rank has a + prefill request, ALL ranks should get None (eager mode).""" + ep_group = self._get_ep_group() + graph_list = _generate_graphs(num_cuda_graphs) + rank = dist.get_rank() + + # Rank 0 has a mixed batch (prefill + decode), all others decode-only + if rank == 0: + real = BD(token_count=64, prefill_req_count=2, decode_req_count=10) + else: + real = BD(token_count=32, prefill_req_count=0, decode_req_count=32) + + result = _match(real, graph_list, ep_group=ep_group, decode_only=True) + _assert_consistent_across_ranks(result, ep_group) + assert ( + result is None + ), "All ranks should run eager when decode_only=True and some rank has prefill" + + # ------------------------------------------------------------------ # + # 4. explicit_chunked_prefill=True, some ranks prefill → all None + # ------------------------------------------------------------------ # + @pytest.mark.internal + @pytest.mark.parametrize("num_cuda_graphs", [1, 16, 32, -1]) + def test_explicit_chunked_prefill_with_mixed_ranks(self, num_cuda_graphs): + """When explicit_chunked_prefill=True and some EP rank has prefill, + ALL ranks should get None (eager mode).""" + ep_group = self._get_ep_group() + graph_list = _generate_graphs(num_cuda_graphs) + rank = dist.get_rank() + + if rank == 0: + real = BD(token_count=64, prefill_req_count=2, decode_req_count=10) + else: + real = BD(token_count=32, prefill_req_count=0, decode_req_count=32) + + result = _match(real, graph_list, ep_group=ep_group, explicit_chunked_prefill=True) + _assert_consistent_across_ranks(result, ep_group) + assert result is None, "All ranks should run eager with explicit_chunked_prefill" + + # ------------------------------------------------------------------ # + # 5. Mixed prefill graphs with strict matching + # ------------------------------------------------------------------ # + @pytest.mark.internal + @pytest.mark.parametrize("num_cuda_graphs", [1, 16, 32, -1]) + def test_strict_matching_with_mixed_prefill(self, num_cuda_graphs): + """With strict matching, request counts are synced across EP ranks + via all-reduce. All ranks should still get a consistent result.""" + ep_group = self._get_ep_group() + graph_list = _generate_graphs(num_cuda_graphs) + rank = dist.get_rank() + + # Varying prefill/decode split across ranks + prefill = min(rank + 1, MIXED_PREFILL_COUNT) + decode = 16 - prefill + real = BD(token_count=64, prefill_req_count=prefill, decode_req_count=decode) + + result = _match(real, graph_list, ep_group=ep_group, strict=True) + _assert_consistent_across_ranks(result, ep_group) + + # ------------------------------------------------------------------ # + # 6. Non-strict matching with mixed prefill + # ------------------------------------------------------------------ # + @pytest.mark.internal + @pytest.mark.parametrize("num_cuda_graphs", [1, 16, 32, -1]) + def test_non_strict_matching_with_mixed_prefill(self, num_cuda_graphs): + """Non-strict matching: prefill slots can serve decode. Token count + is synced across EP ranks; result must be consistent.""" + ep_group = self._get_ep_group() + graph_list = _generate_graphs(num_cuda_graphs) + rank = dist.get_rank() + + prefill = min(rank + 1, MIXED_PREFILL_COUNT) + decode = 16 - prefill + real = BD(token_count=64, prefill_req_count=prefill, decode_req_count=decode) + + result = _match(real, graph_list, ep_group=ep_group) + _assert_consistent_across_ranks(result, ep_group) + + # ------------------------------------------------------------------ # + # 7. Mixed decode/prefill across ranks — strict matching + # ------------------------------------------------------------------ # + @pytest.mark.internal + @pytest.mark.parametrize("num_cuda_graphs", [1, 16, 32, -1]) + def test_mixed_decode_and_prefill_ranks_strict(self, num_cuda_graphs): + """Some EP ranks are pure decode, others have prefill requests. + With strict matching the all-reduce syncs request counts to the + max across ranks. Result must be consistent.""" + ep_group = self._get_ep_group() + graph_list = _generate_graphs(num_cuda_graphs) + rank = dist.get_rank() + + # Even ranks: pure decode (32 tokens) + # Odd ranks: mixed prefill (64 tokens, 2 prefill + 14 decode) + if rank % 2 == 0: + real = BD(token_count=32, prefill_req_count=0, decode_req_count=32) + else: + real = BD(token_count=64, prefill_req_count=2, decode_req_count=14) + + result = _match(real, graph_list, ep_group=ep_group, strict=True) + _assert_consistent_across_ranks(result, ep_group) + + # ------------------------------------------------------------------ # + # 8. Mixed decode/prefill across ranks — non-strict matching + # ------------------------------------------------------------------ # + @pytest.mark.internal + @pytest.mark.parametrize("num_cuda_graphs", [1, 16, 32, -1]) + def test_mixed_decode_and_prefill_ranks_non_strict(self, num_cuda_graphs): + """Some EP ranks are pure decode, others have prefill requests. + Non-strict matching only syncs token counts (not request counts). + Result must be consistent.""" + ep_group = self._get_ep_group() + graph_list = _generate_graphs(num_cuda_graphs) + rank = dist.get_rank() + + # Even ranks: pure decode (32 tokens) + # Odd ranks: mixed prefill (64 tokens, 2 prefill + 14 decode) + if rank % 2 == 0: + real = BD(token_count=32, prefill_req_count=0, decode_req_count=32) + else: + real = BD(token_count=64, prefill_req_count=2, decode_req_count=14) + + result = _match(real, graph_list, ep_group=ep_group) + _assert_consistent_across_ranks(result, ep_group) + + # ------------------------------------------------------------------ # + # 9. All ranks decode-only with decode_only_cuda_graphs → should match + # ------------------------------------------------------------------ # + @pytest.mark.internal + @pytest.mark.parametrize("num_cuda_graphs", [1, 16, 32, -1]) + def test_decode_only_graphs_all_decode(self, num_cuda_graphs): + """When all EP ranks are decode-only and decode_only_cuda_graphs=True, + a match should be found.""" + ep_group = self._get_ep_group() + graph_list = _generate_graphs(num_cuda_graphs) + rank = dist.get_rank() + + token_count = (rank + 1) * 4 + real = BD(token_count=token_count, prefill_req_count=0, decode_req_count=token_count) + + result = _match(real, graph_list, ep_group=ep_group, decode_only=True) + _assert_consistent_across_ranks(result, ep_group) + assert result is not None, "All-decode batch with decode_only_cuda_graphs should match" + + # ------------------------------------------------------------------ # + # 10. Real batch exceeds all graphs → None on all ranks + # ------------------------------------------------------------------ # + @pytest.mark.internal + @pytest.mark.parametrize("num_cuda_graphs", [1, 16, 32, -1]) + def test_oversized_batch_returns_none(self, num_cuda_graphs): + """When the real batch is larger than any available graph, all ranks + should get None.""" + ep_group = self._get_ep_group() + graph_list = _generate_graphs(num_cuda_graphs) + + # Token count exceeds MAX_TOKENS on all ranks + real = BD( + token_count=MAX_TOKENS + 100, + prefill_req_count=0, + decode_req_count=min(MAX_TOKENS + 100, MAX_REQUESTS), + ) + + result = _match(real, graph_list, ep_group=ep_group) + _assert_consistent_across_ranks(result, ep_group) + assert result is None, "Oversized batch should not match any graph" + + # ------------------------------------------------------------------ # + # 11. One EP rank has huge batch → all-reduce lifts to max → no match + # ------------------------------------------------------------------ # + @pytest.mark.internal + @pytest.mark.parametrize("num_cuda_graphs", [1, 16, 32, -1]) + def test_one_rank_oversized_forces_no_match(self, num_cuda_graphs): + """If one EP rank has a batch exceeding all graph capacities, the + all-reduce max lifts everyone → no match on any rank.""" + ep_group = self._get_ep_group() + graph_list = _generate_graphs(num_cuda_graphs) + rank = dist.get_rank() + + if rank == 0: + # This rank has a batch that exceeds all graphs + real = BD( + token_count=MAX_TOKENS + 100, + prefill_req_count=0, + decode_req_count=min(MAX_TOKENS + 100, MAX_REQUESTS), + ) + else: + real = BD(token_count=8, prefill_req_count=0, decode_req_count=8) + + result = _match(real, graph_list, ep_group=ep_group) + _assert_consistent_across_ranks(result, ep_group) + assert result is None, "All-reduce max from oversized rank should cause no match" diff --git a/tests/unit_tests/inference/test_moe_permute.py b/tests/unit_tests/inference/test_moe_permute.py new file mode 100644 index 00000000000..4664d0fa2cd --- /dev/null +++ b/tests/unit_tests/inference/test_moe_permute.py @@ -0,0 +1,446 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Unit tests for megatron.core.inference.moe.permute. + +Tests cover: +- compute_local_tokens_per_expert: token counting against PyTorch reference +- compute_expert_offsets: prefix sums with and without alignment +- permute_tokens: expert grouping, data integrity, alignment padding +- unpermute_tokens: weighted scatter-back, fp32 accumulation +- permute -> unpermute roundtrip +""" + +import pytest +import torch + + +def _ref_tokens_per_expert(routing_map, local_expert_start, num_local_experts): + """PyTorch reference for compute_local_tokens_per_expert.""" + counts = torch.zeros(num_local_experts, dtype=torch.int32, device=routing_map.device) + for eid in routing_map.flatten(): + lid = eid.item() - local_expert_start + if 0 <= lid < num_local_experts: + counts[lid] += 1 + return counts + + +def _ref_expert_offsets(tokens_per_expert, alignment): + """PyTorch reference for compute_expert_offsets.""" + aligned = tokens_per_expert.clone().to(torch.int32) + for i in range(len(aligned)): + if aligned[i] > 0 and alignment > 1: + aligned[i] = ((aligned[i] + alignment - 1) // alignment) * alignment + inc = torch.cumsum(aligned, dim=0) + exc = inc - aligned + return exc.to(torch.int32), inc.to(torch.int32) + + +def _make_inputs(num_tokens, hidden_dim, topk, num_experts, seed=42): + """Create random hidden states, probs, and routing_map.""" + torch.manual_seed(seed) + hidden = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.bfloat16) + probs = torch.rand(num_tokens, topk, device="cuda", dtype=torch.float32) + routing_map = torch.randint(0, num_experts, (num_tokens, topk), device="cuda") + return hidden, probs, routing_map + + +@pytest.mark.internal +class TestComputeLocalTokensPerExpert: + + @pytest.mark.parametrize("num_tokens", [1, 4, 16, 64, 128, 256, 512]) + @pytest.mark.parametrize("topk", [1, 2, 4, 6, 8]) + @pytest.mark.parametrize( + "num_experts,num_local,start", + [ + (4, 4, 0), # all local, small expert count + (8, 8, 0), # all local (EP=1) + (8, 4, 0), # first half local (EP=2, rank 0) + (8, 4, 4), # second half local (EP=2, rank 1) + (8, 2, 2), # middle slice (EP=4, rank 1) + (8, 1, 7), # single expert local (EP=8, last rank) + (32, 8, 0), # 32 experts, first 8 local + (32, 8, 24), # 32 experts, last 8 local + (128, 32, 0), # 128 experts, first 32 local (EP=4, rank 0) + (128, 32, 96), # 128 experts, last 32 local (EP=4, rank 3) + ], + ) + def test_matches_reference(self, num_tokens, topk, num_experts, num_local, start): + from megatron.core.inference.moe.permute import compute_local_tokens_per_expert + + routing_map = torch.randint(0, num_experts, (num_tokens, topk), device="cuda") + result = compute_local_tokens_per_expert(routing_map, start, num_local) + expected = _ref_tokens_per_expert(routing_map, start, num_local) + torch.testing.assert_close(result, expected, atol=0, rtol=0) + + def test_no_local_tokens(self): + """All tokens routed to non-local experts -> all zeros.""" + from megatron.core.inference.moe.permute import compute_local_tokens_per_expert + + routing_map = torch.full((16, 4), 99, dtype=torch.int64, device="cuda") + result = compute_local_tokens_per_expert(routing_map, 0, 8) + assert result.sum().item() == 0 + + def test_single_expert_all_tokens(self): + """All token-topk pairs route to a single local expert.""" + from megatron.core.inference.moe.permute import compute_local_tokens_per_expert + + num_tokens, topk, num_local = 32, 4, 8 + routing_map = torch.full((num_tokens, topk), 3, dtype=torch.int64, device="cuda") + result = compute_local_tokens_per_expert(routing_map, 0, num_local) + assert result[3].item() == num_tokens * topk + assert result.sum().item() == num_tokens * topk + + @pytest.mark.parametrize("seed", [0, 7, 42, 123, 999]) + def test_total_count_equals_local_pairs(self, seed): + """Sum of tokens_per_expert equals total routing pairs hitting local experts.""" + from megatron.core.inference.moe.permute import compute_local_tokens_per_expert + + torch.manual_seed(seed) + num_tokens, topk, num_experts = 64, 6, 16 + local_start, num_local = 4, 4 + routing_map = torch.randint(0, num_experts, (num_tokens, topk), device="cuda") + result = compute_local_tokens_per_expert(routing_map, local_start, num_local) + local_mask = (routing_map >= local_start) & (routing_map < local_start + num_local) + assert result.sum().item() == local_mask.sum().item() + + +@pytest.mark.internal +class TestComputeExpertOffsets: + + @pytest.mark.parametrize("alignment", [1, 8, 16, 32, 64, 128]) + @pytest.mark.parametrize( + "tpe_values", + [ + [5, 0, 12, 3, 0, 7, 1, 20], + [1, 1, 1, 1], + [0, 0, 0, 0], + [100, 0, 0, 50], + [1], + [33, 33, 33, 33, 33, 33, 33, 33], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + [127, 0, 129, 0, 1, 0, 255, 0], + ], + ) + def test_matches_reference(self, alignment, tpe_values): + from megatron.core.inference.moe.permute import compute_expert_offsets + + tpe = torch.tensor(tpe_values, dtype=torch.int32, device="cuda") + exc, inc = compute_expert_offsets(tpe, alignment=alignment) + ref_exc, ref_inc = _ref_expert_offsets(tpe, alignment) + torch.testing.assert_close(exc, ref_exc, atol=0, rtol=0) + torch.testing.assert_close(inc, ref_inc, atol=0, rtol=0) + + @pytest.mark.parametrize("n_experts", [1, 2, 4, 8, 16, 32, 64, 128]) + def test_exclusive_starts_at_zero(self, n_experts): + from megatron.core.inference.moe.permute import compute_expert_offsets + + tpe = torch.randint(1, 50, (n_experts,), dtype=torch.int32, device="cuda") + exc, inc = compute_expert_offsets(tpe, alignment=1) + assert exc[0].item() == 0 + assert inc[-1].item() == tpe.sum().item() + + def test_zero_experts_skipped(self): + """Experts with 0 tokens should not consume any aligned space.""" + from megatron.core.inference.moe.permute import compute_expert_offsets + + tpe = torch.tensor([0, 5, 0, 3], dtype=torch.int32, device="cuda") + exc, inc = compute_expert_offsets(tpe, alignment=32) + # Expert 0: 0 tokens -> 0 aligned -> exc=0, inc=0 + assert exc[0].item() == 0 + assert inc[0].item() == 0 + # Expert 1: 5 tokens -> 32 aligned -> exc=0, inc=32 + assert exc[1].item() == 0 + assert inc[1].item() == 32 + # Expert 2: 0 tokens -> exc=32, inc=32 + assert exc[2].item() == 32 + assert inc[2].item() == 32 + + @pytest.mark.parametrize("alignment", [16, 32, 128]) + def test_all_offsets_aligned(self, alignment): + """Every inclusive offset should be a multiple of alignment.""" + from megatron.core.inference.moe.permute import compute_expert_offsets + + tpe = torch.tensor([3, 7, 0, 15, 1, 0, 50, 2], dtype=torch.int32, device="cuda") + exc, inc = compute_expert_offsets(tpe, alignment=alignment) + for i in range(len(tpe)): + assert ( + inc[i].item() % alignment == 0 + ), f"inc[{i}]={inc[i].item()} not aligned to {alignment}" + assert ( + exc[i].item() % alignment == 0 + ), f"exc[{i}]={exc[i].item()} not aligned to {alignment}" + + +class TestPermuteTokens: + + @pytest.mark.parametrize( + "num_tokens,hidden_dim,topk,num_experts", + [ + (1, 64, 1, 4), + (1, 128, 8, 8), + (4, 64, 2, 4), + (16, 128, 2, 8), + (32, 64, 4, 8), + (64, 256, 6, 8), + (128, 128, 8, 128), + (256, 64, 2, 32), + (512, 128, 6, 16), + ], + ) + def test_data_integrity(self, num_tokens, hidden_dim, topk, num_experts): + """Every permuted row matches the original token's hidden state.""" + from megatron.core.inference.moe.permute import permute_tokens + + hidden, probs, routing_map = _make_inputs(num_tokens, hidden_dim, topk, num_experts) + perm_h, perm_p, perm_map, offs = permute_tokens( + hidden, probs, routing_map, 0, num_experts, alignment=1 + ) + + # Check every non-padding row + for i in range(perm_map.shape[0]): + src = perm_map[i].item() + if src < 0: + continue + torch.testing.assert_close( + perm_h[i], hidden[src], msg=f"Row {i} (src={src}) hidden mismatch" + ) + + @pytest.mark.parametrize("alignment", [1, 16, 32, 64, 128]) + @pytest.mark.parametrize("num_tokens,topk,num_experts", [(16, 2, 4), (64, 4, 8), (128, 8, 32)]) + def test_offsets_are_aligned(self, alignment, num_tokens, topk, num_experts): + """Inclusive offsets are multiples of alignment.""" + from megatron.core.inference.moe.permute import permute_tokens + + hidden, probs, routing_map = _make_inputs(num_tokens, 128, topk, num_experts) + _, _, _, offs = permute_tokens( + hidden, probs, routing_map, 0, num_experts, alignment=alignment + ) + if alignment > 1: + for i in range(offs.shape[0]): + assert ( + offs[i].item() % alignment == 0 + ), f"Offset {i}={offs[i].item()} not aligned to {alignment}" + + @pytest.mark.parametrize( + "num_tokens,topk,num_experts,alignment", + [(8, 2, 4, 128), (32, 2, 4, 128), (16, 4, 8, 64), (64, 6, 8, 32)], + ) + def test_padding_rows_have_neg1(self, num_tokens, topk, num_experts, alignment): + """Padding rows in permutation_map are -1.""" + from megatron.core.inference.moe.permute import permute_tokens + + hidden, probs, routing_map = _make_inputs(num_tokens, 64, topk, num_experts) + _, _, perm_map, _ = permute_tokens( + hidden, probs, routing_map, 0, num_experts, alignment=alignment + ) + padding_mask = perm_map == -1 + real_mask = perm_map >= 0 + assert padding_mask.sum() > 0, "Expected some padding rows with large alignment" + assert real_mask.sum() > 0, "Expected some real rows" + + @pytest.mark.parametrize( + "num_tokens,topk,num_experts", [(16, 2, 4), (32, 4, 8), (64, 6, 16), (128, 8, 128)] + ) + @pytest.mark.parametrize("alignment", [1, 32, 128]) + def test_total_real_rows_equals_routed_pairs(self, num_tokens, topk, num_experts, alignment): + """Number of non-padding rows equals total (token, topk) pairs routed locally.""" + from megatron.core.inference.moe.permute import permute_tokens + + hidden, probs, routing_map = _make_inputs(num_tokens, 64, topk, num_experts) + _, _, perm_map, _ = permute_tokens( + hidden, probs, routing_map, 0, num_experts, alignment=alignment + ) + real_count = (perm_map >= 0).sum().item() + # All experts are local, so every pair should appear + assert real_count == num_tokens * topk + + @pytest.mark.parametrize( + "num_tokens,topk,num_experts,local_start,num_local", + [ + (64, 4, 8, 2, 3), # experts 2, 3, 4 + (64, 4, 8, 0, 1), # only expert 0 + (64, 4, 8, 7, 1), # only expert 7 + (128, 6, 16, 4, 8), # experts 4-11 + (32, 2, 32, 16, 16), # second half of 32 + (256, 8, 128, 0, 32), # first 32 of 128 + ], + ) + def test_expert_subset(self, num_tokens, topk, num_experts, local_start, num_local): + """Only tokens routed to local experts appear in output.""" + from megatron.core.inference.moe.permute import permute_tokens + + hidden, probs, routing_map = _make_inputs(num_tokens, 64, topk, num_experts) + _, _, perm_map, _ = permute_tokens( + hidden, probs, routing_map, local_start, num_local, alignment=1 + ) + real_count = (perm_map >= 0).sum().item() + local_mask = (routing_map >= local_start) & (routing_map < local_start + num_local) + expected_count = local_mask.sum().item() + assert real_count == expected_count + + @pytest.mark.parametrize("hidden_dim", [32, 64, 128, 256, 512, 1024, 2688]) + def test_various_hidden_dims(self, hidden_dim): + """Permute works across various hidden dimensions including non-power-of-2.""" + from megatron.core.inference.moe.permute import permute_tokens + + hidden, probs, routing_map = _make_inputs(32, hidden_dim, 4, 8) + perm_h, _, perm_map, _ = permute_tokens(hidden, probs, routing_map, 0, 8, alignment=1) + # Spot-check first real row + for i in range(perm_map.shape[0]): + src = perm_map[i].item() + if src >= 0: + torch.testing.assert_close(perm_h[i], hidden[src]) + break + + +@pytest.mark.internal +class TestUnpermuteTokens: + + def test_weighted_scatter(self): + """Unpermute correctly accumulates prob-weighted expert outputs.""" + from megatron.core.inference.moe.permute import unpermute_tokens + + num_tokens, hidden_dim = 4, 8 + # Two entries map to token 0, one to token 2 + expert_output = torch.ones(3, hidden_dim, device="cuda", dtype=torch.bfloat16) + permuted_probs = torch.tensor([0.5, 0.3, 0.7], device="cuda", dtype=torch.float32) + perm_map = torch.tensor([0, 0, 2], dtype=torch.int32, device="cuda") + + result = unpermute_tokens(expert_output, permuted_probs, perm_map, num_tokens) + + assert result.dtype == torch.float32 + # Token 0: 0.5 * 1.0 + 0.3 * 1.0 = 0.8 + torch.testing.assert_close( + result[0], torch.full((hidden_dim,), 0.8, device="cuda"), atol=1e-5, rtol=1e-5 + ) + # Token 1: untouched -> 0 + torch.testing.assert_close(result[1], torch.zeros(hidden_dim, device="cuda")) + # Token 2: 0.7 * 1.0 = 0.7 + torch.testing.assert_close( + result[2], torch.full((hidden_dim,), 0.7, device="cuda"), atol=1e-5, rtol=1e-5 + ) + + def test_padding_rows_ignored(self): + """Rows with permutation_map == -1 are skipped.""" + from megatron.core.inference.moe.permute import unpermute_tokens + + expert_output = torch.ones(4, 8, device="cuda", dtype=torch.bfloat16) + permuted_probs = torch.ones(4, device="cuda", dtype=torch.float32) + perm_map = torch.tensor([0, -1, -1, 1], dtype=torch.int32, device="cuda") + + result = unpermute_tokens(expert_output, permuted_probs, perm_map, 3) + # Only tokens 0 and 1 get values + assert result[0].sum().item() != 0 + assert result[1].sum().item() != 0 + assert result[2].sum().item() == 0 + + @pytest.mark.parametrize("hidden_dim", [8, 64, 128, 256, 512, 2688]) + def test_various_hidden_dims(self, hidden_dim): + """Unpermute works across various hidden dimensions.""" + from megatron.core.inference.moe.permute import unpermute_tokens + + num_tokens = 8 + expert_output = torch.randn(4, hidden_dim, device="cuda", dtype=torch.bfloat16) + permuted_probs = torch.tensor([1.0, 1.0, 1.0, 1.0], device="cuda", dtype=torch.float32) + perm_map = torch.tensor([0, 1, 2, 3], dtype=torch.int32, device="cuda") + + result = unpermute_tokens(expert_output, permuted_probs, perm_map, num_tokens) + assert result.shape == (num_tokens, hidden_dim) + # First 4 tokens should have values, rest should be zero + for t in range(4): + torch.testing.assert_close(result[t], expert_output[t].float(), atol=1e-5, rtol=1e-5) + for t in range(4, num_tokens): + assert result[t].sum().item() == 0 + + @pytest.mark.parametrize("topk", [1, 2, 4, 6, 8]) + def test_multiple_topk_accumulation(self, topk): + """Multiple topk entries for the same token are summed correctly.""" + from megatron.core.inference.moe.permute import unpermute_tokens + + hidden_dim = 32 + # All topk entries point to token 0 + expert_output = torch.ones(topk, hidden_dim, device="cuda", dtype=torch.bfloat16) + probs = torch.full((topk,), 0.1, device="cuda", dtype=torch.float32) + perm_map = torch.zeros(topk, dtype=torch.int32, device="cuda") + + result = unpermute_tokens(expert_output, probs, perm_map, 1) + expected_val = 0.1 * topk + torch.testing.assert_close( + result[0], torch.full((hidden_dim,), expected_val, device="cuda"), atol=1e-4, rtol=1e-4 + ) + + +@pytest.mark.internal +class TestPermuteUnpermuteRoundtrip: + + @pytest.mark.parametrize( + "num_tokens,hidden_dim,topk,num_experts,alignment", + [ + (1, 64, 1, 4, 1), + (1, 128, 1, 4, 128), + (8, 64, 1, 4, 1), + (16, 64, 2, 4, 1), + (16, 64, 2, 4, 32), + (32, 128, 4, 8, 32), + (32, 128, 4, 8, 128), + (64, 256, 6, 8, 1), + (64, 256, 6, 8, 128), + (128, 128, 8, 32, 1), + (128, 128, 8, 32, 128), + (256, 64, 2, 128, 32), + (64, 2688, 8, 128, 128), # nanov3-like hidden_dim + ], + ) + def test_roundtrip_identity(self, num_tokens, hidden_dim, topk, num_experts, alignment): + """permute -> (identity transform) -> unpermute recovers weighted sum of inputs.""" + from megatron.core.inference.moe.permute import permute_tokens, unpermute_tokens + + torch.manual_seed(42) + hidden = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.bfloat16) + probs = torch.rand(num_tokens, topk, device="cuda", dtype=torch.float32) + routing_map = torch.randint(0, num_experts, (num_tokens, topk), device="cuda") + + perm_h, perm_p, perm_map, _ = permute_tokens( + hidden, probs, routing_map, 0, num_experts, alignment=alignment + ) + # Pass permuted hidden directly through (identity expert) + result = unpermute_tokens(perm_h, perm_p, perm_map, num_tokens) + + # Build reference: for each token, sum prob[k] * hidden[token] over topk + ref = torch.zeros(num_tokens, hidden_dim, device="cuda", dtype=torch.float32) + for t in range(num_tokens): + prob_sum = probs[t].sum() + ref[t] = hidden[t].float() * prob_sum + + torch.testing.assert_close(result, ref, atol=1e-2, rtol=1e-2) + + @pytest.mark.parametrize( + "local_start,num_local,num_experts", + [(0, 4, 8), (4, 4, 8), (0, 1, 8), (0, 8, 8), (0, 32, 128), (96, 32, 128)], + ) + def test_roundtrip_with_expert_subset(self, local_start, num_local, num_experts): + """Roundtrip works when only a subset of experts are local.""" + from megatron.core.inference.moe.permute import permute_tokens, unpermute_tokens + + torch.manual_seed(42) + num_tokens, hidden_dim, topk = 64, 128, 4 + hidden = torch.randn(num_tokens, hidden_dim, device="cuda", dtype=torch.bfloat16) + probs = torch.rand(num_tokens, topk, device="cuda", dtype=torch.float32) + routing_map = torch.randint(0, num_experts, (num_tokens, topk), device="cuda") + + perm_h, perm_p, perm_map, _ = permute_tokens( + hidden, probs, routing_map, local_start, num_local, alignment=32 + ) + result = unpermute_tokens(perm_h, perm_p, perm_map, num_tokens) + + # Reference: only accumulate probs for local experts + ref = torch.zeros(num_tokens, hidden_dim, device="cuda", dtype=torch.float32) + for t in range(num_tokens): + local_prob_sum = 0.0 + for k in range(topk): + eid = routing_map[t, k].item() + if local_start <= eid < local_start + num_local: + local_prob_sum += probs[t, k].item() + ref[t] = hidden[t].float() * local_prob_sum + + torch.testing.assert_close(result, ref, atol=1e-2, rtol=1e-2) diff --git a/tests/unit_tests/inference/test_mxfp8_utils.py b/tests/unit_tests/inference/test_mxfp8_utils.py new file mode 100644 index 00000000000..a137dfbc820 --- /dev/null +++ b/tests/unit_tests/inference/test_mxfp8_utils.py @@ -0,0 +1,645 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +"""Unit tests for MXFP8 quantization. + +Tests cover: +- mxfp8_quantize (Triton kernel): data and swizzled scales vs PyTorch reference +- MXFP8Tensor.from_bf16: both 'triton' and 'flashinfer' backends +- MXFP8Tensor.scale_2d: reshape correctness +""" + +import pytest +import torch + +pytestmark = [ + pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required"), + pytest.mark.internal, +] + + +def ceil_div(a, b): + return (a + b - 1) // b + + +# ────────────────────────────────────────────────────────────────────── +# Reference functions from PyTorch +# https://github.com/pytorch/pytorch/blob/main/torch/testing/_internal/common_quantized.py#L578 +# ────────────────────────────────────────────────────────────────────── + + +def ref_to_mxfp(data_hp: torch.Tensor, block_size: int = 32, format: str = "mxfp8"): + if data_hp.dtype not in (torch.bfloat16, torch.float): + raise AssertionError(f"{data_hp.dtype} is not supported yet") + if data_hp.shape[-1] % block_size != 0: + raise AssertionError( + f"the last dimension of shape {data_hp.shape} must be divisible by block_size {block_size}" + ) + if not data_hp.is_contiguous(): + raise AssertionError("unsupported: data_hp must be contiguous") + + orig_shape = data_hp.shape + data_hp = data_hp.reshape(*orig_shape[:-1], orig_shape[-1] // block_size, block_size) + + max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1) + + data_hp = data_hp.to(torch.float32) + max_abs = max_abs.to(torch.float32) + + if format == "mxfp8": + F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0 + max_pos = F8E4M3_MAX + elif format == "mxfp4": + F4E2M1_MAX = 6.0 + max_pos = F4E2M1_MAX + + # RCEIL + def _to_mx_rceil( + data_hp: torch.Tensor, max_abs: torch.Tensor, max_pos: float + ) -> tuple[torch.Tensor, torch.Tensor]: + E8M0_EXPONENT_BIAS = 127 + descale = max_abs / max_pos + exponent = torch.where( + torch.isnan(descale), + 0xFF, # Handle biased exponent for nan + ( + torch.clamp( + torch.ceil(torch.log2(descale)), min=-E8M0_EXPONENT_BIAS, max=E8M0_EXPONENT_BIAS + ) + + E8M0_EXPONENT_BIAS + ).to(torch.uint8), + ) + + descale_fp = torch.where( + exponent == 0, 1.0, torch.exp2(E8M0_EXPONENT_BIAS - exponent.to(torch.float32)) + ) + + # scale and saturated cast the data elements to max of target dtype + data_lp = torch.clamp(data_hp * descale_fp, min=-1 * max_pos, max=max_pos) + return exponent, data_lp + + scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos) + + # cast to target dtype + data_lp = data_lp.to(torch.float8_e4m3fn) + data_lp = data_lp.reshape(orig_shape) + + scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu) + scale_e8m0_biased = scale_e8m0_biased.squeeze(-1) + return scale_e8m0_biased, data_lp + + +def ref_swizzle(input_matrix) -> torch.Tensor: + """Rearrange a scale matrix into cuBLAS 2D blocked (swizzled) layout. + + See: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + + Args: + input_matrix: Input tensor of shape (H, W) + + Returns: + Flattened swizzled tensor. + """ + rows, cols = input_matrix.shape + n_row_blocks = ceil_div(rows, 128) + n_col_blocks = ceil_div(cols, 4) + + padded_rows = n_row_blocks * 128 + padded_cols = n_col_blocks * 4 + + padded = input_matrix + if (rows, cols) != (padded_rows, padded_cols): + padded = torch.zeros( + (padded_rows, padded_cols), device=input_matrix.device, dtype=input_matrix.dtype + ) + padded[:rows, :cols] = input_matrix + + blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) + rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) + + return rearranged.flatten() + + +# ────────────────────────────────────────────────────────────────────── +# mxfp8_quantize (Triton kernel) +# ────────────────────────────────────────────────────────────────────── + + +class TestMxfp8Quantize: + + @pytest.mark.parametrize( + "M,K", + [ + (1, 32), + (1, 64), + (1, 128), + (4, 32), + (4, 128), + (16, 64), + (16, 256), + (32, 128), + (64, 256), + (128, 128), + (128, 512), + (128, 2688), # nanov3 hidden_size + (256, 1856), # nanov3 moe_ffn_hidden_size + (512, 2688), + ], + ) + def test_data_matches_reference(self, M, K): + """Quantized FP8 data matches PyTorch reference.""" + from megatron.core.inference.quantization.mxfp8_quantize import mxfp8_quantize + + torch.manual_seed(42) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + + triton_data, _ = mxfp8_quantize(x) + _, ref_data = ref_to_mxfp(x) + + assert triton_data.shape == (M, K) + assert triton_data.dtype == torch.float8_e4m3fn + torch.testing.assert_close( + triton_data.view(torch.uint8), ref_data.view(torch.uint8), atol=0, rtol=0 + ) + + @pytest.mark.parametrize( + "M,K", + [ + (1, 32), + (1, 64), + (4, 128), + (16, 256), + (32, 128), + (128, 128), + (128, 512), + (128, 2688), + (256, 1856), + (512, 2688), + ], + ) + def test_scales_match_reference(self, M, K): + """Swizzled scales match ref_to_mxfp scales passed through ref_swizzle.""" + from megatron.core.inference.quantization.mxfp8_quantize import mxfp8_quantize + + torch.manual_seed(42) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + + _, triton_scales = mxfp8_quantize(x) + ref_scales_2d, _ = ref_to_mxfp(x) # [M, K//32] e8m0 + + # Swizzle the reference scales + ref_swizzled = ref_swizzle(ref_scales_2d) + + # Compare as uint8 since e8m0 is just exponent bytes + torch.testing.assert_close( + triton_scales.view(torch.uint8), ref_swizzled.view(torch.uint8), atol=0, rtol=0 + ) + + @pytest.mark.parametrize("M,K", [(1, 32), (16, 128), (128, 2688)]) + def test_all_zeros_input(self, M, K): + """All-zero input produces all-zero FP8 data and zero scales.""" + from megatron.core.inference.quantization.mxfp8_quantize import mxfp8_quantize + + x = torch.zeros(M, K, device="cuda", dtype=torch.bfloat16) + data, scales = mxfp8_quantize(x) + assert (data.float() == 0).all() + assert (scales.view(torch.uint8) == 0).all() + + @pytest.mark.parametrize("M,K", [(1, 32), (16, 128), (128, 256)]) + def test_constant_input(self, M, K): + """Constant input: all elements in a group have the same value.""" + from megatron.core.inference.quantization.mxfp8_quantize import mxfp8_quantize + + x = torch.full((M, K), 1.0, device="cuda", dtype=torch.bfloat16) + data, _ = mxfp8_quantize(x) + _, ref_data = ref_to_mxfp(x) + torch.testing.assert_close( + data.view(torch.uint8), ref_data.view(torch.uint8), atol=0, rtol=0 + ) + + @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16, torch.float32]) + def test_input_dtypes(self, dtype): + """Kernel accepts bf16, fp16, and fp32 inputs.""" + from megatron.core.inference.quantization.mxfp8_quantize import mxfp8_quantize + + x = torch.randn(16, 128, device="cuda", dtype=dtype) + data, _ = mxfp8_quantize(x) + assert data.dtype == torch.float8_e4m3fn + assert data.shape == (16, 128) + + @pytest.mark.parametrize("M", [1, 127, 128, 129, 255, 256, 257, 512]) + def test_various_row_counts(self, M): + """Test row counts that are not multiples of 128 (macro tile boundary).""" + from megatron.core.inference.quantization.mxfp8_quantize import mxfp8_quantize + + K = 128 + torch.manual_seed(42) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + data, _ = mxfp8_quantize(x) + _, ref_data = ref_to_mxfp(x) + torch.testing.assert_close( + data.view(torch.uint8), ref_data.view(torch.uint8), atol=0, rtol=0 + ) + + @pytest.mark.parametrize("seed", [0, 7, 42, 123, 999]) + def test_reproducible(self, seed): + """Same input always produces same output.""" + from megatron.core.inference.quantization.mxfp8_quantize import mxfp8_quantize + + torch.manual_seed(seed) + x = torch.randn(64, 256, device="cuda", dtype=torch.bfloat16) + d1, s1 = mxfp8_quantize(x) + d2, s2 = mxfp8_quantize(x) + torch.testing.assert_close(d1.view(torch.uint8), d2.view(torch.uint8), atol=0, rtol=0) + torch.testing.assert_close(s1.view(torch.uint8), s2.view(torch.uint8), atol=0, rtol=0) + + +# ────────────────────────────────────────────────────────────────────── +# MXFP8Tensor +# ────────────────────────────────────────────────────────────────────── + + +class TestMXFP8Tensor: + + @pytest.mark.parametrize("M,K", [(16, 128), (64, 256), (128, 2688)]) + def test_from_bf16_triton(self, M, K): + """from_bf16 with triton backend produces correct data and scales.""" + from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor + + torch.manual_seed(42) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + + tensor = MXFP8Tensor.from_bf16(x, backend="triton") + _, ref_data = ref_to_mxfp(x) + + assert tensor.data.shape == (M, K) + assert tensor.data.dtype == torch.float8_e4m3fn + assert tensor.backend == "triton" + torch.testing.assert_close( + tensor.data.view(torch.uint8), ref_data.view(torch.uint8), atol=0, rtol=0 + ) + + @pytest.mark.parametrize("M,K", [(16, 128), (64, 256), (128, 2688)]) + def test_from_bf16_flashinfer(self, M, K): + """from_bf16 with flashinfer backend produces valid output.""" + from megatron.core.inference.quantization.mxfp8_tensor import HAVE_FLASHINFER, MXFP8Tensor + + if not HAVE_FLASHINFER: + pytest.skip("FlashInfer not available") + + torch.manual_seed(42) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + + tensor = MXFP8Tensor.from_bf16(x, backend="flashinfer") + assert tensor.data.shape == (M, K) + assert tensor.data.dtype == torch.float8_e4m3fn + assert tensor.backend == "flashinfer" + + def test_from_bf16_invalid_backend(self): + """from_bf16 with invalid backend raises ValueError.""" + from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor + + x = torch.randn(16, 128, device="cuda", dtype=torch.bfloat16) + with pytest.raises(ValueError, match="Unknown MXFP8 quantization backend"): + MXFP8Tensor.from_bf16(x, backend="invalid") + + @pytest.mark.parametrize("M,K", [(1, 32), (16, 128), (128, 2688), (256, 1856)]) + def test_scale_2d_shape(self, M, K): + """scale_2d returns correct shape: (-1, ceil(K//32, 4)*4).""" + from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor + + torch.manual_seed(42) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + tensor = MXFP8Tensor.from_bf16(x, backend="triton") + + scale_2d = tensor.scale_2d() + expected_cols = ceil_div(K // 32, 4) * 4 + assert scale_2d.dim() == 2 + assert scale_2d.shape[1] == expected_cols + + @pytest.mark.parametrize("M,K", [(16, 128), (128, 2688)]) + def test_scale_2d_idempotent(self, M, K): + """Calling scale_2d twice returns the same result.""" + from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor + + torch.manual_seed(42) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + tensor = MXFP8Tensor.from_bf16(x, backend="triton") + + s1 = tensor.scale_2d() + s2 = tensor.scale_2d() + torch.testing.assert_close(s1.view(torch.uint8), s2.view(torch.uint8), atol=0, rtol=0) + + def test_size_method(self): + """size() delegates to data.size().""" + from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor + + x = torch.randn(32, 128, device="cuda", dtype=torch.bfloat16) + tensor = MXFP8Tensor.from_bf16(x, backend="triton") + assert tensor.size() == torch.Size([32, 128]) + assert tensor.size(0) == 32 + assert tensor.size(1) == 128 + + +# ────────────────────────────────────────────────────────────────────── +# Triton vs FlashInfer cross-validation +# ────────────────────────────────────────────────────────────────────── + + +@pytest.mark.skipif( + torch.cuda.get_device_capability()[0] < 10, + reason="MXFP8 FlashInfer comparison requires Blackwell (SM 100+)", +) +class TestTritonVsFlashinfer: + + @pytest.mark.parametrize("M,K", [(1, 32), (16, 128), (64, 256), (128, 2688), (256, 1856)]) + def test_data_matches(self, M, K): + """Triton and FlashInfer backends produce identical FP8 data.""" + from megatron.core.inference.quantization.mxfp8_tensor import HAVE_FLASHINFER, MXFP8Tensor + + if not HAVE_FLASHINFER: + pytest.skip("FlashInfer not available") + + torch.manual_seed(42) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + + triton_tensor = MXFP8Tensor.from_bf16(x, backend="triton") + flashinfer_tensor = MXFP8Tensor.from_bf16(x, backend="flashinfer") + + torch.testing.assert_close( + triton_tensor.data.float(), flashinfer_tensor.data.float(), atol=0, rtol=0 + ) + + @pytest.mark.parametrize("M,K", [(1, 32), (16, 128), (64, 256), (128, 2688), (256, 1856)]) + def test_scales_match(self, M, K): + """Triton and FlashInfer backends produce identical swizzled scales.""" + from megatron.core.inference.quantization.mxfp8_tensor import HAVE_FLASHINFER, MXFP8Tensor + + if not HAVE_FLASHINFER: + pytest.skip("FlashInfer not available") + + torch.manual_seed(42) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + + triton_tensor = MXFP8Tensor.from_bf16(x, backend="triton") + flashinfer_tensor = MXFP8Tensor.from_bf16(x, backend="flashinfer") + + torch.testing.assert_close( + triton_tensor.scale.view(torch.uint8), + flashinfer_tensor.scale.view(torch.uint8), + atol=0, + rtol=0, + ) + + +def _make_permutation_map(M, num_padding=0): + """Create a permutation_map with optional padding rows at the end.""" + real = torch.arange(M - num_padding, dtype=torch.int32, device="cuda") + pad = torch.full((num_padding,), -1, dtype=torch.int32, device="cuda") + return torch.cat([real, pad]) + + +# ────────────────────────────────────────────────────────────────────── +# squared_relu_and_quantize_mxfp8 vs PyTorch reference +# ────────────────────────────────────────────────────────────────────── + + +class TestSquaredReluAndQuantizeMxfp8: + """Compare fused squared_relu + mxfp8 quantize against PyTorch reference. + + Reference: torch.relu(x.float()).pow(2).to(bf16) -> ref_to_mxfp -> ref_swizzle. + The fused kernel computes squared ReLU in fp32 and quantizes to MXFP8 in one pass, + so the PyTorch fp32 reference is the correct baseline (not the unfused Triton path + which has an intermediate bf16 roundtrip). + """ + + @pytest.mark.parametrize( + "M,K", + [ + (1, 32), + (4, 64), + (16, 128), + (32, 256), + (64, 128), + (128, 128), + (128, 256), + (128, 2688), + (256, 1856), + (512, 2688), + ], + ) + def test_data_matches_pytorch_ref(self, M, K): + """Fused FP8 data matches PyTorch squared ReLU + ref_to_mxfp.""" + from megatron.core.inference.moe.activations import squared_relu_and_quantize_mxfp8 + + torch.manual_seed(42) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + perm_map = _make_permutation_map(M, num_padding=0) + + # PyTorch reference: squared ReLU in fp32, then downcast to bf16, then quantize + activated_ref = torch.relu(x.float()).pow(2) + _, ref_data = ref_to_mxfp(activated_ref) + + # Fused kernel + fused_result = squared_relu_and_quantize_mxfp8(x, perm_map) + + torch.testing.assert_close( + fused_result.data.view(torch.uint8), ref_data.view(torch.uint8), atol=0, rtol=0 + ) + + @pytest.mark.parametrize("M,K", [(1, 32), (16, 128), (128, 128), (128, 2688), (256, 1856)]) + def test_scales_match_pytorch_ref(self, M, K): + """Fused swizzled scales match PyTorch ref_to_mxfp + ref_swizzle.""" + from megatron.core.inference.moe.activations import squared_relu_and_quantize_mxfp8 + + torch.manual_seed(42) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + perm_map = _make_permutation_map(M, num_padding=0) + + # PyTorch reference + activated_ref = torch.relu(x.float()).pow(2) + ref_scales_2d, _ = ref_to_mxfp(activated_ref) + ref_swizzled = ref_swizzle(ref_scales_2d) + + # Fused kernel + fused_result = squared_relu_and_quantize_mxfp8(x, perm_map) + + torch.testing.assert_close( + fused_result.scale.view(torch.uint8), ref_swizzled.view(torch.uint8), atol=0, rtol=0 + ) + + @pytest.mark.parametrize( + "M,K,num_padding", + [(32, 128, 8), (64, 256, 16), (128, 128, 32), (128, 2688, 64), (256, 1856, 128)], + ) + def test_real_rows_match_pytorch_ref_with_padding(self, M, K, num_padding): + """Real rows match PyTorch reference even when padding rows are present.""" + from megatron.core.inference.moe.activations import squared_relu_and_quantize_mxfp8 + + torch.manual_seed(42) + x = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + perm_map = _make_permutation_map(M, num_padding=num_padding) + + # PyTorch reference (only real rows) + real_rows = M - num_padding + activated_ref = torch.relu(x[:real_rows].float()).pow(2) + _, ref_data = ref_to_mxfp(activated_ref) + + # Fused kernel + fused_result = squared_relu_and_quantize_mxfp8(x, perm_map) + + torch.testing.assert_close( + fused_result.data[:real_rows].view(torch.uint8), + ref_data.view(torch.uint8), + atol=0, + rtol=0, + ) + + +# ────────────────────────────────────────────────────────────────────── +# permute_and_quantize_mxfp8 +# ────────────────────────────────────────────────────────────────────── + + +class TestPermuteAndQuantizeMxfp8: + """Compare fused permute + mxfp8 quantize against PyTorch reference. + + PyTorch reference: + 1. For each real row, quantize the source token with ref_to_mxfp + 2. Compare FP8 data per source token + Structural checks (permutation_map, probs, offsets) verified independently. + """ + + def _make_inputs(self, num_tokens, K, topk, num_experts, seed=42): + torch.manual_seed(seed) + hidden = torch.randn(num_tokens, K, device="cuda", dtype=torch.bfloat16) + probs = torch.rand(num_tokens, topk, device="cuda", dtype=torch.float32) + routing_map = torch.randint(0, num_experts, (num_tokens, topk), device="cuda") + return hidden, probs, routing_map + + @pytest.mark.parametrize( + "num_tokens,K,topk,num_experts", + [ + (4, 128, 2, 4), + (16, 128, 2, 8), + (32, 256, 4, 8), + (64, 128, 6, 8), + (64, 2688, 8, 128), + (128, 1856, 4, 32), + ], + ) + def test_data_matches_pytorch_ref(self, num_tokens, K, topk, num_experts): + """For each real row, fused FP8 data matches ref_to_mxfp of the source token.""" + from megatron.core.inference.moe.permute import permute_and_quantize_mxfp8 + + hidden, probs, routing_map = self._make_inputs(num_tokens, K, topk, num_experts) + + fused_mxfp8, _, fused_perm_map, _ = permute_and_quantize_mxfp8( + hidden, probs, routing_map, 0, num_experts, alignment=128 + ) + + # For each real row, quantize the source token with PyTorch ref and compare + for i in range(fused_perm_map.shape[0]): + src = fused_perm_map[i].item() + if src < 0: + continue + _, ref_data = ref_to_mxfp(hidden[src].unsqueeze(0)) + torch.testing.assert_close( + fused_mxfp8.data[i].view(torch.uint8), + ref_data.squeeze(0).view(torch.uint8), + atol=0, + rtol=0, + msg=f"Row {i} (src={src}) FP8 data mismatch vs PyTorch ref", + ) + + @pytest.mark.parametrize( + "num_tokens,K,topk,num_experts", [(16, 128, 2, 8), (32, 256, 4, 8), (64, 2688, 8, 128)] + ) + def test_batch_data_matches_pytorch_ref(self, num_tokens, K, topk, num_experts): + """Batch comparison: gather all real rows, quantize as batch, compare.""" + from megatron.core.inference.moe.permute import permute_and_quantize_mxfp8 + + hidden, probs, routing_map = self._make_inputs(num_tokens, K, topk, num_experts) + + fused_mxfp8, _, fused_perm_map, _ = permute_and_quantize_mxfp8( + hidden, probs, routing_map, 0, num_experts, alignment=128 + ) + + real_mask = fused_perm_map >= 0 + real_indices = real_mask.nonzero(as_tuple=True)[0] + if len(real_indices) == 0: + return + + src_tokens = fused_perm_map[real_indices].long() + permuted_bf16 = hidden[src_tokens] + + _, ref_data = ref_to_mxfp(permuted_bf16) + + torch.testing.assert_close( + fused_mxfp8.data[real_indices].view(torch.uint8), + ref_data.view(torch.uint8), + atol=0, + rtol=0, + ) + + @pytest.mark.parametrize( + "num_tokens,K,topk,num_experts", [(16, 128, 2, 8), (32, 256, 4, 8), (64, 2688, 8, 128)] + ) + def test_correct_token_count(self, num_tokens, K, topk, num_experts): + """Number of real rows equals total (token, topk) pairs routed to local experts.""" + from megatron.core.inference.moe.permute import permute_and_quantize_mxfp8 + + hidden, probs, routing_map = self._make_inputs(num_tokens, K, topk, num_experts) + + _, _, fused_perm_map, _ = permute_and_quantize_mxfp8( + hidden, probs, routing_map, 0, num_experts, alignment=128 + ) + + real_count = (fused_perm_map >= 0).sum().item() + # All experts are local, so every pair should appear + assert real_count == num_tokens * topk + + @pytest.mark.parametrize( + "num_tokens,K,topk,num_experts,local_start,num_local", + [(64, 128, 4, 8, 2, 3), (64, 256, 4, 8, 0, 4), (128, 128, 8, 128, 96, 32)], + ) + def test_expert_subset(self, num_tokens, K, topk, num_experts, local_start, num_local): + """Fused kernel correctly handles local expert subsets.""" + from megatron.core.inference.moe.permute import permute_and_quantize_mxfp8 + + hidden, probs, routing_map = self._make_inputs(num_tokens, K, topk, num_experts) + + _, _, fused_perm_map, _ = permute_and_quantize_mxfp8( + hidden, probs, routing_map, local_start, num_local, alignment=128 + ) + + real_count = (fused_perm_map >= 0).sum().item() + local_mask = (routing_map >= local_start) & (routing_map < local_start + num_local) + expected_count = local_mask.sum().item() + assert real_count == expected_count + + def test_returns_mxfp8_tensor(self): + """Result is an MXFP8Tensor with correct backend.""" + from megatron.core.inference.moe.permute import permute_and_quantize_mxfp8 + from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor + + hidden, probs, routing_map = self._make_inputs(16, 128, 2, 4) + result, _, _, _ = permute_and_quantize_mxfp8( + hidden, probs, routing_map, 0, 4, alignment=128 + ) + assert isinstance(result, MXFP8Tensor) + assert result.backend == "triton" + assert result.data.dtype == torch.float8_e4m3fn + + @pytest.mark.parametrize("alignment", [128]) + def test_offsets_aligned(self, alignment): + """Inclusive offsets are multiples of alignment.""" + from megatron.core.inference.moe.permute import permute_and_quantize_mxfp8 + + hidden, probs, routing_map = self._make_inputs(64, 128, 4, 8) + _, _, _, offs = permute_and_quantize_mxfp8( + hidden, probs, routing_map, 0, 8, alignment=alignment + ) + for i in range(offs.shape[0]): + assert ( + offs[i].item() % alignment == 0 + ), f"Offset {i}={offs[i].item()} not aligned to {alignment}" diff --git a/tests/unit_tests/models/test_mimo_partition.py b/tests/unit_tests/models/test_mimo_partition.py new file mode 100644 index 00000000000..1527fb92935 --- /dev/null +++ b/tests/unit_tests/models/test_mimo_partition.py @@ -0,0 +1,434 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + +''' +WORLD_SIZE=1 LOCAL_RANK=0 python -m torch.distributed.run \ + --nproc_per_node=1 -m pytest \ + tests/unit_tests/models/test_mimo_partition.py -v +''' + +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from megatron.core.models.mimo.partition.utils import PartitionAdapter, PartitionConfig +from megatron.core.transformer.transformer_config import TransformerConfig + + +@pytest.mark.experimental +class TestPartitionConfig: + """Tests for PartitionConfig dataclass and factory method.""" + + def test_is_partitioning_enabled_cp_only(self): + cfg = PartitionConfig( + seq_parallel=False, use_cp=True, tp_comm_overlap=False, max_seq_len=128 + ) + assert cfg.is_partitioning_enabled is True + + def test_is_partitioning_enabled_sp_only(self): + cfg = PartitionConfig( + seq_parallel=True, use_cp=False, tp_comm_overlap=False, max_seq_len=128 + ) + assert cfg.is_partitioning_enabled is True + + def test_is_partitioning_enabled_both(self): + cfg = PartitionConfig( + seq_parallel=True, use_cp=True, tp_comm_overlap=False, max_seq_len=128 + ) + assert cfg.is_partitioning_enabled is True + + def test_is_partitioning_enabled_neither(self): + cfg = PartitionConfig( + seq_parallel=False, use_cp=False, tp_comm_overlap=False, max_seq_len=128 + ) + assert cfg.is_partitioning_enabled is False + + def test_from_mp_config_invalid_type_raises(self): + with pytest.raises(TypeError, match="mp must be a ModelParallelConfig instance"): + PartitionConfig.from_mp_config("not_a_config", max_seq_len=128) + + def test_from_mp_config_no_parallelism(self): + mp = TransformerConfig( + num_layers=1, + hidden_size=64, + num_attention_heads=4, + context_parallel_size=1, + sequence_parallel=False, + ) + with patch('megatron.core.models.mimo.partition.utils.get_pg_size', return_value=1): + cfg = PartitionConfig.from_mp_config(mp, max_seq_len=512) + assert cfg.use_cp is False + assert cfg.seq_parallel is False + assert cfg.cp_group is None + assert cfg.tp_group is None + assert cfg.max_seq_len == 512 + + def test_from_mp_config_kv_format_thd(self): + mp = TransformerConfig(num_layers=1, hidden_size=64, num_attention_heads=4) + with patch('megatron.core.models.mimo.partition.utils.get_pg_size', return_value=1): + cfg = PartitionConfig.from_mp_config(mp, max_seq_len=512, kv_format='thd') + assert cfg.kv_format == 'thd' + + def test_from_mp_config_explicit_cp_group(self): + mock_cp_group = MagicMock() + mp = TransformerConfig( + num_layers=1, hidden_size=64, num_attention_heads=4, context_parallel_size=2 + ) + with patch('megatron.core.models.mimo.partition.utils.get_pg_size', return_value=2): + cfg = PartitionConfig.from_mp_config(mp, max_seq_len=512, cp_group=mock_cp_group) + assert cfg.use_cp is True + assert cfg.cp_group is mock_cp_group + + def test_from_mp_config_explicit_tp_group(self): + mock_tp_group = MagicMock() + mp = TransformerConfig( + num_layers=1, + hidden_size=64, + num_attention_heads=4, + tensor_model_parallel_size=2, + sequence_parallel=True, + ) + with patch('megatron.core.models.mimo.partition.utils.get_pg_size', return_value=1): + cfg = PartitionConfig.from_mp_config(mp, max_seq_len=512, tp_group=mock_tp_group) + assert cfg.seq_parallel is True + assert cfg.tp_group is mock_tp_group + + def test_from_mp_config_auto_fetch_cp_group(self): + mock_group = MagicMock() + mp = TransformerConfig( + num_layers=1, hidden_size=64, num_attention_heads=4, context_parallel_size=2 + ) + with ( + patch( + 'megatron.core.models.mimo.partition.utils.get_context_parallel_group', + return_value=mock_group, + ), + patch('megatron.core.models.mimo.partition.utils.get_pg_size', return_value=2), + ): + cfg = PartitionConfig.from_mp_config(mp, max_seq_len=512) + assert cfg.cp_group is mock_group + + def test_from_mp_config_auto_fetch_tp_group(self): + mock_group = MagicMock() + mp = TransformerConfig( + num_layers=1, + hidden_size=64, + num_attention_heads=4, + tensor_model_parallel_size=2, + sequence_parallel=True, + ) + with ( + patch( + 'megatron.core.models.mimo.partition.utils.get_tensor_model_parallel_group', + return_value=mock_group, + ), + patch('megatron.core.models.mimo.partition.utils.get_pg_size', return_value=1), + ): + cfg = PartitionConfig.from_mp_config(mp, max_seq_len=512) + assert cfg.tp_group is mock_group + + +@pytest.mark.experimental +class TestPartitionAdapterShard: + """Tests for PartitionAdapter.shard().""" + + def _make_cfg( + self, + use_cp=False, + seq_parallel=False, + tp_comm_overlap=False, + max_seq_len=128, + cp_group=None, + tp_group=None, + ): + return PartitionConfig( + use_cp=use_cp, + seq_parallel=seq_parallel, + tp_comm_overlap=tp_comm_overlap, + max_seq_len=max_seq_len, + cp_group=cp_group, + tp_group=tp_group, + ) + + def _make_tensors(self, B=2, S=8, H=16): + embeddings = torch.rand(B, S, H) + labels = torch.randint(0, 100, (B, S)) + loss_mask = torch.ones(B, S) + attention_mask = torch.ones(B, S) + return embeddings, labels, loss_mask, attention_mask + + def test_noop_when_both_disabled(self): + """No sharding when neither CP nor SP is enabled — inputs returned as-is.""" + cfg = self._make_cfg(use_cp=False, seq_parallel=False) + adapter = PartitionAdapter(cfg) + embeddings, labels, loss_mask, attention_mask = self._make_tensors() + out = adapter.shard(embeddings, labels, loss_mask, attention_mask) + assert out[0] is embeddings + assert out[1] is labels + assert out[2] is loss_mask + assert out[3] is attention_mask + assert out[4] is None + + def test_cp_only_shards_sequence(self): + mock_cp_group = MagicMock() + cfg = self._make_cfg(use_cp=True, max_seq_len=8, cp_group=mock_cp_group) + adapter = PartitionAdapter(cfg) + embeddings, labels, loss_mask, attention_mask = self._make_tensors(B=2, S=8, H=16) + sharded = { + 'embeddings': embeddings[:, :4, :], + 'labels': labels[:, :4], + 'loss_mask': loss_mask[:, :4], + 'attention_mask': attention_mask[:, :4], + } + with ( + patch('megatron.core.models.mimo.partition.utils.get_pg_size', return_value=2), + patch( + 'megatron.core.models.mimo.partition.utils.get_batch_on_this_cp_rank', + return_value=sharded, + ), + ): + out = adapter.shard(embeddings, labels, loss_mask, attention_mask) + assert out[0].shape == (2, 4, 16) + assert out[1].shape == (2, 4) + + def test_sp_only_scatters(self): + mock_tp_group = MagicMock() + cfg = self._make_cfg(seq_parallel=True, max_seq_len=8, tp_group=mock_tp_group) + adapter = PartitionAdapter(cfg) + # SP uses seq_dim=0: embeddings shape [S, B, H] + embeddings = torch.rand(8, 2, 16) + labels = torch.randint(0, 100, (2, 8)) + loss_mask = torch.ones(2, 8) + attention_mask = torch.ones(2, 8) + scattered = torch.rand(4, 2, 16) + with ( + patch('megatron.core.models.mimo.partition.utils.get_pg_size', return_value=2), + patch( + 'megatron.core.models.mimo.partition.utils.tensor_parallel.scatter_to_sequence_parallel_region', + return_value=scattered, + ), + ): + out = adapter.shard(embeddings, labels, loss_mask, attention_mask) + assert out[0].shape == (4, 2, 16) + + def test_cp_and_sp_combined(self): + mock_cp_group = MagicMock() + mock_tp_group = MagicMock() + cfg = self._make_cfg( + use_cp=True, + seq_parallel=True, + max_seq_len=16, + cp_group=mock_cp_group, + tp_group=mock_tp_group, + ) + adapter = PartitionAdapter(cfg) + # cp_size=2, tp_size=2 → shard_factor = 2*2*2 = 8; S=16 is divisible + embeddings = torch.rand(2, 16, 16) + labels = torch.randint(0, 100, (2, 16)) + loss_mask = torch.ones(2, 16) + attention_mask = torch.ones(2, 16) + cp_sharded = { + 'embeddings': embeddings[:, :8, :], + 'labels': labels[:, :8], + 'loss_mask': loss_mask[:, :8], + 'attention_mask': attention_mask[:, :8], + } + scattered = torch.rand(2, 4, 16) + + with ( + patch('megatron.core.models.mimo.partition.utils.get_pg_size', return_value=2), + patch( + 'megatron.core.models.mimo.partition.utils.get_batch_on_this_cp_rank', + return_value=cp_sharded, + ), + patch( + 'megatron.core.models.mimo.partition.utils.tensor_parallel.scatter_to_sequence_parallel_region', + return_value=scattered, + ), + ): + out = adapter.shard(embeddings, labels, loss_mask, attention_mask) + assert out[0].shape == (2, 4, 16) + + def test_seq_not_divisible_raises(self): + mock_cp_group = MagicMock() + cfg = self._make_cfg(use_cp=True, max_seq_len=7, cp_group=mock_cp_group) + adapter = PartitionAdapter(cfg) + embeddings = torch.rand(2, 7, 16) # 7 % (2*2) != 0 + labels = torch.randint(0, 100, (2, 7)) + loss_mask = torch.ones(2, 7) + attention_mask = torch.ones(2, 7) + with ( + patch('megatron.core.models.mimo.partition.utils.get_pg_size', return_value=2), + pytest.raises(AssertionError, match="divisible"), + ): + adapter.shard(embeddings, labels, loss_mask, attention_mask) + + def test_tp_comm_overlap_seq_len_assertion(self): + mock_tp_group = MagicMock() + cfg = self._make_cfg( + seq_parallel=True, tp_comm_overlap=True, max_seq_len=16, tp_group=mock_tp_group + ) + adapter = PartitionAdapter(cfg) + # S=8 but max_seq_len=16 → assertion fires + embeddings = torch.rand(8, 2, 16) # [S, B, H] for SP + labels = torch.randint(0, 100, (2, 8)) + loss_mask = torch.ones(2, 8) + attention_mask = torch.ones(2, 8) + with ( + patch('megatron.core.models.mimo.partition.utils.get_pg_size', return_value=2), + pytest.raises(AssertionError, match="TP Comm overlap"), + ): + adapter.shard(embeddings, labels, loss_mask, attention_mask) + + def test_thd_format_skips_divisibility_check(self): + """PackedSeqParams with qkv_format='thd' bypasses the divisibility assertion.""" + from megatron.core.packed_seq_params import PackedSeqParams + + mock_cp_group = MagicMock() + cfg = self._make_cfg(use_cp=True, max_seq_len=7, cp_group=mock_cp_group) + adapter = PartitionAdapter(cfg) + embeddings = torch.rand(2, 7, 16) # seq_len=7 not divisible by cp*2, but THD skips check + labels = torch.randint(0, 100, (2, 7)) + loss_mask = torch.ones(2, 7) + attention_mask = torch.ones(2, 7) + packed_seq_params = MagicMock(spec=PackedSeqParams) + packed_seq_params.qkv_format = 'thd' + packed_seq_params.cu_seqlens_q_padded = torch.tensor([0, 4, 7], dtype=torch.int32) + + # THD path calls tex.thd_get_partitioned_indices — mock it to return first 4 indices + fake_index = torch.arange(4, dtype=torch.int32) + with ( + patch('megatron.core.models.mimo.partition.utils.get_pg_size', return_value=2), + patch('megatron.core.models.mimo.partition.utils.get_pg_rank', return_value=0), + patch('megatron.core.models.mimo.partition.utils.tex') as mock_tex, + ): + mock_tex.thd_get_partitioned_indices.return_value = fake_index + # Should NOT raise AssertionError about divisibility + out = adapter.shard(embeddings, labels, loss_mask, attention_mask, packed_seq_params) + assert out[0] is not None + + def test_none_embeddings_skips_shard_factor_check(self): + """When embeddings is None, the divisibility check is skipped.""" + mock_cp_group = MagicMock() + cfg = self._make_cfg(use_cp=True, max_seq_len=7, cp_group=mock_cp_group) + adapter = PartitionAdapter(cfg) + labels = torch.randint(0, 100, (2, 7)) + loss_mask = torch.ones(2, 7) + attention_mask = torch.ones(2, 7) + cp_sharded = { + 'labels': labels[:, :4], + 'loss_mask': loss_mask[:, :4], + 'attention_mask': attention_mask[:, :4], + } + with ( + patch('megatron.core.models.mimo.partition.utils.get_pg_size', return_value=2), + patch( + 'megatron.core.models.mimo.partition.utils.get_batch_on_this_cp_rank', + return_value=cp_sharded, + ), + ): + out = adapter.shard(None, labels, loss_mask, attention_mask) + assert out[0] is None + + +@pytest.mark.experimental +class TestPartitionAdapterApplyContextParallel: + """Tests for PartitionAdapter._apply_context_parallel().""" + + def _make_cfg(self, use_cp=True, cp_group=None): + return PartitionConfig( + use_cp=use_cp, + seq_parallel=False, + tp_comm_overlap=False, + max_seq_len=128, + cp_group=cp_group, + ) + + def test_returns_unchanged_when_cp_disabled(self): + cfg = self._make_cfg(use_cp=False) + adapter = PartitionAdapter(cfg) + embeddings = torch.rand(2, 8, 16) + labels = torch.randint(0, 100, (2, 8)) + loss_mask = torch.ones(2, 8) + attention_mask = torch.ones(2, 8) + out = adapter._apply_context_parallel(embeddings, labels, loss_mask, attention_mask, None) + assert out[0] is embeddings + assert out[1] is labels + assert out[2] is loss_mask + assert out[3] is attention_mask + + def test_sbhd_path_calls_get_batch_on_this_cp_rank(self): + mock_cp_group = MagicMock() + cfg = self._make_cfg(use_cp=True, cp_group=mock_cp_group) + adapter = PartitionAdapter(cfg) + embeddings = torch.rand(2, 8, 16) + labels = torch.randint(0, 100, (2, 8)) + loss_mask = torch.ones(2, 8) + attention_mask = torch.ones(2, 8) + sharded = { + 'embeddings': embeddings[:, :4, :], + 'labels': labels[:, :4], + 'loss_mask': loss_mask[:, :4], + 'attention_mask': attention_mask[:, :4], + } + with patch( + 'megatron.core.models.mimo.partition.utils.get_batch_on_this_cp_rank', + return_value=sharded, + ) as mock_fn: + out = adapter._apply_context_parallel( + embeddings, labels, loss_mask, attention_mask, None + ) + mock_fn.assert_called_once() + assert out[0].shape == (2, 4, 16) + assert out[1].shape == (2, 4) + + def test_all_none_inputs_produces_none_outputs(self): + mock_cp_group = MagicMock() + cfg = self._make_cfg(use_cp=True, cp_group=mock_cp_group) + adapter = PartitionAdapter(cfg) + with patch( + 'megatron.core.models.mimo.partition.utils.get_batch_on_this_cp_rank', return_value={} + ): + out = adapter._apply_context_parallel(None, None, None, None, None) + assert all(v is None for v in out[:4]) + + def test_only_non_none_tensors_added_to_batch(self): + """None tensors must not appear in the batch dict passed to get_batch_on_this_cp_rank.""" + mock_cp_group = MagicMock() + cfg = self._make_cfg(use_cp=True, cp_group=mock_cp_group) + adapter = PartitionAdapter(cfg) + embeddings = torch.rand(2, 8, 16) + sharded = {'embeddings': embeddings[:, :4, :]} + captured = {} + + def mock_fn(batch): + captured.update(batch) + return sharded + + with patch( + 'megatron.core.models.mimo.partition.utils.get_batch_on_this_cp_rank', + side_effect=mock_fn, + ): + out = adapter._apply_context_parallel(embeddings, None, None, None, None) + + assert 'embeddings' in captured + assert 'labels' not in captured + assert 'loss_mask' not in captured + assert out[0] is not None + assert out[1] is None + + def test_thd_path_raises_when_te_unavailable(self): + """THD format must assert when Transformer Engine is not available.""" + from megatron.core.packed_seq_params import PackedSeqParams + + mock_cp_group = MagicMock() + cfg = self._make_cfg(use_cp=True, cp_group=mock_cp_group) + adapter = PartitionAdapter(cfg) + embeddings = torch.rand(2, 5, 16) + packed_seq_params = MagicMock(spec=PackedSeqParams) + packed_seq_params.qkv_format = 'thd' + with ( + patch('megatron.core.models.mimo.partition.utils._HAVE_TEX', False), + pytest.raises(AssertionError, match="Transformer Engine"), + ): + adapter._apply_context_parallel(embeddings, None, None, None, packed_seq_params) diff --git a/tests/unit_tests/resharding/test_mxfp8_refit.py b/tests/unit_tests/resharding/test_mxfp8_refit.py new file mode 100644 index 00000000000..815d4eeedac --- /dev/null +++ b/tests/unit_tests/resharding/test_mxfp8_refit.py @@ -0,0 +1,242 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +import pytest +import torch + +_IS_BLACKWELL = torch.cuda.is_available() and (torch.cuda.get_device_properties(0).major >= 10) + +try: + from flashinfer import mxfp8_quantize + + _HAVE_FLASHINFER = True +except ImportError: + _HAVE_FLASHINFER = False + +pytestmark = [ + pytest.mark.skipif(not _IS_BLACKWELL, reason="MXFP8 tests require Blackwell GPU (SM >= 10)"), + pytest.mark.skipif(not _HAVE_FLASHINFER, reason="MXFP8 tests require FlashInfer"), +] + + +# =========================================================================== +# MXFP8ReshardTransform +# =========================================================================== + + +class TestMXFP8ReshardTransform: + """Tests for the core MXFP8 reshard transform (transforms.py). + + These test the receiver-side BF16→MXFP8 conversion paths that run on + every refit iteration, including the critical 1D-scale accumulation + logic that avoids corrupting swizzled scales from partial updates. + """ + + def _make_persistent_buffers(self, shapes): + from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor + + buffers = {} + for name, (M, K) in shapes.items(): + x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + buffers[name] = MXFP8Tensor.from_bf16(x) + return buffers + + def test_finalize_recv_bf16_2d_scale(self): + """Receiver-side conversion with 2D scale: immediate per-slice quantization.""" + from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor + from megatron.core.resharding.transforms import MXFP8ReshardTransform + + M, K = 64, 128 + buf = MXFP8Tensor.from_bf16(torch.randn(M, K, dtype=torch.bfloat16, device="cuda")) + + if buf.scale.ndim != 2: + pytest.skip("FlashInfer produced 1D swizzled scale; 2D-scale test not applicable") + + t = MXFP8ReshardTransform( + convertible_params={"decoder.weight"}, + persistent_buffers={"weight": buf}, + buffer_key_prefix="decoder.", + convert_on_send=False, + ) + + new_data = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + t.finalize_recv("decoder.weight", (slice(None), slice(None)), [new_data]) + + expected = MXFP8Tensor.from_bf16(new_data) + assert torch.equal(buf.data, expected.data) + assert torch.equal(buf.scale, expected.scale) + + def test_finalize_recv_bf16_1d_scale_accumulation(self): + """Receiver-side conversion with 1D scale: accumulate slices then quantize.""" + from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor + from megatron.core.resharding.transforms import MXFP8ReshardTransform + + M, K = 64, 128 + buf = MXFP8Tensor.from_bf16(torch.randn(M, K, dtype=torch.bfloat16, device="cuda")) + + if buf.scale.ndim != 1: + pytest.skip("FlashInfer produced 2D scale; 1D-scale accumulation test not applicable") + + t = MXFP8ReshardTransform( + convertible_params={"decoder.weight"}, + persistent_buffers={"weight": buf}, + buffer_key_prefix="decoder.", + convert_on_send=False, + ) + + full_data = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + half = M // 2 + + # First slice: should accumulate (not finalize yet) + t.finalize_recv("decoder.weight", (slice(0, half), slice(None)), [full_data[:half]]) + assert "weight" in t._pending_1d, "Should be pending after partial slice" + + # Second slice: should trigger final quantization + t.finalize_recv("decoder.weight", (slice(half, M), slice(None)), [full_data[half:]]) + assert "weight" not in t._pending_1d, "Should be finalized after all slices" + + expected = MXFP8Tensor.from_bf16(full_data) + assert torch.equal(buf.data, expected.data) + assert torch.equal(buf.scale, expected.scale) + + def test_finalize_recv_1d_scale_wrong_element_count(self): + """1D accumulation should raise if total elements don't match (duplicate slices).""" + from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor + from megatron.core.resharding.transforms import MXFP8ReshardTransform + + M, K = 64, 128 + buf = MXFP8Tensor.from_bf16(torch.randn(M, K, dtype=torch.bfloat16, device="cuda")) + if buf.scale.ndim != 1: + pytest.skip("Need 1D scale for this test") + + t = MXFP8ReshardTransform( + convertible_params={"decoder.weight"}, + persistent_buffers={"weight": buf}, + buffer_key_prefix="decoder.", + convert_on_send=False, + ) + + half_data = torch.randn(M // 2, K, dtype=torch.bfloat16, device="cuda") + t.finalize_recv("decoder.weight", (slice(0, M // 2), slice(None)), [half_data]) + + with pytest.raises(AssertionError, match="duplicate or missing"): + overlap = torch.randn(M // 2 + 1, K, dtype=torch.bfloat16, device="cuda") + t.finalize_recv("decoder.weight", (slice(M // 2 - 1, M), slice(None)), [overlap]) + + +# =========================================================================== +# quantize_params_to_mxfp8 +# =========================================================================== + + +class TestQuantizeParamsToMXFP8: + """Tests for persistent buffer quantization (quantization/utils.py). + + The persistent buffer address stability is critical for CUDA graph + compatibility — if addresses change, captured graphs segfault. + """ + + def test_basic_quantization_replaces_param(self): + from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor + from megatron.core.inference.quantization.utils import quantize_params_to_mxfp8 + + model = torch.nn.Linear(128, 64, bias=False).to(dtype=torch.bfloat16, device="cuda") + buffers = quantize_params_to_mxfp8(model) + + assert "weight" in buffers + assert isinstance(buffers["weight"], MXFP8Tensor) + assert buffers["weight"].data.shape == (64, 128) + assert "weight" not in model._parameters + + def test_persistent_buffer_reuse_preserves_addresses(self): + """Second call must copy into existing buffers (CUDA graph address stability).""" + from megatron.core.inference.quantization.utils import quantize_params_to_mxfp8 + + model = torch.nn.Linear(128, 64, bias=False).to(dtype=torch.bfloat16, device="cuda") + buffers = quantize_params_to_mxfp8(model) + data_ptr = buffers["weight"].data.data_ptr() + scale_ptr = buffers["weight"].scale.data_ptr() + + model2 = torch.nn.Linear(128, 64, bias=False).to(dtype=torch.bfloat16, device="cuda") + quantize_params_to_mxfp8(model2, persistent_buffers=buffers) + + assert buffers["weight"].data.data_ptr() == data_ptr + assert buffers["weight"].scale.data_ptr() == scale_ptr + + def test_nested_module_fqn(self): + """Recursive quantization should produce correct fully-qualified names.""" + from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor + from megatron.core.inference.quantization.utils import quantize_params_to_mxfp8 + + model = torch.nn.Sequential( + torch.nn.Linear(128, 64, bias=False), torch.nn.Linear(64, 32, bias=False) + ).to(dtype=torch.bfloat16, device="cuda") + buffers = quantize_params_to_mxfp8(model) + + assert "0.weight" in buffers and "1.weight" in buffers + assert isinstance(buffers["0.weight"], MXFP8Tensor) + + +# =========================================================================== +# End-to-end MXFP8 refit integration (single-GPU) +# =========================================================================== + + +class TestMXFP8RefitIntegration: + """Integration tests simulating the full send→recv→finalize refit flow.""" + + def test_full_transform_roundtrip_bf16_wire(self): + """Simulate sender sending BF16, receiver converting to MXFP8.""" + from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor + from megatron.core.resharding.transforms import MXFP8ReshardTransform + + M, K = 64, 128 + src_weight = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + src_param = torch.nn.Parameter(src_weight.clone()) + + dst_buf = MXFP8Tensor.from_bf16(torch.randn(M, K, dtype=torch.bfloat16, device="cuda")) + t = MXFP8ReshardTransform( + convertible_params={"decoder.weight"}, + persistent_buffers={"weight": dst_buf}, + buffer_key_prefix="decoder.", + convert_on_send=False, + ) + + # Simulate: prepare_send → wire → prepare_recv → finalize_recv + sent = t.prepare_send("decoder.weight", (slice(None), slice(None)), src_param) + recv_bufs = t.prepare_recv("decoder.weight", (slice(None), slice(None))) + recv_bufs[0].copy_(sent[0]) + t.finalize_recv("decoder.weight", (slice(None), slice(None)), recv_bufs) + + expected = MXFP8Tensor.from_bf16(src_weight) + assert torch.equal(dst_buf.data, expected.data) + assert torch.equal(dst_buf.scale, expected.scale) + + def test_multi_slice_assembly(self): + """Multiple row slices should correctly assemble the full quantized weight.""" + from megatron.core.inference.quantization.mxfp8_tensor import MXFP8Tensor + from megatron.core.resharding.transforms import MXFP8ReshardTransform + + M, K = 128, 256 + full_weight = torch.randn(M, K, dtype=torch.bfloat16, device="cuda") + dst_buf = MXFP8Tensor.from_bf16(torch.zeros(M, K, dtype=torch.bfloat16, device="cuda")) + + t = MXFP8ReshardTransform( + convertible_params={"decoder.weight"}, + persistent_buffers={"weight": dst_buf}, + buffer_key_prefix="decoder.", + convert_on_send=False, + ) + + # Send in 4 row-slices (simulates TP=4 refit) + chunk = M // 4 + for i in range(4): + row_slice = (slice(i * chunk, (i + 1) * chunk), slice(None)) + src_param = torch.nn.Parameter(full_weight.clone()) + sent = t.prepare_send("decoder.weight", row_slice, src_param) + recv = t.prepare_recv("decoder.weight", row_slice) + recv[0].copy_(sent[0]) + t.finalize_recv("decoder.weight", row_slice, recv) + + expected = MXFP8Tensor.from_bf16(full_weight) + assert torch.equal(dst_buf.data, expected.data) + assert torch.equal(dst_buf.scale, expected.scale) diff --git a/tests/unit_tests/ssm/ops/test_causal_conv1d_varlen.py b/tests/unit_tests/ssm/ops/test_causal_conv1d_varlen.py new file mode 100644 index 00000000000..b09457363bd --- /dev/null +++ b/tests/unit_tests/ssm/ops/test_causal_conv1d_varlen.py @@ -0,0 +1,173 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + +"""Unit tests for the Triton varlen causal conv1d kernel. + +Tests correctness of `causal_conv1d_varlen_fn` against a reference implementation +that loops over requests calling `causal_conv1d_fn` with `initial_states`. +""" + +import pytest +import torch + +from megatron.core.ssm.ops.causal_conv1d_varlen import causal_conv1d_varlen_fn + +try: + from causal_conv1d import causal_conv1d_fn + + HAS_CAUSAL_CONV1D = True +except ImportError: + HAS_CAUSAL_CONV1D = False + + +def _reference_conv1d_varlen(x, weight, bias, cu_seqlens, initial_states, activation="silu"): + """Reference: per-request loop calling causal_conv1d_fn with initial_states.""" + num_requests = cu_seqlens.shape[0] - 1 + conv_dim = x.shape[1] + d_conv = weight.shape[1] + parts = [] + for r in range(num_requests): + start = cu_seqlens[r].item() + end = cu_seqlens[r + 1].item() + if end <= start: + continue + seq_len_r = end - start + if initial_states is not None: + init_r = initial_states[r : r + 1] # (1, conv_dim, d_conv-1) + # causal_conv1d_fn with initial_states requires channels-last layout + # for both x and initial_states: create as (1, L, C) then transpose + x_r = x[start:end].unsqueeze(0).transpose(1, 2) # channels-last (1, C, L) + init_r = init_r.permute(0, 2, 1).contiguous().transpose(1, 2) # channels-last + else: + init_r = None + x_r = x[start:end].T.unsqueeze(0).contiguous() # (1, conv_dim, seq_len) + out_r = causal_conv1d_fn( + x=x_r, weight=weight, bias=bias, activation=activation, initial_states=init_r + ) + parts.append(out_r.squeeze(0).T.contiguous()) # (seq_len, conv_dim) + return torch.cat(parts, dim=0) if parts else torch.empty(0, conv_dim, device=x.device) + + +@pytest.mark.skipif(not HAS_CAUSAL_CONV1D, reason="causal_conv1d not installed") +class TestCausalConv1dVarlen: + """Test causal_conv1d_varlen_fn against per-request causal_conv1d_fn reference.""" + + @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) + def test_single_request(self, dtype): + """Single request should match causal_conv1d_fn exactly.""" + torch.manual_seed(42) + conv_dim, d_conv, seq_len = 64, 4, 32 + device = "cuda" + + x = torch.randn(seq_len, conv_dim, dtype=dtype, device=device) + weight = torch.randn(conv_dim, d_conv, dtype=dtype, device=device) + bias = torch.randn(conv_dim, dtype=dtype, device=device) + cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=device) + initial_states = torch.randn(1, conv_dim, d_conv - 1, dtype=dtype, device=device) + + out = causal_conv1d_varlen_fn(x, weight, bias, cu_seqlens, initial_states) + ref = _reference_conv1d_varlen(x, weight, bias, cu_seqlens, initial_states) + + atol = 1e-2 if dtype == torch.bfloat16 else 1e-5 + torch.testing.assert_close(out, ref, atol=atol, rtol=1e-2) + + @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16]) + def test_multiple_requests_varying_lengths(self, dtype): + """Multiple requests with different sequence lengths.""" + torch.manual_seed(123) + conv_dim, d_conv = 128, 4 + seq_lens = [10, 25, 3, 50, 8] + device = "cuda" + + total_tokens = sum(seq_lens) + x = torch.randn(total_tokens, conv_dim, dtype=dtype, device=device) + weight = torch.randn(conv_dim, d_conv, dtype=dtype, device=device) + bias = torch.randn(conv_dim, dtype=dtype, device=device) + + cu_seqlens_list = [0] + for sl in seq_lens: + cu_seqlens_list.append(cu_seqlens_list[-1] + sl) + cu_seqlens = torch.tensor(cu_seqlens_list, dtype=torch.int32, device=device) + + num_requests = len(seq_lens) + initial_states = torch.randn(num_requests, conv_dim, d_conv - 1, dtype=dtype, device=device) + + out = causal_conv1d_varlen_fn(x, weight, bias, cu_seqlens, initial_states) + ref = _reference_conv1d_varlen(x, weight, bias, cu_seqlens, initial_states) + + atol = 1e-2 if dtype == torch.bfloat16 else 1e-5 + torch.testing.assert_close(out, ref, atol=atol, rtol=1e-2) + + def test_seqlen_shorter_than_d_conv(self): + """Sequence shorter than d_conv should use initial_states for all taps.""" + torch.manual_seed(7) + conv_dim, d_conv = 32, 4 + seq_lens = [2, 1, 3] # All shorter than d_conv + device = "cuda" + dtype = torch.float32 + + total_tokens = sum(seq_lens) + x = torch.randn(total_tokens, conv_dim, dtype=dtype, device=device) + weight = torch.randn(conv_dim, d_conv, dtype=dtype, device=device) + bias = torch.randn(conv_dim, dtype=dtype, device=device) + + cu_seqlens_list = [0] + for sl in seq_lens: + cu_seqlens_list.append(cu_seqlens_list[-1] + sl) + cu_seqlens = torch.tensor(cu_seqlens_list, dtype=torch.int32, device=device) + + num_requests = len(seq_lens) + initial_states = torch.randn(num_requests, conv_dim, d_conv - 1, dtype=dtype, device=device) + + out = causal_conv1d_varlen_fn(x, weight, bias, cu_seqlens, initial_states) + ref = _reference_conv1d_varlen(x, weight, bias, cu_seqlens, initial_states) + + torch.testing.assert_close(out, ref, atol=1e-5, rtol=1e-5) + + def test_zero_initial_states(self): + """Zero initial_states should produce same result as None initial_states.""" + torch.manual_seed(99) + conv_dim, d_conv = 64, 4 + seq_lens = [16, 24] + device = "cuda" + dtype = torch.float32 + + total_tokens = sum(seq_lens) + x = torch.randn(total_tokens, conv_dim, dtype=dtype, device=device) + weight = torch.randn(conv_dim, d_conv, dtype=dtype, device=device) + bias = torch.randn(conv_dim, dtype=dtype, device=device) + + cu_seqlens_list = [0] + for sl in seq_lens: + cu_seqlens_list.append(cu_seqlens_list[-1] + sl) + cu_seqlens = torch.tensor(cu_seqlens_list, dtype=torch.int32, device=device) + + num_requests = len(seq_lens) + zero_states = torch.zeros(num_requests, conv_dim, d_conv - 1, dtype=dtype, device=device) + + out_zero = causal_conv1d_varlen_fn(x, weight, bias, cu_seqlens, zero_states) + out_none = causal_conv1d_varlen_fn(x, weight, bias, cu_seqlens, None) + + torch.testing.assert_close(out_zero, out_none, atol=1e-5, rtol=1e-5) + + def test_nonzero_vs_zero_initial_states_differ(self): + """Non-zero initial_states should produce different results from zero.""" + torch.manual_seed(55) + conv_dim, d_conv = 64, 4 + seq_len = 16 + device = "cuda" + dtype = torch.float32 + + x = torch.randn(seq_len, conv_dim, dtype=dtype, device=device) + weight = torch.randn(conv_dim, d_conv, dtype=dtype, device=device) + bias = torch.randn(conv_dim, dtype=dtype, device=device) + cu_seqlens = torch.tensor([0, seq_len], dtype=torch.int32, device=device) + + nonzero_states = torch.randn(1, conv_dim, d_conv - 1, dtype=dtype, device=device) + + out_nonzero = causal_conv1d_varlen_fn(x, weight, bias, cu_seqlens, nonzero_states) + out_none = causal_conv1d_varlen_fn(x, weight, bias, cu_seqlens, None) + + # First few tokens should differ (those that depend on initial state) + assert not torch.allclose( + out_nonzero[: d_conv - 1], out_none[: d_conv - 1], atol=1e-5 + ), "Non-zero initial states should produce different outputs for early tokens" diff --git a/tests/unit_tests/ssm/ops/test_ops_init.py b/tests/unit_tests/ssm/ops/test_ops_init.py new file mode 100644 index 00000000000..9241d044dce --- /dev/null +++ b/tests/unit_tests/ssm/ops/test_ops_init.py @@ -0,0 +1,33 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +"""Test that the megatron.core.ssm.ops package exports the public API.""" + +import unittest + +try: + from megatron.core.ssm import ops as ssm_ops + + HAVE_SSD_OPS = True +except (ImportError, Exception): + HAVE_SSD_OPS = False + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +class TestOpsPackagePublicAPI(unittest.TestCase): + """Ensure the ops package exposes the documented public API.""" + + def test_all_exported(self): + self.assertIn("mamba_chunk_scan_combined_varlen", ssm_ops.__all__) + self.assertIn("causal_conv1d_varlen_fn", ssm_ops.__all__) + + def test_mamba_chunk_scan_combined_varlen_importable(self): + self.assertTrue(hasattr(ssm_ops, "mamba_chunk_scan_combined_varlen")) + self.assertTrue(callable(ssm_ops.mamba_chunk_scan_combined_varlen)) + + def test_causal_conv1d_varlen_fn_importable(self): + self.assertTrue(hasattr(ssm_ops, "causal_conv1d_varlen_fn")) + self.assertTrue(callable(ssm_ops.causal_conv1d_varlen_fn)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/ssm/ops/test_ssd_bmm.py b/tests/unit_tests/ssm/ops/test_ssd_bmm.py new file mode 100644 index 00000000000..c4c4eef5404 --- /dev/null +++ b/tests/unit_tests/ssm/ops/test_ssd_bmm.py @@ -0,0 +1,92 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +import unittest + +import torch + +try: + from megatron.core.ssm.ops.ssd_bmm import _bmm_chunk_fwd + + HAVE_SSD_OPS = True +except (ImportError, Exception): + HAVE_SSD_OPS = False + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for SSD ops") +class TestBmmChunkFwd(unittest.TestCase): + """Tests for _bmm_chunk_fwd (C^T @ B per chunk).""" + + def setUp(self): + torch.manual_seed(42) + self.device = torch.device("cuda") + self.chunk_size = 16 + self.seqlen = 32 + self.ngroups = 2 + self.dstate = 8 # K dimension + self.cu_chunk_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) + + def test_bmm_chunk_fwd_shape(self): + # a: (seqlen, ngroups, k), b: (seqlen, ngroups, k) -> out: (nchunks, ngroups, chunk_size, chunk_size) + a = torch.randn( + self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32 + ) + b = torch.randn( + self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32 + ) + + out = _bmm_chunk_fwd( + a, b, self.chunk_size, self.cu_chunk_seqlens, causal=True, output_dtype=torch.float32 + ) + + nchunks = self.cu_chunk_seqlens.shape[0] - 1 + self.assertEqual(out.shape, (nchunks, self.ngroups, self.chunk_size, self.chunk_size)) + self.assertFalse(torch.isnan(out).any()) + + def test_bmm_chunk_fwd_vs_torch_per_chunk(self): + """Compare first chunk with explicit C^T @ B for that chunk.""" + a = torch.randn( + self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32 + ) + b = torch.randn( + self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32 + ) + + out = _bmm_chunk_fwd( + a, b, self.chunk_size, self.cu_chunk_seqlens, causal=False, output_dtype=torch.float32 + ) + + # Chunk 0: rows 0:16 of a and b. out[0, g] = a[0:16, g] @ b[0:16, g].T + # Relaxed tolerances: Triton block-wise reduction order can differ from torch; + # atol is the main check (max abs diff was ~0.008 in practice). + for g in range(self.ngroups): + a_chunk = a[0:16, g, :].contiguous() # (16, dstate) + b_chunk = b[0:16, g, :].contiguous() # (16, dstate) + expected = torch.mm(a_chunk, b_chunk.T) # (16, 16) + torch.testing.assert_close(out[0, g], expected, rtol=1.0, atol=0.02) + + def test_bmm_chunk_fwd_causal_vs_non_causal_shape(self): + a = torch.randn( + self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32 + ) + b = torch.randn( + self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32 + ) + + out_causal = _bmm_chunk_fwd(a, b, self.chunk_size, self.cu_chunk_seqlens, causal=True) + out_noncausal = _bmm_chunk_fwd(a, b, self.chunk_size, self.cu_chunk_seqlens, causal=False) + + self.assertEqual(out_causal.shape, out_noncausal.shape) + # Causal: lower triangle is correct; upper can differ + for c in range(out_causal.shape[0]): + for g in range(self.ngroups): + for i in range(self.chunk_size): + for j in range(i + 1): + self.assertTrue( + torch.allclose(out_causal[c, g, i, j], out_noncausal[c, g, i, j]), + f"c={c} g={g} i={i} j={j}", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/ssm/ops/test_ssd_chunk_scan.py b/tests/unit_tests/ssm/ops/test_ssd_chunk_scan.py new file mode 100644 index 00000000000..dd654743346 --- /dev/null +++ b/tests/unit_tests/ssm/ops/test_ssd_chunk_scan.py @@ -0,0 +1,192 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +import unittest + +import torch + +try: + from megatron.core.ssm.ops.ssd_chunk_scan import _chunk_scan_fwd + + HAVE_SSD_OPS = True +except (ImportError, Exception): + HAVE_SSD_OPS = False + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for SSD ops") +class TestChunkScanFwd(unittest.TestCase): + """Tests for _chunk_scan_fwd.""" + + def setUp(self): + torch.manual_seed(42) + self.device = torch.device("cuda") + self.seqlen = 32 + self.nheads = 4 + self.headdim = 16 + self.ngroups = 2 + self.dstate = 8 + self.chunk_size = 16 + self.nchunks = 2 + self.cu_chunk_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) + self.seq_idx = torch.tensor([0, 1], dtype=torch.int32, device=self.device) + + def test_chunk_scan_fwd_shape_and_inplace_out(self): + cb = torch.randn( + self.nchunks, + self.ngroups, + self.chunk_size, + self.chunk_size, + device=self.device, + dtype=torch.float32, + ) + x = torch.randn( + self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32 + ) + dt = torch.randn( + self.nheads, self.nchunks, self.chunk_size, device=self.device, dtype=torch.float32 + ) + dA_cumsum = torch.randn( + self.nheads, self.nchunks, self.chunk_size, device=self.device, dtype=torch.float32 + ) + C = torch.randn( + self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32 + ) + states = torch.randn( + self.nchunks, + self.nheads, + self.headdim, + self.dstate, + device=self.device, + dtype=torch.float32, + ) + out = torch.zeros( + self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32 + ) + + _chunk_scan_fwd( + cb, + x, + dt, + dA_cumsum, + C, + states, + self.cu_chunk_seqlens, + out, + self.seq_idx, + D=None, + z=None, + initial_states=None, + ) + + self.assertEqual(out.shape, (self.seqlen, self.nheads, self.headdim)) + self.assertFalse(torch.isnan(out).any()) + # Output should be non-zero (scan writes to out) + self.assertGreater(out.abs().max().item(), 0.0) + + def test_chunk_scan_fwd_with_D(self): + cb = torch.randn( + self.nchunks, + self.ngroups, + self.chunk_size, + self.chunk_size, + device=self.device, + dtype=torch.float32, + ) + x = torch.randn( + self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32 + ) + dt = torch.randn( + self.nheads, self.nchunks, self.chunk_size, device=self.device, dtype=torch.float32 + ) + dA_cumsum = torch.randn( + self.nheads, self.nchunks, self.chunk_size, device=self.device, dtype=torch.float32 + ) + C = torch.randn( + self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32 + ) + states = torch.randn( + self.nchunks, + self.nheads, + self.headdim, + self.dstate, + device=self.device, + dtype=torch.float32, + ) + out = torch.zeros( + self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32 + ) + D = torch.ones(self.nheads, self.headdim, device=self.device, dtype=torch.float32) + + _chunk_scan_fwd( + cb, + x, + dt, + dA_cumsum, + C, + states, + self.cu_chunk_seqlens, + out, + self.seq_idx, + D=D, + z=None, + initial_states=None, + ) + + self.assertFalse(torch.isnan(out).any()) + + def test_chunk_scan_fwd_with_z(self): + cb = torch.randn( + self.nchunks, + self.ngroups, + self.chunk_size, + self.chunk_size, + device=self.device, + dtype=torch.float32, + ) + x = torch.randn( + self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32 + ) + z = torch.randn( + self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32 + ) + dt = torch.randn( + self.nheads, self.nchunks, self.chunk_size, device=self.device, dtype=torch.float32 + ) + dA_cumsum = torch.randn( + self.nheads, self.nchunks, self.chunk_size, device=self.device, dtype=torch.float32 + ) + C = torch.randn( + self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32 + ) + states = torch.randn( + self.nchunks, + self.nheads, + self.headdim, + self.dstate, + device=self.device, + dtype=torch.float32, + ) + out = torch.zeros( + self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32 + ) + + _chunk_scan_fwd( + cb, + x, + dt, + dA_cumsum, + C, + states, + self.cu_chunk_seqlens, + out, + self.seq_idx, + D=None, + z=z, + initial_states=None, + ) + + self.assertFalse(torch.isnan(out).any()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/ssm/ops/test_ssd_chunk_state.py b/tests/unit_tests/ssm/ops/test_ssd_chunk_state.py new file mode 100644 index 00000000000..b848b6d7c96 --- /dev/null +++ b/tests/unit_tests/ssm/ops/test_ssd_chunk_state.py @@ -0,0 +1,208 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +import unittest + +import torch + +try: + from megatron.core.ssm.ops.ssd_chunk_state import ( + _chunk_cumsum_fwd, + _chunk_state_fwd, + chunk_state_varlen, + ) + + HAVE_SSD_OPS = True +except (ImportError, Exception): + HAVE_SSD_OPS = False + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for SSD ops") +class TestChunkCumsumFwd(unittest.TestCase): + """Tests for _chunk_cumsum_fwd.""" + + def setUp(self): + torch.manual_seed(42) + self.device = torch.device("cuda") + self.seqlen = 32 + self.nheads = 4 + self.chunk_size = 16 + self.cu_chunk_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) + + def test_chunk_cumsum_fwd_shape(self): + dt = torch.randn(self.seqlen, self.nheads, device=self.device, dtype=torch.float32) + A = torch.randn(self.nheads, device=self.device, dtype=torch.float32) + + dA_cumsum, dt_out = _chunk_cumsum_fwd(dt, A, self.chunk_size, self.cu_chunk_seqlens) + + nchunks = self.cu_chunk_seqlens.shape[0] - 1 + self.assertEqual(dA_cumsum.shape, (self.nheads, nchunks, self.chunk_size)) + self.assertEqual(dt_out.shape, (self.nheads, nchunks, self.chunk_size)) + self.assertFalse(torch.isnan(dA_cumsum).any()) + self.assertFalse(torch.isnan(dt_out).any()) + + def test_chunk_cumsum_fwd_cumsum_per_chunk(self): + """dA_cumsum should be cumsum of dt * A along the chunk dimension.""" + dt = torch.randn(self.seqlen, self.nheads, device=self.device, dtype=torch.float32) + A = torch.randn(self.nheads, device=self.device, dtype=torch.float32) + + dA_cumsum, dt_out = _chunk_cumsum_fwd( + dt, + A, + self.chunk_size, + self.cu_chunk_seqlens, + dt_bias=None, + dt_softplus=False, + dt_limit=(0.0, float("inf")), + ) + + nchunks = self.cu_chunk_seqlens.shape[0] - 1 + for c in range(nchunks): + start = self.cu_chunk_seqlens[c].item() + end = self.cu_chunk_seqlens[c + 1].item() + chunk_len = end - start + for h in range(self.nheads): + dA_chunk = (dt_out[h, c, :chunk_len] * A[h]).cpu() + expected_cumsum = torch.cumsum(dA_chunk, dim=0) + torch.testing.assert_close( + dA_cumsum[h, c, :chunk_len].cpu(), expected_cumsum, rtol=1e-4, atol=1e-4 + ) + + def test_chunk_cumsum_fwd_with_dt_bias(self): + dt = torch.randn(self.seqlen, self.nheads, device=self.device, dtype=torch.float32) + A = torch.randn(self.nheads, device=self.device, dtype=torch.float32) + dt_bias = torch.randn(self.nheads, device=self.device, dtype=torch.float32) + + dA_cumsum, dt_out = _chunk_cumsum_fwd( + dt, A, self.chunk_size, self.cu_chunk_seqlens, dt_bias=dt_bias + ) + + nchunks = self.cu_chunk_seqlens.shape[0] - 1 + self.assertEqual(dA_cumsum.shape, (self.nheads, nchunks, self.chunk_size)) + self.assertFalse(torch.isnan(dA_cumsum).any()) + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for SSD ops") +class TestChunkStateFwd(unittest.TestCase): + """Tests for _chunk_state_fwd.""" + + def setUp(self): + torch.manual_seed(42) + self.device = torch.device("cuda") + self.seqlen = 32 + self.nheads = 4 + self.headdim = 16 + self.ngroups = 2 + self.dstate = 8 + self.chunk_size = 16 + self.cu_chunk_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) + + def test_chunk_state_fwd_shape(self): + x = torch.randn( + self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32 + ) + B = torch.randn( + self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32 + ) + dt = torch.randn(self.nheads, 2, self.chunk_size, device=self.device, dtype=torch.float32) + dA_cumsum = torch.randn( + self.nheads, 2, self.chunk_size, device=self.device, dtype=torch.float32 + ) + + states = _chunk_state_fwd(B, x, dt, dA_cumsum, self.cu_chunk_seqlens) + + nchunks = self.cu_chunk_seqlens.shape[0] - 1 + self.assertEqual(states.shape, (nchunks, self.nheads, self.headdim, self.dstate)) + self.assertFalse(torch.isnan(states).any()) + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for SSD ops") +class TestChunkStateVarlen(unittest.TestCase): + """Tests for chunk_state_varlen.""" + + def setUp(self): + torch.manual_seed(42) + self.device = torch.device("cuda") + self.seqlen = 32 + self.nheads = 4 + self.headdim = 16 + self.ngroups = 2 + self.dstate = 8 + self.chunk_size = 16 + self.batch = 2 + self.cu_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) + self.cu_chunk_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) + self.last_chunk_indices = torch.tensor([0, 1], dtype=torch.int64, device=self.device) + + def test_chunk_state_varlen_shape(self): + x = torch.randn( + self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32 + ) + B = torch.randn( + self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32 + ) + dt = torch.randn(self.nheads, 2, self.chunk_size, device=self.device, dtype=torch.float32) + dA_cumsum = torch.randn( + self.nheads, 2, self.chunk_size, device=self.device, dtype=torch.float32 + ) + chunk_states = torch.randn( + 2, self.nheads, self.headdim, self.dstate, device=self.device, dtype=torch.float32 + ) + + states = chunk_state_varlen( + B, + x, + dt, + dA_cumsum, + self.cu_seqlens, + chunk_states, + last_chunk_indices=self.last_chunk_indices, + cu_chunk_seqlens=self.cu_chunk_seqlens, + ) + + self.assertEqual(states.shape, (self.batch, self.nheads, self.headdim, self.dstate)) + self.assertFalse(torch.isnan(states).any()) + + def test_chunk_state_varlen_with_initial_states(self): + x = torch.randn( + self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32 + ) + B = torch.randn( + self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32 + ) + dt = torch.randn(self.nheads, 2, self.chunk_size, device=self.device, dtype=torch.float32) + dA_cumsum = torch.randn( + self.nheads, 2, self.chunk_size, device=self.device, dtype=torch.float32 + ) + chunk_states = torch.randn( + 2, self.nheads, self.headdim, self.dstate, device=self.device, dtype=torch.float32 + ) + initial_states = torch.randn( + self.batch, + self.nheads, + self.headdim, + self.dstate, + device=self.device, + dtype=torch.float32, + ) + + states = chunk_state_varlen( + B, + x, + dt, + dA_cumsum, + self.cu_seqlens, + chunk_states, + initial_states=initial_states, + last_chunk_indices=self.last_chunk_indices, + cu_chunk_seqlens=self.cu_chunk_seqlens, + ) + + self.assertEqual(states.shape, (self.batch, self.nheads, self.headdim, self.dstate)) + self.assertFalse(torch.isnan(states).any()) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/ssm/ops/test_ssd_combined.py b/tests/unit_tests/ssm/ops/test_ssd_combined.py new file mode 100644 index 00000000000..b5ef14f7a79 --- /dev/null +++ b/tests/unit_tests/ssm/ops/test_ssd_combined.py @@ -0,0 +1,361 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +import unittest + +import torch + +try: + from megatron.core.ssm.ops.ssd_combined import is_int_pow_2, mamba_chunk_scan_combined_varlen + + HAVE_SSD_OPS = True +except (ImportError, Exception): + HAVE_SSD_OPS = False + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for SSD ops") +class TestIsIntPow2(unittest.TestCase): + """Tests for is_int_pow_2 utility.""" + + def test_powers_of_two(self): + for exp in range(12): + n = 2**exp + self.assertTrue(is_int_pow_2(n), f"2^{exp}={n} should be power of 2") + + def test_non_powers_of_two(self): + for n in [0, 3, 5, 6, 7, 9, 10, 12, 15, 18]: + self.assertFalse(is_int_pow_2(n), f"{n} should not be power of 2") + + def test_negative_and_float(self): + self.assertFalse(is_int_pow_2(-1)) + self.assertFalse(is_int_pow_2(-4)) + self.assertFalse(is_int_pow_2(2.0)) + self.assertFalse(is_int_pow_2(0)) + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for SSD ops") +class TestMambaChunkScanCombinedVarlen(unittest.TestCase): + """Tests for mamba_chunk_scan_combined_varlen.""" + + def setUp(self): + torch.manual_seed(42) + self.device = torch.device("cuda") + self.chunk_size = 16 + self.seqlen = 32 + self.nheads = 4 + self.headdim = 16 + self.ngroups = 2 + self.dstate = 8 + self.batch = 2 + # 2 chunks of 16 each + self.cu_chunk_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) + # last chunk index per sequence: seq0 ends in chunk 0, seq1 ends in chunk 1 + self.last_chunk_indices = torch.tensor([0, 1], dtype=torch.int64, device=self.device) + # seq_idx: which sequence each chunk belongs to (nchunks,) + self.seq_idx = torch.tensor([0, 1], dtype=torch.int32, device=self.device) + + def test_mamba_chunk_scan_combined_varlen_shape_and_no_nan(self): + x = torch.randn( + self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32 + ) + dt = torch.randn(self.seqlen, self.nheads, device=self.device, dtype=torch.float32) + A = torch.randn(self.nheads, device=self.device, dtype=torch.float32) + B = torch.randn( + self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32 + ) + C = torch.randn( + self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32 + ) + out = torch.empty( + self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32 + ) + + varlen_states = mamba_chunk_scan_combined_varlen( + x=x, + dt=dt, + A=A, + B=B, + C=C, + chunk_size=self.chunk_size, + cu_chunk_seqlens=self.cu_chunk_seqlens, + last_chunk_indices=self.last_chunk_indices, + seq_idx=self.seq_idx, + out=out, + ) + + self.assertEqual(varlen_states.shape, (self.batch, self.nheads, self.headdim, self.dstate)) + self.assertEqual(out.shape, (self.seqlen, self.nheads, self.headdim)) + self.assertFalse(torch.isnan(out).any(), "output should have no NaN") + self.assertFalse(torch.isnan(varlen_states).any(), "varlen_states should have no NaN") + + def test_mamba_chunk_scan_combined_varlen_with_D_and_dt_bias(self): + x = torch.randn( + self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32 + ) + dt = torch.randn(self.seqlen, self.nheads, device=self.device, dtype=torch.float32) + A = torch.randn(self.nheads, device=self.device, dtype=torch.float32) + B = torch.randn( + self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32 + ) + C = torch.randn( + self.seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32 + ) + D = torch.ones(self.nheads, self.headdim, device=self.device, dtype=torch.float32) + dt_bias = torch.randn(self.nheads, device=self.device, dtype=torch.float32) + out = torch.empty( + self.seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32 + ) + + varlen_states = mamba_chunk_scan_combined_varlen( + x=x, + dt=dt, + A=A, + B=B, + C=C, + chunk_size=self.chunk_size, + cu_chunk_seqlens=self.cu_chunk_seqlens, + last_chunk_indices=self.last_chunk_indices, + seq_idx=self.seq_idx, + out=out, + D=D, + dt_bias=dt_bias, + ) + + self.assertEqual(varlen_states.shape, (self.batch, self.nheads, self.headdim, self.dstate)) + self.assertFalse(torch.isnan(out).any()) + + def test_mamba_chunk_scan_combined_varlen_single_sequence(self): + """Single sequence of 32 tokens, split into 2 chunks of 16.""" + cu_chunk_seqlens = torch.tensor([0, 16, 32], dtype=torch.int32, device=self.device) + last_chunk_indices = torch.tensor([1], dtype=torch.int64, device=self.device) + seq_idx = torch.tensor([0, 0], dtype=torch.int32, device=self.device) + + x = torch.randn(32, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + dt = torch.randn(32, self.nheads, device=self.device, dtype=torch.float32) + A = torch.randn(self.nheads, device=self.device, dtype=torch.float32) + B = torch.randn(32, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + C = torch.randn(32, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + out = torch.empty(32, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + + varlen_states = mamba_chunk_scan_combined_varlen( + x=x, + dt=dt, + A=A, + B=B, + C=C, + chunk_size=self.chunk_size, + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, + seq_idx=seq_idx, + out=out, + ) + + self.assertEqual(varlen_states.shape, (1, self.nheads, self.headdim, self.dstate)) + self.assertFalse(torch.isnan(out).any()) + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for SSD ops") +class TestIntermediateStateExtraction(unittest.TestCase): + """Tests for intermediate_chunk_indices parameter.""" + + def setUp(self): + torch.manual_seed(42) + self.device = torch.device("cuda") + self.chunk_size = 16 + self.nheads = 4 + self.headdim = 16 + self.ngroups = 2 + self.dstate = 8 + + def _make_inputs(self, seqlen): + """Create random inputs for a single sequence of given length.""" + x = torch.randn(seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32) + dt = torch.randn(seqlen, self.nheads, device=self.device, dtype=torch.float32) + A = torch.randn(self.nheads, device=self.device, dtype=torch.float32) + B = torch.randn(seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + C = torch.randn(seqlen, self.ngroups, self.dstate, device=self.device, dtype=torch.float32) + out = torch.empty( + seqlen, self.nheads, self.headdim, device=self.device, dtype=torch.float32 + ) + return x, dt, A, B, C, out + + def test_intermediate_states_shape_and_no_nan(self): + """1 sequence, 4 chunks. Request intermediates at chunks [0, 1, 2].""" + seqlen = 64 # 4 chunks of 16 + nchunks = seqlen // self.chunk_size + x, dt, A, B, C, out = self._make_inputs(seqlen) + cu_chunk_seqlens = torch.arange( + 0, seqlen + 1, self.chunk_size, dtype=torch.int32, device=self.device + ) + last_chunk_indices = torch.tensor([nchunks - 1], dtype=torch.int64, device=self.device) + seq_idx = torch.zeros(nchunks, dtype=torch.int32, device=self.device) + intermediate_chunk_indices = torch.tensor([0, 1, 2], dtype=torch.int64, device=self.device) + + result = mamba_chunk_scan_combined_varlen( + x=x, + dt=dt, + A=A, + B=B, + C=C, + chunk_size=self.chunk_size, + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, + seq_idx=seq_idx, + out=out, + intermediate_chunk_indices=intermediate_chunk_indices, + ) + + self.assertIsInstance(result, tuple) + final_states, intermediate_states = result + self.assertEqual(final_states.shape, (1, self.nheads, self.headdim, self.dstate)) + self.assertEqual(intermediate_states.shape, (3, self.nheads, self.headdim, self.dstate)) + self.assertFalse(torch.isnan(final_states).any()) + self.assertFalse(torch.isnan(intermediate_states).any()) + + def test_intermediate_states_match_full_states(self): + """Intermediate states should match corresponding entries from full states.""" + seqlen = 64 # 4 chunks + nchunks = seqlen // self.chunk_size + x, dt, A, B, C, out = self._make_inputs(seqlen) + cu_chunk_seqlens = torch.arange( + 0, seqlen + 1, self.chunk_size, dtype=torch.int32, device=self.device + ) + last_chunk_indices = torch.tensor([nchunks - 1], dtype=torch.int64, device=self.device) + seq_idx = torch.zeros(nchunks, dtype=torch.int32, device=self.device) + + # Run with return_intermediate_states=True to get all states + out1 = torch.empty_like(out) + all_states = mamba_chunk_scan_combined_varlen( + x=x, + dt=dt, + A=A, + B=B, + C=C, + chunk_size=self.chunk_size, + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, + seq_idx=seq_idx, + out=out1, + return_intermediate_states=True, + ) + + # Run with intermediate_chunk_indices + indices = [0, 1, 2] + intermediate_chunk_indices = torch.tensor(indices, dtype=torch.int64, device=self.device) + out2 = torch.empty_like(out) + final_states, intermediate_states = mamba_chunk_scan_combined_varlen( + x=x, + dt=dt, + A=A, + B=B, + C=C, + chunk_size=self.chunk_size, + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, + seq_idx=seq_idx, + out=out2, + intermediate_chunk_indices=intermediate_chunk_indices, + ) + + # Intermediate states should match the corresponding all_states entries + for i, chunk_idx in enumerate(indices): + torch.testing.assert_close( + intermediate_states[i], + all_states[chunk_idx], + msg=f"intermediate state at index {i} (chunk {chunk_idx}) does not match", + ) + + # Final state should match last chunk + torch.testing.assert_close(final_states[0], all_states[nchunks - 1]) + + def test_intermediate_states_multi_sequence(self): + """2 packed sequences, verify intermediate extraction across sequence boundaries.""" + seq1_len = 32 # 2 chunks + seq2_len = 48 # 3 chunks + total_len = seq1_len + seq2_len + x, dt, A, B, C, out = self._make_inputs(total_len) + + # cu_chunk_seqlens: seq1 has chunks at [0, 16, 32], seq2 at [32, 48, 64, 80] + boundaries = list(range(0, seq1_len + 1, self.chunk_size)) + list( + range(seq1_len + self.chunk_size, total_len + 1, self.chunk_size) + ) + cu_chunk_seqlens = torch.tensor(boundaries, dtype=torch.int32, device=self.device) + nchunks = len(boundaries) - 1 # 5 chunks total + # Last chunk for seq1 is chunk 1, for seq2 is chunk 4 + last_chunk_indices = torch.tensor([1, 4], dtype=torch.int64, device=self.device) + # seq_idx: [0, 0, 1, 1, 1] + seq_idx = torch.tensor([0, 0, 1, 1, 1], dtype=torch.int32, device=self.device) + + # Request chunk 0 from seq1 and chunks 2, 3 from seq2 + intermediate_chunk_indices = torch.tensor([0, 2, 3], dtype=torch.int64, device=self.device) + + # Also get full states for comparison + out_full = torch.empty_like(out) + all_states = mamba_chunk_scan_combined_varlen( + x=x, + dt=dt, + A=A, + B=B, + C=C, + chunk_size=self.chunk_size, + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, + seq_idx=seq_idx, + out=out_full, + return_intermediate_states=True, + ) + + out2 = torch.empty_like(out) + final_states, intermediate_states = mamba_chunk_scan_combined_varlen( + x=x, + dt=dt, + A=A, + B=B, + C=C, + chunk_size=self.chunk_size, + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, + seq_idx=seq_idx, + out=out2, + intermediate_chunk_indices=intermediate_chunk_indices, + ) + + self.assertEqual(final_states.shape, (2, self.nheads, self.headdim, self.dstate)) + self.assertEqual(intermediate_states.shape, (3, self.nheads, self.headdim, self.dstate)) + + # Verify intermediate states match full states + for i, chunk_idx in enumerate([0, 2, 3]): + torch.testing.assert_close(intermediate_states[i], all_states[chunk_idx]) + + def test_no_intermediate_returns_tensor(self): + """Without intermediate_chunk_indices, result should be a plain tensor.""" + seqlen = 32 + nchunks = seqlen // self.chunk_size + x, dt, A, B, C, out = self._make_inputs(seqlen) + cu_chunk_seqlens = torch.arange( + 0, seqlen + 1, self.chunk_size, dtype=torch.int32, device=self.device + ) + last_chunk_indices = torch.tensor([nchunks - 1], dtype=torch.int64, device=self.device) + seq_idx = torch.zeros(nchunks, dtype=torch.int32, device=self.device) + + result = mamba_chunk_scan_combined_varlen( + x=x, + dt=dt, + A=A, + B=B, + C=C, + chunk_size=self.chunk_size, + cu_chunk_seqlens=cu_chunk_seqlens, + last_chunk_indices=last_chunk_indices, + seq_idx=seq_idx, + out=out, + ) + + self.assertIsInstance(result, torch.Tensor) + self.assertEqual(result.shape, (1, self.nheads, self.headdim, self.dstate)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/ssm/ops/test_ssd_state_passing.py b/tests/unit_tests/ssm/ops/test_ssd_state_passing.py new file mode 100644 index 00000000000..f791d615078 --- /dev/null +++ b/tests/unit_tests/ssm/ops/test_ssd_state_passing.py @@ -0,0 +1,89 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +import unittest + +import torch + +try: + from megatron.core.ssm.ops.ssd_state_passing import _state_passing_fwd + + HAVE_SSD_OPS = True +except (ImportError, Exception): + HAVE_SSD_OPS = False + + +@unittest.skipIf(not HAVE_SSD_OPS, "SSD ops (Triton 3+) not available") +@unittest.skipIf(not torch.cuda.is_available(), "CUDA required for SSD ops") +class TestStatePassingFwd(unittest.TestCase): + """Tests for _state_passing_fwd: recurrence out = exp(dA_cs_last) * prev + new_states.""" + + def setUp(self): + torch.manual_seed(42) + self.device = torch.device("cuda") + self.nchunks = 4 + self.nheads = 2 + self.chunk_size = 16 + self.dim = self.chunk_size * 8 # headdim * dstate flattened + self.cu_chunk_seqlens = torch.tensor( + [0, 16, 32, 48, 64], dtype=torch.int32, device=self.device + ) + + def test_state_passing_fwd_shape(self): + states = torch.randn( + self.nchunks, self.nheads, self.dim, device=self.device, dtype=torch.float32 + ) + dA_cumsum = torch.randn( + self.nheads, self.nchunks, self.chunk_size, device=self.device, dtype=torch.float32 + ) + seq_idx = torch.zeros(self.nchunks, dtype=torch.int32, device=self.device) + + out = _state_passing_fwd( + states, dA_cumsum, self.cu_chunk_seqlens, seq_idx, initial_states=None + ) + + self.assertEqual(out.shape, (self.nchunks, self.nheads, self.dim)) + self.assertFalse(torch.isnan(out).any()) + + def test_state_passing_fwd_with_initial_states(self): + states = torch.randn( + self.nchunks, self.nheads, self.dim, device=self.device, dtype=torch.float32 + ) + dA_cumsum = torch.randn( + self.nheads, self.nchunks, self.chunk_size, device=self.device, dtype=torch.float32 + ) + seq_idx = torch.tensor([0, 0, 1, 1], dtype=torch.int32, device=self.device) + initial_states = torch.randn( + 2, self.nheads, self.dim, device=self.device, dtype=torch.float32 + ) + + out = _state_passing_fwd( + states, dA_cumsum, self.cu_chunk_seqlens, seq_idx, initial_states=initial_states + ) + + self.assertEqual(out.shape, (self.nchunks, self.nheads, self.dim)) + self.assertFalse(torch.isnan(out).any()) + + def test_state_passing_fwd_recurrence_single_head_single_dim(self): + """Sanity: single head, small dim, check recurrence manually for first elements.""" + dim = 4 + nchunks = 2 + nheads = 1 + chunk_size = 2 + cu_chunk_seqlens = torch.tensor([0, 2, 4], dtype=torch.int32, device=self.device) + seq_idx = torch.zeros(nchunks, dtype=torch.int32, device=self.device) + + states = torch.randn(nchunks, nheads, dim, device=self.device, dtype=torch.float32) + dA_cumsum = torch.randn( + nheads, nchunks, chunk_size, device=self.device, dtype=torch.float32 + ) + + out = _state_passing_fwd(states, dA_cumsum, cu_chunk_seqlens, seq_idx) + + # Chunk 0: out[0] = exp(dA_cumsum[0,-1]) * 0 + states[0] = states[0] (no initial state) + # So out[0] should equal states[0] + torch.testing.assert_close(out[0], states[0], rtol=1e-4, atol=1e-4) + self.assertEqual(out.shape, (nchunks, nheads, dim)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/ssm/ops/test_ssm_kernel.py b/tests/unit_tests/ssm/ops/test_ssm_kernel.py new file mode 100644 index 00000000000..a13ee59c794 --- /dev/null +++ b/tests/unit_tests/ssm/ops/test_ssm_kernel.py @@ -0,0 +1,189 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +import math +import unittest +from unittest.mock import MagicMock + +import torch +import torch.nn as nn + +# Assume the provided class is in mamba_mixer.py +from megatron.core.ssm.mamba_mixer import MambaMixer + + +class MockContextParallel: + """ + Mocks the MambaContextParallel helper. + """ + + def __init__(self, d_inner, ngroups, nheads, d_state, device): + self.d_inner_local_tpcp = d_inner + self.ngroups_local_tpcp = ngroups + self.nheads_local_tpcp = nheads + self.cp_size = 1 + + # Random weights for the mock + self.conv1d_weight = torch.randn(d_inner + 2 * ngroups * d_state, 1, 4, device=device) + self.conv1d_bias = torch.randn(d_inner + 2 * ngroups * d_state, device=device) + self.A_log = torch.randn(nheads, device=device) + self.D = torch.ones(nheads, device=device) + self.dt_bias = torch.randn(nheads, device=device) + + # Simple conv1d layer for the fallback path if needed + self.conv1d_layer = nn.Conv1d( + in_channels=self.conv1d_weight.shape[0], + out_channels=self.conv1d_weight.shape[0], + kernel_size=4, + groups=self.conv1d_weight.shape[0], + padding=3, + ).to(device) + + def get_A_log(self): + return self.A_log + + def get_D(self): + return self.D + + def get_dt_bias(self): + return self.dt_bias + + def get_conv1d_weight(self): + return self.conv1d_weight + + def get_conv1d_bias(self): + return self.conv1d_bias + + def conv1d(self, x): + return self.conv1d_layer(x) + + def pre_conv_ssm(self, x): + return x + + def post_conv_ssm(self, x): + return x + + +class TestMambaDynamicInference(unittest.TestCase): + + def setUp(self): + torch.manual_seed(42) + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + if self.device.type == 'cpu': + self.skipTest("Mamba Triton kernels require CUDA") + + # --- Configuration --- + self.d_model = 256 + self.d_state = 16 + self.headdim = 64 + self.d_conv = 4 + self.ngroups = 1 + self.d_inner = self.d_model * 2 # expand=2 + self.nheads = self.d_inner // self.headdim + + # Create the Mixer instance directly + self.mixer = MagicMock(spec=MambaMixer) + self.mixer.d_state = self.d_state + self.mixer.d_conv = self.d_conv + self.mixer.headdim = self.headdim + self.mixer.chunk_size = 256 + self.mixer.activation = "silu" + self.mixer.act = nn.SiLU() + self.mixer.D_has_hdim = False + self.mixer.rmsnorm = True + + # Mock the Context Parallel wrapper (used by ssm_prefill) + self.mixer.cp = MockContextParallel( + d_inner=self.d_inner, + ngroups=self.ngroups, + nheads=self.nheads, + d_state=self.d_state, + device=self.device, + ) + + # --- Setup for ssm_decode --- + # ssm_decode accesses attributes directly from self, not self.cp + self.mixer.d_inner_local_tp = self.d_inner + self.mixer.ngroups_local_tp = self.ngroups + self.mixer.nheads_local_tp = self.nheads + + # Create real parameters for ssm_decode to access + conv_dim = self.d_inner + 2 * self.ngroups * self.d_state + self.mixer.conv1d = nn.Conv1d( + in_channels=conv_dim, + out_channels=conv_dim, + kernel_size=self.d_conv, + groups=conv_dim, + padding=self.d_conv - 1, + bias=True, + device=self.device, + ) + self.mixer.dt_bias = nn.Parameter(torch.randn(self.nheads, device=self.device)) + self.mixer.A_log = nn.Parameter(torch.randn(self.nheads, device=self.device)) + self.mixer.D = nn.Parameter(torch.ones(self.nheads, device=self.device)) + + # Bind methods + self.mixer._ssm_prefill = MambaMixer._ssm_prefill.__get__(self.mixer, MambaMixer) + self.mixer._ssm_decode = MambaMixer._ssm_decode.__get__(self.mixer, MambaMixer) + + def test_ssm_prefill_padding_isolation(self): + """ + Tests that ssm_prefill only updates states for the real request + and that padding request states remain untouched. + + _ssm_prefill expects inputs pre-stripped to real tokens only + (stripping is done by _dynamic_inference_prefill). This test + passes only the real tokens and verifies that only the active + request's state is modified. + """ + num_requests = 48 + real_seq_len = 6 + + # Inputs: only real tokens (padding is stripped upstream) + dim_inputs = self.d_inner * 2 + 2 * self.ngroups * self.d_state + self.nheads + zxBCdt = torch.randn(real_seq_len, 1, dim_inputs, device=self.device, dtype=torch.float32) + + # Metadata: single real request + seq_idx = torch.zeros((1, real_seq_len), dtype=torch.int32, device=self.device) + + cu_seqlens = torch.tensor([0, real_seq_len], dtype=torch.int32, device=self.device) + + batch_indices = torch.tensor([0], dtype=torch.long, device=self.device) + + # States + conv_dim = self.d_inner + 2 * self.ngroups * self.d_state + conv_state = torch.zeros(num_requests, conv_dim, self.d_conv, device=self.device) + ssm_state = torch.zeros( + num_requests, self.nheads, self.headdim, self.d_state, device=self.device + ) + + # Run + self.mixer.norm = MagicMock(side_effect=lambda x, z: x * z) + output = self.mixer._ssm_prefill( + zxBCdt=zxBCdt, + conv_state=conv_state, + ssm_state=ssm_state, + seq_idx=seq_idx, + cu_seqlens=cu_seqlens, + batch_indices=batch_indices, + ) + + # Output should have real_seq_len tokens + self.assertEqual(output.shape[0], real_seq_len) + self.assertTrue(conv_state[0].abs().max() > 0, "Real request conv_state should be modified") + + # Verify isolation of padding states + remaining_conv_states = conv_state[1:num_requests] + remaining_ssm_states = ssm_state[1:num_requests] + + self.assertTrue( + torch.allclose(remaining_conv_states, torch.zeros_like(remaining_conv_states)), + "Conv states for padding requests (indices 1 to N-1) should remain 0", + ) + self.assertTrue( + torch.allclose(remaining_ssm_states, torch.zeros_like(remaining_ssm_states)), + "SSM states for padding requests (indices 1 to N-1) should remain 0", + ) + + +if __name__ == '__main__': + unittest.main(argv=['first-arg-is-ignored'], exit=False) diff --git a/tests/unit_tests/ssm/test_causal_conv1d_triton.py b/tests/unit_tests/ssm/test_causal_conv1d_triton.py new file mode 100644 index 00000000000..3015f5ed989 --- /dev/null +++ b/tests/unit_tests/ssm/test_causal_conv1d_triton.py @@ -0,0 +1,258 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. + +import pytest +import torch + +from megatron.core.ssm.ops.causal_conv1d_triton import causal_conv1d_update + + +def _requires_cuda(): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + +# ---------------------- Reference Implementations ---------------------- # + + +def causal_conv1d_update_ref(x, conv_state, weight, bias, silu_activation): + """Reference: linear (non-circular) causal conv1d update.""" + batch, seq_len, dim = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + out = torch.empty_like(x) + for b in range(batch): + for s in range(seq_len): + # Shift state left by 1 + conv_state[b, :, :-1] = conv_state[b, :, 1:].clone() + conv_state[b, :, -1] = x[b, s, :] + # Convolution over the last `width` elements + window = conv_state[b, :, state_len - width : state_len].float() + w = weight.float() + val = (window * w).sum(dim=1) + if bias is not None: + val = val + bias.float() + if silu_activation: + val = val * torch.sigmoid(val) + out[b, s, :] = val.to(x.dtype) + return out + + +# ---------------------- Tests ---------------------- # + + +@pytest.mark.internal +class TestCausalConv1dUpdate: + + def setup_method(self, method): + _requires_cuda() + + @pytest.mark.parametrize("width", [2, 3, 4]) + def test_linear_no_bias(self, width): + torch.manual_seed(42) + B, seq_len, D, state_len = 2, 3, 64, 8 + x = torch.randn(B, seq_len, D, device="cuda", dtype=torch.float32) + conv_state_triton = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32) + conv_state_ref = conv_state_triton.clone() + weight = torch.randn(D, width, device="cuda", dtype=torch.float32) + + result = causal_conv1d_update( + x, conv_state_triton, weight, bias=None, silu_activation=False, conv_state_indices=None + ) + expected = causal_conv1d_update_ref( + x, conv_state_ref, weight, bias=None, silu_activation=False + ) + + torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(conv_state_triton, conv_state_ref, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("width", [2, 3, 4]) + def test_linear_with_bias(self, width): + torch.manual_seed(42) + B, seq_len, D, state_len = 2, 3, 64, 8 + x = torch.randn(B, seq_len, D, device="cuda", dtype=torch.float32) + conv_state_triton = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32) + conv_state_ref = conv_state_triton.clone() + weight = torch.randn(D, width, device="cuda", dtype=torch.float32) + bias = torch.randn(D, device="cuda", dtype=torch.float32) + + result = causal_conv1d_update( + x, conv_state_triton, weight, bias=bias, silu_activation=False, conv_state_indices=None + ) + expected = causal_conv1d_update_ref( + x, conv_state_ref, weight, bias=bias, silu_activation=False + ) + + torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5) + + @pytest.mark.parametrize("width", [2, 3, 4]) + def test_linear_with_silu(self, width): + torch.manual_seed(42) + B, seq_len, D, state_len = 2, 1, 64, 8 + x = torch.randn(B, seq_len, D, device="cuda", dtype=torch.float32) + conv_state_triton = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32) + conv_state_ref = conv_state_triton.clone() + weight = torch.randn(D, width, device="cuda", dtype=torch.float32) + bias = torch.randn(D, device="cuda", dtype=torch.float32) + + result = causal_conv1d_update( + x, conv_state_triton, weight, bias=bias, silu_activation="silu", conv_state_indices=None + ) + expected = causal_conv1d_update_ref( + x, conv_state_ref, weight, bias=bias, silu_activation=True + ) + + torch.testing.assert_close(result, expected, atol=1e-4, rtol=1e-4) + + def test_2d_input(self): + """Test that 2D input (B, D) is handled correctly and returns 2D output.""" + torch.manual_seed(42) + B, D, state_len, width = 2, 64, 8, 4 + x = torch.randn(B, D, device="cuda", dtype=torch.float32) + conv_state = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32) + weight = torch.randn(D, width, device="cuda", dtype=torch.float32) + + result = causal_conv1d_update( + x, conv_state, weight, bias=None, silu_activation=False, conv_state_indices=None + ) + + assert result.dim() == 2 + assert result.shape == (B, D) + + def test_conv_state_indices(self): + """Test that conv_state_indices correctly maps batch to state entries.""" + torch.manual_seed(42) + B, D, state_len, width = 2, 64, 8, 4 + num_states = 4 + x = torch.randn(B, 1, D, device="cuda", dtype=torch.float32) + conv_state = torch.randn(num_states, D, state_len, device="cuda", dtype=torch.float32) + weight = torch.randn(D, width, device="cuda", dtype=torch.float32) + # Map batch 0 -> state 2, batch 1 -> state 0 + state_indices = torch.tensor([2, 0], device="cuda", dtype=torch.int32) + + # Run with indices + conv_state_copy = conv_state.clone() + result = causal_conv1d_update( + x, + conv_state_copy, + weight, + bias=None, + silu_activation=False, + conv_state_indices=state_indices, + ) + + # Run without indices by manually reordering + conv_state_reordered = conv_state[state_indices.long()].clone() + expected = causal_conv1d_update( + x, + conv_state_reordered, + weight, + bias=None, + silu_activation=False, + conv_state_indices=None, + ) + + torch.testing.assert_close(result, expected, atol=1e-5, rtol=1e-5) + + def test_negative_state_index_zeros_output(self): + """Padding batch entries (index < 0) should produce zero output.""" + torch.manual_seed(42) + B, D, state_len, width = 2, 64, 8, 4 + x = torch.randn(B, 1, D, device="cuda", dtype=torch.float32) + conv_state = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32) + weight = torch.randn(D, width, device="cuda", dtype=torch.float32) + state_indices = torch.tensor([-1, 0], device="cuda", dtype=torch.int32) + + result = causal_conv1d_update( + x, + conv_state, + weight, + bias=None, + silu_activation=False, + conv_state_indices=state_indices, + ) + + # Batch 0 (padded) should be all zeros + torch.testing.assert_close(result[0], torch.zeros(1, D, device="cuda", dtype=torch.float32)) + + @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) + def test_half_precision(self, dtype): + torch.manual_seed(42) + B, seq_len, D, state_len, width = 2, 1, 64, 8, 4 + x = torch.randn(B, seq_len, D, device="cuda", dtype=dtype) + conv_state = torch.randn(B, D, state_len, device="cuda", dtype=dtype) + weight = torch.randn(D, width, device="cuda", dtype=dtype) + + result = causal_conv1d_update( + x, conv_state, weight, bias=None, silu_activation=False, conv_state_indices=None + ) + + assert result.dtype == dtype + assert result.shape == (B, seq_len, D) + assert torch.isfinite(result).all() + + @pytest.mark.parametrize("width", [2, 3, 4]) + def test_intermediate_state(self, width): + """Test that intermediate conv states are correctly stored at each sequence step.""" + torch.manual_seed(42) + B, seq_len, D, state_len = 2, 4, 64, 8 + x = torch.randn(B, seq_len, D, device="cuda", dtype=torch.float32) + conv_state = torch.randn(B, D, state_len, device="cuda", dtype=torch.float32) + weight = torch.randn(D, width, device="cuda", dtype=torch.float32) + + # Allocate intermediate state buffer: (B, seq_len, D, state_len) + int_states = torch.zeros(B, seq_len, D, state_len, device="cuda", dtype=torch.float32) + + # Run with intermediate state recording + conv_state_copy = conv_state.clone() + result = causal_conv1d_update( + x, + conv_state_copy, + weight, + bias=None, + silu_activation=False, + conv_state_indices=None, + intermediate_conv_states=int_states, + ) + + # Verify by running step-by-step and checking each intermediate + conv_state_ref = conv_state.clone() + for s in range(seq_len): + conv_state_ref[:, :, :-1] = conv_state_ref[:, :, 1:].clone() + conv_state_ref[:, :, -1] = x[:, s, :] + torch.testing.assert_close(int_states[:, s, :, :], conv_state_ref, atol=1e-5, rtol=1e-5) + + def test_intermediate_state_with_indices(self): + """Test intermediate states work correctly with conv_state_indices mapping.""" + torch.manual_seed(42) + B, seq_len, D, state_len, width = 2, 3, 64, 8, 4 + num_states = 4 + x = torch.randn(B, seq_len, D, device="cuda", dtype=torch.float32) + conv_state = torch.randn(num_states, D, state_len, device="cuda", dtype=torch.float32) + weight = torch.randn(D, width, device="cuda", dtype=torch.float32) + state_indices = torch.tensor([2, 0], device="cuda", dtype=torch.int32) + + # Intermediate states are indexed by state_batch_coord (i.e., req index, not batch index) + int_states = torch.zeros( + num_states, seq_len, D, state_len, device="cuda", dtype=torch.float32 + ) + + conv_state_copy = conv_state.clone() + result = causal_conv1d_update( + x, + conv_state_copy, + weight, + bias=None, + silu_activation=False, + conv_state_indices=state_indices, + intermediate_conv_states=int_states, + ) + + # The final intermediate state at last seq step should match the final conv_state + for b_idx in range(B): + req_idx = state_indices[b_idx].item() + torch.testing.assert_close( + int_states[req_idx, seq_len - 1, :, :], + conv_state_copy[req_idx, :, :], + atol=1e-5, + rtol=1e-5, + ) diff --git a/tools/trigger_internal_ci.py b/tools/trigger_internal_ci.py new file mode 100644 index 00000000000..3b97c92332e --- /dev/null +++ b/tools/trigger_internal_ci.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""CLI tool to trigger the internal GitLab CI pipeline from a local branch. + +Pushes the current branch to the internal GitLab remote under the +pull-request/ naming convention and triggers a pipeline with +the specified test configuration. +""" + +import argparse +import logging +import os +import subprocess +import sys +from urllib.parse import urlparse + +import gitlab # python-gitlab + +GITLAB_PROJECT_ID = 19378 +GITLAB_BRANCH_PREFIX = "pull-request" + +PIPELINE_VARIABLES_FIXED = { + "UNIT_TEST": "no", + "INTEGRATION_TEST": "no", +} + +logger = logging.getLogger(__name__) + + +def get_remote_url(origin): + """Return the fetch URL configured for the given git remote name.""" + result = subprocess.run( + ["git", "remote", "get-url", origin], + capture_output=True, + text=True, + check=True, + ) + return result.stdout.strip() + + +def get_gitlab_hostname(remote_url): + """Extract the hostname (without port) from an SSH or HTTPS remote URL.""" + if remote_url.startswith("git@"): + hostname = remote_url.split("@", 1)[1].split(":")[0] + else: + hostname = urlparse(remote_url).hostname + return hostname.split(":")[0] + + +def get_current_branch(): + """Return the name of the currently checked-out git branch.""" + result = subprocess.run( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], + capture_output=True, + text=True, + check=True, + ) + return result.stdout.strip() + + +def git_push(origin, target_branch, dry_run=False): + """Force-push HEAD to the given branch on the named git remote.""" + if dry_run: + logger.info("[DRY RUN] Would push HEAD to remote '%s' as %s", origin, target_branch) + return + subprocess.run( + ["git", "push", origin, f"HEAD:{target_branch}", "--force"], + check=True, + ) + + +def trigger_pipeline(gitlab_url, trigger_token, ref, pipeline_vars, dry_run=False): + """Trigger a GitLab pipeline on the given ref with the provided variables.""" + if dry_run: + logger.info( + "[DRY RUN] Would trigger pipeline on https://%s project %s @ %s", + gitlab_url, + GITLAB_PROJECT_ID, + ref, + ) + return + logger.info( + "Triggering pipeline on https://%s project %s @ %s", gitlab_url, GITLAB_PROJECT_ID, ref + ) + gl = gitlab.Gitlab(f"https://{gitlab_url}") + project = gl.projects.get(GITLAB_PROJECT_ID, lazy=True) + pipeline = project.trigger_pipeline(ref=ref, token=trigger_token, variables=pipeline_vars) + logger.info("Pipeline triggered: %s", pipeline.web_url) + + +def main(): + """Parse arguments and orchestrate the push and pipeline trigger flow.""" + parser = argparse.ArgumentParser( + description="Trigger the internal GitLab CI pipeline for the current branch." + ) + parser.add_argument( + "--gitlab-origin", + required=True, + help="Name of the git remote pointing to the internal GitLab (e.g. gitlab)", + ) + parser.add_argument( + "--trigger-token", + default=os.environ.get("GITLAB_TRIGGER_TOKEN"), + help="GitLab pipeline trigger token (or set GITLAB_TRIGGER_TOKEN env var)", + ) + parser.add_argument( + "--functional-test-scope", + default="mr", + help="FUNCTIONAL_TEST_SCOPE pipeline variable (default: mr)", + ) + parser.add_argument( + "--functional-test-repeat", + type=int, + default=5, + help="FUNCTIONAL_TEST_REPEAT pipeline variable (default: 5)", + ) + parser.add_argument( + "--functional-test-cases", + default="all", + help="FUNCTIONAL_TEST_CASES pipeline variable (default: all)", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Print actions without executing git push or pipeline trigger", + ) + args = parser.parse_args() + logging.basicConfig(level=logging.INFO, format="%(message)s") + + if not args.trigger_token: + logger.error("--trigger-token or GITLAB_TRIGGER_TOKEN not set") + sys.exit(1) + + branch = get_current_branch() + logger.info("Current branch: %s", branch) + + remote_url = get_remote_url(args.gitlab_origin) + gitlab_hostname = get_gitlab_hostname(remote_url) + + target_branch = f"{GITLAB_BRANCH_PREFIX}/{branch}" + + git_push(args.gitlab_origin, target_branch, dry_run=args.dry_run) + + pipeline_vars = { + **PIPELINE_VARIABLES_FIXED, + "FUNCTIONAL_TEST_SCOPE": args.functional_test_scope, + "FUNCTIONAL_TEST_REPEAT": str(args.functional_test_repeat), + "FUNCTIONAL_TEST_CASES": args.functional_test_cases, + } + + trigger_pipeline( + gitlab_hostname, args.trigger_token, target_branch, pipeline_vars, dry_run=args.dry_run + ) + + +if __name__ == "__main__": + main() \ No newline at end of file