3838import numpy as np
3939import numpy .typing as npt
4040from matplotlib .gridspec import GridSpec
41+ from nibabel .spatialimages import SpatialImage
4142from nilearn import image as nlimage
4243from nilearn .plotting import plot_anat
4344from svgutils .transform import SVGFigure , fromstring
5556
5657
5758def plot_segs (
58- image_nii : ty .Union [str , nb . Nifti1Image ],
59- seg_niis : list [ty .Union [str , nb . Nifti1Image ]],
60- bbox_nii : ty .Union [str , nb . Nifti1Image , None ] = None ,
59+ image_nii : ty .Union [str , SpatialImage ],
60+ seg_niis : list [ty .Union [str , SpatialImage ]],
61+ bbox_nii : ty .Union [str , SpatialImage , None ] = None ,
6162 masked : bool = False ,
6263 compress : ty .Union [bool , L ["auto" ]] = "auto" ,
6364 ** plot_params ,
@@ -76,7 +77,7 @@ def plot_segs(
7677 image_nii = _3d_in_file (image_nii )
7778 canonical_r = rotation2canonical (image_nii )
7879 image_nii = rotate_affine (image_nii , rot = canonical_r )
79- seg_imgs : list [nb . Nifti1Image ] = [
80+ seg_imgs : list [SpatialImage ] = [
8081 rotate_affine (_3d_in_file (f ), rot = canonical_r ) for f in seg_niis
8182 ]
8283 data = image_nii .get_fdata ()
@@ -88,7 +89,7 @@ def plot_segs(
8889 )
8990
9091 if masked :
91- bbox_nii : nb . Nifti1Image = nlimage .threshold_img (bbox_nii , 1e-3 ) # type: ignore[no-redef]
92+ bbox_nii : SpatialImage = nlimage .threshold_img (bbox_nii , 1e-3 ) # type: ignore[no-redef]
9293
9394 cuts = cuts_from_bbox (bbox_nii , cuts = 7 )
9495 out_files = []
@@ -104,14 +105,14 @@ def plot_segs(
104105
105106
106107def plot_registration (
107- anat_nii : nb . spatialimages . SpatialImage ,
108+ anat_nii : SpatialImage ,
108109 div_id : str ,
109110 plot_params : ty .Union [dict [str , ty .Any ], None ] = None ,
110111 order : tuple [L ["x" , "y" , "z" ], L ["x" , "y" , "z" ], L ["x" , "y" , "z" ]] = ("z" , "x" , "y" ),
111112 cuts : ty .Union [dict [str , list [float ]], None ] = None ,
112113 estimate_brightness : bool = False ,
113114 label : ty .Union [str , None ] = None ,
114- contour : ty .Union [nb . spatialimages . SpatialImage , None ] = None ,
115+ contour : ty .Union [SpatialImage , None ] = None ,
115116 compress : ty .Union [bool , L ["auto" ]] = "auto" ,
116117 dismiss_affine : bool = False ,
117118) -> list [SVGFigure ]:
@@ -184,8 +185,8 @@ def plot_registration(
184185
185186
186187def _plot_anat_with_contours (
187- image : nb . Nifti1Image ,
188- segs : ty .Union [list [nb . Nifti1Image ], None ] = None ,
188+ image : SpatialImage ,
189+ segs : ty .Union [list [SpatialImage ], None ] = None ,
189190 compress : ty .Union [bool , L ["auto" ]] = "auto" ,
190191 ** plot_params ,
191192) -> str :
@@ -233,12 +234,12 @@ def plot_segmentation(anat_file: str, segmentation: str, out_file: str, **kwargs
233234 vmax = kwargs .get ("vmax" )
234235 vmin = kwargs .get ("vmin" )
235236
236- anat_ras = nb .as_closest_canonical (load_api (anat_file , nb . spatialimages . SpatialImage ))
237+ anat_ras = nb .as_closest_canonical (load_api (anat_file , SpatialImage ))
237238 anat_ras_plumb = anat_ras .__class__ (
238239 anat_ras .dataobj , _dicom_real_to_card (anat_ras .affine ), anat_ras .header
239240 )
240241
241- seg_ras = nb .as_closest_canonical (load_api (segmentation , nb . spatialimages . SpatialImage ))
242+ seg_ras = nb .as_closest_canonical (load_api (segmentation , SpatialImage ))
242243 seg_ras_plumb = seg_ras .__class__ (
243244 seg_ras .dataobj , _dicom_real_to_card (seg_ras .affine ), seg_ras .header
244245 )
@@ -430,8 +431,8 @@ def plot_spikes(
430431 """Plot a mosaic enhancing EM spikes."""
431432 from mpl_toolkits .axes_grid1 import make_axes_locatable
432433
433- nii = nb .as_closest_canonical (load_api (in_file , nb . spatialimages . SpatialImage ))
434- fft = load_api (in_file , nb . spatialimages . SpatialImage ).get_fdata ()
434+ nii = nb .as_closest_canonical (load_api (in_file , SpatialImage ))
435+ fft = load_api (in_file , SpatialImage ).get_fdata ()
435436
436437 data = nii .get_fdata ()
437438 zooms = nii .header .get_zooms ()[:2 ]
0 commit comments