Skip to content

Commit c0c4692

Browse files
committed
update
1 parent c96af11 commit c0c4692

File tree

4 files changed

+81
-56
lines changed

4 files changed

+81
-56
lines changed

src/fmri/operators/gradient.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@
33
Adapted from pysap-mri and Modopt libraries.
44
"""
55

6+
from functools import cached_property
7+
68
import numpy as np
9+
import cupy as cp
710
from modopt.math.matrix import PowerMethod
8-
from modopt.opt.gradient import GradBasic
9-
from modopt.base.backend import get_backend
11+
from modopt.opt.gradient import GradBasic, GradParent
12+
from modopt.base.backend import get_backend, get_array_module
1013

1114

1215
def check_lipschitz_cst(f, x_shape, x_dtype, lipschitz_cst, max_nb_of_iter=10):
@@ -224,3 +227,41 @@ def _op_method(self, data):
224227

225228
def _trans_op_method(self, data):
226229
return self.linear_op.op(self.fourier_op.adj_op(data))
230+
231+
232+
class CustomGradAnalysis(GradParent):
233+
"""Custom Gradient Analysis Operator."""
234+
235+
def __init__(self, fourier_op, obs_data, obs_data_gpu=None, lazy=True):
236+
self.fourier_op = fourier_op
237+
self._grad_data_type = np.complex64
238+
self._obs_data = obs_data
239+
if obs_data_gpu is None:
240+
self.obs_data_gpu = cp.array(obs_data)
241+
elif isinstance(obs_data_gpu, cp.ndarray):
242+
self.obs_data_gpu = obs_data_gpu
243+
else:
244+
raise ValueError("Invalid data type for obs_data_gpu")
245+
self.lazy = lazy
246+
self.shape = fourier_op.shape
247+
248+
def get_grad(self, x):
249+
"""Get the gradient value"""
250+
if self.lazy:
251+
self.obs_data_gpu.set(self.obs_data)
252+
self.grad = self.fourier_op.data_consistency(x, self.obs_data_gpu)
253+
return self.grad
254+
255+
@cached_property
256+
def spec_rad(self):
257+
return self.fourier_op.get_lipschitz_cst()
258+
259+
def inv_spec_rad(self):
260+
return 1.0 / self.spec_rad
261+
262+
def cost(self, x, *args, **kwargs):
263+
xp = get_array_module(x)
264+
cost = xp.linalg.norm(self.fourier_op.op(x) - self.obs_data)
265+
if xp != np:
266+
return cost.get()
267+
return cost

src/fmri/operators/weighted.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,8 +430,6 @@ def _auto_thresh(self, input_data):
430430
weights = self._thresh_scale(weights, self._n_op_calls)
431431
else:
432432
weights *= self._thresh_scale
433-
xp = get_array_module(weights)
434-
logger.info(xp.unique(weights))
435433
return weights
436434

437435
def _op_method(self, input_data, extra_factor=1.0):

src/fmri/optimizer.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class AccProxSVRG(SetUp):
3131
def __init__(
3232
self,
3333
x,
34-
grad_list,
34+
fourier_op_list,
3535
prox,
3636
cost="auto",
3737
step_size=1.0,
@@ -79,7 +79,7 @@ def _update(self):
7979
self._v_tld = self.xp.zeros_like(self._v_tld)
8080
# Compute the average gradient.
8181
for g in self._grad_ops:
82-
self._v_tld += g.get_grad(self._x_old)
82+
self._v_tld += g.get_grad(self._x_tld)
8383
self._v_tld /= len(self._grad_ops)
8484

8585
self.xp.copyto(self._x_old, self._x_tld)
@@ -89,16 +89,17 @@ def _update(self):
8989
self.xp.copyto(self._v, self._v_tld)
9090
self._v *= self.batch_size
9191
for g in gIk:
92-
self._v += g.get_grad(self._x_tld)
93-
self._v -= g.get_grad(self._y)
92+
self._v -= g.get_grad(self._x_tld)
93+
self._v += g.get_grad(self._y)
9494
self._v *= self.step_size / self.batch_size
95-
self._x_new = self._y - self._v # Reuse the array
95+
self.xp.copyto(self._x_new, self._y)
96+
self._x_new -= self._v # Reuse the array
9697
self._x_new = self._prox.op(self._x_new, extra_factor=self.step_size)
97-
self._v = self._x_new - self._x_old # Reuse the array
98-
99-
self._y = self._x_new + self.beta * self._v
98+
self.xp.copyto(self._v, self._x_new)
99+
self._v -= self._x_old # Reuse the array
100+
self.xp.copyto(self._y, self._x_new)
101+
self._y += self.beta * self._v
100102
self.xp.copyto(self._x_old, self._x_new)
101-
102103
self.xp.copyto(self._x_tld, self._x_new)
103104

104105
# Test cost function for convergence.
@@ -184,14 +185,13 @@ def __init__(
184185
super().__init__(**kwargs)
185186

186187
# Set the initial variable values
187-
self._check_input_data(x)
188188

189189
self.step_size = step_size
190190

191191
self.update_frequency = update_frequency
192192
self.batch_size = batch_size
193193
self._grad_ops = grad_list
194-
self._prox_op = prox
194+
self._prox = prox
195195

196196
self._rng = np.random.default_rng(seed)
197197

@@ -213,9 +213,9 @@ def _update(self):
213213
self._g += g.get_grad(self._x)
214214
self._g /= len(self._grad_ops)
215215
self.xp.copyto(self._y, self._x)
216-
tk = self.rng.randint(1, self.update_frequency)
216+
tk = self._rng.integers(1, self.update_frequency)
217217
for _ in range(tk):
218-
Ak = self.rng.choices(self._grad_ops, k=self.batch_size)
218+
Ak = self._rng.choice(self._grad_ops, size=self.batch_size, replace=False)
219219
self.xp.copyto(self._g_sto, self._g)
220220
self._g_sto *= self.batch_size
221221
for g in Ak:

src/fmri/reconstructors/frame_based.py

Lines changed: 25 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
66
"""
77

