Skip to content

Commit 947abee

Browse files
committed
Refactor the code to reduce the redundancy
1 parent cf076d9 commit 947abee

File tree

1 file changed

+21
-34
lines changed

1 file changed

+21
-34
lines changed

deepxde/data/pde_operator.py

Lines changed: 21 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ def __init__(
240240
def _losses(self, outputs, loss_fn, inputs, model, num_func, aux=None):
241241
bcs_start = np.cumsum([0] + self.pde.num_bcs)
242242
losses = []
243+
# PDE loss
243244
if config.autodiff == "reverse": # reverse mode AD
244245
for i in range(num_func):
245246
out = outputs[i]
@@ -253,20 +254,6 @@ def _losses(self, outputs, loss_fn, inputs, model, num_func, aux=None):
253254
f = [f]
254255
error_f = [fi[bcs_start[-1]:] for fi in f]
255256
losses_i = [loss_fn(bkd.zeros_like(error), error) for error in error_f]
256-
257-
for j, bc in enumerate(self.pde.bcs):
258-
beg, end = bcs_start[j], bcs_start[j + 1]
259-
# The same BC points are used for training and testing.
260-
error = bc.error(
261-
self.train_x[1],
262-
inputs[1],
263-
out,
264-
beg,
265-
end,
266-
aux_var=model.net.auxiliary_vars[i][:, None],
267-
)
268-
losses_i.append(loss_fn(bkd.zeros_like(error), error))
269-
270257
losses.append(losses_i)
271258

272259
losses = zip(*losses)
@@ -283,26 +270,26 @@ def forward_call(trunk_input):
283270
f = [f]
284271
error_f = [fi[:, bcs_start[-1]:] for fi in f]
285272
losses = [loss_fn(bkd.zeros_like(error), error) for error in error_f] # noqa
286-
# BC
287-
for k, bc in enumerate(self.pde.bcs):
288-
beg, end = bcs_start[k], bcs_start[k + 1]
289-
error_k = []
290-
for i in range(num_func):
291-
output_i = outputs[i]
292-
if bkd.ndim(output_i) == 1: # noqa
293-
output_i = output_i[:, None]
294-
error_ki = bc.error(
295-
self.train_x[1],
296-
inputs[1],
297-
output_i,
298-
beg,
299-
end,
300-
aux_var=model.net.auxiliary_vars[i][:, None],
301-
)
302-
error_k.append(error_ki)
303-
error_k = bkd.stack(error_k, axis=0) # noqa
304-
loss_k = loss_fn(bkd.zeros_like(error_k), error_k) # noqa
305-
losses.append(loss_k)
273+
# BC loss
274+
for k, bc in enumerate(self.pde.bcs):
275+
beg, end = bcs_start[k], bcs_start[k + 1]
276+
error_k = []
277+
for i in range(num_func):
278+
output_i = outputs[i]
279+
if bkd.ndim(output_i) == 1: # noqa
280+
output_i = output_i[:, None]
281+
error_ki = bc.error(
282+
self.train_x[1],
283+
inputs[1],
284+
output_i,
285+
beg,
286+
end,
287+
aux_var=model.net.auxiliary_vars[i][:, None],
288+
)
289+
error_k.append(error_ki)
290+
error_k = bkd.stack(error_k, axis=0) # noqa
291+
loss_k = loss_fn(bkd.zeros_like(error_k), error_k) # noqa
292+
losses.append(loss_k)
306293
return losses
307294

308295
def losses_train(self, targets, outputs, loss_fn, inputs, model, aux=None):

0 commit comments

Comments
 (0)