diff --git a/tests/test_optimize.py b/tests/test_optimize.py index 21b7a91..83523f9 100644 --- a/tests/test_optimize.py +++ b/tests/test_optimize.py @@ -8,14 +8,29 @@ def test_merit_function_view(): def my_function(x): return x - x0 = [0., 0., 0.] + x0 = [0.0, 0.0, 0.0] opt = xd.Optimize.from_callable(my_function, x0=x0, steps=[1e-6, 1e-6, 1e-6], tar=[0., 0., 0.], tols=[1e-12, 1e-12, 1e-12], limits=[[-1, 2], [-1, 4], [-2, 2]]) - opt.solve() + # Check that Jacobian doesn't introduce side effects mf = opt.get_merit_function() + xo.assert_allclose(mf.get_x(), x0, atol=1e-6, rtol=0) + xo.assert_allclose(opt.get_knob_values(), x0, atol=1e-6, rtol=0) + + jac_init = mf.get_jacobian([0.5, 2, -1]) + xo.assert_allclose(mf.get_x(), x0, atol=1e-6, rtol=0) + xo.assert_allclose(opt.get_knob_values(), x0, atol=1e-6, rtol=0) + xo.assert_allclose(jac_init, [[1, 0, 0], [0, 1, 0], [0, 0, 1]], atol=1e-6, rtol=0) + + jac_init2 = mf.get_jacobian([0.5, 2, -1]) + xo.assert_allclose(mf.get_x(), x0, atol=1e-6, rtol=0) + xo.assert_allclose(opt.get_knob_values(), x0, atol=1e-6, rtol=0) + xo.assert_allclose(jac_init2, [[1, 0, 0], [0, 1, 0], [0, 0, 1]], atol=1e-6, rtol=0) + + opt.solve() + jmf = mf.get_jacobian([0.5, 2, -1]) xo.assert_allclose(jmf, [[1, 0, 0], [0, 1, 0], [ 0, 0, 1]], atol=1e-6, rtol=0) diff --git a/xdeps/optimize/optimize.py b/xdeps/optimize/optimize.py index 3a5d097..53316db 100644 --- a/xdeps/optimize/optimize.py +++ b/xdeps/optimize/optimize.py @@ -391,6 +391,7 @@ def __call__(self, x=None, check_limits=None, return_scalar=None, zero_if_met=No def get_jacobian(self, x, f0=None): if hasattr(self, "_force_jacobian"): return self._force_jacobian + prev_x = self._get_x() x = np.array(x).copy() steps = self._knobs_to_x(self.steps_for_jacobian) assert len(x) == len(steps) @@ -409,6 +410,7 @@ def get_jacobian(self, x, f0=None): x[ii] -= steps[ii] self._last_jac = jac + self._set_x(prev_x) return jac def _clip_to_max_steps(self, x_step): @@ -1295,6 +1297,7 @@ def target_mismatch(self, ret=False, max_col_width=40): def get_knob_values(self, iteration=None): """ Get the knob values at a given iteration. + If no iteration is given, return current knob values. Parameters ---------- @@ -1308,7 +1311,7 @@ def get_knob_values(self, iteration=None): """ if iteration is None: - iteration = len(self._log["penalty"]) - 1 + return self._err._extract_knob_values() out = dict() for ii, vv in enumerate(self.vary): out[vv.name] = self._log["knobs"][iteration][ii]