Skip to content

Commit 772782b

Browse files
committed
Backend PyTorch supports more inverse problems
1 parent fb4cb72 commit 772782b

File tree

3 files changed

+19
-3
lines changed

3 files changed

+19
-3
lines changed

deepxde/model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,8 @@ def outputs(inputs):
209209
def outputs_losses(inputs, targets):
210210
outputs_ = outputs(inputs)
211211
# Data losses
212-
targets = torch.from_numpy(targets)
212+
if targets is not None:
213+
targets = torch.from_numpy(targets)
213214
losses = self.data.losses(targets, outputs_, loss_fn, self)
214215
if not isinstance(losses, list):
215216
losses = [losses]

examples/Lorenz_inverse.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,28 @@
1-
"""Backend supported: tensorflow.compat.v1, tensorflow
1+
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch
22
33
Documentation: https://deepxde.readthedocs.io/en/latest/demos/lorenz.inverse.html
44
"""
55
import deepxde as dde
66
import numpy as np
7+
# Backend tensorflow.compat.v1 or tensorflow
78
from deepxde.backend import tf
9+
# Backend pytorch
10+
# import torch
811

912

1013
def gen_traindata():
1114
data = np.load("dataset/Lorenz.npz")
1215
return data["t"], data["y"]
1316

1417

18+
# Backend tensorflow.compat.v1 or tensorflow
1519
C1 = tf.Variable(1.0)
1620
C2 = tf.Variable(1.0)
1721
C3 = tf.Variable(1.0)
22+
# Backend pytorch
23+
# C1 = torch.tensor(1.0, requires_grad=True)
24+
# C2 = torch.tensor(1.0, requires_grad=True)
25+
# C3 = torch.tensor(1.0, requires_grad=True)
1826

1927

2028
def Lorenz_system(x, y):

examples/reaction_inverse.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
"""Backend supported: tensorflow.compat.v1, tensorflow"""
1+
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch"""
22
import deepxde as dde
33
import numpy as np
4+
# Backend tensorflow.compat.v1 or tensorflow
45
from deepxde.backend import tf
6+
# Backend pytorch
7+
# import torch
58

69

710
def gen_traindata():
@@ -15,8 +18,12 @@ def gen_traindata():
1518
return np.hstack((X, T)), Ca, Cb
1619

1720

21+
# Backend tensorflow.compat.v1 or tensorflow
1822
kf = tf.Variable(0.05)
1923
D = tf.Variable(1.0)
24+
# Backend pytorch
25+
# kf = torch.tensor(0.05, requires_grad=True)
26+
# D = torch.tensor(1.0, requires_grad=True)
2027

2128

2229
def pde(x, y):

0 commit comments

Comments
 (0)