Skip to content

Commit 29a3fa4

Browse files
committed
Merge branch 'feature/array_api_compatibility' into 'develop'
Array API compatibility See merge request e040/e0404/pyRadPlan!101
2 parents 954058a + d7f2707 commit 29a3fa4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

49 files changed

+2379
-698
lines changed

examples/pencilbeam_proton.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,14 @@
2323
fluence_optimization,
2424
plot_slice,
2525
load_tg119,
26+
xp_utils,
2627
)
2728

2829
from pyRadPlan.optimization.objectives import SquaredDeviation, SquaredOverdosing, MeanDose
2930

31+
32+
xp_utils.PREFER_GPU = True
33+
xp_utils.PREFERRED_CPU_ARRAY_BACKEND = "numpy"
3034
logging.basicConfig(level=logging.INFO)
3135

3236
# %%
@@ -39,6 +43,7 @@
3943
# Create a plan object
4044
pln = IonPlan(radiation_mode="protons", machine="Generic")
4145
pln.prop_opt = {"solver": "scipy"}
46+
pln.prop_dose_calc = {"dose_grid": ct.grid}
4247

4348
# Generate Steering Geometry ("stf")
4449
stf = generate_stf(ct, cst, pln)

examples/raytracer_rad_depth.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@
2828
from pyRadPlan.visualization import plot_slice
2929
from pyRadPlan.io import load_patient
3030

31+
from pyRadPlan import xp_utils
32+
33+
xp_utils.PREFERRED_GPU_ARRAY_BACKEND = "cupy"
34+
xp_utils.PREFER_GPU = True
35+
3136
# Configure the Logger to show you debug information
3237
logging.basicConfig(level=logging.INFO)
3338
logging.getLogger("pyRadPlan").setLevel(logging.DEBUG)
@@ -63,7 +68,7 @@
6368
# with a single beam at 0 degrees
6469
stfgen = StfGeneratorIMPT(pln)
6570
stfgen.bixel_width = 5.0
66-
stfgen.gantry_angles = [90.0]
71+
stfgen.gantry_angles = [0.0]
6772

6873
# We generate the beam geometry on the CT and CST
6974
stf = stfgen.generate(ct, cst)
@@ -72,6 +77,8 @@
7277
# For that, we use a default HU->rSP table to convert our CT to water-equivalent thickness and then
7378
# call a voxel-wise RayTracing algorithm (proposed by Siddon) to calculate the radiological depth.
7479
rt = RayTracerSiddon([ct.compute_wet(default_hlut())])
80+
rt.lateral_cut_off = 150.0
81+
rt.precision = np.float32
7582
rt.debug_core_performance = True
7683
rad_depth_cubes = rt.trace_cubes(stf[0])
7784

matRad

Submodule matRad updated 47 files

pyRadPlan/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from .analysis._dvh import DVH, DVHCollection
3838
from .visualization import plot_slice
3939
from .io import load_patient, load_tg119
40+
from .core import xp_utils
4041

4142
try:
4243
__version__ = version(__name__)
@@ -49,6 +50,7 @@
4950

