Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions deepxde/data/pde.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from .. import backend as bkd
from .. import config
from ..backend import backend_name
from ..utils import get_num_args, run_if_all_none, mpi_scatter_from_rank0
from ..utils import (
get_num_args,
has_default_values,
run_if_all_none,
mpi_scatter_from_rank0,
)


class PDE(Data):
Expand Down Expand Up @@ -150,9 +155,14 @@ def losses(self, targets, outputs, loss_fn, inputs, model, aux=None):
elif get_num_args(self.pde) == 3:
if self.auxiliary_var_fn is not None:
f = self.pde(inputs, outputs_pde, model.net.auxiliary_vars)
elif backend_name == "jax" and len(aux) == 2:
elif backend_name == "jax":
# JAX inverse problem requires unknowns as the input.
f = self.pde(inputs, outputs_pde, unknowns=aux[1])
if len(aux) == 2:
# External trainable variables in aux[1] are used for unknowns
f = self.pde(inputs, outputs_pde, unknowns=aux[1])
if len(aux) == 1 and has_default_values(self.pde)[-1]:
# No external trainable variables, default values are used for unknowns
f = self.pde(inputs, outputs_pde)
else:
raise ValueError("Auxiliary variable function not defined.")
if not isinstance(f, (list, tuple)):
Expand Down
8 changes: 4 additions & 4 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,10 +377,10 @@ def _compile_jax(self, lr, loss_fn, decay):
if self.params is None:
key = jax.random.PRNGKey(config.jax_random_seed)
self.net.params = self.net.init(key, self.data.test()[0])
external_trainable_variables_arr = [
var.value for var in self.external_trainable_variables
]
self.params = [self.net.params, external_trainable_variables_arr]
external_trainable_variables_val = [
var.value for var in self.external_trainable_variables
]
self.params = [self.net.params, external_trainable_variables_val]
# TODO: learning rate decay
self.opt = optimizers.get(self.opt_name, learning_rate=lr)
self.opt_state = self.opt.init(self.params)
Expand Down
3 changes: 3 additions & 0 deletions deepxde/utils/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,9 @@ def get_num_args(func):
params = inspect.signature(func).parameters
return len(params) - ("self" in params)

def has_default_values(func):
params = inspect.signature(func).parameters.values()
return [param.default is not inspect.Parameter.empty for param in params]

def mpi_scatter_from_rank0(array, drop_last=True):
"""Scatter the given array into continuous subarrays of equal size from rank 0 to all ranks.
Expand Down