Skip to content

Commit 6601758

Browse files
authored
[https://nvbugs/5983390][perf] Kernel fusions in _gather_k_cache_for_chunk of Indexer in DSA (NVIDIA#12322)
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
1 parent 45c1d93 commit 6601758

File tree

3 files changed

+315
-84
lines changed

3 files changed

+315
-84
lines changed

tensorrt_llm/_torch/attention_backend/sparse/dsa.py

Lines changed: 18 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@
3232
from tensorrt_llm.mapping import Mapping
3333
from tensorrt_llm.models.modeling_utils import QuantConfig
3434

35-
from .kernel import triton_convert_req_index_to_global_index
35+
from .kernel import (triton_convert_req_index_to_global_index,
36+
triton_gather_k_cache)
3637

3738
ModelConfig = tensorrt_llm.bindings.ModelConfig
3839

@@ -94,24 +95,6 @@ def _compute_slot_mappings(
9495
return fp8_indices, scale_indices
9596

9697

97-
def _unravel_indices(flat_indices: torch.Tensor,
98-
shape: Tuple[int, ...]) -> Tuple[torch.Tensor, ...]:
99-
"""
100-
Unravel indices into multiple dimensions.
101-
"""
102-
d3 = shape[3]
103-
i3 = flat_indices % d3
104-
flat_indices = flat_indices // d3
105-
d2 = shape[2]
106-
i2 = flat_indices % d2
107-
flat_indices = flat_indices // d2
108-
d1 = shape[1]
109-
i1 = flat_indices % d1
110-
flat_indices = flat_indices // d1
111-
i0 = flat_indices
112-
return i0, i1, i2, i3
113-
114-
11598
def rotate_activation(x: torch.Tensor) -> torch.Tensor:
11699
assert x.dtype == torch.bfloat16
117100

@@ -1402,68 +1385,6 @@ def _update_k_cache(self, k_fp8: torch.Tensor, k_scale: torch.Tensor,
14021385
k_cache, flat_indices_fp8,
14031386
flat_indices_scale)
14041387

1405-
def _gather_k_cache_for_chunk(
1406-
self,
1407-
metadata: DSAtrtllmAttentionMetadata,
1408-
chunk: IndexerPrefillChunkMetadata,
1409-
) -> Tuple[torch.Tensor, torch.Tensor]:
1410-
"""
1411-
Gather K values from indexer cache for a specific chunk.
1412-
1413-
Uses pre-computed extended slot mappings that cover cached + current batch context tokens.
1414-
chunk.k_token_start/k_token_end directly index into the extended slot mapping.
1415-
1416-
Args:
1417-
metadata: Attention metadata
1418-
chunk: Chunk metadata with k_token_start/end as indices into extended slot mapping
1419-
1420-
Returns:
1421-
k_fp8: FP8 quantized k tensor, shape [num_k_tokens, head_dim]
1422-
k_scale: Scaling factors, shape [num_k_tokens, 1]
1423-
"""
1424-
assert metadata.slot_mapping_fp8_fullkv is not None, \
1425-
"_gather_k_cache_for_chunk requires extended slot mappings (only available with cached tokens)"
1426-
1427-
k_cache = metadata.kv_cache_manager.get_indexer_k_cache_buffers(
1428-
self.layer_idx)
1429-
1430-
head_dim = self.head_dim
1431-
scale_size = 4 # float32 = 4 bytes
1432-
1433-
# Extract slot mappings using chunk's k_token_start/end
1434-
# These indices point directly into the extended slot mapping array
1435-
k_token_start = chunk.k_token_start
1436-
k_token_end = chunk.k_token_end
1437-
num_k_tokens = k_token_end - k_token_start
1438-
1439-
slot_mapping_fp8_chunk = metadata.slot_mapping_fp8_fullkv[
1440-
k_token_start:k_token_end]
1441-
slot_mapping_scale_chunk = metadata.slot_mapping_scale_fullkv[
1442-
k_token_start:k_token_end]
1443-
1444-
# Vectorized gather using pre-computed slot mappings
1445-
# Gather FP8 data
1446-
byte_offsets_fp8 = torch.arange(
1447-
head_dim, device=k_cache.device).unsqueeze(0) # [1, head_dim]
1448-
gather_indices_fp8 = slot_mapping_fp8_chunk.unsqueeze(
1449-
1) + byte_offsets_fp8 # [num_k_tokens, head_dim]
1450-
gather_indices_fp8 = _unravel_indices(gather_indices_fp8, k_cache.shape)
1451-
k_fp8_bytes = k_cache[gather_indices_fp8]
1452-
k_fp8 = k_fp8_bytes.view(torch.float8_e4m3fn).view(
1453-
num_k_tokens, head_dim)
1454-
1455-
# Gather scale data
1456-
byte_offsets_scale = torch.arange(
1457-
scale_size, device=k_cache.device).unsqueeze(0) # [1, 4]
1458-
gather_indices_scale = slot_mapping_scale_chunk.unsqueeze(
1459-
1) + byte_offsets_scale # [num_k_tokens, 4]
1460-
gather_indices_scale = _unravel_indices(gather_indices_scale,
1461-
k_cache.shape)
1462-
k_scale_bytes = k_cache[gather_indices_scale]
1463-
k_scale = k_scale_bytes.view(torch.float32).view(num_k_tokens, 1)
1464-
1465-
return k_fp8, k_scale
1466-
14671388
def sparse_attn_indexer(
14681389
self,
14691390
metadata: DSAtrtllmAttentionMetadata,
@@ -1502,10 +1423,23 @@ def sparse_attn_indexer(
15021423
tp_rank = metadata.mapping.tp_rank
15031424
tp_size = metadata.mapping.tp_size
15041425

1426+
# Use the 2D pool data directly (contiguous) instead of the
1427+
# 4D view, because the 4D view may have strides that
1428+
# prevent flattening via .view(-1).
1429+
layer_offset = metadata.kv_cache_manager.layer_offsets[
1430+
self.layer_idx]
1431+
gather_k_cache_pool = metadata.kv_cache_manager.indexer_k_cache_pool_per_layer[
1432+
layer_offset]
1433+
15051434
for chunk in metadata.indexer_prefill_chunks:
1506-
# Gather K from cache for this chunk (dual to _update_k_cache)
1507-
chunk_k_fp8, chunk_k_scale = self._gather_k_cache_for_chunk(
1508-
metadata, chunk)
1435+
chunk_k_fp8, chunk_k_scale = triton_gather_k_cache(
1436+
gather_k_cache_pool,
1437+
metadata.slot_mapping_fp8_fullkv,
1438+
metadata.slot_mapping_scale_fullkv,
1439+
chunk.k_token_start,
1440+
chunk.k_token_end,
1441+
self.head_dim,
1442+
)
15091443

15101444
chunk_num_token = chunk.token_end - chunk.token_start
15111445
apply_q_split = q_split_eligible and chunk_num_token >= q_split_threshold

tensorrt_llm/_torch/attention_backend/sparse/kernel.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1936,3 +1936,121 @@ def triton_convert_req_index_to_global_index(
19361936
out_stride1,
19371937
)
19381938
return out
1939+
1940+
1941+
########################################################
1942+
# Fused K cache gather kernel
1943+
########################################################
1944+
1945+
1946+
@triton.jit
1947+
def _triton_gather_k_cache_kernel(
1948+
k_cache_ptr,
1949+
slot_fp8_ptr,
1950+
slot_scale_ptr,
1951+
out_fp8_ptr,
1952+
out_scale_ptr,
1953+
k_token_start,
1954+
num_k_tokens,
1955+
HEAD_DIM: tl.constexpr,
1956+
SCALE_BYTES: tl.constexpr,
1957+
BLOCK_TOKENS: tl.constexpr,
1958+
):
1959+
pid = tl.program_id(0)
1960+
token_offsets = (pid * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS)).to(
1961+
tl.int64)
1962+
token_mask = token_offsets < num_k_tokens
1963+
1964+
fp8_base = tl.load(slot_fp8_ptr + k_token_start + token_offsets,
1965+
mask=token_mask,
1966+
other=0)
1967+
scale_base = tl.load(slot_scale_ptr + k_token_start + token_offsets,
1968+
mask=token_mask,
1969+
other=0)
1970+
1971+
byte_offsets = tl.arange(0, HEAD_DIM).to(tl.int64)
1972+
src_fp8 = fp8_base[:, None] + byte_offsets[None, :]
1973+
dst_fp8 = token_offsets[:, None] * HEAD_DIM + byte_offsets[None, :]
1974+
gather_mask = token_mask[:, None]
1975+
1976+
fp8_data = tl.load(k_cache_ptr + src_fp8, mask=gather_mask, other=0)
1977+
tl.store(out_fp8_ptr + dst_fp8, fp8_data, mask=gather_mask)
1978+
1979+
scale_byte_offsets = tl.arange(0, SCALE_BYTES).to(tl.int64)
1980+
src_scale = scale_base[:, None] + scale_byte_offsets[None, :]
1981+
dst_scale = token_offsets[:,
1982+
None] * SCALE_BYTES + scale_byte_offsets[None, :]
1983+
1984+
scale_data = tl.load(k_cache_ptr + src_scale, mask=gather_mask, other=0)
1985+
tl.store(out_scale_ptr + dst_scale, scale_data, mask=gather_mask)
1986+
1987+
1988+
def triton_gather_k_cache(
1989+
k_cache: torch.Tensor,
1990+
slot_mapping_fp8: torch.Tensor,
1991+
slot_mapping_scale: torch.Tensor,
1992+
k_token_start: int,
1993+
k_token_end: int,
1994+
head_dim: int,
1995+
):
1996+
"""Gather K FP8 values and scales from the indexer K cache for a chunk.
1997+
1998+
Replaces ``_gather_k_cache_for_chunk``, fusing ~8-12 small PyTorch ops
1999+
(arange, unsqueeze, broadcast add, _unravel_indices, advanced indexing)
2000+
into a single Triton kernel that directly gathers from flat byte offsets.
2001+
This is purely data movement — bit-exact with the original.
2002+
2003+
Args:
2004+
k_cache: Indexer K cache pool data (2D contiguous), uint8.
2005+
slot_mapping_fp8: Flat byte indices for FP8 data
2006+
``[total_kv_len]``, int64.
2007+
slot_mapping_scale: Flat byte indices for scale data
2008+
``[total_kv_len]``, int64.
2009+
k_token_start: Start index into slot mapping arrays.
2010+
k_token_end: End index into slot mapping arrays.
2011+
head_dim: FP8 head dimension (typically 128).
2012+
2013+
Returns:
2014+
Tuple of (k_fp8, k_scale):
2015+
k_fp8: ``[num_k_tokens, head_dim]``, float8_e4m3fn.
2016+
k_scale: ``[num_k_tokens, 1]``, float32.
2017+
"""
2018+
num_k_tokens = k_token_end - k_token_start
2019+
device = k_cache.device
2020+
2021+
if num_k_tokens == 0:
2022+
return (
2023+
torch.empty(0, head_dim, dtype=torch.float8_e4m3fn, device=device),
2024+
torch.empty(0, 1, dtype=torch.float32, device=device),
2025+
)
2026+
2027+
SCALE_BYTES = 4
2028+
BLOCK_TOKENS = 32
2029+
2030+
k_cache_flat = k_cache.reshape(-1)
2031+
out_fp8 = torch.empty(num_k_tokens,
2032+
head_dim,
2033+
dtype=torch.uint8,
2034+
device=device)
2035+
out_scale = torch.empty(num_k_tokens,
2036+
SCALE_BYTES,
2037+
dtype=torch.uint8,
2038+
device=device)
2039+
2040+
grid = (triton.cdiv(num_k_tokens, BLOCK_TOKENS), )
2041+
_triton_gather_k_cache_kernel[grid](
2042+
k_cache_flat,
2043+
slot_mapping_fp8,
2044+
slot_mapping_scale,
2045+
out_fp8.view(-1),
2046+
out_scale.view(-1),
2047+
k_token_start,
2048+
num_k_tokens,
2049+
HEAD_DIM=head_dim,
2050+
SCALE_BYTES=SCALE_BYTES,
2051+
BLOCK_TOKENS=BLOCK_TOKENS,
2052+
)
2053+
2054+
k_fp8 = out_fp8.view(torch.float8_e4m3fn)
2055+
k_scale = out_scale.view(torch.float32).view(num_k_tokens, 1)
2056+
return k_fp8, k_scale

0 commit comments

Comments
 (0)