Skip to content

Commit 75d2020

Browse files
committed
Backend tensorflow supports more examples
1 parent 1604c03 commit 75d2020

14 files changed

+268
-309
lines changed

deepxde/callbacks.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,9 @@ class ModelCheckpoint(Callback):
116116
Args:
117117
filepath (string): Path to save the model file.
118118
verbose: Verbosity mode, 0 or 1.
119-
save_better_only: If True, only save a better model according to the quantity monitored.
120-
Model is only checked at validation step according to ``display_every`` in ``Model.train``.
119+
save_better_only: If True, only save a better model according to the quantity
120+
monitored. Model is only checked at validation step according to
121+
``display_every`` in ``Model.train``.
121122
period: Interval (number of epochs) between checkpoints.
122123
"""
123124

@@ -364,6 +365,12 @@ def __init__(
364365
self.spectrum = []
365366
self.epochs_since_last_save = 0
366367

368+
# TODO: support backend tensorflow
369+
if backend_name != "tensorflow.compat.v1":
370+
raise RuntimeError(
371+
"MovieDumper only supports backend tensorflow.compat.v1."
372+
)
373+
367374
def init(self):
368375
self.tf_op = self.model.net.outputs[:, self.component]
369376
self.feed_dict = self.model.net.feed_dict(False, False, 2, self.x)

deepxde/data/pde.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def add_anchors(self, anchors):
206206
self.train_aux_vars = self.auxiliary_var_fn(self.train_x)
207207

208208
def train_points(self):
209-
X = np.empty((0, self.geom.dim))
209+
X = np.empty((0, self.geom.dim), dtype=config.real(np))
210210
if self.num_domain > 0:
211211
if self.train_distribution == "uniform":
212212
X = self.geom.uniform_points(self.num_domain, boundary=False)

deepxde/geometry/geometry_nd.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from .geometry import Geometry
1212
from .sampler import sample
13+
from .. import config
1314

1415

1516
class Hypercube(Geometry):
@@ -19,7 +20,8 @@ def __init__(self, xmin, xmax):
1920
if np.any(np.array(xmin) >= np.array(xmax)):
2021
raise ValueError("xmin >= xmax")
2122

22-
self.xmin, self.xmax = np.array(xmin), np.array(xmax)
23+
self.xmin = np.array(xmin, dtype=config.real(np))
24+
self.xmax = np.array(xmax, dtype=config.real(np))
2325
self.side_length = self.xmax - self.xmin
2426
super(Hypercube, self).__init__(
2527
len(xmin), (self.xmin, self.xmax), np.linalg.norm(self.side_length)

deepxde/geometry/sampler.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import numpy as np
88
import skopt
99

10+
from .. import config
11+
1012

1113
def sample(n_samples, dimension, sampler="pseudo"):
1214
"""Generate random or quasirandom samples in [0, 1]^dimension.
@@ -28,7 +30,7 @@ def sample(n_samples, dimension, sampler="pseudo"):
2830
def pseudo(n_samples, dimension):
2931
"""Pseudo random."""
3032
rng = np.random.default_rng()
31-
return rng.random((n_samples, dimension))
33+
return rng.random(size=(n_samples, dimension), dtype=config.real(np))
3234

3335

3436
def quasirandom(n_samples, dimension, sampler):
@@ -48,6 +50,8 @@ def quasirandom(n_samples, dimension, sampler):
4850
else:
4951
sampler = skopt.sampler.Sobol(skip=0, randomize=False)
5052
space = [(0.0, 1.0)] * dimension
51-
return np.array(sampler.generate(space, n_samples + 2)[2:])
53+
return np.array(
54+
sampler.generate(space, n_samples + 2)[2:], dtype=config.real(np)
55+
)
5256
space = [(0.0, 1.0)] * dimension
53-
return np.array(sampler.generate(space, n_samples))
57+
return np.array(sampler.generate(space, n_samples), dtype=config.real(np))

examples/Euler_beam.py

