33from .data import Data
44from .sampler import BatchSampler
55from .. import backend as bkd
6- from ..backend import backend_name
76from .. import config
87from ..utils import run_if_all_none
98
@@ -264,33 +263,18 @@ def _losses(self, outputs, loss_fn, inputs, model, num_func, aux=None):
264263 # Use stack instead of as_tensor to keep the gradients.
265264 losses = [bkd .reduce_mean (bkd .stack (loss , 0 )) for loss in losses ]
266265 elif config .autodiff == "forward" : # forward mode AD
267- batchsize1 , batchsize2 = bkd .shape (outputs )[:2 ]
268- shape_3d = (batchsize1 , batchsize2 , model .net .num_outputs )
269266
270- # Uniformly reshape the output into the shape (N1, N2, num_outputs),
271267 def forward_call (trunk_input ):
272- output = aux [0 ]((inputs [0 ], trunk_input ))
273- return bkd .reshape (output , shape_3d )
268+ return aux [0 ]((inputs [0 ], trunk_input ))
274269
275270 f = []
276271 if self .pde .pde is not None :
277- if backend_name in ["tensorflow.compat.v1" ]:
278- outputs_pde = bkd .reshape (outputs , shape_3d )
279- elif backend_name in ["tensorflow" , "pytorch" ]:
280- outputs_pde = (bkd .reshape (outputs , shape_3d ), forward_call )
281272 # Each f has the shape (N1, N2)
282273 f = self .pde .pde (
283- inputs [1 ],
284- outputs_pde ,
285- bkd .reshape (
286- model .net .auxiliary_vars ,
287- shape_3d ,
288- ),
274+ inputs [1 ], (outputs , forward_call ), model .net .auxiliary_vars
289275 )
290276 if not isinstance (f , (list , tuple )):
291277 f = [f ]
292- f = [bkd .reshape (fi , (batchsize1 , batchsize2 )) for fi in f ]
293-
294278 # Each error has the shape (N1, ~N2)
295279 error_f = [fi [:, bcs_start [- 1 ] :] for fi in f ]
296280 for error in error_f :
@@ -365,4 +349,4 @@ def test(self):
365349 )
366350 self .test_x = (func_vals , self .pde .test_x )
367351 self .test_aux_vars = vx
368- return self .test_x , self .test_y , self .test_aux_vars
352+ return self .test_x , self .test_y , self .test_aux_vars
0 commit comments