Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions tests/ut/attention/test_mla_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,10 +451,10 @@ def mock_all_gather_func(tensor, dim):
self.assertIsNone(decode_res)
self.assertIsNotNone(prefill_res)

@patch("torch.distributed.all_gather")
@patch('vllm.distributed.parallel_state._PCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("torch.distributed.all_to_all_single")
def test_process_attn_out_lse(self, mock_all_to_all_single,
mock_all_gather):
def test_process_attn_out_lse(self, mock_all_to_all_single, mock_pcp):
self.impl.dcp_size = 2
self.impl.pcp_size = 2

Expand All @@ -468,11 +468,10 @@ def test_process_attn_out_lse(self, mock_all_to_all_single,
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
input)

def mock_all_gather_func(tensor_list, tensor, group=None):
tensor_list[0] = tensor
tensor_list[1] = tensor.clone()
def make_all_gather(ws):
return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim)

mock_all_gather.side_effect = mock_all_gather_func
mock_pcp.all_gather = MagicMock(side_effect=make_all_gather(2))

decode_metadata = MagicMock()
decode_metadata.actual_seq_lengths_q = MagicMock()
Expand All @@ -487,15 +486,16 @@ def mock_all_gather_func(tensor_list, tensor, group=None):
self.assertEqual(result[0].shape[1], N / self.impl.dcp_size)
self.assertEqual(result[0].shape[2], self.impl.kv_lora_rank + 1)

@patch("torch.distributed.all_gather")
@patch('vllm.distributed.parallel_state._PCP',
new_callable=lambda: MagicMock(spec=GroupCoordinator))
@patch("torch.distributed.all_to_all_single")
@patch('vllm_ascend.attention.mla_cp.get_forward_context')
@patch("torch_npu.atb.npu_multi_head_latent_attention")
@patch('torch_npu.npu_attention_update')
def test_forward_decode_pcp_dcp(self, mock_npu_attention_update,
mock_npu_multi_head_latent_attention,
mock_get_forward_context,
mock_all_to_all_single, mock_all_gather):
mock_all_to_all_single, mock_pcp):
self.impl.dcp_size = 2
self.impl.pcp_size = 2
self.impl.num_kv_heads = 1
Expand Down Expand Up @@ -534,11 +534,10 @@ def test_forward_decode_pcp_dcp(self, mock_npu_attention_update,
mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_(
input)

def mock_all_gather_func(tensor_list, tensor, group=None):
tensor_list[0] = tensor
tensor_list[1] = tensor.clone()
def make_all_gather(ws):
return lambda tensor, dim: torch.cat([tensor] * ws, dim=dim)

mock_all_gather.side_effect = mock_all_gather_func
mock_pcp.all_gather = MagicMock(side_effect=make_all_gather(2))

self.impl._v_up_proj = MagicMock()
self.impl._v_up_proj.return_value = torch.randn(
Expand Down
75 changes: 34 additions & 41 deletions vllm_ascend/attention/mla_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,26 +1120,37 @@ def _forward_decode_pcp_dcp(
lse=softmax_lse)

# Update out&lse
attn_out_lse_list = self._process_attn_out_lse(attn_output,
softmax_lse,
decode_meta)
attn_output = self._npu_attention_update(attn_out_lse_list)
attn_out_lse = self._process_attn_out_lse(attn_output, softmax_lse,
decode_meta)
attn_output = self._npu_attention_update(attn_out_lse)
return self._v_up_proj(attn_output)

def _npu_attention_update(
self, attn_out_lse_list: List[torch.Tensor]) -> torch.Tensor:
attn_out_split_cp = []
attn_lse_split_cp = []

for attn_out_lse in attn_out_lse_list:
attn_out_allgather, attn_lse_allgather = self._out_lse_reshape(
*torch.split(attn_out_lse, [self.kv_lora_rank, 1], dim=-1))
attn_out_split_cp.append(attn_out_allgather)
attn_lse_split_cp.append(attn_lse_allgather)
attn_out, _ = torch_npu.npu_attention_update(attn_lse_split_cp,
attn_out_split_cp, 0)
attn_out = attn_out.view(-1, attn_out_lse_list[0].shape[1],
self.kv_lora_rank)
def _npu_attention_update(self,
attn_out_lse: torch.Tensor) -> torch.Tensor:
B_total, H_total, D_plus_1 = attn_out_lse.shape
S = B_total // self.pcp_size
H = H_total // self.dcp_size
D = self.kv_lora_rank
assert D_plus_1 == D + 1
# [PCP, S, DCP, H, D+1]
x = attn_out_lse.view(self.pcp_size, S, self.dcp_size, H, D_plus_1)
# [PCP, DCP, S, H, D+1]
x = x.permute(0, 2, 1, 3, 4).contiguous()
# Flatten [N, S, H, D+1], N = pcp_size * dcp_size
x = x.view(-1, S, H, D_plus_1)
# Split out lse
out_flat, lse_flat = torch.split(x, [D, 1],
dim=-1) # [N, S, H, D], [N, S, H, 1]
# out: [N, S, H, D] -> [N, S*H, D]
# lse: [N, S, H, 1] -> [N, S*H]
out_flat = out_flat.flatten(1, 2) # [N, S*H, D]
lse_flat = lse_flat.squeeze(-1).flatten(1) # [N, S*H]
# unbind to list
out_list = out_flat.unbind(0) # [S*H, D]
lse_list = lse_flat.unbind(0) # [S*H]
attn_out, _ = torch_npu.npu_attention_update(out_list, lse_list, 0)
print(attn_out.shape)
attn_out = attn_out.view(-1, H, D)
return attn_out

def _out_lse_reshape(self, attn_out: torch.Tensor,
Expand All @@ -1156,7 +1167,6 @@ def _process_attn_out_lse(
softmax_lse: torch.Tensor,
decode_meta: AscendMLADecodeMetadata,
) -> List[torch.Tensor]:
attn_out_lse_list = []
out_mask = decode_meta.batch_seq_mask[:, None,
None].expand_as(attn_output)
attn_output = torch.where(out_mask, 0, attn_output)
Expand All @@ -1175,30 +1185,13 @@ def _process_attn_out_lse(
dist.all_to_all_single(attn_out_lse_all2all,
attn_out_lse,
group=self.dcp_group)
# permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1]
attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1])
if self.pcp_size > 1:
attn_out_lse = attn_out_lse_all2all.contiguous()
attn_out_lse_list = list(
torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1))
attn_out_lse = attn_out_lse_all2all.permute([2, 0, 1])

if self.pcp_size > 1:
# AllGather out&lse within PCP group
attn_out_lse_list = [
torch.empty_like(attn_out_lse) for _ in range(self.pcp_size)
]
dist.all_gather(attn_out_lse_list,
attn_out_lse,
group=self.pcp_group)
if self.dcp_size > 1 and self.pcp_size > 1:
attn_out_lse_list_pcp_dcp = []
for s in attn_out_lse_list:
attn_out_lse_list_split = list(
torch.chunk(s, self.dcp_size, dim=1))
attn_out_lse_list_pcp_dcp += attn_out_lse_list_split
attn_out_lse_list = attn_out_lse_list_pcp_dcp

return attn_out_lse_list
# AllGather out&lse within CP group
attn_out_lse = get_pcp_group().all_gather(attn_out_lse.contiguous(), dim=0)

return attn_out_lse

def _reorg_kvcache(
self,
Expand Down
Loading