Lines changed: 27 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Backend supported: tensorflow.compat.v1"""
1+
"""Backend supported: tensorflow.compat.v1, tensorflow"""
22
import deepxde as dde
33
import numpy as np
44
from deepxde.backend import tf
@@ -29,34 +29,29 @@ def func(x):
2929
return -(x ** 4) / 24 + x ** 3 / 6 - x ** 2 / 4
3030

3131

32-
def main():
33-
geom = dde.geometry.Interval(0, 1)
34-
35-
bc1 = dde.DirichletBC(geom, lambda x: 0, boundary_l)
36-
bc2 = dde.NeumannBC(geom, lambda x: 0, boundary_l)
37-
bc3 = dde.OperatorBC(geom, lambda x, y, _: ddy(x, y), boundary_r)
38-
bc4 = dde.OperatorBC(geom, lambda x, y, _: dddy(x, y), boundary_r)
39-
40-
data = dde.data.PDE(
41-
geom,
42-
pde,
43-
[bc1, bc2, bc3, bc4],
44-
num_domain=10,
45-
num_boundary=2,
46-
solution=func,
47-
num_test=100,
48-
)
49-
layer_size = [1] + [20] * 3 + [1]
50-
activation = "tanh"
51-
initializer = "Glorot uniform"
52-
net = dde.maps.FNN(layer_size, activation, initializer)
53-
54-
model = dde.Model(data, net)
55-
model.compile("adam", lr=0.001, metrics=["l2 relative error"])
56-
losshistory, train_state = model.train(epochs=10000)
57-
58-
dde.saveplot(losshistory, train_state, issave=True, isplot=True)
59-
60-
61-
if __name__ == "__main__":
62-
main()
32+
geom = dde.geometry.Interval(0, 1)
33+
34+
bc1 = dde.DirichletBC(geom, lambda x: 0, boundary_l)
35+
bc2 = dde.NeumannBC(geom, lambda x: 0, boundary_l)
36+
bc3 = dde.OperatorBC(geom, lambda x, y, _: ddy(x, y), boundary_r)
37+
bc4 = dde.OperatorBC(geom, lambda x, y, _: dddy(x, y), boundary_r)
38+
39+
data = dde.data.PDE(
40+
geom,
41+
pde,
42+
[bc1, bc2, bc3, bc4],
43+
num_domain=10,
44+
num_boundary=2,
45+
solution=func,
46+
num_test=100,
47+
)
48+
layer_size = [1] + [20] * 3 + [1]
49+
activation = "tanh"
50+
initializer = "Glorot uniform"
51+
net = dde.maps.FNN(layer_size, activation, initializer)
52+
53+
model = dde.Model(data, net)
54+
model.compile("adam", lr=0.001, metrics=["l2 relative error"])
55+
losshistory, train_state = model.train(epochs=10000)
56+
57+
dde.saveplot(losshistory, train_state, issave=True, isplot=True)

examples/Laplace_disk.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Backend supported: tensorflow.compat.v1"""
1+
"""Backend supported: tensorflow.compat.v1, tensorflow"""
22
import deepxde as dde
33
import numpy as np
44
from deepxde.backend import tf
@@ -16,31 +16,26 @@ def solution(x):
1616
return r * np.cos(theta)
1717

1818

