Skip to content

Commit e8fd90f

Browse files
authored
Merge pull request #625 from CUQI-DTU/access_pde_sol
Enable general observation function for PDE, and accessing original solution (or other observable quantities)
2 parents 503cf03 + 5d1df61 commit e8fd90f

File tree

10 files changed

+186
-57
lines changed

10 files changed

+186
-57
lines changed

cuqi/pde/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,7 @@
44
SteadyStateLinearPDE,
55
TimeDependentLinearPDE
66
)
7+
8+
from ._observation_map import (
9+
FD_spatial_gradient
10+
)

cuqi/pde/_observation_map.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import scipy
2+
import numpy as np
3+
"""
4+
This module contains observation map examples for PDE problems. The map can
5+
be passed to the `PDE` object initializer via the `observation_map` argument.
6+
7+
For example on how to use set observation maps in time dependent PDEs, see
8+
`demos/howtos/TimeDependentLinearPDE.py`.
9+
"""
10+
11+
# 1. Steady State Observation Maps
12+
# --------------------------------
13+
14+
# 2. Time-Dependent Observation Maps
15+
# -----------------------------------
16+
def FD_spatial_gradient(sol, grid, times):
17+
"""Time dependent observation map that computes the finite difference (FD) spatial gradient of a solution given at grid points (grid) and times (times). This map is supported for 1D spatial domains only.
18+
19+
Parameters
20+
----------
21+
sol : np.ndarray
22+
The solution array of shape (number of grid points, number of time steps).
23+
24+
grid : np.ndarray
25+
The spatial grid points of shape (number of grid points,).
26+
27+
times : np.ndarray
28+
The discretized time steps of shape (number of time steps,)."""
29+
30+
if len(grid.shape) != 1:
31+
raise ValueError("FD_spatial_gradient only supports 1D spatial domains.")
32+
observed_quantity = np.zeros((len(grid)-1, len(times)))
33+
for i in range(observed_quantity.shape[0]):
34+
observed_quantity[i, :] = ((sol[i, :] - sol[i+1, :])/
35+
(grid[i] - grid[i+1]))
36+
return observed_quantity

cuqi/pde/_pde.py

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@ class PDE(ABC):
1515
PDE_form : callable function
1616
Callable function which returns a tuple of the needed PDE components (expected components are explained in the subclasses)
1717
18-
observation_map: a function handle
19-
A function that takes the PDE solution as input and the returns the observed solution. e.g. `observation_map=lambda u: u**2` or `observation_map=lambda u: u[0]`
20-
2118
grid_sol: np.ndarray
2219
The grid on which solution is defined
2320
2421
grid_obs: np.ndarray
25-
The grid on which the observed solution should be interpolated (currently only supported for 1D problems).
22+
The grid on which the observed solution should be interpolated (currently only supported for 1D problems).
23+
24+
observation_map: a function handle
25+
A function that takes the PDE solution, interpolated on `grid_obs`, as input and returns the observed solution. e.g., `observation_map=lambda u, grid_obs: u**2`.
26+
2627
"""
2728

2829
def __init__(self, PDE_form, grid_sol=None, grid_obs=None, observation_map=None):
@@ -187,14 +188,21 @@ def _solve_linear_system(self, A, b, linalg_solve, kwargs):
187188
info = None
188189

189190
return solution, info
191+
192+
def interpolate_on_observed_domain(self, solution):
193+
"""Interpolate solution on observed space domain."""
194+
raise NotImplementedError("interpolate_on_observed_domain method is not implemented for LinearPDE base class.")
190195

191196
class SteadyStateLinearPDE(LinearPDE):
192197
"""Linear steady state PDE.
193198
194199
Parameters
195200
-----------
196201
PDE_form : callable function
197-
Callable function with signature `PDE_form(parameter1, parameter2, ...)` where `parameter1`, `parameter2`, etc. are the Bayesian unknown parameters (the user can choose any names for these parameters, e.g. `a`, `b`, etc.). The function returns a tuple with the discretized differential operator A and right-hand-side b. The types of A and b are determined by what the method :meth:`linalg_solve` accepts as first and second parameters, respectively.
202+
Callable function with signature `PDE_form(parameter1, parameter2, ...)` where `parameter1`, `parameter2`, etc. are the Bayesian unknown parameters (the user can choose any names for these parameters, e.g. `a`, `b`, etc.). The function returns a tuple with the discretized differential operator A and right-hand-side b. The types of A and b are determined by what the method :meth:`linalg_solve` accepts as first and second parameters, respectively.
203+
204+
observation_map: a function handle
205+
A function that takes the PDE solution, interpolated on `grid_obs`, as input and returns the observed solution. e.g. `observation_map=lambda u, grid_obs: u**2`.
198206
199207
kwargs:
200208
See :class:`~cuqi.pde.LinearPDE` for the remaining keyword arguments.
@@ -204,8 +212,8 @@ class SteadyStateLinearPDE(LinearPDE):
204212
See demo demos/demo24_fwd_poisson.py for an illustration on how to use SteadyStateLinearPDE with varying solver choices. And demos demos/demo25_fwd_poisson_2D.py and demos/demo26_fwd_poisson_mixedBC.py for examples with mixed (Dirichlet and Neumann) boundary conditions problems. demos/demo25_fwd_poisson_2D.py also illustrates how to observe on a specific boundary, for example.
205213
"""
206214

207-
def __init__(self, PDE_form, **kwargs):
208-
super().__init__(PDE_form, **kwargs)
215+
def __init__(self, PDE_form, observation_map=None, **kwargs):
216+
super().__init__(PDE_form, observation_map=observation_map, **kwargs)
209217

210218
def assemble(self, *args, **kwargs):
211219
"""Assembles differential operator and rhs according to PDE_form"""
@@ -221,17 +229,25 @@ def solve(self):
221229

222230
return self._solve_linear_system(self.diff_op, self.rhs, self._linalg_solve, self._linalg_solve_kwargs)
223231

224-
225-
def observe(self, solution):
226-
232+
def interpolate_on_observed_domain(self, solution):
233+
"""Interpolate solution on observed space grid."""
227234
if self.grids_equal:
228235
solution_obs = solution
229236
else:
230237
solution_obs = interp1d(self.grid_sol, solution, kind='quadratic')(self.grid_obs)
238+
return solution_obs
239+
240+
def observe(self, solution):
241+
"""Apply observation operator to the solution. This includes
242+
interpolation to observation points (if different from the
243+
solution grid) then applying the observation map (if provided)."""
244+
245+
# Interpolate solution on observed domain
246+
solution_obs = self.interpolate_on_observed_domain(solution)
231247

232248
if self.observation_map is not None:
233-
solution_obs = self.observation_map(solution_obs)
234-
249+
solution_obs = self.observation_map(solution_obs, self.grid_obs)
250+
235251
return solution_obs
236252

237253
class TimeDependentLinearPDE(LinearPDE):
@@ -251,16 +267,20 @@ class TimeDependentLinearPDE(LinearPDE):
251267
method: str
252268
Time stepping method. Currently two options are available `forward_euler` and `backward_euler`.
253269
270+
observation_map: a function handle
271+
A function that takes the PDE solution, interpolated on `grid_obs` and `time_obs`, as input and returns the observed solution. e.g. `observation_map=lambda u, grid_obs, time_obs: u**2`.
272+
254273
kwargs:
255274
See :class:`~cuqi.pde.LinearPDE` for the remaining keyword arguments
256275
257276
Example
258277
-----------
259-
See demos/demo34_TimeDependentLinearPDE.py for 1D heat and 1D wave equations.
278+
See demos/howtos/TimeDependentLinearPDE.py for 1D heat and 1D wave equations examples. It demonstrates setting up `TimeDependentLinearPDE` objects, including the choice of time stepping methods, observation domain, and observation map.
260279
"""
261280

