Skip to content

Commit 5917be1

Browse files
committed
Use reverse mode in the if, forward mode in the else.
1 parent 6c258db commit 5917be1

File tree

1 file changed

+34
-34
lines changed

1 file changed

+34
-34
lines changed

deepxde/data/pde_operator.py

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,40 @@ def __init__(
240240
def _losses(self, outputs, loss_fn, inputs, model, num_func, aux=None):
241241
bcs_start = np.cumsum([0] + self.pde.num_bcs)
242242

243-
if config.autodiff == "forward": # forward mode AD
243+
if config.autodiff == "reverse": # reverse mode AD
244+
losses = []
245+
for i in range(num_func):
246+
out = outputs[i]
247+
# Single output
248+
if bkd.ndim(out) == 1:
249+
out = out[:, None]
250+
f = []
251+
if self.pde.pde is not None:
252+
f = self.pde.pde(inputs[1], out, model.net.auxiliary_vars[i][:, None])
253+
if not isinstance(f, (list, tuple)):
254+
f = [f]
255+
error_f = [fi[bcs_start[-1]:] for fi in f]
256+
losses_i = [loss_fn(bkd.zeros_like(error), error) for error in error_f]
257+
258+
for j, bc in enumerate(self.pde.bcs):
259+
beg, end = bcs_start[j], bcs_start[j + 1]
260+
# The same BC points are used for training and testing.
261+
error = bc.error(
262+
self.train_x[1],
263+
inputs[1],
264+
out,
265+
beg,
266+
end,
267+
aux_var=model.net.auxiliary_vars[i][:, None],
268+
)
269+
losses_i.append(loss_fn(bkd.zeros_like(error), error))
270+
271+
losses.append(losses_i)
272+
273+
losses = zip(*losses)
274+
# Use stack instead of as_tensor to keep the gradients.
275+
losses = [bkd.reduce_mean(bkd.stack(loss, 0)) for loss in losses]
276+
else: # forward mode AD
244277
losses=[]
245278

246279
def forward_call(trunk_input):
@@ -274,39 +307,6 @@ def forward_call(trunk_input):
274307
error_k = bkd.stack(error_k, axis=0) # noqa
275308
loss_k = loss_fn(bkd.zeros_like(error_k), error_k) # noqa
276309
losses.append(loss_k)
277-
else: # reverse mode AD
278-
losses = []
279-
for i in range(num_func):
280-
out = outputs[i]
281-
# Single output
282-
if bkd.ndim(out) == 1:
283-
out = out[:, None]
284-
f = []
285-
if self.pde.pde is not None:
286-
f = self.pde.pde(inputs[1], out, model.net.auxiliary_vars[i][:, None])
287-
if not isinstance(f, (list, tuple)):
288-
f = [f]
289-
error_f = [fi[bcs_start[-1] :] for fi in f]
290-
losses_i = [loss_fn(bkd.zeros_like(error), error) for error in error_f]
291-
292-
for j, bc in enumerate(self.pde.bcs):
293-
beg, end = bcs_start[j], bcs_start[j + 1]
294-
# The same BC points are used for training and testing.
295-
error = bc.error(
296-
self.train_x[1],
297-
inputs[1],
298-
out,
299-
beg,
300-
end,
301-
aux_var=model.net.auxiliary_vars[i][:, None],
302-
)
303-
losses_i.append(loss_fn(bkd.zeros_like(error), error))
304-
305-
losses.append(losses_i)
306-
307-
losses = zip(*losses)
308-
# Use stack instead of as_tensor to keep the gradients.
309-
losses = [bkd.reduce_mean(bkd.stack(loss, 0)) for loss in losses]
310310
return losses
311311

312312
def losses_train(self, targets, outputs, loss_fn, inputs, model, aux=None):

0 commit comments

Comments
 (0)