55
66"""
77
8+ import cupy as cp
9+ import logging
10+
811import gc
912from functools import cached_property
1013
1316import copy
1417from tqdm .auto import tqdm , trange
1518
16- from ..operators .gradient import GradAnalysis , GradSynthesis
19+ from ..operators .gradient import GradAnalysis , GradSynthesis , CustomGradAnalysis
1720from .base import BaseFMRIReconstructor
1821from .utils import OPTIMIZERS , initialize_opt
1922
2023from modopt .opt .algorithms import POGM
2124from modopt .opt .linear import Identity
25+ from modopt .opt .gradient import GradParent
2226from ..optimizer import AccProxSVRG , MS2GD
2327
28+ logger = logging .getLogger ("pysap-fmri" )
29+
2430
2531class SequentialReconstructor (BaseFMRIReconstructor ):
2632 """Sequential Reconstruction of fMRI data.
@@ -107,6 +113,8 @@ def reconstruct(
107113 final_estimate [i , ...] = x_iter
108114 # Progressbar update
109115 progbar .close ()
116+
117+ logger .info ("final prox weight: %f " , xp .unique (self .space_prox_op .weights ))
110118 return final_estimate
111119
112120 def _reconstruct_frame (
@@ -219,34 +227,6 @@ def reconstruct(
219227 return final_estimate
220228
221229
222- class CustomGradAnalysis :
223- """Custom Gradient Analysis Operator."""
224-
225- def __init__ (self , fourier_op , obs_data ):
226- self .fourier_op = fourier_op
227- self .obs_data = obs_data
228- self .shape = fourier_op .shape
229-
230- def get_grad (self , x ):
231- """Get the gradient value"""
232- self .grad = self .fourier_op .data_consistency (x , self .obs_data )
233- return self .grad
234-
235- @cached_property
236- def spec_rad (self ):
237- return self .fourier_op .get_lipschitz_cst ()
238-
239- def inv_spec_rad (self ):
240- return 1.0 / self .spec_rad
241-
242- def cost (self , x , * args , ** kwargs ):
243- xp = get_array_module (x )
244- cost = xp .linalg .norm (self .fourier_op .op (x ) - self .obs_data )
245- if xp != np :
246- return cost .get ()
247- return cost
248-
249-
250230class StochasticSequentialReconstructor (BaseFMRIReconstructor ):
251231 """Stochastic Sequential Reconstruction of fMRI data."""
252232
@@ -255,12 +235,18 @@ def __init__(
255235 fourier_op ,
256236 space_linear_op ,
257237 space_prox_op ,
238+ space_prox_op_refine = None ,
258239 progbar_disable = False ,
259240 compute_backend = "numpy" ,
260241 ** kwargs ,
261242 ):
262243 super ().__init__ (fourier_op , space_linear_op , space_prox_op , ** kwargs )
263244
245+ if space_prox_op_refine is None :
246+ self .space_prox_op_refine = space_prox_op
247+ else :
248+ self .space_prox_op_refine = space_prox_op_refine
249+
264250 self .progbar_disable = progbar_disable
265251 self .compute_backend = compute_backend
266252
@@ -269,6 +255,7 @@ def reconstruct(
269255 kspace_data ,
270256 x_init = None ,
271257 max_iter_per_frame = 15 ,
258+ max_iter_stochastic = 20 ,
272259 grad_kwargs = None ,
273260 algorithm = "accproxsvrg" ,
274261 progbar_disable = False ,
@@ -283,6 +270,7 @@ def reconstruct(
283270 xp , _ = get_backend (self .compute_backend )
284271 # Create the gradients operators
285272 grad_list = []
273+ tmp_ksp = cp .zeros_like (kspace_data [0 ])
286274 for i , fop in enumerate (self .fourier_op .fourier_ops ):
287275 # L = fop.get_lipschitz_cst()
288276
@@ -296,7 +284,7 @@ def reconstruct(
296284 # input_data_writeable=True,
297285 # )
298286 # g._obs_data = kspace_data[i, ...]
299- g = CustomGradAnalysis (fop , kspace_data [i , ...])
287+ g = CustomGradAnalysis (fop , kspace_data [i , ...], obs_data_gpu = tmp_ksp )
300288 grad_list .append (g )
301289
302290 max_lip = max (g .spec_rad for g in grad_list )
@@ -307,10 +295,9 @@ def reconstruct(
307295 x = xp .zeros (grad_list [0 ].shape , dtype = "complex64" ),
308296 grad_list = grad_list ,
309297 prox = self .space_prox_op ,
310- step_size = 1.0 / max_lip ,
298+ step_size = 1.0 / 2 * max_lip ,
311299 auto_iterate = False ,
312300 cost = None ,
313- update_frequency = 10 ,
314301 compute_backend = self .compute_backend ,
315302 ** algorithm_kwargs ,
316303 )
@@ -323,12 +310,11 @@ def reconstruct(
323310 prox = self .space_prox_op ,
324311 step_size = 1.0 / max_lip ,
325312 auto_iterate = False ,
326- update_frequency = 10 ,
327313 cost = None ,
328314 ** algorithm_kwargs ,
329315 )
330316
331- opt .iterate (max_iter = 20 )
317+ opt .iterate (max_iter = max_iter_stochastic )
332318
333319 x_anat = opt .x_final .squeeze ()
334320
@@ -348,7 +334,7 @@ def reconstruct(
348334 x_anat ,
349335 x_anat ,
350336 grad = grad_list [i ],
351- prox = self .space_prox_op ,
337+ prox = self .space_prox_op_refine ,
352338 linear = Identity (),
353339 beta = grad_list [i ].inv_spec_rad ,
354340 compute_backend = self .compute_backend ,
@@ -360,9 +346,9 @@ def reconstruct(
360346 progbar .reset (total = max_iter_per_frame )
361347 img = opt .x_final
362348
363- if self .compute_backend == "cupy" :
364- final_img [i ] = img .get ()
365- else :
366- final_img [i ] = img
349+ if self .compute_backend == "cupy" :
350+ final_img [i ] = img .get (). squeeze ()
351+ else :
352+ final_img [i ] = img
367353
368354 return final_img , x_anat
0 commit comments