Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,30 +330,33 @@ def outputs_losses_test(inputs, targets, auxiliary_vars):
False, inputs, targets, auxiliary_vars, self.data.losses_test
)

# Another way is using per-parameter options
# https://pytorch.org/docs/stable/optim.html#per-parameter-options,
# but not all optimizers (such as L-BFGS) support this.
trainable_variables = (
list(self.net.parameters()) + self.external_trainable_variables
)
if self.net.regularizer is None:
self.opt, self.lr_scheduler = optimizers.get(
trainable_variables, self.opt_name, learning_rate=lr, decay=decay
optimizer_params = (
list(self.net.parameters()) + self.external_trainable_variables
)
else:
if self.net.regularizer[0] == "l2":
self.opt, self.lr_scheduler = optimizers.get(
trainable_variables,
self.opt_name,
learning_rate=lr,
decay=decay,
weight_decay=self.net.regularizer[1],
)
if self.opt_name in ["L-BFGS", "L-BFGS-B"]:
print(f"Warning: weight decay is ignored for {self.opt_name}")
optimizer_params = (
list(self.net.parameters()) + self.external_trainable_variables
)
else:
optimizer_params = [
{
"params": self.net.parameters(),
"weight_decay": self.net.regularizer[1],
},
{"params": self.external_trainable_variables},
]
else:
raise NotImplementedError(
f"{self.net.regularizer[0]} regularization to be implemented for "
"backend pytorch."
)
self.opt, self.lr_scheduler = optimizers.get(
optimizer_params, self.opt_name, learning_rate=lr, decay=decay
)

def train_step(inputs, targets, auxiliary_vars):
def closure():
Expand Down
20 changes: 5 additions & 15 deletions deepxde/optimizers/pytorch/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,12 @@ def is_external_optimizer(optimizer):
return optimizer in ["L-BFGS", "L-BFGS-B"]


def get(params, optimizer, learning_rate=None, decay=None, weight_decay=0):
def get(params, optimizer, learning_rate=None, decay=None):
"""Retrieves an Optimizer instance."""
# Custom Optimizer
if isinstance(optimizer, torch.optim.Optimizer):
optim = optimizer
elif optimizer in ["L-BFGS", "L-BFGS-B"]:
if weight_decay > 0:
raise ValueError("L-BFGS optimizer doesn't support weight_decay > 0")
if learning_rate is not None or decay is not None:
print("Warning: learning rate is ignored for {}".format(optimizer))
optim = torch.optim.LBFGS(
Expand All @@ -33,21 +31,13 @@ def get(params, optimizer, learning_rate=None, decay=None, weight_decay=0):
if learning_rate is None:
raise ValueError("No learning rate for {}.".format(optimizer))
if optimizer == "sgd":
optim = torch.optim.SGD(params, lr=learning_rate, weight_decay=weight_decay)
optim = torch.optim.SGD(params, lr=learning_rate)
elif optimizer == "rmsprop":
optim = torch.optim.RMSprop(
params, lr=learning_rate, weight_decay=weight_decay
)
optim = torch.optim.RMSprop(params, lr=learning_rate)
elif optimizer == "adam":
optim = torch.optim.Adam(
params, lr=learning_rate, weight_decay=weight_decay
)
optim = torch.optim.Adam(params, lr=learning_rate)
elif optimizer == "adamw":
if weight_decay == 0:
raise ValueError("AdamW optimizer requires non-zero weight decay")
optim = torch.optim.AdamW(
params, lr=learning_rate, weight_decay=weight_decay
)
optim = torch.optim.AdamW(params, lr=learning_rate)
else:
raise NotImplementedError(
f"{optimizer} to be implemented for backend pytorch."
Expand Down