Commit 95e767d
feat(attention): GQA/MQA + decode-phase support via IAttentionLayer
- Extend flash attention validator to accept GQA shapes (Hq != Hkv):
IAttentionLayer natively handles non-equal head counts without K/V
expansion. Requires Hq divisible by Hkv and matching batch/head_dim.
- Add decode-phase support (seq_q != seq_k) to all three attention
validators; only the seq dimension is skipped in shape checks.
- Document why GQA is not supported in the efficient attention validator:
PyTorch's eager kernel rejects Hq != Hkv, so no reference output exists;
GQA models dispatch to flash attention (FP16) or decompose via
matmul+_safe_softmax (FP32) and never produce this op with GQA shapes.
- Restructure test_attention.py: merge five SDPA classes into TestSDPA,
expand TestFlashAttention with test_decode and test_gqa methods,
add TestEfficientAttention.test_with_bias_decode; trim redundant cases
and remove BUG-1 inline annotations (kept only in module docstring).
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>1 parent 15d0831 commit 95e767d
4 files changed
Lines changed: 445 additions & 341 deletions
File tree
- py/torch_tensorrt/dynamo
- conversion
- lowering/passes
- tests/py/dynamo
- conversion
- hlo
Lines changed: 112 additions & 23 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
3964 | 3964 | | |
3965 | 3965 | | |
3966 | 3966 | | |
3967 | | - | |
3968 | | - | |
3969 | | - | |
3970 | | - | |
3971 | | - | |
| 3967 | + | |
3972 | 3968 | | |
3973 | 3969 | | |
3974 | 3970 | | |
| |||
3977 | 3973 | | |
3978 | 3974 | | |
3979 | 3975 | | |
3980 | | - | |
3981 | | - | |
3982 | | - | |
3983 | | - | |
3984 | | - | |
| 3976 | + | |
| 3977 | + | |
3985 | 3978 | | |
3986 | | - | |
| 3979 | + | |
3987 | 3980 | | |
3988 | 3981 | | |
| 3982 | + | |
| 3983 | + | |
| 3984 | + | |
| 3985 | + | |
| 3986 | + | |
| 3987 | + | |
| 3988 | + | |
| 3989 | + | |
| 3990 | + | |
| 3991 | + | |
| 3992 | + | |
| 3993 | + | |
| 3994 | + | |
| 3995 | + | |
| 3996 | + | |
| 3997 | + | |
| 3998 | + | |
| 3999 | + | |
| 4000 | + | |
| 4001 | + | |
| 4002 | + | |
| 4003 | + | |
| 4004 | + | |
| 4005 | + | |
| 4006 | + | |
| 4007 | + | |
| 4008 | + | |
| 4009 | + | |
| 4010 | + | |
| 4011 | + | |
| 4012 | + | |
| 4013 | + | |
| 4014 | + | |
| 4015 | + | |
| 4016 | + | |
| 4017 | + | |
| 4018 | + | |
| 4019 | + | |
| 4020 | + | |
| 4021 | + | |
| 4022 | + | |
| 4023 | + | |
| 4024 | + | |
| 4025 | + | |
| 4026 | + | |
3989 | 4027 | | |
3990 | 4028 | | |
3991 | 4029 | | |
| |||
4032 | 4070 | | |
4033 | 4071 | | |
4034 | 4072 | | |
4035 | | - | |
4036 | | - | |
4037 | | - | |
4038 | | - | |
4039 | | - | |
| 4073 | + | |
4040 | 4074 | | |
4041 | | - | |
| 4075 | + | |
4042 | 4076 | | |
4043 | 4077 | | |
| 4078 | + | |
| 4079 | + | |
| 4080 | + | |
| 4081 | + | |
| 4082 | + | |
| 4083 | + | |
| 4084 | + | |
| 4085 | + | |
| 4086 | + | |
| 4087 | + | |
| 4088 | + | |
| 4089 | + | |
| 4090 | + | |
| 4091 | + | |
| 4092 | + | |
| 4093 | + | |
| 4094 | + | |
| 4095 | + | |
| 4096 | + | |
| 4097 | + | |
| 4098 | + | |
| 4099 | + | |
| 4100 | + | |
| 4101 | + | |
| 4102 | + | |
| 4103 | + | |
| 4104 | + | |
| 4105 | + | |
| 4106 | + | |
| 4107 | + | |
| 4108 | + | |
| 4109 | + | |
| 4110 | + | |
| 4111 | + | |
| 4112 | + | |
| 4113 | + | |
| 4114 | + | |
| 4115 | + | |
| 4116 | + | |
4044 | 4117 | | |
4045 | 4118 | | |
4046 | 4119 | | |
| |||
4086 | 4159 | | |
4087 | 4160 | | |
4088 | 4161 | | |
4089 | | - | |
4090 | | - | |
4091 | | - | |
4092 | | - | |
4093 | | - | |
| 4162 | + | |
4094 | 4163 | | |
4095 | | - | |
| 4164 | + | |
4096 | 4165 | | |
4097 | 4166 | | |
| 4167 | + | |
| 4168 | + | |
| 4169 | + | |
| 4170 | + | |
| 4171 | + | |
| 4172 | + | |
| 4173 | + | |
| 4174 | + | |
| 4175 | + | |
| 4176 | + | |
| 4177 | + | |
| 4178 | + | |
| 4179 | + | |
| 4180 | + | |
| 4181 | + | |
| 4182 | + | |
| 4183 | + | |
| 4184 | + | |
| 4185 | + | |
| 4186 | + | |
4098 | 4187 | | |
4099 | 4188 | | |
4100 | 4189 | | |
| |||
Lines changed: 43 additions & 16 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
12 | 12 | | |
13 | 13 | | |
14 | 14 | | |
15 | | - | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
16 | 26 | | |
17 | 27 | | |
18 | 28 | | |
19 | 29 | | |
20 | 30 | | |
21 | 31 | | |
22 | 32 | | |
23 | | - | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
24 | 48 | | |
25 | | - | |
26 | | - | |
27 | | - | |
28 | | - | |
29 | | - | |
30 | | - | |
31 | | - | |
32 | | - | |
33 | | - | |
34 | | - | |
35 | | - | |
36 | | - | |
37 | | - | |
38 | 49 | | |
39 | | - | |
| 50 | + | |
| 51 | + | |
40 | 52 | | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
41 | 68 | | |
42 | 69 | | |
43 | 70 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
6 | 6 | | |
7 | 7 | | |
8 | 8 | | |
9 | | - | |
10 | 9 | | |
11 | 10 | | |
12 | 11 | | |
13 | 12 | | |
14 | 13 | | |
15 | 14 | | |
16 | | - | |
17 | 15 | | |
18 | 16 | | |
19 | 17 | | |
| |||
109 | 107 | | |
110 | 108 | | |
111 | 109 | | |
112 | | - | |
113 | | - | |
114 | | - | |
115 | | - | |
116 | | - | |
117 | | - | |
118 | | - | |
119 | | - | |
120 | | - | |
121 | | - | |
122 | | - | |
123 | | - | |
124 | | - | |
125 | | - | |
126 | | - | |
127 | | - | |
128 | | - | |
129 | | - | |
130 | | - | |
131 | | - | |
132 | | - | |
133 | | - | |
134 | | - | |
135 | | - | |
136 | | - | |
137 | | - | |
138 | | - | |
139 | | - | |
140 | | - | |
141 | | - | |
142 | | - | |
143 | | - | |
144 | | - | |
145 | | - | |
146 | | - | |
147 | | - | |
148 | | - | |
149 | | - | |
150 | | - | |
151 | | - | |
152 | | - | |
153 | | - | |
154 | | - | |
155 | | - | |
156 | | - | |
157 | | - | |
158 | | - | |
159 | | - | |
160 | | - | |
161 | | - | |
162 | | - | |
163 | | - | |
164 | 110 | | |
165 | 111 | | |
166 | 112 | | |
| |||
0 commit comments