Skip to content

Commit 75b57e3

Browse files
author
F.Liu
committed
[Perf] Optimize torch.where and vectorize PCP/DCP loops in mla_v1.py
Signed-off-by: F.Liu <[email protected]>
1 parent bb7b74c commit 75b57e3

File tree

1 file changed

+39
-46
lines changed

1 file changed

+39
-46
lines changed

vllm_ascend/attention/mla_v1.py

Lines changed: 39 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass
2-
from typing import (TYPE_CHECKING, ClassVar, List, NamedTuple, Optional, Tuple,
3-
Type, TypeVar)
2+
from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
3+
TypeVar)
44

55
import numpy as np
66
import torch
@@ -376,7 +376,7 @@ def pad_actual_seq_len_q_mtp_disable_pad(self, num_reqs_pad_size, num_reqs,
376376
Only use for acl full graph mode.
377377
Pad the last element of the actual_seq_lengths_q equal to the TND(T) and
378378
the num of dimensions equal to the batch_size of main model.
379-
379+
380380
For example:
381381
batch_size = 8, num_reqs = 4, num_speculative_tokens = 1
382382
input actual_seq_lengths_q = [1, 2, 4, 5] (the 3rd req was accept a token)
@@ -661,7 +661,7 @@ def build(
661661
batch_seq_mask, non_blocking=True)
662662
batch_seq_mask = self.batch_seq_mask_buf[:batch_seq_mask.
663663
shape[0]]
664-
cp_seq_len = torch.where(cp_seq_len == 0, 1, cp_seq_len)
664+
cp_seq_len.masked_fill_(cp_seq_len == 0, 1)
665665
else:
666666
cp_seq_len, batch_seq_mask = None, None
667667

@@ -1971,20 +1971,32 @@ def _forward_decode_pcp_dcp(
19711971
attn_output = self._npu_attention_update(attn_out_lse_list)
19721972
return self._v_up_proj(attn_output)
19731973

1974-
def _npu_attention_update(
1975-
self, attn_out_lse_list: List[torch.Tensor]) -> torch.Tensor:
1976-
attn_out_split_cp = []
1977-
attn_lse_split_cp = []
1978-
1979-
for attn_out_lse in attn_out_lse_list:
1980-
attn_out_allgather, attn_lse_allgather = self._out_lse_reshape(
1981-
*torch.split(attn_out_lse, [self.kv_lora_rank, 1], dim=-1))
1982-
attn_out_split_cp.append(attn_out_allgather)
1983-
attn_lse_split_cp.append(attn_lse_allgather)
1984-
attn_out, _ = torch_npu.npu_attention_update(attn_lse_split_cp,
1985-
attn_out_split_cp, 0)
1986-
attn_out = attn_out.view(-1, attn_out_lse_list[0].shape[1],
1987-
self.kv_lora_rank)
1974+
def _npu_attention_update(self,
1975+
attn_out_lse: torch.Tensor) -> torch.Tensor:
1976+
B_total, H_total, D_plus_1 = attn_out_lse.shape
1977+
S = B_total // self.pcp_size
1978+
H = H_total // self.dcp_size
1979+
D = self.kv_lora_rank
1980+
assert D_plus_1 == D + 1
1981+
# [PCP, S, DCP, H, D+1]
1982+
x = attn_out_lse.view(self.pcp_size, S, self.dcp_size, H, D_plus_1)
1983+
# [PCP, DCP, S, H, D+1]
1984+
x = x.permute(0, 2, 1, 3, 4).contiguous()
1985+
# Flatten [N, S, H, D+1], N = pcp_size * dcp_size
1986+
x = x.view(-1, S, H, D_plus_1)
1987+
# Split out lse
1988+
out_flat, lse_flat = torch.split(x, [D, 1],
1989+
dim=-1) # [N, S, H, D], [N, S, H, 1]
1990+
# out: [N, S, H, D] -> [N, S*H, D]
1991+
# lse: [N, S, H, 1] -> [N, S*H]
1992+
out_flat = out_flat.flatten(1, 2) # [N, S*H, D]
1993+
lse_flat = lse_flat.squeeze(-1).flatten(1) # [N, S*H]
1994+
# unbind to list
1995+
out_list = out_flat.unbind(0) # [S*H, D]
1996+
lse_list = lse_flat.unbind(0) # [S*H]
1997+
1998+
attn_out, _ = torch_npu.npu_attention_update(lse_list, out_list, 0)
1999+
attn_out = attn_out.view(-1, H, self.kv_lora_rank)
19882000
return attn_out
19892001

19902002
def _out_lse_reshape(self, attn_out: torch.Tensor,
@@ -2000,14 +2012,10 @@ def _process_attn_out_lse(
20002012
attn_output: torch.Tensor,
20012013
softmax_lse: torch.Tensor,
20022014
decode_meta: AscendMLADecodeMetadata,
2003-
) -> List[torch.Tensor]:
2004-
attn_out_lse_list = []
2005-
out_mask = decode_meta.batch_seq_mask[:, None,
2006-
None].expand_as(attn_output)
2007-
attn_output = torch.where(out_mask, 0, attn_output)
2008-
lse_mask = decode_meta.batch_seq_mask[:, None,
2009-
None].expand_as(softmax_lse)
2010-
softmax_lse = torch.where(lse_mask, -torch.inf, softmax_lse)
2015+
) -> torch.Tensor:
2016+
out_lse_mask = decode_meta.batch_seq_mask[:, None, None].bool()
2017+
attn_output.masked_fill_(out_lse_mask, 0)
2018+
softmax_lse.masked_fill_(out_lse_mask, -torch.inf)
20112019

20122020
softmax_lse = softmax_lse.to(torch.float32)
20132021
attn_output = attn_output.to(torch.float32)
@@ -2020,30 +2028,15 @@ def _process_attn_out_lse(
20202028
dist.all_to_all_single(attn_out_lse_all2all,
20212029
attn_out_lse,
20222030
group=self.dcp_group)
2023-
# permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1]
20242031
attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1])
20252032
if self.pcp_size > 1:
20262033
attn_out_lse = attn_out_lse_all2all.contiguous()
2027-
attn_out_lse_list = list(
2028-
torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1))
20292034

