Skip to content

Commit 7733faa

Browse files
authored
Refactor backend jax: utilize vmap, and add auxiliary arguments to data.losses (#635)
1 parent 80e1d15 commit 7733faa

File tree

3 files changed

+22
-9
lines changed

3 files changed

+22
-9
lines changed

deepxde/data/function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(
4141
self.train_x, self.train_y = None, None
4242
self.test_x, self.test_y = None, None
4343

44-
def losses(self, targets, outputs, loss, model):
44+
def losses(self, targets, outputs, loss, model, aux=None):
4545
return [loss(targets, outputs)]
4646

4747
def train_next_batch(self, batch_size=None):

deepxde/data/pde.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def __init__(
122122
self.train_next_batch()
123123
self.test()
124124

125-
def losses(self, targets, outputs, loss, model):
125+
def losses(self, targets, outputs, loss, model, aux=None):
126126
f = []
127127
if self.pde is not None:
128128
if get_num_args(self.pde) == 2:

deepxde/model.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
__all__ = ["Model", "TrainState", "LossHistory"]
22

3+
import functools
34
import pickle
45
from collections import OrderedDict
56

@@ -276,24 +277,31 @@ def _compile_jax(self, lr, loss_fn, decay, loss_weights):
276277
self.net.params = self.net.init(key, x)
277278

278279
@jax.jit
280+
@functools.partial(jax.vmap, in_axes=(None, None, 0), out_axes=0)
279281
def inner_outputs(params, training, inputs):
280-
return self.net.apply(params, training, inputs)
282+
return self.net.apply(params, inputs, training=training)
281283

282284
@jax.jit
285+
@functools.partial(jax.vmap, in_axes=(None, None, 0, 0), out_axes=(0, 0))
283286
def inner_outputs_losses(params, training, inputs, targets):
284287
# TODO: add auxiliary vars, regularization loss, weighted losses
285-
outputs_ = self.net.apply(params, inputs, training=training)
288+
_outputs = self.net.apply(params, inputs, training=training)
286289
# Data losses
287-
losses = self.data.losses(targets, outputs_, loss_fn, self)
290+
# TODO: support passing auxiliary arguments to data.losses, for all data types. Note
291+
# that this is particularly useful for jax backend, and is not the same as auxiliary_vars.
292+
# Possible auxiliary arguments are inputs, masks indicating whether current inputs are
293+
# at boundary/initial conditions.
294+
losses = self.data.losses(targets, _outputs, loss_fn, self, aux=None)
288295
if not isinstance(losses, list):
289296
losses = [losses]
290-
return outputs_, losses
297+
return _outputs, jax.numpy.stack(losses)
291298

292299
@jax.jit
293300
def inner_train_step(params, opt_state, inputs, targets):
294301
def loss_function(params):
295-
losses = inner_outputs_losses(params, True, inputs, targets)[1]
296-
return jax.numpy.sum(jax.numpy.stack(losses))
302+
return jax.numpy.sum(
303+
inner_outputs_losses(params, True, inputs, targets)[1], axis=0
304+
).reshape([])
297305

298306
grad_fn = jax.grad(
299307
loss_function
@@ -307,7 +315,12 @@ def outputs(training, inputs):
307315
return inner_outputs(self.net.params, training, inputs)
308316

309317
def outputs_losses(training, inputs, targets):
310-
return inner_outputs_losses(self.net.params, training, inputs, targets)
318+
_outputs, _losses = inner_outputs_losses(
319+
self.net.params, training, inputs, targets
320+
)
321+
return _outputs, jax.numpy.sum(
322+
_losses, axis=0
323+
) # sum over the first axis, because here _losses is a batch
311324

312325
def train_step(inputs, targets):
313326
self.net.params, self.opt_state = inner_train_step(

0 commit comments

Comments
 (0)