Skip to content

Commit 512ed0b

Browse files
committed
typing and deps
1 parent 75b8c68 commit 512ed0b

File tree

4 files changed

+38
-30
lines changed

4 files changed

+38
-30
lines changed

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ description = "Python Sparse data Analysis Package external fMRI plugin."
44
authors = [{name="Pierre-antoine Comby", email="pierre-antoine.comby@crans.org"}]
55
readme = "README.rst"
66

7-
dependencies = ["modopt", "numpy", "tqdm", "joblib", "numba", "scipy", "pywavelets", "mri-nufft"]
7+
dependencies = ["modopt", "numpy", "tqdm", "joblib", "numba", "scipy", "pywavelets", "mri-nufft[cufinufft,gpunufft]"]
88
dynamic = ["version"]
99

1010
[project.optional-dependencies]
@@ -13,7 +13,7 @@ doc = ["pydata-sphinx-theme", "numpydoc", "sphinx_gallery", "sphinx", "sphinx-au
1313
dev = ["black", "isort", "ruff"]
1414

1515
[build-system]
16-
requires = ["setuptools", "setuptools-scm[toml]","wheel"]
16+
requires = ["setuptools", "setuptools-scm[toml]", "wheel"]
1717

1818
######################
1919
# Tool configuration #

src/fmri/operators/fourier.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,17 @@
77
import itertools
88
from abc import ABC, abstractmethod
99
from collections.abc import Sequence
10+
from tracemalloc import BaseFilter
1011

1112
import numpy as np
12-
from mrinufft import get_operator
1313
from modopt.base.backend import get_array_module
14+
from mrinufft import get_operator
15+
from numpy.typing import NDArray
1416

15-
try:
16-
from mrinufft.operators.interfaces.gpunufft import make_pinned_smaps
17-
except ImportError:
18-
make_pinned_smaps = None
17+
from mrinufft.operators.interfaces.gpunufft import make_pinned_smaps
1918

2019
from .utils.fft import fft, ifft
2120

22-
MRINUFFT_AVAILABLE = True
2321
CUPY_AVAILABLE = True
2422

2523
try:
@@ -51,15 +49,19 @@ def __init__(self):
5149
self.smaps = None
5250

5351
@abstractmethod
54-
def op(self, img):
52+
def op(self, img: NDArray) -> NDArray:
5553
"""Forward operator."""
5654
pass
5755

5856
@abstractmethod
59-
def adj_op(self, data):
57+
def adj_op(self, data: NDArray) -> NDArray:
6058
"""Adjoint operator."""
6159
pass
6260

61+
def data_consistency(self, data, obs_data):
62+
"""Data consistency operation"""
63+
return self.adj_op(self.op(data) - obs_data)
64+
6365

6466
class CartesianSpaceFourier(SpaceFourierBase):
6567
"""A Fourier Operator in space."""
@@ -131,7 +133,7 @@ def op(self, img):
131133
ksp = fft(img, axis=axes)
132134
return ksp * self.mask
133135

134-
def adj_op(self, kspace_data):
136+
def adj_op(self, data):
135137
"""Apply the adjoint operator.
136138
137139
Parameters
@@ -146,12 +148,12 @@ def adj_op(self, kspace_data):
146148
"""
147149
axes = tuple(range(-len(self.shape), 0))
148150
if self.n_coils > 1:
149-
img = ifft(kspace_data, axis=axes)
151+
img = ifft(data, axis=axes)
150152
if self.smaps is None:
151153
return img
152154
return np.sum(img * np.conj(self.smaps), axis=1)
153155
else:
154-
return ifft(kspace_data, axis=axes)
156+
return ifft(data, axis=axes)
155157

156158

157159
class RepeatOperator(SpaceFourierBase):
@@ -160,22 +162,22 @@ class RepeatOperator(SpaceFourierBase):
160162
def __init__(self, fourier_ops):
161163
self.fourier_ops = list(fourier_ops)
162164

163-
def op(self, images):
165+
def op(self, img):
164166
"""Apply the forward operator."""
165167
final_ksp = np.empty(
166-
(len(images), self.n_coils, self.n_samples), dtype=np.complex64
168+
(len(img), self.n_coils, self.n_samples), dtype=np.complex64
167169
)
168-
for i in range(len(images)):
169-
final_ksp[i] = self.fourier_ops[i].op(images[i])
170+
for i in range(len(img)):
171+
final_ksp[i] = self.fourier_ops[i].op(img[i])
170172
return final_ksp
171173

172-
def adj_op(self, coeffs):
174+
def adj_op(self, data):
173175
"""Apply Adjoint Operator."""
174176
c = 1 if self.uses_sense else self.n_coils
175-
xp = get_array_module(coeffs)
177+
xp = get_array_module(data)
176178
final_image = xp.empty((self.n_frames, c, *self.shape), dtype=np.complex64)
177-
for i in range(len(coeffs)):
178-
final_image[i] = self.fourier_ops[i].adj_op(coeffs[i])
179+
for i in range(len(data)):
180+
final_image[i] = self.fourier_ops[i].adj_op(data[i])
179181
return final_image.squeeze()
180182

181183
def __getattr__(self, attrName):
@@ -327,7 +329,7 @@ def _init_density(self, density):
327329

328330
def _init_operators(self, **kwargs):
329331
# initialize all the operators
330-
factory = get_operator("gpunufft")
332+
factory: SpaceFourierBase = get_operator("gpunufft")
331333
self.fourier_ops = [None] * self.n_frames
332334
for i, p_img, p_ksp in zip(
333335
range(self.n_frames),
@@ -346,13 +348,13 @@ def _init_operators(self, **kwargs):
346348
**kwargs,
347349
)
348350

349-
def op(self, images):
351+
def op(self, img):
350352
"""Apply the forward operator."""
351353
final_ksp = np.empty(
352-
(len(images), self.n_coils, self.n_samples), dtype=np.complex64
354+
(len(img), self.n_coils, self.n_samples), dtype=np.complex64
353355
)
354-
for i in range(len(images)):
355-
final_ksp[i] = self.fourier_ops[i].op(images[i])
356+
for i in range(len(img)):
357+
final_ksp[i] = self.fourier_ops[i].op(img[i])
356358
return final_ksp
357359

358360
def adj_op(self, coeffs):

src/fmri/operators/gradient.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
1-
""""Gradient operators for MRI reconstruction.
1+
""" "Gradient operators for MRI reconstruction.
22
33
Adapted from pysap-mri and Modopt libraries.
44
"""
55

66
from functools import cached_property
77

88
import numpy as np
9-
import cupy as cp
9+
10+
CUPY_AVAILABLE = True
11+
try:
12+
import cupy as cp
13+
except ImportError:
14+
CUPY_AVAILABLE = False
15+
1016
from modopt.math.matrix import PowerMethod
1117
from modopt.opt.gradient import GradBasic, GradParent
1218
from modopt.base.backend import get_backend, get_array_module

src/fmri/operators/utils/fft.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import scipy as sp
44

55

6-
def fft(image, axis=-1):
6+
def fft(image, axis: int | tuple[int, ...] = -1):
77
"""Apply the FFT operator.
88
99
Parameters
@@ -24,7 +24,7 @@ def fft(image, axis=-1):
2424
)
2525

2626

27-
def ifft(kspace_data, axis=-1):
27+
def ifft(kspace_data, axis: int | tuple[int, ...] = -1):
2828
"""Apply the inverse FFT operator."""
2929
return sp.fft.fftshift(
3030
sp.fft.ifftn(sp.fft.ifftshift(kspace_data, axes=axis), norm="ortho", axes=axis),

0 commit comments

Comments
 (0)