Skip to content

Commit 1757009

Browse files
authored
Merge pull request #553 from jeverink/regularizedGaussian_extension-ADMM
Add ADMM + parameter reordering FISTA
2 parents 8f8b008 + 3d0fc81 commit 1757009

File tree

6 files changed

+374
-11
lines changed

6 files changed

+374
-11
lines changed

cuqi/experimental/mcmc/_rto.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,8 +235,8 @@ def prior(self):
235235

236236
def step(self):
237237
y = self.b_tild + np.random.randn(len(self.b_tild))
238-
sim = FISTA(self.M, y, self.current_point, self.proximal,
239-
maxit = self.maxit, stepsize = self._stepsize, abstol = self.abstol, adaptive = self.adaptive)
238+
sim = FISTA(self.M, y, self.proximal,
239+
self.current_point, maxit = self.maxit, stepsize = self._stepsize, abstol = self.abstol, adaptive = self.adaptive)
240240
self.current_point, _ = sim.solve()
241241
acc = 1
242242
return acc

cuqi/sampler/_rto.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -267,8 +267,8 @@ def _sample(self, N, Nb):
267267
samples[:, 0] = self.x0
268268
for s in range(Ns-1):
269269
y = self.b_tild + np.random.randn(len(self.b_tild))
270-
sim = FISTA(self.M, y, samples[:, s], self.proximal,
271-
maxit = self.maxit, stepsize = _stepsize, abstol = self.abstol, adaptive = self.adaptive)
270+
sim = FISTA(self.M, y, self.proximal,
271+
samples[:, s], maxit = self.maxit, stepsize = _stepsize, abstol = self.abstol, adaptive = self.adaptive)
272272
samples[:, s+1], _ = sim.solve()
273273

274274
self._print_progress(s+2,Ns) #s+2 is the sample number, s+1 is index assuming x0 is the first sample

cuqi/solver/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
LM,
88
PDHG,
99
FISTA,
10+
ADMM,
1011
ProjectNonnegative,
1112
ProjectBox,
1213
ProximalL1

cuqi/solver/_solver.py

