Skip to content

Commit 0ea5209

Browse files
authored
[bugfix] fix model_type vllm_engine (modelscope#8117)
1 parent 79e93cf commit 0ea5209

File tree

3 files changed

+27
-12
lines changed

3 files changed

+27
-12
lines changed

swift/infer_engine/lmdeploy_engine.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
1818

1919
from swift.metrics import Metric
20-
from swift.model import get_model_info_meta, get_processor
20+
from swift.model import get_processor
2121
from swift.template import Template
22-
from swift.utils import get_logger, get_seed
22+
from swift.utils import get_logger, get_seed, safe_snapshot_download
2323
from .infer_engine import InferEngine
2424
from .patch import patch_auto_config, patch_auto_tokenizer
2525
from .protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
@@ -72,8 +72,13 @@ def __init__(
7272
processor = self._get_processor()
7373
template = self._get_template(processor)
7474
else:
75-
get_model_info_meta(
76-
model_id_or_path, hub_token=hub_token, use_hf=use_hf, revision=revision, download_model=True)
75+
safe_snapshot_download(
76+
model_id_or_path,
77+
revision=revision,
78+
download_model=True,
79+
use_hf=use_hf,
80+
ignore_patterns=getattr(template.model_meta, 'ignore_patterns', None),
81+
hub_token=hub_token)
7782
super().__init__(template)
7883

7984
if self.max_model_len is not None:

swift/infer_engine/sglang_engine.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
1313

1414
from swift.metrics import Metric
15-
from swift.model import get_model_info_meta, get_processor
15+
from swift.model import get_processor
1616
from swift.template import Template
17-
from swift.utils import get_logger
17+
from swift.utils import get_logger, safe_snapshot_download
1818
from .infer_engine import InferEngine
1919
from .protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
2020
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, EmbeddingResponse,
@@ -85,8 +85,13 @@ def __init__(
8585
processor = self._get_processor()
8686
template = self._get_template(processor)
8787
else:
88-
get_model_info_meta(
89-
model_id_or_path, hub_token=hub_token, use_hf=use_hf, revision=revision, download_model=True)
88+
safe_snapshot_download(
89+
model_id_or_path,
90+
revision=revision,
91+
download_model=True,
92+
use_hf=use_hf,
93+
ignore_patterns=getattr(template.model_meta, 'ignore_patterns', None),
94+
hub_token=hub_token)
9095
super().__init__(template)
9196
self._prepare_server_args(engine_kwargs)
9297
self.engine = sgl.Engine(server_args=self.server_args)

swift/infer_engine/vllm_engine.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@
1313
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
1414

1515
from swift.metrics import Metric
16-
from swift.model import get_model_info_meta, get_processor
16+
from swift.model import get_processor
1717
from swift.template import Template
18-
from swift.utils import get_device, get_dist_setting, get_logger, is_dist
18+
from swift.utils import get_device, get_dist_setting, get_logger, is_dist, safe_snapshot_download
1919
from .infer_engine import InferEngine
2020
from .patch import patch_auto_config, patch_auto_tokenizer
2121
from .protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
@@ -141,8 +141,13 @@ def __init__(
141141
processor = self._get_processor()
142142
template = self._get_template(processor)
143143
else:
144-
get_model_info_meta(
145-
model_id_or_path, hub_token=hub_token, use_hf=use_hf, revision=revision, download_model=True)
144+
safe_snapshot_download(
145+
model_id_or_path,
146+
revision=revision,
147+
download_model=True,
148+
use_hf=use_hf,
149+
ignore_patterns=getattr(template.model_meta, 'ignore_patterns', None),
150+
hub_token=hub_token)
146151
super().__init__(template)
147152
if max_model_len is not None:
148153
self.max_model_len = max_model_len

0 commit comments

Comments
 (0)