@@ -291,6 +291,7 @@ def _fwd_grouped_kernel_stage1(
291291 logit_cap : tl .constexpr ,
292292 Lk : tl .constexpr ,
293293 Lv : tl .constexpr ,
294+ IS_MLA : tl .constexpr = False ,
294295):
295296 cur_batch = tl .program_id (0 )
296297 cur_head_id = tl .program_id (1 )
@@ -310,7 +311,12 @@ def _fwd_grouped_kernel_stage1(
310311 cur_batch_req_idx = cur_batch
311312
312313 offs_q = cur_batch * stride_qbs + cur_head [:, None ] * stride_qh + offs_d [None , :]
313- q = tl .load (Q + offs_q , mask = (mask_h [:, None ]) & (mask_d [None , :]), other = 0.0 )
314+ q = tl .load (
315+ Q + offs_q ,
316+ mask = (mask_h [:, None ]) & (mask_d [None , :]),
317+ other = 0.0 ,
318+ cache_modifier = ".ca" ,
319+ )
314320
315321 if BLOCK_DPE > 0 :
316322 offs_dpe = BLOCK_DMODEL + tl .arange (0 , BLOCK_DPE )
@@ -319,7 +325,10 @@ def _fwd_grouped_kernel_stage1(
319325 cur_batch * stride_qbs + cur_head [:, None ] * stride_qh + offs_dpe [None , :]
320326 )
321327 qpe = tl .load (
322- Q + off_qpe , mask = (mask_h [:, None ]) & (mask_dpe [None , :]), other = 0.0
328+ Q + off_qpe ,
329+ mask = (mask_h [:, None ]) & (mask_dpe [None , :]),
330+ other = 0.0 ,
331+ cache_modifier = ".ca" ,
323332 )
324333
325334 kv_len_per_split = tl .cdiv (cur_batch_seq_len , NUM_KV_SPLITS )
@@ -331,41 +340,44 @@ def _fwd_grouped_kernel_stage1(
331340 acc = tl .zeros ([BLOCK_H , BLOCK_DV ], dtype = tl .float32 )
332341
333342 if split_kv_end > split_kv_start :
343+ base_offs_k = cur_kv_head * stride_buf_kh + offs_d [:, None ]
344+ base_offs_v = cur_kv_head * stride_buf_vh + offs_dv [None , :]
345+ if BLOCK_DPE > 0 :
346+ base_offs_kpe = cur_kv_head * stride_buf_kh + offs_dpe [:, None ]
347+
334348 ks = tl .load (k_scale )
335349 vs = tl .load (v_scale )
336- for start_n in range (split_kv_start , split_kv_end , BLOCK_N ):
350+ for start_n in tl . range (split_kv_start , split_kv_end , BLOCK_N ):
337351 offs_n = start_n + tl .arange (0 , BLOCK_N )
338352 kv_page_number = tl .load (
339353 Req_to_tokens
340354 + stride_req_to_tokens_b * cur_batch_req_idx
341355 + offs_n // PAGE_SIZE ,
342356 mask = offs_n < split_kv_end ,
343357 other = 0 ,
358+ cache_modifier = ".ca" ,
344359 )
345360 kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
346- offs_buf_k = (
347- kv_loc [None , :] * stride_buf_kbs
348- + cur_kv_head * stride_buf_kh
349- + offs_d [:, None ]
350- )
361+
362+ # explicitly facilitate overlapping load/compute
363+ offs_buf_k = kv_loc [None , :] * stride_buf_kbs + base_offs_k
351364 k = tl .load (
352365 K_Buffer + offs_buf_k ,
353366 mask = (offs_n [None , :] < split_kv_end ) & (mask_d [:, None ]),
354367 other = 0.0 ,
368+ cache_modifier = ".cg" ,
355369 )
370+
356371 if k .dtype .is_fp8 ():
357372 k = (k .to (tl .float32 ) * ks ).to (q .dtype )
358373 qk = tl .dot (q , k .to (q .dtype ))
359374 if BLOCK_DPE > 0 :
360- offs_buf_kpe = (
361- kv_loc [None , :] * stride_buf_kbs
362- + cur_kv_head * stride_buf_kh
363- + offs_dpe [:, None ]
364- )
375+ offs_buf_kpe = kv_loc [None , :] * stride_buf_kbs + base_offs_kpe
365376 kpe = tl .load (
366377 K_Buffer + offs_buf_kpe ,
367378 mask = (offs_n [None , :] < split_kv_end ) & (mask_dpe [:, None ]),
368379 other = 0.0 ,
380+ cache_modifier = ".cg" ,
369381 )
370382 if kpe .dtype .is_fp8 ():
371383 kpe = (kpe .to (tl .float32 ) * ks ).to (qpe .dtype )
@@ -379,18 +391,20 @@ def _fwd_grouped_kernel_stage1(
379391 mask_h [:, None ] & (offs_n [None , :] < split_kv_end ), qk , float ("-inf" )
380392 )
381393
382- offs_buf_v = (
383- kv_loc [:, None ] * stride_buf_vbs
384- + cur_kv_head * stride_buf_vh
385- + offs_dv [None , :]
386- )
387- v = tl .load (
388- V_Buffer + offs_buf_v ,
389- mask = (offs_n [:, None ] < split_kv_end ) & (mask_dv [None , :]),
390- other = 0.0 ,
391- )
392- if v .dtype .is_fp8 ():
393- v = (v .to (tl .float32 ) * vs ).to (q .dtype )
394+ if not IS_MLA :
395+ offs_buf_v = kv_loc [:, None ] * stride_buf_vbs + base_offs_v
396+ v = tl .load (
397+ V_Buffer + offs_buf_v ,
398+ mask = (offs_n [:, None ] < split_kv_end ) & (mask_dv [None , :]),
399+ other = 0.0 ,
400+ )
401+ if v .dtype .is_fp8 ():
402+ v = (v .to (tl .float32 ) * vs ).to (q .dtype )
403+ else :
404+ # MLA uses a single c_kv.
405+ # loading the same c_kv to interpret it as v is not necessary.
406+ # transpose the existing c_kv (aka k) for the dot product.
407+ v = tl .trans (k )
394408
395409 n_e_max = tl .maximum (tl .max (qk , 1 ), e_max )
396410 re_scale = tl .exp (e_max - n_e_max )
@@ -441,7 +455,10 @@ def _decode_grouped_att_m_fwd(
441455 logit_cap ,
442456 k_scale ,
443457 v_scale ,
458+ is_mla = False ,
444459):
460+ # with is_mla there is only a single c_kv in smem.
461+ # could increase BLOCK or num_stages.
445462 BLOCK = 32
446463 Lk = k_buffer .shape [- 1 ]
447464 Lv = v_buffer .shape [- 1 ]
@@ -514,6 +531,7 @@ def _decode_grouped_att_m_fwd(
514531 num_stages = num_stages ,
515532 Lk = Lk ,
516533 Lv = Lv ,
534+ IS_MLA = is_mla ,
517535 ** extra_kargs ,
518536 )
519537
@@ -673,6 +691,7 @@ def decode_attention_fwd_grouped(
673691 logit_cap = 0.0 ,
674692 k_scale = None ,
675693 v_scale = None ,
694+ is_mla = False ,
676695):
677696 _decode_grouped_att_m_fwd (
678697 q ,
@@ -687,6 +706,7 @@ def decode_attention_fwd_grouped(
687706 logit_cap ,
688707 k_scale ,
689708 v_scale ,
709+ is_mla = is_mla ,
690710 )
691711 _decode_softmax_reducev_fwd (
692712 attn_logits , q , o , lse , v_buffer , b_seq_len , num_kv_splits
@@ -708,6 +728,7 @@ def decode_attention_fwd(
708728 logit_cap = 0.0 ,
709729 k_scale = None ,
710730 v_scale = None ,
731+ is_mla = False ,
711732):
712733 assert num_kv_splits == attn_logits .shape [2 ]
713734
@@ -753,4 +774,5 @@ def decode_attention_fwd(
753774 logit_cap ,
754775 k_scale ,
755776 v_scale ,
777+ is_mla = is_mla ,
756778 )
0 commit comments