Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions Wrappers/Python/cil/optimisation/algorithms/ProxSkip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from cil.optimisation.algorithms import Algorithm
import numpy as np
import logging
from warnings import warn


class ProxSkip(Algorithm):


r"""Proximal Skip (ProxSkip) algorithm, see "ProxSkip: Yes! Local Gradient Steps Provably Lead to Communication Acceleration! Finally!†"

Parameters
----------

initial : DataContainer
Initial point for the ProxSkip algorithm.
f : Function
A smooth function with Lipschitz continuous gradient.
g : Function
A convex function with a "simple" proximal.
prob : positive :obj:`float`
Probability to skip the proximal step. If :code:`prob=1`, proximal step is used in every iteration.
step_size : positive :obj:`float`
Step size for the ProxSkip algorithm. It is equal to 1./L for strongly convex f and 2./L for convex f, where L is the Lipschitz constant for the gradient of f.

"""


def __init__(self, initial, f, g, step_size, prob, seed=None, **kwargs):
""" Set up of the algorithm
"""

super(ProxSkip, self).__init__(**kwargs)

self.f = f # smooth function
self.g = g # proximable
self.step_size = step_size
self.prob = prob
self.rng = np.random.default_rng(seed=seed)
self.thetas = []

if self.prob<=0:
raise ValueError("Need a positive probability")
if self.prob==1:
raise warn("If p=1, ProxSkip is equivalent to ISTA/PGD. Please use ISTA/PGD to avoid computing updates of the control variate that is not used.")

self.set_up(initial, f, g, step_size, prob, **kwargs)


def set_up(self, initial, f, g, step_size, prob, **kwargs):

logging.info("{} setting up".format(self.__class__.__name__, ))

## TODO better to use different initials for x and h.
self.initial = initial
self.x = initial.copy()
self.xhat_new = initial.copy()
self.x_new = initial.copy()
self.ht = initial.copy()

self.configured = True

logging.info("{} configured".format(self.__class__.__name__, ))


def update(self):
r""" Performs a single iteration of the ProxSkip algorithm
"""

self.f.gradient(self.x, out=self.xhat_new)
self.xhat_new -= self.ht
self.x.sapyb(1., self.xhat_new, -self.step_size, out=self.xhat_new)

theta = self.rng.random() < self.prob
# convention: use proximal in the first iteration
if self.iteration==0:
theta = True
self.thetas.append(theta)

if theta:
# Proximal step is used
self.g.proximal(self.xhat_new - (self.step_size/self.prob)*self.ht, self.step_size/self.prob, out=self.x_new)
self.ht.sapyb(1., (self.x_new - self.xhat_new), (self.prob/self.step_size), out=self.ht)
else:
self.x_new.fill(self.xhat_new)

def _update_previous_solution(self):
""" Swaps the references to current and previous solution based on the :func:`~Algorithm.update_previous_solution` of the base class :class:`Algorithm`.
"""
tmp = self.x_new
self.x = self.x_new
self.x = tmp

def get_output(self):
" Returns the current solution. "
return self.x


def update_objective(self):

""" Updates the objective

.. math:: f(x) + g(x)

"""

fun_g = self.g(self.x)
fun_f = self.f(self.x)
p1 = fun_f + fun_g
self.loss.append( p1 )


3 changes: 2 additions & 1 deletion Wrappers/Python/cil/optimisation/algorithms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
from .GD import GD
from .FISTA import FISTA
from .FISTA import ISTA
from .ProxSkip import ProxSkip
from .FISTA import ISTA as PGD
from .APGD import APGD
from .PDHG import PDHG
from .ADMM import LADMM
from .SPDHG import SPDHG
from .PD3O import PD3O
from .LSQR import LSQR
from .LSQR import LSQR
101 changes: 98 additions & 3 deletions Wrappers/Python/test/test_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@


