Skip to content

Commit 19787c1

Browse files
authored
Consolidate model and model_name args in FairchemModel (#377)
1 parent 62954ca commit 19787c1

File tree

5 files changed

+75
-77
lines changed

5 files changed

+75
-77
lines changed

examples/scripts/1_Introduction/1.3_fairchem.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,9 @@
3636
si_dc = bulk("Si", "diamond", a=5.43).repeat((2, 2, 2))
3737
atomic_numbers = si_dc.get_atomic_numbers()
3838
model = FairChemModel(
39-
model=None,
40-
model_name=MODEL_NAME,
39+
model=MODEL_NAME,
4140
task_name="omat", # Open Materials task for crystalline systems
42-
cpu=False,
41+
device=device,
4342
)
4443
atoms_list = [si_dc, si_dc]
4544
state = ts.io.atoms_to_state(atoms_list, device=device, dtype=dtype)

tests/models/test_fairchem.py

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@
2929
@pytest.fixture
3030
def eqv2_uma_model_pbc() -> FairChemModel:
3131
"""UMA model for periodic boundary condition systems."""
32-
cpu = DEVICE.type == "cpu"
33-
return FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=cpu)
32+
return FairChemModel(model="uma-s-1", task_name="omat", device=DEVICE)
3433

3534

3635
@pytest.mark.skipif(
@@ -39,7 +38,9 @@ def eqv2_uma_model_pbc() -> FairChemModel:
3938
@pytest.mark.parametrize("task_name", ["omat", "omol", "oc20"])
4039
def test_task_initialization(task_name: str) -> None:
4140
"""Test that different UMA task names work correctly."""
42-
model = FairChemModel(model=None, model_name="uma-s-1", task_name=task_name, cpu=True)
41+
model = FairChemModel(
42+
model="uma-s-1", task_name=task_name, device=torch.device("cpu")
43+
)
4344
assert model.task_name
4445
assert str(model.task_name.value) == task_name
4546
assert hasattr(model, "predictor")
@@ -75,9 +76,7 @@ def test_homogeneous_batching(task_name: str, systems_func: Callable) -> None:
7576
for mol in systems:
7677
mol.info |= {"charge": 0, "spin": 1}
7778

78-
model = FairChemModel(
79-
model=None, model_name="uma-s-1", task_name=task_name, cpu=DEVICE.type == "cpu"
80-
)
79+
model = FairChemModel(model="uma-s-1", task_name=task_name, device=DEVICE)
8180
state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE)
8281
results = model(state)
8382

@@ -109,10 +108,9 @@ def test_heterogeneous_tasks() -> None:
109108
systems[0].info |= {"charge": 0, "spin": 1}
110109

111110
model = FairChemModel(
112-
model=None,
113-
model_name="uma-s-1",
111+
model="uma-s-1",
114112
task_name=task_name,
115-
cpu=DEVICE.type == "cpu",
113+
device=DEVICE,
116114
)
117115
state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE)
118116
results = model(state)
@@ -151,9 +149,7 @@ def test_batch_size_variations(systems_func: Callable, expected_count: int) -> N
151149
"""Test batching with different numbers and sizes of systems."""
152150
systems = systems_func()
153151

154-
model = FairChemModel(
155-
model=None, model_name="uma-s-1", task_name="omat", cpu=DEVICE.type == "cpu"
156-
)
152+
model = FairChemModel(model="uma-s-1", task_name="omat", device=DEVICE)
157153
state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE)
158154
results = model(state)
159155

@@ -173,10 +169,9 @@ def test_stress_computation(*, compute_stress: bool) -> None:
173169
systems = [bulk("Si", "diamond", a=5.43), bulk("Al", "fcc", a=4.05)]
174170

175171
model = FairChemModel(
176-
model=None,
177-
model_name="uma-s-1",
172+
model="uma-s-1",
178173
task_name="omat",
179-
cpu=DEVICE.type == "cpu",
174+
device=DEVICE,
180175
compute_stress=compute_stress,
181176
)
182177
state = ts.io.atoms_to_state(systems, device=DEVICE, dtype=DTYPE)
@@ -195,9 +190,7 @@ def test_stress_computation(*, compute_stress: bool) -> None:
195190
)
196191
def test_device_consistency() -> None:
197192
"""Test device consistency between model and data."""
198-
cpu = DEVICE.type == "cpu"
199-
200-
model = FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=cpu)
193+
model = FairChemModel(model="uma-s-1", task_name="omat", device=DEVICE)
201194
system = bulk("Si", "diamond", a=5.43)
202195
state = ts.io.atoms_to_state([system], device=DEVICE, dtype=DTYPE)
203196

