Skip to content

[PD] MORI-IO: Add state transfer, async I/O workers, and high-concurrency fixes#22665

Draft
maning00 wants to merge 5 commits intosgl-project:mainfrom
maning00:v0.5.10-mori-io-state
Draft

[PD] MORI-IO: Add state transfer, async I/O workers, and high-concurrency fixes#22665
maning00 wants to merge 5 commits intosgl-project:mainfrom
maning00:v0.5.10-mori-io-state

Conversation

@maning00
Copy link
Copy Markdown
Contributor

@maning00 maning00 commented Apr 13, 2026

Motivation

Follow-up to #14626 which introduced MORI-IO as the RDMA-based KV transfer backend for PD disaggregation on AMD hardware. This PR addresses the known limitation (no state transfer) and resolves several performance bottlenecks and correctness issues discovered under high-concurrency workloads:

  1. State data transfer was not implemented for hybrid models (Mamba, SWA, NSA).
  2. The synchronous inline transfer model blocked the sender's poll loop, limiting throughput.
  3. ZMQ-based auxiliary data transfer caused message flooding and transfer queue hangs under high concurrency.
  4. TP slice head mapping was incorrect for prefill_tp_size > decode_tp_size with GQA/MQA.

Modifications

All changes are confined to python/sglang/srt/disaggregation/mori/conn.py.

1. State Transfer Support (Mamba, SWA, NSA)

  • Added send_state() method on MoriKVManager that dispatches to _send_mamba_state() or _send_swa_nsa_state() based on state_type.
  • _send_mamba_state(): Single-index Mamba SSM state transfer with TP-mismatch slice support (computes per-dimension offsets when prefill TP != decode TP).
  • _send_swa_nsa_state(): Multi-token SWA/NSA state transfer using group_concurrent_contiguous() and batched RDMA writes.
  • Extended TransferInfo with dst_state_indices and KVArgsRegisterInfo with dst_state_item_lens / dst_state_dim_per_tensor.

2. Async Worker-Exclusive I/O Model (FastQueue)

Replaced the synchronous inline transfer model with an asynchronous worker-thread architecture:

  • TransferKVChunk: New dataclass encapsulating a single chunk of transfer work.
  • Transfer worker pool: Configurable number of daemon threads (SGLANG_MORI_TRANSFER_QUEUE_SIZE, default: 4), each blocking on its own FastQueue.
  • Sharded queue assignment: Chunks are routed to queues by hash of destination engine key, ensuring per-destination ordering.
  • _process_transfer_chunk(): Full lifecycle management per chunk — KV transfer, state transfer, aux transfer, RDMA status polling, and decode notification.
  • Simplified poll() to a pure status reader; all transfer work is offloaded to worker threads.

3. AUX Data via RDMA (Replacing ZMQ TCP)

  • Rewrote send_aux() to use engine.batch_write() with registered aux memory descriptors instead of ZMQ send_multipart. This eliminates ZMQ message flooding that caused decode-side hangs under high concurrency.
  • Added _connect_threadsafe() with thread-local ZMQ sockets for remaining ZMQ usage.

4. Batched Multi-Layer RDMA Transfers

  • _batched_layer_transfers(): Issues a single engine.batch_write() across all layers, reducing RDMA calls from O(layers) to O(1).
  • _batched_tp_slice_transfers(): Same batching for TP-mismatch slice transfers with vectorized NumPy offset computation.

5. Bug Fixes

  • TP slice fix: Corrected head mapping for prefill_tp_size > decode_tp_size with GQA/MQA. Introduced src_replication and unique_head_idx for correct replicated head mapping.
  • Stale metadata guard: _handle_transfer_message() now accepts metadata when room status is None (room not yet created by scheduler), preventing hangs.
  • update_status() state machine guard: Failed is terminal and never overwritten.
  • CP rank support: _compute_prefill_unique_rank() now correctly encodes TP/PP/CP ranks.

6. Default Parallelism Tuning

Parameter Before After
SGLANG_MORI_QP_PER_TRANSFER 1 4
SGLANG_MORI_NUM_WORKERS 1 4
SGLANG_MORI_TRANSFER_QUEUE_SIZE N/A 4

