Skip to content

Commit d52f438

Browse files
committed
Backend PyTorch supports more examples
1 parent d20a997 commit d52f438

File tree

3 files changed

+13
-18
lines changed

3 files changed

+13
-18
lines changed

deepxde/backend/pytorch/tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def square(x):
4949

5050

5151
def tanh(x):
52-
return torch.nn.functional.tanh(x)
52+
return torch.tanh(x)
5353

5454

5555
def mean(input_tensor, dim, keepdims=False):

examples/dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Backend supported: tensorflow.compat.v1, tensorflow"""
1+
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch"""
22
import deepxde as dde
33

44

examples/func.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,22 +11,17 @@ def func(x):
1111
return x * np.sin(5 * x)
1212

1313

14-
def main():
15-
geom = dde.geometry.Interval(-1, 1)
16-
num_train = 16
17-
num_test = 100
18-
data = dde.data.Function(geom, func, num_train, num_test)
14+
geom = dde.geometry.Interval(-1, 1)
15+
num_train = 16
16+
num_test = 100
17+
data = dde.data.Function(geom, func, num_train, num_test)
1918

20-
activation = "tanh"
21-
initializer = "Glorot uniform"
22-
net = dde.maps.FNN([1] + [20] * 3 + [1], activation, initializer)
19+
activation = "tanh"
20+
initializer = "Glorot uniform"
21+
net = dde.maps.FNN([1] + [20] * 3 + [1], activation, initializer)
2322

24-
model = dde.Model(data, net)
25-
model.compile("adam", lr=0.001, metrics=["l2 relative error"])
26-
losshistory, train_state = model.train(epochs=10000)
23+
model = dde.Model(data, net)
24+
model.compile("adam", lr=0.001, metrics=["l2 relative error"])
25+
losshistory, train_state = model.train(epochs=10000)
2726

28-
dde.saveplot(losshistory, train_state, issave=True, isplot=True)
29-
30-
31-
if __name__ == "__main__":
32-
main()
27+
dde.saveplot(losshistory, train_state, issave=True, isplot=True)

0 commit comments

Comments
 (0)