Skip to content

Commit 8f74b9e

Browse files
mshuaibiirayg1234
andauthored
nvidia graph gen support (#1737)
NVIDIA has released a suite of tools, of which includes a neighborlist. Initial benchmarks look promising for both CPU and GPU performance. Currently our external cpu implementation uses pymatgen. - [x] cpu support - [x] gpu support <img width="4766" height="2060" alt="image" src="https://github.com/user-attachments/assets/5d514209-c876-4ac5-a59e-60e4f337b922" /> Note - we should bump minimum python version to `3.11` for the CI. # TODO * GP is broken for V3. A separate PR will abstract it out to support the other graph gen methods, including v3. see #1791 * `external_graph_gen_method` is unused, to use nvidia graph gen with this, we need a followup PR to add this option to the FairchemCalculator (and deprecate `external_graph_gen` from inference_settings) --------- Co-authored-by: Ray Gao <rgao@meta.com> Co-authored-by: Ray Gao <7001989+rayg1234@users.noreply.github.com>
1 parent 17b8cde commit 8f74b9e

17 files changed

Lines changed: 2091 additions & 820 deletions

File tree

.github/workflows/test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ jobs:
3535
strategy:
3636
max-parallel: 10
3737
matrix:
38-
python_version: ['3.10', '3.13']
38+
python_version: ['3.11', '3.13']
3939

4040
steps:
4141
- name: Checkout code
@@ -155,7 +155,7 @@ jobs:
155155
- name: Install core dependencies and package
156156
run: |
157157
python -m pip install --upgrade pip
158-
pip install packages/fairchem-core[dev] \
158+
pip install packages/fairchem-core[dev,extras] \
159159
packages/fairchem-data-omol[dev] \
160160
packages/fairchem-data-omat \
161161
-r tests/requirements.txt # pin test packages

packages/fairchem-core/pyproject.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ name = "fairchem-core"
77
description = "Machine learning models for chemistry and materials science by the FAIR Chemistry team"
88
license = {text = "MIT License"}
99
dynamic = ["version", "readme"]
10-
requires-python = ">=3.10, <3.14"
10+
requires-python = ">=3.11, <3.14"
1111
dependencies = [
1212
"torch~=2.8.0",
1313
"ray[serve]>=2.53.0",
@@ -36,8 +36,9 @@ dependencies = [
3636
[project.optional-dependencies] # add optional dependencies, e.g. to be installed as pip install fairchem.core[dev]
3737
dev = ["pre-commit", "pytest", "pytest-cov", "coverage", "syrupy", "ruff==0.5.1"]
3838
docs = ["jupyter-book", "jupytext", "sphinx","sphinx-autoapi==3.3.3", "astroid<4", "umap-learn", "vdict", "ipywidgets", "jupyter_book>=2.0", "torch-dftd"]
39-
adsorbml = ["dscribe","x3dase","scikit-image"]
40-
extras = ["ray[default]", "pymatgen", "quacc[phonons]>=0.15.3", "pandas"]
39+
adsorbml = ["dscribe", "x3dase", "scikit-image"]
40+
extras = ["ray[default]", "pymatgen", "quacc[phonons]>=0.15.3", "pandas", "nvalchemi-toolkit-ops"]
41+
4142

4243
[project.scripts]
4344
fairchem = "fairchem.core._cli:main"

src/fairchem/core/datasets/ase_datasets.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from functools import cache, partial
1717
from glob import glob
1818
from pathlib import Path
19-
from typing import Any, Callable
19+
from typing import TYPE_CHECKING, Any
2020

2121
import ase
2222
import numpy as np
@@ -28,6 +28,9 @@
2828
from fairchem.core.datasets.base_dataset import BaseDataset
2929
from fairchem.core.modules.transforms import DataTransforms
3030

31+
if TYPE_CHECKING:
32+
from collections.abc import Callable
33+
3134

3235
def apply_one_tags(
3336
atoms: ase.Atoms, skip_if_nonzero: bool = True, skip_always: bool = False

src/fairchem/core/datasets/atomic_data.py

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import logging
1414
import re
1515
from collections.abc import Sequence
16-
from typing import List, Optional, Union
16+
from typing import Union
1717

1818
import ase
1919
import ase.db.sqlite
@@ -25,6 +25,8 @@
2525
from ase.stress import full_3x3_to_voigt_6_stress, voigt_6_to_full_3x3_stress
2626
from monty.dev import requires
2727

28+
from fairchem.core.common.utils import StrEnum
29+
2830
try:
2931
from pymatgen.io.ase import AseAtomsAdaptor
3032

@@ -33,9 +35,18 @@
3335
AseAtomsAdaptor = None
3436
pmg_installed = False
3537

38+
from fairchem.core.graph.radius_graph_pbc_nvidia import get_neighbors_nvidia_atoms
3639

3740
IndexType = Union[slice, torch.Tensor, np.ndarray, Sequence]
3841

42+
43+
class ExternalGraphMethod(StrEnum):
44+
"""Enum for external graph generation methods."""
45+
46+
PYMATGEN = "pymatgen"
47+
NVIDIA = "nvidia"
48+
49+
3950
# these are all currently certainly output by the current a2g
4051
# except for tags, all fields are required for network inference.
4152
_REQUIRED_KEYS = [
@@ -83,7 +94,7 @@ def size_repr(key: str, item: torch.Tensor, indent=0) -> str:
8394
out = item.item()
8495
elif torch.is_tensor(item):
8596
out = str(list(item.size()))
86-
elif isinstance(item, (List, tuple)):
97+
elif isinstance(item, (list, tuple)):
8798
out = str([len(item)])
8899
elif isinstance(item, dict):
89100
lines = [indent_str + size_repr(k, v, 2) for k, v in item.items()]
@@ -300,10 +311,8 @@ def validate(self):
300311
assert self.forces.dtype == self.pos.dtype
301312
if hasattr(self, "stress"):
302313
# NOTE: usually decomposed. for EFS prediction right now we reshape to (9,). need to discuss, perhaps use (1,3,3)
303-
assert (
304-
self.stress.dim() == 3
305-
and self.stress.shape[1:] == (3, 3)
306-
or (self.stress.dim() == 2 and self.stress.shape[1:] == (9,))
314+
assert (self.stress.dim() == 3 and self.stress.shape[1:] == (3, 3)) or (
315+
self.stress.dim() == 2 and self.stress.shape[1:] == (9,)
307316
)
308317
assert self.stress.shape[0] == self.num_graphs
309318
assert self.stress.dtype == self.pos.dtype
@@ -332,6 +341,7 @@ def from_ase(
332341
r_data_keys: list[str] | None = None, # NOT USED, compat for now
333342
task_name: str | None = None,
334343
target_dtype: torch.dtype = torch.float32,
344+
external_graph_method: ExternalGraphMethod | str = ExternalGraphMethod.PYMATGEN,
335345
) -> AtomicData:
336346
atoms = input_atoms.copy()
337347
calc = input_atoms.calc
@@ -375,7 +385,16 @@ def from_ase(
375385
assert (
376386
max_neigh is not None
377387
), "max_neigh must be specified for cpu graph construction."
378-
split_idx_dist = get_neighbors_pymatgen(atoms, radius, max_neigh)
388+
389+
if external_graph_method == ExternalGraphMethod.PYMATGEN:
390+
split_idx_dist = get_neighbors_pymatgen(atoms, radius, max_neigh)
391+
elif external_graph_method == ExternalGraphMethod.NVIDIA:
392+
split_idx_dist = get_neighbors_nvidia_atoms(atoms, radius, max_neigh)
393+
else:
394+
raise ValueError(
395+
f"external_graph_method must be 'pymatgen' or 'nvidia', got {external_graph_method}"
396+
)
397+
379398
edge_index, cell_offsets = reshape_features(
380399
*split_idx_dist, target_dtype=target_dtype
381400
)
@@ -443,16 +462,20 @@ def from_ase(
443462
# TODO another way to specify this is to spcify a key. maybe total_charge
444463
charge = torch.LongTensor(
445464
[
446-
atoms.info.get("charge", 0)
447-
if r_data_keys is not None and "charge" in r_data_keys
448-
else 0
465+
(
466+
atoms.info.get("charge", 0)
467+
if r_data_keys is not None and "charge" in r_data_keys
468+
else 0
469+
)
449470
]
450471
)
451472
spin = torch.LongTensor(
452473
[
453-
atoms.info.get("spin", 0)
454-
if r_data_keys is not None and "spin" in r_data_keys
455-
else 0
474+
(
475+
atoms.info.get("spin", 0)
476+
if r_data_keys is not None and "spin" in r_data_keys
477+
else 0
478+
)
456479
]
457480
)
458481

@@ -844,7 +867,7 @@ def update_batch_edges(
844867

845868

846869
def atomicdata_list_to_batch(
847-
data_list: list[AtomicData], exclude_keys: Optional[list] = None
870+
data_list: list[AtomicData], exclude_keys: list | None = None
848871
) -> AtomicData:
849872
"""
850873
all data points must be single graphs and have the same set of keys.

src/fairchem/core/datasets/common_structures.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
from __future__ import annotations
22

33
import numpy as np
4-
from ase.build import bulk
4+
from ase import Atoms
5+
from ase.build import bulk, molecule
56
from ase.lattice.cubic import FaceCenteredCubic
67

78

89
def get_fcc_crystal_by_num_atoms(
910
num_atoms: int,
1011
lattice_constant: float = 3.8,
1112
atom_type: str = "C",
12-
):
13+
) -> Atoms:
1314
# lattice_constant = 3.8, fcc generates a supercell with ~50 edges/atom, used for benchmarking
1415
atoms = bulk(atom_type, "fcc", a=lattice_constant)
1516
n_cells = int(np.ceil(np.cbrt(num_atoms)))
@@ -24,7 +25,7 @@ def get_fcc_crystal_by_num_cells(
2425
n_cells: int,
2526
atom_type: str = "Cu",
2627
lattice_constant: float = 3.61,
27-
):
28+
) -> Atoms:
2829
atoms = FaceCenteredCubic(
2930
directions=[[1, 0, 0], [0, 1, 0], [0, 0, 1]],
3031
symbol=atom_type,
@@ -34,3 +35,25 @@ def get_fcc_crystal_by_num_cells(
3435
)
3536
atoms.info = {"charge": 0, "spin": 0}
3637
return atoms
38+
39+
40+
def get_water_box(num_molecules=20, box_size=10.0, seed=42) -> Atoms:
41+
"""Create a random box of water molecules."""
42+
43+
rng = np.random.default_rng(seed)
44+
water = molecule("H2O")
45+
46+
all_positions = []
47+
all_symbols = []
48+
49+
for _ in range(num_molecules):
50+
# Random position and rotation for each water molecule
51+
offset = rng.random(3) * box_size
52+
positions = water.get_positions() + offset
53+
all_positions.extend(positions)
54+
all_symbols.extend(water.get_chemical_symbols())
55+
56+
atoms = Atoms(
57+
symbols=all_symbols, positions=all_positions, cell=[box_size] * 3, pbc=True
58+
)
59+
return atoms

src/fairchem/core/graph/compute.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,11 @@
99

1010
import torch
1111

12-
from fairchem.core.graph.radius_graph_pbc import radius_graph_pbc, radius_graph_pbc_v2
12+
from fairchem.core.graph.radius_graph_pbc import (
13+
radius_graph_pbc,
14+
radius_graph_pbc_v2,
15+
)
16+
from fairchem.core.graph.radius_graph_pbc_nvidia import radius_graph_pbc_nvidia
1317

1418

1519
def get_pbc_distances(
@@ -74,7 +78,7 @@ def generate_graph(
7478
cutoff (float): The maximum distance between atoms to consider them as neighbors.
7579
max_neighbors (int): The maximum number of neighbors to consider for each atom.
7680
enforce_max_neighbors_strictly (bool): Whether to strictly enforce the maximum number of neighbors.
77-
radius_pbc_version: the version of radius_pbc impl
81+
radius_pbc_version: the version of radius_pbc impl (1, 2, or 3 for NVIDIA)
7882
pbc (list[bool]): The periodic boundary conditions in 3 dimensions, defaults to [True,True,True] for 3D pbc
7983
8084
Returns:
@@ -90,6 +94,8 @@ def generate_graph(
9094
radius_graph_pbc_fn = radius_graph_pbc
9195
elif radius_pbc_version == 2:
9296
radius_graph_pbc_fn = radius_graph_pbc_v2
97+
elif radius_pbc_version == 3:
98+
radius_graph_pbc_fn = radius_graph_pbc_nvidia
9399
else:
94100
raise ValueError(f"Invalid radius_pbc version {radius_pbc_version}")
95101

src/fairchem/core/graph/radius_graph_pbc.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,21 @@
1515

1616

1717
def sum_partitions(x: torch.Tensor, partition_idxs: torch.Tensor) -> torch.Tensor:
18-
sums = torch.zeros(partition_idxs.shape[0] - 1, device=x.device, dtype=x.dtype)
19-
for idx in range(partition_idxs.shape[0] - 1):
20-
sums[idx] = x[partition_idxs[idx] : partition_idxs[idx + 1]].sum()
21-
return sums
18+
"""
19+
Sum values within partitions defined by indices.
20+
"""
21+
num_partitions = partition_idxs.shape[0] - 1
22+
if num_partitions == 0:
23+
return torch.zeros(0, device=x.device, dtype=x.dtype)
24+
25+
# Use cumsum-based approach for vectorization
26+
cumsum = torch.zeros(len(x) + 1, device=x.device, dtype=x.dtype)
27+
cumsum[1:] = torch.cumsum(x, dim=0)
28+
29+
# Gather cumsum at partition boundaries and compute differences
30+
starts = cumsum[partition_idxs[:-1]]
31+
ends = cumsum[partition_idxs[1:]]
32+
return ends - starts
2233

2334

2435
def get_counts(x: torch.Tensor, length: int):

0 commit comments

Comments
 (0)