Benchmarking

Hardware Configuration

  • GPUs: 8x AMD Instinct MI355X per node
  • Network: 8x AMD Pensando Pollara 400 AI-NIC per node (ionic_0 ~ ionic_7)
  • Model: DeepSeek-R1 with TP=8
  • Setup: 2-node PD disaggregation (1 prefill + 1 decode) + router

Benchmark Commands

Prefill instance:

export MORI_RDMA_TC=104
python3 -m sglang.launch_server --model-path DeepSeek-R1 \
    --disaggregation-mode prefill --host 0.0.0.0 --port 30002 \
    --tp-size 8 --kv-cache-dtype fp8_e4m3 \
    --disaggregation-transfer-backend mori \
    --disaggregation-ib-device ionic_0,ionic_1,ionic_2,ionic_3,ionic_4,ionic_5,ionic_6,ionic_7 \
    --disable-radix-cache --trust-remote-code

Decode instance:

export MORI_RDMA_TC=104
python3 -m sglang.launch_server --model-path DeepSeek-R1 \
    --disaggregation-mode decode --host 0.0.0.0 --port 30003 \
    --tp-size 8 --kv-cache-dtype fp8_e4m3 \
    --disaggregation-transfer-backend mori \
    --disaggregation-ib-device ionic_0,ionic_1,ionic_2,ionic_3,ionic_4,ionic_5,ionic_6,ionic_7 \
    --disable-radix-cache --trust-remote-code

Router:

python -m sglang_router.launch_router --pd-disaggregation \
    --port 30000 --policy random \
    --prefill-policy random --decode-policy random \
    --prefill http://<prefill_host>:30002 \
    --decode http://<decode_host>:30003

Benchmark client:

python3 -m sglang.bench_serving \
    --backend sglang --host 0.0.0.0 --port 30000 \
    --dataset-name random --num-prompts 2048 --max-concurrency 2048 \
    --random-input-len 8192 --random-output-len 1024 \
    --model DeepSeek-R1

Performance Results

FP8 KV Cache (--kv-cache-dtype fp8_e4m3, --num-prompts 1, --random-output 16, each run 2 times averaged):

Input Tokens TTFT Mean (ms) E2E Latency Mean (ms)
1024 85.64 162.30
2048 85.72 166.53
4096 86.82 167.70
8192 196.60 280.95

BF16 KV Cache (--num-prompts 1, --random-output 16, each run 2 times averaged):

Input Tokens TTFT Mean (ms) E2E Latency Mean (ms)
1024 85.50 162.83
2048 84.63 168.36
4096 85.69 169.20
8192 207.45 296.53

Accuracy Test

python3 -m sglang.test.few_shot_gsm8k \
    --host http://127.0.0.1 --port 30000 \
    --num-questions 200 --parallel 128 --num-shots 5
Metric Result
Accuracy 0.975
Invalid 0.000

Known Limitations

  • SWA/NSA state transfer does not yet support TP-mismatch with non-MLA attention (consistent with Mooncake and NIXL backends).

cc @Duyi-Wang @ZhaiFeiyue

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@github-actions github-actions bot added lora deepseek blackwell SM100/SM120 diffusion SGLang Diffusion labels Apr 13, 2026
Three fixes for the Worker-Exclusive I/O model (Transfer Queue):

1. AUX data via RDMA instead of ZMQ TCP: send_aux() now uses
   engine.batch_write() with registered aux_mem_descs, eliminating
   10 ZMQ messages per room that flooded the decode PULL socket and
   blocked status notifications.

2. Stale metadata guard relaxed: _handle_transfer_message() now
   accepts metadata when current status is None (room not yet
   created by scheduler), since decode can send metadata before
   prefill creates MoriKVSender. Only rejects active/terminal states.

3. ZMQ socket improvements: _connect_threadsafe() shares a single
   zmq.Context across worker threads, sets SNDHWM=0 and LINGER=0
   to prevent message loss under load.
@maning00 maning00 force-pushed the v0.5.10-mori-io-state branch from a9ef95a to fcb1aa3 Compare April 13, 2026 06:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

blackwell SM100/SM120 deepseek diffusion SGLang Diffusion lora

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant