|
| 1 | +import json |
1 | 2 | import logging |
2 | 3 | import os |
3 | 4 | from pathlib import Path |
4 | | -from typing import Optional |
| 5 | +from typing import Dict, List, Optional, Set |
5 | 6 |
|
6 | 7 | from huggingface_hub import hf_hub_download as _hf_hub_download |
7 | 8 | from huggingface_hub import snapshot_download as _snapshot_download |
8 | 9 | from modelscope import snapshot_download as _ms_snapshot_download |
9 | 10 | from modelscope.hub.file_download import model_file_download as _ms_model_file_download |
10 | 11 |
|
11 | 12 | from parallax.utils.weight_filter_utils import ( |
12 | | - determine_needed_weight_files_for_download, |
| 13 | + normalize_language_model_weight_key, |
| 14 | + should_include_weight_key, |
13 | 15 | ) |
14 | 16 |
|
15 | 17 | logger = logging.getLogger(__name__) |
@@ -96,7 +98,7 @@ def selective_model_download( |
96 | 98 | if start_layer is not None and end_layer is not None: |
97 | 99 | logger.debug(f"Determining required weight files for layers [{start_layer}, {end_layer})") |
98 | 100 |
|
99 | | - needed_weight_files = determine_needed_weight_files_for_download( |
| 101 | + needed_weight_files = _determine_needed_weight_files_for_download( |
100 | 102 | model_path=model_path, |
101 | 103 | start_layer=start_layer, |
102 | 104 | end_layer=end_layer, |
@@ -149,5 +151,80 @@ def selective_model_download( |
149 | 151 | ] |
150 | 152 |
|
151 | 153 |
|
| 154 | +def _determine_needed_weight_files_for_download( |
| 155 | + model_path: Path, |
| 156 | + start_layer: int, |
| 157 | + end_layer: int, |
| 158 | + config: Optional[Dict] = None, |
| 159 | +) -> List[str]: |
| 160 | + is_first_shard = start_layer == 0 |
| 161 | + |
| 162 | + is_last_shard = False |
| 163 | + if config: |
| 164 | + num_hidden_layers = config.get("num_hidden_layers", 0) |
| 165 | + is_last_shard = end_layer >= num_hidden_layers |
| 166 | + else: |
| 167 | + config_file = model_path / "config.json" |
| 168 | + if config_file.exists(): |
| 169 | + from parallax.utils.utils import normalize_model_config |
| 170 | + |
| 171 | + with open(config_file, "r") as f: |
| 172 | + cfg = normalize_model_config(json.load(f)) |
| 173 | + num_hidden_layers = cfg.get("num_hidden_layers", 0) |
| 174 | + is_last_shard = end_layer >= num_hidden_layers |
| 175 | + |
| 176 | + index_file = model_path / "model.safetensors.index.json" |
| 177 | + |
| 178 | + if not index_file.exists(): |
| 179 | + logger.debug(f"Index file not found at {index_file}, checking for single weight file") |
| 180 | + # For non-sharded models, look for single weight file |
| 181 | + single_weight_files = [ |
| 182 | + "model.safetensors", |
| 183 | + "pytorch_model.bin", |
| 184 | + "model.bin", |
| 185 | + ] |
| 186 | + for weight_file in single_weight_files: |
| 187 | + if (model_path / weight_file).exists(): |
| 188 | + logger.debug(f"Found single weight file: {weight_file}") |
| 189 | + return [weight_file] |
| 190 | + |
| 191 | + logger.debug("No weight files found (neither index nor single file)") |
| 192 | + return [] |
| 193 | + |
| 194 | + with open(index_file, "r") as f: |
| 195 | + index_data = json.load(f) |
| 196 | + |
| 197 | + weight_map = index_data.get("weight_map", {}) |
| 198 | + if not weight_map: |
| 199 | + logger.debug("weight_map is empty in index file") |
| 200 | + return [] |
| 201 | + |
| 202 | + tie_word_embeddings = False |
| 203 | + if config: |
| 204 | + tie_word_embeddings = config.get("tie_word_embeddings", False) |
| 205 | + |
| 206 | + needed_files: Set[str] = set() |
| 207 | + |
| 208 | + for key, filename in weight_map.items(): |
| 209 | + if filename in needed_files: |
| 210 | + continue |
| 211 | + key = normalize_language_model_weight_key(key) |
| 212 | + if should_include_weight_key( |
| 213 | + key=key, |
| 214 | + start_layer=start_layer, |
| 215 | + end_layer=end_layer, |
| 216 | + is_first_shard=is_first_shard, |
| 217 | + is_last_shard=is_last_shard, |
| 218 | + tie_word_embeddings=tie_word_embeddings, |
| 219 | + ): |
| 220 | + needed_files.add(filename) |
| 221 | + |
| 222 | + result = sorted(list(needed_files)) |
| 223 | + logger.debug( |
| 224 | + f"Determined {len(result)} weight files needed for layers [{start_layer}, {end_layer})" |
| 225 | + ) |
| 226 | + return result |
| 227 | + |
| 228 | + |
152 | 229 | def _use_modelscope() -> bool: |
153 | 230 | return _USE_MODELSCOPE_ENV in os.environ |
0 commit comments