Skip to content

Commit bd4f408

Browse files
committed
ENH/TST/STY: mainly change gradient in Bregman distance.
If a subgradient opertor is given expicitly, this one is used if the user calls bregman_dist.gradient(x). Also minor style fix in bregman distance functional.
1 parent 64fe3c6 commit bd4f408

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

odl/solvers/functional/default_functionals.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2551,14 +2551,14 @@ def __init__(self, functional, point, subgradient_op=None):
25512551
'``Operator``'.format(subgradient_op))
25522552
if not self.__functional.domain == subgradient_op.domain:
25532553
raise ValueError('`functional.domain` {} is not the same as '
2554-
'`subgradient_op.domain` {}'.format(
2555-
self.__functional.domain,
2556-
subgradient_op.domain))
2554+
'`subgradient_op.domain` {}'
2555+
''.format(self.__functional.domain,
2556+
subgradient_op.domain))
25572557
if not self.__functional.domain == subgradient_op.range:
25582558
raise ValueError('`functional.domain` {} is not the same as '
2559-
'`subgradient_op.range` {}'.format(
2560-
self.__functional.domain,
2561-
subgradient_op.range))
2559+
'`subgradient_op.range` {}'
2560+
''.format(self.__functional.domain,
2561+
subgradient_op.range))
25622562
self.__subgradient_op = subgradient_op
25632563

25642564
self.__subgrad_eval = self.__subgradient_op(self.__point)
@@ -2602,7 +2602,7 @@ def proximal(self):
26022602
@property
26032603
def gradient(self):
26042604
"""Gradient operator of the functional."""
2605-
return self.__bregman_dist.gradient
2605+
return self.subgradient_op - ConstantOperator(self.__subgrad_eval)
26062606

26072607
def __repr__(self):
26082608
'''Return ``repr(self)``.'''

odl/test/solvers/functional/default_functionals_test.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -516,12 +516,11 @@ def test_moreau_envelope_l2_sq(space, sigma):
516516
x * 2 / (1 + 2 * sigma))
517517

518518

519-
def test_bregman_functional_no_gradient():
519+
def test_bregman_functional_no_gradient(space):
520520
"""Test that the Bregman distance functional fails if the underlying
521521
functional does not have a gradient and no subgradient operator is
522522
given."""
523523

524-
space = odl.uniform_discr(0, 1, 3)
525524
ind_func = odl.solvers.IndicatorNonnegativity(space)
526525
point = noise_element(space)
527526

@@ -533,7 +532,13 @@ def test_bregman_functional_no_gradient():
533532
# If a subgradient operator is given separately, it is possible to create
534533
# an instance of the functional
535534
subgrad_op = odl.IdentityOperator(space)
536-
odl.solvers.BregmanDistance(ind_func, point, subgrad_op)
535+
bregman_dist = odl.solvers.BregmanDistance(ind_func, point, subgrad_op)
536+
537+
# In this case we should be able to call the gradient of the bregman
538+
# distance, which would give us a subgradient
539+
x = np.abs(noise_element(space))
540+
expected_result = subgrad_op(x) - subgrad_op(point)
541+
assert all_almost_equal(bregman_dist.gradient(x), expected_result)
537542

538543

539544
def test_bregman_functional_l2_squared(space, sigma):

0 commit comments

Comments
 (0)