diff --git a/deepxde/data/pde.py b/deepxde/data/pde.py index acbcc6e25..fcf08214b 100644 --- a/deepxde/data/pde.py +++ b/deepxde/data/pde.py @@ -3,7 +3,6 @@ from .data import Data from .. import backend as bkd from .. import config -from ..backend import backend_name from ..utils import get_num_args, run_if_all_none, mpi_scatter_from_rank0 @@ -128,9 +127,8 @@ def __init__( self.test() def losses(self, targets, outputs, loss_fn, inputs, model, aux=None): - if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch", "paddle"]: - outputs_pde = outputs - elif backend_name == "jax": + outputs_pde = outputs + if bkd.backend_name == "jax": # JAX requires pure functions outputs_pde = (outputs, aux[0]) @@ -166,7 +164,7 @@ def losses(self, targets, outputs, loss_fn, inputs, model, aux=None): for i, bc in enumerate(self.bcs): beg, end = bcs_start[i], bcs_start[i + 1] # The same BC points are used for training and testing. - error = bc.error(self.train_x, inputs, outputs, beg, end) + error = bc.error(self.train_x, inputs, outputs_pde, beg, end) losses.append(loss_fn[len(error_f) + i](bkd.zeros_like(error), error)) return losses diff --git a/deepxde/data/pde_operator.py b/deepxde/data/pde_operator.py index 5de5a6f46..b1cd6d4bf 100644 --- a/deepxde/data/pde_operator.py +++ b/deepxde/data/pde_operator.py @@ -80,13 +80,19 @@ def losses(self, targets, outputs, loss_fn, inputs, model, aux=None): bcs_start = np.cumsum([0] + self.num_bcs) error_f = [fi[bcs_start[-1] :] for fi in f] losses = [loss_fn(bkd.zeros_like(error), error) for error in error_f] + + outputs_pdeoperator = outputs + if bkd.backend_name == "jax": + # JAX requries pure functions + outputs_pdeoperator = (outputs, aux[0]) + for i, bc in enumerate(self.pde.bcs): beg, end = bcs_start[i], bcs_start[i + 1] # The same BC points are used for training and testing. error = bc.error( self.train_x[1], inputs[1], - outputs, + outputs_pdeoperator, beg, end, aux_var=self.train_aux_vars, diff --git a/deepxde/gradients.py b/deepxde/gradients.py index d8ed25139..64b73a9e0 100644 --- a/deepxde/gradients.py +++ b/deepxde/gradients.py @@ -203,7 +203,7 @@ def __init__(self, y, xs, component=None, grad_y=None): if backend_name in ["tensorflow.compat.v1", "tensorflow", "pytorch", "paddle"]: dim_y = y.shape[1] elif backend_name == "jax": - dim_y = y[0].shape[0] + dim_y = y[0].shape[1] if dim_y > 1: if component is None: diff --git a/deepxde/icbc/boundary_conditions.py b/deepxde/icbc/boundary_conditions.py index e1f863a08..7e0d0e699 100644 --- a/deepxde/icbc/boundary_conditions.py +++ b/deepxde/icbc/boundary_conditions.py @@ -52,9 +52,11 @@ def collocation_points(self, X): return self.filter(X) def normal_derivative(self, X, inputs, outputs, beg, end): - dydx = grad.jacobian(outputs, inputs, i=self.component, j=None)[beg:end] + dydx = grad.jacobian(outputs, inputs, i=self.component, j=None) + if backend_name == "jax": + dydx = dydx[0] n = self.boundary_normal(X, beg, end, None) - return bkd.sum(dydx * n, 1, keepdims=True) + return bkd.sum(dydx[beg:end] * n, 1, keepdims=True) @abstractmethod def error(self, X, inputs, outputs, beg, end, aux_var=None): @@ -77,6 +79,8 @@ def error(self, X, inputs, outputs, beg, end, aux_var=None): "DirichletBC function should return an array of shape N by 1 for each " "component. Use argument 'component' for different output components." ) + if backend_name == "jax": + outputs = outputs[0] return outputs[beg:end, self.component : self.component + 1] - values @@ -100,9 +104,11 @@ def __init__(self, geom, func, on_boundary, component=0): self.func = func def error(self, X, inputs, outputs, beg, end, aux_var=None): - return self.normal_derivative(X, inputs, outputs, beg, end) - self.func( - X[beg:end], outputs[beg:end] - ) + normal_derivative = self.normal_derivative(X, inputs, outputs, beg, end) + if backend_name == "jax": + outputs = outputs[0] + values = self.func(X[beg:end], outputs[beg:end]) + return normal_derivative - values class PeriodicBC(BC): @@ -125,10 +131,14 @@ def collocation_points(self, X): def error(self, X, inputs, outputs, beg, end, aux_var=None): mid = beg + (end - beg) // 2 if self.derivative_order == 0: + if backend_name == "jax": + outputs = outputs[0] yleft = outputs[beg:mid, self.component : self.component + 1] yright = outputs[mid:end, self.component : self.component + 1] else: dydx = grad.jacobian(outputs, inputs, i=self.component, j=self.component_x) + if backend_name == "jax": + dydx = dydx[0] yleft = dydx[beg:mid] yright = dydx[mid:end] return yleft - yright @@ -158,6 +168,8 @@ def __init__(self, geom, func, on_boundary): self.func = func def error(self, X, inputs, outputs, beg, end, aux_var=None): + # User defined func is responsible for handling compatibility with the + # desired backend. return self.func(inputs, outputs, X)[beg:end] @@ -210,6 +222,8 @@ def collocation_points(self, X): return self.points def error(self, X, inputs, outputs, beg, end, aux_var=None): + if backend_name == "jax": + outputs = outputs[0] if self.batch_size is not None: if isinstance(self.component, numbers.Number): return ( @@ -260,6 +274,8 @@ def collocation_points(self, X): return self.points def error(self, X, inputs, outputs, beg, end, aux_var=None): + # User defined func is responsible for handling compatibility with the + # desired backend. return self.func(inputs, outputs, X)[beg:end] - self.values diff --git a/deepxde/icbc/initial_conditions.py b/deepxde/icbc/initial_conditions.py index a7ca57f2f..cfef7b44a 100644 --- a/deepxde/icbc/initial_conditions.py +++ b/deepxde/icbc/initial_conditions.py @@ -33,4 +33,6 @@ def error(self, X, inputs, outputs, beg, end, aux_var=None): "IC function should return an array of shape N by 1 for each component." "Use argument 'component' for different output components." ) + if bkd.backend_name == "jax": + outputs = outputs[0] return outputs[beg:end, self.component : self.component + 1] - values diff --git a/deepxde/model.py b/deepxde/model.py index 3584b76c3..86a5210f2 100644 --- a/deepxde/model.py +++ b/deepxde/model.py @@ -1011,6 +1011,23 @@ def save(self, save_path, protocol="backend", verbose=0): elif backend_name == "tensorflow": save_path += ".ckpt" self.net.save_weights(save_path) + elif backend_name == "jax": + # Lazy load Orbax to avoid a hard dependancy when using JAX + # TODO: identify a better solution that complies with PEP8 + import orbax + from flax.training import orbax_utils + save_path += ".ckpt" + checkpoint = { + "params": self.params, + "state": self.opt_state + } + self.checkpointer = orbax.checkpoint.PyTreeCheckpointer() + save_args = orbax_utils.save_args_from_target(checkpoint) + # `Force=True` option causes existing checkpoints to be + # overwritten, matching the PyTorch checkpointer behaviour. + self.checkpointer.save( + save_path, checkpoint, force=True, save_args=save_args + ) elif backend_name == "pytorch": save_path += ".pt" checkpoint = { @@ -1055,6 +1072,10 @@ def restore(self, save_path, device=None, verbose=0): self.saver.restore(self.sess, save_path) elif backend_name == "tensorflow": self.net.load_weights(save_path) + elif backend_name == "jax": + checkpoint = self.checkpointer.restore(save_path) + self.params, self.opt_state = checkpoint["params"], checkpoint["state"] + self.net.params, self.external_trainable_variables = self.params elif backend_name == "pytorch": if device is not None: checkpoint = torch.load(save_path, map_location=torch.device(device))