|
13 | 13 |
|
14 | 14 | import odl |
15 | 15 | from odl import vector |
16 | | -from odl.util.testutils import all_equal |
| 16 | +from odl.space.space_utils import auto_weighting |
| 17 | +from odl.util.testutils import all_equal, simple_fixture, noise_element |
| 18 | + |
| 19 | + |
| 20 | +auto_weighting_optimize = simple_fixture('optimize', [True, False]) |
| 21 | +call_variant = simple_fixture('call_variant', ['oop', 'ip', 'dual']) |
| 22 | +weighting = simple_fixture('weighting', [1.0, 2.0, [1.0, 2.0]]) |
17 | 23 |
|
18 | 24 |
|
19 | 25 | def test_vector_numpy(): |
@@ -77,5 +83,169 @@ def test_vector_numpy(): |
77 | 83 | vector([[1, 0], [0, 1]]) |
78 | 84 |
|
79 | 85 |
|
| 86 | +def test_auto_weighting(call_variant, weighting, auto_weighting_optimize): |
| 87 | + """Test the auto_weighting decorator for different adjoint variants.""" |
| 88 | + rn = odl.rn(2) |
| 89 | + rn_w = odl.rn(2, weighting=weighting) |
| 90 | + |
| 91 | + class ScalingOpBase(odl.Operator): |
| 92 | + |
| 93 | + def __init__(self, dom, ran, c): |
| 94 | + super(ScalingOpBase, self).__init__(dom, ran, linear=True) |
| 95 | + self.c = c |
| 96 | + |
| 97 | + if call_variant == 'oop': |
| 98 | + |
| 99 | + class ScalingOp(ScalingOpBase): |
| 100 | + |
| 101 | + def _call(self, x): |
| 102 | + return self.c * x |
| 103 | + |
| 104 | + @property |
| 105 | + @auto_weighting(optimize=auto_weighting_optimize) |
| 106 | + def adjoint(self): |
| 107 | + return ScalingOp(self.range, self.domain, self.c) |
| 108 | + |
| 109 | + elif call_variant == 'ip': |
| 110 | + |
| 111 | + class ScalingOp(ScalingOpBase): |
| 112 | + |
| 113 | + def _call(self, x, out): |
| 114 | + out[:] = self.c * x |
| 115 | + return out |
| 116 | + |
| 117 | + @property |
| 118 | + @auto_weighting(optimize=auto_weighting_optimize) |
| 119 | + def adjoint(self): |
| 120 | + return ScalingOp(self.range, self.domain, self.c) |
| 121 | + |
| 122 | + elif call_variant == 'dual': |
| 123 | + |
| 124 | + class ScalingOp(ScalingOpBase): |
| 125 | + |
| 126 | + def _call(self, x, out=None): |
| 127 | + if out is None: |
| 128 | + out = self.c * x |
| 129 | + else: |
| 130 | + out[:] = self.c * x |
| 131 | + return out |
| 132 | + |
| 133 | + @property |
| 134 | + @auto_weighting(optimize=auto_weighting_optimize) |
| 135 | + def adjoint(self): |
| 136 | + return ScalingOp(self.range, self.domain, self.c) |
| 137 | + |
| 138 | + else: |
| 139 | + assert False |
| 140 | + |
| 141 | + op1 = ScalingOp(rn, rn_w, 1.5) |
| 142 | + op2 = ScalingOp(rn_w, rn, 1.5) |
| 143 | + |
| 144 | + for op in [op1, op2]: |
| 145 | + dom_el = noise_element(op.domain) |
| 146 | + ran_el = noise_element(op.range) |
| 147 | + assert pytest.approx(op(dom_el).inner(ran_el), |
| 148 | + dom_el.inner(op.adjoint(ran_el))) |
| 149 | + |
| 150 | + |
| 151 | +def test_auto_weighting_noarg(): |
| 152 | + """Test the auto_weighting decorator without the optimize argument.""" |
| 153 | + rn = odl.rn(2) |
| 154 | + rn_w = odl.rn(2, weighting=2) |
| 155 | + |
| 156 | + class ScalingOp(odl.Operator): |
| 157 | + |
| 158 | + def __init__(self, dom, ran, c): |
| 159 | + super(ScalingOp, self).__init__(dom, ran, linear=True) |
| 160 | + self.c = c |
| 161 | + |
| 162 | + def _call(self, x): |
| 163 | + return self.c * x |
| 164 | + |
| 165 | + @property |
| 166 | + @auto_weighting |
| 167 | + def adjoint(self): |
| 168 | + return ScalingOp(self.range, self.domain, self.c) |
| 169 | + |
| 170 | + op1 = ScalingOp(rn, rn, 1.5) |
| 171 | + op2 = ScalingOp(rn_w, rn_w, 1.5) |
| 172 | + op3 = ScalingOp(rn, rn_w, 1.5) |
| 173 | + op4 = ScalingOp(rn_w, rn, 1.5) |
| 174 | + |
| 175 | + for op in [op1, op2, op3, op4]: |
| 176 | + dom_el = noise_element(op.domain) |
| 177 | + ran_el = noise_element(op.range) |
| 178 | + assert pytest.approx(op(dom_el).inner(ran_el), |
| 179 | + dom_el.inner(op.adjoint(ran_el))) |
| 180 | + |
| 181 | + |
| 182 | +def test_auto_weighting_cached_adjoint(): |
| 183 | + """Check if auto_weighting plays well with adjoint caching.""" |
| 184 | + rn = odl.rn(2) |
| 185 | + rn_w = odl.rn(2, weighting=2) |
| 186 | + |
| 187 | + class ScalingOp(odl.Operator): |
| 188 | + |
| 189 | + def __init__(self, dom, ran, c): |
| 190 | + super(ScalingOp, self).__init__(dom, ran, linear=True) |
| 191 | + self.c = c |
| 192 | + self._adjoint = None |
| 193 | + |
| 194 | + def _call(self, x): |
| 195 | + return self.c * x |
| 196 | + |
| 197 | + @property |
| 198 | + @auto_weighting |
| 199 | + def adjoint(self): |
| 200 | + if self._adjoint is None: |
| 201 | + self._adjoint = ScalingOp(self.range, self.domain, self.c) |
| 202 | + return self._adjoint |
| 203 | + |
| 204 | + op = ScalingOp(rn, rn_w, 1.5) |
| 205 | + dom_el = noise_element(op.domain) |
| 206 | + op_eval_before = op(dom_el) |
| 207 | + |
| 208 | + adj = op.adjoint |
| 209 | + adj_again = op.adjoint |
| 210 | + assert adj_again is adj |
| 211 | + |
| 212 | + # Check that original op is intact |
| 213 | + assert not hasattr(op, '_call_unweighted') # op shouldn't be mutated |
| 214 | + op_eval_after = op(dom_el) |
| 215 | + assert all_equal(op_eval_before, op_eval_after) |
| 216 | + |
| 217 | + dom_el = noise_element(op.domain) |
| 218 | + ran_el = noise_element(op.range) |
| 219 | + op(dom_el) |
| 220 | + op.adjoint(ran_el) |
| 221 | + assert pytest.approx(op(dom_el).inner(ran_el), |
| 222 | + dom_el.inner(op.adjoint(ran_el))) |
| 223 | + |
| 224 | + |
| 225 | +def test_auto_weighting_raise_on_return_self(): |
| 226 | + """Check that auto_weighting raises when adjoint returns self.""" |
| 227 | + rn = odl.rn(2) |
| 228 | + |
| 229 | + class InvalidScalingOp(odl.Operator): |
| 230 | + |
| 231 | + def __init__(self, dom, ran, c): |
| 232 | + super(InvalidScalingOp, self).__init__(dom, ran, linear=True) |
| 233 | + self.c = c |
| 234 | + self._adjoint = None |
| 235 | + |
| 236 | + def _call(self, x): |
| 237 | + return self.c * x |
| 238 | + |
| 239 | + @property |
| 240 | + @auto_weighting |
| 241 | + def adjoint(self): |
| 242 | + return self |
| 243 | + |
| 244 | + # This would be a vaild situation for adjont just returning self |
| 245 | + op = InvalidScalingOp(rn, rn, 1.5) |
| 246 | + with pytest.raises(TypeError): |
| 247 | + op.adjoint |
| 248 | + |
| 249 | + |
80 | 250 | if __name__ == '__main__': |
81 | 251 | odl.util.test_file(__file__) |
0 commit comments