Skip to content

Commit ac6d730

Browse files
authored
Merge pull request #153 from effigies/typ/annotations
type: Relax Nifti1Image to SpatialImage
2 parents d88ca78 + 10d4f3a commit ac6d730

File tree

2 files changed

+20
-18
lines changed

2 files changed

+20
-18
lines changed

nireports/reportlets/mosaic.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
import numpy as np
3939
import numpy.typing as npt
4040
from matplotlib.gridspec import GridSpec
41+
from nibabel.spatialimages import SpatialImage
4142
from nilearn import image as nlimage
4243
from nilearn.plotting import plot_anat
4344
from svgutils.transform import SVGFigure, fromstring
@@ -55,9 +56,9 @@
5556

5657

5758
def plot_segs(
58-
image_nii: ty.Union[str, nb.Nifti1Image],
59-
seg_niis: list[ty.Union[str, nb.Nifti1Image]],
60-
bbox_nii: ty.Union[str, nb.Nifti1Image, None] = None,
59+
image_nii: ty.Union[str, SpatialImage],
60+
seg_niis: list[ty.Union[str, SpatialImage]],
61+
bbox_nii: ty.Union[str, SpatialImage, None] = None,
6162
masked: bool = False,
6263
compress: ty.Union[bool, L["auto"]] = "auto",
6364
**plot_params,
@@ -76,7 +77,7 @@ def plot_segs(
7677
image_nii = _3d_in_file(image_nii)
7778
canonical_r = rotation2canonical(image_nii)
7879
image_nii = rotate_affine(image_nii, rot=canonical_r)
79-
seg_imgs: list[nb.Nifti1Image] = [
80+
seg_imgs: list[SpatialImage] = [
8081
rotate_affine(_3d_in_file(f), rot=canonical_r) for f in seg_niis
8182
]
8283
data = image_nii.get_fdata()
@@ -88,7 +89,7 @@ def plot_segs(
8889
)
8990

9091
if masked:
91-
bbox_nii: nb.Nifti1Image = nlimage.threshold_img(bbox_nii, 1e-3) # type: ignore[no-redef]
92+
bbox_nii: SpatialImage = nlimage.threshold_img(bbox_nii, 1e-3) # type: ignore[no-redef]
9293

9394
cuts = cuts_from_bbox(bbox_nii, cuts=7)
9495
out_files = []
@@ -104,14 +105,14 @@ def plot_segs(
104105

105106

106107
def plot_registration(
107-
anat_nii: nb.spatialimages.SpatialImage,
108+
anat_nii: SpatialImage,
108109
div_id: str,
109110
plot_params: ty.Union[dict[str, ty.Any], None] = None,
110111
order: tuple[L["x", "y", "z"], L["x", "y", "z"], L["x", "y", "z"]] = ("z", "x", "y"),
111112
cuts: ty.Union[dict[str, list[float]], None] = None,
112113
estimate_brightness: bool = False,
113114
label: ty.Union[str, None] = None,
114-
contour: ty.Union[nb.spatialimages.SpatialImage, None] = None,
115+
contour: ty.Union[SpatialImage, None] = None,
115116
compress: ty.Union[bool, L["auto"]] = "auto",
116117
dismiss_affine: bool = False,
117118
) -> list[SVGFigure]:
@@ -184,8 +185,8 @@ def plot_registration(
184185

185186

186187
def _plot_anat_with_contours(
187-
image: nb.Nifti1Image,
188-
segs: ty.Union[list[nb.Nifti1Image], None] = None,
188+
image: SpatialImage,
189+
segs: ty.Union[list[SpatialImage], None] = None,
189190
compress: ty.Union[bool, L["auto"]] = "auto",
190191
**plot_params,
191192
) -> str:
@@ -233,12 +234,12 @@ def plot_segmentation(anat_file: str, segmentation: str, out_file: str, **kwargs
233234
vmax = kwargs.get("vmax")
234235
vmin = kwargs.get("vmin")
235236

236-
anat_ras = nb.as_closest_canonical(load_api(anat_file, nb.spatialimages.SpatialImage))
237+
anat_ras = nb.as_closest_canonical(load_api(anat_file, SpatialImage))
237238
anat_ras_plumb = anat_ras.__class__(
238239
anat_ras.dataobj, _dicom_real_to_card(anat_ras.affine), anat_ras.header
239240
)
240241

241-
seg_ras = nb.as_closest_canonical(load_api(segmentation, nb.spatialimages.SpatialImage))
242+
seg_ras = nb.as_closest_canonical(load_api(segmentation, SpatialImage))
242243
seg_ras_plumb = seg_ras.__class__(
243244
seg_ras.dataobj, _dicom_real_to_card(seg_ras.affine), seg_ras.header
244245
)
@@ -430,8 +431,8 @@ def plot_spikes(
430431
"""Plot a mosaic enhancing EM spikes."""
431432
from mpl_toolkits.axes_grid1 import make_axes_locatable
432433

433-
nii = nb.as_closest_canonical(load_api(in_file, nb.spatialimages.SpatialImage))
434-
fft = load_api(in_file, nb.spatialimages.SpatialImage).get_fdata()
434+
nii = nb.as_closest_canonical(load_api(in_file, SpatialImage))
435+
fft = load_api(in_file, SpatialImage).get_fdata()
435436

436437
data = nii.get_fdata()
437438
zooms = nii.header.get_zooms()[:2]

nireports/reportlets/utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import nibabel as nb
4545
import numpy as np
4646
import numpy.typing as npt
47+
from nibabel.spatialimages import SpatialImage
4748
from svgutils.transform import SVGFigure
4849

4950
from ..tools.ndimage import load_api
@@ -269,7 +270,7 @@ def _bbox(img_data: npt.NDArray[G], bbox_data: npt.NDArray) -> npt.NDArray[G]:
269270
return img_data[ystart:ystop, xstart:xstop, zstart:zstop]
270271

271272

272-
def cuts_from_bbox(mask_nii: nb.Nifti1Image, cuts: int = 3) -> dict[str, list[float]]:
273+
def cuts_from_bbox(mask_nii: SpatialImage, cuts: int = 3) -> dict[str, list[float]]:
273274
"""Find equi-spaced cuts for presenting images."""
274275
mask_data = np.asanyarray(mask_nii.dataobj) > 0.0
275276

@@ -316,8 +317,8 @@ def cuts_from_bbox(mask_nii: nb.Nifti1Image, cuts: int = 3) -> dict[str, list[fl
316317

317318

318319
def _3d_in_file(
319-
in_file: ty.Union[nb.Nifti1Image, str, os.PathLike, list[ty.Union[str, os.PathLike]]],
320-
) -> nb.Nifti1Image:
320+
in_file: ty.Union[SpatialImage, str, os.PathLike, list[ty.Union[str, os.PathLike]]],
321+
) -> SpatialImage:
321322
"""if self.inputs.in_file is 3d, return it.
322323
if 4d, pick an arbitrary volume and return that.
323324
@@ -329,8 +330,8 @@ def _3d_in_file(
329330
if isinstance(in_file, list):
330331
in_file = in_file[0]
331332

332-
if not isinstance(in_file, nb.Nifti1Image):
333-
in_file = load_api(in_file, nb.Nifti1Image)
333+
if not isinstance(in_file, SpatialImage):
334+
in_file = load_api(in_file, SpatialImage)
334335

335336
if len(in_file.shape) == 3:
336337
return in_file

0 commit comments

Comments
 (0)