Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 8 additions & 28 deletions tensorrt_llm/_torch/disaggregation/native/mixers/ssm/peer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Tuple
from typing import List, Optional, Tuple

import numpy as np

Expand Down Expand Up @@ -331,22 +331,6 @@ def _mamba_tp(ri: RankInfo) -> Tuple[int, int]:
return 1, 0
return ri.tp_size, ri.tp_rank

@staticmethod
def _build_layer_ptrs(
pool: PhysicalPool,
layer_offsets: Dict[int, int],
overlapping_layers: List[int],
slot: int,
) -> np.ndarray:
"""Build per-layer pointers for a given pool (conv or ssm) and slot."""
ptrs = []
for glid in overlapping_layers:
lid = layer_offsets[glid]
ptrs.append(
pool.base_address + lid * pool.num_slots * pool.slot_bytes + slot * pool.slot_bytes
)
return np.array(ptrs, dtype=np.int64)

@staticmethod
def _select_mapper(
*,
Expand All @@ -368,7 +352,7 @@ def _select_mapper(
transfer_layers=transfer_layers,
src_layer_off=0,
dst_layer_off=0,
block_bytes_per_layer=self_pool.slot_bytes,
block_bytes_per_layer=self_pool.data_bytes,
)
if is_conv:
return ConvStateMismatchMapper(
Expand All @@ -383,8 +367,8 @@ def _select_mapper(
peer_tp_rank=peer_mamba_tp_rank,
)
# SSM state: head-level granularity
self_nheads = self_pool.slot_bytes // self_mlg.ssm_bytes_per_head
peer_nheads = peer_pool.slot_bytes // peer_mlg.ssm_bytes_per_head
self_nheads = self_pool.data_bytes // self_mlg.ssm_bytes_per_head
peer_nheads = peer_pool.data_bytes // peer_mlg.ssm_bytes_per_head
return MambaHeadMismatchMapper(
transfer_layers=transfer_layers,
src_layer_off=0,
Expand Down Expand Up @@ -430,18 +414,14 @@ def build_mamba_frags(
(self_mlg.conv_states, peer_mlg.conv_states, True),
(self_mlg.ssm_states, peer_mlg.ssm_states, False),
]:
src_ptrs = MambaPolicy._build_layer_ptrs(
self_pool, self_mlg.mamba_layer_offsets, overlapping_layers, src_slot
)
dst_ptrs = MambaPolicy._build_layer_ptrs(
peer_pool, peer_mlg.mamba_layer_offsets, overlapping_layers, dst_slot
)
src_ptrs = self_mlg.resolve_layer_ptrs(self_pool, overlapping_layers, src_slot)
dst_ptrs = peer_mlg.resolve_layer_ptrs(peer_pool, overlapping_layers, dst_slot)

src_region = SpecRegion(
memory=MemRegionGroup(ptrs=src_ptrs, bytes_per_region=self_pool.slot_bytes)
memory=MemRegionGroup(ptrs=src_ptrs, bytes_per_region=self_pool.data_bytes)
)
dst_region = SpecRegion(
memory=MemRegionGroup(ptrs=dst_ptrs, bytes_per_region=peer_pool.slot_bytes)
memory=MemRegionGroup(ptrs=dst_ptrs, bytes_per_region=peer_pool.data_bytes)
)

mapper = MambaPolicy._select_mapper(
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/disaggregation/native/peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,8 @@ def get_kv_map(
peer_layer_offset=peer_layer_offset,
self_pool_num_layers=self_num_layers,
peer_pool_num_layers=peer_num_layers,
self_pool_slot_bytes=self_phys.slot_bytes,
peer_pool_slot_bytes=peer_phys.slot_bytes,
self_pool_slot_bytes=self_phys.slot_stride,
peer_pool_slot_bytes=peer_phys.slot_stride,
)

self._kv_map_cache[cache_key] = mapper
Expand Down
23 changes: 16 additions & 7 deletions tensorrt_llm/_torch/disaggregation/resource/kv_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def extract(
SpecRegion whose memory is a MemRegionGroup containing all blocks
described by region_ids.

For KV cache: each ptr = base_address + slot_id * slot_bytes, pointing
For KV cache: each ptr = base_address + slot_id * slot_stride, pointing
to the start of a full slot. The slot contains buffer entries for all
layers in this layer_group laid out contiguously from offset 0.

Expand All @@ -75,7 +75,7 @@ def extract(
pool = get_physical_pool(self._page_table, layer_group_id, pv.pool_idx)

base_ptr = pool.base_address
block_size = pool.slot_bytes
block_size = pool.slot_stride

# KV cache: filter out invalid block_ids (BAD_PAGE_INDEX = -1)
valid = region_ids >= 0
Expand Down Expand Up @@ -104,15 +104,19 @@ def _build_layer_group_for_mamba(
conv_state = manager._impl.mamba_cache.conv
ssm_state = manager._impl.mamba_cache.temporal

conv_slot_bytes = conv_state.stride(1) * conv_state.element_size()
conv_pool = PhysicalPool(
base_address=conv_state.data_ptr(),
slot_bytes=conv_state.stride(1) * conv_state.element_size(),
slot_stride=conv_slot_bytes,
data_bytes=conv_slot_bytes,
num_slots=conv_state.shape[1],
)

ssm_slot_bytes = ssm_state.stride(1) * ssm_state.element_size()
ssm_pool = PhysicalPool(
base_address=ssm_state.data_ptr(),
slot_bytes=ssm_state.stride(1) * ssm_state.element_size(),
slot_stride=ssm_slot_bytes,
data_bytes=ssm_slot_bytes,
num_slots=ssm_state.shape[1],
)

Expand Down Expand Up @@ -190,7 +194,10 @@ def build_page_table(kv_cache_manager: KVCacheManager) -> KVCachePageTable:
entries.append((lid, int(DataRole.VALUE), base_offset + buffer_size, buffer_size))

kv_physical = PhysicalPool(
base_address=base_addr, slot_bytes=slot_bytes, num_slots=num_blocks
base_address=base_addr,
slot_stride=slot_bytes,
data_bytes=slot_bytes,
num_slots=num_blocks,
)
kv_view = PoolView(pool_idx=0, buffer_entries=np.array(entries, dtype=BUFFER_ENTRY_DTYPE))
physical_pools = [kv_physical]
Expand All @@ -207,7 +214,8 @@ def build_page_table(kv_cache_manager: KVCacheManager) -> KVCachePageTable:
indexer_slot_bytes = per_block_elems * indexer_pool.element_size()
indexer_physical = PhysicalPool(
base_address=int(indexer_pool.data_ptr()),
slot_bytes=indexer_slot_bytes,
slot_stride=indexer_slot_bytes,
data_bytes=indexer_slot_bytes,
num_slots=num_blocks,
)
indexer_view = PoolView(
Expand Down Expand Up @@ -366,7 +374,8 @@ def _role_str_to_enum(role: str) -> DataRole:
pools=[
PhysicalPool(
base_address=int(pool_group._pools[pi].slot_address(0)),
slot_bytes=int(pool_group._pools[pi].slot_size),
slot_stride=int(pool_group._pools[pi].slot_size),
data_bytes=int(pool_group._pools[pi].slot_size),
num_slots=int(pool_group._pools[pi].num_slots),
)
for pi in range(num_pools)
Expand Down
33 changes: 30 additions & 3 deletions tensorrt_llm/_torch/disaggregation/resource/page.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,31 @@
@dataclass
class PhysicalPool:
base_address: int # uint64
slot_bytes: int
slot_stride: int # bytes between consecutive slots (for addressing)
data_bytes: int # actual data bytes per slot (for transfer size)
num_slots: int

def to_dict(self) -> dict:
return {
"base_address": int(self.base_address),
"slot_bytes": int(self.slot_bytes),
"slot_stride": int(self.slot_stride),
"data_bytes": int(self.data_bytes),
"num_slots": int(self.num_slots),
}

@staticmethod
def from_dict(data: dict) -> "PhysicalPool":
# Backward compat: old format has "slot_bytes" only.
if "slot_stride" in data:
slot_stride = int(data["slot_stride"])
data_bytes = int(data["data_bytes"])
else:
slot_stride = int(data["slot_bytes"])
data_bytes = slot_stride
return PhysicalPool(
base_address=int(data["base_address"]),
slot_bytes=int(data["slot_bytes"]),
slot_stride=slot_stride,
data_bytes=data_bytes,
num_slots=int(data["num_slots"]),
)

Expand Down Expand Up @@ -158,6 +168,23 @@ class MambaLayerGroup(LayerGroup):
conv_section_bytes: Optional[List[int]] = None
ssm_bytes_per_head: Optional[int] = None

def resolve_layer_ptrs(
self,
pool: PhysicalPool,
overlapping_layers: List[int],
slot: int,
) -> np.ndarray:
"""Compute per-layer physical addresses for a given pool and slot."""
ptrs = np.empty(len(overlapping_layers), dtype=np.int64)
for i, glid in enumerate(overlapping_layers):
lid = self.mamba_layer_offsets[glid]
ptrs[i] = (
pool.base_address
+ lid * pool.num_slots * pool.slot_stride
+ slot * pool.slot_stride
)
return ptrs

def to_dict(self) -> dict:
return {
"pool_group_idx": int(self.pool_group_idx),
Expand Down
8 changes: 4 additions & 4 deletions tensorrt_llm/_torch/disaggregation/resource/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ class PoolRole(Enum):

def get_pool_bytes(pool: PhysicalPool) -> int:
"""Total bytes across all slots in this pool."""
return pool.slot_bytes * pool.num_slots
return pool.slot_stride * pool.num_slots


def get_slot_address(pool: PhysicalPool, slot_id: int) -> int:
"""Base address of *slot_id*."""
if slot_id >= pool.num_slots:
raise ValueError(f"slot_id {slot_id} >= num_slots {pool.num_slots}")
return pool.base_address + slot_id * pool.slot_bytes
return pool.base_address + slot_id * pool.slot_stride


# -------------------------------------------------------------------------
Expand Down Expand Up @@ -160,7 +160,7 @@ def get_device_pointer(
raise ValueError(f"slot_id {slot_id} >= num_slots {pool.num_slots}")
for e in pool_view.buffer_entries:
if int(e["local_layer_id"]) == int(local_layer_id) and int(e["role"]) == int(role):
return int(pool.base_address) + int(slot_id) * int(pool.slot_bytes) + int(e["offset"])
return int(pool.base_address) + int(slot_id) * int(pool.slot_stride) + int(e["offset"])
raise ValueError(f"Buffer not found: local_layer_id={local_layer_id}, role={role}")


Expand All @@ -179,7 +179,7 @@ def get_unique_pool_memory_descs(
if isinstance(lg, MambaLayerGroup):
num_mamba_layers = len(lg.mamba_layer_offsets)
for pool in [lg.conv_states, lg.ssm_states]:
pool_size = num_mamba_layers * pool.num_slots * pool.slot_bytes
pool_size = num_mamba_layers * pool.num_slots * pool.slot_stride
pool_key = (pool.base_address, pool_size)
if pool_key not in unique_pools:
unique_pools[pool_key] = pool_counter
Expand Down
14 changes: 9 additions & 5 deletions tests/unittest/disaggregated/region/test_page.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,20 @@ def _make_buffer_entries():


def test_physical_pool_construction():
pool = PhysicalPool(base_address=0x10000, slot_bytes=256, num_slots=4)
pool = PhysicalPool(base_address=0x10000, slot_stride=256, data_bytes=256, num_slots=4)
assert pool.base_address == 0x10000
assert pool.slot_bytes == 256
assert pool.slot_stride == 256
assert pool.data_bytes == 256
assert pool.num_slots == 4


def test_physical_pool_roundtrip():
pool = PhysicalPool(base_address=0x10000, slot_bytes=256, num_slots=4)
pool = PhysicalPool(base_address=0x10000, slot_stride=256, data_bytes=256, num_slots=4)
d = pool.to_dict()
restored = PhysicalPool.from_dict(d)
assert restored.base_address == pool.base_address
assert restored.slot_bytes == pool.slot_bytes
assert restored.slot_stride == pool.slot_stride
assert restored.data_bytes == pool.data_bytes
assert restored.num_slots == pool.num_slots


Expand Down Expand Up @@ -90,7 +92,9 @@ def test_kv_cache_page_table_roundtrip():
],
pool_groups=[
PhysicalPoolGroup(
pools=[PhysicalPool(base_address=0x10000, slot_bytes=256, num_slots=4)]
pools=[
PhysicalPool(base_address=0x10000, slot_stride=256, data_bytes=256, num_slots=4)
]
)
],
)
Expand Down
18 changes: 10 additions & 8 deletions tests/unittest/disaggregated/test_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def test_layer_group_meta_serialization():
[(0, int(DataRole.KEY), 0, 256), (0, int(DataRole.VALUE), 256, 256)],
dtype=BUFFER_ENTRY_DTYPE,
)
kv_pool = PhysicalPool(base_address=1000, slot_bytes=512, num_slots=10)
kv_pool = PhysicalPool(base_address=1000, slot_stride=512, data_bytes=512, num_slots=10)
pv = PoolView(pool_idx=0, buffer_entries=entries)
local_layers = [
LocalLayer(local_layer_id=0, global_layer_id=0),
Expand Down Expand Up @@ -245,8 +245,8 @@ def test_layer_group_meta_serialization():
def test_mamba_layer_group_serialization():
from tensorrt_llm._torch.disaggregation.resource.page import MambaLayerGroup, PhysicalPool

conv_pool = PhysicalPool(base_address=1000, slot_bytes=128, num_slots=10)
ssm_pool = PhysicalPool(base_address=8000, slot_bytes=256, num_slots=8)
conv_pool = PhysicalPool(base_address=1000, slot_stride=128, data_bytes=128, num_slots=10)
ssm_pool = PhysicalPool(base_address=8000, slot_stride=256, data_bytes=256, num_slots=8)
mlg = MambaLayerGroup(
pool_group_idx=1,
mamba_layer_offsets={10: 0, 11: 1, 12: 2},
Expand All @@ -266,10 +266,12 @@ def test_mamba_layer_group_serialization():
assert isinstance(restored, MambaLayerGroup)
assert restored.mamba_layer_offsets == {10: 0, 11: 1, 12: 2}
assert restored.conv_states.base_address == 1000
assert restored.conv_states.slot_bytes == 128
assert restored.conv_states.slot_stride == 128
assert restored.conv_states.data_bytes == 128
assert restored.conv_states.num_slots == 10
assert restored.ssm_states.base_address == 8000
assert restored.ssm_states.slot_bytes == 256
assert restored.ssm_states.slot_stride == 256
assert restored.ssm_states.data_bytes == 256
assert restored.ssm_states.num_slots == 8
assert restored.conv_section_bytes == [512, 256, 256]
assert restored.ssm_bytes_per_head == 128
Expand Down Expand Up @@ -306,16 +308,16 @@ def test_mixed_page_table_serialization():
mamba_lg = MambaLayerGroup(
pool_group_idx=1,
mamba_layer_offsets={1: 0, 2: 1},
conv_states=PhysicalPool(base_address=5000, slot_bytes=1024, num_slots=4),
ssm_states=PhysicalPool(base_address=9000, slot_bytes=2048, num_slots=4),
conv_states=PhysicalPool(base_address=5000, slot_stride=1024, data_bytes=1024, num_slots=4),
ssm_states=PhysicalPool(base_address=9000, slot_stride=2048, data_bytes=2048, num_slots=4),
conv_section_bytes=[256, 128, 128],
ssm_bytes_per_head=64,
)

page_table = KVCachePageTable(
tokens_per_block=16,
layer_groups=[attn_lg, mamba_lg],
pool_groups=[PhysicalPoolGroup(pools=[PhysicalPool(1000, 512, 10)])],
pool_groups=[PhysicalPoolGroup(pools=[PhysicalPool(1000, 512, 512, 10)])],
)

d = page_table.to_dict()
Expand Down
4 changes: 2 additions & 2 deletions tests/unittest/disaggregated/test_kv_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ def create_transfer_worker_setup(
for pg in page_table.pool_groups:
for pool_desc in pg.pools:
key = pool_desc.base_address
pool_bytes = pool_desc.slot_bytes * pool_desc.num_slots
pool_bytes = pool_desc.slot_stride * pool_desc.num_slots
if key not in unique_pools or pool_bytes > unique_pools[key]:
unique_pools[key] = pool_bytes

Expand Down Expand Up @@ -370,7 +370,7 @@ def create_transfer_worker_setup(
for pg in gen_page_table.pool_groups:
for pool_desc in pg.pools:
key = pool_desc.base_address
pool_bytes = pool_desc.slot_bytes * pool_desc.num_slots
pool_bytes = pool_desc.slot_stride * pool_desc.num_slots
if key not in gen_unique_pools or pool_bytes > gen_unique_pools[key]:
gen_unique_pools[key] = pool_bytes
gen_element_bytes = get_size_in_bytes(1, gen_kv_cache_manager.dtype)
Expand Down
10 changes: 7 additions & 3 deletions tests/unittest/disaggregated/test_peer.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def make_page_table(pool_ptrs=None, block_bytes=None, global_layer_ids=None):
PoolView(pool_idx=pi, buffer_entries=buffer_entries) for pi in range(len(pool_ptrs))
]
physical_pools = [
PhysicalPool(base_address=ptr, slot_bytes=bs, num_slots=128)
PhysicalPool(base_address=ptr, slot_stride=bs, data_bytes=bs, num_slots=128)
for ptr, bs in zip(pool_ptrs, block_bytes)
]

Expand All @@ -62,8 +62,12 @@ def make_page_table(pool_ptrs=None, block_bytes=None, global_layer_ids=None):
mamba_lg = MambaLayerGroup(
pool_group_idx=1,
mamba_layer_offsets={100: 0, 101: 1},
conv_states=PhysicalPool(base_address=0xA000, slot_bytes=2048, num_slots=128),
ssm_states=PhysicalPool(base_address=0xB000, slot_bytes=4096, num_slots=128),
conv_states=PhysicalPool(
base_address=0xA000, slot_stride=2048, data_bytes=2048, num_slots=128
),
ssm_states=PhysicalPool(
base_address=0xB000, slot_stride=4096, data_bytes=4096, num_slots=128
),
conv_section_bytes=[512, 256, 256],
ssm_bytes_per_head=64,
)
Expand Down
Loading