Skip to content

Commit e7c7da4

Browse files
author
Holger Kohr
committed
TST: add tests for auto_weighting decorator
1 parent 3804665 commit e7c7da4

File tree

1 file changed

+171
-1
lines changed

1 file changed

+171
-1
lines changed

odl/test/space/space_utils_test.py

Lines changed: 171 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@
1313

1414
import odl
1515
from odl import vector
16-
from odl.util.testutils import all_equal
16+
from odl.space.space_utils import auto_weighting
17+
from odl.util.testutils import all_equal, simple_fixture, noise_element
18+
19+
20+
auto_weighting_optimize = simple_fixture('optimize', [True, False])
21+
call_variant = simple_fixture('call_variant', ['oop', 'ip', 'dual'])
22+
weighting = simple_fixture('weighting', [1.0, 2.0, [1.0, 2.0]])
1723

1824

1925
def test_vector_numpy():
@@ -77,5 +83,169 @@ def test_vector_numpy():
7783
vector([[1, 0], [0, 1]])
7884

7985

86+
def test_auto_weighting(call_variant, weighting, auto_weighting_optimize):
87+
"""Test the auto_weighting decorator for different adjoint variants."""
88+
rn = odl.rn(2)
89+
rn_w = odl.rn(2, weighting=weighting)
90+
91+
class ScalingOpBase(odl.Operator):
92+
93+
def __init__(self, dom, ran, c):
94+
super(ScalingOpBase, self).__init__(dom, ran, linear=True)
95+
self.c = c
96+
97+
if call_variant == 'oop':
98+
99+
class ScalingOp(ScalingOpBase):
100+
101+
def _call(self, x):
102+
return self.c * x
103+
104+
@property
105+
@auto_weighting(optimize=auto_weighting_optimize)
106+
def adjoint(self):
107+
return ScalingOp(self.range, self.domain, self.c)
108+
109+
elif call_variant == 'ip':
110+
111+
class ScalingOp(ScalingOpBase):
112+
113+
def _call(self, x, out):
114+
out[:] = self.c * x
115+
return out
116+
117+
@property
118+
@auto_weighting(optimize=auto_weighting_optimize)
119+
def adjoint(self):
120+
return ScalingOp(self.range, self.domain, self.c)
121+
122+
elif call_variant == 'dual':
123+
124+
class ScalingOp(ScalingOpBase):
125+
126+
def _call(self, x, out=None):
127+
if out is None:
128+
out = self.c * x
129+
else:
130+
out[:] = self.c * x
131+
return out
132+
133+
@property
134+
@auto_weighting(optimize=auto_weighting_optimize)
135+
def adjoint(self):
136+
return ScalingOp(self.range, self.domain, self.c)
137+
138+
else:
139+
assert False
140+
141+
op1 = ScalingOp(rn, rn_w, 1.5)
142+
op2 = ScalingOp(rn_w, rn, 1.5)
143+
144+
for op in [op1, op2]:
145+
dom_el = noise_element(op.domain)
146+
ran_el = noise_element(op.range)
147+
assert pytest.approx(op(dom_el).inner(ran_el),
148+
dom_el.inner(op.adjoint(ran_el)))
149+
150+
151+
def test_auto_weighting_noarg():
152+
"""Test the auto_weighting decorator without the optimize argument."""
153+
rn = odl.rn(2)
154+
rn_w = odl.rn(2, weighting=2)
155+
156+
class ScalingOp(odl.Operator):
157+
158+
def __init__(self, dom, ran, c):
159+
super(ScalingOp, self).__init__(dom, ran, linear=True)
160+
self.c = c
161+
162+
def _call(self, x):
163+
return self.c * x
164+
165+
@property
166+
@auto_weighting
167+
def adjoint(self):
168+
return ScalingOp(self.range, self.domain, self.c)
169+
170+
op1 = ScalingOp(rn, rn, 1.5)
171+
op2 = ScalingOp(rn_w, rn_w, 1.5)
172+
op3 = ScalingOp(rn, rn_w, 1.5)
173+
op4 = ScalingOp(rn_w, rn, 1.5)
174+
175+
for op in [op1, op2, op3, op4]:
176+
dom_el = noise_element(op.domain)
177+
ran_el = noise_element(op.range)
178+
assert pytest.approx(op(dom_el).inner(ran_el),
179+
dom_el.inner(op.adjoint(ran_el)))
180+
181+
182+
def test_auto_weighting_cached_adjoint():
183+
"""Check if auto_weighting plays well with adjoint caching."""
184+
rn = odl.rn(2)
185+
rn_w = odl.rn(2, weighting=2)
186+
187+
class ScalingOp(odl.Operator):
188+
189+
def __init__(self, dom, ran, c):
190+
super(ScalingOp, self).__init__(dom, ran, linear=True)
191+
self.c = c
192+
self._adjoint = None
193+
194+
def _call(self, x):
195+
return self.c * x
196+
197+
@property
198+
@auto_weighting
199+
def adjoint(self):
200+
if self._adjoint is None:
201+
self._adjoint = ScalingOp(self.range, self.domain, self.c)
202+
return self._adjoint
203+
204+
op = ScalingOp(rn, rn_w, 1.5)
205+
dom_el = noise_element(op.domain)
206+
op_eval_before = op(dom_el)
207+
208+
adj = op.adjoint
209+
adj_again = op.adjoint
210+
assert adj_again is adj
211+
212+
# Check that original op is intact
213+
assert not hasattr(op, '_call_unweighted') # op shouldn't be mutated
214+
op_eval_after = op(dom_el)
215+
assert all_equal(op_eval_before, op_eval_after)
216+
217+
dom_el = noise_element(op.domain)
218+
ran_el = noise_element(op.range)
219+
op(dom_el)
220+
op.adjoint(ran_el)
221+
assert pytest.approx(op(dom_el).inner(ran_el),
222+
dom_el.inner(op.adjoint(ran_el)))
223+
224+
225+
def test_auto_weighting_raise_on_return_self():
226+
"""Check that auto_weighting raises when adjoint returns self."""
227+
rn = odl.rn(2)
228+
229+
class InvalidScalingOp(odl.Operator):
230+
231+
def __init__(self, dom, ran, c):
232+
super(InvalidScalingOp, self).__init__(dom, ran, linear=True)
233+
self.c = c
234+
self._adjoint = None
235+
236+
def _call(self, x):
237+
return self.c * x
238+
239+
@property
240+
@auto_weighting
241+
def adjoint(self):
242+
return self
243+
244+
# This would be a vaild situation for adjont just returning self
245+
op = InvalidScalingOp(rn, rn, 1.5)
246+
with pytest.raises(TypeError):
247+
op.adjoint
248+
249+
80250
if __name__ == '__main__':
81251
odl.util.test_file(__file__)

0 commit comments

Comments
 (0)