Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 118 additions & 5 deletions python/sglang/srt/disaggregation/nixl/conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,19 @@ class TransferInfo:
dst_kv_indices: npt.NDArray[np.int32]
dst_aux_index: int
required_dst_info_num: int
dst_state_indices: List[int]

def is_dummy(self):
return self.dst_kv_indices.size == 0

@classmethod
def from_zmq(cls, msg: List[bytes]):
# Parse state_indices from msg[7] if present
if len(msg) > 7 and msg[7] != b"":
dst_state_indices = list(np.frombuffer(msg[7], dtype=np.int32))
else:
dst_state_indices = []

return cls(
room=int(msg[0].decode("ascii")),
endpoint=msg[1].decode("ascii"),
Expand All @@ -56,6 +63,7 @@ def from_zmq(cls, msg: List[bytes]):
dst_kv_indices=np.frombuffer(msg[4], dtype=np.int32),
dst_aux_index=int(msg[5].decode("ascii")),
required_dst_info_num=int(msg[6].decode("ascii")),
dst_state_indices=dst_state_indices,
)


Expand All @@ -70,13 +78,20 @@ class KVArgsRegisterInfo:
agent_metadata: bytes
dst_kv_ptrs: list[int]
dst_aux_ptrs: list[int]
dst_state_data_ptrs: list[int]
gpu_id: int
decode_tp_size: int
decode_tp_rank: int
dst_kv_item_len: int

@classmethod
def from_zmq(cls, msg: List[bytes]):
# Parse state_data_ptrs from msg[7] if present
if len(msg) > 7 and msg[7] != b"":
dst_state_data_ptrs = list(struct.unpack(f"{len(msg[7]) // 8}Q", msg[7]))
else:
dst_state_data_ptrs = []

return cls(
room=str(msg[0].decode("ascii")),
endpoint=msg[1].decode("ascii"),
Expand All @@ -85,10 +100,11 @@ def from_zmq(cls, msg: List[bytes]):
agent_metadata=msg[4],
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5]) // 8}Q", msg[5])),
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6]) // 8}Q", msg[6])),
gpu_id=int(msg[7].decode("ascii")),
decode_tp_size=int(msg[8].decode("ascii")),
decode_tp_rank=int(msg[9].decode("ascii")),
dst_kv_item_len=int(msg[10].decode("ascii")),
dst_state_data_ptrs=dst_state_data_ptrs,
gpu_id=int(msg[8].decode("ascii")),
decode_tp_size=int(msg[9].decode("ascii")),
decode_tp_rank=int(msg[10].decode("ascii")),
dst_kv_item_len=int(msg[11].decode("ascii")),
)


Expand All @@ -106,6 +122,10 @@ class TransferStatus:
num_pp_ranks_expected: Optional[int] = None
# Whether aux data has been received.
received_aux: bool = False
# PP ranks that have sent state data (state is layer-specific, each PP rank sends its portion).
received_state_per_pp: Set[int] = dataclasses.field(default_factory=set)
# Whether state data is expected (set based on state_type).
expects_state: bool = False
# Mark as failed
is_failure: bool = False

Expand All @@ -114,6 +134,9 @@ def is_done(self):
return True
if self.num_pp_ranks_expected is None or not self.received_aux:
return False
# If state data is expected, check all PP ranks have sent it
if self.expects_state and len(self.received_state_per_pp) < self.num_pp_ranks_expected:
return False
# All PP ranks must have reported their expected count
if len(self.expected_kvs_per_pp) < self.num_pp_ranks_expected:
return False
Expand Down Expand Up @@ -306,6 +329,18 @@ def register_buffer_to_engine(self):
if not self.aux_descs:
raise Exception("NIXL memory registration failed for aux tensors")

# Register state/extra pool data buffers if present
if self.kv_args.state_data_ptrs and self.kv_args.state_data_lens:
state_addrs = []
for state_data_ptr, state_data_len in zip(
self.kv_args.state_data_ptrs, self.kv_args.state_data_lens
):
state_addrs.append((state_data_ptr, state_data_len, self.kv_args.gpu_id, ""))
self.state_descs = self.agent.register_memory(state_addrs, "VRAM")
logger.debug(f"Register state tensors, len(state_addrs)= {len(state_addrs)}")
if not self.state_descs:
raise Exception("NIXL memory registration failed for state tensors")

