Skip to content

Commit c5c4e60

Browse files
committed
update code to support different backend
1 parent d09ddd1 commit c5c4e60

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

deepxde/data/pde_operator.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .data import Data
44
from .sampler import BatchSampler
55
from .. import backend as bkd
6+
from ..backend import backend_name
67
from .. import config
78
from ..utils import run_if_all_none
89

@@ -274,13 +275,14 @@ def forward_call(trunk_input):
274275

275276
f = []
276277
if self.pde.pde is not None:
278+
if backend_name in ["tensorflow.compat.v1"]:
279+
outputs_pde = bkd.reshape(outputs, shape_2d)
280+
elif backend_name in ["tensorflow", "pytorch"]:
281+
outputs_pde = (bkd.reshape(outputs, shape_2d), forward_call)
277282
# Each f has the shape (N1, N2)
278283
f = self.pde.pde(
279284
inputs[1],
280-
(
281-
bkd.reshape(outputs, shape_2d),
282-
forward_call,
283-
),
285+
outputs_pde,
284286
bkd.reshape(
285287
model.net.auxiliary_vars,
286288
shape_2d,

0 commit comments

Comments
 (0)