Skip to content

Commit fc118d2

Browse files
binaryaaronmckornfieldcursoragent
authored
fix: resolve cached model refs for training loads (#476)
## Summary followup to #473. expands using the `ModelRef` targets in metadata and Hugging Face training loads. --------- Signed-off-by: Matt Kornfield <mkornfield@nvidia.com> Signed-off-by: Aaron Gonzales <aagonzales@nvidia.com> Co-authored-by: Matt Kornfield <mkornfield@nvidia.com> Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 80fd838 commit fc118d2

15 files changed

Lines changed: 924 additions & 108 deletions

File tree

docs/user-guide/running.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ execute in order (`config` → `dataframe` → `metadata` → `advisory`).
275275
|-------|-------|-------------------|
276276
| `gpu.cuda` | config | PyTorch is importable and a CUDA GPU is visible |
277277
| `env.inference_key` | config | `NSS_INFERENCE_KEY` is set when PII classification is enabled (warning only) |
278-
| `env.hf_token` | config | `HF_TOKEN` or `HUGGING_FACE_HUB_TOKEN` is set; warns unconditionally when neither is present so gated-repo downloads don't fail later (warning only) |
278+
| `env.hf_model_availability` | config | The pretrained model reference is usable locally or can be fetched from Hugging Face; warns about a missing HF token only when online HF access may be needed |
279279
| `dataset.size` | dataframe | Training split meets the hard minimum row count |
280280
| `columns.groupby` | dataframe | `group_training_examples_by` column is present and has no nulls |
281281
| `columns.orderby` | dataframe | `order_training_examples_by` column is present |

docs/user-guide/troubleshooting.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,10 @@ check of its own.
467467
| `no_gpu` | error | `gpu.cuda` | No CUDA GPU detected (required for training or generation) |
468468
| `low_vram` | warning | `gpu.vram` | Free GPU VRAM may be insufficient |
469469
| `inference_key_missing` | warning | `env.inference_key` | `NSS_INFERENCE_KEY` not set; PII classification degraded |
470-
| `hf_token_missing` | warning | `env.hf_token` | Neither `HF_TOKEN` nor `HUGGING_FACE_HUB_TOKEN` set; gated model downloads may fail |
470+
| `hf_token_missing` | warning | `env.hf_model_availability` | Neither `HF_TOKEN` nor `HUGGING_FACE_HUB_TOKEN` set, and model loading may need online Hugging Face access |
471+
| `hf_model_not_cached` | warning/error | `env.hf_model_availability` | Hugging Face model is not present in the local cache; severity is error when HF offline mode is enabled |
472+
| `hf_model_cache_incomplete` | error | `env.hf_model_availability` | Cached Hugging Face model snapshot is missing required config, tokenizer, weights, or shards |
473+
| `hf_remote_code_not_cached` | warning/error | `env.hf_model_availability` | Trusted model references remote code that is not cached locally; severity is error when HF offline mode is enabled |
471474
| `preflight.check_crash` | error | (crashing check) | A check raised an unexpected exception; the issue's `check` field names the crashing check and other checks continued running |
472475
| `column_not_found` | error | `columns.groupby` / `columns.orderby` | Required column missing from dataset, or input DataFrame uses unsupported MultiIndex columns |
473476
| `column_nulls` | error | `columns.groupby` | Required column contains null values |

src/nemo_safe_synthesizer/llm/metadata.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from ..errors import ParameterError
2929
from ..observability import get_logger
3030
from ..utils import load_json, write_json
31-
from .utils import trust_remote_code_for_model
31+
from .utils import ModelRef
3232

3333
logger = get_logger(__name__)
3434

@@ -96,9 +96,12 @@ def from_tokenizer(cls, name: str, tokenizer: PreTrainedTokenizerBase | None = N
9696
Returns:
9797
A new ``LLMPromptConfig`` populated from the tokenizer.
9898
"""
99-
tokenizer = tokenizer or AutoTokenizer.from_pretrained(
100-
name, trust_remote_code=trust_remote_code_for_model(name)
101-
)
99+
if tokenizer is None:
100+
model_ref = ModelRef.parse(name)
101+
tokenizer = AutoTokenizer.from_pretrained(
102+
model_ref.target(),
103+
trust_remote_code=model_ref.trust_remote_code,
104+
)
102105
bos_token = kwargs.get("bos_token", getattr(tokenizer, "bos_token", None))
103106
bos_token_id = kwargs.get("bos_token_id", getattr(tokenizer, "bos_token_id", None))
104107
eos_token = kwargs.get("eos_token", getattr(tokenizer, "eos_token", None))
@@ -362,10 +365,11 @@ def populate_derived_fields(cls, data: dict) -> dict:
362365
"""
363366
if data.get("autoconfig") is None:
364367
model_name_or_path = data["model_name_or_path"]
368+
model_ref = ModelRef.parse(model_name_or_path)
365369
try:
366370
data["autoconfig"] = AutoConfig.from_pretrained(
367-
model_name_or_path,
368-
trust_remote_code=trust_remote_code_for_model(model_name_or_path),
371+
model_ref.target(),
372+
trust_remote_code=model_ref.trust_remote_code,
369373
)
370374
except OSError as err:
371375
raise _model_load_parameter_error(model_name_or_path, err) from err
@@ -496,11 +500,16 @@ def _load_config_and_tokenizer(
496500
Returns:
497501
A ``(config, tokenizer)`` tuple ready to pass to ``super().__init__``.
498502
"""
499-
trust = trust_remote_code_for_model(model_name_or_path)
503+
model_ref = ModelRef.parse(model_name_or_path)
500504
try:
501-
config: PretrainedConfig = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust)
505+
config: PretrainedConfig = AutoConfig.from_pretrained(
506+
model_ref.target(), trust_remote_code=model_ref.trust_remote_code
507+
)
502508
if tokenizer is None:
503-
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=trust)
509+
tokenizer = AutoTokenizer.from_pretrained(
510+
model_ref.target(),
511+
trust_remote_code=model_ref.trust_remote_code,
512+
)
504513
except OSError as err:
505514
raise _model_load_parameter_error(model_name_or_path, err) from err
506515
return config, tokenizer

src/nemo_safe_synthesizer/llm/utils.py

Lines changed: 183 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from __future__ import annotations
1212

1313
import gc
14+
import json
1415
from dataclasses import dataclass
1516
from fnmatch import fnmatchcase
1617
from pathlib import Path
@@ -27,7 +28,35 @@
2728

2829
@dataclass(frozen=True, slots=True)
2930
class ModelRef:
30-
"""Resolved model reference for local cache and trust policy decisions."""
31+
"""Resolved model reference for local cache and trust policy decisions.
32+
33+
Intended public API:
34+
- ``parse()`` normalizes a user-supplied model string or path without
35+
contacting Hugging Face.
36+
- ``target()`` returns the value that should be passed to
37+
``from_pretrained``-style loaders: a local snapshot path when available,
38+
otherwise the original model reference.
39+
- ``trust_remote_code`` reports whether the reference belongs to a trusted
40+
organization after accounting for resolved local HF cache paths.
41+
- ``partial_cached_snapshot()`` returns HF's local snapshot path for the
42+
repo/revision, even when the snapshot is incomplete.
43+
- ``missing_required_components()`` reports whether a local model directory
44+
has the components this project expects before an offline load.
45+
- ``missing_remote_code_components()`` reports trusted remote-code files
46+
referenced by Transformers ``auto_map`` metadata but absent locally.
47+
48+
Deliberate Hugging Face coupling:
49+
repo-id validation, cache-root resolution, cache scanning, snapshot layout,
50+
artifact names, tokenizer filenames, and sharded weight index parsing mirror
51+
current Hugging Face Hub and Transformers behavior. This is intentional so
52+
NSS decisions match the libraries that load the model. If model loading or
53+
cache preflight behavior changes after an upstream HF release, inspect this
54+
class first.
55+
56+
Internal helpers are not a generic model-layout abstraction. They should
57+
stay close to HF's implementation rather than grow compatibility shims for
58+
unrelated storage formats.
59+
"""
3160

3261
original: str | Path
3362
repo_id: str | None = None
@@ -36,6 +65,17 @@ class ModelRef:
3665
cache_root: Path | None = None
3766

3867
trusted_orgs: ClassVar[frozenset[str]] = frozenset({"nvidia"})
68+
tokenizer_artifact_names: ClassVar[frozenset[str]] = frozenset(
69+
{
70+
"tokenizer.json",
71+
"tokenizer.model",
72+
"sentencepiece.bpe.model",
73+
"spiece.model",
74+
"vocab.json",
75+
"vocab.txt",
76+
"merges.txt",
77+
}
78+
)
3979

4080
@classmethod
4181
def parse(
@@ -45,8 +85,18 @@ def parse(
4585
revision: str = "main",
4686
cache_root: str | Path | None = None,
4787
) -> Self:
48-
"""Parse a model identifier or path without contacting Hugging Face."""
88+
"""Parse a model identifier or path without contacting Hugging Face.
89+
90+
This is safe to call in preflight and loader setup because it uses
91+
Hugging Face's local cache APIs only. Cached-model hits may still cost a
92+
few milliseconds because HF cache scanning walks cache metadata to
93+
confirm model artifacts exist.
94+
"""
4995
cache_root_path = Path(cache_root) if cache_root is not None else cls._default_hf_cache_root()
96+
model_ref = str(model_name)
97+
if not model_ref:
98+
return cls(original=model_name, revision=revision, cache_root=cache_root_path)
99+
50100
model_path = Path(model_name)
51101
if model_path.exists():
52102
repo_id = cls._repo_id_from_hf_cache_path(model_path, cache_root_path)
@@ -58,7 +108,6 @@ def parse(
58108
cache_root=cache_root_path,
59109
)
60110

61-
model_ref = str(model_name)
62111
repo_id = cls._repo_id_from_hub_identifier(model_ref)
63112
local_path = cls._cached_snapshot_for_repo(repo_id, revision, cache_root_path) if repo_id else None
64113
return cls(
@@ -95,6 +144,12 @@ def _repo_id_from_hub_identifier(model_ref: str) -> str | None:
95144

96145
@staticmethod
97146
def _repo_id_from_hf_cache_path(path: Path, cache_root: Path) -> str | None:
147+
"""Return the HF repo id for a path inside the configured Hub cache.
148+
149+
This relies on ``huggingface_hub.scan_cache_dir`` and the current
150+
``models--org--repo/snapshots/<commit>`` cache model. It is deliberately
151+
not a generic path parser.
152+
"""
98153
path_resolved = path.resolve(strict=False)
99154
from huggingface_hub import scan_cache_dir
100155
from huggingface_hub.errors import CacheNotFound
@@ -114,7 +169,13 @@ def _repo_id_from_hf_cache_path(path: Path, cache_root: Path) -> str | None:
114169
return None
115170

116171
@staticmethod
117-
def _cached_snapshot_for_repo(repo_id: str, revision: str, cache_root: Path) -> Path | None:
172+
def _local_snapshot_for_repo(repo_id: str, revision: str, cache_root: Path) -> Path | None:
173+
"""Return HF's local snapshot path without validating completeness.
174+
175+
Delegates to ``snapshot_download(local_files_only=True)`` so behavior
176+
stays aligned with Hugging Face cache resolution instead of duplicating
177+
ref-file lookup rules.
178+
"""
118179
from huggingface_hub import snapshot_download
119180
from huggingface_hub.errors import LocalEntryNotFoundError
120181

@@ -129,7 +190,14 @@ def _cached_snapshot_for_repo(repo_id: str, revision: str, cache_root: Path) ->
129190
)
130191
except LocalEntryNotFoundError:
131192
return None
132-
if not ModelRef._snapshot_has_model_artifacts(snapshot_path, cache_root):
193+
return snapshot_path
194+
195+
@classmethod
196+
def _cached_snapshot_for_repo(cls, repo_id: str, revision: str, cache_root: Path) -> Path | None:
197+
snapshot_path = cls._local_snapshot_for_repo(repo_id, revision, cache_root)
198+
if snapshot_path is None:
199+
return None
200+
if not cls._snapshot_has_model_artifacts(snapshot_path, cache_root):
133201
return None
134202
return snapshot_path
135203

@@ -162,7 +230,11 @@ def _snapshot_has_model_artifacts(cls, snapshot_path: Path, cache_root: Path) ->
162230

163231
@staticmethod
164232
def _model_artifact_patterns() -> tuple[str, ...]:
165-
"""Return known model artifact names using HF Hub's public constants."""
233+
"""Return known model artifact names using HF Hub's public constants.
234+
235+
Keep this close to Hugging Face's weight naming conventions. New HF
236+
artifact names or index formats should be reflected here.
237+
"""
166238
from huggingface_hub.constants import (
167239
FLAX_WEIGHTS_NAME,
168240
PYTORCH_WEIGHTS_FILE_PATTERN,
@@ -187,6 +259,111 @@ def _model_artifact_patterns() -> tuple[str, ...]:
187259
"consolidated*.pth",
188260
)
189261

262+
@classmethod
263+
def _required_component_status(cls, model_dir: Path) -> dict[str, bool]:
264+
"""Return required local model component presence for a Transformers load.
265+
266+
The checks are intentionally shaped around ``from_pretrained`` layouts:
267+
root ``config.json``, recognized tokenizer files, and HF-style weight
268+
files or shard indexes. Revisit this if Transformers changes accepted
269+
directory layouts.
270+
"""
271+
files = [path for path in model_dir.rglob("*") if path.is_file()]
272+
return {
273+
"config": (model_dir / "config.json").is_file(),
274+
"tokenizer": any(path.name in cls.tokenizer_artifact_names for path in files),
275+
"model weights": cls._has_complete_model_artifacts(model_dir, files),
276+
}
277+
278+
@classmethod
279+
def missing_required_components(cls, model_dir: Path) -> list[str]:
280+
"""Return local model components missing from ``model_dir``."""
281+
return [name for name, present in cls._required_component_status(model_dir).items() if not present]
282+
283+
@classmethod
284+
def missing_remote_code_components(cls, model_dir: Path) -> list[str]:
285+
"""Return trusted remote-code components referenced by config but absent locally."""
286+
required = cls._remote_code_components(model_dir)
287+
missing: list[str] = []
288+
for component, local_path in required:
289+
if local_path is None or not (model_dir / local_path).is_file():
290+
missing.append(component)
291+
return sorted(missing)
292+
293+
@classmethod
294+
def _remote_code_components(cls, model_dir: Path) -> list[tuple[str, Path | None]]:
295+
config_path = model_dir / "config.json"
296+
try:
297+
data = json.loads(config_path.read_text())
298+
except (OSError, json.JSONDecodeError):
299+
return []
300+
301+
auto_map = data.get("auto_map")
302+
if not isinstance(auto_map, dict):
303+
return []
304+
305+
components: list[tuple[str, Path | None]] = []
306+
for value in auto_map.values():
307+
for class_ref in cls._auto_map_class_refs(value):
308+
component = cls._remote_code_component(class_ref)
309+
if component is not None:
310+
components.append(component)
311+
return components
312+
313+
@staticmethod
314+
def _auto_map_class_refs(value: object) -> list[str]:
315+
if isinstance(value, str):
316+
return [value]
317+
if isinstance(value, list):
318+
return [item for item in value if isinstance(item, str)]
319+
return []
320+
321+
@staticmethod
322+
def _remote_code_component(class_ref: str) -> tuple[str, Path | None] | None:
323+
repo_id: str | None = None
324+
module_ref = class_ref
325+
if "--" in class_ref:
326+
repo_id, module_ref = class_ref.split("--", 1)
327+
if "." not in module_ref:
328+
return None
329+
330+
module_name, _ = module_ref.rsplit(".", 1)
331+
module_path = Path(*module_name.split(".")).with_suffix(".py")
332+
if repo_id is not None:
333+
return f"remote code from {repo_id} ({module_path.as_posix()})", None
334+
return module_path.as_posix(), module_path
335+
336+
@classmethod
337+
def _has_complete_model_artifacts(cls, model_dir: Path, files: list[Path]) -> bool:
338+
weight_indexes = [path for path in files if path.name.endswith(".index.json")]
339+
if weight_indexes:
340+
return any(cls._index_references_existing_shards(model_dir, index_path) for index_path in weight_indexes)
341+
342+
return any(fnmatchcase(path.name, pattern) for path in files for pattern in cls._model_artifact_patterns())
343+
344+
@staticmethod
345+
def _index_references_existing_shards(model_dir: Path, index_path: Path) -> bool:
346+
"""Return whether an HF weight index references shards present on disk."""
347+
try:
348+
data = json.loads(index_path.read_text())
349+
except (OSError, json.JSONDecodeError):
350+
return False
351+
352+
weight_map = data.get("weight_map")
353+
if not isinstance(weight_map, dict) or not weight_map:
354+
return False
355+
356+
shard_names = {name for name in weight_map.values() if isinstance(name, str)}
357+
if not shard_names:
358+
return False
359+
return all((model_dir / name).is_file() for name in shard_names)
360+
361+
def partial_cached_snapshot(self) -> Path | None:
362+
"""Return the local HF snapshot for this repo/revision, even if it is partial."""
363+
if self.repo_id is None or self.cache_root is None:
364+
return None
365+
return self._local_snapshot_for_repo(self.repo_id, self.revision, self.cache_root)
366+
190367
@classmethod
191368
def is_trusted_org(cls, org: str) -> bool:
192369
"""Return whether an organization is allowed to load remote code."""

src/nemo_safe_synthesizer/preflight/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
CUDAAvailabilityCheck,
1919
DatasetSizeCheck,
2020
GroupbyColumnCheck,
21-
HFTokenCheck,
21+
HFModelAvailabilityCheck,
2222
InferenceKeyCheck,
2323
OrderbyColumnCheck,
2424
OversamplingCheck,
@@ -60,7 +60,7 @@
6060
"SmallDatasetCheck",
6161
"DatasetSizeCheck",
6262
"GroupbyColumnCheck",
63-
"HFTokenCheck",
63+
"HFModelAvailabilityCheck",
6464
"InferenceKeyCheck",
6565
"IssueCollector",
6666
"MetadataCheck",

src/nemo_safe_synthesizer/preflight/checks/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727
from .environment import (
2828
CUDAAvailabilityCheck,
29-
HFTokenCheck,
29+
HFModelAvailabilityCheck,
3030
InferenceKeyCheck,
3131
VRAMHeadroomCheck,
3232
)
@@ -38,7 +38,7 @@
3838
"SmallDatasetCheck",
3939
"DatasetSizeCheck",
4040
"GroupbyColumnCheck",
41-
"HFTokenCheck",
41+
"HFModelAvailabilityCheck",
4242
"InferenceKeyCheck",
4343
"OrderbyColumnCheck",
4444
"OversamplingCheck",
@@ -58,7 +58,7 @@
5858
# CONFIG
5959
CUDAAvailabilityCheck(),
6060
InferenceKeyCheck(),
61-
HFTokenCheck(),
61+
HFModelAvailabilityCheck(),
6262
# DATAFRAME
6363
DatasetSizeCheck(),
6464
GroupbyColumnCheck(),

0 commit comments

Comments
 (0)