diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 5908c1271e4..af6f130cefb 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -19,3 +19,15 @@ transforms: stage: post_export cleanup_input_constraints: stage: post_export + quantize: + stage: pattern_matcher + quantize_moe: + stage: pattern_matcher + match_repeat_kv: + stage: pattern_matcher + match_eager_attention: + stage: pattern_matcher + match_grouped_attention: + stage: pattern_matcher + match_attention_layout: + stage: pattern_matcher diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py index 68175233f91..89dc59f6354 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_attention.py @@ -7,7 +7,28 @@ import torch.nn as nn import torch.nn.functional as F -# TODO (nvchenghaoz): Remove related kernels once we have a backend-specific implementation for attention. + +def _apply_logit_softcapping(attn_scores: torch.Tensor, logit_cap: Optional[float]) -> torch.Tensor: + """Apply logit softcapping using the formula: logit_cap * tanh(logits / logit_cap)""" + if logit_cap is not None and logit_cap > 0.0: + return logit_cap * torch.tanh(attn_scores / logit_cap) + return attn_scores + + +def _convert_boolean_mask_to_float(attn_mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + """Convert boolean attention mask to floating point mask. + Args: + attn_mask: Boolean tensor where True allows attention, False blocks it + dtype: Target dtype for the output mask + Returns: + Floating point mask where True -> 1.0, False -> -inf + """ + if attn_mask.dtype == torch.bool: + float_mask = torch.zeros_like(attn_mask, dtype=dtype) + float_mask = float_mask.masked_fill(attn_mask, 1.0) # True -> 1.0 + float_mask = float_mask.masked_fill(~attn_mask, float("-inf")) # False -> -inf + return float_mask + return attn_mask @torch.library.custom_op("auto_deploy::torch_attention_repeat_kv", mutates_args=()) @@ -77,19 +98,96 @@ def grouped_sdpa( dropout_p: float = 0.0, is_causal: bool = False, scale: Optional[float] = None, + sinks: Optional[torch.Tensor] = None, + sliding_window: Optional[int] = None, + logit_cap: Optional[float] = None, ) -> torch.Tensor: - """SDPA attention that can handle GQA.""" + """SDPA attention that can handle GQA. Expects bnsd format inputs.""" + b, n_heads, s_q, head_dim = query.shape # bnsd format: [batch, num_heads, seq_len, head_dim] + _, n_kv_heads, s_k, _ = key.shape # bnsd format: [batch, num_kv_heads, seq_len, head_dim] + + # Inputs are already in bnsd format, no need to transpose + query_t = query # [b, n_heads, s_q, head_dim] + key_t = key # [b, n_kv_heads, s_k, head_dim] + value_t = value # [b, n_kv_heads, s_k, v_head_dim] + + # Handle GQA by repeating KV if needed + if n_heads != n_kv_heads: + n_rep = n_heads // n_kv_heads + key_t = repeat_kv(key_t, n_rep) + value_t = repeat_kv(value_t, n_rep) + + # Set scale + if scale is None: + scale = 1.0 / math.sqrt(head_dim) + + # Compute attention scores: Q @ K^T + attn_scores = torch.matmul(query_t, key_t.transpose(-2, -1)) * scale # [b, n_heads, s_q, s_k] + + # Apply attention mask if provided + if attn_mask is not None: + # Convert boolean mask to float if needed + attn_mask = _convert_boolean_mask_to_float(attn_mask, attn_scores.dtype) + attn_scores = attn_scores + attn_mask + + # Apply causal mask if specified and only during the context phase + if is_causal and s_q == s_k: # Only apply causal mask during context processing + causal_mask = torch.triu( + torch.ones(s_q, s_k, device=query.device, dtype=torch.bool), + diagonal=1, # Use diagonal=1 for standard causal masking + ) + attn_scores.masked_fill_(causal_mask.unsqueeze(0).unsqueeze(0), float("-inf")) + + # Apply sliding window mask if specified + if sliding_window is not None and sliding_window > 0: + # Handle position calculation for both context and generation phases + if s_q == s_k: + # Context phase: standard position calculation + query_positions = torch.arange(s_q, device=query.device) + key_positions = torch.arange(s_k, device=query.device) + else: + # Generation phase: query is at position s_k (after the cache) + query_positions = torch.arange(s_k, s_k + s_q, device=query.device) # [s_k] for s_q=1 + key_positions = torch.arange(s_k, device=query.device) # [0,1,2,...,s_k-1] + + # Create position difference matrix: query_pos - key_pos + pos_diff = query_positions.unsqueeze(1) - key_positions.unsqueeze(0) # [s_q, s_k] + + # Sliding window mask: allow attention only if 0 <= pos_diff < sliding_window_size + sliding_window_mask = (pos_diff < 0) | (pos_diff >= sliding_window) # [s_q, s_k] + attn_scores.masked_fill_(sliding_window_mask.unsqueeze(0).unsqueeze(0), float("-inf")) + + # Apply logit softcapping if enabled + attn_scores = _apply_logit_softcapping(attn_scores, logit_cap) + + # Apply sinks if provided + if sinks is not None: + # Concatenate sinks to attention scores following the reference implementation + # sinks should have n_heads elements, each head gets its own sink value + # Expand sinks to [b, n_heads, s_q, 1] - one sink column per head + sinks_expanded = sinks.reshape(1, -1, 1, 1).expand( + b, n_heads, s_q, 1 + ) # [b, n_heads, s_q, 1] + + # Concatenate along the key dimension (last dimension) + logits_max = torch.max(attn_scores, dim=-1, keepdim=True).values + sinks = torch.exp(sinks_expanded - logits_max) + unnormalized_scores = torch.exp(attn_scores - logits_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks + scores = unnormalized_scores / normalizer + # Use only the non-sink portion for computing output + # We added exactly 1 column, so remove exactly 1 column + attn_out = torch.matmul(scores, value_t) # [b, n_heads, s_q, v_head_dim] + else: + attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(query.dtype) + attn_out = torch.matmul(attn_weights, value_t) # [b, n_heads, s_q, v_head_dim] - return F.scaled_dot_product_attention( - query.contiguous(), - key.contiguous(), - value.contiguous(), - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=is_causal, - scale=scale, - enable_gqa=True, - ) + # Apply dropout if specified + if dropout_p > 0.0: + attn_out = F.dropout(attn_out, p=dropout_p, training=False) + + # Return in bnsd format (same as input format) + return attn_out @grouped_sdpa.register_fake @@ -101,6 +199,9 @@ def grouped_sdpa_fake( dropout_p=0.0, is_causal=False, scale=None, + sinks=None, + sliding_window=None, + logit_cap=None, ): """Fake implementation of grouped SDPA.""" return query.new_empty(*query.shape[:-1], value.shape[-1]).contiguous() @@ -108,9 +209,9 @@ def grouped_sdpa_fake( @torch.library.custom_op("auto_deploy::torch_attention_bsnd_grouped_sdpa", mutates_args=()) def bsnd_grouped_sdpa( - query: torch.Tensor, # layout: [b, n, s_q, d] - key: torch.Tensor, # layout: [b, n, s_k, d] - value: torch.Tensor, # layout: [b, n, s_k, d] + query: torch.Tensor, # layout: [b, s_q, n, d] + key: torch.Tensor, # layout: [b, s_k, n, d] + value: torch.Tensor, # layout: [b, s_k, n, d] attn_mask: Optional[torch.Tensor] = None, # layout: [b, n, s_q, s_k] dropout_p: float = 0.0, is_causal: bool = False, @@ -124,14 +225,16 @@ def bsnd_grouped_sdpa( Note that attn_mask layout is still assumed to be [b, n, s_q, s_k] and is consistent with the original sdpa op! """ - # let's transpose to bnsd so we can use the grouped sdpa - query = query.transpose(1, 2).contiguous() - key = key.transpose(1, 2).contiguous() - value = value.transpose(1, 2).contiguous() - - out = grouped_sdpa(query, key, value, attn_mask, dropout_p, is_causal, scale) - - # let's transpose back to bnsd + # Transpose inputs to bnsd format for grouped_sdpa + query = query.transpose(1, 2).contiguous() # [b, s_q, n, d] -> [b, n, s_q, d] + key = key.transpose(1, 2).contiguous() # [b, s_k, n, d] -> [b, n, s_k, d] + value = value.transpose(1, 2).contiguous() # [b, s_k, n, d] -> [b, n, s_k, d] + + # Call grouped_sdpa with bnsd inputs + out = grouped_sdpa( + query, key, value, attn_mask, dropout_p, is_causal, scale, sinks, sliding_window, logit_cap + ) + # Transpose back to bsnd format return out.transpose(1, 2).contiguous() diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py index 9eccd0c83a9..f4f60bc31af 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_backend_attention.py @@ -103,7 +103,7 @@ def _torch_generate_mha( # Apply sinks if provided (following the model file pattern) if sinks is not None: # Concatenate sinks to attention scores - sinks = sinks.reshape(-1, 1, 1).expand(-1, attn_scores.shape[-2], -1) + sinks = sinks.reshape(-1, 1, 1) attn_weights = torch.cat([attn_scores, sinks], dim=-1) attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) # Use only the non-sink portion for computing output (ignore sinks) @@ -202,9 +202,7 @@ def _torch_context_mha( ) # [seq_len_i, kv_seq_len] # Sliding window mask: allow attention only if 0 <= pos_diff < sliding_window_size - sliding_window_mask = (pos_diff < 0) | ( - pos_diff >= sliding_window_size - ) # [seq_len_i, kv_seq_len] + sliding_window_mask = pos_diff >= sliding_window_size # Combine causal and sliding window masks combined_mask = causal_mask | sliding_window_mask @@ -219,14 +217,14 @@ def _torch_context_mha( # Apply sinks if provided (following the model file pattern) if sinks is not None: # Concatenate sinks to attention scores - sinks = sinks.reshape(1, -1, 1, 1).expand( - attn_scores.shape[0], -1, attn_scores.shape[-2], -1 + new_sinks = sinks.reshape(1, -1, 1, 1).expand( + attn_scores.shape[0], -1, attn_scores.shape[2], 1 ) - attn_weights = torch.cat([attn_scores, sinks], dim=-1) + attn_weights = torch.cat([attn_scores, new_sinks], dim=-1) attn_weights = torch.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q.dtype) # Use only the non-sink portion for computing output (ignore sinks) attn_out = torch.matmul( - attn_weights[..., : -sinks.size(-1)], v_seq_t + attn_weights[..., : -new_sinks.size(-1)], v_seq_t ) # [1, n_heads, seq_len_i, v_head_dim] else: attn_weights = torch.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype) diff --git a/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py b/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py index e42da002f6d..dba782bb4ac 100644 --- a/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py +++ b/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py @@ -17,7 +17,8 @@ def trtllm_allreduce(tensor, op, all_reduce_params=None): rank, world_size = get_rank_world_size() assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op." p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank) - torch_op = AllReduce(mapping=p_config, strategy=AllReduceStrategy.AUTO) + # Use Strategy.NCCL until https://nvbugspro.nvidia.com/bug/5331013 is fixed, then change to Strategy.AUTO + torch_op = AllReduce(mapping=p_config, strategy=AllReduceStrategy.NCCL) return torch_op(tensor, all_reduce_params=all_reduce_params) @torch.library.custom_op( diff --git a/tensorrt_llm/_torch/auto_deploy/models/hf.py b/tensorrt_llm/_torch/auto_deploy/models/hf.py index f407a042538..fc37c1e557a 100644 --- a/tensorrt_llm/_torch/auto_deploy/models/hf.py +++ b/tensorrt_llm/_torch/auto_deploy/models/hf.py @@ -76,6 +76,12 @@ class AutoModelForCausalLMFactory(ModelFactory): "max_position_embeddings": 1024, } + def _get_max_position_embeddings_config(self) -> Dict[str, Any]: + """Get the max position embeddings config for the model.""" + return { + "max_position_embeddings": self.max_seq_len, + } + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -83,7 +89,11 @@ def __init__(self, *args, **kwargs): # Ingest defaults for tokenizer and model kwargs self.tokenizer_kwargs = deep_merge_dicts(self._tokenizer_defaults, self.tokenizer_kwargs) - self.model_kwargs = deep_merge_dicts(self._model_defaults, self.model_kwargs) + self.model_kwargs = deep_merge_dicts( + self._model_defaults, + self.model_kwargs, + self._get_max_position_embeddings_config(), + ) # special handling for torch_dtype in model_kwargs since HF does not correctly update # torch_dtype string to an actual torch.dtype object (only with default) @@ -295,7 +305,7 @@ def _prefetch_checkpoint(self, model_name_or_path: str, skip_prefetch_weights: b # at this point it should be a directory (either the original one or the download dir) assert os.path.isdir(fetched_dir), f"Checkpoint path {fetched_dir} is not a directory." - self._load_quantization_config() + self._load_quantization_config(fetched_dir) return fetched_dir @@ -313,13 +323,13 @@ def _load_checkpoint(self, model: nn.Module, device: DeviceLikeType): # model-transformed weights,leading to unexpected key mismatches or format issues. load_checkpoint_in_model(model, checkpoint=ckpt_file, full_state_dict=False) - def _load_quantization_config(self): + def _load_quantization_config(self, fetched_dir: str): """Load the quantization config from the model directory if not done already.""" if self._quant_config is not None: return assert self.model - hf_quant_config_file = os.path.join(self.model, "hf_quant_config.json") + hf_quant_config_file = os.path.join(fetched_dir, "hf_quant_config.json") if os.path.exists(hf_quant_config_file): with open(hf_quant_config_file, "r") as file: quantization_config = json.load(file) @@ -344,6 +354,15 @@ class AutoModelForImageTextToTextFactory(AutoModelForCausalLMFactory): }, } + def _get_max_position_embeddings_config(self) -> Dict[str, Any]: + """Get the max position embeddings config for the model.""" + return { + "max_position_embeddings": self.max_seq_len, + "text_config": { + "max_position_embeddings": self.max_seq_len, + }, + } + @property def automodel_from_config(self): return AutoModelForImageTextToText.from_config diff --git a/tensorrt_llm/_torch/auto_deploy/transform/interface.py b/tensorrt_llm/_torch/auto_deploy/transform/interface.py index 294bd0c178d..dd5bc421bb8 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/interface.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/interface.py @@ -227,18 +227,26 @@ def __call__( # run or skip the transform if self.config.enabled: # run graph pre-cleanup - self._run_pre_cleanup(gm, info_last) - - # run the transform in a error-handling wrapper - try: - gm, info = self._apply(gm, cm, factory) - except Exception as e: - error_msg = f"Transform {t_name} failed" - if self.config.skip_on_error: + is_clean_pre, has_valid_shapes_pre = self._run_pre_cleanup(gm, info_last) + + # run the transform in a error-handling wrapper if desired + if self.config.skip_on_error: + try: + gm, info = self._apply(gm, cm, factory) + except Exception as e: + error_msg = f"Transform {t_name} failed" ad_logger.warning(f"{error_msg}: {e}") info = TransformInfo(skipped=True, num_matches=0) - else: - raise TransformError(error_msg) from e + else: + # handle this here normally to improve debugging and error message + gm, info = self._apply(gm, cm, factory) + + # we cannot say it's clean if the previous wasn't clean even if this one is + # create new info object with updated cleanup status + info_dict = info.model_dump() + info_dict["is_clean"] &= is_clean_pre + info_dict["has_valid_shapes"] &= has_valid_shapes_pre + info = TransformInfo(**info_dict) # run graph post-cleanup info = self._run_post_cleanup(gm, info) @@ -279,20 +287,36 @@ def _set_autodeploy_meta(self, gm: GraphModule, autodeploy_meta: AutodeployMeta) gm.meta[self._autodeploy_meta_key] = autodeploy_meta @final - def _run_pre_cleanup(self, gm: GraphModule, info: TransformInfo) -> None: + def _run_pre_cleanup(self, gm: GraphModule, info: TransformInfo) -> Tuple[bool, bool]: """Run graph cleanup before the transform. + Args: + gm: The graph module to run cleanup on. + info: The last transform info. + + Returns: + A tuple of (is_clean, has_valid_shapes) indicating the cleanup status after the + pre-cleanup. + This is used to ensure the transform is applied to a clean graph as needed by the transform. """ if not self.config.requires_clean_graph: - return + return info.is_clean, info.has_valid_shapes + + is_clean = info.is_clean + has_valid_shapes = is_clean and info.has_valid_shapes # check if run cleanup depending on the config and info - if self.config.requires_shape_prop and not (info.is_clean and info.has_valid_shapes): + if self.config.requires_shape_prop and not has_valid_shapes: with lift_to_meta(gm): canonicalize_graph(gm, shape_prop=True) - elif self.config.requires_clean_graph and not info.is_clean: + is_clean = True + has_valid_shapes = True + elif self.config.requires_clean_graph and not is_clean: canonicalize_graph(gm) + is_clean = True + + return is_clean, has_valid_shapes @final def _run_post_cleanup(self, gm: GraphModule, info: TransformInfo) -> TransformInfo: diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py b/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py new file mode 100644 index 00000000000..94da4dd514b --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/attention.py @@ -0,0 +1,562 @@ +"""Pattern matching for detecting repeat_kv, eager, grouped attention patterns from Huggingface models.""" + +from typing import Any, Callable, Dict, List, Tuple, Type + +import torch +import torch.nn.functional as F +from pydantic import Field +from torch.fx import GraphModule + +from ...custom_ops.attention_interface import AttentionDescriptor +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ...utils.logger import ad_logger +from ...utils.node_utils import is_op +from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern +from ..interface import BaseTransform, TransformConfig, TransformInfo, TransformRegistry + + +def _apply_pattern( + gm: GraphModule, + pattern_name: str, + register_fn: Callable[[ADPatternMatcherPass], None], +) -> int: + """Utility to register and apply a pattern.""" + patterns = ADPatternMatcherPass() + register_fn(patterns) + num_matches = patterns.apply(gm.graph) + return num_matches + + +def _repeat_kv_pattern(hidden_states, n_rep) -> torch.Tensor: + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = torch.unsqueeze(hidden_states, 2) + hidden_states = hidden_states.expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +def _repeat_kv_repl(hidden_states, n_rep) -> torch.Tensor: + return torch.ops.auto_deploy.torch_attention_repeat_kv(hidden_states, n_rep) + + +# with causal_mask, no division +def _sfdp_pattern_1(query, key, value, attention_mask, scaling, dropout): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + attn_weights = attn_weights + attention_mask + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = F.dropout(attn_weights, p=dropout, training=False) + attn_output = torch.matmul(attn_weights, value) + return attn_output + + +def _sfdp_replacement_1(query, key, value, attention_mask, scaling, dropout): + return torch.ops.auto_deploy.torch_attention_sdpa.default( + query, + key, + value, + attn_mask=None, + dropout_p=dropout, + is_causal=True, + scale=scaling, + ) + + +# no causal_mask, no division +def _sfdp_pattern_2(query, key, value, scaling, dropout): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = F.dropout(attn_weights, p=dropout, training=False) + attn_output = torch.matmul(attn_weights, value) + return attn_output + + +def _sfdp_replacement_2(query, key, value, scaling, dropout): + return torch.ops.auto_deploy.torch_attention_sdpa.default( + query, + key, + value, + attn_mask=None, + dropout_p=dropout, + is_causal=False, + scale=scaling, + ) + + +# with causal_mask, with division +def _sfdp_pattern_3(query, key, value, attention_mask, scaling, dropout): + attn_weights = torch.matmul(query, key.transpose(2, 3)) / scaling + attn_weights = attn_weights + attention_mask + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = F.dropout(attn_weights, p=dropout, training=False) + attn_output = torch.matmul(attn_weights, value) + return attn_output + + +def _sfdp_replacement_3(query, key, value, attention_mask, scaling, dropout): + scaling = 1.0 / scaling + return torch.ops.auto_deploy.torch_attention_sdpa.default( + query, + key, + value, + attn_mask=None, + dropout_p=dropout, + is_causal=True, + scale=scaling, + ) + + +# no causal_mask, with division +def _sfdp_pattern_4(query, key, value, scaling, dropout): + attn_weights = torch.matmul(query, key.transpose(2, 3)) / scaling + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype) + attn_weights = F.dropout(attn_weights, p=dropout, training=False) + attn_output = torch.matmul(attn_weights, value) + return attn_output + + +def _sfdp_replacement_4(query, key, value, scaling, dropout): + scaling = 1.0 / scaling + return torch.ops.auto_deploy.torch_attention_sdpa.default( + query, + key, + value, + attn_mask=None, + dropout_p=dropout, + is_causal=False, + scale=scaling, + ) + + +# no causal_mask, with division, explicit casting model +def _sfdp_pattern_5(query, key, value, scaling, dropout): + attn_weights = torch.matmul(query, key.transpose(2, 3)) / scaling + attn_weights = attn_weights.to(torch.float32) + attn_weights = F.softmax(attn_weights, dim=-1).to(query.dtype) + attn_weights = F.dropout(attn_weights, p=dropout, training=False) + attn_output = torch.matmul(attn_weights, value) + return attn_output + + +def _sfdp_replacement_5(query, key, value, scaling, dropout): + scaling = 1.0 / scaling + return torch.ops.auto_deploy.torch_attention_sdpa.default( + query, + key, + value, + attn_mask=None, + dropout_p=dropout, + is_causal=False, + scale=scaling, + ) + + +# with causal_mask, with division, explicit casting model +def _sfdp_pattern_6(query, key, value, attention_mask, scaling, dropout): + attn_weights = torch.matmul(query, key.transpose(2, 3)) / scaling + attn_weights = attn_weights + attention_mask + attn_weights = attn_weights.to(torch.float32) + attn_weights = F.softmax(attn_weights, dim=-1).to(query.dtype) + attn_weights = F.dropout(attn_weights, p=dropout, training=False) + attn_output = torch.matmul(attn_weights, value) + return attn_output + + +def _sfdp_replacement_6(query, key, value, attention_mask, scaling, dropout): + scaling = 1.0 / scaling + return torch.ops.auto_deploy.torch_attention_sdpa.default( + query, + key, + value, + attn_mask=None, + dropout_p=dropout, + is_causal=True, + scale=scaling, + ) + + +# Only pass in causal attention mask in downstream standardized pipeline +def _sfdp_pattern_7(query, key, value, attention_mask, scaling, dropout): + return torch.ops.auto_deploy.torch_attention_sdpa.default( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=dropout, + is_causal=False, + scale=scaling, + ) + + +def _sfdp_replacement_7(query, key, value, attention_mask, scaling, dropout): + return torch.ops.auto_deploy.torch_attention_sdpa.default( + query, + key, + value, + attn_mask=None, + dropout_p=dropout, + is_causal=True if attention_mask is not None else False, + scale=scaling, + ) + + +# with causal_mask, no division, does not cast to fp32 for softmax +def _sfdp_pattern_8(query, key, value, attention_mask, scaling, dropout): + attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling + attn_weights = attn_weights + attention_mask + attn_weights = F.softmax(attn_weights, dim=-1) + attn_weights = F.dropout(attn_weights, p=dropout, training=False) + attn_output = torch.matmul(attn_weights, value) + return attn_output + + +def _sfdp_replacement_8(query, key, value, attention_mask, scaling, dropout): + return torch.ops.auto_deploy.torch_attention_sdpa.default( + query, + key, + value, + attn_mask=None, + dropout_p=dropout, + is_causal=True, + scale=scaling, + ) + + +def _get_sfdp_patterns() -> List[Dict[str, Any]]: + bs, seq_len, n_heads, hidden_size = 8, 16, 8, 512 + head_dim = hidden_size // n_heads + + def common_tensor(): + return torch.randn(bs, n_heads, seq_len, head_dim, device="cuda", dtype=torch.bfloat16) + + def causal_mask(): + return torch.randn(bs, 1, 1, seq_len, device="cuda", dtype=torch.bfloat16) + + configs = [ + (_sfdp_pattern_1, _sfdp_replacement_1, True, 0.1234743, 0.85849734), + (_sfdp_pattern_2, _sfdp_replacement_2, False, 0.234743, 0.5849734), + (_sfdp_pattern_3, _sfdp_replacement_3, True, 0.34743, 0.849734), + (_sfdp_pattern_4, _sfdp_replacement_4, False, 0.74321, 0.9734), + (_sfdp_pattern_5, _sfdp_replacement_5, False, 0.874321, 0.89734), + (_sfdp_pattern_6, _sfdp_replacement_6, True, 0.634743, 0.6849734), + (_sfdp_pattern_7, _sfdp_replacement_7, True, 0.34743, 0.849734), + (_sfdp_pattern_8, _sfdp_replacement_8, True, 0.2234743, 0.95849734), + ] + + patterns = [] + for search_fn, replace_fn, has_mask, scale, dropout in configs: + dummy_args = [common_tensor(), common_tensor(), common_tensor()] + if has_mask: + dummy_args.append(causal_mask()) + dummy_args.extend([scale, dropout]) + + patterns.append( + { + "search_fn": search_fn, + "replace_fn": replace_fn, + "dummy_args": dummy_args, + "scalar_workaround": {"scaling": scale, "dropout": dropout}, + "op_ignore_types": {torch.ops.aten.to.dtype: (torch.dtype,)}, + } + ) + + return patterns + + +def _grouped_attn_pattern_1(q, k, v, n_rep, attn_mask, dropout_p, scale): + k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep) + v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep) + return torch.ops.auto_deploy.torch_attention_sdpa.default( + q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale + ) + + +def _grouped_attn_replacement_1(q, k, v, n_rep, attn_mask, dropout_p, scale): + return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default( + q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale + ) + + +# Only expose torch_attention_grouped_sdpa after the transformation +def _grouped_attn_pattern_2(q, k, v, attn_mask, dropout_p, scale): + return torch.ops.auto_deploy.torch_attention_sdpa.default( + q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale + ) + + +def _grouped_attn_replacement_2(q, k, v, attn_mask, dropout_p, scale): + return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default( + q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=False, scale=scale + ) + + +def _grouped_attn_pattern_3(q, k, v, n_rep, attn_mask, dropout_p, scale): + k = torch.ops.auto_deploy.torch_attention_repeat_kv(k, n_rep) + v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, n_rep) + return torch.ops.auto_deploy.torch_attention_sdpa.default( + q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale + ) + + +def _grouped_attn_replacement_3(q, k, v, n_rep, attn_mask, dropout_p, scale): + return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default( + q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale + ) + + +# Only expose torch_attention_grouped_sdpa after the transformation +def _grouped_attn_pattern_4(q, k, v, attn_mask, dropout_p, scale): + return torch.ops.auto_deploy.torch_attention_sdpa.default( + q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale + ) + + +def _grouped_attn_replacement_4(q, k, v, attn_mask, dropout_p, scale): + return torch.ops.auto_deploy.torch_attention_grouped_sdpa.default( + q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=True, scale=scale + ) + + +@TransformRegistry.register("match_repeat_kv") +class MatchRepeatKV(BaseTransform): + """ + Match and replace the repeat_kv pattern with torch.ops.auto_deploy.torch_attention_repeat_kv. + """ + + def _apply( + self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory + ) -> Tuple[GraphModule, TransformInfo]: + def register_repeat_kv(patterns: ADPatternMatcherPass): + dummy_args = [ + torch.randn(8, 8, 16, 64, device="cuda", dtype=torch.float16), + 7, + ] + register_ad_pattern( + search_fn=_repeat_kv_pattern, + replace_fn=_repeat_kv_repl, + patterns=patterns, + dummy_args=dummy_args, + op_ignore_types={ + torch.ops.aten.reshape.default: (int,), + torch.ops.aten.expand.default: (int,), + }, + scalar_workaround={"n_rep": dummy_args[1]}, + ) + + num_kv_patterns = _apply_pattern(gm, "Repeat KV", register_repeat_kv) + + if num_kv_patterns > 0: + self.config.run_shape_prop = True + + info = TransformInfo( + skipped=False, + num_matches=num_kv_patterns, + is_clean=False, + has_valid_shapes=False, + ) + + return gm, info + + +@TransformRegistry.register("match_eager_attention") +class MatchEagerAttention(BaseTransform): + """ + Match and replace the eager attention pattern with torch.ops.auto_deploy.torch_attention_sdpa. + """ + + def _apply( + self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory + ) -> Tuple[GraphModule, TransformInfo]: + def register_eager_attention(patterns: ADPatternMatcherPass): + for pattern_config in _get_sfdp_patterns(): + register_ad_pattern(**pattern_config, patterns=patterns) + + num_eager_patterns = _apply_pattern(gm, "Eager Attention", register_eager_attention) + + info = TransformInfo( + skipped=False, + num_matches=num_eager_patterns, + is_clean=False, + has_valid_shapes=False, + ) + + return gm, info + + +@TransformRegistry.register("match_grouped_attention") +class MatchGroupedAttention(BaseTransform): + """ + Match and replace the grouped attention pattern with + torch.ops.auto_deploy.torch_attention_grouped_sdpa. + """ + + def _apply( + self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory + ) -> Tuple[GraphModule, TransformInfo]: + def register_grouped_attention(patterns: ADPatternMatcherPass): + q = torch.randn(8, 8, 16, 64, device="cuda", dtype=torch.float16) + k1 = torch.randn(8, 1, 16, 64, device="cuda", dtype=torch.float16) + v1 = torch.randn(8, 1, 16, 64, device="cuda", dtype=torch.float16) + attn_mask = torch.randn(8, 1, 1, 16, device="cuda", dtype=torch.float16) + dropout = 0.12345 + scale = 0.56789 + n_rep = 7 + + dummy_args_1 = [q, k1, v1, n_rep, attn_mask, dropout, scale] + dummy_args_2 = [q, k1, v1, attn_mask, dropout, scale] + + register_ad_pattern( + search_fn=_grouped_attn_pattern_1, + replace_fn=_grouped_attn_replacement_1, + patterns=patterns, + dummy_args=dummy_args_1, + scalar_workaround={"scale": scale, "dropout_p": dropout, "n_rep": n_rep}, + ) + register_ad_pattern( + search_fn=_grouped_attn_pattern_2, + replace_fn=_grouped_attn_replacement_2, + patterns=patterns, + dummy_args=dummy_args_2, + scalar_workaround={ + "scale": scale, + "dropout_p": dropout, + }, + ) + register_ad_pattern( + search_fn=_grouped_attn_pattern_3, + replace_fn=_grouped_attn_replacement_3, + patterns=patterns, + dummy_args=dummy_args_1, + scalar_workaround={"scale": scale, "dropout_p": dropout, "n_rep": n_rep}, + ) + register_ad_pattern( + search_fn=_grouped_attn_pattern_4, + replace_fn=_grouped_attn_replacement_4, + patterns=patterns, + dummy_args=dummy_args_2, + scalar_workaround={ + "scale": scale, + "dropout_p": dropout, + }, + ) + + num_grouped_patterns = _apply_pattern(gm, "Grouped Attention", register_grouped_attention) + + info = TransformInfo( + skipped=False, + num_matches=num_grouped_patterns, + is_clean=False, + has_valid_shapes=False, + ) + + return gm, info + + +class MatchAttentionLayoutConfig(TransformConfig): + """Configuration for the insert cached attention transform.""" + + attention_op: Type[AttentionDescriptor] = Field(description="The attention descriptor to use.") + + +@TransformRegistry.register("match_attention_layout") +class MatchAttentionLayout(BaseTransform): + """ + Match and transform attention operations to match the layout expected by the attention backend. + + If the attention backend expects 'bnsd' layout (batch, num_heads, seq_len, head_dim), which + is the default for SDPA operations, we don't need to transform anything. + + If the backend expects 'bsnd' layout (batch, seq_len, num_heads, head_dim), we insert + appropriate transposes before and after SDPA operations and replace them with bsnd_grouped_sdpa. + """ + + config: MatchAttentionLayoutConfig + + @classmethod + def get_config_class(cls) -> Type[TransformConfig]: + return MatchAttentionLayoutConfig + + def _apply( + self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory + ) -> Tuple[GraphModule, TransformInfo]: + # Get attention layout from attention_op + attention_layout = self.config.attention_op.get_attention_layout() + + # List of SDPA operations to look for + sdpa_ops = { + torch.ops.auto_deploy.torch_attention_sdpa, + torch.ops.auto_deploy.torch_attention_grouped_sdpa, + } + + graph = gm.graph + num_bsnd_patterns = 0 + + # Look for SDPA operations + for sdpa_node in list(graph.nodes): + if sdpa_node.op != "call_function" or not is_op(sdpa_node, sdpa_ops): + continue + + ad_logger.debug(f"Found SDPA node to transform for bsnd layout: {sdpa_node}") + + # Extract q, k, v inputs + q, k, v = sdpa_node.args[:3] + + # Check if we need to transpose the inputs + if attention_layout == "bsnd": + # Add transposes before the node (from bnsd to bsnd) + with graph.inserting_before(sdpa_node): + q_updated = graph.call_function(torch.ops.aten.transpose.int, args=(q, 1, 2)) + k_updated = graph.call_function(torch.ops.aten.transpose.int, args=(k, 1, 2)) + v_updated = graph.call_function(torch.ops.aten.transpose.int, args=(v, 1, 2)) + + # Preserve fake tensor in meta["val"] for the transposed inputs + q_updated.meta["val"] = q.meta["val"].transpose(1, 2) + k_updated.meta["val"] = k.meta["val"].transpose(1, 2) + v_updated.meta["val"] = v.meta["val"].transpose(1, 2) + elif attention_layout == "bnsd": + # we don't need to do anything... + q_updated = q + k_updated = k + v_updated = v + else: + raise ValueError(f"Unsupported attention layout: {attention_layout}") + + # Create bsnd_grouped_sdpa node with the same args as the original node + # but using the transposed inputs + with graph.inserting_before(sdpa_node): + source_sdpa_node = graph.call_function( + self.config.attention_op.get_source_attention_op(), + args=(q_updated, k_updated, v_updated) + sdpa_node.args[3:], + kwargs=sdpa_node.kwargs, + ) + + # Check if need to update the output node to match the layout + if attention_layout == "bsnd": + # Add transpose for the output (from bsnd back to bnsd) + with graph.inserting_after(source_sdpa_node): + output_updated = graph.call_function( + torch.ops.aten.transpose.int, args=(source_sdpa_node, 1, 2) + ) + + # Preserve fake tensor in meta["val"] for the transposed inputs + source_sdpa_node.meta["val"] = sdpa_node.meta["val"].transpose(1, 2).contiguous() + output_updated.meta["val"] = source_sdpa_node.meta["val"].transpose(1, 2) + elif attention_layout == "bnsd": + output_updated = source_sdpa_node + else: + raise ValueError(f"Unsupported attention layout: {attention_layout}") + + # Replace the old node with the transposed output + sdpa_node.replace_all_uses_with(output_updated) + + num_bsnd_patterns += 1 + + info = TransformInfo( + skipped=False, + num_matches=num_bsnd_patterns, + is_clean=False, + has_valid_shapes=False, + ) + + return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/quantization.py b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py similarity index 68% rename from tensorrt_llm/_torch/auto_deploy/transformations/library/quantization.py rename to tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py index 0414ed2fe25..8cf3630b828 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/quantization.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py @@ -1,11 +1,12 @@ from collections import defaultdict from functools import partial -from typing import Any, Dict +from typing import Dict, Tuple import torch.nn as nn from torch.fx import GraphModule, Node -from ...utils.logger import ad_logger +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface from ...utils.node_utils import ( extract_param_names_from_lin_node, get_quantization_params_from_linear_node, @@ -20,7 +21,7 @@ remove_output_quantizers, should_skip_quantization, ) -from .._graph import canonicalize_graph +from ..interface import BaseTransform, TransformInfo, TransformRegistry def _insert_quantized_linear( @@ -138,12 +139,8 @@ def get_scale_name(scale_name): scale_target_module = gm # Register in root module scale_name_prefix = "" - ad_logger.info(f"Quantized BMM with dynamic weight tensor for node {node}") else: # If we can't determine the shape, skip quantization - ad_logger.warning( - f"BMM weight is dynamic tensor without shape metadata, skipping quantization for node {node}" - ) return # Common logic for both parameter and dynamic tensor cases @@ -169,53 +166,70 @@ def get_scale_name(scale_name): node.args = (*node.args, *scale_values) -def quantize(gm: GraphModule, quant_config: Dict[str, Any]) -> None: - """Quantize the GraphModule and replace linear with quantized linear.""" - # extract info from quant_config - is_quant_graph = is_quantized_graph(gm) - quant_algo = quant_config.get("quant_algo") - excluded_patterns = quant_config.get("exclude_modules", []) - - # no quantization to do - if not (is_quant_graph or quant_config): - ad_logger.info("No quantization to do.") - return +@TransformRegistry.register("quantize") +class Quantization(BaseTransform): + """Quantize the GraphModule and replace linear/BMM with quantized linear/BMM.""" - # tracking quantized operations in the graph - quantized_nodes: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int)) - for n in gm.graph.nodes: - if should_skip_quantization(n, excluded_patterns): - continue - - # Process linear operations - if is_linear_op(n, include_quantization=False): - # get per-layer quantization format from the node - quant_algo_n: str = ( - get_quantization_from_linear_node(n) if is_quant_graph else quant_algo + def _apply( + self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory + ) -> Tuple[GraphModule, TransformInfo]: + # extract info from quant_config + quant_config = factory.get_quant_config() + if not quant_config: + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True ) - if not quant_algo_n: - continue - # insert quantized linear node - _insert_quantized_linear(gm, n, QuantizationImpl.create(quant_algo_n), is_quant_graph) - quantized_nodes[quant_algo_n]["linear"] += 1 + is_quant_graph = is_quantized_graph(gm) + quant_algo = quant_config.get("quant_algo") + excluded_patterns = quant_config.get("exclude_modules", []) + if not quant_algo: + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) - # Process BMM operations - elif is_bmm_op(n): - if not quant_algo: + # tracking quantized operations in the graph + quantized_nodes: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int)) + for n in gm.graph.nodes: + if should_skip_quantization(n, excluded_patterns): continue - # insert quantized bmm node - _insert_quantized_bmm( - gm, n, QuantizationImpl.create(quant_algo, is_bmm=True), is_quant_graph - ) - quantized_nodes[quant_algo]["bmm"] += 1 - - if is_quant_graph: - remove_output_quantizers(gm) + # Process linear operations + if is_linear_op(n, include_quantization=False): + # get per-layer quantization format from the node + quant_algo_n: str = ( + get_quantization_from_linear_node(n) if is_quant_graph else quant_algo + ) + if not quant_algo_n: + continue + + # insert quantized linear node + _insert_quantized_linear( + gm, n, QuantizationImpl.create(quant_algo_n), is_quant_graph + ) + quantized_nodes[quant_algo_n]["linear"] += 1 + + # Process BMM operations + elif is_bmm_op(n): + if not quant_algo: + continue + + # insert quantized bmm node + _insert_quantized_bmm( + gm, n, QuantizationImpl.create(quant_algo, is_bmm=True), is_quant_graph + ) + quantized_nodes[quant_algo]["bmm"] += 1 + + if is_quant_graph: + remove_output_quantizers(gm) + + num_matches = 0 + for quant_algo in quantized_nodes: + for op_type, count in quantized_nodes[quant_algo].items(): + num_matches += count + + info = TransformInfo( + skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=True + ) - canonicalize_graph(gm) - for quant_algo in quantized_nodes: - for op_type, count in quantized_nodes[quant_algo].items(): - ad_logger.info(f"Found {count} {quant_algo} quantized {op_type} nodes.") - ad_logger.debug("After quantization: " + str(gm)) + return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/quantize_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py similarity index 72% rename from tensorrt_llm/_torch/auto_deploy/transformations/library/quantize_moe.py rename to tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py index 93890d1da8c..b7b24cd5d5c 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/quantize_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py @@ -1,14 +1,15 @@ from functools import partial -from typing import Any, Callable, Dict, List, Tuple +from typing import Callable, List, Tuple import torch import torch.nn as nn from torch.fx import GraphModule, Node -from ...utils.logger import ad_logger +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface from ...utils.node_utils import is_op from ...utils.quantization_utils import QuantizationImpl, should_skip_quantization -from .._graph import canonicalize_graph +from ..interface import BaseTransform, TransformInfo, TransformRegistry quantized_moe_op_map = { "FP8": torch.ops.auto_deploy.torch_quant_fp8_moe, @@ -92,47 +93,10 @@ def collect_scales(index: int) -> Tuple[List[Node], List[Node], List[Node]]: quantized_op, args=tuple(args), ) - ad_logger.debug(f"Updating {node.name} args to {new_node.args}") node.replace_all_uses_with(new_node) gm.graph.erase_node(node) -def quantize_moe(gm: GraphModule, quant_config: Dict[str, Any]) -> None: - """ - Traverse gm, find every torch.ops.auto_deploy.torch_moe, and replace it with the - quantized version using the quant_algo from quant_config. - """ - quant_algo = quant_config.get("quant_algo") - if not quant_algo: - ad_logger.info("No quantization to do.") - return gm - excluded_patterns = quant_config.get("exclude_modules", []) - - quant_impl = QuantizationImpl.create(quant_algo) - quantized_op = quantized_moe_op_map[quant_algo] - - count = 0 - - for node in list(gm.graph.nodes): - if is_op(node, torch.ops.auto_deploy.torch_moe): - # Check that all expert weights should be quantized - w1_names, w2_names, w3_names = _extract_moe_weight_param_lists(node) - if any( - should_skip_quantization(n, excluded_patterns) - for n in w1_names + w2_names + w3_names - ): - continue - _quantize_moe_node(gm, node, quant_impl, quantized_op) - count += 1 - - if count == 0: - return gm - - gm = canonicalize_graph(gm) - ad_logger.info(f"Found {count} {quant_algo} quantized {quantized_op} nodes.") - return - - # TODO(Fridah-nv): robust handling similar to `extract_param_names_from_lin_node` or expand it def _extract_moe_weight_param_lists(moe_node: Node) -> Tuple[List[str], List[str], List[str]]: """ @@ -165,3 +129,51 @@ def _unwrap_list(arg) -> List[str]: w3_names = _unwrap_list(w3_list) return w1_names, w2_names, w3_names + + +@TransformRegistry.register("quantize_moe") +class QuantizeMOE(BaseTransform): + """ + Traverse gm, find every torch.ops.auto_deploy.torch_moe, and replace it with the + quantized version using the quant_algo from quant_config. + """ + + def _apply( + self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory + ) -> Tuple[GraphModule, TransformInfo]: + quant_config = factory.get_quant_config() + quant_algo = quant_config.get("quant_algo") if quant_config else None + + if not quant_config or not quant_algo: + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) + excluded_patterns = quant_config.get("exclude_modules", []) + + quant_impl = QuantizationImpl.create(quant_algo) + quantized_op = quantized_moe_op_map[quant_algo] + + count = 0 + + for node in list(gm.graph.nodes): + if is_op(node, torch.ops.auto_deploy.torch_moe): + # Check that all expert weights should be quantized + w1_names, w2_names, w3_names = _extract_moe_weight_param_lists(node) + if any( + should_skip_quantization(n, excluded_patterns) + for n in w1_names + w2_names + w3_names + ): + continue + _quantize_moe_node(gm, node, quant_impl, quantized_op) + count += 1 + + if count == 0: + return gm, TransformInfo( + skipped=False, num_matches=0, is_clean=True, has_valid_shapes=True + ) + + info = TransformInfo( + skipped=False, num_matches=count, is_clean=False, has_valid_shapes=False + ) + + return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/_graph.py b/tensorrt_llm/_torch/auto_deploy/transformations/_graph.py index 5e92764079f..0babe665850 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/_graph.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/_graph.py @@ -96,23 +96,24 @@ def named_graphmodules(gm: fx.GraphModule) -> Iterator[Tuple[str, fx.GraphModule yield name, m -def _move_single_gm_to_device( - gm: GraphModule, device: torch.device, recompile_graph: bool = False -) -> None: +def _move_single_gm_to_device(gm: GraphModule, device: torch.device) -> None: """Move one GraphModule and its nodes to the specified device in-place. Partially inspired by https://github.com/pytorch/pytorch/blob/05cb98f91d49df9eadfcb3fc29bbd1b621d88860/torch/export/passes/__init__.py#L11 """ # move state dict gm.to(device) + recompile_graph = False for node in gm.graph.nodes: # move all the nodes kwargs with burnt-in device if "device" in node.kwargs: + recompile_graph = True kwargs = node.kwargs.copy() kwargs["device"] = device node.kwargs = kwargs if is_op(node, torch.ops.aten.to.device): + recompile_graph = True args = list(node.args) args[1] = device node.args = tuple(args) @@ -135,7 +136,7 @@ def move_to_device(gm: fx.GraphModule, device: DeviceLikeType) -> fx.GraphModule for _, subgm in reversed(list(named_graphmodules(gm))): # recompile graph to update self generated codes in subgraph - _move_single_gm_to_device(subgm, device, subgm is not gm) + _move_single_gm_to_device(subgm, device) def _is_impure_node(node: Node) -> bool: diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py index 7662a3d5839..4a39c7f662f 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py @@ -1,13 +1,10 @@ """A library of transformation passes.""" -from .attention import * from .collectives import * from .eliminate_redundant_transposes import * from .fused_moe import * from .fusion import * from .kvcache import * -from .quantization import * -from .quantize_moe import * from .rms_norm import * from .rope import * from .sharding import * diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/attention.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/attention.py deleted file mode 100644 index e6efb8e0e7f..00000000000 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/attention.py +++ /dev/null @@ -1,833 +0,0 @@ -"""Pattern matching for detecting repeat_kv pattern from Huggingface models.""" - -from typing import Dict, Optional, Type - -import torch -from torch.fx import GraphModule, Node - -from ...custom_ops.attention_interface import AttentionDescriptor -from ...utils.logger import ad_logger -from ...utils.node_utils import is_op -from .._graph import canonicalize_graph - - -def match_repeat_kv(gm: GraphModule) -> None: - """ - Match and replace the repeat_kv pattern in fx graphs. - - The pattern is: - unsqueeze -> expand -> reshape -> [optional] contiguous - - This is replaced with torch.ops.auto_deploy.torch_attention_repeat_kv. - """ - graph = gm.graph - - num_kv_patterns = 0 - - # Iterate through nodes in the graph - for node in list(graph.nodes): - # Look for reshape nodes that could be the end of our pattern - if is_op(node, torch.ops.aten.reshape): - match_info = _match_repeat_kv_pattern(node) - if match_info: - ad_logger.debug(f"Found repeat_kv pattern at {node}") - _replace_with_repeat_kv(graph, match_info) - num_kv_patterns += 1 - - # Clean up the graph if we made any replacements - if num_kv_patterns: - canonicalize_graph(gm) - ad_logger.info(f"Found {num_kv_patterns} repeat_kv patterns") - - -def match_eager_attention(gm: GraphModule) -> None: - """ - Match and replace the eager attention pattern in fx graphs. - - The pattern is: - transpose -> matmul -> mul -> (optional) add -> softmax -> to -> dropout -> matmul - - This is replaced with torch.ops.auto_deploy.torch_attention_sdpa. - """ - graph = gm.graph - - # Track replacements to avoid processing nodes multiple times - num_eager_patterns = 0 - - # Iterate through nodes in the graph - for node in list(graph.nodes): - # Look for the final matmul nodes that could be part of our pattern - if is_op(node, torch.ops.aten.matmul): - match_info = _match_eager_attention_pattern(node) - if match_info: - ad_logger.debug(f"Found eager attention pattern at {node}") - _replace_with_sdpa(graph, match_info) - num_eager_patterns += 1 - - # Clean up the graph if we made any replacements - if num_eager_patterns: - canonicalize_graph(gm) - ad_logger.info(f"Found {num_eager_patterns} eager attention patterns") - - -def match_grouped_attention(gm: GraphModule) -> None: - """ - Match and replace the grouped attention pattern in fx graphs. - - The pattern is: - repeat_kv(k, n_rep) -> - repeat_kv(v, n_rep) -> - sdpa(q, repeated_k, repeated_v) - - This is replaced with torch.ops.auto_deploy.torch_attention_grouped_sdpa. - """ - graph = gm.graph - - # Track replacements to avoid processing nodes multiple times - num_grouped_patterns = 0 - - # Iterate through nodes in the graph - for node in list(graph.nodes): - # Look for SDPA nodes that could be part of our pattern - if is_op(node, torch.ops.auto_deploy.torch_attention_sdpa): - match_info = _match_grouped_attention_pattern(node) - if match_info: - ad_logger.debug(f"Found grouped attention pattern at {node}") - _replace_with_grouped_sdpa(graph, match_info) - num_grouped_patterns += 1 - - # Clean up the graph if we made any replacements - if num_grouped_patterns: - canonicalize_graph(gm) - ad_logger.info(f"Found {num_grouped_patterns} grouped attention patterns") - - -def match_causal_attn_mask(gm: GraphModule) -> None: - """ - Match attention operations with causal attention masks and optimize them. - - For operations that use explicit causal masks, this replaces: - - sdpa(q, k, v, causal_mask, dropout_p, False, scale) - with: - - sdpa(q, k, v, None, dropout_p, True, scale) - - This optimization enables more efficient implementations on supported backends. - """ - graph = gm.graph - - # Track replacements to avoid processing nodes multiple times - num_causal_patterns = 0 - - # Iterate through nodes in the graph - for node in list(graph.nodes): - # Look for SDPA nodes or grouped SDPA nodes - if not ( - is_op(node, torch.ops.auto_deploy.torch_attention_sdpa) - or is_op(node, torch.ops.auto_deploy.torch_attention_grouped_sdpa) - ): - continue - - # Get the attention mask argument (4th argument) - if len(node.args) < 4 or node.args[3] is None: - continue - - attn_mask = node.args[3] - - # Check if this mask is a causal mask - if not _is_causal_mask(attn_mask): - ad_logger.debug(f"Found non-causal attention mask at {node=}!") - continue - - ad_logger.debug(f"Found causal attention mask at {node}") - - # construct the new args list with args provided to the node and the default values otherwise - new_args = [] - for idx, arg in enumerate(node.target._schema.arguments): - # In case arg is provided to the node, use it - if idx < len(node.args): - new_args.append(node.args[idx]) - # In case arg is not provided to the node, use the default value - elif arg.has_default_value: - new_args.append(arg.default_value) - else: - raise ValueError(f"Missing required argument: {arg.name}") - - # Create new arguments with None mask and is_causal=True - new_args[3] = None # Set mask to None - new_args[5] = True # Set is_causal to True - - # Create new node with updated arguments - with graph.inserting_before(node): - new_node = graph.call_function(node.target, args=tuple(new_args), kwargs=node.kwargs) - - # Preserve metadata - new_node.meta = node.meta.copy() - - # Replace the old node with the new one - node.replace_all_uses_with(new_node) - - num_causal_patterns += 1 - - # Clean up the graph if we made any replacements - if num_causal_patterns: - canonicalize_graph(gm) - ad_logger.info(f"Found {num_causal_patterns} causal mask attention patterns") - - -def _match_repeat_kv_pattern(reshape_node: Node) -> Optional[Dict[str, Node]]: - """ - Match the repeat_kv pattern starting from a reshape node. - - The pattern is: - unsqueeze -> expand -> reshape -> [optional] contiguous - - Returns a dictionary with information about the match or None if no match. - """ - # Check that reshape_node is a reshape operation - if not is_op(reshape_node, torch.ops.aten.reshape): - return None - - # The reshape should have expand as its first argument - if len(reshape_node.args) < 1: - return None - - expand_node = reshape_node.args[0] - if not is_op(expand_node, torch.ops.aten.expand): - return None - - # The expand should have unsqueeze as its first argument - if len(expand_node.args) < 1: - return None - - unsqueeze_node = expand_node.args[0] - if not is_op(unsqueeze_node, torch.ops.aten.unsqueeze): - return None - - # The unsqueeze should be inserting a dimension at position 2 - if len(unsqueeze_node.args) < 2 or unsqueeze_node.args[1] != 2: - return None - - # Get the input tensor to unsqueeze - if len(unsqueeze_node.args) < 1: - return None - - input_tensor = unsqueeze_node.args[0] - - # Check input dimensions - should be 4D (batch, num_key_value_heads, seq_len, head_dim) - input_val = input_tensor.meta.get("val", None) - if input_val is None or len(input_val.shape) != 4: - return None - - # Extract batch size, num_kv_heads, seq_len, and head_dim from the input tensor shape - batch_size, num_kv_heads, seq_len, head_dim = input_val.shape - - # Check reshape args - if len(reshape_node.args) < 2 or not isinstance(reshape_node.args[1], list): - return None - - reshape_args = reshape_node.args[1] - if len(reshape_args) != 4: - return None - - # Check expand args - if len(expand_node.args) < 2 or not isinstance(expand_node.args[1], list): - return None - - expand_args = expand_node.args[1] - if len(expand_args) != 5: - return None - - # Determine n_rep by comparing the output and input head dimensions - # In the expand args, we should have [batch, num_kv_heads, n_rep, seq_len, head_dim] - # In the reshape args, we should have [batch, num_heads, seq_len, head_dim] - # where num_heads = num_kv_heads * n_rep - _, _, n_rep, _, _ = expand_args - _, reshape_num_heads, _, _ = reshape_args - - # Check that n_rep is an integer - if not isinstance(n_rep, int): - return None - - # Check that num_heads = num_kv_heads * n_rep - # This may be a symbolic expression, so we need to compare with caution - reshape_out_val = reshape_node.meta.get("val", None) - if reshape_out_val is None or len(reshape_out_val.shape) != 4: - return None - - # Ensure output shape is correct - out_batch, out_heads, out_seq, out_dim = reshape_out_val.shape - - # Check that input batch and seq dimensions match output - if out_batch != batch_size or out_seq != seq_len or out_dim != head_dim: - return None - - # Check if reshape is followed by a contiguous node - contiguous_node = None - users = list(reshape_node.users) - - # Only consider contiguous if reshape has exactly one user - if len(users) == 1 and is_op(users[0], torch.ops.aten.contiguous): - contiguous_node = users[0] - - result = { - "input_tensor": input_tensor, - "unsqueeze_node": unsqueeze_node, - "expand_node": expand_node, - "reshape_node": reshape_node, - "n_rep": n_rep, - } - - if contiguous_node: - result["contiguous_node"] = contiguous_node - - return result - - -def _match_eager_attention_pattern(final_matmul_node: Node) -> Optional[Dict[str, Node]]: - """ - Match the eager attention pattern starting from the final matmul node. - - The pattern is: - transpose -> matmul -> mul/div -> (optional) add -> (optional) to -> softmax -> (optional) to -> dropout -> matmul - - Returns a dictionary with information about the match or None if no match. - """ - # Check that final_matmul_node is a matmul operation - if not is_op(final_matmul_node, torch.ops.aten.matmul): - return None - - # Check we have two arguments - if len(final_matmul_node.args) < 2: - return None - - # The first arg of final matmul should be dropout - dropout_node = final_matmul_node.args[0] - if not is_op(dropout_node, torch.ops.aten.dropout): - return None - - # The second arg of final matmul is the value tensor (possibly repeated/transformed) - value = final_matmul_node.args[1] - - # The dropout should have a to_dtype node (or directly softmax) as input - if len(dropout_node.args) < 1: - return None - - # Allow optional to_dtype node after softmax - to_dtype_after_softmax = dropout_node.args[0] - if is_op(to_dtype_after_softmax, torch.ops.aten.to): - if len(to_dtype_after_softmax.args) < 1: - return None - softmax_node = to_dtype_after_softmax.args[0] - else: - softmax_node = to_dtype_after_softmax - - # Now we should have a softmax node - if not is_op(softmax_node, torch.ops.aten.softmax): - return None - - # The softmax should have dim=-1 (may be specified in different ways) - if len(softmax_node.args) < 2 or ( - isinstance(softmax_node.args[1], int) and softmax_node.args[1] != -1 - ): - # Check kwargs if not in args - if softmax_node.kwargs.get("dim", -1) != -1: - return None - - # The softmax node's input can be: - # - direct from add/mul/div - # - or through a to_dtype node (like to_35 in the example) - if len(softmax_node.args) < 1: - return None - - # Handle optional to_dtype node before softmax - prev_node = softmax_node.args[0] - if is_op(prev_node, torch.ops.aten.to): - if len(prev_node.args) < 1: - return None - prev_node = prev_node.args[0] - - # Check for attention mask pattern (add node) - if is_op(prev_node, torch.ops.aten.add): - add_node = prev_node - attn_mask = add_node.args[1] # Second arg is the mask - - # The add should have a mul or div node as its first argument - if len(add_node.args) < 1: - return None - - scaling_node = add_node.args[0] - if not (is_op(scaling_node, torch.ops.aten.mul) or is_op(scaling_node, torch.ops.aten.div)): - return None - elif is_op(prev_node, torch.ops.aten.mul) or is_op(prev_node, torch.ops.aten.div): - # No mask case - the softmax input is directly the mul or div node - scaling_node = prev_node - attn_mask = None - else: - return None - - # Check the scaling operation and extract the scaling factor - is_division = is_op(scaling_node, torch.ops.aten.div) - - # The mul/div node should have a matmul node as input - if len(scaling_node.args) < 2: - return None - - # Extract the scaling factor, adjusting for division vs multiplication - scale = scaling_node.args[1] - # Allow for constant or tensor scale - if not isinstance(scale, (float, int, Node)): - return None - - # For division, we need to invert the scaling factor if it's a constant - if is_division and isinstance(scale, (float, int)): - scale = 1.0 / scale - - first_matmul_node = scaling_node.args[0] - if not is_op(first_matmul_node, torch.ops.aten.matmul): - return None - - # The first matmul should have the query and key transpose as inputs - if len(first_matmul_node.args) < 2: - return None - - query = first_matmul_node.args[0] - transpose_key = first_matmul_node.args[1] - - # Check for transpose, could be any dimensions - if not is_op(transpose_key, torch.ops.aten.transpose): - return None - - # The transpose should have the key as input - if len(transpose_key.args) < 1: - return None - - key = transpose_key.args[0] - - # Create the match info dictionary - match_info = { - "query": query, - "key": key, - "value": value, - "scale": scale, - "dropout_p": dropout_node.args[1] if len(dropout_node.args) > 1 else 0.0, - "final_matmul": final_matmul_node, - } - - # Add the attention mask if it exists - if attn_mask is not None: - match_info["attn_mask"] = attn_mask - - return match_info - - -def _match_grouped_attention_pattern(sdpa_node: Node) -> Optional[Dict[str, Node]]: - """ - Match the grouped attention pattern starting from an SDPA node. - - The pattern is: - repeat_kv(k, n_rep) -> - repeat_kv(v, n_rep) -> - sdpa(q, repeated_k, repeated_v) - - Returns a dictionary with information about the match or None if no match. - """ - # Check that sdpa_node is an SDPA operation - if not is_op(sdpa_node, torch.ops.auto_deploy.torch_attention_sdpa): - return None - - # SDPA should have query, key, value as its first three arguments - if len(sdpa_node.args) < 3: - return None - - query, key_repeated, value_repeated = sdpa_node.args[0:3] - - # Key and value should come from repeat_kv operations - if not is_op(key_repeated, torch.ops.auto_deploy.torch_attention_repeat_kv) or not is_op( - value_repeated, torch.ops.auto_deploy.torch_attention_repeat_kv - ): - return None - - # Extract the original key, value, and n_rep - orig_key = key_repeated.args[0] - orig_value = value_repeated.args[0] - key_n_rep = key_repeated.args[1] - value_n_rep = value_repeated.args[1] - - # Both repeat_kv operations should have the same n_rep - if key_n_rep != value_n_rep: - return None - - # Return the match information - return { - "query": query, - "key": orig_key, - "value": orig_value, - "key_repeated": key_repeated, - "value_repeated": value_repeated, - "n_rep": key_n_rep, - "sdpa_node": sdpa_node, - } - - -def _replace_with_repeat_kv(graph, match_info: Dict[str, Node]) -> None: - """ - Replace the matched repeat_kv pattern with the custom op. - """ - input_tensor = match_info["input_tensor"] - reshape_node = match_info["reshape_node"] - n_rep = match_info["n_rep"] - - # Determine the node to replace (either reshape or contiguous if present) - node_to_replace = match_info.get("contiguous_node", reshape_node) - - with graph.inserting_before(node_to_replace): - repeat_kv_node = graph.call_function( - torch.ops.auto_deploy.torch_attention_repeat_kv, args=(input_tensor, n_rep) - ) - - # Preserve metadata from the original node - repeat_kv_node.meta = node_to_replace.meta.copy() - - # Replace all uses of the node with the repeat_kv node - node_to_replace.replace_all_uses_with(repeat_kv_node) - - -def _replace_with_sdpa(graph, match_info: Dict[str, Node]) -> None: - """ - Replace the matched eager attention pattern with scaled_dot_product_attention. - """ - # retrieve the default op for scaled_dot_product_attention - sdpa_op = torch.ops.auto_deploy.torch_attention_sdpa.default - - # construct the args for the ops based on the match_info and the op's schema - args = [] - for arg in sdpa_op._schema.arguments: - if arg.name in match_info: - args.append(match_info[arg.name]) - elif arg.has_default_value: - args.append(arg.default_value) - else: - raise ValueError(f"Missing required argument: {arg.name}") - args = tuple(args) - - # retrieve the final matmul node to know where to insert the sdpa node - final_matmul = match_info["final_matmul"] - - with graph.inserting_before(final_matmul): - sdpa_node = graph.call_function(sdpa_op, args=args) - - # Preserve metadata from the original node - sdpa_node.meta = final_matmul.meta.copy() - - # Replace all uses of the final matmul node with the sdpa node - final_matmul.replace_all_uses_with(sdpa_node) - - -def _replace_with_grouped_sdpa(graph, match_info: Dict[str, Node]) -> None: - """ - Replace the matched grouped attention pattern with torch.ops.auto_deploy.torch_attention_grouped_sdpa. - """ - sdpa_node = match_info["sdpa_node"] - query = match_info["query"] - key = match_info["key"] - value = match_info["value"] - - # Construct the new args and kwargs - args = (query, key, value) + sdpa_node.args[3:] - kwargs = sdpa_node.kwargs.copy() - - with graph.inserting_before(sdpa_node): - grouped_sdpa_node = graph.call_function( - torch.ops.auto_deploy.torch_attention_grouped_sdpa.default, args=args, kwargs=kwargs - ) - - # Preserve metadata from the original node - grouped_sdpa_node.meta = sdpa_node.meta.copy() - - # Replace all uses of the SDPA node with the grouped_sdpa node - sdpa_node.replace_all_uses_with(grouped_sdpa_node) - - -def _is_causal_mask(mask_node: Node) -> bool: - """ - Determine if a node represents a causal attention mask. - - Causal masks typically involve: - 1. Creating a matrix with very negative values (e.g., -inf or close to it) - 2. Using triu with offset 1 to create an upper triangular matrix - 3. Usually involves comparison operations (gt, lt) with position indices - - Returns True if the node appears to be a causal mask pattern. - """ - # Direct pattern from the test case: masked_fill with triu(ones,1) and -inf - if is_op(mask_node, torch.ops.aten.masked_fill): - mask_args = mask_node.args - if len(mask_args) >= 2: - _ = mask_args[0] # zero tensor - mask_tensor = mask_args[1] - fill_value = mask_args[2] if len(mask_args) > 2 else mask_node.kwargs.get("value", None) - - # Check if fill value is very negative (e.g., -inf) - if fill_value is not None and ( - fill_value == float("-inf") - or (isinstance(fill_value, (int, float)) and fill_value < -1e4) - ): - # Try to trace back to find a triu pattern - if _has_triu_ancestor(mask_tensor, offset=1): - return True - - # Pattern from negative_fill test case: masked_fill with ~triu(ones,1) and 0.0 - # The negative_fill pattern has a pre-filled tensor with very negative values - # and zeros in the lower triangle - if is_op(mask_node, torch.ops.aten.masked_fill): - mask_args = mask_node.args - if len(mask_args) >= 2: - negative_tensor = mask_args[0] - mask_tensor = mask_args[1] - fill_value = mask_args[2] if len(mask_args) > 2 else mask_node.kwargs.get("value", None) - - # Check if fill value is zero and the tensor is pre-filled with negative values - if fill_value == 0.0 or fill_value == 0: - # Check for the full tensor with negative values - if is_op(negative_tensor, torch.ops.aten.full): - fill_args = negative_tensor.args - if ( - len(fill_args) > 1 - and isinstance(fill_args[1], (int, float)) - and fill_args[1] < -1e4 - ): - # This is likely a negative-filled tensor - # Now check if the mask is a bitwise_not of triu - if is_op(mask_tensor, torch.ops.aten.bitwise_not): - if len(mask_tensor.args) > 0 and _has_triu_ancestor( - mask_tensor.args[0], offset=1 - ): - return True - - # Pattern for llama-3.1 style causal mask: slice of expand(unsqueeze(unsqueeze(mul_(triu, gt)))) - if is_op(mask_node, torch.ops.aten.slice): - # Follow the chain backward to the source of the slice - if len(mask_node.args) == 0: - return False - slice_source = mask_node.args[0] - - # Check for typical expand pattern - if not (slice_source and is_op(slice_source, torch.ops.aten.expand)): - return False - - # Continue tracing back through the pattern - if len(slice_source.args) == 0: - return False - expand_source = slice_source.args[0] - - # Check for first unsqueeze operation - if not (expand_source and is_op(expand_source, torch.ops.aten.unsqueeze)): - return False - - # Look for the source of first unsqueeze - if len(expand_source.args) == 0: - return False - first_unsqueeze_source = expand_source.args[0] - - # Check for second unsqueeze operation - if not (first_unsqueeze_source and is_op(first_unsqueeze_source, torch.ops.aten.unsqueeze)): - return False - - # Look for the source of the second unsqueeze - if len(first_unsqueeze_source.args) == 0: - return False - second_unsqueeze_source = first_unsqueeze_source.args[0] - - # Check for mul_ operation - if is_op(second_unsqueeze_source, torch.ops.aten.mul_): - # Check if one of the mul_ arguments is a triu operation - has_triu = False - for arg in second_unsqueeze_source.args: - if is_op(arg, torch.ops.aten.triu): - if len(arg.args) > 1 and arg.args[1] == 1: - has_triu = True - break - - if has_triu: - # Check if one of the mul_ arguments involves a full tensor with negative values - for arg in second_unsqueeze_source.args: - if is_op(arg, torch.ops.aten.full): - if ( - len(arg.args) > 1 - and isinstance(arg.args[1], (int, float)) - and arg.args[1] < -1e4 - ): - return True - - return has_triu - - # Original implementation for backward compatibility - if is_op(mask_node, torch.ops.aten.slice): - # Follow the chain backward to the source of the slice - if len(mask_node.args) == 0: - return False - slice_source = mask_node.args[0] - - # Check for typical expand pattern - if not (slice_source and is_op(slice_source, torch.ops.aten.expand)): - return False - - # Continue tracing back through the pattern - if len(slice_source.args) == 0: - return False - expand_source = slice_source.args[0] - - # Check for unsqueeze operations - if not (expand_source and is_op(expand_source, torch.ops.aten.unsqueeze)): - return False - - # Look for the source of the unsqueeze - if len(expand_source.args) == 0: - return False - unsqueeze_source = expand_source.args[0] - - if not unsqueeze_source: - return False - - # Check for triu pattern which is common in causal masks - if is_op(unsqueeze_source, torch.ops.aten.mul_): - for arg in unsqueeze_source.args: - if not is_op(arg, torch.ops.aten.triu): - continue - - if len(arg.args) <= 1: - continue - - triu_offset = arg.args[1] - # Causal masks typically use triu with offset 1 - if triu_offset == 1: - return True - - return False - - # Check if we have a full tensor filled with a very negative number - if not is_op(unsqueeze_source, torch.ops.aten.full): - return False - - if len(unsqueeze_source.args) <= 1: - return False - - fill_value = unsqueeze_source.args[1] - # Check if the fill value is very negative (likely -inf or close) - if isinstance(fill_value, float) and fill_value < -1e10: - return True - - # If we can't definitively identify it as causal, return False - return False - - -def _has_triu_ancestor(node: Node, offset: int = 1, depth: int = 0, max_depth: int = 5) -> bool: - """Helper function to find a triu operation in the ancestry of a node.""" - if depth > max_depth: # Prevent infinite recursion - return False - - if is_op(node, torch.ops.aten.triu): - if len(node.args) > 1 and node.args[1] == offset: - return True - - # Check if any of the arguments has a triu ancestor - for arg in node.args: - if isinstance(arg, Node) and _has_triu_ancestor(arg, offset, depth + 1, max_depth): - return True - - # Check if any of the kwargs has a triu ancestor - for value in node.kwargs.values(): - if isinstance(value, Node) and _has_triu_ancestor(value, offset, depth + 1, max_depth): - return True - - return False - - -def match_attention_layout(gm: GraphModule, attention_op: Type[AttentionDescriptor]) -> None: - """ - Match and transform attention operations to match the layout expected by the attention backend. - - If the attention backend expects 'bnsd' layout (batch, num_heads, seq_len, head_dim), which - is the default for SDPA operations, we don't need to transform anything. - - If the backend expects 'bsnd' layout (batch, seq_len, num_heads, head_dim), we insert - appropriate transposes before and after SDPA operations and replace them with bsnd_grouped_sdpa. - """ - # Get attention layout from attention_op - attention_layout = attention_op.get_attention_layout() - - # List of SDPA operations to look for - sdpa_ops = { - torch.ops.auto_deploy.torch_attention_sdpa, - torch.ops.auto_deploy.torch_attention_grouped_sdpa, - } - - graph = gm.graph - num_bsnd_patterns = 0 - - # Look for SDPA operations - for sdpa_node in list(graph.nodes): - if sdpa_node.op != "call_function" or not is_op(sdpa_node, sdpa_ops): - continue - - ad_logger.debug(f"Found SDPA node to transform for bsnd layout: {sdpa_node}") - - # Extract q, k, v inputs - q, k, v = sdpa_node.args[:3] - - # Check if we need to transpose the inputs - if attention_layout == "bsnd": - # Add transposes before the node (from bnsd to bsnd) - with graph.inserting_before(sdpa_node): - q_updated = graph.call_function(torch.ops.aten.transpose.int, args=(q, 1, 2)) - k_updated = graph.call_function(torch.ops.aten.transpose.int, args=(k, 1, 2)) - v_updated = graph.call_function(torch.ops.aten.transpose.int, args=(v, 1, 2)) - - # Preserve fake tensor in meta["val"] for the transposed inputs - q_updated.meta["val"] = q.meta["val"].transpose(1, 2) - k_updated.meta["val"] = k.meta["val"].transpose(1, 2) - v_updated.meta["val"] = v.meta["val"].transpose(1, 2) - elif attention_layout == "bnsd": - # we don't need to do anything... - q_updated = q - k_updated = k - v_updated = v - else: - raise ValueError(f"Unsupported attention layout: {attention_layout}") - - # Create bsnd_grouped_sdpa node with the same args as the original node - # but using the transposed inputs - with graph.inserting_before(sdpa_node): - source_sdpa_node = graph.call_function( - attention_op.get_source_attention_op(), - args=(q_updated, k_updated, v_updated) + sdpa_node.args[3:], - kwargs=sdpa_node.kwargs, - ) - - # Check if need to update the output node to match the layout - if attention_layout == "bsnd": - # Add transpose for the output (from bsnd back to bnsd) - with graph.inserting_after(source_sdpa_node): - output_updated = graph.call_function( - torch.ops.aten.transpose.int, args=(source_sdpa_node, 1, 2) - ) - - # Preserve fake tensor in meta["val"] for the transposed inputs - source_sdpa_node.meta["val"] = sdpa_node.meta["val"].transpose(1, 2).contiguous() - output_updated.meta["val"] = source_sdpa_node.meta["val"].transpose(1, 2) - elif attention_layout == "bnsd": - output_updated = source_sdpa_node - else: - raise ValueError(f"Unsupported attention layout: {attention_layout}") - - # Replace the old node with the transposed output - sdpa_node.replace_all_uses_with(output_updated) - - num_bsnd_patterns += 1 - - # Clean up the graph if we made any replacements - if num_bsnd_patterns: - canonicalize_graph(gm) - ad_logger.debug(f"Transformed graph for bsnd layout: {gm}") - - ad_logger.info(f"Found and matched {num_bsnd_patterns} attention layouts") diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py index 62a9d355602..618c8108f84 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py @@ -174,7 +174,7 @@ def _get_mem_info_in_mb(): memory_for_forward_pass = free_mem_pre - free_mem_post ad_logger.info(f"Memory for forward pass (MB): {memory_for_forward_pass}") - new_cache_size = free_mem_post * free_mem_ratio + current_cache_size + new_cache_size = free_mem_post * 1024 * 1024 * free_mem_ratio + current_cache_size new_num_pages = int(new_cache_size // (current_cache_size // current_num_pages)) # Need to sync all the GPUs diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py b/tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py index ae686690e8d..65e7f7f614c 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/library/rope.py @@ -141,6 +141,12 @@ def match_rope_pattern(gm: GraphModule) -> int: torch.randn(batch_size, num_heads, seq_len, head_dim, device="meta", dtype=torch.float16), torch.randn(batch_size, seq_len, head_dim // 2, device="meta", dtype=torch.float16), ] + # float32 input can change the graph when there's .float() in pattern + dummy_complex_2 = [ + torch.randn(batch_size, num_heads, seq_len, head_dim, device="meta", dtype=torch.float32), + torch.randn(batch_size, num_heads, seq_len, head_dim, device="meta", dtype=torch.float32), + torch.randn(batch_size, seq_len, head_dim // 2, device="meta", dtype=torch.float32), + ] register_ad_pattern( search_fn=_explicit_rope_pattern, replace_fn=_explicit_rope_repl, @@ -172,6 +178,16 @@ def match_rope_pattern(gm: GraphModule) -> int: }, scalar_workaround={"unsqueeze_dim": 1}, ) + register_ad_pattern( + search_fn=_complex_rope_pattern, + replace_fn=_complex_rope_repl, + patterns=patterns, + dummy_args=dummy_complex_2, + op_ignore_types={ + torch.ops.aten.reshape.default: (int,), + }, + scalar_workaround={"unsqueeze_dim": 1}, + ) num_matches = patterns.apply(graph) canonicalize_graph(gm) diff --git a/tensorrt_llm/_torch/auto_deploy/transformations/transform.py b/tensorrt_llm/_torch/auto_deploy/transformations/transform.py index a2f31644d5b..3844ce4d312 100644 --- a/tensorrt_llm/_torch/auto_deploy/transformations/transform.py +++ b/tensorrt_llm/_torch/auto_deploy/transformations/transform.py @@ -24,17 +24,10 @@ fuse_collectives, fuse_rmsnorm, insert_cached_attention, - match_attention_layout, - match_causal_attn_mask, - match_eager_attention, - match_grouped_attention, match_moe_pattern, - match_repeat_kv, match_rope_layout, match_rope_pattern, optimize_rope, - quantize, - quantize_moe, resize_kv_cache, sharding_transform_executor, update_in_out_nodes, @@ -63,6 +56,12 @@ def __call__(self, cm: CachedSequenceInterface) -> nn.Module: ############################################################################################ # RUN MODULAR INFERENCE OPTIMIZER FOR ALREADY-MIGRATED TRANSFORMS ############################################################################################ + # TODO (hg): default values that are not representable in YAML. + if "match_attention_layout" in self.ad_config.transforms: + self.ad_config.transforms[ + "match_attention_layout" + ].attention_op = AttentionRegistry.get(self.ad_config.attn_backend) + new_optimizer = ModularInferenceOptimizer(self.factory, self.ad_config.transforms) egm = new_optimizer(cm) @@ -71,28 +70,10 @@ def __call__(self, cm: CachedSequenceInterface) -> nn.Module: ############################################################################################ # RUN PATTERN MATCHER TRANSFORMATIONS TO STANDARDIZE GRAPH REPRESENTATION ############################################################################################ - # quantization - quantize(egm, self.factory.get_quant_config()) - quantize_moe(egm, self.factory.get_quant_config()) # Match MoE pattern match_moe_pattern(egm) - # Match repeat_kv pattern - match_repeat_kv(egm) - - # Match eager attention pattern - match_eager_attention(egm) - - # Match grouped attention pattern - match_grouped_attention(egm) - - # Match and optimize causal attention masks - match_causal_attn_mask(egm) - - # Match attention layout expected by our backend - match_attention_layout(egm, AttentionRegistry.get(self.ad_config.attn_backend)) - # Match rope match_rope_pattern(egm) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py b/tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py index 28e195b41eb..00b535dec61 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py @@ -153,6 +153,8 @@ def register_ad_pattern( 5. register_replacement can auto-generate `search_fn_pattern` if you input `example_inputs`, but that approach will fail when symbolic shapes are involved. Here we explicitly trace & convert via `fx_to_pattern`. + 6. The PatternMatcherPass would check num_users of the nodes, meaning that the pattern is required + to be functionally isolated, no intermediate nodes are shared with the rest of the graph. """ argnames = list(inspect.signature(search_fn).parameters.keys()) diff --git a/tests/integration/defs/accuracy/references/gsm8k.yaml b/tests/integration/defs/accuracy/references/gsm8k.yaml index 44d26e42452..5c942ed41b0 100644 --- a/tests/integration/defs/accuracy/references/gsm8k.yaml +++ b/tests/integration/defs/accuracy/references/gsm8k.yaml @@ -70,6 +70,11 @@ deepseek-ai/DeepSeek-R1: - quant_algo: FP8_BLOCK_SCALES spec_dec_algo: MTP accuracy: 95.413 +Qwen3/Qwen3-8B: + - accuracy: 87.1114 + - quant_algo: FP8 + kv_cache_quant_algo: FP8 + accuracy: 87.1114 Qwen3/Qwen3-30B-A3B: - quant_algo: FP8_BLOCK_SCALES accuracy: 84.36 diff --git a/tests/integration/defs/accuracy/test_cli_flow.py b/tests/integration/defs/accuracy/test_cli_flow.py index 1553838b95a..b30ec9c9124 100644 --- a/tests/integration/defs/accuracy/test_cli_flow.py +++ b/tests/integration/defs/accuracy/test_cli_flow.py @@ -50,11 +50,13 @@ def test_context_fmha_disabled(self): def test_context_fmha_fp32_acc(self): self.run(extra_summarize_args=["--enable_context_fmha_fp32_acc"]) + @skip_post_blackwell @pytest.mark.parametrize("precision", ["int8", "int4"]) def test_weight_only(self, precision: str): quant_algo = QuantAlgo.W8A16 if precision == "int8" else QuantAlgo.W4A16 self.run(quant_algo=quant_algo) + @skip_post_blackwell def test_int8_kv_cache(self): self.run(kv_cache_quant_algo=QuantAlgo.INT8) @@ -415,6 +417,7 @@ class TestVicuna7B(CliFlowAccuracyTestHarness): EAGLE_MODEL_NAME = "yuhuili/EAGLE-Vicuna-7B-v1.3" EAGLE_MODEL_PATH = f"{llm_models_root()}/EAGLE-Vicuna-7B-v1.3" + @skip_post_blackwell def test_lookahead(self, mocker): mocker.patch.object(CnnDailymail, "MAX_BATCH_SIZE", 8) @@ -425,6 +428,7 @@ def test_lookahead(self, mocker): ], extra_summarize_args=["--lookahead_config=[7,7,7]"]) + @skip_post_blackwell @parametrize_with_ids("cuda_graph", [False, True]) def test_medusa(self, cuda_graph, mocker): mocker.patch.object(self.__class__, "EXAMPLE_FOLDER", "medusa") @@ -1104,6 +1108,7 @@ def test_fp8_tp2pp2_manage_weights(self): @pytest.mark.skip_less_device(2) @pytest.mark.skip_less_device_memory(80000) + @skip_post_blackwell def test_weight_only_int4_tp2(self): self.run(quant_algo=QuantAlgo.W4A16, tp_size=2, @@ -1111,6 +1116,7 @@ def test_weight_only_int4_tp2(self): @pytest.mark.skip_less_device(2) @pytest.mark.skip_less_device_memory(80000) + @skip_post_blackwell def test_weight_only_int8_tp2(self): self.run(quant_algo=QuantAlgo.W8A16, tp_size=2, diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index 12b297aeaf3..fa9ab908976 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -533,3 +533,44 @@ def test_auto_dtype(self, overlap_scheduler): self.MODEL_PATH) as llm: task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + + +@pytest.mark.timeout(3600) +class TestQwen3_8B(LlmapiAccuracyTestHarness): + MODEL_NAME = "Qwen3/Qwen3-8B" + MODEL_PATH = f"{llm_models_root()}/Qwen3/Qwen3-8B-FP8" + + @pytest.mark.parametrize("overlap_scheduler", [False, True]) + def test_auto_dtype(self, overlap_scheduler): + ctx_server_config = { + "disable_overlap_scheduler": True, + "cuda_graph_config": None, + "cache_transceiver_config": { + "backend": "default" + } + } + gen_server_config = { + "disable_overlap_scheduler": overlap_scheduler, + "cuda_graph_config": None, + "cache_transceiver_config": { + "backend": "default" + } + } + disaggregated_server_config = { + "hostname": "localhost", + "port": 8000, + "backend": "pytorch", + "context_servers": { + "num_instances": 1, + "urls": ["localhost:8001"] + }, + "generation_servers": { + "num_instances": 1, + "urls": ["localhost:8002"] + } + } + with launch_disaggregated_llm(disaggregated_server_config, + ctx_server_config, gen_server_config, + self.MODEL_PATH) as llm: + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) diff --git a/tests/integration/defs/conftest.py b/tests/integration/defs/conftest.py index c79f1ffe7d2..f8c5f55d09a 100644 --- a/tests/integration/defs/conftest.py +++ b/tests/integration/defs/conftest.py @@ -1892,6 +1892,10 @@ def check_device_contain(keyword_list): get_sm_version() >= 100, reason="This test is not supported in post-Blackwell architecture") +skip_post_blackwell_ultra = pytest.mark.skipif( + get_sm_version() >= 103, + reason="This test is not supported in post-Blackwell-Ultra architecture") + skip_device_contain_gb200 = pytest.mark.skipif( check_device_contain(["GB200"]), reason="This test is not supported on GB200 or GB100") diff --git a/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py b/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py index 5ed5c3e2710..55971c3ad0e 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py +++ b/tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py @@ -27,21 +27,21 @@ MPI_REQUEST = MPI_TAG MPI_RESULT = MPI_TAG + 1 +MODEL_PATHS = { + "DeepSeek-V3-Lite-fp8": "DeepSeek-V3-Lite/fp8", + "TinyLlama-1.1B-Chat-v1.0": "llama-models-v2/TinyLlama-1.1B-Chat-v1.0", + "Llama-3.1-8B-Instruct": "llama-3.1-model/Llama-3.1-8B-Instruct/", + "EAGLE3-LLaMA3.1-Instruct-8B": "EAGLE3-LLaMA3.1-Instruct-8B", + "Qwen3-8B-FP8": "Qwen3/Qwen3-8B-FP8", +} + def model_path(model_name): llm_models_root = os.environ["LLM_MODELS_ROOT"] - if 'DeepSeek-V3-Lite-fp8' in model_name: - return os.path.join(llm_models_root, 'DeepSeek-V3-Lite', 'fp8') - elif 'TinyLlama-1.1B-Chat-v1.0' in model_name: - return os.path.join(llm_models_root, 'llama-models-v2', - 'TinyLlama-1.1B-Chat-v1.0') - elif 'Llama-3.1-8B-Instruct' in model_name: - return os.path.join(llm_models_root, 'llama-3.1-model', - 'Llama-3.1-8B-Instruct/') - elif 'EAGLE3-LLaMA3.1-Instruct-8B' in model_name: - return os.path.join(llm_models_root, 'EAGLE3-LLaMA3.1-Instruct-8B') - else: - raise ValueError(f"Unknown model: {model_name}") + for name, path in MODEL_PATHS.items(): + if name in model_name: + return os.path.join(llm_models_root, path) + raise ValueError(f"Unknown model: {model_name}") async def run_worker(kv_cache_config, cache_transceiver_config, pytorch_config, @@ -232,6 +232,22 @@ def test_disaggregated_simple_deepseek(model, generation_overlap, ]) +@skip_no_hopper +@pytest.mark.parametrize("model", ["Qwen3-8B-FP8"]) +@pytest.mark.parametrize("generation_overlap", [False, True]) +@pytest.mark.parametrize("enable_cuda_graph", [False, True]) +def test_disaggregated_simple_qwen3(model, generation_overlap, + enable_cuda_graph): + verify_disaggregated( + model, generation_overlap, enable_cuda_graph, + " What is the capital of China?", + " The capital of China is Beijing. 2. What is the population of China? The population of China is about 1", + [ + 576, 6722, 315, 5616, 374, 26549, 13, 220, 17, 13, 3555, 374, 279, + 7042, 315, 5616, 30, 576, 7042, 315, 5616, 374, 911, 220, 16 + ]) + + @pytest.mark.parametrize("model", ["DeepSeek-V3-Lite-fp8/fp8"]) @pytest.mark.parametrize("enable_cuda_graph", [False]) @pytest.mark.parametrize("generation_overlap", [False]) diff --git a/tests/integration/defs/examples/test_bert.py b/tests/integration/defs/examples/test_bert.py index f6bff3222f0..f0268325ea0 100644 --- a/tests/integration/defs/examples/test_bert.py +++ b/tests/integration/defs/examples/test_bert.py @@ -18,6 +18,12 @@ from defs.conftest import get_device_count, get_sm_version from defs.trt_test_alternative import check_call +# skip trt flow cases on post-Blackwell-Ultra +if get_sm_version() >= 103: + pytest.skip( + "TRT workflow tests are not supported on post Blackwell-Ultra architecture", + allow_module_level=True) + # # Build parameters @pytest.mark.parametrize( diff --git a/tests/integration/defs/examples/test_bindings.py b/tests/integration/defs/examples/test_bindings.py index 39aeb9fda4a..8f3684370b7 100644 --- a/tests/integration/defs/examples/test_bindings.py +++ b/tests/integration/defs/examples/test_bindings.py @@ -18,9 +18,15 @@ import pytest from defs.common import convert_weights, venv_check_call, venv_mpi_check_call -from defs.conftest import get_device_count +from defs.conftest import get_device_count, get_sm_version from defs.trt_test_alternative import check_call +# skip trt flow cases on post-Blackwell-Ultra +if get_sm_version() >= 103: + pytest.skip( + "TRT workflow tests are not supported on post Blackwell-Ultra architecture", + allow_module_level=True) + @pytest.fixture(scope="module") def bindings_example_root(llm_root): diff --git a/tests/integration/defs/examples/test_chatglm.py b/tests/integration/defs/examples/test_chatglm.py index 311cc621f99..37ee4a1c174 100644 --- a/tests/integration/defs/examples/test_chatglm.py +++ b/tests/integration/defs/examples/test_chatglm.py @@ -18,12 +18,20 @@ import pytest from defs.common import convert_weights, venv_check_call +from defs.conftest import get_sm_version, skip_post_blackwell from defs.trt_test_alternative import check_call, exists +# skip trt flow cases on post-Blackwell-Ultra +if get_sm_version() >= 103: + pytest.skip( + "TRT workflow tests are not supported on post Blackwell-Ultra architecture", + allow_module_level=True) + # TODO: add more test case for input_padding, paged_kv_cache, num_beams @pytest.mark.skip_less_device_memory(24000) -@pytest.mark.parametrize("use_weight_only", [True, False], +@pytest.mark.parametrize("use_weight_only", + [pytest.param(True, marks=skip_post_blackwell), False], ids=["enable_weight_only", "disable_weight_only"]) @pytest.mark.parametrize("llm_glm_4_9b_model_root", ["glm-4-9b", "glm-4-9b-chat"], diff --git a/tests/integration/defs/examples/test_commandr.py b/tests/integration/defs/examples/test_commandr.py index 2de725f5ee2..1c0fa612088 100644 --- a/tests/integration/defs/examples/test_commandr.py +++ b/tests/integration/defs/examples/test_commandr.py @@ -18,11 +18,19 @@ import pytest from defs.common import (convert_weights, generate_summary_cmd, venv_check_call, venv_mpi_check_call) -from defs.conftest import get_gpu_device_list +from defs.conftest import (get_gpu_device_list, get_sm_version, + skip_post_blackwell) from defs.trt_test_alternative import check_call +# skip trt flow cases on post-Blackwell-Ultra +if get_sm_version() >= 103: + pytest.skip( + "TRT workflow tests are not supported on post Blackwell-Ultra architecture", + allow_module_level=True) + @pytest.mark.skip_less_device_memory(80000) +@skip_post_blackwell @pytest.mark.parametrize("use_weight_only", [True, False], ids=["enable_weight_only", "disable_weight_only"]) def test_llm_commandr_v01_single_gpu_summary(commandr_example_root, @@ -79,7 +87,8 @@ def test_llm_commandr_v01_single_gpu_summary(commandr_example_root, @pytest.mark.skip_less_device(4) @pytest.mark.skip_less_device_memory(80000) @pytest.mark.skip_less_host_memory(1000000) -@pytest.mark.parametrize("use_weight_only", [True, False], +@pytest.mark.parametrize("use_weight_only", + [pytest.param(True, marks=skip_post_blackwell), False], ids=["enable_weight_only", "disable_weight_only"]) def test_llm_commandr_plus_4gpus_summary(commandr_example_root, llm_commandr_plus_model_root, diff --git a/tests/integration/defs/examples/test_draft_target_model.py b/tests/integration/defs/examples/test_draft_target_model.py index 3ee0221e94d..87f7644ebb8 100644 --- a/tests/integration/defs/examples/test_draft_target_model.py +++ b/tests/integration/defs/examples/test_draft_target_model.py @@ -19,10 +19,16 @@ import pytest from defs.common import convert_weights, venv_check_call, venv_mpi_check_call -from defs.conftest import (get_device_memory, llm_models_root, +from defs.conftest import (get_device_memory, get_sm_version, llm_models_root, skip_post_blackwell, skip_pre_hopper) from defs.trt_test_alternative import check_call +# skip trt flow cases on post-Blackwell-Ultra +if get_sm_version() >= 103: + pytest.skip( + "TRT workflow tests are not supported on post Blackwell-Ultra architecture", + allow_module_level=True) + # TODO: remove skip after enable Blackwell for Speculative Decoding @skip_post_blackwell diff --git a/tests/integration/defs/examples/test_eagle.py b/tests/integration/defs/examples/test_eagle.py index ad57339074a..fb9a4617bd2 100644 --- a/tests/integration/defs/examples/test_eagle.py +++ b/tests/integration/defs/examples/test_eagle.py @@ -18,9 +18,15 @@ import pytest from defs.common import (convert_weights, get_dummy_spec_decoding_heads, venv_check_call) -from defs.conftest import skip_post_blackwell, skip_pre_ada +from defs.conftest import get_sm_version, skip_post_blackwell, skip_pre_ada from defs.trt_test_alternative import check_call +# skip trt flow cases on post-Blackwell-Ultra +if get_sm_version() >= 103: + pytest.skip( + "TRT workflow tests are not supported on post Blackwell-Ultra architecture", + allow_module_level=True) + @skip_post_blackwell @pytest.mark.parametrize("use_dynamic_tree", [False, True], diff --git a/tests/integration/defs/examples/test_enc_dec.py b/tests/integration/defs/examples/test_enc_dec.py index 92ff672c84c..10be785b220 100644 --- a/tests/integration/defs/examples/test_enc_dec.py +++ b/tests/integration/defs/examples/test_enc_dec.py @@ -16,10 +16,16 @@ import pytest from defs.common import (convert_weights, quantize_data, venv_check_call, venv_mpi_check_call) -from defs.conftest import (get_device_count, skip_fp8_pre_ada, +from defs.conftest import (get_device_count, get_sm_version, skip_fp8_pre_ada, skip_post_blackwell) from defs.trt_test_alternative import check_call +# skip trt flow cases on post-Blackwell-Ultra +if get_sm_version() >= 103: + pytest.skip( + "TRT workflow tests are not supported on post Blackwell-Ultra architecture", + allow_module_level=True) + @pytest.mark.parametrize("use_fp8", [True, False], ids=["enable_fp8", "disable_fp8"]) @@ -38,8 +44,11 @@ ids=["enable_gemm_plugin", "disable_gemm_plugin"]) @pytest.mark.parametrize("data_type", ['bfloat16', 'float16', 'float32']) @pytest.mark.parametrize("enc_dec_model_root", [ - 't5-small', 'flan-t5-small', 'byt5-small', 'bart-large-cnn', - 'mbart-large-50-many-to-one-mmt', 'wmt14' + pytest.param('t5-small', marks=skip_post_blackwell), + pytest.param('flan-t5-small', marks=skip_post_blackwell), + pytest.param('byt5-small', marks=skip_post_blackwell), 'bart-large-cnn', + pytest.param('mbart-large-50-many-to-one-mmt', marks=skip_post_blackwell), + 'wmt14' ], indirect=True) @pytest.mark.parametrize("compare_hf_fp32", [True, False], diff --git a/tests/integration/defs/examples/test_exaone.py b/tests/integration/defs/examples/test_exaone.py index b0b3113ed2f..1d331c460d4 100644 --- a/tests/integration/defs/examples/test_exaone.py +++ b/tests/integration/defs/examples/test_exaone.py @@ -17,10 +17,17 @@ import pytest from defs.common import (convert_weights, generate_summary_cmd, venv_check_call, venv_mpi_check_call) -from defs.conftest import skip_post_blackwell +from defs.conftest import get_sm_version, skip_post_blackwell from defs.trt_test_alternative import check_call +# skip trt flow cases on post-Blackwell-Ultra +if get_sm_version() >= 103: + pytest.skip( + "TRT workflow tests are not supported on post Blackwell-Ultra architecture", + allow_module_level=True) + +@skip_post_blackwell @pytest.mark.parametrize("num_beams", [1, 2, 4], ids=lambda num_beams: f'nb:{num_beams}') @pytest.mark.parametrize("data_type", ['bfloat16', 'float16']) diff --git a/tests/integration/defs/examples/test_gpt.py b/tests/integration/defs/examples/test_gpt.py index 0e320a239f1..e20980c93bf 100644 --- a/tests/integration/defs/examples/test_gpt.py +++ b/tests/integration/defs/examples/test_gpt.py @@ -24,9 +24,16 @@ similarity_score, test_multi_lora_support, venv_check_call, venv_check_output, venv_mpi_check_call, venv_mpi_check_output) -from defs.conftest import get_device_memory, skip_fp8_pre_ada, skip_pre_ada +from defs.conftest import (get_device_memory, get_sm_version, skip_fp8_pre_ada, + skip_post_blackwell, skip_pre_ada) from defs.trt_test_alternative import check_call +# skip trt flow cases on post-Blackwell-Ultra +if get_sm_version() >= 103: + pytest.skip( + "TRT workflow tests are not supported on post Blackwell-Ultra architecture", + allow_module_level=True) + INPUT_TEXT_1 = "After Washington had returned to Williamsburg, " + \ "Dinwiddie ordered him to lead a larger force to assist Trent in his work. " + \ "While en route, Washington learned of Trent's retreat. " + \ @@ -688,6 +695,7 @@ def test_llm_gpt3_175b_1node_8gpus(gpt_example_root, llm_venv, engine_dir, ]) +@skip_post_blackwell @pytest.mark.parametrize("per_token_channel", [True, False], ids=["enable_ptpc", "disable_ptpc"]) def test_llm_gpt2_smooth_single_gpu_summary(gpt_example_root, llm_venv, @@ -732,6 +740,7 @@ def test_llm_gpt2_smooth_single_gpu_summary(gpt_example_root, llm_venv, ]) +@skip_post_blackwell def test_llm_gpt2_int8_kv_1gpu(gpt_example_root, llm_venv, llm_gpt2_model_root, llm_datasets_root, engine_dir, cmodel_dir): "gpt2 INT8 KV Cache test on 1 gpu" @@ -1360,6 +1369,7 @@ def test_llm_gpt2_starcoder_1node_4gpus(gpt_example_root, summary_cmd) +@skip_post_blackwell @pytest.mark.skip_less_host_memory(250000) def test_llm_gpt2_starcoder_1gpus(gpt_example_root, llm_gpt2_starcoder_model_root, llm_venv, @@ -1401,6 +1411,7 @@ def test_llm_gpt2_starcoder_1gpus(gpt_example_root, venv_check_call(llm_venv, summary_cmd) +@skip_post_blackwell @pytest.mark.skip_less_host_memory(250000) @pytest.mark.parametrize("dtype", ["float16"]) @pytest.mark.parametrize("precision", ["int8", "int4"]) @@ -1710,6 +1721,7 @@ def test_llm_gpt2_multi_lora_1gpu(gpt_example_root, llm_venv, for item in expected_output[idx]]), f"output is {output}" +@skip_post_blackwell @pytest.mark.skip_less_device_memory(50000) @pytest.mark.parametrize("data_type", ['float16', 'fp8'], ids=['base_fp16', 'base_fp8']) diff --git a/tests/integration/defs/examples/test_gptj.py b/tests/integration/defs/examples/test_gptj.py index e5337765e11..bc80a15017f 100644 --- a/tests/integration/defs/examples/test_gptj.py +++ b/tests/integration/defs/examples/test_gptj.py @@ -15,9 +15,15 @@ import pytest from defs.common import venv_check_call -from defs.conftest import get_gpu_device_list +from defs.conftest import get_gpu_device_list, get_sm_version from defs.trt_test_alternative import check_call +# skip trt flow cases on post-Blackwell-Ultra +if get_sm_version() >= 103: + pytest.skip( + "TRT workflow tests are not supported on post Blackwell-Ultra architecture", + allow_module_level=True) + INPUT_TEXT = """ Write a Python function `find_max(words)` to solve the following problem:\nWrite a function that accepts a list of strings.\nThe list contains different words. Return the word with maximum number\nof unique characters. If multiple strings have maximum number of unique\ncharacters, return the one which comes first in lexicographical order.\nfind_max(["name", "of", "string"]) == "string"\nfind_max(["name", "enam", "game"]) == "enam"\nfind_max(["aaaaaaa", "bb" ,"cc"]) == ""aaaaaaa" """ diff --git a/tests/integration/defs/examples/test_granite.py b/tests/integration/defs/examples/test_granite.py index 234a07c174c..f789c406565 100644 --- a/tests/integration/defs/examples/test_granite.py +++ b/tests/integration/defs/examples/test_granite.py @@ -19,8 +19,15 @@ import pytest from defs.common import (convert_weights, test_multi_lora_support, venv_mpi_check_call) +from defs.conftest import get_sm_version from defs.trt_test_alternative import check_call +# skip trt flow cases on post-Blackwell-Ultra +if get_sm_version() >= 103: + pytest.skip( + "TRT workflow tests are not supported on post Blackwell-Ultra architecture", + allow_module_level=True) + @pytest.fixture(scope="module", autouse=True) def disable_unified_converter(): diff --git a/tests/integration/defs/examples/test_internlm.py b/tests/integration/defs/examples/test_internlm.py index e6768bca1b1..144f86b32fa 100644 --- a/tests/integration/defs/examples/test_internlm.py +++ b/tests/integration/defs/examples/test_internlm.py @@ -14,9 +14,15 @@ # limitations under the License. import pytest from defs.common import convert_weights, parse_mpi_cmd, venv_mpi_check_call -from defs.conftest import get_device_memory +from defs.conftest import get_device_memory, get_sm_version from defs.trt_test_alternative import check_call +# skip trt flow cases on post-Blackwell-Ultra +if get_sm_version() >= 103: + pytest.skip( + "TRT workflow tests are not supported on post Blackwell-Ultra architecture", + allow_module_level=True) + # @pytest.mark.skip_less_device(2) @pytest.mark.parametrize("num_beams", [1, 2, 4], diff --git a/tests/integration/defs/examples/test_llama.py b/tests/integration/defs/examples/test_llama.py index 6063c53ed6d..debc21bf393 100644 --- a/tests/integration/defs/examples/test_llama.py +++ b/tests/integration/defs/examples/test_llama.py @@ -36,6 +36,12 @@ # yapf: enable from defs.trt_test_alternative import check_call, exists +# skip trt flow cases on post-Blackwell-Ultra +# if get_sm_version() >= 103: +# pytest.skip( +# "TRT workflow tests are not supported on post Blackwell-Ultra architecture", +# allow_module_level=True) + INPUT_TEXT_1 = "After Washington had returned to Williamsburg, " + \ "Dinwiddie ordered him to lead a larger force to assist Trent in his work. " + \ "While en route, Washington learned of Trent's retreat. " + \ @@ -688,6 +694,7 @@ def test_llm_llama_v2_1gpu_sparsity(llama_example_root, llama_model_root, ]) +@skip_post_blackwell @pytest.mark.parametrize("num_beams", [1], ids=lambda num_beams: f'nb:{num_beams}') @pytest.mark.parametrize("data_type", ['bfloat16', 'float16']) @@ -886,6 +893,7 @@ def test_llm_llama_v2_gather_logits_2gpu_pp2(llama_example_root, summary_cmd) +@skip_post_blackwell @pytest.mark.parametrize("llama_model_root", ['llama-v2-7b-hf'], indirect=True) def test_llm_llama_v2_1gpu_auto_parallel(llama_example_root, llama_model_root, llm_venv, cmodel_dir, engine_dir): @@ -911,6 +919,7 @@ def test_llm_llama_v2_1gpu_auto_parallel(llama_example_root, llama_model_root, check_call(" ".join(build_cmd), shell=True, env=llm_venv._new_env) +@skip_post_blackwell @pytest.mark.skip_less_device(2) @pytest.mark.parametrize("num_beams", [1, 4], ids=lambda num_beams: f'nb:{num_beams}') @@ -1622,6 +1631,7 @@ def test_llm_llama_v2_1gpu_fp8_gemv(llama_example_root, llama_model_root, venv_check_call(llm_venv, summary_cmd) +@skip_post_blackwell @pytest.mark.skip_less_device_memory(50000) @pytest.mark.parametrize("data_type", ['bfloat16', 'float16']) @pytest.mark.parametrize("gemm_swiglu_plugin", ["fp8"]) @@ -1697,7 +1707,12 @@ def test_llm_llama_v2_1gpu_gemm_swiglu(llama_example_root, llama_model_root, @pytest.mark.parametrize( - "data_type", ['float16', 'fp8', 'sq_ootb', 'awq', 'int8_wo'], + "data_type", [ + 'float16', 'fp8', + pytest.param('sq_ootb', marks=skip_post_blackwell), + pytest.param('awq', marks=skip_post_blackwell), + pytest.param('int8_wo', marks=skip_post_blackwell) + ], ids=['base_fp16', 'base_fp8', 'base_sq_ootb', 'base_awq', 'base_int8_wo']) @pytest.mark.parametrize("lora_data_type", ['float16'], ids=['lora_fp16']) @pytest.mark.parametrize("llama_model_root", ['llama-v2-13b-hf'], indirect=True) @@ -2280,6 +2295,7 @@ def test_llm_llama_code_llama_multi_gpus_summary(llama_example_root, summary_cmd) +@skip_post_blackwell @pytest.mark.skip_less_device_memory(30000) @pytest.mark.parametrize("num_beams", [1, 4], ids=lambda num_beams: f'nb:{num_beams}') @@ -2336,6 +2352,7 @@ def test_llm_llama_smooth_quant_1gpu_summary(llama_example_root, venv_check_call(llm_venv, summary_cmd) +@skip_post_blackwell @pytest.mark.skip_less_device_memory(30000) @pytest.mark.parametrize("num_beams", [1, 4], ids=lambda num_beams: f'nb:{num_beams}') @@ -2385,6 +2402,7 @@ def test_llm_llama_int8_kv_1gpu_summary(llama_example_root, llama_model_root, venv_check_call(llm_venv, summary_cmd) +@skip_post_blackwell @pytest.mark.skip_less_device_memory(30000) @pytest.mark.parametrize("num_beams", [1, 4], ids=lambda num_beams: f'nb:{num_beams}') @@ -2429,6 +2447,7 @@ def test_llm_llama_int8_sq_ootb_1gpu_summary( venv_check_call(llm_venv, summary_cmd) +@skip_post_blackwell @pytest.mark.skip_less_device(2) @pytest.mark.parametrize("num_beams", [1], ids=lambda num_beams: f'nb:{num_beams}') @@ -2488,6 +2507,7 @@ def test_llm_llama_v2_int8sq_2gpu_tp2(data_type, llama_example_root, summary_cmd) +@skip_post_blackwell @pytest.mark.skip_less_device_memory(30000) @pytest.mark.parametrize("num_beams", [1, 4], ids=lambda num_beams: f'nb:{num_beams}') @@ -2543,6 +2563,7 @@ def test_llm_llama_wo_1gpu_summary(llama_example_root, llama_model_root, venv_check_call(llm_venv, summary_cmd) +@skip_post_blackwell @pytest.mark.skip_less_device_memory(30000) @pytest.mark.parametrize("num_beams", [1, 4], ids=lambda num_beams: f'nb:{num_beams}') @@ -2872,7 +2893,9 @@ def test_llm_llama_v2_lora_benchmark_2gpu(llama_example_root, llama_model_root, @pytest.mark.skip_less_device(4) @pytest.mark.parametrize("num_beams", [1, 4], ids=lambda num_beams: f'nb:{num_beams}') -@pytest.mark.parametrize("qformat", ["fp8", "int4_awq"]) +@pytest.mark.parametrize( + "qformat", + ["fp8", pytest.param("int4_awq", marks=skip_post_blackwell)]) @pytest.mark.parametrize( "tp_pp_size", [(4, 1), (2, 2)], ids=lambda tp_pp_size: f'tp{tp_pp_size[0]}pp{tp_pp_size[1]}') @@ -3268,8 +3291,11 @@ def test_llm_llama_1gpu_streaming_llm(llama_example_root, deepseek_model_root, assert "上海人工智能实验室" in output, output -@pytest.mark.parametrize( - "fp8_quant", ['disable_fp8', 'enable_fp8', 'enable_fp8_meta_recipe']) +@pytest.mark.parametrize("fp8_quant", [ + 'disable_fp8', + pytest.param('enable_fp8', marks=skip_post_blackwell), + pytest.param('enable_fp8_meta_recipe', marks=skip_post_blackwell) +]) @pytest.mark.parametrize("llama_model_root", ['llama-3.1-8b', 'llama-3.2-1b'], indirect=True) def test_llm_llama_v3_1_1node_single_gpu(llama_example_root, llama_model_root, @@ -3581,6 +3607,7 @@ def test_llm_llama_v3_1_2nodes_8gpus(test_type, llama_example_root, check_call(" ".join(mmlu_cmd), shell=True, env=llm_venv._new_env) +@skip_post_blackwell @pytest.mark.skip_less_device_memory(50000) @pytest.mark.parametrize("low_latency_gemm_plugin", ["fp8"]) @pytest.mark.parametrize("llama_model_root", ['llama-v2-7b-hf'], indirect=True) @@ -3813,6 +3840,7 @@ def test_llm_llama_v2_fp8_2gpu_cp2(data_type, llama_example_root, @skip_pre_ada +@skip_post_blackwell @pytest.mark.parametrize("llama_model_root", ['llama-3.1-8b', 'llama-3.2-1b'], indirect=True) def test_llm_llama_lookahead_xqa_fp8_1gpu(llama_example_root, llama_model_root, @@ -4014,6 +4042,7 @@ def test_mistral_nemo_fp8_with_bf16_lora( ) +@skip_post_blackwell @pytest.mark.parametrize("llama_model_root", ['llama-3.1-8b'], indirect=True) def test_llm_llama_lookahead_single_gpu_summary(llama_example_root, llama_model_root, llm_venv, diff --git a/tests/integration/defs/examples/test_mamba.py b/tests/integration/defs/examples/test_mamba.py index 042abbcb257..c771278bdbe 100644 --- a/tests/integration/defs/examples/test_mamba.py +++ b/tests/integration/defs/examples/test_mamba.py @@ -18,16 +18,24 @@ import pytest from defs.common import (convert_weights, generate_summary_cmd, venv_check_call, venv_mpi_check_call) +from defs.conftest import get_sm_version, skip_post_blackwell from defs.trt_test_alternative import check_call +# skip trt flow cases on post-Blackwell-Ultra +if get_sm_version() >= 103: + pytest.skip( + "TRT workflow tests are not supported on post Blackwell-Ultra architecture", + allow_module_level=True) + @pytest.mark.parametrize("gemm_plugin", [True, False], ids=["enable_gemm_plugin", "disable_gemm_plugin"]) @pytest.mark.parametrize("dtype", ['bfloat16', 'float16']) @pytest.mark.parametrize("mamba_model_root", [ - 'mamba-130m', 'mamba-2.8b', 'mamba-1.4b', 'mamba-790m', 'mamba-370m', - 'mamba2-130m', 'mamba2-2.7b', 'mamba2-1.3b', 'mamba2-780m', 'mamba2-370m', - 'mamba-codestral-7B-v0.1' + pytest.param('mamba-130m', marks=skip_post_blackwell), 'mamba-2.8b', + 'mamba-1.4b', 'mamba-790m', 'mamba-370m', 'mamba2-130m', 'mamba2-2.7b', + 'mamba2-1.3b', 'mamba2-780m', 'mamba2-370m', + pytest.param('mamba-codestral-7B-v0.1', marks=skip_post_blackwell) ], indirect=True) def test_llm_mamba_1gpu(mamba_example_root, mamba_model_root, diff --git a/tests/integration/defs/examples/test_medusa.py b/tests/integration/defs/examples/test_medusa.py index f0d94ca2e1c..34a21fcd878 100644 --- a/tests/integration/defs/examples/test_medusa.py +++ b/tests/integration/defs/examples/test_medusa.py @@ -18,10 +18,17 @@ import pytest from defs.common import (convert_weights, get_dummy_spec_decoding_heads, venv_check_call) -from defs.conftest import skip_fp8_pre_ada +from defs.conftest import get_sm_version, skip_fp8_pre_ada, skip_post_blackwell from defs.trt_test_alternative import check_call +# skip trt flow cases on post-Blackwell-Ultra +if get_sm_version() >= 103: + pytest.skip( + "TRT workflow tests are not supported on post Blackwell-Ultra architecture", + allow_module_level=True) + +@skip_post_blackwell @pytest.mark.parametrize("batch_size", [1, 8], ids=['bs1', 'bs8']) @pytest.mark.parametrize("data_type", ['bfloat16']) @pytest.mark.parametrize("num_medusa_heads", [4], ids=['4-heads']) @@ -79,6 +86,7 @@ def test_llm_medusa_1gpu(batch_size, data_type, medusa_model_roots, venv_check_call(llm_venv, summary_cmd) +@skip_post_blackwell @pytest.mark.parametrize("batch_size", [1, 8], ids=['bs1', 'bs8']) @pytest.mark.parametrize("data_type", ['bfloat16', 'float16']) @pytest.mark.parametrize("num_medusa_heads", [4], ids=['4-heads']) diff --git a/tests/integration/defs/examples/test_mistral.py b/tests/integration/defs/examples/test_mistral.py index 2e0314ad4ce..b80237d7326 100644 --- a/tests/integration/defs/examples/test_mistral.py +++ b/tests/integration/defs/examples/test_mistral.py @@ -20,9 +20,15 @@ import pytest from defs.common import (convert_weights, quantize_data, test_multi_lora_support, venv_check_call) -from defs.conftest import skip_post_blackwell, skip_pre_ada +from defs.conftest import get_sm_version, skip_post_blackwell, skip_pre_ada from defs.trt_test_alternative import check_call +# skip trt flow cases on post-Blackwell-Ultra +if get_sm_version() >= 103: + pytest.skip( + "TRT workflow tests are not supported on post Blackwell-Ultra architecture", + allow_module_level=True) + def get_optimal_jobs(): cpu_count = multiprocessing.cpu_count() diff --git a/tests/integration/defs/examples/test_mixtral.py b/tests/integration/defs/examples/test_mixtral.py index 9985a586f92..c12be2c69b6 100644 --- a/tests/integration/defs/examples/test_mixtral.py +++ b/tests/integration/defs/examples/test_mixtral.py @@ -19,9 +19,16 @@ import pytest from defs.common import (convert_weights, generate_summary_cmd, quantize_data, venv_check_call, venv_mpi_check_call) -from defs.conftest import llm_models_root, skip_post_blackwell, skip_pre_ada +from defs.conftest import (get_sm_version, llm_models_root, skip_post_blackwell, + skip_pre_ada) from defs.trt_test_alternative import check_call +# skip trt flow cases on post-Blackwell-Ultra +if get_sm_version() >= 103: + pytest.skip( + "TRT workflow tests are not supported on post Blackwell-Ultra architecture", + allow_module_level=True) + @skip_post_blackwell @pytest.mark.parametrize("model_name", ['mixtral-8x7b-v0.1-AWQ']) diff --git a/tests/integration/defs/examples/test_multimodal.py b/tests/integration/defs/examples/test_multimodal.py index 25b2d45d539..591205e9afd 100644 --- a/tests/integration/defs/examples/test_multimodal.py +++ b/tests/integration/defs/examples/test_multimodal.py @@ -18,9 +18,16 @@ import pytest import torch from defs.common import convert_weights, venv_check_call, venv_mpi_check_call -from defs.conftest import get_device_memory, skip_post_blackwell, skip_pre_ada +from defs.conftest import (get_device_memory, get_sm_version, + skip_post_blackwell, skip_pre_ada) from defs.trt_test_alternative import check_call +# skip trt flow cases on post-Blackwell-Ultra +if get_sm_version() >= 103: + pytest.skip( + "TRT workflow tests are not supported on post Blackwell-Ultra architecture", + allow_module_level=True) + @pytest.fixture(scope="module") def multimodal_example_root(llm_root): @@ -623,19 +630,19 @@ def _test_llm_multimodal_general(llm_venv, reason="Skip due to low memory")), 'llava-onevision-qwen2-7b-ov-hf', 'llava-onevision-qwen2-7b-ov-hf-video', - 'nougat-base', + pytest.param('nougat-base', marks=skip_post_blackwell), 'VILA1.5-3b', 'cogvlm-chat', 'fuyu-8b', - 'deplot', + pytest.param('deplot', marks=skip_post_blackwell), pytest.param('neva-22b', marks=pytest.mark.skip(reason="RCCA https://nvbugs/5220761")), 'kosmos-2', - 'video-neva', + pytest.param('video-neva', marks=skip_post_blackwell), pytest.param('Phi-3-vision-128k-instruct', marks=skip_post_blackwell), pytest.param('Phi-3.5-vision-instruct', marks=skip_post_blackwell), pytest.param('Phi-4-multimodal-instruct', marks=skip_post_blackwell), - 'Llama-3.2-11B-Vision', + pytest.param('Llama-3.2-11B-Vision', marks=skip_post_blackwell), 'Qwen2-VL-7B-Instruct', 'internlm-xcomposer2-vl-7b', 'Mistral-Small-3.1-24B-Instruct-2503', @@ -688,8 +695,8 @@ def test_llm_multimodal_general(llm_venv, llm_root, llm_datasets_root, 'Phi-3-vision-128k-instruct', 'Phi-3.5-vision-instruct', 'Phi-4-multimodal-instruct', - 'Llama-3.2-11B-Vision-Instruct', - 'Llama-3.2-11B-Vision', + pytest.param('Llama-3.2-11B-Vision-Instruct', marks=skip_post_blackwell), + pytest.param('Llama-3.2-11B-Vision', marks=skip_post_blackwell), 'Qwen2-VL-7B-Instruct', ], indirect=True) diff --git a/tests/integration/defs/examples/test_nemotron.py b/tests/integration/defs/examples/test_nemotron.py index eaaf80b4e7c..061ca111224 100644 --- a/tests/integration/defs/examples/test_nemotron.py +++ b/tests/integration/defs/examples/test_nemotron.py @@ -15,9 +15,15 @@ import pytest from defs.common import venv_check_call, venv_mpi_check_call -from defs.conftest import skip_fp8_pre_ada +from defs.conftest import get_sm_version, skip_fp8_pre_ada from defs.trt_test_alternative import check_call +# skip trt flow cases on post-Blackwell-Ultra +if get_sm_version() >= 103: + pytest.skip( + "TRT workflow tests are not supported on post Blackwell-Ultra architecture", + allow_module_level=True) + @pytest.mark.skip_less_device_memory(50000) @pytest.mark.parametrize("qformat", ["full_prec", "fp8", "int4_awq"]) diff --git a/tests/integration/defs/examples/test_nemotron_nas.py b/tests/integration/defs/examples/test_nemotron_nas.py index a469dbc0a2d..d1663eab672 100644 --- a/tests/integration/defs/examples/test_nemotron_nas.py +++ b/tests/integration/defs/examples/test_nemotron_nas.py @@ -2,9 +2,15 @@ import pytest from defs.common import convert_weights, venv_check_call, venv_mpi_check_call -from defs.conftest import get_device_memory +from defs.conftest import get_device_memory, get_sm_version from defs.trt_test_alternative import check_call +# skip trt flow cases on post-Blackwell-Ultra +if get_sm_version() >= 103: + pytest.skip( + "TRT workflow tests are not supported on post Blackwell-Ultra architecture", + allow_module_level=True) + ROUGE1_ACCURACY_THRESHOLD = 20 diff --git a/tests/integration/defs/examples/test_ngram.py b/tests/integration/defs/examples/test_ngram.py index dec643ad5ea..2de49e8322f 100644 --- a/tests/integration/defs/examples/test_ngram.py +++ b/tests/integration/defs/examples/test_ngram.py @@ -18,9 +18,15 @@ import pytest from defs.common import convert_weights, venv_check_call -from defs.conftest import skip_post_blackwell +from defs.conftest import get_sm_version, skip_post_blackwell from defs.trt_test_alternative import check_call +# skip trt flow cases on post-Blackwell-Ultra +if get_sm_version() >= 103: + pytest.skip( + "TRT workflow tests are not supported on post Blackwell-Ultra architecture", + allow_module_level=True) + # TODO: remove skip after support NGram on B200 @skip_post_blackwell diff --git a/tests/integration/defs/examples/test_phi.py b/tests/integration/defs/examples/test_phi.py index aacd8ca4988..f15bf5773c3 100644 --- a/tests/integration/defs/examples/test_phi.py +++ b/tests/integration/defs/examples/test_phi.py @@ -23,6 +23,12 @@ skip_post_blackwell, skip_pre_ada) from defs.trt_test_alternative import check_call +# skip trt flow cases on post-Blackwell-Ultra +if get_sm_version() >= 103: + pytest.skip( + "TRT workflow tests are not supported on post Blackwell-Ultra architecture", + allow_module_level=True) + @pytest.fixture(scope="module") def phi_example_root(llm_root, llm_venv): @@ -36,6 +42,7 @@ def phi_example_root(llm_root, llm_venv): return example_root +@skip_post_blackwell @pytest.mark.skip_less_device_memory(40000) @pytest.mark.parametrize("num_beams", [1, 2, 4], ids=lambda num_beams: f'nb:{num_beams}') diff --git a/tests/integration/defs/examples/test_qwen.py b/tests/integration/defs/examples/test_qwen.py index c859f6192e7..44820cc7d85 100644 --- a/tests/integration/defs/examples/test_qwen.py +++ b/tests/integration/defs/examples/test_qwen.py @@ -20,10 +20,16 @@ import pytest from defs.common import (convert_weights, test_multi_lora_support, venv_check_call, venv_mpi_check_call) -from defs.conftest import (get_device_count, get_device_memory, +from defs.conftest import (get_device_count, get_device_memory, get_sm_version, skip_post_blackwell, skip_pre_ada) from defs.trt_test_alternative import check_call +# skip trt flow cases on post-Blackwell-Ultra +if get_sm_version() >= 103: + pytest.skip( + "TRT workflow tests are not supported on post Blackwell-Ultra architecture", + allow_module_level=True) + @pytest.mark.parametrize( "context_fmha_type", diff --git a/tests/integration/defs/examples/test_qwenvl.py b/tests/integration/defs/examples/test_qwenvl.py index 1ecda591387..7ff74560809 100644 --- a/tests/integration/defs/examples/test_qwenvl.py +++ b/tests/integration/defs/examples/test_qwenvl.py @@ -20,8 +20,15 @@ import pytest from defs.common import venv_check_call, venv_check_output +from defs.conftest import get_sm_version from defs.trt_test_alternative import check_call +# skip trt flow cases on post-Blackwell-Ultra +if get_sm_version() >= 103: + pytest.skip( + "TRT workflow tests are not supported on post Blackwell-Ultra architecture", + allow_module_level=True) + @pytest.fixture(scope="module") def qwenvl_example_root(llm_root, llm_venv): diff --git a/tests/integration/defs/examples/test_recurrentgemma.py b/tests/integration/defs/examples/test_recurrentgemma.py index 5febed25179..980a460cc15 100644 --- a/tests/integration/defs/examples/test_recurrentgemma.py +++ b/tests/integration/defs/examples/test_recurrentgemma.py @@ -19,10 +19,17 @@ import pytest from defs.common import (convert_weights, generate_summary_cmd, quantize_data, venv_check_call, venv_mpi_check_call) -from defs.conftest import skip_fp8_pre_ada +from defs.conftest import get_sm_version, skip_fp8_pre_ada, skip_post_blackwell from defs.trt_test_alternative import check_call +# skip trt flow cases on post-Blackwell-Ultra +if get_sm_version() >= 103: + pytest.skip( + "TRT workflow tests are not supported on post Blackwell-Ultra architecture", + allow_module_level=True) + +@skip_post_blackwell @pytest.mark.parametrize("gemm_plugin", [True, False], ids=["enable_gemm_plugin", "disable_gemm_plugin"]) @pytest.mark.parametrize("gpt_attention_plugin", [True, False], diff --git a/tests/integration/defs/examples/test_redrafter.py b/tests/integration/defs/examples/test_redrafter.py index d27e8772e41..ce9a62d097b 100644 --- a/tests/integration/defs/examples/test_redrafter.py +++ b/tests/integration/defs/examples/test_redrafter.py @@ -15,9 +15,17 @@ import pytest from defs.common import convert_weights, venv_check_call +from defs.conftest import get_sm_version, skip_post_blackwell from defs.trt_test_alternative import check_call +# skip trt flow cases on post-Blackwell-Ultra +if get_sm_version() >= 103: + pytest.skip( + "TRT workflow tests are not supported on post Blackwell-Ultra architecture", + allow_module_level=True) + +@skip_post_blackwell @pytest.mark.parametrize("batch_size", [8], ids=['bs8']) @pytest.mark.parametrize("redrafter_num_beams", [5, 8], ids=['nb5', 'nb8']) @pytest.mark.parametrize("redrafter_draft_len_per_beam", [5], ids=['dl5']) diff --git a/tests/integration/defs/examples/test_whisper.py b/tests/integration/defs/examples/test_whisper.py index 9cb30b3f38d..d66ff738db5 100644 --- a/tests/integration/defs/examples/test_whisper.py +++ b/tests/integration/defs/examples/test_whisper.py @@ -15,9 +15,15 @@ import pytest from defs.common import convert_weights, venv_check_call -from defs.conftest import skip_post_blackwell +from defs.conftest import get_sm_version, skip_post_blackwell from defs.trt_test_alternative import check_call +# skip trt flow cases on post-Blackwell-Ultra +if get_sm_version() >= 103: + pytest.skip( + "TRT workflow tests are not supported on post Blackwell-Ultra architecture", + allow_module_level=True) + @skip_post_blackwell @pytest.mark.parametrize("use_cpp_runtime", [True, False], diff --git a/tests/integration/test_lists/qa/examples_test_list.txt b/tests/integration/test_lists/qa/examples_test_list.txt index c93c81a169a..6c79e873b74 100644 --- a/tests/integration/test_lists/qa/examples_test_list.txt +++ b/tests/integration/test_lists/qa/examples_test_list.txt @@ -488,6 +488,8 @@ accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True] accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[False] accuracy/test_disaggregated_serving.py::TestLlama4ScoutInstruct::test_auto_dtype[True] +accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[False] +accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[True] accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_fp8[throughput_latency] accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_fp8[latency] accuracy/test_llm_api_pytorch.py::TestQwen3_235B_A22B::test_nvfp4[latency_moe_cutlass] @@ -608,6 +610,10 @@ disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_att disaggregated/test_disaggregated.py::test_disaggregated_load_balance[TinyLlama-1.1B-Chat-v1.0] disaggregated/test_disaggregated.py::test_disaggregated_cache_aware_balance[TinyLlama-1.1B-Chat-v1.0] disaggregated/test_disaggregated.py::test_disaggregated_trtllm_sampler[TinyLlama-1.1B-Chat-v1.0] +disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_qwen3[False-False-Qwen3-8B-FP8] +disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_qwen3[False-True-Qwen3-8B-FP8] +disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_qwen3[True-False-Qwen3-8B-FP8] +disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_qwen3[True-True-Qwen3-8B-FP8] disaggregated/test_workers.py::test_workers_conditional_disaggregation[TinyLlama-1.1B-Chat-v1.0] disaggregated/test_workers.py::test_workers_kv_cache_events[TinyLlama-1.1B-Chat-v1.0] disaggregated/test_workers.py::test_workers_kv_cache_aware_router[TinyLlama-1.1B-Chat-v1.0] diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index e99d33c1a08..fd6b7387235 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -39,6 +39,8 @@ l0_dgx_h100: - accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[False] + - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[False] + - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_auto_dtype[True] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp1pp2] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp1pp2] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp1] diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 21cf1253aaf..c8049a689be 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -61,6 +61,10 @@ l0_h100: - disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_deepseek[False-True-DeepSeek-V3-Lite-fp8/fp8] - disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_deepseek[True-False-DeepSeek-V3-Lite-fp8/fp8] - disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_deepseek[True-True-DeepSeek-V3-Lite-fp8/fp8] TIMEOUT (90) + - disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_qwen3[False-False-Qwen3-8B-FP8] + - disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_qwen3[False-True-Qwen3-8B-FP8] + - disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_qwen3[True-False-Qwen3-8B-FP8] + - disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_qwen3[True-True-Qwen3-8B-FP8] - disaggregated/test_disaggregated.py::test_disaggregated_load_balance[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_disaggregated.py::test_disaggregated_cache_aware_balance[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_disaggregated.py::test_disaggregated_conditional[TinyLlama-1.1B-Chat-v1.0] diff --git a/tests/integration/test_lists/waives.txt b/tests/integration/test_lists/waives.txt index 895f0c866e2..2e1ae548762 100644 --- a/tests/integration/test_lists/waives.txt +++ b/tests/integration/test_lists/waives.txt @@ -45,193 +45,41 @@ examples/test_nemotron.py::test_llm_nemotron_3_8b_1gpu[bfloat16-fp8] SKIP (https examples/test_mistral.py::test_llm_mistral_v1_1gpu[mistral-7b-v0.1-float16-max_attention_window_size_4096-chunked_summarization_long] SKIP (https://nvbugs/5321371) test_e2e.py::test_openai_chat_structural_tag_example SKIP (https://nvbugspro.nvidia.com/bug/5375594) cpp/test_e2e.py::test_model[fp8-chatglm-90] SKIP (https://nvbugs/5034830) -full:B200_PCIe/examples/test_mamba.py::test_llm_mamba_1gpu[mamba2-130m-float16-enable_gemm_plugin] SKIP (Disable for Blackwell) -full:B200_PCIe/examples/test_mamba.py::test_llm_mamba_1gpu[mamba2-130m-float16-disable_gemm_plugin] SKIP (Disable for Blackwell) -full:B200_PCIe/examples/test_mamba.py::test_llm_mamba_1gpu[mamba-codestral-7B-v0.1-float16-enable_gemm_plugin] SKIP (Disable for Blackwell) -full:B200_PCIe/examples/test_mamba.py::test_llm_mamba_1gpu[mamba-codestral-7B-v0.1-float16-disable_gemm_plugin] SKIP (Disable for Blackwell) -full:B200_PCIe/examples/test_medusa.py::test_llm_medusa_1gpu[use_cpp_session-medusa-vicuna-7b-v1.3-4-heads-bfloat16-bs1] SKIP (Disable for Blackwell) -full:B200_PCIe/examples/test_llama.py::test_llm_llama_v2_1gpu_gemm_swiglu[llama-v2-7b-hf-fp8-float16] SKIP (Disable for Blackwell) -full:B200_PCIe/examples/test_llama.py::test_llm_llama_int8_kv_1gpu_summary[llama-7b-enable_weight_only-nb:4] SKIP (Disable for Blackwell) -full:B200_PCIe/examples/test_llama.py::test_llm_llama_int8_sq_ootb_1gpu_summary[llama-7b-nb:1] SKIP (Disable for Blackwell) -full:B200_PCIe/examples/test_llama.py::test_llm_llama_wo_1gpu_summary[llama-7b-int4-nb:1] SKIP (Disable for Blackwell) -full:B200_PCIe/examples/test_llama.py::test_llm_llama_int8_kv_awq_1gpu_summary[llama-7b-nb:4] SKIP (Disable for Blackwell) -full:B200_PCIe/examples/test_llama.py::test_llm_llama_v2_1gpu_low_latency_gemm[llama-v2-7b-hf-fp8] SKIP (Disable for Blackwell) -full:B200_PCIe/examples/test_medusa.py::test_llm_medusa_1gpu[use_cpp_session-medusa-vicuna-7b-v1.3-4-heads-bfloat16-bs8] SKIP (Disable for Blackwell) -full:B200_PCIe/examples/test_redrafter.py::test_llm_redrafter_1gpu[use_cpp_session-redrafter-vicuna-7b-v1.3-bfloat16-dl5-nb5-bs8] SKIP (Disable for Blackwell spec decoding) -full:B200_PCIe/examples/test_redrafter.py::test_llm_redrafter_1gpu[use_cpp_session-redrafter-vicuna-7b-v1.3-bfloat16-dl5-nb8-bs8] SKIP (Disable for Blackwell spec decoding) -full:B200_PCIe/examples/test_redrafter.py::test_llm_redrafter_1gpu[use_py_session-redrafter-vicuna-7b-v1.3-bfloat16-dl5-nb5-bs8] SKIP (Disable for Blackwell spec decoding) -full:B200_PCIe/examples/test_redrafter.py::test_llm_redrafter_1gpu[use_py_session-redrafter-vicuna-7b-v1.3-bfloat16-dl5-nb8-bs8] SKIP (Disable for Blackwell spec decoding) -full:B200_PCIe/accuracy/test_cli_flow.py::TestGpt2::test_weight_only[int8] SKIP (Disable for Blackwell) -full:B200_PCIe/accuracy/test_cli_flow.py::TestGpt2::test_weight_only[int4] SKIP (Disable for Blackwell) -full:B200_PCIe/accuracy/test_cli_flow.py::TestGpt2::test_smooth_quant[per_token=False-per_channel=False] SKIP (Disable for Blackwell) -full:B200_PCIe/accuracy/test_cli_flow.py::TestGpt2::test_smooth_quant[per_token=True-per_channel=True] SKIP (Disable for Blackwell) -full:B200_PCIe/examples/test_exaone.py::test_llm_exaone_1gpu[enable_weight_only-exaone_3.0_7.8b_instruct-float16-nb:1] SKIP (Disable for Blackwell) -full:B200_PCIe/examples/test_gpt.py::test_llm_gpt_starcoder_lora_1gpu[peft-lora-starcoder2-15b-unity-copilot-starcoder2-lora_fp16-base_fp16] SKIP (Disable for Blackwell) -full:B200_PCIe/examples/test_gpt.py::test_llm_gpt_starcoder_lora_1gpu[peft-lora-starcoder2-15b-unity-copilot-starcoder2-lora_fp16-base_fp8] SKIP (Disable for Blackwell) -full:B200_PCIe/examples/test_llama.py::test_llm_llama_v2_lora_1gpu[chinese-llama-2-lora-13b-llama-v2-13b-hf-lora_fp16-base_awq] SKIP (Disable for Blackwell) -full:B200_PCIe/examples/test_llama.py::test_llm_llama_v2_lora_1gpu[chinese-llama-2-lora-13b-llama-v2-13b-hf-lora_fp16-base_int8_wo] SKIP (Disable for Blackwell) -full:B200_PCIe/examples/test_phi.py::test_llm_phi_single_gpu_summary[Phi-3-mini-128k-instruct-bfloat16-enable_gemm_plugin-enable_attention_plugin-enable_fmha_with_fp32_acc-nb:1] SKIP (Disable for Blackwell) -full:B200_PCIe/examples/test_phi.py::test_llm_phi_single_gpu_summary[Phi-3-small-8k-instruct-bfloat16-enable_gemm_plugin-enable_attention_plugin-enable_fmha_with_fp32_acc-nb:1] SKIP (Disable for Blackwell) -full:B200_PCIe/examples/test_phi.py::test_llm_phi_single_gpu_summary[Phi-3.5-mini-instruct-bfloat16-enable_gemm_plugin-enable_attention_plugin-enable_fmha_with_fp32_acc-nb:1] SKIP (Disable for Blackwell) full:B200_PCIe/unittest/trt/functional SKIP (Disable for Blackwell) full:B200_PCIe/unittest/trt/quantization SKIP (Disable for Blackwell) -full:B200_PCIe/accuracy/test_cli_flow.py::TestVicuna7B::test_medusa[cuda_graph=False] SKIP (Disable for Blackwell) -full:B200_PCIe/accuracy/test_cli_flow.py::TestVicuna7B::test_medusa[cuda_graph=True] SKIP (Disable for Blackwell) -full:B200_PCIe/accuracy/test_cli_flow.py::TestVicuna7B::test_lookahead SKIP (Disable for Blackwell) full:B200_PCIe/unittest/trt/attention/test_bert_attention.py SKIP (Disable for Blackwell) full:B200_PCIe/unittest/trt/model/test_mamba.py SKIP (Disable for Blackwell) -full:B200_PCIe/examples/test_medusa.py::test_llm_medusa_with_qaunt_base_model_1gpu[fp8-use_cpp_session-medusa-vicuna-7b-v1.3-4-heads-float16-bs1] SKIP (Disable for Blackwell) -full:B200_PCIe/examples/test_medusa.py::test_llm_medusa_with_qaunt_base_model_1gpu[fp8-use_py_session-medusa-vicuna-7b-v1.3-4-heads-float16-bs1] SKIP (Disable for Blackwell) full:B200_PCIe/unittest/bindings SKIP (Disable for Blackwell) full:B200_PCIe/unittest/trt/attention/test_sage_attention.py unittest/llmapi/test_llm_download.py unittest/llmapi/test_llm_kv_cache_events.py unittest/trt/model/redrafter unittest/trt/model/test_phi.py unittest/trt/model/test_unet.py unittest/trt/python_plugin unittest/tools unittest/utils unittest/others SKIP (Disable for Blackwell) full:B200_PCIe/unittest/trt/quantization/test_weight_only_quant_matmul.py SKIP (Disable for Blackwell) full:B200_PCIe/unittest/trt/quantization/test_weight_only_groupwise_quant_matmul.py SKIP (Disable for Blackwell) -full:B200_PCIe/examples/test_gpt.py::test_llm_gpt2_starcoder_weight_only[starcoder2-int8-float16] SKIP (Disable for Blackwell) full:B200_PCIe/unittest/trt/model/test_gpt.py -k "partition0" SKIP (Disable for Blackwell) full:B200_PCIe/unittest/test_model_runner_cpp.py SKIP (Disable for Blackwell) -full:B200_PCIe/examples/test_llama.py::test_llm_llama_smooth_quant_1gpu_summary[float16-llama-7b-enable_ptpc-nb:4] SKIP (Disable for Blackwell for SQ) -full:B200_PCIe/examples/test_llama.py::test_llm_llama_wo_1gpu_summary[llama-7b-int8-nb:1] SKIP (Disable for Blackwell for WO) -full:B200_PCIe/examples/test_llama.py::test_llm_llama_v3_int8_gptq_1gpu_summary[llama-v3-8b-instruct-hf-float16-nb:1] SKIP (Disable for Blackwell for weight only) -full:B200_PCIe/examples/test_commandr.py::test_llm_commandr_v01_single_gpu_summary[enable_weight_only] SKIP (Disable for Blackwell for weight only) -full:B200_PCIe/examples/test_llama.py::test_llm_llama_v3_1_quantization_1gpu_manage_weights[llama-3.1-8b-int4_wo] SKIP (Disable for Blackwell for weight only) -full:B200_PCIe/examples/test_llama.py::test_llm_llama_v3_1_autoq_1gpu_mmlu[llama-3.1-8b] SKIP (Disable for Blackwell for weight only) -full:B200_PCIe/examples/test_multimodal.py::test_llm_multimodal_general[nougat-base-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1] SKIP (Disable for Blackwell for context fmha doesn't support custom mask) -full:B200_PCIe/examples/test_multimodal.py::test_llm_multimodal_general[deplot-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (Disable for Blackwell for context fmha doesn't support custom mask) -full:B200_PCIe/accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[cuda_graph=False-chunked_context=False-typical_acceptance=False] SKIP (Disable for Blackwell for Speculative Dec) -full:B200_PCIe/accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[cuda_graph=True-chunked_context=False-typical_acceptance=False] SKIP (Disable for Blackwell for Speculative Dec) -full:B200_PCIe/accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[cuda_graph=True-chunked_context=True-typical_acceptance=False] SKIP (Disable for Blackwell for Speculative Dec) -full:B200_PCIe/examples/test_medusa.py::test_llm_medusa_1gpu[use_py_session-medusa-vicuna-7b-v1.3-4-heads-bfloat16-bs8] SKIP (Disable for Blackwell for Speculative Dec) full:B200_PCIe/unittest/llmapi/test_llm_models.py -m "part0" SKIP (Disable for Blackwell for context fmha doesn't support when headsize is 80/96) -full:B200_PCIe/examples/test_multimodal.py::test_llm_multimodal_general[Phi-3-vision-128k-instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (Disable for Blackwell for context fmha doesn't support when headsize is 96) -full:B200_PCIe/examples/test_multimodal.py::test_llm_multimodal_general[Phi-3.5-vision-instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (Disable for Blackwell for context fmha doesn't support when headsize is 96) -full:B200_PCIe/examples/test_multimodal.py::test_llm_multimodal_general[Llama-3.2-11B-Vision-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1] SKIP (Disable for Blackwell for context fmha doesn't support custom mask) -full:B200_PCIe/examples/test_multimodal.py::test_llm_multimodal_general[Llama-3.2-11B-Vision-pp:1-tp:1-bfloat16-bs:8-cpp_e2e:False-nb:1] SKIP (Disable for Blackwell for context fmha doesn't support custom mask) -full:B200_PCIe/examples/test_llama.py::test_llm_llama_v3_1_1node_single_gpu[llama-3.1-8b-enable_fp8] SKIP (Disable for Blackwell for fp8 rowwise gemm) -full:B200_PCIe/examples/test_llama.py::test_llm_llama_v3_1_1node_single_gpu[llama-3.1-8b-enable_fp8_meta_recipe] SKIP (Disable for Blackwell for fp8 rowwise gemm) -full:B200_PCIe/examples/test_multimodal.py::test_llm_multimodal_general[video-neva-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1] SKIP (megatron-core 0.8 is not supported in python 3.12) full:B200_PCIe/examples/test_nemotron.py::test_llm_nemotron_3_8b_1gpu[bfloat16-fp8] SKIP (megatron-core 0.8 is not supported in python 3.12) full:B200_PCIe/accuracy/test_cli_flow.py::TestMixtral8x7B::test_fp4_plugin SKIP (Disable for Blackwell OOM) -full:B200_PCIe/examples/test_commandr.py::test_llm_commandr_v01_single_gpu_summary[disable_weight_only] SKIP (Disable for Blackwell OOM) full:B200_PCIe/unittest/llmapi/test_llm_models.py -m "not (part0 or part1)" SKIP (Disable for Blackwell OOM) -full:B200/examples/test_llama.py::test_llm_llama_v2_1gpu_auto_parallel[llama-v2-7b-hf] SKIP (Disable for Blackwell) -full:B200/examples/test_mamba.py::test_llm_mamba_1gpu[mamba2-130m-float16-enable_gemm_plugin] SKIP (Disable for Blackwell) -full:B200/examples/test_mamba.py::test_llm_mamba_1gpu[mamba2-130m-float16-disable_gemm_plugin] SKIP (Disable for Blackwell) -full:B200/examples/test_mamba.py::test_llm_mamba_1gpu[mamba-codestral-7B-v0.1-float16-enable_gemm_plugin] SKIP (Disable for Blackwell) -full:B200/examples/test_mamba.py::test_llm_mamba_1gpu[mamba-codestral-7B-v0.1-float16-disable_gemm_plugin] SKIP (Disable for Blackwell) -full:B200/examples/test_medusa.py::test_llm_medusa_1gpu[use_cpp_session-medusa-vicuna-7b-v1.3-4-heads-bfloat16-bs1] SKIP (Disable for Blackwell) -full:B200/examples/test_llama.py::test_llm_llama_v2_1gpu_gemm_swiglu[llama-v2-7b-hf-fp8-float16] SKIP (Disable for Blackwell) -full:B200/examples/test_llama.py::test_llm_llama_int8_kv_1gpu_summary[llama-7b-enable_weight_only-nb:4] SKIP (Disable for Blackwell) -full:B200/examples/test_llama.py::test_llm_llama_int8_sq_ootb_1gpu_summary[llama-7b-nb:1] SKIP (Disable for Blackwell) -full:B200/examples/test_llama.py::test_llm_llama_wo_1gpu_summary[llama-7b-int4-nb:1] SKIP (Disable for Blackwell) -full:B200/examples/test_llama.py::test_llm_llama_int8_kv_awq_1gpu_summary[llama-7b-nb:4] SKIP (Disable for Blackwell) -full:B200/examples/test_llama.py::test_llm_llama_v2_1gpu_low_latency_gemm[llama-v2-7b-hf-fp8] SKIP (Disable for Blackwell) -full:B200/examples/test_medusa.py::test_llm_medusa_1gpu[use_cpp_session-medusa-vicuna-7b-v1.3-4-heads-bfloat16-bs8] SKIP (Disable for Blackwell) -full:B200/examples/test_redrafter.py::test_llm_redrafter_1gpu[use_cpp_session-redrafter-vicuna-7b-v1.3-bfloat16-dl5-nb5-bs8] SKIP (Disable for Blackwell spec decoding) -full:B200/examples/test_redrafter.py::test_llm_redrafter_1gpu[use_cpp_session-redrafter-vicuna-7b-v1.3-bfloat16-dl5-nb8-bs8] SKIP (Disable for Blackwell spec decoding) -full:B200/examples/test_redrafter.py::test_llm_redrafter_1gpu[use_py_session-redrafter-vicuna-7b-v1.3-bfloat16-dl5-nb5-bs8] SKIP (Disable for Blackwell spec decoding) -full:B200/examples/test_redrafter.py::test_llm_redrafter_1gpu[use_py_session-redrafter-vicuna-7b-v1.3-bfloat16-dl5-nb8-bs8] SKIP (Disable for Blackwell spec decoding) -full:B200/accuracy/test_cli_flow.py::TestGpt2::test_weight_only[int8] SKIP (Disable for Blackwell) -full:B200/accuracy/test_cli_flow.py::TestGpt2::test_weight_only[int4] SKIP (Disable for Blackwell) -full:B200/accuracy/test_cli_flow.py::TestGpt2::test_smooth_quant[per_token=False-per_channel=False] SKIP (Disable for Blackwell) -full:B200/accuracy/test_cli_flow.py::TestGpt2::test_smooth_quant[per_token=True-per_channel=True] SKIP (Disable for Blackwell) -full:B200/examples/test_exaone.py::test_llm_exaone_1gpu[enable_weight_only-exaone_3.0_7.8b_instruct-float16-nb:1] SKIP (Disable for Blackwell) -full:B200/examples/test_gpt.py::test_llm_gpt_starcoder_lora_1gpu[peft-lora-starcoder2-15b-unity-copilot-starcoder2-lora_fp16-base_fp16] SKIP (Disable for Blackwell) -full:B200/examples/test_gpt.py::test_llm_gpt_starcoder_lora_1gpu[peft-lora-starcoder2-15b-unity-copilot-starcoder2-lora_fp16-base_fp8] SKIP (Disable for Blackwell) -full:B200/examples/test_llama.py::test_llm_llama_v2_lora_1gpu[chinese-llama-2-lora-13b-llama-v2-13b-hf-lora_fp16-base_awq] SKIP (Disable for Blackwell) -full:B200/examples/test_llama.py::test_llm_llama_v2_lora_1gpu[chinese-llama-2-lora-13b-llama-v2-13b-hf-lora_fp16-base_int8_wo] SKIP (Disable for Blackwell) -full:B200/examples/test_phi.py::test_llm_phi_single_gpu_summary[Phi-3-mini-128k-instruct-bfloat16-enable_gemm_plugin-enable_attention_plugin-enable_fmha_with_fp32_acc-nb:1] SKIP (Disable for Blackwell) -full:B200/examples/test_phi.py::test_llm_phi_single_gpu_summary[Phi-3-small-8k-instruct-bfloat16-enable_gemm_plugin-enable_attention_plugin-enable_fmha_with_fp32_acc-nb:1] SKIP (Disable for Blackwell) -full:B200/examples/test_phi.py::test_llm_phi_single_gpu_summary[Phi-3-small-128k-instruct-bfloat16-enable_gemm_plugin-enable_attention_plugin-enable_fmha_with_fp32_acc-nb:1] SKIP (Disable for Blackwell) -full:B200/examples/test_phi.py::test_llm_phi_single_gpu_summary[Phi-3.5-mini-instruct-bfloat16-enable_gemm_plugin-enable_attention_plugin-enable_fmha_with_fp32_acc-nb:1] SKIP (Disable for Blackwell) -full:B200/examples/test_phi.py::test_llm_phi_quantization_1gpu[Phi-3-mini-128k-instruct-fp8-float16] SKIP (Disable for Blackwell) -full:B200/examples/test_phi.py::test_llm_phi_quantization_1gpu[Phi-3.5-mini-instruct-fp8-float16] SKIP (Disable for Blackwell) full:B200/unittest/trt/functional SKIP (Disable for Blackwell) full:B200/unittest/trt/quantization SKIP (Disable for Blackwell) -full:B200/accuracy/test_cli_flow.py::TestVicuna7B::test_medusa[cuda_graph=False] SKIP (Disable for Blackwell) -full:B200/accuracy/test_cli_flow.py::TestVicuna7B::test_medusa[cuda_graph=True] SKIP (Disable for Blackwell) -full:B200/accuracy/test_cli_flow.py::TestVicuna7B::test_lookahead SKIP (Disable for Blackwell) full:B200/unittest/trt/attention/test_bert_attention.py SKIP (Disable for Blackwell) full:B200/unittest/trt/model/test_mamba.py SKIP (Disable for Blackwell) -full:B200/examples/test_medusa.py::test_llm_medusa_with_qaunt_base_model_1gpu[fp8-use_cpp_session-medusa-vicuna-7b-v1.3-4-heads-float16-bs1] SKIP (Disable for Blackwell) -full:B200/examples/test_medusa.py::test_llm_medusa_with_qaunt_base_model_1gpu[fp8-use_py_session-medusa-vicuna-7b-v1.3-4-heads-float16-bs1] SKIP (Disable for Blackwell) full:B200/unittest/bindings SKIP (Disable for Blackwell) full:B200/unittest/trt/attention/test_sage_attention.py unittest/llmapi/test_llm_download.py unittest/llmapi/test_llm_kv_cache_events.py unittest/trt/model/redrafter unittest/trt/model/test_phi.py unittest/trt/model/test_unet.py unittest/trt/python_plugin unittest/tools unittest/utils unittest/others SKIP (Disable for Blackwell) full:B200/unittest/trt/quantization/test_weight_only_quant_matmul.py SKIP (Disable for Blackwell) full:B200/unittest/trt/quantization/test_weight_only_groupwise_quant_matmul.py SKIP (Disable for Blackwell) -full:B200/examples/test_gpt.py::test_llm_gpt2_starcoder_weight_only[starcoder2-int8-float16] SKIP (Disable for Blackwell) full:B200/unittest/trt/model/test_gpt.py -k "partition0" SKIP (Disable for Blackwell) full:B200/unittest/test_model_runner_cpp.py SKIP (Disable for Blackwell) -full:B200/examples/test_llama.py::test_llm_llama_smooth_quant_1gpu_summary[float16-llama-7b-enable_ptpc-nb:4] SKIP (Disable for Blackwell for SQ) -full:B200/examples/test_llama.py::test_llm_llama_wo_1gpu_summary[llama-7b-int8-nb:1] SKIP (Disable for Blackwell for WO) -full:B200/examples/test_llama.py::test_llm_llama_v3_int8_gptq_1gpu_summary[llama-v3-8b-instruct-hf-float16-nb:1] SKIP (Disable for Blackwell for weight only) -full:B200/examples/test_commandr.py::test_llm_commandr_v01_single_gpu_summary[enable_weight_only] SKIP (Disable for Blackwell for weight only) -full:B200/examples/test_llama.py::test_llm_llama_v3_1_quantization_1gpu_manage_weights[llama-3.1-8b-int4_wo] SKIP (Disable for Blackwell for weight only) -full:B200/examples/test_llama.py::test_llm_llama_v3_1_autoq_1gpu_mmlu[llama-3.1-8b] SKIP (Disable for Blackwell for weight only) -full:B200/examples/test_multimodal.py::test_llm_multimodal_general[nougat-base-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1] SKIP (Disable for Blackwell for context fmha doesn't support custom mask) -full:B200/examples/test_multimodal.py::test_llm_multimodal_general[deplot-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (Disable for Blackwell for context fmha doesn't support custom mask) -full:B200/accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[cuda_graph=False-chunked_context=False-typical_acceptance=False] SKIP (Disable for Blackwell for Speculative Dec) -full:B200/accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[cuda_graph=True-chunked_context=False-typical_acceptance=False] SKIP (Disable for Blackwell for Speculative Dec) -full:B200/accuracy/test_cli_flow.py::TestVicuna7B::test_eagle[cuda_graph=True-chunked_context=True-typical_acceptance=False] SKIP (Disable for Blackwell for Speculative Dec) -full:B200/examples/test_medusa.py::test_llm_medusa_1gpu[use_py_session-medusa-vicuna-7b-v1.3-4-heads-bfloat16-bs8] SKIP (Disable for Blackwell for Speculative Dec) full:B200/unittest/llmapi/test_llm_models.py -m "part0" SKIP (Disable for Blackwell for context fmha doesn't support when headsize is 80/96) -full:B200/examples/test_multimodal.py::test_llm_multimodal_general[Phi-3-vision-128k-instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (Disable for Blackwell for context fmha doesn't support when headsize is 96) -full:B200/examples/test_multimodal.py::test_llm_multimodal_general[Phi-3.5-vision-instruct-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1] SKIP (Disable for Blackwell for context fmha doesn't support when headsize is 96) -full:B200/examples/test_multimodal.py::test_llm_multimodal_general[Llama-3.2-11B-Vision-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1] SKIP (Disable for Blackwell for context fmha doesn't support custom mask) -full:B200/examples/test_multimodal.py::test_llm_multimodal_general[Llama-3.2-11B-Vision-pp:1-tp:1-bfloat16-bs:8-cpp_e2e:False-nb:1] SKIP (Disable for Blackwell for context fmha doesn't support custom mask) -full:B200/examples/test_llama.py::test_llm_llama_v3_1_1node_single_gpu[llama-3.1-8b-enable_fp8] SKIP (Disable for Blackwell for fp8 rowwise gemm) -full:B200/examples/test_llama.py::test_llm_llama_v3_1_1node_single_gpu[llama-3.1-8b-enable_fp8_meta_recipe] SKIP (Disable for Blackwell for fp8 rowwise gemm) full:B200/examples/test_multimodal.py::test_llm_multimodal_general[video-neva-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False-nb:1] SKIP (megatron-core 0.8 is not supported in python 3.12) full:B200/examples/test_nemotron.py::test_llm_nemotron_3_8b_1gpu[bfloat16-fp8] SKIP (megatron-core 0.8 is not supported in python 3.12) full:B200/accuracy/test_cli_flow.py::TestMixtral8x7B::test_fp4_plugin SKIP (Disable for Blackwell OOM) -full:B200/examples/test_commandr.py::test_llm_commandr_v01_single_gpu_summary[disable_weight_only] SKIP (Disable for Blackwell OOM) full:B200/unittest/llmapi/test_llm_models.py -m "not (part0 or part1)" SKIP (Disable for Blackwell OOM) -full:B200/examples/test_llama.py::test_llm_llama_code_llama_quantization_4gpus_summary[CodeLlama-34b-Instruct-tp2pp2-int4_awq-nb:4] SKIP (not support on B200) -full:B200/examples/test_llama.py::test_llm_llama_code_llama_quantization_4gpus_summary[CodeLlama-70b-hf-tp2pp2-int4_awq-nb:1] SKIP (not support on B200) -full:B200/examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-t5-small-float16-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:2-pp:1-nb:1-enable_fp8] SKIP (not support on B200) -full:B200/examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-flan-t5-small-bfloat16-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:2-pp:2-nb:1-disable_fp8] SKIP (not support on B200) -full:B200/examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-mbart-large-50-many-to-one-mmt-float16-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:2-pp:2-nb:1-enable_fp8] SKIP (not support on B200) -full:B200/examples/test_enc_dec.py::test_llm_enc_dec_general[no_compare_hf-byt5-small-float16-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:2-pp:1-nb:1-disable_fp8] SKIP (not support on B200) -full:B200/examples/test_chatglm.py::test_llm_glm_4_9b_single_gpu_summary[glm-4-9b-enable_weight_only] SKIP (not support on B200) -full:B200/examples/test_chatglm.py::test_llm_glm_4_9b_single_gpu_summary[glm-4-9b-chat-enable_weight_only] SKIP (not support on B200) -full:B200/examples/test_commandr.py::test_llm_commandr_plus_4gpus_summary[enable_weight_only] SKIP (not support on B200) -full:B200/examples/test_qwen.py::test_llm_qwen_single_gpu_summary[qwen2_7b_instruct-enable_paged_kv_cache-enable_remove_input_padding-enable_weight_only-enable_fmha_fp32_acc] SKIP (not support on B200) -full:B200/examples/test_qwen.py::test_llm_qwen_single_gpu_summary[qwen2.5_1.5b_instruct-enable_paged_kv_cache-enable_remove_input_padding-enable_weight_only-enable_fmha_fp32_acc] SKIP (not support on B200) -full:B200/examples/test_qwen.py::test_llm_qwen_7b_int8_kv_1node_1gpus[qwen2_7b_instruct-enable_gemm_plugin-enable_weight_only] SKIP (not support on B200) -full:B200/examples/test_gpt.py::test_llm_gpt2_smooth_single_gpu_summary[enable_ptpc] SKIP (not support on B200) -full:B200/examples/test_gpt.py::test_llm_gpt2_smooth_single_gpu_summary[disable_ptpc] SKIP (not support on B200) -full:B200/examples/test_gpt.py::test_llm_gpt2_int8_kv_1gpu SKIP (not support on B200) -full:B200/examples/test_gpt.py::test_llm_gpt2_starcoder_weight_only[starcoder-int8-float16] SKIP (not support on B200) -full:B200/examples/test_gpt.py::test_llm_gpt2_starcoder_weight_only[starcoder-int4-float16] SKIP (not support on B200) -full:B200/examples/test_gemma.py::test_llm_hf_gemma_quantization_1gpu[gemma-2b-int8_sq-bfloat16-8] SKIP (not support on B200) -full:B200/examples/test_llama.py::test_llm_llama_v2_awq_2gpu_summary[llama-v2-7b-hf-nb:1] SKIP (not support on B200) -full:B200/examples/test_llama.py::test_llm_llama_v2_awq_2gpu_summary[Llama-2-7B-AWQ-nb:1] SKIP (not support on B200) -full:B200/examples/test_llama.py::test_llm_llama_v2_awq_2gpu_summary[Llama-2-7B-GPTQ-nb:4] SKIP (not support on B200) -full:B200/examples/test_llama.py::test_llm_llama_v2_lora_1gpu[chinese-llama-2-lora-13b-llama-v2-13b-hf-lora_fp16-base_sq_ootb] SKIP (not support on B200) -full:B200/examples/test_llama.py::test_llm_llama_smooth_quant_1gpu_summary[float16-llama-7b-disable_ptpc-nb:1] SKIP (not support on B200) -full:B200/examples/test_llama.py::test_llm_llama_int8_kv_1gpu_summary[llama-7b-enable_weight_only-nb:1] SKIP (not support on B200) -full:B200/examples/test_llama.py::test_llm_llama_int8_kv_1gpu_summary[llama-7b-disable_weight_only-nb:1] SKIP (not support on B200) -full:B200/examples/test_llama.py::test_llm_llama_v2_int8sq_2gpu_tp2[llama-v2-7b-hf-bfloat16-nb:1] SKIP (not support on B200) -full:B200/examples/test_llama.py::test_llm_llama_int8_kv_awq_1gpu_summary[llama-7b-nb:1] SKIP (not support on B200) -full:B200/accuracy/test_cli_flow.py::TestMixtral8x7B::test_weight_only_int4_tp2 SKIP (not support on B200) -full:B200/accuracy/test_cli_flow.py::TestMixtral8x7B::test_weight_only_int8_tp2 SKIP (not support on B200) -full:B200/examples/test_gpt.py::test_llm_gpt2_starcoder_weight_only[starcoderplus-int8-float16] SKIP (not support on B200) -full:B200/examples/test_gpt.py::test_llm_gpt2_starcoder_weight_only[starcoderplus-int4-float16] SKIP (not support on B200) -full:B200/examples/test_llama.py::test_llm_llama_lookahead_xqa_fp8_1gpu[llama-3.1-8b] SKIP (No available XQA kernels are found for speculative decoding mode) -full:B200/examples/test_llama.py::test_llm_llama_lookahead_xqa_fp8_1gpu[llama-3.2-1b] SKIP (No available XQA kernels are found for speculative decoding mode) -full:B200/examples/test_medusa.py::test_llm_medusa_1gpu[use_py_session-medusa-vicuna-7b-v1.3-4-heads-bfloat16-bs1] SKIP (No available XQA kernels are found for speculative decoding mode) -full:B200/examples/test_gpt.py::test_llm_gpt2_starcoder_weight_only[starcoder2-int4-float16] SKIP (not support on B200) full:B200/examples/test_mixtral.py::test_llm_mixtral_moe_plugin_fp8_lora_4gpus[Mixtral-8x7B-v0.1-chinese-mixtral-lora] SKIP (https://nvbugs/5064768) -full:B200/accuracy/test_cli_flow.py::TestGpt2::test_int8_kv_cache SKIP (not support on B200) -full:B200/examples/test_recurrentgemma.py::test_llm_recurrentgemma_1gpu[use_cpp_session-recurrentgemma-2b-use_paged_cache-int8_sq-float16-enable_attn_plugin-enable_gemm_plugin] SKIP (not support on B200) -full:B200/examples/test_recurrentgemma.py::test_llm_recurrentgemma_1gpu[use_cpp_session-recurrentgemma-2b-use_paged_cache-int4_awq-float16-enable_attn_plugin-enable_gemm_plugin] SKIP (not support on B200) -full:B200/examples/test_recurrentgemma.py::test_llm_recurrentgemma_1gpu[use_cpp_session-recurrentgemma-2b-use_paged_cache-fp8-float16-enable_attn_plugin-enable_gemm_plugin] SKIP (not support on B200) examples/test_qwen.py::test_llm_qwen_moe_multi_gpu_summary[qwen2_57b_a14b-tp4pp1-context_fmha] SKIP (https://nvbugs/5063469) examples/test_qwen.py::test_llm_qwen_moe_multi_gpu_summary[qwen2_57b_a14b-tp2pp2-context_fmha_fp32_acc] SKIP (https://nvbugs/5063469) examples/test_mixtral.py::test_llm_mixtral_moe_plugin_fp8_lora_4gpus[Mixtral-8x7B-v0.1-chinese-mixtral-lora] SKIP (https://nvbugs/5064768) llmapi/test_llm_e2e.py::test_llmapi_build_command_parameters_align[llama-llama-models-v2/TinyLlama-1.1B-Chat-v1.0] SKIP (https://nvbugs/5061624) test_e2e.py::test_openai_consistent_chat SKIP (https://nvbugs/5112075) -full:B200/examples/test_gemma.py::test_llm_hf_gemma_quantization_1gpu[gemma-2-9b-it-fp8-bfloat16-8] SKIP (not supported on B200) -full:B200/examples/test_gpt.py::test_llm_gpt2_starcoder_1gpus SKIP (not supported on B200) examples/test_eagle.py::test_qwen_eagle_1gpu[qwen_7b_chat-eagle1] SKIP (https://nvbugs/5206383) examples/test_eagle.py::test_qwen_eagle_1gpu[qwen1.5_7b_chat-eagle1] SKIP (https://nvbugs/5206383) examples/test_eagle.py::test_qwen_eagle_1gpu[qwen2_7b_instruct-eagle1] SKIP (https://nvbugs/5206383) @@ -250,12 +98,6 @@ examples/test_eagle.py::test_phi_eagle_1gpu[phi-2-eagle2] SKIP (https://nvbugs/5 examples/test_eagle.py::test_phi_eagle_1gpu[Phi-3-mini-128k-instruct-eagle2] SKIP (https://nvbugs/5206383) examples/test_eagle.py::test_phi_eagle_1gpu[Phi-3-small-128k-instruct-eagle2] SKIP (https://nvbugs/5206383) examples/test_eagle.py::test_phi_eagle_1gpu[Phi-3.5-mini-instruct-eagle2] SKIP (https://nvbugs/5206383) -full:B200/examples/test_llama.py::test_llm_llama_lookahead_single_gpu_summary[llama-3.1-8b] SKIP (not supported on B200) -full:B200/examples/test_multimodal.py::test_llm_multimodal_general[nougat-base-pp:1-tp:1-bfloat16-bs:8-cpp_e2e:False-nb:1] SKIP (TRTLLM-GEN does not support custom mask) -full:B200/examples/test_multimodal.py::test_llm_multimodal_general[deplot-pp:1-tp:1-float16-bs:8-cpp_e2e:False-nb:1] SKIP (TRTLLM-GEN does not support custom mask) -full:B200/examples/test_multimodal.py::test_llm_multimodal_general[Llama-3.2-11B-Vision-pp:1-tp:2-bfloat16-bs:1-cpp_e2e:False-nb:1] SKIP (TRTLLM-GEN does not support custom mask) -full:B200/examples/test_multimodal.py::test_llm_fp8_multimodal_general[fp8-fp8-scienceqa-Llama-3.2-11B-Vision-Instruct-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False] SKIP (TRTLLM-GEN does not support custom mask) -full:B200/examples/test_chatglm.py::test_llm_glm_4_9b_single_gpu_summary[glm-4-9b-chat-disable_weight_only] SKIP (https://nvbugs/5114743) examples/test_gpt.py::test_llm_gpt_starcoder_lora_1gpu[peft-lora-starcoder2-15b-unity-copilot-starcoder2-lora_fp16-base_fp16] SKIP (https://nvbugs/5114678) examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-mbart-large-50-many-to-one-mmt-float16-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:2-pp:2-nb:1-enable_fp8] SKIP (https://nvbugs/5135328) examples/test_multimodal.py::test_llm_multimodal_general[kosmos-2-pp:1-tp:1-float16-bs:8-cpp_e2e:True-nb:1] SKIP (https://nvbugs/5141288) diff --git a/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py b/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py index d0753c3cf28..ef3bf35a431 100644 --- a/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py +++ b/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py @@ -8,16 +8,49 @@ from torch.export import export from torch.fx import GraphModule +from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactory from tensorrt_llm._torch.auto_deploy.transformations.library.sharding import ShardingTransformInfo -class FakeFactory: - def __init__(self, model: nn.Module): - self.model = model +class FakeFactory(ModelFactory): + """Dummy factory to pass cache_config for testing.""" - def build_model(self, device: str) -> nn.Module: - return self.model.to(device=device) + def __init__(self, model=None, cache_config=None, quant_config=None): + self._model = model + self.cache_config = cache_config + self.quant_config = quant_config + + def build_model(self, device: str): + return self._model.to(device=device) if self._model else None + + def _build_model(self, device: str): + return + + def _load_checkpoint(self, model, device): + return + + def get_cache_config(self): + return self.cache_config + + def get_quant_config(self): + return self.quant_config + + +class SequenceEmbeddingInfo(SequenceInfo): + hidden_size: int + dtype: torch.dtype + + def set_example_sequence(self) -> None: + super().set_example_sequence() + # set input ids to a 3D tensor (actually input embeddings) + self.input_ids = torch.rand( + *self.input_ids.shape, + self.hidden_size, + device=self.input_ids.device, + dtype=self.dtype, + ) def count_parameters(model: torch.nn.Module): @@ -32,6 +65,79 @@ def count_buffers(model: torch.nn.Module): return sum(np.prod(b.shape) for b in model.buffers()) +def run_test_transformed_gm( + model: nn.Module, + x: torch.Tensor, + gm_transformed: GraphModule, + check_transformed_graph: Callable[[GraphModule], bool], + _get_expected_num_params: Callable[[int], int], + atol: float = 1e-3, + rtol: float = 1e-3, + test_load_hook: bool = True, + strict_loading: bool = True, + dynamic_shapes: Dict = None, + skip_output_assert: bool = False, + *args, # Additional arguments for transform +) -> GraphModule: + # run model once + y_model = model(x) + + # num params + num_params_model = count_parameters(model) + print(num_params_model) + + # export + check (we clone the state dict to have a bit more freedom in testing below) + gm_ref = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True) + print(gm_ref) + y_gm = gm_ref(x) + num_params_gm = count_parameters(gm_ref) + + assert num_params_model == num_params_gm + if not skip_output_assert: + torch.testing.assert_close(y_model, y_gm, atol=atol, rtol=rtol) + + print(gm_transformed) + # in case buffers or other tensors were added during the transform + gm_transformed = gm_transformed.to("cuda") + y_transformed = gm_transformed(x) + n_p_transformed = count_parameters(gm_transformed) + + n_p_t_expected = _get_expected_num_params(num_params_model) + assert n_p_transformed == n_p_t_expected, ( + f"actual params {n_p_transformed} != expected params {n_p_t_expected}" + ) + + # check if the transformation worked + assert check_transformed_graph(gm_transformed) + + if strict_loading and not skip_output_assert: + # check if output equals without loading state dict + torch.testing.assert_close(y_model, y_transformed, atol=atol, rtol=rtol) + + if test_load_hook and not skip_output_assert: + # check if loading hook works from original state dict + reset_parameters(gm_transformed) + y_random = gm_transformed(x) + assert not all_close(y_model, y_random), f"{y_model=}, {y_random=}" + + gm_transformed.load_state_dict(model.state_dict(), strict=True if strict_loading else False) + y_loaded_from_original = gm_transformed(x) + torch.testing.assert_close(y_model, y_loaded_from_original, atol=atol, rtol=rtol) + + # check if loading hook works from state_dict of a transformed model + state_dict_sharded = copy.deepcopy(gm_transformed.state_dict()) + reset_parameters(gm_transformed) + y_random2 = gm_transformed(x) + assert not all_close(y_model, y_random2), f"{y_model=}, {y_random2=}" + + gm_transformed.load_state_dict(state_dict_sharded, strict=True if strict_loading else False) + y_loaded_from_transformed = gm_transformed(x) + torch.testing.assert_close(y_model, y_loaded_from_transformed, atol=atol, rtol=rtol) + + # check if we can still export the model as expected + export(gm_transformed, args=(x,)) + + def run_test( model: nn.Module, x: torch.Tensor, diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/test_ad_build_small_multi.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/test_ad_build_small_multi.py index 33ace089018..92457666a71 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/test_ad_build_small_multi.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/test_ad_build_small_multi.py @@ -19,9 +19,6 @@ ], ) def test_build_ad(world_size: int, experiment_config: Dict): - if world_size > 1: - pytest.skip("https://nvbugspro.nvidia.com/bug/5331013") - experiment_config["args"]["world_size"] = world_size experiment_config["args"]["runtime"] = "trtllm" # Default runtime set to trtllm experiment_config = ExperimentConfig(**experiment_config) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py index ea27c66d035..b378cc06d09 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher.py @@ -7,14 +7,8 @@ from torch.fx import GraphModule from transformers.integrations.sdpa_attention import repeat_kv as hf_repeat_kv +from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import AttentionDescriptor from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer -from tensorrt_llm._torch.auto_deploy.transformations.library.attention import ( - match_attention_layout, - match_causal_attn_mask, - match_eager_attention, - match_grouped_attention, - match_repeat_kv, -) from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op torch.manual_seed(0) @@ -164,16 +158,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Multiplication pattern attn_weights = torch.matmul(q, k.transpose(2, 3)) * self.scaling - # Add attention mask if enabled + # Add causal attention mask if enabled if self.has_mask: - # Create a simple causal mask for testing - make sure all tensors are on the same device - mask = torch.triu( - torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), - diagonal=1, + # [1, 1, seq_len, seq_len] causal mask with -inf in the upper triangle + attn_mask = torch.triu( + torch.full((seq_len, seq_len), float("-inf"), device=device), diagonal=1 ) - mask = mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, seq_len] - attn_mask = torch.zeros_like(attn_weights, device=device) - attn_mask = attn_mask.masked_fill(mask, float("-inf")) + attn_mask = ( + attn_mask.unsqueeze(0).unsqueeze(0).to(x.dtype) + ) # shape: [1, 1, seq_len, seq_len] attn_weights = attn_weights + attn_mask # Apply softmax, dtype conversion, and dropout @@ -249,13 +242,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # Add attention mask if enabled if self.has_mask: - mask = torch.triu( - torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), - diagonal=1, + # [1, 1, seq_len, seq_len] causal mask with -inf in the upper triangle + attn_mask = torch.triu( + torch.full((seq_len, seq_len), float("-inf"), device=device), diagonal=1 ) - mask = mask.unsqueeze(0).unsqueeze(0) - attn_mask = torch.zeros_like(attn_weights, device=device) - attn_mask = attn_mask.masked_fill(mask, float("-inf")) + attn_mask = ( + attn_mask.unsqueeze(0).unsqueeze(0).to(x.dtype) + ) # shape: [1, 1, seq_len, seq_len] attn_weights = attn_weights + attn_mask # Add a to_dtype node before softmax to match pattern in the graph @@ -366,8 +359,6 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: batch_size, seq_len, _ = x.shape - device = x.device - dtype = x.dtype # Generate q, k, v q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim) @@ -385,28 +376,26 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: v = torch.ops.auto_deploy.torch_attention_repeat_kv(v, self.n_rep) # Create attention mask if needed - attn_mask = None if self.has_mask: - # Simple causal mask - mask = torch.triu( - torch.ones(seq_len, seq_len, dtype=torch.bool, device=device), - diagonal=1, + attn_output = torch.ops.auto_deploy.torch_attention_sdpa( + q, + k, + v, + attn_mask=None, + dropout_p=self.dropout, + is_causal=True, + scale=1.0 / (self.head_dim**0.5), + ) + else: + attn_output = torch.ops.auto_deploy.torch_attention_sdpa( + q, + k, + v, + attn_mask=None, + dropout_p=self.dropout, + is_causal=False, + scale=1.0 / (self.head_dim**0.5), ) - mask = mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_len, seq_len] - attn_mask = torch.zeros( - (batch_size, 1, seq_len, seq_len), device=device, dtype=dtype - ).masked_fill(mask, float("-inf")) - - # Apply scaled dot product attention - attn_output = torch.ops.auto_deploy.torch_attention_sdpa( - q, - k, - v, - attn_mask=attn_mask, - dropout_p=self.dropout, - is_causal=False, - scale=1.0 / (self.head_dim**0.5), - ) # Reshape output for the linear projection attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) @@ -423,11 +412,47 @@ def _get_match_repeat_kv_optimizer() -> Callable: "cleanup_noop_slice": { "stage": "post_export", }, + "match_repeat_kv": { + "stage": "pattern_matcher", + }, + } + + def _transform(gm: GraphModule) -> GraphModule: + gm = InferenceOptimizer(None, config)(None, gm) + return gm + + return _transform + + +def _get_match_eager_attention_optimizer() -> Callable: + config = { + "cleanup_noop_slice": { + "stage": "post_export", + }, + "match_eager_attention": { + "stage": "pattern_matcher", + }, + } + + def _transform(gm: GraphModule) -> GraphModule: + gm = InferenceOptimizer(None, config)(None, gm) + return gm + + return _transform + + +def _get_match_grouped_attention_optimizer() -> Callable: + config = { + "cleanup_noop_slice": { + "stage": "post_export", + }, + "match_grouped_attention": { + "stage": "pattern_matcher", + }, } def _transform(gm: GraphModule) -> GraphModule: gm = InferenceOptimizer(None, config)(None, gm) - match_repeat_kv(gm) return gm return _transform @@ -516,8 +541,8 @@ def verify_matcher(gm): ) -@pytest.mark.parametrize("has_mask", [True, False]) -@pytest.mark.parametrize("use_division", [False, True]) +@pytest.mark.parametrize("has_mask", [False, True]) +@pytest.mark.parametrize("use_division", [True, False]) @pytest.mark.parametrize( "dropout, skip_output_assert", [ @@ -537,8 +562,10 @@ def test_match_eager_attention(has_mask, use_division, dropout, skip_output_asse # Create different model types based on the parameter if model_type == "standard": - model = EagerAttentionModel(hidden_size, num_heads, has_mask, dropout, use_division).to( - "cuda", dtype=torch.float16 + model = ( + EagerAttentionModel(hidden_size, num_heads, has_mask, dropout, use_division) + .to("cuda", dtype=torch.float16) + .eval() ) # Print the original scaling approach and value if use_division: @@ -549,8 +576,10 @@ def test_match_eager_attention(has_mask, use_division, dropout, skip_output_asse expected_scale = model.scaling else: # complex # Complex model only uses division for scaling - model = ComplexEagerAttentionModel(hidden_size, num_heads, has_mask, dropout).to( - "cuda", dtype=torch.float16 + model = ( + ComplexEagerAttentionModel(hidden_size, num_heads, has_mask, dropout) + .to("cuda", dtype=torch.float16) + .eval() ) expected_scale = 1.0 / model.scale_divisor # Override use_division and only run test once (ignore the parameterization) @@ -567,6 +596,7 @@ def test_match_eager_attention(has_mask, use_division, dropout, skip_output_asse expected_matches = 1 def verify_matcher(gm): + # torch_attention_sdpa is replaced with torch_attention_sdpa after the transformation sdpa_nodes = [ n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention_sdpa) ] @@ -636,13 +666,15 @@ def verify_matcher(gm): # Check mask handling for masked attention if has_mask: - has_mask_arg = "attn_mask" in kwargs - if not has_mask_arg and len(node.args) >= 4: - has_mask_arg = node.args[3] is not None + is_causal = kwargs.get("is_causal", None) + if is_causal is None and len(node.args) >= 6: + is_causal = node.args[5] - if not has_mask_arg: - print("❌ Missing mask information in SDPA node") + if is_causal is not True: + print(f"❌ Expected is_causal=True for masked attention, got {is_causal}") valid = False + else: + print("✅ is_causal correctly set to True") print("Graph verification successful" if valid else "Graph verification failed") return valid @@ -651,7 +683,7 @@ def verify_matcher(gm): run_test( model, x, - match_eager_attention, + _get_match_eager_attention_optimizer(), verify_matcher, lambda num_p_og: num_p_og, atol=1e-3, @@ -685,7 +717,7 @@ def verify_no_matches(gm): _ = run_test( model, x, - match_repeat_kv, + _get_match_eager_attention_optimizer(), verify_no_matches, lambda num_p_og: num_p_og, atol=1e-3, @@ -709,9 +741,8 @@ def test_match_grouped_attention(num_heads, num_kv_heads, has_mask): x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float16) dynamic_shapes = model.get_dynamic_shapes() - # We should find 1 instance of the pattern if num_heads != num_kv_heads - # Otherwise, no pattern should be matched (no grouped attention) - expected_matches = 1 if num_heads != num_kv_heads else 0 + # We should find 1 instance of torch_attention_grouped_sdpa + expected_matches = 1 def verify_matcher(gm): grouped_sdpa_nodes = [ @@ -727,10 +758,6 @@ def verify_matcher(gm): ) return False - # If we don't expect any matches, we're done - if expected_matches == 0: - return True - # Otherwise, check the node properties for node in grouped_sdpa_nodes: # Basic checks: should have at least 3 positional args (q, k, v) @@ -743,16 +770,14 @@ def verify_matcher(gm): # Mask handling should be preserved if has_mask: - # Check if attn_mask is in kwargs or provided via args - has_mask_arg = "attn_mask" in kwargs - if ( - not has_mask_arg and len(node.args) >= 4 - ): # Assuming attn_mask is the 4th positional arg - has_mask_arg = node.args[3] is not None + is_causal = kwargs.get("is_causal", None) + if is_causal is None and len(node.args) >= 6: + is_causal = node.args[5] - if not has_mask_arg: - print("❌ Expected attn_mask in args or kwargs but not found") - return False + if is_causal is not True: + print(f"❌ Expected is_causal=True for masked attention, got {is_causal}") + else: + print("✅ is_causal correctly set to True") return True @@ -760,7 +785,7 @@ def verify_matcher(gm): _ = run_test( model, x, - match_grouped_attention, + _get_match_grouped_attention_optimizer(), verify_matcher, lambda num_p_og: num_p_og, atol=1e-3, @@ -884,98 +909,6 @@ def get_dynamic_shapes(self): return {0: Dim("batch_size", max=8), 1: Dim("seq_len", min=4, max=16)} -@pytest.mark.parametrize("mask_type", ["triu", "negative_fill", "non_causal"]) -@pytest.mark.parametrize("use_grouped_sdpa", [False, True]) -@torch.inference_mode() -def test_match_causal_attention(mask_type, use_grouped_sdpa): - batch_size, seq_len = 4, 12 - hidden_size = 512 - num_heads = 8 - num_kv_heads = 4 if use_grouped_sdpa else num_heads - - model = CausalAttentionModel( - hidden_size, - num_heads, - mask_type=mask_type, - use_grouped_sdpa=use_grouped_sdpa, - num_kv_heads=num_kv_heads, - ).to("cuda", dtype=torch.float16) - - x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float16) - dynamic_shapes = model.get_dynamic_shapes() - - # We expect optimization (None mask + is_causal=True) when using causal masks - should_optimize = mask_type in ["triu", "negative_fill"] - - def verify_matcher(gm): - # Find attention operations - if use_grouped_sdpa: - attn_nodes = [ - n - for n in gm.graph.nodes - if is_op(n, torch.ops.auto_deploy.torch_attention_grouped_sdpa) - ] - else: - attn_nodes = [ - n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention_sdpa) - ] - - if len(attn_nodes) != 1: - print(f"Expected 1 attention node, found {len(attn_nodes)}") - return False - - node = attn_nodes[0] - - # Check if attention mask was set to None and is_causal was set to True - if should_optimize: - # Attention mask (4th arg) should be None - has_mask = ( - node.args[3] is not None if len(node.args) > 3 else "attn_mask" in node.kwargs - ) - - # is_causal (6th arg) should be True - is_causal = node.args[5] if len(node.args) > 5 else node.kwargs.get("is_causal", False) - - # Check if optimization was correctly applied - if has_mask or not is_causal: - print("❌ Expected optimization: mask=None, is_causal=True") - print( - f" Got: mask={node.args[3] if len(node.args) > 3 else 'not in args'}, " - f"is_causal={is_causal}" - ) - return False - - print("✅ Successfully optimized causal mask: mask=None, is_causal=True") - else: - # Non-causal masks should remain as is - has_mask = ( - node.args[3] is not None if len(node.args) > 3 else "attn_mask" in node.kwargs - ) - - # Check if non-optimization was correctly preserved - if not has_mask: - print("❌ Expected non-causal mask to be preserved") - return False - - print("✅ Successfully preserved non-causal mask") - - return True - - # Run the test - _ = run_test( - model, - x, - match_causal_attn_mask, - verify_matcher, - lambda num_p_og: num_p_og, - atol=1e-3, - rtol=1e-3, - test_load_hook=True, - strict_loading=True, - dynamic_shapes=dynamic_shapes, - ) - - class Llama3CausalAttentionModel(torch.nn.Module): """Model that creates a causal attention mask mimicking the llama-3.1 pattern.""" @@ -1082,78 +1015,7 @@ def get_dynamic_shapes(self): return {0: Dim("batch_size", max=8), 1: Dim("seq_len", min=4, max=16)} -@pytest.mark.parametrize("use_grouped_sdpa", [False, True]) -@pytest.mark.skip(reason="Skip until we have more robust attention masking handling, see #4783") -@torch.inference_mode() -def test_match_llama3_causal_attention(use_grouped_sdpa): - batch_size, seq_len = 4, 12 - hidden_size = 512 - num_heads = 8 - num_kv_heads = 4 if use_grouped_sdpa else num_heads - - model = Llama3CausalAttentionModel( - hidden_size, - num_heads, - use_grouped_sdpa=use_grouped_sdpa, - num_kv_heads=num_kv_heads, - ).to("cuda", dtype=torch.float32) - - x = torch.randn(batch_size, seq_len, hidden_size, device="cuda", dtype=torch.float32) - dynamic_shapes = model.get_dynamic_shapes() - - def verify_matcher(gm): - # Find attention operations - if use_grouped_sdpa: - attn_nodes = [ - n - for n in gm.graph.nodes - if is_op(n, torch.ops.auto_deploy.torch_attention_grouped_sdpa) - ] - else: - attn_nodes = [ - n for n in gm.graph.nodes if is_op(n, torch.ops.auto_deploy.torch_attention_sdpa) - ] - - if len(attn_nodes) != 1: - print(f"Expected 1 attention node, found {len(attn_nodes)}") - return False - - node = attn_nodes[0] - - # Attention mask (4th arg) should be None - has_mask = node.args[3] is not None if len(node.args) > 3 else "attn_mask" in node.kwargs - - # is_causal (6th arg) should be True - is_causal = node.args[5] if len(node.args) > 5 else node.kwargs.get("is_causal", False) - - # Check if optimization was correctly applied - if has_mask or not is_causal: - print("❌ Expected optimization: mask=None, is_causal=True") - print( - f" Got: mask={node.args[3] if len(node.args) > 3 else 'not in args'}, " - f"is_causal={is_causal}" - ) - return False - - print("✅ Successfully optimized llama-3.1 causal mask: mask=None, is_causal=True") - return True - - # Run the test - run_test( - model, - x, - match_causal_attn_mask, - verify_matcher, - lambda num_p_og: num_p_og, - atol=1e-3, - rtol=1e-3, - test_load_hook=True, - strict_loading=True, - dynamic_shapes=dynamic_shapes, - ) - - -class MockAttentionDescriptor: +class MockAttentionDescriptor(AttentionDescriptor): """A mock class that mimics the AttentionDescriptor interface for testing.""" layout: str = "bnsd" @@ -1458,7 +1320,15 @@ def verify_matcher(gm): run_test( model, x, - lambda gm: match_attention_layout(gm, MockAttentionDescriptor), + lambda gm: InferenceOptimizer( + None, + { + "match_attention_layout": { + "stage": "pattern_matcher", + "attention_op": MockAttentionDescriptor, + }, + }, + )(None, gm), verify_matcher, lambda num_p_og: num_p_og, atol=1e-3, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py index 42de0bbe159..a813e9906af 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_attention_matcher_hf.py @@ -1,26 +1,26 @@ """Test that the attention matcher works with HF's SDPA backends.""" +import copy from typing import Any, Callable, Dict import pytest import torch import torch.nn as nn -from _graph_test_helpers import run_test +from accelerate import init_empty_weights from torch.export import Dim from torch.fx import GraphModule from transformers.models.llama.configuration_llama import LlamaConfig from transformers.models.llama.modeling_llama import LlamaModel -from tensorrt_llm._torch.auto_deploy.transformations.library import ( - match_attention_layout, - match_causal_attn_mask, - match_eager_attention, - match_grouped_attention, - match_repeat_kv, -) +from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import AttentionDescriptor +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer +from tensorrt_llm._torch.auto_deploy.transformations._graph import move_to_device + +torch.manual_seed(0) -class MockAttentionDescriptor: +class MockAttentionDescriptor(AttentionDescriptor): """A mock class that mimics the AttentionDescriptor interface for testing.""" layout: str = "bsnd" @@ -45,11 +45,24 @@ def forward(self, x: torch.Tensor): def _joint_transform(gm: GraphModule) -> None: - match_repeat_kv(gm) - match_eager_attention(gm) - match_grouped_attention(gm) - match_causal_attn_mask(gm) - match_attention_layout(gm, MockAttentionDescriptor()) + gm = InferenceOptimizer( + None, + { + "match_repeat_kv": { + "stage": "pattern_matcher", + }, + "match_eager_attention": { + "stage": "pattern_matcher", + }, + "match_grouped_attention": { + "stage": "pattern_matcher", + }, + "match_attention_layout": { + "stage": "pattern_matcher", + "attention_op": MockAttentionDescriptor, + }, + }, + )(None, gm) @pytest.mark.parametrize( @@ -65,23 +78,6 @@ def _joint_transform(gm: GraphModule) -> None: ["eager", "sdpa"], ) def test_match_llama_attention(config: Dict[str, Any], attn_implementation: str): - batch_size, seq_len = 4, 12 - full_config = { - "num_hidden_layers": 1, - "vocab_size": 256, - "hidden_size": 128, - "intermediate_size": 128, - "attn_implementation": attn_implementation, - **config, - } - dynamic_shapes = {0: Dim("batch_size", max=8), 1: Dim("seq_len", min=4, max=16)} - - model = HFWrapper(LlamaModel(LlamaConfig(**full_config))).to("cuda") - model.eval() - x = torch.randint( - 0, full_config["vocab_size"], (batch_size, seq_len), dtype=torch.long, device="cuda" - ) - def verify_matcher(gm: GraphModule): """Ensure that there is exactly one torch.ops.auto_deploy.torch_attention_bsnd_grouped_sdpa call in the graph. Also check that there is no repeat_kv pattern left. @@ -106,18 +102,69 @@ def verify_matcher(gm: GraphModule): op="call_function", target=torch.ops.auto_deploy.torch_attention_repeat_kv ) assert len(nodes) == 0, "Found repeat_kv pattern in the graph" + attn_nodes = gm.graph.find_nodes( + op="call_function", target=torch.ops.auto_deploy.torch_attention_sdpa + ) + assert len(attn_nodes) == 0, "Found torch_attention_sdpa node in the graph" return True - _ = run_test( - model, - x, - _joint_transform, - verify_matcher, - lambda num_p_og: num_p_og, - atol=1e-3, - rtol=5e-2, - test_load_hook=True, - strict_loading=True, - dynamic_shapes=dynamic_shapes, + batch_size, seq_len = 2, 4 + full_config = { + "num_hidden_layers": 1, + "vocab_size": 256, + "hidden_size": 128, + "intermediate_size": 128, + "attn_implementation": attn_implementation, + **config, + } + dynamic_shapes = {0: Dim("batch_size", max=8), 1: Dim("seq_len", min=2, max=8)} + + # Build and export model on meta device + with init_empty_weights(): + model = HFWrapper(LlamaModel(LlamaConfig(**full_config))).eval() + x = torch.randint( + 0, full_config["vocab_size"], (batch_size, seq_len), dtype=torch.long, device="cuda" + ) + gm = torch_export_to_gm(model, args=(x,), dynamic_shapes=(dynamic_shapes,), clone=True) + + print("Exported gm", gm) + gm_exported = copy.deepcopy(gm) + + # Move model to cuda + device = "cuda" + model._apply( + lambda t: torch.normal(0.0, 1.0, size=t.shape, device=device).to(t.dtype) + if t.device == torch.device("meta") + else t.to(device) ) + y_model = model(x) + + gm_exported._apply( + lambda t: torch.normal(0.0, 1.0, size=t.shape, device=device).to(t.dtype) + if t.device == torch.device("meta") + else t.to(device) + ) + gm_exported.load_state_dict(model.state_dict()) + move_to_device(gm_exported, "cuda") + y_gm_exported = gm_exported(x) + torch.testing.assert_close(y_gm_exported, y_model, atol=5e-3, rtol=5e-3) + + # Apply transformation + _joint_transform(gm) + assert verify_matcher(gm) + print("Transformed gm", gm) + + # Move gm to cuda + gm._apply( + lambda t: torch.normal(0.0, 1.0, size=t.shape, device=device).to(t.dtype) + if t.device == torch.device("meta") + else t.to(device) + ) + gm.load_state_dict(model.state_dict()) + move_to_device(gm, "cuda") + + # Verify output + y_gm = gm(x) + torch.testing.assert_close(y_gm_exported, y_gm, atol=5e-2, rtol=5e-2) + torch.testing.assert_close(y_model, y_gm, atol=5e-2, rtol=5e-2) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_moe.py index 3d328be658c..0327f01329d 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_moe.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_moe.py @@ -1,10 +1,11 @@ import pytest import torch -from _graph_test_helpers import run_test +from _graph_test_helpers import FakeFactory, run_test_transformed_gm from _model_test_utils import MoEOpModel from _torch_test_utils import fp4_compatible, fp8_compatible, trtllm_ops_available -from tensorrt_llm._torch.auto_deploy.transformations.library import quantize_moe +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op @@ -62,13 +63,20 @@ def _expected_num_params(n): quant_config = {"quant_algo": quant_algo} - def _transform(gm, *args): - return quantize_moe(gm, quant_config) + gm = torch_export_to_gm(model, args=(x,), clone=True) + gm_transformed = InferenceOptimizer( + FakeFactory(quant_config=quant_config), + { + "quantize_moe": { + "stage": "pattern_matcher", + }, + }, + )(None, gm) - _ = run_test( + run_test_transformed_gm( model=model, x=x, - transform=_transform, + gm_transformed=gm_transformed, check_transformed_graph=_check_transformed_graph, _get_expected_num_params=_expected_num_params, atol=0.5, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py index 1e063e76573..6f2734bc6c6 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py @@ -4,13 +4,14 @@ import pytest import torch -from _graph_test_helpers import run_test +from _graph_test_helpers import run_test_transformed_gm from _model_test_utils import MLP, BMMDynamicModel, BMMModel from _torch_test_utils import fp4_compatible, fp8_compatible from tensorrt_llm._torch.auto_deploy.custom_ops.quant import QUANT_OPS from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm -from tensorrt_llm._torch.auto_deploy.transformations.library import quantize +from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactory +from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import fp8_scale @@ -19,6 +20,22 @@ def check_quantized(gm): return any(is_op(n, QUANT_OPS) for n in gm.graph.nodes) +class DummyFactory(ModelFactory): + """Dummy factory to pass quant_config for testing.""" + + def __init__(self, quant_config): + self.quant_config = quant_config + + def _build_model(self, device: str): + return + + def _load_checkpoint(self, model, device): + return + + def get_quant_config(self): + return self.quant_config + + @pytest.mark.parametrize( "quant_config,atol,rtol,num_p_og", [ @@ -51,11 +68,22 @@ def test_quantization(quant_config, atol, rtol, num_p_og): model.linear2.register_buffer( "input_scale", torch.tensor([1.0], device=model.linear2.weight.device) ) - - gm_transformed = run_test( + # set up sequence+cache objects + gm = torch_export_to_gm(model, args=(x,), clone=True) + gm_transformed = InferenceOptimizer( + DummyFactory(quant_config), + { + "quantize": { + "stage": "pattern_matcher", + }, + }, + )(None, gm) + gm_transformed.to("cuda") + + run_test_transformed_gm( model, x, - quantize, + gm_transformed, check_quantized, num_p_og, atol, @@ -122,10 +150,22 @@ def test_bmm_quantization(quant_config, atol, rtol, num_p_og, model_class): model.register_buffer("bmm_dynamic_input_scale", fp8_scale(x)) model.register_buffer("bmm_dynamic_weight_scale", fp8_scale(dummy_weight)) - gm_transformed = run_test( + # set up sequence+cache objects + gm = torch_export_to_gm(model, args=(x,), clone=True) + gm_transformed = InferenceOptimizer( + DummyFactory(quant_config), + { + "quantize": { + "stage": "pattern_matcher", + }, + }, + )(None, gm) + gm_transformed.to("cuda") + + run_test_transformed_gm( model, x, - quantize, + gm_transformed, check_quantized, num_p_og, atol,