File tree Expand file tree Collapse file tree 1 file changed +6
-4
lines changed Expand file tree Collapse file tree 1 file changed +6
-4
lines changed Original file line number Diff line number Diff line change 33from .data import Data
44from .sampler import BatchSampler
55from .. import backend as bkd
6+ from ..backend import backend_name
67from .. import config
78from ..utils import run_if_all_none
89
@@ -274,13 +275,14 @@ def forward_call(trunk_input):
274275
275276 f = []
276277 if self .pde .pde is not None :
278+ if backend_name in ["tensorflow.compat.v1" ]:
279+ outputs_pde = bkd .reshape (outputs , shape_2d )
280+ elif backend_name in ["tensorflow" , "pytorch" ]:
281+ outputs_pde = (bkd .reshape (outputs , shape_2d ), forward_call )
277282 # Each f has the shape (N1, N2)
278283 f = self .pde .pde (
279284 inputs [1 ],
280- (
281- bkd .reshape (outputs , shape_2d ),
282- forward_call ,
283- ),
285+ outputs_pde ,
284286 bkd .reshape (
285287 model .net .auxiliary_vars ,
286288 shape_2d ,
You can’t perform that action at this time.
0 commit comments