Skip to content

Commit 1f1a15d

Browse files
committed
06_20_2025: non-autonomous hamiltonian integration example
1 parent a120b71 commit 1f1a15d

File tree

2 files changed

+279
-0
lines changed

2 files changed

+279
-0
lines changed
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
{
2+
"cells": [
3+
{
4+
"attachments": {},
5+
"cell_type": "markdown",
6+
"id": "556562f3-8ece-4517-8c93-ee5e2fc29131",
7+
"metadata": {},
8+
"source": [
9+
"# Example-06: Non-autonomous hamiltonian integration"
10+
]
11+
},
12+
{
13+
"cell_type": "code",
14+
"execution_count": 1,
15+
"id": "c9e9a3dc-a1c5-4f5b-b40d-dc97f4759d40",
16+
"metadata": {},
17+
"outputs": [],
18+
"source": [
19+
"# In this example integration of non-autonomous hamiltonian is illustrated\n",
20+
"# Such integration has only limmited support, since function iteration tools do not carry time\n",
21+
"# Thus, only one second order integration step can be performed and time should be adjusted manually after each step, i.e. using normal python loop or custom scan body\n",
22+
"\n",
23+
"# Support for more general case would require to modife function iterations, for example, instead of the following loop:\n",
24+
"# for _ in range(n): x = f(x, *args)\n",
25+
"# nesting should correspond to:\n",
26+
"# for _ in range(n): x = f(x, dt, t, *args) ; t = t + dt\n",
27+
"# Similary, fold (and other functions)should be modified to carry time\n",
28+
"\n",
29+
"# Instead, it is possible to use extended phase space with midpoint or tao integrators"
30+
]
31+
},
32+
{
33+
"cell_type": "code",
34+
"execution_count": 2,
35+
"id": "d92f8ee7-ea12-4a11-921d-a0b44783f3ac",
36+
"metadata": {},
37+
"outputs": [],
38+
"source": [
39+
"# Import \n",
40+
"\n",
41+
"import jax\n",
42+
"from jax import Array\n",
43+
"from jax import jit\n",
44+
"from jax import vmap\n",
45+
"\n",
46+
"from sympint import fold\n",
47+
"from sympint import nest\n",
48+
"from sympint import midpoint\n",
49+
"from sympint import sequence\n",
50+
"\n",
51+
"jax.numpy.set_printoptions(linewidth=256, precision=12)"
52+
]
53+
},
54+
{
55+
"cell_type": "code",
56+
"execution_count": 3,
57+
"id": "24d31165-c4b0-4e24-bd10-25695e6e008a",
58+
"metadata": {},
59+
"outputs": [],
60+
"source": [
61+
"# Set data type\n",
62+
"\n",
63+
"jax.config.update(\"jax_enable_x64\", True)"
64+
]
65+
},
66+
{
67+
"cell_type": "code",
68+
"execution_count": 4,
69+
"id": "0bbcbcc1-5fd7-4dc5-91d2-04157d0774d1",
70+
"metadata": {},
71+
"outputs": [],
72+
"source": [
73+
"# Set device\n",
74+
"\n",
75+
"device, *_ = jax.devices('cpu')\n",
76+
"jax.config.update('jax_default_device', device)"
77+
]
78+
},
79+
{
80+
"cell_type": "code",
81+
"execution_count": 5,
82+
"id": "f1208433-15a8-403d-b1ae-45d8f5b00b6f",
83+
"metadata": {},
84+
"outputs": [],
85+
"source": [
86+
"# Set parameters\n",
87+
"\n",
88+
"si = jax.numpy.array(0.0)\n",
89+
"ds = jax.numpy.array(0.01)\n",
90+
"kn = jax.numpy.array(1.0)"
91+
]
92+
},
93+
{
94+
"cell_type": "code",
95+
"execution_count": 6,
96+
"id": "4d85b8da-a6f6-425d-8b13-069af52e8541",
97+
"metadata": {},
98+
"outputs": [],
99+
"source": [
100+
"# Set initial condition\n",
101+
"\n",
102+
"qs = jax.numpy.array([0.1, 0.1])\n",
103+
"ps = jax.numpy.array([0.0, 0.0])\n",
104+
"x = jax.numpy.hstack([qs, ps])"
105+
]
106+
},
107+
{
108+
"cell_type": "code",
109+
"execution_count": 7,
110+
"id": "8b4ab5d9-676d-450d-8e9e-4927881bc542",
111+
"metadata": {},
112+
"outputs": [],
113+
"source": [
114+
"# Define hamiltonian\n",
115+
"\n",
116+
"def hamiltonian(qs, ps, s, kn, *args):\n",
117+
" q_x, q_y = qs\n",
118+
" p_x, p_y = ps\n",
119+
" return 1/2*(p_x**2 + p_y**2) + 1/2*kn*(1 + jax.numpy.cos(s))*(q_x**2 + q_y**2)"
120+
]
121+
},
122+
{
123+
"cell_type": "code",
124+
"execution_count": 8,
125+
"id": "dc84c3f8-14ce-4344-b5d4-faa013098b27",
126+
"metadata": {},
127+
"outputs": [],
128+
"source": [
129+
"# Set implicit midpoint integration step\n",
130+
"\n",
131+
"integrator = jit(fold(sequence(0, 0, [midpoint(hamiltonian, ns=2**4)], merge=False)))"
132+
]
133+
},
134+
{
135+
"cell_type": "code",
136+
"execution_count": 9,
137+
"id": "8c2ef87f-263b-47a6-a8f1-1f2242fb2773",
138+
"metadata": {},
139+
"outputs": [
140+
{
141+
"name": "stdout",
142+
"output_type": "stream",
143+
"text": [
144+
"[ 0.017983795895 0.017983795895 -0.133154567382 -0.133154567382]\n"
145+
]
146+
}
147+
],
148+
"source": [
149+
"# Perform integration with explicit time update\n",
150+
"\n",
151+
"time = si\n",
152+
"data = x\n",
153+
"for _ in range(10**2):\n",
154+
" data = integrator(data, ds, time, kn)\n",
155+
" time = time + ds\n",
156+
"print(data)"
157+
]
158+
},
159+
{
160+
"cell_type": "code",
161+
"execution_count": 10,
162+
"id": "2c5650e3-bf09-4d95-acbf-110d56001e47",
163+
"metadata": {},
164+
"outputs": [],
165+
"source": [
166+
"# Define hamiltonian (extended)\n",
167+
"\n",
168+
"def extended(qs, ps, s, kn, *args):\n",
169+
" q_x, q_y, q_t = qs\n",
170+
" p_x, p_y, p_t = ps\n",
171+
" return p_t + 1/2*(p_x**2 + p_y**2) + 1/2*kn*(1 + jax.numpy.cos(q_t))*(q_x**2 + q_y**2)"
172+
]
173+
},
174+
{
175+
"cell_type": "code",
176+
"execution_count": 11,
177+
"id": "184e803c-716e-4410-89fc-705ccc3e50ac",
178+
"metadata": {},
179+
"outputs": [],
180+
"source": [
181+
"# Set extended initial condition\n",
182+
"\n",
183+
"Qs = jax.numpy.concat([qs, si.reshape(-1)])\n",
184+
"Ps = jax.numpy.concat([ps, -hamiltonian(qs, ps, si, kn).reshape(-1)])\n",
185+
"X = jax.numpy.hstack([Qs, Ps])"
186+
]
187+
},
188+
{
189+
"cell_type": "code",
190+
"execution_count": 12,
191+
"id": "e00ec64a-601b-4f56-9404-f42d399a1616",
192+
"metadata": {},
193+
"outputs": [],
194+
"source": [
195+
"# Set implicit midpoint integration step using extended hamiltonian\n",
196+
"\n",
197+
"integrator = jit(fold(sequence(0, 0, [midpoint(extended, ns=2**4)], merge=False)))"
198+
]
199+
},
200+
{
201+
"cell_type": "code",
202+
"execution_count": 13,
203+
"id": "b2137273-e1a7-485d-8746-3601849015c6",
204+
"metadata": {},
205+
"outputs": [
206+
{
207+
"name": "stdout",
208+
"output_type": "stream",
209+
"text": [
210+
"[ 0.017983795895 0.017983795895 1. -0.133154567382 -0.133154567382 -0.018228323463]\n"
211+
]
212+
}
213+
],
214+
"source": [
215+
"# Set and compile element\n",
216+
"\n",
217+
"element = jit(nest(10**2, integrator))\n",
218+
"out = element(X, ds, si, kn)\n",
219+
"print(out)"
220+
]
221+
},
222+
{
223+
"cell_type": "code",
224+
"execution_count": null,
225+
"id": "886d34e5-f90c-473e-b6fc-f99d7b534d71",
226+
"metadata": {},
227+
"outputs": [],
228+
"source": []
229+
}
230+
],
231+
"metadata": {
232+
"colab": {
233+
"collapsed_sections": [
234+
"myt0_gMIOq7b",
235+
"5d97819c"
236+
],
237+
"name": "03_frequency.ipynb",
238+
"provenance": []
239+
},
240+
"kernelspec": {
241+
"display_name": "Python 3 (ipykernel)",
242+
"language": "python",
243+
"name": "python3"
244+
},
245+
"language_info": {
246+
"codemirror_mode": {
247+
"name": "ipython",
248+
"version": 3
249+
},
250+
"file_extension": ".py",
251+
"mimetype": "text/x-python",
252+
"name": "python",
253+
"nbconvert_exporter": "python",
254+
"pygments_lexer": "ipython3",
255+
"version": "3.12.1"
256+
},
257+
"latex_envs": {
258+
"LaTeX_envs_menu_present": true,
259+
"autoclose": false,
260+
"autocomplete": true,
261+
"bibliofile": "biblio.bib",
262+
"cite_by": "apalike",
263+
"current_citInitial": 1,
264+
"eqLabelWithNumbers": true,
265+
"eqNumInitial": 1,
266+
"hotkeys": {
267+
"equation": "Ctrl-E",
268+
"itemize": "Ctrl-I"
269+
},
270+
"labels_anchors": false,
271+
"latex_user_defs": false,
272+
"report_style_numbering": false,
273+
"user_envs_cfg": false
274+
}
275+
},
276+
"nbformat": 4,
277+
"nbformat_minor": 5
278+
}

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ JAX composable symplectic integrators.
1212
examples/example-02.ipynb
1313
examples/example-03.ipynb
1414
examples/example-04.ipynb
15+
examples/example-05.ipynb
1516

1617
.. toctree::
1718
:caption: API:

0 commit comments

Comments
 (0)