@@ -12,3 +12,343 @@ $ pip install git+https://github.com/i-a-morozov/sympint.git@main
1212
1313[ https://i-a-morozov.github.io/sympint/ ] ( https://i-a-morozov.github.io/sympint/ )
1414
15+ # Demo
16+
17+
18+ ``` python
19+ # In this demo construction of symplectic integrators is illustated for basic accelerator elements
20+ ```
21+
22+
23+ ``` python
24+ # Import
25+
26+ import torch
27+ import jax
28+
29+ # Exact solutions
30+
31+ from model.library.transformations import drift
32+ from model.library.transformations import quadrupole
33+ from model.library.transformations import bend
34+
35+ # Function iterations
36+
37+ from sympint import nest
38+ from sympint import fold
39+
40+ # Integrators and composer
41+
42+ from sympint import sequence
43+ from sympint import midpoint
44+ from sympint import tao
45+ ```
46+
47+
48+ ``` python
49+ # Set data type
50+
51+ jax.config.update(" jax_enable_x64" , True )
52+ ```
53+
54+
55+ ``` python
56+ # Set device
57+
58+ device, * _ = jax.devices(' cpu' )
59+ jax.config.update(' jax_default_device' , device)
60+ ```
61+
62+
63+ ``` python
64+ # Define Hamiltonial functions for accelerator elements
65+
66+ def h_drif (qs , ps , t , dp , * args ):
67+ qx, qy = qs
68+ px, py = ps
69+ return 1 / 2 * (px** 2 + py** 2 )/ (1 + dp)
70+
71+ def h_quad (qs , ps , t , kn , ks , dp , * args ):
72+ qx, qy = qs
73+ px, py = ps
74+ return 1 / 2 * (px** 2 + py** 2 )/ (1 + dp) + 1 / 2 * kn* (qx** 2 - qy** 2 ) - ks* qx* qy
75+
76+ def h_bend (qs , ps , t , rd , kn , ks , dp , * args ):
77+ qx, qy = qs
78+ px, py = ps
79+ return 1 / 2 * (px** 2 + py** 2 )/ (1 + dp) - qx* dp/ rd + qx** 2 / (2 * rd** 2 ) + 1 / 2 * kn* (qx** 2 - qy** 2 ) - ks* qx* qy
80+ ```
81+
82+
83+ ``` python
84+ # Set parameters
85+
86+ ti = torch.tensor(0.0 , dtype = torch.float64)
87+ dt = torch.tensor(0.1 , dtype = torch.float64)
88+ rd = torch.tensor(25.0 , dtype = torch.float64)
89+ kn = torch.tensor(2.0 , dtype = torch.float64)
90+ ks = torch.tensor(0.1 , dtype = torch.float64)
91+ dp = torch.tensor(0.001 , dtype = torch.float64)
92+ ```
93+
94+
95+ ``` python
96+ # Hamiltonian conservation (drif)
97+
98+ (qx, px, qy, py) = x = torch.tensor([0.01 , 0.001 , - 0.005 , 0.0005 ], dtype = torch.float64)
99+ qs = torch.stack([qx, qy])
100+ ps = torch.stack([px, py])
101+ hi = h_drif(qs, ps, ti, dp)
102+
103+ (qx, px, qy, py) = drift(x, dp, dt)
104+ qs = torch.stack([qx, qy])
105+ ps = torch.stack([px, py])
106+ hf = h_drif(qs, ps, ti, dp)
107+
108+ print (torch.allclose(hi, hf, rtol = 1.0E-16 , atol = 1.0E-16 ))
109+ ```
110+
111+ True
112+
113+
114+
115+ ``` python
116+ # Hamiltonian conservation (quad)
117+
118+ (qx, px, qy, py) = x = torch.tensor([0.01 , 0.001 , - 0.005 , 0.0005 ], dtype = torch.float64)
119+ qs = torch.stack([qx, qy])
120+ ps = torch.stack([px, py])
121+ hi = h_quad(qs, ps, ti, kn, ks, dp)
122+
123+ (qx, px, qy, py) = quadrupole(x, kn, ks, dp, dt)
124+ qs = torch.stack([qx, qy])
125+ ps = torch.stack([px, py])
126+ hf = h_quad(qs, ps, ti, kn, ks, dp)
127+
128+ print (torch.allclose(hi, hf, rtol = 1.0E-16 , atol = 1.0E-16 ))
129+ ```
130+
131+ True
132+
133+
134+
135+ ``` python
136+ # Hamiltonian conservation (bend)
137+
138+ (qx, px, qy, py) = x = torch.tensor([0.01 , 0.001 , - 0.005 , 0.0005 ], dtype = torch.float64)
139+ qs = torch.stack([qx, qy])
140+ ps = torch.stack([px, py])
141+ hi = h_bend(qs, ps, ti, rd, kn, ks, dp)
142+
143+ (qx, px, qy, py) = bend(x, rd, kn, ks, dp, dt)
144+ qs = torch.stack([qx, qy])
145+ ps = torch.stack([px, py])
146+ hf = h_bend(qs, ps, ti, rd, kn, ks, dp)
147+
148+ print (torch.allclose(hi, hf, rtol = 1.0E-16 , atol = 1.0E-16 ))
149+ ```
150+
151+ True
152+
153+
154+
155+ ``` python
156+ # To illustrate (multi-map) split and (Yoshida) composition explicit symplectic integrator consider the following split
157+ # h = h1 + h2 = 1/2*(px**2 + py**2)/(1 + dp) - qx*dp/rd + qx**2/(2*rd**2) + 1/2*kn*(qx**2 - qy**2) - ks*qx*qy
158+ # h1 = 1/2*(px**2 + py**2)/(1 + dp)
159+ # qx = qx + dt*px/(1 + dp)
160+ # px = px
161+ # qy = qy + dt*py/(1 + dp)
162+ # py = py
163+ # h2 = - qx*dp/rd + qx**2/(2*rd**2) + 1/2*kn*(qx**2 - qy**2) - ks*qx*qy
164+ # qx = qx
165+ # px = px + dt*(dp/rd - qx/rd**2 - kn*qx + ks*qy)
166+ # qy = qy
167+ # py = py + dt*(kn*qy + ks*qx)
168+
169+ def fa (x , dt , rd , kn , ks , dp ):
170+ qx, qy, px, py = x
171+ return jax.numpy.stack([qx + dt* px/ (1 + dp), qy + dt* py/ (1 + dp), px, py])
172+
173+ def fb (x , dt , rd , kn , ks , dp ):
174+ qx, qy, px, py = x
175+ return jax.numpy.stack([qx, qy, px + dt* (dp/ rd - qx/ rd** 2 - kn* qx + ks* qy), py + dt* (kn* qy + ks* qx)])
176+ ```
177+
178+
179+ ``` python
180+ # Yoshida (bend)
181+
182+ # Generate integration step
183+
184+ step = fold(sequence(0 , 1 , [fa, fb], merge = True ))
185+
186+ # Evaluate integration step
187+
188+ (qx, px, qy, py) = x = torch.tensor([0.01 , 0.001 , - 0.005 , 0.0005 ], dtype = torch.float64)
189+ qs = torch.stack([qx, qy])
190+ ps = torch.stack([px, py])
191+ qsps = jax.numpy.array(torch.hstack([qs, ps]).tolist())
192+ qs, ps = qsps.reshape(2 , - 1 )
193+ print (h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))
194+
195+ qsps = step(qsps, dt.item(), rd.item(), kn.item(), ks.item(), dp.item())
196+ qs, ps = qsps.reshape(2 , - 1 )
197+ print (h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))
198+ print ()
199+
200+ # Evaluate exact solution
201+
202+ (qx, px, qy, py) = x = torch.tensor([0.01 , 0.001 , - 0.005 , 0.0005 ], dtype = torch.float64)
203+ qs = torch.stack([qx, qy])
204+ ps = torch.stack([px, py])
205+ QsPs = jax.numpy.array(torch.hstack([qs, ps]).tolist())
206+ qs, ps = QsPs.reshape(2 , - 1 )
207+ print (h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))
208+
209+ (qx, px, qy, py) = bend(x, rd, kn, ks, dp, dt)
210+ qs = torch.stack([qx, qy])
211+ ps = torch.stack([px, py])
212+ QsPs = jax.numpy.array(torch.hstack([qs, ps]).tolist())
213+ qs, ps = QsPs.reshape(2 , - 1 )
214+ print (h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))
215+ print ()
216+
217+ # Compare
218+
219+ print (qsps)
220+ print (QsPs)
221+ print (jax.numpy.linalg.norm(qsps - QsPs))
222+ ```
223+
224+ 8.030437562437563e-05
225+ 8.030437539389926e-05
226+
227+ 8.030437562437563e-05
228+ 8.03043756243756e-05
229+
230+ [ 0.00999747 -0.0049949 -0.00105058 -0.0003978 ]
231+ [ 0.00999746 -0.00499491 -0.00105066 -0.00039784]
232+ 9.384719084898185e-08
233+
234+
235+
236+ ``` python
237+ # Midpoint (bend)
238+
239+ # Generate integration step
240+
241+ step = fold(sequence(0 , 1 , [midpoint(h_bend, ns = 1 )], merge = False ))
242+
243+ # Evaluate integration step
244+
245+ (qx, px, qy, py) = x = torch.tensor([0.01 , 0.001 , - 0.005 , 0.0005 ], dtype = torch.float64)
246+ qs = torch.stack([qx, qy])
247+ ps = torch.stack([px, py])
248+ qsps = jax.numpy.array(torch.hstack([qs, ps]).tolist())
249+ qs, ps = qsps.reshape(2 , - 1 )
250+ print (h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))
251+
252+ qsps = step(qsps, dt.item(), ti.item(), rd.item(), kn.item(), ks.item(), dp.item())
253+ qs, ps = qsps.reshape(2 , - 1 )
254+ print (h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))
255+ print ()
256+
257+ # Evaluate exact solution
258+
259+ (qx, px, qy, py) = x = torch.tensor([0.01 , 0.001 , - 0.005 , 0.0005 ], dtype = torch.float64)
260+ qs = torch.stack([qx, qy])
261+ ps = torch.stack([px, py])
262+ QsPs = jax.numpy.array(torch.hstack([qs, ps]).tolist())
263+ qs, ps = QsPs.reshape(2 , - 1 )
264+ print (h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))
265+
266+ (qx, px, qy, py) = bend(x, rd, kn, ks, dp, dt)
267+ qs = torch.stack([qx, qy])
268+ ps = torch.stack([px, py])
269+ QsPs = jax.numpy.array(torch.hstack([qs, ps]).tolist())
270+ qs, ps = QsPs.reshape(2 , - 1 )
271+ print (h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))
272+ print ()
273+
274+ # Compare
275+
276+ print (qsps)
277+ print (QsPs)
278+ print (jax.numpy.linalg.norm(qsps - QsPs))
279+ ```
280+
281+ 8.030437562437563e-05
282+ 8.030437562437561e-05
283+
284+ 8.030437562437563e-05
285+ 8.03043756243756e-05
286+
287+ [ 0.00999747 -0.0049949 -0.00105061 -0.00039781]
288+ [ 0.00999746 -0.00499491 -0.00105066 -0.00039784]
289+ 5.8710763609389174e-08
290+
291+
292+
293+ ``` python
294+ # Tao (bend)
295+
296+ # Generate integration step
297+
298+ step = fold(sequence(0 , 1 , [tao(h_bend, binding = 0.0 )], merge = False ))
299+
300+ # Evaluate integration step
301+
302+ (qx, px, qy, py) = x = torch.tensor([0.01 , 0.001 , - 0.005 , 0.0005 ], dtype = torch.float64)
303+ qs = torch.stack([qx, qy])
304+ ps = torch.stack([px, py])
305+ qsps = jax.numpy.array(torch.hstack([qs, ps]).tolist())
306+ qs, ps = qsps.reshape(2 , - 1 )
307+ print (h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))
308+
309+ qsps = step(qsps, dt.item(), ti.item(), rd.item(), kn.item(), ks.item(), dp.item())
310+ qs, ps = qsps.reshape(2 , - 1 )
311+ print (h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))
312+ print ()
313+
314+ # Evaluate exact solution
315+
316+ (qx, px, qy, py) = x = torch.tensor([0.01 , 0.001 , - 0.005 , 0.0005 ], dtype = torch.float64)
317+ qs = torch.stack([qx, qy])
318+ ps = torch.stack([px, py])
319+ QsPs = jax.numpy.array(torch.hstack([qs, ps]).tolist())
320+ qs, ps = QsPs.reshape(2 , - 1 )
321+ print (h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))
322+
323+ (qx, px, qy, py) = bend(x, rd, kn, ks, dp, dt)
324+ qs = torch.stack([qx, qy])
325+ ps = torch.stack([px, py])
326+ QsPs = jax.numpy.array(torch.hstack([qs, ps]).tolist())
327+ qs, ps = QsPs.reshape(2 , - 1 )
328+ print (h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))
329+ print ()
330+
331+ # Compare
332+
333+ print (qsps)
334+ print (QsPs)
335+ print (jax.numpy.linalg.norm(qsps - QsPs))
336+ ```
337+
338+ 8.030437562437563e-05
339+ 8.030437585721544e-05
340+
341+ 8.030437562437563e-05
342+ 8.03043756243756e-05
343+
344+ [ 0.00999747 -0.0049949 -0.00105064 -0.00039783]
345+ [ 0.00999746 -0.00499491 -0.00105066 -0.00039784]
346+ 2.50766882131807e-08
347+
348+
349+
350+ ``` python
351+
352+ ```
353+
354+
0 commit comments