diff --git a/tensorrt_llm/_torch/disaggregation/native/mixers/ssm/peer.py b/tensorrt_llm/_torch/disaggregation/native/mixers/ssm/peer.py index e8c8816ba4c8..d08b44df88df 100644 --- a/tensorrt_llm/_torch/disaggregation/native/mixers/ssm/peer.py +++ b/tensorrt_llm/_torch/disaggregation/native/mixers/ssm/peer.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import numpy as np @@ -331,22 +331,6 @@ def _mamba_tp(ri: RankInfo) -> Tuple[int, int]: return 1, 0 return ri.tp_size, ri.tp_rank - @staticmethod - def _build_layer_ptrs( - pool: PhysicalPool, - layer_offsets: Dict[int, int], - overlapping_layers: List[int], - slot: int, - ) -> np.ndarray: - """Build per-layer pointers for a given pool (conv or ssm) and slot.""" - ptrs = [] - for glid in overlapping_layers: - lid = layer_offsets[glid] - ptrs.append( - pool.base_address + lid * pool.num_slots * pool.slot_bytes + slot * pool.slot_bytes - ) - return np.array(ptrs, dtype=np.int64) - @staticmethod def _select_mapper( *, @@ -368,7 +352,7 @@ def _select_mapper( transfer_layers=transfer_layers, src_layer_off=0, dst_layer_off=0, - block_bytes_per_layer=self_pool.slot_bytes, + block_bytes_per_layer=self_pool.data_bytes, ) if is_conv: return ConvStateMismatchMapper( @@ -383,8 +367,8 @@ def _select_mapper( peer_tp_rank=peer_mamba_tp_rank, ) # SSM state: head-level granularity - self_nheads = self_pool.slot_bytes // self_mlg.ssm_bytes_per_head - peer_nheads = peer_pool.slot_bytes // peer_mlg.ssm_bytes_per_head + self_nheads = self_pool.data_bytes // self_mlg.ssm_bytes_per_head + peer_nheads = peer_pool.data_bytes // peer_mlg.ssm_bytes_per_head return MambaHeadMismatchMapper( transfer_layers=transfer_layers, src_layer_off=0, @@ -430,18 +414,14 @@ def build_mamba_frags( (self_mlg.conv_states, peer_mlg.conv_states, True), (self_mlg.ssm_states, peer_mlg.ssm_states, False), ]: - src_ptrs = MambaPolicy._build_layer_ptrs( - self_pool, self_mlg.mamba_layer_offsets, overlapping_layers, src_slot - ) - dst_ptrs = MambaPolicy._build_layer_ptrs( - peer_pool, peer_mlg.mamba_layer_offsets, overlapping_layers, dst_slot - ) + src_ptrs = self_mlg.resolve_layer_ptrs(self_pool, overlapping_layers, src_slot) + dst_ptrs = peer_mlg.resolve_layer_ptrs(peer_pool, overlapping_layers, dst_slot) src_region = SpecRegion( - memory=MemRegionGroup(ptrs=src_ptrs, bytes_per_region=self_pool.slot_bytes) + memory=MemRegionGroup(ptrs=src_ptrs, bytes_per_region=self_pool.data_bytes) ) dst_region = SpecRegion( - memory=MemRegionGroup(ptrs=dst_ptrs, bytes_per_region=peer_pool.slot_bytes) + memory=MemRegionGroup(ptrs=dst_ptrs, bytes_per_region=peer_pool.data_bytes) ) mapper = MambaPolicy._select_mapper( diff --git a/tensorrt_llm/_torch/disaggregation/native/peer.py b/tensorrt_llm/_torch/disaggregation/native/peer.py index 0d305cee40df..38c0816c5def 100644 --- a/tensorrt_llm/_torch/disaggregation/native/peer.py +++ b/tensorrt_llm/_torch/disaggregation/native/peer.py @@ -255,8 +255,8 @@ def get_kv_map( peer_layer_offset=peer_layer_offset, self_pool_num_layers=self_num_layers, peer_pool_num_layers=peer_num_layers, - self_pool_slot_bytes=self_phys.slot_bytes, - peer_pool_slot_bytes=peer_phys.slot_bytes, + self_pool_slot_bytes=self_phys.slot_stride, + peer_pool_slot_bytes=peer_phys.slot_stride, ) self._kv_map_cache[cache_key] = mapper diff --git a/tensorrt_llm/_torch/disaggregation/resource/kv_extractor.py b/tensorrt_llm/_torch/disaggregation/resource/kv_extractor.py index c83d26c26f29..36636cca3ab5 100644 --- a/tensorrt_llm/_torch/disaggregation/resource/kv_extractor.py +++ b/tensorrt_llm/_torch/disaggregation/resource/kv_extractor.py @@ -62,7 +62,7 @@ def extract( SpecRegion whose memory is a MemRegionGroup containing all blocks described by region_ids. - For KV cache: each ptr = base_address + slot_id * slot_bytes, pointing + For KV cache: each ptr = base_address + slot_id * slot_stride, pointing to the start of a full slot. The slot contains buffer entries for all layers in this layer_group laid out contiguously from offset 0. @@ -75,7 +75,7 @@ def extract( pool = get_physical_pool(self._page_table, layer_group_id, pv.pool_idx) base_ptr = pool.base_address - block_size = pool.slot_bytes + block_size = pool.slot_stride # KV cache: filter out invalid block_ids (BAD_PAGE_INDEX = -1) valid = region_ids >= 0 @@ -104,15 +104,19 @@ def _build_layer_group_for_mamba( conv_state = manager._impl.mamba_cache.conv ssm_state = manager._impl.mamba_cache.temporal + conv_slot_bytes = conv_state.stride(1) * conv_state.element_size() conv_pool = PhysicalPool( base_address=conv_state.data_ptr(), - slot_bytes=conv_state.stride(1) * conv_state.element_size(), + slot_stride=conv_slot_bytes, + data_bytes=conv_slot_bytes, num_slots=conv_state.shape[1], ) + ssm_slot_bytes = ssm_state.stride(1) * ssm_state.element_size() ssm_pool = PhysicalPool( base_address=ssm_state.data_ptr(), - slot_bytes=ssm_state.stride(1) * ssm_state.element_size(), + slot_stride=ssm_slot_bytes, + data_bytes=ssm_slot_bytes, num_slots=ssm_state.shape[1], ) @@ -190,7 +194,10 @@ def build_page_table(kv_cache_manager: KVCacheManager) -> KVCachePageTable: entries.append((lid, int(DataRole.VALUE), base_offset + buffer_size, buffer_size)) kv_physical = PhysicalPool( - base_address=base_addr, slot_bytes=slot_bytes, num_slots=num_blocks + base_address=base_addr, + slot_stride=slot_bytes, + data_bytes=slot_bytes, + num_slots=num_blocks, ) kv_view = PoolView(pool_idx=0, buffer_entries=np.array(entries, dtype=BUFFER_ENTRY_DTYPE)) physical_pools = [kv_physical] @@ -207,7 +214,8 @@ def build_page_table(kv_cache_manager: KVCacheManager) -> KVCachePageTable: indexer_slot_bytes = per_block_elems * indexer_pool.element_size() indexer_physical = PhysicalPool( base_address=int(indexer_pool.data_ptr()), - slot_bytes=indexer_slot_bytes, + slot_stride=indexer_slot_bytes, + data_bytes=indexer_slot_bytes, num_slots=num_blocks, ) indexer_view = PoolView( @@ -366,7 +374,8 @@ def _role_str_to_enum(role: str) -> DataRole: pools=[ PhysicalPool( base_address=int(pool_group._pools[pi].slot_address(0)), - slot_bytes=int(pool_group._pools[pi].slot_size), + slot_stride=int(pool_group._pools[pi].slot_size), + data_bytes=int(pool_group._pools[pi].slot_size), num_slots=int(pool_group._pools[pi].num_slots), ) for pi in range(num_pools) diff --git a/tensorrt_llm/_torch/disaggregation/resource/page.py b/tensorrt_llm/_torch/disaggregation/resource/page.py index 8e96e3889745..641d0b67a4b0 100644 --- a/tensorrt_llm/_torch/disaggregation/resource/page.py +++ b/tensorrt_llm/_torch/disaggregation/resource/page.py @@ -18,21 +18,31 @@ @dataclass class PhysicalPool: base_address: int # uint64 - slot_bytes: int + slot_stride: int # bytes between consecutive slots (for addressing) + data_bytes: int # actual data bytes per slot (for transfer size) num_slots: int def to_dict(self) -> dict: return { "base_address": int(self.base_address), - "slot_bytes": int(self.slot_bytes), + "slot_stride": int(self.slot_stride), + "data_bytes": int(self.data_bytes), "num_slots": int(self.num_slots), } @staticmethod def from_dict(data: dict) -> "PhysicalPool": + # Backward compat: old format has "slot_bytes" only. + if "slot_stride" in data: + slot_stride = int(data["slot_stride"]) + data_bytes = int(data["data_bytes"]) + else: + slot_stride = int(data["slot_bytes"]) + data_bytes = slot_stride return PhysicalPool( base_address=int(data["base_address"]), - slot_bytes=int(data["slot_bytes"]), + slot_stride=slot_stride, + data_bytes=data_bytes, num_slots=int(data["num_slots"]), ) @@ -158,6 +168,23 @@ class MambaLayerGroup(LayerGroup): conv_section_bytes: Optional[List[int]] = None ssm_bytes_per_head: Optional[int] = None + def resolve_layer_ptrs( + self, + pool: PhysicalPool, + overlapping_layers: List[int], + slot: int, + ) -> np.ndarray: + """Compute per-layer physical addresses for a given pool and slot.""" + ptrs = np.empty(len(overlapping_layers), dtype=np.int64) + for i, glid in enumerate(overlapping_layers): + lid = self.mamba_layer_offsets[glid] + ptrs[i] = ( + pool.base_address + + lid * pool.num_slots * pool.slot_stride + + slot * pool.slot_stride + ) + return ptrs + def to_dict(self) -> dict: return { "pool_group_idx": int(self.pool_group_idx), diff --git a/tensorrt_llm/_torch/disaggregation/resource/utils.py b/tensorrt_llm/_torch/disaggregation/resource/utils.py index c97dd2721325..247016a27aa9 100644 --- a/tensorrt_llm/_torch/disaggregation/resource/utils.py +++ b/tensorrt_llm/_torch/disaggregation/resource/utils.py @@ -23,14 +23,14 @@ class PoolRole(Enum): def get_pool_bytes(pool: PhysicalPool) -> int: """Total bytes across all slots in this pool.""" - return pool.slot_bytes * pool.num_slots + return pool.slot_stride * pool.num_slots def get_slot_address(pool: PhysicalPool, slot_id: int) -> int: """Base address of *slot_id*.""" if slot_id >= pool.num_slots: raise ValueError(f"slot_id {slot_id} >= num_slots {pool.num_slots}") - return pool.base_address + slot_id * pool.slot_bytes + return pool.base_address + slot_id * pool.slot_stride # ------------------------------------------------------------------------- @@ -160,7 +160,7 @@ def get_device_pointer( raise ValueError(f"slot_id {slot_id} >= num_slots {pool.num_slots}") for e in pool_view.buffer_entries: if int(e["local_layer_id"]) == int(local_layer_id) and int(e["role"]) == int(role): - return int(pool.base_address) + int(slot_id) * int(pool.slot_bytes) + int(e["offset"]) + return int(pool.base_address) + int(slot_id) * int(pool.slot_stride) + int(e["offset"]) raise ValueError(f"Buffer not found: local_layer_id={local_layer_id}, role={role}") @@ -179,7 +179,7 @@ def get_unique_pool_memory_descs( if isinstance(lg, MambaLayerGroup): num_mamba_layers = len(lg.mamba_layer_offsets) for pool in [lg.conv_states, lg.ssm_states]: - pool_size = num_mamba_layers * pool.num_slots * pool.slot_bytes + pool_size = num_mamba_layers * pool.num_slots * pool.slot_stride pool_key = (pool.base_address, pool_size) if pool_key not in unique_pools: unique_pools[pool_key] = pool_counter diff --git a/tests/unittest/disaggregated/region/test_page.py b/tests/unittest/disaggregated/region/test_page.py index c11391b530a7..066e04cb6159 100644 --- a/tests/unittest/disaggregated/region/test_page.py +++ b/tests/unittest/disaggregated/region/test_page.py @@ -23,18 +23,20 @@ def _make_buffer_entries(): def test_physical_pool_construction(): - pool = PhysicalPool(base_address=0x10000, slot_bytes=256, num_slots=4) + pool = PhysicalPool(base_address=0x10000, slot_stride=256, data_bytes=256, num_slots=4) assert pool.base_address == 0x10000 - assert pool.slot_bytes == 256 + assert pool.slot_stride == 256 + assert pool.data_bytes == 256 assert pool.num_slots == 4 def test_physical_pool_roundtrip(): - pool = PhysicalPool(base_address=0x10000, slot_bytes=256, num_slots=4) + pool = PhysicalPool(base_address=0x10000, slot_stride=256, data_bytes=256, num_slots=4) d = pool.to_dict() restored = PhysicalPool.from_dict(d) assert restored.base_address == pool.base_address - assert restored.slot_bytes == pool.slot_bytes + assert restored.slot_stride == pool.slot_stride + assert restored.data_bytes == pool.data_bytes assert restored.num_slots == pool.num_slots @@ -90,7 +92,9 @@ def test_kv_cache_page_table_roundtrip(): ], pool_groups=[ PhysicalPoolGroup( - pools=[PhysicalPool(base_address=0x10000, slot_bytes=256, num_slots=4)] + pools=[ + PhysicalPool(base_address=0x10000, slot_stride=256, data_bytes=256, num_slots=4) + ] ) ], ) diff --git a/tests/unittest/disaggregated/test_extractor.py b/tests/unittest/disaggregated/test_extractor.py index 57b38ce317ea..4534df09f81d 100644 --- a/tests/unittest/disaggregated/test_extractor.py +++ b/tests/unittest/disaggregated/test_extractor.py @@ -214,7 +214,7 @@ def test_layer_group_meta_serialization(): [(0, int(DataRole.KEY), 0, 256), (0, int(DataRole.VALUE), 256, 256)], dtype=BUFFER_ENTRY_DTYPE, ) - kv_pool = PhysicalPool(base_address=1000, slot_bytes=512, num_slots=10) + kv_pool = PhysicalPool(base_address=1000, slot_stride=512, data_bytes=512, num_slots=10) pv = PoolView(pool_idx=0, buffer_entries=entries) local_layers = [ LocalLayer(local_layer_id=0, global_layer_id=0), @@ -245,8 +245,8 @@ def test_layer_group_meta_serialization(): def test_mamba_layer_group_serialization(): from tensorrt_llm._torch.disaggregation.resource.page import MambaLayerGroup, PhysicalPool - conv_pool = PhysicalPool(base_address=1000, slot_bytes=128, num_slots=10) - ssm_pool = PhysicalPool(base_address=8000, slot_bytes=256, num_slots=8) + conv_pool = PhysicalPool(base_address=1000, slot_stride=128, data_bytes=128, num_slots=10) + ssm_pool = PhysicalPool(base_address=8000, slot_stride=256, data_bytes=256, num_slots=8) mlg = MambaLayerGroup( pool_group_idx=1, mamba_layer_offsets={10: 0, 11: 1, 12: 2}, @@ -266,10 +266,12 @@ def test_mamba_layer_group_serialization(): assert isinstance(restored, MambaLayerGroup) assert restored.mamba_layer_offsets == {10: 0, 11: 1, 12: 2} assert restored.conv_states.base_address == 1000 - assert restored.conv_states.slot_bytes == 128 + assert restored.conv_states.slot_stride == 128 + assert restored.conv_states.data_bytes == 128 assert restored.conv_states.num_slots == 10 assert restored.ssm_states.base_address == 8000 - assert restored.ssm_states.slot_bytes == 256 + assert restored.ssm_states.slot_stride == 256 + assert restored.ssm_states.data_bytes == 256 assert restored.ssm_states.num_slots == 8 assert restored.conv_section_bytes == [512, 256, 256] assert restored.ssm_bytes_per_head == 128 @@ -306,8 +308,8 @@ def test_mixed_page_table_serialization(): mamba_lg = MambaLayerGroup( pool_group_idx=1, mamba_layer_offsets={1: 0, 2: 1}, - conv_states=PhysicalPool(base_address=5000, slot_bytes=1024, num_slots=4), - ssm_states=PhysicalPool(base_address=9000, slot_bytes=2048, num_slots=4), + conv_states=PhysicalPool(base_address=5000, slot_stride=1024, data_bytes=1024, num_slots=4), + ssm_states=PhysicalPool(base_address=9000, slot_stride=2048, data_bytes=2048, num_slots=4), conv_section_bytes=[256, 128, 128], ssm_bytes_per_head=64, ) @@ -315,7 +317,7 @@ def test_mixed_page_table_serialization(): page_table = KVCachePageTable( tokens_per_block=16, layer_groups=[attn_lg, mamba_lg], - pool_groups=[PhysicalPoolGroup(pools=[PhysicalPool(1000, 512, 10)])], + pool_groups=[PhysicalPoolGroup(pools=[PhysicalPool(1000, 512, 512, 10)])], ) d = page_table.to_dict() diff --git a/tests/unittest/disaggregated/test_kv_transfer.py b/tests/unittest/disaggregated/test_kv_transfer.py index d7808688bc24..da81b92f3207 100644 --- a/tests/unittest/disaggregated/test_kv_transfer.py +++ b/tests/unittest/disaggregated/test_kv_transfer.py @@ -230,7 +230,7 @@ def create_transfer_worker_setup( for pg in page_table.pool_groups: for pool_desc in pg.pools: key = pool_desc.base_address - pool_bytes = pool_desc.slot_bytes * pool_desc.num_slots + pool_bytes = pool_desc.slot_stride * pool_desc.num_slots if key not in unique_pools or pool_bytes > unique_pools[key]: unique_pools[key] = pool_bytes @@ -370,7 +370,7 @@ def create_transfer_worker_setup( for pg in gen_page_table.pool_groups: for pool_desc in pg.pools: key = pool_desc.base_address - pool_bytes = pool_desc.slot_bytes * pool_desc.num_slots + pool_bytes = pool_desc.slot_stride * pool_desc.num_slots if key not in gen_unique_pools or pool_bytes > gen_unique_pools[key]: gen_unique_pools[key] = pool_bytes gen_element_bytes = get_size_in_bytes(1, gen_kv_cache_manager.dtype) diff --git a/tests/unittest/disaggregated/test_peer.py b/tests/unittest/disaggregated/test_peer.py index d7a9f9913f8b..d4e8ee0b419c 100644 --- a/tests/unittest/disaggregated/test_peer.py +++ b/tests/unittest/disaggregated/test_peer.py @@ -48,7 +48,7 @@ def make_page_table(pool_ptrs=None, block_bytes=None, global_layer_ids=None): PoolView(pool_idx=pi, buffer_entries=buffer_entries) for pi in range(len(pool_ptrs)) ] physical_pools = [ - PhysicalPool(base_address=ptr, slot_bytes=bs, num_slots=128) + PhysicalPool(base_address=ptr, slot_stride=bs, data_bytes=bs, num_slots=128) for ptr, bs in zip(pool_ptrs, block_bytes) ] @@ -62,8 +62,12 @@ def make_page_table(pool_ptrs=None, block_bytes=None, global_layer_ids=None): mamba_lg = MambaLayerGroup( pool_group_idx=1, mamba_layer_offsets={100: 0, 101: 1}, - conv_states=PhysicalPool(base_address=0xA000, slot_bytes=2048, num_slots=128), - ssm_states=PhysicalPool(base_address=0xB000, slot_bytes=4096, num_slots=128), + conv_states=PhysicalPool( + base_address=0xA000, slot_stride=2048, data_bytes=2048, num_slots=128 + ), + ssm_states=PhysicalPool( + base_address=0xB000, slot_stride=4096, data_bytes=4096, num_slots=128 + ), conv_section_bytes=[512, 256, 256], ssm_bytes_per_head=64, )