Skip to content

Commit 8321164

Browse files
use helion instead of triton for low precision attention quantization kernels
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 4e2211f Pull-Request: #3880
1 parent e9a3547 commit 8321164

File tree

2 files changed

+56
-159
lines changed

2 files changed

+56
-159
lines changed

torchao/prototype/attention/fp8_fa3/helion_qkv_quantization.py

Lines changed: 51 additions & 154 deletions
Original file line numberDiff line numberDiff line change
@@ -5,25 +5,17 @@
55
# LICENSE file in the root directory of this source tree.
66

77
"""
8-
Fused RoPE + FP8 Quantization kernels using Helion (Optimized version).
8+
FP8 Quantization kernels using Helion.
99
10-
This module provides Helion kernels that fuse:
11-
- RoPE (Rotary Position Embeddings) for Q and K
12-
- FP8 quantization for Q, K, V
13-
- Layout transformation from [B, S, H, D] (FLUX) to [B, H, S, D] (SDPA)
14-
15-
The layout transformation is fused directly into the kernels, avoiding
16-
separate transpose+contiguous memory copies that would otherwise add ~70%
17-
overhead to the overall kernel execution time.
10+
This module provides Helion kernels that perform per-head FP8 quantization
11+
for Q, K, V tensors.
1812
1913
The 3-kernel structure parallelizes over (B, H, S) with nested D loop:
20-
- Phase 1: RoPE + partial max (reads [B,S,H,D], writes [B,H,S,D])
21-
- Reduce: Aggregate maxes per head (parallel over B * H)
22-
- Phase 2: Quantize (reads V from [B,S,H,D], writes [B,H,S,D])
23-
24-
Input format: [B, S, H, D] (FLUX-style)
25-
Output format: [B, H, S, D] (SDPA-style)
14+
- Phase 1: Compute partial absmax values per S-block
15+
- Reduce: Aggregate maxes per head and compute scale/descale factors
16+
- Phase 2: Apply scales and cast to FP8
2617
18+
Input/output format: [B, H, S, D]
2719
"""
2820

2921
from typing import Optional, Tuple
@@ -35,74 +27,35 @@
3527

3628

3729
# =============================================================================
38-
# Phase 1: RoPE + Max computation
39-
# Reads from [B, S, H, D] (FLUX layout), writes to [B, H, S, D] (SDPA layout)
40-
# Fuses layout transformation with RoPE computation to avoid separate copy
30+
# Phase 1: Partial absmax computation
4131
# =============================================================================
4232

4333

