@@ -206,30 +206,37 @@ def forward(self, q, k, v, return_sparsity=False):
206206
207207 assert headdim in [64 , 128 ], "headdim should be in [64, 128]. For other headdim, you can use padding and specify the softmax scale."
208208
209- ## quant v
210- b , h_kv , kv_len , head_dim = v .shape
211- padded_len = (kv_len + 127 ) // 128 * 128
212- v_transposed_permutted = torch .empty ((b , h_kv , head_dim , padded_len ), dtype = v .dtype , device = v .device )
213- fused .transpose_pad_permute_cuda (v , v_transposed_permutted , 1 )
214- v_fp8 = torch .empty (v_transposed_permutted .shape , dtype = torch .float8_e4m3fn , device = v .device )
215- v_scale = torch .empty ((b , h_kv , head_dim ), dtype = torch .float32 , device = v .device )
216- fused .scale_fuse_quant_cuda (v_transposed_permutted , v_fp8 , v_scale , kv_len , 2.25 , 1 )
217-
218209 o_s = torch .empty_like (q )
219- if arch == "sm90" :
220- qattn .qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_sm90 (
221- q_int8 , k_int8 , v_fp8 , o_s , lut , valid_block_num , q_scale , k_scale , v_scale , 1 , False , 1 , scale
210+
211+ if arch in ("sm80" , "sm86" , "sm87" ):
212+ pvthreshold = torch .full ((q .shape [- 3 ],), 1e6 , dtype = torch .float32 , device = q .device )
213+ v_fp16 = v .to (torch .float16 )
214+ qattn .qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf_with_pv_threshold (
215+ q_int8 , k_int8 , v_fp16 , o_s , lut , valid_block_num , pvthreshold , q_scale , k_scale , 1 , False , 1 , scale , 0
222216 )
223217 else :
224- pvthreshold = torch .full ((q .shape [- 3 ],), 1e6 , dtype = torch .float32 , device = q .device )
225- if SAGE2PP_ENABLED :
226- qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold (
227- q_int8 , k_int8 , v_fp8 , o_s , lut , valid_block_num , pvthreshold , q_scale , k_scale , v_scale , 1 , False , 1 , scale , 0
218+ b , h_kv , kv_len , head_dim = v .shape
219+ padded_len = (kv_len + 127 ) // 128 * 128
220+ v_transposed_permutted = torch .empty ((b , h_kv , head_dim , padded_len ), dtype = v .dtype , device = v .device )
221+ fused .transpose_pad_permute_cuda (v , v_transposed_permutted , 1 )
222+ v_fp8 = torch .empty (v_transposed_permutted .shape , dtype = torch .float8_e4m3fn , device = v .device )
223+ v_scale = torch .empty ((b , h_kv , head_dim ), dtype = torch .float32 , device = v .device )
224+ fused .scale_fuse_quant_cuda (v_transposed_permutted , v_fp8 , v_scale , kv_len , 2.25 , 1 )
225+
226+ if arch == "sm90" :
227+ qattn .qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_sm90 (
228+ q_int8 , k_int8 , v_fp8 , o_s , lut , valid_block_num , q_scale , k_scale , v_scale , 1 , False , 1 , scale
228229 )
229230 else :
230- qattn .qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold (
231- q_int8 , k_int8 , v_fp8 , o_s , lut , valid_block_num , pvthreshold , q_scale , k_scale , v_scale , 1 , False , 1 , scale , 0
232- )
231+ pvthreshold = torch .full ((q .shape [- 3 ],), 1e6 , dtype = torch .float32 , device = q .device )
232+ if SAGE2PP_ENABLED :
233+ qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold (
234+ q_int8 , k_int8 , v_fp8 , o_s , lut , valid_block_num , pvthreshold , q_scale , k_scale , v_scale , 1 , False , 1 , scale , 0
235+ )
236+ else :
237+ qattn .qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold (
238+ q_int8 , k_int8 , v_fp8 , o_s , lut , valid_block_num , pvthreshold , q_scale , k_scale , v_scale , 1 , False , 1 , scale , 0
239+ )
233240
234241 ########## SPARGE END ##########
235242
0 commit comments