3333- Async H memory loading with aggressive pipelining
3434- BF16 tensors with FP32 compute for numerical stability
3535- GQA (grouped-query attention) support with configurable H (query) and HV (value) heads
36+ - Uses scalar FP32 FMA operations for SM90+ (Hopper, Blackwell) compatibility
37+ - Can be optimized with packed F32x2 FMA for SM100+ in future releases
3638"""
3739
3840import math
@@ -123,6 +125,37 @@ def load_h_chunk_async(h_sh_chunk, h_global, tidx, row_offset):
123125 cute .copy (atom_async_copy , tS , tD )
124126
125127
128+ # ==============================================================================
129+ # FMA WRAPPER FUNCTIONS (SM90 Compatibility)
130+ # ==============================================================================
131+ # Note: cute.arch.fma_packed_f32x2() generates F32x2 intrinsics that are NOT
132+ # supported on SM90 (Hopper). These wrappers use scalar FMA operations that
133+ # work on all SM90+ architectures. Future optimization: add architecture-
134+ # specific variants for SM100+ (Blackwell) using packed intrinsics.
135+
136+
137+ @cute .jit
138+ def fma_pair_mul (a1 , a2 , b1 , b2 ):
139+ """Multiply two pairs: (a1, a2) * (b1, b2).
140+
141+ Equivalent to fma_packed_f32x2 with c=(0,0), but compatible with SM90+.
142+ """
143+ result1 = a1 * b1
144+ result2 = a2 * b2
145+ return result1 , result2
146+
147+
148+ @cute .jit
149+ def fma_pair (a1 , a2 , b1 , b2 , c1 , c2 ):
150+ """FMA two pairs: (a1, a2) * (b1, b2) + (c1, c2).
151+
152+ Equivalent to fma_packed_f32x2, but compatible with SM90+.
153+ """
154+ result1 = a1 * b1 + c1
155+ result2 = a2 * b2 + c2
156+ return result1 , result2
157+
158+
126159@cute .jit
127160def compute_single_gate (
128161 alpha , beta_raw , dt_bias_val , A_log_val , softplus_beta , softplus_threshold
@@ -161,15 +194,11 @@ def normalize_and_store_qk_to_smem(q_head, k_head, q_sh, k_sh, lane_idx, scale,
161194 k_sum_sq2 = cutlass .Float32 (0.0 )
162195
163196 for i in cutlass .range_constexpr (0 , 4 , 2 ):
164- q_sum_sq , q_sum_sq2 = cute .arch .fma_packed_f32x2 (
165- src_a = (q_reg [i ], q_reg [i + 1 ]),
166- src_b = (q_reg [i ], q_reg [i + 1 ]),
167- src_c = (q_sum_sq , q_sum_sq2 ),
197+ q_sum_sq , q_sum_sq2 = fma_pair (
198+ q_reg [i ], q_reg [i + 1 ], q_reg [i ], q_reg [i + 1 ], q_sum_sq , q_sum_sq2
168199 )
169- k_sum_sq , k_sum_sq2 = cute .arch .fma_packed_f32x2 (
170- src_a = (k_reg [i ], k_reg [i + 1 ]),
171- src_b = (k_reg [i ], k_reg [i + 1 ]),
172- src_c = (k_sum_sq , k_sum_sq2 ),
200+ k_sum_sq , k_sum_sq2 = fma_pair (
201+ k_reg [i ], k_reg [i + 1 ], k_reg [i ], k_reg [i + 1 ], k_sum_sq , k_sum_sq2
173202 )
174203
175204 q_sum_sq = q_sum_sq + q_sum_sq2
@@ -214,20 +243,16 @@ def decay_h_from_smem_and_compute_pred(
214243 pred2 = cutlass .Float32 (0.0 )
215244
216245 for i in cutlass .range_constexpr (0 , 32 , 2 ):
217- h_chunk [i ], h_chunk [i + 1 ] = cute .arch .fma_packed_f32x2 (
218- src_a = (
219- h_sh_chunk [lane_idx , k_base + i ].to (cutlass .Float32 ),
220- h_sh_chunk [lane_idx , k_base + i + 1 ].to (cutlass .Float32 ),
221- ),
222- src_b = (g_exp , g_exp ),
223- src_c = (cutlass .Float32 (0.0 ), cutlass .Float32 (0.0 )),
246+ h_chunk [i ], h_chunk [i + 1 ] = fma_pair_mul (
247+ h_sh_chunk [lane_idx , k_base + i ].to (cutlass .Float32 ),
248+ h_sh_chunk [lane_idx , k_base + i + 1 ].to (cutlass .Float32 ),
249+ g_exp ,
250+ g_exp ,
224251 )
225252
226253 for i in cutlass .range_constexpr (0 , 32 , 2 ):
227- pred , pred2 = cute .arch .fma_packed_f32x2 (
228- src_a = (h_chunk [i ], h_chunk [i + 1 ]),
229- src_b = (kq_chunk [i ], kq_chunk [i + 1 ]),
230- src_c = (pred , pred2 ),
254+ pred , pred2 = fma_pair (
255+ h_chunk [i ], h_chunk [i + 1 ], kq_chunk [i ], kq_chunk [i + 1 ], pred , pred2
231256 )
232257
233258 pred = pred + pred2
@@ -238,10 +263,8 @@ def decay_h_from_smem_and_compute_pred(
238263def update_h_with_delta (h_chunk , kq_chunk , v_delta ):
239264 """Update H with delta: h = h + k * v_delta."""
240265 for i in cutlass .range_constexpr (0 , 32 , 2 ):
241- h_chunk [i ], h_chunk [i + 1 ] = cute .arch .fma_packed_f32x2 (
242- src_a = (kq_chunk [i ], kq_chunk [i + 1 ]),
243- src_b = (v_delta , v_delta ),
244- src_c = (h_chunk [i ], h_chunk [i + 1 ]),
266+ h_chunk [i ], h_chunk [i + 1 ] = fma_pair (
267+ kq_chunk [i ], kq_chunk [i + 1 ], v_delta , v_delta , h_chunk [i ], h_chunk [i + 1 ]
245268 )
246269
247270
@@ -251,10 +274,8 @@ def compute_output(h_chunk, kq_chunk):
251274 out = cutlass .Float32 (0.0 )
252275 out2 = cutlass .Float32 (0.0 )
253276 for i in cutlass .range_constexpr (0 , 32 , 2 ):
254- out , out2 = cute .arch .fma_packed_f32x2 (
255- src_a = (h_chunk [i ], h_chunk [i + 1 ]),
256- src_b = (kq_chunk [i ], kq_chunk [i + 1 ]),
257- src_c = (out , out2 ),
277+ out , out2 = fma_pair (
278+ h_chunk [i ], h_chunk [i + 1 ], kq_chunk [i ], kq_chunk [i + 1 ], out , out2
258279 )
259280 out = out + out2
260281 return out
@@ -264,10 +285,8 @@ def compute_output(h_chunk, kq_chunk):
264285def decay_h_in_place (h_chunk , g_exp ):
265286 """Apply decay to H in place: h = h * g_exp."""
266287 for i in cutlass .range_constexpr (0 , 32 , 2 ):
267- h_chunk [i ], h_chunk [i + 1 ] = cute .arch .fma_packed_f32x2 (
268- src_a = (h_chunk [i ], h_chunk [i + 1 ]),
269- src_b = (g_exp , g_exp ),
270- src_c = (cutlass .Float32 (0.0 ), cutlass .Float32 (0.0 )),
288+ h_chunk [i ], h_chunk [i + 1 ] = fma_pair_mul (
289+ h_chunk [i ], h_chunk [i + 1 ], g_exp , g_exp
271290 )
272291
273292
@@ -817,19 +836,15 @@ def gated_delta_rule_decode_kernel_seqlen1(
817836 pred = cutlass .Float32 (0.0 )
818837 pred2 = cutlass .Float32 (0.0 )
819838 for i in cutlass .range_constexpr (0 , 32 , 2 ):
820- h_chunk [i ], h_chunk [i + 1 ] = cute .arch .fma_packed_f32x2 (
821- src_a = (
822- h_sh_chunk0 [lane_idx , k_base + i ].to (cutlass .Float32 ),
823- h_sh_chunk0 [lane_idx , k_base + i + 1 ].to (cutlass .Float32 ),
824- ),
825- src_b = (g_exp , g_exp ),
826- src_c = (cutlass .Float32 (0.0 ), cutlass .Float32 (0.0 )),
839+ h_chunk [i ], h_chunk [i + 1 ] = fma_pair_mul (
840+ h_sh_chunk0 [lane_idx , k_base + i ].to (cutlass .Float32 ),
841+ h_sh_chunk0 [lane_idx , k_base + i + 1 ].to (cutlass .Float32 ),
842+ g_exp ,
843+ g_exp ,
827844 )
828845 for i in cutlass .range_constexpr (0 , 32 , 2 ):
829- pred , pred2 = cute .arch .fma_packed_f32x2 (
830- src_a = (h_chunk [i ], h_chunk [i + 1 ]),
831- src_b = (k_chunk [i ], k_chunk [i + 1 ]),
832- src_c = (pred , pred2 ),
846+ pred , pred2 = fma_pair (
847+ h_chunk [i ], h_chunk [i + 1 ], k_chunk [i ], k_chunk [i + 1 ], pred , pred2
833848 )
834849 pred = pred + pred2
835850
@@ -845,10 +860,8 @@ def gated_delta_rule_decode_kernel_seqlen1(
845860 v_val = (v_sh [lane_idx ] - pred_final ) * beta
846861
847862 for i in cutlass .range_constexpr (0 , 32 , 2 ):
848- h_chunk [i ], h_chunk [i + 1 ] = cute .arch .fma_packed_f32x2 (
849- src_a = (k_chunk [i ], k_chunk [i + 1 ]),
850- src_b = (v_val , v_val ),
851- src_c = (h_chunk [i ], h_chunk [i + 1 ]),
863+ h_chunk [i ], h_chunk [i + 1 ] = fma_pair (
864+ k_chunk [i ], k_chunk [i + 1 ], v_val , v_val , h_chunk [i ], h_chunk [i + 1 ]
852865 )
853866
854867 # Load Q for output computation
@@ -858,10 +871,8 @@ def gated_delta_rule_decode_kernel_seqlen1(
858871 out = cutlass .Float32 (0.0 )
859872 out2 = cutlass .Float32 (0.0 )
860873 for i in cutlass .range_constexpr (0 , 32 , 2 ):
861- out , out2 = cute .arch .fma_packed_f32x2 (
862- src_a = (h_chunk [i ], h_chunk [i + 1 ]),
863- src_b = (qk_temp [i ], qk_temp [i + 1 ]),
864- src_c = (out , out2 ),
874+ out , out2 = fma_pair (
875+ h_chunk [i ], h_chunk [i + 1 ], qk_temp [i ], qk_temp [i + 1 ], out , out2
865876 )
866877 out = out + out2
867878
@@ -894,19 +905,15 @@ def gated_delta_rule_decode_kernel_seqlen1(
894905 pred = cutlass .Float32 (0.0 )
895906 pred2 = cutlass .Float32 (0.0 )
896907 for i in cutlass .range_constexpr (0 , 32 , 2 ):
897- h_chunk [i ], h_chunk [i + 1 ] = cute .arch .fma_packed_f32x2 (
898- src_a = (
899- h_sh_chunk1 [lane_idx , k_base + i ].to (cutlass .Float32 ),
900- h_sh_chunk1 [lane_idx , k_base + i + 1 ].to (cutlass .Float32 ),
901- ),
902- src_b = (g_exp , g_exp ),
903- src_c = (cutlass .Float32 (0.0 ), cutlass .Float32 (0.0 )),
908+ h_chunk [i ], h_chunk [i + 1 ] = fma_pair_mul (
909+ h_sh_chunk1 [lane_idx , k_base + i ].to (cutlass .Float32 ),
910+ h_sh_chunk1 [lane_idx , k_base + i + 1 ].to (cutlass .Float32 ),
911+ g_exp ,
912+ g_exp ,
904913 )
905914 for i in cutlass .range_constexpr (0 , 32 , 2 ):
906- pred , pred2 = cute .arch .fma_packed_f32x2 (
907- src_a = (h_chunk [i ], h_chunk [i + 1 ]),
908- src_b = (k_chunk [i ], k_chunk [i + 1 ]),
909- src_c = (pred , pred2 ),
915+ pred , pred2 = fma_pair (
916+ h_chunk [i ], h_chunk [i + 1 ], k_chunk [i ], k_chunk [i + 1 ], pred , pred2
910917 )
911918 pred = pred + pred2
912919
@@ -922,10 +929,8 @@ def gated_delta_rule_decode_kernel_seqlen1(
922929 v_val = (v_sh [32 + lane_idx ] - pred_final ) * beta
923930
924931 for i in cutlass .range_constexpr (0 , 32 , 2 ):
925- h_chunk [i ], h_chunk [i + 1 ] = cute .arch .fma_packed_f32x2 (
926- src_a = (k_chunk [i ], k_chunk [i + 1 ]),
927- src_b = (v_val , v_val ),
928- src_c = (h_chunk [i ], h_chunk [i + 1 ]),
932+ h_chunk [i ], h_chunk [i + 1 ] = fma_pair (
933+ k_chunk [i ], k_chunk [i + 1 ], v_val , v_val , h_chunk [i ], h_chunk [i + 1 ]
929934 )
930935
931936 for i in cutlass .range_constexpr (32 ):
@@ -934,10 +939,8 @@ def gated_delta_rule_decode_kernel_seqlen1(
934939 out = cutlass .Float32 (0.0 )
935940 out2 = cutlass .Float32 (0.0 )
936941 for i in cutlass .range_constexpr (0 , 32 , 2 ):
937- out , out2 = cute .arch .fma_packed_f32x2 (
938- src_a = (h_chunk [i ], h_chunk [i + 1 ]),
939- src_b = (qk_temp [i ], qk_temp [i + 1 ]),
940- src_c = (out , out2 ),
942+ out , out2 = fma_pair (
943+ h_chunk [i ], h_chunk [i + 1 ], qk_temp [i ], qk_temp [i + 1 ], out , out2
941944 )
942945 out = out + out2
943946
@@ -965,19 +968,15 @@ def gated_delta_rule_decode_kernel_seqlen1(
965968 pred = cutlass .Float32 (0.0 )
966969 pred2 = cutlass .Float32 (0.0 )
967970 for i in cutlass .range_constexpr (0 , 32 , 2 ):
968- h_chunk [i ], h_chunk [i + 1 ] = cute .arch .fma_packed_f32x2 (
969- src_a = (
970- h_sh_chunk2 [lane_idx , k_base + i ].to (cutlass .Float32 ),
971- h_sh_chunk2 [lane_idx , k_base + i + 1 ].to (cutlass .Float32 ),
972- ),
973- src_b = (g_exp , g_exp ),
974- src_c = (cutlass .Float32 (0.0 ), cutlass .Float32 (0.0 )),
971+ h_chunk [i ], h_chunk [i + 1 ] = fma_pair_mul (
972+ h_sh_chunk2 [lane_idx , k_base + i ].to (cutlass .Float32 ),
973+ h_sh_chunk2 [lane_idx , k_base + i + 1 ].to (cutlass .Float32 ),
974+ g_exp ,
975+ g_exp ,
975976 )
976977 for i in cutlass .range_constexpr (0 , 32 , 2 ):
977- pred , pred2 = cute .arch .fma_packed_f32x2 (
978- src_a = (h_chunk [i ], h_chunk [i + 1 ]),
979- src_b = (k_chunk [i ], k_chunk [i + 1 ]),
980- src_c = (pred , pred2 ),
978+ pred , pred2 = fma_pair (
979+ h_chunk [i ], h_chunk [i + 1 ], k_chunk [i ], k_chunk [i + 1 ], pred , pred2
981980 )
982981 pred = pred + pred2
983982
@@ -993,10 +992,8 @@ def gated_delta_rule_decode_kernel_seqlen1(
993992 v_val = (v_sh [64 + lane_idx ] - pred_final ) * beta
994993
995994 for i in cutlass .range_constexpr (0 , 32 , 2 ):
996- h_chunk [i ], h_chunk [i + 1 ] = cute .arch .fma_packed_f32x2 (
997- src_a = (k_chunk [i ], k_chunk [i + 1 ]),
998- src_b = (v_val , v_val ),
999- src_c = (h_chunk [i ], h_chunk [i + 1 ]),
995+ h_chunk [i ], h_chunk [i + 1 ] = fma_pair (
996+ k_chunk [i ], k_chunk [i + 1 ], v_val , v_val , h_chunk [i ], h_chunk [i + 1 ]
1000997 )
1001998
1002999 for i in cutlass .range_constexpr (32 ):
@@ -1005,10 +1002,8 @@ def gated_delta_rule_decode_kernel_seqlen1(
10051002 out = cutlass .Float32 (0.0 )
10061003 out2 = cutlass .Float32 (0.0 )
10071004 for i in cutlass .range_constexpr (0 , 32 , 2 ):
1008- out , out2 = cute .arch .fma_packed_f32x2 (
1009- src_a = (h_chunk [i ], h_chunk [i + 1 ]),
1010- src_b = (qk_temp [i ], qk_temp [i + 1 ]),
1011- src_c = (out , out2 ),
1005+ out , out2 = fma_pair (
1006+ h_chunk [i ], h_chunk [i + 1 ], qk_temp [i ], qk_temp [i + 1 ], out , out2
10121007 )
10131008 out = out + out2
10141009
@@ -1036,19 +1031,15 @@ def gated_delta_rule_decode_kernel_seqlen1(
10361031 pred = cutlass .Float32 (0.0 )
10371032 pred2 = cutlass .Float32 (0.0 )
10381033 for i in cutlass .range_constexpr (0 , 32 , 2 ):
1039- h_chunk [i ], h_chunk [i + 1 ] = cute .arch .fma_packed_f32x2 (
1040- src_a = (
1041- h_sh_chunk3 [lane_idx , k_base + i ].to (cutlass .Float32 ),
1042- h_sh_chunk3 [lane_idx , k_base + i + 1 ].to (cutlass .Float32 ),
1043- ),
1044- src_b = (g_exp , g_exp ),
1045- src_c = (cutlass .Float32 (0.0 ), cutlass .Float32 (0.0 )),
1034+ h_chunk [i ], h_chunk [i + 1 ] = fma_pair_mul (
1035+ h_sh_chunk3 [lane_idx , k_base + i ].to (cutlass .Float32 ),
1036+ h_sh_chunk3 [lane_idx , k_base + i + 1 ].to (cutlass .Float32 ),
1037+ g_exp ,
1038+ g_exp ,
10461039 )
10471040 for i in cutlass .range_constexpr (0 , 32 , 2 ):
1048- pred , pred2 = cute .arch .fma_packed_f32x2 (
1049- src_a = (h_chunk [i ], h_chunk [i + 1 ]),
1050- src_b = (k_chunk [i ], k_chunk [i + 1 ]),
1051- src_c = (pred , pred2 ),
1041+ pred , pred2 = fma_pair (
1042+ h_chunk [i ], h_chunk [i + 1 ], k_chunk [i ], k_chunk [i + 1 ], pred , pred2
10521043 )
10531044 pred = pred + pred2
10541045
@@ -1064,10 +1055,8 @@ def gated_delta_rule_decode_kernel_seqlen1(
10641055 v_val = (v_sh [96 + lane_idx ] - pred_final ) * beta
10651056
10661057 for i in cutlass .range_constexpr (0 , 32 , 2 ):
1067- h_chunk [i ], h_chunk [i + 1 ] = cute .arch .fma_packed_f32x2 (
1068- src_a = (k_chunk [i ], k_chunk [i + 1 ]),
1069- src_b = (v_val , v_val ),
1070- src_c = (h_chunk [i ], h_chunk [i + 1 ]),
1058+ h_chunk [i ], h_chunk [i + 1 ] = fma_pair (
1059+ k_chunk [i ], k_chunk [i + 1 ], v_val , v_val , h_chunk [i ], h_chunk [i + 1 ]
10711060 )
10721061
10731062 for i in cutlass .range_constexpr (32 ):
@@ -1076,10 +1065,8 @@ def gated_delta_rule_decode_kernel_seqlen1(
10761065 out = cutlass .Float32 (0.0 )
10771066 out2 = cutlass .Float32 (0.0 )
10781067 for i in cutlass .range_constexpr (0 , 32 , 2 ):
1079- out , out2 = cute .arch .fma_packed_f32x2 (
1080- src_a = (h_chunk [i ], h_chunk [i + 1 ]),
1081- src_b = (qk_temp [i ], qk_temp [i + 1 ]),
1082- src_c = (out , out2 ),
1068+ out , out2 = fma_pair (
1069+ h_chunk [i ], h_chunk [i + 1 ], qk_temp [i ], qk_temp [i + 1 ], out , out2
10831070 )
10841071 out = out + out2
10851072
@@ -1632,19 +1619,15 @@ def gated_delta_rule_decode_kernel_seqlen1_lowBS_1chunk(
16321619 pred = cutlass .Float32 (0.0 )
16331620 pred2 = cutlass .Float32 (0.0 )
16341621 for i in cutlass .range_constexpr (0 , 32 , 2 ):
1635- h_chunk [i ], h_chunk [i + 1 ] = cute .arch .fma_packed_f32x2 (
1636- src_a = (
1637- h_sh_chunk [lane_idx , k_base + i ].to (cutlass .Float32 ),
1638- h_sh_chunk [lane_idx , k_base + i + 1 ].to (cutlass .Float32 ),
1639- ),
1640- src_b = (g_exp , g_exp ),
1641- src_c = (cutlass .Float32 (0.0 ), cutlass .Float32 (0.0 )),
1622+ h_chunk [i ], h_chunk [i + 1 ] = fma_pair_mul (
1623+ h_sh_chunk [lane_idx , k_base + i ].to (cutlass .Float32 ),
1624+ h_sh_chunk [lane_idx , k_base + i + 1 ].to (cutlass .Float32 ),
1625+ g_exp ,
1626+ g_exp ,
16421627 )
16431628 for i in cutlass .range_constexpr (0 , 32 , 2 ):
1644- pred , pred2 = cute .arch .fma_packed_f32x2 (
1645- src_a = (h_chunk [i ], h_chunk [i + 1 ]),
1646- src_b = (k_chunk [i ], k_chunk [i + 1 ]),
1647- src_c = (pred , pred2 ),
1629+ pred , pred2 = fma_pair (
1630+ h_chunk [i ], h_chunk [i + 1 ], k_chunk [i ], k_chunk [i + 1 ], pred , pred2
16481631 )
16491632 pred = pred + pred2
16501633
@@ -1660,10 +1643,8 @@ def gated_delta_rule_decode_kernel_seqlen1_lowBS_1chunk(
16601643 v_val = (v_sh [lane_idx ] - pred_final ) * beta
16611644
16621645 for i in cutlass .range_constexpr (0 , 32 , 2 ):
1663- h_chunk [i ], h_chunk [i + 1 ] = cute .arch .fma_packed_f32x2 (
1664- src_a = (k_chunk [i ], k_chunk [i + 1 ]),
1665- src_b = (v_val , v_val ),
1666- src_c = (h_chunk [i ], h_chunk [i + 1 ]),
1646+ h_chunk [i ], h_chunk [i + 1 ] = fma_pair (
1647+ k_chunk [i ], k_chunk [i + 1 ], v_val , v_val , h_chunk [i ], h_chunk [i + 1 ]
16671648 )
16681649
16691650 for i in cutlass .range_constexpr (32 ):
@@ -1672,10 +1653,8 @@ def gated_delta_rule_decode_kernel_seqlen1_lowBS_1chunk(
16721653 out = cutlass .Float32 (0.0 )
16731654 out2 = cutlass .Float32 (0.0 )
16741655 for i in cutlass .range_constexpr (0 , 32 , 2 ):
1675- out , out2 = cute .arch .fma_packed_f32x2 (
1676- src_a = (h_chunk [i ], h_chunk [i + 1 ]),
1677- src_b = (qk_temp [i ], qk_temp [i + 1 ]),
1678- src_c = (out , out2 ),
1656+ out , out2 = fma_pair (
1657+ h_chunk [i ], h_chunk [i + 1 ], qk_temp [i ], qk_temp [i + 1 ], out , out2
16791658 )
16801659 out = out + out2
16811660
0 commit comments