Skip to content

Commit d20a997

Browse files
committed
Backend PyTorch supports Adam optimizer
1 parent 23c57d7 commit d20a997

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

deepxde/model.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,9 @@ def outputs_losses(data_id, inputs, targets):
189189
losses = compute_losses(targets, outputs)
190190
return outputs, losses
191191

192-
opt = torch.optim.Adam(self.net.parameters(), lr=lr)
192+
opt = optimizers.get(
193+
self.net.parameters(), self.opt_name, learning_rate=lr, decay=decay
194+
)
193195

194196
def train_step(data_id, inputs, targets):
195197
_, losses = outputs_losses(data_id, inputs, targets)
@@ -404,6 +406,7 @@ def predict(self, x, operator=None, callbacks=None):
404406
self.callbacks.set_model(self)
405407
self.callbacks.on_predict_begin()
406408
# TODO: use self._run for tensorflow
409+
# TODO: predict operator with auxiliary_vars
407410
if backend_name == "tensorflow.compat.v1":
408411
if operator is None:
409412
op = self.net.outputs
Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,21 @@
1-
__all__ = ["is_external_optimizer"]
1+
import torch
2+
3+
4+
__all__ = ["get", "is_external_optimizer"]
25

36

47
def is_external_optimizer(optimizer):
58
return False
9+
10+
11+
def get(params, optimizer, learning_rate=None, decay=None):
12+
"""Retrieves an Optimizer instance."""
13+
if isinstance(optimizer, torch.optim.Optimizer):
14+
return optimizer
15+
if learning_rate is None:
16+
raise ValueError("No learning rate for {}.".format(optimizer))
17+
18+
# TODO: decay
19+
if optimizer == "adam":
20+
return torch.optim.Adam(params, lr=learning_rate)
21+
raise NotImplementedError(f"{optimizer} to be implemented for backend pytorch.")

0 commit comments

Comments
 (0)