Skip to content

Commit 88ac765

Browse files
authored
Merge pull request #8 from EuropeanSpallationSource/sync_w_python2_version
Sync latest changes in the python2 version
2 parents c10d921 + f92abe9 commit 88ac765

32 files changed

+3245
-2881
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ repos:
44
hooks:
55
- id: end-of-file-fixer
66
- id: trailing-whitespace
7-
exclude: '^tests/examples/SNS_Linac/linac_errors/.*|^tests/examples/AccLattice_Tests/.*|^tests/examples/Optimization/.*'
7+
exclude: '^tests/examples/SNS_Linac/linac_errors/.*|^tests/examples/AccLattice_Tests/.*|^tests/examples/Optimization/.*|^.*\.dat'
88

99

1010
- repo: https://github.com/ambv/black

py/orbit/py_linac/lattice/LinacAccNodes.py

100644100755
Lines changed: 82 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -308,13 +308,15 @@ def __init__(self, name="quad"):
308308
self.setType("linacQuad")
309309

310310
def fringeIN(node, paramsDict):
311-
# B*rho = 3.335640952*momentum [T*m] if momentum in GeV/c
311+
# B*rho = 3.335640952*momentum/charge [T*m] if momentum in GeV/c
312312
usageIN = node.getUsage()
313313
if not usageIN:
314314
return
315315
bunch = paramsDict["bunch"]
316+
charge = bunch.charge()
316317
momentum = bunch.getSyncParticle().momentum()
317-
kq = node.getParam("dB/dr") / (3.335640952 * momentum)
318+
# ---- The charge sign will be accounted for inside tracking module functions.
319+
kq = node.getParam("dB/dr") / bunch.B_Rho()
318320
poleArr = node.getParam("poles")
319321
klArr = node.getParam("kls")
320322
skewArr = node.getParam("skews")
@@ -329,12 +331,15 @@ def fringeIN(node, paramsDict):
329331
TPB.multpfringeIN(bunch, pole, k, skew)
330332

