Skip to content

Commit bc1f67b

Browse files
authored
Backend JAX: Fix external variable initialization (#1775)
1 parent bd43c6c commit bc1f67b

File tree

3 files changed

+34
-7
lines changed

3 files changed

+34
-7
lines changed

deepxde/data/pde.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44
from .. import backend as bkd
55
from .. import config
66
from ..backend import backend_name
7-
from ..utils import get_num_args, run_if_all_none, mpi_scatter_from_rank0
7+
from ..utils import (
8+
get_num_args,
9+
has_default_values,
10+
mpi_scatter_from_rank0,
11+
run_if_all_none,
12+
)
813

914

1015
class PDE(Data):
@@ -150,9 +155,18 @@ def losses(self, targets, outputs, loss_fn, inputs, model, aux=None):
150155
elif get_num_args(self.pde) == 3:
151156
if self.auxiliary_var_fn is not None:
152157
f = self.pde(inputs, outputs_pde, model.net.auxiliary_vars)
153-
elif backend_name == "jax" and len(aux) == 2:
158+
elif backend_name == "jax":
154159
# JAX inverse problem requires unknowns as the input.
155-
f = self.pde(inputs, outputs_pde, unknowns=aux[1])
160+
if len(aux) == 2:
161+
# External trainable variables in aux[1] are used for unknowns
162+
f = self.pde(inputs, outputs_pde, unknowns=aux[1])
163+
elif len(aux) == 1 and has_default_values(self.pde)[-1]:
164+
# No external trainable variables, default values are used for unknowns
165+
f = self.pde(inputs, outputs_pde)
166+
else:
167+
raise ValueError(
168+
"Default unknowns are required if no trainable variables are provided."
169+
)
156170
else:
157171
raise ValueError("Auxiliary variable function not defined.")
158172
if not isinstance(f, (list, tuple)):

deepxde/model.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -413,10 +413,10 @@ def _compile_jax(self, lr, loss_fn, decay):
413413
if self.params is None:
414414
key = jax.random.PRNGKey(config.jax_random_seed)
415415
self.net.params = self.net.init(key, self.data.test()[0])
416-
external_trainable_variables_arr = [
417-
var.value for var in self.external_trainable_variables
418-
]
419-
self.params = [self.net.params, external_trainable_variables_arr]
416+
external_trainable_variables_val = [
417+
var.value for var in self.external_trainable_variables
418+
]
419+
self.params = [self.net.params, external_trainable_variables_val]
420420
# TODO: learning rate decay
421421
self.opt = optimizers.get(self.opt_name, learning_rate=lr)
422422
self.opt_state = self.opt.init(self.params)

deepxde/utils/internal.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,19 @@ def get_num_args(func):
203203
return len(params) - ("self" in params)
204204

205205

206+
def has_default_values(func):
207+
"""Check if the given function has default values for its parameters.
208+
209+
Args:
210+
func (function): The function to inspect.
211+
212+
Returns:
213+
list: A list of boolean values indicating whether each parameter has a default value.
214+
"""
215+
params = inspect.signature(func).parameters.values()
216+
return [param.default is not inspect.Parameter.empty for param in params]
217+
218+
206219
def mpi_scatter_from_rank0(array, drop_last=True):
207220
"""Scatter the given array into continuous subarrays of equal size from rank 0 to all ranks.
208221

0 commit comments

Comments
 (0)