-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathpySPIRALdemo.py
More file actions
332 lines (295 loc) · 14.5 KB
/
pySPIRALdemo.py
File metadata and controls
332 lines (295 loc) · 14.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
#-*-coding:utf-8-*-
# SPIRAL: Sparse Poisson Intensity Reconstruction Algorithms
# Demonstration Code Version 1.0
# Matlab version by Zachary T. Harmany (zth@duke.edu)
# Python version by Maxime Woringer, Apr. 2016
#Included here are three demonstrations, changing the varialbe 'demo'
# to 1, 2, or 3 selects among three simulations. Details of each can be
# found below.
# ==== importations
from __future__ import print_function
try:
import rwt
except Exception:
raise ImportError("the 'rwt' (Rice Wavelet Toolbox) package could not be loaded. It can be installed from https://github.com/ricedsp/rwt/")
import pySPIRALTAP, sys
import numpy as np
import matplotlib.pyplot as plt
import scipy.io # to import.mat files
from conv2 import conv2
# ==== variables
demo = 2
# ==== Demo 1
if demo == 1:
"""
# =============================================================================
# = Demonstration 1 =
# =============================================================================
# Description: One dimensional compressed sensing example penalizing the
# sparsity (l1 norm) of the coefficients in the canonical basis. Here the
# true signal f is of length 100,000 with 1,500 nonzero entries yielding a
# sparsity of 1.5%. We take 40,000 compressive measurements in y. The
# average number of photons per measurement is 15.03, with a maximum of 145.
# We run SPIRAL until the relative change in the iterates falls below
# a tolerance of 1x10^-8, up to a maximum of 100 iterations (however only
# 37 iterations are required to satisfy the stopping criterion).
#
# Output: This demonstration automatically displays the following:
# Figure 1: Simulation setup (true signal, true detector intensity,
# observed counts),
# Figure 2: Reconstructed signal overlayed ontop of true signal,
# Figure 3: RMSE error evolution versus iteration and compute time, and
# Figure 4: Objective evolution versus iteration and compute time.
"""
# ==== Load example data:
# f = True signal
# A = Sensing matrix
# y ~ Poisson(Af)
rf=scipy.io.loadmat('./demodata/canonicaldata.mat')
f,y,Aorig = (rf['f'], rf['y'], rf['A']) # A Stored as a sparse matrix
## Setup function handles for computing A and A^T:
AT = lambda x: Aorig.transpose().dot(x)
A = lambda x: Aorig.dot(x)
# ==== Set regularization parameters and iteration limit:
tau = 1e-6
maxiter = 100
tolerance = 1e-8
verbose = 10
# ==== Simple initialization:
# AT(y) rescaled to a least-squares fit to the mean intensity
finit = y.sum()*AT(y).size/AT(y).sum()/AT(np.ones_like(y)).sum() * AT(y)
# ==== Run the algorithm:
## Demonstrating all the options for our algorithm:
resSPIRAL = pySPIRALTAP.SPIRALTAP(y,A,tau,
AT=AT,
maxiter=maxiter,
miniter=5,
penalty='canonical',
noisetype='gaussian',
initialization=finit,
stopcriterion=3,
tolerance=tolerance,
alphainit=1,
alphamin=1e-30,
alphamax=1e30,
alphaaccept=1e30,
logepsilon=1e-10,
saveobjective=True,
savereconerror=True,
savesolutionpath=False,
truth=f,
verbose=verbose, savecputime=True)
## Deparse outputs
fhatSPIRAL = resSPIRAL[0]
parSPIRAL = resSPIRAL[1]
iterationsSPIRAL = parSPIRAL['iterations']
objectiveSPIRAL = parSPIRAL['objective']
reconerrorSPIRAL = parSPIRAL['reconerror']
cputimeSPIRAL = parSPIRAL['cputime']
## ==== Display Results:
## Problem Data:
plt.figure(1)
plt.subplot(311)
plt.plot(f)
plt.title('True Signal (f), Nonzeros = {}, Mean Intensity = {}'.format((f!=0).sum(), f.mean()))
plt.ylim((0, 1.24*f.max()))
plt.subplot(312)
plt.plot(A(f))
plt.title('True Detector Intensity (Af), Mean Intensity = {}'.format(A(f).mean()))
plt.subplot(313)
plt.plot(y)
plt.title('Observed Photon Counts (y), Mean Count = {}'.format(y.mean()))
## Reconstructed Signals:
plt.figure(2)
plt.plot(f, color='blue')
plt.plot(fhatSPIRAL, color='red')
plt.xlabel('Sample number')
plt.ylabel('Amplitude')
plt.title('SPIRAL Estimate, RMS error = {}, Nonzero Components = {}'.format(np.linalg.norm(f-fhatSPIRAL)/np.linalg.norm(f), (fhatSPIRAL!=0).sum()))
## RMS Error:
plt.figure(3)
plt.subplot(211)
plt.plot(range(iterationsSPIRAL), reconerrorSPIRAL, color='blue')
plt.xlabel('Iteration')
plt.ylabel('RMS Error')
plt.subplot(212)
plt.plot(cputimeSPIRAL, reconerrorSPIRAL, color='blue')
plt.xlabel('CPU Time')
plt.ylabel('RMS Error')
plt.title('RMS Error Evolution (CPU Time)')
## Objective:
plt.figure(4)
plt.subplot(211)
plt.plot(range(iterationsSPIRAL), objectiveSPIRAL)
plt.xlabel('Iteration')
plt.ylabel('Objective')
plt.subplot(212)
plt.plot(cputimeSPIRAL, objectiveSPIRAL)
plt.xlabel('CPU Time')
plt.ylabel('Objective')
plt.title('Objective Evolution (CPU Time)')
plt.show()
elif demo == 2:
"""
% =============================================================================
% = Demonstration 2 =
% =============================================================================
% Description: Here we consider an image deblurring example. The true signal
% is a 128x128 Shepp-Logan phantom image with mean intensity 1.22x10^5. The
% true detector mean intensity is 45.8, and the observed photon count mean
% is 45.8 with a maximum of 398. Here we consider four penalization methods
% - Sparsity (l1 norm) of coefficients in an orthonormal (wavelet) basis,
% - Total variation of the image,
% - Penalty based on Recursive Dyadic Partitions (RDPs), and
% - Penalty based on Translationally-Invariant (cycle-spun) RDPs.
% We run all the SPIRAL methods for a minimum of 50 iterations until the
% relative change in the iterates falls below a tolerance of 1x10^-8, up
% to a maximum of 100 iterations (however only ~70 iterations are required
% to satisfy the stopping criterion for all the methods).
%
% Output: This demonstration automatically displays the following:
% Figure 1: Simulation setup (true signal, true detector intensity,
% observed counts),
% Figure 2: The objective evolution for the methods where explicit
% computation of the objective is possible,
% Figure 3: RMSE error evolution versus iteration and compute time,
% Figure 4: The final reconstructions, and
% Figure 5: The the magnitude of the errors between the final
% reconstructions and the true phantom image.
"""
# ==== Load example data
# f = True signal
# blur = Blur PSF
# y ~ Poisson(Af)
rf=scipy.io.loadmat('./demodata/imagedata.mat')
f, blur, y = (np.float_(rf['f']), np.float_(rf['blur']), np.float_(rf['y']))
A = lambda x: conv2(x, blur, 'same')
AT = lambda x: conv2(x, blur, 'same')
Af = A(f)
# ==== Setup wavelet basis for l1-onb
wav = rwt.daubcqf(2)[0]
W = lambda x: rwt.idwt(x,wav)[0]
WT = lambda x: rwt.dwt(x,wav)[0]
# ==== Set regularization parameters and iteration limit:
tauonb = 1.0e-5
tautv = 3.0e-6
taurdp = 2.0e+0
taurdpti = 6.0e-1
miniter = 50
maxiter = 100
stopcriterion = 3
tolerance = 1e-8
verbose = 10
# ==== Simple initialization: AT(y) rescaled to have a least-squares fit to the mean value
finit = y.sum()*AT(y).size/AT(y).sum()/AT(np.ones_like(y)).sum() * AT(y)
# ==== Run the algorithm, demonstrating all the options for our algorithm:
resSPIRAL = pySPIRALTAP.SPIRALTAP(y, A, tauonb, penalty='onb', AT=AT, W=W, WT=WT,
maxiter=maxiter, initialisation=finit, miniter=miniter,
stopcriterion=stopcriterion, monotone=True,
saveobjective=True, savereconerror=True, savecputime=True,
savesolutionpath=False, truth=f, verbose=verbose)
## Deparse outputs
fhatSPIRALonb = resSPIRAL[0]
parSPIRAL = resSPIRAL[1]
iterationsSPIRALonb = parSPIRAL['iterations']
objectiveSPIRALonb = parSPIRAL['objective']
reconerrorSPIRALonb = parSPIRAL['reconerror']
cputimeSPIRALonb = parSPIRAL['cputime']
resSPIRAL = pySPIRALTAP.SPIRALTAP(y, A, tautv, penalty='tv', AT=AT,
maxiter=maxiter, initialisation=finit, miniter=miniter,
stopcriterion=stopcriterion, tolerance=tolerance,
monotone=True, saveobjective=True, savereconerror=True,
savecputime=True, savesolutionpath=False, truth=f,
verbose=verbose)
## Deparse outputs
fhatSPIRALtv = resSPIRAL[0]
parSPIRAL = resSPIRAL[1]
iterationsSPIRALtv = parSPIRAL['iterations']
objectiveSPIRALtv = parSPIRAL['objective']
reconerrorSPIRALtv = parSPIRAL['reconerror']
cputimeSPIRALtv = parSPIRAL['cputime']
# resSPIRAL = pySPIRALTAP.SPIRALTAP(y, A, taurdp, penalty='rdp', AT=AT, maxiter=maxiter,
# initialisation=finit, miniter=miniter,
# stopcriterion=stopcriterion, tolerance=tolerance,
# monotone=False, saveobjective=False, savereconerror=True,
# savecputime=True, savesolutionpath=False, truth=f,
# verbose=verbose)
# ## Deparse outputs
# fhatSPIRALrdp = resSPIRAL[0]
# parSPIRAL = resSPIRAL[1]
# iterationsSPIRALrdp = parSPIRAL['iterations']
# reconerrorSPIRALrdp = parSPIRAL['reconerror']
# cputimeSPIRALrdp = parSPIRAL['cputime']
# [fhatSPIRALrdpti, iterationsSPIRALrdpti, ...
# reconerrorSPIRALrdpti, cputimeSPIRALrdpti] ...
# resSPIRAL = pySPIRALTAP.SPIRALTAP(y, A, taurdpti, penalty='rdp-ti', maxiter=maxiter,
# initialization=finit, miniter=miniter, AT=AT,
# stopcriterion=stopcriterion, tolerance=tolerance,
# monotone=False, saveobjective=False, savereconerror=True,
# savecputime=True, savesolutionpath=False, truth=f,
# verbose=verbose)
# ## Deparse outputs
# fhatSPIRALrdpti = resSPIRAL[0]
# parSPIRAL = resSPIRAL[1]
# iterationsSPIRALrdpti = parSPIRAL['iterations']
# objectiveSPIRALtv = parSPIRAL['objective']
# reconerrorSPIRALrdpti = parSPIRAL['reconerror']
# cputimeSPIRALrdpti = parSPIRAL['cputime']
print("WARNING: RDP-based reconstruction are not implemented yet" , file=sys.stderr)
# ==== Display results
# Problem data
plt.figure()
plt.subplot(131); plt.imshow(f, cmap='gray'); plt.title('True Signal (f)')
plt.subplot(132); plt.imshow(Af, cmap='gray'); plt.title('True Detector Intensity (Af)')
plt.subplot(133); plt.imshow(y, cmap='gray'); plt.title('Observed Photon Counts (y)')
# Display Objectives for Monotonic Methods
plt.figure()
plt.subplot(121)
plt.plot(range(iterationsSPIRALonb), objectiveSPIRALonb,
label='ONB Objective Evolution (Iteration)')
plt.plot(range(iterationsSPIRALtv), objectiveSPIRALtv,
label='TV Objective Evolution (Iteration)')
plt.xlabel('Iteration');plt.ylabel('Objective');plt.legend()
plt.xlim((0, np.max((iterationsSPIRALonb, iterationsSPIRALtv))))
plt.subplot(122)
plt.plot(cputimeSPIRALonb, objectiveSPIRALonb, label='ONB Objective Evolution (CPU Time)')
plt.plot(cputimeSPIRALtv, objectiveSPIRALtv, label='TV Objective Evolution (CPU Time)')
plt.xlabel('CPU Time');plt.ylabel('Objective');plt.legend()
# Display RMS Error Evolution for All Methods
plt.subplot(121)
plt.plot(range(iterationsSPIRALonb), reconerrorSPIRALonb, label='ONB')
plt.plot(range(iterationsSPIRALtv), reconerrorSPIRALtv, label='TV')
#plt.plot(range(iterationsSPIRALrdp), reconerrorSPIRALrdp, label='RDP')
#plt.plot(range(iterationsSPIRALrdpti), reconerrorSPIRALrdpti, label='RDP-TI')
plt.title('Error Evolution (Iteration)');plt.xlabel('Iteration');plt.ylabel('RMS Error')
plt.subplot(122)
plt.plot(cputimeSPIRALonb, reconerrorSPIRALonb, label='ONB')
plt.plot(cputimeSPIRALtv, reconerrorSPIRALtv, label='TV')
#plt.plot(cputimeSPIRALrdp), reconerrorSPIRALrdp, label='RDP')
#plt.plot(cputimeSPIRALrdpti), reconerrorSPIRALrdpti, label='RDP-TI')
plt.title('Error Evolution (CPU Time)');plt.xlabel('CPU Time');plt.ylabel('RMS Error')
# Display Images for All Methods
plt.figure()
plt.subplot(121);plt.imshow(fhatSPIRALonb, cmap='gray')
plt.title("ONB, RMS={}".format(reconerrorSPIRALonb[-1]))
plt.subplot(122);plt.imshow(fhatSPIRALtv, cmap='gray')
plt.title("TV, RMS={}".format(reconerrorSPIRALtv[-1]))
#plt.subplot(223);plt.imshow(fhatSPIRALrdp, cmap='gray')
#plt.title("RDP, RMS=".format(reconerrorSPIRALrdp[-1]))
#plt.subplot(224);plt.imshow(fhatSPIRALrdpti, cmap='gray')
#plt.title("RDP-TI, RMS=".format(reconerrorSPIRALrdpti[-1]))
# Difference images
diffSPIRALonb = np.abs(f-fhatSPIRALonb)
diffSPIRALtv = np.abs(f-fhatSPIRALtv)
#diffSPIRALrdp = np.abs(f-fhatSPIRALrdp)
#diffSPIRALrdpti = np.abs(f-fhatSPIRALrdpti)
plt.figure()
plt.subplot(121);plt.imshow(diffSPIRALonb)
plt.title("ONB, RMS={}".format(reconerrorSPIRALonb[-1]))
plt.subplot(122);plt.imshow(diffSPIRALtv)
plt.title("TV, RMS={}".format(reconerrorSPIRALtv[-1]))
#plt.subplot(223);plt.imshow(diffSPIRALrdp)
#plt.title("RDP, RMS=".format(reconerrorSPIRALrdp[-1]))
#plt.subplot(224);plt.imshow(diffSPIRALrdpti)
#plt.title("RDP-TI, RMS=".format(reconerrorSPIRALrdpti[-1]))
plt.show()