33from typing import Literal
44import SimpleITK as sitk
55import numpy as np
6+ import array_api_compat
7+ from array_api_compat import numpy as xnp
68from ._grids import Grid
79
10+ from numpy .typing import NDArray
11+ from ..core .xp_utils .typing import Array
12+ from ..core .xp_utils import to_numpy , from_numpy
813
9- def sitk_mask_to_linear_indices (mask : sitk .Image , order = "sitk" ) -> np .ndarray :
14+
15+ def sitk_mask_to_linear_indices (mask : sitk .Image , order = "sitk" ) -> NDArray :
1016 """
1117 Convert a SimpleITK mask to linear indices.
1218
@@ -20,7 +26,7 @@ def sitk_mask_to_linear_indices(mask: sitk.Image, order="sitk") -> np.ndarray:
2026
2127 Returns
2228 -------
23- np.ndarray
29+ NDArray
2430 A 1D numpy array of linear indices where the mask is non-zero.
2531
2632 Raises
@@ -37,16 +43,14 @@ def sitk_mask_to_linear_indices(mask: sitk.Image, order="sitk") -> np.ndarray:
3743 raise ValueError ("Invalid ordering. Must be 'sitk' or 'numpy'." )
3844
3945
40- def linear_indices_to_sitk_mask (
41- indices : np .ndarray , ref_image : sitk .Image , order = "sitk"
42- ) -> sitk .Image :
46+ def linear_indices_to_sitk_mask (indices : Array , ref_image : sitk .Image , order = "sitk" ) -> sitk .Image :
4347 """
4448 Convert linear indices to a SimpleITK mask.
4549
4650 Parameters
4751 ----------
48- indices : np.ndarray
49- A 1D numpy array of linear indices where the mask is non-zero.
52+ indices : Array
53+ A 1D Array API conform array of linear indices where the mask is non-zero.
5054 ref_image : sitk.Image
5155 The reference image on which the mask is defined.
5256 order : str, optional
@@ -64,7 +68,9 @@ def linear_indices_to_sitk_mask(
6468 If the ordering is not 'sitk' or 'numpy'.
6569 """
6670
67- arr = np .zeros_like (sitk .GetArrayViewFromImage (ref_image ), dtype = np .uint8 )
71+ indices = to_numpy (indices )
72+
73+ arr : NDArray = xnp .zeros_like (sitk .GetArrayViewFromImage (ref_image ), dtype = xnp .uint8 )
6874
6975 if order == "sitk" :
7076 arr .T .flat [indices ] = 1
@@ -80,18 +86,18 @@ def linear_indices_to_sitk_mask(
8086
8187
8288def linear_indices_to_grid_coordinates (
83- indices : np . ndarray ,
89+ indices : Array ,
8490 grid : Grid ,
8591 index_type : Literal ["numpy" , "sitk" ] = "numpy" ,
8692 dtype : np .dtype = np .float64 ,
87- ) -> np . ndarray :
93+ ) -> Array :
8894 """
8995 Convert linear indices to gridcoordinates.
9096
9197 Parameters
9298 ----------
93- indices : np.ndarray
94- A 1D numpy array of linear indices where the mask is non-zero.
99+ indices : Array
100+ A 1D Array API conform array of linear indices where the mask is non-zero.
95101 grid : Grid
96102 The image grid on which the indices lie.
97103 index_type : Literal["numpy", "sitk"], optional
@@ -102,10 +108,12 @@ def linear_indices_to_grid_coordinates(
102108
103109 Returns
104110 -------
105- np.ndarray
106- A 2D numpy array of image coordinates.
111+ Array
112+ A 2D Array API conform array array of image coordinates.
107113 """
108114
115+ xp = array_api_compat .array_namespace (indices )
116+
109117 # this is a manual reimplementation of np.unravel_index
110118 # to avoid the overhead of creating a tuple of arrays
111119 if index_type == "numpy" :
@@ -117,6 +125,8 @@ def linear_indices_to_grid_coordinates(
117125 else :
118126 raise ValueError ("Invalid index type. Must be 'numpy' or 'sitk'." )
119127
128+ indices = to_numpy (indices )
129+
120130 v = np .empty ((3 , np .asarray (indices ).size ), dtype = dtype )
121131 tmp , v [order [0 ]] = np .divmod (indices , d2 )
122132 v [order [2 ]], v [order [1 ]] = np .divmod (tmp , d1 )
@@ -128,22 +138,22 @@ def linear_indices_to_grid_coordinates(
128138
129139 physical_point = origin + np .matmul (np .matmul (grid .direction , spacing_diag ), v ).T
130140
131- return physical_point
141+ return from_numpy ( xp , physical_point )
132142
133143
134144def linear_indices_to_image_coordinates (
135- indices : np . ndarray ,
145+ indices : Array ,
136146 image : sitk .Image ,
137147 index_type : Literal ["numpy" , "sitk" ] = "numpy" ,
138148 dtype : np .dtype = np .float64 ,
139- ) -> np . ndarray :
149+ ) -> Array :
140150 """
141151 Convert linear indices to image coordinates.
142152
143153 Parameters
144154 ----------
145- indices : np.ndarray
146- A 1D numpy array of linear indices where the mask is non-zero.
155+ indices : Array
156+ A 1D Array API conform array of linear indices where the mask is non-zero.
147157 image : sitk.Image
148158 The reference image on which the mask is defined.
149159 index_type : Literal["numpy", "sitk"], optional
@@ -154,8 +164,8 @@ def linear_indices_to_image_coordinates(
154164
155165 Returns
156166 -------
157- np.ndarray
158- A 2D numpy array of image coordinates.
167+ Array
168+ A 2D Array API conform array of image coordinates.
159169 """
160170
161171 grid = Grid .from_sitk_image (image )
0 commit comments