@@ -240,7 +240,40 @@ def __init__(
240240 def _losses (self , outputs , loss_fn , inputs , model , num_func , aux = None ):
241241 bcs_start = np .cumsum ([0 ] + self .pde .num_bcs )
242242
243- if config .autodiff == "forward" : # forward mode AD
243+ if config .autodiff == "reverse" : # reverse mode AD
244+ losses = []
245+ for i in range (num_func ):
246+ out = outputs [i ]
247+ # Single output
248+ if bkd .ndim (out ) == 1 :
249+ out = out [:, None ]
250+ f = []
251+ if self .pde .pde is not None :
252+ f = self .pde .pde (inputs [1 ], out , model .net .auxiliary_vars [i ][:, None ])
253+ if not isinstance (f , (list , tuple )):
254+ f = [f ]
255+ error_f = [fi [bcs_start [- 1 ]:] for fi in f ]
256+ losses_i = [loss_fn (bkd .zeros_like (error ), error ) for error in error_f ]
257+
258+ for j , bc in enumerate (self .pde .bcs ):
259+ beg , end = bcs_start [j ], bcs_start [j + 1 ]
260+ # The same BC points are used for training and testing.
261+ error = bc .error (
262+ self .train_x [1 ],
263+ inputs [1 ],
264+ out ,
265+ beg ,
266+ end ,
267+ aux_var = model .net .auxiliary_vars [i ][:, None ],
268+ )
269+ losses_i .append (loss_fn (bkd .zeros_like (error ), error ))
270+
271+ losses .append (losses_i )
272+
273+ losses = zip (* losses )
274+ # Use stack instead of as_tensor to keep the gradients.
275+ losses = [bkd .reduce_mean (bkd .stack (loss , 0 )) for loss in losses ]
276+ else : # forward mode AD
244277 losses = []
245278
246279 def forward_call (trunk_input ):
@@ -274,39 +307,6 @@ def forward_call(trunk_input):
274307 error_k = bkd .stack (error_k , axis = 0 ) # noqa
275308 loss_k = loss_fn (bkd .zeros_like (error_k ), error_k ) # noqa
276309 losses .append (loss_k )
277- else : # reverse mode AD
278- losses = []
279- for i in range (num_func ):
280- out = outputs [i ]
281- # Single output
282- if bkd .ndim (out ) == 1 :
283- out = out [:, None ]
284- f = []
285- if self .pde .pde is not None :
286- f = self .pde .pde (inputs [1 ], out , model .net .auxiliary_vars [i ][:, None ])
287- if not isinstance (f , (list , tuple )):
288- f = [f ]
289- error_f = [fi [bcs_start [- 1 ] :] for fi in f ]
290- losses_i = [loss_fn (bkd .zeros_like (error ), error ) for error in error_f ]
291-
292- for j , bc in enumerate (self .pde .bcs ):
293- beg , end = bcs_start [j ], bcs_start [j + 1 ]
294- # The same BC points are used for training and testing.
295- error = bc .error (
296- self .train_x [1 ],
297- inputs [1 ],
298- out ,
299- beg ,
300- end ,
301- aux_var = model .net .auxiliary_vars [i ][:, None ],
302- )
303- losses_i .append (loss_fn (bkd .zeros_like (error ), error ))
304-
305- losses .append (losses_i )
306-
307- losses = zip (* losses )
308- # Use stack instead of as_tensor to keep the gradients.
309- losses = [bkd .reduce_mean (bkd .stack (loss , 0 )) for loss in losses ]
310310 return losses
311311
312312 def losses_train (self , targets , outputs , loss_fn , inputs , model , aux = None ):
0 commit comments