20302035
if self.pcp_size > 1:
2031-
# AllGather out&lse within PCP group
2032-
attn_out_lse_list = [
2033-
torch.empty_like(attn_out_lse) for _ in range(self.pcp_size)
2034-
]
2035-
dist.all_gather(attn_out_lse_list,
2036-
attn_out_lse,
2037-
group=self.pcp_group)
2038-
if self.dcp_size > 1 and self.pcp_size > 1:
2039-
attn_out_lse_list_pcp_dcp = []
2040-
for s in attn_out_lse_list:
2041-
attn_out_lse_list_split = list(
2042-
torch.chunk(s, self.dcp_size, dim=1))
2043-
attn_out_lse_list_pcp_dcp += attn_out_lse_list_split
2044-
attn_out_lse_list = attn_out_lse_list_pcp_dcp
2045-
2046-
return attn_out_lse_list
2036+
# AllGather out&lse within CP group
2037+
attn_out_lse = get_pcp_group().all_gather(attn_out_lse, dim=0)
2038+
2039+
return attn_out_lse
20472040

20482041
def _reorg_kvcache(
20492042
self,
@@ -2116,4 +2109,4 @@ def _reorg_kvcache(
21162109
assert reorganized_kv_c_normed.shape[0] == sum_seq_len
21172110
assert reorganized_k_pe.shape[0] == sum_seq_len
21182111
assert max_seq_len_check == max_seq_len
2119-
return reorganized_kv_c_normed, reorganized_k_pe
2112+
return reorganized_kv_c_normed, reorganized_k_pe

0 commit comments

Comments
 (0)