Skip to content

Commit 5678a3b

Browse files
author
Michael Dzamba
committed
feat: ban retired UMA checkpoints by md5
Refuse to load uma-s-1p1 and uma-s-1p2 inference checkpoints. The check runs at the single chokepoint (MLIPPredictUnit.__init__) immediately before torch.load, so every end-user inference path -- get_predict_unit, FAIRChemCalculator.from_model_checkpoint, and direct construction -- is covered. New bans can be added by appending an md5 to _BANNED_CHECKPOINTS in predict.py. Includes a parameterized integration test that loads the cached HF artifacts and asserts BannedCheckpointError, plus a small utility script for computing the md5 of a downloaded checkpoint.
1 parent 28292f0 commit 5678a3b

3 files changed

Lines changed: 121 additions & 0 deletions

File tree

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""
2+
Copyright (c) Meta Platforms, Inc. and affiliates.
3+
4+
This source code is licensed under the MIT license found in the
5+
LICENSE file in the root directory of this source tree.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import hashlib
11+
from pathlib import Path
12+
13+
from fairchem.core.calculate.pretrained_mlip import (
14+
pretrained_checkpoint_path_from_name,
15+
)
16+
17+
MODELS = ["uma-s-1p1", "uma-s-1p2"]
18+
19+
20+
def md5_of_file(path: str | Path, chunk_size: int = 1024 * 1024) -> str:
21+
h = hashlib.md5()
22+
with open(path, "rb") as f:
23+
for chunk in iter(lambda: f.read(chunk_size), b""):
24+
h.update(chunk)
25+
return h.hexdigest()
26+
27+
28+
def main() -> None:
29+
for name in MODELS:
30+
print(f"Downloading {name} ...")
31+
path = pretrained_checkpoint_path_from_name(name)
32+
size_mb = Path(path).stat().st_size / (1024 * 1024)
33+
digest = md5_of_file(path)
34+
print(f" path: {path}")
35+
print(f" size: {size_mb:.1f} MB")
36+
print(f" md5: {digest}")
37+
print()
38+
39+
40+
if __name__ == "__main__":
41+
main()

src/fairchem/core/units/mlip_unit/predict.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from __future__ import annotations
99

1010
import copy
11+
import hashlib
1112
import logging
1213
import math
1314
import os
@@ -56,6 +57,46 @@
5657
from fairchem.core.units.mlip_unit.api.inference import MLIPInferenceCheckpoint
5758

5859

60+
class BannedCheckpointError(RuntimeError):
61+
"""
62+
Raised when a checkpoint with a banned MD5 hash is loaded.
63+
"""
64+
65+
66+
# Map of banned checkpoint MD5 -> human-readable reason. Any inference
67+
# checkpoint whose md5 matches an entry here will be rejected before it is
68+
# deserialized by torch.load.
69+
_BANNED_CHECKPOINTS: dict[str, str] = {
70+
# uma-s-1p1
71+
"36a2f071350be0ee4c15e7ebdd16dde1": (
72+
"uma-s-1p1 has been retired. Please upgrade to a newer UMA model "
73+
"(see fairchem.core.calculate.pretrained_mlip.available_models)."
74+
),
75+
# uma-s-1p2
76+
"26ac47f57e7d68af9f031077cdc2cbe9": (
77+
"uma-s-1p2 has been retired. Please upgrade to a newer UMA model "
78+
"(see fairchem.core.calculate.pretrained_mlip.available_models)."
79+
),
80+
}
81+
82+
83+
def _md5_of_file(path: str, chunk_size: int = 1024 * 1024) -> str:
84+
h = hashlib.md5()
85+
with open(path, "rb") as f:
86+
for chunk in iter(lambda: f.read(chunk_size), b""):
87+
h.update(chunk)
88+
return h.hexdigest()
89+
90+
91+
def _verify_checkpoint_not_banned(path: str) -> None:
92+
digest = _md5_of_file(path)
93+
if digest in _BANNED_CHECKPOINTS:
94+
raise BannedCheckpointError(
95+
f"Refusing to load banned checkpoint {path} (md5={digest}): "
96+
f"{_BANNED_CHECKPOINTS[digest]}"
97+
)
98+
99+
59100
def collate_predictions(predict_fn):
60101
@wraps(predict_fn)
61102
def collated_predict(
@@ -123,6 +164,7 @@ def __init__(
123164
)
124165

125166
# Load checkpoint first to get model type
167+
_verify_checkpoint_not_banned(inference_model_path)
126168
checkpoint = torch.load(
127169
inference_model_path, map_location="cpu", weights_only=False
128170
)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
"""
2+
Copyright (c) Meta Platforms, Inc. and affiliates.
3+
4+
This source code is licensed under the MIT license found in the
5+
LICENSE file in the root directory of this source tree.
6+
"""
7+
8+
from __future__ import annotations
9+
10+
import pytest
11+
from huggingface_hub import try_to_load_from_cache
12+
from huggingface_hub.file_download import _CACHED_NO_EXIST
13+
14+
from fairchem.core._config import CACHE_DIR
15+
from fairchem.core.calculate.pretrained_mlip import _MODEL_CKPTS
16+
from fairchem.core.units.mlip_unit import MLIPPredictUnit
17+
from fairchem.core.units.mlip_unit.predict import BannedCheckpointError
18+
19+
20+
@pytest.mark.parametrize("model_name", ["uma-s-1p1", "uma-s-1p2"])
21+
def test_real_uma_checkpoint_is_banned(model_name):
22+
"""
23+
The retired UMA checkpoints shipped on HuggingFace must trip the
24+
BannedCheckpointError. Skipped when the file is not present in the
25+
local HF cache so the test does not require network access in CI.
26+
"""
27+
spec = _MODEL_CKPTS.checkpoints[model_name]
28+
rel_path = f"{spec.subfolder}/{spec.filename}" if spec.subfolder else spec.filename
29+
cached = try_to_load_from_cache(
30+
repo_id=spec.repo_id,
31+
filename=rel_path,
32+
cache_dir=CACHE_DIR,
33+
)
34+
if cached is None or cached is _CACHED_NO_EXIST:
35+
pytest.skip(f"{model_name} not in local HF cache at {CACHE_DIR}")
36+
37+
with pytest.raises(BannedCheckpointError):
38+
MLIPPredictUnit(cached, device="cpu")

0 commit comments

Comments
 (0)