diff --git a/examples/auto_deploy/model_registry/configs/qwen3.5_moe_35b.yaml b/examples/auto_deploy/model_registry/configs/qwen3.5_moe_35b.yaml index 1a86f46262c..0870b07a3e0 100644 --- a/examples/auto_deploy/model_registry/configs/qwen3.5_moe_35b.yaml +++ b/examples/auto_deploy/model_registry/configs/qwen3.5_moe_35b.yaml @@ -1,38 +1,31 @@ runtime: trtllm compile_backend: torch-cudagraph -max_seq_len: 4096 +attn_backend: trtllm +max_seq_len: 8192 max_num_tokens: 4096 max_batch_size: 512 -world_size: 2 +world_size: 4 +cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] enable_chunked_prefill: true model_factory: AutoModelForCausalLM kv_cache_config: enable_block_reuse: false - free_gpu_memory_fraction: 0.95 - tokens_per_block: 64 + free_gpu_memory_fraction: 0.8 + tokens_per_block: 32 model_kwargs: torch_dtype: bfloat16 - # text_config: - # num_hidden_layers: 6 - # vision_config: - # depth: 2 transforms: + export_to_gm: + num_moe_experts_for_export: 2 + fuse_gemms_mixed_children: + enabled: true detect_sharding: - sharding_dims: ['tp','ep', 'bmm'] - # use only manual config for TP sharding - sharding_source: ['manual'] - manual_config: - tp_plan: - # GDN layer - "in_proj_qkv": "delta" - # attention layer - "q_proj": "colwise" - "k_proj": "colwise" - "v_proj": "colwise" - "o_proj": "rowwise" - # replicating shared experts (keep them commented out) - # "shared_expert_gate_proj": "colwise" - # "shared_expert_up_proj": "colwise" - # "shared_expert_down_proj": "rowwise" - # gating layer should be replicated as well - # "gate": "gather" + allreduce_strategy: SYMM_MEM + multi_stream_moe: + stage: compile + enabled: true + multi_stream_gemm: + stage: compile + enabled: true + gather_logits_before_lm_head: + enabled: true diff --git a/examples/auto_deploy/model_registry/configs/qwen3.5_moe_400b.yaml b/examples/auto_deploy/model_registry/configs/qwen3.5_moe_400b.yaml index 250ee830e54..49ced4dbc50 100644 --- a/examples/auto_deploy/model_registry/configs/qwen3.5_moe_400b.yaml +++ b/examples/auto_deploy/model_registry/configs/qwen3.5_moe_400b.yaml @@ -1,16 +1,17 @@ runtime: trtllm compile_backend: torch-cudagraph -max_seq_len: 2048 -max_num_tokens: 2048 -max_batch_size: 512 -cuda_graph_batch_sizes: [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] +attn_backend: trtllm +max_seq_len: 262144 +max_num_tokens: 8192 +max_batch_size: 32 +cuda_graph_batch_sizes: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32] world_size: 8 enable_chunked_prefill: true model_factory: AutoModelForCausalLM kv_cache_config: - enable_block_reuse: false - free_gpu_memory_fraction: 0.95 - tokens_per_block: 64 + enable_block_reuse: true + free_gpu_memory_fraction: 0.8 + tokens_per_block: 32 model_kwargs: torch_dtype: bfloat16 transforms: @@ -19,21 +20,12 @@ transforms: fuse_gemms_mixed_children: enabled: true detect_sharding: - sharding_dims: ['tp','ep', 'bmm'] - # use only manual config for TP sharding - sharding_source: ['manual'] - manual_config: - tp_plan: - # GDN layer - "in_proj_qkv": "delta" - # attention layer - "q_proj": "colwise" - "k_proj": "colwise" - "v_proj": "colwise" - "o_proj": "rowwise" - # replicating shared experts (keep them commented out) - # "shared_expert_gate_proj": "colwise" - # "shared_expert_up_proj": "colwise" - # "shared_expert_down_proj": "rowwise" - # gating layer should be replicated as well - # "gate": "gather" + allreduce_strategy: SYMM_MEM + multi_stream_moe: + stage: compile + enabled: true + multi_stream_gemm: + stage: compile + enabled: true + gather_logits_before_lm_head: + enabled: true diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index a2082e99695..3cd5df76b49 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -55,6 +55,8 @@ transforms: run_shape_prop: true match_l2norm_pattern: stage: pattern_matcher + match_moe_routing_pattern: + stage: pattern_matcher ############################################################################################ # RUN TRANSFORMATIONS ON STANDARDIZED GRAPH REPRESENTATION ############################################################################################ @@ -89,15 +91,20 @@ transforms: # proceeds normally. match_swiglu_pattern: stage: pattern_matcher - enabled: false + enabled: true match_nvfp4_swiglu_pattern: stage: pattern_matcher requires_shape_prop: true - enabled: false + enabled: true + match_finegrained_fp8_swiglu_pattern: + stage: pattern_matcher + requires_shape_prop: true + enabled: true quantize_fp8_moe: stage: pattern_matcher quantize_nvfp4_moe: stage: pattern_matcher + run_shape_prop: true quantize_mxfp4_moe: stage: pattern_matcher detect_hidden_states_for_capture: @@ -156,6 +163,8 @@ transforms: enabled: true fuse_nvfp4_swiglu: stage: post_load_fusion + fuse_finegrained_fp8_swiglu: + stage: post_load_fusion fuse_finegrained_fp8_linear: stage: post_load_fusion backend: trtllm @@ -185,6 +194,8 @@ transforms: rmsnorm_backend: flashinfer gated_rmsnorm_backend: triton requires_shape_prop: true + fuse_gdn_gating: + stage: post_load_fusion fuse_l2norm: stage: post_load_fusion backend: fla @@ -249,6 +260,9 @@ transforms: multi_stream_mla_attn: stage: compile enabled: false + multi_stream_gemm: + stage: compile + enabled: false compile_model: stage: compile expect_mem_change: true diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py index b6b36446212..9d0a2f887f9 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py @@ -119,7 +119,7 @@ def reset(self, device: torch.device) -> None: # NOTE (lucaslie): avoid OOM for many cudagraphs, # see https://github.com/NVIDIA/TensorRT-LLM/pull/3686 - self.workspace_buffer = torch.empty(320 * 1024 * 1024, device=device, dtype=torch.uint8) + self.workspace_buffer = torch.empty(1024 * 1024 * 1024, device=device, dtype=torch.uint8) # NOTE (lucaslie): flashinfer fa3 backend has accuracy issue + illegal memory access issues # on H100 PCIe, see https://github.com/NVIDIA/TensorRT-LLM/issues/4504 diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py index 6580ef6ce18..534e19b3cef 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py @@ -431,6 +431,9 @@ class SequenceInfo: Total sequence length including cached tokens for each sequence (input_pos + seq_len). - use_initial_states: [bool_0, bool_1, ..., bool_{b-1}] Per-sequence boolean indicating whether initial states should be used (True if input_pos > 0). + - any_prefill_use_initial_states: [bool] + Scalar boolean indicating whether any prefill sequence needs initial states. Precomputed on + the host to avoid GPU->CPU sync from torch.any() on the device tensor per layer. ### OTHER ARGUMENTS USED BY THE RUNTIME ######################################################## - extra_page_per_seq: [ep_0, ep_1, ..., ep_{b-1}] @@ -527,6 +530,7 @@ def __init__( ("last_page_len", self.max_batch_size, torch.int), ("slot_idx", self.max_batch_size, torch.long), ### INFO OBJECTS THAT ARE AVAILABLE TO DESCRIBE THE INPUTS IN A MORE COMPACT WAY ####### + ("any_prefill_use_initial_states", 1, torch.bool), ("batch_info", 3, torch.int), ("max_seq_info", 4, torch.int), ### ADDITIONAL ARGUMENTS AVAILABLE THAT ARE DERIVED FROM THE BASIC ARGUMENTS ########### @@ -1037,6 +1041,17 @@ def nest_sequences( use_initial_states = ip_host > 0 self._stage_arg("use_initial_states", use_initial_states) + # precompute any(use_initial_states[:num_prefill]) on the host to avoid + # per-layer GPU->CPU sync from torch.any() inside cached ops + if self._is_required("any_prefill_use_initial_states"): + bi_host = self.get_arg("batch_info_host") + num_prefill = bi_host[0].item() + uis = self.get_arg("use_initial_states_host", truncate=True) + self._stage_arg( + "any_prefill_use_initial_states", + [bool(uis[:num_prefill].any())], + ) + ### UPDATE LOGITS GATHERING METADATA using heuristic if not provided ####################### # default is to gather all logits if token_gather_indices is None: diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py index 6026dfe4d52..5e019519000 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py @@ -53,6 +53,7 @@ def fla_cached_delta_rule( cu_seqlen: torch.Tensor, slot_idx: torch.Tensor, use_initial_states: torch.Tensor, + any_prefill_use_initial_states_host: torch.Tensor, # EXTRA METADATA # # CACHES @@ -82,7 +83,8 @@ def fla_cached_delta_rule( if num_prefill > 0: initial_states = None - if torch.any(use_initial_states[:num_prefill]): + # Use precomputed host flag to avoid GPU->CPU sync from torch.any() + if any_prefill_use_initial_states_host.item(): initial_states = torch.where( use_initial_states[:num_prefill, None, None, None], delta_cache[slot_idx[:num_prefill]], @@ -138,6 +140,7 @@ def fla_cached_delta_rule_fake( cu_seqlen: torch.Tensor, slot_idx: torch.Tensor, use_initial_states: torch.Tensor, + any_prefill_use_initial_states_host: torch.Tensor, # EXTRA METADATA # # CACHES @@ -169,7 +172,13 @@ def get_cached_attention_op(cls) -> MHACallable: @classmethod def get_standard_metadata_args(cls) -> List[str]: - return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"] + return [ + "batch_info_host", + "cu_seqlen", + "slot_idx", + "use_initial_states", + "any_prefill_use_initial_states_host", + ] @classmethod def get_cache_initializers( diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_gated_delta.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_gated_delta.py index a0d635828c1..a514767736e 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_gated_delta.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_gated_delta.py @@ -18,17 +18,24 @@ Gated Delta Rule is based on this paper: https://arxiv.org/abs/2412.06464 Kernels are based on this repo: https://github.com/fla-org/flash-linear-attention + +This op accepts raw (un-normalized, un-expanded) q/k and raw gating projections +(a, b) together with per-head parameters (A_log, dt_bias). L2 normalization, +GQA repeat-interleave, and gating computation are performed internally: + - Decode: fully fused in fused_sigmoid_gating_delta_rule_update (L2 norm, GQA, gating) + - Prefill: explicit repeat-interleave + chunk_gated_delta_rule(use_qk_l2norm_in_kernel=True) """ from typing import List import torch +import torch.nn.functional as F from torch._ops import OpOverloadPacket from torch.fx import Node from .....llmapi.llm_args import KvCacheConfig from ....modules.fla.chunk import chunk_gated_delta_rule -from ....modules.fla.fused_recurrent import fused_recurrent_gated_delta_rule_update_fwd +from ....modules.fla.fused_sigmoid_gating_recurrent import fused_sigmoid_gating_delta_rule_update from ...utils.node_utils import extract_op_args from ..attention_interface import ( AttentionDescriptor, @@ -43,87 +50,115 @@ @torch.library.custom_op("auto_deploy::fla_cached_gated_delta_rule", mutates_args=("delta_cache",)) def fla_cached_gated_delta_rule( - # INPUTS (dense but may be flattened across sequences) + # INPUTS (raw, un-normalized, un-expanded) q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, # STANDARD METADATA batch_info_host: torch.Tensor, cu_seqlen: torch.Tensor, slot_idx: torch.Tensor, use_initial_states: torch.Tensor, - # EXTRA METADATA - # + any_prefill_use_initial_states_host: torch.Tensor, # CACHES - delta_cache: torch.Tensor, # [max_batch_size, H, K, V] + delta_cache: torch.Tensor, # [max_batch_size, HV, K, V] # CONSTANTS scale: float, ) -> torch.Tensor: - b, s, num_heads, _ = q.shape + bsz, s, H_k, K = q.shape + HV = v.shape[2] + interleave = HV // H_k - # flatten batch and sequence dims - q_flat = q.view(b * s, num_heads, -1) - k_flat = k.view(b * s, num_heads, -1) - v_flat = v.view(b * s, num_heads, -1) - g_flat = g.view(b * s, num_heads) - beta_flat = beta.view(b * s, num_heads) + # Flatten batch and sequence dims + q_flat = q.view(bsz * s, H_k, K) + k_flat = k.view(bsz * s, H_k, K) + v_flat = v.view(bsz * s, HV, -1) + a_flat = a.view(bsz * s, HV) + b_flat = b.view(bsz * s, HV) - # pre-allocate output y = torch.empty_like(v, memory_format=torch.contiguous_format) - y_flat = y.view(b * s, num_heads, -1) + y_flat = y.view(bsz * s, HV, -1) num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() num_seq = num_prefill + num_decode - # clean up metadata cu_seqlen_prefill = cu_seqlen[: num_prefill + 1] slot_idx = slot_idx[:num_seq].to(torch.long) use_initial_states = use_initial_states[:num_seq] if num_prefill > 0: initial_states = None - if torch.any(use_initial_states[:num_prefill]): + # Use precomputed host flag to avoid GPU->CPU sync from torch.any() + if any_prefill_use_initial_states_host.item(): initial_states = torch.where( use_initial_states[:num_prefill, None, None, None], delta_cache[slot_idx[:num_prefill]], 0, ) + q_pf = q_flat[None, :num_prefill_tokens] + k_pf = k_flat[None, :num_prefill_tokens] + v_pf = v_flat[None, :num_prefill_tokens] + a_pf = a_flat[None, :num_prefill_tokens] + b_pf = b_flat[None, :num_prefill_tokens] + + # GQA expand for chunk kernel (it does not handle H != HV natively) + if interleave > 1: + q_pf = q_pf.repeat_interleave(interleave, dim=2) + k_pf = k_pf.repeat_interleave(interleave, dim=2) + + # Compute g and beta from raw parameters + g_pf = -A_log.float().exp() * F.softplus(a_pf.float() + dt_bias) + beta_pf = b_pf.float().sigmoid() + y_prefill, final_state = chunk_gated_delta_rule( - q=q_flat[None, :num_prefill_tokens], - k=k_flat[None, :num_prefill_tokens], - v=v_flat[None, :num_prefill_tokens], - g=g_flat[None, :num_prefill_tokens], - beta=beta_flat[None, :num_prefill_tokens], + q=q_pf, + k=k_pf, + v=v_pf, + g=g_pf, + beta=beta_pf, scale=scale, initial_state=initial_states, output_final_state=True, cu_seqlens=cu_seqlen_prefill, + use_qk_l2norm_in_kernel=True, ) y_flat[None, :num_prefill_tokens] = y_prefill.to(y_flat.dtype) delta_cache.index_copy_(0, slot_idx[:num_prefill], final_state.to(delta_cache.dtype)) - del y_prefill, initial_states, final_state if num_decode > 0: cu_seqlen_decode = torch.arange(0, num_decode + 1, device=q.device, dtype=torch.long) - y_decode = fused_recurrent_gated_delta_rule_update_fwd( - q=q_flat[None, num_prefill_tokens:].contiguous(), - k=k_flat[None, num_prefill_tokens:].contiguous(), - v=v_flat[None, num_prefill_tokens:].contiguous(), - g=g_flat[None, num_prefill_tokens:].contiguous(), - beta=beta_flat[None, num_prefill_tokens:].contiguous(), - scale=scale, + + q_dec = q_flat[None, num_prefill_tokens:].contiguous() + k_dec = k_flat[None, num_prefill_tokens:].contiguous() + v_dec = v_flat[None, num_prefill_tokens:].contiguous() + a_dec = a_flat[None, num_prefill_tokens:].contiguous() + b_dec = b_flat[None, num_prefill_tokens:].contiguous() + + y_decode = fused_sigmoid_gating_delta_rule_update( + A_log=A_log, + a=a_dec, + dt_bias=dt_bias, + softplus_beta=1.0, + softplus_threshold=20.0, + q=q_dec, + k=k_dec, + v=v_dec, + b=b_dec, initial_state_source=delta_cache, initial_state_indices=slot_idx[num_prefill:].contiguous(), + scale=scale, + use_qk_l2norm_in_kernel=True, cu_seqlens=cu_seqlen_decode, ) y_flat[None, num_prefill_tokens:] = y_decode.to(y_flat.dtype) - del y_decode return y @@ -131,22 +166,19 @@ def fla_cached_gated_delta_rule( @fla_cached_gated_delta_rule.register_fake def fla_cached_gated_delta_rule_fake( - # INPUTS (dense but may be flattened across sequences) q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - # STANDARD METADATA + a: torch.Tensor, + b: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, batch_info_host: torch.Tensor, cu_seqlen: torch.Tensor, slot_idx: torch.Tensor, use_initial_states: torch.Tensor, - # EXTRA METADATA - # - # CACHES - delta_cache: torch.Tensor, # [max_batch_size, H, K, V] - # CONSTANTS + any_prefill_use_initial_states_host: torch.Tensor, + delta_cache: torch.Tensor, scale: float, ) -> torch.Tensor: return torch.empty_like(v) @@ -160,8 +192,8 @@ def get_attention_layout(cls) -> AttentionLayout: @classmethod def get_num_qkv_args(cls) -> int: - # q, k, v, g, beta - return 5 + # q, k, v, a, b, A_log, dt_bias + return 7 @classmethod def get_source_attention_op(cls) -> OpOverloadPacket: @@ -173,7 +205,13 @@ def get_cached_attention_op(cls) -> MHACallable: @classmethod def get_standard_metadata_args(cls) -> List[str]: - return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"] + return [ + "batch_info_host", + "cu_seqlen", + "slot_idx", + "use_initial_states", + "any_prefill_use_initial_states_host", + ] @classmethod def get_cache_initializers( @@ -181,18 +219,22 @@ def get_cache_initializers( ) -> ResourceHandlerDict: key_node = source_attn_node.args[1] value_node = source_attn_node.args[2] - num_heads = key_node.meta["val"].shape[-2] + # Cache shape is [max_batch_size, HV, K, V] where HV = num_v_heads (state per value-head). + # With GVA, q/k may have fewer heads (H_k) than v (HV), so read num_heads from value_node. + num_heads = value_node.meta["val"].shape[-2] key_dim = key_node.meta["val"].shape[-1] value_dim = value_node.meta["val"].shape[-1] - key_dtype = key_node.meta["val"].dtype return { "delta_cache": StateResourceHandler( num_heads, key_dim, value_dim, - # NOTE: not configurable at the moment, using auto to match the key dtype - dtype=cls.resolve_cache_dtype("auto", key_dtype), + # GDN state is a running recurrence (unlike KV caches which store + # independent per-token values). Bfloat16 quantization errors + # compound at every decode step through the recurrence update, so + # we always use float32 to preserve accuracy over long sequences. + dtype=torch.float32, ) } diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_gated_delta.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_gated_delta.py index ac017421b8b..75db3d4ed45 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_gated_delta.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_gated_delta.py @@ -19,7 +19,9 @@ Basic: S = S + k * (v - S*k) * beta Gated: S = S * exp(g) + k * (v - S*k) * beta -This op is used by Qwen3Next's GatedDeltaNet layers. +This op accepts raw (un-normalized, un-expanded) q/k and raw gating projections +(a, b) together with the per-head parameters (A_log, dt_bias). L2 normalization, +GQA repeat-interleave, and gating computation are all performed internally. Reference: - HF transformers v4.57.1 `torch_chunk_gated_delta_rule`: @@ -33,6 +35,16 @@ import torch.nn.functional as F +def _l2norm(x: torch.Tensor, dim: int = -1, eps: float = 1e-6) -> torch.Tensor: + """L2 normalization matching the HF/FLA convention. + + Uses ``rsqrt(sum(x^2) + eps)`` rather than ``x / max(||x||, eps)`` + (the ``F.normalize`` convention). The difference matters for small-norm + vectors because eps is added *inside* the square root here. + """ + return x * torch.rsqrt((x * x).sum(dim=dim, keepdim=True) + eps) + + def _torch_chunk_gated_delta_rule_impl( query: torch.Tensor, key: torch.Tensor, @@ -47,8 +59,8 @@ def _torch_chunk_gated_delta_rule_impl( Adapted from HF transformers v4.57.1 modeling_qwen3_next.py `torch_chunk_gated_delta_rule`. Args: - query: [B, H, S, K] - query states (already l2-normalized externally) - key: [B, H, S, K] - key states (already l2-normalized externally) + query: [B, H, S, K] - query states (l2-normalized, GQA-expanded) + key: [B, H, S, K] - key states (l2-normalized, GQA-expanded) value: [B, H, S, V] - value states g: [B, H, S] - gating/decay values (negative log-space) beta: [B, H, S] - beta scaling values (sigmoid-activated) @@ -70,7 +82,6 @@ def _torch_chunk_gated_delta_rule_impl( scale = 1.0 / (k_head_dim**0.5) query = query * scale - # Pad sequence to be divisible by chunk_size pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size query = F.pad(query, (0, 0, 0, pad_size)) key = F.pad(key, (0, 0, 0, pad_size)) @@ -82,7 +93,6 @@ def _torch_chunk_gated_delta_rule_impl( v_beta = value * beta.unsqueeze(-1) k_beta = key * beta.unsqueeze(-1) - # Reshape to chunks: [B, H, num_chunks, chunk_size, D] query, key, value, k_beta, v_beta = [ x.reshape(x.shape[0], x.shape[1], -1, chunk_size, x.shape[-1]) for x in (query, key, value, k_beta, v_beta) @@ -92,7 +102,6 @@ def _torch_chunk_gated_delta_rule_impl( torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=0 ) - # Chunk decay g = g.cumsum(dim=-1) decay_mask = ((g.unsqueeze(-1) - g.unsqueeze(-2)).tril().exp().float()).tril() attn = -((k_beta @ key.transpose(-1, -2)) * decay_mask).masked_fill(mask, 0) @@ -110,7 +119,6 @@ def _torch_chunk_gated_delta_rule_impl( torch.ones(chunk_size, chunk_size, dtype=torch.bool, device=query.device), diagonal=1 ) - # Process each chunk recurrently for i in range(0, total_sequence_length // chunk_size): q_i, k_i, v_i = query[:, :, i], key[:, :, i], value[:, :, i] attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) @@ -123,7 +131,6 @@ def _torch_chunk_gated_delta_rule_impl( + (k_i * (g[:, :, i, -1, None] - g[:, :, i]).exp()[..., None]).transpose(-1, -2) @ v_new ) - # Remove padding and reshape back core_attn_out = core_attn_out.reshape( core_attn_out.shape[0], core_attn_out.shape[1], -1, core_attn_out.shape[-1] ) @@ -136,35 +143,56 @@ def torch_gated_delta_rule( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, scale: Optional[float] = None, ) -> torch.Tensor: """Gated Delta Rule custom op for linear attention (torch reference implementation). - All inputs use the autodeploy [B, S, H, D] (bsnd) layout convention. + Performs L2 normalization, GQA repeat-interleave, gating computation, and the + gated delta rule recurrence internally. All inputs use the autodeploy [B, S, H, D] + (bsnd) layout convention. Args: - q: [B, S, H, K] - query states (should be l2-normalized before calling) - k: [B, S, H, K] - key states (should be l2-normalized before calling) - v: [B, S, H, V] - value states - g: [B, S, H] - gating/decay values - beta: [B, S, H] - beta scaling values - scale: optional query scaling factor (defaults to K^-0.5) + q: [B, S, H_k, K] - raw query states (un-normalized, un-expanded) + k: [B, S, H_k, K] - raw key states (un-normalized, un-expanded) + v: [B, S, HV, V] - value states + a: [B, S, HV] - raw gating projection (before softplus) + b: [B, S, HV] - raw beta projection (before sigmoid) + A_log: [HV] - log of decay base per value head + dt_bias: [HV] - bias added to gating projection + scale: optional query scaling factor (defaults to K^-0.5) Returns: - output: [B, S, H, V] + output: [B, S, HV, V] """ + H_k = q.shape[2] + HV = v.shape[2] + + # L2 normalize q and k (must match HF/FLA l2norm convention) + q_norm = _l2norm(q.float()).to(q.dtype) + k_norm = _l2norm(k.float()).to(k.dtype) + + # GQA expand if num_v_heads > num_k_heads + if HV > H_k: + q_norm = q_norm.repeat_interleave(HV // H_k, dim=2) + k_norm = k_norm.repeat_interleave(HV // H_k, dim=2) + + # Compute gating: g = -exp(A_log) * softplus(a + dt_bias) + g = -A_log.float().exp() * F.softplus(a.float() + dt_bias) + beta = b.float().sigmoid() + # Transpose from bsnd -> bhsd for internal computation - q_t = q.transpose(1, 2) - k_t = k.transpose(1, 2) + q_t = q_norm.transpose(1, 2) + k_t = k_norm.transpose(1, 2) v_t = v.transpose(1, 2) g_t = g.transpose(1, 2) beta_t = beta.transpose(1, 2) out = _torch_chunk_gated_delta_rule_impl(q_t, k_t, v_t, g_t, beta_t, scale=scale) - # Transpose back from bhsd -> bsnd return out.transpose(1, 2).contiguous() @@ -173,8 +201,11 @@ def torch_gated_delta_rule_fake( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, scale: Optional[float] = None, ) -> torch.Tensor: + # Output shape is [B, S, H, V] matching v (not q/k which may have fewer heads) return torch.empty_like(v) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/gdn_gating.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/gdn_gating.py new file mode 100644 index 00000000000..bea8eaea22e --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/gdn_gating.py @@ -0,0 +1,186 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Custom ops for fused GDN gating computation. + +Computes g = -exp(A_log) * softplus(a + dt_bias) in a single kernel, +collapsing 5-7 separate kernel launches into one. + +Two ops are provided: +- torch_fused_gdn_gating: pure-torch source op (used in model forward) +- triton_fused_gdn_gating: Triton kernel op (swapped in via fusion transform) +""" + +import torch +import torch.nn.functional as F +import triton +import triton.language as tl + + +# --------------------------------------------------------------------------- +# Triton kernel (adapted from tensorrt_llm/_torch/models/modeling_qwen3_next.py) +# --------------------------------------------------------------------------- +@triton.jit +def _fused_gdn_gating_kernel( + g_ptr, + A_log_ptr, + a_ptr, + dt_bias_ptr, + seq_len, + NUM_HEADS: tl.constexpr, + beta: tl.constexpr, + threshold: tl.constexpr, + BLK_HEADS: tl.constexpr, +): + """Triton kernel that computes g = -exp(A_log) * softplus(a + dt_bias). + + Grid: (batch, seq_len, cdiv(NUM_HEADS, BLK_HEADS)) + """ + i_b = tl.program_id(0) + i_s = tl.program_id(1) + i_d = tl.program_id(2) + + head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) + off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off + mask = head_off < NUM_HEADS + + blk_A_log = tl.load(A_log_ptr + head_off, mask=mask) + blk_a = tl.load(a_ptr + off, mask=mask) + blk_bias = tl.load(dt_bias_ptr + head_off, mask=mask) + + x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) + softplus_x = tl.where( + beta * x <= threshold, + (1 / beta) * tl.log(1 + tl.exp(beta * x)), + x, + ) + blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x + tl.store(g_ptr + off, blk_g.to(g_ptr.dtype.element_ty), mask=mask) + + +# --------------------------------------------------------------------------- +# torch source op (used in model forward, later replaced by fusion transform) +# --------------------------------------------------------------------------- +@torch.library.custom_op("auto_deploy::torch_fused_gdn_gating", mutates_args=()) +def torch_fused_gdn_gating( + A_log: torch.Tensor, + a: torch.Tensor, + dt_bias: torch.Tensor, + beta: float = 1.0, + threshold: float = 20.0, +) -> torch.Tensor: + """Pure-torch fused GDN gating: g = -exp(A_log) * softplus(a + dt_bias). + + Args: + A_log: [H] - log of the decay parameter + a: [B, S, H] - gating activation + dt_bias: [H] - bias added before softplus + beta: softplus beta parameter (default 1.0) + threshold: softplus threshold for numerical stability (default 20.0) + + Returns: + g: [B, S, H] in float32 + """ + g = -torch.exp(A_log.float()) * F.softplus(a.float() + dt_bias.float(), beta, threshold) + return g + + +@torch_fused_gdn_gating.register_fake +def _torch_fused_gdn_gating_fake( + A_log: torch.Tensor, + a: torch.Tensor, + dt_bias: torch.Tensor, + beta: float = 1.0, + threshold: float = 20.0, +) -> torch.Tensor: + """Fake implementation for torch.compile / export shape propagation. + + Returns: + g: [B, S, H] in float32 (same shape as a, always float32) + """ + return torch.empty_like(a, dtype=torch.float32) + + +# --------------------------------------------------------------------------- +# Triton fused op (swapped in by FuseGdnGating transform) +# --------------------------------------------------------------------------- +@torch.library.custom_op("auto_deploy::triton_fused_gdn_gating", mutates_args=()) +def triton_fused_gdn_gating( + A_log: torch.Tensor, + a: torch.Tensor, + dt_bias: torch.Tensor, + beta: float = 1.0, + threshold: float = 20.0, +) -> torch.Tensor: + """Triton-fused GDN gating: g = -exp(A_log) * softplus(a + dt_bias). + + Handles both 2D [B*S, H] and 3D [B, S, H] inputs for ``a``. + + Args: + A_log: [H] - log of the decay parameter + a: [B, S, H] - gating activation (3D) + dt_bias: [H] - bias added before softplus + beta: softplus beta parameter (default 1.0) + threshold: softplus threshold for numerical stability (default 20.0) + + Returns: + g: [B, S, H] in float32 + """ + orig_shape = a.shape + if a.dim() == 2: + # 2D input: treat as [B*S, 1, H] + batch_size = a.shape[0] + seq_len = 1 + num_heads = a.shape[1] + a_flat = a.contiguous() + else: + batch_size, seq_len, num_heads = a.shape + a_flat = a.reshape(batch_size * seq_len, num_heads).contiguous() + + g = torch.empty(batch_size * seq_len, num_heads, device=a.device, dtype=torch.float32) + + BLK_HEADS = 8 + grid = (batch_size, seq_len, triton.cdiv(num_heads, BLK_HEADS)) + + _fused_gdn_gating_kernel[grid]( + g, + A_log, + a_flat, + dt_bias, + seq_len, + num_heads, + beta, + threshold, + BLK_HEADS, + num_warps=1, + ) + + return g.reshape(orig_shape[:-1] + (num_heads,)) + + +@triton_fused_gdn_gating.register_fake +def _triton_fused_gdn_gating_fake( + A_log: torch.Tensor, + a: torch.Tensor, + dt_bias: torch.Tensor, + beta: float = 1.0, + threshold: float = 20.0, +) -> torch.Tensor: + """Fake implementation for torch.compile / export shape propagation. + + Returns: + g: same shape as ``a``, in float32 + """ + return torch.empty_like(a, dtype=torch.float32) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/torch_backend_gated_delta.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/torch_backend_gated_delta.py index 3f932795ca5..02c0245bd04 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/torch_backend_gated_delta.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fla/torch_backend_gated_delta.py @@ -18,6 +18,10 @@ The Gated Delta Rule extends the basic Delta Rule with an exponential decay gate ``g``: S = S * exp(g) + k * (v - S^T @ k) * beta +This op accepts raw (un-normalized, un-expanded) q/k and raw gating projections +(a, b) together with per-head parameters (A_log, dt_bias). L2 normalization, +GQA repeat-interleave, and gating computation are performed internally. + This module provides: - ``_torch_gated_delta_step``: single-token recurrence (decode) - ``_torch_gated_delta_prefill``: loop-based prefill over the sequence dimension @@ -31,6 +35,7 @@ from typing import List, Tuple import torch +import torch.nn.functional as F from torch._ops import OpOverloadPacket from torch.fx import Node @@ -45,6 +50,7 @@ ResourceHandlerDict, StateResourceHandler, ) +from .fla_gated_delta import _l2norm # --------------------------------------------------------------------------- # Core recurrence helpers @@ -65,8 +71,8 @@ def _torch_gated_delta_step( All computation is performed in float32 for numerical stability. Args: - q: [B, H, K] - k: [B, H, K] + q: [B, H, K] (already l2-normalized and GQA-expanded) + k: [B, H, K] (already l2-normalized and GQA-expanded) v: [B, H, V] g: [B, H] gating / decay values (negative, log-space) beta: [B, H] beta scaling values @@ -84,18 +90,11 @@ def _torch_gated_delta_step( beta = beta.float() state = state.float() - # Apply decay gate to state - state = state * torch.exp(g[..., None, None]) # [B, H, K, V] - - # Delta update: v' = v - S^T @ k - v_prime = v - torch.einsum("bhk,bhkv->bhv", k, state) # [B, H, V] - v_prime = v_prime * beta[..., None] # [B, H, V] - - # Update state: S = S + k outer v' - state = state + torch.einsum("bhk,bhv->bhkv", k, v_prime) # [B, H, K, V] - - # Output: o = (q * scale) @ S - output = torch.einsum("bhk,bhkv->bhv", q * scale, state) # [B, H, V] + state = state * torch.exp(g[..., None, None]) + v_prime = v - torch.einsum("bhk,bhkv->bhv", k, state) + v_prime = v_prime * beta[..., None] + state = state + torch.einsum("bhk,bhv->bhkv", k, v_prime) + output = torch.einsum("bhk,bhkv->bhv", q * scale, state) return output, state @@ -114,7 +113,7 @@ def _torch_gated_delta_prefill( Iterates ``_torch_gated_delta_step`` over the sequence dimension. Args: - q: [B, S, H, K] (bsnd layout) + q: [B, S, H, K] (bsnd layout, already l2-normalized and GQA-expanded) k: [B, S, H, K] v: [B, S, H, V] g: [B, S, H] @@ -131,23 +130,68 @@ def _torch_gated_delta_prefill( outputs = [] for t in range(S): - # Slice at time t from bsnd: q[:, t] gives [B, H, K] o_t, state = _torch_gated_delta_step( - q[:, t], # [B, H, K] - k[:, t], # [B, H, K] - v[:, t], # [B, H, V] - g[:, t], # [B, H] - beta[:, t], # [B, H] - state, # [B, H, K, V] + q[:, t], + k[:, t], + v[:, t], + g[:, t], + beta[:, t], + state, scale, ) - outputs.append(o_t) # [B, H, V] + outputs.append(o_t) - # Stack along seq dim: [B, S, H, V] output = torch.stack(outputs, dim=1) return output, state +# --------------------------------------------------------------------------- +# Preprocessing: L2 norm + GQA expand + gating computation +# --------------------------------------------------------------------------- + + +def _preprocess_raw_inputs( + q: torch.Tensor, + k: torch.Tensor, + a: torch.Tensor, + b_proj: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, + HV: int, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Apply L2 normalization, GQA expansion, and gating computation. + + Args: + q: [..., H_k, K] + k: [..., H_k, K] + a: [..., HV] raw gating projection + b_proj: [..., HV] raw beta projection + A_log: [HV] log of decay base + dt_bias: [HV] gating bias + HV: number of value heads + + Returns: + q_out: [..., HV, K] (l2-normed, expanded) + k_out: [..., HV, K] (l2-normed, expanded) + g: [..., HV] (decay gate in log-space) + beta: [..., HV] (sigmoid-activated scaling) + """ + H_k = q.shape[-2] + interleave = HV // H_k + + q_out = _l2norm(q.float()).to(q.dtype) + k_out = _l2norm(k.float()).to(k.dtype) + + if interleave > 1: + q_out = q_out.repeat_interleave(interleave, dim=-2) + k_out = k_out.repeat_interleave(interleave, dim=-2) + + g = -A_log.float().exp() * F.softplus(a.float() + dt_bias) + beta = b_proj.float().sigmoid() + + return q_out, k_out, g, beta + + # --------------------------------------------------------------------------- # Cached custom op # --------------------------------------------------------------------------- @@ -157,62 +201,48 @@ def _torch_gated_delta_prefill( "auto_deploy::torch_cached_gated_delta_rule", mutates_args=("delta_cache",) ) def torch_cached_gated_delta_rule( - # INPUTS (dense but may be flattened across sequences) + # INPUTS (raw, un-normalized, un-expanded) q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, # STANDARD METADATA batch_info_host: torch.Tensor, cu_seqlen: torch.Tensor, slot_idx: torch.Tensor, use_initial_states: torch.Tensor, # CACHES - delta_cache: torch.Tensor, # [max_batch_size, H, K, V] + delta_cache: torch.Tensor, # [max_batch_size, HV, K, V] # CONSTANTS scale: float, ) -> torch.Tensor: """Cached gated delta rule using pure-torch recurrence. Handles mixed prefill + decode batches. Inputs use the autodeploy bsnd layout. - - Args: - q: [B, S, H, K] - k: [B, S, H, K] - v: [B, S, H, V] - g: [B, S, H] - beta: [B, S, H] - batch_info_host: [num_prefill, num_prefill_tokens, num_decode] on host - cu_seqlen: cumulative sequence lengths for prefill sequences - slot_idx: per-sequence slot indices into delta_cache - use_initial_states: per-sequence bool (True if cache history exists) - delta_cache: [max_slots, H, K, V] recurrent state cache - scale: query scaling factor - - Returns: - output: [B, S, H, V] + L2 normalization, GQA expansion, and gating (g/beta) are computed internally. """ - b, s, num_heads, _ = q.shape + bsz, s, H_k, _ = q.shape + HV = v.shape[2] - # Pre-allocate output y = torch.empty_like(v, memory_format=torch.contiguous_format) num_prefill, num_prefill_tokens, num_decode = batch_info_host.tolist() num_seq = num_prefill + num_decode - # Clean up metadata cu_seqlen_prefill = cu_seqlen[: num_prefill + 1] slot_idx = slot_idx[:num_seq].to(torch.long) use_initial_states = use_initial_states[:num_seq] - # Flatten for indexing: [B*S, H, D] - q_flat = q.reshape(b * s, num_heads, -1) - k_flat = k.reshape(b * s, num_heads, -1) - v_flat = v.reshape(b * s, num_heads, -1) - g_flat = g.reshape(b * s, num_heads) - beta_flat = beta.reshape(b * s, num_heads) - y_flat = y.reshape(b * s, num_heads, -1) + # Flatten for indexing: [B*S, ...] + q_flat = q.reshape(bsz * s, H_k, -1) + k_flat = k.reshape(bsz * s, H_k, -1) + v_flat = v.reshape(bsz * s, HV, -1) + a_flat = a.reshape(bsz * s, HV) + b_flat = b.reshape(bsz * s, HV) + y_flat = y.reshape(bsz * s, HV, -1) key_dim = q.shape[-1] value_dim = v.shape[-1] @@ -224,20 +254,22 @@ def torch_cached_gated_delta_rule( end = cu_seqlen_prefill[seq_idx + 1].item() slot = slot_idx[seq_idx] - # Gather per-sequence tensors: [1, seq_len, H, D] - q_seq = q_flat[start:end].unsqueeze(0) # [1, S, H, K] - k_seq = k_flat[start:end].unsqueeze(0) # [1, S, H, K] - v_seq = v_flat[start:end].unsqueeze(0) # [1, S, H, V] - g_seq = g_flat[start:end].unsqueeze(0) # [1, S, H] - beta_seq = beta_flat[start:end].unsqueeze(0) # [1, S, H] + q_seq = q_flat[start:end].unsqueeze(0) # [1, S_i, H_k, K] + k_seq = k_flat[start:end].unsqueeze(0) # [1, S_i, H_k, K] + v_seq = v_flat[start:end].unsqueeze(0) # [1, S_i, HV, V] + a_seq = a_flat[start:end].unsqueeze(0) # [1, S_i, HV] + b_seq = b_flat[start:end].unsqueeze(0) # [1, S_i, HV] + + q_proc, k_proc, g_seq, beta_seq = _preprocess_raw_inputs( + q_seq, k_seq, a_seq, b_seq, A_log, dt_bias, HV + ) - # Initial state for this sequence if use_initial_states[seq_idx]: - init_state = delta_cache[slot].unsqueeze(0).clone() # [1, H, K, V] + init_state = delta_cache[slot].unsqueeze(0).clone() else: init_state = torch.zeros( 1, - num_heads, + HV, key_dim, value_dim, dtype=torch.float32, @@ -245,8 +277,8 @@ def torch_cached_gated_delta_rule( ) y_seq, final_state = _torch_gated_delta_prefill( - q_seq, - k_seq, + q_proc, + k_proc, v_seq, g_seq, beta_seq, @@ -254,10 +286,7 @@ def torch_cached_gated_delta_rule( init_state, ) - # Write output y_flat[start:end] = y_seq.squeeze(0).to(y_flat.dtype) - - # Write final state back to cache delta_cache[slot] = final_state.squeeze(0).to(delta_cache.dtype) # ---- DECODE ---- @@ -267,30 +296,29 @@ def torch_cached_gated_delta_rule( seq_idx = num_prefill + i slot = slot_idx[seq_idx] - # Single token: [H, D] - q_tok = q_flat[token_idx] # [H, K] - k_tok = k_flat[token_idx] # [H, K] - v_tok = v_flat[token_idx] # [H, V] - g_tok = g_flat[token_idx] # [H] - beta_tok = beta_flat[token_idx] # [H] + q_tok = q_flat[token_idx].unsqueeze(0) # [1, H_k, K] + k_tok = k_flat[token_idx].unsqueeze(0) # [1, H_k, K] + v_tok = v_flat[token_idx].unsqueeze(0) # [1, HV, V] + a_tok = a_flat[token_idx].unsqueeze(0) # [1, HV] + b_tok = b_flat[token_idx].unsqueeze(0) # [1, HV] + + q_proc, k_proc, g_tok, beta_tok = _preprocess_raw_inputs( + q_tok, k_tok, a_tok, b_tok, A_log, dt_bias, HV + ) - # Load state from cache - state = delta_cache[slot].unsqueeze(0).clone() # [1, H, K, V] + state = delta_cache[slot].unsqueeze(0).clone() o_tok, new_state = _torch_gated_delta_step( - q_tok.unsqueeze(0), # [1, H, K] - k_tok.unsqueeze(0), # [1, H, K] - v_tok.unsqueeze(0), # [1, H, V] - g_tok.unsqueeze(0), # [1, H] - beta_tok.unsqueeze(0), # [1, H] - state, # [1, H, K, V] + q_proc, + k_proc, + v_tok, + g_tok, + beta_tok, + state, scale, ) - # Write output y_flat[token_idx] = o_tok.squeeze(0).to(y_flat.dtype) - - # Write state back to cache delta_cache[slot] = new_state.squeeze(0).to(delta_cache.dtype) return y @@ -301,8 +329,10 @@ def torch_cached_gated_delta_rule_fake( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, batch_info_host: torch.Tensor, cu_seqlen: torch.Tensor, slot_idx: torch.Tensor, @@ -331,8 +361,8 @@ def get_attention_layout(cls) -> AttentionLayout: @classmethod def get_num_qkv_args(cls) -> int: - # q, k, v, g, beta - return 5 + # q, k, v, a, b, A_log, dt_bias + return 7 @classmethod def get_source_attention_op(cls) -> OpOverloadPacket: @@ -352,7 +382,9 @@ def get_cache_initializers( ) -> ResourceHandlerDict: key_node = source_attn_node.args[1] value_node = source_attn_node.args[2] - num_heads = key_node.meta["val"].shape[-2] + # Cache shape is [max_batch_size, HV, K, V] where HV = num_v_heads (state per value-head). + # With GVA, q/k may have fewer heads (H_k) than v (HV), so read num_heads from value_node. + num_heads = value_node.meta["val"].shape[-2] key_dim = key_node.meta["val"].shape[-1] value_dim = value_node.meta["val"].shape[-1] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/benchmark_routing.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/benchmark_routing.py new file mode 100644 index 00000000000..3781ee8198a --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/benchmark_routing.py @@ -0,0 +1,195 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Benchmark: Fused Triton top-k + softmax routing vs. baseline PyTorch. + +Compares the original 3-op MoE routing pattern used in Qwen3.5 + (softmax -> topk -> renormalize) +against the fused Triton kernel that exploits the equivalence + topk(logits) -> softmax(topk_logits) + +Usage (standalone, avoids heavy tensorrt_llm imports): + python tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/benchmark_routing.py +""" + +import os +import sys + +import torch +import torch.nn.functional as F + +# Allow running as a standalone script without triggering the full +# tensorrt_llm import chain. We import only the triton_routing module. +_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +if _THIS_DIR not in sys.path: + sys.path.insert(0, _THIS_DIR) + +from triton_routing import triton_fused_topk_softmax_fn # noqa: E402 + +# ============================================================================ +# Baseline: 3-op PyTorch implementation (softmax -> topk -> renormalize) +# ============================================================================ + + +def baseline_routing( + router_logits: torch.Tensor, + top_k: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Original Qwen3.5 MoE routing: softmax -> topk -> renormalize.""" + routing_weights = F.softmax(router_logits, dtype=torch.float, dim=-1) + routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) + return routing_weights, selected_experts + + +# ============================================================================ +# Correctness check +# ============================================================================ + + +def check_correctness( + router_logits: torch.Tensor, + top_k: int, + atol: float = 1e-5, + rtol: float = 1e-4, +) -> bool: + """Verify that fused and baseline produce identical results.""" + ref_weights, ref_indices = baseline_routing(router_logits, top_k) + fused_weights, fused_indices = triton_fused_topk_softmax_fn(router_logits, top_k) + + # Sort both by expert index within each token so order doesn't matter + ref_sort = ref_indices.sort(dim=-1) + fused_sort = fused_indices.sort(dim=-1) + + ref_weights_sorted = ref_weights.gather(-1, ref_sort.indices) + fused_weights_sorted = fused_weights.gather(-1, fused_sort.indices) + + indices_match = torch.equal(ref_sort.values.to(torch.int32), fused_sort.values) + weights_close = torch.allclose(ref_weights_sorted, fused_weights_sorted, atol=atol, rtol=rtol) + + if not indices_match: + mismatched = (ref_sort.values.to(torch.int32) != fused_sort.values).sum().item() + total = ref_sort.values.numel() + print(f" WARNING: Index mismatch in {mismatched}/{total} elements") + if not weights_close: + max_diff = (ref_weights_sorted - fused_weights_sorted).abs().max().item() + print(f" WARNING: Weight mismatch, max diff = {max_diff:.6e}") + + return indices_match and weights_close + + +# ============================================================================ +# Timing utilities +# ============================================================================ + + +def benchmark_fn( + fn, + *args, + warmup: int = 50, + iters: int = 200, +) -> float: + """Benchmark a GPU function, return median time in microseconds.""" + # Warmup + for _ in range(warmup): + fn(*args) + torch.cuda.synchronize() + + # Timed iterations using CUDA events + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(iters)] + + for i in range(iters): + start_events[i].record() + fn(*args) + end_events[i].record() + + torch.cuda.synchronize() + + times_ms = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + times_ms.sort() + # Return median in microseconds + median_ms = times_ms[len(times_ms) // 2] + return median_ms * 1000.0 + + +# ============================================================================ +# Main benchmark +# ============================================================================ + + +def run_benchmark(): + device = torch.device("cuda") + + # Qwen3.5 MoE parameters + num_experts = 256 + top_k = 8 + + token_counts = [1, 32, 128, 512, 1024, 4096] + dtypes = [torch.bfloat16, torch.float16] + + print("=" * 80) + print("MoE Routing Kernel Benchmark: Baseline (3-op) vs Fused Triton") + print(f" num_experts = {num_experts}, top_k = {top_k}") + print("=" * 80) + + # Run correctness check first + print("\n--- Correctness Checks ---") + all_correct = True + for dtype in dtypes: + for num_tokens in token_counts: + router_logits = torch.randn(num_tokens, num_experts, dtype=dtype, device=device) + dtype_str = str(dtype).replace("torch.", "") + passed = check_correctness(router_logits, top_k) + status = "PASS" if passed else "FAIL" + print(f" {dtype_str:>8s} tokens={num_tokens:<6d} {status}") + all_correct = all_correct and passed + + if not all_correct: + print("\nWARNING: Some correctness checks failed. See details above.") + else: + print("\nAll correctness checks passed.") + + # Performance benchmark + print("\n--- Performance ---") + header = ( + f"{'dtype':>8s} {'tokens':>8s} {'baseline(us)':>14s} {'fused(us)':>14s} {'speedup':>8s}" + ) + print(header) + print("-" * len(header)) + + for dtype in dtypes: + for num_tokens in token_counts: + router_logits = torch.randn(num_tokens, num_experts, dtype=dtype, device=device) + + # Benchmark baseline + t_baseline = benchmark_fn(baseline_routing, router_logits, top_k) + + # Benchmark fused Triton kernel + t_fused = benchmark_fn(triton_fused_topk_softmax_fn, router_logits, top_k) + + speedup = t_baseline / t_fused if t_fused > 0 else float("inf") + dtype_str = str(dtype).replace("torch.", "") + print( + f"{dtype_str:>8s} {num_tokens:>8d} {t_baseline:>14.2f} " + f"{t_fused:>14.2f} {speedup:>7.2f}x" + ) + + print("=" * 80) + + +if __name__ == "__main__": + run_benchmark() diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_routing.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_routing.py new file mode 100644 index 00000000000..9e8b987d1c0 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/fused_moe/triton_routing.py @@ -0,0 +1,214 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fused Triton kernel for MoE top-k routing with softmax. + +Leverages the mathematical equivalence: + topk(softmax(x)); x /= x.sum() ≡ softmax(topk(x)) + +Instead of computing softmax over ALL experts (e.g. 256), then selecting top-k, +then renormalizing, this kernel: + 1. Finds top-k from raw logits (softmax is monotonic, preserves ordering) + 2. Computes softmax only over the k selected logits (e.g. k=8) + +This fuses three separate kernel launches into one and avoids intermediate +global memory traffic. +""" + +import math + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fused_topk_softmax_kernel( + logits_ptr, # Input: (T, E) router logits + weights_ptr, # Output: (T, K) routing weights (float32) + indices_ptr, # Output: (T, K) expert indices (int32) + num_tokens, # number of tokens (T) + num_experts, # number of experts (E), e.g. 256 + stride_lt, # logits stride along token dim + stride_le, # logits stride along expert dim + stride_wt, # weights stride along token dim + stride_wk, # weights stride along topk dim + stride_it, # indices stride along token dim + stride_ik, # indices stride along topk dim + BLOCK_E: tl.constexpr, # >= num_experts, must be power of 2 + TOP_K: tl.constexpr, # number of top experts to select (any positive int) + BLOCK_K: tl.constexpr, # >= TOP_K, must be power of 2 (for Triton tensor ops) +): + """Fused top-k selection + softmax routing kernel. + + Each Triton program processes one token (row). It loads all expert logits, + iteratively finds the top-k values/indices via repeated argmax, then + computes a numerically-stable softmax over only the k selected logits. + """ + # Each program handles one token + token_id = tl.program_id(0) + if token_id >= num_tokens: + return + + # Load all expert logits for this token into registers + offs_e = tl.arange(0, BLOCK_E) + mask_e = offs_e < num_experts + logits = tl.load( + logits_ptr + token_id * stride_lt + offs_e * stride_le, + mask=mask_e, + other=float("-inf"), + ).to(tl.float32) + + # --- Iterative top-k: find k largest values and their indices --- + # Allocate with BLOCK_K (power of 2) so Triton tensor ops work for any TOP_K. + # Unused padding slots stay at -inf → exp(-inf) = 0, so softmax is unaffected. + topk_vals = tl.full([BLOCK_K], float("-inf"), dtype=tl.float32) + topk_idxs = tl.zeros([BLOCK_K], dtype=tl.int32) + offs_k = tl.arange(0, BLOCK_K) + + for k_i in tl.static_range(TOP_K): + # Find the current maximum value across all experts + max_val = tl.max(logits, axis=0) + + # Find the index of the maximum (pick smallest index on ties) + is_max = logits == max_val + # For non-max positions, substitute a large index so tl.min ignores them + candidate = tl.where(is_max, offs_e, BLOCK_E) + max_idx = tl.min(candidate, axis=0) + + # Store into the k-th slot of our top-k arrays + ki_mask = offs_k == k_i + topk_vals = tl.where(ki_mask, max_val, topk_vals) + topk_idxs = tl.where(ki_mask, max_idx.to(tl.int32), topk_idxs) + + # Mask out the found maximum so it is not selected again + logits = tl.where(offs_e == max_idx, float("-inf"), logits) + + # --- Numerically-stable softmax over only the top-k values --- + max_topk = tl.max(topk_vals, axis=0) + exp_vals = tl.exp(topk_vals - max_topk) + sum_exp = tl.sum(exp_vals, axis=0) + softmax_vals = exp_vals / sum_exp + + # --- Store results (only the valid TOP_K entries, not the BLOCK_K padding) --- + mask_k = offs_k < TOP_K + tl.store( + weights_ptr + token_id * stride_wt + offs_k * stride_wk, + softmax_vals, + mask=mask_k, + ) + tl.store( + indices_ptr + token_id * stride_it + offs_k * stride_ik, + topk_idxs, + mask=mask_k, + ) + + +def _next_power_of_2(n: int) -> int: + """Return the smallest power of 2 >= n.""" + return 1 << math.ceil(math.log2(max(n, 1))) + + +def triton_fused_topk_softmax_fn( + router_logits: torch.Tensor, + top_k: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Fused top-k + softmax routing using a single Triton kernel. + + Args: + router_logits: (T, E) float tensor of router logits. + top_k: Number of experts to select per token. + + Returns: + routing_weights: (T, top_k) tensor of softmax routing weights (same dtype as input). + selected_experts: (T, top_k) int32 tensor of expert indices. + """ + assert router_logits.ndim == 2, "router_logits must be 2-D (T, E)" + num_tokens, num_experts = router_logits.shape + + # Allocate outputs — use input dtype to avoid downstream FP32→BF16 cast kernels. + # The Triton kernel computes softmax in FP32 internally and auto-casts on store. + routing_weights = torch.empty( + (num_tokens, top_k), dtype=router_logits.dtype, device=router_logits.device + ) + selected_experts = torch.empty( + (num_tokens, top_k), dtype=torch.int32, device=router_logits.device + ) + + # Determine compile-time constants + BLOCK_E = _next_power_of_2(num_experts) + BLOCK_K = _next_power_of_2(top_k) + + # Launch grid: one program per token + grid = (num_tokens,) + + _fused_topk_softmax_kernel[grid]( + router_logits, + routing_weights, + selected_experts, + num_tokens, + num_experts, + router_logits.stride(0), + router_logits.stride(1), + routing_weights.stride(0), + routing_weights.stride(1), + selected_experts.stride(0), + selected_experts.stride(1), + BLOCK_E=BLOCK_E, + TOP_K=top_k, + BLOCK_K=BLOCK_K, + ) + + return routing_weights, selected_experts + + +# --------------------------------------------------------------------------- +# Register as a torch custom op for graph tracing / export compatibility +# --------------------------------------------------------------------------- + + +@torch.library.custom_op("auto_deploy::triton_fused_topk_softmax", mutates_args=()) +def triton_fused_topk_softmax( + router_logits: torch.Tensor, + top_k: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Fused top-k + softmax routing custom op. + + Computes ``softmax(topk(router_logits))`` in a single fused Triton kernel. + This is mathematically equivalent to the 3-step sequence + ``softmax → topk → renormalize`` used in standard MoE routers (e.g. Qwen3.5). + + Args: + router_logits: (T, E) tensor of raw router logits. + top_k: Number of top experts to select per token. + + Returns: + A tuple of: + - routing_weights: (T, top_k) tensor (same dtype as router_logits). + - selected_experts: (T, top_k) int32 tensor. + """ + return triton_fused_topk_softmax_fn(router_logits, top_k) + + +@triton_fused_topk_softmax.register_fake +def _triton_fused_topk_softmax_fake( + router_logits: torch.Tensor, + top_k: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Fake (meta) implementation for tracing / export.""" + num_tokens = router_logits.shape[0] + routing_weights = router_logits.new_empty((num_tokens, top_k), dtype=router_logits.dtype) + selected_experts = router_logits.new_empty((num_tokens, top_k), dtype=torch.int32) + return routing_weights, selected_experts diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/linear/swiglu.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/linear/swiglu.py index e13acb6f79f..b373387379b 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/linear/swiglu.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/linear/swiglu.py @@ -309,3 +309,140 @@ def _( # Output shape: [..., hidden_size] where hidden_size = down_weight.shape[0] output_shape = list(input.shape[:-1]) + [down_weight.shape[0]] return input.new_empty(output_shape, dtype=input.dtype) + + +# ── FineGrained FP8 quantized SwiGLU ops ──────────────────────────────────── + + +@torch.library.custom_op("auto_deploy::torch_finegrained_fp8_swiglu_mlp", mutates_args=()) +def torch_finegrained_fp8_swiglu_mlp( + input: torch.Tensor, + gate_weight: torch.Tensor, + up_weight: torch.Tensor, + down_weight: torch.Tensor, + gate_weight_scale: torch.Tensor, + up_weight_scale: torch.Tensor, + down_weight_scale: torch.Tensor, +) -> torch.Tensor: + """FineGrained FP8 quantized SwiGLU MLP operation (intermediate representation). + + Computes: silu(fp8_linear(x, gate)) * fp8_linear(x, up) -> fp8_linear(down) + + This is the intermediate representation used after pattern matching for FineGrained + FP8 quantized checkpoints, before gate+up weight fusion is applied. + + Args: + input: Input tensor of shape [..., hidden_size] in bfloat16. + gate_weight: FP8 gate weight [intermediate_size, hidden_size] float8_e4m3fn. + up_weight: FP8 up weight [intermediate_size, hidden_size] float8_e4m3fn. + down_weight: FP8 down weight [hidden_size, intermediate_size] float8_e4m3fn. + gate_weight_scale: Per-block weight scale for gate [N/128, K/128] float32. + up_weight_scale: Per-block weight scale for up [N/128, K/128] float32. + down_weight_scale: Per-block weight scale for down [N/128, K/128] float32. + + Returns: + Output tensor of shape [..., hidden_size]. + """ + gate_out = torch.ops.auto_deploy.torch_fake_quant_finegrained_fp8_linear( + input, + gate_weight, + None, + input_scale=[], + weight_scale=[gate_weight_scale], + input_zp=[], + weight_zp=[], + ) + up_out = torch.ops.auto_deploy.torch_fake_quant_finegrained_fp8_linear( + input, + up_weight, + None, + input_scale=[], + weight_scale=[up_weight_scale], + input_zp=[], + weight_zp=[], + ) + hidden = F.silu(gate_out) * up_out + return torch.ops.auto_deploy.torch_fake_quant_finegrained_fp8_linear( + hidden, + down_weight, + None, + input_scale=[], + weight_scale=[down_weight_scale], + input_zp=[], + weight_zp=[], + ) + + +@torch_finegrained_fp8_swiglu_mlp.register_fake +def _( + input: torch.Tensor, + gate_weight: torch.Tensor, + up_weight: torch.Tensor, + down_weight: torch.Tensor, + gate_weight_scale: torch.Tensor, + up_weight_scale: torch.Tensor, + down_weight_scale: torch.Tensor, +) -> torch.Tensor: + """Fake implementation for tracing.""" + # Output shape: [..., hidden_size] where hidden_size = down_weight.shape[0] + output_shape = list(input.shape[:-1]) + [down_weight.shape[0]] + return input.new_empty(output_shape, dtype=input.dtype) + + +@torch.library.custom_op("auto_deploy::fused_finegrained_fp8_swiglu_mlp", mutates_args=()) +def fused_finegrained_fp8_swiglu_mlp( + input: torch.Tensor, + gate_up_weight: torch.Tensor, + down_weight: torch.Tensor, + gate_up_weight_scale: torch.Tensor, + down_weight_scale: torch.Tensor, +) -> torch.Tensor: + """Fused FineGrained FP8 SwiGLU MLP with concatenated gate+up weights. + + Performs a single FP8 matmul for gate and up projections, then splits, + applies SwiGLU activation, and does the down FP8 matmul. + + Args: + input: Input tensor of shape [..., hidden_size] in bfloat16. + gate_up_weight: Concatenated FP8 gate+up weight + [2*intermediate_size, hidden_size] float8_e4m3fn. + down_weight: FP8 down weight [hidden_size, intermediate_size] float8_e4m3fn. + gate_up_weight_scale: Concatenated per-block weight scale for gate+up + [2*N/128, K/128] float32. + down_weight_scale: Per-block weight scale for down [N/128, K/128] float32. + + Returns: + Output tensor of shape [..., hidden_size]. + """ + # Single FP8 linear for both gate and up projections + gate_up_out = torch.ops.auto_deploy.trtllm_finegrained_fp8_linear( + input, + gate_up_weight, + None, + gate_up_weight_scale, + ) + + # Apply SwiGLU activation: split, silu(gate) * up (uses FlashInfer when available) + hidden = _silu_and_mul(gate_up_out) + + # Down projection + return torch.ops.auto_deploy.trtllm_finegrained_fp8_linear( + hidden, + down_weight, + None, + down_weight_scale, + ) + + +@fused_finegrained_fp8_swiglu_mlp.register_fake +def _( + input: torch.Tensor, + gate_up_weight: torch.Tensor, + down_weight: torch.Tensor, + gate_up_weight_scale: torch.Tensor, + down_weight_scale: torch.Tensor, +) -> torch.Tensor: + """Fake implementation for tracing.""" + # Output shape: [..., hidden_size] where hidden_size = down_weight.shape[0] + output_shape = list(input.shape[:-1]) + [down_weight.shape[0]] + return input.new_empty(output_shape, dtype=input.dtype) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/flashinfer_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/flashinfer_backend_mamba.py index 349e5de9913..9fa7b31c90e 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/flashinfer_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/flashinfer_backend_mamba.py @@ -43,6 +43,7 @@ def _flashinfer_cached_ssm( cu_seqlen: torch.Tensor, slot_idx: torch.Tensor, use_initial_states: torch.Tensor, + any_prefill_use_initial_states_host: torch.Tensor, # EXTRA METADATA chunk_indices: torch.Tensor, # [num_logical_chunks] chunk_offsets: torch.Tensor, # [num_logical_chunks] @@ -80,6 +81,7 @@ def _flashinfer_cached_ssm( cu_seqlen, slot_idx, use_initial_states, + any_prefill_use_initial_states_host, chunk_indices, chunk_offsets, seq_idx_prefill, @@ -162,6 +164,7 @@ def _flashinfer_cached_ssm_fake( cu_seqlen: torch.Tensor, slot_idx: torch.Tensor, use_initial_states: torch.Tensor, + any_prefill_use_initial_states_host: torch.Tensor, # EXTRA METADATA chunk_indices: torch.Tensor, # [num_logical_chunks] chunk_offsets: torch.Tensor, # [num_logical_chunks] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/mamba_backend_common.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/mamba_backend_common.py index fe9d832e387..873fb9c3329 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/mamba_backend_common.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/mamba_backend_common.py @@ -122,6 +122,7 @@ def _run_ssm_prefill( cu_seqlen: torch.Tensor, slot_idx: torch.Tensor, use_initial_states: torch.Tensor, + any_prefill_use_initial_states_host: torch.Tensor, chunk_indices: torch.Tensor, chunk_offsets: torch.Tensor, seq_idx_prefill: torch.Tensor, @@ -145,7 +146,8 @@ def _run_ssm_prefill( seq_idx_prefill = seq_idx_prefill[:, :num_prefill_tokens] initial_states = None - if torch.any(use_initial_states[:num_prefill]): + # Use precomputed host flag to avoid GPU->CPU sync from torch.any() + if any_prefill_use_initial_states_host.item(): initial_states = torch.where( use_initial_states[:num_prefill, None, None, None], ssm_state_cache[slot_idx[:num_prefill]], @@ -248,7 +250,13 @@ def get_source_attention_op(cls) -> OpOverloadPacket: @classmethod def get_standard_metadata_args(cls) -> List[str]: - return ["batch_info_host", "cu_seqlen", "slot_idx", "use_initial_states"] + return [ + "batch_info_host", + "cu_seqlen", + "slot_idx", + "use_initial_states", + "any_prefill_use_initial_states_host", + ] @classmethod def get_prepare_extra_metadata_info( diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py index 93e2b16d965..5488d7db623 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py @@ -43,6 +43,7 @@ def _triton_cached_ssm( cu_seqlen: torch.Tensor, slot_idx: torch.Tensor, use_initial_states: torch.Tensor, + any_prefill_use_initial_states_host: torch.Tensor, # EXTRA METADATA chunk_indices: torch.Tensor, # [num_logical_chunks] chunk_offsets: torch.Tensor, # [num_logical_chunks] @@ -81,6 +82,7 @@ def _triton_cached_ssm( cu_seqlen, slot_idx, use_initial_states, + any_prefill_use_initial_states_host, chunk_indices, chunk_offsets, seq_idx_prefill, @@ -159,6 +161,7 @@ def _triton_cached_ssm_fake( cu_seqlen: torch.Tensor, slot_idx: torch.Tensor, use_initial_states: torch.Tensor, + any_prefill_use_initial_states_host: torch.Tensor, # EXTRA METADATA chunk_indices: torch.Tensor, # [num_logical_chunks] chunk_offsets: torch.Tensor, # [num_logical_chunks] diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/normalization/rms_norm.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/normalization/rms_norm.py index 45a7080d5ac..d1a74ea2909 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/normalization/rms_norm.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/normalization/rms_norm.py @@ -235,7 +235,7 @@ def _triton_rmsnorm_gated_meta( if gate is not None: assert gate.shape == x.shape, "gate must match x shape" - return x.new_empty(x.shape, dtype=torch.float32) + return x.new_empty(x.shape, dtype=x.dtype) # Forked from: diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/quantization/torch_quant.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/quantization/torch_quant.py index 695254191c2..d0e2d270205 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/quantization/torch_quant.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/quantization/torch_quant.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from typing import List, Optional import torch @@ -464,7 +465,9 @@ def _safe_act_quant(x: torch.Tensor, block_size: int = 128) -> tuple: assert x.is_contiguous() assert x.shape[-1] % block_size == 0 y = torch.empty_like(x, dtype=torch.float8_e4m3fn) - s = x.new_empty(*x.shape[:-1], x.shape[-1] // block_size, dtype=torch.float32) + # Keep scale metadata in the model dtype to avoid FP32->BF16 cast kernels + # when the tensor is consumed by downstream MoE/quantized paths. + s = x.new_empty(*x.shape[:-1], x.shape[-1] // block_size, dtype=x.dtype) grid = lambda meta: (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),) # noqa: E731 _act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size) @@ -474,8 +477,14 @@ def _safe_act_quant(x: torch.Tensor, block_size: int = 128) -> tuple: def _dequant_block_fp8_weight(weight_fp8, weight_scale, block_n, block_k, dtype=torch.bfloat16): """Dequantize block-scaled FP8 weight to BF16 for tiny projections.""" N, K = weight_fp8.shape - scale_expanded = weight_scale.repeat_interleave(block_n, dim=0).repeat_interleave( - block_k, dim=1 + scale_n, scale_k = weight_scale.shape + # Use ceil division so the expanded scale covers the full weight dimension + # even when N or K is not exactly divisible by the block size (e.g. 576 / 5 + # scales → ceil=116, giving 580 rows after repeat, then sliced to 576). + actual_block_n = math.ceil(N / scale_n) if scale_n > 0 else block_n + actual_block_k = math.ceil(K / scale_k) if scale_k > 0 else block_k + scale_expanded = weight_scale.repeat_interleave(actual_block_n, dim=0).repeat_interleave( + actual_block_k, dim=1 ) scale_expanded = scale_expanded[:N, :K] return weight_fp8.to(dtype) * scale_expanded.to(dtype) @@ -581,6 +590,8 @@ def trtllm_finegrained_fp8_linear( # For small layers where a dimension < 128 (e.g. N=64), the derived block # size will be < 128. Fall back to BF16 dequant + cuBLAS. if block_n != 128 or block_k != 128: + # BF16 fallback: the Triton FP8 kernel launches Grid=1x1x1 for tiny N, + # wasting 99% of SM capacity. Dequantize weight + cuBLAS is faster. weight_dequant = _dequant_block_fp8_weight( weight, weight_scale, block_n, block_k, dtype=input.dtype ) diff --git a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_5_moe.py b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_5_moe.py index d4925c19b0e..a609afc4a9e 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_5_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/models/custom/modeling_qwen3_5_moe.py @@ -284,7 +284,7 @@ def _apply_interleaved_mrope(self, freqs: torch.Tensor) -> torch.Tensor: # ============================================================================= # Adapted from the Qwen3Next GDN patch: # tensorrt_llm/_torch/auto_deploy/models/patches/qwen3_next.py -# Uses autodeploy custom ops: torch_causal_conv1d, torch_l2norm, torch_gated_delta_rule +# Uses autodeploy custom ops: torch_causal_conv1d, torch_gated_delta_rule class Qwen3_5MoeGatedDeltaNet(nn.Module): @@ -333,11 +333,9 @@ def __init__(self, config: Qwen3_5MoeTextConfig, layer_idx: int): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, seq_len, _ = hidden_states.shape - # 1. Projections (separate, unlike Qwen3Next which uses combined in_proj_qkvz) mixed_qkv = self.in_proj_qkv(hidden_states) # [B, S, conv_dim] z = self.in_proj_z(hidden_states) # [B, S, value_dim] - z = z.reshape(batch_size, seq_len, -1, self.head_v_dim) # [B, S, num_v_heads, head_v_dim] b = self.in_proj_b(hidden_states) # [B, S, num_v_heads] a = self.in_proj_a(hidden_states) # [B, S, num_v_heads] @@ -367,33 +365,18 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: key = key.reshape(batch_size, seq_len, -1, self.head_k_dim) value = value.reshape(batch_size, seq_len, -1, self.head_v_dim) - # 3. L2 normalize Q and K via autodeploy op - query = torch.ops.auto_deploy.torch_l2norm(query) - key = torch.ops.auto_deploy.torch_l2norm(key) - - # 4. Compute beta and gating - beta = b.sigmoid() # [B, S, num_v_heads] - # If the model is loaded in fp16, without the .float() here, A might be -inf - g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) # [B, S, num_v_heads] - - # Repeat-interleave Q, K if num_v_heads > num_k_heads (GQA for linear attention) - if self.num_v_heads // self.num_k_heads > 1: - query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) - key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) - - # 5. Gated Delta Rule via autodeploy custom op - # Op expects [B, S, H, D] layout (bsnd convention) - core_attn_out = torch.ops.auto_deploy.torch_gated_delta_rule(query, key, value, g, beta) + # 3. Gated Delta Rule via autodeploy custom op + # L2 norm, GQA repeat-interleave, and g/beta computation are handled inside the op. + core_attn_out = torch.ops.auto_deploy.torch_gated_delta_rule( + query, key, value, a, b, self.A_log, self.dt_bias + ) - # 6. Gated RMSNorm - z_shape_og = z.shape - core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) - z = z.reshape(-1, z.shape[-1]) + # 5. Gated RMSNorm + merge heads + z = z.reshape(batch_size, seq_len, -1, self.head_v_dim) # [B, S, num_v_heads, head_v_dim] core_attn_out = self.norm(core_attn_out, z) - core_attn_out = core_attn_out.reshape(z_shape_og) core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1) - # 7. Output projection + # 6. Output projection output = self.out_proj(core_attn_out) return output @@ -619,6 +602,12 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: w2_weights = [self.experts[i].down_proj.weight for i in range(len(self.experts))] w3_weights = [self.experts[i].up_proj.weight for i in range(len(self.experts))] + # Shared expert with sigmoid gating + shared_expert_output = self.shared_expert(hidden_states_flat) + shared_expert_output = ( + F.sigmoid(self.shared_expert_gate(hidden_states_flat)) * shared_expert_output + ) + expert_output = torch.ops.auto_deploy.torch_moe( hidden_states_flat, selected_experts, @@ -629,11 +618,6 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: is_gated_mlp=True, ) - # Shared expert with sigmoid gating - shared_expert_output = self.shared_expert(hidden_states_flat) - shared_expert_output = ( - F.sigmoid(self.shared_expert_gate(hidden_states_flat)) * shared_expert_output - ) expert_output = expert_output + shared_expert_output expert_output = expert_output.reshape(batch_size, sequence_length, hidden_dim) @@ -1515,7 +1499,7 @@ def forward( Steps: 1. Embed input_ids -> inputs_embeds 2. Run vision tower on pixel_values -> masked_scatter into embeds - 3. Compute mRoPE position_ids via get_rope_index + 3. Compute mRoPE position_ids via get_rope_index (or use external ones) 4. Compute (cos, sin) from rotary_emb 5. Call language_model (TextModel) with (inputs_embeds, position_embeddings) """ diff --git a/tensorrt_llm/_torch/auto_deploy/models/patches/qwen3_next.py b/tensorrt_llm/_torch/auto_deploy/models/patches/qwen3_next.py index 874877083cc..c459e65d8d0 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/patches/qwen3_next.py +++ b/tensorrt_llm/_torch/auto_deploy/models/patches/qwen3_next.py @@ -3,7 +3,7 @@ Includes: - MoE patch: replaces Qwen3NextSparseMoeBlock.forward with torch_moe op - GDN patch: replaces Qwen3NextGatedDeltaNet.forward with autodeploy custom ops - (torch_causal_conv1d, torch_l2norm, torch_gated_delta_rule) + (torch_causal_conv1d, torch_gated_delta_rule) - Mask/cache patches: simplify _update_linear_attn_mask and DynamicCache.__bool__ for torch.export compatibility @@ -99,8 +99,8 @@ def _patched_gdn_forward( Removes cache-dependent control flow and uses autodeploy custom ops: - torch_causal_conv1d for the depthwise causal convolution - - torch_l2norm for L2 normalization of Q and K - torch_gated_delta_rule for the core gated delta rule computation + (L2 norm, GQA expansion, and gating are handled inside the op) """ hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) batch_size, seq_len, _ = hidden_states.shape @@ -143,25 +143,13 @@ def _patched_gdn_forward( key = key.reshape(batch_size, seq_len, -1, self.head_k_dim) value = value.reshape(batch_size, seq_len, -1, self.head_v_dim) - # 3. L2 normalize Q and K via autodeploy op - query = torch.ops.auto_deploy.torch_l2norm(query) - key = torch.ops.auto_deploy.torch_l2norm(key) - - # 4. Compute beta and gating - beta = b.sigmoid() # [B, S, num_v_heads] - # If the model is loaded in fp16, without the .float() here, A might be -inf - g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) # [B, S, num_v_heads] - - # Repeat-interleave Q, K if num_v_heads > num_k_heads (GQA for linear attention) - if self.num_v_heads // self.num_k_heads > 1: - query = query.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) - key = key.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) - - # 5. Gated Delta Rule via autodeploy custom op - # Op expects [B, S, H, D] layout (bsnd convention) - core_attn_out = torch.ops.auto_deploy.torch_gated_delta_rule(query, key, value, g, beta) + # 3. Gated Delta Rule via autodeploy custom op + # L2 norm, GQA repeat-interleave, and g/beta computation are handled inside the op. + core_attn_out = torch.ops.auto_deploy.torch_gated_delta_rule( + query, key, value, a, b, self.A_log, self.dt_bias + ) - # 6. Gated RMSNorm + # 5. Gated RMSNorm z_shape_og = z.shape core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) z = z.reshape(-1, z.shape[-1]) @@ -169,7 +157,7 @@ def _patched_gdn_forward( core_attn_out = core_attn_out.reshape(z_shape_og) core_attn_out = core_attn_out.reshape(batch_size, seq_len, -1) - # 7. Output projection + # 6. Output projection output = self.out_proj(core_attn_out) return output diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_gdn_gating.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_gdn_gating.py new file mode 100644 index 00000000000..92620e21620 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_gdn_gating.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Graph transform to fuse GDN gating ops from torch source to Triton kernel.""" + +from typing import Tuple, Type + +import torch +from torch.fx import GraphModule, Node + +from ...custom_ops.fla import gdn_gating as _gdn_gating_ops # noqa: F401 (registers ops) +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ...utils.node_utils import is_op +from ..interface import ( + BaseTransform, + SharedConfig, + TransformConfig, + TransformInfo, + TransformRegistry, +) + + +@TransformRegistry.register("fuse_gdn_gating") +class FuseGdnGating(BaseTransform): + """Replaces torch_fused_gdn_gating ops with triton_fused_gdn_gating. + + This transform runs in the post_load_fusion stage and swaps the pure-torch + source op with a single-kernel Triton implementation, eliminating ~5 kernel + launches per GDN layer. + + Args: + gm: Input graph module to transform. + + Returns: + Transformed graph module with Triton-fused GDN gating operations. + """ + + @classmethod + def get_config_class(cls) -> Type[TransformConfig]: + return TransformConfig + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + graph = gm.graph + target_op = torch.ops.auto_deploy.triton_fused_gdn_gating.default + cnt = 0 + + for node in list(graph.nodes): + if is_op(node, torch.ops.auto_deploy.torch_fused_gdn_gating): + with graph.inserting_after(node): + new_node: Node = graph.call_function( + target_op, + args=node.args, + kwargs=node.kwargs, + ) + new_node.meta = node.meta.copy() + node.replace_all_uses_with(new_node) + graph.erase_node(node) + cnt += 1 + + info = TransformInfo( + skipped=False, + num_matches=cnt, + is_clean=cnt == 0, + has_valid_shapes=cnt == 0, + ) + + return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_swiglu.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_swiglu.py index 268cdad6f58..7ebbc612b9f 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_swiglu.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_swiglu.py @@ -616,3 +616,271 @@ def _apply( ) return gm, info + + +# ── FineGrained FP8 quantized SwiGLU pattern matching and fusion ───────────── + +from ...custom_ops.linear.swiglu import torch_finegrained_fp8_swiglu_mlp # noqa: E402 + + +def _finegrained_fp8_swiglu_pattern_no_bias( + x, + gate_weight, + gate_weight_scale, + up_weight, + up_weight_scale, + down_weight, + down_weight_scale, +): + """Pattern for FineGrained FP8 quantized SwiGLU MLP without biases. + + Matches: silu(fp8_linear(x, gate)) * fp8_linear(x, up) -> fp8_linear(down) + """ + gate_out = torch.ops.auto_deploy.torch_fake_quant_finegrained_fp8_linear.default( + x, + gate_weight, + None, + input_scale=[], + weight_scale=[gate_weight_scale], + input_zp=[], + weight_zp=[], + ) + up_out = torch.ops.auto_deploy.torch_fake_quant_finegrained_fp8_linear.default( + x, + up_weight, + None, + input_scale=[], + weight_scale=[up_weight_scale], + input_zp=[], + weight_zp=[], + ) + silu_out = torch.ops.aten.silu.default(gate_out) + mul_out = torch.ops.aten.mul.Tensor(silu_out, up_out) + down_out = torch.ops.auto_deploy.torch_fake_quant_finegrained_fp8_linear.default( + mul_out, + down_weight, + None, + input_scale=[], + weight_scale=[down_weight_scale], + input_zp=[], + weight_zp=[], + ) + return down_out + + +def _finegrained_fp8_swiglu_replacement_no_bias( + x, + gate_weight, + gate_weight_scale, + up_weight, + up_weight_scale, + down_weight, + down_weight_scale, +): + """Replacement for FineGrained FP8 quantized SwiGLU pattern without biases.""" + return torch_finegrained_fp8_swiglu_mlp( + x, + gate_weight, + up_weight, + down_weight, + gate_weight_scale, + up_weight_scale, + down_weight_scale, + ) + + +@TransformRegistry.register("match_finegrained_fp8_swiglu_pattern") +class MatchFineGrainedFP8SwiGLUPattern(BaseTransform): + """Matches FineGrained FP8 quantized SwiGLU MLP patterns. + + This transform runs in the pattern_matcher stage AFTER + quantize_finegrained_fp8_linear_from_config has converted torch_linear_simple ops + to torch_fake_quant_finegrained_fp8_linear ops. + + It detects the following FineGrained FP8 pattern: + silu(fp8_linear(x, gate)) * fp8_linear(x, up) -> fp8_linear(down) + + And replaces it with a single torch_finegrained_fp8_swiglu_mlp op that can be + fused later. + + Note: This transform runs before sharding. The composite SwiGLU op will NOT be + sharded by the sharding transform. Enable only when sharding is not needed or + when sharding-aware handling is added separately. + """ + + config: TransformConfig + + @classmethod + def get_config_class(cls) -> Type[TransformConfig]: + return TransformConfig + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + patterns = ADPatternMatcherPass() + + # FP8 shape params for dummy args (shapes don't matter for matching, + # but must be multiples of 128 for block quantization) + N = 256 # intermediate_size + K = 256 # hidden_size + N_down = K # hidden_size (output of down proj) + + x = torch.randn(2, K, device="meta", dtype=torch.bfloat16) + + # Gate args + gate_w = torch.randn(N, K, device="meta", dtype=torch.float8_e4m3fn) + gate_ws = torch.randn(N // 128, K // 128, device="meta", dtype=torch.float32) + + # Up args (same shapes as gate) + up_w = torch.randn(N, K, device="meta", dtype=torch.float8_e4m3fn) + up_ws = torch.randn(N // 128, K // 128, device="meta", dtype=torch.float32) + + # Down args + down_w = torch.randn(N_down, N, device="meta", dtype=torch.float8_e4m3fn) + down_ws = torch.randn(N_down // 128, N // 128, device="meta", dtype=torch.float32) + + dummy_args = [ + x, + gate_w, + gate_ws, + up_w, + up_ws, + down_w, + down_ws, + ] + + register_ad_pattern( + search_fn=_finegrained_fp8_swiglu_pattern_no_bias, + replace_fn=_finegrained_fp8_swiglu_replacement_no_bias, + patterns=patterns, + dummy_args=dummy_args, + ) + + num_matches = patterns.apply(gm.graph) + + if num_matches > 0: + gm.recompile() + + info = TransformInfo( + skipped=False, + num_matches=num_matches, + is_clean=num_matches == 0, + has_valid_shapes=num_matches == 0, + ) + + return gm, info + + +@TransformRegistry.register("fuse_finegrained_fp8_swiglu") +class FuseFineGrainedFP8SwiGLU(BaseTransform): + """Fuses torch_finegrained_fp8_swiglu_mlp ops by concatenating gate and up FP8 weights. + + This transform runs in the post_load_fusion stage and replaces + torch_finegrained_fp8_swiglu_mlp ops with fused_finegrained_fp8_swiglu_mlp ops + that use a single concatenated gate+up weight matrix. + + FP8 weight fusion: + - gate+up FP8 weights are concatenated along dim=0: [N, K] -> [2N, K] + - gate+up per-block weight scales are concatenated along dim=0: + [N/128, K/128] -> [2N/128, K/128] + """ + + config: TransformConfig + + @classmethod + def get_config_class(cls) -> Type[TransformConfig]: + return TransformConfig + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + graph = gm.graph + cnt = 0 + fused_weight_idx = 0 + + for node in list(graph.nodes): + if not is_op(node, torch.ops.auto_deploy.torch_finegrained_fp8_swiglu_mlp.default): + continue + + # Extract args: + # (input, gate_weight, up_weight, down_weight, + # gate_weight_scale, up_weight_scale, down_weight_scale) + input_node = node.args[0] + gate_weight_node = node.args[1] + up_weight_node = node.args[2] + down_weight_node = node.args[3] + gate_weight_scale_node = node.args[4] + up_weight_scale_node = node.args[5] + down_weight_scale_node = node.args[6] + + # Get the actual weight tensors + gate_weight = get_attr_by_name(gm, gate_weight_node.target) + up_weight = get_attr_by_name(gm, up_weight_node.target) + + # Concatenate gate and up FP8 weights along dim=0: [N, K] -> [2N, K] + gate_up_weight = torch.cat([gate_weight, up_weight], dim=0) + + # Get and concatenate weight scales along dim=0: + # [N/128, K/128] -> [2N/128, K/128] + gate_weight_scale = get_attr_by_name(gm, gate_weight_scale_node.target) + up_weight_scale = get_attr_by_name(gm, up_weight_scale_node.target) + gate_up_weight_scale = torch.cat([gate_weight_scale, up_weight_scale], dim=0) + + # Register fused buffers + prefix = f"fused_finegrained_fp8_swiglu_{fused_weight_idx}" + gm.register_buffer(f"{prefix}_gate_up_weight", gate_up_weight) + gm.register_buffer(f"{prefix}_gate_up_weight_scale", gate_up_weight_scale) + + # Create get_attr nodes for fused weights/scales + with graph.inserting_before(node): + fused_gate_up_weight_node = graph.get_attr(f"{prefix}_gate_up_weight") + fused_gate_up_weight_scale_node = graph.get_attr(f"{prefix}_gate_up_weight_scale") + + # Create the fused_finegrained_fp8_swiglu_mlp node + with graph.inserting_after(node): + fused_node: Node = graph.call_function( + torch.ops.auto_deploy.fused_finegrained_fp8_swiglu_mlp.default, + args=( + input_node, + fused_gate_up_weight_node, + down_weight_node, + fused_gate_up_weight_scale_node, + down_weight_scale_node, + ), + ) + + # Replace uses and erase old node + node.replace_all_uses_with(fused_node) + graph.erase_node(node) + + # Eagerly free unfused weight/scale tensors that are no longer referenced + # to avoid a temporary memory spike from holding both fused and unfused + # copies simultaneously across all layers. + _try_free_attr_node(gm, graph, gate_weight_node) + _try_free_attr_node(gm, graph, up_weight_node) + _try_free_attr_node(gm, graph, gate_weight_scale_node) + _try_free_attr_node(gm, graph, up_weight_scale_node) + + fused_weight_idx += 1 + cnt += 1 + + if cnt > 0: + gm.recompile() + + # Clean up any remaining dead code and unused submodules + eliminate_dead_code(gm) + delete_all_unused_submodules(gm) + + info = TransformInfo( + skipped=False, num_matches=cnt, is_clean=cnt == 0, has_valid_shapes=cnt == 0 + ) + + return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/moe_routing.py b/tensorrt_llm/_torch/auto_deploy/transform/library/moe_routing.py new file mode 100644 index 00000000000..63b192feeae --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/moe_routing.py @@ -0,0 +1,225 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Graph transform to fuse MoE softmax → top-k → renormalize routing. + +Detects the standard MoE routing pattern used by Qwen3.5 (and similar models): + routing_weights = softmax(router_logits, dtype=float32) + routing_weights, indices = topk(routing_weights, k) + routing_weights = routing_weights / routing_weights.sum(keepdim=True) + +and replaces it with a single fused Triton kernel: + routing_weights, indices = triton_fused_topk_softmax(router_logits, k) + +This leverages the mathematical equivalence: + topk(softmax(x)); x /= x.sum() ≡ softmax(topk(x)) + +The fused kernel avoids computing softmax over all experts (e.g. 256), instead +finding top-k from raw logits and computing softmax only over the k selected values. +""" + +import operator +from typing import Optional, Tuple, Type + +import torch +from torch.fx import GraphModule, Node + +# Importing this module registers the torch.ops.auto_deploy.triton_fused_topk_softmax op. +import tensorrt_llm._torch.auto_deploy.custom_ops.fused_moe.triton_routing # noqa: F401 +from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactory +from tensorrt_llm._torch.auto_deploy.shim.interface import CachedSequenceInterface +from tensorrt_llm._torch.auto_deploy.transform.interface import ( + BaseTransform, + SharedConfig, + TransformConfig, + TransformInfo, + TransformRegistry, +) +from tensorrt_llm._torch.auto_deploy.utils._graph import eliminate_dead_code +from tensorrt_llm._torch.auto_deploy.utils.logger import ad_logger +from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op + +# --------------------------------------------------------------------------- +# Pattern-detection helpers +# --------------------------------------------------------------------------- + + +def _get_single_getitem_user(node: Node, index: int) -> Optional[Node]: + """Return the unique ``operator.getitem(node, index)`` user, or *None*.""" + for user in node.users: + if user.op == "call_function" and user.target is operator.getitem and user.args[1] == index: + return user + return None + + +def _trace_back_through_softmax(node: Node) -> Optional[Node]: + """Trace backwards from *node* to find the raw logits before softmax. + + Handles the common aten decompositions produced by ``torch.export``: + + 1. ``aten.softmax.int(logits, dim)`` — no dtype cast + 2. ``aten.softmax.int(logits, dim, dtype)`` — with dtype cast + 3. ``aten._softmax.default(logits, dim, half_to_float)`` + 4. ``aten._to_copy(logits, dtype=float32) → aten._softmax.default(…)`` + + Returns the original logits tensor (before any softmax / dtype cast) or + *None* if the input does not originate from a softmax. + """ + _softmax_ops = ( + torch.ops.aten.softmax.int, + torch.ops.aten._softmax.default, + ) + if not is_op(node, _softmax_ops): + return None + + softmax_input = node.args[0] + + # For aten._softmax.default, check for a preceding dtype cast + if is_op(node, torch.ops.aten._softmax.default): + if isinstance(softmax_input, Node) and is_op( + softmax_input, torch.ops.aten._to_copy.default + ): + if len(softmax_input.users) == 1: + return softmax_input.args[0] + + # For both variants the first arg is the logits + return softmax_input + + +def _find_renormalization_node(values_node: Node) -> Optional[Node]: + """Return the ``aten.div`` node that renormalizes *values_node*, or *None*. + + Looks for the pattern:: + + sum_val = aten.sum.dim_IntList(values, [dim], keepdim=True) + renorm = aten.div.Tensor(values, sum_val) + """ + for user in values_node.users: + if is_op(user, torch.ops.aten.div.Tensor) and user.args[0] is values_node: + divisor = user.args[1] + if ( + isinstance(divisor, Node) + and is_op(divisor, torch.ops.aten.sum.dim_IntList) + and divisor.args[0] is values_node + ): + return user + elif is_op(user, torch.ops.aten.sum.dim_IntList) and user.args[0] is values_node: + for sum_user in user.users: + if ( + is_op(sum_user, torch.ops.aten.div.Tensor) + and sum_user.args[0] is values_node + and sum_user.args[1] is user + ): + return sum_user + return None + + +# --------------------------------------------------------------------------- +# Transform +# --------------------------------------------------------------------------- + + +@TransformRegistry.register("match_moe_routing_pattern") +class MatchMoeRoutingPattern(BaseTransform): + """Match softmax → topk → renormalize and replace with a fused Triton op. + + This transform detects the 3-op MoE routing pattern:: + + routing_weights = softmax(logits, dtype=float32) + routing_weights, indices = topk(routing_weights, k) + routing_weights /= routing_weights.sum(keepdim=True) + + and replaces it with:: + + routing_weights, indices = triton_fused_topk_softmax(logits, k) + + The fused kernel exploits the equivalence + ``topk(softmax(x)) / Σ ≡ softmax(topk(x))`` and avoids computing + softmax over all experts. + """ + + config: TransformConfig + + @classmethod + def get_config_class(cls) -> Type[TransformConfig]: + return TransformConfig + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + graph = gm.graph + num_matches = 0 + + for node in list(graph.nodes): + # ---- Step 1: find an aten.topk node ---------------------------- + if not is_op(node, torch.ops.aten.topk.default): + continue + + topk_node = node + topk_input = topk_node.args[0] + top_k = topk_node.args[1] # int literal + + if not isinstance(topk_input, Node) or not isinstance(top_k, int): + continue + + # ---- Step 2: verify that topk input is a softmax --------------- + original_logits = _trace_back_through_softmax(topk_input) + if original_logits is None: + continue + + # ---- Step 3: locate getitem[0] (values) and getitem[1] (indices) + values_node = _get_single_getitem_user(topk_node, 0) + indices_node = _get_single_getitem_user(topk_node, 1) + if values_node is None or indices_node is None: + continue + + # ---- Step 4: verify values are renormalized -------------------- + renorm_node = _find_renormalization_node(values_node) + if renorm_node is None: + continue + + # ---- Step 5: all conditions met — insert fused op -------------- + ad_logger.info(f"Matched MoE routing pattern: softmax → topk(k={top_k}) → renormalize") + + with graph.inserting_before(topk_node): + fused_node = graph.call_function( + torch.ops.auto_deploy.triton_fused_topk_softmax, + args=(original_logits, top_k), + ) + fused_weights = graph.call_function(operator.getitem, args=(fused_node, 0)) + fused_indices = graph.call_function(operator.getitem, args=(fused_node, 1)) + + # Replace all downstream uses + renorm_node.replace_all_uses_with(fused_weights) + indices_node.replace_all_uses_with(fused_indices) + + num_matches += 1 + + if num_matches > 0: + eliminate_dead_code(gm) + gm.recompile() + ad_logger.info(f"Fused {num_matches} MoE routing pattern(s).") + + info = TransformInfo( + skipped=False, + num_matches=num_matches, + is_clean=num_matches == 0, + has_valid_shapes=num_matches == 0, + ) + return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_gemm.py b/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_gemm.py new file mode 100644 index 00000000000..8c3f276a8c1 --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_gemm.py @@ -0,0 +1,348 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generalized multi-stream transform for parallelizing fp8 GEMMs. + +When multiple fp8 linear (GEMM) operations share the same input tensor, they +can execute concurrently on separate CUDA streams. This transform identifies +such *fork points* in the FX graph and moves the **largest** GEMM (estimated by +weight shape) to the auxiliary CUDA stream while the remaining GEMMs stay on the +main stream. + +The overlap benefit comes from the GPU pipeline: the main-stream GEMMs and the +aux-stream GEMM execute concurrently on the GPU, reducing the total wall-clock +time compared to sequential execution. + +This is a generalization of the pattern used in ``multi_stream_mla_attn.py`` +(which is MLA-specific) and can handle arbitrary fork-and-join patterns of +fp8 linear ops. + +Example fork points that benefit from this transform: + - **Linear attention layers** (4 fp8 linears: in_proj_qkv, z, b, a) + - **Standard MHA layers** (3 fp8 linears: q_proj, k_proj, v_proj) +""" + +import math +from typing import Callable, Dict, List, Tuple + +import torch +from torch.fx import GraphModule, Node + +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ...utils._graph import create_derived_custom_op, get_attr_by_name +from ...utils.logger import ad_logger +from ...utils.multi_stream_utils import ( + _make_aux_stream_impl, + begin_aux_stream_passthrough, + cuda_stream_manager, + end_aux_stream_passthrough, + record_event_passthrough, + wait_aux_stream_passthrough, +) +from ...utils.node_utils import is_op +from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry + +# --------------------------------------------------------------------------- +# Supported linear op targets. Extend this list to cover additional +# quantised or unquantised linear variants. +# --------------------------------------------------------------------------- +_SUPPORTED_LINEAR_OPS: List[Callable] = [ + torch.ops.auto_deploy.trtllm_finegrained_fp8_linear, +] + +# Multi-stream passthrough functions used by other transforms. If any user of +# a fork point is one of these, we skip the fork point to avoid conflicts. +_MULTI_STREAM_OPS = [ + begin_aux_stream_passthrough, + end_aux_stream_passthrough, + wait_aux_stream_passthrough, + record_event_passthrough, +] + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _is_supported_linear(node: Node) -> bool: + """Return ``True`` if *node* is a call to one of the supported linear ops.""" + return is_op(node, _SUPPORTED_LINEAR_OPS) + + +def _is_multi_stream_op(node: Node) -> bool: + """Return ``True`` if *node* is a multi-stream passthrough function call.""" + if node.op != "call_function": + return False + return node.target in _MULTI_STREAM_OPS + + +def _estimate_weight_size(gm: GraphModule, linear_node: Node) -> int: + """Estimate the GEMM cost of a linear node from its weight shape. + + For a linear with weight ``[N, K]``, the cost is proportional to ``N * K`` + (since the M dimension is shared across all linears at the same fork point). + + The weight is ``args[1]`` for all supported linear ops. We first try the + node's meta information (``node.meta["val"].shape``), falling back to + accessing the actual tensor from the graph module. + + Returns: + An integer proportional to the GEMM cost (product of weight dimensions). + """ + weight_node = linear_node.args[1] + + # Try meta shape first (available after shape propagation). + val = weight_node.meta.get("val") if hasattr(weight_node, "meta") else None + if val is not None and hasattr(val, "shape") and len(val.shape) >= 2: + return math.prod(val.shape) + + # Fallback: access the actual tensor via the get_attr path. + if weight_node.op == "get_attr": + try: + weight_tensor = get_attr_by_name(gm, weight_node.target) + return weight_tensor.numel() + except AttributeError: + pass + + # If we cannot determine the size, return 0 (this linear will not be + # selected as the largest). + ad_logger.warning( + f"Could not estimate weight size for linear node {linear_node.name}; " + "it will not be considered as the largest GEMM." + ) + return 0 + + +def _create_aux_op(base_op: Callable) -> Callable: + """Create an ``_aux`` variant of a linear op that runs on the auxiliary CUDA stream. + + Uses a custom ``make_fake`` that delegates to the base op's registered fake + so that output shapes are computed correctly (linear output shape != input shape). + """ + return create_derived_custom_op( + base_op, + "_aux", + _make_aux_stream_impl, + make_fake=lambda base: lambda *a, **kw: base(*a, **kw), + ) + + +def _find_gemm_fork_points( + gm: GraphModule, + supported_ops: List[Callable], +) -> List[Tuple[Node, List[Node]]]: + """Find fork points where 2+ supported linear ops share the same input. + + Returns a list of ``(fork_point, [linear_users])`` tuples. Fork points + that already have multi-stream ops among their users are skipped to avoid + conflicts with other multi-stream transforms (e.g. ``multi_stream_moe``). + """ + results: List[Tuple[Node, List[Node]]] = [] + + for node in gm.graph.nodes: + # Collect direct supported-linear users of this node. + linear_users = [u for u in node.users if is_op(u, supported_ops)] + if len(linear_users) < 2: + continue + + # Skip if any user of this fork point is already a multi-stream op. + if any(_is_multi_stream_op(u) for u in node.users): + ad_logger.debug(f"Skipping fork point {node.name}: already has multi-stream ops.") + continue + + results.append((node, linear_users)) + + return results + + +def _move_users_after(graph, target_node: Node) -> None: + """Move any transitive users of *target_node* that precede it to after it. + + After inserting an aux node at a late position in the graph and replacing + uses of the original node, some former downstream nodes may violate + topological order (they reference *target_node* but appear before it in + the graph's linked list). This function restores topological order by + moving those nodes to just after *target_node* while preserving their + relative order. + + This is safe because: + - Moved nodes originally depended on the (now-erased) largest linear, so + their non-aux inputs all appear before the original linear position, + which is before *target_node*. + - Moving them forward (to a later position) cannot place them before any + of their other inputs. + """ + node_order = {n: i for i, n in enumerate(graph.nodes)} + target_pos = node_order[target_node] + + # BFS to find all transitive users that appear before target_node. + nodes_to_move: List[Node] = [] + visited: set = set() + queue = list(target_node.users.keys()) + + while queue: + n = queue.pop(0) + if n in visited or n.op == "output": + continue + visited.add(n) + if node_order.get(n, float("inf")) < target_pos: + nodes_to_move.append(n) + queue.extend(n.users.keys()) + + if not nodes_to_move: + return + + # Sort by original order to maintain relative dependencies. + nodes_to_move.sort(key=lambda n: node_order[n]) + + # Move each node to right after target_node (or the previously moved node). + anchor = target_node + for n in nodes_to_move: + anchor.append(n) + anchor = n + + +def _parallelize_largest_gemm( + gm: GraphModule, + supported_ops: List[Callable], +) -> Tuple[GraphModule, int]: + """Move the largest GEMM at each fork point to the auxiliary CUDA stream. + + For each fork point with 2+ supported linear users: + + 1. Estimate weight size for each linear to identify the largest. + 2. Insert ``record_event_passthrough(fork_point)`` before the earliest + non-largest linear to record the main-stream event (data is ready). + 3. Create an ``_aux`` variant of the largest linear's op. + 4. Insert the aux node **after** the latest non-largest linear in graph + order so that the GPU pipeline can overlap the main-stream GEMMs with + the aux-stream GEMM. + 5. Wire the aux node's hidden-state input to the ``record_event_passthrough`` + output (data dependency ensures event recording precedes aux dispatch). + 6. Replace all uses of the original largest linear with the aux node and + erase the original. + 7. Move any downstream nodes of the original largest linear that now + precede the aux node in graph order to after it (restoring topological + order without sacrificing GPU overlap). + """ + fork_points = _find_gemm_fork_points(gm, supported_ops) + if not fork_points: + return gm, 0 + + graph = gm.graph + node_order = {n: i for i, n in enumerate(graph.nodes)} + + # Create aux ops lazily for whatever linear op types are found. + op_dict: Dict[Callable, Callable] = {} + + num_replaced = 0 + + for fork_point, linear_users in fork_points: + # ---- Step 1: Identify the largest linear by weight size. ---- + sizes = {ln: _estimate_weight_size(gm, ln) for ln in linear_users} + largest = max(linear_users, key=lambda ln: sizes[ln]) + remaining = [ln for ln in linear_users if ln is not largest] + + if not remaining: + # Shouldn't happen (we require 2+ linears), but guard anyway. + continue + + # Sort remaining linears by their position in the graph. + remaining.sort(key=lambda n: node_order.get(n, 0)) + earliest_remaining = remaining[0] + latest_remaining = remaining[-1] + + ad_logger.info( + f"Fork point {fork_point.name}: moving {largest.name} " + f"(weight size {sizes[largest]}) to aux stream; " + f"{len(remaining)} linear(s) stay on main stream." + ) + + # ---- Step 2: Insert record_event_passthrough. ---- + # Placed before the earliest remaining linear so the main-stream event + # is recorded *before* any main-stream GEMMs are dispatched. + with graph.inserting_before(earliest_remaining): + rec_node = graph.call_function( + record_event_passthrough, + args=(fork_point,), + ) + + # ---- Step 3: Create aux op lazily. ---- + if largest.target not in op_dict: + op_dict[largest.target] = _create_aux_op(largest.target) + + # ---- Step 4: Insert aux node after the latest remaining linear. ---- + # This ensures all main-stream GEMMs are dispatched to the GPU before + # the aux node submits its work + wait, enabling overlap. + new_args = tuple(rec_node if arg is fork_point else arg for arg in largest.args) + + with graph.inserting_after(latest_remaining): + aux_node = graph.call_function( + op_dict[largest.target], + args=new_args, + kwargs=largest.kwargs, + ) + + # ---- Step 5 & 6: Replace uses and erase original. ---- + largest.replace_all_uses_with(aux_node) + graph.erase_node(largest) + + # ---- Step 7: Restore topological order. ---- + # The downstream nodes of the original largest linear (e.g. view, + # reshape, split) may now appear *before* aux_node in graph order + # because aux_node was inserted after the latest remaining linear. + # Move those nodes to after aux_node so the graph is valid. + _move_users_after(graph, aux_node) + + num_replaced += 1 + + return gm, num_replaced + + +# --------------------------------------------------------------------------- +# Transform class +# --------------------------------------------------------------------------- + + +@TransformRegistry.register("multi_stream_gemm") +class MultiStreamGemm(BaseTransform): + """Multi-stream parallelization of fp8 GEMMs sharing the same input. + + For each fork point where 2+ fp8 linear ops share the same input tensor, + the largest GEMM (by weight shape) is moved to the auxiliary CUDA stream + so it executes concurrently with the remaining GEMMs on the main stream. + """ + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + # Ensure aux stream and events are set up for the current device. + cuda_stream_manager.add_device(torch.cuda.current_device()) + + gm, num_matches = _parallelize_largest_gemm(gm, _SUPPORTED_LINEAR_OPS) + + info = TransformInfo( + skipped=False, + num_matches=num_matches, + is_clean=num_matches == 0, + has_valid_shapes=num_matches == 0, + ) + return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py index 712eed42d6a..d0aed4ef96c 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/multi_stream_moe.py @@ -212,12 +212,12 @@ def _apply( torch.ops.auto_deploy.trtllm_quant_fp8_moe_fused, torch.ops.auto_deploy.trtllm_quant_nvfp4_moe_fused, torch.ops.auto_deploy.trtllm_nvfp4_trtllm_gen_moe_fused, + torch.ops.auto_deploy.trtllm_quant_finegrained_fp8_moe_fused, ] # Ensure that aux stream and events for the current device are added to the CudaStreamManager. cuda_stream_manager.add_device(torch.cuda.current_device()) gm, num_matches = _execute_shared_expert_in_aux_stream(gm, base_ops) - info = TransformInfo( skipped=False, num_matches=num_matches, diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py b/tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py index 9c5c5247f4f..de5267b5309 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/rms_norm.py @@ -266,6 +266,94 @@ def _apply( target_op = _BACKEND_OPS[backend] cnt = 0 + # First, fuse the norm-before-gate decomposition: + # torch_rmsnorm(x, w, eps) * silu(gate.to(fp32)) -> triton_rmsnorm_gated(x, w, gate, ...) + # This avoids a separate fp32 mul + cast before downstream GEMM. + for node in list(graph.nodes): + if not is_op(node, torch.ops.auto_deploy.torch_rmsnorm): + continue + + # torch_rmsnorm output should only feed the mul in this pattern. + if len(node.users) != 1: + continue + mul_node = next(iter(node.users)) + if not is_op(mul_node, torch.ops.aten.mul.Tensor): + continue + + lhs, rhs = mul_node.args + if lhs is node: + other = rhs + elif rhs is node: + other = lhs + else: + continue + + if not isinstance(other, Node) or not is_op(other, torch.ops.aten.silu.default): + continue + + gate_input = other.args[0] + if isinstance(gate_input, Node) and is_op(gate_input, torch.ops.aten.to.dtype): + if len(gate_input.args) < 2 or gate_input.args[1] != torch.float32: + continue + gate = gate_input.args[0] + gate_cast_node = gate_input + else: + gate = gate_input + gate_cast_node = None + + # Optional trailing cast back to bf16/fp16. + output_node = mul_node + trailing_cast_node = None + if len(mul_node.users) == 1: + only_user = next(iter(mul_node.users)) + if ( + is_op(only_user, torch.ops.aten.to.dtype) + and len(only_user.args) >= 2 + and only_user.args[0] is mul_node + and only_user.args[1] in (torch.bfloat16, torch.float16) + ): + output_node = only_user + trailing_cast_node = only_user + + # Infer group_size from normalized dimension (fallback-safe). + x, weight, eps = node.args + group_size = None + if isinstance(weight, Node): + w_meta = weight.meta.get("val") if hasattr(weight, "meta") else None + if w_meta is not None and hasattr(w_meta, "numel"): + group_size = int(w_meta.numel()) + if group_size is None and isinstance(x, Node): + x_meta = x.meta.get("val") if hasattr(x, "meta") else None + if x_meta is not None and hasattr(x_meta, "shape"): + group_size = int(x_meta.shape[-1]) + if group_size is None: + continue + + with graph.inserting_after(output_node): + fused_node: Node = graph.call_function( + torch.ops.auto_deploy.triton_rmsnorm_gated, + args=(x, weight, gate, eps, group_size, True), + ) + + output_node.replace_all_uses_with(fused_node) + graph.erase_node(output_node) + cnt += 1 + + if ( + trailing_cast_node is not None + and trailing_cast_node is not output_node + and len(trailing_cast_node.users) == 0 + ): + graph.erase_node(trailing_cast_node) + if mul_node is not output_node and len(mul_node.users) == 0: + graph.erase_node(mul_node) + if len(other.users) == 0: + graph.erase_node(other) + if gate_cast_node is not None and len(gate_cast_node.users) == 0: + graph.erase_node(gate_cast_node) + if len(node.users) == 0: + graph.erase_node(node) + # Replace torch_rmsnorm ops with the selected backend for node in list(graph.nodes): if is_op(node, torch.ops.auto_deploy.torch_rmsnorm): diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 0f14c2434fb..7bd638c55da 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -53,6 +53,7 @@ is_any_lin_op, is_any_moe_op, is_any_ssm_op, + is_fake_quantized_linear_op, is_op, is_weight_node, num_users_of_weight_node, @@ -77,6 +78,26 @@ ######################################################## +######################################################## +# Helper functions +######################################################## +def is_quantized_linear_scale_tensor(node: "Node", weight_node_key: str) -> bool: + """Check if a weight node is a scale tensor for a quantized linear op. + + Scale tensors (e.g., weight_scale_inv for FineGrained FP8) are in "block space" and should + not be sharded with the same min_local_shape as the actual weight tensor. + They are handled separately by quantization_cb. + + Args: + node: The linear operation node + weight_node_key: The parameter key of the weight node (e.g., "model.layers.0.self_attn.v_proj.weight_scale_inv") + + Returns: + True if this is a scale tensor for a quantized linear op, False otherwise + """ + return is_fake_quantized_linear_op(node) and "_scale" in weight_node_key + + ######################################################## # Helper enums ######################################################## @@ -206,11 +227,6 @@ def _init_mapping(self): moe_cluster_size=1, ) - enable_attention_dp: bool = Field( - default=False, - description="When True, skip TP sharding as attention data parallelism is enabled.", - ) - def validate_config(self, sources: Union[ShardingSource, List[ShardingSource]] = None) -> bool: if sources is None: sources = [ShardingSource.FACTORY, ShardingSource.MANUAL] @@ -1547,6 +1563,14 @@ def _shard_parameter_node( ) for weight_node in weight_nodes.weights: + if is_quantized_linear_scale_tensor(node, weight_node.node_key): + # Scale tensors (e.g. weight_scale_inv) are sharded by + # quantization_cb (via QuantizationShardingMixin.shard_scales + + # shard_load_hook) when processing the main weight. Calling + # shard_weight_tensor here would register a SECOND load hook + # that double-shards the scale during checkpoint loading. + continue + _, weight_new_shape = shard_weight_tensor( gm=gm, weight_tensor=weight_node.tensor, @@ -1808,9 +1832,10 @@ def get_partition(lst, world_size, rank): # Standard TP with all_reduce: # No attention-DP, so tokens are NOT distributed across ranks. # Just add all_reduce after MoE to sum TP partial results. + _, all_reduce_op = _get_dist_ops(config.dist_backend) with gm.graph.inserting_after(node): dist_node = gm.graph.call_function( - torch.ops.auto_deploy.torch_dist_all_reduce.default, + all_reduce_op, args=(node, allreduce_strategy), ) node.replace_all_uses_with(dist_node) @@ -1898,10 +1923,9 @@ def _insert_sharded_mxfp4_mlp_ep( node.args = args_ep # Add a dist all-reduce after the op (sum partial results across EP ranks) + _, all_reduce_op = _get_dist_ops(config.dist_backend) with gm.graph.inserting_after(node): - red = gm.graph.call_function( - torch.ops.auto_deploy.torch_dist_all_reduce, args=(node, config.allreduce_strategy.name) - ) + red = gm.graph.call_function(all_reduce_op, args=(node, config.allreduce_strategy.name)) node.replace_all_uses_with(red) # keep dataflow: red(input=node) red.replace_input_with(red, node) diff --git a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py index 66c93f075a7..a967f72d597 100644 --- a/tests/integration/defs/accuracy/test_llm_api_autodeploy.py +++ b/tests/integration/defs/accuracy/test_llm_api_autodeploy.py @@ -628,33 +628,19 @@ class TestQwen3_5_MoE(LlmapiAccuracyTestHarness): """ MODEL_NAME = "Qwen/Qwen3.5-397B-A17B" + MODEL_NAME_SMALL = "Qwen/Qwen3.5-35B-A3B" MAX_SEQ_LEN = max(MMLU.MAX_INPUT_LEN + MMLU.MAX_OUTPUT_LEN, GSM8K.MAX_INPUT_LEN + GSM8K.MAX_OUTPUT_LEN) - def get_default_kwargs(self): - return { - "skip_tokenizer_init": False, - "trust_remote_code": True, - "enable_chunked_prefill": True, - "compile_backend": "torch-cudagraph", - "max_batch_size": 128, - "max_seq_len": self.MAX_SEQ_LEN, - "max_num_tokens": self.MAX_SEQ_LEN, - "cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64, 128], - "kv_cache_config": { - "enable_block_reuse": False, - "free_gpu_memory_fraction": 0.5, - "tokens_per_block": 64, - }, - "model_kwargs": { - "torch_dtype": "bfloat16", - }, - "transforms": { - "export_to_gm": { - "num_moe_experts_for_export": 2, - }, - }, - } + def _load_config(self): + """Load config from qwen3.5_moe_400b.yaml with test-specific overrides.""" + config = _load_ad_config('qwen3.5_moe_400b.yaml') + config.pop('world_size', None) + config['max_seq_len'] = self.MAX_SEQ_LEN + config['max_num_tokens'] = self.MAX_SEQ_LEN + config.setdefault('skip_tokenizer_init', False) + config.setdefault('trust_remote_code', True) + return config def get_default_sampling_params(self): eos_id = -1 @@ -669,7 +655,7 @@ def get_default_sampling_params(self): def test_bf16(self, world_size): if get_device_count() < world_size: pytest.skip("Not enough devices for world size, skipping test") - kwargs = self.get_default_kwargs() + kwargs = self._load_config() sampling_params = self.get_default_sampling_params() with AutoDeployLLM(model=self.MODEL_NAME, tokenizer=self.MODEL_NAME, @@ -681,6 +667,28 @@ def test_bf16(self, world_size): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + @staticmethod + def _load_small_config(): + config = _load_ad_config('qwen3.5_moe_35b.yaml') + world_size = config.pop('world_size', 1) + return config, world_size + + @pytest.mark.skip_less_device_memory(80000) + def test_bf16_small(self): + config, world_size = self._load_small_config() + if get_device_count() < world_size: + pytest.skip("Not enough devices for world size, skipping test") + sampling_params = self.get_default_sampling_params() + with AutoDeployLLM(model=self.MODEL_NAME_SMALL, + tokenizer=self.MODEL_NAME_SMALL, + dtype="bfloat16", + world_size=world_size, + **config) as llm: + task = MMLU(self.MODEL_NAME_SMALL) + task.evaluate(llm, sampling_params=sampling_params) + task = GSM8K(self.MODEL_NAME_SMALL) + task.evaluate(llm) + class TestMiniMaxM2(LlmapiAccuracyTestHarness): """Accuracy regression tests for MiniMax M2. @@ -733,6 +741,60 @@ def test_finegrained_fp8(self): task.evaluate(llm) +class TestKimiK2_5(LlmapiAccuracyTestHarness): + """Accuracy regression tests for Kimi-K2.5 (moonshotai/Kimi-K2.5) via AutoDeploy. + + Runs the NVFP4 model via AutoDeploy and verifies benchmark performance on MMLU and GSM8K. + Configuration from examples/auto_deploy/model_registry/configs/kimi_k2.yaml. + """ + + MODEL_NAME = "moonshotai/Kimi-K2.5" + MODEL_PATH = f"{llm_models_root()}/Kimi-K2.5-NVFP4" + CONFIG_YAML = str(_AD_CONFIGS_DIR / "kimi_k2.yaml") + + def get_default_sampling_params(self): + eos_id = -1 + beam_width = 1 + return SamplingParams(end_id=eos_id, + pad_id=eos_id, + n=beam_width, + use_beam_search=beam_width > 1) + + @skip_pre_blackwell + @pytest.mark.skip_less_device_memory(120000) + @pytest.mark.skip_less_device(8) + @pytest.mark.parametrize( + "ep_size,attention_dp", + [(1, False), (1, True), (8, False), (8, True)], + ids=["tp8", "tp8_attn_dp", "ep8", "dep8"], + ) + def test_nvfp4(self, ep_size, attention_dp): + if get_device_count() < 8: + pytest.skip("Not enough devices for world size 8, skipping test") + config = _load_ad_config("kimi_k2.yaml") + config["world_size"] = 8 + kwargs = {k: v for k, v in config.items() if k != "world_size"} + kwargs.setdefault("transforms", {})["detect_sharding"] = { + "enable_attention_dp": attention_dp, + "dist_mapping": { + "tp": 8, + "moe_ep": ep_size + }, + } + sampling_params = self.get_default_sampling_params() + with AutoDeployLLM(model=self.MODEL_PATH, + tokenizer=self.MODEL_PATH, + world_size=8, + yaml_extra=[self.CONFIG_YAML], + trust_remote_code=True, + **kwargs) as llm: + _set_quant_config(llm, "nvfp4") + task = MMLU(self.MODEL_NAME) + task.evaluate(llm, sampling_params=sampling_params) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + + class TestModelRegistryAccuracy(LlmapiAccuracyTestHarness): """Accuracy tests for models from the AutoDeploy model registry. @@ -822,69 +884,3 @@ def test_autodeploy_from_registry(self, model_name, config_overrides, tasks, task.evaluate(llm, sampling_params=sampling_params) except (AssertionError, RuntimeError, ValueError) as e: raise type(e)(f"[{task_cls.__name__}] {e}") from None - - -class TestKimiK2_5(LlmapiAccuracyTestHarness): - """Accuracy regression tests for Kimi-K2.5 via AutoDeploy. - - Runs the model via AutoDeploy and verifies benchmark performance on MMLU and GSM8K. - Configuration derived from examples/auto_deploy/model_registry/configs/kimi_k2.yaml. - """ - - MODEL_NAME = "nvidia/Kimi-K2.5-NVFP4" - MAX_SEQ_LEN = max(MMLU.MAX_INPUT_LEN + MMLU.MAX_OUTPUT_LEN, - GSM8K.MAX_INPUT_LEN + GSM8K.MAX_OUTPUT_LEN) - - def get_default_kwargs(self): - return { - "skip_tokenizer_init": False, - "trust_remote_code": True, - "enable_chunked_prefill": True, - "compile_backend": "torch-cudagraph", - "max_batch_size": 64, - "max_seq_len": self.MAX_SEQ_LEN, - "max_num_tokens": self.MAX_SEQ_LEN, - "cuda_graph_batch_sizes": [1, 2, 4, 8, 16, 32, 64], - "kv_cache_config": { - "dtype": "bfloat16", - "enable_block_reuse": False, - "free_gpu_memory_fraction": 0.7, - "tokens_per_block": 64, - }, - "model_kwargs": { - "torch_dtype": "bfloat16", - }, - "transforms": { - "export_to_gm": { - "num_moe_experts_for_export": 2, - }, - "fuse_nvfp4_moe": { - "allow_different_input_scales": True, - }, - }, - } - - def get_default_sampling_params(self): - eos_id = -1 - beam_width = 1 - return SamplingParams(end_id=eos_id, - pad_id=eos_id, - n=beam_width, - use_beam_search=beam_width > 1) - - @pytest.mark.skip_less_device_memory(180000) - @pytest.mark.parametrize("world_size", [8]) - def test_nvfp4(self, world_size): - if get_device_count() < world_size: - pytest.skip("Not enough devices for world size, skipping test") - kwargs = self.get_default_kwargs() - sampling_params = self.get_default_sampling_params() - with AutoDeployLLM(model=self.MODEL_NAME, - tokenizer=self.MODEL_NAME, - dtype="bfloat16", - world_size=world_size, - **kwargs) as llm: - task = MMLU(self.MODEL_NAME) - task.evaluate(llm, sampling_params=sampling_params) - task = GSM8K(self.MODEL_NAME) - task.evaluate(llm) diff --git a/tests/unittest/auto_deploy/_utils_test/_model_test_utils.py b/tests/unittest/auto_deploy/_utils_test/_model_test_utils.py index d37dc522bad..c7346e79e06 100644 --- a/tests/unittest/auto_deploy/_utils_test/_model_test_utils.py +++ b/tests/unittest/auto_deploy/_utils_test/_model_test_utils.py @@ -465,8 +465,8 @@ def apply_rotary_pos_emb_ds(q, k, cos, sin, position_ids, unsqueeze_dim=1): "model_kwargs": { "first_k_dense_replace": 1, "num_hidden_layers": 2, - "hidden_size": 32, - "intermediate_size": 64, + "hidden_size": 128, + "intermediate_size": 128, "kv_lora_rank": 512, # NOTE: must be 512 (default) for flashinfer_mla to work "qk_rope_head_dim": 64, # NOTE: must be 64 (default) for flashinfer_mla to work "moe_intermediate_size": 128, diff --git a/tests/unittest/auto_deploy/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/auto_deploy/multigpu/transformations/library/test_tp_sharding.py index 9e50ac3f106..5ec9ce91900 100644 --- a/tests/unittest/auto_deploy/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/auto_deploy/multigpu/transformations/library/test_tp_sharding.py @@ -292,14 +292,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: k_conv = k_conv.reshape(b, s, k_conv.shape[-1] // self.head_k_dim, self.head_k_dim) v_conv = v_conv.reshape(b, s, v_conv.shape[-1] // self.head_v_dim, self.head_v_dim) - # Repeat q, k to match num_v_heads when num_v_heads > num_k_heads (GQA) - if self.num_v_heads // self.num_k_heads > 1: - q_conv = q_conv.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) - k_conv = k_conv.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) - - beta = b_tensor.sigmoid() - g = -self.A_log.exp() * F.softplus(a + self.dt_bias) - attn_out = torch.ops.auto_deploy.torch_gated_delta_rule(q_conv, k_conv, v_conv, g, beta) + # L2 norm, GQA repeat-interleave, and gating are handled inside the op + attn_out = torch.ops.auto_deploy.torch_gated_delta_rule( + q_conv, k_conv, v_conv, a, b_tensor, self.A_log, self.dt_bias + ) # Gated norm on head_v_dim, then project attn_out_flat = attn_out.reshape(-1, self.head_v_dim) @@ -378,17 +374,12 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: k_conv = k_conv.reshape(b, s, k_conv.shape[-1] // self.head_k_dim, self.head_k_dim) v_conv = v_conv.reshape(b, s, v_conv.shape[-1] // self.head_v_dim, self.head_v_dim) - # 4. Repeat q, k to match num_v_heads when num_v_heads > num_k_heads (GQA) - if self.num_v_heads // self.num_k_heads > 1: - q_conv = q_conv.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) - k_conv = k_conv.repeat_interleave(self.num_v_heads // self.num_k_heads, dim=2) - - # 5. Gated delta rule - beta = b_tensor.sigmoid() - g = -self.A_log.exp() * F.softplus(a + self.dt_bias) - attn_out = torch.ops.auto_deploy.torch_gated_delta_rule(q_conv, k_conv, v_conv, g, beta) + # 4. Gated delta rule (L2 norm, GQA expand, gating handled inside the op) + attn_out = torch.ops.auto_deploy.torch_gated_delta_rule( + q_conv, k_conv, v_conv, a, b_tensor, self.A_log, self.dt_bias + ) - # 6. Gated norm on head_v_dim, then project + # 5. Gated norm on head_v_dim, then project attn_out_flat = attn_out.reshape(-1, self.head_v_dim) z_flat = z.reshape(-1, self.head_v_dim) normed = self.norm(attn_out_flat) * z_flat @@ -510,7 +501,6 @@ def _run_sharding_execution_job( model = model_cls(num_features, num_features, bias=bias).to( device="cuda", dtype=torch.float16 ) - # update the tp_plan in predefined_config to force simple sharding of the single linear layer predefined_config = {"tp_plan": {"*": "gather"}} elif model_cls == FineGrainedFP8MLP: # FineGrainedFP8MLP needs features divisible by 128 (block size) @@ -733,7 +723,6 @@ def _run_pattern_detection_job( model = model_cls(num_features, num_features, bias=bias).to( device="cuda", dtype=torch.float16 ) - # update the tp_plan in predefined_config to force simple sharding of the single linear layer predefined_config = {"tp_plan": {"*": "gather"}} elif model_cls == FineGrainedFP8MLP: # FineGrainedFP8MLP needs features divisible by 128 (block size) diff --git a/tests/unittest/auto_deploy/singlegpu/custom_ops/fla/test_fla_cached_gated_delta_rule.py b/tests/unittest/auto_deploy/singlegpu/custom_ops/fla/test_fla_cached_gated_delta_rule.py index def1f18c326..f30fbb2eb64 100644 --- a/tests/unittest/auto_deploy/singlegpu/custom_ops/fla/test_fla_cached_gated_delta_rule.py +++ b/tests/unittest/auto_deploy/singlegpu/custom_ops/fla/test_fla_cached_gated_delta_rule.py @@ -3,16 +3,22 @@ Covers: - Decode-only path: batch of single tokens, verify output and final state match ``fused_recurrent_gated_delta_rule_fwd`` called directly with gathered - initial states. + initial states and manually preprocessed inputs. - Prefill-only path: batch of multi-token sequences (variable length), verify output and final state match per-sequence ``chunk_gated_delta_rule``. - Prefill with initial state: same as prefill but with ``use_initial_states=True`` and non-zero initial cache, verifying the cache history is correctly loaded and passed to the kernel. +- GVA (Grouped Value Attention): q/k have fewer heads than v (a/b/A_log/dt_bias). + +The cached op accepts raw (un-normalized, un-expanded) q/k with raw gating +projections (a, b, A_log, dt_bias). L2 normalization, GQA expansion, and +gating are performed internally. """ import pytest import torch +import torch.nn.functional as F # Register all auto_deploy custom ops import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 @@ -30,25 +36,45 @@ def gdr_env(): return {"device": device, "dtype": dtype} -def _random_inputs(device, dtype, batch, seq, num_heads, key_dim, value_dim): - """Generate random gated delta rule inputs.""" - q = torch.randn(batch, seq, num_heads, key_dim, device=device, dtype=dtype) - k = torch.randn(batch, seq, num_heads, key_dim, device=device, dtype=dtype) - v = torch.randn(batch, seq, num_heads, value_dim, device=device, dtype=dtype) - g = -torch.rand(batch, seq, num_heads, device=device, dtype=dtype) # negative (decay) - beta = torch.sigmoid(torch.randn(batch, seq, num_heads, device=device, dtype=dtype)) +def _random_inputs(device, dtype, batch, seq, num_k_heads, num_v_heads, key_dim, value_dim): + """Generate random gated delta rule inputs (raw, un-preprocessed). + + q/k have num_k_heads, v/a/b have num_v_heads (GVA when num_v_heads > num_k_heads). + A_log, dt_bias have shape [num_v_heads]. + """ + q = torch.randn(batch, seq, num_k_heads, key_dim, device=device, dtype=dtype) + k = torch.randn(batch, seq, num_k_heads, key_dim, device=device, dtype=dtype) + v = torch.randn(batch, seq, num_v_heads, value_dim, device=device, dtype=dtype) + a = torch.randn(batch, seq, num_v_heads, device=device, dtype=dtype) + b = torch.randn(batch, seq, num_v_heads, device=device, dtype=dtype) + A_log = torch.zeros(num_v_heads, device=device, dtype=dtype) + dt_bias = torch.zeros(num_v_heads, device=device, dtype=dtype) + return q, k, v, a, b, A_log, dt_bias - # L2 normalize Q and K as the patched forward does - q = torch.nn.functional.normalize(q, dim=-1) - k = torch.nn.functional.normalize(k, dim=-1) - return q, k, v, g, beta +def _preprocess_for_reference(q, k, a, b_proj, A_log, dt_bias, num_v_heads=None): + """Manually preprocess raw inputs for use with reference kernels. -def test_decode_only(gdr_env): + If num_v_heads is given and greater than q.shape[2], GQA-expands q and k + to num_v_heads. Returns q_norm, k_norm, g, beta ready for chunk/fused_recurrent. + """ + q_norm = F.normalize(q.float(), dim=-1).to(q.dtype) + k_norm = F.normalize(k.float(), dim=-1).to(k.dtype) + g = -A_log.float().exp() * F.softplus(a.float() + dt_bias) + beta = b_proj.float().sigmoid() + if num_v_heads is not None and num_v_heads > q.shape[2]: + interleave = num_v_heads // q.shape[2] + q_norm = q_norm.repeat_interleave(interleave, dim=2) + k_norm = k_norm.repeat_interleave(interleave, dim=2) + return q_norm, k_norm, g, beta + + +@pytest.mark.parametrize("num_k_heads,num_v_heads", [(2, 2), (2, 4)]) +def test_decode_only(gdr_env, num_k_heads, num_v_heads): """Decode-only: batch of single tokens through the cached op. Verifies output and cache state match fused_recurrent_gated_delta_rule_fwd - called directly with gathered initial states. + called directly with gathered initial states and preprocessed inputs. """ device = gdr_env["device"] dtype = gdr_env["dtype"] @@ -57,79 +83,80 @@ def test_decode_only(gdr_env): batch = 3 seq = 1 - num_heads = 2 key_dim = 8 value_dim = 8 max_batch_size = 6 scale = key_dim**-0.5 - q, k, v, g, beta = _random_inputs(device, dtype, batch, seq, num_heads, key_dim, value_dim) + q, k, v, a, b, A_log, dt_bias = _random_inputs( + device, dtype, batch, seq, num_k_heads, num_v_heads, key_dim, value_dim + ) - # Slot mapping with arbitrary order slot_idx = torch.tensor([4, 1, 3], device=device, dtype=torch.int32) # Initialize cache with random state (simulating existing history) + # Cache shape uses num_v_heads (HV), not num_k_heads delta_cache = torch.randn( max_batch_size, - num_heads, + num_v_heads, key_dim, value_dim, device=device, dtype=dtype, ) - # Metadata for decode-only: no prefill batch_info_host = torch.tensor([0, 0, batch], device=device, dtype=torch.int32) cu_seqlen = torch.zeros(1, device=device, dtype=torch.int32) use_initial_states = torch.ones(batch, device=device, dtype=torch.bool) + any_prefill_use_initial_states_host = torch.tensor( + [False], device=device, dtype=torch.bool + ) # decode-only, no prefill - # Snapshot cache before mutation for reference gathered_before = delta_cache.clone().index_select(0, slot_idx.long()) - # Run cached op y = torch.ops.auto_deploy.fla_cached_gated_delta_rule( q, k, v, - g, - beta, + a, + b, + A_log, + dt_bias, batch_info_host, cu_seqlen, slot_idx, use_initial_states, + any_prefill_use_initial_states_host, delta_cache, scale, ) - assert y.shape == (batch, seq, num_heads, value_dim) + assert y.shape == (batch, seq, num_v_heads, value_dim) assert torch.isfinite(y).all() - # Reference: call fused_recurrent_gated_delta_rule_fwd directly - # The cached op internally does: - # q_flat[num_prefill_tokens:, None] -> [num_decode, 1, H, K] - # So we reshape our inputs similarly - q_flat = q.view(batch, num_heads, -1) # [B, H, K] - k_flat = k.view(batch, num_heads, -1) # [B, H, K] - v_flat = v.view(batch, num_heads, -1) # [B, H, V] - g_flat = g.view(batch, num_heads) # [B, H] - beta_flat = beta.view(batch, num_heads) # [B, H] + # Reference: preprocess (L2 + GQA expand) and call fused_recurrent directly + q_norm, k_norm, g, beta = _preprocess_for_reference(q, k, a, b, A_log, dt_bias, num_v_heads) + q_flat = q_norm.view(batch, num_v_heads, -1) + k_flat = k_norm.view(batch, num_v_heads, -1) + v_flat = v.view(batch, num_v_heads, -1) + g_flat = g.view(batch, num_v_heads) + beta_flat = beta.view(batch, num_v_heads) y_ref, final_state_ref = fused_recurrent_gated_delta_rule_fwd( - q=q_flat[:, None], # [B, 1, H, K] - k=k_flat[:, None], # [B, 1, H, K] - v=v_flat[:, None], # [B, 1, H, V] - g=g_flat[:, None], # [B, 1, H] - beta=beta_flat[:, None], # [B, 1, H] + q=q_flat[:, None], + k=k_flat[:, None], + v=v_flat[:, None], + g=g_flat[:, None], + beta=beta_flat[:, None], scale=scale, initial_state=gathered_before.clone(), output_final_state=True, + use_qk_l2norm_in_kernel=False, # already normalized in _preprocess_for_reference ) - # y_ref shape: [B, 1, H, V] -> compare against y: [B, 1, H, V] y_ref_reshaped = y_ref.to(dtype) torch.testing.assert_close(y, y_ref_reshaped, atol=atol, rtol=rtol) - # Compare updated cache states after = delta_cache.index_select(0, slot_idx.long()) torch.testing.assert_close( after, @@ -139,10 +166,12 @@ def test_decode_only(gdr_env): ) -def test_prefill_only(gdr_env): +@pytest.mark.parametrize("num_k_heads,num_v_heads", [(2, 2), (2, 4)]) +def test_prefill_only(gdr_env, num_k_heads, num_v_heads): """Prefill-only: two sequences of different lengths, flattened. - Verifies output and final state match per-sequence chunk_gated_delta_rule. + Verifies output and final state match per-sequence chunk_gated_delta_rule + with use_qk_l2norm_in_kernel=True and manually computed g/beta. """ device = gdr_env["device"] dtype = gdr_env["dtype"] @@ -151,34 +180,33 @@ def test_prefill_only(gdr_env): seq_lens = [5, 3] total_tokens = sum(seq_lens) - num_heads = 2 key_dim = 8 value_dim = 8 max_batch_size = 4 scale = key_dim**-0.5 - # Create flattened inputs: [1, total_tokens, H, D] - q, k, v, g, beta = _random_inputs( + q, k, v, a, b, A_log, dt_bias = _random_inputs( device, dtype, 1, total_tokens, - num_heads, + num_k_heads, + num_v_heads, key_dim, value_dim, ) slot_idx = torch.tensor([2, 0], device=device, dtype=torch.int32) + # Cache shape uses num_v_heads delta_cache = torch.zeros( max_batch_size, - num_heads, + num_v_heads, key_dim, value_dim, device=device, dtype=dtype, ) - # Metadata for prefill-only num_prefill = len(seq_lens) batch_info_host = torch.tensor( [num_prefill, total_tokens, 0], @@ -187,46 +215,52 @@ def test_prefill_only(gdr_env): ) cu_seqlen = torch.tensor([0, seq_lens[0], total_tokens], device=device, dtype=torch.int32) use_initial_states = torch.zeros(num_prefill, device=device, dtype=torch.bool) + any_prefill_use_initial_states_host = torch.tensor( + [False], device=device, dtype=torch.bool + ) # no prefill seq uses initial state - # Run cached op y = torch.ops.auto_deploy.fla_cached_gated_delta_rule( q, k, v, - g, - beta, + a, + b, + A_log, + dt_bias, batch_info_host, cu_seqlen, slot_idx, use_initial_states, + any_prefill_use_initial_states_host, delta_cache, scale, ) - assert y.shape == (1, total_tokens, num_heads, value_dim) + assert y.shape == (1, total_tokens, num_v_heads, value_dim) assert torch.isfinite(y).all() - # Reference: call chunk_gated_delta_rule per sequence + # Reference: compute g/beta and call chunk per sequence with l2norm + q_norm, k_norm, g, beta = _preprocess_for_reference(q, k, a, b, A_log, dt_bias, num_v_heads) + y_ref = torch.empty_like(y) for i, sl in enumerate(seq_lens): start = sum(seq_lens[:i]) end = start + sl - # chunk_gated_delta_rule expects [B, T, H, D] layout y_seq, final_state = chunk_gated_delta_rule( - q=q[:, start:end], - k=k[:, start:end], + q=q_norm[:, start:end], + k=k_norm[:, start:end], v=v[:, start:end], g=g[:, start:end], beta=beta[:, start:end], scale=scale, initial_state=None, output_final_state=True, + use_qk_l2norm_in_kernel=False, # already normalized ) y_ref[:, start:end] = y_seq.to(dtype) - # Verify cache was updated for this slot torch.testing.assert_close( delta_cache[slot_idx[i].long()], final_state.squeeze(0).to(delta_cache.dtype), @@ -237,7 +271,8 @@ def test_prefill_only(gdr_env): torch.testing.assert_close(y, y_ref, atol=atol, rtol=rtol) -def test_prefill_with_initial_state(gdr_env): +@pytest.mark.parametrize("num_k_heads,num_v_heads", [(2, 2), (2, 4)]) +def test_prefill_with_initial_state(gdr_env, num_k_heads, num_v_heads): """Prefill with initial state: verifies cache history is correctly loaded. Sets use_initial_states=True and a non-zero initial cache, then checks that @@ -250,57 +285,65 @@ def test_prefill_with_initial_state(gdr_env): rtol = 5e-3 seq_len = 4 - num_heads = 2 key_dim = 8 value_dim = 8 max_batch_size = 2 scale = key_dim**-0.5 - q, k, v, g, beta = _random_inputs(device, dtype, 1, seq_len, num_heads, key_dim, value_dim) + q, k, v, a, b, A_log, dt_bias = _random_inputs( + device, dtype, 1, seq_len, num_k_heads, num_v_heads, key_dim, value_dim + ) slot_idx = torch.tensor([1], device=device, dtype=torch.int32) - # Non-zero initial state in cache + # Non-zero initial state in cache (uses num_v_heads) delta_cache = torch.randn( max_batch_size, - num_heads, + num_v_heads, key_dim, value_dim, device=device, dtype=dtype, ) - initial_state = delta_cache[1].clone() # snapshot + initial_state = delta_cache[1].clone() - # Metadata: one prefill sequence with initial state batch_info_host = torch.tensor([1, seq_len, 0], device=device, dtype=torch.int32) cu_seqlen = torch.tensor([0, seq_len], device=device, dtype=torch.int32) use_initial_states = torch.tensor([True], device=device, dtype=torch.bool) + any_prefill_use_initial_states_host = torch.tensor( + [True], device=device, dtype=torch.bool + ) # one prefill seq with initial state - # Run cached op WITH initial state y_with_init = torch.ops.auto_deploy.fla_cached_gated_delta_rule( q, k, v, - g, - beta, + a, + b, + A_log, + dt_bias, batch_info_host, cu_seqlen, slot_idx, use_initial_states, + any_prefill_use_initial_states_host, delta_cache, scale, ) - # Reference: chunk_gated_delta_rule with the same initial state + # Reference: compute g/beta and call chunk with l2norm and initial state + q_norm, k_norm, g, beta = _preprocess_for_reference(q, k, a, b, A_log, dt_bias, num_v_heads) + y_ref, final_ref = chunk_gated_delta_rule( - q=q, - k=k, + q=q_norm, + k=k_norm, v=v, g=g, beta=beta, scale=scale, initial_state=initial_state.unsqueeze(0), output_final_state=True, + use_qk_l2norm_in_kernel=False, # already normalized in _preprocess_for_reference ) torch.testing.assert_close(y_with_init, y_ref.to(dtype), atol=atol, rtol=rtol) @@ -311,25 +354,30 @@ def test_prefill_with_initial_state(gdr_env): rtol=rtol, ) - # Also verify it's different from running WITHOUT initial state (zero state) + # Verify it differs from running WITHOUT initial state delta_cache_zero = torch.zeros_like(delta_cache) use_initial_states_false = torch.tensor([False], device=device, dtype=torch.bool) + any_prefill_use_initial_states_host_false = torch.tensor( + [False], device=device, dtype=torch.bool + ) y_without_init = torch.ops.auto_deploy.fla_cached_gated_delta_rule( q, k, v, - g, - beta, + a, + b, + A_log, + dt_bias, batch_info_host, cu_seqlen, slot_idx, use_initial_states_false, + any_prefill_use_initial_states_host_false, delta_cache_zero, scale, ) - # The results should differ when there's a non-zero initial state assert not torch.allclose(y_with_init, y_without_init, atol=1e-3, rtol=1e-3), ( "Output with initial state should differ from output without initial state" ) diff --git a/tests/unittest/auto_deploy/singlegpu/custom_ops/fla/test_fused_gdn_gating.py b/tests/unittest/auto_deploy/singlegpu/custom_ops/fla/test_fused_gdn_gating.py new file mode 100644 index 00000000000..6e5cc747058 --- /dev/null +++ b/tests/unittest/auto_deploy/singlegpu/custom_ops/fla/test_fused_gdn_gating.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the fused GDN gating custom ops (torch + triton).""" + +import pytest +import torch +import torch.nn.functional as F + +from tensorrt_llm._torch.auto_deploy.custom_ops.fla import ( + gdn_gating as _gdn_gating_ops, # noqa: F401 +) + + +def _reference_gdn_gating(A_log, a, dt_bias, beta=1.0, threshold=20.0): + """Pure-torch reference: g = -exp(A_log) * softplus(a + dt_bias).""" + return -torch.exp(A_log.float()) * F.softplus(a.float() + dt_bias.float(), beta, threshold) + + +@pytest.mark.parametrize("batch_size", [1, 4]) +@pytest.mark.parametrize("seq_len", [1, 8, 32]) +@pytest.mark.parametrize("num_heads", [8, 16, 64]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +class TestFusedGdnGating: + """Tests for torch_fused_gdn_gating and triton_fused_gdn_gating ops.""" + + def test_torch_op_matches_reference(self, batch_size, seq_len, num_heads, dtype): + """torch_fused_gdn_gating matches the inline reference computation.""" + A_log = torch.randn(num_heads, device="cuda", dtype=dtype) + a = torch.randn(batch_size, seq_len, num_heads, device="cuda", dtype=dtype) + dt_bias = torch.randn(num_heads, device="cuda", dtype=dtype) + + ref = _reference_gdn_gating(A_log, a, dt_bias) + out = torch.ops.auto_deploy.torch_fused_gdn_gating(A_log, a, dt_bias) + + assert out.dtype == torch.float32 + assert out.shape == (batch_size, seq_len, num_heads) + torch.testing.assert_close(out, ref, atol=1e-5, rtol=1e-5) + + def test_triton_op_matches_reference(self, batch_size, seq_len, num_heads, dtype): + """triton_fused_gdn_gating matches the inline reference computation.""" + A_log = torch.randn(num_heads, device="cuda", dtype=dtype) + a = torch.randn(batch_size, seq_len, num_heads, device="cuda", dtype=dtype) + dt_bias = torch.randn(num_heads, device="cuda", dtype=dtype) + + ref = _reference_gdn_gating(A_log, a, dt_bias) + out = torch.ops.auto_deploy.triton_fused_gdn_gating(A_log, a, dt_bias) + + assert out.dtype == torch.float32 + assert out.shape == (batch_size, seq_len, num_heads) + torch.testing.assert_close(out, ref, atol=1e-5, rtol=1e-5) + + def test_torch_and_triton_match(self, batch_size, seq_len, num_heads, dtype): + """Torch and triton ops produce identical results.""" + A_log = torch.randn(num_heads, device="cuda", dtype=dtype) + a = torch.randn(batch_size, seq_len, num_heads, device="cuda", dtype=dtype) + dt_bias = torch.randn(num_heads, device="cuda", dtype=dtype) + + out_torch = torch.ops.auto_deploy.torch_fused_gdn_gating(A_log, a, dt_bias) + out_triton = torch.ops.auto_deploy.triton_fused_gdn_gating(A_log, a, dt_bias) + + torch.testing.assert_close(out_torch, out_triton, atol=1e-5, rtol=1e-5) + + +@pytest.mark.parametrize("beta,threshold", [(1.0, 20.0), (2.0, 10.0)]) +def test_softplus_parameters(beta, threshold): + """Verify that non-default softplus beta/threshold are respected.""" + num_heads = 16 + A_log = torch.randn(num_heads, device="cuda", dtype=torch.float16) + a = torch.randn(2, 4, num_heads, device="cuda", dtype=torch.float16) + dt_bias = torch.randn(num_heads, device="cuda", dtype=torch.float16) + + ref = _reference_gdn_gating(A_log, a, dt_bias, beta, threshold) + out_torch = torch.ops.auto_deploy.torch_fused_gdn_gating(A_log, a, dt_bias, beta, threshold) + out_triton = torch.ops.auto_deploy.triton_fused_gdn_gating(A_log, a, dt_bias, beta, threshold) + + torch.testing.assert_close(out_torch, ref, atol=1e-5, rtol=1e-5) + torch.testing.assert_close(out_triton, ref, atol=1e-5, rtol=1e-5) diff --git a/tests/unittest/auto_deploy/singlegpu/custom_ops/fla/test_torch_cached_gated_delta_rule.py b/tests/unittest/auto_deploy/singlegpu/custom_ops/fla/test_torch_cached_gated_delta_rule.py index fe0b4c0030a..e5c2028da51 100644 --- a/tests/unittest/auto_deploy/singlegpu/custom_ops/fla/test_torch_cached_gated_delta_rule.py +++ b/tests/unittest/auto_deploy/singlegpu/custom_ops/fla/test_torch_cached_gated_delta_rule.py @@ -2,15 +2,20 @@ Covers: - Decode-only path: batch of single tokens, verify output and final state - match ``_torch_gated_delta_step`` called directly. + match ``_torch_gated_delta_step`` called directly with preprocessed inputs. - Prefill-only path: batch of multi-token sequences, verify output and final - state match ``_torch_gated_delta_prefill``. + state match ``_torch_gated_delta_prefill`` with preprocessed inputs. - Prefill with initial state: same as prefill but with ``use_initial_states=True`` and non-zero initial cache, verifying the cache history is correctly loaded. + +The cached op accepts raw (un-normalized, un-expanded) q/k with raw gating +projections (a, b, A_log, dt_bias). L2 normalization, GQA expansion, and +gating are performed internally. """ import pytest import torch +import torch.nn.functional as F # Register all auto_deploy custom ops import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 @@ -30,23 +35,31 @@ def gated_delta_env(): def _random_inputs(device, dtype, batch, seq, num_heads, key_dim, value_dim): - """Generate random gated delta rule inputs.""" + """Generate random gated delta rule inputs (raw, un-preprocessed).""" q = torch.randn(batch, seq, num_heads, key_dim, device=device, dtype=dtype) k = torch.randn(batch, seq, num_heads, key_dim, device=device, dtype=dtype) v = torch.randn(batch, seq, num_heads, value_dim, device=device, dtype=dtype) - g = -torch.rand(batch, seq, num_heads, device=device, dtype=dtype) # negative (decay) - beta = torch.sigmoid(torch.randn(batch, seq, num_heads, device=device, dtype=dtype)) + a = torch.randn(batch, seq, num_heads, device=device, dtype=dtype) + b = torch.randn(batch, seq, num_heads, device=device, dtype=dtype) + A_log = torch.zeros(num_heads, device=device, dtype=dtype) + dt_bias = torch.zeros(num_heads, device=device, dtype=dtype) + return q, k, v, a, b, A_log, dt_bias + - # L2 normalize Q and K as the patched forward does - q = torch.nn.functional.normalize(q, dim=-1) - k = torch.nn.functional.normalize(k, dim=-1) - return q, k, v, g, beta +def _preprocess_for_reference(q, k, a, b_proj, A_log, dt_bias): + """Manually preprocess raw inputs for use with reference helpers.""" + q_norm = F.normalize(q.float(), dim=-1).to(q.dtype) + k_norm = F.normalize(k.float(), dim=-1).to(k.dtype) + g = -A_log.float().exp() * F.softplus(a.float() + dt_bias) + beta = b_proj.float().sigmoid() + return q_norm, k_norm, g, beta def test_decode_only(gated_delta_env): """Decode-only: batch of single tokens through the cached op. - Verifies output and cache state match _torch_gated_delta_step. + Verifies output and cache state match _torch_gated_delta_step with + manually preprocessed inputs. """ device = gated_delta_env["device"] dtype = gated_delta_env["dtype"] @@ -59,12 +72,18 @@ def test_decode_only(gated_delta_env): max_batch_size = 6 scale = key_dim**-0.5 - q, k, v, g, beta = _random_inputs(device, dtype, batch, seq, num_heads, key_dim, value_dim) + q, k, v, a, b, A_log, dt_bias = _random_inputs( + device, + dtype, + batch, + seq, + num_heads, + key_dim, + value_dim, + ) - # Slot mapping with arbitrary order slot_idx = torch.tensor([5, 1, 3, 0], device=device, dtype=torch.int32) - # Initialize cache with random state (simulating existing history) delta_cache = torch.randn( max_batch_size, num_heads, @@ -74,21 +93,20 @@ def test_decode_only(gated_delta_env): dtype=torch.float32, ) - # Metadata for decode-only: no prefill batch_info_host = torch.tensor([0, 0, batch], device=device, dtype=torch.int32) cu_seqlen = torch.zeros(1, device=device, dtype=torch.int32) use_initial_states = torch.ones(batch, device=device, dtype=torch.bool) - # Snapshot cache before mutation for reference gathered_before = delta_cache.clone().index_select(0, slot_idx.long()) - # Run cached op y = torch.ops.auto_deploy.torch_cached_gated_delta_rule( q, k, v, - g, - beta, + a, + b, + A_log, + dt_bias, batch_info_host, cu_seqlen, slot_idx, @@ -100,29 +118,29 @@ def test_decode_only(gated_delta_env): assert y.shape == (batch, seq, num_heads, value_dim) assert torch.isfinite(y).all() - # Reference: call _torch_gated_delta_step directly per sequence + # Reference: preprocess and call _torch_gated_delta_step directly + q_norm, k_norm, g, beta = _preprocess_for_reference(q, k, a, b, A_log, dt_bias) + y_ref_list = [] final_state_ref_list = [] for i in range(batch): o_ref, s_ref = _torch_gated_delta_step( - q[i, 0].unsqueeze(0), # [1, H, K] - k[i, 0].unsqueeze(0), # [1, H, K] - v[i, 0].unsqueeze(0), # [1, H, V] - g[i, 0].unsqueeze(0), # [1, H] - beta[i, 0].unsqueeze(0), # [1, H] - gathered_before[i].unsqueeze(0), # [1, H, K, V] + q_norm[i, 0].unsqueeze(0), + k_norm[i, 0].unsqueeze(0), + v[i, 0].unsqueeze(0), + g[i, 0].unsqueeze(0), + beta[i, 0].unsqueeze(0), + gathered_before[i].unsqueeze(0), scale, ) - y_ref_list.append(o_ref.squeeze(0)) # [H, V] - final_state_ref_list.append(s_ref.squeeze(0)) # [H, K, V] + y_ref_list.append(o_ref.squeeze(0)) + final_state_ref_list.append(s_ref.squeeze(0)) - y_ref = torch.stack(y_ref_list, dim=0).unsqueeze(1).to(dtype) # [B, 1, H, V] - final_state_ref = torch.stack(final_state_ref_list, dim=0) # [B, H, K, V] + y_ref = torch.stack(y_ref_list, dim=0).unsqueeze(1).to(dtype) + final_state_ref = torch.stack(final_state_ref_list, dim=0) - # Compare outputs torch.testing.assert_close(y, y_ref, atol=1e-3, rtol=1e-3) - # Compare updated cache states after = delta_cache.index_select(0, slot_idx.long()) torch.testing.assert_close( after, @@ -135,7 +153,8 @@ def test_decode_only(gated_delta_env): def test_prefill_only(gated_delta_env): """Prefill-only: two sequences of different lengths, flattened. - Verifies output and final state match _torch_gated_delta_prefill. + Verifies output and final state match _torch_gated_delta_prefill with + manually preprocessed inputs. """ device = gated_delta_env["device"] dtype = gated_delta_env["dtype"] @@ -148,8 +167,7 @@ def test_prefill_only(gated_delta_env): max_batch_size = 4 scale = key_dim**-0.5 - # Create flattened inputs: [1, total_tokens, H, D] - q, k, v, g, beta = _random_inputs( + q, k, v, a, b, A_log, dt_bias = _random_inputs( device, dtype, 1, @@ -169,7 +187,6 @@ def test_prefill_only(gated_delta_env): dtype=torch.float32, ) - # Metadata for prefill-only num_prefill = len(seq_lens) batch_info_host = torch.tensor( [num_prefill, total_tokens, 0], @@ -179,13 +196,14 @@ def test_prefill_only(gated_delta_env): cu_seqlen = torch.tensor([0, seq_lens[0], total_tokens], device=device, dtype=torch.int32) use_initial_states = torch.zeros(num_prefill, device=device, dtype=torch.bool) - # Run cached op y = torch.ops.auto_deploy.torch_cached_gated_delta_rule( q, k, v, - g, - beta, + a, + b, + A_log, + dt_bias, batch_info_host, cu_seqlen, slot_idx, @@ -197,18 +215,14 @@ def test_prefill_only(gated_delta_env): assert y.shape == (1, total_tokens, num_heads, value_dim) assert torch.isfinite(y).all() - # Reference: run _torch_gated_delta_prefill per sequence + # Reference: preprocess and call _torch_gated_delta_prefill per sequence + q_norm, k_norm, g, beta = _preprocess_for_reference(q, k, a, b, A_log, dt_bias) + y_ref = torch.empty_like(y) for i, sl in enumerate(seq_lens): start = sum(seq_lens[:i]) end = start + sl - q_seq = q[:, start:end] - k_seq = k[:, start:end] - v_seq = v[:, start:end] - g_seq = g[:, start:end] - beta_seq = beta[:, start:end] - init_state = torch.zeros( 1, num_heads, @@ -219,18 +233,17 @@ def test_prefill_only(gated_delta_env): ) y_seq, final_state = _torch_gated_delta_prefill( - q_seq, - k_seq, - v_seq, - g_seq, - beta_seq, + q_norm[:, start:end], + k_norm[:, start:end], + v[:, start:end], + g[:, start:end], + beta[:, start:end], scale, init_state, ) y_ref[:, start:end] = y_seq.to(dtype) - # Verify cache was updated for this slot torch.testing.assert_close( delta_cache[slot_idx[i].long()], final_state.squeeze(0).to(delta_cache.dtype), @@ -257,11 +270,18 @@ def test_prefill_with_initial_state(gated_delta_env): max_batch_size = 2 scale = key_dim**-0.5 - q, k, v, g, beta = _random_inputs(device, dtype, 1, seq_len, num_heads, key_dim, value_dim) + q, k, v, a, b, A_log, dt_bias = _random_inputs( + device, + dtype, + 1, + seq_len, + num_heads, + key_dim, + value_dim, + ) slot_idx = torch.tensor([1], device=device, dtype=torch.int32) - # Non-zero initial state in cache delta_cache = torch.randn( max_batch_size, num_heads, @@ -270,20 +290,20 @@ def test_prefill_with_initial_state(gated_delta_env): device=device, dtype=torch.float32, ) - initial_state = delta_cache[1].clone() # snapshot + initial_state = delta_cache[1].clone() - # Metadata: one prefill sequence with initial state batch_info_host = torch.tensor([1, seq_len, 0], device=device, dtype=torch.int32) cu_seqlen = torch.tensor([0, seq_len], device=device, dtype=torch.int32) use_initial_states = torch.tensor([True], device=device, dtype=torch.bool) - # Run cached op WITH initial state y_with_init = torch.ops.auto_deploy.torch_cached_gated_delta_rule( q, k, v, - g, - beta, + a, + b, + A_log, + dt_bias, batch_info_host, cu_seqlen, slot_idx, @@ -292,10 +312,12 @@ def test_prefill_with_initial_state(gated_delta_env): scale, ) - # Reference: _torch_gated_delta_prefill with the same initial state + # Reference: preprocess and call _torch_gated_delta_prefill + q_norm, k_norm, g, beta = _preprocess_for_reference(q, k, a, b, A_log, dt_bias) + y_ref, final_ref = _torch_gated_delta_prefill( - q, - k, + q_norm, + k_norm, v, g, beta, @@ -311,7 +333,7 @@ def test_prefill_with_initial_state(gated_delta_env): rtol=1e-3, ) - # Also verify it's different from running WITHOUT initial state (zero state) + # Verify it differs from running WITHOUT initial state delta_cache_zero = torch.zeros_like(delta_cache) use_initial_states_false = torch.tensor([False], device=device, dtype=torch.bool) @@ -319,8 +341,10 @@ def test_prefill_with_initial_state(gated_delta_env): q, k, v, - g, - beta, + a, + b, + A_log, + dt_bias, batch_info_host, cu_seqlen, slot_idx, @@ -329,7 +353,6 @@ def test_prefill_with_initial_state(gated_delta_env): scale, ) - # The results should differ when there's a non-zero initial state assert not torch.allclose(y_with_init, y_without_init, atol=1e-3, rtol=1e-3), ( "Output with initial state should differ from output without initial state" ) diff --git a/tests/unittest/auto_deploy/singlegpu/custom_ops/mamba/test_flashinfer_mamba_cached_op.py b/tests/unittest/auto_deploy/singlegpu/custom_ops/mamba/test_flashinfer_mamba_cached_op.py index f0dea7e343a..f963754ba4f 100644 --- a/tests/unittest/auto_deploy/singlegpu/custom_ops/mamba/test_flashinfer_mamba_cached_op.py +++ b/tests/unittest/auto_deploy/singlegpu/custom_ops/mamba/test_flashinfer_mamba_cached_op.py @@ -41,6 +41,7 @@ def test_flashinfer_decode_matches_triton(mamba_env): cu_seqlen = torch.zeros(batch + 1, device=device, dtype=torch.int32) use_initial_states = torch.zeros(batch, device=device, dtype=torch.bool) + any_prefill_use_initial_states_host = torch.tensor([False], device=device, dtype=torch.bool) y_triton = torch.ops.auto_deploy.triton_cached_ssm( hidden_states, A, @@ -54,6 +55,7 @@ def test_flashinfer_decode_matches_triton(mamba_env): cu_seqlen, slot_idx, use_initial_states, + any_prefill_use_initial_states_host, # EXTRA METADATA None, # chunk indices None, # chunk offsets @@ -78,6 +80,7 @@ def test_flashinfer_decode_matches_triton(mamba_env): cu_seqlen, slot_idx, use_initial_states, + any_prefill_use_initial_states_host, # EXTRA METADATA None, # chunk indices None, # chunk offsets diff --git a/tests/unittest/auto_deploy/singlegpu/custom_ops/mamba/test_triton_mamba_cached_op.py b/tests/unittest/auto_deploy/singlegpu/custom_ops/mamba/test_triton_mamba_cached_op.py index b102bd5796c..604291fc2b2 100644 --- a/tests/unittest/auto_deploy/singlegpu/custom_ops/mamba/test_triton_mamba_cached_op.py +++ b/tests/unittest/auto_deploy/singlegpu/custom_ops/mamba/test_triton_mamba_cached_op.py @@ -56,6 +56,7 @@ def test_triton_generate_only_with_slot_mapping(mamba_env): seq_len = torch.ones(batch, device=device, dtype=torch.int32) cu_seqlen = torch.zeros(batch + 1, device=device, dtype=torch.int32) use_initial_states = torch.zeros(batch, device=device, dtype=torch.bool) + any_prefill_use_initial_states_host = torch.tensor([False], device=device, dtype=torch.bool) # Torch reference y_torch = torch.ops.auto_deploy.torch_cached_ssm( @@ -93,6 +94,7 @@ def test_triton_generate_only_with_slot_mapping(mamba_env): cu_seqlen, slot_idx, use_initial_states, + any_prefill_use_initial_states_host, # EXTRA METADATA None, # chunk indices None, # chunk offsets @@ -141,6 +143,7 @@ def test_triton_context_flattened_and_state_writeback(mamba_env): seq_len = torch.tensor(lens, device=device, dtype=torch.int32) cu_seqlen = torch.tensor([0, lens[0]], device=device, dtype=torch.int32) use_initial_states = torch.tensor([0] * batch, device=device).to(torch.bool) + any_prefill_use_initial_states_host = torch.tensor([False], device=device, dtype=torch.bool) cu_seqlens = torch.cat( [ torch.zeros(1, dtype=torch.int32, device=device), @@ -190,6 +193,7 @@ def test_triton_context_flattened_and_state_writeback(mamba_env): cu_seqlens, slot_idx, use_initial_states, + any_prefill_use_initial_states_host, # EXTRA METADATA None, # chunk indices None, # chunk offsets diff --git a/tests/unittest/auto_deploy/singlegpu/custom_ops/test_multi_stream_gemm.py b/tests/unittest/auto_deploy/singlegpu/custom_ops/test_multi_stream_gemm.py new file mode 100644 index 00000000000..c5509b303ef --- /dev/null +++ b/tests/unittest/auto_deploy/singlegpu/custom_ops/test_multi_stream_gemm.py @@ -0,0 +1,690 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the generalized multi-stream GEMM parallelization transform. + +The transform targets fork points where 2+ fp8 linear ops share the same input +and moves the **largest** (by weight shape) to the auxiliary CUDA stream. + +Architecture patterns tested: + + **Two-branch fork** -- Two linears of different sizes sharing one input. + The larger one should be moved to the aux stream. + + **Three-branch fork (MHA-like)** -- Three linears (q_proj, k_proj, v_proj) + with q_proj being the largest. + + **Four-branch fork (linear attention-like)** -- Four linears sharing one + input, with one being significantly larger than the others. + + **Skip already-handled** -- Fork points that already have multi-stream ops + are skipped to avoid conflicts. + + **Single linear** -- No fork point, should produce zero matches. + + **Equal-weight linears** -- All linears at a fork point have the same weight + size; one should still be selected deterministically. + +Each pattern is tested for: + 1. Pattern matching -- correct number of fork points found. + 2. Largest GEMM identification -- correct linear moved to aux stream. + 3. Graph structure -- ``record_event_passthrough`` and aux nodes present. + 4. Numerical correctness -- output matches eager reference. + 5. CUDA graph compatibility -- capture + replay produces correct output. + 6. Multi-layer stacking -- multiple fork points handled independently. +""" + +from typing import Optional + +import torch +import torch.nn as nn + +from tensorrt_llm._torch.auto_deploy.transform.library.multi_stream_gemm import ( + _estimate_weight_size, + _find_gemm_fork_points, + _parallelize_largest_gemm, +) +from tensorrt_llm._torch.auto_deploy.utils.multi_stream_utils import ( + cuda_stream_manager, + record_event_passthrough, +) + +# --------------------------------------------------------------------------- +# Mock fp8 linear custom op -- mimics trtllm_finegrained_fp8_linear signature +# --------------------------------------------------------------------------- + + +@torch.library.custom_op("auto_deploy::mock_fp8_linear_gemm_test", mutates_args=()) +def mock_fp8_linear( + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + weight_scale: torch.Tensor, +) -> torch.Tensor: + """Mock fp8 linear: simple matmul standing in for the real kernel.""" + out = input @ weight.to(input.dtype).t() + if bias is not None: + out = out + bias + return out + + +@mock_fp8_linear.register_fake +def _mock_fp8_linear_fake( + input: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor], + weight_scale: torch.Tensor, +) -> torch.Tensor: + out_features = weight.shape[0] + return torch.empty((*input.shape[:-1], out_features), dtype=input.dtype, device=input.device) + + +_MOCK_OPS = [torch.ops.auto_deploy.mock_fp8_linear_gemm_test] + + +# --------------------------------------------------------------------------- +# Helper: wrap nn.Linear as our mock fp8 linear in forward +# --------------------------------------------------------------------------- + + +class MockFP8Linear(nn.Module): + """Wrapper that calls the mock fp8 linear custom op with real weight tensors.""" + + def __init__(self, in_features: int, out_features: int): + super().__init__() + self.weight = nn.Parameter(torch.randn(out_features, in_features)) + self.weight_scale = nn.Parameter( + torch.ones(max(1, out_features // 128), max(1, in_features // 128)), + requires_grad=False, + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.ops.auto_deploy.mock_fp8_linear_gemm_test( + x, self.weight, None, self.weight_scale + ) + + +# --------------------------------------------------------------------------- +# Mock model architectures +# --------------------------------------------------------------------------- + + +class TwoBranchFork(nn.Module): + """Two fp8 linears sharing one input, different weight sizes. + + branch_a: [large_out, hidden] -- larger GEMM + branch_b: [small_out, hidden] -- smaller GEMM + Merge: concat -> project back + """ + + def __init__(self, hidden_dim: int, large_out: int, small_out: int): + super().__init__() + self.branch_a = MockFP8Linear(hidden_dim, large_out) # Larger + self.branch_b = MockFP8Linear(hidden_dim, small_out) # Smaller + self.proj = nn.Linear(large_out + small_out, hidden_dim, bias=False) + self.norm = nn.LayerNorm(hidden_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = self.branch_a(x) + b = self.branch_b(x) + combined = torch.cat([a, b], dim=-1) + return self.norm(self.proj(combined)) + + +class ThreeBranchFork(nn.Module): + """Three fp8 linears sharing one input (MHA-like: q, k, v projections). + + q_proj: [large_dim, hidden] -- largest + k_proj: [small_dim, hidden] -- smaller + v_proj: [small_dim, hidden] -- smaller + Merge: sum all outputs (simplified from real attention) + """ + + def __init__(self, hidden_dim: int, large_dim: int, small_dim: int): + super().__init__() + self.q_proj = MockFP8Linear(hidden_dim, large_dim) # Largest + self.k_proj = MockFP8Linear(hidden_dim, small_dim) + self.v_proj = MockFP8Linear(hidden_dim, small_dim) + self.o_proj = nn.Linear(large_dim, hidden_dim, bias=False) + self.k_down = nn.Linear(small_dim, large_dim, bias=False) + self.v_down = nn.Linear(small_dim, large_dim, bias=False) + self.norm = nn.LayerNorm(hidden_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + q = self.q_proj(x) + k = self.k_down(self.k_proj(x)) + v = self.v_down(self.v_proj(x)) + combined = q + k + v + return self.norm(self.o_proj(combined)) + + +class FourBranchFork(nn.Module): + """Four fp8 linears sharing one input (linear attention-like). + + in_proj_qkv: [large_dim, hidden] -- largest + in_proj_z: [mid_dim, hidden] + in_proj_b: [small_dim, hidden] + in_proj_a: [small_dim, hidden] + Merge: sum all (simplified) + """ + + def __init__(self, hidden_dim: int, large_dim: int, mid_dim: int, small_dim: int): + super().__init__() + self.in_proj_qkv = MockFP8Linear(hidden_dim, large_dim) # Largest + self.in_proj_z = MockFP8Linear(hidden_dim, mid_dim) + self.in_proj_b = MockFP8Linear(hidden_dim, small_dim) + self.in_proj_a = MockFP8Linear(hidden_dim, small_dim) + # Project all down to hidden_dim for summation + self.down_qkv = nn.Linear(large_dim, hidden_dim, bias=False) + self.down_z = nn.Linear(mid_dim, hidden_dim, bias=False) + self.down_b = nn.Linear(small_dim, hidden_dim, bias=False) + self.down_a = nn.Linear(small_dim, hidden_dim, bias=False) + self.norm = nn.LayerNorm(hidden_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + qkv = self.down_qkv(self.in_proj_qkv(x)) + z = self.down_z(self.in_proj_z(x)) + b = self.down_b(self.in_proj_b(x)) + a = self.down_a(self.in_proj_a(x)) + return self.norm(qkv + z + b + a) + + +class EqualWeightFork(nn.Module): + """Two fp8 linears with identical weight shapes sharing one input. + + Tests that the transform handles ties deterministically. + """ + + def __init__(self, hidden_dim: int, out_dim: int): + super().__init__() + self.branch_a = MockFP8Linear(hidden_dim, out_dim) + self.branch_b = MockFP8Linear(hidden_dim, out_dim) + self.norm = nn.LayerNorm(out_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.norm(self.branch_a(x) + self.branch_b(x)) + + +class SingleLinearModel(nn.Module): + """Only one fp8 linear -- no fork point, no match expected.""" + + def __init__(self, hidden_dim: int, out_dim: int): + super().__init__() + self.fc = MockFP8Linear(hidden_dim, out_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.fc(x) + + +# --------------------------------------------------------------------------- +# Test helpers +# --------------------------------------------------------------------------- + + +def _build_gm(model, example_input): + """Export *model* to an FX ``GraphModule``.""" + return torch.export.export(model, (example_input,)).module() + + +def _get_targets(gm): + """Return the set of ``call_function`` targets in *gm*.""" + return {n.target for n in gm.graph.nodes if n.op == "call_function"} + + +def _count_aux_ops(gm): + """Count how many aux-stream variant ops are in the graph.""" + count = 0 + for n in gm.graph.nodes: + if n.op == "call_function" and hasattr(n.target, "__name__"): + if n.target.__name__.endswith("_aux"): + count += 1 + return count + + +def _assert_numerical_correctness(gm, model, test_x, *, atol=1e-4): + """Assert that *gm* and *model* produce the same output on *test_x*.""" + ref = model(test_x) + out = gm(test_x) + assert torch.allclose(out, ref, atol=atol), ( + f"Output mismatch: max diff = {(out - ref).abs().max().item()}" + ) + + +def _assert_cuda_graph_correctness(gm, model, test_x, *, atol=1e-4): + """Assert correctness under CUDA graph capture + replay.""" + ref = model(test_x) + static_x = torch.randn_like(test_x) + static_out = torch.empty_like(ref) + + # Warm up + for _ in range(3): + static_out.copy_(gm(static_x)) + + cuda_graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(cuda_graph): + static_out.copy_(gm(static_x)) + + static_x.copy_(test_x) + cuda_graph.replay() + + assert torch.allclose(static_out, ref, atol=atol), ( + f"CUDA graph output mismatch: max diff = {(static_out - ref).abs().max().item()}" + ) + + +# =================================================================== +# Tests -- Two-branch fork +# =================================================================== + + +def test_two_branch_pattern_matching(): + """A fork with two fp8 linears should produce exactly one fork point.""" + model = TwoBranchFork(128, 256, 64).eval().to("cuda") + example = torch.randn(4, 128, device="cuda") + gm = _build_gm(model, example) + + fork_points = _find_gemm_fork_points(gm, _MOCK_OPS) + assert len(fork_points) == 1, f"Expected 1 fork point, got {len(fork_points)}" + # Should have 2 linear users at the fork point + assert len(fork_points[0][1]) == 2 + + +def test_two_branch_largest_moved_to_aux(): + """The larger linear (256 out-features) should be the one on the aux stream.""" + cuda_stream_manager.add_device(torch.cuda.current_device()) + model = TwoBranchFork(128, 256, 64).eval().to("cuda") + example = torch.randn(4, 128, device="cuda") + gm = _build_gm(model, example) + + gm, num = _parallelize_largest_gemm(gm, _MOCK_OPS) + + assert num == 1, f"Expected 1 replacement, got {num}" + targets = _get_targets(gm) + assert record_event_passthrough in targets, "record_event_passthrough not in graph" + + +def test_two_branch_numerical_correctness(): + """Numerical output must match the original model after the transform.""" + cuda_stream_manager.add_device(torch.cuda.current_device()) + model = TwoBranchFork(128, 256, 64).eval().to("cuda") + example = torch.randn(4, 128, device="cuda") + gm = _build_gm(model, example) + + gm, num = _parallelize_largest_gemm(gm, _MOCK_OPS) + assert num == 1 + + _assert_numerical_correctness(gm, model, torch.randn(4, 128, device="cuda")) + + +def test_two_branch_cuda_graph(): + """CUDA graph capture + replay for two-branch fork.""" + cuda_stream_manager.add_device(torch.cuda.current_device()) + model = TwoBranchFork(128, 256, 64).eval().to("cuda") + example = torch.randn(4, 128, device="cuda") + gm = _build_gm(model, example) + + gm, num = _parallelize_largest_gemm(gm, _MOCK_OPS) + assert num == 1 + + _assert_cuda_graph_correctness(gm, model, torch.randn(4, 128, device="cuda")) + + +def test_two_branch_multi_layer(): + """Two stacked layers -- both fork points should be transformed.""" + cuda_stream_manager.add_device(torch.cuda.current_device()) + model = ( + nn.Sequential( + TwoBranchFork(128, 256, 64), + TwoBranchFork(128, 256, 64), + ) + .eval() + .to("cuda") + ) + example = torch.randn(4, 128, device="cuda") + gm = _build_gm(model, example) + + gm, num = _parallelize_largest_gemm(gm, _MOCK_OPS) + assert num == 2, f"Expected 2 replacements, got {num}" + + _assert_numerical_correctness(gm, model, torch.randn(4, 128, device="cuda")) + + +# =================================================================== +# Tests -- Three-branch fork (MHA-like) +# =================================================================== + + +def test_three_branch_pattern_matching(): + """A fork with three fp8 linears should produce exactly one fork point.""" + model = ThreeBranchFork(128, 256, 64).eval().to("cuda") + example = torch.randn(4, 128, device="cuda") + gm = _build_gm(model, example) + + fork_points = _find_gemm_fork_points(gm, _MOCK_OPS) + assert len(fork_points) == 1, f"Expected 1 fork point, got {len(fork_points)}" + assert len(fork_points[0][1]) == 3 + + +def test_three_branch_numerical_correctness(): + """Numerical correctness for three-branch fork.""" + cuda_stream_manager.add_device(torch.cuda.current_device()) + model = ThreeBranchFork(128, 256, 64).eval().to("cuda") + example = torch.randn(4, 128, device="cuda") + gm = _build_gm(model, example) + + gm, num = _parallelize_largest_gemm(gm, _MOCK_OPS) + assert num == 1 + + _assert_numerical_correctness(gm, model, torch.randn(4, 128, device="cuda")) + + +def test_three_branch_cuda_graph(): + """CUDA graph capture + replay for three-branch fork.""" + cuda_stream_manager.add_device(torch.cuda.current_device()) + model = ThreeBranchFork(128, 256, 64).eval().to("cuda") + example = torch.randn(4, 128, device="cuda") + gm = _build_gm(model, example) + + gm, num = _parallelize_largest_gemm(gm, _MOCK_OPS) + assert num == 1 + + _assert_cuda_graph_correctness(gm, model, torch.randn(4, 128, device="cuda")) + + +# =================================================================== +# Tests -- Four-branch fork (linear attention-like) +# =================================================================== + + +def test_four_branch_pattern_matching(): + """A fork with four fp8 linears should produce exactly one fork point.""" + model = FourBranchFork(128, 512, 256, 64).eval().to("cuda") + example = torch.randn(4, 128, device="cuda") + gm = _build_gm(model, example) + + fork_points = _find_gemm_fork_points(gm, _MOCK_OPS) + assert len(fork_points) == 1, f"Expected 1 fork point, got {len(fork_points)}" + assert len(fork_points[0][1]) == 4 + + +def test_four_branch_numerical_correctness(): + """Numerical correctness for four-branch fork.""" + cuda_stream_manager.add_device(torch.cuda.current_device()) + model = FourBranchFork(128, 512, 256, 64).eval().to("cuda") + example = torch.randn(4, 128, device="cuda") + gm = _build_gm(model, example) + + gm, num = _parallelize_largest_gemm(gm, _MOCK_OPS) + assert num == 1 + + _assert_numerical_correctness(gm, model, torch.randn(4, 128, device="cuda")) + + +def test_four_branch_cuda_graph(): + """CUDA graph capture + replay for four-branch fork.""" + cuda_stream_manager.add_device(torch.cuda.current_device()) + model = FourBranchFork(128, 512, 256, 64).eval().to("cuda") + example = torch.randn(4, 128, device="cuda") + gm = _build_gm(model, example) + + gm, num = _parallelize_largest_gemm(gm, _MOCK_OPS) + assert num == 1 + + _assert_cuda_graph_correctness(gm, model, torch.randn(4, 128, device="cuda")) + + +def test_four_branch_multi_layer(): + """Two stacked four-branch layers -- both should be transformed.""" + cuda_stream_manager.add_device(torch.cuda.current_device()) + model = ( + nn.Sequential( + FourBranchFork(128, 512, 256, 64), + FourBranchFork(128, 512, 256, 64), + ) + .eval() + .to("cuda") + ) + example = torch.randn(4, 128, device="cuda") + gm = _build_gm(model, example) + + gm, num = _parallelize_largest_gemm(gm, _MOCK_OPS) + assert num == 2, f"Expected 2 replacements, got {num}" + + _assert_numerical_correctness(gm, model, torch.randn(4, 128, device="cuda")) + + +# =================================================================== +# Tests -- Edge cases +# =================================================================== + + +def test_single_linear_no_match(): + """A single fp8 linear (no fork) should produce zero matches.""" + model = SingleLinearModel(128, 256).eval().to("cuda") + example = torch.randn(4, 128, device="cuda") + gm = _build_gm(model, example) + + fork_points = _find_gemm_fork_points(gm, _MOCK_OPS) + assert len(fork_points) == 0, f"Expected 0 fork points, got {len(fork_points)}" + + cuda_stream_manager.add_device(torch.cuda.current_device()) + gm, num = _parallelize_largest_gemm(gm, _MOCK_OPS) + assert num == 0 + + # Graph should NOT contain any stream-management nodes. + targets = _get_targets(gm) + assert record_event_passthrough not in targets + + +def test_equal_weight_deterministic(): + """Even with equal weights, the transform should succeed (one moved to aux).""" + cuda_stream_manager.add_device(torch.cuda.current_device()) + model = EqualWeightFork(128, 128).eval().to("cuda") + example = torch.randn(4, 128, device="cuda") + gm = _build_gm(model, example) + + gm, num = _parallelize_largest_gemm(gm, _MOCK_OPS) + assert num == 1, f"Expected 1 replacement with equal weights, got {num}" + + _assert_numerical_correctness(gm, model, torch.randn(4, 128, device="cuda")) + + +def test_weight_size_estimation(): + """_estimate_weight_size should return the product of weight dimensions.""" + model = TwoBranchFork(128, 256, 64).eval().to("cuda") + example = torch.randn(4, 128, device="cuda") + gm = _build_gm(model, example) + + fork_points = _find_gemm_fork_points(gm, _MOCK_OPS) + assert len(fork_points) == 1 + + _, linears = fork_points[0] + sizes = [_estimate_weight_size(gm, ln) for ln in linears] + + # One should be 256*128=32768, the other 64*128=8192 + assert max(sizes) > min(sizes), f"Expected different sizes but got {sizes}" + assert 256 * 128 in sizes, f"Expected 256*128=32768 in sizes, got {sizes}" + assert 64 * 128 in sizes, f"Expected 64*128=8192 in sizes, got {sizes}" + + +def test_skip_already_handled_fork_point(): + """Fork points with existing multi-stream ops should be skipped. + + We manually insert a ``record_event_passthrough`` into the graph before + running the transform and verify that the fork point is skipped. + """ + cuda_stream_manager.add_device(torch.cuda.current_device()) + model = TwoBranchFork(128, 256, 64).eval().to("cuda") + example = torch.randn(4, 128, device="cuda") + gm = _build_gm(model, example) + + # Manually inject a record_event_passthrough at the fork point. + fork_points = _find_gemm_fork_points(gm, _MOCK_OPS) + assert len(fork_points) == 1 + fork_node = fork_points[0][0] + + graph = gm.graph + # Insert a dummy record_event_passthrough as a user of the fork point. + with graph.inserting_after(fork_node): + graph.call_function(record_event_passthrough, args=(fork_node,)) + + # Now the transform should skip this fork point. + gm2, num = _parallelize_largest_gemm(gm, _MOCK_OPS) + assert num == 0, f"Expected 0 replacements (fork point already handled), got {num}" + + +def test_idempotent_double_application(): + """Applying the transform twice should be idempotent (second pass is no-op).""" + cuda_stream_manager.add_device(torch.cuda.current_device()) + model = TwoBranchFork(128, 256, 64).eval().to("cuda") + example = torch.randn(4, 128, device="cuda") + gm = _build_gm(model, example) + + gm, num1 = _parallelize_largest_gemm(gm, _MOCK_OPS) + assert num1 == 1 + + # Second application -- the fork point now has a record_event_passthrough + # user, so it should be skipped. + gm, num2 = _parallelize_largest_gemm(gm, _MOCK_OPS) + assert num2 == 0, f"Expected 0 on second pass, got {num2}" + + # Should still be numerically correct after double application. + _assert_numerical_correctness(gm, model, torch.randn(4, 128, device="cuda")) + + +# =================================================================== +# Tests -- Topological ordering (regression test for downstream users +# of the largest linear appearing before remaining linears) +# =================================================================== + + +class InterleavedDownstreamFork(nn.Module): + """Model where the largest linear's downstream ops appear BEFORE the remaining linears in graph order. + + This mirrors real MHA patterns where: + q_proj(x) -> reshape_q -> ... (largest, downstream appears early) + k_proj(x) -> reshape_k -> ... (remaining, appears later) + v_proj(x) -> reshape_v -> ... (remaining, appears later) + + Without proper topological fix, the aux node (replacing q_proj) would be + inserted after v_proj, but reshape_q would still reference it from its + original early position, causing a 'used before defined' error. + """ + + def __init__(self, hidden_dim: int, large_dim: int, small_dim: int): + super().__init__() + self.q_proj = MockFP8Linear(hidden_dim, large_dim) # Largest + self.k_proj = MockFP8Linear(hidden_dim, small_dim) + self.v_proj = MockFP8Linear(hidden_dim, small_dim) + self.norm = nn.LayerNorm(hidden_dim) + # Project q through extra ops to simulate reshape/view chain + self.q_down = nn.Linear(large_dim, hidden_dim, bias=False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # q_proj and its downstream appear first in graph order + q = self.q_down(torch.relu(self.q_proj(x))) + # k_proj, v_proj appear later + k = self.k_proj(x) + v = self.v_proj(x) + # Join point: all three branches merge + # Project k, v to hidden_dim for addition + return self.norm(q + k[..., : x.shape[-1]] + v[..., : x.shape[-1]]) + + +class DeepDownstreamFork(nn.Module): + """Model where the largest linear has a deep chain of downstream ops before the remaining linears. + + Tests transitive movement. + + largest -> relu -> linear -> relu -> linear -> ... + remaining_1, remaining_2 appear after the chain + """ + + def __init__(self, hidden_dim: int, large_dim: int, small_dim: int): + super().__init__() + self.large_proj = MockFP8Linear(hidden_dim, large_dim) # Largest + self.small_proj_1 = MockFP8Linear(hidden_dim, small_dim) + self.small_proj_2 = MockFP8Linear(hidden_dim, small_dim) + # Deep chain on the largest branch + self.chain_1 = nn.Linear(large_dim, large_dim, bias=False) + self.chain_2 = nn.Linear(large_dim, hidden_dim, bias=False) + self.norm = nn.LayerNorm(hidden_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Deep chain: large_proj -> relu -> chain_1 -> relu -> chain_2 + big = self.chain_2(torch.relu(self.chain_1(torch.relu(self.large_proj(x))))) + # Remaining linears appear after the chain in graph order + s1 = self.small_proj_1(x) + s2 = self.small_proj_2(x) + return self.norm(big + s1[..., : x.shape[-1]] + s2[..., : x.shape[-1]]) + + +def test_interleaved_downstream_topological_order(): + """Regression: aux node must not violate topological order. + + This checks when the largest linear's downstream ops precede the remaining linears. + """ + cuda_stream_manager.add_device(torch.cuda.current_device()) + model = InterleavedDownstreamFork(128, 256, 128).eval().to("cuda") + example = torch.randn(4, 128, device="cuda") + gm = _build_gm(model, example) + + # This should NOT raise 'used before defined' RuntimeError. + gm, num = _parallelize_largest_gemm(gm, _MOCK_OPS) + assert num == 1, f"Expected 1 replacement, got {num}" + + # Verify graph is topologically valid by running it. + test_x = torch.randn(4, 128, device="cuda") + _assert_numerical_correctness(gm, model, test_x) + + +def test_interleaved_downstream_cuda_graph(): + """CUDA graph capture + replay for interleaved downstream pattern.""" + cuda_stream_manager.add_device(torch.cuda.current_device()) + model = InterleavedDownstreamFork(128, 256, 128).eval().to("cuda") + example = torch.randn(4, 128, device="cuda") + gm = _build_gm(model, example) + + gm, num = _parallelize_largest_gemm(gm, _MOCK_OPS) + assert num == 1 + + _assert_cuda_graph_correctness(gm, model, torch.randn(4, 128, device="cuda")) + + +def test_deep_downstream_topological_order(): + """Regression: deep chain of ops downstream of the largest linear must all be moved after the aux node.""" + cuda_stream_manager.add_device(torch.cuda.current_device()) + model = DeepDownstreamFork(128, 256, 128).eval().to("cuda") + example = torch.randn(4, 128, device="cuda") + gm = _build_gm(model, example) + + gm, num = _parallelize_largest_gemm(gm, _MOCK_OPS) + assert num == 1, f"Expected 1 replacement, got {num}" + + test_x = torch.randn(4, 128, device="cuda") + _assert_numerical_correctness(gm, model, test_x) + + +def test_deep_downstream_cuda_graph(): + """CUDA graph for deep downstream chain pattern.""" + cuda_stream_manager.add_device(torch.cuda.current_device()) + model = DeepDownstreamFork(128, 256, 128).eval().to("cuda") + example = torch.randn(4, 128, device="cuda") + gm = _build_gm(model, example) + + gm, num = _parallelize_largest_gemm(gm, _MOCK_OPS) + assert num == 1 + + _assert_cuda_graph_correctness(gm, model, torch.randn(4, 128, device="cuda")) diff --git a/tests/unittest/auto_deploy/singlegpu/models/test_qwen3_5_moe.py b/tests/unittest/auto_deploy/singlegpu/models/test_qwen3_5_moe.py index 564d7967d36..8a956c533e0 100644 --- a/tests/unittest/auto_deploy/singlegpu/models/test_qwen3_5_moe.py +++ b/tests/unittest/auto_deploy/singlegpu/models/test_qwen3_5_moe.py @@ -291,7 +291,7 @@ def ref_gdn_forward(module, hidden_states): Uses module.conv1d directly (nn.Conv1d forward) instead of torch_causal_conv1d, and ref_chunk_gated_delta_rule (with internal l2norm) instead of - torch_l2norm + torch_gated_delta_rule. + torch_gated_delta_rule (which handles l2norm, GQA expand, and gating internally). """ batch_size, seq_len, _ = hidden_states.shape diff --git a/tests/unittest/auto_deploy/singlegpu/models/test_qwen3_next_gdn_patches.py b/tests/unittest/auto_deploy/singlegpu/models/test_qwen3_next_gdn_patches.py index b30044c7e14..d425e453cd9 100644 --- a/tests/unittest/auto_deploy/singlegpu/models/test_qwen3_next_gdn_patches.py +++ b/tests/unittest/auto_deploy/singlegpu/models/test_qwen3_next_gdn_patches.py @@ -19,16 +19,19 @@ torch_chunk_gated_delta_rule as hf_torch_chunk_gated_delta_rule, ) -# Register all auto_deploy custom ops (torch_gated_delta_rule, torch_causal_conv1d, torch_l2norm) +# Register all auto_deploy custom ops (torch_gated_delta_rule, torch_causal_conv1d, etc.) import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 +from tensorrt_llm._torch.auto_deploy.custom_ops.fla.fla_gated_delta import _l2norm from tensorrt_llm._torch.auto_deploy.models.patches.qwen3_next import _patched_gdn_forward def test_torch_gated_delta_rule_op(): """Verify `torch_gated_delta_rule` custom op matches HF `torch_chunk_gated_delta_rule`. - Both operate on pure-torch math (no FLA kernels). We compare with - `use_qk_l2norm_in_kernel=False` so L2 norm is excluded from both paths. + The custom op accepts raw (un-normalized, un-expanded) q/k and raw gating + projections (a, b, A_log, dt_bias), performing L2 norm, GQA expansion, and + gating computation internally. The reference HF function takes pre-processed + inputs, so we manually apply the same preprocessing before calling it. """ torch.manual_seed(42) @@ -38,26 +41,32 @@ def test_torch_gated_delta_rule_op(): k_head_dim = 16 v_head_dim = 16 - # Inputs in [B, S, H, D] layout (bsnd convention) - q = torch.randn(batch_size, seq_len, num_heads, k_head_dim, dtype=torch.float32) - k = torch.randn(batch_size, seq_len, num_heads, k_head_dim, dtype=torch.float32) + # Raw inputs in [B, S, H, D] layout (bsnd convention) + q_raw = torch.randn(batch_size, seq_len, num_heads, k_head_dim, dtype=torch.float32) + k_raw = torch.randn(batch_size, seq_len, num_heads, k_head_dim, dtype=torch.float32) v = torch.randn(batch_size, seq_len, num_heads, v_head_dim, dtype=torch.float32) - g = -torch.rand(batch_size, seq_len, num_heads, dtype=torch.float32) # negative (decay) - beta = torch.sigmoid(torch.randn(batch_size, seq_len, num_heads, dtype=torch.float32)) - - # L2 normalize Q and K (as our patched forward does externally) - q = torch.nn.functional.normalize(q, dim=-1) - k = torch.nn.functional.normalize(k, dim=-1) - - # Reference: HF torch implementation (no l2norm inside, since we did it externally) + a = torch.randn(batch_size, seq_len, num_heads, dtype=torch.float32) + b = torch.randn(batch_size, seq_len, num_heads, dtype=torch.float32) + A_log = torch.randn(num_heads, dtype=torch.float32) + dt_bias = torch.randn(num_heads, dtype=torch.float32) + + # Preprocess for HF reference: l2 norm + gating (must match _l2norm convention) + q_norm = _l2norm(q_raw.float()) + k_norm = _l2norm(k_raw.float()) + g = -A_log.float().exp() * torch.nn.functional.softplus(a.float() + dt_bias) + beta = b.float().sigmoid() + + # Reference: HF torch implementation with pre-processed inputs with torch.no_grad(): ref_output, _ = hf_torch_chunk_gated_delta_rule( - q, k, v, g=g, beta=beta, use_qk_l2norm_in_kernel=False + q_norm, k_norm, v, g=g, beta=beta, use_qk_l2norm_in_kernel=False ) - # Test: our custom op + # Test: our custom op with raw inputs with torch.no_grad(): - test_output = torch.ops.auto_deploy.torch_gated_delta_rule(q, k, v, g, beta) + test_output = torch.ops.auto_deploy.torch_gated_delta_rule( + q_raw, k_raw, v, a, b, A_log, dt_bias + ) torch.testing.assert_close( ref_output, @@ -78,22 +87,29 @@ def test_torch_gated_delta_rule_op_bfloat16(): k_head_dim = 8 v_head_dim = 8 - q = torch.randn(batch_size, seq_len, num_heads, k_head_dim, dtype=torch.bfloat16) - k = torch.randn(batch_size, seq_len, num_heads, k_head_dim, dtype=torch.bfloat16) + q_raw = torch.randn(batch_size, seq_len, num_heads, k_head_dim, dtype=torch.bfloat16) + k_raw = torch.randn(batch_size, seq_len, num_heads, k_head_dim, dtype=torch.bfloat16) v = torch.randn(batch_size, seq_len, num_heads, v_head_dim, dtype=torch.bfloat16) - g = -torch.rand(batch_size, seq_len, num_heads, dtype=torch.bfloat16) - beta = torch.sigmoid(torch.randn(batch_size, seq_len, num_heads, dtype=torch.bfloat16)) + a = torch.randn(batch_size, seq_len, num_heads, dtype=torch.bfloat16) + b = torch.randn(batch_size, seq_len, num_heads, dtype=torch.bfloat16) + A_log = torch.randn(num_heads, dtype=torch.bfloat16) + dt_bias = torch.randn(num_heads, dtype=torch.bfloat16) - q = torch.nn.functional.normalize(q, dim=-1) - k = torch.nn.functional.normalize(k, dim=-1) + q_norm = _l2norm(q_raw.float()).to(torch.bfloat16) + k_norm = _l2norm(k_raw.float()).to(torch.bfloat16) + g = -A_log.float().exp() * torch.nn.functional.softplus(a.float() + dt_bias) + g = g.to(torch.bfloat16) + beta = b.float().sigmoid().to(torch.bfloat16) with torch.no_grad(): ref_output, _ = hf_torch_chunk_gated_delta_rule( - q, k, v, g=g, beta=beta, use_qk_l2norm_in_kernel=False + q_norm, k_norm, v, g=g, beta=beta, use_qk_l2norm_in_kernel=False ) with torch.no_grad(): - test_output = torch.ops.auto_deploy.torch_gated_delta_rule(q, k, v, g, beta) + test_output = torch.ops.auto_deploy.torch_gated_delta_rule( + q_raw, k_raw, v, a, b, A_log, dt_bias + ) torch.testing.assert_close( ref_output, @@ -172,8 +188,9 @@ def test_qwen3_next_gdn_patch(): """Verify patched Qwen3NextGatedDeltaNet.forward matches the original HF implementation. The patch replaces the forward with autodeploy custom ops - (torch_causal_conv1d, torch_l2norm, torch_gated_delta_rule) while - maintaining numerical equivalence. + (torch_causal_conv1d, torch_gated_delta_rule) while maintaining + numerical equivalence. L2 norm, GQA expansion, and gating are + handled inside torch_gated_delta_rule. """ torch.manual_seed(42) diff --git a/tests/unittest/auto_deploy/singlegpu/smoke/test_ad_build_small_single.py b/tests/unittest/auto_deploy/singlegpu/smoke/test_ad_build_small_single.py index dd56511748d..5f2ae4d0075 100644 --- a/tests/unittest/auto_deploy/singlegpu/smoke/test_ad_build_small_single.py +++ b/tests/unittest/auto_deploy/singlegpu/smoke/test_ad_build_small_single.py @@ -7,6 +7,9 @@ from tensorrt_llm._torch.auto_deploy.llm_args import LlmArgs, _ParallelConfig from tensorrt_llm._torch.auto_deploy.shim.ad_executor import ADEngine +# When a run uses FP8 block scaling GEMM on a GPU that doesn't support it, skip only that run. +_FP8_BLOCK_SCALING_GEMM_ERR = "Unsupported SM version for FP8 block scaling GEMM" + def _check_ad_config(experiment_config: ExperimentConfig, llm_args: LlmArgs): # Verify that llm_args was captured @@ -219,7 +222,15 @@ def check_and_original_build(cls, ad_config): ADEngine.build_from_config = check_and_original_build try: - main(experiment_config) + try: + main(experiment_config) + except RuntimeError as e: + if _FP8_BLOCK_SCALING_GEMM_ERR in str(e): + pytest.skip( + "This run uses FP8 block scaling GEMM, which requires SM 89 (Ada), " + "90 (Hopper), 100/103 (Blackwell), or 120 (RTX 6000)" + ) + raise finally: # Restore original build_from_config ADEngine.build_from_config = original_build_from_config diff --git a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_finegrained_fp8_swiglu.py b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_finegrained_fp8_swiglu.py new file mode 100644 index 00000000000..65275e88b08 --- /dev/null +++ b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_finegrained_fp8_swiglu.py @@ -0,0 +1,397 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for FineGrained FP8 quantized SwiGLU pattern matching and fusion transforms. + +Tests the FineGrained FP8 SwiGLU path: +1. match_finegrained_fp8_swiglu_pattern: Matches torch_fake_quant_finegrained_fp8_linear + SwiGLU -> torch_finegrained_fp8_swiglu_mlp +2. fuse_finegrained_fp8_swiglu: Fuses gate+up FP8 weights -> fused_finegrained_fp8_swiglu_mlp +""" + +import pytest +import torch +import torch.nn as nn +from _torch_test_utils import fp8_compatible, trtllm_ops_available +from torch.export import Dim + +import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer +from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op + +_skip_reason = "Requires FP8 (Hopper+) and TRT-LLM ops" +_skip_condition = not (fp8_compatible() and trtllm_ops_available()) + +# Block size for finegrained FP8 quantization (128x128) +_BLOCK_SIZE = 128 + + +def _quantize_fp8_block(weight: torch.Tensor) -> tuple: + """Quantize a weight tensor to FP8 with per-128x128-block scales. + + Args: + weight: Float weight tensor of shape [N, K] where N and K are multiples of 128. + + Returns: + (weight_fp8, weight_scale_inv) where: + - weight_fp8: [N, K] float8_e4m3fn + - weight_scale_inv: [N/128, K/128] float32 per-block scale + """ + N, K = weight.shape + assert N % _BLOCK_SIZE == 0 and K % _BLOCK_SIZE == 0, ( + f"Dimensions must be multiples of {_BLOCK_SIZE}, got ({N}, {K})" + ) + FP8_MAX = torch.finfo(torch.float8_e4m3fn).max + + # Reshape into blocks of [128, 128] + blocks = weight.reshape(N // _BLOCK_SIZE, _BLOCK_SIZE, K // _BLOCK_SIZE, _BLOCK_SIZE) + blocks = blocks.permute(0, 2, 1, 3) # [N/128, K/128, 128, 128] + + # Compute per-block scale: amax / FP8_MAX + block_amax = blocks.abs().amax(dim=(-2, -1)) # [N/128, K/128] + scale = (block_amax / FP8_MAX).clamp(min=1e-12).to(torch.float32) + + # Quantize: divide by scale, clamp, cast to fp8 + scale_expanded = scale[:, :, None, None] # [N/128, K/128, 1, 1] + blocks_scaled = (blocks.float() / scale_expanded).clamp(-FP8_MAX, FP8_MAX) + blocks_fp8 = blocks_scaled.to(torch.float8_e4m3fn) + + # Reshape back to [N, K] + weight_fp8 = blocks_fp8.permute(0, 2, 1, 3).reshape(N, K) + + return weight_fp8, scale + + +class FineGrainedFP8SwiGLUMLP(nn.Module): + """SwiGLU MLP using FineGrained FP8 quantized linear ops. + + Mimics the graph structure produced by quantize_finegrained_fp8_linear_from_config + applied to a standard SwiGLU MLP: silu(gate(x)) * up(x) -> down(hidden). + """ + + def __init__(self, hidden_size: int = 256, intermediate_size: int = 256): + super().__init__() + device = torch.device("cuda") + + # Create random weights and quantize them to FP8 + gate_weight = ( + torch.randn(intermediate_size, hidden_size, dtype=torch.bfloat16, device=device) * 0.05 + ) + up_weight = ( + torch.randn(intermediate_size, hidden_size, dtype=torch.bfloat16, device=device) * 0.05 + ) + down_weight = ( + torch.randn(hidden_size, intermediate_size, dtype=torch.bfloat16, device=device) * 0.05 + ) + + # Quantize each projection + gate_fp8, gate_scale = _quantize_fp8_block(gate_weight) + up_fp8, up_scale = _quantize_fp8_block(up_weight) + down_fp8, down_scale = _quantize_fp8_block(down_weight) + + self.register_buffer("gate_weight", gate_fp8) + self.register_buffer("gate_weight_scale", gate_scale) + self.register_buffer("up_weight", up_fp8) + self.register_buffer("up_weight_scale", up_scale) + self.register_buffer("down_weight", down_fp8) + self.register_buffer("down_weight_scale", down_scale) + + def forward(self, x): + gate_out = torch.ops.auto_deploy.torch_fake_quant_finegrained_fp8_linear( + x, + self.gate_weight, + None, + [], + [self.gate_weight_scale], + [], + [], + ) + up_out = torch.ops.auto_deploy.torch_fake_quant_finegrained_fp8_linear( + x, + self.up_weight, + None, + [], + [self.up_weight_scale], + [], + [], + ) + hidden = torch.nn.functional.silu(gate_out) * up_out + return torch.ops.auto_deploy.torch_fake_quant_finegrained_fp8_linear( + hidden, + self.down_weight, + None, + [], + [self.down_weight_scale], + [], + [], + ) + + +class FineGrainedFP8SwiGLUTestModel(nn.Module): + """Test model wrapping FineGrained FP8 SwiGLU MLP between linear layers.""" + + def __init__(self, hidden_size: int = 256, intermediate_size: int = 256): + super().__init__() + device = torch.device("cuda") + self.linear_in = nn.Linear(hidden_size, hidden_size, device=device, dtype=torch.bfloat16) + self.mlp = FineGrainedFP8SwiGLUMLP(hidden_size, intermediate_size) + self.linear_out = nn.Linear(hidden_size, hidden_size, device=device, dtype=torch.bfloat16) + + def forward(self, x): + x = self.linear_in(x) + x = self.mlp(x) + x = self.linear_out(x) + return x + + +class FineGrainedFP8SwiGLUMultiLayerModel(nn.Module): + """Test model with multiple FineGrained FP8 SwiGLU MLP layers.""" + + def __init__( + self, + hidden_size: int = 256, + intermediate_size: int = 256, + num_layers: int = 2, + ): + super().__init__() + device = torch.device("cuda") + self.layers = nn.ModuleList() + for _ in range(num_layers): + self.layers.append( + nn.ModuleDict( + { + "linear": nn.Linear( + hidden_size, + hidden_size, + device=device, + dtype=torch.bfloat16, + ), + "mlp": FineGrainedFP8SwiGLUMLP(hidden_size, intermediate_size), + } + ) + ) + + def forward(self, x): + for layer in self.layers: + x = layer["linear"](x) + x = layer["mlp"](x) + return x + + +# -- Test helpers -------------------------------------------------------------- + + +def _count_ops(gm, op): + """Count how many nodes in the graph match the given op.""" + return sum(1 for n in gm.graph.nodes if is_op(n, op)) + + +def _has_no_fake_quant_finegrained_fp8(gm): + """Verify no torch_fake_quant_finegrained_fp8_linear ops remain.""" + return _count_ops(gm, torch.ops.auto_deploy.torch_fake_quant_finegrained_fp8_linear) == 0 + + +# -- Tests --------------------------------------------------------------------- + + +@pytest.mark.skipif(_skip_condition, reason=_skip_reason) +def test_finegrained_fp8_swiglu_pattern_match_only(): + """Test that match_finegrained_fp8_swiglu_pattern produces torch_finegrained_fp8_swiglu_mlp.""" + torch.manual_seed(0) + model = FineGrainedFP8SwiGLUMLP().to("cuda") + x = torch.randn(2, 256, device="cuda", dtype=torch.bfloat16) + + gm = torch_export_to_gm(model, args=(x,), clone=True) + + # Verify the graph has torch_fake_quant_finegrained_fp8_linear ops before transform + assert _count_ops(gm, torch.ops.auto_deploy.torch_fake_quant_finegrained_fp8_linear) == 3, ( + "Expected 3 torch_fake_quant_finegrained_fp8_linear ops (gate, up, down) before transform" + ) + + # Apply only pattern matching + gm_matched = InferenceOptimizer( + None, + { + "match_finegrained_fp8_swiglu_pattern": { + "stage": "pattern_matcher", + }, + }, + )(None, gm) + + # Check the intermediate op is present + fp8_swiglu_count = _count_ops( + gm_matched, torch.ops.auto_deploy.torch_finegrained_fp8_swiglu_mlp.default + ) + assert fp8_swiglu_count == 1, ( + f"Expected 1 torch_finegrained_fp8_swiglu_mlp op, got {fp8_swiglu_count}" + ) + + # All 3 fake_quant_finegrained_fp8 ops should be consumed + assert _has_no_fake_quant_finegrained_fp8(gm_matched), ( + "torch_fake_quant_finegrained_fp8_linear ops should be consumed by pattern matcher" + ) + + # Verify numerical correctness + gm_matched = gm_matched.to("cuda") + y_matched = gm_matched(x) + y_model = model(x) + torch.testing.assert_close(y_matched, y_model, atol=1e-3, rtol=1e-3) + + +@pytest.mark.skipif(_skip_condition, reason=_skip_reason) +def test_finegrained_fp8_swiglu_full_fusion(): + """Test full pipeline: pattern match -> fuse -> fused_finegrained_fp8_swiglu_mlp.""" + torch.manual_seed(0) + model = FineGrainedFP8SwiGLUTestModel().to("cuda") + x = torch.randn(2, 256, device="cuda", dtype=torch.bfloat16) + + gm = torch_export_to_gm(model, args=(x,), clone=True, dynamic_shapes=({0: Dim.DYNAMIC},)) + + # Apply pattern matching + fusion + gm_fused = InferenceOptimizer( + None, + { + "match_finegrained_fp8_swiglu_pattern": { + "stage": "pattern_matcher", + }, + "fuse_finegrained_fp8_swiglu": { + "stage": "post_load_fusion", + }, + }, + )(None, gm) + + gm_fused = gm_fused.to("cuda") + + # Check the fused op is present + fused_count = _count_ops( + gm_fused, torch.ops.auto_deploy.fused_finegrained_fp8_swiglu_mlp.default + ) + assert fused_count == 1, f"Expected 1 fused_finegrained_fp8_swiglu_mlp op, got {fused_count}" + + # No intermediate or unfused ops should remain + assert ( + _count_ops(gm_fused, torch.ops.auto_deploy.torch_finegrained_fp8_swiglu_mlp.default) == 0 + ), "Intermediate torch_finegrained_fp8_swiglu_mlp should be replaced by fused version" + assert _has_no_fake_quant_finegrained_fp8(gm_fused), ( + "No torch_fake_quant_finegrained_fp8_linear ops should remain after fusion" + ) + + # Verify numerical correctness (fused uses TRT-LLM kernel, allow wider tolerance) + y_fused = gm_fused(x) + y_model = model(x) + torch.testing.assert_close(y_fused, y_model, atol=0.15, rtol=0.05) + + # Test with a different batch size to verify dynamic shapes work + x2 = torch.randn(4, 256, device="cuda", dtype=torch.bfloat16) + y_fused_2 = gm_fused(x2) + y_model_2 = model(x2) + torch.testing.assert_close(y_fused_2, y_model_2, atol=0.15, rtol=0.05) + + +@pytest.mark.skipif(_skip_condition, reason=_skip_reason) +@pytest.mark.parametrize("num_layers", [2, 3]) +def test_finegrained_fp8_swiglu_fusion_multiple_layers(num_layers): + """Test that multiple FineGrained FP8 SwiGLU patterns are fused correctly.""" + torch.manual_seed(0) + model = FineGrainedFP8SwiGLUMultiLayerModel(num_layers=num_layers).to("cuda") + x = torch.randn(2, 256, device="cuda", dtype=torch.bfloat16) + + gm = torch_export_to_gm(model, args=(x,), clone=True) + + # Apply pattern matching + fusion + gm_fused = InferenceOptimizer( + None, + { + "match_finegrained_fp8_swiglu_pattern": { + "stage": "pattern_matcher", + }, + "fuse_finegrained_fp8_swiglu": { + "stage": "post_load_fusion", + }, + }, + )(None, gm) + + gm_fused = gm_fused.to("cuda") + + # Check that all layers are fused + fused_count = _count_ops( + gm_fused, torch.ops.auto_deploy.fused_finegrained_fp8_swiglu_mlp.default + ) + assert fused_count == num_layers, ( + f"Expected {num_layers} fused_finegrained_fp8_swiglu_mlp ops, got {fused_count}" + ) + + # Verify numerical correctness + y_fused = gm_fused(x) + y_model = model(x) + torch.testing.assert_close(y_fused, y_model, atol=0.2, rtol=0.1) + + +@pytest.mark.skipif(_skip_condition, reason=_skip_reason) +def test_finegrained_fp8_swiglu_does_not_match_non_swiglu(): + """Test that the FP8 SwiGLU matcher does not match non-SwiGLU FP8 linears.""" + torch.manual_seed(0) + device = torch.device("cuda") + hidden_size = 256 + + # Model with two sequential FP8 linears + relu (NOT a SwiGLU pattern) + class NonSwiGLUModel(nn.Module): + def __init__(self): + super().__init__() + w1 = torch.randn(hidden_size, hidden_size, dtype=torch.bfloat16, device=device) * 0.05 + w2 = torch.randn(hidden_size, hidden_size, dtype=torch.bfloat16, device=device) * 0.05 + + w1_fp8, w1_scale = _quantize_fp8_block(w1) + w2_fp8, w2_scale = _quantize_fp8_block(w2) + + self.register_buffer("w1", w1_fp8) + self.register_buffer("w1_scale", w1_scale) + self.register_buffer("w2", w2_fp8) + self.register_buffer("w2_scale", w2_scale) + + def forward(self, x): + # Sequential linears without SwiGLU pattern + y = torch.ops.auto_deploy.torch_fake_quant_finegrained_fp8_linear( + x, self.w1, None, [], [self.w1_scale], [], [] + ) + y = torch.nn.functional.relu(y) + return torch.ops.auto_deploy.torch_fake_quant_finegrained_fp8_linear( + y, self.w2, None, [], [self.w2_scale], [], [] + ) + + model = NonSwiGLUModel().to("cuda") + x = torch.randn(2, hidden_size, device="cuda", dtype=torch.bfloat16) + + gm = torch_export_to_gm(model, args=(x,), clone=True) + + gm_result = InferenceOptimizer( + None, + { + "match_finegrained_fp8_swiglu_pattern": { + "stage": "pattern_matcher", + }, + }, + )(None, gm) + + # No SwiGLU ops should be found + assert ( + _count_ops(gm_result, torch.ops.auto_deploy.torch_finegrained_fp8_swiglu_mlp.default) == 0 + ), "Non-SwiGLU FP8 pattern should not match" + + # Original FP8 linear ops should still be present + assert ( + _count_ops(gm_result, torch.ops.auto_deploy.torch_fake_quant_finegrained_fp8_linear) == 2 + ), "Original FP8 linear ops should be unchanged" diff --git a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_fuse_gdn_gating.py b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_fuse_gdn_gating.py new file mode 100644 index 00000000000..9c00ddfae12 --- /dev/null +++ b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_fuse_gdn_gating.py @@ -0,0 +1,112 @@ +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the FuseGdnGating graph transform.""" + +import torch +from _graph_test_helpers import run_test_transformed_gm +from torch.export import Dim + +from tensorrt_llm._torch.auto_deploy.custom_ops.fla import ( + gdn_gating as _gdn_gating_ops, # noqa: F401 +) +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer +from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op + + +class GdnGatingModel(torch.nn.Module): + """Minimal model that uses torch_fused_gdn_gating followed by a linear.""" + + def __init__(self, num_heads: int = 16, hidden: int = 64): + super().__init__() + self.proj = torch.nn.Linear(hidden, num_heads, device="cuda", dtype=torch.float16) + self.A_log = torch.nn.Parameter(torch.randn(num_heads, device="cuda", dtype=torch.float16)) + self.dt_bias = torch.nn.Parameter( + torch.randn(num_heads, device="cuda", dtype=torch.float16) + ) + self.out = torch.nn.Linear(num_heads, hidden, device="cuda", dtype=torch.float16) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + a = self.proj(x) # [B, S, H] + g = torch.ops.auto_deploy.torch_fused_gdn_gating(self.A_log, a, self.dt_bias) + return self.out(g.to(x.dtype)) + + +def test_fuse_gdn_gating(): + """Verify FuseGdnGating replaces torch source op with triton op.""" + model = GdnGatingModel() + + def checker(gm): + return any(is_op(n, torch.ops.auto_deploy.triton_fused_gdn_gating) for n in gm.graph.nodes) + + x = torch.randn(2, 8, 64, device="cuda", dtype=torch.float16) + dynamic_shapes = {0: Dim.DYNAMIC} + gm = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True) + gm_transformed = InferenceOptimizer( + None, + { + "fuse_gdn_gating": { + "stage": "post_load_fusion", + }, + }, + )(None, gm) + + run_test_transformed_gm( + model, + x, + gm_transformed, + checker, + lambda num_p_og: num_p_og, + dynamic_shapes=dynamic_shapes, + ) + + # Also verify with different batch size (dynamic shapes) + new_input = torch.randn(4, 8, 64, device="cuda", dtype=torch.float16) + y_transformed = gm_transformed(new_input) + y_model = model(new_input) + torch.testing.assert_close(y_transformed, y_model, atol=1e-3, rtol=1e-3) + + +def test_no_match_without_source_op(): + """FuseGdnGating should be a no-op when no torch_fused_gdn_gating ops exist.""" + + class PlainModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(64, 64, device="cuda", dtype=torch.float16) + + def forward(self, x): + return self.linear(x) + + model = PlainModel() + x = torch.randn(2, 8, 64, device="cuda", dtype=torch.float16) + dynamic_shapes = {0: Dim.DYNAMIC} + gm = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True) + + gm_transformed = InferenceOptimizer( + None, + { + "fuse_gdn_gating": { + "stage": "post_load_fusion", + }, + }, + )(None, gm) + + # No triton ops should appear + has_triton_op = any( + is_op(n, torch.ops.auto_deploy.triton_fused_gdn_gating) for n in gm_transformed.graph.nodes + ) + assert not has_triton_op, "triton_fused_gdn_gating should not appear when there is no source op" diff --git a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_gated_delta_rule_cache.py b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_gated_delta_rule_cache.py index 5daf348f01d..57259bddf7e 100644 --- a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_gated_delta_rule_cache.py +++ b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_gated_delta_rule_cache.py @@ -21,6 +21,7 @@ from typing import List, Optional +import pytest import torch import torch.nn as nn from _torch_test_utils import all_close @@ -72,54 +73,60 @@ def get_export_infos(self, model: nn.Module) -> List[SubModuleExportInfo]: class GatedDeltaRuleModel(nn.Module): """Minimal model that projects embeddings through torch_gated_delta_rule. + Supports GVA: q/k use num_k_heads, v/g/beta use num_v_heads. + L2 normalization and repeat_interleave are handled inside the op. + Architecture: input_ids -> embedding -> linear projections -> torch_gated_delta_rule -> output proj + + The op internally performs L2 normalization, GQA repeat-interleave, and gating + computation from raw a/b projections and per-head A_log/dt_bias parameters. """ def __init__( self, vocab_size: int, hidden_size: int, - num_heads: int, + num_k_heads: int, + num_v_heads: int, key_dim: int, value_dim: int, ): super().__init__() - self.num_heads = num_heads + self.num_k_heads = num_k_heads + self.num_v_heads = num_v_heads self.key_dim = key_dim self.value_dim = value_dim self.embed_tokens = nn.Embedding(vocab_size, hidden_size) - self.q_proj = nn.Linear(hidden_size, num_heads * key_dim, bias=False) - self.k_proj = nn.Linear(hidden_size, num_heads * key_dim, bias=False) - self.v_proj = nn.Linear(hidden_size, num_heads * value_dim, bias=False) - self.g_proj = nn.Linear(hidden_size, num_heads, bias=False) - self.beta_proj = nn.Linear(hidden_size, num_heads, bias=False) - self.o_proj = nn.Linear(num_heads * value_dim, hidden_size, bias=False) + self.q_proj = nn.Linear(hidden_size, num_k_heads * key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, num_k_heads * key_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, num_v_heads * value_dim, bias=False) + self.a_proj = nn.Linear(hidden_size, num_v_heads, bias=False) + self.b_proj = nn.Linear(hidden_size, num_v_heads, bias=False) + self.A_log = nn.Parameter(torch.zeros(num_v_heads)) + self.dt_bias = nn.Parameter(torch.zeros(num_v_heads)) + self.o_proj = nn.Linear(num_v_heads * value_dim, hidden_size, bias=False) @torch.no_grad() def forward( self, input_ids: torch.Tensor, position_ids: Optional[torch.Tensor] = None ) -> torch.Tensor: x = self.embed_tokens(input_ids) # [B, S, hidden] - b, s, _ = x.shape - - q = self.q_proj(x).view(b, s, self.num_heads, self.key_dim) - k = self.k_proj(x).view(b, s, self.num_heads, self.key_dim) - v = self.v_proj(x).view(b, s, self.num_heads, self.value_dim) + bsz, s, _ = x.shape - # L2 normalize Q and K - q = torch.nn.functional.normalize(q, dim=-1) - k = torch.nn.functional.normalize(k, dim=-1) + q = self.q_proj(x).view(bsz, s, self.num_k_heads, self.key_dim) + k = self.k_proj(x).view(bsz, s, self.num_k_heads, self.key_dim) + v = self.v_proj(x).view(bsz, s, self.num_v_heads, self.value_dim) + a = self.a_proj(x) # [B, S, HV] + b = self.b_proj(x) # [B, S, HV] - # g should be negative (decay), beta should be in (0, 1) - g = -torch.nn.functional.softplus(self.g_proj(x)) # [B, S, H] - beta = torch.sigmoid(self.beta_proj(x)) # [B, S, H] - - attn_out = torch.ops.auto_deploy.torch_gated_delta_rule(q, k, v, g, beta) - # attn_out: [B, S, H, V] + attn_out = torch.ops.auto_deploy.torch_gated_delta_rule( + q, k, v, a, b, self.A_log, self.dt_bias + ) + # attn_out: [B, S, HV, V] - attn_out = attn_out.reshape(b, s, -1) + attn_out = attn_out.reshape(bsz, s, -1) return self.o_proj(attn_out) @@ -128,8 +135,9 @@ def forward( # --------------------------------------------------------------------------- +@pytest.mark.parametrize("num_k_heads,num_v_heads", [(2, 2), (2, 4)]) @torch.inference_mode() -def test_gated_delta_rule_with_cache(): +def test_gated_delta_rule_with_cache(num_k_heads, num_v_heads): """Test the insert_cached_gated_delta_rule transform with fla_gated_delta backend.""" # Configuration dtype = torch.bfloat16 @@ -143,7 +151,6 @@ def test_gated_delta_rule_with_cache(): seq_len = 16 vocab_size = 100 hidden_size = 32 - num_heads = 2 key_dim = 8 value_dim = 8 max_position_embeddings = 64 @@ -165,7 +172,8 @@ def test_gated_delta_rule_with_cache(): model = GatedDeltaRuleModel( vocab_size=vocab_size, hidden_size=hidden_size, - num_heads=num_heads, + num_k_heads=num_k_heads, + num_v_heads=num_v_heads, key_dim=key_dim, value_dim=value_dim, ).to(dtype=dtype, device="cuda") diff --git a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_torch_gated_delta_rule_cache.py b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_torch_gated_delta_rule_cache.py index 8b72833eb0f..9db23105488 100644 --- a/tests/unittest/auto_deploy/singlegpu/transformations/library/test_torch_gated_delta_rule_cache.py +++ b/tests/unittest/auto_deploy/singlegpu/transformations/library/test_torch_gated_delta_rule_cache.py @@ -13,6 +13,7 @@ from typing import List, Optional +import pytest import torch import torch.nn as nn from _torch_test_utils import all_close @@ -64,54 +65,60 @@ def get_export_infos(self, model: nn.Module) -> List[SubModuleExportInfo]: class GatedDeltaRuleModel(nn.Module): """Minimal model that projects embeddings through torch_gated_delta_rule. + Supports GVA: q/k use num_k_heads, v/g/beta use num_v_heads. + L2 normalization and repeat_interleave are handled inside the op. + Architecture: input_ids -> embedding -> linear projections -> torch_gated_delta_rule -> output proj + + The op internally performs L2 normalization, GQA repeat-interleave, and gating + computation from raw a/b projections and per-head A_log/dt_bias parameters. """ def __init__( self, vocab_size: int, hidden_size: int, - num_heads: int, + num_k_heads: int, + num_v_heads: int, key_dim: int, value_dim: int, ): super().__init__() - self.num_heads = num_heads + self.num_k_heads = num_k_heads + self.num_v_heads = num_v_heads self.key_dim = key_dim self.value_dim = value_dim self.embed_tokens = nn.Embedding(vocab_size, hidden_size) - self.q_proj = nn.Linear(hidden_size, num_heads * key_dim, bias=False) - self.k_proj = nn.Linear(hidden_size, num_heads * key_dim, bias=False) - self.v_proj = nn.Linear(hidden_size, num_heads * value_dim, bias=False) - self.g_proj = nn.Linear(hidden_size, num_heads, bias=False) - self.beta_proj = nn.Linear(hidden_size, num_heads, bias=False) - self.o_proj = nn.Linear(num_heads * value_dim, hidden_size, bias=False) + self.q_proj = nn.Linear(hidden_size, num_k_heads * key_dim, bias=False) + self.k_proj = nn.Linear(hidden_size, num_k_heads * key_dim, bias=False) + self.v_proj = nn.Linear(hidden_size, num_v_heads * value_dim, bias=False) + self.a_proj = nn.Linear(hidden_size, num_v_heads, bias=False) + self.b_proj = nn.Linear(hidden_size, num_v_heads, bias=False) + self.A_log = nn.Parameter(torch.zeros(num_v_heads)) + self.dt_bias = nn.Parameter(torch.zeros(num_v_heads)) + self.o_proj = nn.Linear(num_v_heads * value_dim, hidden_size, bias=False) @torch.no_grad() def forward( self, input_ids: torch.Tensor, position_ids: Optional[torch.Tensor] = None ) -> torch.Tensor: x = self.embed_tokens(input_ids) # [B, S, hidden] - b, s, _ = x.shape - - q = self.q_proj(x).view(b, s, self.num_heads, self.key_dim) - k = self.k_proj(x).view(b, s, self.num_heads, self.key_dim) - v = self.v_proj(x).view(b, s, self.num_heads, self.value_dim) + bsz, s, _ = x.shape - # L2 normalize Q and K - q = torch.nn.functional.normalize(q, dim=-1) - k = torch.nn.functional.normalize(k, dim=-1) + q = self.q_proj(x).view(bsz, s, self.num_k_heads, self.key_dim) + k = self.k_proj(x).view(bsz, s, self.num_k_heads, self.key_dim) + v = self.v_proj(x).view(bsz, s, self.num_v_heads, self.value_dim) + a = self.a_proj(x) # [B, S, HV] + b = self.b_proj(x) # [B, S, HV] - # g should be negative (decay), beta should be in (0, 1) - g = -torch.nn.functional.softplus(self.g_proj(x)) # [B, S, H] - beta = torch.sigmoid(self.beta_proj(x)) # [B, S, H] - - attn_out = torch.ops.auto_deploy.torch_gated_delta_rule(q, k, v, g, beta) - # attn_out: [B, S, H, V] + attn_out = torch.ops.auto_deploy.torch_gated_delta_rule( + q, k, v, a, b, self.A_log, self.dt_bias + ) + # attn_out: [B, S, HV, V] - attn_out = attn_out.reshape(b, s, -1) + attn_out = attn_out.reshape(bsz, s, -1) return self.o_proj(attn_out) @@ -120,8 +127,9 @@ def forward( # --------------------------------------------------------------------------- +@pytest.mark.parametrize("num_k_heads,num_v_heads", [(2, 2), (2, 4)]) @torch.inference_mode() -def test_torch_gated_delta_rule_cache(): +def test_torch_gated_delta_rule_cache(num_k_heads, num_v_heads): """Test the insert_cached_gated_delta_rule transform with torch_gated_delta backend.""" # Configuration dtype = torch.float32 @@ -131,7 +139,6 @@ def test_torch_gated_delta_rule_cache(): seq_len = 16 vocab_size = 100 hidden_size = 32 - num_heads = 2 key_dim = 8 value_dim = 8 max_position_embeddings = 64 @@ -153,7 +160,8 @@ def test_torch_gated_delta_rule_cache(): model = GatedDeltaRuleModel( vocab_size=vocab_size, hidden_size=hidden_size, - num_heads=num_heads, + num_k_heads=num_k_heads, + num_v_heads=num_v_heads, key_dim=key_dim, value_dim=value_dim, ).to(dtype=dtype, device="cuda")