Skip to content

Commit b0d239b

Browse files
authored
PDEOperatorCartesianProd supports forward-mode AD (#1903)
1 parent 4e9a283 commit b0d239b

File tree

1 file changed

+54
-16
lines changed

1 file changed

+54
-16
lines changed

deepxde/data/pde_operator.py

Lines changed: 54 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -237,23 +237,59 @@ def __init__(
237237
self.train_next_batch()
238238
self.test()
239239

240-
def _losses(self, outputs, loss_fn, inputs, model, num_func):
240+
def _losses(self, outputs, loss_fn, inputs, model, num_func, aux=None):
241241
bcs_start = np.cumsum([0] + self.pde.num_bcs)
242242

243243
losses = []
244-
for i in range(num_func):
245-
out = outputs[i]
246-
# Single output
247-
if bkd.ndim(out) == 1:
248-
out = out[:, None]
244+
# PDE loss
245+
if config.autodiff == "reverse": # reverse mode AD
246+
for i in range(num_func):
247+
out = outputs[i]
248+
# Single output
249+
if bkd.ndim(out) == 1:
250+
out = out[:, None]
251+
f = []
252+
if self.pde.pde is not None:
253+
f = self.pde.pde(
254+
inputs[1], out, model.net.auxiliary_vars[i][:, None]
255+
)
256+
if not isinstance(f, (list, tuple)):
257+
f = [f]
258+
error_f = [fi[bcs_start[-1] :] for fi in f]
259+
losses_i = [loss_fn(bkd.zeros_like(error), error) for error in error_f]
260+
losses.append(losses_i)
261+
262+
losses = zip(*losses)
263+
# Use stack instead of as_tensor to keep the gradients.
264+
losses = [bkd.reduce_mean(bkd.stack(loss, 0)) for loss in losses]
265+
elif config.autodiff == "forward": # forward mode AD
266+
267+
def forward_call(trunk_input):
268+
return aux[0]((inputs[0], trunk_input))
269+
249270
f = []
250271
if self.pde.pde is not None:
251-
f = self.pde.pde(inputs[1], out, model.net.auxiliary_vars[i][:, None])
272+
# Each f has the shape (N1, N2)
273+
f = self.pde.pde(
274+
inputs[1], (outputs, forward_call), model.net.auxiliary_vars
275+
)
252276
if not isinstance(f, (list, tuple)):
253277
f = [f]
254-
error_f = [fi[bcs_start[-1] :] for fi in f]
255-
losses_i = [loss_fn(bkd.zeros_like(error), error) for error in error_f]
256-
278+
# Each error has the shape (N1, ~N2)
279+
error_f = [fi[:, bcs_start[-1] :] for fi in f]
280+
for error in error_f:
281+
error_i = []
282+
for i in range(num_func):
283+
error_i.append(loss_fn(bkd.zeros_like(error[i]), error[i]))
284+
losses.append(bkd.reduce_mean(bkd.stack(error_i, 0)))
285+
286+
# BC loss
287+
losses_bc = []
288+
for i in range(num_func):
289+
losses_i = []
290+
out = outputs[i]
291+
if bkd.ndim(out) == 1:
292+
out = out[:, None]
257293
for j, bc in enumerate(self.pde.bcs):
258294
beg, end = bcs_start[j], bcs_start[j + 1]
259295
# The same BC points are used for training and testing.
@@ -267,19 +303,21 @@ def _losses(self, outputs, loss_fn, inputs, model, num_func):
267303
)
268304
losses_i.append(loss_fn(bkd.zeros_like(error), error))
269305

270-
losses.append(losses_i)
306+
losses_bc.append(losses_i)
271307

272-
losses = zip(*losses)
273-
# Use stack instead of as_tensor to keep the gradients.
274-
losses = [bkd.reduce_mean(bkd.stack(loss, 0)) for loss in losses]
308+
losses_bc = zip(*losses_bc)
309+
losses_bc = [bkd.reduce_mean(bkd.stack(loss, 0)) for loss in losses_bc]
310+
losses.append(losses_bc)
275311
return losses
276312

277313
def losses_train(self, targets, outputs, loss_fn, inputs, model, aux=None):
278314
num_func = self.num_func if self.batch_size is None else self.batch_size
279-
return self._losses(outputs, loss_fn, inputs, model, num_func)
315+
return self._losses(outputs, loss_fn, inputs, model, num_func, aux=aux)
280316

281317
def losses_test(self, targets, outputs, loss_fn, inputs, model, aux=None):
282-
return self._losses(outputs, loss_fn, inputs, model, len(self.test_x[0]))
318+
return self._losses(
319+
outputs, loss_fn, inputs, model, len(self.test_x[0]), aux=aux
320+
)
283321

284322
def train_next_batch(self, batch_size=None):
285323
if self.train_x is None:

0 commit comments

Comments
 (0)