11__all__ = ["Model" , "TrainState" , "LossHistory" ]
22
3+ import functools
34import pickle
45from collections import OrderedDict
56
@@ -276,24 +277,31 @@ def _compile_jax(self, lr, loss_fn, decay, loss_weights):
276277 self .net .params = self .net .init (key , x )
277278
278279 @jax .jit
280+ @functools .partial (jax .vmap , in_axes = (None , None , 0 ), out_axes = 0 )
279281 def inner_outputs (params , training , inputs ):
280- return self .net .apply (params , training , inputs )
282+ return self .net .apply (params , inputs , training = training )
281283
282284 @jax .jit
285+ @functools .partial (jax .vmap , in_axes = (None , None , 0 , 0 ), out_axes = (0 , 0 ))
283286 def inner_outputs_losses (params , training , inputs , targets ):
284287 # TODO: add auxiliary vars, regularization loss, weighted losses
285- outputs_ = self .net .apply (params , inputs , training = training )
288+ _outputs = self .net .apply (params , inputs , training = training )
286289 # Data losses
287- losses = self .data .losses (targets , outputs_ , loss_fn , self )
290+ # TODO: support passing auxiliary arguments to data.losses, for all data types. Note
291+ # that this is particularly useful for jax backend, and is not the same as auxiliary_vars.
292+ # Possible auxiliary arguments are inputs, masks indicating whether current inputs are
293+ # at boundary/initial conditions.
294+ losses = self .data .losses (targets , _outputs , loss_fn , self , aux = None )
288295 if not isinstance (losses , list ):
289296 losses = [losses ]
290- return outputs_ , losses
297+ return _outputs , jax . numpy . stack ( losses )
291298
292299 @jax .jit
293300 def inner_train_step (params , opt_state , inputs , targets ):
294301 def loss_function (params ):
295- losses = inner_outputs_losses (params , True , inputs , targets )[1 ]
296- return jax .numpy .sum (jax .numpy .stack (losses ))
302+ return jax .numpy .sum (
303+ inner_outputs_losses (params , True , inputs , targets )[1 ], axis = 0
304+ ).reshape ([])
297305
298306 grad_fn = jax .grad (
299307 loss_function
@@ -307,7 +315,12 @@ def outputs(training, inputs):
307315 return inner_outputs (self .net .params , training , inputs )
308316
309317 def outputs_losses (training , inputs , targets ):
310- return inner_outputs_losses (self .net .params , training , inputs , targets )
318+ _outputs , _losses = inner_outputs_losses (
319+ self .net .params , training , inputs , targets
320+ )
321+ return _outputs , jax .numpy .sum (
322+ _losses , axis = 0
323+ ) # sum over the first axis, because here _losses is a batch
311324
312325 def train_step (inputs , targets ):
313326 self .net .params , self .opt_state = inner_train_step (
0 commit comments