Skip to content

Commit 82e5109

Browse files
committed
06_27_2025: cleanup and docs update
1 parent 993e1c6 commit 82e5109

File tree

4 files changed

+79
-37
lines changed

4 files changed

+79
-37
lines changed

docs/source/examples/example-01.ipynb

Lines changed: 25 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -11,18 +11,17 @@
1111
{
1212
"cell_type": "code",
1313
"execution_count": 1,
14-
"id": "9aca811d-d252-4872-b1d2-84eb2ffb50ef",
14+
"id": "e5b0ff81-f09d-4772-b3b5-a5fa89183572",
1515
"metadata": {},
1616
"outputs": [],
1717
"source": [
18-
"# In this example non-autonomous hamiltonial is integrated using midpoint and tao integrators\n",
19-
"# For explicitly defined hamiltonian function and factory generated one"
18+
"# In this example non-autonomous hamiltonial is integrated"
2019
]
2120
},
2221
{
2322
"cell_type": "code",
2423
"execution_count": 2,
25-
"id": "936385d7-e524-4ceb-b98b-781cd68e9bb6",
24+
"id": "8d36ea78-e74a-4b7c-bea9-b3fb3e912f2d",
2625
"metadata": {},
2726
"outputs": [],
2827
"source": [
@@ -31,7 +30,7 @@
3130
"import jax\n",
3231
"from jax import Array\n",
3332
"from jax import jit\n",
34-
"from jax import vmap\n",
33+
"from jax import jacrev\n",
3534
"\n",
3635
"from elementary import fold\n",
3736
"from elementary import nest\n",
@@ -48,7 +47,7 @@
4847
{
4948
"cell_type": "code",
5049
"execution_count": 3,
51-
"id": "702bdf71-e4d6-49b8-bd42-42a68d52cb21",
50+
"id": "b8fc02ee-9f08-4456-92fe-d561a276a2c0",
5251
"metadata": {},
5352
"outputs": [],
5453
"source": [
@@ -60,7 +59,7 @@
6059
{
6160
"cell_type": "code",
6261
"execution_count": 4,
63-
"id": "fa15f4d1-0702-4969-bafd-63fb23793950",
62+
"id": "dc2e0cd9-8be4-4f2c-a40a-335c41935958",
6463
"metadata": {},
6564
"outputs": [],
6665
"source": [
@@ -73,7 +72,7 @@
7372
{
7473
"cell_type": "code",
7574
"execution_count": 5,
76-
"id": "bb56ad88-26ed-465f-b141-d581cb5cd324",
75+
"id": "fd191c35-b7ce-4038-86dd-2a0f1a56e475",
7776
"metadata": {},
7877
"outputs": [],
7978
"source": [
@@ -87,7 +86,7 @@
8786
{
8887
"cell_type": "code",
8988
"execution_count": 6,
90-
"id": "43d3146a-ae63-4ad0-9971-80c40c8d7986",
89+
"id": "d0bb3c3f-eea3-4894-8a7a-79dcd729b369",
9190
"metadata": {},
9291
"outputs": [],
9392
"source": [
@@ -101,7 +100,7 @@
101100
{
102101
"cell_type": "code",
103102
"execution_count": 7,
104-
"id": "3b57db3d-8b24-4220-bd76-34394cf77dcb",
103+
"id": "d957559c-df5d-430f-b080-74e2327a15c7",
105104
"metadata": {},
106105
"outputs": [],
107106
"source": [
@@ -121,7 +120,7 @@
121120
{
122121
"cell_type": "code",
123122
"execution_count": 8,
124-
"id": "3cdc8a98-2453-4378-b79d-aeb4a5405f25",
123+
"id": "ec7439c5-33c0-4b90-841d-5cb91de3e615",
125124
"metadata": {},
126125
"outputs": [],
127126
"source": [
@@ -135,7 +134,7 @@
135134
{
136135
"cell_type": "code",
137136
"execution_count": 9,
138-
"id": "b1a8cb34-eb7b-4e2e-9b97-99dbee504c36",
137+
"id": "7bf4d7b0-8f7f-4f56-8a37-da51aaa8a8f9",
139138
"metadata": {},
140139
"outputs": [],
141140
"source": [
@@ -147,7 +146,7 @@
147146
{
148147
"cell_type": "code",
149148
"execution_count": 10,
150-
"id": "4c99d4c6-da3b-4330-a48e-bf7f6799895e",
149+
"id": "2cddbce9-6bdc-4775-b704-77a8751fe0d5",
151150
"metadata": {},
152151
"outputs": [
153152
{
@@ -169,7 +168,7 @@
169168
{
170169
"cell_type": "code",
171170
"execution_count": 11,
172-
"id": "a6bfe0bf-dc77-41f7-9553-85c683c42df5",
171+
"id": "19999d5f-4812-49a5-947d-d8ab2dd23921",
173172
"metadata": {},
174173
"outputs": [],
175174
"source": [
@@ -181,7 +180,7 @@
181180
{
182181
"cell_type": "code",
183182
"execution_count": 12,
184-
"id": "a0290e60-2679-457e-932d-49e018ad22ff",
183+
"id": "0197ddbc-306e-473e-8efb-ef7350151b65",
185184
"metadata": {},
186185
"outputs": [
187186
{
@@ -203,13 +202,13 @@
203202
{
204203
"cell_type": "code",
205204
"execution_count": 13,
206-
"id": "f0f85ac6-0edc-4888-83f3-f13aa9fa654e",
205+
"id": "bed740e9-90ae-41cd-9497-4515a1357de8",
207206
"metadata": {},
208207
"outputs": [],
209208
"source": [
210209
"# Define non-autonomous and extended hamiltonian (factory)\n",
211210
"\n",
212-
"def vector(qs:Array, s:Array, kn:Array, *args:Array) -> Array:\n",
211+
"def vector(qs:Array, s:Array, kn:Array, *args:Array) -> tuple[Array, Array, Array]:\n",
213212
" q_x, q_y, q_s = qs\n",
214213
" a_x, a_y, a_s = jax.numpy.zeros_like(qs)\n",
215214
" a_s = - 1/2*kn*(1 + jax.numpy.cos(s))*(q_x**2 + q_y**2)\n",
@@ -227,7 +226,7 @@
227226
{
228227
"cell_type": "code",
229228
"execution_count": 14,
230-
"id": "7281a8c6-e48e-412c-8a7a-4171568cfc3d",
229+
"id": "e0d482f4-339d-447b-b9e4-bf4ea9affe9a",
231230
"metadata": {},
232231
"outputs": [],
233232
"source": [
@@ -241,7 +240,7 @@
241240
{
242241
"cell_type": "code",
243242
"execution_count": 15,
244-
"id": "0f0d5d1c-e319-4903-870b-5a5c889e01ca",
243+
"id": "7c4b69cc-a971-496c-b06a-60e5a5523bf4",
245244
"metadata": {},
246245
"outputs": [],
247246
"source": [
@@ -253,7 +252,7 @@
253252
{
254253
"cell_type": "code",
255254
"execution_count": 16,
256-
"id": "7ee34eb9-5831-499d-acb2-2d7bb47845ff",
255+
"id": "da3b5a65-4622-48ad-98bf-5aeb105efe36",
257256
"metadata": {},
258257
"outputs": [
259258
{
@@ -275,7 +274,7 @@
275274
{
276275
"cell_type": "code",
277276
"execution_count": 17,
278-
"id": "f3e51caf-9d1d-4e4f-9c13-f091eddc705f",
277+
"id": "82028949-c38d-44a5-951f-adb343d576e3",
279278
"metadata": {},
280279
"outputs": [],
281280
"source": [
@@ -287,7 +286,7 @@
287286
{
288287
"cell_type": "code",
289288
"execution_count": 18,
290-
"id": "a7bfdfe6-9287-4301-9db8-b28ab3482e88",
289+
"id": "2fe48996-c406-4f12-af48-bcb6ea190787",
291290
"metadata": {},
292291
"outputs": [
293292
{
@@ -309,7 +308,7 @@
309308
{
310309
"cell_type": "code",
311310
"execution_count": 19,
312-
"id": "3d4ed639-5ede-4c3c-bd0a-29faca01499e",
311+
"id": "951dec84-88db-49fd-980d-3712ee53dac9",
313312
"metadata": {},
314313
"outputs": [
315314
{
@@ -333,13 +332,13 @@
333332
"source": [
334333
"# Diffirentiability (initial condition)\n",
335334
"\n",
336-
"jax.jacrev(element)(X, ds, si, kn)"
335+
"jacrev(element)(X, ds, si, kn)"
337336
]
338337
},
339338
{
340339
"cell_type": "code",
341340
"execution_count": 20,
342-
"id": "a0944868-7d08-4841-9a51-3d71fdec59b5",
341+
"id": "6d84bd97-6ce3-418f-89be-6fb842d1c58d",
343342
"metadata": {},
344343
"outputs": [
345344
{
@@ -356,7 +355,7 @@
356355
"source": [
357356
"# Diffirentiability (parameter)\n",
358357
"\n",
359-
"jax.jacrev(element, argnums=-1)(X, ds, si, kn)"
358+
"jacrev(element, argnums=-1)(X, ds, si, kn)"
360359
]
361360
}
362361
],

docs/source/index.rst

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,56 @@
11
Welcome to elementary's documentation!
22
===================================
33

4-
Generic differentiable accelerator elements modeling in JAX
4+
Generic differentiable accelerator elements modeling in JAX.
5+
Single particle hamiltonian function:
6+
7+
.. math::
8+
9+
H(q_x, q_y, q_s, p_x, p_y, p_s; s) = p_s/\beta - t(s)(q_x p_y - q_y p_x) - (1 + h(s) q_x) (\sqrt{(p_s + 1/beta - \varphi(qs, s))^2 - (p_x - a_x(qs, s))^2 - (p_y - a_y(qs, s))^2 - 1/(\beta \gamma)^2} - a_s)
10+
11+
where :math:`\beta` and :math:`\gamma` are the relativistic factors, :math:`h(s)` is the reference trajectory curvature and :math:`t(s)` is the reference trajectory torsion, :math:`a_x(qs, s)`, :math:`a_y(qs, s)` and :math:`a_s(qs, s)` are the scaled vector potential components, and :math:`\varphi(qs, s)` is the scaled scalar potential.
12+
Additionaly, longitudinal coordinate and momentum are given by:
13+
14+
.. math::
15+
q_s = \frac{s}{\beta} - c t
16+
p_s = \frac{E}{c P} - 1/beta
17+
18+
Common predefined elements are available or you can create your own by specifying scaled potentials and reference trajectory parameters (curvature and torsion).
19+
All but vector potentials arguments are optional for hamiltonian and element construction.
20+
Vector and scalar potentials are assumed to have matching signatures.
21+
22+
.. code-block:: python
23+
def vector(qs:Array, s:Array, *args:Array) -> tuple[Array, Array, Array]:
24+
q_x, q_y, q_s = qs
25+
...
26+
27+
def scalar(qs:Array, s:Array, *args:Array) -> Array:
28+
q_x, q_y, q_s = qs
29+
...
30+
31+
This is also the case for curvature and torsion functions.
32+
Note, same extra arguments as in vector and scalar functions should be passed.
33+
34+
.. code-block:: python
35+
def curvature(s:Array, *args:Array) -> Array:
36+
...
37+
38+
def torsion(s:Array, *args:Array) -> Array:
39+
...
40+
41+
The resulting hamiltonian and element signatures are:
42+
43+
.. code-block:: python
44+
def hamiltonian(qs: Array, ps: Array, s: Array, *args: Array) -> Array:
45+
q_x, q_y, q_s = qs
46+
p_x, p_y, p_s = ps
47+
...
48+
49+
def element(qsps:Array, length:Array, start:Array, *args:Array) -> Array:
50+
qs, ps = jax.numpy.reshape(qsps, (2, -1))
51+
q_x, q_y, q_s = qs
52+
p_x, p_y, p_s = ps
53+
...
554
655
.. toctree::
756
:caption: Examples:

elementary/dipole.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ def vector(qs:Array,
9090
final=final)
9191
def dipole(qsps, length, angle, *args):
9292
r = jax.numpy.abs(length)/angle
93-
return element(qsps, length, 0.0, r, *args)
93+
start = jax.numpy.zeros_like(length)
94+
return element(qsps, length, start, r, *args)
9495
return dipole
9596

9697

elementary/hamiltonian.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
"""
1+
r"""
22
Hamiltonian
33
-----------
44
5-
Generic single particle Hamiltonian factory
5+
Generic single particle accelerator Hamiltonian factory
66
77
"""
88
from typing import Callable
@@ -21,13 +21,6 @@ def hamiltonian_factory(vector:Callable[..., tuple[Array, Array, Array]],
2121
"""
2222
Generic single particle Hamiltonian factory
2323
24-
H = H(qs, ps; s) = H(q_x, q_y, q_s, p_x, p_y, p_s; s)
25-
H = p_s/beta - torsion(s)*(q_x*p_y - q_y*p_x) - (1 + curvature(s)*q_x)*(root + a_s(qs, s))
26-
root = sqrt(P_s**2 - P_x**2 - P_y**2 - 1/(beta*gamma)**2)
27-
P_s = p_s + 1/beta - scalar(qs, s)
28-
P_x = p_x - a_x(qs, s)
29-
P_y = p_y - a_y(qs, s)
30-
3124
Parameters
3225
----------
3326
vector: Callable[..., tuple[Array, Array, Array]]

0 commit comments

Comments
 (0)