|
13 | 13 | import logging |
14 | 14 | import re |
15 | 15 | from collections.abc import Sequence |
16 | | -from typing import List, Optional, Union |
| 16 | +from typing import Union |
17 | 17 |
|
18 | 18 | import ase |
19 | 19 | import ase.db.sqlite |
|
25 | 25 | from ase.stress import full_3x3_to_voigt_6_stress, voigt_6_to_full_3x3_stress |
26 | 26 | from monty.dev import requires |
27 | 27 |
|
| 28 | +from fairchem.core.common.utils import StrEnum |
| 29 | + |
28 | 30 | try: |
29 | 31 | from pymatgen.io.ase import AseAtomsAdaptor |
30 | 32 |
|
|
33 | 35 | AseAtomsAdaptor = None |
34 | 36 | pmg_installed = False |
35 | 37 |
|
| 38 | +from fairchem.core.graph.radius_graph_pbc_nvidia import get_neighbors_nvidia_atoms |
36 | 39 |
|
37 | 40 | IndexType = Union[slice, torch.Tensor, np.ndarray, Sequence] |
38 | 41 |
|
| 42 | + |
| 43 | +class ExternalGraphMethod(StrEnum): |
| 44 | + """Enum for external graph generation methods.""" |
| 45 | + |
| 46 | + PYMATGEN = "pymatgen" |
| 47 | + NVIDIA = "nvidia" |
| 48 | + |
| 49 | + |
39 | 50 | # these are all currently certainly output by the current a2g |
40 | 51 | # except for tags, all fields are required for network inference. |
41 | 52 | _REQUIRED_KEYS = [ |
@@ -83,7 +94,7 @@ def size_repr(key: str, item: torch.Tensor, indent=0) -> str: |
83 | 94 | out = item.item() |
84 | 95 | elif torch.is_tensor(item): |
85 | 96 | out = str(list(item.size())) |
86 | | - elif isinstance(item, (List, tuple)): |
| 97 | + elif isinstance(item, (list, tuple)): |
87 | 98 | out = str([len(item)]) |
88 | 99 | elif isinstance(item, dict): |
89 | 100 | lines = [indent_str + size_repr(k, v, 2) for k, v in item.items()] |
@@ -300,10 +311,8 @@ def validate(self): |
300 | 311 | assert self.forces.dtype == self.pos.dtype |
301 | 312 | if hasattr(self, "stress"): |
302 | 313 | # NOTE: usually decomposed. for EFS prediction right now we reshape to (9,). need to discuss, perhaps use (1,3,3) |
303 | | - assert ( |
304 | | - self.stress.dim() == 3 |
305 | | - and self.stress.shape[1:] == (3, 3) |
306 | | - or (self.stress.dim() == 2 and self.stress.shape[1:] == (9,)) |
| 314 | + assert (self.stress.dim() == 3 and self.stress.shape[1:] == (3, 3)) or ( |
| 315 | + self.stress.dim() == 2 and self.stress.shape[1:] == (9,) |
307 | 316 | ) |
308 | 317 | assert self.stress.shape[0] == self.num_graphs |
309 | 318 | assert self.stress.dtype == self.pos.dtype |
@@ -332,6 +341,7 @@ def from_ase( |
332 | 341 | r_data_keys: list[str] | None = None, # NOT USED, compat for now |
333 | 342 | task_name: str | None = None, |
334 | 343 | target_dtype: torch.dtype = torch.float32, |
| 344 | + external_graph_method: ExternalGraphMethod | str = ExternalGraphMethod.PYMATGEN, |
335 | 345 | ) -> AtomicData: |
336 | 346 | atoms = input_atoms.copy() |
337 | 347 | calc = input_atoms.calc |
@@ -375,7 +385,16 @@ def from_ase( |
375 | 385 | assert ( |
376 | 386 | max_neigh is not None |
377 | 387 | ), "max_neigh must be specified for cpu graph construction." |
378 | | - split_idx_dist = get_neighbors_pymatgen(atoms, radius, max_neigh) |
| 388 | + |
| 389 | + if external_graph_method == ExternalGraphMethod.PYMATGEN: |
| 390 | + split_idx_dist = get_neighbors_pymatgen(atoms, radius, max_neigh) |
| 391 | + elif external_graph_method == ExternalGraphMethod.NVIDIA: |
| 392 | + split_idx_dist = get_neighbors_nvidia_atoms(atoms, radius, max_neigh) |
| 393 | + else: |
| 394 | + raise ValueError( |
| 395 | + f"external_graph_method must be 'pymatgen' or 'nvidia', got {external_graph_method}" |
| 396 | + ) |
| 397 | + |
379 | 398 | edge_index, cell_offsets = reshape_features( |
380 | 399 | *split_idx_dist, target_dtype=target_dtype |
381 | 400 | ) |
@@ -443,16 +462,20 @@ def from_ase( |
443 | 462 | # TODO another way to specify this is to spcify a key. maybe total_charge |
444 | 463 | charge = torch.LongTensor( |
445 | 464 | [ |
446 | | - atoms.info.get("charge", 0) |
447 | | - if r_data_keys is not None and "charge" in r_data_keys |
448 | | - else 0 |
| 465 | + ( |
| 466 | + atoms.info.get("charge", 0) |
| 467 | + if r_data_keys is not None and "charge" in r_data_keys |
| 468 | + else 0 |
| 469 | + ) |
449 | 470 | ] |
450 | 471 | ) |
451 | 472 | spin = torch.LongTensor( |
452 | 473 | [ |
453 | | - atoms.info.get("spin", 0) |
454 | | - if r_data_keys is not None and "spin" in r_data_keys |
455 | | - else 0 |
| 474 | + ( |
| 475 | + atoms.info.get("spin", 0) |
| 476 | + if r_data_keys is not None and "spin" in r_data_keys |
| 477 | + else 0 |
| 478 | + ) |
456 | 479 | ] |
457 | 480 | ) |
458 | 481 |
|
@@ -844,7 +867,7 @@ def update_batch_edges( |
844 | 867 |
|
845 | 868 |
|
846 | 869 | def atomicdata_list_to_batch( |
847 | | - data_list: list[AtomicData], exclude_keys: Optional[list] = None |
| 870 | + data_list: list[AtomicData], exclude_keys: list | None = None |
848 | 871 | ) -> AtomicData: |
849 | 872 | """ |
850 | 873 | all data points must be single graphs and have the same set of keys. |
|
0 commit comments