5051
__all__ = [
5152
"__version__",
53+
"xp_utils",
5254
"Plan",
5355
"IonPlan",
5456
"PhotonPlan",

pyRadPlan/core/datamodel.py

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from pydantic.alias_generators import to_camel
1111
from copy import deepcopy
1212

13+
import array_api_compat
14+
1315

1416
class PyRadPlanBaseModel(BaseModel):
1517
"""
@@ -45,7 +47,7 @@ def __eq__(self, other: Any) -> bool:
4547
"""
4648
try:
4749
return super().__eq__(other)
48-
except ValueError:
50+
except (ValueError, TypeError):
4951
if self.__dict__.keys() != other.__dict__.keys():
5052
return False
5153
stack = [(self.__dict__, other.__dict__)]
@@ -56,12 +58,7 @@ def __eq__(self, other: Any) -> bool:
5658
for key in dict_a:
5759
if isinstance(dict_a[key], dict) and isinstance(dict_b[key], dict):
5860
stack.append((dict_a[key], dict_b[key]))
59-
elif isinstance(dict_a[key], np.ndarray) and isinstance(
60-
dict_b[key], np.ndarray
61-
):
62-
if not np.array_equal(dict_a[key], dict_b[key]):
63-
return False
64-
elif dict_a[key] != dict_b[key]:
61+
elif not self._eq_dict_entry(dict_a, dict_b, key):
6562
return False
6663
return True
6764

@@ -76,6 +73,55 @@ def __ne__(self, other: Any) -> bool:
7673
else:
7774
return True
7875

76+
def _eq_dict_entry(self, dict_a: dict, dict_b: dict, key: Any) -> bool:
77+
"""Compare two dictionary entries for equality."""
78+
79+
if (
80+
isinstance(dict_a[key], np.ndarray)
81+
and isinstance(dict_b[key], np.ndarray)
82+
and dict_a[key].dtype == object
83+
and dict_b[key].dtype == object
84+
):
85+
return self._eq_object_arrays(dict_a[key], dict_b[key])
86+
elif array_api_compat.is_array_api_obj(dict_a[key]) and array_api_compat.is_array_api_obj(
87+
dict_b[key]
88+
):
89+
try:
90+
xp = array_api_compat.array_namespace(dict_a[key], dict_b[key])
91+
except (ValueError, TypeError):
92+
return False
93+
if dict_a[key].shape != dict_b[key].shape or not xp.all(dict_a[key] == dict_b[key]):
94+
return False
95+
elif dict_a[key] != dict_b[key]:
96+
return False
97+
98+
return True
99+
100+
def _eq_object_arrays(self, obj_array_a: np.ndarray, obj_array_b: np.ndarray) -> bool:
101+
"""Compare two object arrays for equality."""
102+
assert obj_array_a.dtype == object and obj_array_b.dtype == object
103+
104+
if obj_array_a.shape != obj_array_b.shape:
105+
return False
106+
107+
for a, b in zip(obj_array_a.flat, obj_array_b.flat):
108+
if a is None and b is None:
109+
continue
110+
elif isinstance(a, np.ndarray) and isinstance(b, np.ndarray):
111+
if a.shape != b.shape or a.dtype != b.dtype or not np.all(a == b):
112+
return False
113+
elif array_api_compat.is_array_api_obj(a) and array_api_compat.is_array_api_obj(b):
114+
try:
115+
xp = array_api_compat.array_namespace(a, b)
116+
except (ValueError, TypeError):
117+
return False
118+
if not xp.all(a == b):
119+
return False
120+
elif any(a != b):
121+
return False
122+
123+
return True
124+
79125
def to_matrad(self, context: Union[str, dict] = "mat-file") -> Any:
80126
"""
81127
Perform matRad compatible serialization.

pyRadPlan/core/np2sitk.py

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,16 @@
33
from typing import Literal
44
import SimpleITK as sitk
55
import numpy as np
6+
import array_api_compat
7+
from array_api_compat import numpy as xnp
68
from ._grids import Grid
79

10+
from numpy.typing import NDArray
11+
from ..core.xp_utils.typing import Array
12+
from ..core.xp_utils import to_numpy, from_numpy
813

9-
def sitk_mask_to_linear_indices(mask: sitk.Image, order="sitk") -> np.ndarray:
14+
15+
def sitk_mask_to_linear_indices(mask: sitk.Image, order="sitk") -> NDArray:
1016
"""
1117
Convert a SimpleITK mask to linear indices.
1218
@@ -20,7 +26,7 @@ def sitk_mask_to_linear_indices(mask: sitk.Image, order="sitk") -> np.ndarray:
2026
2127
Returns
2228
-------
23-
np.ndarray
29+
NDArray
2430
A 1D numpy array of linear indices where the mask is non-zero.
2531
2632
Raises
@@ -37,16 +43,14 @@ def sitk_mask_to_linear_indices(mask: sitk.Image, order="sitk") -> np.ndarray:
3743
raise ValueError("Invalid ordering. Must be 'sitk' or 'numpy'.")
3844

3945

40-
def linear_indices_to_sitk_mask(
41-
indices: np.ndarray, ref_image: sitk.Image, order="sitk"
42-
) -> sitk.Image:
46+
def linear_indices_to_sitk_mask(indices: Array, ref_image: sitk.Image, order="sitk") -> sitk.Image:
4347
"""
4448
Convert linear indices to a SimpleITK mask.
4549
4650
Parameters
4751
----------
48-
indices : np.ndarray
49-
A 1D numpy array of linear indices where the mask is non-zero.
52+
indices : Array
53+
A 1D Array API conform array of linear indices where the mask is non-zero.
5054
ref_image : sitk.Image
5155
The reference image on which the mask is defined.
5256
order : str, optional
@@ -64,7 +68,9 @@ def linear_indices_to_sitk_mask(
6468
If the ordering is not 'sitk' or 'numpy'.
6569
"""
6670

67-
arr = np.zeros_like(sitk.GetArrayViewFromImage(ref_image), dtype=np.uint8)
71+
indices = to_numpy(indices)
72+
73+
arr: NDArray = xnp.zeros_like(sitk.GetArrayViewFromImage(ref_image), dtype=xnp.uint8)
6874

6975
if order == "sitk":
7076
arr.T.flat[indices] = 1
@@ -80,18 +86,18 @@ def linear_indices_to_sitk_mask(
8086

8187

8288
def linear_indices_to_grid_coordinates(
83-
indices: np.ndarray,
89+
indices: Array,
8490
grid: Grid,
8591
index_type: Literal["numpy", "sitk"] = "numpy",
8692
dtype: np.dtype = np.float64,
87-
) -> np.ndarray:
93+
) -> Array:
8894
"""
8995
Convert linear indices to gridcoordinates.
9096
9197
Parameters
9298
----------
93-
indices : np.ndarray
94-
A 1D numpy array of linear indices where the mask is non-zero.
99+
indices : Array
100+
A 1D Array API conform array of linear indices where the mask is non-zero.
95101
grid : Grid
96102
The image grid on which the indices lie.
97103
index_type : Literal["numpy", "sitk"], optional
@@ -102,10 +108,12 @@ def linear_indices_to_grid_coordinates(
102108
103109
Returns
104110
-------
105-
np.ndarray
106-
A 2D numpy array of image coordinates.
111+
Array
112+
A 2D Array API conform array array of image coordinates.
107113
"""
108114

115+
xp = array_api_compat.array_namespace(indices)
116+
109117
# this is a manual reimplementation of np.unravel_index
110118
# to avoid the overhead of creating a tuple of arrays
111119
if index_type == "numpy":
@@ -117,6 +125,8 @@ def linear_indices_to_grid_coordinates(
117125
else:
118126
raise ValueError("Invalid index type. Must be 'numpy' or 'sitk'.")
119127

128+
indices = to_numpy(indices)
129+
120130
v = np.empty((3, np.asarray(indices).size), dtype=dtype)
121131
tmp, v[order[0]] = np.divmod(indices, d2)
122132
v[order[2]], v[order[1]] = np.divmod(tmp, d1)
@@ -128,22 +138,22 @@ def linear_indices_to_grid_coordinates(
128138

129139
physical_point = origin + np.matmul(np.matmul(grid.direction, spacing_diag), v).T
130140

131-
return physical_point
141+
return from_numpy(xp, physical_point)
132142

133143

134144
def linear_indices_to_image_coordinates(
135-
indices: np.ndarray,
145+
indices: Array,
136146
image: sitk.Image,
137147
index_type: Literal["numpy", "sitk"] = "numpy",
138148
dtype: np.dtype = np.float64,
139-
) -> np.ndarray:
149+
) -> Array:
140150
"""
141151
Convert linear indices to image coordinates.
142152
143153
Parameters
144154
----------
145-
indices : np.ndarray
146-
A 1D numpy array of linear indices where the mask is non-zero.
155+
indices : Array
156+
A 1D Array API conform array of linear indices where the mask is non-zero.
147157
image : sitk.Image
148158
The reference image on which the mask is defined.
149159
index_type : Literal["numpy", "sitk"], optional
@@ -154,8 +164,8 @@ def linear_indices_to_image_coordinates(
154164
155165
Returns
156166
-------
157-
np.ndarray
158-
A 2D numpy array of image coordinates.
167+
Array
168+
A 2D Array API conform array of image coordinates.
159169
"""
160170

161171
grid = Grid.from_sitk_image(image)
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
"""Provide information about the available compute backends."""
2+
3+
import importlib
4+
5+
from typing import Optional
6+
7+
try:
8+
import cupy as cp
9+
except ImportError:
10+
cp = None
11+
12+
try:
13+
import torch
14+
except ImportError:
15+
torch = None
16+
17+
from .helpers import (
18+
get_current_stream,
19+
create_stream,
20+
synchronize,
21+
record_event,
22+
elapsed_time,
23+
to_numpy,
24+
from_numpy,
25+
to_namespace,
26+
)
27+
28+
from .typing import Array, ArrayNamespace
29+
30+
31+
# Check if GPU is available
32+
def cupy_available() -> bool:
33+
"""Check if CuPy is available and a compatible GPU is present."""
34+
return cp is not None and cp.cuda.is_available()
35+
36+
37+
def pytorch_available() -> bool:
38+
"""Check if PyTorch is available."""
39+
return torch is not None
40+
41+
42+
def pytorch_gpu_available() -> bool:
43+
"""Check if PyTorch is available and a compatible GPU is present."""
44+
return torch is not None and torch.cuda.is_available()
45+
46+
47+
PREFERRED_CPU_ARRAY_BACKEND: str = "numpy"
48+
PREFERRED_GPU_ARRAY_BACKEND: str = "cupy" if cupy_available() else None
49+
PREFER_GPU = True
50+
51+
52+
def choose_array_api_namespace(namespace: Optional[str] = None) -> ArrayNamespace:
53+
"""
54+
Get the name of the preferred Array API conform computational namespace / backend.
55+
56+
Parameters
57+
----------
58+
namespace : Optional[str], optional
59+
The name of the desired backend. If None, the preferred backend is used.
60+
"""
61+
if namespace is None:
62+
if PREFER_GPU and PREFERRED_GPU_ARRAY_BACKEND is not None:
63+
namespace = PREFERRED_GPU_ARRAY_BACKEND
64+
else:
65+
namespace = PREFERRED_CPU_ARRAY_BACKEND
66+
else:
67+
namespace = namespace.lower()
68+
69+
try:
70+
return importlib.import_module(f"array_api_compat.{namespace}")
71+
except ModuleNotFoundError:
72+
return importlib.import_module(namespace)
73+
74+
75+
__all__ = [
76+
"cp",
77+
"torch",
78+
"cupy_available",
79+
"torch_gpu_available",
80+
"get_current_stream",
81+
"create_stream",
82+
"record_event",
83+
"synchronize",
84+
"elapsed_time",
85+
"to_numpy",
86+
"from_numpy",
87+
"to_namespace",
88+
"Array",
89+
"ArrayNamespace",
90+
]

0 commit comments

Comments
 (0)