diff --git a/tests/test_match_optics_and_ip_knob.py b/tests/test_match_optics_and_ip_knob.py index f257e2001..c33808c28 100644 --- a/tests/test_match_optics_and_ip_knob.py +++ b/tests/test_match_optics_and_ip_knob.py @@ -5,6 +5,7 @@ import xobjects as xo import xtrack as xt from xobjects.test_helpers import for_all_test_contexts +import xtrack._temp.lhc_match as lm test_data_folder = pathlib.Path( __file__).parent.joinpath('../test_data').absolute() @@ -726,3 +727,61 @@ def test_match_ir8_optics(test_context): xo.assert_allclose(tw.lhcb2['betx', 'ip2'], 10., atol=1e-5, rtol=0) xo.assert_allclose(tw.lhcb2['bety', 'ip2'], 10., atol=1e-5, rtol=0) +@for_all_test_contexts +def test_match_optics_ad(test_context): + collider = xt.Environment.from_json(test_data_folder / + 'hllhc15_thick/hllhc15_collider_thick.json') + collider.vars.load_madx(test_data_folder / + 'hllhc15_thick/opt_round_150_1500.madx') + collider.build_trackers(test_context) + + line = collider.lhcb1 + tw0 = line.twiss() + + lm.set_var_limits_and_steps(collider) + + opt = line.match( + solve=False, + default_tol={None: 1e-8, 'betx': 1e-6, 'bety': 1e-6, 'alfx': 1e-6, 'alfy': 1e-6}, + start='s.ds.l8.b1', end='ip1', + init=tw0, init_at=xt.START, + vary=[ + # Only IR8 quadrupoles including DS + xt.VaryList(['kq6.l8b1', 'kq7.l8b1', 'kq8.l8b1', 'kq9.l8b1', 'kq10.l8b1', + 'kqtl11.l8b1', 'kqt12.l8b1', 'kqt13.l8b1', + 'kq4.l8b1', 'kq5.l8b1', 'kq4.r8b1', 'kq5.r8b1', + 'kq6.r8b1', 'kq7.r8b1', 'kq8.r8b1', 'kq9.r8b1', + 'kq10.r8b1', 'kqtl11.r8b1', 'kqt12.r8b1', 'kqt13.r8b1'])], + targets=[ + xt.TargetSet(at='ip8', tars=('betx', 'bety', 'alfx', 'alfy', 'dx', 'dpx'), value=tw0), + xt.TargetSet(at='ip1', betx=0.15, bety=0.1, alfx=0, alfy=0, dx=0, dpx=0), + xt.TargetRelPhaseAdvance('mux', value = tw0['mux', 'ip1.l1'] - tw0['mux', 's.ds.l8.b1']), + xt.TargetRelPhaseAdvance('muy', value = tw0['muy', 'ip1.l1'] - tw0['muy', 's.ds.l8.b1']), + ], + use_ad=True) + + opt.step(30) + + ## With edge effects takes a bit longer + + assert opt._err.call_counter < 120 + assert len(opt.log()) < 30 + + tw = line.twiss(init=tw0, start='s.ds.l8.b1', end='ip1') + + xo.assert_allclose(tw['betx', 'ip1'], 0.15, atol=1e-6, rtol=0) + xo.assert_allclose(tw['bety', 'ip1'], 0.1, atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfx', 'ip1'], 0., atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfy', 'ip1'], 0., atol=1e-6, rtol=0) + xo.assert_allclose(tw['dx', 'ip1'], 0., atol=1e-6, rtol=0) + xo.assert_allclose(tw['dy', 'ip1'], 0., atol=1e-6, rtol=0) + + xo.assert_allclose(tw['betx', 'ip8'], tw0['betx', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['bety', 'ip8'], tw0['bety', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfx', 'ip8'], tw0['alfx', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfy', 'ip8'], tw0['alfy', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['dx', 'ip8'], tw0['dx', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['dy', 'ip8'], tw0['dy', 'ip8'], atol=1e-6, rtol=0) + + xo.assert_allclose(tw['mux', 'ip1.l1'] - tw['mux', 's.ds.l8.b1'], tw0['mux', 'ip1.l1'] - tw0['mux', 's.ds.l8.b1'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['muy', 'ip1.l1'] - tw['muy', 's.ds.l8.b1'], tw0['muy', 'ip1.l1'] - tw0['muy', 's.ds.l8.b1'], atol=1e-6, rtol=0) \ No newline at end of file diff --git a/xtrack/autodiff.py b/xtrack/autodiff.py new file mode 100644 index 000000000..0493f57f6 --- /dev/null +++ b/xtrack/autodiff.py @@ -0,0 +1,390 @@ +import jax +import jax.numpy as jnp +from functools import partial +from typing import NamedTuple +import xtrack as xt + + +class TransferMatrixFactory: + + @staticmethod + @jax.jit + def quad(k1, l, beta0, gamma0): + """Quadrupole transfer matrix. + + Parameters + ---------- + k1 : float + Quadrupole strength. + l : float + Length of the quadrupole. + beta0 : float + Reference relativistic beta. + gamma0 : float + Reference relativistic gamma. + + Returns + ------- + f_matrix : jnp.ndarray + Transfer matrix for the quadrupole. + """ + + kx = jnp.sqrt(k1.astype(complex)) + ky = jnp.sqrt(-k1.astype(complex)) + sx = l * jnp.sinc(kx * l / jnp.pi) + cx = jnp.cos(kx * l) + sy = l * jnp.sinc(ky * l / jnp.pi) # limit of sin(ky * l) / ky when ky -> 0 + cy = jnp.cos(ky * l) + + f_matrix = jnp.array([ + [cx, sx, 0, 0, 0, 0], + [-kx**2 * sx, cx, 0, 0, 0, 0], + [0, 0, cy, sy, 0, 0], + [0, 0, -ky**2 * sy, cy, 0, 0], + [0, 0, 0, 0, 1, l/(beta0**2 * gamma0**2)], + [0, 0, 0, 0, 0, 1] + ]) + + return f_matrix.real + + @staticmethod + @jax.jit + def drift(l, beta0, gamma0): + """Drift transfer matrix. + + Parameters + ---------- + l : float + Length of the drift. + beta0 : float + Reference relativistic beta. + gamma0 : float + Reference relativistic gamma. + Returns + ------- + f_matrix : jnp.ndarray + Transfer matrix for the drift. + """ + + f_matrix = jnp.array([ + [1, l, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0], + [0, 0, 1, l, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 1, l/(beta0**2 * gamma0**2)], + [0, 0, 0, 0, 0, 1] + ]) + return f_matrix + + @staticmethod + @jax.jit + def bend(k0, k1, l, h, beta0, gamma0): + """Bend transfer matrix. + + Parameters + ---------- + k0 : float + First order curvature (dipole strength). + k1 : float + Second order curvature (quadrupole strength). + l : float + Length of the bend. + h : float + Horizontal offset of the bend. + beta0 : float + Reference relativistic beta. + gamma0 : float + Reference relativistic gamma. + Returns + ------- + f_matrix : jnp.ndarray + Transfer matrix for the bend. + """ + + kx = jnp.sqrt((h * k0 + k1).astype(complex)) + ky = jnp.sqrt(-k1.astype(complex)) # for dipoles usually 0 + sx = l * jnp.sinc(kx * l / jnp.pi) + cx = jnp.cos(kx * l) + sy = l * jnp.sinc(ky * l / jnp.pi) + cy = jnp.cos(ky * l) + dx = (1 - cx) / kx**2 + j1 = (l - sx) / kx**2 + + f_matrix = jnp.array([ + [cx, sx, 0, 0, 0, h/beta0 * dx], + [-kx**2 * sx, cx, 0, 0, 0, h/beta0 * sx], + [0, 0, cy, sy, 0, 0], + [0, 0, -ky**2 * sy, cy, 0, 0], + [-h/beta0 * sx, -h/beta0 * dx, 0, 0, 1, l/(beta0**2 * gamma0**2) - h**2/beta0**2 * j1], + [0, 0, 0, 0, 0, 1] + ]) + + return f_matrix.real + +class EncodedElem(NamedTuple): + etype: int + data0: float = 0.0 + data1: float = 0.0 + data2: float = 0.0 + data3: float = 0.0 + k1_idx: int = -1 + +@jax.jit +def get_values_from_transfer_matrix(r_mat, param_values): + """Compute Twiss parameters and dispersion from transfer matrix. + + Parameters + ---------- + r_mat : jnp.ndarray + Transfer matrix of the element. + param_values : jnp.ndarray + Initial Twiss parameters and dispersion values. + + Returns + ------- + jnp.ndarray + Updated Twiss parameters and dispersion values. + """ + + # Order: betx, bety, alfx, alfy, mux, muy, dx, dy, dpx, dpy + bx0, by0, ax0, ay0, mux0, muy0, dx0, dy0, dpx0, dpy0 = param_values + + # --- Horizontal plane --- + r00, r01, r10, r11 = r_mat[0,0], r_mat[0,1], r_mat[1,0], r_mat[1,1] + + tmp_x = r00 * bx0 - r01 * ax0 + betx = (tmp_x**2 + r01**2) / bx0 + alfx = -((tmp_x * (r10 * bx0 - r11 * ax0) + r01 * r11) / bx0) + mux = mux0 + jnp.arctan2(r01, tmp_x) / (2 * jnp.pi) + + # --- Vertical plane --- + r22, r23, r32, r33 = r_mat[2,2], r_mat[2,3], r_mat[3,2], r_mat[3,3] + + tmp_y = r22 * by0 - r23 * ay0 + bety = (tmp_y**2 + r23**2) / by0 + alfy = -((tmp_y * (r32 * by0 - r33 * ay0) + r23 * r33) / by0) + muy = muy0 + jnp.arctan2(r23, tmp_y) / (2 * jnp.pi) + + # --- Dispersion --- + dx = r00 * dx0 + r01 * dpx0 + r_mat[0,5] + dy = r22 * dy0 + r23 * dpy0 + r_mat[2,5] + dpx = r10 * dx0 + r11 * dpx0 + r_mat[1,5] + dpy = r32 * dy0 + r33 * dpy0 + r_mat[3,5] + + return jnp.array([betx, bety, alfx, alfy, mux, muy, dx, dy, dpx, dpy]) + +def encode_elements(elements, elem_to_deriv): + """Encode elements into a format suitable for JAX, + containing the type of each element and serializing parameters. + + Parameters + ---------- + elements : list of xtrack elements + The elements to encode. + elem_to_deriv : list of xtrack elements or None + The elements for which derivatives are computed. If None, no derivatives are computed. + Returns + ------- + EncodedElem + A NamedTuple containing the encoded elements. + """ + + if elem_to_deriv is not None: + deriv_lookup = {id(elem): i for i, elem in enumerate(elem_to_deriv)} + + encoded = [] + + for elem in elements: + if elem_to_deriv is not None and elem in elem_to_deriv: + encoded.append(EncodedElem( + etype=0, + data0=elem.length, + k1_idx=deriv_lookup[id(elem)] + )) + elif isinstance(elem, xt.Quadrupole): + encoded.append(EncodedElem( + etype=1, + data0=elem.k1, + data1=elem.length + )) + elif isinstance(elem, xt.Bend) or isinstance(elem, xt.RBend): + encoded.append(EncodedElem( + etype=2, + data0=elem.k0, + data1=elem.k1, + data2=elem.length, + data3=elem.h + )) + elif isinstance(elem, xt.Multipole) and elem.isthick and elem.length > 0: + encoded.append(EncodedElem( + etype=3, + data0=elem.length + )) + elif isinstance(elem, xt.Multipole): + encoded.append(EncodedElem( + etype=4 + )) + elif isinstance(elem, xt.Drift) or hasattr(elem, 'length'): + encoded.append(EncodedElem( + etype=3, + data0=elem.length + )) + else: + encoded.append(EncodedElem( + etype=4, + )) + + # Convert list of NamedTuples to NamedTuple of arrays for JAX + return EncodedElem( + etype=jnp.array([e.etype for e in encoded]), + data0=jnp.array([e.data0 for e in encoded]), + data1=jnp.array([e.data1 for e in encoded]), + data2=jnp.array([e.data2 for e in encoded]), + data3=jnp.array([e.data3 for e in encoded]), + k1_idx=jnp.array([e.k1_idx for e in encoded]), + ) + +@partial(jax.jit, static_argnums=(2,3)) +def get_values(k1_arr, encoded_elements, beta0, gamma0, initial_params): + """Compute Twiss parameters and dispersion from encoded elements. + + Parameters + ---------- + k1_arr : jnp.ndarray + Array of k1 values for the elements that require derivatives. + encoded_elements : EncodedElem + Encoded elements containing the type and parameters of each element. + beta0 : float + Reference relativistic beta. + gamma0 : float + Reference relativistic gamma. + initial_params : jnp.ndarray + Initial Twiss parameters and dispersion values. + + Returns + ------- + jnp.ndarray + Updated Twiss parameters and dispersion values. + """ + + def scan_step(params, elem): + TMF = TransferMatrixFactory + + # Defining methods inside the switch to avoid recompilation + tm = jax.lax.switch(elem.etype, [ + lambda: TMF.quad(k1_arr[elem.k1_idx], elem.data0, beta0, gamma0), + lambda: TMF.quad(elem.data0, elem.data1, beta0, gamma0), + lambda: TMF.bend(elem.data0, elem.data1, elem.data2, elem.data3, beta0, gamma0), + lambda: TMF.drift(elem.data0, beta0, gamma0), + lambda: jnp.eye(6) + ] + ) + new_params = get_values_from_transfer_matrix(tm, params) + return new_params, None + + final_params, _ = jax.lax.scan(scan_step, initial_params, encoded_elements) + return final_params + +def compute_param_derivatives(elements, elem_to_deriv, init_cond, beta0, gamma0): + """Compute the derivatives of the Twiss parameters with respect to k1 values. + + Parameters + ---------- + elements : list of xtrack elements + The elements for which to compute the derivatives. + elem_to_deriv : list of xtrack elements + The elements for which derivatives are computed. + init_cond : list of float + Initial conditions for the Twiss parameters and dispersion. + beta0 : float + Reference relativistic beta. + gamma0 : float + Reference relativistic gamma. + + Returns + ------- + jnp.ndarray + The Jacobian matrix of derivatives with respect to k1 values. + """ + + encoded_elements = encode_elements(elements, elem_to_deriv) + k1_arr = jnp.array([elem.k1 for elem in elem_to_deriv]) + + initial_params = jnp.array(init_cond) + + def wrapped_get_values(k1_arr): + return get_values(k1_arr, encoded_elements, beta0, gamma0, initial_params) + + pushfwd = partial(jax.jvp, wrapped_get_values, (k1_arr,)) + basis = jnp.eye(len(k1_arr)) + y, jac = jax.vmap(pushfwd, out_axes=(None, 1))((basis,)) + return jac, y # to yield both jacobian and final values + +@partial(jax.jit, static_argnums=(1,2)) +def get_values_noderiv(encoded_elements, beta0, gamma0, initial_params): + """Compute Twiss parameters and dispersion from encoded elements without derivatives. + + Parameters + ---------- + encoded_elements : EncodedElem + Encoded elements containing the type and parameters of each element. + beta0 : float + Reference relativistic beta. + gamma0 : float + Reference relativistic gamma. + initial_params : jnp.ndarray + Initial Twiss parameters and dispersion values. + + Returns + ------- + jnp.ndarray + Updated Twiss parameters and dispersion values. + """ + + def scan_step(params, elem): + TMF = TransferMatrixFactory + + # Defining methods inside the switch to avoid recompilation + tm = jax.lax.switch(elem.etype, [ + lambda: jnp.eye(6), + lambda: TMF.quad(elem.data0, elem.data1, beta0, gamma0), + lambda: TMF.bend(elem.data0, elem.data1, elem.data2, elem.data3, beta0, gamma0), + lambda: TMF.drift(elem.data0, beta0, gamma0), + lambda: jnp.eye(6) + ] + ) + new_params = get_values_from_transfer_matrix(tm, params) + return new_params, None + + final_params, _ = jax.lax.scan(scan_step, initial_params, encoded_elements) + return final_params + +def compute_values(elements, tw0): + """Compute Twiss parameters and dispersion from elements and initial conditions + without calculating derivatives. + + Parameters + ---------- + elements : list of xtrack elements + The elements for which to compute the Twiss parameters and dispersion. + tw0 : xtrack.Twiss + Initial Twiss parameters and dispersion. + + Returns + ------- + jnp.ndarray + Updated Twiss parameters and dispersion values. + """ + + beta0 = tw0.particle_on_co.beta0[0] + gamma0 = tw0.particle_on_co.gamma0[0] + + encoded_elements = encode_elements(elements, None) + + initial_params = jnp.array([ + tw0.betx[0], tw0.bety[0], tw0.alfx[0], tw0.alfy[0], + tw0.mux[0], tw0.muy[0], tw0.dx[0], tw0.dy[0], + tw0.dpx[0], tw0.dpy[0] + ]) + + return get_values_noderiv(encoded_elements, beta0, gamma0, initial_params) \ No newline at end of file diff --git a/xtrack/match.py b/xtrack/match.py index 040983844..4e05eeb63 100644 --- a/xtrack/match.py +++ b/xtrack/match.py @@ -4,6 +4,8 @@ from .general import _print, _LOC import xtrack as xt import xdeps as xd +import sympy +from xtrack.autodiff import compute_param_derivatives XTRACK_DEFAULT_TOL = 1e-9 XTRACK_DEFAULT_SIGMA_REL = 0.01 @@ -57,6 +59,19 @@ 'dx_ng', 'dpx_ng'] +AD_QTY_IDX = { + "betx": 0, + "bety": 1, + "alfx": 2, + "alfy": 3, + "mux": 4, + "muy": 5, + "dx": 6, + "dy": 7, + "dpx": 8, + "dpy": 9, +} + # Alternative transitions functions # def _transition_sigmoid_integral(x): # x_shift = x - 3 @@ -676,7 +691,7 @@ def match_line(line, vary, targets, solve=True, assert_within_tol=True, restore_if_fail=True, verbose=False, n_steps_max=20, default_tol=None, solver=None, check_limits=True, - name="", + name="", use_ad=False, **kwargs): opt = OptimizeLine(line, vary, targets, @@ -687,7 +702,7 @@ def match_line(line, vary, targets, solve=True, assert_within_tol=True, restore_if_fail=restore_if_fail, verbose=verbose, n_steps_max=n_steps_max, default_tol=default_tol, solver=solver, check_limits=check_limits, - name=name, + name=name, use_ad=use_ad, **kwargs) if solve: @@ -812,6 +827,205 @@ def run(self, allow_failure=True): out.line = self.line return out +class MeritFunctionLine(xd.MeritFunctionForMatch): + def __init__( + self, + merit_function_match, + use_ad=False + ): + + self.vary = merit_function_match.vary + self.targets = merit_function_match.targets + self.actions = merit_function_match.actions + self.return_scalar = merit_function_match.return_scalar + self.call_counter = merit_function_match.call_counter + self.verbose = merit_function_match.verbose + self.tw_kwargs = merit_function_match.tw_kwargs + self.steps_for_jacobian = merit_function_match.steps_for_jacobian + self.found_point_within_tol = merit_function_match.found_point_within_tol + self.zero_if_met = merit_function_match.zero_if_met + self.show_call_counter = merit_function_match.show_call_counter + self.check_limits = merit_function_match.check_limits + self.use_ad = use_ad + + def get_derivatives_elements_knobs(self): + """ + Compute the derivatives of quadrupole k1 with respect to knobs. + This is done by executing the symbolic expressions for each knob + and extracting the derivatives of k1 with respect to the symbolic variable. + + Yields + ------- + dkq_dvv : dict + A dictionary mapping knob names to another dictionary that maps + quadrupole names to the derivative of the quadrupole k1 with respect + to the knob. + quad_sources_ord : list + An ordered list of quadrupole names that appear in the derivatives. + target_places : list + An ordered list of target locations used in the optimization. + """ + + class DummyElement: + """Placeholder object for injecting symbolic attributes.""" + pass + + dkq_dvv = {} # Mapping: knob name -> {quad name -> d(quad)/d(knob)} + + for vary_entry in self.vary: + knob_name = vary_entry.name + symbolic_var = sympy.var("a") + + # Find all quadrupole k1 dependencies on this knob + quad_exprs = [] + dummy_quads = {} + + for dep in self.actions[0].line.ref_manager.find_deps([self.actions[0].line.vars[knob_name]]): + if dep.__class__.__name__ == "AttrRef" and dep._key == "k1": + quad_name = dep._owner._key + quad_exprs.append((quad_name, dep._expr)) + dummy_quads[quad_name] = DummyElement() + + # Build symbolic expression function for the knob + func_code = self.actions[0].line.ref_manager.mk_fun("myfun", a=self.actions[0].line.vars[knob_name]) + func_globals = { + "vars": self.actions[0].line.ref_manager.containers["vars"]._owner.copy(), + "element_refs": dummy_quads, + } + func_locals = {} + + ################### myfun ################################ + # def myfun(a): + # knob_name = a + # element_refs[quad_name].k1 = (1.0 * knob_name) -> SymPy expression + # ... + ########################################################## + + exec(func_code, func_globals, func_locals) # Create function, stored in func_locals + func_locals["myfun"](symbolic_var) # Execute function + + # Extract derivatives of k1 with respect to this knob + k1_derivs = {} + for quad_name, _ in quad_exprs: + derivative = func_globals["element_refs"][quad_name].k1.diff(symbolic_var) + k1_derivs[quad_name] = derivative + + dkq_dvv[knob_name] = k1_derivs + + # Set of all quadrupole names appearing in the derivatives + quad_sources = set() + for derivs in dkq_dvv.values(): + quad_sources.update(derivs.keys()) + + # Set of all target locations used in optimization and sort them by order + target_places = set() + for target in self.targets: + if isinstance(target.tar, tuple): + target_places.add(target.tar[1]) + elif hasattr(target, "start") and hasattr(target, "end"): + if target.start != '__ele_start__': + target_places.add(target.start) + if target.end != '__ele_stop__': + target_places.add(target.end) + else: + if self.actions[0]._tw0.name[-2] not in target_places: + target_places.add(self.actions[0]._tw0.name[-2]) + # Assumption: Point before _end_point is same as endpoint given in opt + else: + raise ValueError(f"Unknown target type: {type(target)}") + # Convert to ordered list based on appearance of name in opt.line + index_map = {name: i for i, name in enumerate(self.actions[0]._tw0.name)} + target_places = sorted(target_places, key=index_map.get) + + # Ordered list of quadrupole sources (based on their position in the beamline) + quad_sources_ordered = [ + name for name in self.actions[0]._tw0.name if name in quad_sources + ] + + self.quad_sources_ord = quad_sources_ordered + self.target_places = target_places + self.dkq_dvv = dkq_dvv + + def get_jacobian(self, x, f0=None): + if self.use_ad: + return self.get_jacobian_ad(x) + else: + return super().get_jacobian(x, f0=f0) + + def get_jacobian_ad(self, x): + if not hasattr(self, "quad_sources_ord") or not hasattr(self, "target_places") or not hasattr(self, "dkq_dvv"): + self.get_derivatives_elements_knobs() + x = np.array(x).copy() + #jacobian = get_jac(opt, all_quad_sources, target_places, dkq_dvv) + + opt_tw = self.actions[0].run() + #opt_tw = opt.action_twiss._tw0 + # Initial conditions for first derivative calculation + init_cond = np.array([opt_tw.betx[0], opt_tw.bety[0], opt_tw.alfx[0], opt_tw.alfy[0], + opt_tw.mux[0], opt_tw.muy[0], opt_tw.dx[0], opt_tw.dy[0], + opt_tw.dpx[0], opt_tw.dpy[0]]) + beta0 = opt_tw.particle_on_co.beta0[0] + gamma0 = opt_tw.particle_on_co.gamma0[0] + + twiss_derivs = {} + for place in self.target_places: # in order of appearance + # Calc derivative for all quadrupoles for target place + # Source point = qqnn, Observation point = target + twiss_derivs[place] = {} + trunc_elements = np.array([self.actions[0].line.element_dict[name] for name in opt_tw.rows[:place].name]) + nonzero_qq = [] + nonzero_qqn = [] + for qqnn in self.quad_sources_ord: + if opt_tw['s', place] < opt_tw['s', qqnn]: + twiss_derivs[place][qqnn] = np.zeros(10) # batch after + else: + nonzero_qqn.append(qqnn) + nonzero_qq.append(self.actions[0].line.element_dict[qqnn]) # first elements + # add to list to be calculated + if len(nonzero_qq) == 0: + continue + nonzero_deriv, _ = compute_param_derivatives(trunc_elements, nonzero_qq, init_cond, beta0, gamma0) + + for i, qqn in enumerate(nonzero_qqn): + twiss_derivs[place][qqn] = nonzero_deriv[i] + for qqn, deriv in zip(nonzero_qqn, nonzero_deriv.T): + twiss_derivs[place][qqn] = deriv + + jac_estim = np.zeros((len(self.targets), len(self.vary))) + for itt, tt in enumerate(self.targets): + + tar_start = None + if isinstance(tt.tar, tuple): + tar_quantity = tt.tar[0] + tar_place = tt.tar[1] + else: + tar_quantity = tt.var + tar_place = self.target_places[-1] if tt.end == '__ele_stop__' else tt.end + tar_start = None if tt.start == '__ele_start__' else tt.start + tar_weight = tt.weight + + tar_weight = tt.weight + quantity_idx = AD_QTY_IDX[tar_quantity] + for ivv in range(len(self.vary)): + vv = self.vary[ivv].name + quad_names = self.dkq_dvv[vv].keys() + + dtar_dvv = 0 + for qqnn in quad_names: + if qqnn in twiss_derivs[tar_place].keys(): + dtar_dvv += (twiss_derivs[tar_place][qqnn][quantity_idx]) * float(self.dkq_dvv[vv][qqnn]) + if tar_start is not None: + dtar_dvv -= (twiss_derivs[tar_start][qqnn][quantity_idx]) * float(self.dkq_dvv[vv][qqnn]) + + dtar_dvv *= tar_weight + + jac_estim[itt, ivv] = dtar_dvv + + #return jac_estim + + self._last_jac = jac_estim + return jac_estim + class OptimizeLine(xd.Optimize): def __init__(self, line, vary, targets, assert_within_tol=True, @@ -821,7 +1035,7 @@ def __init__(self, line, vary, targets, assert_within_tol=True, n_steps_max=20, default_tol=None, solver=None, check_limits=True, action_twiss=None, action_twiss_ng=None, - name="", + name="", use_ad=False, **kwargs): if hasattr(targets, 'values'): # dict like @@ -861,13 +1075,21 @@ def __init__(self, line, vary, targets, assert_within_tol=True, action_twiss.prepare() tt.action = action_twiss + # Handle at if isinstance(tt.tar, tuple): tt_name = tt.tar[0] # `at` is present tt_at = tt.tar[1] + if use_ad == True and tt_name not in ['betx', 'bety', 'alfx', 'alfy', 'mux', 'muy', 'dx', 'dy', 'dpx', 'dpy']: + print("Warning: use_ad is set to True, but the target {} is not supported for automatic differentiation.") + use_ad = False else: tt_name = tt.tar tt_at = None + if use_ad == True and not isinstance(tt, TargetRelPhaseAdvance): + print("Warning: use_ad is set to True, but the target {} is not supported for automatic differentiation.") + use_ad = False + if tt_at is not None and isinstance(tt_at, _LOC): assert isinstance(tt.action, ActionTwiss) tt.action.prepare() # does nothing if already prepared @@ -930,10 +1152,13 @@ def __init__(self, line, vary, targets, assert_within_tol=True, n_steps_max=n_steps_max, restore_if_fail=restore_if_fail, check_limits=check_limits, - name=name) + name=name, line=line) + + _err = MeritFunctionLine(self._err, use_ad=use_ad) self.line = line self.action_twiss = action_twiss self.default_tol = default_tol + self._err = _err def clone(self, add_targets=None, add_vary=None, remove_targets=None, remove_vary=None,