Skip to content

Commit 331203d

Browse files
authored
Add workers arg to from_model_checkpoint (#1844)
1 parent 7557c28 commit 331203d

1 file changed

Lines changed: 4 additions & 0 deletions

File tree

src/fairchem/core/calculate/ase_calculator.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def from_model_checkpoint(
125125
overrides: dict | None = None,
126126
device: Literal["cuda", "cpu"] | None = None,
127127
seed: int = 41,
128+
workers: int = 1,
128129
) -> FAIRChemCalculator:
129130
"""Instantiate a FAIRChemCalculator from a checkpoint file.
130131
@@ -139,6 +140,7 @@ def from_model_checkpoint(
139140
overrides: Optional dictionary of settings to override default inference settings.
140141
device: Optional torch device to load the model onto.
141142
seed: Random seed for reproducibility.
143+
workers: Number of parallel workers for prediction unit. Default is 1.
142144
"""
143145

144146
if name_or_path in pretrained_mlip.available_models:
@@ -147,13 +149,15 @@ def from_model_checkpoint(
147149
inference_settings=inference_settings,
148150
overrides=overrides,
149151
device=device,
152+
workers=workers,
150153
)
151154
elif os.path.isfile(name_or_path):
152155
predict_unit = pretrained_mlip.load_predict_unit(
153156
name_or_path,
154157
inference_settings=inference_settings,
155158
overrides=overrides,
156159
device=device,
160+
workers=workers,
157161
)
158162
else:
159163
raise ValueError(

0 commit comments

Comments
 (0)