Skip to content

Commit c3768fd

Browse files
committed
feat: add modelscope download
1 parent c698f57 commit c3768fd

12 files changed

Lines changed: 178 additions & 324 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ dependencies = [
2020
"msgpack>=1.0.7",
2121
"safetensors>=0.5.1",
2222
"huggingface-hub",
23+
"modelscope",
2324
"transformers>=4.57.1",
2425
"jinja2>=3.1.0",
2526
"numpy>=1.26",
@@ -69,7 +70,6 @@ benchmark = [
6970
"tqdm",
7071
"datasets",
7172
"pillow",
72-
"modelscope",
7373
]
7474

7575
dev = [

scripts/download_model_shard.sh

Lines changed: 0 additions & 45 deletions
This file was deleted.

scripts/download_shard.py

Lines changed: 0 additions & 81 deletions
This file was deleted.

src/backend/benchmark/backend_request_func.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
from typing import List, Optional, Union
1212

1313
import aiohttp
14-
import huggingface_hub.constants
15-
from huggingface_hub import snapshot_download
1614
from tqdm.asyncio import tqdm
1715
from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast
1816

17+
from parallax.utils.model_download import download_model_snapshot
18+
1919
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
2020

2121

@@ -268,12 +268,11 @@ async def async_request_openai_chat_completions(
268268

269269
def get_model(pretrained_model_name_or_path: str) -> str:
270270

271-
model_path = snapshot_download(
271+
model_path = download_model_snapshot(
272272
repo_id=pretrained_model_name_or_path,
273-
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
274273
ignore_patterns=[".*.pt", ".*.safetensors", ".*.bin"],
275274
)
276-
return model_path
275+
return str(model_path)
277276

278277

279278
def get_tokenizer(

src/parallax/server/shard_loader.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from typing import Any, Dict, Optional, Tuple
1111

1212
import mlx.core as mx
13-
from huggingface_hub import snapshot_download
1413
from mlx import nn
1514
from mlx.utils import tree_unflatten
1615
from mlx_lm.models.switch_layers import QuantizedSwitchLinear, SwitchLinear
@@ -19,6 +18,7 @@
1918
from mlx_lm.utils import _download, load_config
2019

2120
from parallax.server.model import ShardedModel
21+
from parallax.utils.model_download import download_model_snapshot
2222
from parallax.utils.tokenizer_utils import load_tokenizer
2323
from parallax.utils.utils import normalize_model_config
2424
from parallax_utils.logging_config import get_logger
@@ -195,7 +195,7 @@ def load_lora(self, base_model: nn.Module, adapter_path: str) -> nn.Module:
195195
logger.info(
196196
f"Adapter path {adapter_path} not found locally. Attempting to download from Hugging Face..."
197197
)
198-
downloaded_path = snapshot_download(
198+
downloaded_path = download_model_snapshot(
199199
repo_id=str(adapter_path), local_dir=str(adapter_path)
200200
)
201201
adapter_path = pathlib.Path(downloaded_path)
@@ -236,14 +236,14 @@ def load(
236236
A tuple containing the loaded sharded MLX model and its configuration dictionary.
237237
"""
238238
if use_selective_download and self.start_layer is not None and self.end_layer is not None:
239-
from parallax.utils.selective_download import (
240-
get_model_path_with_selective_download,
239+
from parallax.utils.model_download import (
240+
selective_model_download,
241241
)
242242

243243
logger.info(
244244
f"Using selective download for layers [{self.start_layer}, {self.end_layer})"
245245
)
246-
model_path = get_model_path_with_selective_download(
246+
model_path = selective_model_download(
247247
self.model_path_str,
248248
start_layer=self.start_layer,
249249
end_layer=self.end_layer,

src/parallax/sglang/model_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,12 +303,12 @@ def initialize_sgl_model_runner(
303303
use_hfcache = kwargs.get("use_hfcache", False)
304304
nccl_port = kwargs.get("nccl_port", None)
305305
# Use selective download for GPU models to save bandwidth and disk space
306-
from parallax.utils.selective_download import get_model_path_with_selective_download
306+
from parallax.utils.model_download import selective_model_download
307307

308308
logger.info(
309309
f"Downloading model with selective weight files for layers [{start_layer}, {end_layer})"
310310
)
311-
model_path = get_model_path_with_selective_download(
311+
model_path = selective_model_download(
312312
model_repo, start_layer=start_layer, end_layer=end_layer, local_files_only=use_hfcache
313313
)
314314

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import logging
2+
import os
3+
from pathlib import Path
4+
from typing import Optional
5+
6+
from huggingface_hub import hf_hub_download as _hf_hub_download
7+
from huggingface_hub import snapshot_download as _snapshot_download
8+
from modelscope import snapshot_download as _ms_snapshot_download
9+
from modelscope.hub.file_download import model_file_download as _ms_model_file_download
10+
11+
from parallax.utils.weight_filter_utils import (
12+
determine_needed_weight_files_for_download,
13+
)
14+
15+
logger = logging.getLogger(__name__)
16+
_USE_MODELSCOPE_ENV = "USE_MODELSCOPE"
17+
18+
__all__ = [
19+
"download_model_file",
20+
"download_model_snapshot",
21+
"selective_model_download",
22+
]
23+
24+
25+
def download_model_snapshot(
26+
repo_id: str,
27+
allow_patterns: Optional[list[str] | str] = None,
28+
ignore_patterns: Optional[list[str] | str] = None,
29+
local_dir: Optional[str | Path] = None,
30+
local_files_only: bool = False,
31+
) -> Path:
32+
if _use_modelscope():
33+
return Path(
34+
_ms_snapshot_download(
35+
model_id=repo_id,
36+
allow_patterns=allow_patterns,
37+
ignore_patterns=ignore_patterns,
38+
local_dir=str(local_dir) if local_dir is not None else None,
39+
local_files_only=local_files_only,
40+
)
41+
)
42+
43+
return Path(
44+
_snapshot_download(
45+
repo_id=repo_id,
46+
allow_patterns=allow_patterns,
47+
ignore_patterns=ignore_patterns,
48+
local_dir=local_dir,
49+
local_files_only=local_files_only,
50+
)
51+
)
52+
53+
54+
def download_model_file(
55+
repo_id: str,
56+
filename: str,
57+
local_files_only: bool = False,
58+
) -> Path:
59+
if _use_modelscope():
60+
return Path(
61+
_ms_model_file_download(
62+
model_id=repo_id,
63+
file_path=filename,
64+
local_files_only=local_files_only,
65+
)
66+
)
67+
68+
return Path(
69+
_hf_hub_download(
70+
repo_id=repo_id,
71+
filename=filename,
72+
local_files_only=local_files_only,
73+
)
74+
)
75+
76+
77+
def selective_model_download(
78+
repo_id: str,
79+
start_layer: Optional[int] = None,
80+
end_layer: Optional[int] = None,
81+
local_files_only: bool = False,
82+
) -> Path:
83+
local_path = Path(repo_id)
84+
if local_path.exists():
85+
logger.debug(f"Using local model path: {local_path}")
86+
return local_path
87+
88+
logger.debug(f"Downloading model metadata for {repo_id}")
89+
model_path = download_model_snapshot(
90+
repo_id=repo_id,
91+
ignore_patterns=_EXCLUDE_WEIGHT_PATTERNS,
92+
local_files_only=local_files_only,
93+
)
94+
logger.debug(f"Downloaded model metadata to {model_path}")
95+
96+
if start_layer is not None and end_layer is not None:
97+
logger.debug(f"Determining required weight files for layers [{start_layer}, {end_layer})")
98+
99+
needed_weight_files = determine_needed_weight_files_for_download(
100+
model_path=model_path,
101+
start_layer=start_layer,
102+
end_layer=end_layer,
103+
)
104+
105+
if not needed_weight_files:
106+
logger.debug("Could not determine specific weight files, downloading all")
107+
download_model_snapshot(repo_id=repo_id, local_files_only=local_files_only)
108+
else:
109+
# Step 3: Download only the needed weight files
110+
logger.info(f"Downloading {len(needed_weight_files)} weight files")
111+
112+
for weight_file in needed_weight_files:
113+
# Check if file already exists in local cache before downloading
114+
weight_file_path = model_path / weight_file
115+
if weight_file_path.exists():
116+
continue
117+
118+
logger.debug(f"Downloading {weight_file}")
119+
try:
120+
download_model_file(
121+
repo_id=repo_id,
122+
filename=weight_file,
123+
local_files_only=local_files_only,
124+
)
125+
except Exception as e:
126+
logger.error(f"Failed to download {weight_file} for {repo_id}: {e}")
127+
logger.error(
128+
"This node cannot reach Hugging Face Hub to download weight files. "
129+
"Please check network connectivity or pre-download the model."
130+
)
131+
raise
132+
133+
logger.debug(f"Downloaded weight files for layers [{start_layer}, {end_layer})")
134+
else:
135+
logger.debug("No layer range specified, downloading all model files")
136+
download_model_snapshot(repo_id=repo_id, local_files_only=local_files_only)
137+
138+
return model_path
139+
140+
141+
_EXCLUDE_WEIGHT_PATTERNS = [
142+
"*.safetensors",
143+
"*.bin",
144+
"*.pt",
145+
"*.pth",
146+
"pytorch_model*.bin",
147+
"model*.safetensors",
148+
"weight*.safetensors",
149+
]
150+
151+
152+
def _use_modelscope() -> bool:
153+
return _USE_MODELSCOPE_ENV in os.environ

0 commit comments

Comments
 (0)