Skip to content

Commit 1604c03

Browse files
committed
Backend tensorflow supports all diffusion examples
1 parent 7cf28a7 commit 1604c03

File tree

7 files changed

+124
-137
lines changed

7 files changed

+124
-137
lines changed

deepxde/data/pde.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def __init__(
107107
)
108108
)
109109
self.train_distribution = train_distribution
110-
self.anchors = anchors
110+
self.anchors = None if anchors is None else anchors.astype(config.real(np))
111111
self.exclusions = exclusions
112112

113113
self.soln = solution
@@ -192,6 +192,7 @@ def resample_train_points(self):
192192

193193
def add_anchors(self, anchors):
194194
"""Add new points for training PDE losses. The BC points will not be updated."""
195+
anchors = anchors.astype(config.real(np))
195196
if self.anchors is None:
196197
self.anchors = anchors
197198
else:
@@ -236,7 +237,9 @@ def bc_points(self):
236237
x_bcs = [bc.collocation_points(self.train_x_all) for bc in self.bcs]
237238
self.num_bcs = list(map(len, x_bcs))
238239
self.train_x_bc = (
239-
np.vstack(x_bcs) if x_bcs else np.empty([0, self.train_x_all.shape[-1]])
240+
np.vstack(x_bcs)
241+
if x_bcs
242+
else np.empty([0, self.train_x_all.shape[-1]], dtype=config.real(np))
240243
)
241244
return self.train_x_bc
242245

deepxde/icbcs/boundary_conditions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
import numpy as np
1010

11+
from .. import config
1112
from .. import gradients as grad
1213
from ..backend import tf
1314

@@ -146,12 +147,12 @@ class PointSetBC(object):
146147
"""
147148

148149
def __init__(self, points, values, component=0):
149-
self.points = np.array(points)
150+
self.points = np.array(points, dtype=config.real(np))
150151
if not isinstance(values, numbers.Number) and values.shape[1] != 1:
151152
raise RuntimeError(
152153
"PointSetBC should output 1D values. Use argument 'component' for different components."
153154
)
154-
self.values = values
155+
self.values = values.astype(config.real(np))
155156
self.component = component
156157

157158
def collocation_points(self, X):

deepxde/model.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,11 @@ def compile(
7777
weight the loss contributions. The loss value that will be minimized by
7878
the model will then be the weighted sum of all individual losses,
7979
weighted by the loss_weights coefficients.
80-
external_trainable_variables: A trainable ``tf.Variable`` object or a list of
81-
trainable ``tf.Variable`` objects. The unknown parameters in the physics
82-
systems that need to be recovered. Trainable variables from the neural
83-
networks and external_trainable_variables are trained together.
80+
external_trainable_variables: A trainable ``tf.Variable`` object or a list
81+
of trainable ``tf.Variable`` objects. The unknown parameters in the
82+
physics systems that need to be recovered. If the backend is
83+
tensorflow.compat.v1, `external_trainable_variables` is ignored, and all
84+
trainable ``tf.Variable`` objects are automatically collected.
8485
"""
8586
print("Compiling model...")
8687

