diff --git a/xfields/__init__.py b/xfields/__init__.py index e68c3215..1b7319c1 100644 --- a/xfields/__init__.py +++ b/xfields/__init__.py @@ -25,6 +25,7 @@ from .beam_elements.temp_slicer import TempSlicer from .beam_elements.electroncloud import ElectronCloud from .beam_elements.electronlens_interpolated import ElectronLensInterpolated +from .beam_elements.waketracker import WakeTracker from .general import _pkg_root from .config_tools import replace_spacecharge_with_quasi_frozen diff --git a/xfields/beam_elements/element_with_slicer.py b/xfields/beam_elements/element_with_slicer.py index 270b8ab4..714fc9c2 100644 --- a/xfields/beam_elements/element_with_slicer.py +++ b/xfields/beam_elements/element_with_slicer.py @@ -77,6 +77,7 @@ def __init__(self, num_slices=num_slices, # Per bunch, this is N_1 in the paper bunch_spacing_zeta=bunch_spacing_zeta, # This is P in the paper filling_scheme=filling_scheme, + bunch_selection=bunch_selection, num_turns=num_turns, circumference=circumference) @@ -104,19 +105,28 @@ def _initialize_moments( num_slices=None, # Per bunch, this is N_1 in the paper bunch_spacing_zeta=None, # This is P in the paper filling_scheme=None, + bunch_selection=None, num_turns=1, circumference=None): + if filling_scheme is not None: i_last_bunch = np.where(filling_scheme)[0][-1] num_periods = i_last_bunch + 1 else: num_periods = 1 + + if bunch_selection is None: + num_targets = num_periods + else: + num_targets = 1+ np.max(bunch_selection)-np.min(bunch_selection) + self.moments_data = CompressedProfile( moments=self.source_moments + ['result'], zeta_range=zeta_range, num_slices=num_slices, bunch_spacing_zeta=bunch_spacing_zeta, + num_targets = num_targets, num_periods=num_periods, num_turns=num_turns, circumference=circumference, diff --git a/xfields/beam_elements/waketracker/convolution.py b/xfields/beam_elements/waketracker/convolution.py index 17fc2d0a..56e0b343 100644 --- a/xfields/beam_elements/waketracker/convolution.py +++ b/xfields/beam_elements/waketracker/convolution.py @@ -1,4 +1,6 @@ +import time import numpy as np +from matplotlib import pyplot as plt from scipy.constants import e as qe import xobjects as xo @@ -64,13 +66,17 @@ def __init__(self, component, waketracker=None, _flatten=False, log_moments=None def my_rfft(self, data, **kwargs): if type(self._context) in (xo.ContextCpu, xo.ContextCupy): - return self._context.nplike_lib.fft.rfft(data, **kwargs) + if hasattr(self._context,'omp_num_threads') and self._context.omp_num_threads > 1: + kwargs['workers'] = self._context.omp_num_threads + return self._context.splike_lib.fft.rfft(data, **kwargs) else: raise NotImplementedError('Waketacker implemented only for CPU and Cupy') def my_irfft(self, data, **kwargs): if type(self._context) in (xo.ContextCpu, xo.ContextCupy): - return self._context.nplike_lib.fft.irfft(data, **kwargs) + if hasattr(self._context,'omp_num_threads') and self._context.omp_num_threads > 1: + kwargs['workers'] = self._context.omp_num_threads + return self._context.splike_lib.fft.irfft(data, **kwargs) else: raise NotImplementedError('Waketacker implemented only for CPU and Cupy') @@ -81,13 +87,13 @@ def _initialize_conv_data(self, _flatten=False, moments_data=None, beta0=None): self._M_aux = moments_data._M_aux self._N_1 = moments_data._N_1 self._N_S = moments_data._N_S - self._N_T = moments_data._N_S + self._N_T = moments_data._N_T self._BB = 1 # B in the paper # (for now we assume that B=0 is the first bunch in time and the # last one in zeta) self._AA = self._BB - self._N_S - self._CC = self._AA - self._DD = self._BB + self._DD = -1*np.min(self.waketracker.slicer.bunch_selection)+1 + self._CC = self._DD - self._N_T # Build wake matrix self.z_wake = _build_z_wake(moments_data._z_a, moments_data._z_b, @@ -97,6 +103,7 @@ def _initialize_conv_data(self, _flatten=False, moments_data=None, beta0=None): moments_data.dz, self._AA, self._BB, self._CC, self._DD, moments_data._z_P) + assert beta0 is not None # here below I had to add float() to beta0 because when using Cupy # context particles.beta0[0] turns out to be a 0d array. To be checked @@ -246,6 +253,6 @@ def _build_z_wake(z_a, z_b, num_turns, n_aux, m_aux, circumference, dz, z_p = 0 for ii, ll in enumerate(range( - cc - bb + 1, dd - aa)): + int(cc - bb + 1), int(dd - aa))): z_wake[tt, ii * n_aux:(ii + 1) * n_aux] = temp_z + ll * z_p return z_wake diff --git a/xfields/beam_elements/waketracker/waketracker.py b/xfields/beam_elements/waketracker/waketracker.py index e21a6688..e892f076 100644 --- a/xfields/beam_elements/waketracker/waketracker.py +++ b/xfields/beam_elements/waketracker/waketracker.py @@ -1,5 +1,4 @@ from typing import Tuple - import numpy as np from scipy.constants import c as clight @@ -51,6 +50,9 @@ def __init__(self, components, filling_scheme=None, bunch_selection=None, num_turns=1, + fake_coupled_bunch_phase_x = None, + fake_coupled_bunch_phase_y = None, + beta_x = None, beta_y = None, circumference=None, log_moments=None, _flatten=False, @@ -61,11 +63,30 @@ def __init__(self, components, self.components = components self.pipeline_manager = None + self.fake_coupled_bunch_phases = {} + self.betas = {} + if fake_coupled_bunch_phase_x is not None: + self.fake_coupled_bunch_phases['x'] = fake_coupled_bunch_phase_x + assert beta_x is not None and beta_x > 0 + self.betas['x'] = beta_x + if fake_coupled_bunch_phase_y is not None: + self.fake_coupled_bunch_phases['y'] = fake_coupled_bunch_phase_y + assert beta_y is not None and beta_y > 0 + self.betas['y'] = beta_y + if self.fake_coupled_bunch_phases: + assert bunch_selection is not None and filling_scheme is not None + assert bunch_selection, "When faking a coupled bunch mode, only one bunch should be selected as ref." + all_slicer_moments = [] for cc in self.components: assert not hasattr(cc, 'moments_data') or cc.moments_data is None all_slicer_moments += cc.source_moments + if self.fake_coupled_bunch_phases: + for moment_name in self.fake_coupled_bunch_phases.keys(): + if moment_name in all_slicer_moments: + all_slicer_moments.append('p'+moment_name) + self.all_slicer_moments = list(set(all_slicer_moments)) super().__init__( @@ -79,18 +100,18 @@ def __init__(self, components, num_turns=num_turns, circumference=circumference, with_compressed_profile=True, - _context=self.context) + _context=self._context) self._initialize_moments( zeta_range=zeta_range, # These are [a, b] in the paper num_slices=num_slices, # Per bunch, this is N_1 in the paper bunch_spacing_zeta=bunch_spacing_zeta, # This is P in the paper filling_scheme=filling_scheme, + bunch_selection=bunch_selection, num_turns=num_turns, circumference=circumference) self._flatten = _flatten - all_slicer_moments = list(set(all_slicer_moments)) def init_pipeline(self, pipeline_manager, element_name, partner_names): @@ -99,12 +120,11 @@ def init_pipeline(self, pipeline_manager, element_name, partner_names): partner_names=partner_names) def track(self, particles): - # Find first active particle to get beta0 if particles.state[0] > 0: beta0 = particles.beta0[0] else: - i_alive = np.where(particles.state > 0)[0] + i_alive = self._context.nplike_lib.where(particles.state > 0)[0] if len(i_alive) == 0: return i_first = i_alive[0] @@ -122,19 +142,35 @@ def track(self, particles): cc._conv_data._initialize_conv_data(_flatten=self._flatten, moments_data=self.moments_data, beta0=beta0) - # Use common slicer from parent class to measure all moments status = super().track(particles) - if status and status.on_hold == True: return status + if self.fake_coupled_bunch_phases: + self._compute_fake_bunch_moments() + for wf in self.components: wf._conv_data.track(particles, i_slot_particles=self.i_slot_particles, i_slice_particles=self.i_slice_particles, moments_data=self.moments_data) + def _compute_fake_bunch_moments(self): + conjugate_names = {'x':'px','y':'py'} + n_slots = int(self._context.nplike_lib.max(self.slicer.filled_slots))+1 + for moment_name in self.fake_coupled_bunch_phases.keys(): + z_dummy,mom = self.moments_data.get_source_moment_profile(moment_name,0,self.bunch_selection[0]) + z_dummy,mom_conj = self.moments_data.get_source_moment_profile(conjugate_names[moment_name],0,self.bunch_selection[0]) + complex_normalised_moments = mom + (1j*self.betas[moment_name])*mom_conj + slots = self._context.nplike_lib.transpose(self._context.nplike_lib.tile(self.slicer.filled_slots,(len(complex_normalised_moments),1))) + complex_normalised_moments = self._context.nplike_lib.tile(complex_normalised_moments,(n_slots,1)) + all_beam_moments = self._context.nplike_lib.real(complex_normalised_moments*self._context.nplike_lib.exp(1j*self.fake_coupled_bunch_phases[moment_name]*(self.bunch_selection[0]-slots))) + self.moments_data.set_all_beam_moments(moment_name,0,all_beam_moments) + z_dummy,mom = self.moments_data.get_source_moment_profile('num_particles',0,self.bunch_selection[0]) + all_beam_num_particles = self._context.nplike_lib.tile(mom,(n_slots,1)) + self.moments_data.set_all_beam_moments('num_particles',0,all_beam_num_particles) + @property def zeta_range(self): return self.slicer.zeta_range @@ -158,7 +194,7 @@ def num_turns(self): @property def circumference(self): return self.moments_data.circumference - + def __add__(self, other): if other == 0: diff --git a/xfields/slicers/compressed_profile.py b/xfields/slicers/compressed_profile.py index 1895690e..fdcb02ac 100644 --- a/xfields/slicers/compressed_profile.py +++ b/xfields/slicers/compressed_profile.py @@ -220,3 +220,39 @@ def get_moment_profile(self, moment_name, i_turn): i_start_in_moments_data:i_end_in_moments_data]) return z_out, moment_out + + def get_source_moment_profile(self, moment_name, i_turn,i_source): + """ + Get the moment profile for a given turn. + + Parameters + ---------- + moment_name : str + The name of the moment to get + i_turn : int + The turn index, 0 <= i_turn < self.num_turns + + Returns + ------- + z_out : np.ndarray + The z positions within the moment profile + moment_out : np.ndarray + The moment profile + """ + + z_out = self._arr2ctx(np.zeros(self._N_1)) + moment_out = self._arr2ctx(np.zeros(self._N_1)) + i_moment = self.moments_names.index(moment_name) + _z_P = self._z_P or 0 + + z_out = ( + self._z_a + self.dz / 2 + - i_source * _z_P + self.dz * self._arr2ctx(np.arange(self._N_1))) + + i_start_in_moments_data = (self._N_S - i_source - 1) * self._N_aux + i_end_in_moments_data = i_start_in_moments_data + self._N_1 + moment_out = ( + self.data[i_moment, i_turn, + i_start_in_moments_data:i_end_in_moments_data]) + + return z_out, moment_out