Skip to content

Commit bf4754e

Browse files
fix: Use single method to determine trust_remote_code value throughout (#229)
## Summary - Some `AutoConfig`, `AutoModelForCausalLM`, and `AutoTokenizer` `from_pretrained` calls and `vLLM` init were missing `trust_remote_code=True` for `nvidia/` models (e.g. Nemotron), causing `ValueError` when loading models with custom code - Consolidates the check into a single `trust_remote_code_for_model()` in `llm/utils.py`, called by all 8 `ModelMetadata` subclasses, `populate_derived_fields`, `LLMPromptConfig.from_tokenizer`, `HuggingFaceBackend`, and `VllmBackend` - Removes the redundant `TrainingBackend._trust_remote_code_for_model()` method ## Test plan - [x] Verified `nvidia/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16` loads successfully through the SDK `SafeSynthesizer` pipeline (requires extra deps to load, but the trust_remote_code won't cause a problem now) - [x] Added unit tests for `trust_remote_code_from_model` - [X] Unit tests pass (`make test`) ## Other notes Created #231 to followup and make modifying this behavior use configurable. Made with [Cursor](https://cursor.com) --------- Signed-off-by: Kendrick Boyd <kendrickb@nvidia.com> Signed-off-by: Aaron Gonzales <aagonzales@nvidia.com> Co-authored-by: Aaron Gonzales <aagonzales@nvidia.com>
1 parent 68abb4c commit bf4754e

7 files changed

Lines changed: 148 additions & 55 deletions

File tree

src/nemo_safe_synthesizer/generation/vllm_backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from ..generation.regex_manager import build_json_based_regex
3232
from ..generation.results import GenerateJobResults, GenerationBatches, GenerationStatus
3333
from ..llm.metadata import ModelMetadata
34-
from ..llm.utils import cleanup_memory, get_max_vram
34+
from ..llm.utils import cleanup_memory, get_max_vram, trust_remote_code_for_model
3535
from ..observability import get_logger, heartbeat
3636
from ..utils import all_equal_type, load_json
3737

@@ -212,6 +212,7 @@ def initialize(self, **kwargs) -> None:
212212
max_lora_rank=self.config.training.lora_r,
213213
structured_outputs_config=structured_outputs_config,
214214
attention_config=attention_config,
215+
trust_remote_code=trust_remote_code_for_model(self.config.training.pretrained_model),
215216
)
216217

217218
# vLLM's get_tokenizer() returns a wider union than HF's PreTrainedTokenizerBase;

src/nemo_safe_synthesizer/llm/metadata.py

Lines changed: 47 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
field_validator,
1717
model_validator,
1818
)
19-
from transformers import AutoConfig, AutoTokenizer, PretrainedConfig
19+
from transformers import AutoConfig, AutoTokenizer, PretrainedConfig, PreTrainedTokenizerBase
2020

2121
from ..cli.artifact_structure import Workdir
2222
from ..config.parameters import SafeSynthesizerParameters
@@ -27,6 +27,7 @@
2727
)
2828
from ..observability import get_logger
2929
from ..utils import load_json, write_json
30+
from .utils import trust_remote_code_for_model
3031

3132
logger = get_logger(__name__)
3233

@@ -77,7 +78,7 @@ class LLMPromptConfig(BaseModel):
7778
"""Integer id for the EOS token."""
7879

7980
@classmethod
80-
def from_tokenizer(cls, name: str, tokenizer: AutoTokenizer | None = None, **kwargs) -> LLMPromptConfig:
81+
def from_tokenizer(cls, name: str, tokenizer: PreTrainedTokenizerBase | None = None, **kwargs) -> LLMPromptConfig:
8182
"""Create a prompt config by reading from settings of a tokenizer.
8283
8384
If no ``tokenizer`` is supplied one is loaded from ``name``
@@ -94,7 +95,9 @@ def from_tokenizer(cls, name: str, tokenizer: AutoTokenizer | None = None, **kwa
9495
Returns:
9596
A new ``LLMPromptConfig`` populated from the tokenizer.
9697
"""
97-
tokenizer = tokenizer or AutoTokenizer.from_pretrained(name)
98+
tokenizer = tokenizer or AutoTokenizer.from_pretrained(
99+
name, trust_remote_code=trust_remote_code_for_model(name)
100+
)
98101
bos_token = kwargs.get("bos_token", getattr(tokenizer, "bos_token", None))
99102
bos_token_id = kwargs.get("bos_token_id", getattr(tokenizer, "bos_token_id", None))
100103
eos_token = kwargs.get("eos_token", getattr(tokenizer, "eos_token", None))
@@ -339,7 +342,11 @@ def populate_derived_fields(cls, data: dict) -> dict:
339342
The mutated ``data`` dict with derived fields populated.
340343
"""
341344
if data.get("autoconfig") is None:
342-
data["autoconfig"] = AutoConfig.from_pretrained(data["model_name_or_path"])
345+
model_name_or_path = data["model_name_or_path"]
346+
data["autoconfig"] = AutoConfig.from_pretrained(
347+
model_name_or_path,
348+
trust_remote_code=trust_remote_code_for_model(model_name_or_path),
349+
)
343350

344351
if data.get("base_max_seq_length") is None:
345352
data["base_max_seq_length"] = get_base_max_seq_length(data["autoconfig"])
@@ -447,6 +454,32 @@ def save_metadata(self) -> None:
447454
indent=4,
448455
)
449456

457+
@staticmethod
458+
def _load_config_and_tokenizer(
459+
model_name_or_path: str,
460+
tokenizer: PreTrainedTokenizerBase | None = None,
461+
) -> tuple[PretrainedConfig, PreTrainedTokenizerBase]:
462+
"""Load ``PretrainedConfig`` and (optionally) ``AutoTokenizer`` for a model.
463+
464+
Centralises the repeated boilerplate present in every subclass
465+
``__init__``: loading the HuggingFace config and, when no
466+
pre-loaded tokenizer is supplied, fetching one via
467+
``AutoTokenizer.from_pretrained``.
468+
469+
Args:
470+
model_name_or_path: HuggingFace model identifier or local path.
471+
tokenizer: Pre-loaded tokenizer to reuse. When ``None`` a new
472+
one is loaded from ``model_name_or_path``.
473+
474+
Returns:
475+
A ``(config, tokenizer)`` tuple ready to pass to ``super().__init__``.
476+
"""
477+
trust = trust_remote_code_for_model(model_name_or_path)
478+
config: PretrainedConfig = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust)
479+
if tokenizer is None:
480+
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=trust)
481+
return config, tokenizer
482+
450483
@classmethod
451484
def _resolve_model_class(cls: type["ModelMetadata"], model_name_or_path: Path | str) -> type["ModelMetadata"]:
452485
"""Resolve model name or path to the matching metadata subclass.
@@ -588,8 +621,7 @@ def __init__(
588621
rope_scaling_factor: float | None = None,
589622
**kwargs,
590623
) -> None:
591-
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) if tokenizer is None else tokenizer
592-
config: PretrainedConfig = AutoConfig.from_pretrained(model_name_or_path)
624+
config, tokenizer = ModelMetadata._load_config_and_tokenizer(model_name_or_path, tokenizer)
593625

594626
super().__init__(
595627
autoconfig=config,
@@ -628,8 +660,7 @@ def __init__(
628660
rope_scaling_factor: float | None = None,
629661
**kwargs,
630662
) -> None:
631-
config: PretrainedConfig = AutoConfig.from_pretrained(model_name_or_path)
632-
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) if tokenizer is None else tokenizer
663+
config, tokenizer = ModelMetadata._load_config_and_tokenizer(model_name_or_path, tokenizer)
633664