@@ -92,12 +93,18 @@ def compile(
9293
self.saver = tf.train.Saver(max_to_keep=None)
9394

9495
self.opt_name = optimizer
95-
if external_trainable_variables is not None:
96+
if external_trainable_variables is None:
97+
self.external_trainable_variables = []
98+
else:
99+
if backend_name == "tensorflow.compat.v1":
100+
print(
101+
"Warning: For the backend tensorflow.compat.v1, "
102+
"`external_trainable_variables` is ignored, and all trainable "
103+
"``tf.Variable`` objects are automatically collected."
104+
)
96105
if not isinstance(external_trainable_variables, list):
97106
external_trainable_variables = [external_trainable_variables]
98107
self.external_trainable_variables = external_trainable_variables
99-
else:
100-
self.external_trainable_variables = []
101108

102109
loss_fn = losses_module.get(loss)
103110
if backend_name == "tensorflow.compat.v1":
@@ -135,6 +142,7 @@ def compute_losses(targets, outputs):
135142

136143
opt = optimizers.get(self.opt_name, learning_rate=lr, decay=decay)
137144

145+
# TODO: Avoid creating multiple graphs by using tf.TensorSpec.
138146
@tf.function
139147
def outputs_losses(data_id, inputs, targets):
140148
self.net.data_id = data_id

examples/diffusion_1d.py

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Backend supported: tensorflow.compat.v1"""
1+
"""Backend supported: tensorflow.compat.v1, tensorflow"""
22
import deepxde as dde
33
import numpy as np
44
from deepxde.backend import tf
@@ -19,36 +19,31 @@ def func(x):
1919
return np.sin(np.pi * x[:, 0:1]) * np.exp(-x[:, 1:])
2020

2121

22-
def main():
23-
geom = dde.geometry.Interval(-1, 1)
24-
timedomain = dde.geometry.TimeDomain(0, 1)
25-
geomtime = dde.geometry.GeometryXTime(geom, timedomain)
26-
27-
bc = dde.DirichletBC(geomtime, func, lambda _, on_boundary: on_boundary)
28-
ic = dde.IC(geomtime, func, lambda _, on_initial: on_initial)
29-
data = dde.data.TimePDE(
30-
geomtime,
31-
pde,
32-
[bc, ic],
33-
num_domain=40,
34-
num_boundary=20,
35-
num_initial=10,
36-
solution=func,
37-
num_test=10000,
38-
)
39-
40-
layer_size = [2] + [32] * 3 + [1]
41-
activation = "tanh"
42-
initializer = "Glorot uniform"
43-
net = dde.maps.FNN(layer_size, activation, initializer)
22+
geom = dde.geometry.Interval(-1, 1)
23+
timedomain = dde.geometry.TimeDomain(0, 1)
24+
geomtime = dde.geometry.GeometryXTime(geom, timedomain)
4425

45-
model = dde.Model(data, net)
26+
bc = dde.DirichletBC(geomtime, func, lambda _, on_boundary: on_boundary)
27+
ic = dde.IC(geomtime, func, lambda _, on_initial: on_initial)
28+
data = dde.data.TimePDE(
29+
geomtime,
30+
pde,
31+
[bc, ic],
32+
num_domain=40,
33+
num_boundary=20,
34+
num_initial=10,
35+
solution=func,
36+
num_test=10000,
37+
)
4638

47-
model.compile("adam", lr=0.001, metrics=["l2 relative error"])
48-
losshistory, train_state = model.train(epochs=10000)
39+
layer_size = [2] + [32] * 3 + [1]
40+
activation = "tanh"
41+
initializer = "Glorot uniform"
42+
net = dde.maps.FNN(layer_size, activation, initializer)
4943

50-
dde.saveplot(losshistory, train_state, issave=True, isplot=True)
44+
model = dde.Model(data, net)
5145

46+
model.compile("adam", lr=0.001, metrics=["l2 relative error"])
47+
losshistory, train_state = model.train(epochs=10000)
5248

53-
if __name__ == "__main__":
54-
main()
49+
dde.saveplot(losshistory, train_state, issave=True, isplot=True)

examples/diffusion_1d_exactBC.py

Lines changed: 16 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Backend supported: tensorflow.compat.v1"""
1+
"""Backend supported: tensorflow.compat.v1, tensorflow"""
22
import deepxde as dde
33
import numpy as np
44
from deepxde.backend import tf
@@ -19,35 +19,23 @@ def func(x):
1919
return np.sin(np.pi * x[:, 0:1]) * np.exp(-x[:, 1:])
2020

2121

22-
def main():
23-
geom = dde.geometry.Interval(-1, 1)
24-
timedomain = dde.geometry.TimeDomain(0, 1)
25-
geomtime = dde.geometry.GeometryXTime(geom, timedomain)
22+
geom = dde.geometry.Interval(-1, 1)
23+
timedomain = dde.geometry.TimeDomain(0, 1)
24+
geomtime = dde.geometry.GeometryXTime(geom, timedomain)
2625

27-
data = dde.data.TimePDE(
28-
geomtime,
29-
pde,
30-
[],
31-
num_domain=40,
32-
solution=func,
33-
num_test=10000,
34-
)
35-
36-
layer_size = [2] + [32] * 3 + [1]
37-
activation = "tanh"
38-
initializer = "Glorot uniform"
39-
net = dde.maps.FNN(layer_size, activation, initializer)
40-
net.apply_output_transform(
41-
lambda x, y: x[:, 1:2] * (1 - x[:, 0:1] ** 2) * y + tf.sin(np.pi * x[:, 0:1])
42-
)
43-
44-
model = dde.Model(data, net)
26+
data = dde.data.TimePDE(geomtime, pde, [], num_domain=40, solution=func, num_test=10000)
4527

46-
model.compile("adam", lr=0.001, metrics=["l2 relative error"])
47-
losshistory, train_state = model.train(epochs=10000)
28+
layer_size = [2] + [32] * 3 + [1]
29+
activation = "tanh"
30+
initializer = "Glorot uniform"
31+
net = dde.maps.FNN(layer_size, activation, initializer)
32+
net.apply_output_transform(
33+
lambda x, y: x[:, 1:2] * (1 - x[:, 0:1] ** 2) * y + tf.sin(np.pi * x[:, 0:1])
34+
)
4835

49-
dde.saveplot(losshistory, train_state, issave=True, isplot=True)
36+
model = dde.Model(data, net)
5037

38+
model.compile("adam", lr=0.001, metrics=["l2 relative error"])
39+
losshistory, train_state = model.train(epochs=10000)
5140

52-
if __name__ == "__main__":
53-
main()
41+
dde.saveplot(losshistory, train_state, issave=True, isplot=True)

examples/diffusion_1d_inverse.py

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Backend supported: tensorflow.compat.v1"""
1+
"""Backend supported: tensorflow.compat.v1, tensorflow"""
22
import deepxde as dde
33
import numpy as np
44
from deepxde.backend import tf
@@ -22,42 +22,39 @@ def func(x):
2222
return np.sin(np.pi * x[:, 0:1]) * np.exp(-x[:, 1:])
2323

2424

25-
def main():
26-
geom = dde.geometry.Interval(-1, 1)
27-
timedomain = dde.geometry.TimeDomain(0, 1)
28-
geomtime = dde.geometry.GeometryXTime(geom, timedomain)
25+
geom = dde.geometry.Interval(-1, 1)
26+
timedomain = dde.geometry.TimeDomain(0, 1)
27+
geomtime = dde.geometry.GeometryXTime(geom, timedomain)
2928

30-
bc = dde.DirichletBC(geomtime, func, lambda _, on_boundary: on_boundary)
31-
ic = dde.IC(geomtime, func, lambda _, on_initial: on_initial)
29+
bc = dde.DirichletBC(geomtime, func, lambda _, on_boundary: on_boundary)
30+
ic = dde.IC(geomtime, func, lambda _, on_initial: on_initial)
3231

33-
observe_x = np.vstack((np.linspace(-1, 1, num=10), np.full((10), 1))).T
34-
observe_y = dde.PointSetBC(observe_x, func(observe_x), component=0)
32+
observe_x = np.vstack((np.linspace(-1, 1, num=10), np.full((10), 1))).T
33+
observe_y = dde.PointSetBC(observe_x, func(observe_x), component=0)
3534

36-
data = dde.data.TimePDE(
37-
geomtime,
38-
pde,
39-
[bc, ic, observe_y],
40-
num_domain=40,
41-
num_boundary=20,
42-
num_initial=10,
43-
anchors=observe_x,
44-
solution=func,
45-
num_test=10000,
46-
)
47-
48-
layer_size = [2] + [32] * 3 + [1]
49-
activation = "tanh"
50-
initializer = "Glorot uniform"
51-
net = dde.maps.FNN(layer_size, activation, initializer)
52-
53-
model = dde.Model(data, net)
35+
data = dde.data.TimePDE(
36+
geomtime,
37+
pde,
38+
[bc, ic, observe_y],
39+
num_domain=40,
40+
num_boundary=20,
41+
num_initial=10,
42+
anchors=observe_x,
43+
solution=func,
44+
num_test=10000,
45+
)
5446

55-
model.compile("adam", lr=0.001, metrics=["l2 relative error"])
56-
variable = dde.callbacks.VariableValue(C, period=1000)
57-
losshistory, train_state = model.train(epochs=50000, callbacks=[variable])
47+
layer_size = [2] + [32] * 3 + [1]
48+
activation = "tanh"
49+
initializer = "Glorot uniform"
50+
net = dde.maps.FNN(layer_size, activation, initializer)
5851

59-
dde.saveplot(losshistory, train_state, issave=True, isplot=True)
52+
model = dde.Model(data, net)
6053

54+
model.compile(
55+
"adam", lr=0.001, metrics=["l2 relative error"], external_trainable_variables=C
56+
)
57+
variable = dde.callbacks.VariableValue(C, period=1000)
58+
losshistory, train_state = model.train(epochs=50000, callbacks=[variable])
6159

62-
if __name__ == "__main__":
63-
main()
60+
dde.saveplot(losshistory, train_state, issave=True, isplot=True)

examples/diffusion_1d_resample.py

Lines changed: 31 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Backend supported: tensorflow.compat.v1"""
1+
"""Backend supported: tensorflow.compat.v1, tensorflow"""
22
import deepxde as dde
33
import numpy as np
44
from deepxde.backend import tf
@@ -19,38 +19,33 @@ def func(x):
1919
return np.sin(np.pi * x[:, 0:1]) * np.exp(-x[:, 1:])
2020

2121

22-
def main():
23-
geom = dde.geometry.Interval(-1, 1)
24-
timedomain = dde.geometry.TimeDomain(0, 1)
25-
geomtime = dde.geometry.GeometryXTime(geom, timedomain)
26-
27-
bc = dde.DirichletBC(geomtime, func, lambda _, on_boundary: on_boundary)
28-
ic = dde.IC(geomtime, func, lambda _, on_initial: on_initial)
29-
data = dde.data.TimePDE(
30-
geomtime,
31-
pde,
32-
[bc, ic],
33-
num_domain=40,
34-
num_boundary=20,
35-
num_initial=10,
36-
train_distribution="pseudo",
37-
solution=func,
38-
num_test=10000,
39-
)
40-
41-
layer_size = [2] + [32] * 3 + [1]
42-
activation = "tanh"
43-
initializer = "Glorot uniform"
44-
net = dde.maps.FNN(layer_size, activation, initializer)
45-
46-
model = dde.Model(data, net)
47-
48-
resampler = dde.callbacks.PDEResidualResampler(period=100)
49-
model.compile("adam", lr=0.001, metrics=["l2 relative error"])
50-
losshistory, train_state = model.train(epochs=2000, callbacks=[resampler])
51-
52-
dde.saveplot(losshistory, train_state, issave=True, isplot=True)
53-
54-
55-
if __name__ == "__main__":
56-
main()
22+
geom = dde.geometry.Interval(-1, 1)
23+
timedomain = dde.geometry.TimeDomain(0, 1)
24+
geomtime = dde.geometry.GeometryXTime(geom, timedomain)
25+
26+
bc = dde.DirichletBC(geomtime, func, lambda _, on_boundary: on_boundary)
27+
ic = dde.IC(geomtime, func, lambda _, on_initial: on_initial)
28+
data = dde.data.TimePDE(
29+
geomtime,
30+
pde,
31+
[bc, ic],
32+
num_domain=40,
33+
num_boundary=20,
34+
num_initial=10,
35+
train_distribution="pseudo",
36+
solution=func,
37+
num_test=10000,
38+
)
39+
40+
layer_size = [2] + [32] * 3 + [1]
41+
activation = "tanh"
42+
initializer = "Glorot uniform"
43+
net = dde.maps.FNN(layer_size, activation, initializer)
44+
45+
model = dde.Model(data, net)
46+
47+
resampler = dde.callbacks.PDEResidualResampler(period=100)
48+
model.compile("adam", lr=0.001, metrics=["l2 relative error"])
49+
losshistory, train_state = model.train(epochs=2000, callbacks=[resampler])
50+
51+
dde.saveplot(losshistory, train_state, issave=True, isplot=True)

0 commit comments

Comments
 (0)