8+
import cupy as cp
9+
import logging
10+
811
import gc
912
from functools import cached_property
1013

@@ -13,14 +16,17 @@
1316
import copy
1417
from tqdm.auto import tqdm, trange
1518

16-
from ..operators.gradient import GradAnalysis, GradSynthesis
19+
from ..operators.gradient import GradAnalysis, GradSynthesis, CustomGradAnalysis
1720
from .base import BaseFMRIReconstructor
1821
from .utils import OPTIMIZERS, initialize_opt
1922

2023
from modopt.opt.algorithms import POGM
2124
from modopt.opt.linear import Identity
25+
from modopt.opt.gradient import GradParent
2226
from ..optimizer import AccProxSVRG, MS2GD
2327

28+
logger = logging.getLogger("pysap-fmri")
29+
2430

2531
class SequentialReconstructor(BaseFMRIReconstructor):
2632
"""Sequential Reconstruction of fMRI data.
@@ -107,6 +113,8 @@ def reconstruct(
107113
final_estimate[i, ...] = x_iter
108114
# Progressbar update
109115
progbar.close()
116+
117+
logger.info("final prox weight: %f ", xp.unique(self.space_prox_op.weights))
110118
return final_estimate
111119

112120
def _reconstruct_frame(
@@ -219,34 +227,6 @@ def reconstruct(
219227
return final_estimate
220228

221229

222-
class CustomGradAnalysis:
223-
"""Custom Gradient Analysis Operator."""
224-
225-
def __init__(self, fourier_op, obs_data):
226-
self.fourier_op = fourier_op
227-
self.obs_data = obs_data
228-
self.shape = fourier_op.shape
229-
230-
def get_grad(self, x):
231-
"""Get the gradient value"""
232-
self.grad = self.fourier_op.data_consistency(x, self.obs_data)
233-
return self.grad
234-
235-
@cached_property
236-
def spec_rad(self):
237-
return self.fourier_op.get_lipschitz_cst()
238-
239-
def inv_spec_rad(self):
240-
return 1.0 / self.spec_rad
241-
242-
def cost(self, x, *args, **kwargs):
243-
xp = get_array_module(x)
244-
cost = xp.linalg.norm(self.fourier_op.op(x) - self.obs_data)
245-
if xp != np:
246-
return cost.get()
247-
return cost
248-
249-
250230
class StochasticSequentialReconstructor(BaseFMRIReconstructor):
251231
"""Stochastic Sequential Reconstruction of fMRI data."""
252232

@@ -255,12 +235,18 @@ def __init__(
255235
fourier_op,
256236
space_linear_op,
257237
space_prox_op,
238+
space_prox_op_refine=None,
258239
progbar_disable=False,
259240
compute_backend="numpy",
260241
**kwargs,
261242
):
262243
super().__init__(fourier_op, space_linear_op, space_prox_op, **kwargs)
263244

245+
if space_prox_op_refine is None:
246+
self.space_prox_op_refine = space_prox_op
247+
else:
248+
self.space_prox_op_refine = space_prox_op_refine
249+
264250
self.progbar_disable = progbar_disable
265251
self.compute_backend = compute_backend
266252

@@ -269,6 +255,7 @@ def reconstruct(
269255
kspace_data,
270256
x_init=None,
271257
max_iter_per_frame=15,
258+
max_iter_stochastic=20,
272259
grad_kwargs=None,
273260
algorithm="accproxsvrg",
274261
progbar_disable=False,
@@ -283,6 +270,7 @@ def reconstruct(
283270
xp, _ = get_backend(self.compute_backend)
284271
# Create the gradients operators
285272
grad_list = []
273+
tmp_ksp = cp.zeros_like(kspace_data[0])
286274
for i, fop in enumerate(self.fourier_op.fourier_ops):
287275
# L = fop.get_lipschitz_cst()
288276

@@ -296,7 +284,7 @@ def reconstruct(
296284
# input_data_writeable=True,
297285
# )
298286
# g._obs_data = kspace_data[i, ...]
299-
g = CustomGradAnalysis(fop, kspace_data[i, ...])
287+
g = CustomGradAnalysis(fop, kspace_data[i, ...], obs_data_gpu=tmp_ksp)
300288
grad_list.append(g)
301289

302290
max_lip = max(g.spec_rad for g in grad_list)
@@ -307,10 +295,9 @@ def reconstruct(
307295
x=xp.zeros(grad_list[0].shape, dtype="complex64"),
308296
grad_list=grad_list,
309297
prox=self.space_prox_op,
310-
step_size=1.0 / max_lip,
298+
step_size=1.0 / 2 * max_lip,
311299
auto_iterate=False,
312300
cost=None,
313-
update_frequency=10,
314301
compute_backend=self.compute_backend,
315302
**algorithm_kwargs,
316303
)
@@ -323,12 +310,11 @@ def reconstruct(
323310
prox=self.space_prox_op,
324311
step_size=1.0 / max_lip,
325312
auto_iterate=False,
326-
update_frequency=10,
327313
cost=None,
328314
**algorithm_kwargs,
329315
)
330316

331-
opt.iterate(max_iter=20)
317+
opt.iterate(max_iter=max_iter_stochastic)
332318

333319
x_anat = opt.x_final.squeeze()
334320

@@ -348,7 +334,7 @@ def reconstruct(
348334
x_anat,
349335
x_anat,
350336
grad=grad_list[i],
351-
prox=self.space_prox_op,
337+
prox=self.space_prox_op_refine,
352338
linear=Identity(),
353339
beta=grad_list[i].inv_spec_rad,
354340
compute_backend=self.compute_backend,
@@ -360,9 +346,9 @@ def reconstruct(
360346
progbar.reset(total=max_iter_per_frame)
361347
img = opt.x_final
362348

363-
if self.compute_backend == "cupy":
364-
final_img[i] = img.get()
365-
else:
366-
final_img[i] = img
349+
if self.compute_backend == "cupy":
350+
final_img[i] = img.get().squeeze()
351+
else:
352+
final_img[i] = img
367353

368354
return final_img, x_anat

0 commit comments

Comments
 (0)