Skip to content

Commit b9a5dd3

Browse files
authored
Merge pull request #610 from CUQI-DTU/mi_fwd_model
Multiple Inputs: update the Model module
2 parents c62aed4 + afe20b4 commit b9a5dd3

File tree

7 files changed

+2547
-415
lines changed

7 files changed

+2547
-415
lines changed

cuqi/experimental/geometry/_productgeometry.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,10 +172,10 @@ def fun2vec(self, *funvals):
172172
return tuple(funvecs)
173173

174174

175-
def __repr__(self) -> str:
175+
def __repr__(self, pad="") -> str:
176176
"""Representation of the product geometry."""
177177
string = "{}(".format(self.__class__.__name__) + "\n"
178178
for g in self.geometries:
179-
string += "\t{}\n".format(g.__repr__())
180-
string += ")"
179+
string += pad + " {}\n".format(g.__repr__())
180+
string += pad + ")"
181181
return string

cuqi/model/_model.py

Lines changed: 1051 additions & 347 deletions
Large diffs are not rendered by default.

cuqi/pde/_pde.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from scipy.interpolate import interp1d
55
import numpy as np
66

7+
78
class PDE(ABC):
89
"""
910
Parametrized PDE abstract base class
@@ -30,7 +31,7 @@ def __init__(self, PDE_form, grid_sol=None, grid_obs=None, observation_map=None)
3031
self.observation_map = observation_map
3132

3233
@abstractmethod
33-
def assemble(self,parameter):
34+
def assemble(self, *args, **kwargs):
3435
pass
3536

3637
@abstractmethod
@@ -155,9 +156,9 @@ class SteadyStateLinearPDE(LinearPDE):
155156
def __init__(self, PDE_form, **kwargs):
156157
super().__init__(PDE_form, **kwargs)
157158

158-
def assemble(self, parameter):
159+
def assemble(self, *args, **kwargs):
159160
"""Assembles differential operator and rhs according to PDE_form"""
160-
self.diff_op, self.rhs = self.PDE_form(parameter)
161+
self.diff_op, self.rhs = self.PDE_form(*args, **kwargs)
161162

162163
def solve(self):
163164
"""Solve the PDE and returns the solution and an information variable `info` which is a tuple of all variables returned by the function `linalg_solve` after the solution."""
@@ -178,7 +179,7 @@ def observe(self, solution):
178179
solution_obs = self.observation_map(solution_obs)
179180

180181
return solution_obs
181-
182+
182183
class TimeDependentLinearPDE(LinearPDE):
183184
"""Time Dependent Linear PDE with fixed time stepping using Euler method (backward or forward).
184185
@@ -234,13 +235,16 @@ def method(self, value):
234235
"method can be set to either `forward_euler` or `backward_euler`")
235236
self._method = value
236237

237-
def assemble(self, parameter):
238+
def assemble(self, *args, **kwargs):
238239
"""Assemble PDE"""
239-
self._parameter = parameter
240+
self._parameter_kwargs = kwargs
241+
self._parameter_args = args
240242

241243
def assemble_step(self, t):
242244
"""Assemble time step at time t"""
243-
self.diff_op, self.rhs, self.initial_condition = self.PDE_form(self._parameter, t)
245+
self.diff_op, self.rhs, self.initial_condition = self.PDE_form(
246+
*self._parameter_args, **self._parameter_kwargs, t=t
247+
)
244248

245249
def solve(self):
246250
"""Solve PDE by time-stepping"""
@@ -279,15 +283,15 @@ def observe(self, solution):
279283
# Interpolate solution in time and space to the observation
280284
# time and space
281285
else:
282-
# Raise error if solution is 2D or 3D in space
286+
# Raise error if solution is 2D or 3D in space
283287
if len(solution.shape) > 2:
284288
raise ValueError("Interpolation of solutions of 2D and 3D "+
285289
"space dimensions based on the provided "+
286290
"grid_obs and time_obs are not supported. "+
287291
"You can, instead, pass a custom "+
288292
"observation_map and pass grid_obs and "+
289293
"time_obs as None.")
290-
294+
291295
# Interpolate solution in space and time to the observation
292296
# time and space
293297
solution_obs = scipy.interpolate.RectBivariateSpline(
@@ -297,7 +301,7 @@ def observe(self, solution):
297301
# Apply observation map
298302
if self.observation_map is not None:
299303
solution_obs = self.observation_map(solution_obs)
300-
304+
301305
# squeeze if only one time observation
302306
if len(self._time_obs) == 1:
303307
solution_obs = solution_obs.squeeze()

cuqi/testproblem/_testproblem.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -863,10 +863,9 @@ def PDE_form(IC, t): return (Dxx, np.zeros(N), IC)
863863
# Bayesian model
864864
x = cuqi.distribution.Gaussian(np.zeros(model.domain_dim), 1)
865865
y = cuqi.distribution.Gaussian(model(x), sigma2)
866-
867-
# Initialize Deconvolution as BayesianProblem problem
868-
super().__init__(y, x, y=data)
869866

867+
# Initialize Heat1D as BayesianProblem problem
868+
super().__init__(y, x, y=data)
870869
# Store exact values
871870
self.exactSolution = x_exact
872871
self.exactData = y_exact

tests/test_distribution.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ def beta_likelihood():
677677
lambda x: x**2,
678678
range_geometry=1,
679679
domain_geometry=1,
680-
gradient=lambda direction, wrt: 2*wrt*direction)
680+
gradient=lambda direction, x: 2*x*direction)
681681

682682
# set a gaussian prior
683683
x = cuqi.distribution.Gaussian(0, 1)

0 commit comments

Comments
 (0)