19-
def main():
20-
geom = dde.geometry.Rectangle(xmin=[0, 0], xmax=[1, 2 * np.pi])
21-
bc_rad = dde.DirichletBC(
22-
geom,
23-
lambda x: np.cos(x[:, 1:2]),
24-
lambda x, on_boundary: on_boundary and np.isclose(x[0], 1),
19+
geom = dde.geometry.Rectangle(xmin=[0, 0], xmax=[1, 2 * np.pi])
20+
bc_rad = dde.DirichletBC(
21+
geom,
22+
lambda x: np.cos(x[:, 1:2]),
23+
lambda x, on_boundary: on_boundary and np.isclose(x[0], 1),
24+
)
25+
data = dde.data.PDE(
26+
geom, pde, bc_rad, num_domain=2540, num_boundary=80, solution=solution
27+
)
28+
29+
net = dde.maps.FNN([2] + [20] * 3 + [1], "tanh", "Glorot normal")
30+
# Use [r*sin(theta), r*cos(theta)] as features,
31+
# so that the network is automatically periodic along the theta coordinate.
32+
net.apply_feature_transform(
33+
lambda x: tf.concat(
34+
[x[:, 0:1] * tf.sin(x[:, 1:2]), x[:, 0:1] * tf.cos(x[:, 1:2])], axis=1
2535
)
26-
data = dde.data.PDE(
27-
geom, pde, bc_rad, num_domain=2540, num_boundary=80, solution=solution
28-
)
29-
30-
net = dde.maps.FNN([2] + [20] * 3 + [1], "tanh", "Glorot normal")
31-
# Use [r*sin(theta), r*cos(theta)] as features,
32-
# so that the network is automatically periodic along the theta coordinate.
33-
net.apply_feature_transform(
34-
lambda x: tf.concat(
35-
[x[:, 0:1] * tf.sin(x[:, 1:2]), x[:, 0:1] * tf.cos(x[:, 1:2])], axis=1
36-
)
37-
)
38-
39-
model = dde.Model(data, net)
40-
model.compile("adam", lr=1e-3, metrics=["l2 relative error"])
41-
losshistory, train_state = model.train(epochs=15000)
42-
dde.saveplot(losshistory, train_state, issave=True, isplot=True)
43-
36+
)
4437

45-
if __name__ == "__main__":
46-
main()
38+
model = dde.Model(data, net)
39+
model.compile("adam", lr=1e-3, metrics=["l2 relative error"])
40+
losshistory, train_state = model.train(epochs=15000)
41+
dde.saveplot(losshistory, train_state, issave=True, isplot=True)

examples/Lorenz_inverse.py

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Backend supported: tensorflow.compat.v1
1+
"""Backend supported: tensorflow.compat.v1, tensorflow
22
33
Documentation: https://deepxde.readthedocs.io/en/latest/demos/lorenz.inverse.html
44
"""
@@ -38,38 +38,33 @@ def boundary(_, on_initial):
3838
return on_initial
3939

4040

41-
def main():
42-
geom = dde.geometry.TimeDomain(0, 3)
43-
44-
# Initial conditions
45-
ic1 = dde.IC(geom, lambda X: -8, boundary, component=0)
46-
ic2 = dde.IC(geom, lambda X: 7, boundary, component=1)
47-
ic3 = dde.IC(geom, lambda X: 27, boundary, component=2)
48-
49-
# Get the train data
50-
observe_t, ob_y = gen_traindata()
51-
observe_y0 = dde.PointSetBC(observe_t, ob_y[:, 0:1], component=0)
52-
observe_y1 = dde.PointSetBC(observe_t, ob_y[:, 1:2], component=1)
53-
observe_y2 = dde.PointSetBC(observe_t, ob_y[:, 2:3], component=2)
54-
55-
data = dde.data.PDE(
56-
geom,
57-
Lorenz_system,
58-
[ic1, ic2, ic3, observe_y0, observe_y1, observe_y2],
59-
num_domain=400,
60-
num_boundary=2,
61-
anchors=observe_t,
62-
)
63-
64-
net = dde.maps.FNN([1] + [40] * 3 + [3], "tanh", "Glorot uniform")
65-
model = dde.Model(data, net)
66-
model.compile("adam", lr=0.001)
67-
variable = dde.callbacks.VariableValue(
68-
[C1, C2, C3], period=600, filename="variables.dat"
69-
)
70-
losshistory, train_state = model.train(epochs=60000, callbacks=[variable])
71-
dde.saveplot(losshistory, train_state, issave=True, isplot=True)
72-
73-
74-
if __name__ == "__main__":
75-
main()
41+
geom = dde.geometry.TimeDomain(0, 3)
42+
43+
# Initial conditions
44+
ic1 = dde.IC(geom, lambda X: -8, boundary, component=0)
45+
ic2 = dde.IC(geom, lambda X: 7, boundary, component=1)
46+
ic3 = dde.IC(geom, lambda X: 27, boundary, component=2)
47+
48+
# Get the train data
49+
observe_t, ob_y = gen_traindata()
50+
observe_y0 = dde.PointSetBC(observe_t, ob_y[:, 0:1], component=0)
51+
observe_y1 = dde.PointSetBC(observe_t, ob_y[:, 1:2], component=1)
52+
observe_y2 = dde.PointSetBC(observe_t, ob_y[:, 2:3], component=2)
53+
54+
data = dde.data.PDE(
55+
geom,
56+
Lorenz_system,
57+
[ic1, ic2, ic3, observe_y0, observe_y1, observe_y2],
58+
num_domain=400,
59+
num_boundary=2,
60+
anchors=observe_t,
61+
)
62+
63+
net = dde.maps.FNN([1] + [40] * 3 + [3], "tanh", "Glorot uniform")
64+
model = dde.Model(data, net)
65+
model.compile("adam", lr=0.001, external_trainable_variables=[C1, C2, C3])
66+
variable = dde.callbacks.VariableValue(
67+
[C1, C2, C3], period=600, filename="variables.dat"
68+
)
69+
losshistory, train_state = model.train(epochs=60000, callbacks=[variable])
70+
dde.saveplot(losshistory, train_state, issave=True, isplot=True)

examples/Poisson_Dirichlet_1d.py

Lines changed: 36 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Backend supported: tensorflow.compat.v1
1+
"""Backend supported: tensorflow.compat.v1, tensorflow
22
33
Documentation: https://deepxde.readthedocs.io/en/latest/demos/poisson.1d.dirichlet.html
44
"""
@@ -21,42 +21,38 @@ def func(x):
2121
return np.sin(np.pi * x)
2222

2323

24-
def main():
25-
geom = dde.geometry.Interval(-1, 1)
26-
bc = dde.DirichletBC(geom, func, boundary)
27-
data = dde.data.PDE(geom, pde, bc, 16, 2, solution=func, num_test=100)
28-
29-
layer_size = [1] + [50] * 3 + [1]
30-
activation = "tanh"
31-
initializer = "Glorot uniform"
32-
net = dde.maps.FNN(layer_size, activation, initializer)
33-
34-
model = dde.Model(data, net)
35-
model.compile("adam", lr=0.001, metrics=["l2 relative error"])
36-
37-
checkpointer = dde.callbacks.ModelCheckpoint(
38-
"model/model.ckpt", verbose=1, save_better_only=True
39-
)
40-
# ImageMagick (https://imagemagick.org/) is required to generate the movie.
41-
movie = dde.callbacks.MovieDumper(
42-
"model/movie", [-1], [1], period=100, save_spectrum=True, y_reference=func
43-
)
44-
losshistory, train_state = model.train(
45-
epochs=10000, callbacks=[checkpointer, movie]
46-
)
47-
48-
dde.saveplot(losshistory, train_state, issave=True, isplot=True)
49-
50-
# Plot PDE residual
51-
model.restore("model/model.ckpt-" + str(train_state.best_step), verbose=1)
52-
x = geom.uniform_points(1000, True)
53-
y = model.predict(x, operator=pde)
54-
plt.figure()
55-
plt.plot(x, y)
56-
plt.xlabel("x")
57-
plt.ylabel("PDE residual")
58-
plt.show()
59-
60-
61-
if __name__ == "__main__":
62-
main()
24+
geom = dde.geometry.Interval(-1, 1)
25+
bc = dde.DirichletBC(geom, func, boundary)
26+
data = dde.data.PDE(geom, pde, bc, 16, 2, solution=func, num_test=100)
27+
28+
layer_size = [1] + [50] * 3 + [1]
29+
activation = "tanh"
30+
initializer = "Glorot uniform"
31+
net = dde.maps.FNN(layer_size, activation, initializer)
32+
33+
model = dde.Model(data, net)
34+
model.compile("adam", lr=0.001, metrics=["l2 relative error"])
35+
36+
losshistory, train_state = model.train(epochs=10000)
37+
# Optional: Save the model during training.
38+
# checkpointer = dde.callbacks.ModelCheckpoint(
39+
# "model/model.ckpt", verbose=1, save_better_only=True
40+
# )
41+
# Optional: Save the movie of the network solution during training.
42+
# ImageMagick (https://imagemagick.org/) is required to generate the movie.
43+
# movie = dde.callbacks.MovieDumper(
44+
# "model/movie", [-1], [1], period=100, save_spectrum=True, y_reference=func
45+
# )
46+
# losshistory, train_state = model.train(epochs=10000, callbacks=[checkpointer, movie])
47+
48+
dde.saveplot(losshistory, train_state, issave=True, isplot=True)
49+
50+
# Plot PDE residual
51+
model.restore("model/model.ckpt-" + str(train_state.best_step), verbose=1)
52+
x = geom.uniform_points(1000, True)
53+
y = model.predict(x, operator=pde)
54+
plt.figure()
55+
plt.plot(x, y)
56+
plt.xlabel("x")
57+
plt.ylabel("PDE residual")
58+
plt.show()

0 commit comments

Comments
 (0)