Skip to content

Commit c6cf701

Browse files
committed
13_03_2025: added a mini demo
1 parent 1202c1b commit c6cf701

File tree

4 files changed

+837
-1
lines changed

4 files changed

+837
-1
lines changed

README.MD

Lines changed: 340 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,343 @@ $ pip install git+https://github.com/i-a-morozov/sympint.git@main
1212

1313
[https://i-a-morozov.github.io/sympint/](https://i-a-morozov.github.io/sympint/)
1414

15+
# Demo
16+
17+
18+
```python
19+
# In this demo construction of symplectic integrators is illustated for basic accelerator elements
20+
```
21+
22+
23+
```python
24+
# Import
25+
26+
import torch
27+
import jax
28+
29+
# Exact solutions
30+
31+
from model.library.transformations import drift
32+
from model.library.transformations import quadrupole
33+
from model.library.transformations import bend
34+
35+
# Function iterations
36+
37+
from sympint import nest
38+
from sympint import fold
39+
40+
# Integrators and composer
41+
42+
from sympint import sequence
43+
from sympint import midpoint
44+
from sympint import tao
45+
```
46+
47+
48+
```python
49+
# Set data type
50+
51+
jax.config.update("jax_enable_x64", True)
52+
```
53+
54+
55+
```python
56+
# Set device
57+
58+
device, *_ = jax.devices('cpu')
59+
jax.config.update('jax_default_device', device)
60+
```
61+
62+
63+
```python
64+
# Define Hamiltonial functions for accelerator elements
65+
66+
def h_drif(qs, ps, t, dp, *args):
67+
qx, qy = qs
68+
px, py = ps
69+
return 1/2*(px**2 + py**2)/(1 + dp)
70+
71+
def h_quad(qs, ps, t, kn, ks, dp, *args):
72+
qx, qy = qs
73+
px, py = ps
74+
return 1/2*(px**2 + py**2)/(1 + dp) + 1/2*kn*(qx**2 - qy**2) - ks*qx*qy
75+
76+
def h_bend(qs, ps, t, rd, kn, ks, dp, *args):
77+
qx, qy = qs
78+
px, py = ps
79+
return 1/2*(px**2 + py**2)/(1 + dp) - qx*dp/rd + qx**2/(2*rd**2) + 1/2*kn*(qx**2 - qy**2) - ks*qx*qy
80+
```
81+
82+
83+
```python
84+
# Set parameters
85+
86+
ti = torch.tensor(0.0, dtype=torch.float64)
87+
dt = torch.tensor(0.1, dtype=torch.float64)
88+
rd = torch.tensor(25.0, dtype=torch.float64)
89+
kn = torch.tensor(2.0, dtype=torch.float64)
90+
ks = torch.tensor(0.1, dtype=torch.float64)
91+
dp = torch.tensor(0.001, dtype=torch.float64)
92+
```
93+
94+
95+
```python
96+
# Hamiltonian conservation (drif)
97+
98+
(qx, px, qy, py) = x = torch.tensor([0.01, 0.001, -0.005, 0.0005], dtype=torch.float64)
99+
qs = torch.stack([qx, qy])
100+
ps = torch.stack([px, py])
101+
hi = h_drif(qs, ps, ti, dp)
102+
103+
(qx, px, qy, py) = drift(x, dp, dt)
104+
qs = torch.stack([qx, qy])
105+
ps = torch.stack([px, py])
106+
hf = h_drif(qs, ps, ti, dp)
107+
108+
print(torch.allclose(hi, hf, rtol=1.0E-16, atol=1.0E-16))
109+
```
110+
111+
True
112+
113+
114+
115+
```python
116+
# Hamiltonian conservation (quad)
117+
118+
(qx, px, qy, py) = x = torch.tensor([0.01, 0.001, -0.005, 0.0005], dtype=torch.float64)
119+
qs = torch.stack([qx, qy])
120+
ps = torch.stack([px, py])
121+
hi = h_quad(qs, ps, ti, kn, ks, dp)
122+
123+
(qx, px, qy, py) = quadrupole(x, kn, ks, dp, dt)
124+
qs = torch.stack([qx, qy])
125+
ps = torch.stack([px, py])
126+
hf = h_quad(qs, ps, ti, kn, ks, dp)
127+
128+
print(torch.allclose(hi, hf, rtol=1.0E-16, atol=1.0E-16))
129+
```
130+
131+
True
132+
133+
134+
135+
```python
136+
# Hamiltonian conservation (bend)
137+
138+
(qx, px, qy, py) = x = torch.tensor([0.01, 0.001, -0.005, 0.0005], dtype=torch.float64)
139+
qs = torch.stack([qx, qy])
140+
ps = torch.stack([px, py])
141+
hi = h_bend(qs, ps, ti, rd, kn, ks, dp)
142+
143+
(qx, px, qy, py) = bend(x, rd, kn, ks, dp, dt)
144+
qs = torch.stack([qx, qy])
145+
ps = torch.stack([px, py])
146+
hf = h_bend(qs, ps, ti, rd, kn, ks, dp)
147+
148+
print(torch.allclose(hi, hf, rtol=1.0E-16, atol=1.0E-16))
149+
```
150+
151+
True
152+
153+
154+
155+
```python
156+
# To illustrate (multi-map) split and (Yoshida) composition explicit symplectic integrator consider the following split
157+
# h = h1 + h2 = 1/2*(px**2 + py**2)/(1 + dp) - qx*dp/rd + qx**2/(2*rd**2) + 1/2*kn*(qx**2 - qy**2) - ks*qx*qy
158+
# h1 = 1/2*(px**2 + py**2)/(1 + dp)
159+
# qx = qx + dt*px/(1 + dp)
160+
# px = px
161+
# qy = qy + dt*py/(1 + dp)
162+
# py = py
163+
# h2 = - qx*dp/rd + qx**2/(2*rd**2) + 1/2*kn*(qx**2 - qy**2) - ks*qx*qy
164+
# qx = qx
165+
# px = px + dt*(dp/rd - qx/rd**2 - kn*qx + ks*qy)
166+
# qy = qy
167+
# py = py + dt*(kn*qy + ks*qx)
168+
169+
def fa(x, dt, rd, kn, ks, dp):
170+
qx, qy, px, py = x
171+
return jax.numpy.stack([qx + dt*px/(1 + dp), qy + dt*py/(1 + dp), px, py])
172+
173+
def fb(x, dt, rd, kn, ks, dp):
174+
qx, qy, px, py = x
175+
return jax.numpy.stack([qx, qy, px + dt*(dp/rd - qx/rd**2 - kn*qx + ks*qy), py + dt*(kn*qy + ks*qx)])
176+
```
177+
178+
179+
```python
180+
# Yoshida (bend)
181+
182+
# Generate integration step
183+
184+
step = fold(sequence(0, 1, [fa, fb], merge=True))
185+
186+
# Evaluate integration step
187+
188+
(qx, px, qy, py) = x = torch.tensor([0.01, 0.001, -0.005, 0.0005], dtype=torch.float64)
189+
qs = torch.stack([qx, qy])
190+
ps = torch.stack([px, py])
191+
qsps = jax.numpy.array(torch.hstack([qs, ps]).tolist())
192+
qs, ps = qsps.reshape(2, -1)
193+
print(h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))
194+
195+
qsps = step(qsps, dt.item(), rd.item(), kn.item(), ks.item(), dp.item())
196+
qs, ps = qsps.reshape(2, -1)
197+
print(h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))
198+
print()
199+
200+
# Evaluate exact solution
201+
202+
(qx, px, qy, py) = x = torch.tensor([0.01, 0.001, -0.005, 0.0005], dtype=torch.float64)
203+
qs = torch.stack([qx, qy])
204+
ps = torch.stack([px, py])
205+
QsPs = jax.numpy.array(torch.hstack([qs, ps]).tolist())
206+
qs, ps = QsPs.reshape(2, -1)
207+
print(h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))
208+
209+
(qx, px, qy, py) = bend(x, rd, kn, ks, dp, dt)
210+
qs = torch.stack([qx, qy])
211+
ps = torch.stack([px, py])
212+
QsPs = jax.numpy.array(torch.hstack([qs, ps]).tolist())
213+
qs, ps = QsPs.reshape(2, -1)
214+
print(h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))
215+
print()
216+
217+
# Compare
218+
219+
print(qsps)
220+
print(QsPs)
221+
print(jax.numpy.linalg.norm(qsps - QsPs))
222+
```
223+
224+
8.030437562437563e-05
225+
8.030437539389926e-05
226+
227+
8.030437562437563e-05
228+
8.03043756243756e-05
229+
230+
[ 0.00999747 -0.0049949 -0.00105058 -0.0003978 ]
231+
[ 0.00999746 -0.00499491 -0.00105066 -0.00039784]
232+
9.384719084898185e-08
233+
234+
235+
236+
```python
237+
# Midpoint (bend)
238+
239+
# Generate integration step
240+
241+
step = fold(sequence(0, 1, [midpoint(h_bend, ns=1)], merge=False))
242+
243+
# Evaluate integration step
244+
245+
(qx, px, qy, py) = x = torch.tensor([0.01, 0.001, -0.005, 0.0005], dtype=torch.float64)
246+
qs = torch.stack([qx, qy])
247+
ps = torch.stack([px, py])
248+
qsps = jax.numpy.array(torch.hstack([qs, ps]).tolist())
249+
qs, ps = qsps.reshape(2, -1)
250+
print(h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))
251+
252+
qsps = step(qsps, dt.item(), ti.item(), rd.item(), kn.item(), ks.item(), dp.item())
253+
qs, ps = qsps.reshape(2, -1)
254+
print(h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))
255+
print()
256+
257+
# Evaluate exact solution
258+
259+
(qx, px, qy, py) = x = torch.tensor([0.01, 0.001, -0.005, 0.0005], dtype=torch.float64)
260+
qs = torch.stack([qx, qy])
261+
ps = torch.stack([px, py])
262+
QsPs = jax.numpy.array(torch.hstack([qs, ps]).tolist())
263+
qs, ps = QsPs.reshape(2, -1)
264+
print(h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))
265+
266+
(qx, px, qy, py) = bend(x, rd, kn, ks, dp, dt)
267+
qs = torch.stack([qx, qy])
268+
ps = torch.stack([px, py])
269+
QsPs = jax.numpy.array(torch.hstack([qs, ps]).tolist())
270+
qs, ps = QsPs.reshape(2, -1)
271+
print(h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))
272+
print()
273+
274+
# Compare
275+
276+
print(qsps)
277+
print(QsPs)
278+
print(jax.numpy.linalg.norm(qsps - QsPs))
279+
```
280+
281+
8.030437562437563e-05
282+
8.030437562437561e-05
283+
284+
8.030437562437563e-05
285+
8.03043756243756e-05
286+
287+
[ 0.00999747 -0.0049949 -0.00105061 -0.00039781]
288+
[ 0.00999746 -0.00499491 -0.00105066 -0.00039784]
289+
5.8710763609389174e-08
290+
291+
292+
293+
```python
294+
# Tao (bend)
295+
296+
# Generate integration step
297+
298+
step = fold(sequence(0, 1, [tao(h_bend, binding=0.0)], merge=False))
299+
300+
# Evaluate integration step
301+
302+
(qx, px, qy, py) = x = torch.tensor([0.01, 0.001, -0.005, 0.0005], dtype=torch.float64)
303+
qs = torch.stack([qx, qy])
304+
ps = torch.stack([px, py])
305+
qsps = jax.numpy.array(torch.hstack([qs, ps]).tolist())
306+
qs, ps = qsps.reshape(2, -1)
307+
print(h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))
308+
309+
qsps = step(qsps, dt.item(), ti.item(), rd.item(), kn.item(), ks.item(), dp.item())
310+
qs, ps = qsps.reshape(2, -1)
311+
print(h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))
312+
print()
313+
314+
# Evaluate exact solution
315+
316+
(qx, px, qy, py) = x = torch.tensor([0.01, 0.001, -0.005, 0.0005], dtype=torch.float64)
317+
qs = torch.stack([qx, qy])
318+
ps = torch.stack([px, py])
319+
QsPs = jax.numpy.array(torch.hstack([qs, ps]).tolist())
320+
qs, ps = QsPs.reshape(2, -1)
321+
print(h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))
322+
323+
(qx, px, qy, py) = bend(x, rd, kn, ks, dp, dt)
324+
qs = torch.stack([qx, qy])
325+
ps = torch.stack([px, py])
326+
QsPs = jax.numpy.array(torch.hstack([qs, ps]).tolist())
327+
qs, ps = QsPs.reshape(2, -1)
328+
print(h_bend(qs, ps, ti.item(), rd.item(), kn.item(), ks.item(), dp.item()))
329+
print()
330+
331+
# Compare
332+
333+
print(qsps)
334+
print(QsPs)
335+
print(jax.numpy.linalg.norm(qsps - QsPs))
336+
```
337+
338+
8.030437562437563e-05
339+
8.030437585721544e-05
340+
341+
8.030437562437563e-05
342+
8.03043756243756e-05
343+
344+
[ 0.00999747 -0.0049949 -0.00105064 -0.00039783]
345+
[ 0.00999746 -0.00499491 -0.00105066 -0.00039784]
346+
2.50766882131807e-08
347+
348+
349+
350+
```python
351+
352+
```
353+
354+

0 commit comments

Comments
 (0)