Lines changed: 169 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -584,8 +584,8 @@ class FISTA(object):
584584
----------
585585
A : ndarray or callable f(x,*args).
586586
b : ndarray.
587-
x0 : ndarray. Initial guess.
588587
proximal : callable f(x, gamma) for proximal mapping.
588+
x0 : ndarray. Initial guess.
589589
maxit : The maximum number of iterations.
590590
stepsize : The stepsize of the gradient step.
591591
abstol : The numerical tolerance for convergence checks.
@@ -606,11 +606,11 @@ class FISTA(object):
606606
b = rng.standard_normal(m)
607607
stepsize = 0.99/(sp.linalg.interpolative.estimate_spectral_norm(A)**2)
608608
x0 = np.zeros(n)
609-
fista = FISTA(A, b, x0, proximal = ProximalL1, stepsize = stepsize, maxit = 100, abstol=1e-12, adaptive = True)
609+
fista = FISTA(A, b, proximal = ProximalL1, x0, stepsize = stepsize, maxit = 100, abstol=1e-12, adaptive = True)
610610
sol, _ = fista.solve()
611611
612612
"""
613-
def __init__(self, A, b, x0, proximal, maxit=100, stepsize=1e0, abstol=1e-14, adaptive = True):
613+
def __init__(self, A, b, proximal, x0, maxit=100, stepsize=1e0, abstol=1e-14, adaptive = True):
614614

615615
self.A = A
616616
self.b = b
@@ -650,8 +650,157 @@ def solve(self):
650650
x_new = x_new + ((k-1)/(k+2))*(x_new - x_old)
651651

652652
x = x_new.copy()
653+
654+
class ADMM(object):
655+
"""Alternating Direction Method of Multipliers for solving regularized linear least squares problems of the form:
656+
Minimize ||Ax-b||^2 + sum_i f_i(L_i x),
657+
where the sum ranges from 1 to an arbitrary n. See definition of the parameter `penalty_terms` below for more details about f_i and L_i
658+
659+
Reference:
660+
[1] Boyd et al. "Distributed optimization and statistical learning via the alternating direction method of multipliers."Foundations and Trends® in Machine learning, 2011.
661+
662+
663+
Parameters
664+
----------
665+
A : ndarray or callable
666+
Represents a matrix or a function that performs matrix-vector multiplications.
667+
When A is a callable, it accepts arguments (x, flag) where:
668+
- flag=1 indicates multiplication of A with vector x, that is A @ x.
669+
- flag=2 indicates multiplication of the transpose of A with vector x, that is A.T @ x.
670+
b : ndarray.
671+
penalty_terms : List of tuples (callable proximal operator of f_i, linear operator L_i)
672+
Each callable proximal operator f_i accepts two arguments (x, p) and should return the minimizer of p/2||x-z||^2 + f(x) over z for some f.
673+
x0 : ndarray. Initial guess.
674+
penalty_parameter : Trade-off between linear least squares and regularization term in the solver iterates. Denoted as "rho" in [1].
675+
maxit : The maximum number of iterations.
676+
adaptive : Whether to adaptively update the penalty_parameter each iteration such that the primal and dual residual norms are of the same order of magnitude. Based on [1], Subsection 3.4.1
677+
678+
Example
679+
-----------
680+
.. code-block:: python
653681
682+
from cuqi.solver import ADMM, ProximalL1, ProjectNonnegative
683+
import numpy as np
684+
685+
rng = np.random.default_rng()
686+
687+
m, n, k = 10, 5, 4
688+
A = rng.standard_normal((m, n))
689+
b = rng.standard_normal(m)
690+
L = rng.standard_normal((k, n))
691+
692+
x0 = np.zeros(n)
693+
admm = ADMM(A, b, x0, penalty_terms = [(ProximalL1, L), (lambda z, _ : ProjectNonnegative(z), np.eye(n))], tradeoff = 10)
694+
sol, _ = admm.solve()
695+
696+
"""
697+
698+
def __init__(self, A, b, penalty_terms, x0, penalty_parameter = 10, maxit = 100, inner_max_it = 10, adaptive = True):
699+
700+
self.A = A
701+
self.b = b
702+
self.x_cur = x0
703+
704+
dual_len = [penalty[1].shape[0] for penalty in penalty_terms]
705+
self.z_cur = [np.zeros(l) for l in dual_len]
706+
self.u_cur = [np.zeros(l) for l in dual_len]
707+
self.n = penalty_terms[0][1].shape[1]
708+
709+
self.rho = penalty_parameter
710+
self.maxit = maxit
711+
self.inner_max_it = inner_max_it
712+
self.adaptive = adaptive
713+
714+
self.penalty_terms = penalty_terms
715+
716+
self.p = len(self.penalty_terms)
717+
self._big_matrix = None
718+
self._big_vector = None
719+
720+
def solve(self):
721+
"""
722+
Solves the regularized linear least squares problem using ADMM in scaled form. Based on [1], Subsection 3.1.1
723+
"""
724+
z_new = self.p*[0]
725+
u_new = self.p*[0]
726+
727+
# Iterating
728+
for i in range(self.maxit):
729+
self._iteration_pre_processing()
730+
731+
# Main update (Least Squares)
732+
solver = CGLS(self._big_matrix, self._big_vector, self.x_cur, self.inner_max_it)
733+
x_new, _ = solver.solve()
734+
735+
# Regularization update
736+
for j, penalty in enumerate(self.penalty_terms):
737+
z_new[j] = penalty[0](penalty[1]@x_new + self.u_cur[j], 1.0/self.rho)
738+
739+
res_primal = 0.0
740+
# Dual update
741+
for j, penalty in enumerate(self.penalty_terms):
742+
r_partial = penalty[1]@x_new - z_new[j]
743+
res_primal += LA.norm(r_partial)**2
744+
745+
u_new[j] = self.u_cur[j] + r_partial
746+
747+
res_dual = 0.0
748+
for j, penalty in enumerate(self.penalty_terms):
749+
res_dual += LA.norm(penalty[1].T@(z_new[j] - self.z_cur[j]))**2
750+
751+
# Adaptive approach based on [1], Subsection 3.4.1
752+
if self.adaptive:
753+
if res_dual > 1e2*res_primal:
754+
self.rho *= 0.5 # More regularization
755+
elif res_primal > 1e2*res_dual:
756+
self.rho *= 2.0 # More data fidelity
757+
758+
self.x_cur, self.z_cur, self.u_cur = x_new, z_new.copy(), u_new
759+
760+
return self.x_cur, i
654761

762+
def _iteration_pre_processing(self):
763+
""" Preprocessing
764+
Every iteration of ADMM requires solving a linear least squares system of the form
765+
minimize 1/(rho) \|Ax-b\|_2^2 + sum_{i=1}^{p} \|penalty[1]x - (y - u)\|_2^2
766+
To solve this, all linear least squares terms are combined into a single big term
767+
with matrix big_matrix and data big_vector.
768+
769+
The matrix only needs to be updated when rho changes, i.e., when the adaptive option is used.
770+
The data vector needs to be updated every iteration.
771+
"""
772+
773+
self._big_vector = np.hstack([np.sqrt(1/self.rho)*self.b] + [self.z_cur[i] - self.u_cur[i] for i in range(self.p)])
774+
775+
# Check whether matrix needs to be updated
776+
if self._big_matrix is not None and not self.adaptive:
777+
return
778+
779+
# Update big_matrix
780+
if callable(self.A):
781+
def matrix_eval(x, flag):
782+
if flag == 1:
783+
out1 = np.sqrt(1/self.rho)*self.A(x, 1)
784+
out2 = [penalty[1]@x for penalty in self.penalty_terms]
785+
out = np.hstack([out1] + out2)
786+
elif flag == 2:
787+
idx_start = len(x)
788+
idx_end = len(x)
789+
out1 = np.zeros(self.n)
790+
for _, t in reversed(self.penalty_terms):
791+
idx_start -= t.shape[0]
792+
out1 += t.T@x[idx_start:idx_end]
793+
idx_end = idx_start
794+
out2 = np.sqrt(1/self.rho)*self.A(x[:idx_end], 2)
795+
out = out1 + out2
796+
return out
797+
self._big_matrix = matrix_eval
798+
else:
799+
self._big_matrix = np.vstack([np.sqrt(1/self.rho)*self.A] + [penalty[1] for penalty in self.penalty_terms])
800+
801+
802+
803+
655804
def ProjectNonnegative(x):
656805
"""(Euclidean) projection onto the nonnegative orthant.
657806
@@ -678,6 +827,22 @@ def ProjectBox(x, lower = None, upper = None):
678827

