Skip to content

Commit 3544fdf

Browse files
authored
Backend PyTorch: Fix L2 regularizers for external_trainable_variables (#1884)
1 parent 8275aeb commit 3544fdf

File tree

1 file changed

+35
-24
lines changed

1 file changed

+35
-24
lines changed

deepxde/model.py

Lines changed: 35 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,10 @@ def compile(
112112
weighted by the `loss_weights` coefficients.
113113
external_trainable_variables: A trainable ``dde.Variable`` object or a list
114114
of trainable ``dde.Variable`` objects. The unknown parameters in the
115-
physics systems that need to be recovered. If the backend is
116-
tensorflow.compat.v1, `external_trainable_variables` is ignored, and all
117-
trainable ``dde.Variable`` objects are automatically collected.
115+
physics systems that need to be recovered. Regularization will not be
116+
applied to these variables. If the backend is tensorflow.compat.v1,
117+
`external_trainable_variables` is ignored, and all trainable ``dde.Variable``
118+
objects are automatically collected.
118119
verbose (Integer): Controls the verbosity of the compile process.
119120
"""
120121
if verbose > 0 and config.rank == 0:
@@ -330,30 +331,40 @@ def outputs_losses_test(inputs, targets, auxiliary_vars):
330331
False, inputs, targets, auxiliary_vars, self.data.losses_test
331332
)
332333

333-
# Another way is using per-parameter options
334-
# https://pytorch.org/docs/stable/optim.html#per-parameter-options,
335-
# but not all optimizers (such as L-BFGS) support this.
336-
trainable_variables = (
337-
list(self.net.parameters()) + self.external_trainable_variables
338-
)
339-
if self.net.regularizer is None:
340-
self.opt, self.lr_scheduler = optimizers.get(
341-
trainable_variables, self.opt_name, learning_rate=lr, decay=decay
342-
)
343-
else:
344-
if self.net.regularizer[0] == "l2":
345-
self.opt, self.lr_scheduler = optimizers.get(
346-
trainable_variables,
347-
self.opt_name,
348-
learning_rate=lr,
349-
decay=decay,
350-
weight_decay=self.net.regularizer[1],
351-
)
352-
else:
334+
weight_decay = 0
335+
if self.net.regularizer is not None:
336+
if self.net.regularizer[0] != "l2":
353337
raise NotImplementedError(
354338
f"{self.net.regularizer[0]} regularization to be implemented for "
355-
"backend pytorch."
339+
"backend pytorch"
356340
)
341+
weight_decay = self.net.regularizer[1]
342+
343+
optimizer_params = self.net.parameters()
344+
if self.external_trainable_variables:
345+
# L-BFGS doesn't support per-parameter options.
346+
if self.opt_name in ["L-BFGS", "L-BFGS-B"]:
347+
optimizer_params = (
348+
list(optimizer_params) + self.external_trainable_variables
349+
)
350+
if weight_decay > 0:
351+
print(
352+
"Warning: L2 regularization will also be applied to external_trainable_variables. "
353+
"Ensure this is intended behavior."
354+
)
355+
else:
356+
optimizer_params = [
357+
{"params": optimizer_params},
358+
{"params": self.external_trainable_variables, "weight_decay": 0},
359+
]
360+
361+
self.opt, self.lr_scheduler = optimizers.get(
362+
optimizer_params,
363+
self.opt_name,
364+
learning_rate=lr,
365+
decay=decay,
366+
weight_decay=weight_decay,
367+
)
357368

358369
def train_step(inputs, targets, auxiliary_vars):
359370
def closure():

0 commit comments

Comments
 (0)