Skip to content

Commit c960373

Browse files
authored
fix(sla-core): support Ampere graphic for sagesla (#44)
1 parent 4a60478 commit c960373

File tree

1 file changed

+26
-19
lines changed

1 file changed

+26
-19
lines changed

turbodiffusion/SLA/core.py

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -206,30 +206,37 @@ def forward(self, q, k, v, return_sparsity=False):
206206

207207
assert headdim in [64, 128], "headdim should be in [64, 128]. For other headdim, you can use padding and specify the softmax scale."
208208

209-
## quant v
210-
b, h_kv, kv_len, head_dim = v.shape
211-
padded_len = (kv_len + 127) // 128 * 128
212-
v_transposed_permutted = torch.empty((b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device)
213-
fused.transpose_pad_permute_cuda(v, v_transposed_permutted, 1)
214-
v_fp8 = torch.empty(v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device)
215-
v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device)
216-
fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 2.25, 1)
217-
218209
o_s = torch.empty_like(q)
219-
if arch == "sm90":
220-
qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_sm90(
221-
q_int8, k_int8, v_fp8, o_s, lut, valid_block_num, q_scale, k_scale, v_scale, 1, False, 1, scale
210+
211+
if arch in ("sm80", "sm86", "sm87"):
212+
pvthreshold = torch.full((q.shape[-3],), 1e6, dtype=torch.float32, device=q.device)
213+
v_fp16 = v.to(torch.float16)
214+
qattn.qk_int8_sv_f16_accum_f16_block_sparse_attn_inst_buf_with_pv_threshold(
215+
q_int8, k_int8, v_fp16, o_s, lut, valid_block_num, pvthreshold, q_scale, k_scale, 1, False, 1, scale, 0
222216
)
223217
else:
224-
pvthreshold = torch.full((q.shape[-3],), 1e6, dtype=torch.float32, device=q.device)
225-
if SAGE2PP_ENABLED:
226-
qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(
227-
q_int8, k_int8, v_fp8, o_s, lut, valid_block_num, pvthreshold, q_scale, k_scale, v_scale, 1, False, 1, scale, 0
218+
b, h_kv, kv_len, head_dim = v.shape
219+
padded_len = (kv_len + 127) // 128 * 128
220+
v_transposed_permutted = torch.empty((b, h_kv, head_dim, padded_len), dtype=v.dtype, device=v.device)
221+
fused.transpose_pad_permute_cuda(v, v_transposed_permutted, 1)
222+
v_fp8 = torch.empty(v_transposed_permutted.shape, dtype=torch.float8_e4m3fn, device=v.device)
223+
v_scale = torch.empty((b, h_kv, head_dim), dtype=torch.float32, device=v.device)
224+
fused.scale_fuse_quant_cuda(v_transposed_permutted, v_fp8, v_scale, kv_len, 2.25, 1)
225+
226+
if arch == "sm90":
227+
qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_sm90(
228+
q_int8, k_int8, v_fp8, o_s, lut, valid_block_num, q_scale, k_scale, v_scale, 1, False, 1, scale
228229
)
229230
else:
230-
qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(
231-
q_int8, k_int8, v_fp8, o_s, lut, valid_block_num, pvthreshold, q_scale, k_scale, v_scale, 1, False, 1, scale, 0
232-
)
231+
pvthreshold = torch.full((q.shape[-3],), 1e6, dtype=torch.float32, device=q.device)
232+
if SAGE2PP_ENABLED:
233+
qk_int8_sv_f8_accum_f16_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(
234+
q_int8, k_int8, v_fp8, o_s, lut, valid_block_num, pvthreshold, q_scale, k_scale, v_scale, 1, False, 1, scale, 0
235+
)
236+
else:
237+
qattn.qk_int8_sv_f8_accum_f32_block_sparse_attn_inst_buf_fuse_v_scale_with_pv_threshold(
238+
q_int8, k_int8, v_fp8, o_s, lut, valid_block_num, pvthreshold, q_scale, k_scale, v_scale, 1, False, 1, scale, 0
239+
)
233240

234241
########## SPARGE END ##########
235242

0 commit comments

Comments
 (0)