679828
return np.minimum(np.maximum(x, lower), upper)
680829

830+
def ProjectHalfspace(x, a, b):
831+
"""(Euclidean) projection onto the halfspace defined {z|<a,z> <= b}.
832+
833+
Parameters
834+
----------
835+
x : array_like.
836+
a : array_like.
837+
b : array_like.
838+
"""
839+
840+
ax_b = np.inner(a,x) - b
841+
if ax_b <= 0:
842+
return x
843+
else:
844+
return x - (ax_b/np.inner(a,a))*a
845+
681846
def ProximalL1(x, gamma):
682847
"""(Euclidean) proximal operator of the \|x\|_1 norm.
683848
Also known as the shrinkage or soft thresholding operator.
@@ -687,4 +852,4 @@ def ProximalL1(x, gamma):
687852
x : array_like.
688853
gamma : scale parameter.
689854
"""
690-
return np.multiply(np.sign(x), np.maximum(np.abs(x)-gamma, 0))
855+
return np.multiply(np.sign(x), np.maximum(np.abs(x)-gamma, 0))

test.ipynb

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {},
7+
"outputs": [
8+
{
9+
"name": "stdout",
10+
"output_type": "stream",
11+
"text": [
12+
"[-3.99513417e-03 -1.32339656e-01 -4.52822633e-02 -7.44973888e-02\n",
13+
" -3.35005208e-11]\n"
14+
]
15+
}
16+
],
17+
"source": [
18+
"import numpy as np\n",
19+
"import scipy as sp\n",
20+
"\n",
21+
"from cuqi.solver import CGLS, LM, FISTA, ADMM, ProximalL1, ProjectNonnegative\n",
22+
"from scipy.optimize import lsq_linear\n",
23+
"\n",
24+
"\n",
25+
"def test_ADMM_matrix_form():\n",
26+
" # Parameters\n",
27+
" rng = np.random.default_rng(seed = 42)\n",
28+
" m, n = 10, 5\n",
29+
" A = rng.standard_normal((m, n))\n",
30+
" b = rng.standard_normal(m)\n",
31+
" \n",
32+
" k = 4\n",
33+
" L = rng.standard_normal((k, n))\n",
34+
"\n",
35+
" x0 = np.zeros(n)\n",
36+
" sol, _ = ADMM(A, b, [(ProximalL1, np.eye(n)), (lambda z, _ : ProjectNonnegative(z), L)],\n",
37+
" x0, 10, maxit = 100, adaptive = True).solve()\n",
38+
"\n",
39+
" print(sol)\n",
40+
" ref_sol = np.array([-3.99513417e-03, -1.32339656e-01, -4.52822633e-02, -7.44973888e-02, -3.35005208e-11])\n",
41+
" # Compare\n",
42+
" assert np.allclose(sol, ref_sol, atol=1e-4)\n",
43+
"\n",
44+
"test_ADMM_matrix_form()"
45+
]
46+
},
47+
{
48+
"cell_type": "code",
49+
"execution_count": 2,
50+
"metadata": {},
51+
"outputs": [
52+
{
53+
"name": "stdout",
54+
"output_type": "stream",
55+
"text": [
56+
"[-3.99513417e-03 -1.32339656e-01 -4.52822633e-02 -7.44973888e-02\n",
57+
" -3.35005152e-11]\n"
58+
]
59+
}
60+
],
61+
"source": [
62+
"\n",
63+
"\n",
64+
"def test_ADMM_function_form():\n",
65+
" # Parameters\n",
66+
" rng = np.random.default_rng(seed = 42)\n",
67+
" m, n = 10, 5\n",
68+
" A = rng.standard_normal((m, n))\n",
69+
" def A_fun(x, flag):\n",
70+
" if flag == 1:\n",
71+
" return A@x\n",
72+
" if flag == 2:\n",
73+
" return A.T@x\n",
74+
" \n",
75+
" b = rng.standard_normal(m)\n",
76+
" \n",
77+
" k = 4\n",
78+
" L = rng.standard_normal((k, n))\n",
79+
"\n",
80+
" x0 = np.zeros(n)\n",
81+
" sol, _ = ADMM(A_fun, b, [(ProximalL1, np.eye(n)), (lambda z, _ : ProjectNonnegative(z), L)],\n",
82+
" x0, 10, maxit = 100, adaptive = True).solve()\n",
83+
"\n",
84+
" print(sol)\n",
85+
" ref_sol = np.array([-3.99513417e-03, -1.32339656e-01, -4.52822633e-02, -7.44973888e-02, -3.35005208e-11])\n",
86+
" # Compare\n",
87+
" assert np.allclose(sol, ref_sol, atol=1e-4)\n",
88+
"\n",
89+
"test_ADMM_function_form()"
90+
]
91+
},
92+
{
93+
"cell_type": "code",
94+
"execution_count": 3,
95+
"metadata": {},
96+
"outputs": [
97+
{
98+
"name": "stdout",
99+
"output_type": "stream",
100+
"text": [
101+
"[-3.99513417e-03 -1.32339656e-01 -4.52822633e-02 -7.44973888e-02\n",
102+
" -3.35005208e-11]\n"
103+
]
104+
}
105+
],
106+
"source": [
107+
"\n",
108+
"# Parameters\n",
109+
"rng = np.random.default_rng(seed = 42)\n",
110+
"m, n = 10, 5\n",
111+
"A = rng.standard_normal((m, n))\n",
112+
"b = rng.standard_normal(m)\n",
113+
" \n",
114+
"k = 4\n",
115+
"L = rng.standard_normal((k, n))\n",
116+
"\n",
117+
"x0 = np.zeros(n)\n",
118+
"sol, _ = ADMM(A, b, [(ProximalL1, np.eye(n)), (lambda z, _ : ProjectNonnegative(z), L)],\n",
119+
" x0, 10, maxit = 100, adaptive = False).solve()\n",
120+
"\n",
121+
"print(sol)"
122+
]
123+
},
124+
{
125+
"cell_type": "code",
126+
"execution_count": null,
127+
"metadata": {},
128+
"outputs": [],
129+
"source": []
130+
}
131+
],
132+
"metadata": {
133+
"kernelspec": {
134+
"display_name": "base",
135+
"language": "python",
136+
"name": "python3"
137+
},
138+
"language_info": {
139+
"codemirror_mode": {
140+
"name": "ipython",
141+
"version": 3
142+
},
143+
"file_extension": ".py",
144+
"mimetype": "text/x-python",
145+
"name": "python",
146+
"nbconvert_exporter": "python",
147+
"pygments_lexer": "ipython3",
148+
"version": "3.11.9"
149+
}
150+
},
151+
"nbformat": 4,
152+
"nbformat_minor": 2
153+
}

0 commit comments

Comments
 (0)