forked from modelscope/ms-swift
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlmdeploy_engine.py
More file actions
350 lines (318 loc) · 15.8 KB
/
Copy pathlmdeploy_engine.py
File metadata and controls
350 lines (318 loc) · 15.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
# Copyright (c) ModelScope Contributors. All rights reserved.
import asyncio
import inspect
import lmdeploy
import os
import time
import torch
from contextlib import contextmanager
from copy import deepcopy
from lmdeploy import PytorchEngineConfig, TurbomindEngineConfig, VisionConfig, pipeline
from lmdeploy.api import autoget_backend_config
from lmdeploy.serve import async_engine
from packaging import version
from PIL import Image
from transformers import GenerationConfig
from transformers.utils.versions import require_version
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
from swift.metrics import Metric
from swift.model import get_processor
from swift.template import Template
from swift.utils import get_logger, get_seed, safe_snapshot_download
from .infer_engine import InferEngine
from .patch import patch_auto_config, patch_auto_tokenizer
from .protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, InferRequest, RequestConfig)
from .utils import InferStreamer
try:
from lmdeploy import EngineGenerationConfig as LmdeployGenerationConfig
except ImportError:
# compat lmdeploy >= 0.6.*
from lmdeploy import GenerationConfig as LmdeployGenerationConfig
logger = get_logger()
class LmdeployEngine(InferEngine):
def __init__(
self,
model_id_or_path: str,
*,
template: Optional[Template] = None,
torch_dtype: Optional[torch.dtype] = None,
model_type: Optional[str] = None,
template_type: Optional[str] = None,
use_hf: Optional[bool] = None,
hub_token: Optional[str] = None,
revision: Optional[str] = None,
# engine_kwargs
tp: int = 1,
session_len: Optional[int] = None,
cache_max_entry_count: float = 0.8,
quant_policy: int = 0, # e.g. 4, 8
vision_batch_size: int = 1, # max_batch_size in VisionConfig
engine_kwargs: Optional[Dict[str, Any]] = None,
devices: Optional[List[int]] = None,
) -> None:
self.model_id_or_path = model_id_or_path
self.torch_dtype = torch_dtype
self.model_type = model_type
self.use_hf = use_hf
self.hub_token = hub_token
self.revision = revision
self.tp = tp
self.session_len = session_len
self.cache_max_entry_count = cache_max_entry_count
self.quant_policy = quant_policy
self.vision_batch_size = vision_batch_size
self.devices = devices
if template is None:
processor = self._get_processor()
template = self._get_template(processor, template_type=template_type)
else:
safe_snapshot_download(
model_id_or_path,
revision=revision,
download_model=True,
use_hf=use_hf,
ignore_patterns=getattr(template.model_meta, 'ignore_patterns', None),
hub_token=hub_token)
super().__init__(template)
if self.max_model_len is not None:
self.max_model_len -= 1
self._prepare_engine_kwargs(engine_kwargs)
self.config.torch_dtype = self.torch_dtype = self.torch_dtype or self.model_info.torch_dtype
self._prepare_engine()
self._load_generation_config()
def _get_processor(self):
return get_processor(
model_id_or_path=self.model_id_or_path,
torch_dtype=self.torch_dtype,
download_model=True,
model_type=self.model_type,
use_hf=self.use_hf,
hub_token=self.hub_token,
revision=self.revision)
def _prepare_engine_kwargs(self, engine_kwargs):
if engine_kwargs is None:
engine_kwargs = {}
engine_kwargs['tp'] = self.tp
engine_kwargs['session_len'] = self.session_len
engine_kwargs['cache_max_entry_count'] = self.cache_max_entry_count
engine_kwargs['quant_policy'] = self.quant_policy
if 'devices' in inspect.signature(TurbomindEngineConfig).parameters:
engine_kwargs['devices'] = self.devices
backend_config = TurbomindEngineConfig(**engine_kwargs)
backend_config = autoget_backend_config(self.model_dir, backend_config)
self.backend_config = backend_config
logger.info(f'backend_config: {backend_config}')
pipeline_kwargs = {}
is_multimodal = self.model_meta.is_multimodal
if is_multimodal:
require_version(
'lmdeploy<0.9', 'LmdeployEngine will no longer maintain inference for '
'multimodal models in lmdeploy>=0.9.')
vision_config = VisionConfig(max_batch_size=self.vision_batch_size)
pipeline_kwargs['vision_config'] = vision_config
logger.info(f'vision_config: {vision_config}')
self.pipeline_kwargs = pipeline_kwargs
@contextmanager
def _patch_pipeline(self):
_old_best_match_model = async_engine.best_match_model
def _best_match_model(*args, **kwargs) -> Optional[str]:
return self.model_info.model_type
async_engine.best_match_model = _best_match_model
try:
yield
finally:
async_engine.best_match_model = _old_best_match_model
def _prepare_engine(self):
with patch_auto_tokenizer(self.tokenizer), patch_auto_config(self.config), self._patch_pipeline():
engine = pipeline(self.model_dir, backend_config=self.backend_config, **self.pipeline_kwargs)
self.engine = engine
def _load_generation_config(self):
generation_config_path = os.path.join(self.model_dir, 'generation_config.json')
if os.path.isfile(generation_config_path):
generation_config = GenerationConfig.from_pretrained(self.model_dir)
kwargs = generation_config.to_dict()
max_new_tokens = kwargs.get('max_new_tokens')
if max_new_tokens is None:
kwargs.pop('max_new_tokens', None)
parameters = inspect.signature(LmdeployGenerationConfig).parameters
for k, v in kwargs.copy().items():
if k not in parameters or v is None:
kwargs.pop(k)
self.generation_config = LmdeployGenerationConfig(**kwargs)
else:
self.generation_config = LmdeployGenerationConfig()
def _add_stop_words(self, generation_config: LmdeployGenerationConfig, request_config: RequestConfig) -> None:
template_meta = self.template.template_meta
stop_words = (request_config.stop or []) + (self.generation_config.stop_words or []) + template_meta.stop_words
generation_config.stop_words = self._get_stop_token_ids(stop_words)
# compat lmdeploy >= 0.6.*
generation_config.stop_token_ids = generation_config.stop_words
def _prepare_generation_config(self, request_config: RequestConfig) -> LmdeployGenerationConfig:
kwargs = {'max_new_tokens': request_config.max_tokens}
for key in ['temperature', 'top_k', 'top_p', 'repetition_penalty']:
new_value = getattr(request_config, key)
if new_value is None:
kwargs[key] = getattr(self.generation_config, key)
else:
kwargs[key] = new_value
if request_config.seed is None:
request_config.seed = get_seed()
kwargs['random_seed'] = request_config.seed
if request_config.temperature == 0:
kwargs['temperature'] = 1 # avoid unnecessary process
kwargs['top_k'] = 1
if request_config.logprobs:
kwargs['logprobs'] = 1
if request_config.top_logprobs is not None:
kwargs['logprobs'] = max(1, request_config.top_logprobs)
res = LmdeployGenerationConfig(**kwargs)
return res
async def _infer_stream_async(
self,
inputs: Dict[str, Any],
generation_config: LmdeployGenerationConfig,
request_config: RequestConfig,
) -> AsyncIterator[ChatCompletionStreamResponse]:
session_id = time.time_ns()
kwargs = {'stream_output': True, 'gen_config': generation_config, 'sequence_start': True, 'sequence_end': True}
if version.parse(lmdeploy.__version__) >= version.parse('0.6.5'):
async with self.engine.model_inst(session_id) as inst:
context = self.engine.safe_run(inst, session_id, **inputs, **kwargs)
else:
context = self.engine.safe_run(session_id)
infer_streamer = InferStreamer(self.template, template_inputs=inputs['template_inputs'])
token_idx = 0
async with context as gen:
if version.parse(lmdeploy.__version__) < version.parse('0.6.5'):
generator = await self.engine.get_generator(False, session_id)
gen = generator.async_stream_infer(session_id=session_id, **inputs, **kwargs)
is_finished = False
while not is_finished:
try:
output = await gen.__anext__()
except StopAsyncIteration:
is_finished = True
delta_text = infer_streamer.get_printable_text(output.token_ids, is_finished)
if not delta_text and not is_finished:
continue
logprobs = self._get_logprobs(output.logprobs, output.token_ids[token_idx:],
request_config.top_logprobs)
token_idx = len(output.token_ids)
usage_info = self._get_usage_info(len(inputs['input_ids']), output.num_token)
toolcall = None
if is_finished:
toolcall = self._get_toolcall(
self.template.decode_generate_ids(output.token_ids, template_inputs=inputs['template_inputs']))
finish_reason = self._get_finish_reason(generation_config.max_new_tokens, output.num_token,
output.status.name == 'FINISH')
choices = [
ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role='assistant', content=delta_text, tool_calls=toolcall),
finish_reason=finish_reason,
logprobs=logprobs)
]
yield ChatCompletionStreamResponse(model=self.model_name, choices=choices, usage=usage_info)
async def _infer_full_async(
self,
inputs: Dict[str, Any],
generation_config: LmdeployGenerationConfig,
request_config: RequestConfig,
) -> ChatCompletionResponse:
session_id = time.time_ns()
kwargs = {'stream_output': False, 'gen_config': generation_config, 'sequence_start': True, 'sequence_end': True}
if version.parse(lmdeploy.__version__) >= version.parse('0.6.5'):
async with self.engine.model_inst(session_id) as inst:
async with self.engine.safe_run(inst, session_id, **inputs, **kwargs) as gen:
async for output in gen:
pass
if self.engine.backend == 'pytorch':
# manually end pytorch session
await inst.async_end(session_id)
else:
async with self.engine.safe_run(session_id):
generator = await self.engine.get_generator(False, session_id)
async for output in generator.async_stream_infer(session_id=session_id, **inputs, **kwargs):
pass
response = self.template.decode_generate_ids(output.token_ids, template_inputs=inputs['template_inputs'])
logprobs = self._get_logprobs(output.logprobs, output.token_ids, request_config.top_logprobs)
usage_info = self._get_usage_info(len(inputs['input_ids']), output.num_token)
toolcall = self._get_toolcall(response)
finish_reason = self._get_finish_reason(generation_config.max_new_tokens, output.num_token,
output.status.name == 'FINISH')
token_ids = output.token_ids if request_config.return_details else None
choices = [
ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role='assistant', content=response, tool_calls=toolcall),
finish_reason=finish_reason,
logprobs=logprobs,
token_ids=token_ids)
]
prompt_token_ids = None
images_size = None
if request_config.return_details:
prompt_token_ids = inputs['input_ids']
images = inputs['template_inputs'].images
if all(isinstance(image, Image.Image) for image in images):
images_size = [image.size for image in images]
return ChatCompletionResponse(
model=self.model_name,
choices=choices,
usage=usage_info,
prompt_token_ids=prompt_token_ids,
images_size=images_size)
async def infer_async(self,
infer_request: InferRequest,
request_config: Optional[RequestConfig] = None,
*,
pre_infer_hook=None,
**kwargs) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionStreamResponse]]:
request_config = deepcopy(request_config or RequestConfig())
self.template.set_mode('lmdeploy')
loop = asyncio.get_running_loop()
with torch.inference_mode():
inputs = await loop.run_in_executor(None, self.template.encode, infer_request, True)
images = inputs.pop('images', None)
if images:
if version.parse(lmdeploy.__version__) >= version.parse('0.6.5'):
messages = self.engine._convert_prompts(('', images))
messages = await self.engine.async_convert_to_pil_images(messages)
results = await self.engine.vl_encoder.preprocess(messages)
if self.engine.backend == 'turbomind':
results = await self.engine.vl_encoder.async_infer(results)
inputs['images'] = [result['content'] for result in results if result['role'] == 'forward'][0]
await self.template.prepare_lmdeploy_turbomind_inputs(inputs)
else:
inputs['images'] = results[1]['content']
await self.template.prepare_lmdeploy_pytorch_inputs(inputs)
else:
inputs['images'] = await self.engine.vl_encoder.async_infer(images)
await self.template.prepare_lmdeploy_turbomind_inputs(inputs)
self.set_default_max_tokens(request_config, inputs)
generation_config = self._prepare_generation_config(request_config)
self._add_stop_words(generation_config, request_config)
kwargs.update({'inputs': inputs, 'generation_config': generation_config, 'request_config': request_config})
if pre_infer_hook:
kwargs = pre_infer_hook(kwargs)
if request_config.stream:
return self._infer_stream_async(**kwargs)
else:
return await self._infer_full_async(**kwargs)
def _batch_infer_stream(self, *args, **kwargs):
if hasattr(self.engine, 'vl_encoder'):
self.engine.vl_encoder._loop_task = None
if hasattr(self.engine, 'free_insts'):
self.engine.free_insts = None
return super()._batch_infer_stream(*args, **kwargs)
def infer(
self,
infer_requests: List[InferRequest],
request_config: Optional[RequestConfig] = None,
metrics: Optional[List[Metric]] = None,
*,
use_tqdm: Optional[bool] = None,
**kwargs,
) -> List[Union[ChatCompletionResponse, Iterator[ChatCompletionStreamResponse]]]:
return super().infer(infer_requests, request_config, metrics, use_tqdm=use_tqdm, **kwargs)