55# LICENSE file in the root directory of this source tree.
66
77"""
8- Fused RoPE + FP8 Quantization kernels using Helion (Optimized version) .
8+ FP8 Quantization kernels using Helion.
99
10- This module provides Helion kernels that fuse:
11- - RoPE (Rotary Position Embeddings) for Q and K
12- - FP8 quantization for Q, K, V
13- - Layout transformation from [B, S, H, D] (FLUX) to [B, H, S, D] (SDPA)
14-
15- The layout transformation is fused directly into the kernels, avoiding
16- separate transpose+contiguous memory copies that would otherwise add ~70%
17- overhead to the overall kernel execution time.
10+ This module provides Helion kernels that perform per-head FP8 quantization
11+ for Q, K, V tensors.
1812
1913The 3-kernel structure parallelizes over (B, H, S) with nested D loop:
20- - Phase 1: RoPE + partial max (reads [B,S,H,D], writes [B,H,S,D])
21- - Reduce: Aggregate maxes per head (parallel over B * H)
22- - Phase 2: Quantize (reads V from [B,S,H,D], writes [B,H,S,D])
23-
24- Input format: [B, S, H, D] (FLUX-style)
25- Output format: [B, H, S, D] (SDPA-style)
14+ - Phase 1: Compute partial absmax values per S-block
15+ - Reduce: Aggregate maxes per head and compute scale/descale factors
16+ - Phase 2: Apply scales and cast to FP8
2617
18+ Input/output format: [B, H, S, D]
2719"""
2820
2921from typing import Optional , Tuple
3527
3628
3729# =============================================================================
38- # Phase 1: RoPE + Max computation
39- # Reads from [B, S, H, D] (FLUX layout), writes to [B, H, S, D] (SDPA layout)
40- # Fuses layout transformation with RoPE computation to avoid separate copy
30+ # Phase 1: Partial absmax computation
4131# =============================================================================
4232
4333
4434@helion .kernel (static_shapes = True )
45- def rope_qkv_phase1_helion (
46- q : torch .Tensor , # [B, S, H, D] - FLUX input layout
47- k : torch .Tensor , # [B, S, H, D] - FLUX input layout
48- v : torch .Tensor , # [B, S, H, D] - FLUX input layout
49- cos : torch .Tensor , # [S, D]
50- sin : torch .Tensor , # [S, D]
51- q_rope_out : torch .Tensor , # [B, H, S, D] - output in SDPA layout
52- k_rope_out : torch .Tensor , # [B, H, S, D] - output in SDPA layout
35+ def qkv_phase1_helion (
36+ q : torch .Tensor , # [B, H, S, D]
37+ k : torch .Tensor , # [B, H, S, D]
38+ v : torch .Tensor , # [B, H, S, D]
5339 partial_max : torch .Tensor , # [B, H, num_s_blocks, 3] - output
5440) -> None :
5541 """
56- Phase 1: Apply RoPE to Q and K, store results, compute partial max.
57-
58- Reads from [B, S, H, D] (FLUX layout) and writes to [B, H, S, D] (SDPA layout),
59- fusing the layout transformation with the RoPE computation.
42+ Phase 1: Compute partial absmax for Q, K, V per S-block.
6043
6144 Uses 3D tiling over (B, H, S) with block_size=[1, 1, block_s], plus
6245 a nested inner loop over D with block_size=D (single iteration).
6346 """
64- B , S , H , D = q .size ()
65- D_HALF = hl .specialize (D // 2 )
47+ B , H , S , D = q .size ()
6648
6749 block_size_s = hl .register_block_size (S )
6850
6951 for tile_b , tile_h , tile_s in hl .tile ([B , H , S ], block_size = [1 , 1 , block_size_s ]):
7052 for tile_d in hl .tile (D , block_size = D ):
71- # Load from [B, S, H, D] input layout
72- q_tile = q [tile_b .begin , tile_s , tile_h .begin , tile_d ].to (torch .float32 )
73- k_tile = k [tile_b .begin , tile_s , tile_h .begin , tile_d ].to (torch .float32 )
74- v_tile = v [tile_b .begin , tile_s , tile_h .begin , tile_d ].to (torch .float32 )
75-
76- cos_tile = cos [tile_s , tile_d ].to (torch .float32 )
77- sin_tile = sin [tile_s , tile_d ].to (torch .float32 )
78-
79- # Split into real/imag components
80- q_tile_ri = q_tile .reshape (- 1 , D_HALF , 2 )
81- k_tile_ri = k_tile .reshape (- 1 , D_HALF , 2 )
82- q_real , q_imag = hl .split (q_tile_ri )
83- k_real , k_imag = hl .split (k_tile_ri )
84-
85- cos_tile_ri = cos_tile .reshape (- 1 , D_HALF , 2 )
86- sin_tile_ri = sin_tile .reshape (- 1 , D_HALF , 2 )
87- cos_real , cos_imag = hl .split (cos_tile_ri )
88- sin_real , sin_imag = hl .split (sin_tile_ri )
89-
90- # Apply RoPE
91- q_rope_real = q_real * cos_real - q_imag * sin_real
92- q_rope_imag = q_real * sin_imag + q_imag * cos_imag
93- k_rope_real = k_real * cos_real - k_imag * sin_real
94- k_rope_imag = k_real * sin_imag + k_imag * cos_imag
95-
96- q_rope = hl .join (q_rope_real , q_rope_imag ).reshape (- 1 , D )
97- k_rope = hl .join (k_rope_real , k_rope_imag ).reshape (- 1 , D )
98-
99- # Store RoPE'd Q, K to [B, H, S, D] output layout
100- q_rope_out [tile_b .begin , tile_h .begin , tile_s , tile_d ] = q_rope .to (q .dtype )
101- k_rope_out [tile_b .begin , tile_h .begin , tile_s , tile_d ] = k_rope .to (k .dtype )
102-
103- # Compute partial max for this block
104- q_max_tile = torch .amax (torch .abs (q_rope ), dim = - 1 )
105- k_max_tile = torch .amax (torch .abs (k_rope ), dim = - 1 )
53+ q_tile = q [tile_b .begin , tile_h .begin , tile_s , tile_d ].to (torch .float32 )
54+ k_tile = k [tile_b .begin , tile_h .begin , tile_s , tile_d ].to (torch .float32 )
55+ v_tile = v [tile_b .begin , tile_h .begin , tile_s , tile_d ].to (torch .float32 )
56+
57+ q_max_tile = torch .amax (torch .abs (q_tile ), dim = - 1 )
58+ k_max_tile = torch .amax (torch .abs (k_tile ), dim = - 1 )
10659 v_max_tile = torch .amax (torch .abs (v_tile ), dim = - 1 )
10760
10861 q_max = torch .amax (q_max_tile )
@@ -116,12 +69,11 @@ def rope_qkv_phase1_helion(
11669
11770# =============================================================================
11871# Reduce kernel: Aggregate partial maxes and compute scales
119- # Parallelizes over (B, H) using hl.tile([...]) with block_size=[1, 1]
12072# =============================================================================
12173
12274
12375@helion .kernel (static_shapes = True )
124- def rope_qkv_reduce_helion (
76+ def qkv_reduce_helion (
12577 partial_max : torch .Tensor , # [B, H, num_s_blocks, 3]
12678 q_scale : torch .Tensor , # [B, H] - output
12779 k_scale : torch .Tensor , # [B, H] - output
@@ -134,13 +86,10 @@ def rope_qkv_reduce_helion(
13486 Reduce partial maxes across S-blocks and compute scales/descales.
13587
13688 Uses 2D tiling over (B, H) with block_size=[1, 1].
137- - tile_b.begin, tile_h.begin are scalar indices
138- - Sequential reduction over S blocks using vectorized access
13989 """
14090 FP8_MAX : float = 448.0
14191 eps : float = 1e-12
14292 B , H = q_scale .size ()
143- num_s_blocks = partial_max .size (2 )
14493
14594 for tile_b , tile_h in hl .tile ([B , H ], block_size = [1 , 1 ]):
14695 q_partial = partial_max [tile_b .begin , tile_h .begin , :, 0 ]
@@ -171,17 +120,15 @@ def rope_qkv_reduce_helion(
171120
172121
173122# =============================================================================
174- # Phase 2: Quantize
175- # Reads V from [B, S, H, D] (FLUX layout), writes to [B, H, S, D] (SDPA layout)
176- # Fuses layout transformation with quantization to avoid separate copy
123+ # Phase 2: Quantize to FP8
177124# =============================================================================
178125
179126
180127@helion .kernel (static_shapes = True )
181- def rope_qkv_phase2_helion (
182- q_rope : torch .Tensor , # [B, H, S, D] - intermediate RoPE'd Q (SDPA layout)
183- k_rope : torch .Tensor , # [B, H, S, D] - intermediate RoPE'd K (SDPA layout)
184- v : torch .Tensor , # [B, S, H , D] - original V (FLUX layout)
128+ def qkv_phase2_helion (
129+ q : torch .Tensor , # [B, H, S, D]
130+ k : torch .Tensor , # [B, H, S, D]
131+ v : torch .Tensor , # [B, H, S , D]
185132 q_out : torch .Tensor , # [B, H, S, D] - FP8 output
186133 k_out : torch .Tensor , # [B, H, S, D] - FP8 output
187134 v_out : torch .Tensor , # [B, H, S, D] - FP8 output
@@ -190,16 +137,12 @@ def rope_qkv_phase2_helion(
190137 v_scale : torch .Tensor , # [B, H]
191138) -> None :
192139 """
193- Phase 2: Quantize pre-computed RoPE'd Q, K and V.
194-
195- Q and K are read from intermediate buffers in [B, H, S, D] (SDPA layout).
196- V is read from original input in [B, S, H, D] (FLUX layout).
197- All outputs are written in [B, H, S, D] (SDPA layout).
140+ Phase 2: Quantize Q, K, V to FP8 using precomputed scales.
198141
199142 Uses 3D tiling over (B, H, S) with block_size=[1, 1, block_s], plus
200143 a nested inner loop over D with block_size=D (single iteration).
201144 """
202- B , H , S , D = q_rope .size ()
145+ B , H , S , D = q .size ()
203146
204147 block_size_s = hl .register_block_size (S )
205148
@@ -209,14 +152,10 @@ def rope_qkv_phase2_helion(
209152 k_sc = k_scale [tile_b .begin , tile_h .begin ]
210153 v_sc = v_scale [tile_b .begin , tile_h .begin ]
211154
212- # Load Q, K from [B, H, S, D] intermediate buffers
213- q_val = q_rope [tile_b .begin , tile_h .begin , tile_s , tile_d ].to (torch .float32 )
214- k_val = k_rope [tile_b .begin , tile_h .begin , tile_s , tile_d ].to (torch .float32 )
215-
216- # Load V from [B, S, H, D] input layout
217- v_val = v [tile_b .begin , tile_s , tile_h .begin , tile_d ].to (torch .float32 )
155+ q_val = q [tile_b .begin , tile_h .begin , tile_s , tile_d ].to (torch .float32 )
156+ k_val = k [tile_b .begin , tile_h .begin , tile_s , tile_d ].to (torch .float32 )
157+ v_val = v [tile_b .begin , tile_h .begin , tile_s , tile_d ].to (torch .float32 )
218158
219- # Quantize to FP8
220159 q_fp8 = (q_val * q_sc ).to (torch .float8_e4m3fn )
221160 k_fp8 = (k_val * k_sc ).to (torch .float8_e4m3fn )
222161 v_fp8 = (v_val * v_sc ).to (torch .float8_e4m3fn )
@@ -227,16 +166,14 @@ def rope_qkv_phase2_helion(
227166
228167
229168# =============================================================================
230- # Main entry point (same API as Triton version)
169+ # Main entry point
231170# =============================================================================
232171
233172
234- def fp8_rope_quantize_func (
173+ def helion_fp8_sdpa_quantize (
235174 q : torch .Tensor ,
236175 k : torch .Tensor ,
237176 v : torch .Tensor ,
238- cos : torch .Tensor ,
239- sin : torch .Tensor ,
240177 num_chunks : Optional [int ] = None , # Ignored - block sizes are autotuned
241178) -> Tuple [
242179 torch .Tensor ,
@@ -247,64 +184,39 @@ def fp8_rope_quantize_func(
247184 torch .Tensor ,
248185]:
249186 """
250- Fused RoPE + FP8 quantization for Q, K, V tensors.
251-
252- Applies RoPE to Q and K, then quantizes all tensors to FP8 with per-head scaling.
253- Also performs layout transformation from [B, S, H, D] to [B, H, S, D].
254-
255- The layout transformation is fused into the kernels themselves, avoiding
256- the need for separate transpose+contiguous memory copies.
187+ Fused per-head FP8 quantization for Q, K, V using Helion kernels.
257188
258189 Uses 3-kernel structure with full parallelization:
259- - Phase 1: RoPE + partial max (parallel over B * H * S_blocks)
190+ - Phase 1: Partial absmax (parallel over B * H * S_blocks)
260191 - Reduce: Aggregate maxes per head (parallel over B * H)
261192 - Phase 2: Quantize (parallel over B * H * S_blocks)
262193
263194 Note: The num_chunks parameter is ignored. Block sizes are autotuned by Helion.
264195
265196 Args:
266- q: Query tensor of shape [B, S, H, D] in bf16/fp16
267- k: Key tensor of shape [B, S, H, D] in bf16/fp16
268- v: Value tensor of shape [B, S, H, D] in bf16/fp16
269- cos: Cosine frequencies for RoPE, shape [S, D]
270- sin: Sine frequencies for RoPE, shape [S, D]
197+ q: Query tensor of shape [B, H, S, D] in bf16/fp16
198+ k: Key tensor of shape [B, H, S, D] in bf16/fp16
199+ v: Value tensor of shape [B, H, S, D] in bf16/fp16
271200 num_chunks: Ignored (kept for API compatibility)
272201
273202 Returns:
274- q_fp8: Quantized query with RoPE, shape [B, H, S, D] in fp8
275- k_fp8: Quantized key with RoPE, shape [B, H, S, D] in fp8
276- v_fp8: Quantized value, shape [B, H, S, D] in fp8
277- q_descale: Query descale factors, shape [B, H] in fp32
278- k_descale: Key descale factors, shape [B, H] in fp32
279- v_descale: Value descale factors, shape [B, H] in fp32
203+ q_fp8, k_fp8, v_fp8: Quantized tensors in float8_e4m3fn, shape [B, H, S, D]
204+ q_descale, k_descale, v_descale: Descale factors of shape [B, H] in fp32
280205 """
281- assert q .dim () == 4 , f"Expected 4D tensor [B, S, H , D], got { q .dim ()} D"
282- assert k .dim () == 4 , f"Expected 4D tensor [B, S, H , D], got { k .dim ()} D"
283- assert v .dim () == 4 , f"Expected 4D tensor [B, S, H , D], got { v .dim ()} D"
206+ assert q .dim () == 4 , f"Expected 4D tensor [B, H, S , D], got { q .dim ()} D"
207+ assert k .dim () == 4 , f"Expected 4D tensor [B, H, S , D], got { k .dim ()} D"
208+ assert v .dim () == 4 , f"Expected 4D tensor [B, H, S , D], got { v .dim ()} D"
284209 assert q .shape == k .shape == v .shape , "Q, K, V must have the same shape"
285- assert cos .dim () == 2 , f"Expected 2D cos tensor [S, D], got { cos .dim ()} D"
286- assert sin .dim () == 2 , f"Expected 2D sin tensor [S, D], got { sin .dim ()} D"
287210
288- B , S , H , D = q .shape
211+ B , H , S , D = q .shape
289212
290- assert D % 2 == 0 , f"Head dimension D must be even for RoPE, got D={ D } "
291- assert cos .shape == (S , D ), f"Expected cos shape [{ S } , { D } ], got { cos .shape } "
292- assert sin .shape == (S , D ), f"Expected sin shape [{ S } , { D } ], got { sin .shape } "
293-
294- # Ensure inputs are contiguous
295213 q = q .contiguous ()
296214 k = k .contiguous ()
297215 v = v .contiguous ()
298- cos = cos .contiguous ()
299- sin = sin .contiguous ()
300216
301217 # Upper bound for S blocks (block_size_s is autotuned)
302218 max_s_blocks = S
303219
304- # Allocate intermediate buffers
305- q_rope_intermediate = torch .empty (B , H , S , D , dtype = q .dtype , device = q .device )
306- k_rope_intermediate = torch .empty (B , H , S , D , dtype = k .dtype , device = q .device )
307-
308220 partial_max = torch .zeros (
309221 B , H , max_s_blocks , 3 , dtype = torch .float32 , device = q .device
310222 )
@@ -316,20 +228,11 @@ def fp8_rope_quantize_func(
316228 k_descale = torch .empty (B , H , dtype = torch .float32 , device = q .device )
317229 v_descale = torch .empty (B , H , dtype = torch .float32 , device = q .device )
318230
319- # Phase 1: RoPE + partial max
320- rope_qkv_phase1_helion (
321- q ,
322- k ,
323- v ,
324- cos ,
325- sin ,
326- q_rope_intermediate ,
327- k_rope_intermediate ,
328- partial_max ,
329- )
231+ # Phase 1: partial absmax
232+ qkv_phase1_helion (q , k , v , partial_max )
330233
331234 # Reduce: aggregate maxes per head
332- rope_qkv_reduce_helion (
235+ qkv_reduce_helion (
333236 partial_max ,
334237 q_scale ,
335238 k_scale ,
@@ -345,16 +248,10 @@ def fp8_rope_quantize_func(
345248 v_fp8 = torch .empty (B , H , S , D , dtype = torch .float8_e4m3fn , device = q .device )
346249
347250 # Phase 2: quantize
348- rope_qkv_phase2_helion (
349- q_rope_intermediate ,
350- k_rope_intermediate ,
351- v ,
352- q_fp8 ,
353- k_fp8 ,
354- v_fp8 ,
355- q_scale ,
356- k_scale ,
357- v_scale ,
251+ qkv_phase2_helion (
252+ q , k , v ,
253+ q_fp8 , k_fp8 , v_fp8 ,
254+ q_scale , k_scale , v_scale ,
358255 )
359256
360257 return q_fp8 , k_fp8 , v_fp8 , q_descale , k_descale , v_descale
0 commit comments