Skip to content

Commit 99ef2d7

Browse files
authored
chore(materials): Include callback to log progress when sampling (#214)
Enables custom logging of sampling progress. Co-authored-by: Andrew Fowler <andrew31416@users.noreply.github.com>
1 parent 27bda83 commit 99ef2d7

File tree

3 files changed

+29
-4
lines changed

3 files changed

+29
-4
lines changed

mattergen/common/utils/data_classes.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dataclasses import asdict, dataclass, field
77
from functools import cached_property
88
from pathlib import Path
9-
from typing import Any, Literal
9+
from typing import Any, Literal, Protocol
1010

1111
import numpy as np
1212
from huggingface_hub import hf_hub_download
@@ -154,3 +154,13 @@ def checkpoint_path(self) -> str:
154154
raise ValueError(f"Unrecognized load_epoch {self.load_epoch}")
155155
ckpt = ckpts[ckpt_ix]
156156
return ckpt
157+
158+
159+
class ProgressCallback(Protocol):
160+
def __call__(self, progress: float):
161+
"""Callback which can be used to report progress on long-running inference.
162+
163+
Args:
164+
progress: Float between 0 and 1.
165+
"""
166+
pass

mattergen/generator.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from mattergen.common.utils.globals import DEFAULT_SAMPLING_CONFIG_PATH, get_device
3333
from mattergen.diffusion.lightning_module import DiffusionLightningModule
3434
from mattergen.diffusion.sampling.pc_sampler import PredictorCorrector
35+
from mattergen.common.utils.data_classes import ProgressCallback
3536

3637

3738
def draw_samples_from_sampler(
@@ -41,6 +42,7 @@ def draw_samples_from_sampler(
4142
output_path: Path | None = None,
4243
cfg: DictConfig | None = None,
4344
record_trajectories: bool = True,
45+
progress_callback: ProgressCallback | None = None,
4446
) -> list[Structure]:
4547

4648
# Dict
@@ -51,7 +53,9 @@ def draw_samples_from_sampler(
5153

5254
all_samples_list = []
5355
all_trajs_list = []
54-
for conditioning_data, mask in tqdm(condition_loader, desc="Generating samples"):
56+
for batch_idx, (conditioning_data, mask) in enumerate(tqdm(condition_loader, desc="Generating samples")):
57+
if progress_callback is not None:
58+
progress_callback(progress=batch_idx / len(condition_loader))
5559

5660
# generate samples
5761
if record_trajectories:
@@ -60,6 +64,11 @@ def draw_samples_from_sampler(
6064
else:
6165
sample, mean = sampler.sample(conditioning_data, mask)
6266
all_samples_list.extend(mean.to_data_list())
67+
68+
if progress_callback is not None:
69+
# log 100% progress
70+
progress_callback(progress=1.0)
71+
6372
all_samples = collate(all_samples_list)
6473
assert isinstance(all_samples, ChemGraph)
6574
lengths, angles = lattice_matrix_to_params_torch(all_samples.cell)
@@ -199,6 +208,9 @@ class CrystalGenerator:
199208
_model: DiffusionLightningModule | None = None
200209
_cfg: DictConfig | None = None
201210

211+
# can be used to monitor progress of generation
212+
progress_callback: ProgressCallback | None = None
213+
202214
def __post_init__(self) -> None:
203215
assert self.num_atoms_distribution in NUM_ATOMS_DISTRIBUTIONS, (
204216
f"num_atoms_distribution must be one of {list(NUM_ATOMS_DISTRIBUTIONS.keys())}, "
@@ -374,6 +386,7 @@ def generate(
374386
output_path=Path(output_dir),
375387
properties_to_condition_on=self.properties_to_condition_on,
376388
record_trajectories=self.record_trajectories,
389+
progress_callback=self.progress_callback,
377390
)
378391

379392
return generated_structures

mattergen/scripts/generate.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pymatgen.core.structure import Structure
1010

1111
from mattergen.common.data.types import TargetProperty
12-
from mattergen.common.utils.data_classes import PRETRAINED_MODEL_NAME, MatterGenCheckpointInfo
12+
from mattergen.common.utils.data_classes import PRETRAINED_MODEL_NAME, MatterGenCheckpointInfo, ProgressCallback
1313
from mattergen.generator import CrystalGenerator
1414

1515

@@ -29,6 +29,7 @@ def main(
2929
diffusion_guidance_factor: float | None = None,
3030
strict_checkpoint_loading: bool = True,
3131
target_compositions: list[dict[str, int]] | None = None,
32+
progress_callback: ProgressCallback | None = None,
3233
) -> list[Structure]:
3334
"""
3435
Evaluate diffusion model against molecular metrics.
@@ -46,7 +47,7 @@ def main(
4647
strict_checkpoint_loading: Whether to raise an exception when not all parameters from the checkpoint can be matched to the model.
4748
target_compositions: List of dictionaries with target compositions to condition on. Each dictionary should have the form `{element: number_of_atoms}`. If None, the target compositions are not conditioned on.
4849
Only supported for models trained for crystal structure prediction (CSP) (default: None)
49-
50+
progress_callback: Optional callback function that takes in a single float argument representing the progress of the generation process (between 0 and 1).
5051
NOTE: When specifying dictionary values via the CLI, make sure there is no whitespace between the key and value, e.g., `--properties_to_condition_on={key1:value1}`.
5152
"""
5253
assert (
@@ -94,6 +95,7 @@ def main(
9495
diffusion_guidance_factor if diffusion_guidance_factor is not None else 0.0
9596
),
9697
target_compositions_dict=target_compositions,
98+
progress_callback=progress_callback,
9799
)
98100
return generator.generate(output_dir=Path(output_path))
99101

0 commit comments

Comments
 (0)