forked from NVIDIA/Model-Optimizer
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathexample_utils.py
More file actions
executable file
·743 lines (612 loc) · 28.4 KB
/
example_utils.py
File metadata and controls
executable file
·743 lines (612 loc) · 28.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import glob
import inspect
import json
import os
import shutil
import sys
import warnings
from pathlib import Path
from typing import Any
import torch
import transformers
from accelerate import infer_auto_device_map, init_empty_weights
from accelerate.utils import get_max_memory
from safetensors.torch import load_file
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
AutoProcessor,
AutoTokenizer,
PreTrainedTokenizerBase,
ProcessorMixin,
)
try:
from huggingface_hub import snapshot_download
except ImportError:
snapshot_download = None
import modelopt.torch.quantization as mtq
from modelopt.torch.utils.image_processor import BaseImageProcessor, MllamaImageProcessor
SPECULATIVE_MODEL_LIST = ["Eagle", "Medusa"]
def run_nemotron_vl_preview(
full_model, tokenizer, input_ids, pyt_ckpt_path, stage_name, allow_fallback=False
):
"""Run text-only and VL preview generation for Nemotron VL models.
Args:
full_model: The full VL model
tokenizer: The tokenizer
input_ids: Input tensor for generation
pyt_ckpt_path: Path to the model checkpoint
stage_name: Description of the stage (e.g., "before quantization", "after quantization")
allow_fallback: Whether to allow fallback to standard generate on failure
Returns:
Generated text response or None if generation failed
"""
from vlm_utils import run_text_only_generation, run_vl_preview_generation
print(f"Running text-only preview generation for Nemotron VL model ({stage_name})...")
question = tokenizer.decode(input_ids[0], skip_special_tokens=True)
generation_config = {
"max_new_tokens": 100,
"do_sample": False,
"eos_token_id": tokenizer.eos_token_id,
}
# Try text-only generation (may fail for encoder-decoder models like Nemotron-Parse)
text_response = run_text_only_generation(
full_model, tokenizer, question, generation_config, pyt_ckpt_path
)
generated_ids = None
if text_response is not None:
print(f"✅ Text-only generation successful: {text_response[:100]}...")
generated_ids = text_response
elif allow_fallback:
print("Text-only generation failed, falling back to standard generate...")
generated_ids = full_model.generate(input_ids, max_new_tokens=100)
# Run additional VL test with images
print(f"Running additional VL test with images ({stage_name})...")
run_vl_preview_generation(full_model, tokenizer, pyt_ckpt_path, stage_name)
return generated_ids
def _is_multimodal_config(config):
"""Check if a config indicates a multimodal model (config-only version of is_multimodal_model)."""
return (
hasattr(config, "vision_config") # Standard vision config (e.g., Qwen2.5-VL)
or getattr(config, "model_type", "") == "phi4mm" # Phi-4 multimodal
or hasattr(config, "vision_lora") # Vision LoRA configurations
or hasattr(config, "audio_processor") # Audio processing capabilities
or (
hasattr(config, "embd_layer") and hasattr(config.embd_layer, "image_embd_layer")
) # Image embedding layers
or getattr(config, "is_encoder_decoder", False) # Encoder-decoder VL models
or any( # Architecture-based detection for custom VL models (e.g., Nemotron-Parse)
"conditionalgeneration" in arch.lower() for arch in getattr(config, "architectures", [])
)
)
def is_nemotron_vl(model_or_config):
"""Check if model or config indicates a Nemotron VL model.
Args:
model_or_config: Either a model instance or a config object.
Returns:
bool: True if it's a Nemotron VL model, False otherwise.
"""
# Try to get config from model, or use directly if it's a config
if hasattr(model_or_config, "config"):
config = model_or_config.config
from modelopt.torch.export.model_utils import is_multimodal_model
if not is_multimodal_model(model_or_config):
return False
else:
config = model_or_config
if not _is_multimodal_config(config):
return False
architectures = getattr(config, "architectures", [])
return any("nemotron" in arch.lower() for arch in architectures)
def create_vlm_calibration_loop(full_model, calib_dataloader):
"""Create a calibration loop for VLM models that handles multimodal inputs.
This function inspects the model's forward signature and filters batch kwargs
to only include supported parameters, then calls the appropriate forward method.
Args:
full_model: The full VLM model
calib_dataloader: DataLoader yielding multimodal batches
Returns:
A calibration function that can be passed to mtq.quantize()
"""
# Import here to avoid circular dependency
from nemotron_vl_calib import safe_nemotron_vl_forward
def calibrate_loop(_model):
# Inspect model's forward signature to determine what parameters it accepts
forward_params = inspect.signature(full_model.forward).parameters
accepts_kwargs = any(
p.kind == inspect.Parameter.VAR_KEYWORD for p in forward_params.values()
)
allowed_keys = set(forward_params.keys())
# Check if model is encoder-decoder (needs decoder_input_ids instead of input_ids)
is_enc_dec = getattr(full_model.config, "is_encoder_decoder", False)
full_model.eval()
with torch.no_grad():
for batch in calib_dataloader:
# For encoder-decoder models, rename input_ids → decoder_input_ids
# and disable KV caching to avoid tuple index errors in decoder layers
if is_enc_dec and "input_ids" in batch and "pixel_values" in batch:
batch["decoder_input_ids"] = batch.pop("input_ids")
if "attention_mask" in batch:
batch["decoder_attention_mask"] = batch.pop("attention_mask")
batch["use_cache"] = False
# Filter batch to only include parameters the model accepts
if accepts_kwargs:
call_kwargs = batch
else:
call_kwargs = {k: v for k, v in batch.items() if k in allowed_keys}
# Remove None values
call_kwargs = {k: v for k, v in call_kwargs.items() if v is not None}
# Use safe_nemotron_vl_forward for Nemotron Nano VL (embedding-injection style)
# For other VLMs (like Nemotron-Parse), use standard forward
if hasattr(full_model, "img_context_token_id"):
safe_nemotron_vl_forward(full_model, call_kwargs)
else:
full_model(**call_kwargs)
return calibrate_loop
def build_quant_cfg(
qformat,
kv_cache_qformat,
awq_block_size,
model_type,
quant_cfg_choices,
kv_quant_cfg_choices,
moe_calib_experts_ratio: float | None = None,
) -> dict[str, Any]:
quant_cfg = {}
assert qformat in quant_cfg_choices, (
f"Unsupported quantization format: {qformat} with {kv_cache_qformat} KV cache"
)
quant_cfg = quant_cfg_choices[qformat]
if "awq" in qformat:
quant_cfg = copy.deepcopy(quant_cfg_choices[qformat])
weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"]
if isinstance(weight_quantizer, list):
weight_quantizer = weight_quantizer[0]
# If awq_block_size argument is provided, update weight_quantizer
if awq_block_size:
weight_quantizer["block_sizes"][-1] = awq_block_size
# Coarser optimal scale search seems to resolve the overflow in TRT-LLM for some models
if qformat == "w4a8_awq" and model_type in ["gemma", "mpt"]:
quant_cfg["algorithm"] = {"method": "awq_lite", "alpha_step": 1}
enable_quant_kv_cache = kv_cache_qformat != "none"
print(f"{'Enable' if enable_quant_kv_cache else 'Disable'} KV cache quantization")
# Check if any bmm_quantizer is in the quant_cfg. If so, we need to enable the bmm_quantizer.
if enable_quant_kv_cache:
quant_cfg = mtq.update_quant_cfg_with_kv_cache_quant(
quant_cfg,
getattr(mtq, kv_quant_cfg_choices[kv_cache_qformat])["quant_cfg"],
)
if moe_calib_experts_ratio:
assert 0 < moe_calib_experts_ratio <= 1, "moe_calib_experts_ratio must be between 0 and 1"
if isinstance(quant_cfg["algorithm"], str):
quant_cfg["algorithm"] = {
"method": quant_cfg["algorithm"],
"moe_calib_experts_ratio": moe_calib_experts_ratio,
}
elif isinstance(quant_cfg["algorithm"], dict):
quant_cfg["algorithm"]["moe_calib_experts_ratio"] = moe_calib_experts_ratio
else:
warnings.warn(
f"Quantization algorithm: {quant_cfg['algorithm']} does not support setting moe_calib_experts_ratio"
)
# Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead.
if model_type == "gemma" and "int8_sq" in qformat:
quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5}
if model_type == "phi4mm":
# Only quantize the language model
quant_cfg["quant_cfg"]["*speech*"] = {"enable": False}
quant_cfg["quant_cfg"]["*audio*"] = {"enable": False}
quant_cfg["quant_cfg"]["*image*"] = {"enable": False}
quant_cfg["quant_cfg"]["*vision*"] = {"enable": False}
return quant_cfg
def is_speculative(hf_config):
"""Check if the model architecture is a speculative model."""
return hf_config.architectures and any(
name in hf_config.architectures[0] for name in SPECULATIVE_MODEL_LIST
)
def get_tokenizer(ckpt_path, trust_remote_code=False, **kwargs) -> PreTrainedTokenizerBase:
print(f"Initializing tokenizer from {ckpt_path}")
if "vila" in ckpt_path.lower():
ckpt_path += "/llm"
tokenizer = AutoTokenizer.from_pretrained(
ckpt_path, trust_remote_code=trust_remote_code, **kwargs
)
# can't set attribute 'pad_token' for "<unk>"
# We skip this step for Nemo models
if tokenizer.pad_token != "<unk>" or tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
assert tokenizer.pad_token is not None, f"Pad token for {ckpt_path} cannot be set!"
return tokenizer
def get_processor(
ckpt_path,
model_type,
device: torch.device = "auto",
trust_remote_code=False,
attn_implementation=None,
) -> BaseImageProcessor | ProcessorMixin | None:
"""
Returns a :class:`modelopt.torch.utils.image_processor.MllamaImageProcessor` object.
"""
model_kwargs = {"trust_remote_code": trust_remote_code}
if attn_implementation is not None:
model_kwargs["attn_implementation"] = attn_implementation
if model_type == "whisper":
processor = AutoProcessor.from_pretrained(
ckpt_path,
padding_side="left",
**model_kwargs,
)
if processor.tokenizer.pad_token is None:
processor.tokenizer.pad_token = processor.tokenizer.eos_token
assert processor.tokenizer.pad_token is not None, (
f"Pad token for {ckpt_path} cannot be set!"
)
return processor
elif model_type == "mllama":
processor = AutoProcessor.from_pretrained(
ckpt_path,
padding_side="left",
**model_kwargs,
)
if processor.tokenizer.pad_token is None:
processor.tokenizer.pad_token = processor.tokenizer.eos_token
assert processor.tokenizer.pad_token is not None, (
f"Pad token for {ckpt_path} cannot be set!"
)
return MllamaImageProcessor(processor, device)
else:
# Try to load AutoProcessor for other VL models (e.g., Nemotron-Parse)
try:
processor = AutoProcessor.from_pretrained(ckpt_path, **model_kwargs)
print(f"Loaded AutoProcessor for model type: {model_type}")
return processor
except Exception as e:
print(f"Could not load processor for {model_type}: {e}")
return None
def load_mtp_weights(
model: torch.nn.Module, model_path: str
) -> tuple[list[str], dict[str, torch.Tensor]]:
"""Load MTP weights from the model checkpoint.
Some models store additional layers in separate safetensors files with non-standard
names (e.g., mtp.safetensors). HuggingFace's from_pretrained() may not load these
files even though they're referenced in model.safetensors.index.json.
This function detects such cases and explicitly loads the missing weights.
Args:
model: The loaded model that may be missing weights
model_path: Path to the model directory
Returns:
List of layer prefixes that were loaded from non-standard safetensors files.
These layers should typically be excluded from quantization.
Empty list if no additional weights were loaded.
Dictionary of MTP weights that were not loaded into the model state dict.
"""
model_path = Path(model_path)
index_file = model_path / "model.safetensors.index.json"
if not index_file.exists():
return [], {}
# Load the index to find all referenced safetensors files
index = json.load(open(index_file))
weight_map = index["weight_map"]
# Find all files in weight_map whose key or value contains "mtp"
mtp_weight_map = {}
for k, v in weight_map.items():
if "mtp" in k or "mtp" in v:
mtp_weight_map.setdefault(v, []).append(k)
if not mtp_weight_map:
return [], {}
def _extract_layer_prefixes(keys):
mtp_layer_prefixes = set()
for key in keys:
parts = key.split(".")
for i, part in enumerate(parts):
if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
prefix = ".".join(parts[: i + 2])
mtp_layer_prefixes.add(prefix)
break
return mtp_layer_prefixes
# Flatten mtp_weight_map.values() (list of list of str) to a single list of str
mtp_keys = [k for keys in mtp_weight_map.values() for k in keys]
mtp_layer_prefixes = _extract_layer_prefixes(mtp_keys)
# Check which non-standard files exist and have missing weights
model_state = model.state_dict()
total_loaded = 0
not_in_state_dict = {}
for filename, mtp_keys in mtp_weight_map.items():
filepath = model_path / filename
if not filepath.exists():
continue
print(f"Loading {len(mtp_keys)} mtp weights from {filename}...")
weights = load_file(str(filepath), device="cpu")
weights = {k: v for k, v in weights.items() if k in mtp_keys}
# Load the MTP weights to the model state dict
in_state_dict = {k: weights[k] for k in weights if k in model_state}
not_in_state_dict = not_in_state_dict | {
k: weights[k] for k in weights if k not in model_state
}
if in_state_dict:
model.load_state_dict(in_state_dict, strict=False)
total_loaded += len(in_state_dict)
if total_loaded > 0:
print(
f"✓ Successfully loaded {total_loaded} MTP weights, "
f"{len(not_in_state_dict)} MTP weights not in model.state_dict"
)
if mtp_layer_prefixes:
print(f"✓ Detected MTP layers to exclude from quantization: {mtp_layer_prefixes}")
return list(mtp_layer_prefixes), not_in_state_dict
def get_dtype(dtype):
if dtype == "bf16":
dtype = torch.bfloat16
elif dtype == "fp16":
dtype = torch.float16
elif dtype == "fp32":
dtype = torch.float32
else:
raise NotImplementedError(f"Unknown dtype {dtype}")
return dtype
def get_model(
ckpt_path,
device="cuda",
gpu_mem_percentage=0.8,
trust_remote_code=False,
use_seq_device_map=False,
attn_implementation=None,
):
print(f"Initializing model from {ckpt_path}")
device_map = "auto"
if device == "cpu":
device_map = "cpu"
# Add VILA to sys.path before loading config if needed
if "vila" in ckpt_path.lower():
vila_path = os.path.join(ckpt_path, "..", "VILA")
if vila_path not in sys.path:
sys.path.append(vila_path)
from llava.model import LlavaLlamaConfig, LlavaLlamaModel # noqa: F401
# Prepare config kwargs for loading
config_kwargs = {"trust_remote_code": trust_remote_code} if trust_remote_code else {}
# Load config once and handle VL model detection
try:
hf_config = AutoConfig.from_pretrained(ckpt_path, **config_kwargs)
if is_nemotron_vl(hf_config):
print(
"Detected Nemotron VL model from config. "
"Disabling automatic device mapping for compatibility."
)
device_map = None
except Exception as e:
print(f"Error: Could not load config from {ckpt_path}: {e}")
raise RuntimeError(f"Failed to load model configuration from {ckpt_path}") from e
if attn_implementation is not None:
config_kwargs["attn_implementation"] = attn_implementation
# Note: Forcibly converting the model precision between bf16 and fp16 may introduce accuracy drop
model_kwargs = config_kwargs.copy()
# Don't set torch_dtype for VILA models as they handle it explicitly in their builder
if "vila" not in ckpt_path.lower():
model_kwargs.setdefault("torch_dtype", "auto")
if "vila" in ckpt_path.lower():
hf_vila = AutoModel.from_pretrained(
ckpt_path,
device_map=device_map,
**model_kwargs,
)
model = hf_vila.llm
else:
if use_seq_device_map:
device_map = "sequential"
# If we use sequential, set max_memory limit to ensure that the model does not occupy the full GPU
max_memory = get_max_memory()
max_memory = {key: value * gpu_mem_percentage for key, value in max_memory.items()}
model_kwargs["max_memory"] = max_memory
if hf_config.model_type == "bart":
# device_map "auto" and "cuda" triggers error regarding meta tensor from safetensors
device_map = None
if is_speculative(hf_config):
model = AutoModelForCausalLM.from_pretrained(
ckpt_path,
device_map=device_map,
**model_kwargs,
)
elif (
hasattr(hf_config, "quantization_config")
and hf_config.quantization_config.get("format", None) == "pack-quantized"
):
torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16)
model = AutoModelForCausalLM.from_pretrained(
ckpt_path,
device_map="auto",
trust_remote_code=trust_remote_code,
torch_dtype=torch_dtype,
)
else:
architecture = hf_config.architectures[0]
if not hasattr(transformers, architecture) or "Deepseek" in architecture:
if not hasattr(transformers, architecture):
warnings.warn(
f"Architecture {architecture} not found in transformers: {transformers.__version__}. "
"Falling back to AutoModelForCausalLM (or AutoModel for non-causal architectures)."
)
assert trust_remote_code, (
"Please set trust_remote_code to True if you want to use this architecture"
)
# Use AutoModelForCausalLM for causal LMs, AutoModel for encoder-decoder models
if getattr(hf_config, "is_encoder_decoder", False):
auto_model_module = AutoModel
else:
auto_model_module = AutoModelForCausalLM
from_config = auto_model_module.from_config
else:
auto_model_module = getattr(transformers, architecture)
from_config = auto_model_module._from_config
with init_empty_weights():
# When computing the device_map, assuming bfloat16 precision by default,
# unless specified by the hf_config.
torch_dtype = getattr(hf_config, "torch_dtype", torch.bfloat16)
model_kwargs2 = model_kwargs.copy()
if auto_model_module not in [AutoModelForCausalLM, AutoModel]:
model_kwargs2.pop("trust_remote_code", None)
model_kwargs2["torch_dtype"] = torch_dtype
model_kwargs2.pop("max_memory", None)
model = from_config(hf_config, **model_kwargs2)
max_memory = get_max_memory()
inferred_device_map = infer_auto_device_map(model, max_memory=max_memory)
on_cpu = "cpu" in inferred_device_map.values()
if on_cpu:
for _device in max_memory:
if isinstance(_device, int):
max_memory[_device] *= gpu_mem_percentage
print(
"Model does not fit to the GPU mem. "
f"We apply the following memory limit for calibration: \n{max_memory}\n"
"If you hit GPU OOM issue, please adjust `gpu_mem_percentage` or "
"reduce the calibration `batch_size` manually."
)
model_kwargs["max_memory"] = max_memory
model = auto_model_module.from_pretrained(
ckpt_path,
device_map=device_map,
**model_kwargs,
)
model.eval()
# If device_map was disabled (None), manually move model to target device
if device_map is None and device != "cpu":
print(f"Moving model to {device} device...")
model = model.to(device)
if device == "cuda" and not is_model_on_gpu(model):
print("Warning: Some parameters are not on a GPU. Calibration can be slow or hit OOM")
return model
def is_model_on_gpu(model) -> bool:
"""Returns if the model is fully loaded on GPUs."""
return all("cuda" in str(param.device) for param in model.parameters())
def is_enc_dec(model_type) -> bool:
"""Return if the model is a encoder-decoder model."""
return model_type in ["t5", "bart", "whisper"]
def _resolve_model_path(model_name_or_path: str, trust_remote_code: bool = False) -> str:
"""Resolve a model name or path to a local directory path.
If the input is already a local directory, returns it as-is.
If the input is a HuggingFace model ID, attempts to resolve it to the local cache path.
Args:
model_name_or_path: Either a local directory path or HuggingFace model ID
trust_remote_code: Whether to trust remote code when loading the model
Returns:
Local directory path to the model files
"""
# If it's already a local directory, return as-is
if os.path.isdir(model_name_or_path):
return model_name_or_path
# Try to resolve HuggingFace model ID to local cache path
try:
# First try to load the config to trigger caching
config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code)
# The config object should have the local path information
# Try different ways to get the cached path
if hasattr(config, "_name_or_path") and os.path.isdir(config._name_or_path):
return config._name_or_path
# Alternative: use snapshot_download if available
if snapshot_download is not None:
try:
local_path = snapshot_download(
repo_id=model_name_or_path,
allow_patterns=["*.py", "*.json"], # Only download Python files and config
)
return local_path
except Exception as e:
print(f"Warning: Could not download model files using snapshot_download: {e}")
# Fallback: try to find in HuggingFace cache
from transformers.utils import TRANSFORMERS_CACHE
# Look for the model in the cache directory
cache_pattern = os.path.join(TRANSFORMERS_CACHE, "models--*")
cache_dirs = glob.glob(cache_pattern)
# Convert model name to cache directory format
model_cache_name = model_name_or_path.replace("/", "--")
for cache_dir in cache_dirs:
if model_cache_name in cache_dir:
# Look for the snapshots directory
snapshots_dir = os.path.join(cache_dir, "snapshots")
if os.path.exists(snapshots_dir):
# Get the latest snapshot
snapshot_dirs = [
d
for d in os.listdir(snapshots_dir)
if os.path.isdir(os.path.join(snapshots_dir, d))
]
if snapshot_dirs:
latest_snapshot = max(snapshot_dirs) # Use lexicographically latest
snapshot_path = os.path.join(snapshots_dir, latest_snapshot)
return snapshot_path
except Exception as e:
print(f"Warning: Could not resolve model path for {model_name_or_path}: {e}")
# If all else fails, return the original path
# This will cause the copy function to skip with a warning
return model_name_or_path
def copy_custom_model_files(source_path: str, export_path: str, trust_remote_code: bool = False):
"""Copy custom model files (configuration_*.py, modeling_*.py, *.json, etc.) from source to export directory.
This function copies custom Python files and JSON configuration files that are needed for
models with custom code. It excludes config.json and model.safetensors.index.json as these
are typically handled separately by the model export process.
Args:
source_path: Path to the original model directory or HuggingFace model ID
export_path: Path to the exported model directory
trust_remote_code: Whether trust_remote_code was used (only copy files if True)
"""
if not trust_remote_code:
return
# Resolve the source path (handles both local paths and HF model IDs)
resolved_source_path = _resolve_model_path(source_path, trust_remote_code)
source_dir = Path(resolved_source_path)
export_dir = Path(export_path)
if not source_dir.exists():
if resolved_source_path != source_path:
print(
f"Warning: Could not find local cache for HuggingFace model '{source_path}' "
f"(resolved to '{resolved_source_path}')"
)
else:
print(f"Warning: Source directory '{source_path}' does not exist")
return
if not export_dir.exists():
print(f"Warning: Export directory {export_path} does not exist")
return
# Common patterns for custom model files that need to be copied
custom_file_patterns = [
"configuration_*.py",
"modeling*.py",
"tokenization_*.py",
"processing_*.py",
"image_processing*.py",
"feature_extraction_*.py",
"*.json",
]
copied_files = []
for pattern in custom_file_patterns:
for file_path in source_dir.glob(pattern):
if file_path.is_file():
# Skip config.json and model.safetensors.index.json as they're handled separately
if file_path.name in ["config.json", "model.safetensors.index.json"]:
continue
dest_path = export_dir / file_path.name
try:
shutil.copy2(file_path, dest_path)
copied_files.append(file_path.name)
print(f"Copied custom model file: {file_path.name}")
except Exception as e:
print(f"Warning: Failed to copy {file_path.name}: {e}")
if copied_files:
print(f"Successfully copied {len(copied_files)} custom model files to {export_path}")
else:
print("No custom model files found to copy")