Skip to content

Commit 50b0ef7

Browse files
authored
LMDeploy Distserve (#3304)
* sync main * typo correct * 1. typo 2. add migration event * 1. move slime to 'https://github.com/JimyMa/DLSlime.git' and init readme. * Update disagg README * mute slime when disable distserve * remove build_migration.sh * revert debug code * 1. identify interface. 2. add multi backend registry * add dlslime max transfer batch * add an infinistore interface * add load/store * conditional register of Multi Migration Backend * merge router to proxy * remove redandunt print * 1. remove redandunt print 2. revert safe_run * dsv3 kvtransfer support (bypass v cache) * dsv3 debug, 1. change log info to log debug of log resp. 2. add num_cpus to ray.init for run in dlc * DSV3 Debug, known issue: 1. [PD Connection more efficiently][High Priority] In DSV3 DP + EP condition, we need to concurrently construct prefill_dp_size (for exampe 32) * decode_dp_size(for example 144) links. We add a function `pd_consolidation_multi_thread` to do this. However, we need to check if the construction operation is thread safe. 2. [Combine with proxy] Maybe we should save conn_config to avoid repeatly reconnection of PD Link. 3. [PD Control Plane][High Priority] For DP + EP, we need to reconstruct DisaggEngineConfig to record more information (e.g. dp_idx, tp_idx ...) 4. [Combine with router][Important] How to perform PD Load Balance in disaggregated LLM Serving. 5. [PD Data Plane] adapt to Open Source KVCache manager like Mooncake, infiniStore or NiXL and more transport media. * revert match to if,else * [bugfix] rename typo * [refactor] refactor pd_conn * 1. format code. 2. add engine_role for passing ut test * 1. format code 2. parse dp, ep, and dp rank to DisaggEngineConfig * 1. add pd conn timeout, 2. add default EngineRole to Hybrid, 3. fix disagg strategy proxy typo * 1. refactor PDConnection Pool * refactor debug * fix migration loop bug * add proxy arguments about distserve * bugfix * debug interface * remove unnesessary EngineRole Check. * add v1/chat/completions support * remove redundent print * async free cache * async free cache * 1. add some comments. * 1. bugfix * [proxy] add connection_warmup api * 1. bugfix (warmup_connection_typo and wrong args) 2. preserve cache bugfix * [disagg] update readme, 1. fault tolerance and 2. replace router to proxy. * bugfix * fix decode back pressure bug * 1. add migration_request to chat/completions for correctly cache free * 2. free cache bugfix * 1. fix lock running bug * 1. fix dist.broadcast deadlock * [lint] 1. fix lint * rename Ethernet to RoCE * change emun.Enum.__members__[elem] to enum.Enum[elem] directly * update readme * update migration-backend * 1. update readme 2. move module to string for conditional import * 1. update readme * 1. remove migic number and handle long assignments in dlslime. 2. add uniexecutor support * fix error migration in dummy situation * 1. bugfix when token is not a decodable utf-8 (in test) * 1. overlapping migration and forward. * bump dlslime to v0.0.1.post5 * remove print * remove free in decode engine because already freed in proxy * 1. bump dlslime to 0.0.1.post7 * 1. [proxy] revert self.nodes to nodes 2. [api_server] remove redundant api * 1. [cli] remove available_nic args * format comments * [pytorch paging] remove redundant logger * [model_agent] bugfix caused by merge * [model agent] bypass model agent migrate * revert migrate to sync mode * bypass model agent migrate in uni_executor * [proxy] set default serving strategy to DistServe * 1. [disagg] update readme * info -> debug * remove unused code * lazily initialize migration event * add nvlink support * mute TCP support by now * update readme for execption * set migration token_ids output to numpy array * update readme * In PD Disaggregation Mode, fallback next token ids to CPU * 1. [disagg] update readme * move disagg to pytorch backend
1 parent 31d7290 commit 50b0ef7

32 files changed

+1720
-135
lines changed

benchmark/profile_generation.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ async def _gather_tasks(tasks):
178178

179179
out_token_throughput = np.round(token_latency_stats.size / elapsed_time, 2)
180180
total_token_throughput = np.round(concurrency * test_round * (input_seqlen + output_seqlen) / elapsed_time, 2)
181-
print(f'\n{"-" * 50}\ntotal time: {elapsed_time:.2f}s\n'
181+
print(f'\n{" - " * 50}\ntotal time: {elapsed_time:.2f}s\n'
182182
f'concurrency: {concurrency}, test_round: {test_round}\n'
183183
f'input_tokens: {input_seqlen}, output_tokens: {output_seqlen}\n'
184184
f'first_token latency(min, max, ave): '
@@ -188,7 +188,7 @@ async def _gather_tasks(tasks):
188188
f'{token_latency_ave}s\n'
189189
f'token_latency percentiles(50%,75%,95%,99%)(s): {percentiles}\n'
190190
f'throughput(output): {out_token_throughput} token/s\n'
191-
f'throughput(total): {total_token_throughput} token/s\n{"-" * 50}')
191+
f'throughput(total): {total_token_throughput} token/s\n{" - " * 50}')
192192
return model_path, \
193193
[first_token_latency_min, first_token_latency_max,
194194
first_token_latency_ave], \

lmdeploy/cli/serve.py

+22-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Copyright (c) OpenMMLab. All rights reserved.
2-
2+
from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend
33
from lmdeploy.utils import get_max_batch_size
44

55
from .cli import CLI
@@ -167,6 +167,8 @@ def add_parser_api_server():
167167
ArgumentHelper.dp_rank(pt_group)
168168
ArgumentHelper.ep(pt_group)
169169
ArgumentHelper.enable_microbatch(pt_group)
170+
ArgumentHelper.role(pt_group)
171+
ArgumentHelper.migration_backend(pt_group)
170172

171173
# turbomind args
172174
tb_group = parser.add_argument_group('TurboMind engine arguments')
@@ -216,7 +218,13 @@ def add_parser_proxy():
216218
parser.set_defaults(run=SubCliServe.proxy)
217219
parser.add_argument('--server-name', type=str, default='0.0.0.0', help='Host ip for proxy serving')
218220
parser.add_argument('--server-port', type=int, default=8000, help='Server port of the proxy')
219-
parser.add_argument('--strategy',
221+
parser.add_argument('--serving-strategy',
222+
type=str,
223+
choices=['Hybrid', 'DistServe'],
224+
default='Hybrid',
225+
help='the strategy to serve, Hybrid for colocating Prefill and Decode'
226+
'workloads into same engine, DistServe for Prefill-Decode Disaggregation')
227+
parser.add_argument('--routing-strategy',
220228
type=str,
221229
choices=['random', 'min_expected_latency', 'min_observed_latency'],
222230
default='min_expected_latency',
@@ -226,6 +234,15 @@ def add_parser_proxy():
226234
help='Whether to disable cache status of the '
227235
'proxy. If set, the proxy will forget the status '
228236
'of the previous time')
237+
238+
# For Disaggregation
239+
parser.add_argument('--migration-protocol',
240+
type=str,
241+
choices=['RDMA', 'NVLINK'],
242+
default='RDMA',
243+
help='transport protocol of KV migration')
244+
parser.add_argument('--link-type', type=str, choices=['RoCE', 'IB'], default='RoCE', help='RDMA Link Type')
245+
parser.add_argument('--disable-gdr', action='store_true', help='with GPU Direct Memory Access')
229246
ArgumentHelper.api_keys(parser)
230247
ArgumentHelper.ssl(parser)
231248
ArgumentHelper.log_level(parser)
@@ -311,7 +328,9 @@ def api_server(args):
311328
quant_policy=args.quant_policy,
312329
eager_mode=args.eager_mode,
313330
max_prefill_token_num=args.max_prefill_token_num,
314-
enable_microbatch=args.enable_microbatch)
331+
enable_microbatch=args.enable_microbatch,
332+
role=EngineRole[args.role],
333+
migration_backend=MigrationBackend[args.migration_backend])
315334
else:
316335
from lmdeploy.messages import TurbomindEngineConfig
317336
backend_config = TurbomindEngineConfig(dtype=args.dtype,

lmdeploy/cli/utils.py

+19
Original file line numberDiff line numberDiff line change
@@ -527,3 +527,22 @@ def enable_microbatch(parser):
527527
return parser.add_argument('--enable-microbatch',
528528
action='store_true',
529529
help='enable microbatch for specified model')
530+
531+
# For Disaggregation
532+
@staticmethod
533+
def role(parser):
534+
return parser.add_argument('--role',
535+
type=str,
536+
default='Hybrid',
537+
choices=['Hybrid', 'Prefill', 'Decode'],
538+
help='Hybrid for Non-Disaggregated Engine;'
539+
'Prefill for Disaggregated Prefill Engine;'
540+
'Decode for Disaggregated Decode Engine;')
541+
542+
@staticmethod
543+
def migration_backend(parser):
544+
return parser.add_argument('--migration-backend',
545+
type=str,
546+
default='DLSlime',
547+
choices=['DLSlime'],
548+
help='kvcache migration management backend when PD disaggregation')

lmdeploy/messages.py

+19
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,9 @@
66
import torch
77
from pydantic.dataclasses import dataclass as pydantic_dataclass
88

9+
from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend
10+
from lmdeploy.pytorch.disagg.request import MigrationRequest
11+
912
from .tokenizer import Tokenizer
1013
from .utils import get_logger
1114

@@ -107,6 +110,11 @@ class GenerationConfig:
107110
output_logits: Literal['all', 'generation'] = None
108111
output_last_hidden_state: Literal['all', 'generation'] = None
109112

113+
# for disaggregation
114+
with_cache: bool = False
115+
preserve_cache: bool = False
116+
migration_request: Optional[MigrationRequest] = None
117+
110118
def convert_stop_bad_words_to_ids(self, tokenizer: Tokenizer):
111119
"""convert stop_words/bad_sords to ids and append the ids to
112120
stop_token_ids/bad_token_ids."""
@@ -298,6 +306,10 @@ class PytorchEngineConfig:
298306
distributed_executor_backend (str): backend of distributed backend,
299307
options: ['uni', 'mp', 'ray']
300308
enable_microbatch (bool): enable microbatch for specified model
309+
role (EngineRole): role of engin, options: ['Hybrid', 'Prefill',
310+
'Decode']. Default to `EngineRole.Hybrid`.
311+
migration_backend: migration backend. options: ['DLSlime'].
312+
Default to `MigrationBackend.DLSlime`.
301313
"""
302314
dtype: str = 'auto'
303315
tp: int = 1
@@ -324,6 +336,9 @@ class PytorchEngineConfig:
324336
distributed_executor_backend: str = None
325337
enable_microbatch: bool = False
326338

339+
role: EngineRole = EngineRole.Hybrid
340+
migration_backend: MigrationBackend = MigrationBackend.DLSlime
341+
327342
def __post_init__(self):
328343
"""Check input validation."""
329344
assert self.dtype in ['auto', 'float16', 'bfloat16']
@@ -404,6 +419,8 @@ class EngineOutput:
404419
may not equal to the length of token_ids
405420
logprobs (List[Dict[int, float]]): the top logprobs for each output
406421
position.
422+
cache_block_ids (List[int]): send cache blocks back for migration in
423+
Disaggregated LLM Serving when Prefill Engine is Done.
407424
"""
408425
status: ResponseType
409426
token_ids: List[int]
@@ -412,6 +429,8 @@ class EngineOutput:
412429
logits: torch.Tensor = None
413430
last_hidden_state: torch.Tensor = None
414431

432+
cache_block_ids: Optional[List[int]] = None
433+
415434

416435
@dataclass
417436
class VisionConfig:

lmdeploy/pytorch/config.py

+6
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
import torch
66

7+
from lmdeploy.pytorch.disagg.config import EngineRole, MigrationBackend
8+
79

810
def _update_torch_dtype(config: 'ModelConfig', dtype: str):
911
"""Update the torch dtype from the model config.
@@ -80,6 +82,10 @@ class CacheConfig:
8082
quant_policy: Literal[0, 4, 8] = 0
8183
device_type: str = 'cuda'
8284

85+
# For PD Disaggregation
86+
role: EngineRole = EngineRole.Hybrid
87+
migration_backend: MigrationBackend = MigrationBackend.DLSlime
88+
8389
def __post_init__(self):
8490
"""post init."""
8591
from lmdeploy.utils import get_logger

lmdeploy/pytorch/disagg/README.md

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
# LMDeploy-DistServe
2+
3+
## Key Components
4+
5+
1.**Router Service**: Coordinates between prefill/decode engines
6+
2.**Migration Manager**: Facilitates high-performance memory sharing
7+
8+
## Installation
9+
10+
```
11+
# Inference Engine
12+
pip install lmdeploy[all] >= 0.7.0
13+
14+
# Transfer Engine
15+
pip install dlslime>=0.0.1.post7
16+
```
17+
18+
## Quick Start
19+
20+
A PD disaggregated deployment of DeepSeekV3 is shown below:
21+
22+
### 1. Launch Router Service
23+
24+
```shell
25+
lmdeploy serve proxy --server-name 0.0.0.0 --server-port 8000 --routing-strategy "min_expected_latency" --serving-strategy DistServe --log-level INFO
26+
```
27+
28+
LMDeploy-DistServe support both NVLink and RDMA for kvcache transferring from Prefill Engine to Decode Engine. RDMA is default model. Set `--migration-protocol NVLink` for NVLink transport.
29+
30+
### 2. Configure Endpoints
31+
32+
First deploy your prefill and decode engines.
33+
34+
```shell
35+
# Prefill Engine
36+
CUDA_VISIBLE_DEVICES=0 lmdeploy serve api_server internlm/internlm2_5-7b-chat --server-port 23333 --role Prefill --proxy-url http://0.0.0.0:8000 --backend pytorch
37+
# Decode Engine
38+
CUDA_VISIBLE_DEVICES=1 lmdeploy serve api_server internlm/internlm2_5-7b-chat --server-port 23334 --role Decode --proxy-url http://0.0.0.0:8000 --backend pytorch
39+
```
40+
41+
By now, only **Pytorch backend** supports PD Disaggregation.
42+
43+
## API Usage
44+
45+
```shell
46+
# API Invoke
47+
curl -X POST "http://localhost:8000/v1/completions" \
48+
-H "Content-Type: application/json" \
49+
-d '{"model": "internlm/internlm2_5-7b-chat", "temperature":0, "prompt": "Shanghai is a city that ", "max_tokens": 16, "stream": false}'
50+
# Output
51+
{
52+
"id":"2",
53+
"object":"text_completion",
54+
"created":1743662400,"
55+
model":"internlm/internlm2_5-7b-chat",
56+
"choices":[
57+
{
58+
"index":0,
59+
"text":" is very famous for its skyscrapers. It is also a city","logprobs":null,"finish_reason":"length"
60+
}
61+
],
62+
"usage": {
63+
"prompt_tokens":7,"total_tokens":23,"completion_tokens":16
64+
}
65+
}
66+
```
67+
68+
## Trouble Shooting
69+
70+
### RDMA Connection Failed:
71+
72+
Make sure ibverbs is correctly installed:
73+
74+
```
75+
# on Ubuntu
76+
sudo apt install libibverbs-dev
77+
# on CentOS
78+
sudo yum install ibverbs-devel
79+
```
80+
81+
```bash
82+
ibstat # Verify IB device status
83+
ibv_devinfo # Check device capabilities
84+
```
85+
86+
### Check GPU Direct RDMA:
87+
88+
By now, lmdeploy-distserve use GPUDirect RDMA to perform KVTransfer. Make sure GPUDirect RDMA Driver is loaded to kernel.
89+
90+
```bash
91+
lsmod | grep nv_peer_mem
92+
# GPUDirect RDMA info will be printed If GPUDirect RDMA is correctly loaded.
93+
```
94+
95+
### Connection Pool
96+
97+
Currently, if the ​​Proxy disconnects​​, the connection pool must be ​​warmed up again​​. A future enhancement could involve:
98+
99+
A ​​dedicated connection pool management server​​ (e.g., using ​​Raft-based tools like ETCD​​, as mentioned in ​​Mooncake​​) to improve ​​connection discovery​​ and avoid repeated warmups.
100+
101+
### Proxy
102+
103+
Do not add an engine nodes to **different proxy** because it is not supported and is not considered as a right usage by now.

lmdeploy/pytorch/disagg/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from lmdeploy.logger import get_logger
3+
4+
logger = get_logger('lmdeploy')
5+
6+
try:
7+
logger.debug('Registering DLSlime Backend')
8+
from .dlslime import DLSlimeBackend
9+
except ImportError:
10+
logger.warning('Disable DLSlime Backend')
11+
12+
try:
13+
logger.debug('Registering Mooncake Backend')
14+
from .mooncake import MooncakeBackend
15+
except ImportError:
16+
logger.warning('Disable Mooncake Backend')
17+
18+
try:
19+
logger.debug('Registering InfiniStoreBackend Backend')
20+
from .infinistore import InfiniStoreBackend
21+
except ImportError:
22+
logger.warning('Disable InfiniStoreBackend Backend')
23+
24+
__all__ = ['DLSlimeBackend', 'MooncakeBackend', 'InfiniStoreBackend']
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from mmengine.registry import Registry
3+
4+
MIGRATION_BACKENDS = Registry('migration_backend', locations=['lmdeploy.pytorch.disagg.backend.backend'])
+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# Copyright (c) OpenMMLab. All rights reserved.
2+
from abc import abstractmethod
3+
4+
from lmdeploy.pytorch.disagg.config import MigrationProtocol
5+
from lmdeploy.pytorch.disagg.messages import DistServeRegisterMRMessage, MigrationAssignment
6+
from lmdeploy.pytorch.disagg.request import DistServeConnectionRequest, DistServeInitRequest
7+
8+
9+
class MigrationBackendImpl:
10+
11+
@abstractmethod
12+
def p2p_initialize(self, init_request: DistServeInitRequest):
13+
raise NotImplementedError
14+
15+
@abstractmethod
16+
def register_memory_region(self, register_mr_request: DistServeRegisterMRMessage):
17+
raise NotImplementedError
18+
19+
@abstractmethod
20+
def endpoint_info(self, remote_engine_id: int, protocol: MigrationProtocol):
21+
return NotImplementedError
22+
23+
@abstractmethod
24+
def p2p_connect(self, conn_req: DistServeConnectionRequest):
25+
raise NotImplementedError
26+
27+
@abstractmethod
28+
def p2p_migrate(self, assignment: MigrationAssignment, async_op: bool = False):
29+
raise NotImplementedError
30+
31+
@abstractmethod
32+
def store(self, assignment: MigrationAssignment, async_op: bool = False):
33+
raise NotImplementedError
34+
35+
@abstractmethod
36+
def load(self, assignment: MigrationAssignment, async_op: bool = False):
37+
raise NotImplementedError

0 commit comments

Comments
 (0)