Skip to content

Commit 1bc92b2

Browse files
ganyi1996ppocwazai
authored andcommitted
[ROCm][Deepseekv3.2] Refactor Sparse Indexer as CustomOp (vllm-project#29287)
Signed-off-by: ganyi <ygan@amd.com> Signed-off-by: 陈建华 <1647430658@qq.com>
1 parent d1d5acf commit 1bc92b2

8 files changed

Lines changed: 982 additions & 323 deletions

File tree

vllm/_aiter_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
import vllm.envs as envs
1010
from vllm.platforms import current_platform
1111
from vllm.utils.torch_utils import direct_register_custom_op, is_torch_equal_or_newer
12+
from vllm.v1.attention.ops.rocm_aiter_mla_sparse import (
13+
rocm_aiter_sparse_attn_indexer,
14+
rocm_aiter_sparse_attn_indexer_fake,
15+
)
1216

1317
_FP8_DTYPE = current_platform.fp8_dtype()
1418

@@ -1091,6 +1095,14 @@ def register_ops_once() -> None:
10911095
dispatch_key=current_platform.dispatch_key,
10921096
)
10931097

1098+
direct_register_custom_op(
1099+
op_name="rocm_aiter_sparse_attn_indexer",
1100+
op_func=rocm_aiter_sparse_attn_indexer,
1101+
mutates_args=["topk_indices_buffer"],
1102+
fake_impl=rocm_aiter_sparse_attn_indexer_fake,
1103+
dispatch_key=current_platform.dispatch_key,
1104+
)
1105+
10941106
_OPS_REGISTERED = True
10951107

10961108
@staticmethod

vllm/config/compilation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,7 @@ class CompilationConfig:
611611
"vllm::gdn_attention_core",
612612
"vllm::kda_attention",
613613
"vllm::sparse_attn_indexer",
614+
"vllm::rocm_aiter_sparse_attn_indexer",
614615
]
615616