4434
@helion.kernel(static_shapes=True)
45-
def rope_qkv_phase1_helion(
46-
q: torch.Tensor, # [B, S, H, D] - FLUX input layout
47-
k: torch.Tensor, # [B, S, H, D] - FLUX input layout
48-
v: torch.Tensor, # [B, S, H, D] - FLUX input layout
49-
cos: torch.Tensor, # [S, D]
50-
sin: torch.Tensor, # [S, D]
51-
q_rope_out: torch.Tensor, # [B, H, S, D] - output in SDPA layout
52-
k_rope_out: torch.Tensor, # [B, H, S, D] - output in SDPA layout
35+
def qkv_phase1_helion(
36+
q: torch.Tensor, # [B, H, S, D]
37+
k: torch.Tensor, # [B, H, S, D]
38+
v: torch.Tensor, # [B, H, S, D]
5339
partial_max: torch.Tensor, # [B, H, num_s_blocks, 3] - output
5440
) -> None:
5541
"""
56-
Phase 1: Apply RoPE to Q and K, store results, compute partial max.
57-
58-
Reads from [B, S, H, D] (FLUX layout) and writes to [B, H, S, D] (SDPA layout),
59-
fusing the layout transformation with the RoPE computation.
42+
Phase 1: Compute partial absmax for Q, K, V per S-block.
6043
6144
Uses 3D tiling over (B, H, S) with block_size=[1, 1, block_s], plus
6245
a nested inner loop over D with block_size=D (single iteration).
6346
"""
64-
B, S, H, D = q.size()
65-
D_HALF = hl.specialize(D // 2)
47+
B, H, S, D = q.size()
6648

6749
block_size_s = hl.register_block_size(S)
6850

6951
for tile_b, tile_h, tile_s in hl.tile([B, H, S], block_size=[1, 1, block_size_s]):
7052
for tile_d in hl.tile(D, block_size=D):
71-
# Load from [B, S, H, D] input layout
72-
q_tile = q[tile_b.begin, tile_s, tile_h.begin, tile_d].to(torch.float32)
73-
k_tile = k[tile_b.begin, tile_s, tile_h.begin, tile_d].to(torch.float32)
74-
v_tile = v[tile_b.begin, tile_s, tile_h.begin, tile_d].to(torch.float32)
75-
76-
cos_tile = cos[tile_s, tile_d].to(torch.float32)
77-
sin_tile = sin[tile_s, tile_d].to(torch.float32)
78-
79-
# Split into real/imag components
80-
q_tile_ri = q_tile.reshape(-1, D_HALF, 2)
81-
k_tile_ri = k_tile.reshape(-1, D_HALF, 2)
82-
q_real, q_imag = hl.split(q_tile_ri)
83-
k_real, k_imag = hl.split(k_tile_ri)
84-
85-
cos_tile_ri = cos_tile.reshape(-1, D_HALF, 2)
86-
sin_tile_ri = sin_tile.reshape(-1, D_HALF, 2)
87-
cos_real, cos_imag = hl.split(cos_tile_ri)
88-
sin_real, sin_imag = hl.split(sin_tile_ri)
89-
90-
# Apply RoPE
91-
q_rope_real = q_real * cos_real - q_imag * sin_real
92-
q_rope_imag = q_real * sin_imag + q_imag * cos_imag
93-
k_rope_real = k_real * cos_real - k_imag * sin_real
94-
k_rope_imag = k_real * sin_imag + k_imag * cos_imag
95-
96-
q_rope = hl.join(q_rope_real, q_rope_imag).reshape(-1, D)
97-
k_rope = hl.join(k_rope_real, k_rope_imag).reshape(-1, D)
98-
99-
# Store RoPE'd Q, K to [B, H, S, D] output layout
100-
q_rope_out[tile_b.begin, tile_h.begin, tile_s, tile_d] = q_rope.to(q.dtype)
101-
k_rope_out[tile_b.begin, tile_h.begin, tile_s, tile_d] = k_rope.to(k.dtype)
102-
103-
# Compute partial max for this block
104-
q_max_tile = torch.amax(torch.abs(q_rope), dim=-1)
105-
k_max_tile = torch.amax(torch.abs(k_rope), dim=-1)
53+
q_tile = q[tile_b.begin, tile_h.begin, tile_s, tile_d].to(torch.float32)
54+
k_tile = k[tile_b.begin, tile_h.begin, tile_s, tile_d].to(torch.float32)
55+
v_tile = v[tile_b.begin, tile_h.begin, tile_s, tile_d].to(torch.float32)
56+
57+
q_max_tile = torch.amax(torch.abs(q_tile), dim=-1)
58+
k_max_tile = torch.amax(torch.abs(k_tile), dim=-1)
10659
v_max_tile = torch.amax(torch.abs(v_tile), dim=-1)
10760

10861
q_max = torch.amax(q_max_tile)
@@ -116,12 +69,11 @@ def rope_qkv_phase1_helion(
11669

11770
# =============================================================================
11871
# Reduce kernel: Aggregate partial maxes and compute scales
119-
# Parallelizes over (B, H) using hl.tile([...]) with block_size=[1, 1]
12072
# =============================================================================
12173

12274

12375
@helion.kernel(static_shapes=True)
124-
def rope_qkv_reduce_helion(
76+
def qkv_reduce_helion(
12577
partial_max: torch.Tensor, # [B, H, num_s_blocks, 3]
12678
q_scale: torch.Tensor, # [B, H] - output
12779
k_scale: torch.Tensor, # [B, H] - output
@@ -134,13 +86,10 @@ def rope_qkv_reduce_helion(
13486
Reduce partial maxes across S-blocks and compute scales/descales.
13587
13688
Uses 2D tiling over (B, H) with block_size=[1, 1].
137-
- tile_b.begin, tile_h.begin are scalar indices
138-
- Sequential reduction over S blocks using vectorized access
13989
"""
14090
FP8_MAX: float = 448.0
14191
eps: float = 1e-12
14292
B, H = q_scale.size()
143-
num_s_blocks = partial_max.size(2)
14493

14594
for tile_b, tile_h in hl.tile([B, H], block_size=[1, 1]):
14695
q_partial = partial_max[tile_b.begin, tile_h.begin, :, 0]
@@ -171,17 +120,15 @@ def rope_qkv_reduce_helion(
171120

172121

173122
# =============================================================================
174-
# Phase 2: Quantize
175-
# Reads V from [B, S, H, D] (FLUX layout), writes to [B, H, S, D] (SDPA layout)
176-
# Fuses layout transformation with quantization to avoid separate copy
123+
# Phase 2: Quantize to FP8
177124
# =============================================================================
178125

179126

180127
@helion.kernel(static_shapes=True)
181-
def rope_qkv_phase2_helion(
182-
q_rope: torch.Tensor, # [B, H, S, D] - intermediate RoPE'd Q (SDPA layout)
183-
k_rope: torch.Tensor, # [B, H, S, D] - intermediate RoPE'd K (SDPA layout)
184-
v: torch.Tensor, # [B, S, H, D] - original V (FLUX layout)
128+
def qkv_phase2_helion(
129+
q: torch.Tensor, # [B, H, S, D]
130+
k: torch.Tensor, # [B, H, S, D]
131+
v: torch.Tensor, # [B, H, S, D]
185132
q_out: torch.Tensor, # [B, H, S, D] - FP8 output
186133
k_out: torch.Tensor, # [B, H, S, D] - FP8 output
187134
v_out: torch.Tensor, # [B, H, S, D] - FP8 output
@@ -190,16 +137,12 @@ def rope_qkv_phase2_helion(
190137
v_scale: torch.Tensor, # [B, H]
191138
) -> None:
192139
"""
193-
Phase 2: Quantize pre-computed RoPE'd Q, K and V.
194-
195-
Q and K are read from intermediate buffers in [B, H, S, D] (SDPA layout).
196-
V is read from original input in [B, S, H, D] (FLUX layout).
197-
All outputs are written in [B, H, S, D] (SDPA layout).
140+
Phase 2: Quantize Q, K, V to FP8 using precomputed scales.
198141
199142
Uses 3D tiling over (B, H, S) with block_size=[1, 1, block_s], plus
200143
a nested inner loop over D with block_size=D (single iteration).
201144
"""
202-
B, H, S, D = q_rope.size()
145+
B, H, S, D = q.size()
203146

204147
block_size_s = hl.register_block_size(S)
205148

@@ -209,14 +152,10 @@ def rope_qkv_phase2_helion(
209152
k_sc = k_scale[tile_b.begin, tile_h.begin]
210153
v_sc = v_scale[tile_b.begin, tile_h.begin]
211154

212-
# Load Q, K from [B, H, S, D] intermediate buffers
213-
q_val = q_rope[tile_b.begin, tile_h.begin, tile_s, tile_d].to(torch.float32)
214-
k_val = k_rope[tile_b.begin, tile_h.begin, tile_s, tile_d].to(torch.float32)
215-
216-
# Load V from [B, S, H, D] input layout
217-
v_val = v[tile_b.begin, tile_s, tile_h.begin, tile_d].to(torch.float32)
155+
q_val = q[tile_b.begin, tile_h.begin, tile_s, tile_d].to(torch.float32)
156+
k_val = k[tile_b.begin, tile_h.begin, tile_s, tile_d].to(torch.float32)
157+
v_val = v[tile_b.begin, tile_h.begin, tile_s, tile_d].to(torch.float32)
218158

219-
# Quantize to FP8
220159
q_fp8 = (q_val * q_sc).to(torch.float8_e4m3fn)
221160
k_fp8 = (k_val * k_sc).to(torch.float8_e4m3fn)
222161
v_fp8 = (v_val * v_sc).to(torch.float8_e4m3fn)
@@ -227,16 +166,14 @@ def rope_qkv_phase2_helion(
227166

228167

229168
# =============================================================================
230-
# Main entry point (same API as Triton version)
169+
# Main entry point
231170
# =============================================================================
232171

233172

234-
def fp8_rope_quantize_func(
173+
def helion_fp8_sdpa_quantize(
235174
q: torch.Tensor,
236175
k: torch.Tensor,
237176
v: torch.Tensor,
238-
cos: torch.Tensor,
239-
sin: torch.Tensor,
240177
num_chunks: Optional[int] = None, # Ignored - block sizes are autotuned
241178
) -> Tuple[
242179
torch.Tensor,
@@ -247,64 +184,39 @@ def fp8_rope_quantize_func(
247184
torch.Tensor,
248185
]:
249186
"""
250-
Fused RoPE + FP8 quantization for Q, K, V tensors.
251-
252-
Applies RoPE to Q and K, then quantizes all tensors to FP8 with per-head scaling.
253-
Also performs layout transformation from [B, S, H, D] to [B, H, S, D].
254-
255-
The layout transformation is fused into the kernels themselves, avoiding
256-
the need for separate transpose+contiguous memory copies.
187+
Fused per-head FP8 quantization for Q, K, V using Helion kernels.
257188
258189
Uses 3-kernel structure with full parallelization:
259-
- Phase 1: RoPE + partial max (parallel over B * H * S_blocks)
190+
- Phase 1: Partial absmax (parallel over B * H * S_blocks)
260191
- Reduce: Aggregate maxes per head (parallel over B * H)
261192
- Phase 2: Quantize (parallel over B * H * S_blocks)
262193
263194
Note: The num_chunks parameter is ignored. Block sizes are autotuned by Helion.
264195
265196
Args:
266-
q: Query tensor of shape [B, S, H, D] in bf16/fp16
267-
k: Key tensor of shape [B, S, H, D] in bf16/fp16
268-
v: Value tensor of shape [B, S, H, D] in bf16/fp16
269-
cos: Cosine frequencies for RoPE, shape [S, D]
270-
sin: Sine frequencies for RoPE, shape [S, D]
197+
q: Query tensor of shape [B, H, S, D] in bf16/fp16
198+
k: Key tensor of shape [B, H, S, D] in bf16/fp16
199+
v: Value tensor of shape [B, H, S, D] in bf16/fp16
271200
num_chunks: Ignored (kept for API compatibility)
272201
273202
Returns:
274-
q_fp8: Quantized query with RoPE, shape [B, H, S, D] in fp8
275-
k_fp8: Quantized key with RoPE, shape [B, H, S, D] in fp8
276-
v_fp8: Quantized value, shape [B, H, S, D] in fp8
277-
q_descale: Query descale factors, shape [B, H] in fp32
278-
k_descale: Key descale factors, shape [B, H] in fp32
279-
v_descale: Value descale factors, shape [B, H] in fp32
203+
q_fp8, k_fp8, v_fp8: Quantized tensors in float8_e4m3fn, shape [B, H, S, D]
204+
q_descale, k_descale, v_descale: Descale factors of shape [B, H] in fp32
280205
"""
281-
assert q.dim() == 4, f"Expected 4D tensor [B, S, H, D], got {q.dim()}D"
282-
assert k.dim() == 4, f"Expected 4D tensor [B, S, H, D], got {k.dim()}D"
283-
assert v.dim() == 4, f"Expected 4D tensor [B, S, H, D], got {v.dim()}D"
206+
assert q.dim() == 4, f"Expected 4D tensor [B, H, S, D], got {q.dim()}D"
207+
assert k.dim() == 4, f"Expected 4D tensor [B, H, S, D], got {k.dim()}D"
208+
assert v.dim() == 4, f"Expected 4D tensor [B, H, S, D], got {v.dim()}D"
284209
assert q.shape == k.shape == v.shape, "Q, K, V must have the same shape"
285-
assert cos.dim() == 2, f"Expected 2D cos tensor [S, D], got {cos.dim()}D"
286-
assert sin.dim() == 2, f"Expected 2D sin tensor [S, D], got {sin.dim()}D"
287210

288-
B, S, H, D = q.shape
211+
B, H, S, D = q.shape
289212

290-
assert D % 2 == 0, f"Head dimension D must be even for RoPE, got D={D}"
291-
assert cos.shape == (S, D), f"Expected cos shape [{S}, {D}], got {cos.shape}"
292-
assert sin.shape == (S, D), f"Expected sin shape [{S}, {D}], got {sin.shape}"
293-
294-
# Ensure inputs are contiguous
295213
q = q.contiguous()
296214
k = k.contiguous()
297215
v = v.contiguous()
298-
cos = cos.contiguous()
299-
sin = sin.contiguous()
300216

301217
# Upper bound for S blocks (block_size_s is autotuned)
302218
max_s_blocks = S
303219

304-
# Allocate intermediate buffers
305-
q_rope_intermediate = torch.empty(B, H, S, D, dtype=q.dtype, device=q.device)
306-
k_rope_intermediate = torch.empty(B, H, S, D, dtype=k.dtype, device=q.device)
307-
308220
partial_max = torch.zeros(
309221
B, H, max_s_blocks, 3, dtype=torch.float32, device=q.device
310222
)
@@ -316,20 +228,11 @@ def fp8_rope_quantize_func(
316228
k_descale = torch.empty(B, H, dtype=torch.float32, device=q.device)
317229
v_descale = torch.empty(B, H, dtype=torch.float32, device=q.device)
318230

319-
# Phase 1: RoPE + partial max
320-
rope_qkv_phase1_helion(
321-
q,
322-
k,
323-
v,
324-
cos,
325-
sin,
326-
q_rope_intermediate,
327-
k_rope_intermediate,
328-
partial_max,
329-
)
231+
# Phase 1: partial absmax
232+
qkv_phase1_helion(q, k, v, partial_max)
330233

331234
# Reduce: aggregate maxes per head
332-
rope_qkv_reduce_helion(
235+
qkv_reduce_helion(
333236
partial_max,
334237
q_scale,
335238
k_scale,
@@ -345,16 +248,10 @@ def fp8_rope_quantize_func(
345248
v_fp8 = torch.empty(B, H, S, D, dtype=torch.float8_e4m3fn, device=q.device)
346249

347250
# Phase 2: quantize
348-
rope_qkv_phase2_helion(
349-
q_rope_intermediate,
350-
k_rope_intermediate,
351-
v,
352-
q_fp8,
353-
k_fp8,
354-
v_fp8,
355-
q_scale,
356-
k_scale,
357-
v_scale,
251+
qkv_phase2_helion(
252+
q, k, v,
253+
q_fp8, k_fp8, v_fp8,
254+
q_scale, k_scale, v_scale,
358255
)
359256

360257
return q_fp8, k_fp8, v_fp8, q_descale, k_descale, v_descale

torchao/prototype/attention/fp8_fa3/quantization.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,17 @@ def _fp8_sdpa_quantize(
7070
if q.shape[3] != k.shape[3]:
7171
raise ValueError(f"Head dim mismatch: {q.shape[3]} vs {k.shape[3]}")
7272

73-
if torch.compiler.is_compiling():
73+
if False:
7474
# Under torch.compile, use the PyTorch primitives path which the
7575
# compiler can trace and optimize.
7676
q_fp8, q_descale = _quantize_per_head(q)
7777
k_fp8, k_descale = _quantize_per_head(k)
7878
v_fp8, v_descale = _quantize_per_head(v)
7979
return q_fp8, k_fp8, v_fp8, q_descale, k_descale, v_descale
8080
else:
81-
# In eager mode, use fused Triton kernels for better performance.
82-
from torchao.prototype.attention.fp8_fa3.triton_qkv_quantization import (
83-
triton_fp8_sdpa_quantize,
81+
# In eager mode, use fused Helion kernels for better performance.
82+
from torchao.prototype.attention.fp8_fa3.helion_qkv_quantization import (
83+
helion_fp8_sdpa_quantize,
8484
)
8585

86-
return triton_fp8_sdpa_quantize(q, k, v)
86+
return helion_fp8_sdpa_quantize(q, k, v)

0 commit comments

Comments
 (0)