from cil.optimisation.functions import MixedL21Norm, BlockFunction, L1Norm, KullbackLeibler, IndicatorBox, LeastSquares, ZeroFunction, L2NormSquared, OperatorCompositionFunction, TotalVariation, SGFunction, SVRGFunction, SAGAFunction, SAGFunction, LSVRGFunction, ScaledFunction
from cil.optimisation.algorithms import Algorithm, GD, CGLS, SIRT, FISTA, ISTA, SPDHG, PDHG, LADMM, PD3O, PGD, APGD , LSQR
from cil.optimisation.algorithms import Algorithm, GD, CGLS, SIRT, FISTA, ISTA, SPDHG, PDHG, LADMM, PD3O, PGD, APGD , LSQR, ProxSkip


from scipy.optimize import minimize, rosen
Expand Down Expand Up @@ -340,8 +340,7 @@ def test_provable_convergence(self):
with self.assertRaises(NotImplementedError):
alg.is_provably_convergent()






class TestFISTA(CCPiTestClass):
Expand Down Expand Up @@ -533,6 +532,102 @@ def get_step_size(self, algorithm):
self.assertEqual(alg.step_size, 0.99/2)
self.assertEqual(alg.step_size, 0.99/2)


class TestProxSkip(CCPiTestClass):

def setUp(self):

np.random.seed(10)
n = 50
m = 500

A = np.random.uniform(0, 1, (m, n)).astype('float32')
b = (A.dot(np.random.randn(n)) + 0.1 *
np.random.randn(m)).astype('float32')

self.Aop = MatrixOperator(A)
self.bop = VectorData(b)

self.f = LeastSquares(self.Aop, b=self.bop, c=0.5)
self.g = 0.5 * L1Norm()
self.step_size = 1.99/self.f.L

self.ig = self.Aop.domain

self.initial = self.ig.allocate()

def tearDown(self):
pass

def test_signature(self):

# check required arguments (initial, f, g, step size, and prob)
with np.testing.assert_raises(TypeError):
proxskip = ProxSkip(initial = self.initial, f=self.f, g=self.g, step_size=self.step_size)

# test neg prob
with np.testing.assert_raises(ValueError):
proxskip = ProxSkip(initial = self.initial, f=self.f, g=self.g, step_size=self.step_size, prob=-0.1)

# zero prob
with np.testing.assert_raises(ValueError):
proxskip = ProxSkip(initial = self.initial, f=self.f, g=self.g, step_size=self.step_size, prob=0.)

def test_coin_flip(self):

seed = 10
num_it = 100
prob = 0.2

proxskip1 = ProxSkip(initial=self.initial, f=self.f, g=self.g,
step_size=self.step_size, prob=prob, seed=seed)
proxskip1.run(num_it, verbose=0)

rng = np.random.default_rng(seed)

thetas1 = []
for k in range(num_it):
tmp = rng.random() < prob
theta = True if k == 0 else tmp
thetas1.append(theta)

assert np.array_equal(proxskip1.thetas, thetas1)


def test_seeds(self):

# same seeds
proxskip1 = ProxSkip(initial = self.initial, f=self.f, g=self.g, step_size=self.step_size, prob=0.1, seed=10)
proxskip1.run(100, verbose=0)

proxskip2 = ProxSkip(initial = self.initial, f=self.f, g=self.g, step_size=self.step_size, prob=0.1, seed=10)
proxskip2.run(100, verbose=0)

np.testing.assert_allclose(proxskip2.thetas, proxskip1.thetas)

# different seeds
proxskip2 = ProxSkip(initial = self.initial, f=self.f, g=self.g, step_size=self.step_size,
prob=0.1, seed=20)
proxskip2.run(100, verbose=0)

assert not np.array_equal(proxskip2.thetas, proxskip1.thetas)


def test_ista_vs_proxskip(self):

prox = ProxSkip(initial=self.initial, f=self.f,
g=self.g, step_size = self.step_size, prob = 0.1)
prox.run(2000, verbose=0)

ista = ISTA(initial=self.initial, f=self.f,
g=self.g, step_size = self.step_size)
ista.run(1000, verbose=0)

np.testing.assert_allclose(ista.objective[-1], prox.objective[-1], atol=1e-3)
np.testing.assert_allclose(
prox.solution.array, ista.solution.array, atol=1e-3)


class testISTA(CCPiTestClass):

def setUp(self):
Expand Down
Loading