Skip to content

Commit 7536b0d

Browse files
authored
Merge pull request #656 from CUQI-DTU/multiple_inputs_pde
Multiple inputs support for pde gradient
2 parents 52c8f8f + f52546b commit 7536b0d

File tree

4 files changed

+382
-59
lines changed

4 files changed

+382
-59
lines changed

cuqi/model/_model.py

Lines changed: 73 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def _non_default_args(self):
179179
self._stored_non_default_args =\
180180
cuqi.utilities.get_non_default_args(self._forward_func)
181181
return self._stored_non_default_args
182-
182+
183183
@property
184184
def number_of_inputs(self):
185185
""" The number of inputs of the model. """
@@ -422,7 +422,7 @@ def _2fun(self, geometry=None, is_par=True, **kwargs):
422422
# Use CUQIarray funvals if geometry is consistent
423423
if isinstance(v, CUQIarray) and v.geometry == geometries[i]:
424424
kwargs[k] = v.funvals
425-
# Else, if we still need to convert to function value (is_par[i] is True)
425+
# Else, if we still need to convert to function value (is_par[i] is True)
426426
# we use the geometry par2fun method
427427
elif is_par[i] and v is not None:
428428
kwargs[k] = geometries[i].par2fun(v)
@@ -496,7 +496,7 @@ def _2par(self, geometry=None, to_CUQIarray=False, is_par=False, **kwargs):
496496
# Use CUQIarray parameters if geometry is consistent
497497
if isinstance(v, CUQIarray) and v.geometry == geometries[i]:
498498
v = v.parameters
499-
# Else, if we still need to convert to parameter value (is_par[i] is False)
499+
# Else, if we still need to convert to parameter value (is_par[i] is False)
500500
# we use the geometry fun2par method
501501
elif not is_par[i] and v is not None:
502502
v = geometries[i].fun2par(v)
@@ -665,7 +665,7 @@ def _parse_args_add_to_kwargs(
665665
error_msg = (
666666
"The "
667667
+ map_name.lower()
668-
+ f" input is specified by a keywords arguments {list(kwargs.keys())} that does not match the non_default_args of the "
668+
+ f" input is specified by keywords arguments {list(kwargs.keys())} that does not match the non_default_args of the "
669669
+ map_name
670670
+ f" {non_default_args}."
671671
)
@@ -808,7 +808,11 @@ def _handle_case_when_model_input_is_distributions(self, kwargs):
808808
new_model = copy(self)
809809

810810
# Store the original non_default_args of the model
811-
new_model._original_non_default_args = self._non_default_args
811+
new_model._original_non_default_args = (
812+
self._original_non_default_args
813+
if hasattr(self, "_original_non_default_args")
814+
else self._non_default_args
815+
)
812816

813817
# Update the non_default_args of the model to match the distribution
814818
# names. Defaults to x in the case of only one distribution that has no
@@ -1052,7 +1056,7 @@ def _apply_chain_rule_to_account_for_domain_geometry_gradient(self,
10521056

10531057
# turn grad_is_par to a tuple of bools if it is not already
10541058
if isinstance(grad_is_par, bool):
1055-
grad_is_par = tuple([grad_is_par]*len(grad))
1059+
grad_is_par = tuple([grad_is_par]*self.number_of_inputs)
10561060

10571061
# If the domain geometry is a _ProductGeometry and the gradient is
10581062
# stacked, split it
@@ -1451,7 +1455,7 @@ class PDEModel(Model):
14511455
:ivar range_geometry: The geometry representing the range.
14521456
:ivar domain_geometry: The geometry representing the domain.
14531457
"""
1454-
def __init__(self, PDE: cuqi.pde.PDE, range_geometry, domain_geometry):
1458+
def __init__(self, PDE: cuqi.pde.PDE, range_geometry, domain_geometry, **kwargs):
14551459

14561460
if not isinstance(PDE, cuqi.pde.PDE):
14571461
raise ValueError("PDE needs to be a cuqi PDE.")
@@ -1460,23 +1464,30 @@ def __init__(self, PDE: cuqi.pde.PDE, range_geometry, domain_geometry):
14601464
self.pde = PDE
14611465
self._stored_non_default_args = None
14621466

1463-
super().__init__(self._forward_func, range_geometry, domain_geometry)
1467+
# If gradient or jacobian is not provided, we create it from the PDE
1468+
if not np.any([k in kwargs.keys() for k in ["gradient", "jacobian"]]):
1469+
# Create gradient or jacobian function to pass to the Model based on
1470+
# the PDE object. The dictionary derivative_kwarg contains the
1471+
# created function along with the function type (either "gradient"
1472+
# or "jacobian")
1473+
derivative_kwarg = self._create_derivative_function()
1474+
# append derivative_kwarg to kwargs
1475+
kwargs.update(derivative_kwarg)
1476+
1477+
super().__init__(forward=self._forward_func_pde,
1478+
range_geometry=range_geometry,
1479+
domain_geometry=domain_geometry,
1480+
**kwargs)
14641481

14651482
@property
14661483
def _non_default_args(self):
14671484
if self._stored_non_default_args is None:
14681485
# extract the non-default arguments of the PDE
1469-
self._stored_non_default_args = cuqi.utilities.get_non_default_args(
1470-
self.pde.PDE_form
1471-
)
1472-
# remove t from the non-default arguments
1473-
self._stored_non_default_args = self._non_default_args
1474-
if "t" in self._non_default_args:
1475-
self._stored_non_default_args.remove("t")
1486+
self._stored_non_default_args = self.pde._non_default_args
14761487

14771488
return self._stored_non_default_args
14781489

1479-
def _forward_func(self, **kwargs):
1490+
def _forward_func_pde(self, **kwargs):
14801491

14811492
self.pde.assemble(**kwargs)
14821493

@@ -1486,14 +1497,55 @@ def _forward_func(self, **kwargs):
14861497

14871498
return obs
14881499

1489-
def _gradient_func(self, direction, wrt):
1490-
""" Compute direction-Jacobian product (gradient) of the model. """
1500+
def _create_derivative_function(self):
1501+
"""Private function that creates the derivative function (gradient or
1502+
jacobian) based on the PDE object. The derivative function is created as
1503+
a lambda function that takes the direction and the parameters as input
1504+
and returns the gradient or jacobian of the PDE. This private function
1505+
returns a dictionary with the created function and the function type
1506+
(either "gradient" or "jacobian")."""
1507+
14911508
if hasattr(self.pde, "gradient_wrt_parameter"):
1492-
return self.pde.gradient_wrt_parameter(direction, wrt)
1509+
# Build the string that will be used to create the lambda function
1510+
function_str = (
1511+
"lambda direction, "
1512+
+ ", ".join(self._non_default_args)
1513+
+ ", pde_func: pde_func(direction, "
1514+
+ ", ".join(self._non_default_args)
1515+
+ ")"
1516+
)
1517+
1518+
# create the lambda function from the string
1519+
function = eval(function_str)
1520+
1521+
# create partial function from the lambda function with gradient_wrt_parameter
1522+
# as the first argument
1523+
grad_func = partial(function, pde_func=self.pde.gradient_wrt_parameter)
1524+
1525+
# Return the gradient function
1526+
return {"gradient": grad_func}
1527+
14931528
elif hasattr(self.pde, "jacobian_wrt_parameter"):
1494-
return direction@self.pde.jacobian_wrt_parameter(wrt)
1529+
# Build the string that will be used to create the lambda function
1530+
function_str = (
1531+
"lambda "
1532+
+ ", ".join(self._non_default_args)
1533+
+ ", pde_func: pde_func( "
1534+
+ ", ".join(self._non_default_args)
1535+
+ ")"
1536+
)
1537+
1538+
# create the lambda function from the string
1539+
function = eval(function_str)
1540+
1541+
# create partial function from the lambda function with jacobian_wrt_parameter
1542+
# as the first argument
1543+
jacobian_func = partial(function, pde_func=self.pde.jacobian_wrt_parameter)
1544+
1545+
# Return the jacobian function
1546+
return {"jacobian": jacobian_func}
14951547
else:
1496-
raise NotImplementedError("Gradient is not implemented for this model.")
1548+
return {} # empty dictionary if no gradient or jacobian is found
14971549

14981550
# Add the underlying PDE class name to the repr.
14991551
def __repr__(self) -> str:

cuqi/pde/_pde.py

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from inspect import getsource
44
from scipy.interpolate import interp1d
55
import numpy as np
6+
from cuqi.utilities import get_non_default_args
67

78

89
class PDE(ABC):
@@ -29,6 +30,7 @@ def __init__(self, PDE_form, grid_sol=None, grid_obs=None, observation_map=None)
2930
self.grid_sol = grid_sol
3031
self.grid_obs = grid_obs
3132
self.observation_map = observation_map
33+
self._stored_non_default_args = None
3234

3335
@abstractmethod
3436
def assemble(self, *args, **kwargs):
@@ -64,6 +66,13 @@ def _compare_grid(grid1,grid2):
6466

6567
return equal_arrays
6668

69+
@property
70+
def _non_default_args(self):
71+
"""Returns the non-default arguments of the PDE_form function"""
72+
if self._stored_non_default_args is None:
73+
self._stored_non_default_args = get_non_default_args(self.PDE_form)
74+
return self._stored_non_default_args
75+
6776
@property
6877
def grid_sol(self):
6978
if hasattr(self,"_grid_sol"):
@@ -94,6 +103,48 @@ def grid_obs(self,value):
94103
def grids_equal(self):
95104
return self._grids_equal
96105

106+
def _parse_args_add_to_kwargs(
107+
self, *args, map_name, **kwargs):
108+
""" Private function that parses the input arguments and adds them as
109+
keyword arguments matching (the order of) the non default arguments of
110+
the pde class.
111+
"""
112+
113+
# If any args are given, add them to kwargs
114+
if len(args) > 0:
115+
if len(kwargs) > 0:
116+
raise ValueError(
117+
+ map_name.lower()
118+
+ " input is specified both as positional and keyword arguments. This is not supported."
119+
)
120+
121+
# Check if the number of args does not match the number of
122+
# non_default_args of the model
123+
if len(args) != len(self._non_default_args):
124+
raise ValueError(
125+
"The number of positional arguments does not match the number of non-default arguments of "
126+
+ map_name.lower()
127+
+ "."
128+
)
129+
130+
# Add args to kwargs following the order of non_default_args
131+
for idx, arg in enumerate(args):
132+
kwargs[self._non_default_args[idx]] = arg
133+
134+
# Check kwargs matches non_default_args
135+
if set(list(kwargs.keys())) != set(self._non_default_args):
136+
error_msg = (
137+
map_name.lower()
138+
+ f" input is specified by keywords arguments {list(kwargs.keys())} that does not match the non_default_args of "
139+
+ map_name
140+
+ f" {self._non_default_args}."
141+
)
142+
raise ValueError(error_msg)
143+
144+
# Make sure order of kwargs is the same as non_default_args
145+
kwargs = {k: kwargs[k] for k in self._non_default_args}
146+
147+
return kwargs
97148

98149
class LinearPDE(PDE):
99150
"""
@@ -143,7 +194,7 @@ class SteadyStateLinearPDE(LinearPDE):
143194
Parameters
144195
-----------
145196
PDE_form : callable function
146-
Callable function with signature `PDE_form(parameter)` where `parameter` is the Bayesian parameter. The function returns a tuple with the discretized differential operator A and right-hand-side b. The types of A and b are determined by what the method :meth:`linalg_solve` accepts as first and second parameters, respectively.
197+
Callable function with signature `PDE_form(parameter1, parameter2, ...)` where `parameter1`, `parameter2`, etc. are the Bayesian unknown parameters (the user can choose any names for these parameters, e.g. `a`, `b`, etc.). The function returns a tuple with the discretized differential operator A and right-hand-side b. The types of A and b are determined by what the method :meth:`linalg_solve` accepts as first and second parameters, respectively.
147198
148199
kwargs:
149200
See :class:`~cuqi.pde.LinearPDE` for the remaining keyword arguments.
@@ -158,7 +209,10 @@ def __init__(self, PDE_form, **kwargs):
158209

159210
def assemble(self, *args, **kwargs):
160211
"""Assembles differential operator and rhs according to PDE_form"""
161-
self.diff_op, self.rhs = self.PDE_form(*args, **kwargs)
212+
kwargs = self._parse_args_add_to_kwargs(
213+
*args, map_name="assemble", **kwargs
214+
)
215+
self.diff_op, self.rhs = self.PDE_form(**kwargs)
162216

163217
def solve(self):
164218
"""Solve the PDE and returns the solution and an information variable `info` which is a tuple of all variables returned by the function `linalg_solve` after the solution."""
@@ -186,7 +240,7 @@ class TimeDependentLinearPDE(LinearPDE):
186240
Parameters
187241
-----------
188242
PDE_form : callable function
189-
Callable function with signature `PDE_form(parameter, t)` where `parameter` is the Bayesian parameter and `t` is the time at which the PDE form is evaluated. The function returns a tuple of (`differential_operator`, `source_term`, `initial_condition`) where `differential_operator` is the linear operator at time `t`, `source_term` is the source term at time `t`, and `initial_condition` is the initial condition. The types of `differential_operator` and `source_term` are determined by what the method :meth:`linalg_solve` accepts as linear operator and right-hand side, respectively. The type of `initial_condition` should be the same type as the solution returned by :meth:`linalg_solve`.
243+
Callable function with signature `PDE_form(parameter1, parameter2, ..., t)` where `parameter1`, `parameter2`, etc. are the Bayesian unknown parameters (the user can choose any names for these parameters, e.g. `a`, `b`, etc.) and `t` is the time at which the PDE form is evaluated. The function returns a tuple of (`differential_operator`, `source_term`, `initial_condition`) where `differential_operator` is the linear operator at time `t`, `source_term` is the source term at time `t`, and `initial_condition` is the initial condition. The types of `differential_operator` and `source_term` are determined by what the method :meth:`linalg_solve` accepts as linear operator and right-hand side, respectively. The type of `initial_condition` should be the same type as the solution returned by :meth:`linalg_solve`.
190244
191245
time_steps : ndarray
192246
An array of the discretized times corresponding to the time steps that starts with the initial time and ends with the final time
@@ -228,6 +282,18 @@ def __init__(self, PDE_form, time_steps, time_obs='final', method='forward_euler
228282
def method(self):
229283
return self._method
230284

285+
@property
286+
def _non_default_args(self):
287+
"""Returns the non-default arguments of the PDE_form function"""
288+
if self._stored_non_default_args is None:
289+
self._stored_non_default_args = get_non_default_args(self.PDE_form)
290+
# Remove the time argument from the non-default arguments
291+
# since it is provided automatically by `solve` method and is not
292+
# an argument to be inferred in Bayesian inference setting.
293+
if 't' in self._stored_non_default_args:
294+
self._stored_non_default_args.remove('t')
295+
return self._stored_non_default_args
296+
231297
@method.setter
232298
def method(self, value):
233299
if value.lower() != 'forward_euler' and value.lower() != 'backward_euler':
@@ -237,13 +303,13 @@ def method(self, value):
237303

238304
def assemble(self, *args, **kwargs):
239305
"""Assemble PDE"""
306+
kwargs = self._parse_args_add_to_kwargs(*args, map_name="assemble", **kwargs)
240307
self._parameter_kwargs = kwargs
241-
self._parameter_args = args
242308

243309
def assemble_step(self, t):
244310
"""Assemble time step at time t"""
245311
self.diff_op, self.rhs, self.initial_condition = self.PDE_form(
246-
*self._parameter_args, **self._parameter_kwargs, t=t
312+
**self._parameter_kwargs, t=t
247313
)
248314

249315
def solve(self):

cuqi/utilities/_utilities.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ def approx_derivative(func, wrt, direction=None, epsilon=np.sqrt(np.finfo(float)
188188
# We compute the Jacobian matrix of func using forward differences.
189189
# If the function is scalar-valued, we compute the gradient instead.
190190
# If the direction is provided, we compute the direction-Jacobian product.
191-
wrt = np.asarray(wrt)
191+
wrt = force_ndarray(wrt, flatten=True)
192192
f0 = func(wrt)
193193
Matr = np.zeros([infer_len(wrt), infer_len(f0)])
194194
dx = np.zeros(len(wrt))

0 commit comments

Comments
 (0)