@@ -211,7 +204,7 @@ def test_device_consistency() -> None:
211204
)
212205
def test_empty_batch_error() -> None:
213206
"""Test that empty batches raise appropriate errors."""
214-
model = FairChemModel(model=None, model_name="uma-s-1", task_name="omat", cpu=True)
207+
model = FairChemModel(model="uma-s-1", task_name="omat", device=torch.device("cpu"))
215208
with pytest.raises((ValueError, RuntimeError, IndexError)):
216209
model(ts.io.atoms_to_state([], device=torch.device("cpu"), dtype=torch.float32))
217210

@@ -223,7 +216,7 @@ def test_load_from_checkpoint_path() -> None:
223216
"""Test loading model from a saved checkpoint file path."""
224217
checkpoint_path = pretrained_checkpoint_path_from_name("uma-s-1")
225218
loaded_model = FairChemModel(
226-
model=str(checkpoint_path), task_name="omat", cpu=DEVICE == "cpu"
219+
model=str(checkpoint_path), task_name="omat", device=DEVICE
227220
)
228221

229222
# Verify the loaded model works
@@ -278,10 +271,9 @@ def test_fairchem_charge_spin(charge: float, spin: float) -> None:
278271

279272
# Create model with UMA omol task (supports charge/spin for molecules)
280273
model = FairChemModel(
281-
model=None,
282-
model_name="uma-s-1",
274+
model="uma-s-1",
283275
task_name="omol",
284-
cpu=DEVICE.type == "cpu",
276+
device=DEVICE,
285277
)
286278

287279
# This should not raise an error

tests/models/test_fairchem_legacy.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,14 @@ def model_path_oc20(tmp_path_factory: pytest.TempPathFactory) -> str:
3535

3636
@pytest.fixture
3737
def eqv2_oc20_model_pbc(model_path_oc20: str) -> FairChemV1Model:
38-
cpu = DEVICE.type == "cpu"
39-
return FairChemV1Model(model=model_path_oc20, cpu=cpu, seed=0, pbc=True)
38+
return FairChemV1Model(model=model_path_oc20, device=DEVICE, seed=0, pbc=True)
4039

4140

4241
@pytest.fixture
4342
def eqv2_oc20_model_non_pbc(
4443
model_path_oc20: str,
4544
) -> FairChemV1Model:
46-
cpu = DEVICE.type == "cpu"
47-
return FairChemV1Model(model=model_path_oc20, cpu=cpu, seed=0, pbc=False)
45+
return FairChemV1Model(model=model_path_oc20, device=DEVICE, seed=0, pbc=False)
4846

4947

5048
if get_token():
@@ -59,8 +57,7 @@ def model_path_omat24(tmp_path_factory: pytest.TempPathFactory) -> str:
5957
def eqv2_omat24_model_pbc(
6058
model_path_omat24: str,
6159
) -> FairChemV1Model:
62-
cpu = DEVICE.type == "cpu"
63-
return FairChemV1Model(model=model_path_omat24, cpu=cpu, seed=0, pbc=True)
60+
return FairChemV1Model(model=model_path_omat24, device=DEVICE, seed=0, pbc=True)
6461

6562

6663
@pytest.fixture
@@ -106,10 +103,9 @@ def ocp_calculator(model_path_oc20: str) -> OCPCalculator:
106103

107104
def test_fairchem_mixed_pbc_init_raises(model_path_oc20: str) -> None:
108105
"""Test that initializing FairChemV1Model with mixed PBC raises ValueError."""
109-
cpu = DEVICE.type == "cpu"
110106
mixed_pbc = torch.tensor([True, False, True], dtype=torch.bool)
111107
with pytest.raises(ValueError, match="FairChemV1Model does not support mixed PBC"):
112-
FairChemV1Model(model=model_path_oc20, cpu=cpu, seed=0, pbc=mixed_pbc)
108+
FairChemV1Model(model=model_path_oc20, device=DEVICE, seed=0, pbc=mixed_pbc)
113109

114110

