Skip to content

Commit ba5087c

Browse files
committed
Add some tests for julia codegen
1 parent 70d3a7e commit ba5087c

File tree

1 file changed

+193
-0
lines changed

1 file changed

+193
-0
lines changed

tests/test_julia_codegen.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
import pytest
2+
from gotranx.schemes import get_scheme
3+
from gotranx.codegen import JuliaCodeGenerator
4+
from gotranx.codegen import RHSArgument
5+
from gotranx.ode import make_ode
6+
7+
8+
@pytest.fixture(scope="module")
9+
def ode(trans, parser):
10+
expr = """
11+
parameters(a=0)
12+
parameters("My component",
13+
sigma=ScalarParam(12.0, description="Some description"),
14+
rho=21.0,
15+
beta=2.4
16+
)
17+
states("My component", x=1.0, y=2.0,z=3.05)
18+
19+
expressions("My component")
20+
rhoz = rho - z
21+
dy_dt = x*rhoz - y # millivolt
22+
dx_dt = sigma*(-x + y)
23+
betaz = beta*z
24+
dz_dt = -betaz + x*y
25+
"""
26+
tree = parser.parse(expr)
27+
return make_ode(*trans.transform(tree), name="lorentz")
28+
29+
30+
@pytest.fixture(scope="module")
31+
def codegen(ode) -> JuliaCodeGenerator:
32+
return JuliaCodeGenerator(ode)
33+
34+
35+
def test_julia_codegen_initial_state_values(codegen: JuliaCodeGenerator):
36+
assert codegen.initial_state_values() == (
37+
"\nfunction init_state_values!(states)"
38+
"\n #="
39+
"\n x=1.0, z=3.05, y=2.0"
40+
"\n =#"
41+
"\n states[1] = 1.0"
42+
"\n states[2] = 3.05"
43+
"\n states[3] = 2.0"
44+
"\nend"
45+
"\n"
46+
)
47+
48+
49+
def test_julia_codegen_parameter_index(codegen: JuliaCodeGenerator):
50+
assert codegen.parameter_index() == (
51+
"# Parameter index"
52+
"\nfunction parameter_index(name::String)"
53+
"\n"
54+
'\n if name == "a"'
55+
"\n return 1"
56+
"\n"
57+
"\n"
58+
'\n elseif name == "beta"'
59+
"\n return 2"
60+
"\n"
61+
"\n"
62+
'\n elseif name == "rho"'
63+
"\n return 3"
64+
"\n"
65+
"\n"
66+
'\n elseif name == "sigma"'
67+
"\n return 4"
68+
"\n"
69+
"\n end"
70+
"\n return -1"
71+
"\nend"
72+
)
73+
74+
75+
def test_julia_codegen_initial_parameter_values(codegen: JuliaCodeGenerator):
76+
assert codegen.initial_parameter_values() == (
77+
"\nfunction init_parameter_values!(parameters)"
78+
"\n #="
79+
"\n a=0, beta=2.4, rho=21.0, sigma=12.0"
80+
"\n =#"
81+
"\n parameters[1] = 0"
82+
"\n parameters[2] = 2.4"
83+
"\n parameters[3] = 21.0"
84+
"\n parameters[4] = 12.0"
85+
"\nend"
86+
"\n"
87+
)
88+
89+
90+
@pytest.mark.parametrize(
91+
"order, arguments",
92+
[
93+
(
94+
RHSArgument.stp,
95+
("states, t, parameters"),
96+
),
97+
(
98+
RHSArgument.spt,
99+
("states, parameters, t"),
100+
),
101+
(
102+
RHSArgument.tsp,
103+
("t, states, parameters"),
104+
),
105+
],
106+
)
107+
def test_julia_codegen_rhs(order: str, arguments: str, codegen: JuliaCodeGenerator):
108+
assert codegen.rhs(order=order) == (
109+
f"\nfunction rhs({arguments}, values)"
110+
"\n"
111+
"\n # Assign states"
112+
"\n x = states[1]"
113+
"\n z = states[2]"
114+
"\n y = states[3]"
115+
"\n"
116+
"\n # Assign parameters"
117+
"\n a = parameters[1]"
118+
"\n beta = parameters[2]"
119+
"\n rho = parameters[3]"
120+
"\n sigma = parameters[4]"
121+
"\n"
122+
"\n # Assign expressions"
123+
"\n betaz = beta .* z"
124+
"\n rhoz = rho - z"
125+
"\n dx_dt = sigma .* (-x + y)"
126+
"\n values[1] = dx_dt"
127+
"\n dz_dt = -betaz + x .* y"
128+
"\n values[2] = dz_dt"
129+
"\n dy_dt = rhoz .* x - y"
130+
"\n values[3] = dy_dt"
131+
"\nend"
132+
"\n"
133+
)
134+
135+
136+
def test_julia_codegen_explicit_euler(codegen: JuliaCodeGenerator):
137+
assert codegen.scheme(get_scheme("explicit_euler")) == (
138+
"\nfunction explicit_euler(states, t, dt, parameters, values)"
139+
"\n"
140+
"\n # Assign states"
141+
"\n x = states[1]"
142+
"\n z = states[2]"
143+
"\n y = states[3]"
144+
"\n"
145+
"\n # Assign parameters"
146+
"\n a = parameters[1]"
147+
"\n beta = parameters[2]"
148+
"\n rho = parameters[3]"
149+
"\n sigma = parameters[4]"
150+
"\n"
151+
"\n # Assign expressions"
152+
"\n betaz = beta .* z"
153+
"\n rhoz = rho - z"
154+
"\n dx_dt = sigma .* (-x + y)"
155+
"\n values[1] = dt .* dx_dt + x"
156+
"\n dz_dt = -betaz + x .* y"
157+
"\n values[2] = dt .* dz_dt + z"
158+
"\n dy_dt = rhoz .* x - y"
159+
"\n values[3] = dt .* dy_dt + y"
160+
"\nend"
161+
"\n"
162+
)
163+
164+
165+
def test_julia_monitored(codegen: JuliaCodeGenerator):
166+
assert codegen.monitor_values() == (
167+
"\nfunction monitor_values(t, states, parameters, values)"
168+
"\n"
169+
"\n # Assign states"
170+
"\n x = states[1]"
171+
"\n z = states[2]"
172+
"\n y = states[3]"
173+
"\n"
174+
"\n # Assign parameters"
175+
"\n a = parameters[1]"
176+
"\n beta = parameters[2]"
177+
"\n rho = parameters[3]"
178+
"\n sigma = parameters[4]"
179+
"\n"
180+
"\n # Assign expressions"
181+
"\n betaz = beta .* z"
182+
"\n values[1] = betaz"
183+
"\n rhoz = rho - z"
184+
"\n values[2] = rhoz"
185+
"\n dx_dt = sigma .* (-x + y)"
186+
"\n values[3] = dx_dt"
187+
"\n dz_dt = -betaz + x .* y"
188+
"\n values[4] = dz_dt"
189+
"\n dy_dt = rhoz .* x - y"
190+
"\n values[5] = dy_dt"
191+
"\nend"
192+
"\n"
193+
)

0 commit comments

Comments
 (0)