Skip to content

Commit e030a3d

Browse files
authored
Backend Tensorflow 1.x: DeepONet & DeepONetCartesianProd support multiple outputs (#1532)
1 parent 5b21146 commit e030a3d

File tree

3 files changed

+273
-36
lines changed

3 files changed

+273
-36
lines changed

deepxde/data/pde_operator.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -236,8 +236,10 @@ def _losses(self, outputs, loss_fn, inputs, model, num_func):
236236

237237
losses = []
238238
for i in range(num_func):
239-
out = outputs[i][:, None]
240-
239+
out = outputs[i]
240+
# Single output
241+
if bkd.ndim(out) == 1:
242+
out = out[:, None]
241243
f = []
242244
if self.pde.pde is not None:
243245
f = self.pde.pde(inputs[1], out, model.net.auxiliary_vars[i][:, None])

deepxde/data/triple.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,11 @@ class TripleCartesianProd(Data):
5959
"""
6060

6161
def __init__(self, X_train, y_train, X_test, y_test):
62-
if len(X_train[0]) * len(X_train[1]) != y_train.size:
62+
if len(X_train[0]) != y_train.shape[0] or len(X_train[1]) != y_train.shape[1]:
6363
raise ValueError(
6464
"The training dataset does not have the format of Cartesian product."
6565
)
66-
if len(X_test[0]) * len(X_test[1]) != y_test.size:
66+
if len(X_test[0]) != y_test.shape[0] or len(X_test[1]) != y_test.shape[1]:
6767
raise ValueError(
6868
"The testing dataset does not have the format of Cartesian product."
6969
)

0 commit comments

Comments
 (0)