115111
def test_fairchem_mixed_pbc_forward_raises(

torch_sim/models/fairchem.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import traceback
1313
import typing
1414
import warnings
15+
from pathlib import Path
1516
from typing import Any
1617

1718
import torch
@@ -43,7 +44,6 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None:
4344

4445
if typing.TYPE_CHECKING:
4546
from collections.abc import Callable
46-
from pathlib import Path
4747

4848
from torch_sim.typing import StateDict
4949

@@ -71,34 +71,34 @@ class FairChemModel(ModelInterface):
7171

7272
def __init__(
7373
self,
74-
model: str | Path | None,
74+
model: str | Path,
7575
neighbor_list_fn: Callable | None = None,
7676
*, # force remaining arguments to be keyword-only
77-
model_name: str | None = None,
7877
model_cache_dir: str | Path | None = None,
79-
cpu: bool = False,
78+
device: torch.device | None = None,
8079
dtype: torch.dtype | None = None,
8180
compute_stress: bool = False,
8281
task_name: UMATask | str | None = None,
8382
) -> None:
8483
"""Initialize the FairChem model.
8584
8685
Args:
87-
model (str | Path | None): Path to model checkpoint file
86+
model (str | Path): Either a pretrained model name or path to model
87+
checkpoint file. The function will first check if the input matches
88+
a known pretrained model name, then check if it's a valid file path.
8889
neighbor_list_fn (Callable | None): Function to compute neighbor lists
8990
(not currently supported)
90-
model_name (str | None): Name of pretrained model to load
9191
model_cache_dir (str | Path | None): Path where to save the model
92-
cpu (bool): Whether to use CPU instead of GPU for computation
92+
device (torch.device | None): Device to use for computation. If None,
93+
defaults to CUDA if available, otherwise CPU.
9394
dtype (torch.dtype | None): Data type to use for computation
9495
compute_stress (bool): Whether to compute stress tensor
9596
task_name (UMATask | str | None): Task type for UMA models (optional,
9697
only needed for UMA models)
9798
9899
Raises:
99-
RuntimeError: If both model_name and model are specified
100100
NotImplementedError: If custom neighbor list function is provided
101-
ValueError: If neither model nor model_name is provided
101+
ValueError: If model is not a known model name or valid file path
102102
"""
103103
setup_imports()
104104
setup_logging()
@@ -114,24 +114,19 @@ def __init__(
114114
"Custom neighbor list is not supported for FairChemModel."
115115
)
116116

117-
if model_name is not None:
118-
if model is not None:
119-
raise RuntimeError(
120-
"model_name and checkpoint_path were both specified, "
121-
"please use only one at a time"
122-
)
123-
model = model_name
124-
125-
if model is None:
126-
raise ValueError("Either model or model_name must be provided")
117+
# Convert Path to string for consistency
118+
if isinstance(model, Path):
119+
model = str(model)
127120

128121
# Convert task_name to UMATask if it's a string (only for UMA models)
129122
if isinstance(task_name, str):
130123
task_name = UMATask(task_name)
131124

132125
# Use the efficient predictor API for optimal performance
133-
device_str = "cpu" if cpu else "cuda" if torch.cuda.is_available() else "cpu"
134-
self._device = torch.device(device_str)
126+
self._device = device or torch.device(
127+
"cuda" if torch.cuda.is_available() else "cpu"
128+
)
129+
device_str = str(self._device)
135130
self.task_name = task_name
136131

137132
# Create efficient batch predictor for fast inference

torch_sim/models/fairchem_legacy.py

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@
1919
from __future__ import annotations
2020

2121
import copy
22+
import os
2223
import traceback
2324
import typing
2425
import warnings
26+
from pathlib import Path
2527
from types import MappingProxyType
2628
from typing import Any
2729

@@ -71,7 +73,6 @@ def __init__(self, err: ImportError = exc, *_args: Any, **_kwargs: Any) -> None:
7173

7274
if typing.TYPE_CHECKING:
7375
from collections.abc import Callable
74-
from pathlib import Path
7576

7677
from torch_sim.typing import StateDict
7778

@@ -110,6 +111,7 @@ class FairChemV1Model(ModelInterface):
110111
Examples:
111112
>>> model = FairChemV1Model(model="path/to/checkpoint.pt", compute_stress=True)
112113
>>> results = model(state)
114+
113115
"""
114116

115117
_reshaped_props = MappingProxyType(
@@ -118,14 +120,13 @@ class FairChemV1Model(ModelInterface):
118120

119121
def __init__( # noqa: C901, PLR0915
120122
self,
121-
model: str | Path | None,
123+
model: str | Path | None = None,
122124
neighbor_list_fn: Callable | None = None,
123125
*, # force remaining arguments to be keyword-only
124126
config_yml: str | None = None,
125-
model_name: str | None = None,
126127
local_cache: str | None = None,
127128
trainer: str | None = None,
128-
cpu: bool = False,
129+
device: torch.device | None = None,
129130
seed: int | None = None,
130131
dtype: torch.dtype | None = None,
131132
compute_stress: bool = False,
@@ -139,24 +140,28 @@ def __init__( # noqa: C901, PLR0915
139140
in energy and force calculations.
140141
141142
Args:
142-
model (str | Path | None): Path to model checkpoint file
143+
model (str | Path | None): Either a pretrained model name or path to model
144+
checkpoint file. The function will first check if it's a valid file
145+
path, and if not, will attempt to load it as a pretrained model name
146+
(requires local_cache to be set). If None, config_yml must be provided.
143147
neighbor_list_fn (Callable | None): Function to compute neighbor lists
144148
(not currently supported)
145149
config_yml (str | None): Path to configuration YAML file
146-
model_name (str | None): Name of pretrained model to load
147-
local_cache (str | None): Path to local model cache directory
150+
local_cache (str | None): Path to local model cache directory (required
151+
when using pretrained model names)
148152
trainer (str | None): Name of trainer class to use
149-
cpu (bool): Whether to use CPU instead of GPU for computation
153+
device (torch.device | None): Device to use for computation. If None,
154+
defaults to CUDA if available, otherwise CPU.
150155
seed (int | None): Random seed for reproducibility
151156
dtype (torch.dtype | None): Data type to use for computation
152157
compute_stress (bool): Whether to compute stress tensor
153158
pbc (torch.Tensor | bool): Whether to use periodic boundary conditions
154159
disable_amp (bool): Whether to disable AMP
155160
Raises:
156-
RuntimeError: If both model_name and model are specified
157-
NotImplementedError: If local_cache is not set when model_name is used
158161
NotImplementedError: If custom neighbor list function is provided
159162
ValueError: If stress computation is requested but not supported by model
163+
ValueError: If neither config_yml nor model is provided
164+
ValueError: If model cannot be loaded as file or pretrained model
160165
161166
Notes:
162167
Either config_yml or model must be provided. The model loads configuration
@@ -178,19 +183,25 @@ def __init__( # noqa: C901, PLR0915
178183
)
179184
self.pbc = pbc
180185

181-
if model_name is not None:
182-
if model is not None:
183-
raise RuntimeError(
184-
"model_name and checkpoint_path were both specified, "
185-
"please use only one at a time"
186-
)
187-
if local_cache is None:
188-
raise NotImplementedError(
189-
"Local cache must be set when specifying a model name"
186+
# Process model parameter if provided
187+
if model is not None:
188+
# Convert Path to string for consistency
189+
if isinstance(model, Path):
190+
model = str(model)
191+
192+
# Determine if model is a file path or a pretrained model name
193+
# First check if it's a valid file path
194+
if not os.path.isfile(model):
195+
# If not a file, try to load as pretrained model name
196+
if local_cache is None:
197+
raise ValueError(
198+
f"Model '{model}' is not a valid file path. "
199+
"If using a pretrained model name, local_cache must be set."
200+
)
201+
# Attempt to load as pretrained model name
202+
model = model_name_to_local_file(
203+
model_name=model, local_cache=local_cache
190204
)
191-
model = model_name_to_local_file(
192-
model_name=model_name, local_cache=local_cache
193-
)
194205

195206
# Either the config path or the checkpoint path needs to be provided
196207
if not config_yml and model is None:
@@ -276,6 +287,11 @@ def __init__( # noqa: C901, PLR0915
276287
self.config["checkpoint"] = str(model)
277288
del config["dataset"]["src"]
278289

290+
# Determine if CPU should be used (for the legacy trainer API)
291+
cpu = device is not None and device.type == "cpu"
292+
if device is None:
293+
cpu = not torch.cuda.is_available()
294+
279295
self.trainer = registry.get_trainer_class(config["trainer"])(
280296
task=config.get("task", {}),
281297
model=config["model"],

0 commit comments

Comments
 (0)