634665
super().__init__(
635666
autoconfig=config,
@@ -668,12 +699,11 @@ class Mistral(ModelMetadata):
668699
def __init__(
669700
self,
670701
model_name_or_path: str,
671-
tokenizer: AutoTokenizer | None = None,
702+
tokenizer: PreTrainedTokenizerBase | None = None,
672703
rope_scaling_factor: float | None = None,
673704
**kwargs,
674705
) -> None:
675-
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name_or_path) if tokenizer is None else tokenizer
676-
config: PretrainedConfig = AutoConfig.from_pretrained(model_name_or_path)
706+
config, tokenizer = ModelMetadata._load_config_and_tokenizer(model_name_or_path, tokenizer)
677707
if rope_scaling_factor:
678708
logger.warning(
679709
f"Rope scaling factor {rope_scaling_factor} is not supported for Mistral due to longer default context lengths. Ignoring."
@@ -714,8 +744,7 @@ def __init__(
714744
rope_scaling_factor: float | None = None,
715745
**kwargs,
716746
) -> None:
717-
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(model_name_or_path) if tokenizer is None else tokenizer
718-
config: PretrainedConfig = AutoConfig.from_pretrained(model_name_or_path)
747+
config, tokenizer = ModelMetadata._load_config_and_tokenizer(model_name_or_path, tokenizer)
719748

720749
super().__init__(
721750
autoconfig=config,
@@ -751,8 +780,7 @@ def __init__(
751780
rope_scaling_factor: float | None = None,
752781
**kwargs,
753782
) -> None:
754-
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) if tokenizer is None else tokenizer
755-
config = AutoConfig.from_pretrained(model_name_or_path)
783+
config, tokenizer = ModelMetadata._load_config_and_tokenizer(model_name_or_path, tokenizer)
756784

757785
super().__init__(
758786
autoconfig=config,
@@ -792,14 +820,13 @@ def __init__(
792820
rope_scaling_factor: float | None = None,
793821
**kwargs,
794822
) -> None:
795-
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) if tokenizer is None else tokenizer
796-
config = AutoConfig.from_pretrained(model_name_or_path)
823+
config, tokenizer = ModelMetadata._load_config_and_tokenizer(model_name_or_path, tokenizer)
797824
if rope_scaling_factor:
798825
logger.warning(
799826
f"Rope scaling factor {rope_scaling_factor} is not supported for SmolLM2 due to longer default context lengths. Ignoring."
800827
)
801828

802-
im_start_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
829+
im_start_id = tokenizer.convert_tokens_to_ids("<|im_start|>") # ty: ignore[unresolved-attribute] -- third-party stub
803830
super().__init__(
804831
autoconfig=config,
805832
instruction=DEFAULT_INSTRUCTION,
@@ -840,8 +867,7 @@ def __init__(
840867
rope_scaling_factor: float | None = None,
841868
**kwargs,
842869
) -> None:
843-
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) if tokenizer is None else tokenizer
844-
config = AutoConfig.from_pretrained(model_name_or_path)
870+
config, tokenizer = ModelMetadata._load_config_and_tokenizer(model_name_or_path, tokenizer)
845871

846872
# we use the bos token here explicitly for support during group-by SFT.
847873
# the groupby assumes there is a bos token at the start of the prompt.
@@ -890,8 +916,7 @@ def __init__(
890916
rope_scaling_factor: float | None = None,
891917
**kwargs,
892918
) -> None:
893-
tokenizer = tokenizer or AutoTokenizer.from_pretrained(model_name_or_path)
894-
config = AutoConfig.from_pretrained(model_name_or_path)
919+
config, tokenizer = ModelMetadata._load_config_and_tokenizer(model_name_or_path, tokenizer)
895920

896921
super().__init__(
897922
autoconfig=config,

src/nemo_safe_synthesizer/llm/utils.py

Lines changed: 42 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,60 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
"""GPU memory management, quantization, device mapping, and tokenizer helpers for LLM loading."""
4+
"""GPU memory management, quantization, device mapping, and tokenizer helpers for LLM loading.
5+
6+
Optional LLM dependencies are imported inside the helpers that need them so
7+
lightweight utilities such as ``trust_remote_code_for_model`` remain usable
8+
without installing the full training or inference stack.
9+
"""
510

611
from __future__ import annotations
712

813
import gc
914
from pathlib import Path
10-
from typing import Any, Literal
11-
12-
import torch
13-
from accelerate import infer_auto_device_map, init_empty_weights
14-
from peft import (
15-
PeftModel,
16-
)
17-
from transformers import (
18-
AutoConfig,
19-
AutoModelForCausalLM,
20-
AutoTokenizer,
21-
BitsAndBytesConfig,
22-
PreTrainedTokenizer,
23-
)
15+
from typing import TYPE_CHECKING, Any, Literal
2416

2517
from ..observability import get_logger
2618

19+
if TYPE_CHECKING:
20+
from peft import PeftModel
21+
from transformers import AutoConfig, BitsAndBytesConfig, PreTrainedTokenizer
22+
2723
logger = get_logger(__name__)
2824

2925

3026
def trust_remote_code_for_model(model_name: str | Path) -> bool:
3127
"""Determine whether to trust remote code when loading a model.
3228
33-
Returns ``True`` only for models whose name starts with
34-
``"nvidia/"``.
29+
Returns ``True`` for NVIDIA-owned Hub model identifiers and for paths
30+
inside Hugging Face's encoded cache directory for NVIDIA models.
3531
3632
Args:
3733
model_name: HuggingFace model identifier or local path.
3834
3935
Returns:
4036
Whether to set ``trust_remote_code=True`` when loading the model.
4137
"""
42-
mn = str(model_name)
43-
return mn.startswith("nvidia/")
38+
model_ref = str(model_name).casefold()
39+
if model_ref.startswith("nvidia/"):
40+
return True
41+
42+
path_parts = Path(model_ref).parts
43+
while path_parts:
44+
match path_parts:
45+
case ("huggingface", "hub", *cache_parts):
46+
# Brittle by design: this mirrors Hugging Face's current cache path layout.
47+
return any(part.startswith("models--nvidia--") for part in cache_parts)
48+
case (_, *remaining):
49+
path_parts = remaining
50+
51+
return False
4452

4553

4654
def cleanup_memory() -> None:
4755
"""Run garbage collection and empty the CUDA cache."""
56+
import torch
57+
4858
gc.collect()
4959
with torch.no_grad():
5060
torch.cuda.empty_cache()
@@ -56,6 +66,7 @@ def gpu_stats() -> None:
5666
Queries CUDA device 0 and logs the peak reserved memory and total
5767
available memory in GiB.
5868
"""
69+
import torch
5970

6071
def round_gb(value: float) -> float:
6172
return round(value / 1024 / 1024 / 1024, 3)
@@ -80,6 +91,8 @@ def get_max_vram(max_vram_fraction: float | None = None) -> dict[int, float]:
8091
Returns:
8192
Mapping of CUDA device index to the usable memory fraction.
8293
"""
94+
import torch
95+
8396
if max_vram_fraction is None:
8497
max_vram_fraction = 0.8
8598
max_memory = {}
@@ -148,6 +161,8 @@ def get_param_from_config(
148161
Raises:
149162
ValueError: If neither ``model_name`` nor ``config`` is provided.
150163
"""
164+
from transformers import AutoConfig
165+
151166
if config is None:
152167
if model_name is None:
153168
raise ValueError("model_name is required if config is not provided")
@@ -170,6 +185,8 @@ def _get_auto_tokenizer(
170185
Returns:
171186
Configured ``PreTrainedTokenizer`` with BOS/EOS tokens enabled.
172187
"""
188+
from transformers import AutoTokenizer
189+
173190
tokenizer = AutoTokenizer.from_pretrained(
174191
model_name,
175192
model_max_length=max_position_embeddings,
@@ -204,6 +221,9 @@ def get_device_map(
204221
Returns:
205222
Ordered dictionary mapping layer names to device identifiers.
206223
"""
224+
from accelerate import infer_auto_device_map, init_empty_weights
225+
from transformers import AutoConfig, AutoModelForCausalLM
226+
207227
config = autoconfig or AutoConfig.from_pretrained(
208228
model_target,
209229
revision=revision,
@@ -253,6 +273,9 @@ def get_quantization_config(quantization_bits: Literal[4, 8]) -> BitsAndBytesCon
253273
Raises:
254274
ValueError: If ``quantization_bits`` is not 4 or 8.
255275
"""
276+
import torch
277+
from transformers import BitsAndBytesConfig
278+
256279
if quantization_bits == 4:
257280
return BitsAndBytesConfig(
258281
load_in_4bit=True,

src/nemo_safe_synthesizer/training/backend.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -296,13 +296,3 @@ def teardown(self) -> None:
296296
this runs even when training raises.
297297
"""
298298
pass
299-
300-
def _trust_remote_code_for_model(self) -> bool:
301-
"""Determine whether the model should be loaded with ``trust_remote_code=True``.
302-
303-
Currently returns ``True`` only for NVIDIA models on HuggingFace Hub.
304-
305-
Returns:
306-
Whether to trust remote code when loading the model.
307-
"""
308-
return str(self.params.training.pretrained_model).startswith("nvidia/")

src/nemo_safe_synthesizer/training/huggingface_backend.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
get_device_map,
5959
get_max_vram,
6060
get_quantization_config,
61+
trust_remote_code_for_model,
6162
)
6263
from ..observability import get_logger, traced_runtime, traced_user
6364
from ..privacy.dp_transformers.dp_utils import (
@@ -120,7 +121,8 @@ def __init__(self, *args, **kwargs):
120121
self.model_loader_type = AutoModelForCausalLM
121122
self.training_output_dir = Path(self.workdir.train.cache)
122123
self.autoconfig = AutoConfig.from_pretrained(
123-
self.params.training.pretrained_model, trust_remote_code=self._trust_remote_code_for_model()
124+
self.params.training.pretrained_model,
125+
trust_remote_code=trust_remote_code_for_model(self.params.training.pretrained_model),
124126
)
125127

126128
def _load_pretrained_model(self, **model_args: Any) -> None:
@@ -132,7 +134,9 @@ def _load_pretrained_model(self, **model_args: Any) -> None:
132134

133135
self.tokenizer: PreTrainedTokenizer = add_bos_eos_tokens_to_tokenizer(
134136
AutoTokenizer.from_pretrained(
135-
self.params.training.pretrained_model, model_max_length=model_args.get("max_seq_length", None)
137+
self.params.training.pretrained_model,
138+
trust_remote_code=trust_remote_code_for_model(self.params.training.pretrained_model),
139+
model_max_length=model_args.get("max_seq_length", None),
136140
)
137141
)
138142

@@ -202,10 +206,17 @@ def _build_base_framework_params(self, model_kwargs: dict) -> dict:
202206
Returns:
203207
Dictionary of parameters for ``from_pretrained``.
204208
"""
209+
trust_remote_code = trust_remote_code_for_model(self.params.training.pretrained_model)
205210
return dict(
206211
pretrained_model_name_or_path=self.params.training.pretrained_model,
212+
trust_remote_code=trust_remote_code,
207213
device_map=model_kwargs.pop(
208-
"device_map", get_device_map(self.params.training.pretrained_model, autoconfig=self.autoconfig)
214+
"device_map",
215+
get_device_map(
216+
self.params.training.pretrained_model,
217+
autoconfig=self.autoconfig,
218+
trust_remote_code=trust_remote_code,
219+
),
209220
),
210221
attn_implementation=model_kwargs.pop(
211222
"attn_implementation", self._resolve_attn_implementation(self.params.training.attn_implementation)

0 commit comments

Comments
 (0)