Skip to content

Commit 6df12b5

Browse files
committed
PDEOperator supports testing functions
1 parent db7bcf8 commit 6df12b5

File tree

2 files changed

+47
-33
lines changed

2 files changed

+47
-33
lines changed

deepxde/data/pde.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ class PDE(Data):
2424
`num_boundary` sampled points.
2525
exclusions: A Numpy array of points to be excluded for training.
2626
solution: The reference solution.
27-
num_test: The number of points sampled inside the domain for testing. The
28-
testing points on the boundary are the same set of points used for training.
27+
num_test: The number of points sampled inside the domain for testing PDE loss.
28+
The testing points for BCs/ICs are the same set of points used for training.
2929
If ``None``, then the training points will be used for testing.
3030
auxiliary_var_function: A function that inputs `train_x` or `test_x` and outputs
3131
auxiliary variables.

deepxde/data/pde_operator.py

Lines changed: 45 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ class PDEOperator(Data):
2323
variables of the function by `function_variables=[0]`, where `0` indicates
2424
the first variable `x`. If ``None``, then we assume the domains of the
2525
function and the PDE are the same.
26+
num_test: The number of functions for testing PDE loss. The testing functions
27+
for BCs/ICs are the same functions used for training. If ``None``, then the
28+
training functions will be used for testing.
2629
2730
Attributes:
2831
train_x_bc: A triple of three Numpy arrays (v, x, vx) fed into PIDeepONet for
@@ -43,6 +46,7 @@ def __init__(
4346
evaluation_points,
4447
num_function,
4548
function_variables=None,
49+
num_test=None,
4650
):
4751
self.pde = pde
4852
self.func_space = function_space
@@ -53,6 +57,7 @@ def __init__(
5357
if function_variables is not None
5458
else list(range(pde.geom.dim))
5559
)
60+
self.num_test = num_test
5661

5762
self.num_bcs = [n * self.num_func for n in self.pde.num_bcs]
5863
self.train_x_bc = None
@@ -92,33 +97,53 @@ def losses(self, targets, outputs, loss, model):
9297
def train_next_batch(self, batch_size=None):
9398
func_feats = self.func_space.random(self.num_func)
9499
func_vals = self.func_space.eval_batch(func_feats, self.eval_pts)
95-
96100
v, x, vx = self.bc_inputs(func_feats, func_vals)
97101
if self.pde.pde is not None:
98-
# Branch input: v
99-
v_pde = np.repeat(func_vals, len(self.pde.train_x_all), axis=0)
102+
v_pde, x_pde, vx_pde = self.gen_inputs(
103+
func_feats, func_vals, self.pde.train_x_all
104+
)
100105
v = np.vstack((v, v_pde))
101-
# Trunk input: x
102-
x_pde = np.tile(self.pde.train_x_all, (self.num_func, 1))
103106
x = np.vstack((x, x_pde))
104-
# vx
105-
vx_pde = self.func_space.eval_batch(
106-
func_feats, self.pde.train_x_all[:, self.func_vars]
107-
).reshape(-1, 1)
108107
vx = np.vstack((vx, vx_pde))
109-
110108
self.train_x = (v, x, vx)
111109
self.train_y = None
112110
return self.train_x, self.train_y
113111

114112
@run_if_all_none("test_x", "test_y")
115113
def test(self):
116-
# TODO: Use different BC data from self.train_x
117-
# TODO
118-
self.test_x = self.train_x
119-
self.test_y = self.train_y
114+
if self.num_test is None:
115+
self.test_x = self.train_x
116+
else:
117+
func_feats = self.func_space.random(self.num_test)
118+
func_vals = self.func_space.eval_batch(func_feats, self.eval_pts)
119+
# TODO: Use different BC data from self.train_x
120+
v, x, vx = self.train_x_bc
121+
if self.pde.pde is not None:
122+
v_pde, x_pde, vx_pde = self.gen_inputs(
123+
func_feats, func_vals, self.pde.test_x[sum(self.pde.num_bcs) :]
124+
)
125+
v = np.vstack((v, v_pde))
126+
x = np.vstack((x, x_pde))
127+
vx = np.vstack((vx, vx_pde))
128+
self.test_x = (v, x, vx)
129+
self.test_y = None
120130
return self.test_x, self.test_y
121131

132+
def gen_inputs(self, func_feats, func_vals, points):
133+
# Format:
134+
# v1, x_1
135+
# ...
136+
# v1, x_N1
137+
# v2, x_1
138+
# ...
139+
# v2, x_N1
140+
v = np.repeat(func_vals, len(points), axis=0)
141+
x = np.tile(points, (len(func_feats), 1))
142+
vx = self.func_space.eval_batch(func_feats, points[:, self.func_vars]).reshape(
143+
-1, 1
144+
)
145+
return v, x, vx
146+
122147
def bc_inputs(self, func_feats, func_vals):
123148
if not self.pde.bcs:
124149
self.train_x_bc = (
@@ -127,26 +152,15 @@ def bc_inputs(self, func_feats, func_vals):
127152
np.empty((0, 1), dtype=config.real(np)),
128153
)
129154
return self.train_x_bc
130-
# Format:
131-
# v1, x_bc1_1
132-
# ...
133-
# v1, x_bc1_N1
134-
# v2, x_bc1_1
135-
# ...
136-
# v2, x_bc1_N1
137155
v, x, vx = [], [], []
138156
bcs_start = np.cumsum([0] + self.pde.num_bcs)
139-
for i, num_bc in enumerate(self.pde.num_bcs):
157+
for i, _ in enumerate(self.pde.num_bcs):
140158
beg, end = bcs_start[i], bcs_start[i + 1]
141-
# Branch input: v
142-
v.append(np.repeat(func_vals, num_bc, axis=0))
143-
# Trunk input: x
144-
x.append(np.tile(self.pde.train_x_bc[beg:end], (self.num_func, 1)))
145-
# vx
146-
vx.append(
147-
self.func_space.eval_batch(
148-
func_feats, self.pde.train_x_bc[beg:end, self.func_vars]
149-
).reshape(-1, 1)
159+
vi, xi, vxi = self.gen_inputs(
160+
func_feats, func_vals, self.pde.train_x_bc[beg:end]
150161
)
162+
v.append(vi)
163+
x.append(xi)
164+
vx.append(vxi)
151165
self.train_x_bc = (np.vstack(v), np.vstack(x), np.vstack(vx))
152166
return self.train_x_bc

0 commit comments

Comments
 (0)