def _add_remote_peer(self, decode_kv_args: KVArgsRegisterInfo):
agent_name = decode_kv_args.agent_name
if agent_name in self.decode_kv_args_table:
Expand Down Expand Up @@ -562,6 +597,48 @@ def send_aux(
raise Exception("KVSender failed to post transfer")
return xfer_handle

def _send_mamba_state(
self,
peer_name: str,
prefill_state_indices: List[int],
dst_state_data_ptrs: list[int],
dst_state_indices: List[int],
dst_gpu_id: int,
notif: str,
):
"""Transfer Mamba states via RDMA."""
assert len(prefill_state_indices) == 1, "Mamba should have single state index"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The code asserts the length of prefill_state_indices but not dst_state_indices. However, dst_state_indices is accessed at index 0 on line 621. If dst_state_indices is an empty list, this will cause an IndexError. It's possible for dst_state_indices to be empty as it's initialized to [] in TransferInfo.from_zmq if not present in the message. An assertion should be added to ensure dst_state_indices has at least one element before it's accessed.

Suggested change
assert len(prefill_state_indices) == 1, "Mamba should have single state index"
assert len(prefill_state_indices) == 1, "Mamba should have single state index"
assert len(dst_state_indices) == 1, "Mamba should have single state index"


src_addrs = []
dst_addrs = []

prefill_state_data_ptrs = self.kv_args.state_data_ptrs
prefill_state_item_lens = self.kv_args.state_item_lens

for i, dst_state_ptr in enumerate(dst_state_data_ptrs):
length = prefill_state_item_lens[i]
src_addr = prefill_state_data_ptrs[i] + length * int(prefill_state_indices[0])
dst_addr = dst_state_ptr + length * int(dst_state_indices[0])
src_addrs.append((src_addr, length, self.kv_args.gpu_id))
dst_addrs.append((dst_addr, length, dst_gpu_id))

src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM")
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM")

xfer_handle = self.agent.initialize_xfer(
"WRITE",
src_descs,
dst_descs,
peer_name,
notif.encode("ascii"),
)
if not xfer_handle:
raise Exception("Failed to create Mamba state transfer")
state = self.agent.transfer(xfer_handle)
if state == "ERR":
raise Exception("Failed to post Mamba state transfer")
return xfer_handle

def add_transfer_request(
self,
bootstrap_room: int,
Expand All @@ -570,6 +647,7 @@ def add_transfer_request(
is_last: bool,
chunk_id: int,
aux_index: Optional[int] = None,
state_indices: Optional[List[int]] = None,
):
assert self.disaggregation_mode == DisaggregationMode.PREFILL
assert not is_last or (is_last and aux_index is not None)
Expand Down Expand Up @@ -618,13 +696,31 @@ def add_transfer_request(
handles.append(kv_xfer_handle)
# Only the last chunk we need to send the aux data.
if is_last:
if state_indices is not None:
state_type = getattr(self.kv_args, "state_type", "none")
if self.attn_tp_size != self.decode_kv_args_table[req.agent_name].decode_tp_size:
raise RuntimeError(
"PD Disaggregation does NOT support PD different TP sizes for hybrid mamba models yet."
)

if state_type == "mamba":
state_xfer_handle = self._send_mamba_state(
req.agent_name,
state_indices,
self.decode_kv_args_table[req.agent_name].dst_state_data_ptrs,
req.dst_state_indices,
self.decode_kv_args_table[req.agent_name].gpu_id,
f"{req.room}_state_{self.kv_args.pp_rank}",
)
handles.append(state_xfer_handle)

assert aux_index is not None
aux_xfer_handle = self.send_aux(
req.agent_name,
aux_index,
self.decode_kv_args_table[req.agent_name].dst_aux_ptrs,
req.dst_aux_index,
str(req.room) + "_aux",
f"{req.room}_aux",
)
handles.append(aux_xfer_handle)
if is_last:
Expand Down Expand Up @@ -661,6 +757,9 @@ def update_transfer_status(self):
)
elif components[1] == "aux":
self.transfer_statuses[room].received_aux = True
elif components[1] == "state":
pp_rank = int(components[2]) if len(components) > 2 else 0
self.transfer_statuses[room].received_state_per_pp.add(pp_rank)

def check_transfer_done(self, room: int):
if room not in self.transfer_statuses:
Expand Down Expand Up @@ -735,6 +834,7 @@ def send(
is_last,
self.chunk_id,
self.aux_index,
state_indices,
)
self.xfer_handles.extend(new_xfer_handles)
self.chunk_id += 1
Expand Down Expand Up @@ -808,9 +908,18 @@ def init(
kv_indices.tobytes() if not is_dummy else b"",
str(aux_index).encode("ascii"),
str(self.required_dst_info_num).encode("ascii"),
(
np.array(state_indices, dtype=np.int32).tobytes()
if not is_dummy and state_indices is not None
else b""
),
]
)

# Mark that we expect state data if state_indices was provided
if state_indices is not None:
self.kv_mgr.transfer_statuses[self.bootstrap_room].expects_state = True

self.started_transfer = True
self.init_time = time.time()

Expand Down Expand Up @@ -862,6 +971,9 @@ def _register_kv_args(self):
packed_aux_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
)
packed_state_data_ptrs = b"".join(
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.state_data_ptrs
)

with lock:
sock.send_multipart(
Expand All @@ -874,6 +986,7 @@ def _register_kv_args(self):
self.kv_mgr.agent.get_agent_metadata(),
packed_kv_data_ptrs,
packed_aux_data_ptrs,
packed_state_data_ptrs,
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
str(self.kv_mgr.kv_args.decode_tp_size).encode("ascii"),
str(self.kv_mgr.kv_args.engine_rank).encode("ascii"),
Expand Down
Loading