|
32 | 32 | from tensorrt_llm.mapping import Mapping |
33 | 33 | from tensorrt_llm.models.modeling_utils import QuantConfig |
34 | 34 |
|
35 | | -from .kernel import triton_convert_req_index_to_global_index |
| 35 | +from .kernel import (triton_convert_req_index_to_global_index, |
| 36 | + triton_gather_k_cache) |
36 | 37 |
|
37 | 38 | ModelConfig = tensorrt_llm.bindings.ModelConfig |
38 | 39 |
|
@@ -94,24 +95,6 @@ def _compute_slot_mappings( |
94 | 95 | return fp8_indices, scale_indices |
95 | 96 |
|
96 | 97 |
|
97 | | -def _unravel_indices(flat_indices: torch.Tensor, |
98 | | - shape: Tuple[int, ...]) -> Tuple[torch.Tensor, ...]: |
99 | | - """ |
100 | | - Unravel indices into multiple dimensions. |
101 | | - """ |
102 | | - d3 = shape[3] |
103 | | - i3 = flat_indices % d3 |
104 | | - flat_indices = flat_indices // d3 |
105 | | - d2 = shape[2] |
106 | | - i2 = flat_indices % d2 |
107 | | - flat_indices = flat_indices // d2 |
108 | | - d1 = shape[1] |
109 | | - i1 = flat_indices % d1 |
110 | | - flat_indices = flat_indices // d1 |
111 | | - i0 = flat_indices |
112 | | - return i0, i1, i2, i3 |
113 | | - |
114 | | - |
115 | 98 | def rotate_activation(x: torch.Tensor) -> torch.Tensor: |
116 | 99 | assert x.dtype == torch.bfloat16 |
117 | 100 |
|
@@ -1402,68 +1385,6 @@ def _update_k_cache(self, k_fp8: torch.Tensor, k_scale: torch.Tensor, |
1402 | 1385 | k_cache, flat_indices_fp8, |
1403 | 1386 | flat_indices_scale) |
1404 | 1387 |
|
1405 | | - def _gather_k_cache_for_chunk( |
1406 | | - self, |
1407 | | - metadata: DSAtrtllmAttentionMetadata, |
1408 | | - chunk: IndexerPrefillChunkMetadata, |
1409 | | - ) -> Tuple[torch.Tensor, torch.Tensor]: |
1410 | | - """ |
1411 | | - Gather K values from indexer cache for a specific chunk. |
1412 | | -
|
1413 | | - Uses pre-computed extended slot mappings that cover cached + current batch context tokens. |
1414 | | - chunk.k_token_start/k_token_end directly index into the extended slot mapping. |
1415 | | -
|
1416 | | - Args: |
1417 | | - metadata: Attention metadata |
1418 | | - chunk: Chunk metadata with k_token_start/end as indices into extended slot mapping |
1419 | | -
|
1420 | | - Returns: |
1421 | | - k_fp8: FP8 quantized k tensor, shape [num_k_tokens, head_dim] |
1422 | | - k_scale: Scaling factors, shape [num_k_tokens, 1] |
1423 | | - """ |
1424 | | - assert metadata.slot_mapping_fp8_fullkv is not None, \ |
1425 | | - "_gather_k_cache_for_chunk requires extended slot mappings (only available with cached tokens)" |
1426 | | - |
1427 | | - k_cache = metadata.kv_cache_manager.get_indexer_k_cache_buffers( |
1428 | | - self.layer_idx) |
1429 | | - |
1430 | | - head_dim = self.head_dim |
1431 | | - scale_size = 4 # float32 = 4 bytes |
1432 | | - |
1433 | | - # Extract slot mappings using chunk's k_token_start/end |
1434 | | - # These indices point directly into the extended slot mapping array |
1435 | | - k_token_start = chunk.k_token_start |
1436 | | - k_token_end = chunk.k_token_end |
1437 | | - num_k_tokens = k_token_end - k_token_start |
1438 | | - |
1439 | | - slot_mapping_fp8_chunk = metadata.slot_mapping_fp8_fullkv[ |
1440 | | - k_token_start:k_token_end] |
1441 | | - slot_mapping_scale_chunk = metadata.slot_mapping_scale_fullkv[ |
1442 | | - k_token_start:k_token_end] |
1443 | | - |
1444 | | - # Vectorized gather using pre-computed slot mappings |
1445 | | - # Gather FP8 data |
1446 | | - byte_offsets_fp8 = torch.arange( |
1447 | | - head_dim, device=k_cache.device).unsqueeze(0) # [1, head_dim] |
1448 | | - gather_indices_fp8 = slot_mapping_fp8_chunk.unsqueeze( |
1449 | | - 1) + byte_offsets_fp8 # [num_k_tokens, head_dim] |
1450 | | - gather_indices_fp8 = _unravel_indices(gather_indices_fp8, k_cache.shape) |
1451 | | - k_fp8_bytes = k_cache[gather_indices_fp8] |
1452 | | - k_fp8 = k_fp8_bytes.view(torch.float8_e4m3fn).view( |
1453 | | - num_k_tokens, head_dim) |
1454 | | - |
1455 | | - # Gather scale data |
1456 | | - byte_offsets_scale = torch.arange( |
1457 | | - scale_size, device=k_cache.device).unsqueeze(0) # [1, 4] |
1458 | | - gather_indices_scale = slot_mapping_scale_chunk.unsqueeze( |
1459 | | - 1) + byte_offsets_scale # [num_k_tokens, 4] |
1460 | | - gather_indices_scale = _unravel_indices(gather_indices_scale, |
1461 | | - k_cache.shape) |
1462 | | - k_scale_bytes = k_cache[gather_indices_scale] |
1463 | | - k_scale = k_scale_bytes.view(torch.float32).view(num_k_tokens, 1) |
1464 | | - |
1465 | | - return k_fp8, k_scale |
1466 | | - |
1467 | 1388 | def sparse_attn_indexer( |
1468 | 1389 | self, |
1469 | 1390 | metadata: DSAtrtllmAttentionMetadata, |
@@ -1502,10 +1423,23 @@ def sparse_attn_indexer( |
1502 | 1423 | tp_rank = metadata.mapping.tp_rank |
1503 | 1424 | tp_size = metadata.mapping.tp_size |
1504 | 1425 |
|
| 1426 | + # Use the 2D pool data directly (contiguous) instead of the |
| 1427 | + # 4D view, because the 4D view may have strides that |
| 1428 | + # prevent flattening via .view(-1). |
| 1429 | + layer_offset = metadata.kv_cache_manager.layer_offsets[ |
| 1430 | + self.layer_idx] |
| 1431 | + gather_k_cache_pool = metadata.kv_cache_manager.indexer_k_cache_pool_per_layer[ |
| 1432 | + layer_offset] |
| 1433 | + |
1505 | 1434 | for chunk in metadata.indexer_prefill_chunks: |
1506 | | - # Gather K from cache for this chunk (dual to _update_k_cache) |
1507 | | - chunk_k_fp8, chunk_k_scale = self._gather_k_cache_for_chunk( |
1508 | | - metadata, chunk) |
| 1435 | + chunk_k_fp8, chunk_k_scale = triton_gather_k_cache( |
| 1436 | + gather_k_cache_pool, |
| 1437 | + metadata.slot_mapping_fp8_fullkv, |
| 1438 | + metadata.slot_mapping_scale_fullkv, |
| 1439 | + chunk.k_token_start, |
| 1440 | + chunk.k_token_end, |
| 1441 | + self.head_dim, |
| 1442 | + ) |
1509 | 1443 |
|
1510 | 1444 | chunk_num_token = chunk.token_end - chunk.token_start |
1511 | 1445 | apply_q_split = q_split_eligible and chunk_num_token >= q_split_threshold |
|
0 commit comments