Skip to content

Commit bd43c6c

Browse files
authored
Backend PyTorch: Add L1 and L1+L2 regularizers (#1905)
1 parent bb7ddb4 commit bd43c6c

File tree

1 file changed

+18
-11
lines changed

1 file changed

+18
-11
lines changed

deepxde/model.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,17 @@ def build_loss():
279279
def _compile_pytorch(self, lr, loss_fn, decay):
280280
"""pytorch"""
281281

282+
l1_factor, l2_factor = 0, 0
283+
if self.net.regularizer is not None:
284+
if self.net.regularizer[0] == "l1":
285+
l1_factor = self.net.regularizer[1]
286+
elif self.net.regularizer[0] == "l2":
287+
l2_factor = self.net.regularizer[1]
288+
else:
289+
raise NotImplementedError(
290+
f"{self.net.regularizer[0]} regularizer hasn't been implemented for backend pytorch."
291+
)
292+
282293
def outputs(training, inputs):
283294
self.net.train(mode=training)
284295
with torch.no_grad():
@@ -318,6 +329,11 @@ def outputs_losses(training, inputs, targets, auxiliary_vars, losses_fn):
318329
# Weighted losses
319330
if self.loss_weights is not None:
320331
losses *= torch.as_tensor(self.loss_weights)
332+
if l1_factor > 0:
333+
l1_loss = l1_factor * torch.sum(
334+
torch.stack([torch.sum(p.abs()) for p in self.net.parameters()])
335+
)
336+
losses = torch.cat([losses, l1_loss.unsqueeze(0)])
321337
# Clear cached Jacobians and Hessians.
322338
grad.clear()
323339
return outputs_, losses
@@ -332,23 +348,14 @@ def outputs_losses_test(inputs, targets, auxiliary_vars):
332348
False, inputs, targets, auxiliary_vars, self.data.losses_test
333349
)
334350

335-
weight_decay = 0
336-
if self.net.regularizer is not None:
337-
if self.net.regularizer[0] != "l2":
338-
raise NotImplementedError(
339-
f"{self.net.regularizer[0]} regularization to be implemented for "
340-
"backend pytorch"
341-
)
342-
weight_decay = self.net.regularizer[1]
343-
344351
optimizer_params = self.net.parameters()
345352
if self.external_trainable_variables:
346353
# L-BFGS doesn't support per-parameter options.
347354
if self.opt_name in ["L-BFGS", "L-BFGS-B"]:
348355
optimizer_params = (
349356
list(optimizer_params) + self.external_trainable_variables
350357
)
351-
if weight_decay > 0:
358+
if l2_factor > 0:
352359
print(
353360
"Warning: L2 regularization will also be applied to external_trainable_variables. "
354361
"Ensure this is intended behavior."
@@ -364,7 +371,7 @@ def outputs_losses_test(inputs, targets, auxiliary_vars):
364371
self.opt_name,
365372
learning_rate=lr,
366373
decay=decay,
367-
weight_decay=weight_decay,
374+
weight_decay=l2_factor,
368375
)
369376

370377
def train_step(inputs, targets, auxiliary_vars):

0 commit comments

Comments
 (0)