77import itertools
88from abc import ABC , abstractmethod
99from collections .abc import Sequence
10+ from tracemalloc import BaseFilter
1011
1112import numpy as np
12- from mrinufft import get_operator
1313from modopt .base .backend import get_array_module
14+ from mrinufft import get_operator
15+ from numpy .typing import NDArray
1416
15- try :
16- from mrinufft .operators .interfaces .gpunufft import make_pinned_smaps
17- except ImportError :
18- make_pinned_smaps = None
17+ from mrinufft .operators .interfaces .gpunufft import make_pinned_smaps
1918
2019from .utils .fft import fft , ifft
2120
22- MRINUFFT_AVAILABLE = True
2321CUPY_AVAILABLE = True
2422
2523try :
@@ -51,15 +49,19 @@ def __init__(self):
5149 self .smaps = None
5250
5351 @abstractmethod
54- def op (self , img ) :
52+ def op (self , img : NDArray ) -> NDArray :
5553 """Forward operator."""
5654 pass
5755
5856 @abstractmethod
59- def adj_op (self , data ) :
57+ def adj_op (self , data : NDArray ) -> NDArray :
6058 """Adjoint operator."""
6159 pass
6260
61+ def data_consistency (self , data , obs_data ):
62+ """Data consistency operation"""
63+ return self .adj_op (self .op (data ) - obs_data )
64+
6365
6466class CartesianSpaceFourier (SpaceFourierBase ):
6567 """A Fourier Operator in space."""
@@ -131,7 +133,7 @@ def op(self, img):
131133 ksp = fft (img , axis = axes )
132134 return ksp * self .mask
133135
134- def adj_op (self , kspace_data ):
136+ def adj_op (self , data ):
135137 """Apply the adjoint operator.
136138
137139 Parameters
@@ -146,12 +148,12 @@ def adj_op(self, kspace_data):
146148 """
147149 axes = tuple (range (- len (self .shape ), 0 ))
148150 if self .n_coils > 1 :
149- img = ifft (kspace_data , axis = axes )
151+ img = ifft (data , axis = axes )
150152 if self .smaps is None :
151153 return img
152154 return np .sum (img * np .conj (self .smaps ), axis = 1 )
153155 else :
154- return ifft (kspace_data , axis = axes )
156+ return ifft (data , axis = axes )
155157
156158
157159class RepeatOperator (SpaceFourierBase ):
@@ -160,22 +162,22 @@ class RepeatOperator(SpaceFourierBase):
160162 def __init__ (self , fourier_ops ):
161163 self .fourier_ops = list (fourier_ops )
162164
163- def op (self , images ):
165+ def op (self , img ):
164166 """Apply the forward operator."""
165167 final_ksp = np .empty (
166- (len (images ), self .n_coils , self .n_samples ), dtype = np .complex64
168+ (len (img ), self .n_coils , self .n_samples ), dtype = np .complex64
167169 )
168- for i in range (len (images )):
169- final_ksp [i ] = self .fourier_ops [i ].op (images [i ])
170+ for i in range (len (img )):
171+ final_ksp [i ] = self .fourier_ops [i ].op (img [i ])
170172 return final_ksp
171173
172- def adj_op (self , coeffs ):
174+ def adj_op (self , data ):
173175 """Apply Adjoint Operator."""
174176 c = 1 if self .uses_sense else self .n_coils
175- xp = get_array_module (coeffs )
177+ xp = get_array_module (data )
176178 final_image = xp .empty ((self .n_frames , c , * self .shape ), dtype = np .complex64 )
177- for i in range (len (coeffs )):
178- final_image [i ] = self .fourier_ops [i ].adj_op (coeffs [i ])
179+ for i in range (len (data )):
180+ final_image [i ] = self .fourier_ops [i ].adj_op (data [i ])
179181 return final_image .squeeze ()
180182
181183 def __getattr__ (self , attrName ):
@@ -327,7 +329,7 @@ def _init_density(self, density):
327329
328330 def _init_operators (self , ** kwargs ):
329331 # initialize all the operators
330- factory = get_operator ("gpunufft" )
332+ factory : SpaceFourierBase = get_operator ("gpunufft" )
331333 self .fourier_ops = [None ] * self .n_frames
332334 for i , p_img , p_ksp in zip (
333335 range (self .n_frames ),
@@ -346,13 +348,13 @@ def _init_operators(self, **kwargs):
346348 ** kwargs ,
347349 )
348350
349- def op (self , images ):
351+ def op (self , img ):
350352 """Apply the forward operator."""
351353 final_ksp = np .empty (
352- (len (images ), self .n_coils , self .n_samples ), dtype = np .complex64
354+ (len (img ), self .n_coils , self .n_samples ), dtype = np .complex64
353355 )
354- for i in range (len (images )):
355- final_ksp [i ] = self .fourier_ops [i ].op (images [i ])
356+ for i in range (len (img )):
357+ final_ksp [i ] = self .fourier_ops [i ].op (img [i ])
356358 return final_ksp
357359
358360 def adj_op (self , coeffs ):
0 commit comments