331333
def fringeOUT(node, paramsDict):
334+
# B*rho = 3.335640952*momentum/charge [T*m] if momentum in GeV/c
332335
usageOUT = node.getUsage()
333336
if not usageOUT:
334337
return
335338
bunch = paramsDict["bunch"]
339+
charge = bunch.charge()
336340
momentum = bunch.getSyncParticle().momentum()
337-
kq = node.getParam("dB/dr") / (3.335640952 * momentum)
341+
# ---- The charge sign will be accounted for inside tracking module functions
342+
kq = node.getParam("dB/dr") / bunch.B_Rho()
338343
poleArr = node.getParam("poles")
339344
klArr = node.getParam("kls")
340345
skewArr = node.getParam("skews")
@@ -393,27 +398,32 @@ def initialize(self):
393398
for i in range(nParts):
394399
self.setLength(lengthStep, i)
395400
"""
396-
#=============================================
397-
# This is an old TEAPOT-like implementation
398-
# of the Quad slicing.
399-
#=============================================
400-
lengthIN = (self.getLength()/(nParts - 1))/2.0
401-
lengthOUT = (self.getLength()/(nParts - 1))/2.0
402-
lengthStep = lengthIN + lengthOUT
403-
self.setLength(lengthIN,0)
404-
self.setLength(lengthOUT,nParts - 1)
405-
for i in range(nParts-2):
406-
self.setLength(lengthStep,i+1)
407-
"""
401+
#=============================================
402+
# This is an old TEAPOT-like implementation
403+
# of the Quad slicing.
404+
#=============================================
405+
lengthIN = (self.getLength()/(nParts - 1))/2.0
406+
lengthOUT = (self.getLength()/(nParts - 1))/2.0
407+
lengthStep = lengthIN + lengthOUT
408+
self.setLength(lengthIN,0)
409+
self.setLength(lengthOUT,nParts - 1)
410+
for i in range(nParts-2):
411+
self.setLength(lengthStep,i+1)
412+
"""
408413

409414
def track(self, paramsDict):
410415
"""
411416
The Quad Combined Function TEAPOT class implementation
412417
of the AccNode class track(probe) method.
413418
"""
414419
bunch = paramsDict["bunch"]
420+
charge = bunch.charge()
415421
momentum = bunch.getSyncParticle().momentum()
416-
kq = self.getParam("dB/dr") / (3.335640952 * momentum)
422+
# ---- The sign of dB/dr will be delivered to tracking module
423+
# ---- functions as kq.
424+
# ---- The charge sign will be accounted for inside tracking module
425+
# ---- functions.
426+
kq = self.getParam("dB/dr") / bunch.B_Rho()
417427
nParts = self.getnParts()
418428
index = self.getActivePartIndex()
419429
length = self.getLength(index)
@@ -447,33 +457,33 @@ def track(self, paramsDict):
447457
self.tracking_module.quad2(bunch, step / 4)
448458
self.tracking_module.quad1(bunch, step / 4, kq)
449459
"""
450-
#=============================================
451-
# This is an old TEAPOT-like implementation
452-
# of the Quad tracking.
453-
#=============================================
454-
if(index == 0):
455-
self.tracking_module.quad1(bunch, length, kq)
456-
return
457-
if(index > 0 and index < (nParts-1)):
458-
self.tracking_module.quad2(bunch, length/2.0)
459-
for i in range(len(poleArr)):
460-
pole = poleArr[i]
461-
kl = klArr[i]/(nParts - 1)
462-
skew = skewArr[i]
463-
TPB.multp(bunch,pole,kl,skew)
464-
self.tracking_module.quad2(bunch, length/2.0)
465-
self.tracking_module.quad1(bunch, length, kq)
466-
return
467-
if(index == (nParts-1)):
468-
self.tracking_module.quad2(bunch, length)
469-
for i in range(len(poleArr)):
470-
pole = poleArr[i]
471-
kl = klArr[i]*kq*length/(nParts - 1)
472-
skew = skewArr[i]
473-
TPB.multp(bunch,pole,kl,skew)
474-
self.tracking_module.quad2(bunch, length)
475-
self.tracking_module.quad1(bunch, length, kq)
476-
"""
460+
#=============================================
461+
# This is an old TEAPOT-like implementation
462+
# of the Quad tracking.
463+
#=============================================
464+
if(index == 0):
465+
self.tracking_module.quad1(bunch, length, kq)
466+
return
467+
if(index > 0 and index < (nParts-1)):
468+
self.tracking_module.quad2(bunch, length/2.0)
469+
for i in range(len(poleArr)):
470+
pole = poleArr[i]
471+
kl = klArr[i]/(nParts - 1)
472+
skew = skewArr[i]
473+
TPB.multp(bunch,pole,kl,skew)
474+
self.tracking_module.quad2(bunch, length/2.0)
475+
self.tracking_module.quad1(bunch, length, kq)
476+
return
477+
if(index == (nParts-1)):
478+
self.tracking_module.quad2(bunch, length)
479+
for i in range(len(poleArr)):
480+
pole = poleArr[i]
481+
kl = klArr[i]*kq*length/(nParts - 1)
482+
skew = skewArr[i]
483+
TPB.multp(bunch,pole,kl,skew)
484+
self.tracking_module.quad2(bunch, length)
485+
self.tracking_module.quad1(bunch, length, kq)
486+
"""
477487
return
478488

