Skip to content

Commit ad52e77

Browse files
committed
Backend paddle: add optimizers with supportting regularizer
1 parent 8275aeb commit ad52e77

File tree

2 files changed

+22
-3
lines changed

2 files changed

+22
-3
lines changed

deepxde/model.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,8 +506,13 @@ def outputs_losses_test(inputs, targets, auxiliary_vars):
506506
trainable_variables = (
507507
list(self.net.parameters()) + self.external_trainable_variables
508508
)
509+
regularizer = getattr(self.net, 'regularizer', None)
510+
if regularizer is not None:
511+
weight_decay = self.net.regularizer_value if self.opt_name == "adamw" else self.net.regularizer
512+
else:
513+
weight_decay = None
509514
self.opt = optimizers.get(
510-
trainable_variables, self.opt_name, learning_rate=lr, decay=decay
515+
trainable_variables, self.opt_name, learning_rate=lr, decay=decay, weight_decay=weight_decay,
511516
)
512517

513518
def train_step(inputs, targets, auxiliary_vars):

deepxde/optimizers/paddle/optimizers.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@ def is_external_optimizer(optimizer):
1919
return optimizer in ["L-BFGS", "L-BFGS-B"]
2020

2121

22-
def get(params, optimizer, learning_rate=None, decay=None):
22+
def get(params, optimizer, learning_rate=None, decay=None, weight_decay=None):
2323
"""Retrieves an Optimizer instance."""
2424
if isinstance(optimizer, paddle.optimizer.Optimizer):
2525
return optimizer
2626

2727
if optimizer in ["L-BFGS", "L-BFGS-B"]:
28+
if weight_decay is not None:
29+
raise ValueError("L-BFGS optimizer doesn't support weight_decay")
2830
if learning_rate is not None or decay is not None:
2931
print("Warning: learning rate is ignored for {}".format(optimizer))
3032
optim = paddle.optimizer.LBFGS(
@@ -46,5 +48,17 @@ def get(params, optimizer, learning_rate=None, decay=None):
4648
learning_rate = _get_lr_scheduler(learning_rate, decay)
4749

4850
if optimizer == "adam":
49-
return paddle.optimizer.Adam(learning_rate=learning_rate, parameters=params)
51+
return paddle.optimizer.Adam(learning_rate=learning_rate, parameters=params, weight_decay=weight_decay)
52+
elif optimizer == "sgd":
53+
return paddle.optimizer.SGD(learning_rate=learning_rate, parameters=params, weight_decay=weight_decay)
54+
elif optimizer == "rmsprop":
55+
return paddle.optimizer.RMSProp(
56+
learning_rate=learning_rate, parameters=params, weight_decay=weight_decay,
57+
)
58+
elif optimizer == "adamw":
59+
if weight_decay[0] == 0:
60+
raise ValueError("AdamW optimizer requires non-zero weight decay")
61+
return paddle.optimizer.AdamW(
62+
learning_rate=learning_rate, parameters=params, weight_decay=weight_decay[0],
63+
)
5064
raise NotImplementedError(f"{optimizer} to be implemented for backend Paddle.")

0 commit comments

Comments
 (0)