Skip to content

Commit 3aca6f7

Browse files
authored
Backend paddle: support regularizer (#1896)
1 parent ec4bdd3 commit 3aca6f7

File tree

3 files changed

+33
-3
lines changed

3 files changed

+33
-3
lines changed

deepxde/model.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -518,7 +518,11 @@ def outputs_losses_test(inputs, targets, auxiliary_vars):
518518
list(self.net.parameters()) + self.external_trainable_variables
519519
)
520520
self.opt = optimizers.get(
521-
trainable_variables, self.opt_name, learning_rate=lr, decay=decay
521+
trainable_variables,
522+
self.opt_name,
523+
learning_rate=lr,
524+
decay=decay,
525+
weight_decay=self.net.regularizer,
522526
)
523527

524528
def train_step(inputs, targets, auxiliary_vars):

deepxde/nn/paddle/nn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ class NN(paddle.nn.Layer):
66

77
def __init__(self):
88
super().__init__()
9+
self.regularizer = None
910
self._input_transform = None
1011
self._output_transform = None
1112

deepxde/optimizers/paddle/optimizers.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,16 @@ def is_external_optimizer(optimizer):
1919
return optimizer in ["L-BFGS", "L-BFGS-B"]
2020

2121

22-
def get(params, optimizer, learning_rate=None, decay=None):
22+
def get(params, optimizer, learning_rate=None, decay=None, weight_decay=None):
2323
"""Retrieves an Optimizer instance."""
2424
if isinstance(optimizer, paddle.optimizer.Optimizer):
2525
return optimizer
2626

2727
if optimizer in ["L-BFGS", "L-BFGS-B"]:
2828
if learning_rate is not None or decay is not None:
2929
print("Warning: learning rate is ignored for {}".format(optimizer))
30+
if weight_decay is not None:
31+
raise ValueError("L-BFGS optimizer doesn't support weight_decay")
3032
optim = paddle.optimizer.LBFGS(
3133
learning_rate=1,
3234
max_iter=LBFGS_options["iter_per_step"],
@@ -46,5 +48,28 @@ def get(params, optimizer, learning_rate=None, decay=None):
4648
learning_rate = _get_lr_scheduler(learning_rate, decay)
4749

4850
if optimizer == "adam":
49-
return paddle.optimizer.Adam(learning_rate=learning_rate, parameters=params)
51+
return paddle.optimizer.Adam(
52+
learning_rate=learning_rate, parameters=params, weight_decay=weight_decay
53+
)
54+
if optimizer == "sgd":
55+
return paddle.optimizer.SGD(
56+
learning_rate=learning_rate, parameters=params, weight_decay=weight_decay
57+
)
58+
if optimizer == "rmsprop":
59+
return paddle.optimizer.RMSProp(
60+
learning_rate=learning_rate,
61+
parameters=params,
62+
weight_decay=weight_decay,
63+
)
64+
if optimizer == "adamw":
65+
if (
66+
not isinstance(weight_decay, paddle.regularizer.L2Decay)
67+
or weight_decay._coeff == 0
68+
):
69+
raise ValueError("AdamW optimizer requires non-zero L2 regularizer")
70+
return paddle.optimizer.AdamW(
71+
learning_rate=learning_rate,
72+
parameters=params,
73+
weight_decay=weight_decay._coeff,
74+
)
5075
raise NotImplementedError(f"{optimizer} to be implemented for backend Paddle.")

0 commit comments

Comments
 (0)