|
4 | 4 | from probdiffeq.backend import ( |
5 | 5 | containers, |
6 | 6 | control_flow, |
7 | | - functools, |
8 | 7 | linalg, |
9 | 8 | tree_array_util, |
10 | 9 | tree_util, |
@@ -89,42 +88,21 @@ def solve_adaptive_terminal_values( |
89 | 88 | vector_field, initial_condition, t0, t1, adaptive_solver, dt0, *, ssm |
90 | 89 | ) -> IVPSolution: |
91 | 90 | """Simulate the terminal values of an initial value problem.""" |
92 | | - save_at = np.asarray([t1]) |
93 | | - (_t, solution_save_at), _, num_steps = _solve_adaptive_save_at( |
94 | | - tree_util.Partial(vector_field), |
95 | | - t0, |
| 91 | + save_at = np.asarray([t0, t1]) |
| 92 | + solution = solve_adaptive_save_at( |
| 93 | + vector_field, |
96 | 94 | initial_condition, |
97 | 95 | save_at=save_at, |
98 | 96 | adaptive_solver=adaptive_solver, |
99 | 97 | dt0=dt0, |
100 | | - ) |
101 | | - # "squeeze"-type functionality (there is only a single state!) |
102 | | - squeeze_fun = functools.partial(np.squeeze_along_axis, axis=0) |
103 | | - solution_save_at = tree_util.tree_map(squeeze_fun, solution_save_at) |
104 | | - num_steps = tree_util.tree_map(squeeze_fun, num_steps) |
105 | | - |
106 | | - # I think the user expects marginals, so we compute them here |
107 | | - # todo: do this in IVPSolution.* methods? |
108 | | - posterior, output_scale = solution_save_at |
109 | | - marginals = posterior.init if isinstance(posterior, stats.MarkovSeq) else posterior |
110 | | - |
111 | | - u = ssm.stats.qoi_from_sample(marginals.mean) |
112 | | - std = ssm.stats.standard_deviation(marginals) |
113 | | - u_std = ssm.stats.qoi_from_sample(std) |
114 | | - return IVPSolution( |
115 | | - t=t1, |
116 | | - u=u, |
117 | | - u_std=u_std, |
118 | 98 | ssm=ssm, |
119 | | - marginals=marginals, |
120 | | - posterior=posterior, |
121 | | - output_scale=output_scale, |
122 | | - num_steps=num_steps, |
| 99 | + warn=False, # Turn off warnings because any solver goes for terminal values |
123 | 100 | ) |
| 101 | + return tree_util.tree_map(lambda s: s[-1], solution) |
124 | 102 |
|
125 | 103 |
|
126 | 104 | def solve_adaptive_save_at( |
127 | | - vector_field, initial_condition, save_at, adaptive_solver, dt0, *, ssm |
| 105 | + vector_field, initial_condition, save_at, adaptive_solver, dt0, *, ssm, warn=True |
128 | 106 | ) -> IVPSolution: |
129 | 107 | r"""Solve an initial value problem and return the solution at a pre-determined grid. |
130 | 108 |
|
@@ -152,7 +130,7 @@ def solve_adaptive_save_at( |
152 | 130 | } |
153 | 131 | ``` |
154 | 132 | """ |
155 | | - if not adaptive_solver.solver.is_suitable_for_save_at: |
| 133 | + if not adaptive_solver.solver.is_suitable_for_save_at and warn: |
156 | 134 | msg = ( |
157 | 135 | f"Strategy {adaptive_solver.solver} should not " |
158 | 136 | f"be used in solve_adaptive_save_at. " |
|
0 commit comments