Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions deepxde/data/pde.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,9 @@ def losses(self, targets, outputs, loss_fn, inputs, model, aux=None):
elif get_num_args(self.pde) == 3:
if self.auxiliary_var_fn is None:
if aux is None or len(aux) == 1:
raise ValueError("Auxiliary variable function not defined.")
f = self.pde(inputs, outputs_pde, unknowns=aux[1])
f = self.pde(inputs, outputs_pde)
else:
f = self.pde(inputs, outputs_pde, unknowns=aux[1])
else:
f = self.pde(inputs, outputs_pde, model.net.auxiliary_vars)
if not isinstance(f, (list, tuple)):
Expand Down
8 changes: 4 additions & 4 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,10 +375,10 @@ def _compile_jax(self, lr, loss_fn, decay):
if self.params is None:
key = jax.random.PRNGKey(config.jax_random_seed)
self.net.params = self.net.init(key, self.data.test()[0])
external_trainable_variables_arr = [
var.value for var in self.external_trainable_variables
]
self.params = [self.net.params, external_trainable_variables_arr]
external_trainable_variables_val = [
var.value for var in self.external_trainable_variables
]
self.params = [self.net.params, external_trainable_variables_val]
# TODO: learning rate decay
self.opt = optimizers.get(self.opt_name, learning_rate=lr)
self.opt_state = self.opt.init(self.params)
Expand Down