262-
def __init__(self, PDE_form, time_steps, time_obs='final', method='forward_euler', **kwargs):
263-
super().__init__(PDE_form, **kwargs)
281+
def __init__(self, PDE_form, time_steps, time_obs='final',
282+
method='forward_euler', observation_map=None, **kwargs):
283+
super().__init__(PDE_form, observation_map=observation_map, **kwargs)
264284

265285
self.time_steps = time_steps
266286
self.method = method
@@ -339,8 +359,8 @@ def solve(self):
339359

340360
return u, info
341361

342-
def observe(self, solution):
343-
362+
def interpolate_on_observed_domain(self, solution):
363+
"""Interpolate solution on observed time and space points."""
344364
# If observation grid is the same as solution grid and observation time
345365
# is the final time step then no need to interpolate
346366
if self.grids_equal and np.all(self.time_steps[-1:] == self._time_obs):
@@ -361,15 +381,26 @@ def observe(self, solution):
361381
# Interpolate solution in space and time to the observation
362382
# time and space
363383
solution_obs = scipy.interpolate.RectBivariateSpline(
364-
self.grid_sol, self.time_steps, solution)(self.grid_obs,
365-
self._time_obs)
384+
self.grid_sol, self.time_steps, solution
385+
)(self.grid_obs, self._time_obs)
366386

387+
return solution_obs
388+
389+
def observe(self, solution):
390+
"""Apply observation operator to the solution. This includes
391+
interpolation to observation points (if different from the
392+
solution grid) then applying the observation map (if provided)."""
393+
394+
# Interpolate solution on observed domain
395+
solution_obs = self.interpolate_on_observed_domain(solution)
396+
367397
# Apply observation map
368398
if self.observation_map is not None:
369-
solution_obs = self.observation_map(solution_obs)
399+
solution_obs = self.observation_map(solution_obs, self.grid_obs,
400+
self._time_obs)
370401

371402
# squeeze if only one time observation
372403
if len(self._time_obs) == 1:
373404
solution_obs = solution_obs.squeeze()
374405

375-
return solution_obs
406+
return solution_obs

0 commit comments

Comments
 (0)