Skip to content

Commit d9408ff

Browse files
koushrootDarkLight1337
authored
Triton MLA perf fixes (vllm-project#33529)
Signed-off-by: Koushik Dutta <koushd@gmail.com> Co-authored-by: root <root@ubuntu-nvidia.localdomain> Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
1 parent 16a65e4 commit d9408ff

2 files changed

Lines changed: 69 additions & 26 deletions

File tree

vllm/v1/attention/backends/mla/triton_mla.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
MLACommonMetadata,
1515
)
1616
from vllm.platforms.interface import DeviceCapability
17+
from vllm.triton_utils import triton
1718
from vllm.utils.torch_utils import is_quantized_kv_cache
1819
from vllm.v1.attention.backend import (
1920
AttentionLayer,
@@ -115,6 +116,8 @@ def __init__(
115116
if is_quantized_kv_cache(self.kv_cache_dtype):
116117
self.supports_quant_query_input = False
117118

119+
self._sm_count = torch.cuda.get_device_properties(0).multi_processor_count
120+
118121
def _flash_attn_varlen_diff_headdims(
119122
self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
120123
):
@@ -149,7 +152,24 @@ def forward_mqa(
149152
lse = torch.zeros(B, q_num_heads, dtype=q.dtype, device=q.device)
150153

151154
# For batch invariance, use only 1 split to ensure deterministic reduction
152-
num_kv_splits = 1 if envs.VLLM_BATCH_INVARIANT else 4
155+
if envs.VLLM_BATCH_INVARIANT:
156+
num_kv_splits = 1
157+
else:
158+
# Minimum work per split
159+
# hardware dependent
160+
min_work_per_split = 512
161+
162+
ideal_splits = max(1, attn_metadata.max_seq_len // min_work_per_split)
163+
164+
# use power of 2 to avoid excessive kernel instantiations
165+
ideal_splits = triton.next_power_of_2(ideal_splits)
166+
167+
# Calculate SM-based maximum splits with occupancy multiplier
168+
# 2-4x allows multiple blocks per SM for latency hiding
169+
# hardware dependent
170+
occupancy_multiplier = 2
171+
max_splits = self._sm_count * occupancy_multiplier
172+
num_kv_splits = min(ideal_splits, max_splits)
153173

154174
# TODO(lucas) Allocate ahead of time
155175
attn_logits = torch.empty(
@@ -186,6 +206,7 @@ def forward_mqa(
186206
PAGE_SIZE,
187207
k_scale=layer._k_scale,
188208
v_scale=layer._k_scale,
209+
is_mla=True,
189210
)
190211

191212
return o, lse

vllm/v1/attention/ops/triton_decode_attention.py

Lines changed: 47 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,7 @@ def _fwd_grouped_kernel_stage1(
291291
logit_cap: tl.constexpr,
292292
Lk: tl.constexpr,
293293
Lv: tl.constexpr,
294+
IS_MLA: tl.constexpr = False,
294295
):
295296
cur_batch = tl.program_id(0)
296297
cur_head_id = tl.program_id(1)
@@ -310,7 +311,12 @@ def _fwd_grouped_kernel_stage1(
310311
cur_batch_req_idx = cur_batch
311312

312313
offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :]
313-
q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0)
314+
q = tl.load(
315+
Q + offs_q,
316+
mask=(mask_h[:, None]) & (mask_d[None, :]),
317+
other=0.0,
318+
cache_modifier=".ca",
319+
)
314320

315321
if BLOCK_DPE > 0:
316322
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
@@ -319,7 +325,10 @@ def _fwd_grouped_kernel_stage1(
319325
cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :]
320326
)
321327
qpe = tl.load(
322-
Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0
328+
Q + off_qpe,
329+
mask=(mask_h[:, None]) & (mask_dpe[None, :]),
330+
other=0.0,
331+
cache_modifier=".ca",
323332
)
324333

325334
kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS)
@@ -331,41 +340,44 @@ def _fwd_grouped_kernel_stage1(
331340
acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32)
332341

333342
if split_kv_end > split_kv_start:
343+
base_offs_k = cur_kv_head * stride_buf_kh + offs_d[:, None]
344+
base_offs_v = cur_kv_head * stride_buf_vh + offs_dv[None, :]
345+
if BLOCK_DPE > 0:
346+
base_offs_kpe = cur_kv_head * stride_buf_kh + offs_dpe[:, None]
347+
334348
ks = tl.load(k_scale)
335349
vs = tl.load(v_scale)
336-
for start_n in range(split_kv_start, split_kv_end, BLOCK_N):
350+
for start_n in tl.range(split_kv_start, split_kv_end, BLOCK_N):
337351
offs_n = start_n + tl.arange(0, BLOCK_N)
338352
kv_page_number = tl.load(
339353
Req_to_tokens
340354
+ stride_req_to_tokens_b * cur_batch_req_idx
341355
+ offs_n // PAGE_SIZE,
342356
mask=offs_n < split_kv_end,
343357
other=0,
358+
cache_modifier=".ca",
344359
)
345360
kv_loc = kv_page_number * PAGE_SIZE + offs_n % PAGE_SIZE
346-
offs_buf_k = (
347-
kv_loc[None, :] * stride_buf_kbs
348-
+ cur_kv_head * stride_buf_kh
349-
+ offs_d[:, None]
350-
)
361+
362+
# explicitly facilitate overlapping load/compute
363+
offs_buf_k = kv_loc[None, :] * stride_buf_kbs + base_offs_k
351364
k = tl.load(
352365
K_Buffer + offs_buf_k,
353366
mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]),
354367
other=0.0,
368+
cache_modifier=".cg",
355369
)
370+
356371
if k.dtype.is_fp8():
357372
k = (k.to(tl.float32) * ks).to(q.dtype)
358373
qk = tl.dot(q, k.to(q.dtype))
359374
if BLOCK_DPE > 0:
360-
offs_buf_kpe = (
361-
kv_loc[None, :] * stride_buf_kbs
362-
+ cur_kv_head * stride_buf_kh
363-
+ offs_dpe[:, None]
364-
)
375+
offs_buf_kpe = kv_loc[None, :] * stride_buf_kbs + base_offs_kpe
365376
kpe = tl.load(
366377
K_Buffer + offs_buf_kpe,
367378
mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]),
368379
other=0.0,
380+
cache_modifier=".cg",
369381
)
370382
if kpe.dtype.is_fp8():
371383
kpe = (kpe.to(tl.float32) * ks).to(qpe.dtype)
@@ -379,18 +391,20 @@ def _fwd_grouped_kernel_stage1(
379391
mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf")
380392
)
381393

