Skip to content

Commit 89de84d

Browse files
committed
STY: mainly style updates in bregman functional and corresponding test.
1 parent bd4f408 commit 89de84d

File tree

2 files changed

+26
-24
lines changed

2 files changed

+26
-24
lines changed

odl/solvers/functional/default_functionals.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2476,11 +2476,11 @@ class BregmanDistance(Functional):
24762476
Notes
24772477
-----
24782478
Given a functional :math:`f`, which has a (sub)gradient :math:`\partial f`,
2479-
and given a point :math:`y`, the Bregman distance functional :math:`D_f`
2480-
in a point :math:`x` is given by
2479+
and given a point :math:`y`, the Bregman distance functional
2480+
:math:`D_f(\cdot, y)` in a point :math:`x` is given by
24812481
24822482
.. math::
2483-
D_f(x) = f(x) - f(y) - \langle \partial f(y), x - y \\rangle.
2483+
D_f(x, y) = f(x) - f(y) - \langle \partial f(y), x - y \\rangle.
24842484
24852485
24862486
References
@@ -2503,7 +2503,7 @@ def __init__(self, functional, point, subgradient_op=None):
25032503
optional argument `subgradient_op` is not given, the functional
25042504
needs to implement `functional.gradient`.
25052505
point : element of ``functional.domain``
2506-
The point from which to define the Bregman distance
2506+
The point from which to define the Bregman distance.
25072507
subgradient_op : `Operator`, optional
25082508
The operator that takes an element in `functional.domain` and
25092509
returns a subgradient of the functional in that point.
@@ -2513,15 +2513,16 @@ def __init__(self, functional, point, subgradient_op=None):
25132513
--------
25142514
Example of initializing the Bregman distance functional:
25152515
2516-
>>> space = odl.uniform_discr(0, 2, 14)
2516+
>>> space = odl.uniform_discr(0, 1, 10)
25172517
>>> l2_squared = odl.solvers.L2NormSquared(space)
25182518
>>> point = space.one()
25192519
>>> Bregman_dist = odl.solvers.BregmanDistance(l2_squared, point)
25202520
2521-
This is gives the shifted L2 norm squared ||x - 1||:
2521+
This is gives the shifted L2 norm squared ||x - 1||^2:
25222522
2523-
>>> Bregman_dist(space.zero())
2524-
2.0
2523+
>>> expected_value = l2_squared(space.one())
2524+
>>> Bregman_dist(space.zero()) == expected_value
2525+
True
25252526
"""
25262527
if not isinstance(functional, Functional):
25272528
raise TypeError('`functional` {} not an instance of ``Functional``'
@@ -2545,16 +2546,16 @@ def __init__(self, functional, point, subgradient_op=None):
25452546
''.format(functional))
25462547
else:
25472548
# Check that given subgradient is an operator that maps from the
2548-
# domain of the functional to the domain of the functional
2549+
# domain of the functional to itself
25492550
if not isinstance(subgradient_op, Operator):
25502551
raise TypeError('`subgradient_op` {} is not an instance of '
25512552
'``Operator``'.format(subgradient_op))
2552-
if not self.__functional.domain == subgradient_op.domain:
2553+
if not subgradient_op.domain == self.__functional.domain:
25532554
raise ValueError('`functional.domain` {} is not the same as '
25542555
'`subgradient_op.domain` {}'
25552556
''.format(self.__functional.domain,
25562557
subgradient_op.domain))
2557-
if not self.__functional.domain == subgradient_op.range:
2558+
if not subgradient_op.range == self.__functional.domain:
25582559
raise ValueError('`functional.domain` {} is not the same as '
25592560
'`subgradient_op.range` {}'
25602561
''.format(self.__functional.domain,
@@ -2568,7 +2569,8 @@ def __init__(self, functional, point, subgradient_op=None):
25682569

25692570
super(BregmanDistance, self).__init__(
25702571
space=functional.domain, linear=False,
2571-
grad_lipschitz=self.__functional.grad_lipschitz)
2572+
grad_lipschitz=(self.__functional.grad_lipschitz +
2573+
self.__subgrad_eval.norm()))
25722574

25732575
@property
25742576
def functional(self):
@@ -2606,11 +2608,9 @@ def gradient(self):
26062608

26072609
def __repr__(self):
26082610
'''Return ``repr(self)``.'''
2609-
return '{}({!r}, {!r}, {!r}, {!r})'.format(self.__class__.__name__,
2610-
self.domain,
2611-
self.functional,
2612-
self.point,
2613-
self.subgradient_op)
2611+
return '{}({!r}, {!r}, {!r})'.format(self.__class__.__name__,
2612+
self.functional, self.point,
2613+
self.subgradient_op)
26142614

26152615

26162616
if __name__ == '__main__':

odl/test/solvers/functional/default_functionals_test.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -517,9 +517,12 @@ def test_moreau_envelope_l2_sq(space, sigma):
517517

518518

519519
def test_bregman_functional_no_gradient(space):
520-
"""Test that the Bregman distance functional fails if the underlying
520+
"""Test Bregman distance for functional without gradient.
521+
522+
Test that the Bregman distance functional fails if the underlying
521523
functional does not have a gradient and no subgradient operator is
522-
given."""
524+
given. Also test giving the subgradient operator separately.
525+
"""
523526

524527
ind_func = odl.solvers.IndicatorNonnegativity(space)
525528
point = noise_element(space)
@@ -542,14 +545,12 @@ def test_bregman_functional_no_gradient(space):
542545

543546

544547
def test_bregman_functional_l2_squared(space, sigma):
545-
"""Test for the Bregman distance functional, using l2 norm squared as
546-
underlying functional."""
548+
"""Test Bregman distance using l2 norm squared as underlying functional."""
547549
sigma = float(sigma)
548550

549551
l2_sq = odl.solvers.L2NormSquared(space)
550552
point = noise_element(space)
551-
subgrad_op = odl.ScalingOperator(space, 2.0)
552-
bregman_dist = odl.solvers.BregmanDistance(l2_sq, point, subgrad_op)
553+
bregman_dist = odl.solvers.BregmanDistance(l2_sq, point)
553554

554555
expected_func = odl.solvers.L2NormSquared(space).translated(point)
555556

@@ -559,7 +560,8 @@ def test_bregman_functional_l2_squared(space, sigma):
559560
assert all_almost_equal(bregman_dist(x), expected_func(x))
560561

561562
# Gradient evaluation
562-
assert all_almost_equal(bregman_dist(x), expected_func(x))
563+
assert all_almost_equal(bregman_dist.gradient(x),
564+
expected_func.gradient(x))
563565

564566
# Convex conjugate
565567
cc_bregman_dist = bregman_dist.convex_conj

0 commit comments

Comments
 (0)