479489
def getTotalField(self, z):
@@ -537,7 +547,17 @@ def fringeIN(node, paramsDict):
537547
TPB.multpfringeIN(bunch, pole, kl, skew)
538548
frinout = 1
539549
TPB.wedgerotate(bunch, e, frinout)
540-
TPB.wedgebendCF(bunch, e, inout, rho, len(poleArr), poleArr, klArr, skewArr, nParts - 1)
550+
TPB.wedgebendCF(
551+
bunch,
552+
e,
553+
inout,
554+
rho,
555+
len(poleArr),
556+
poleArr,
557+
klArr,
558+
skewArr,
559+
nParts - 1,
560+
)
541561
else:
542562
if usageIN:
543563
TPB.bendfringeIN(bunch, rho)
@@ -560,7 +580,17 @@ def fringeOUT(node, paramsDict):
560580
nParts = paramsDict["parentNode"].getnParts()
561581
if e != 0.0:
562582
inout = 1
563-
TPB.wedgebendCF(bunch, e, inout, rho, len(poleArr), poleArr, klArr, skewArr, nParts - 1)
583+
TPB.wedgebendCF(
584+
bunch,
585+
e,
586+
inout,
587+
rho,
588+
len(poleArr),
589+
poleArr,
590+
klArr,
591+
skewArr,
592+
nParts - 1,
593+
)
564594
if usageOUT:
565595
frinout = 0
566596
TPB.wedgerotate(bunch, -e, frinout)
@@ -697,11 +727,11 @@ def track(self, paramsDict):
697727
length = self.getParam("effLength") / nParts
698728
field = self.getParam("B")
699729
bunch = paramsDict["bunch"]
730+
charge = bunch.charge()
700731
syncPart = bunch.getSyncParticle()
701732
momentum = syncPart.momentum()
702733
# dp/p = Q*c*B*L/p p in GeV/c c = 2.99792*10^8/10^9
703-
# Q is used inside kick-method
704-
kick = field * length * 0.299792 / momentum
734+
kick = -field * charge * length * 0.299792 / momentum
705735
self.tracking_module.kick(bunch, kick, 0.0, 0.0)
706736

707737

@@ -742,11 +772,11 @@ def track(self, paramsDict):
742772
length = self.getParam("effLength") / nParts
743773
field = self.getParam("B")
744774
bunch = paramsDict["bunch"]
775+
charge = bunch.charge()
745776
syncPart = bunch.getSyncParticle()
746777
momentum = syncPart.momentum()
747778
# dp/p = Q*c*B*L/p p in GeV/c, c = 2.99792*10^8/10^9
748-
# Q is used inside kick-method
749-
kick = field * length * 0.299792 / momentum
779+
kick = field * charge * length * 0.299792 / momentum
750780
self.tracking_module.kick(bunch, 0, kick, 0.0)
751781

752782

