11from dataclasses import dataclass
2- from typing import (TYPE_CHECKING , ClassVar , List , NamedTuple , Optional , Tuple ,
3- Type , TypeVar )
2+ from typing import (TYPE_CHECKING , ClassVar , NamedTuple , Optional , Tuple , Type ,
3+ TypeVar )
44
55import numpy as np
66import torch
@@ -376,7 +376,7 @@ def pad_actual_seq_len_q_mtp_disable_pad(self, num_reqs_pad_size, num_reqs,
376376 Only use for acl full graph mode.
377377 Pad the last element of the actual_seq_lengths_q equal to the TND(T) and
378378 the num of dimensions equal to the batch_size of main model.
379-
379+
380380 For example:
381381 batch_size = 8, num_reqs = 4, num_speculative_tokens = 1
382382 input actual_seq_lengths_q = [1, 2, 4, 5] (the 3rd req was accept a token)
@@ -661,7 +661,7 @@ def build(
661661 batch_seq_mask , non_blocking = True )
662662 batch_seq_mask = self .batch_seq_mask_buf [:batch_seq_mask .
663663 shape [0 ]]
664- cp_seq_len = torch . where (cp_seq_len == 0 , 1 , cp_seq_len )
664+ cp_seq_len . masked_fill_ (cp_seq_len == 0 , 1 )
665665 else :
666666 cp_seq_len , batch_seq_mask = None , None
667667
@@ -1971,20 +1971,32 @@ def _forward_decode_pcp_dcp(
19711971 attn_output = self ._npu_attention_update (attn_out_lse_list )
19721972 return self ._v_up_proj (attn_output )
19731973
1974- def _npu_attention_update (
1975- self , attn_out_lse_list : List [torch .Tensor ]) -> torch .Tensor :
1976- attn_out_split_cp = []
1977- attn_lse_split_cp = []
1978-
1979- for attn_out_lse in attn_out_lse_list :
1980- attn_out_allgather , attn_lse_allgather = self ._out_lse_reshape (
1981- * torch .split (attn_out_lse , [self .kv_lora_rank , 1 ], dim = - 1 ))
1982- attn_out_split_cp .append (attn_out_allgather )
1983- attn_lse_split_cp .append (attn_lse_allgather )
1984- attn_out , _ = torch_npu .npu_attention_update (attn_lse_split_cp ,
1985- attn_out_split_cp , 0 )
1986- attn_out = attn_out .view (- 1 , attn_out_lse_list [0 ].shape [1 ],
1987- self .kv_lora_rank )
1974+ def _npu_attention_update (self ,
1975+ attn_out_lse : torch .Tensor ) -> torch .Tensor :
1976+ B_total , H_total , D_plus_1 = attn_out_lse .shape
1977+ S = B_total // self .pcp_size
1978+ H = H_total // self .dcp_size
1979+ D = self .kv_lora_rank
1980+ assert D_plus_1 == D + 1
1981+ # [PCP, S, DCP, H, D+1]
1982+ x = attn_out_lse .view (self .pcp_size , S , self .dcp_size , H , D_plus_1 )
1983+ # [PCP, DCP, S, H, D+1]
1984+ x = x .permute (0 , 2 , 1 , 3 , 4 ).contiguous ()
1985+ # Flatten [N, S, H, D+1], N = pcp_size * dcp_size
1986+ x = x .view (- 1 , S , H , D_plus_1 )
1987+ # Split out lse
1988+ out_flat , lse_flat = torch .split (x , [D , 1 ],
1989+ dim = - 1 ) # [N, S, H, D], [N, S, H, 1]
1990+ # out: [N, S, H, D] -> [N, S*H, D]
1991+ # lse: [N, S, H, 1] -> [N, S*H]
1992+ out_flat = out_flat .flatten (1 , 2 ) # [N, S*H, D]
1993+ lse_flat = lse_flat .squeeze (- 1 ).flatten (1 ) # [N, S*H]
1994+ # unbind to list
1995+ out_list = out_flat .unbind (0 ) # [S*H, D]
1996+ lse_list = lse_flat .unbind (0 ) # [S*H]
1997+
1998+ attn_out , _ = torch_npu .npu_attention_update (lse_list , out_list , 0 )
1999+ attn_out = attn_out .view (- 1 , H , self .kv_lora_rank )
19882000 return attn_out
19892001
19902002 def _out_lse_reshape (self , attn_out : torch .Tensor ,
@@ -2000,14 +2012,10 @@ def _process_attn_out_lse(
20002012 attn_output : torch .Tensor ,
20012013 softmax_lse : torch .Tensor ,
20022014 decode_meta : AscendMLADecodeMetadata ,
2003- ) -> List [torch .Tensor ]:
2004- attn_out_lse_list = []
2005- out_mask = decode_meta .batch_seq_mask [:, None ,
2006- None ].expand_as (attn_output )
2007- attn_output = torch .where (out_mask , 0 , attn_output )
2008- lse_mask = decode_meta .batch_seq_mask [:, None ,
2009- None ].expand_as (softmax_lse )
2010- softmax_lse = torch .where (lse_mask , - torch .inf , softmax_lse )
2015+ ) -> torch .Tensor :
2016+ out_lse_mask = decode_meta .batch_seq_mask [:, None , None ].bool ()
2017+ attn_output .masked_fill_ (out_lse_mask , 0 )
2018+ softmax_lse .masked_fill_ (out_lse_mask , - torch .inf )
20112019
20122020 softmax_lse = softmax_lse .to (torch .float32 )
20132021 attn_output = attn_output .to (torch .float32 )
@@ -2020,30 +2028,15 @@ def _process_attn_out_lse(
20202028 dist .all_to_all_single (attn_out_lse_all2all ,
20212029 attn_out_lse ,
20222030 group = self .dcp_group )
2023- # permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1]
20242031 attn_out_lse_all2all = attn_out_lse_all2all .permute ([2 , 0 , 1 ])
20252032 if self .pcp_size > 1 :
20262033 attn_out_lse = attn_out_lse_all2all .contiguous ()
2027- attn_out_lse_list = list (
2028- torch .chunk (attn_out_lse_all2all , self .dcp_size , dim = 1 ))
20292034
20302035 if self .pcp_size > 1 :
2031- # AllGather out&lse within PCP group
2032- attn_out_lse_list = [
2033- torch .empty_like (attn_out_lse ) for _ in range (self .pcp_size )
2034- ]
2035- dist .all_gather (attn_out_lse_list ,
2036- attn_out_lse ,
2037- group = self .pcp_group )
2038- if self .dcp_size > 1 and self .pcp_size > 1 :
2039- attn_out_lse_list_pcp_dcp = []
2040- for s in attn_out_lse_list :
2041- attn_out_lse_list_split = list (
2042- torch .chunk (s , self .dcp_size , dim = 1 ))
2043- attn_out_lse_list_pcp_dcp += attn_out_lse_list_split
2044- attn_out_lse_list = attn_out_lse_list_pcp_dcp
2045-
2046- return attn_out_lse_list
2036+ # AllGather out&lse within CP group
2037+ attn_out_lse = get_pcp_group ().all_gather (attn_out_lse , dim = 0 )
2038+
2039+ return attn_out_lse
20472040
20482041 def _reorg_kvcache (
20492042 self ,
@@ -2116,4 +2109,4 @@ def _reorg_kvcache(
21162109 assert reorganized_kv_c_normed .shape [0 ] == sum_seq_len
21172110 assert reorganized_k_pe .shape [0 ] == sum_seq_len
21182111 assert max_seq_len_check == max_seq_len
2119- return reorganized_kv_c_normed , reorganized_k_pe
2112+ return reorganized_kv_c_normed , reorganized_k_pe
0 commit comments