@@ -240,6 +240,7 @@ def __init__(
240240 def _losses (self , outputs , loss_fn , inputs , model , num_func , aux = None ):
241241 bcs_start = np .cumsum ([0 ] + self .pde .num_bcs )
242242 losses = []
243+ # PDE loss
243244 if config .autodiff == "reverse" : # reverse mode AD
244245 for i in range (num_func ):
245246 out = outputs [i ]
@@ -253,20 +254,6 @@ def _losses(self, outputs, loss_fn, inputs, model, num_func, aux=None):
253254 f = [f ]
254255 error_f = [fi [bcs_start [- 1 ]:] for fi in f ]
255256 losses_i = [loss_fn (bkd .zeros_like (error ), error ) for error in error_f ]
256-
257- for j , bc in enumerate (self .pde .bcs ):
258- beg , end = bcs_start [j ], bcs_start [j + 1 ]
259- # The same BC points are used for training and testing.
260- error = bc .error (
261- self .train_x [1 ],
262- inputs [1 ],
263- out ,
264- beg ,
265- end ,
266- aux_var = model .net .auxiliary_vars [i ][:, None ],
267- )
268- losses_i .append (loss_fn (bkd .zeros_like (error ), error ))
269-
270257 losses .append (losses_i )
271258
272259 losses = zip (* losses )
@@ -283,26 +270,26 @@ def forward_call(trunk_input):
283270 f = [f ]
284271 error_f = [fi [:, bcs_start [- 1 ]:] for fi in f ]
285272 losses = [loss_fn (bkd .zeros_like (error ), error ) for error in error_f ] # noqa
286- # BC
287- for k , bc in enumerate (self .pde .bcs ):
288- beg , end = bcs_start [k ], bcs_start [k + 1 ]
289- error_k = []
290- for i in range (num_func ):
291- output_i = outputs [i ]
292- if bkd .ndim (output_i ) == 1 : # noqa
293- output_i = output_i [:, None ]
294- error_ki = bc .error (
295- self .train_x [1 ],
296- inputs [1 ],
297- output_i ,
298- beg ,
299- end ,
300- aux_var = model .net .auxiliary_vars [i ][:, None ],
301- )
302- error_k .append (error_ki )
303- error_k = bkd .stack (error_k , axis = 0 ) # noqa
304- loss_k = loss_fn (bkd .zeros_like (error_k ), error_k ) # noqa
305- losses .append (loss_k )
273+ # BC loss
274+ for k , bc in enumerate (self .pde .bcs ):
275+ beg , end = bcs_start [k ], bcs_start [k + 1 ]
276+ error_k = []
277+ for i in range (num_func ):
278+ output_i = outputs [i ]
279+ if bkd .ndim (output_i ) == 1 : # noqa
280+ output_i = output_i [:, None ]
281+ error_ki = bc .error (
282+ self .train_x [1 ],
283+ inputs [1 ],
284+ output_i ,
285+ beg ,
286+ end ,
287+ aux_var = model .net .auxiliary_vars [i ][:, None ],
288+ )
289+ error_k .append (error_ki )
290+ error_k = bkd .stack (error_k , axis = 0 ) # noqa
291+ loss_k = loss_fn (bkd .zeros_like (error_k ), error_k ) # noqa
292+ losses .append (loss_k )
306293 return losses
307294
308295 def losses_train (self , targets , outputs , loss_fn , inputs , model , aux = None ):
0 commit comments