@@ -803,6 +833,7 @@ def track(self, paramsDict):
803833
The Thick Kick class implementation of the AccNode class track(probe) method.
804834
"""
805835
bunch = paramsDict["bunch"]
836+
charge = bunch.charge()
806837
momentum = bunch.getSyncParticle().momentum()
807838
Bx = self.getParam("Bx")
808839
By = self.getParam("By")
@@ -812,9 +843,8 @@ def track(self, paramsDict):
812843
# print "debug name =",self.getName()," Bx=",Bx," By=",By," L=",self.getLength(index)," index=",index
813844
# ==========================================
814845
# dp/p = Q*c*B*L/p p in GeV/c, c = 2.99792*10^8/10^9
815-
# Q is used inside kick-method
816-
kickY = Bx * length * 0.299792 / momentum
817-
kickX = By * length * 0.299792 / momentum
846+
kickY = +Bx * charge * length * 0.299792 / momentum
847+
kickX = -By * charge * length * 0.299792 / momentum
818848
self.tracking_module.drift(bunch, length / 2.0)
819849
self.tracking_module.kick(bunch, kickX, kickY, 0.0)
820850
self.tracking_module.drift(bunch, length / 2.0)

py/orbit/py_linac/lattice/LinacFieldOverlappingNodes.py

100644100755
Lines changed: 43 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ def track(self, paramsDict):
288288
index = self.getActivePartIndex()
289289
part_length = self.getLength(index)
290290
bunch = paramsDict["bunch"]
291+
charge = bunch.charge()
291292
syncPart = bunch.getSyncParticle()
292293
eKin_in = syncPart.kinEnergy()
293294
momentum = syncPart.momentum()
@@ -300,7 +301,10 @@ def track(self, paramsDict):
300301
phase_shift = rfCavity.getPhase() - rfCavity.getDesignPhase()
301302
phase = rfCavity.getFirstGapEtnrancePhase() + phase_shift
302303
# ----------------------------------------
303-
phase = math.fmod(frequency * (arrival_time - designArrivalTime) * 2.0 * math.pi + phase, 2.0 * math.pi)
304+
phase = math.fmod(
305+
frequency * (arrival_time - designArrivalTime) * 2.0 * math.pi + phase,
306+
2.0 * math.pi,
307+
)
304308
if index == 0:
305309
self.part_pos = self.z_min
306310
self.gap_phase_vs_z_arr = [
@@ -314,21 +318,20 @@ def track(self, paramsDict):
314318
Ep = self.getEzFiledInternal(zp, rfCavity, E0L, rf_ampl)
315319
# ------- track through a quad
316320
G = self.getTotalField((zm + z0) / 2)
317-
GP = 0.0
321+
dB_dz = 0.0
318322
if self.useLongField == True:
319-
GP = self.getTotalFieldDerivative((zm + z0) / 2)
323+
dB_dz = self.getTotalFieldDerivative((zm + z0) / 2)
320324
if abs(G) != 0.0:
321-
kq = G / (3.335640952 * momentum)
325+
kq = G / bunch.B_Rho()
322326
# ------- track through a quad
323327
step = part_length / 2
324328
self.tracking_module.quad1(bunch, step / 4.0, kq)
325329
self.tracking_module.quad2(bunch, step / 2.0)
326330
self.tracking_module.quad1(bunch, step / 2.0, kq)
327331
self.tracking_module.quad2(bunch, step / 2.0)
328332
self.tracking_module.quad1(bunch, step / 4.0, kq)
329-
if abs(GP) != 0.0:
330-
kqP = GP / (3.335640952 * momentum)
331-
self.tracking_module.quad3(bunch, step, kqP)
333+
if abs(dB_dz) != 0.0:
334+
self.tracking_module.quad3(bunch, step, dB_dz)
332335
else:
333336
self.tracking_module.drift(bunch, part_length / 2)
334337
self.part_pos += part_length / 2
@@ -344,23 +347,30 @@ def track(self, paramsDict):
344347
# s += " dE= %9.6f "%((eKin_out-eKin_in)*1000.)
345348
# print s
346349
# ---- this part is the debugging ---STOP---
347-
self.cppGapModel.trackBunch(bunch, part_length / 2, Em, E0, Ep, frequency, phase + delta_phase + modePhase)
350+
self.cppGapModel.trackBunch(
351+
bunch,
352+
part_length / 2,
353+
Em,
354+
E0,
355+
Ep,
356+
frequency,
357+
phase + delta_phase + modePhase,
358+
)
348359
# ------- track through a quad
349360
G = self.getTotalField((z0 + zp) / 2)
350-
GP = 0.0
361+
dB_dz = 0.0
351362
if self.useLongField == True:
352-
GP = self.getTotalFieldDerivative((z0 + zp) / 2)
363+
dB_dz = self.getTotalFieldDerivative((z0 + zp) / 2)
353364
if abs(G) != 0.0:
354-
kq = G / (3.335640952 * momentum)
365+
kq = G / bunch.B_Rho()
355366
step = part_length / 2
356367
self.tracking_module.quad1(bunch, step / 4.0, kq)
357368
self.tracking_module.quad2(bunch, step / 2.0)
358369
self.tracking_module.quad1(bunch, step / 2.0, kq)
359370
self.tracking_module.quad2(bunch, step / 2.0)
360371
self.tracking_module.quad1(bunch, step / 4.0, kq)
361-
if abs(GP) != 0.0:
362-
kqP = GP / (3.335640952 * momentum)
363-
self.tracking_module.quad3(bunch, step, kqP)
372+
if abs(dB_dz) != 0.0:
373+
self.tracking_module.quad3(bunch, step, dB_dz)
364374
else:
365375
self.tracking_module.drift(bunch, part_length / 2)
366376
# ---- advance the particle position
@@ -436,7 +446,10 @@ def trackDesign(self, paramsDict):
436446
else:
437447
first_gap_arr_time = rfCavity.getDesignArrivalTime()
438448
# print "debug name=",self.getName()," delta_phase=",frequency*(arrival_time - first_gap_arr_time)*360.0," phase=",phase*180/math.pi
439-
phase = math.fmod(frequency * (arrival_time - first_gap_arr_time) * 2.0 * math.pi + phase, 2.0 * math.pi)
449+
phase = math.fmod(
450+
frequency * (arrival_time - first_gap_arr_time) * 2.0 * math.pi + phase,
451+
2.0 * math.pi,
452+
)
440453
# print "debug design name=",self.getName()," arr_time=",arrival_time," phase=",phase*180./math.pi," E0TL=",E0TL*1.0e+3," freq=",frequency
441454
if index == 0:
442455
self.part_pos = self.z_min
@@ -464,7 +477,15 @@ def trackDesign(self, paramsDict):
464477
# s += " dE= %9.6f "%((eKin_out-eKin_in)*1000.)
465478
# print s
466479
# ---- this part is the debugging ---STOP---
467-
self.cppGapModel.trackBunch(bunch, part_length / 2, Em, E0, Ep, frequency, phase + delta_phase + modePhase)
480+
self.cppGapModel.trackBunch(
481+
bunch,
482+
part_length / 2,
483+
Em,
484+
E0,
485+
Ep,
486+
frequency,
487+
phase + delta_phase + modePhase,
488+
)
468489
self.tracking_module.drift(bunch, part_length / 2)
469490
# ---- advance the particle position
470491
self.part_pos += part_length / 2
@@ -627,16 +648,17 @@ def track(self, paramsDict):
627648
if index == 0:
628649
self.z_value = -self.getLength() / 2
629650
bunch = paramsDict["bunch"]
651+
charge = bunch.charge()
630652
momentum = bunch.getSyncParticle().momentum()
631653
n_steps = int(length / self.z_step) + 1
632654
z_step = length / n_steps
633655
for z_ind in range(n_steps):
634656
z = self.z_value + z_step * (z_ind + 0.5)
635657
G = self.getTotalField(z)
636-
GP = 0.0
658+
dB_dz = 0.0
637659
if self.useLongField == True:
638-
GP = self.getTotalFieldDerivative(z)
639-
kq = G / (3.335640952 * momentum)
660+
dB_dz = self.getTotalFieldDerivative(z)
661+
kq = G / bunch.B_Rho()
640662
if abs(kq) == 0.0:
641663
self.tracking_module.drift(bunch, z_step)
642664
continue
@@ -646,9 +668,8 @@ def track(self, paramsDict):
646668
self.tracking_module.quad1(bunch, z_step / 2.0, kq)
647669
self.tracking_module.quad2(bunch, z_step / 2.0)
648670
self.tracking_module.quad1(bunch, z_step / 4.0, kq)
649-
if abs(GP) != 0.0:
650-
kqP = GP / (3.335640952 * momentum)
651-
self.tracking_module.quad3(bunch, z_step, kqP)
671+
if abs(dB_dz) != 0.0:
672+
self.tracking_module.quad3(bunch, z_step, dB_dz)
652673
self.z_value += length
653674

654675
def getTotalField(self, z_from_center):

0 commit comments

Comments
 (0)