Skip to content

Commit 95e767d

Browse files
yizhuoz004claude
andcommitted
feat(attention): GQA/MQA + decode-phase support via IAttentionLayer
- Extend flash attention validator to accept GQA shapes (Hq != Hkv): IAttentionLayer natively handles non-equal head counts without K/V expansion. Requires Hq divisible by Hkv and matching batch/head_dim. - Add decode-phase support (seq_q != seq_k) to all three attention validators; only the seq dimension is skipped in shape checks. - Document why GQA is not supported in the efficient attention validator: PyTorch's eager kernel rejects Hq != Hkv, so no reference output exists; GQA models dispatch to flash attention (FP16) or decompose via matmul+_safe_softmax (FP32) and never produce this op with GQA shapes. - Restructure test_attention.py: merge five SDPA classes into TestSDPA, expand TestFlashAttention with test_decode and test_gqa methods, add TestEfficientAttention.test_with_bias_decode; trim redundant cases and remove BUG-1 inline annotations (kept only in module docstring). Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 15d0831 commit 95e767d

4 files changed

Lines changed: 445 additions & 341 deletions

File tree

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 112 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3964,11 +3964,7 @@ def aten_ops_linear(
39643964
def scaled_dot_product_attention_validator(
39653965
node: Node, settings: Optional[CompilationSettings] = None
39663966
) -> bool:
3967-
if node.kwargs.get("enable_gqa", False):
3968-
_LOGGER.debug(
3969-
"enable_gqa is not yet supported by the converter. Please try setting decompose_attention=True in the compilation settings."
3970-
)
3971-
return False
3967+
enable_gqa = node.kwargs.get("enable_gqa", False)
39723968

39733969
query_shape, key_shape, value_shape = None, None, None
39743970
if "val" in node.args[0].meta:
@@ -3977,15 +3973,57 @@ def scaled_dot_product_attention_validator(
39773973
key_shape = node.args[1].meta["val"].size()
39783974
if "val" in node.args[2].meta:
39793975
value_shape = node.args[2].meta["val"].size()
3980-
if (
3981-
query_shape != key_shape
3982-
or query_shape != value_shape
3983-
or key_shape != value_shape
3984-
):
3976+
3977+
if key_shape != value_shape:
39853978
_LOGGER.debug(
3986-
"query, key, and value have different shapes. Please try setting decompose_attention=True in the compilation settings."
3979+
"key and value have different shapes, which is not supported. Please try setting decompose_attention=True in the compilation settings."
39873980
)
39883981
return False
3982+
3983+
if query_shape is not None and key_shape is not None:
3984+
if len(query_shape) != len(key_shape):
3985+
_LOGGER.debug(
3986+
"query and key have different ranks, which is not supported."
3987+
)
3988+
return False
3989+
ndim = len(query_shape)
3990+
if enable_gqa:
3991+
# IAttentionLayer natively supports GQA: Q and K/V may differ on the
3992+
# head dim (dim 1) as long as Hq is divisible by Hkv.
3993+
# Check batch (dim 0) and head_dim (last dim) match; skip seq (dim -2)
3994+
# and head (dim 1) dims.
3995+
head_dim = ndim - 1
3996+
seq_dim = ndim - 2
3997+
heads_dim = 1
3998+
for i in range(ndim):
3999+
if i in (seq_dim, heads_dim):
4000+
continue
4001+
if query_shape[i] != key_shape[i]:
4002+
_LOGGER.debug(
4003+
f"query and key mismatch on dim {i} with enable_gqa=True."
4004+
)
4005+
return False
4006+
num_q_heads = query_shape[1]
4007+
num_kv_heads = key_shape[1]
4008+
if num_q_heads % num_kv_heads != 0:
4009+
_LOGGER.debug(
4010+
f"enable_gqa=True but num_q_heads={num_q_heads} is not divisible "
4011+
f"by num_kv_heads={num_kv_heads}."
4012+
)
4013+
return False
4014+
else:
4015+
# IAttentionLayer supports decode-phase (seq_q != seq_k).
4016+
# Check all dims except the seq dim.
4017+
seq_dim = ndim - 2
4018+
if any(
4019+
query_shape[i] != key_shape[i]
4020+
for i in range(ndim)
4021+
if i != seq_dim
4022+
):
4023+
_LOGGER.debug(
4024+
"query and key have incompatible shapes (batch, heads, or head_dim mismatch). Please try setting decompose_attention=True in the compilation settings."
4025+
)
4026+
return False
39894027
return True
39904028

39914029

@@ -4032,15 +4070,50 @@ def scaled_dot_product_flash_attention_validator(
40324070
key_shape = node.args[1].meta["val"].size()
40334071
if "val" in node.args[2].meta:
40344072
value_shape = node.args[2].meta["val"].size()
4035-
if (
4036-
query_shape != key_shape
4037-
or query_shape != value_shape
4038-
or key_shape != value_shape
4039-
):
4073+
if key_shape != value_shape:
40404074
_LOGGER.debug(
4041-
"query, key, and value have different shapes. Please try setting decompose_attention=True in the compilation settings."
4075+
"key and value have different shapes, which is not supported. Please try setting decompose_attention=True in the compilation settings."
40424076
)
40434077
return False
4078+
if query_shape is not None and key_shape is not None:
4079+
if len(query_shape) != len(key_shape):
4080+
_LOGGER.debug(
4081+
"query and key have different ranks, which is not supported."
4082+
)
4083+
return False
4084+
ndim = len(query_shape)
4085+
seq_dim = ndim - 2
4086+
heads_dim = 1
4087+
num_q_heads = query_shape[heads_dim]
4088+
num_kv_heads = key_shape[heads_dim]
4089+
is_gqa = num_q_heads != num_kv_heads
4090+
if is_gqa:
4091+
# GQA: IAttentionLayer natively handles Hq != Hkv.
4092+
# Require batch/head_dim to match and Hq divisible by Hkv.
4093+
for i in range(ndim):
4094+
if i in (seq_dim, heads_dim):
4095+
continue
4096+
if query_shape[i] != key_shape[i]:
4097+
_LOGGER.debug(
4098+
f"GQA: query and key mismatch on dim {i}."
4099+
)
4100+
return False
4101+
if num_q_heads % num_kv_heads != 0:
4102+
_LOGGER.debug(
4103+
f"GQA: num_q_heads={num_q_heads} not divisible by num_kv_heads={num_kv_heads}."
4104+
)
4105+
return False
4106+
else:
4107+
# MHA / decode-phase: seq may differ, all other dims must match.
4108+
if any(
4109+
query_shape[i] != key_shape[i]
4110+
for i in range(ndim)
4111+
if i != seq_dim
4112+
):
4113+
_LOGGER.debug(
4114+
"query and key have incompatible shapes (batch, heads, or head_dim mismatch). Please try setting decompose_attention=True in the compilation settings."
4115+
)
4116+
return False
40444117
return True
40454118

40464119

@@ -4086,15 +4159,31 @@ def scaled_dot_product_efficient_attention_validator(
40864159
key_shape = node.args[1].meta["val"].size()
40874160
if "val" in node.args[2].meta:
40884161
value_shape = node.args[2].meta["val"].size()
4089-
if (
4090-
query_shape != key_shape
4091-
or query_shape != value_shape
4092-
or key_shape != value_shape
4093-
):
4162+
if key_shape != value_shape:
40944163
_LOGGER.debug(
4095-
"query, key, and value have different shapes. Please try setting decompose_attention=True in the compilation settings."
4164+
"key and value have different shapes, which is not supported. Please try setting decompose_attention=True in the compilation settings."
40964165
)
40974166
return False
4167+
# GQA (Hq != Hkv) is intentionally not supported here.
4168+
# PyTorch's eager _scaled_dot_product_efficient_attention kernel rejects
4169+
# non-equal head counts at runtime, so no valid reference output exists for
4170+
# comparison. In practice, GQA models on CUDA dispatch to
4171+
# _scaled_dot_product_flash_attention (FP16/BF16) or decompose into
4172+
# matmul+_safe_softmax (FP32) — this op never appears with GQA shapes in
4173+
# a real FX graph. GQA is handled by the flash attention validator instead.
4174+
#
4175+
# IAttentionLayer does support decode-phase (seq_q != seq_k), so only the
4176+
# sequence dimension is skipped in the shape check below.
4177+
if query_shape is not None and key_shape is not None:
4178+
if len(query_shape) != len(key_shape) or any(
4179+
query_shape[i] != key_shape[i]
4180+
for i in range(len(query_shape))
4181+
if i != len(query_shape) - 2 # skip the seq dim
4182+
):
4183+
_LOGGER.debug(
4184+
"query and key have incompatible shapes (batch, heads, or head_dim mismatch). Please try setting decompose_attention=True in the compilation settings."
4185+
)
4186+
return False
40984187
return True
40994188

41004189

py/torch_tensorrt/dynamo/lowering/passes/force_causal_efficient_attention.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,32 +12,59 @@
1212
def force_causal_efficient_attention(
1313
gm: torch.fx.GraphModule, settings: CompilationSettings
1414
) -> torch.fx.GraphModule:
15-
"""Force efficient-attention calls to causal mode when enabled in settings."""
15+
"""Force efficient-attention calls to causal mode when enabled in settings.
16+
17+
For square attention (seq_q == seq_k): replaces attn_bias with is_causal=True
18+
so IAttentionLayer can use its native causal path.
19+
20+
For decode-phase attention (seq_q != seq_k): skip the transformation.
21+
Applying is_causal=True is semantically wrong here — it creates a lower-
22+
triangular mask aligned to position 0, so the query attends only to k[0]
23+
instead of all past keys. The node is left unchanged and passed to
24+
IAttentionLayer, which supports non-square Q/K natively.
25+
"""
1626
if not settings.attn_bias_is_causal:
1727
return gm
1828

1929
changed = False
2030
for node in gm.graph.nodes:
2131
if (
2232
node.target
23-
== torch.ops.aten._scaled_dot_product_efficient_attention.default
33+
!= torch.ops.aten._scaled_dot_product_efficient_attention.default
34+
):
35+
continue
36+
37+
attn_bias = node.args[3] if len(node.args) > 3 else None
38+
if attn_bias is None:
39+
continue
40+
41+
query_node, key_node = node.args[0], node.args[1]
42+
query_meta = query_node.meta.get("val") if hasattr(query_node, "meta") else None
43+
key_meta = key_node.meta.get("val") if hasattr(key_node, "meta") else None
44+
if (
45+
query_meta is not None
46+
and key_meta is not None
47+
and query_meta.size(2) != key_meta.size(2)
2448
):
25-
attn_bias = node.args[3] if len(node.args) > 3 else None
26-
if attn_bias is None:
27-
continue
28-
node.args = (
29-
node.args[0],
30-
node.args[1],
31-
node.args[2],
32-
None,
33-
False,
34-
0.0,
35-
True,
36-
)
37-
changed = True
3849
logger.debug(
39-
f"The args of node {node} was changed to causal mode. Now the node's arguments are: {node.args}"
50+
f"Skipping causal force for node {node}: seq_q={query_meta.size(2)} "
51+
f"!= seq_k={key_meta.size(2)} (decode-phase, IAttentionLayer handles it)"
4052
)
53+
continue
54+
55+
node.args = (
56+
node.args[0],
57+
node.args[1],
58+
node.args[2],
59+
None,
60+
False,
61+
0.0,
62+
True,
63+
)
64+
changed = True
65+
logger.debug(
66+
f"Node {node} changed to causal mode: {node.args}"
67+
)
4168

4269
if changed:
4370
gm = clean_up_graph_after_modifications(gm)

tests/py/dynamo/conversion/harness.py

Lines changed: 0 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,12 @@
66
from typing import Any, Callable, List, Optional, Sequence, Tuple
77

88
import torch
9-
import torch_tensorrt
109
from torch.fx.experimental.proxy_tensor import unset_fake_temporarily
1110
from torch.fx.passes.shape_prop import ShapeProp
1211
from torch.testing._internal.common_utils import TestCase
1312
from torch_tensorrt import Input
1413
from torch_tensorrt._Device import Device
1514
from torch_tensorrt._enums import dtype
16-
from torch_tensorrt.dynamo import _defaults
1715
from torch_tensorrt.dynamo._defaults import default_device
1816
from torch_tensorrt.dynamo._settings import CompilationSettings
1917
from torch_tensorrt.dynamo._tracer import get_dynamic_shapes_args
@@ -109,58 +107,6 @@ def get_use_dynamo_tracer(use_dynamo_tracer: Any) -> bool:
109107
return False
110108

111109

112-
# this method is only used in our converter test to infer the module output dtypes via dummy inference
113-
# which is due to fx.symbolic_trace does not have the meta['val'] info in the node
114-
# TODO: lan to remove this once our converter test is moved from fx.symbolic_trace to dynamo trace
115-
def infer_module_output_dtypes_for_test(
116-
module: torch.fx.GraphModule,
117-
inputs: Sequence[Input],
118-
device: Device,
119-
kwarg_inputs: Optional[dict[str, Any]] = None,
120-
truncate_double: bool = False,
121-
) -> List[dtype]:
122-
"""
123-
This function performs model inference to determine the output dtypes
124-
and truncates them accordingly. inputs can be either arg_inputs or flattened input list.
125-
If it is flattened list, kwarg_inputs should be None, as it is already included in the flattened input.
126-
"""
127-
# TODO: We can also determine output dtypes from the module.graph based on node metadata.
128-
# However, our converter tests use fx.symbolic_trace which sometimes does not provide metadata,
129-
# so we stick to the model inference approach currently.
130-
with unset_fake_temporarily():
131-
# Get the device on which the model exists
132-
# For large models, this can be done on CPU to save GPU memory allocation for TRT.
133-
device = get_model_device(module)
134-
torch_inputs = get_torch_inputs(inputs, device)
135-
if kwarg_inputs is None:
136-
kwarg_inputs = {}
137-
torch_kwarg_inputs = get_torch_inputs(kwarg_inputs, device)
138-
module_outputs = module(*torch_inputs, **torch_kwarg_inputs)
139-
if not isinstance(module_outputs, (list, tuple)):
140-
module_outputs = [module_outputs]
141-
142-
# Int64 outputs can sometimes be generated from within other operators
143-
# such as aten.sum - such outputs can be truncated
144-
output_dtypes = []
145-
for output in module_outputs:
146-
output_ = output
147-
# We don't need to check if output is nested here because the input module will be flattened
148-
if not isinstance(output, torch.Tensor):
149-
if isinstance(output, str):
150-
raise ValueError(
151-
f"Received an output type {type(output)} that's not in the acceptable datatypes (https://pytorch.org/docs/stable/tensor_attributes.html#torch.dtype)"
152-
)
153-
else:
154-
output_ = torch.tensor(output)
155-
156-
if truncate_double and output_.dtype == dtype.float64:
157-
output_dtypes.append(dtype.float32)
158-
else:
159-
output_dtypes.append(dtype._from(output_.dtype))
160-
161-
return output_dtypes
162-
163-
164110
def fetch_attr(mod, target):
165111
"""
166112
Fetch an attribute from the ``Module`` hierarchy of ``mod.module``.

0 commit comments

Comments
 (0)