diff --git a/tests/test_madnginterface.py b/tests/test_madnginterface.py index 8c21c6006..0643b5cde 100644 --- a/tests/test_madnginterface.py +++ b/tests/test_madnginterface.py @@ -3,6 +3,7 @@ import xobjects as xo import pathlib import numpy as np +from xtrack._temp import lhc_match as lm test_data_folder = pathlib.Path( __file__).parent.joinpath('../test_data').absolute() @@ -68,7 +69,7 @@ def test_madng_interface_with_multipole_errors_and_misalignments(): line[nn_quad].shift_y = sy * line.ref['on_error'] line[nn_quad].rot_s_rad = rr * line.ref['on_error'] line[nn_quad].knl[2] = kkk * line.ref['on_error'] - tw = line.madng_twiss() + tw = line.madng_twiss(coupling_edw_teng=True, compute_chromatic_properties=True) xo.assert_allclose(tw.x, tw.x_ng, atol=5e-4*tw.x.std(), rtol=0) xo.assert_allclose(tw.y, tw.y_ng, atol=5e-4*tw.y.std(), rtol=0) @@ -82,7 +83,7 @@ def test_madng_interface_with_multipole_errors_and_misalignments(): xo.assert_allclose(tw.by_chrom, tw.by_ng, atol=5e-3*tw.wy_chrom.max(), rtol=0) line['on_error'] = 0 - tw = line.madng_twiss() + tw = line.madng_twiss(coupling_edw_teng=True, compute_chromatic_properties=True) xo.assert_allclose(tw.x, 0, atol=1e-10, rtol=0) xo.assert_allclose(tw.y, 0, atol=1e-10, rtol=0) xo.assert_allclose(tw.betx2, 0, atol=1e-10, rtol=0) @@ -173,7 +174,7 @@ def test_madng_interface_with_slicing(): line.cut_at_s(np.arange(1000)) tw_xs = line.twiss4d() - tw = line.madng_twiss() + tw = line.madng_twiss(coupling_edw_teng=True, compute_chromatic_properties=True) assert len(tw) == len(tw_xs) @@ -196,16 +197,16 @@ def test_madng_twiss_with_initial_conditions(): line = xt.load(test_data_folder / 'hllhc15_thick/lhc_thick_with_knobs.json') #pytest.set_trace() - tw_xs = line.twiss(betx=120, bety=150) - tw = line.madng_twiss(beta11=120, beta22=150) + tw_xs = line.twiss(betx=120, bety=150, alfx=5, alfy=5, dx=1e-4) + tw = line.madng_twiss(beta11=120, beta22=150, alfa11=5, alfa22=5, dx=1e-4) assert len(tw) == len(tw_xs) assert len(tw.betx) == len(tw.beta11_ng) - xo.assert_allclose(tw.betx, tw.beta11_ng, rtol=1e-7, atol=1e-6) - xo.assert_allclose(tw.bety, tw.beta22_ng, rtol=1e-7, atol=1e-6) - xo.assert_allclose(tw.alfx, tw.alfa11_ng, rtol=1e-7, atol=1e-6) - xo.assert_allclose(tw.alfy, tw.alfa22_ng, rtol=1e-7, atol=1e-6) + xo.assert_allclose(tw.betx, tw.beta11_ng, rtol=1e-6, atol=1e-6) + xo.assert_allclose(tw.bety, tw.beta22_ng, rtol=1e-6, atol=1e-6) + xo.assert_allclose(tw.alfx, tw.alfa11_ng, rtol=1e-6, atol=1e-6) + xo.assert_allclose(tw.alfy, tw.alfa22_ng, rtol=1e-6, atol=1e-6) xo.assert_allclose(tw.dx, tw.dx_ng, rtol=1e-8, atol=1e-6) xo.assert_allclose(tw.dy, tw.dy_ng, rtol=1e-8, atol=1e-6) xo.assert_allclose(tw.dpx, tw.dpx_ng, rtol=1e-8, atol=1e-6) @@ -213,8 +214,8 @@ def test_madng_twiss_with_initial_conditions(): xo.assert_allclose(tw.x, tw.x_ng, rtol=1e-8, atol=1e-6) xo.assert_allclose(tw.y, tw.y_ng, rtol=1e-8, atol=1e-6) - tw2_xs = line.twiss(start='s.ds.l8.b1', end='ip1', betx=100, bety=34) - tw2_xsng = line.madng_twiss(start='s.ds.l8.b1', end='ip1', beta11=100, beta22=34, xsuite_tw=False) + tw2_xs = line.twiss(start='s.ds.l8.b1', end='ip1', betx=100, bety=34, dx=1e-5) + tw2_xsng = line.madng_twiss(start='s.ds.l8.b1', end='ip1', beta11=100, beta22=34, dx=1e-5, xsuite_tw=False) assert len(tw2_xs.betx) == len(tw2_xsng.beta11_ng) xo.assert_allclose(tw2_xs.betx, tw2_xsng.beta11_ng, rtol=1e-8, atol=1e-6) @@ -248,16 +249,34 @@ def test_madng_twiss_with_initial_conditions(): xo.assert_allclose(tw3_xsng.alfx, tw3_xsng.alfa11_ng, rtol=1e-8, atol=1e-6) xo.assert_allclose(tw3_xsng.alfy, tw3_xsng.alfa22_ng, rtol=1e-8, atol=1e-6) + tw4_xs = line.twiss(start='ip3', end='ip4', betx=121.5668, bety=218.58374, alfx=2.295, alfy=-2.6429, dx=-0.51) + tw4_xsng = line.madng_twiss(start='ip3', end='ip4', beta11=121.5668, beta22=218.58374, alfa11=2.295, + alfa22=-2.6429, dx=-0.51, xsuite_tw=False) + + assert len(tw4_xs.betx) == len(tw4_xsng.beta11_ng) + xo.assert_allclose(tw4_xs.betx, tw4_xsng.beta11_ng, rtol=1e-6, atol=1e-5) + xo.assert_allclose(tw4_xs.bety, tw4_xsng.beta22_ng, rtol=1e-6, atol=1e-5) + xo.assert_allclose(tw4_xs.alfx, tw4_xsng.alfa11_ng, rtol=1e-6, atol=1e-5) + xo.assert_allclose(tw4_xs.alfy, tw4_xsng.alfa22_ng, rtol=1e-6, atol=1e-5) + xo.assert_allclose(tw4_xs.dx, tw4_xsng.dx_ng, rtol=1e-7, atol=1e-8) + xo.assert_allclose(tw4_xs.dy, tw4_xsng.dy_ng, rtol=1e-7, atol=1e-8) + xo.assert_allclose(tw4_xs.x, tw4_xsng.x_ng, rtol=1e-8, atol=1e-10) + xo.assert_allclose(tw4_xs.y, tw4_xsng.y_ng, rtol=1e-8, atol=1e-10) + xo.assert_allclose(tw4_xs.px, tw4_xsng.px_ng, rtol=1e-8, atol=1e-10) + xo.assert_allclose(tw4_xs.py, tw4_xsng.py_ng, rtol=1e-8, atol=1e-10) + xo.assert_allclose(tw4_xs.mux, tw4_xsng.mu1_ng, rtol=1e-8, atol=1e-5) + xo.assert_allclose(tw4_xs.muy, tw4_xsng.mu2_ng, rtol=1e-8, atol=1e-5) + def test_madng_slices(): line = xt.load(test_data_folder / 'hllhc15_thick/lhc_thick_with_knobs.json') tw = line.twiss4d() - twng = line.madng_twiss() + twng = line.madng_twiss(compute_chromatic_properties=True) line.cut_at_s(np.linspace(0, line.get_length(), 5000)) tw_sliced = line.twiss4d() - twng_sliced = line.madng_twiss() + twng_sliced = line.madng_twiss(compute_chromatic_properties=True) tt_sliced = line.get_table() assert np.all(np.array(sorted(list(set(tt_sliced.element_type)))) == @@ -302,3 +321,201 @@ def test_madng_slices(): xo.assert_allclose(twng_ip.wy_ng, twng_ip_sliced.wy_ng, rtol=1e-3) xo.assert_allclose(twng_ip.dx_ng, twng_ip_sliced.dx_ng, atol=1e-6) xo.assert_allclose(twng_ip.dy_ng, twng_ip_sliced.dy_ng, atol=1e-6) + +def test_madng_match_optics(): + collider = xt.Environment.from_json(test_data_folder / + 'hllhc15_thick/hllhc15_collider_thick.json') + collider.vars.load_madx(test_data_folder / + 'hllhc15_thick/opt_round_150_1500.madx') + + line = collider.lhcb1 + tw0 = line.madng_twiss() + + lm.set_var_limits_and_steps(collider) + + # Match with Xsuite Targets + opt = line.match( + solve=False, + default_tol={None: 1e-8, 'betx': 1e-6, 'bety': 1e-6, 'alfx': 1e-6, 'alfy': 1e-6}, + start='s.ds.l8.b1', end='ip1', + init=tw0, init_at=xt.START, + vary=[ + # Only IR8 quadrupoles including DS + xt.VaryList(['kq6.l8b1', 'kq7.l8b1', 'kq8.l8b1', 'kq9.l8b1', 'kq10.l8b1', + 'kqtl11.l8b1', 'kqt12.l8b1', 'kqt13.l8b1', + 'kq4.l8b1', 'kq5.l8b1', 'kq4.r8b1', 'kq5.r8b1', + 'kq6.r8b1', 'kq7.r8b1', 'kq8.r8b1', 'kq9.r8b1', + 'kq10.r8b1', 'kqtl11.r8b1', 'kqt12.r8b1', 'kqt13.r8b1'])], + targets=[ + xt.TargetSet(at='ip8', tars=('betx', 'bety', 'alfx', 'alfy', 'dx', 'dpx'), value=tw0, weight=1), + xt.TargetSet(at='ip1', betx=0.15, bety=0.1, alfx=0, alfy=0, dx=0, dpx=0, weight=1), + xt.TargetRelPhaseAdvance('mux', value = tw0['mux', 'ip1.l1'] - tw0['mux', 's.ds.l8.b1'], start='s.ds.l8.b1', end='ip1.l1', weight=1), + xt.TargetRelPhaseAdvance('muy', value = tw0['muy', 'ip1.l1'] - tw0['muy', 's.ds.l8.b1'], start='s.ds.l8.b1', end='ip1.l1', weight=1), + ], + use_tpsa=True) + + opt.step(30) + + assert opt._err.call_counter < 20 + assert len(opt.log()) < 10 + + tw = line.twiss(init=tw0, start='s.ds.l8.b1', end='ip1') + + xo.assert_allclose(tw['betx', 'ip1'], 0.15, atol=1e-6, rtol=0) + xo.assert_allclose(tw['bety', 'ip1'], 0.1, atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfx', 'ip1'], 0., atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfy', 'ip1'], 0., atol=1e-6, rtol=0) + xo.assert_allclose(tw['dx', 'ip1'], 0., atol=1e-6, rtol=0) + xo.assert_allclose(tw['dy', 'ip1'], 0., atol=1e-6, rtol=0) + + xo.assert_allclose(tw['betx', 'ip8'], tw0['betx', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['bety', 'ip8'], tw0['bety', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfx', 'ip8'], tw0['alfx', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfy', 'ip8'], tw0['alfy', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['dx', 'ip8'], tw0['dx', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['dy', 'ip8'], tw0['dy', 'ip8'], atol=1e-6, rtol=0) + + xo.assert_allclose(tw['mux', 'ip1.l1'] - tw['mux', 's.ds.l8.b1'], tw0['mux', 'ip1.l1'] - tw0['mux', 's.ds.l8.b1'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['muy', 'ip1.l1'] - tw['muy', 's.ds.l8.b1'], tw0['muy', 'ip1.l1'] - tw0['muy', 's.ds.l8.b1'], atol=1e-6, rtol=0) + + opt.reload(0) + opt.actions[0].cleanup() + + # Match with MAD-NG and Xsuite Targets mixed + opt = line.match( + solve=False, + default_tol={None: 1e-8, 'betx': 1e-6, 'bety': 1e-6, 'alfx': 1e-6, 'alfy': 1e-6}, + start='s.ds.l8.b1', end='ip1', + init=tw0, init_at=xt.START, + vary=[ + # Only IR8 quadrupoles including DS + xt.VaryList(['kq6.l8b1', 'kq7.l8b1', 'kq8.l8b1', 'kq9.l8b1', 'kq10.l8b1', + 'kqtl11.l8b1', 'kqt12.l8b1', 'kqt13.l8b1', + 'kq4.l8b1', 'kq5.l8b1', 'kq4.r8b1', 'kq5.r8b1', + 'kq6.r8b1', 'kq7.r8b1', 'kq8.r8b1', 'kq9.r8b1', + 'kq10.r8b1', 'kqtl11.r8b1', 'kqt12.r8b1', 'kqt13.r8b1'])], + targets=[ + xt.TargetSet(at='ip8', tars=('beta11_ng', 'bety', 'alfa11_ng', 'alfy', 'dx_ng', 'dpx'), value=tw0, weight=1), + xt.TargetSet(at='ip1', betx=0.15, beta22_ng=0.1, alfx=0, alfa22_ng=0, dx=0, dpx_ng=0, weight=1), + xt.TargetRelPhaseAdvance('mux', value = tw0['mux', 'ip1.l1'] - tw0['mux', 's.ds.l8.b1'], start='s.ds.l8.b1', end='ip1.l1', weight=1), + xt.TargetRelPhaseAdvance('mu2_ng', value = tw0['mu2_ng', 'ip1.l1'] - tw0['mu2_ng', 's.ds.l8.b1'], start='s.ds.l8.b1', end='ip1.l1', weight=1), + ], + use_tpsa=True) + + opt.step(30) + + assert opt._err.call_counter < 20 + assert len(opt.log()) < 10 + + tw = line.twiss(init=tw0, start='s.ds.l8.b1', end='ip1') + + xo.assert_allclose(tw['betx', 'ip1'], 0.15, atol=1e-6, rtol=0) + xo.assert_allclose(tw['bety', 'ip1'], 0.1, atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfx', 'ip1'], 0., atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfy', 'ip1'], 0., atol=1e-6, rtol=0) + xo.assert_allclose(tw['dx', 'ip1'], 0., atol=1e-6, rtol=0) + xo.assert_allclose(tw['dy', 'ip1'], 0., atol=1e-6, rtol=0) + + xo.assert_allclose(tw['betx', 'ip8'], tw0['betx', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['bety', 'ip8'], tw0['bety', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfx', 'ip8'], tw0['alfx', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfy', 'ip8'], tw0['alfy', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['dx', 'ip8'], tw0['dx', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['dy', 'ip8'], tw0['dy', 'ip8'], atol=1e-6, rtol=0) + + xo.assert_allclose(tw['mux', 'ip1.l1'] - tw['mux', 's.ds.l8.b1'], tw0['mux', 'ip1.l1'] - tw0['mux', 's.ds.l8.b1'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['muy', 'ip1.l1'] - tw['muy', 's.ds.l8.b1'], tw0['muy', 'ip1.l1'] - tw0['muy', 's.ds.l8.b1'], atol=1e-6, rtol=0) + + opt.reload(0) + opt.actions[0].cleanup() + + # Match on full line without initial conditions + opt = line.match( + solve=False, + default_tol={None: 1e-8, 'betx': 1e-6, 'bety': 1e-6, 'alfx': 1e-6, 'alfy': 1e-6}, + vary=[ + # Only IR8 quadrupoles including DS + xt.VaryList(['kq6.l8b1', 'kq7.l8b1', 'kq8.l8b1', 'kq9.l8b1', 'kq10.l8b1', + 'kqtl11.l8b1', 'kqt12.l8b1', 'kqt13.l8b1', + 'kq4.l8b1', 'kq5.l8b1', 'kq4.r8b1', 'kq5.r8b1', + 'kq6.r8b1', 'kq7.r8b1', 'kq8.r8b1', 'kq9.r8b1', + 'kq10.r8b1', 'kqtl11.r8b1', 'kqt12.r8b1', 'kqt13.r8b1'])], + targets=[ + xt.TargetSet(at='ip8', tars=('beta11_ng', 'beta22_ng', 'alfa11_ng', 'alfa22_ng', 'dx_ng', 'dpx_ng'), value=tw0, weight=1), + xt.TargetSet(at='ip1.l1', beta11_ng=0.15, beta22_ng=0.1, alfa11_ng=0, alfa22_ng=0, dx_ng=0, dpx_ng=0, weight=1), + xt.TargetRelPhaseAdvance('mu1_ng', value = tw0['mu1_ng', 'ip1.l1'] - tw0['mu1_ng', 's.ds.l8.b1'], start='s.ds.l8.b1', end='ip1.l1', weight=1), + xt.TargetRelPhaseAdvance('mu2_ng', value = tw0['mu2_ng', 'ip1.l1'] - tw0['mu2_ng', 's.ds.l8.b1'], start='s.ds.l8.b1', end='ip1.l1', weight=1), + ], + use_tpsa=True) + + opt.step(30) + + assert opt._err.call_counter < 20 + assert len(opt.log()) < 10 + + tw = line.twiss(init=tw0) + + xo.assert_allclose(tw['betx', 'ip1.l1'], 0.15, atol=1e-6, rtol=0) + xo.assert_allclose(tw['bety', 'ip1.l1'], 0.1, atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfx', 'ip1.l1'], 0., atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfy', 'ip1.l1'], 0., atol=1e-6, rtol=0) + xo.assert_allclose(tw['dx', 'ip1.l1'], 0., atol=1e-6, rtol=0) + xo.assert_allclose(tw['dy', 'ip1.l1'], 0., atol=1e-6, rtol=0) + + xo.assert_allclose(tw['betx', 'ip8'], tw0['betx', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['bety', 'ip8'], tw0['bety', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfx', 'ip8'], tw0['alfx', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['alfy', 'ip8'], tw0['alfy', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['dx', 'ip8'], tw0['dx', 'ip8'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['dy', 'ip8'], tw0['dy', 'ip8'], atol=1e-6, rtol=0) + + xo.assert_allclose(tw['mux', 'ip1.l1'] - tw['mux', 's.ds.l8.b1'], tw0['mux', 'ip1.l1'] - tw0['mux', 's.ds.l8.b1'], atol=1e-6, rtol=0) + xo.assert_allclose(tw['muy', 'ip1.l1'] - tw['muy', 's.ds.l8.b1'], tw0['muy', 'ip1.l1'] - tw0['muy', 's.ds.l8.b1'], atol=1e-6, rtol=0) + +def test_madng_orbit_bump(): + env = xt.Environment() + env.vars.default_to_zero = True + line = env.new_line(length=10, components=[ + env.new('corr1', xt.Multipole, isthick=True, + knl=['kick_h_1'], ksl=['kick_v_1'], length=0.1, at=1), + env.new('corr2', xt.Multipole, isthick=True, + knl=['kick_h_2'], ksl=['kick_v_2'], length=0.1, at=2), + env.new('corr3', xt.Multipole, isthick=True, + knl=['kick_h_3'], ksl=['kick_v_3'], length=0.1, at=8), + env.new('corr4', xt.Multipole, isthick=True, + knl=['kick_h_4'], ksl=['kick_v_4'], length=0.1, at=9), + env.new('mid', xt.Marker, at=5), + env.new('end', xt.Marker, at=10) + ]) + line.set_particle_ref('proton', p0c=26e9) + + opt = line.match( + solve=False, + betx=1, bety=1, + vary=xt.VaryList(['kick_h_1', 'kick_v_1', + 'kick_h_2', 'kick_v_2', + 'kick_h_3', 'kick_v_3', + 'kick_h_4', 'kick_v_4']), + targets=[ + xt.TargetSet(x=1e-3, y=-2e-3, px=0, py=0, at='mid'), + xt.TargetSet(x=0, y=0, px=0, py=0, at='end'), + ], + use_tpsa=True + ) + + jac_ng = opt._err.get_jacobian(opt._err._get_x()) + + jac_opt = np.array([[-40, 0, -30, 0, 0, 0, 0, 0], + [-100, 0, -100, 0, 0, 0, 0, 0], + [0, 40, 0, 30, 0, 0, 0, 0], + [0, 100, 0, 100, 0, 0, 0, 0], + [-90, 0, -80, 0, -20, 0, -10, 0], + [-100, 0, -100, 0, -100, 0, -100, 0], + [0, 90, 0, 80, 0, 20, 0, 10], + [0, 100, 0, 100, 0, 100, 0, 100]]) + + xo.assert_allclose(jac_ng, jac_opt, rtol=1e-12, atol=1e-12) + + opt.solve() + + assert opt._err.call_counter < 7 \ No newline at end of file diff --git a/xtrack/madng_interface.py b/xtrack/madng_interface.py index dce0b70d1..827fc0a67 100644 --- a/xtrack/madng_interface.py +++ b/xtrack/madng_interface.py @@ -1,9 +1,9 @@ import numpy as np + from .match import Action import os import uuid -from .mad_writer import mad_str_or_value import xtrack as xt NG_XS_MAP = { @@ -15,14 +15,70 @@ 'mu2': 'muy', } +XS_NG_MAP = { + 'betx': 'beta11', + 'bety': 'beta22', + 'alfx': 'alfa11', + 'alfy': 'alfa22', + 'mux': 'mu1', + 'muy': 'mu2', + 'dx': 'dx', + 'dpx': 'dpx', + 'x': 'x', + 'px': 'px', + 'y': 'y', + 'py': 'py', + 'zeta': 't', + 'delta': 'pt', +} + BETA0_COLUMNS = ['x', 'px', 'y', 'py', 't', 'pt', - 'dx', 'dy', 'dpx', 'dpy', 'ddx', 'ddpx', 'ddy', 'ddpy', 'wx', 'phix', - 'wy', 'phiy', 'mu1', 'mu2', 'mu3', 'dmu1', 'dmu2', 'dmu3', 'r11', - 'r12', 'r21', 'r22', 'alfa11', 'alfa12', 'alfa13', 'alfa21', - 'alfa22', 'alfa23', 'alfa31', 'alfa32', 'alfa33', 'beta11', - 'beta12', 'beta13', 'beta21', 'beta22', 'beta23', 'beta31', - 'beta32', 'beta33', 'gama11', 'gama12', 'gama13', 'gama21', - 'gama22', 'gama23', 'gama31', 'gama32', 'gama33'] + 'dx', 'dy', 'dpx', 'dpy', 'ddx', 'ddpx', 'ddy', 'ddpy', 'wx', 'phix', + 'wy', 'phiy', 'mu1', 'mu2', 'mu3', 'dmu1', 'dmu2', 'dmu3', 'r11', + 'r12', 'r21', 'r22', 'alfa11', 'alfa12', 'alfa13', 'alfa21', + 'alfa22', 'alfa23', 'alfa31', 'alfa32', 'alfa33', 'beta11', + 'beta12', 'beta13', 'beta21', 'beta22', 'beta23', 'beta31', + 'beta32', 'beta33', 'gama11', 'gama12', 'gama13', 'gama21', + 'gama22', 'gama23', 'gama31', 'gama32', 'gama33'] + +TW_BASE_COLUMNS = ['s', 'beta11', 'beta22', 'beta33', 'alfa11', 'alfa22', 'alfa33', + 'gama11', 'gama22', 'gama33', 'x', 'px', 'y', 'py', 't', 'pt', + 'dx', 'dy', 'dpx', 'dpy', 'mu1', 'mu2', 'mu3'] + +OPTFUN_QUANTITIES = ['beta11', 'beta22', 'alfa11', 'alfa22', 'gama11', 'gama22', + 'dx', 'dy', 'dpx', 'dpy', 'mu1', 'mu2'] + +CHROM_COLUMNS = ['dmu1', 'dmu2', 'dmu3', 'Dx', 'Dpx', 'Dy', + 'Dpy', 'ddx', 'ddpx', 'ddy', 'ddpy', 'wx', 'wy', 'phix', 'phiy'] + +COUPLING_COLUMNS = ['alfa12', 'alfa13', 'alfa21', 'alfa23', 'alfa31', 'alfa32', + 'beta12', 'beta13', 'beta21', 'beta23', 'beta31', 'beta32', + 'gama12', 'gama13', 'gama21', 'gama23', 'gama31', 'gama32', + 'f1001', 'f1010', 'r11', 'r12', 'r21', 'r22'] + +TPSA_ALLOWED_TARGETS = { 'beta11', 'beta22', 'alfa11', 'alfa22', 'dx', 'dpx', 'dy', 'dpy', + 'mu1', 'mu2', 'x', 'px', 'y', 'py', 't', 'pt' } + +XSUITE_MADNG_ENV_NAME = "_xsuite_matching_env" + +def dp2pt(dp, beta0): + """Convert relative momentum deviation dp/p to transverse momentum pt/p. + + Parameters + ---------- + dp : float + Relative momentum deviation (dp/p, dimensionless). + beta0 : float + Particle relativistic beta (v/c). + + Returns + ------- + float + Transverse momentum relative to total momentum (pt/p, dimensionless). + """ + + _beta0 = 1 / beta0 + return np.sqrt((1 + dp) ** 2 + (_beta0**2 - 1)) - _beta0 class MadngVars: @@ -30,7 +86,17 @@ def __init__(self, mad): self.mad = mad def __setitem__(self, key, value): - setattr(self.mad.MADX, key.replace('.', '_'), value) + #setattr(self.mad.MADX, key.replace('.', '_'), value) + # Check for key if it's a ctpsa or tpsa + + var = f"MADX['{key.replace('.', '_')}']" + is_tpsa = self.mad.send(f"py:send(MAD.typeid.is_tpsa({var}) or MAD.typeid.is_ctpsa({var}))").recv() + if is_tpsa: + self.mad.send(f"{var}:set0(py:recv())").send(value) + else: + self.mad[var] = value + + #Expressions still to be handled, could use the following: # mng.send( # MADX:open_env() @@ -93,24 +159,31 @@ def _build_rdt_script(mng_sequence_name, rdts, columns): def _build_beta0_block_string(tw_kwargs): flag_init = False beta0_keys = [] + beta0_vals = [] for k in tw_kwargs.keys(): if k in BETA0_COLUMNS: beta0_keys.append(k) + beta0_vals.append(tw_kwargs[k]) + flag_init = True + elif k in XS_NG_MAP: + beta0_keys.append(XS_NG_MAP[k]) + beta0_vals.append(tw_kwargs[k]) flag_init = True if flag_init: # Construct beta0 string beta0_str = 'X0 = beta0 {' - for k in beta0_keys: - beta0_str += f'{k} = {tw_kwargs[k]}, ' + for k, v in zip(beta0_keys, beta0_vals): + beta0_str += f'{k} = {v}, ' beta0_str = beta0_str[:-2] + '}, ' else: beta0_str = '' return beta0_str -def _tw_ng(line, rdts=(), normal_form=True, +def _tw_ng(line, rdts=(), normal_form=False, mapdef_twiss=2, mapdef_normal_form=4, - nslice=3, xsuite_tw=True, X0=None, **tw_kwargs): + nslice=3, xsuite_tw=True, X0=None, compute_chromatic_properties=False, + coupling_edw_teng=False, **tw_kwargs): _action = ActionTwissMadng(line, { "rdts": rdts, @@ -134,25 +207,28 @@ def _tw_ng(line, rdts=(), normal_form=True, raise NotImplementedError('TwissTable as init not implemented.') X0_str = _build_beta0_block_string(tw_kwargs) else: - X0_str = f'X0={X0}, ' + X0_str = f'X0 = {X0}, ' + + if start is not None and end is None or start is None and end is not None: + raise ValueError('Start and end must be specified together.') - if not (start is None and end is None and init is None) \ - and not (start is not None and end is not None and X0_str != ''): - raise ValueError('Start and end must be specified together, as well as initial conditions, if open twiss is used.') + if start is not None and end is not None and X0_str == '': + raise ValueError('Initial conditions must be specified when start and end are given.') + + # if not (start is None and end is None and init is None) \ + # and not (start is not None and end is not None and X0_str != ''): + # raise ValueError('Start and end must be specified together, as well as initial conditions, if open twiss is used.') full_twiss_str = '' - tw_columns = ['s', 'beta11', 'beta22', 'alfa11', 'alfa22', - 'x', 'px', 'y', 'py', 't', 'pt', - 'dx', 'dy', 'dpx', 'dpy', 'mu1', 'mu2'] - if start is None and end is None: - extended_tw_columns = ['beta12', 'beta21', 'alfa12', 'alfa21', - 'wx', 'wy', 'phix', 'phiy', 'dmu1', 'dmu2', - 'f1001', 'f1010', 'r11', 'r12', 'r21', 'r22', - ] - full_twiss_str = f"mapdef={mapdef_twiss}, implicit=true, nslice={nslice}, misalgn=true, coupling=true, chrom=true" - tw_columns += extended_tw_columns + tw_columns = TW_BASE_COLUMNS.copy() + full_twiss_str = f"implicit=true, nslice={nslice}, misalign=true, coupling={str(coupling_edw_teng).lower()}, chrom={str(compute_chromatic_properties).lower()}" + + if coupling_edw_teng: + tw_columns += COUPLING_COLUMNS + if compute_chromatic_properties: + tw_columns += CHROM_COLUMNS columns = tw_columns + list(rdts) send_cmd = _build_column_send_script(columns) @@ -160,7 +236,6 @@ def _tw_ng(line, rdts=(), normal_form=True, if len(rdts) > 0: mng_script = _build_rdt_script(mng._sequence_name, rdts, columns) else: - # If start/end -> range, if only start: cycle - twiss - cycle back range_str = '' if start is not None and end is not None: @@ -175,6 +250,7 @@ def _tw_ng(line, rdts=(), normal_form=True, '''} ''' + send_cmd) + mng.send(mng_script) out = mng.recv('columns') @@ -190,8 +266,13 @@ def _tw_ng(line, rdts=(), normal_form=True, xs_tw_kwargs = { NG_XS_MAP.get(k, k): v for k, v in tw_kwargs.items() } - tw = line.twiss(method='4d', reverse=False, **xs_tw_kwargs) - else: + try: + tw = line.twiss(method='4d', reverse=False, **xs_tw_kwargs) + except Exception as e: + print(f"Error occurred while getting twiss: {e}\nContinuing without Xsuite Twiss") + xsuite_tw = False + + if not xsuite_tw: # Handle wrap-around range if i_start > i_end: name_co = np.array(names[i_start:] + names[:i_end + 1] + ('_end_point',)) @@ -237,11 +318,11 @@ def _process_data(data): for nn in rdts: tw[nn] = np.atleast_1d(np.squeeze(out_dct[nn]))[:-1] - if start is None or end is None: - temp_x = tw.wx_ng * np.exp(1j*2*np.pi*tw.phix_ng) + if compute_chromatic_properties: + temp_x = tw.wx_ng * np.exp(1j*2*np.pi*tw.phix_ng) # potentially multiply phix_ng by 2 (MAD-NG 1.1.8) tw['ax_ng'] = np.imag(temp_x) tw['bx_ng'] = np.real(temp_x) - temp_y = tw.wy_ng * np.exp(1j*2*np.pi*tw.phiy_ng) + temp_y = tw.wy_ng * np.exp(1j*2*np.pi*tw.phiy_ng) # potentially multiply phiy_ng by 2 (MAD-NG 1.1.8) tw['ay_ng'] = np.imag(temp_y) tw['by_ng'] = np.real(temp_y) del tw['phix_ng'] @@ -310,15 +391,18 @@ def madng_get_init(line, at): if not hasattr(line.tracker, '_madng'): line.build_madng_model() mng = line.tracker._madng + if at == xt.START: + at = "1" + else: + at = f"'{at}'" mng.send(f""" local observed in MAD.element.flags - {mng._sequence_name}:select(observed, {{list = {{'{at}'}}}}) + {mng._sequence_name}:select(observed, {{list = {{{at}}}}}) twpart, mf = twiss {{sequence = {mng._sequence_name}, observe = 1, savemap = true, info = 2}} - {mng._sequence_name}.X0 = twpart['{at}'].__map - {mng._sequence_name}.X0.status = "Aset" ! Bug corrected in next version + {XSUITE_MADNG_ENV_NAME}.X0 = twpart[{at}].__map """) - return f"{mng._sequence_name}.X0" + return f"{XSUITE_MADNG_ENV_NAME}.X0" def _survey_ng(line): if not hasattr(line.tracker, '_madng'): @@ -390,11 +474,353 @@ def prepare(self, force=False): if init is not None and start is not None and end is not None: assert isinstance(init, xt.TwissTable) self.X0 = madng_get_init(self.line, at=start) + elif init is not None: + assert isinstance(init, xt.TwissTable) + self.X0 = madng_get_init(self.line, at=xt.START) + self._alredy_prepared = True def run(self): return self.line.madng_twiss(xsuite_tw = False, X0=self.X0, **self.tw_kwargs) +class ActionTwissMadngTPSA(Action): + def __init__(self, line, vary_names, targets = [], tw_kwargs={}, twiss_flag=True, **kwargs): + self.line = line + self.vary_names = vary_names + self.targets = targets + self.target_locations = None # set in prepare + self.tw_kwargs = tw_kwargs + self.tw_kwargs.update(kwargs) + self.twiss_flag = twiss_flag + self._already_prepared = False + + def prepare(self, force=False): + ''' + Prepare the MAD-NG TPSA matching environment. + This method sets up the MAD-NG environment for TPSA matching by + configuring the initial conditions, setting target locations, and quantities + based on the provided targets. + To achieve that, arrays and maps are created within MAD-NG to keep track of + the target locations, quantities and differential algebraic maps. + + Parameters + ---------- + force : bool, optional + If True, forces re-preparation even if already prepared. Default is False. + + Raises + ------ + ValueError + If the target quantity is not allowed with TPSA matching + or if start and end are provided without initial conditions. + ''' + + if self._already_prepared and not force: + return + + init = self.tw_kwargs.get('init', None) + start = self.tw_kwargs.get('start', None) + end = self.tw_kwargs.get('end', None) + + if init is None: + if start is not None and end is not None: + raise ValueError('If start and end are specified, init must be provided as TwissTable.') + else: + init = self.line.madng_twiss(**self.tw_kwargs) + self.tw_kwargs.update({'init': init}) + + assert isinstance(init, xt.TwissTable) + madng_init_flag = "x_ng" in init.cols + quantity_appendix = "_ng" if madng_init_flag else "" + + if not hasattr(self.line.tracker, '_madng'): + self.line.build_madng_model() + mng = self.line.tracker._madng + self.mng = mng + + self.target_locations = set() + targets_map_str = '' + self.target_quantities = set() + + xs_ng_target_map = XSUITE_MADNG_ENV_NAME + '.xs_ng_target_map = {}\n' + + for i, target in enumerate(self.targets): + if isinstance(target.tar, tuple): + qty = target.tar[0][:-3] if target.tar[0].endswith('_ng') else XS_NG_MAP[target.tar[0]] + assert qty in TPSA_ALLOWED_TARGETS, f"Target quantity '{target.tar[0]}' not allowed with TPSA matching." + self.target_locations.add(target.tar[1]) + + aux_str = '' + if qty in OPTFUN_QUANTITIES: + aux_str = 'optfun = true' + elif qty in ['x', 'px', 'y', 'py', 't', 'pt']: + aux_str = f'orbit = {['x', 'px', 'y', 'py', 't', 'pt'].index(qty) + 1}' + + # set string for quantity mapping + loc to save in madng + targets_map_str += f"{XSUITE_MADNG_ENV_NAME}.targets_map[{i+1}] = {{ loc = '{target.tar[1]}', qty = '{qty}', {aux_str} }}\n" + self.target_quantities.add(target.tar[0]) + xs_ng_target_map += f"{XSUITE_MADNG_ENV_NAME}.xs_ng_target_map['{target.tar[0]}'] = '{qty}'\n" + + elif hasattr(target, "start") and hasattr(target, "end"): + start_loc_str = 'nil' + end_loc_str = end + if target.start != '__ele_start__': + self.target_locations.add(target.start) + start_loc_str = target.start + if target.end != '__ele_stop__': + self.target_locations.add(target.end) + end_loc_str = target.end + + qty = target.var[:-3] if target.var.endswith('_ng') else XS_NG_MAP[target.var] + assert qty in TPSA_ALLOWED_TARGETS, f"Target quantity '{target.var}' not allowed with TPSA matching." + targets_map_str += f"{XSUITE_MADNG_ENV_NAME}.targets_map[{i+1}] = {{ loc = '{end_loc_str}', qty = '{qty}', loc_start = '{start_loc_str}', optfun = true }}\n" + self.target_quantities.add(target.var) + xs_ng_target_map += f"{XSUITE_MADNG_ENV_NAME}.xs_ng_target_map['{target.var}'] = '{qty}'\n" + self.target_locations = list(self.target_locations) + + # set coords (TODO: delta) + beta0 = self.line.particle_ref.beta0[0] + + start_loc = 0 if start is None else start + init_coord = np.array([init['x' + quantity_appendix, start_loc], + init['px' + quantity_appendix, start_loc], + init['y' + quantity_appendix, start_loc], + init['py' + quantity_appendix, start_loc], + 0, + 0]) + init_coord[4] = init['t' + quantity_appendix, start_loc] if madng_init_flag else init['zeta', start_loc] * beta0 + init_coord[5] = init['pt' + quantity_appendix, start_loc] if madng_init_flag else dp2pt(init['delta', start_loc], beta0) + + coord_str = '' + part_order = ['x', 'px', 'y', 'py', 't', 'pt'] + for part, val in zip(part_order, init_coord): + if np.abs(val) > 1e-12: + coord_str += f'X0.{part} = {val} ' + + param_assignment_str = '' + param_list_str = '{' + for name in self.vary_names: + param_assignment_str += f"MADX['{name}'] = MADX['{name}'] + X0['{name}'] \n" + param_list_str += f"'{name}', " + param_list_str = param_list_str[:-2] + '}' + + observables_str = '{' + if start is not None and end is not None: + observables_str += f"'{start}', '{end}', " + + if self.target_locations is not None: + for loc in self.target_locations: + if loc != start and loc != end: + observables_str += f"'{loc}', " + observables_str += '}' + + qty_str = '{' + for qty in self.target_quantities: + qty_str += f"'{qty}', " + qty_str += '}' + + if madng_init_flag: + init_cond_str = f"local B0 = MAD.beta0 {{ beta11 = {init['beta11' + quantity_appendix, start_loc]},\n" + f"beta22 = {init['beta22' + quantity_appendix, start_loc]},\n"\ + + f"alfa11 = {init['alfa11' + quantity_appendix, start_loc]},\n" + f"alfa22 = {init['alfa22' + quantity_appendix, start_loc]},\n"\ + + f"dx = {init['dx' + quantity_appendix, start_loc]},\n" + f"dpx = {init['dpx' + quantity_appendix, start_loc]},\n"\ + + f"dy = {init['dy' + quantity_appendix, start_loc]},\n" + f"dpy = {init['dpy' + quantity_appendix, start_loc]}\n }}" + else: + init_cond_str = f"local B0 = MAD.beta0 {{ beta11 = {init['betx', start_loc]},\n" + f"beta22 = {init['bety', start_loc]},\n"\ + + f"alfa11 = {init['alfx', start_loc]},\n" + f"alfa22 = {init['alfy', start_loc]},\n"\ + + f"dx = {init['dx', start_loc]},\n" + f"dpx = {init['dpx', start_loc]},\n"\ + + f"dy = {init['dy', start_loc]},\n" + f"dpy = {init['dpy', start_loc]}\n }}" + + mng_init_str = r''' + ''' + XSUITE_MADNG_ENV_NAME + r''' = {} -- to avoid variable name clashes + local obs_flag = MAD.element.flags.observed + + local pts=''' + observables_str + r''' + + ''' + mng._sequence_name + r''':select(obs_flag, {list=pts}) + + local params = ''' + param_list_str + r''' + + local X0 = MAD.damap { + nv=6, -- number of variables + mo=2, -- max order of variables + np=#params, -- number of parameters + po=1, -- max order of parameters + pn=params, -- parameter names + } + + ''' + coord_str + r''' + + -- Converting to TPSA (mutating type) + for _, v in ipairs(params) do + MADX[v] = MADX[v] + X0[v] + end + + ''' + init_cond_str + r''' + + local map1 = MAD.gphys.bet2map(B0, X0) + + -- Maps target locations to damaps + ''' + XSUITE_MADNG_ENV_NAME + r'''.target_loc_map = table.new(0, ''' + str(len(self.target_locations)) + r''') + -- Array of targets with additional info (location, quantity, orbit/optical function) + ''' + XSUITE_MADNG_ENV_NAME + r'''.targets_map = table.new(''' + str(len(self.targets)) + r''', 0) + -- List of target quantities (suitable for MAD-NG) + ''' + XSUITE_MADNG_ENV_NAME + r'''.tar_qtys = ''' + qty_str + r''' + -- Initial map for tracking/twiss + ''' + XSUITE_MADNG_ENV_NAME + r'''.init_X0_map = map1 + -- Defining targets array + ''' + targets_map_str + r''' + -- Mapping from xsuite quantity names to madng quantity names + ''' + xs_ng_target_map + r''' + ''' + + mng.send(mng_init_str) + + self._already_prepared = True + + def run(self): + ''' + Execute the MAD-NG TPSA matching action. + This method performs either a Twiss or Track operation in MAD-NG + depending if quantities can be calculated through tracking or not. + It retrieves the results and constructs a TwissTable with the requested + target quantities at the specified target locations. + + Returns + ------- + xt.TwissTable + A TwissTable containing the results of the Twiss or Track operation + with the requested target quantities at the specified target locations. + ''' + + if self._already_prepared is False: + self.prepare() + + start = self.tw_kwargs.get('start', None) + end = self.tw_kwargs.get('end', None) + + operation = 'twiss' if self.twiss_flag else 'track' + + range_str = '' + if start is not None and end is not None: + range_str = f"range = '{start}/{end}', " + mng_track_str = ( + f"local trk, mflw = MAD.{operation}{{\n" + f" sequence={self.mng._sequence_name},\n" + f" X0={XSUITE_MADNG_ENV_NAME}.init_X0_map,\n" + f" savemap=true,\n" + f" observe=1,\n" + f" {range_str}\n" + f"}}\n" + f"{XSUITE_MADNG_ENV_NAME}.trk = trk\n" + ) + + self.mng.send(mng_track_str) + + loc_map_str = '' + for i, loc in enumerate(self.target_locations): + loc_map_str += f"{XSUITE_MADNG_ENV_NAME}.target_loc_map['{loc}'] = {XSUITE_MADNG_ENV_NAME}.trk['{loc}'].__map\n" + + + if self.twiss_flag: + mng_table_str = r''' + local trk = ''' + XSUITE_MADNG_ENV_NAME + r'''.trk + ''' + loc_map_str + r''' + py:send(trk) + ''' + + res = self.mng.send(mng_table_str).recv(XSUITE_MADNG_ENV_NAME + '.trk').to_df() + res = xt.TwissTable(res) + + # Add quantities which are not present yet with the name corresponding to the target quantity + for qty in self.target_quantities: + if qty not in res.cols: + res[qty] = res[qty[:-3] if qty.endswith('_ng') else XS_NG_MAP[qty]] + else: + loc_map_str = '' + for loc in self.target_locations: + loc_map_str += f"{XSUITE_MADNG_ENV_NAME}.target_loc_map['{loc}'] = {XSUITE_MADNG_ENV_NAME}.trk['{loc}'].__map\n" + + mng_table_str = r''' + local trk = ''' + XSUITE_MADNG_ENV_NAME + r'''.trk + -- Add derived columns which are not present due to Track calculation + -- and use target names as defined from the user (Xsuite) + for _, tar in ipairs( ''' + XSUITE_MADNG_ENV_NAME + r'''.tar_qtys ) do + if not trk[''' + XSUITE_MADNG_ENV_NAME + r'''.xs_ng_target_map[tar]] then + trk:addcol(tar, \ri -> MAD.gphys.optfun(trk[ri].__map, ''' + XSUITE_MADNG_ENV_NAME + r'''.xs_ng_target_map[tar] .. '_')) + end + end + + -- Save damaps + ''' + loc_map_str + r''' + py:send(trk) + ''' + + res = self.mng.send(mng_table_str).recv(XSUITE_MADNG_ENV_NAME + '.trk').to_df() + res = xt.TwissTable(res) + + return res + + def acquire_jacobian(self): + ''' + Acquire the Jacobian matrix for the TPSA matching targets and variables. + This method computes the Jacobian matrix for the specified targets and + variables using MAD-NG's TPSA capabilities. It constructs + the Jacobian matrix by evaluating the sensitivity of each target quantity + with respect to each variable using MAD-NG's optfun function (optical functions) + or by direct extraction from the TPSA (orbit). + + Returns + ------- + np.ndarray + A 2D NumPy array representing the Jacobian matrix, where each row + corresponds to a target and each column corresponds to a variable. + ''' + + tar_len_str = f"local tarlen = {len(self.targets)}\n" + vary_len_str = f"local varylen = {len(self.vary_names)}\n" + jac_decl_str = f"{XSUITE_MADNG_ENV_NAME}.jac = MAD.matrix(tarlen, varylen)\n" + + mng_str = tar_len_str + vary_len_str + jac_decl_str + r''' + -- Compute Jacobian + for i, target in ipairs( ''' + XSUITE_MADNG_ENV_NAME + r'''.targets_map ) do + local map = ''' + XSUITE_MADNG_ENV_NAME + r'''.target_loc_map[target.loc] + for j = 1, map.np(map), 1 do + -- Quantity which can be calculated with optfun + if target.optfun then + -- If loc_start (phase advance) is defined, we provide initial map + if target.loc_start then + local a0 = ''' + XSUITE_MADNG_ENV_NAME + r'''.target_loc_map[target.loc_start] + ''' + XSUITE_MADNG_ENV_NAME + r'''.jac[(i-1)*varylen + j] = MAD.gphys.optfun(map, target.qty .. "_", j, 1, a0) + else + ''' + XSUITE_MADNG_ENV_NAME + r'''.jac[(i-1)*varylen + j] = MAD.gphys.optfun(map, target.qty .. "_", j, 1) + end + + -- Orbit Quantity + elseif target.orbit then + -- Index 7 + j corresponds to one index of order 0 plus six (for x,px,y,py,t,pt) of order 1 plus parameter index + ''' + XSUITE_MADNG_ENV_NAME + r'''.jac[(i-1)*varylen + j] = map[target.orbit]:get(7 + j) + end + end + end + + py:send(''' + XSUITE_MADNG_ENV_NAME + r'''.jac) + ''' + + self.mng.send(mng_str) + + jac = np.array(self.mng.recv()) + + return jac + + def cleanup(self): + # Need to reconvert TPSAs to normal values + if self._already_prepared is True: + mng_str = '' + for var_name in self.vary_names: + mng_str += f"MADX['{var_name}'] = MADX['{var_name}']:get0()\n" + mng_str += f"{XSUITE_MADNG_ENV_NAME}.X0 = nil\n" + self.mng.send(mng_str) + self._already_prepared = False def line_to_madng(line, sequence_name='seq', temp_fname=None, keep_files=False, **kwargs): @@ -410,11 +836,16 @@ def line_to_madng(line, sequence_name='seq', temp_fname=None, keep_files=False, from pymadng import MAD - mng = MAD(**kwargs) + nocharge = str(kwargs.pop('nocharge', True)).lower() + + mng = MAD(mad_path="/home/babreufi/git/MAD-NG-1.1.5/bin/mad", **kwargs) mng.send(f""" local mad_func = loadfile('{temp_fname}.mad', nil, MADX) assert(mad_func) mad_func() + MAD.option.nocharge = {nocharge} + MADX.option.rbarc = true + {XSUITE_MADNG_ENV_NAME} = {{}} -- to avoid variable name clashes """) mng._init_madx_data = madx_seq diff --git a/xtrack/match.py b/xtrack/match.py index 040983844..19b604a1b 100644 --- a/xtrack/match.py +++ b/xtrack/match.py @@ -535,7 +535,6 @@ def __repr__(self): return f'TargetPhaseAdv({self.var}({self.end} - {self.start}), val={self.value}, tol={self.tol}, weight={self.weight})' def compute(self, tw): - if self.end == '__ele_stop__': mu_1 = tw[self.var, -1] else: @@ -812,6 +811,52 @@ def run(self, allow_failure=True): out.line = self.line return out +class MeritFunctionLine(xd.MeritFunctionForMatch): + def __init__( + self, + merit_function_match, + use_tpsa=False, + ): + + self.vary = merit_function_match.vary + self.targets = merit_function_match.targets + self.actions = merit_function_match.actions + self.return_scalar = merit_function_match.return_scalar + self.call_counter = merit_function_match.call_counter + self.verbose = merit_function_match.verbose + self.tw_kwargs = merit_function_match.tw_kwargs + self.steps_for_jacobian = merit_function_match.steps_for_jacobian + self.found_point_within_tol = merit_function_match.found_point_within_tol + self.zero_if_met = merit_function_match.zero_if_met + self.show_call_counter = merit_function_match.show_call_counter + self.check_limits = merit_function_match.check_limits + self.use_tpsa = use_tpsa + + def get_jacobian(self, x=None, f0=None): + if self.use_tpsa: + return self.get_jacobian_tpsa() + else: + return super().get_jacobian(x, f0=f0) + + def get_jacobian_tpsa(self): + from .madng_interface import ActionTwissMadngTPSA + action = None + for a in self.actions: + if isinstance(a, ActionTwissMadngTPSA): + action = a + break + if action is None: + raise RuntimeError('No ActionTwissMadngTPSA found in actions for TPSA jacobian computation') + + + + jacobian = action.acquire_jacobian() + + for i, tar in enumerate(self.targets): + jacobian[i] *= tar.weight + + return jacobian + class OptimizeLine(xd.Optimize): def __init__(self, line, vary, targets, assert_within_tol=True, @@ -821,7 +866,7 @@ def __init__(self, line, vary, targets, assert_within_tol=True, n_steps_max=20, default_tol=None, solver=None, check_limits=True, action_twiss=None, action_twiss_ng=None, - name="", + use_tpsa=False, name="", **kwargs): if hasattr(targets, 'values'): # dict like @@ -840,17 +885,44 @@ def __init__(self, line, vary, targets, assert_within_tol=True, aux_vary = [] + # part of the `auxvar` experimental code + # if isinstance(tt.value, (GreaterThan, LessThan)): + # if tt.value.mode == 'auxvar': + # aux_vary.append(tt.value.gen_vary(aux_vary_container)) + # aux_vary_container[aux_vary[-1].name] = 0 + # val = tt.runeval() + # if val > 0: + # aux_vary_container[aux_vary[-1].name] = np.sqrt(val) + + if not isinstance(vary, (list, tuple)): + vary = [vary] + + vary = list(vary) + aux_vary + + vary_flatten = _flatten_vary(vary) + _complete_vary_with_info_from_line(vary_flatten, line) + for tt in targets_flatten: # Handle action if tt.action is None: if (isinstance(tt.tar, tuple) and tt.tar[0].endswith('_ng')) or ( - isinstance(tt, TargetRelPhaseAdvance) and tt.var.endswith('_ng')): + isinstance(tt, TargetRelPhaseAdvance) and tt.var.endswith('_ng')) or use_tpsa: if action_twiss_ng is None: - from .madng_interface import ActionTwissMadng - action_twiss_ng = ActionTwissMadng( - line, {}, **kwargs) - action_twiss_ng.prepare() + if use_tpsa: + twiss_flag_ng = any(isinstance(tar, xt.TargetRelPhaseAdvance) for tar in targets_flatten) + + from .madng_interface import ActionTwissMadngTPSA + + action_twiss_ng = ActionTwissMadngTPSA( + line, [v.name for v in vary_flatten], targets_flatten, {}, twiss_flag=twiss_flag_ng, **kwargs) + action_twiss_ng.prepare() + + else: + from .madng_interface import ActionTwissMadng + action_twiss_ng = ActionTwissMadng( + line, {}, **kwargs) + action_twiss_ng.prepare() tt.action = action_twiss_ng else: if action_twiss is None: @@ -906,22 +978,6 @@ def __init__(self, line, vary, targets, assert_within_tol=True, else: tt.tol = default_tol - # part of the `auxvar` experimental code - # if isinstance(tt.value, (GreaterThan, LessThan)): - # if tt.value.mode == 'auxvar': - # aux_vary.append(tt.value.gen_vary(aux_vary_container)) - # aux_vary_container[aux_vary[-1].name] = 0 - # val = tt.runeval() - # if val > 0: - # aux_vary_container[aux_vary[-1].name] = np.sqrt(val) - - if not isinstance(vary, (list, tuple)): - vary = [vary] - - vary = list(vary) + aux_vary - - vary_flatten = _flatten_vary(vary) - _complete_vary_with_info_from_line(vary_flatten, line) xd.Optimize.__init__(self, vary=vary_flatten, targets=targets_flatten, solver=solver, @@ -931,9 +987,12 @@ def __init__(self, line, vary, targets, assert_within_tol=True, restore_if_fail=restore_if_fail, check_limits=check_limits, name=name) + + _err = MeritFunctionLine(self._err, use_tpsa=use_tpsa) self.line = line self.action_twiss = action_twiss self.default_tol = default_tol + self._err = _err def clone(self, add_targets=None, add_vary=None, remove_targets=None, remove_vary=None, @@ -991,6 +1050,40 @@ def clone(self, add_targets=None, add_vary=None, def plot(self, *args, **kwargs): return self.action_twiss.run().plot(*args, **kwargs) + def step( + self, + n_steps=1, + take_best=True, + enable_target=None, + enable_vary=None, + enable_vary_name=None, + disable_target=None, + disable_vary=None, + disable_vary_name=None, + rcond=None, + sing_val_cutoff=None, + verbose=None, + broyden=False, + cleanup_madng_tpsa=False, + ): + super().step(n_steps, take_best, enable_target, enable_vary, enable_vary_name, disable_target, + disable_vary, disable_vary_name, rcond, sing_val_cutoff, verbose, broyden) + + if cleanup_madng_tpsa and self._err.use_tpsa: + for a in self.actions: + if hasattr(a, "cleanup"): + a.cleanup() + break + + def solve(self, n_steps=None, verbose=None, take_best=True, rcond=None, sing_val_cutoff=None, broyden=False, cleanup_madng_tpsa=True): + super().solve(n_steps, verbose, take_best, rcond, sing_val_cutoff, broyden) + + if cleanup_madng_tpsa and self._err.use_tpsa: + for a in self.actions: + if hasattr(a, "cleanup"): + a.cleanup() + break + def _flatten_vary(vary): vary_flatten = [] for vv in vary: