Skip to content
Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ jobs:
- uses: actions/setup-python@v4

- name: Setup Miniconda
uses: conda-incubator/setup-miniconda@v2
uses: conda-incubator/setup-miniconda@v3
with:
auto-update-conda: true
python-version: 3.8
miniforge-variant: Mambaforge
miniforge-variant: Miniforge3
miniforge-version: latest

- name: Conda info
Expand All @@ -30,7 +30,7 @@ jobs:
conda info
which python

- name: Mamba install FEniCS
- name: conda install FEniCS
shell: bash -l {0}
run: |
conda config --set always_yes yes
Expand Down
8 changes: 4 additions & 4 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ jobs:
fetch-depth: 1

- name: Setup Miniconda
uses: conda-incubator/setup-miniconda@v2
uses: conda-incubator/setup-miniconda@v3
with:
auto-update-conda: true
python-version: 3.8
miniforge-variant: Mambaforge
miniforge-variant: Miniforge3
miniforge-version: latest

- name: Conda info
Expand All @@ -32,12 +32,12 @@ jobs:
conda info
which python

- name: Mamba install FEniCS
- name: Conda install FEniCS
shell: bash -l {0}
run: |
conda config --set always_yes yes
conda config --add channels conda-forge
mamba create -n fenicsproject -c conda-forge fenics
conda create -n fenicsproject -c conda-forge fenics
conda activate fenicsproject
which python
python -c "from dolfin import *"
Expand Down
175 changes: 110 additions & 65 deletions cuqipy_fenics/pde.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
from abc import ABC, abstractmethod
from cuqi.pde import PDE
from cuqi.array import CUQIarray
from cuqi.utilities import get_non_default_args
import dolfin as dl
from copy import copy
import warnings
from functools import partial
from .utilities import _LazyUFLLoader
ufl = _LazyUFLLoader()

Expand All @@ -23,16 +25,18 @@ class FEniCSPDE(PDE,ABC):
----------
PDE_form : callable or tuple of two callables
If passed as a callable: the callable returns the weak form of the PDE.
The callable should take three arguments, the first argument is the
parameter (input of the forward model), the second argument is the state
variable (solution variable), and the third argument is the adjoint
The callable should take three or more arguments, the first arguments
are the unknown parameters (inputs of the forward model, e.g.
`parameter1`, `parameter2`), the second to last argument is the state
variable (solution variable), and the last argument is the adjoint
variable (the test variable in the weak formulation).

If passed as a tuple of two callables: the first callable returns the
weak form of the PDE left hand side, and the second callable returns the
weak form of the PDE right hand side. The left hand side callable takes
the same three arguments as described above. The right hand side
callable takes only the parameter and the adjoint variable as arguments.
the same arguments as described above. The right hand side callable
takes only the unknown parameters and the adjoint variable (the latter
being the last argument) as arguments.
See the example below.

mesh : FEniCS mesh
Expand All @@ -41,8 +45,9 @@ class FEniCSPDE(PDE,ABC):
solution_function_space : FEniCS function space
FEniCS function space object that defines the function space of the state variable (solution variable).

parameter_function_space : FEniCS function space
FEniCS function space object that defines the function space of the Bayesian parameter (input of the forward model).
parameter_function_space : FEniCS function space or a list of them
FEniCS function space object or a list of them that defines the function space of the unknown parameters (inputs of the forward model).
If multiple parameters are passed, the function space should be a list of FEniCS function spaces, one for each parameter.

dirichlet_bcs: FEniCS Dirichlet boundary condition object or a list of them
FEniCS Dirichlet boundary condition object(s) that define the Dirichlet boundary conditions of the PDE.
Expand All @@ -52,7 +57,7 @@ class FEniCSPDE(PDE,ABC):

observation_operator : python function handle, optional
Function handle of a python function that returns the observed quantity from the PDE solution. If not provided, the identity operator is assumed (i.e. the entire solution is observed).
This python function takes as input the Bayesian parameter (input of the forward model) and the state variable (solution variable) as first and second inputs, respectively.
This python function takes as input the unknown parameters, e.g. `parameter1`, `parameter2`, and the state variable (solution variable) in that order.

The returned observed quantity can be a ufl.algebra.Operator, FEniCS Function, np.ndarray, int, or float.

Expand Down Expand Up @@ -146,9 +151,15 @@ def __init__(self, PDE_form, mesh, solution_function_space,
for key, value in linalg_solve_kwargs.items():
self._solver.parameters[key] = value

# Initialize the parameter
self.parameter = dl.Function(self.parameter_function_space)

# Initialize the parameter (one or more)
# If only one parameter is passed, it is converted to a list
if not isinstance(self.parameter_function_space, (list, tuple)):
parameter_function_space_list = [self.parameter_function_space]
else:
parameter_function_space_list = self.parameter_function_space
self.parameter= {}
for i, k in enumerate(self._non_default_args):
self.parameter[k] = dl.Function(parameter_function_space_list[i])

@property
def parameter(self):
Expand All @@ -157,8 +168,8 @@ def parameter(self):

@parameter.setter
def parameter(self, value):
""" Set the parameter of the PDE. Since the PDE solution depends on the
parameter, this will set the PDE solution to None. """
""" Set the parameters of the PDE. Since the PDE solution depends on the
parameters, this will set the PDE solution to None. """
if value is None:
raise ValueError('Parameter cannot be None.')

Expand All @@ -169,7 +180,8 @@ def parameter(self, value):
# Subsequent times setting the parameter (avoid assigning the parameter
# to new object, set parameter array in place instead)
elif self._is_parameter_new(value):
self._parameter.vector().set_local(value.vector().get_local())
for key in self._non_default_args:
self._parameter[key].vector().set_local(value[key].vector().get_local())
# The operator in the solver is no longer valid
self._flags["is_operator_valid"] = False

Expand All @@ -179,6 +191,19 @@ def parameter(self, value):
self._gradient = None
self.rhs = None

@property
def parameter_args(self):
"""Get the args form of the parameter"""
args = list(self.parameter.values())
return args

@property
def _non_default_args(self):
form = self._form
if isinstance(self._form, tuple):
# extract non-default args from the lhs first form
form = self._form[0]
return get_non_default_args(form)[:-2] # Exclude the last two arguments (u and p) from the list of non-default args

@property
def forward_solution(self):
Expand All @@ -189,7 +214,24 @@ def forward_solution(self):
def PDE_form(self):
""" Get the PDE form """
if isinstance(self._form, tuple):
return lambda m, u, p: self._form[0](m, u, p) - self._form[1](m, p)
# Create a string for the lambda function that represents the PDE form
form_str = (
"lambda form_lhs, form_rhs, "
+ ", ".join(self._non_default_args)
+ ", u, p: form_lhs("
+ ", ".join(self._non_default_args)
+ ", u, p) - form_rhs("
+ ", ".join(self._non_default_args)
+ ", p)"
)
# Create a lambda function that represents the PDE form
form = eval(form_str)

# partial evaluation of the form
form_partial = partial(form, form_lhs=self._form[0],
form_rhs=self._form[1])

return form_partial
else:
return self._form

Expand Down Expand Up @@ -243,10 +285,14 @@ def observation_operator(self):
@observation_operator.setter
def observation_operator(self, value):
""" Set the observation operator """
self._observation_operator = self._create_observation_operator(value)
if value == None or callable(value):
self._observation_operator = value
else:
raise NotImplementedError(
"observation_operator must be a callable function or None")

@abstractmethod
def assemble(self,parameter):
def assemble(self, *args, **kwargs):
""" Assemble the PDE weak form """
raise NotImplementedError

Expand All @@ -265,34 +311,34 @@ def gradient_wrt_parameter(self):
""" Compute gradient of the PDE weak form w.r.t. the parameter"""
raise NotImplementedError

@abstractmethod
def _create_observation_operator(self, observation_operator):
raise NotImplementedError

def _is_parameter_new(self, input_parameter):
""" A helper function to check if the `input_parameter` is different
from the current parameter (cached in self._parameter). """

if hasattr(self, '_parameter') \
and np.allclose(self._parameter.vector().get_local(),
input_parameter.vector().get_local(),
atol=dl.DOLFIN_EPS, rtol=dl.DOLFIN_EPS):
return False
else:
if not hasattr(self, '_parameter'):
return True

is_new = False
for key in self._non_default_args:
if not np.allclose(self._parameter[key].vector().get_local(),
input_parameter[key].vector().get_local(),
atol=dl.DOLFIN_EPS, rtol=dl.DOLFIN_EPS):
is_new = True
return is_new

class SteadyStateLinearFEniCSPDE(FEniCSPDE):
""" Class representation of steady state linear PDEs defined in FEniCS. It accepts the same arguments as the base class `cuqipy_fenics.pde.FEniCSPDE`."""

def assemble(self, parameter=None):
def assemble(self, *args, **kwargs):
self._solution_trial_function = dl.TrialFunction(
self.solution_function_space)
self._solution_test_function = dl.TestFunction(
self.solution_function_space)

if parameter is not None:
self.parameter = parameter
kwargs = self._parse_args_add_to_kwargs(
*args, map_name="assemble", **kwargs
)
self.parameter = kwargs

# Either assemble the lhs and rhs forms separately or the full PDE form
if self.lhs_form is not None:
Expand Down Expand Up @@ -328,11 +374,10 @@ def _assemble_full(self):
and self._flags["is_operator_valid"] and\
self.rhs is not None:
return

diff_op = dl.lhs(self.PDE_form(self.parameter,
diff_op = dl.lhs(self.PDE_form(*self.parameter_args,
self._solution_trial_function,
self._solution_test_function))
self.rhs = dl.rhs(self.PDE_form(self.parameter,
self.rhs = dl.rhs(self.PDE_form(*self.parameter_args,
self._solution_trial_function,
self._solution_test_function))

Expand All @@ -353,7 +398,7 @@ def _assemble_lhs(self):
and self._flags["is_operator_valid"]:
return

diff_op = dl.assemble(self.lhs_form(self.parameter,
diff_op = dl.assemble(self.lhs_form(*self.parameter_args,
self._solution_trial_function,
self._solution_test_function))

Expand All @@ -367,7 +412,7 @@ def _assemble_rhs(self):
and self.rhs is not None:
return

self.rhs = dl.assemble(self.rhs_form(self.parameter,
self.rhs = dl.assemble(self.rhs_form(*self.parameter_args,
self._solution_test_function))
for bc in self._dirichlet_bcs: bc.apply(self.rhs)

Expand All @@ -381,16 +426,20 @@ def observe(self,PDE_solution_fun):
if self.observation_operator is None:
return PDE_solution_fun
else:
return self._apply_obs_op(self.parameter, PDE_solution_fun)
return self._apply_obs_op(*self.parameter_args, PDE_solution_fun)

def gradient_wrt_parameter(self, direction, wrt, **kwargs):
""" Compute the gradient of the PDE with respect to the parameter
def gradient_wrt_parameter(self, direction, *args, **kwargs):
""" Compute the gradient of the PDE with respect to the parameters

Note: This implementation is largely based on the code:
https://github.com/hippylib/hippylib/blob/master/hippylib/modeling/PDEProblem.py

See also: Gunzburger, M. D. (2002). Perspectives in flow control and optimization. Society for Industrial and Applied Mathematics, for adjoint based derivative derivation.
"""

kwargs = self._parse_args_add_to_kwargs(
*args, map_name="gradient_wrt_parameter", **kwargs
)
# Raise an error if the adjoint boundary conditions are not provided
if self._adjoint_dirichlet_bcs is None:
raise ValueError(
Expand All @@ -402,16 +451,15 @@ def gradient_wrt_parameter(self, direction, wrt, **kwargs):

# Compute forward solution
# TODO: Use stored forward solution if available and wrt == self.parameter
self.parameter = wrt
self.assemble()
self.parameter = kwargs
self.assemble(*self.parameter_args)
self.forward_solution, _ = self.solve()

# Compute adjoint solution
test_parameter = dl.TestFunction(self.parameter_function_space)
test_solution = dl.TestFunction(self.solution_function_space)

# note: temp_form is a weak form used for building the adjoint operator
temp_form = self.PDE_form(wrt, self.forward_solution, trial_adjoint)
temp_form = self.PDE_form(*self.parameter_args, self.forward_solution, trial_adjoint)
adjoint_form = dl.derivative(
temp_form, self.forward_solution, test_solution)

Expand All @@ -431,11 +479,25 @@ def gradient_wrt_parameter(self, direction, wrt, **kwargs):

# Compute gradient
# note: temp_form is a weak form used for building the gradient
temp_form = self.PDE_form(wrt, self.forward_solution, adjoint)
gradient_form = dl.derivative(temp_form, wrt, test_parameter)
gradient = dl.Function(self.parameter_function_space)
dl.assemble(gradient_form, tensor=gradient.vector())
return gradient
temp_form = self.PDE_form(*self.parameter_args, self.forward_solution, adjoint)
parameter_function_space = self.parameter_function_space
gradient_list = []
if not isinstance(self.parameter_function_space, (list, tuple)):
parameter_function_space = [self.parameter_function_space]

for i, k in enumerate(self._non_default_args):
test_parameter = dl.TestFunction(parameter_function_space[i])
gradient_form = dl.derivative(temp_form, self.parameter_args[i], test_parameter)
gradient = dl.Function(parameter_function_space[i])
dl.assemble(gradient_form, tensor=gradient.vector())
gradient_list.append(gradient)

# If only one parameter is passed, return a single gradient
if len(gradient_list) == 1:
gradient_list = gradient_list[0]
else:
gradient_list = tuple(gradient_list)
return gradient_list


def _apply_obs_op(self, PDE_parameter_fun, PDE_solution_fun,):
Expand All @@ -450,21 +512,4 @@ def _apply_obs_op(self, PDE_parameter_fun, PDE_solution_fun,):
raise NotImplementedError("obs_op output must be a number, a numpy array or a ufl.algebra.Operator type")


def _create_observation_operator(self, observation_operator):
"""
"""
if observation_operator == 'potential':
observation_operator = lambda m, u: u
elif observation_operator == 'gradu_squared':
observation_operator = lambda m, u: dl.inner(dl.grad(u),dl.grad(u))
elif observation_operator == 'power_density':
observation_operator = lambda m, u: m*dl.inner(dl.grad(u),dl.grad(u))
elif observation_operator == 'sigma_u':
observation_operator = lambda m, u: m*u
elif observation_operator == 'sigma_norm_gradu':
observation_operator = lambda m, u: m*dl.sqrt(dl.inner(dl.grad(u),dl.grad(u)))
elif observation_operator == None or callable(observation_operator):
observation_operator = observation_operator
else:
raise NotImplementedError
return observation_operator

Loading
Loading