|
9 | 9 |
|
10 | 10 | import contextlib |
11 | 11 | import logging |
| 12 | +import os |
12 | 13 | from copy import deepcopy |
13 | 14 |
|
14 | 15 | import numpy as np |
@@ -1860,3 +1861,97 @@ def test_execution_mode_not_set_when_conditions_not_met(model_name): |
1860 | 1861 | f"Expected execution_mode to be None when activation_checkpointing=True, " |
1861 | 1862 | f"got {predict_unit.inference_settings.execution_mode}" |
1862 | 1863 | ) |
| 1864 | + |
| 1865 | + |
| 1866 | +# ========================================================================= |
| 1867 | +# Full-model GP correctness: no-GP vs all-gather vs A2A |
| 1868 | +# Skipped on CI (no multi-GPU), run locally or via SLURM |
| 1869 | +# ========================================================================= |
| 1870 | + |
| 1871 | +_skip_if_ci = pytest.mark.skipif( |
| 1872 | + os.environ.get("CI") == "true", |
| 1873 | + reason="Multi-GPU test, skipped in CI", |
| 1874 | +) |
| 1875 | + |
| 1876 | + |
| 1877 | +@_skip_if_ci |
| 1878 | +@pytest.mark.gpu() |
| 1879 | +@pytest.mark.parametrize("num_atoms", [10, 50, 100]) |
| 1880 | +@pytest.mark.parametrize( |
| 1881 | + "workers, gp_mode", |
| 1882 | + [ |
| 1883 | + # All-gather (default GP) |
| 1884 | + (1, None), |
| 1885 | + (2, None), |
| 1886 | + (4, None), |
| 1887 | + # A2A + spatial |
| 1888 | + (1, {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}), |
| 1889 | + (2, {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}), |
| 1890 | + (4, {"use_all_to_all_gp": True, "gp_partition_strategy": "spatial"}), |
| 1891 | + # A2A + index_split |
| 1892 | + (1, {"use_all_to_all_gp": True, "gp_partition_strategy": "index_split"}), |
| 1893 | + (2, {"use_all_to_all_gp": True, "gp_partition_strategy": "index_split"}), |
| 1894 | + (4, {"use_all_to_all_gp": True, "gp_partition_strategy": "index_split"}), |
| 1895 | + ], |
| 1896 | +) |
| 1897 | +def test_full_model_gp_correctness(num_atoms, workers, gp_mode): |
| 1898 | + seed = 42 |
| 1899 | + model_path = pretrained_checkpoint_path_from_name("uma-s-1p1") |
| 1900 | + ifsets = InferenceSettings( |
| 1901 | + tf32=False, |
| 1902 | + merge_mole=True, |
| 1903 | + activation_checkpointing=False, |
| 1904 | + internal_graph_gen_version=2, |
| 1905 | + external_graph_gen=False, |
| 1906 | + ) |
| 1907 | + atoms = get_fcc_crystal_by_num_atoms(num_atoms) |
| 1908 | + atomic_data = AtomicData.from_ase(atoms, task_name=["omat"]) |
| 1909 | + |
| 1910 | + overrides = None |
| 1911 | + if gp_mode is not None: |
| 1912 | + overrides = {"backbone": gp_mode} |
| 1913 | + |
| 1914 | + seed_everywhere(seed) |
| 1915 | + ppunit = ParallelMLIPPredictUnit( |
| 1916 | + inference_model_path=model_path, |
| 1917 | + device="cuda", |
| 1918 | + inference_settings=ifsets, |
| 1919 | + num_workers=workers, |
| 1920 | + overrides=overrides, |
| 1921 | + ) |
| 1922 | + pp_results = ppunit.predict(atomic_data) |
| 1923 | + distutils.cleanup_gp_ray() |
| 1924 | + |
| 1925 | + seed_everywhere(seed) |
| 1926 | + ref_unit = pretrained_mlip.get_predict_unit( |
| 1927 | + "uma-s-1p1", device="cuda", inference_settings=ifsets |
| 1928 | + ) |
| 1929 | + ref_results = ref_unit.predict(atomic_data) |
| 1930 | + |
| 1931 | + assert torch.allclose( |
| 1932 | + pp_results["energy"].detach().cpu(), |
| 1933 | + ref_results["energy"].detach().cpu(), |
| 1934 | + atol=ATOL, |
| 1935 | + ), ( |
| 1936 | + f"Energy mismatch: workers={workers}, gp_mode={gp_mode}, " |
| 1937 | + f"num_atoms={num_atoms}, " |
| 1938 | + f"pp={pp_results['energy'].item():.6f}, " |
| 1939 | + f"ref={ref_results['energy'].item():.6f}" |
| 1940 | + ) |
| 1941 | + assert torch.allclose( |
| 1942 | + pp_results["forces"].detach().cpu(), |
| 1943 | + ref_results["forces"].detach().cpu(), |
| 1944 | + atol=FORCE_TOL, |
| 1945 | + ), ( |
| 1946 | + f"Forces mismatch: workers={workers}, gp_mode={gp_mode}, " |
| 1947 | + f"num_atoms={num_atoms}, " |
| 1948 | + f"max_diff={torch.max(torch.abs(pp_results['forces'].detach().cpu() - ref_results['forces'].detach().cpu())).item():.6e}" |
| 1949 | + ) |
| 1950 | + assert torch.allclose( |
| 1951 | + pp_results["stress"].detach().cpu(), |
| 1952 | + ref_results["stress"].detach().cpu(), |
| 1953 | + atol=ATOL, |
| 1954 | + ), ( |
| 1955 | + f"Stress mismatch: workers={workers}, gp_mode={gp_mode}, " |
| 1956 | + f"num_atoms={num_atoms}" |
| 1957 | + ) |
0 commit comments