382-
offs_buf_v = (
383-
kv_loc[:, None] * stride_buf_vbs
384-
+ cur_kv_head * stride_buf_vh
385-
+ offs_dv[None, :]
386-
)
387-
v = tl.load(
388-
V_Buffer + offs_buf_v,
389-
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
390-
other=0.0,
391-
)
392-
if v.dtype.is_fp8():
393-
v = (v.to(tl.float32) * vs).to(q.dtype)
394+
if not IS_MLA:
395+
offs_buf_v = kv_loc[:, None] * stride_buf_vbs + base_offs_v
396+
v = tl.load(
397+
V_Buffer + offs_buf_v,
398+
mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]),
399+
other=0.0,
400+
)
401+
if v.dtype.is_fp8():
402+
v = (v.to(tl.float32) * vs).to(q.dtype)
403+
else:
404+
# MLA uses a single c_kv.
405+
# loading the same c_kv to interpret it as v is not necessary.
406+
# transpose the existing c_kv (aka k) for the dot product.
407+
v = tl.trans(k)
394408

395409
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
396410
re_scale = tl.exp(e_max - n_e_max)
@@ -441,7 +455,10 @@ def _decode_grouped_att_m_fwd(
441455
logit_cap,
442456
k_scale,
443457
v_scale,
458+
is_mla=False,
444459
):
460+
# with is_mla there is only a single c_kv in smem.
461+
# could increase BLOCK or num_stages.
445462
BLOCK = 32
446463
Lk = k_buffer.shape[-1]
447464
Lv = v_buffer.shape[-1]
@@ -514,6 +531,7 @@ def _decode_grouped_att_m_fwd(
514531
num_stages=num_stages,
515532
Lk=Lk,
516533
Lv=Lv,
534+
IS_MLA=is_mla,
517535
**extra_kargs,
518536
)
519537

@@ -673,6 +691,7 @@ def decode_attention_fwd_grouped(
673691
logit_cap=0.0,
674692
k_scale=None,
675693
v_scale=None,
694+
is_mla=False,
676695
):
677696
_decode_grouped_att_m_fwd(
678697
q,
@@ -687,6 +706,7 @@ def decode_attention_fwd_grouped(
687706
logit_cap,
688707
k_scale,
689708
v_scale,
709+
is_mla=is_mla,
690710
)
691711
_decode_softmax_reducev_fwd(
692712
attn_logits, q, o, lse, v_buffer, b_seq_len, num_kv_splits
@@ -708,6 +728,7 @@ def decode_attention_fwd(
708728
logit_cap=0.0,
709729
k_scale=None,
710730
v_scale=None,
731+
is_mla=False,
711732
):
712733
assert num_kv_splits == attn_logits.shape[2]
713734

@@ -753,4 +774,5 @@ def decode_attention_fwd(
753774
logit_cap,
754775
k_scale,
755776
v_scale,
777+
is_mla=is_mla,
756778
)

0 commit comments

Comments
 (0)