616617
def compute_hash(self) -> str:
Lines changed: 318 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,318 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""Custom Sparse Attention Indexer layers."""
4+
5+
import torch
6+
7+
from vllm._aiter_ops import rocm_aiter_ops
8+
from vllm.forward_context import get_forward_context
9+
from vllm.logger import init_logger
10+
from vllm.model_executor.custom_op import CustomOp
11+
from vllm.platforms import current_platform
12+
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits
13+
from vllm.utils.torch_utils import direct_register_custom_op
14+
from vllm.v1.attention.backends.mla.indexer import (
15+
DeepseekV32IndexerMetadata,
16+
)
17+
from vllm.v1.attention.ops.common import pack_seq_triton, unpack_seq_triton
18+
from vllm.v1.worker.workspace import current_workspace_manager
19+
20+
if current_platform.is_cuda_alike():
21+
from vllm import _custom_ops as ops
22+
elif current_platform.is_xpu():
23+
from vllm._ipex_ops import ipex_ops as ops
24+
25+
logger = init_logger(__name__)
26+
27+
28+
def sparse_attn_indexer(
29+
hidden_states: torch.Tensor,
30+
k_cache_prefix: str,
31+
kv_cache: torch.Tensor,
32+
q_fp8: torch.Tensor,
33+
k: torch.Tensor,
34+
weights: torch.Tensor,
35+
quant_block_size: int,
36+
scale_fmt: str | None,
37+
topk_tokens: int,
38+
head_dim: int,
39+
max_model_len: int,
40+
total_seq_lens: int,
41+
topk_indices_buffer: torch.Tensor,
42+
) -> torch.Tensor:
43+
# careful! this will be None in dummy run
44+
attn_metadata = get_forward_context().attn_metadata
45+
fp8_dtype = current_platform.fp8_dtype()
46+
47+
# assert isinstance(attn_metadata, dict)
48+
if not isinstance(attn_metadata, dict):
49+
# Reserve workspace for indexer during profiling run
50+
current_workspace_manager().get_simultaneous(
51+
((total_seq_lens, head_dim), torch.float8_e4m3fn),
52+
((total_seq_lens, 4), torch.uint8),
53+
)
54+
return sparse_attn_indexer_fake(
55+
hidden_states,
56+
k_cache_prefix,
57+
kv_cache,
58+
q_fp8,
59+
k,
60+
weights,
61+
quant_block_size,
62+
scale_fmt,
63+
topk_tokens,
64+
head_dim,
65+
max_model_len,
66+
total_seq_lens,
67+
topk_indices_buffer,
68+
)
69+
attn_metadata = attn_metadata[k_cache_prefix]
70+
assert isinstance(attn_metadata, DeepseekV32IndexerMetadata)
71+
slot_mapping = attn_metadata.slot_mapping
72+
has_decode = attn_metadata.num_decodes > 0
73+
has_prefill = attn_metadata.num_prefills > 0
74+
num_decode_tokens = attn_metadata.num_decode_tokens
75+
76+
ops.indexer_k_quant_and_cache(
77+
k,
78+
kv_cache,
79+
slot_mapping,
80+
quant_block_size,
81+
scale_fmt,
82+
)
83+
84+
topk_indices_buffer[: hidden_states.shape[0]] = -1
85+
if has_prefill:
86+
prefill_metadata = attn_metadata.prefill
87+
88+
# Get the full shared workspace buffers once (will allocate on first use)
89+
workspace_manager = current_workspace_manager()
90+
k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
91+
((total_seq_lens, head_dim), fp8_dtype),
92+
((total_seq_lens, 4), torch.uint8),
93+
)
94+
for chunk in prefill_metadata.chunks:
95+
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
96+
k_scale = k_scale_full[: chunk.total_seq_lens]
97+
ops.cp_gather_indexer_k_quant_cache(
98+
kv_cache,
99+
k_fp8,
100+
k_scale,
101+
chunk.block_table,
102+
chunk.cu_seq_lens,
103+
)
104+
105+
logits = fp8_mqa_logits(
106+
q_fp8[chunk.token_start : chunk.token_end],
107+
(k_fp8, k_scale.view(torch.float32).flatten()),
108+
weights[chunk.token_start : chunk.token_end],
109+
chunk.cu_seqlen_ks,
110+
chunk.cu_seqlen_ke,
111+
)
112+
num_rows = logits.shape[0]
113+
114+
topk_indices = topk_indices_buffer[
115+
chunk.token_start : chunk.token_end, :topk_tokens
116+
]
117+
torch.ops._C.top_k_per_row_prefill(
118+
logits,
119+
chunk.cu_seqlen_ks,
120+
chunk.cu_seqlen_ke,
121+
topk_indices,
122+
num_rows,
123+
logits.stride(0),
124+
logits.stride(1),
125+
topk_tokens,
126+
)
127+
128+
if has_decode:
129+
decode_metadata = attn_metadata.decode
130+
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
131+
# we only have [num_block, block_size, head_dim],
132+
kv_cache = kv_cache.unsqueeze(-2)
133+
decode_lens = decode_metadata.decode_lens
134+
if decode_metadata.requires_padding:
135+
# pad in edge case where we have short chunked prefill length <
136+
# decode_threshold since we unstrictly split
137+
# prefill and decode by decode_threshold
138+
# (currently set to 1 + speculative tokens)
139+
padded_q_fp8_decode_tokens = pack_seq_triton(
140+
q_fp8[:num_decode_tokens], decode_lens
141+
)
142+
else:
143+
padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
144+
decode_lens.shape[0], -1, *q_fp8.shape[1:]
145+
)
146+
# TODO: move and optimize below logic with triton kernels
147+
batch_size = padded_q_fp8_decode_tokens.shape[0]
148+
next_n = padded_q_fp8_decode_tokens.shape[1]
149+
assert batch_size == decode_metadata.seq_lens.shape[0]
150+
num_padded_tokens = batch_size * next_n
151+
152+
logits = fp8_paged_mqa_logits(
153+
padded_q_fp8_decode_tokens,
154+
kv_cache,
155+
weights[:num_padded_tokens],
156+
decode_metadata.seq_lens,
157+
decode_metadata.block_table,
158+
decode_metadata.schedule_metadata,
159+
max_model_len=max_model_len,
160+
)
161+
162+
num_rows = logits.shape[0]
163+
164+
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
165+
torch.ops._C.top_k_per_row_decode(
166+
logits,
167+
next_n,
168+
decode_metadata.seq_lens,
169+
topk_indices,
170+
num_rows,
171+
logits.stride(0),
172+
logits.stride(1),
173+
topk_tokens,
174+
)
175+
176+
if decode_metadata.requires_padding:
177+
# if padded, we need to unpack
178+
# the topk indices removing padded tokens
179+
topk_indices = unpack_seq_triton(
180+
topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]),
181+
decode_lens,
182+
)
183+
topk_indices_buffer[:num_decode_tokens, : topk_indices.shape[-1]] = (
184+
topk_indices
185+
)
186+
187+
return topk_indices_buffer
188+
189+
190+
def sparse_attn_indexer_fake(
191+
hidden_states: torch.Tensor,
192+
k_cache_prefix: str,
193+
kv_cache: torch.Tensor,
194+
q_fp8: torch.Tensor,
195+
k: torch.Tensor,
196+
weights: torch.Tensor,
197+
quant_block_size: int,
198+
scale_fmt: str | None,
199+
topk_tokens: int,
200+
head_dim: int,
201+
max_model_len: int,
202+
total_seq_lens: int,
203+
topk_indices_buffer: torch.Tensor | None,
204+
) -> torch.Tensor:
205+
return topk_indices_buffer
206+
207+
208+
direct_register_custom_op(
209+
op_name="sparse_attn_indexer",
210+
op_func=sparse_attn_indexer,
211+
mutates_args=["topk_indices_buffer"],
212+
fake_impl=sparse_attn_indexer_fake,
213+
dispatch_key=current_platform.dispatch_key,
214+
)
215+
216+
217+
@CustomOp.register("sparse_attn_indexer")
218+
class SparseAttnIndexer(CustomOp):
219+
"""Sparse Attention Indexer Custom Op Layer. This layer is extracted as a
220+
separate custom op since it involves heavy custom kernels like `mqa_logits`,
221+
`paged_mqa_logits` and `top_k_per_row`, etc. Those kernels maybe requires
222+
specific memory layout or implementation for different hardware backends to
223+
achieve optimal performance.
224+
225+
For now, the default native path will use CUDA backend path. Other platform
226+
may requires add the corresponding Custom Op name `sparse_attn_indexer` to
227+
`custom_ops` in `CompilationConfig` to enable the platform specific path.
228+
"""
229+
230+
def __init__(
231+
self,
232+
k_cache,
233+
quant_block_size: int,
234+
scale_fmt: str,
235+
topk_tokens: int,
236+
head_dim: int,
237+
max_model_len: int,
238+
max_total_seq_len: int,
239+
topk_indices_buffer: torch.Tensor,
240+
):
241+
super().__init__()
242+
self.k_cache = k_cache
243+
self.quant_block_size = quant_block_size
244+
self.scale_fmt = scale_fmt
245+
self.topk_tokens = topk_tokens
246+
self.head_dim = head_dim
247+
self.max_model_len = max_model_len
248+
self.max_total_seq_len = max_total_seq_len
249+
self.topk_indices_buffer = topk_indices_buffer
250+
251+
def forward_native(
252+
self,
253+
hidden_states: torch.Tensor,
254+
q_fp8: torch.Tensor,
255+
k: torch.Tensor,
256+
weights: torch.Tensor,
257+
):
258+
if current_platform.is_cuda():
259+
return self.forward_cuda(hidden_states, q_fp8, k, weights)
260+
elif current_platform.is_rocm():
261+
return self.forward_hip(hidden_states, q_fp8, k, weights)
262+
else:
263+
raise NotImplementedError(
264+
"SparseAttnIndexer native forward is only implemented for "
265+
"CUDA and ROCm platform."
266+
)
267+
268+
def forward_cuda(
269+
self,
270+
hidden_states: torch.Tensor,
271+
q_fp8: torch.Tensor,
272+
k: torch.Tensor,
273+
weights: torch.Tensor,
274+
):
275+
return torch.ops.vllm.sparse_attn_indexer(
276+
hidden_states,
277+
self.k_cache.prefix,
278+
self.k_cache.kv_cache[0],
279+
q_fp8,
280+
k,
281+
weights,
282+
self.quant_block_size,
283+
self.scale_fmt,
284+
self.topk_tokens,
285+
self.head_dim,
286+
self.max_model_len,
287+
self.max_total_seq_len,
288+
self.topk_indices_buffer,
289+
)
290+
291+
def forward_hip(
292+
self,
293+
hidden_states: torch.Tensor,
294+
q_fp8: torch.Tensor,
295+
k: torch.Tensor,
296+
weights: torch.Tensor,
297+
):
298+
if rocm_aiter_ops.is_enabled():
299+
return torch.ops.vllm.rocm_aiter_sparse_attn_indexer(
300+
hidden_states,
301+
self.k_cache.prefix,
302+
self.k_cache.kv_cache[0],
303+
q_fp8,
304+
k,
305+
weights,
306+
self.quant_block_size,
307+
self.scale_fmt,
308+
self.topk_tokens,
309+
self.head_dim,
310+
self.max_model_len,
311+
self.max_total_seq_len,
312+
self.topk_indices_buffer,
313+
)
314+
else:
315+
raise RuntimeError(
316+
"Sparse attention indexer ROCm custom op requires ROCm "
317+
"Aiter ops to be enabled."
318+
)

0 commit comments

Comments
 (0)