33from  typing  import  Literal 
44import  SimpleITK  as  sitk 
55import  numpy  as  np 
6+ import  array_api_compat 
7+ from  array_api_compat  import  numpy  as  xnp 
68from  ._grids  import  Grid 
79
10+ from  numpy .typing  import  NDArray 
11+ from  ..core .xp_utils .typing  import  Array 
12+ from  ..core .xp_utils  import  to_numpy , from_numpy 
813
9- def  sitk_mask_to_linear_indices (mask : sitk .Image , order = "sitk" ) ->  np .ndarray :
14+ 
15+ def  sitk_mask_to_linear_indices (mask : sitk .Image , order = "sitk" ) ->  NDArray :
1016    """ 
1117    Convert a SimpleITK mask to linear indices. 
1218
@@ -20,7 +26,7 @@ def sitk_mask_to_linear_indices(mask: sitk.Image, order="sitk") -> np.ndarray:
2026
2127    Returns 
2228    ------- 
23-     np.ndarray  
29+     NDArray  
2430        A 1D numpy array of linear indices where the mask is non-zero. 
2531
2632    Raises 
@@ -37,16 +43,14 @@ def sitk_mask_to_linear_indices(mask: sitk.Image, order="sitk") -> np.ndarray:
3743    raise  ValueError ("Invalid ordering. Must be 'sitk' or 'numpy'." )
3844
3945
40- def  linear_indices_to_sitk_mask (
41-     indices : np .ndarray , ref_image : sitk .Image , order = "sitk" 
42- ) ->  sitk .Image :
46+ def  linear_indices_to_sitk_mask (indices : Array , ref_image : sitk .Image , order = "sitk" ) ->  sitk .Image :
4347    """ 
4448    Convert linear indices to a SimpleITK mask. 
4549
4650    Parameters 
4751    ---------- 
48-     indices : np.ndarray  
49-         A 1D numpy  array of linear indices where the mask is non-zero. 
52+     indices : Array  
53+         A 1D Array API conform  array of linear indices where the mask is non-zero. 
5054    ref_image : sitk.Image 
5155        The reference image on which the mask is defined. 
5256    order : str, optional 
@@ -64,7 +68,9 @@ def linear_indices_to_sitk_mask(
6468        If the ordering is not 'sitk' or 'numpy'. 
6569    """ 
6670
67-     arr  =  np .zeros_like (sitk .GetArrayViewFromImage (ref_image ), dtype = np .uint8 )
71+     indices  =  to_numpy (indices )
72+ 
73+     arr : NDArray  =  xnp .zeros_like (sitk .GetArrayViewFromImage (ref_image ), dtype = xnp .uint8 )
6874
6975    if  order  ==  "sitk" :
7076        arr .T .flat [indices ] =  1 
@@ -80,18 +86,18 @@ def linear_indices_to_sitk_mask(
8086
8187
8288def  linear_indices_to_grid_coordinates (
83-     indices : np . ndarray ,
89+     indices : Array ,
8490    grid : Grid ,
8591    index_type : Literal ["numpy" , "sitk" ] =  "numpy" ,
8692    dtype : np .dtype  =  np .float64 ,
87- ) ->  np . ndarray :
93+ ) ->  Array :
8894    """ 
8995    Convert linear indices to gridcoordinates. 
9096
9197    Parameters 
9298    ---------- 
93-     indices : np.ndarray  
94-         A 1D numpy  array of linear indices where the mask is non-zero. 
99+     indices : Array  
100+         A 1D Array API conform  array of linear indices where the mask is non-zero. 
95101    grid : Grid 
96102        The image grid on which the indices lie. 
97103    index_type : Literal["numpy", "sitk"], optional 
@@ -102,10 +108,12 @@ def linear_indices_to_grid_coordinates(
102108
103109    Returns 
104110    ------- 
105-     np.ndarray  
106-         A 2D numpy  array of image coordinates. 
111+     Array  
112+         A 2D Array API conform array  array of image coordinates. 
107113    """ 
108114
115+     xp  =  array_api_compat .array_namespace (indices )
116+ 
109117    # this is a manual reimplementation of np.unravel_index 
110118    # to avoid the overhead of creating a tuple of arrays 
111119    if  index_type  ==  "numpy" :
@@ -117,6 +125,8 @@ def linear_indices_to_grid_coordinates(
117125    else :
118126        raise  ValueError ("Invalid index type. Must be 'numpy' or 'sitk'." )
119127
128+     indices  =  to_numpy (indices )
129+ 
120130    v  =  np .empty ((3 , np .asarray (indices ).size ), dtype = dtype )
121131    tmp , v [order [0 ]] =  np .divmod (indices , d2 )
122132    v [order [2 ]], v [order [1 ]] =  np .divmod (tmp , d1 )
@@ -128,22 +138,22 @@ def linear_indices_to_grid_coordinates(
128138
129139    physical_point  =  origin  +  np .matmul (np .matmul (grid .direction , spacing_diag ), v ).T 
130140
131-     return  physical_point 
141+     return  from_numpy ( xp ,  physical_point ) 
132142
133143
134144def  linear_indices_to_image_coordinates (
135-     indices : np . ndarray ,
145+     indices : Array ,
136146    image : sitk .Image ,
137147    index_type : Literal ["numpy" , "sitk" ] =  "numpy" ,
138148    dtype : np .dtype  =  np .float64 ,
139- ) ->  np . ndarray :
149+ ) ->  Array :
140150    """ 
141151    Convert linear indices to image coordinates. 
142152
143153    Parameters 
144154    ---------- 
145-     indices : np.ndarray  
146-         A 1D numpy  array of linear indices where the mask is non-zero. 
155+     indices : Array  
156+         A 1D Array API conform  array of linear indices where the mask is non-zero. 
147157    image : sitk.Image 
148158        The reference image on which the mask is defined. 
149159    index_type : Literal["numpy", "sitk"], optional 
@@ -154,8 +164,8 @@ def linear_indices_to_image_coordinates(
154164
155165    Returns 
156166    ------- 
157-     np.ndarray  
158-         A 2D numpy  array of image coordinates. 
167+     Array  
168+         A 2D Array API conform  array of image coordinates. 
159169    """ 
160170
161171    grid  =  Grid .from_sitk_image (image )
0 commit comments