@@ -237,23 +237,59 @@ def __init__(
237237 self .train_next_batch ()
238238 self .test ()
239239
240- def _losses (self , outputs , loss_fn , inputs , model , num_func ):
240+ def _losses (self , outputs , loss_fn , inputs , model , num_func , aux = None ):
241241 bcs_start = np .cumsum ([0 ] + self .pde .num_bcs )
242242
243243 losses = []
244- for i in range (num_func ):
245- out = outputs [i ]
246- # Single output
247- if bkd .ndim (out ) == 1 :
248- out = out [:, None ]
244+ # PDE loss
245+ if config .autodiff == "reverse" : # reverse mode AD
246+ for i in range (num_func ):
247+ out = outputs [i ]
248+ # Single output
249+ if bkd .ndim (out ) == 1 :
250+ out = out [:, None ]
251+ f = []
252+ if self .pde .pde is not None :
253+ f = self .pde .pde (
254+ inputs [1 ], out , model .net .auxiliary_vars [i ][:, None ]
255+ )
256+ if not isinstance (f , (list , tuple )):
257+ f = [f ]
258+ error_f = [fi [bcs_start [- 1 ] :] for fi in f ]
259+ losses_i = [loss_fn (bkd .zeros_like (error ), error ) for error in error_f ]
260+ losses .append (losses_i )
261+
262+ losses = zip (* losses )
263+ # Use stack instead of as_tensor to keep the gradients.
264+ losses = [bkd .reduce_mean (bkd .stack (loss , 0 )) for loss in losses ]
265+ elif config .autodiff == "forward" : # forward mode AD
266+
267+ def forward_call (trunk_input ):
268+ return aux [0 ]((inputs [0 ], trunk_input ))
269+
249270 f = []
250271 if self .pde .pde is not None :
251- f = self .pde .pde (inputs [1 ], out , model .net .auxiliary_vars [i ][:, None ])
272+ # Each f has the shape (N1, N2)
273+ f = self .pde .pde (
274+ inputs [1 ], (outputs , forward_call ), model .net .auxiliary_vars
275+ )
252276 if not isinstance (f , (list , tuple )):
253277 f = [f ]
254- error_f = [fi [bcs_start [- 1 ] :] for fi in f ]
255- losses_i = [loss_fn (bkd .zeros_like (error ), error ) for error in error_f ]
256-
278+ # Each error has the shape (N1, ~N2)
279+ error_f = [fi [:, bcs_start [- 1 ] :] for fi in f ]
280+ for error in error_f :
281+ error_i = []
282+ for i in range (num_func ):
283+ error_i .append (loss_fn (bkd .zeros_like (error [i ]), error [i ]))
284+ losses .append (bkd .reduce_mean (bkd .stack (error_i , 0 )))
285+
286+ # BC loss
287+ losses_bc = []
288+ for i in range (num_func ):
289+ losses_i = []
290+ out = outputs [i ]
291+ if bkd .ndim (out ) == 1 :
292+ out = out [:, None ]
257293 for j , bc in enumerate (self .pde .bcs ):
258294 beg , end = bcs_start [j ], bcs_start [j + 1 ]
259295 # The same BC points are used for training and testing.
@@ -267,19 +303,21 @@ def _losses(self, outputs, loss_fn, inputs, model, num_func):
267303 )
268304 losses_i .append (loss_fn (bkd .zeros_like (error ), error ))
269305
270- losses .append (losses_i )
306+ losses_bc .append (losses_i )
271307
272- losses = zip (* losses )
273- # Use stack instead of as_tensor to keep the gradients.
274- losses = [ bkd . reduce_mean ( bkd . stack ( loss , 0 )) for loss in losses ]
308+ losses_bc = zip (* losses_bc )
309+ losses_bc = [ bkd . reduce_mean ( bkd . stack ( loss , 0 )) for loss in losses_bc ]
310+ losses . append ( losses_bc )
275311 return losses
276312
277313 def losses_train (self , targets , outputs , loss_fn , inputs , model , aux = None ):
278314 num_func = self .num_func if self .batch_size is None else self .batch_size
279- return self ._losses (outputs , loss_fn , inputs , model , num_func )
315+ return self ._losses (outputs , loss_fn , inputs , model , num_func , aux = aux )
280316
281317 def losses_test (self , targets , outputs , loss_fn , inputs , model , aux = None ):
282- return self ._losses (outputs , loss_fn , inputs , model , len (self .test_x [0 ]))
318+ return self ._losses (
319+ outputs , loss_fn , inputs , model , len (self .test_x [0 ]), aux = aux
320+ )
283321
284322 def train_next_batch (self , batch_size = None ):
285323 if self .train_x is None :
0 commit comments