-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathrun_lotkavolterra.py
More file actions
278 lines (214 loc) · 8.66 KB
/
run_lotkavolterra.py
File metadata and controls
278 lines (214 loc) · 8.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
"""Lotka-Volterra benchmark.
See makefile for instructions.
"""
import argparse
import functools
import os
import statistics
import timeit
import warnings
from collections.abc import Callable
import diffrax
import jax
import jax.numpy as jnp
import numpy as np
import scipy.integrate
import tqdm
from probdiffeq import ivpsolve, ivpsolvers, taylor
from probdiffeq.util.doc_util import info
def set_jax_config() -> None:
"""Set JAX and other external libraries up."""
# x64 precision
jax.config.update("jax_enable_x64", True)
# CPU
jax.config.update("jax_platform_name", "cpu")
def print_library_info() -> None:
"""Print the environment info for this benchmark."""
info.print_info()
print("\n------------------------------------------\n")
def parse_arguments() -> argparse.Namespace:
"""Parse the arguments from the command line."""
parser = argparse.ArgumentParser()
parser.add_argument("--start", type=int, default=1)
parser.add_argument("--stop", type=int, default=3)
parser.add_argument("--repeats", type=int, default=10)
parser.add_argument("--save", action=argparse.BooleanOptionalAction)
return parser.parse_args()
def tolerances_from_args(arguments: argparse.Namespace, /) -> jax.Array:
"""Choose vector of tolerances from the command-line arguments."""
return 0.1 ** jnp.arange(arguments.start, arguments.stop, step=1.0)
def timeit_fun_from_args(arguments: argparse.Namespace, /) -> Callable:
"""Construct a timeit-function from the command-line arguments."""
def timer(fun, /):
return list(timeit.repeat(fun, number=1, repeat=arguments.repeats))
return timer
def solver_probdiffeq(num_derivatives: int, implementation, correction) -> Callable:
"""Construct a solver that wraps ProbDiffEq's solution routines."""
@jax.jit
def vf_probdiffeq(y, *, t): # noqa: ARG001
"""Lotka--Volterra dynamics."""
dy1 = 0.5 * y[0] - 0.05 * y[0] * y[1]
dy2 = -0.5 * y[1] + 0.05 * y[0] * y[1]
return jnp.asarray([dy1, dy2])
u0 = jnp.asarray((20.0, 20.0))
t0, t1 = (0.0, 50.0)
@jax.jit
def param_to_solution(tol):
# Build a solver
vf_auto = functools.partial(vf_probdiffeq, t=t0)
tcoeffs = taylor.odejet_padded_scan(vf_auto, (u0,), num=num_derivatives)
init, ibm, ssm = ivpsolvers.prior_wiener_integrated(
tcoeffs, ssm_fact=implementation
)
strategy = ivpsolvers.strategy_filter(ssm=ssm)
corr = correction(vf_probdiffeq, ssm=ssm)
solver = ivpsolvers.solver_mle(strategy, prior=ibm, correction=corr, ssm=ssm)
control = ivpsolvers.control_proportional_integral()
adaptive_solver = ivpsolvers.adaptive(
solver, atol=1e-2 * tol, rtol=tol, control=control, ssm=ssm
)
# Solve
dt0 = ivpsolve.dt0(vf_auto, (u0,))
solution = ivpsolve.solve_adaptive_terminal_values(
init, t0=t0, t1=t1, dt0=dt0, adaptive_solver=adaptive_solver, ssm=ssm
)
# Return the terminal value
return jax.block_until_ready(solution.u[0])
return param_to_solution
def solver_diffrax(*, solver) -> Callable:
"""Construct a solver that wraps Diffrax' solution routines."""
@diffrax.ODETerm
@jax.jit
def vf_diffrax(_t, y, _args):
"""Lotka--Volterra dynamics."""
dy1 = 0.5 * y[0] - 0.05 * y[0] * y[1]
dy2 = -0.5 * y[1] + 0.05 * y[0] * y[1]
return jnp.asarray([dy1, dy2])
u0 = jnp.asarray((20.0, 20.0))
t0, t1 = (0.0, 50.0)
@jax.jit
def param_to_solution(tol):
controller = diffrax.PIDController(atol=1e-3 * tol, rtol=tol)
saveat = diffrax.SaveAt(t0=False, t1=True, ts=None)
solution = diffrax.diffeqsolve(
vf_diffrax,
y0=u0,
t0=t0,
t1=t1,
saveat=saveat,
stepsize_controller=controller,
dt0=None,
max_steps=10_000,
solver=solver,
)
return jax.block_until_ready(solution.ys[0, :])
return param_to_solution
def solver_scipy(*, method: str) -> Callable:
"""Construct a solver that wraps SciPy's solution routines."""
def vf_scipy(_t, y):
"""Lotka--Volterra dynamics."""
dy1 = 0.5 * y[0] - 0.05 * y[0] * y[1]
dy2 = -0.5 * y[1] + 0.05 * y[0] * y[1]
return np.asarray([dy1, dy2])
u0 = jnp.asarray((20.0, 20.0))
time_span = np.asarray([0.0, 50.0])
def param_to_solution(tol):
solution = scipy.integrate.solve_ivp(
vf_scipy,
y0=u0,
t_span=time_span,
t_eval=time_span,
atol=1e-3 * tol,
rtol=tol,
method=method,
)
return jnp.asarray(solution.y[:, -1])
return param_to_solution
def plot_ivp_solution():
"""Compute plotting-values for the IVP."""
def vf_scipy(_t, y):
"""Lotka--Volterra dynamics."""
dy1 = 0.5 * y[0] - 0.05 * y[0] * y[1]
dy2 = -0.5 * y[1] + 0.05 * y[0] * y[1]
return np.asarray([dy1, dy2])
u0 = jnp.asarray((20.0, 20.0))
time_span = np.asarray([0.0, 50.0])
tol = 1e-12
solution = scipy.integrate.solve_ivp(
vf_scipy, y0=u0, t_span=time_span, atol=1e-3 * tol, rtol=tol, method="LSODA"
)
return solution.t, solution.y.T
def rmse_relative(expected: jax.Array, *, nugget=1e-5) -> Callable:
"""Compute the relative RMSE."""
expected = jnp.asarray(expected)
def rmse(received):
received = jnp.asarray(received)
error_absolute = jnp.abs(expected - received)
error_relative = error_absolute / jnp.abs(nugget + expected)
return jnp.linalg.norm(error_relative) / jnp.sqrt(error_relative.size)
return rmse
def workprec(fun, *, precision_fun: Callable, timeit_fun: Callable) -> Callable:
"""Turn a parameter-to-solution function to a parameter-to-workprecision function.
Turn a function param->solution into a function
(param1, param2, ...)->(workprecision1, workprecision2, ...)
where workprecisionX is a dictionary with keys "work" and "precision".
"""
def parameter_list_to_workprecision(list_of_args, /):
works_mean = []
works_std = []
precisions = []
for arg in list_of_args:
precision = precision_fun(fun(arg).block_until_ready())
times = timeit_fun(lambda: fun(arg).block_until_ready()) # noqa: B023
precisions.append(precision)
works_mean.append(statistics.mean(times))
works_std.append(statistics.stdev(times))
return {
"work_mean": jnp.asarray(works_mean),
"work_std": jnp.asarray(works_std),
"precision": jnp.asarray(precisions),
}
return parameter_list_to_workprecision
if __name__ == "__main__":
# Set up all the configs
set_jax_config()
print_library_info()
# Simulate once to get plotting code
ts, ys = plot_ivp_solution()
# If we change the probdiffeq-impl halfway through a script, a warning is raised.
# But for this benchmark, such a change is on purpose.
warnings.filterwarnings("ignore")
# Read configuration from command line
args = parse_arguments()
tolerances = tolerances_from_args(args)
timeit_fun = timeit_fun_from_args(args)
# Assemble algorithms
ts0, ts1 = ivpsolvers.correction_ts0, ivpsolvers.correction_ts1
ts0_iso = solver_probdiffeq(5, correction=ts0, implementation="isotropic")
ts0_bd = solver_probdiffeq(5, correction=ts0, implementation="blockdiag")
ts1_dense = solver_probdiffeq(8, correction=ts1, implementation="dense")
algorithms = {
r"ProbDiffEq: TS0($5$, isotropic)": ts0_iso,
r"ProbDiffEq: TS0($5$, blockdiag)": ts0_bd,
r"ProbDiffEq: TS1($8$, dense)": ts1_dense,
"Diffrax: Tsit5()": solver_diffrax(solver=diffrax.Tsit5()),
"Diffrax: Dopri8()": solver_diffrax(solver=diffrax.Dopri8()),
"SciPy: 'RK45'": solver_scipy(method="RK45"),
"SciPy: 'DOP853'": solver_scipy(method="DOP853"),
}
# Compute a reference solution
reference = solver_scipy(method="LSODA")(1e-15)
precision_fun = rmse_relative(reference)
# Compute all work-precision diagrams
results = {}
for label, algo in tqdm.tqdm(algorithms.items()):
param_to_wp = workprec(algo, precision_fun=precision_fun, timeit_fun=timeit_fun)
results[label] = param_to_wp(tolerances)
# Save results
if args.save:
jnp.save(os.path.dirname(__file__) + "/results.npy", results)
jnp.save(os.path.dirname(__file__) + "/plot_ts.npy", ts)
jnp.save(os.path.dirname(__file__) + "/plot_ys.npy", ys)
print("\nSaving successful.\n")
else:
print("\nSkipped saving.\n")