Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,8 +517,13 @@ def outputs_losses_test(inputs, targets, auxiliary_vars):
trainable_variables = (
list(self.net.parameters()) + self.external_trainable_variables
)
weight_decay = getattr(self.net, "regularizer", None)
self.opt = optimizers.get(
trainable_variables, self.opt_name, learning_rate=lr, decay=decay
trainable_variables,
self.opt_name,
learning_rate=lr,
decay=decay,
weight_decay=weight_decay,
)

def train_step(inputs, targets, auxiliary_vars):
Expand Down
28 changes: 26 additions & 2 deletions deepxde/optimizers/paddle/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,14 @@ def is_external_optimizer(optimizer):
return optimizer in ["L-BFGS", "L-BFGS-B"]


def get(params, optimizer, learning_rate=None, decay=None):
def get(params, optimizer, learning_rate=None, decay=None, weight_decay=None):
"""Retrieves an Optimizer instance."""
if isinstance(optimizer, paddle.optimizer.Optimizer):
return optimizer

if optimizer in ["L-BFGS", "L-BFGS-B"]:
if weight_decay is not None:
raise ValueError("L-BFGS optimizer doesn't support weight_decay")
if learning_rate is not None or decay is not None:
print("Warning: learning rate is ignored for {}".format(optimizer))
optim = paddle.optimizer.LBFGS(
Expand All @@ -46,5 +48,27 @@ def get(params, optimizer, learning_rate=None, decay=None):
learning_rate = _get_lr_scheduler(learning_rate, decay)

if optimizer == "adam":
return paddle.optimizer.Adam(learning_rate=learning_rate, parameters=params)
return paddle.optimizer.Adam(
learning_rate=learning_rate, parameters=params, weight_decay=weight_decay
)
if optimizer == "sgd":
return paddle.optimizer.SGD(
learning_rate=learning_rate, parameters=params, weight_decay=weight_decay
)
if optimizer == "rmsprop":
return paddle.optimizer.RMSProp(
learning_rate=learning_rate,
parameters=params,
weight_decay=weight_decay,
)
if optimizer == "adamw":
if isinstance(weight_decay, paddle.regularizer.L2Decay):
if weight_decay._coeff == 0:
raise ValueError("AdamW optimizer requires non-zero weight decay")
return paddle.optimizer.AdamW(
learning_rate=learning_rate,
parameters=params,
weight_decay=weight_decay._coeff,
)
raise ValueError("AdamW optimizer requires l2 regularizer")
raise NotImplementedError(f"{optimizer} to be implemented for backend Paddle.")
Loading