Skip to content

Commit 500479e

Browse files
authored
Backend PyTorch supports "Step" LR decay (#722)
1 parent e07c649 commit 500479e

File tree

2 files changed

+41
-17
lines changed

2 files changed

+41
-17
lines changed

deepxde/model.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def __init__(self, data, net):
4747
if backend_name == "tensorflow.compat.v1":
4848
self.sess = None
4949
self.saver = None
50+
elif backend_name == "pytorch":
51+
self.lr_scheduler = None
5052
elif backend_name == "jax":
5153
self.opt_state = None
5254

@@ -251,7 +253,9 @@ def outputs(training, inputs):
251253
def outputs_losses(training, inputs, targets, losses_fn):
252254
self.net.train(mode=training)
253255
if isinstance(inputs, tuple):
254-
inputs = tuple(map(lambda x: torch.as_tensor(x).requires_grad_(), inputs))
256+
inputs = tuple(
257+
map(lambda x: torch.as_tensor(x).requires_grad_(), inputs)
258+
)
255259
else:
256260
inputs = torch.as_tensor(inputs)
257261
inputs.requires_grad_()
@@ -283,7 +287,7 @@ def outputs_losses_test(inputs, targets):
283287
trainable_variables = (
284288
list(self.net.parameters()) + self.external_trainable_variables
285289
)
286-
self.opt = optimizers.get(
290+
self.opt, self.lr_scheduler = optimizers.get(
287291
trainable_variables, self.opt_name, learning_rate=lr, decay=decay
288292
)
289293

@@ -296,6 +300,8 @@ def closure():
296300
return total_loss
297301

298302
self.opt.step(closure)
303+
if self.lr_scheduler is not None:
304+
self.lr_scheduler.step()
299305

300306
# Callables
301307
self.outputs = outputs

deepxde/optimizers/pytorch/optimizers.py

Lines changed: 33 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ def is_external_optimizer(optimizer):
1111

1212
def get(params, optimizer, learning_rate=None, decay=None):
1313
"""Retrieves an Optimizer instance."""
14+
# Custom Optimizer
1415
if isinstance(optimizer, torch.optim.Optimizer):
15-
return optimizer
16-
17-
if optimizer in ["L-BFGS", "L-BFGS-B"]:
16+
optim = optimizer
17+
elif optimizer in ["L-BFGS", "L-BFGS-B"]:
1818
if learning_rate is not None or decay is not None:
1919
print("Warning: learning rate is ignored for {}".format(optimizer))
20-
return torch.optim.LBFGS(
20+
optim = torch.optim.LBFGS(
2121
params,
2222
lr=1,
2323
max_iter=LBFGS_options["iter_per_step"],
@@ -27,15 +27,33 @@ def get(params, optimizer, learning_rate=None, decay=None):
2727
history_size=LBFGS_options["maxcor"],
2828
line_search_fn=None,
2929
)
30-
31-
if learning_rate is None:
32-
raise ValueError("No learning rate for {}.".format(optimizer))
33-
34-
if decay is not None:
35-
# TODO: learning rate decay
36-
raise NotImplementedError(
37-
"learning rate decay to be implemented for backend pytorch."
30+
else:
31+
if learning_rate is None:
32+
raise ValueError("No learning rate for {}.".format(optimizer))
33+
if optimizer == "sgd":
34+
optim = torch.optim.SGD(params, lr=learning_rate)
35+
elif optimizer == "rmsprop":
36+
optim = torch.optim.RMSprop(params, lr=learning_rate)
37+
elif optimizer == "adam":
38+
optim = torch.optim.Adam(params, lr=learning_rate)
39+
else:
40+
raise NotImplementedError(
41+
f"{optimizer} to be implemented for backend pytorch."
42+
)
43+
lr_scheduler = _get_learningrate_scheduler(optim, decay)
44+
return optim, lr_scheduler
45+
46+
47+
def _get_learningrate_scheduler(optim, decay):
48+
if decay is None:
49+
return None
50+
51+
if decay[0] == "step":
52+
return torch.optim.lr_scheduler.StepLR(
53+
optim, step_size=decay[1], gamma=decay[2]
3854
)
39-
if optimizer == "adam":
40-
return torch.optim.Adam(params, lr=learning_rate)
41-
raise NotImplementedError(f"{optimizer} to be implemented for backend pytorch.")
55+
56+
# TODO: More learning rate scheduler
57+
raise NotImplementedError(
58+
f"{decay[0]} learning rate scheduler to be implemented for backend pytorch."
59+
)

0 commit comments

Comments
 (0)