-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathivpsolvers.py
More file actions
1244 lines (1003 loc) · 40.3 KB
/
ivpsolvers.py
File metadata and controls
1244 lines (1003 loc) · 40.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""Probabilistic IVP solvers."""
from probdiffeq import stats
from probdiffeq.backend import (
containers,
control_flow,
functools,
linalg,
special,
tree_util,
)
from probdiffeq.backend import numpy as np
from probdiffeq.backend.typing import (
Any,
ArrayLike,
Callable,
Generic,
NamedArg,
TypeVar,
)
from probdiffeq.impl import impl
def prior_wiener_integrated(
tcoeffs, *, ssm_fact: str, output_scale: ArrayLike | None = None, damp: float = 0.0
):
"""Construct an adaptive(/continuous-time), multiply-integrated Wiener process."""
ssm = impl.choose(ssm_fact, tcoeffs_like=tcoeffs)
# TODO: should the output_scale be an argument to solve()?
# TODO: should the output scale (and all 'damp'-like factors)
# mirror the pytree structure of 'tcoeffs'?
if output_scale is None:
output_scale = np.ones_like(ssm.prototypes.output_scale())
discretize = ssm.conditional.ibm_transitions(base_scale=output_scale)
# Increase damping to get visually more pleasing uncertainties
# and more numerical robustness for
# high-order solvers in low precision arithmetic
init = ssm.normal.from_tcoeffs(tcoeffs, damp=damp)
return init, discretize, ssm
def prior_wiener_integrated_discrete(ts, *args, **kwargs):
"""Compute a time-discretized, multiply-integrated Wiener process."""
init, discretize, ssm = prior_wiener_integrated(*args, **kwargs)
scales = np.ones_like(ssm.prototypes.output_scale())
discretize_vmap = functools.vmap(discretize, in_axes=(0, None))
conditionals = discretize_vmap(np.diff(ts), scales)
return init, conditionals, ssm
R = TypeVar("R")
@containers.dataclass
class _InterpRes(Generic[R]):
step_from: R
"""The new 'step_from' field.
At time `max(t, s1.t)`.
Use this as the right-most reference state
in future interpolations, or continue time-stepping from here.
"""
interpolated: R
"""The new 'solution' field.
At time `t`. This is the interpolation result.
"""
interp_from: R
"""The new `interp_from` field.
At time `t`. Use this as the right-most reference state
in future interpolations, or continue time-stepping from here.
The difference between `interpolated` and `interp_from` emerges in save_at* modes.
`interpolated` belongs to the just-concluded time interval,
and `interp_from` belongs to the to-be-started time interval.
Concretely, this means that `interp_from` has a unit backward model
and `interpolated` remembers how to step back to the previous target location.
"""
class _PositiveCubatureRule(containers.NamedTuple):
"""Cubature rule with positive weights."""
points: ArrayLike
weights_sqrtm: ArrayLike
def cubature_third_order_spherical(input_shape) -> _PositiveCubatureRule:
"""Third-order spherical cubature integration."""
assert len(input_shape) <= 1
if len(input_shape) == 1:
(d,) = input_shape
points_mat, weights_sqrtm = _third_order_spherical_params(d=d)
return _PositiveCubatureRule(points=points_mat, weights_sqrtm=weights_sqrtm)
# If input_shape == (), compute weights via input_shape=(1,)
# and 'squeeze' the points.
points_mat, weights_sqrtm = _third_order_spherical_params(d=1)
(S, _) = points_mat.shape
points = np.reshape(points_mat, (S,))
return _PositiveCubatureRule(points=points, weights_sqrtm=weights_sqrtm)
def _third_order_spherical_params(*, d):
eye_d = np.eye(d) * np.sqrt(d)
pts = np.concatenate((eye_d, -1 * eye_d))
weights_sqrtm = np.ones((2 * d,)) / np.sqrt(2.0 * d)
return pts, weights_sqrtm
def cubature_unscented_transform(input_shape, r=1.0) -> _PositiveCubatureRule:
"""Unscented transform."""
assert len(input_shape) <= 1
if len(input_shape) == 1:
(d,) = input_shape
points_mat, weights_sqrtm = _unscented_transform_params(d=d, r=r)
return _PositiveCubatureRule(points=points_mat, weights_sqrtm=weights_sqrtm)
# If input_shape == (), compute weights via input_shape=(1,)
# and 'squeeze' the points.
points_mat, weights_sqrtm = _unscented_transform_params(d=1, r=r)
(S, _) = points_mat.shape
points = np.reshape(points_mat, (S,))
return _PositiveCubatureRule(points=points, weights_sqrtm=weights_sqrtm)
def _unscented_transform_params(d, *, r):
eye_d = np.eye(d) * np.sqrt(d + r)
zeros = np.zeros((1, d))
pts = np.concatenate((eye_d, zeros, -1 * eye_d))
_scale = d + r
weights_sqrtm1 = np.ones((d,)) / np.sqrt(2.0 * _scale)
weights_sqrtm2 = np.sqrt(r / _scale)
weights_sqrtm = np.hstack((weights_sqrtm1, weights_sqrtm2, weights_sqrtm1))
return pts, weights_sqrtm
def cubature_gauss_hermite(input_shape, degree=5) -> _PositiveCubatureRule:
"""(Statistician's) Gauss-Hermite cubature.
The number of cubature points is `prod(input_shape)**degree`.
"""
assert len(input_shape) == 1
(dim,) = input_shape
# Roots of the probabilist/statistician's Hermite polynomials (in Numpy...)
_roots = special.roots_hermitenorm(n=degree, mu=True)
pts, weights, sum_of_weights = _roots
weights = weights / sum_of_weights
# Transform into jax arrays and take square root of weights
pts = np.asarray(pts)
weights_sqrtm = np.sqrt(np.asarray(weights))
# Build a tensor grid and return class
tensor_pts = _tensor_points(pts, d=dim)
tensor_weights_sqrtm = _tensor_weights(weights_sqrtm, d=dim)
return _PositiveCubatureRule(points=tensor_pts, weights_sqrtm=tensor_weights_sqrtm)
# TODO: how does this generalise to an input_shape instead of an input_dimension?
# via tree_map(lambda s: _tensor_points(x, s), input_shape)?
def _tensor_weights(*args, **kwargs):
mesh = _tensor_points(*args, **kwargs)
return np.prod_along_axis(mesh, axis=1)
def _tensor_points(x, /, *, d):
x_mesh = np.meshgrid(*([x] * d))
y_mesh = tree_util.tree_map(lambda s: np.reshape(s, (-1,)), x_mesh)
return np.stack(y_mesh).T
@containers.dataclass
class _Strategy:
"""Estimation-strategy interface."""
ssm: Any
is_suitable_for_save_at: int
is_suitable_for_save_every_step: int
is_suitable_for_offgrid_marginals: int
def init(self, sol, /):
"""Initialise a state from a solution."""
raise NotImplementedError
def extrapolate(self, rv, strategy_state, /, *, transition):
"""Extrapolate (also known as prediction)."""
raise NotImplementedError
def extract(self, rv, strategy_state, /):
"""Extract a solution from a state."""
raise NotImplementedError
def interpolate(self, state_t0, state_t1, *, dt0, dt1, output_scale, prior):
"""Interpolate."""
raise NotImplementedError
def interpolate_at_t1(self, state_t0, state_t1, *, dt0, dt1, output_scale, prior):
"""Process the state at a checkpoint."""
raise NotImplementedError
def strategy_smoother(*, ssm) -> _Strategy:
"""Construct a smoother."""
@containers.dataclass
class Smoother(_Strategy):
def init(self, sol, /):
# Special case for implementing offgrid-marginals...
if isinstance(sol, stats.MarkovSeq):
rv = sol.init
cond = sol.conditional
else:
rv = sol
cond = self.ssm.conditional.identity(ssm.num_derivatives + 1)
return rv, cond
def extrapolate(self, rv, aux, /, *, transition):
del aux
return self.ssm.conditional.revert(rv, transition)
def extract(self, hidden_state, extra, /):
return stats.MarkovSeq(init=hidden_state, conditional=extra)
def interpolate(self, state_t0, state_t1, *, dt0, dt1, output_scale, prior):
"""Interpolate.
A smoother interpolates by_
* Extrapolating from t0 to t, which gives the "filtering" marginal
and the backward transition from t to t0.
* Extrapolating from t to t1, which gives another "filtering" marginal
and the backward transition from t1 to t.
* Applying the new t1-to-t backward transition to compute the interpolation.
This intermediate result is informed about its "right-hand side" datum.
Subsequent interpolations continue from the value at 't'.
Subsequent IVP solver steps continue from the value at 't1'.
"""
# TODO: if we pass prior1 and prior2, then
# we don't have to pass dt0, dt1, output_scale, and prior...
# Extrapolate from t0 to t, and from t to t1.
prior0 = prior(dt0, output_scale)
extrapolated_t = self.extrapolate(*state_t0, transition=prior0)
prior1 = prior(dt1, output_scale)
extrapolated_t1 = self.extrapolate(*extrapolated_t, transition=prior1)
# Marginalise from t1 to t to obtain the interpolated solution.
marginal_t1, _ = state_t1
conditional_t1_to_t = extrapolated_t1[1]
rv_at_t = self.ssm.conditional.marginalise(marginal_t1, conditional_t1_to_t)
solution_at_t = (rv_at_t, extrapolated_t[1])
# The state at t1 gets a new backward model;
# (it must remember how to get back to t, not to t0).
solution_at_t1 = (marginal_t1, conditional_t1_to_t)
return _InterpRes(
step_from=solution_at_t1,
interpolated=solution_at_t,
interp_from=solution_at_t,
)
def interpolate_at_t1(
self, state_t0, state_t1, *, dt0, dt1, output_scale, prior
):
del prior
del state_t0
del dt0
del dt1
del output_scale
return _InterpRes(state_t1, state_t1, state_t1)
return Smoother(
ssm=ssm,
is_suitable_for_save_at=False,
is_suitable_for_save_every_step=True,
is_suitable_for_offgrid_marginals=True,
)
def strategy_filter(*, ssm) -> _Strategy:
"""Construct a filter."""
@containers.dataclass
class Filter(_Strategy):
def init(self, sol, /):
return sol, None
def extrapolate(self, rv, aux, /, *, transition):
del aux
rv = self.ssm.conditional.marginalise(rv, transition)
return rv, None
def extract(self, hidden_state, _extra, /):
return hidden_state
def interpolate(self, state_t0, state_t1, dt0, dt1, output_scale, *, prior):
# todo: by ditching marginal_t1 and dt1, this function _extrapolates
# (no *inter*polation happening)
del dt1
marginal_t1, _ = state_t1
hidden, extra = state_t0
prior0 = prior(dt0, output_scale)
hidden, extra = self.extrapolate(hidden, extra, transition=prior0)
# Consistent state-types in interpolation result.
interp = (hidden, extra)
step_from = (marginal_t1, None)
return _InterpRes(
step_from=step_from, interpolated=interp, interp_from=interp
)
def interpolate_at_t1(
self, state_t0, state_t1, dt0, dt1, output_scale, *, prior
):
del prior
del state_t0
del dt0
del dt1
del output_scale
rv, extra = state_t1
return _InterpRes((rv, extra), (rv, extra), (rv, extra))
return Filter(
ssm=ssm,
is_suitable_for_save_at=True,
is_suitable_for_save_every_step=True,
is_suitable_for_offgrid_marginals=True,
)
def strategy_fixedpoint(*, ssm) -> _Strategy:
"""Construct a fixedpoint-smoother."""
@containers.dataclass
class FixedPoint(_Strategy):
def init(self, sol, /):
cond = self.ssm.conditional.identity(ssm.num_derivatives + 1)
return sol, cond
def extrapolate(self, rv, bw0, /, *, transition):
extrapolated, cond = self.ssm.conditional.revert(rv, transition)
cond = self.ssm.conditional.merge(bw0, cond)
return extrapolated, cond
def extract(self, hidden_state, extra, /):
return stats.MarkovSeq(init=hidden_state, conditional=extra)
def interpolate_at_t1(
self, state_t0, state_t1, *, dt0, dt1, output_scale, prior
):
del prior
del state_t0
del dt0
del dt1
del output_scale
rv, extra = state_t1
cond_identity = self.ssm.conditional.identity(ssm.num_derivatives + 1)
return _InterpRes((rv, cond_identity), (rv, extra), (rv, cond_identity))
def interpolate(self, state_t0, state_t1, *, dt0, dt1, output_scale, prior):
"""Interpolate.
A fixed-point smoother interpolates by
* Extrapolating from t0 to t, which gives the "filtering" marginal
and the backward transition from t to t0.
* Extrapolating from t to t1, which gives another "filtering" marginal
and the backward transition from t1 to t.
* Applying the t1-to-t backward transition
to compute the interpolation result.
This intermediate result is informed about its "right-hand side" datum.
The difference to smoother-interpolation is quite subtle:
* The backward transition of the solution at 't'
is merged with that at 't0'.
The reason is that the backward transition at 't0' knows
"how to get to the quantity of interest",
and this is precisely what we want to interpolate.
* Subsequent interpolations do not continue from the value at 't', but
from a very similar value where the backward transition
is replaced with an identity. The reason is that the interpolated solution
becomes the new quantity of interest, and subsequent interpolations
need to learn how to get here.
* Subsequent solver steps do not continue from the value at 't1',
but the value at 't1' where the backward model is replaced by
the 't1-to-t' backward model. The reason is similar to the above:
future steps need to know "how to get back to the quantity of interest",
which is the interpolated solution.
These distinctions are precisely why we need three fields
in every interpolation result:
the solution,
the continue-interpolation-from-here,
and the continue-stepping-from-here.
All three are different for fixed point smoothers.
(Really, I try removing one of them monthly and
then don't understand why tests fail.)
"""
marginal_t1, _ = state_t1
# Extrapolate from t0 to t, and from t to t1.
# This yields all building blocks.
prior0 = prior(dt0, output_scale)
extrapolated_t = self.extrapolate(*state_t0, transition=prior0)
conditional_id = self.ssm.conditional.identity(ssm.num_derivatives + 1)
previous_new = (extrapolated_t[0], conditional_id)
prior1 = prior(dt1, output_scale)
extrapolated_t1 = self.extrapolate(*previous_new, transition=prior1)
# Marginalise from t1 to t to obtain the interpolated solution.
conditional_t1_to_t = extrapolated_t1[1]
rv_at_t = self.ssm.conditional.marginalise(marginal_t1, conditional_t1_to_t)
# Return the right combination of marginals and conditionals.
return _InterpRes(
step_from=(marginal_t1, conditional_t1_to_t),
interpolated=(rv_at_t, extrapolated_t[1]),
interp_from=previous_new,
)
return FixedPoint(
ssm=ssm,
is_suitable_for_save_at=True,
is_suitable_for_save_every_step=False,
is_suitable_for_offgrid_marginals=False,
)
@containers.dataclass
class _Correction:
"""Correction model interface."""
name: str
ode_order: int
ssm: Any
linearize: Any
vector_field: Callable
re_linearize: bool
def init(self, x, /):
"""Initialise the state from the solution."""
jac = self.linearize.init()
return x, jac
def estimate_error(self, rv, correction_state, /, t):
"""Estimate the error."""
f_wrapped = functools.partial(self.vector_field, t=t)
cond, correction_state = self.linearize.update(f_wrapped, rv, correction_state)
observed = self.ssm.conditional.marginalise(rv, cond)
zero_data = np.zeros(())
output_scale = self.ssm.stats.mahalanobis_norm_relative(zero_data, rv=observed)
stdev = self.ssm.stats.standard_deviation(observed)
error_estimate_unscaled = np.squeeze(stdev)
error_estimate = output_scale * error_estimate_unscaled
return error_estimate, observed, (correction_state, cond)
def correct(self, rv, correction_state, /, t):
"""Perform the correction step."""
linearization_state, cond = correction_state
if self.re_linearize:
f_wrapped = functools.partial(self.vector_field, t=t)
cond, linearization_state = self.linearize.update(
f_wrapped, rv, linearization_state
)
observed, reverted = self.ssm.conditional.revert(rv, cond)
corrected = reverted.noise
return corrected, observed, linearization_state
def correction_ts0(vector_field, *, ssm, ode_order=1, damp: float = 0.0) -> _Correction:
"""Zeroth-order Taylor linearisation."""
linearize = ssm.linearise.ode_taylor_0th(ode_order=ode_order, damp=damp)
return _Correction(
name="TS0",
vector_field=vector_field,
ode_order=ode_order,
ssm=ssm,
linearize=linearize,
re_linearize=False,
)
def correction_ts1(
vector_field,
*,
ssm,
ode_order=1,
damp: float = 0.0,
jvp_probes=10,
jvp_probes_seed=1,
) -> _Correction:
"""First-order Taylor linearisation."""
assert jvp_probes > 0
linearize = ssm.linearise.ode_taylor_1st(
ode_order=ode_order,
damp=damp,
jvp_probes=jvp_probes,
jvp_probes_seed=jvp_probes_seed,
)
return _Correction(
name="TS1",
vector_field=vector_field,
ode_order=ode_order,
ssm=ssm,
linearize=linearize,
re_linearize=False,
)
def correction_slr0(
vector_field, *, ssm, cubature_fun=cubature_third_order_spherical, damp: float = 0.0
) -> _Correction:
"""Zeroth-order statistical linear regression."""
linearize = ssm.linearise.ode_statistical_0th(cubature_fun, damp=damp)
return _Correction(
ssm=ssm,
vector_field=vector_field,
ode_order=1,
linearize=linearize,
name="SLR0",
re_linearize=True,
)
def correction_slr1(
vector_field, *, ssm, cubature_fun=cubature_third_order_spherical, damp: float = 0.0
) -> _Correction:
"""First-order statistical linear regression."""
linearize = ssm.linearise.ode_statistical_1st(cubature_fun, damp=damp)
return _Correction(
ssm=ssm,
vector_field=vector_field,
ode_order=1,
linearize=linearize,
name="SLR1",
re_linearize=True,
)
@containers.dataclass
class _Calibration:
"""Calibration implementation."""
init: Callable
update: Callable
extract: Callable
class _State(containers.NamedTuple):
"""Solver state."""
t: Any
rv: Any
strategy_state: Any
correction_state: Any
output_scale: Any
@tree_util.register_dataclass
@containers.dataclass
class _ErrorEstimate:
estimate: ArrayLike
reference: ArrayLike
@containers.dataclass
class _ProbabilisticSolver:
name: str
step_implementation: Callable
prior: Callable
ssm: Any
strategy: _Strategy
calibration: _Calibration
correction: _Correction
def offgrid_marginals(self, *, t, marginals_t1, posterior_t0, t0, t1, output_scale):
"""Compute offgrid_marginals."""
if not self.is_suitable_for_offgrid_marginals:
raise NotImplementedError
dt0 = t - t0
dt1 = t1 - t
rv, extra = self.strategy.init(posterior_t0)
rv, corr = self.correction.init(rv)
# TODO: Replace dt0, dt1, prior, and output_scale with prior_dt0, and prior_dt1
interp = self.strategy.interpolate(
state_t0=(rv, extra),
state_t1=(marginals_t1, None),
dt0=dt0,
dt1=dt1,
output_scale=output_scale,
prior=self.prior,
)
(marginals, _aux) = interp.interpolated
u = self.ssm.stats.qoi(marginals)
return u, marginals
@property
def error_contraction_rate(self):
return self.ssm.num_derivatives + 1
@property
def is_suitable_for_offgrid_marginals(self):
return self.strategy.is_suitable_for_offgrid_marginals
@property
def is_suitable_for_save_at(self):
return self.strategy.is_suitable_for_save_at
@property
def is_suitable_for_save_every_step(self):
return self.strategy.is_suitable_for_save_every_step
def init(self, t, init) -> _State:
rv, extra = self.strategy.init(init)
rv, corr = self.correction.init(rv)
# TODO: make the init() and extract() an interface.
# Then, lots of calibration logic simplifies considerably.
calib_state = self.calibration.init()
return _State(
t=t,
rv=rv,
strategy_state=extra,
correction_state=corr,
output_scale=calib_state,
)
def step(self, state: _State, *, dt):
return self.step_implementation(state, dt=dt, calibration=self.calibration)
def extract(self, state: _State, /):
posterior = self.strategy.extract(state.rv, state.strategy_state)
t = state.t
_output_scale_prior, output_scale = self.calibration.extract(state.output_scale)
return t, (posterior, output_scale)
def interpolate(self, *, t, interp_from: _State, interp_to: _State) -> _InterpRes:
output_scale, _ = self.calibration.extract(interp_to.output_scale)
# Interpolate
interp = self.strategy.interpolate(
state_t0=(interp_from.rv, interp_from.strategy_state),
state_t1=(interp_to.rv, interp_to.strategy_state),
dt0=t - interp_from.t,
dt1=interp_to.t - t,
output_scale=output_scale,
prior=self.prior,
)
# Turn outputs into valid states
def _state(t_, x, scale, cs):
return _State(
t=t_,
rv=x[0],
strategy_state=x[1],
correction_state=cs,
output_scale=scale,
)
step_from = _state(
interp_to.t,
interp.step_from,
interp_to.output_scale,
interp_to.correction_state,
)
interpolated = _state(
t, interp.interpolated, interp_to.output_scale, interp_to.correction_state
)
interp_from = _state(
t,
interp.interp_from,
interp_from.output_scale,
interp_from.correction_state,
)
return _InterpRes(
step_from=step_from, interpolated=interpolated, interp_from=interp_from
)
def interpolate_at_t1(
self, *, t, interp_from: _State, interp_to: _State
) -> _InterpRes:
"""Process the solution in case t=t_n."""
del t
tmp = self.strategy.interpolate_at_t1(
state_t0=None,
dt0=None,
dt1=None,
output_scale=None,
state_t1=(interp_to.rv, interp_to.strategy_state),
prior=self.prior,
)
step_from_, solution_, interp_from_ = (
tmp.step_from,
tmp.interpolated,
tmp.interp_from,
)
def _state(t_, x, scale, cs):
return _State(
t=t_,
rv=x[0],
strategy_state=x[1],
correction_state=cs,
output_scale=scale,
)
t = interp_to.t
prev = _state(
t, interp_from_, interp_from.output_scale, interp_from.correction_state
)
sol = _state(t, solution_, interp_to.output_scale, interp_to.correction_state)
acc = _state(t, step_from_, interp_to.output_scale, interp_to.correction_state)
return _InterpRes(step_from=acc, interpolated=sol, interp_from=prev)
def solver_mle(strategy, *, correction, prior, ssm):
"""Create a solver that calibrates the output scale via maximum-likelihood.
Warning: needs to be combined with a call to stats.calibrate()
after solving if the MLE-calibration shall be *used*.
"""
def step_mle(state, /, *, dt, calibration):
u_step_from = tree_util.ravel_pytree(ssm.unravel(state.rv.mean)[0])[0]
# Estimate the error
output_scale_prior, _calibrated = calibration.extract(state.output_scale)
transition = prior(dt, output_scale_prior)
mean = ssm.stats.mean(state.rv)
mean_extra = ssm.conditional.apply(mean, transition)
t = state.t + dt
error, _, correction_state = correction.estimate_error(
mean_extra, state.correction_state, t=t
)
# Do the full prediction step (reuse previous discretisation)
hidden, extra = strategy.extrapolate(
state.rv, state.strategy_state, transition=transition
)
# Do the full correction step
hidden, observed, corr_state = correction.correct(hidden, correction_state, t=t)
# Calibrate the output scale
output_scale = calibration.update(state.output_scale, observed=observed)
# Normalise the error
state = _State(
t=t,
rv=hidden,
strategy_state=extra,
correction_state=corr_state,
output_scale=output_scale,
)
u_proposed = tree_util.ravel_pytree(ssm.unravel(state.rv.mean)[0])[0]
reference = np.maximum(np.abs(u_proposed), np.abs(u_step_from))
error = _ErrorEstimate(dt * error, reference=reference)
return error, state
return _ProbabilisticSolver(
ssm=ssm,
name="Probabilistic solver with MLE calibration",
prior=prior,
calibration=_calibration_running_mean(ssm=ssm),
step_implementation=step_mle,
strategy=strategy,
correction=correction,
)
def _calibration_running_mean(*, ssm) -> _Calibration:
def init():
prior = np.ones_like(ssm.prototypes.output_scale())
return prior, prior, 0.0
def update(state, /, observed):
prior, calibrated, num_data = state
new_term = ssm.stats.mahalanobis_norm_relative(0.0, observed)
calibrated = ssm.stats.update_mean(calibrated, new_term, num=num_data)
return prior, calibrated, num_data + 1.0
def extract(state, /):
prior, calibrated, _num_data = state
return prior, calibrated
return _Calibration(init=init, update=update, extract=extract)
def solver_dynamic(strategy, *, correction, prior, ssm):
"""Create a solver that calibrates the output scale dynamically."""
def step_dynamic(state, /, *, dt, calibration):
u_step_from = tree_util.ravel_pytree(ssm.unravel(state.rv.mean)[0])[0]
# Estimate error and calibrate the output scale
ones = np.ones_like(ssm.prototypes.output_scale())
transition = prior(dt, ones)
mean = ssm.stats.mean(state.rv)
hidden = ssm.conditional.apply(mean, transition)
t = state.t + dt
error, observed, correction_state = correction.estimate_error(
hidden, state.correction_state, t=t
)
output_scale = calibration.update(state.output_scale, observed=observed)
# Do the full extrapolation with the calibrated output scale
scale, _ = calibration.extract(output_scale)
transition = prior(dt, scale)
hidden, extra = strategy.extrapolate(
state.rv, state.strategy_state, transition=transition
)
# Do the full correction step
hidden, _, correction_state = correction.correct(hidden, correction_state, t=t)
# Return solution
state = _State(
t=t,
rv=hidden,
strategy_state=extra,
correction_state=correction_state,
output_scale=output_scale,
)
# Normalise the error
u_proposed = tree_util.ravel_pytree(ssm.unravel(state.rv.mean)[0])[0]
reference = np.maximum(np.abs(u_proposed), np.abs(u_step_from))
error = _ErrorEstimate(dt * error, reference=reference)
return error, state
return _ProbabilisticSolver(
prior=prior,
ssm=ssm,
strategy=strategy,
correction=correction,
calibration=_calibration_most_recent(ssm=ssm),
name="Dynamic probabilistic solver",
step_implementation=step_dynamic,
)
def _calibration_most_recent(*, ssm) -> _Calibration:
def init():
return np.ones_like(ssm.prototypes.output_scale())
def update(_state, /, observed):
return ssm.stats.mahalanobis_norm_relative(0.0, observed)
def extract(state, /):
return state, state
return _Calibration(init=init, update=update, extract=extract)
def solver(strategy, *, correction, prior, ssm):
"""Create a solver that does not calibrate the output scale automatically."""
def step(state: _State, *, dt, calibration):
del calibration # unused
u_step_from = tree_util.ravel_pytree(ssm.unravel(state.rv.mean)[0])[0]
# Estimate the error
transition = prior(dt, state.output_scale)
mean = ssm.stats.mean(state.rv)
hidden = ssm.conditional.apply(mean, transition)
t = state.t + dt
error, _, correction_state = correction.estimate_error(
hidden, state.correction_state, t=t
)
# Do the full extrapolation step (reuse the transition)
hidden, extra = strategy.extrapolate(
state.rv, state.strategy_state, transition=transition
)
# Do the full correction step
hidden, _, correction_state = correction.correct(hidden, correction_state, t=t)
state = _State(
t=t,
rv=hidden,
strategy_state=extra,
correction_state=correction_state,
output_scale=state.output_scale,
)
# Normalise the error
u_proposed = tree_util.ravel_pytree(ssm.unravel(state.rv.mean)[0])[0]
reference = np.maximum(np.abs(u_proposed), np.abs(u_step_from))
error = _ErrorEstimate(dt * error, reference=reference)
return error, state
return _ProbabilisticSolver(
ssm=ssm,
prior=prior,
strategy=strategy,
correction=correction,
calibration=_calibration_none(ssm=ssm),
step_implementation=step,
name="Probabilistic solver",
)
def _calibration_none(*, ssm) -> _Calibration:
def init():
return np.ones_like(ssm.prototypes.output_scale())
def update(_state, /, observed):
raise NotImplementedError
def extract(state, /):
return state, state
return _Calibration(init=init, update=update, extract=extract)
def adaptive(
slvr,
/,
*,
ssm,
atol=1e-4,
rtol=1e-2,
control=None,
norm_ord=None,
clip_dt: bool = False,
eps: float | None = None,
):
"""Make an IVP solver adaptive."""
if control is None:
control = control_proportional_integral()
if eps is None:
eps = 10 * np.finfo_eps(float)
return _AdaSolver(
slvr,
ssm=ssm,
atol=atol,
rtol=rtol,
control=control,
norm_ord=norm_ord,
clip_dt=clip_dt,
eps=eps,
)
class _AdaState(containers.NamedTuple):
dt: float
step_from: Any
interp_from: Any
control: Any
stats: Any
class _AdaSolver:
"""Adaptive IVP solvers."""
def __init__(
self,
slvr: _ProbabilisticSolver,
/,
*,
atol,
rtol,
control,
norm_ord,
ssm,
clip_dt: bool,
eps: float,
):
self.solver = slvr
self.atol = atol
self.rtol = rtol
self.control = control