Skip to content

Commit a5746f0

Browse files
committed
Call solve_and_save_at in solve_and_save_terminal_values
1 parent b91bb40 commit a5746f0

1 file changed

Lines changed: 7 additions & 29 deletions

File tree

probdiffeq/ivpsolve.py

Lines changed: 7 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from probdiffeq.backend import (
55
containers,
66
control_flow,
7-
functools,
87
linalg,
98
tree_array_util,
109
tree_util,
@@ -89,42 +88,21 @@ def solve_adaptive_terminal_values(
8988
vector_field, initial_condition, t0, t1, adaptive_solver, dt0, *, ssm
9089
) -> IVPSolution:
9190
"""Simulate the terminal values of an initial value problem."""
92-
save_at = np.asarray([t1])
93-
(_t, solution_save_at), _, num_steps = _solve_adaptive_save_at(
94-
tree_util.Partial(vector_field),
95-
t0,
91+
save_at = np.asarray([t0, t1])
92+
solution = solve_adaptive_save_at(
93+
vector_field,
9694
initial_condition,
9795
save_at=save_at,
9896
adaptive_solver=adaptive_solver,
9997
dt0=dt0,
100-
)
101-
# "squeeze"-type functionality (there is only a single state!)
102-
squeeze_fun = functools.partial(np.squeeze_along_axis, axis=0)
103-
solution_save_at = tree_util.tree_map(squeeze_fun, solution_save_at)
104-
num_steps = tree_util.tree_map(squeeze_fun, num_steps)
105-
106-
# I think the user expects marginals, so we compute them here
107-
# todo: do this in IVPSolution.* methods?
108-
posterior, output_scale = solution_save_at
109-
marginals = posterior.init if isinstance(posterior, stats.MarkovSeq) else posterior
110-
111-
u = ssm.stats.qoi_from_sample(marginals.mean)
112-
std = ssm.stats.standard_deviation(marginals)
113-
u_std = ssm.stats.qoi_from_sample(std)
114-
return IVPSolution(
115-
t=t1,
116-
u=u,
117-
u_std=u_std,
11898
ssm=ssm,
119-
marginals=marginals,
120-
posterior=posterior,
121-
output_scale=output_scale,
122-
num_steps=num_steps,
99+
warn=False, # Turn off warnings because any solver goes for terminal values
123100
)
101+
return tree_util.tree_map(lambda s: s[-1], solution)
124102

125103

126104
def solve_adaptive_save_at(
127-
vector_field, initial_condition, save_at, adaptive_solver, dt0, *, ssm
105+
vector_field, initial_condition, save_at, adaptive_solver, dt0, *, ssm, warn=True
128106
) -> IVPSolution:
129107
r"""Solve an initial value problem and return the solution at a pre-determined grid.
130108
@@ -152,7 +130,7 @@ def solve_adaptive_save_at(
152130
}
153131
```
154132
"""
155-
if not adaptive_solver.solver.is_suitable_for_save_at:
133+
if not adaptive_solver.solver.is_suitable_for_save_at and warn:
156134
msg = (
157135
f"Strategy {adaptive_solver.solver} should not "
158136
f"be used in solve_adaptive_save_at. "

0 commit comments

Comments
 (0)