Skip to content

Commit cdec9c7

Browse files
committed
Fix interaction of L1 regularization with loss weights
1 parent a31d197 commit cdec9c7

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

deepxde/model.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -324,15 +324,16 @@ def outputs_losses(training, inputs, targets, auxiliary_vars, losses_fn):
324324
losses = losses_fn(targets, outputs_, loss_fn, inputs, self, aux=aux)
325325
if not isinstance(losses, list):
326326
losses = [losses]
327-
if l1_factor > 0:
328-
l1_loss = torch.sum(
329-
torch.stack([torch.sum(p.abs()) for p in self.net.parameters()])
330-
)
331-
losses.append(l1_factor * l1_loss)
332327
losses = torch.stack(losses)
333328
# Weighted losses
334329
if self.loss_weights is not None:
335330
losses *= torch.as_tensor(self.loss_weights)
331+
if l1_factor > 0:
332+
l1_loss = torch.sum(
333+
torch.stack([torch.sum(p.abs()) for p in self.net.parameters()])
334+
)
335+
l1_loss *= l1_factor
336+
losses = torch.cat([losses, l1_loss.unsqueeze(0)])
336337
# Clear cached Jacobians and Hessians.
337338
grad.clear()
338339
return outputs_, losses

0 commit comments

Comments
 (0)