Skip to content

Commit bd46225

Browse files
authored
DistributedManager cleanup and kNN cuml/scipy hotfixes (#1182)
* Enable a "soft" check path that doesn't raise an exception, just returns false, for check_min_version. This let's knn import without crashing if cuml isn't installed. * Use the soft fail for scipy too. * remove barrier at the cleanup of distributed manager. It is still opt-in, if it's ever needed, but by default it's off. * Finish cleanup of knn. New logic: - soft check of cuml and scipy install. - if the backend is selected as cuml or scipy, but not installed, it will error loudly. - if the backend is "auto", it will select cuml/scipy if available but torch if not. * Add test path for cpu knn when scipy is not installed. document hard_fail parameter in check_min_version.
1 parent 4287bc7 commit bd46225

File tree

6 files changed

+65
-15
lines changed

6 files changed

+65
-15
lines changed

physicsnemo/distributed/manager.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -800,18 +800,25 @@ def create_groups_from_config(
800800

801801
@atexit.register
802802
@staticmethod
803-
def cleanup():
804-
"""Clean up distributed group and singleton"""
803+
def cleanup(barrier: bool = False):
804+
"""Clean up distributed group and singleton
805+
806+
Parameters
807+
----------
808+
barrier : bool, optional
809+
Whether to use a global barrier before destroying the process group, by default False
810+
"""
805811
# Destroying group.WORLD is enough for all process groups to get destroyed
806812
if (
807813
"_is_initialized" in DistributedManager._shared_state
808814
and DistributedManager._shared_state["_is_initialized"]
809815
and "_distributed" in DistributedManager._shared_state
810816
and DistributedManager._shared_state["_distributed"]
811817
):
812-
if torch.cuda.is_available():
813-
dist.barrier(device_ids=[DistributedManager().local_rank])
814-
else:
815-
dist.barrier()
818+
if barrier:
819+
if torch.cuda.is_available():
820+
dist.barrier(device_ids=[DistributedManager().local_rank])
821+
else:
822+
dist.barrier()
816823
dist.destroy_process_group()
817824
DistributedManager._shared_state = {}

physicsnemo/utils/neighbors/knn/_cuml_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from physicsnemo.utils.version_check import check_min_version
2020

21-
CUML_AVAILABLE = check_min_version("cuml", "24.0.0")
21+
CUML_AVAILABLE = check_min_version("cuml", "24.0.0", hard_fail=False)
2222

2323
if CUML_AVAILABLE:
2424
import cuml

physicsnemo/utils/neighbors/knn/_scipy_impl.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from physicsnemo.utils.version_check import check_min_version
2020

21-
SCIPY_AVAILABLE = check_min_version("scipy", "1.7.0")
21+
SCIPY_AVAILABLE = check_min_version("scipy", "1.7.0", hard_fail=False)
2222

2323
if SCIPY_AVAILABLE:
2424
from scipy.spatial import KDTree

physicsnemo/utils/neighbors/knn/knn.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818

1919
import torch
2020

21+
from ._cuml_impl import CUML_AVAILABLE
2122
from ._cuml_impl import knn_impl as knn_cuml
23+
from ._scipy_impl import SCIPY_AVAILABLE
2224
from ._scipy_impl import knn_impl as knn_scipy
2325
from ._torch_impl import knn_impl as knn_torch
2426

@@ -72,9 +74,15 @@ def knn(
7274

7375
if backend == "auto":
7476
if points.is_cuda:
75-
backend = "cuml"
77+
if CUML_AVAILABLE:
78+
backend = "cuml"
79+
else:
80+
backend = "torch"
7681
else:
77-
backend = "scipy"
82+
if SCIPY_AVAILABLE:
83+
backend = "scipy"
84+
else:
85+
backend = "torch"
7886

7987
# Cuml foes not support bfloat16:
8088
# Autocast to float32:

physicsnemo/utils/version_check.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,10 @@
3737

3838

3939
def check_min_version(
40-
package_name: str, min_version: str, error_msg: Optional[str] = None
40+
package_name: str,
41+
min_version: str,
42+
error_msg: Optional[str] = None,
43+
hard_fail: bool = True,
4144
) -> bool:
4245
"""
4346
Check if an installed package meets the minimum version requirement.
@@ -46,7 +49,7 @@ def check_min_version(
4649
package_name: Name of the package to check
4750
min_version: Minimum required version string (e.g. '2.6.0')
4851
error_msg: Optional custom error message
49-
52+
hard_fail: Whether to raise an ImportError if the version requirement is not met
5053
Returns:
5154
True if version requirement is met
5255
@@ -57,14 +60,20 @@ def check_min_version(
5760
package = importlib.import_module(package_name)
5861
package_version = getattr(package, "__version__", "0.0.0")
5962
except ImportError:
60-
raise ImportError(f"Package {package_name} is required but not installed.")
63+
if hard_fail:
64+
raise ImportError(f"Package {package_name} is required but not installed.")
65+
else:
66+
return False
6167

6268
if version.parse(package_version) < version.parse(min_version):
6369
msg = (
6470
error_msg
6571
or f"{package_name} version {min_version} or higher is required, but found {package_version}"
6672
)
67-
raise ImportError(msg)
73+
if hard_fail:
74+
raise ImportError(msg)
75+
else:
76+
return False
6877

6978
return True
7079

test/utils/neighbors/test_knn.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from physicsnemo.utils.neighbors import knn
2121
from physicsnemo.utils.neighbors.knn._cuml_impl import knn_impl as knn_cuml
2222
from physicsnemo.utils.neighbors.knn._scipy_impl import knn_impl as knn_scipy
23+
from physicsnemo.utils.version_check import check_min_version
2324

2425

2526
@pytest.mark.parametrize("device", ["cpu", "cuda"])
@@ -33,6 +34,15 @@ def test_knn(device: str, k: int, backend: str, dtype: torch.dtype):
3334
Basic test for KNN functionality.
3435
We use a predictable grid of points to ensure the results are valid.
3536
"""
37+
38+
if backend == "cuml":
39+
if not check_min_version("cuml", "24.0.0", hard_fail=False):
40+
pytest.skip("cuml not available")
41+
42+
elif backend == "scipy":
43+
if not check_min_version("scipy", "1.7.0", hard_fail=False):
44+
pytest.skip("scipy not available")
45+
3646
# Skip cuml tests on CPU as it's not supported
3747
if backend == "cuml" and device == "cpu":
3848
pytest.skip("cuml backend not supported on CPU")
@@ -102,12 +112,17 @@ def test_knn_torch_compile_no_graph_break(device):
102112
queries = torch.randn(13, 3, device=device)
103113
k = 5
104114

115+
if not check_min_version("cuml", "24.0.0", hard_fail=False):
116+
backend = "torch"
117+
else:
118+
backend = "auto"
119+
105120
def search_fn(points, queries):
106121
return knn(
107122
points,
108123
queries,
109124
k=k,
110-
backend="auto",
125+
backend=backend,
111126
)
112127

113128
# Run both and compare outputs
@@ -133,8 +148,12 @@ def test_opcheck(device):
133148
k = 5
134149

135150
if device == "cuda":
151+
if not check_min_version("cuml", "24.0.0", hard_fail=False):
152+
pytest.skip("cuml not available")
136153
op = knn_cuml
137154
else:
155+
if not check_min_version("scipy", "1.7.0", hard_fail=False):
156+
pytest.skip("scipy not available")
138157
op = knn_scipy
139158

140159
torch.library.opcheck(op, args=(points, queries, k))
@@ -146,6 +165,13 @@ def test_knn_comparison(device):
146165
queries = torch.randn(21, 3, device=device)
147166
k = 5
148167

168+
if not check_min_version("cuml", "24.0.0", hard_fail=False):
169+
if device == "cuda":
170+
pytest.skip("cuml not available")
171+
if not check_min_version("scipy", "1.7.0", hard_fail=False):
172+
if device == "cuda":
173+
pytest.skip("scipy not available")
174+
149175
if device == "cuda":
150176
indices_cuml, distances_A = knn(points, queries, k, backend="cuml")
151177
indices_torch, distances_B = knn(points, queries, k, backend="torch")

0 commit comments

Comments
 (0)