Skip to content

Commit af68f80

Browse files
author
rgao user
committed
Add full-model GP correctness test (no-GP vs allgather vs A2A)
1 parent 50ab32d commit af68f80

2 files changed

Lines changed: 100 additions & 18 deletions

File tree

src/fairchem/core/common/parallelism/graph_parallel_verfication.md

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -38,28 +38,15 @@ pytest tests/core/common/parallelism/test_a2a_correctness.py -v
3838
8 tests: correctness at 100/500 atoms × 2 strategies, consistency across graph sizes × 2,
3939
1536-dim embeddings × 2.
4040

41-
### 1d. Full-model GPU correctness (8 GPUs, ~10 min)
42-
43-
Run BL and A2A benchmarks at 1000 atoms with a single repeat and compare outputs:
41+
### 1d. Full-model GPU correctness (4+ GPUs, ~5 min)
4442

4543
```bash
46-
# BL baseline
47-
fairchem -c configs/uma/speed/uma-speed.yaml \
48-
job=local_8gpu \
49-
runner.natoms_list=[1000] \
50-
runner.timeiters=1 \
51-
runner.repeats=1
52-
53-
# A2A + spatial
54-
fairchem -c configs/uma/speed/uma-speed.yaml \
55-
job=local_8gpu \
56-
runner.natoms_list=[1000] \
57-
runner.timeiters=1 \
58-
runner.repeats=1 \
59-
'+runner.overrides={backbone: {use_all_to_all_gp: true, gp_partition_strategy: spatial}}'
44+
pytest tests/core/units/mlip_unit/test_predict.py::test_full_model_gp_correctness -v
6045
```
6146

62-
Verify energy/forces/stress match between BL and A2A (tol=1e-4).
47+
27 tests: 3 atom counts (10, 50, 100) × 9 configs (no-GP, allgather, A2A-spatial,
48+
A2A-index_split at 1/2/4 workers). Compares energy/forces/stress against single-GPU
49+
reference (tol: energy/stress 5e-4, forces 1e-4). Skipped on CI (`CI=true`).
6350

6451
### 1e. Predict pipeline + MD consistency (CPU, PR3 branch, ~2 min)
6552

tests/core/units/mlip_unit/test_predict.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import contextlib
1111
import logging
12+
import os
1213
from copy import deepcopy
1314

1415
import numpy as np
@@ -1860,3 +1861,97 @@ def test_execution_mode_not_set_when_conditions_not_met(model_name):
18601861
f"Expected execution_mode to be None when activation_checkpointing=True, "
18611862
f"got {predict_unit.inference_settings.execution_mode}"
18621863
)
1864+
1865+
1866+
# =========================================================================
1867+
# Full-model GP correctness: no-GP vs all-gather vs A2A
1868+
# Skipped on CI (no multi-GPU), run locally or via SLURM
1869+
# =========================================================================
1870+
1871+
_skip_if_ci = pytest.mark.skipif(
1872+
os.environ.get("CI") == "true",
1873+
reason="Multi-GPU test, skipped in CI",
1874+
)
1875+
1876+
1877+
@_skip_if_ci
1878+
@pytest.mark.gpu()
1879+
@pytest.mark.parametrize("num_atoms", [10, 50, 100])
1880+
@pytest.mark.parametrize(
1881+
"workers, gp_mode",
1882+
[
1883+
# All-gather (default GP)
1884+
(1, None),
1885+
(2, None),
1886+
(4, None),
1887+
# A2A + spatial
1888+
(1, {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}),
1889+
(2, {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}),
1890+
(4, {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}),
1891+
# A2A + index_split
1892+
(1, {"use_all_to_all_gp": True, "gp_partition_strategy": "index_split"}),
1893+
(2, {"use_all_to_all_gp": True, "gp_partition_strategy": "index_split"}),
1894+
(4, {"use_all_to_all_gp": True, "gp_partition_strategy": "index_split"}),
1895+
],
1896+
)
1897+
def test_full_model_gp_correctness(num_atoms, workers, gp_mode):
1898+
seed = 42
1899+
model_path = pretrained_checkpoint_path_from_name("uma-s-1p1")
1900+
ifsets = InferenceSettings(
1901+
tf32=False,
1902+
merge_mole=True,
1903+
activation_checkpointing=False,
1904+
internal_graph_gen_version=2,
1905+
external_graph_gen=False,
1906+
)
1907+
atoms = get_fcc_crystal_by_num_atoms(num_atoms)
1908+
atomic_data = AtomicData.from_ase(atoms, task_name=["omat"])
1909+
1910+
overrides = None
1911+
if gp_mode is not None:
1912+
overrides = {"backbone": gp_mode}
1913+
1914+
seed_everywhere(seed)
1915+
ppunit = ParallelMLIPPredictUnit(
1916+
inference_model_path=model_path,
1917+
device="cuda",
1918+
inference_settings=ifsets,
1919+
num_workers=workers,
1920+
overrides=overrides,
1921+
)
1922+
pp_results = ppunit.predict(atomic_data)
1923+
distutils.cleanup_gp_ray()
1924+
1925+
seed_everywhere(seed)
1926+
ref_unit = pretrained_mlip.get_predict_unit(
1927+
"uma-s-1p1", device="cuda", inference_settings=ifsets
1928+
)
1929+
ref_results = ref_unit.predict(atomic_data)
1930+
1931+
assert torch.allclose(
1932+
pp_results["energy"].detach().cpu(),
1933+
ref_results["energy"].detach().cpu(),
1934+
atol=ATOL,
1935+
), (
1936+
f"Energy mismatch: workers={workers}, gp_mode={gp_mode}, "
1937+
f"num_atoms={num_atoms}, "
1938+
f"pp={pp_results['energy'].item():.6f}, "
1939+
f"ref={ref_results['energy'].item():.6f}"
1940+
)
1941+
assert torch.allclose(
1942+
pp_results["forces"].detach().cpu(),
1943+
ref_results["forces"].detach().cpu(),
1944+
atol=FORCE_TOL,
1945+
), (
1946+
f"Forces mismatch: workers={workers}, gp_mode={gp_mode}, "
1947+
f"num_atoms={num_atoms}, "
1948+
f"max_diff={torch.max(torch.abs(pp_results['forces'].detach().cpu() - ref_results['forces'].detach().cpu())).item():.6e}"
1949+
)
1950+
assert torch.allclose(
1951+
pp_results["stress"].detach().cpu(),
1952+
ref_results["stress"].detach().cpu(),
1953+
atol=ATOL,
1954+
), (
1955+
f"Stress mismatch: workers={workers}, gp_mode={gp_mode}, "
1956+
f"num_atoms={num_atoms}"
1957+
)

0 commit comments

Comments
 (0)