diff --git a/deepxde/data/pde.py b/deepxde/data/pde.py index a81c5df82..084e35cfa 100644 --- a/deepxde/data/pde.py +++ b/deepxde/data/pde.py @@ -4,7 +4,7 @@ from .. import backend as bkd from .. import config from ..backend import backend_name -from ..utils import get_num_args, run_if_all_none, mpi_scatter_from_rank0 +from ..utils import get_num_args, run_if_all_none, mpi_scatter_from_rank0, list_handler class PDE(Data): @@ -84,6 +84,7 @@ def __init__( solution=None, num_test=None, auxiliary_var_function=None, + is_SPINN=False, ): self.geom = geometry self.pde = pde @@ -108,10 +109,18 @@ def __init__( self.anchors = None if anchors is None else anchors.astype(config.real(np)) self.exclusions = exclusions - self.soln = solution + if solution is not None: + @list_handler + def solution_handling_list(inputs): + return solution(inputs) + self.soln = solution_handling_list + else: + self.soln = solution + self.num_test = num_test self.auxiliary_var_fn = auxiliary_var_function + self.is_SPINN = is_SPINN # TODO: train_x_all is used for PDE losses. It is better to add train_x_pde # explicitly. @@ -128,29 +137,41 @@ def __init__( self.test() def losses(self, targets, outputs, loss_fn, inputs, model, aux=None): + + bcs_start = np.cumsum([0] + self.num_bcs) + bcs_start = list(map(int, bcs_start)) + + if self.is_SPINN: + num_bcs_output = [num_bc**2 for num_bc in self.num_bcs] + bcs_start_output = np.cumsum([0] + num_bcs_output) + bcs_start_output = list(map(int, bcs_start_output)) + else: + bcs_start_output = bcs_start + if backend_name in ["tensorflow.compat.v1", "paddle"]: - outputs_pde = outputs + outputs_pde = outputs[bcs_start_output[-1] :] elif backend_name in ["tensorflow", "pytorch"]: if config.autodiff == "reverse": - outputs_pde = outputs + outputs_pde = outputs[bcs_start_output[-1] :] elif config.autodiff == "forward": # forward-mode AD requires functions - outputs_pde = (outputs, aux[0]) + outputs_pde = (outputs[bcs_start_output[-1] :], aux[0]) elif backend_name == "jax": # JAX requires pure functions - outputs_pde = (outputs, aux[0]) + outputs_pde = (outputs[bcs_start_output[-1] :], aux[0]) + inputs_pde = inputs[-1] if isinstance(inputs, (list, tuple)) else inputs[bcs_start[-1] :] f = [] if self.pde is not None: if get_num_args(self.pde) == 2: - f = self.pde(inputs, outputs_pde) + f = self.pde(inputs_pde, outputs_pde) elif get_num_args(self.pde) == 3: if self.auxiliary_var_fn is None: if aux is None or len(aux) == 1: raise ValueError("Auxiliary variable function not defined.") - f = self.pde(inputs, outputs_pde, unknowns=aux[1]) + f = self.pde(inputs_pde, outputs_pde, unknowns=aux[1]) else: - f = self.pde(inputs, outputs_pde, model.net.auxiliary_vars) + f = self.pde(inputs_pde, outputs_pde, model.net.auxiliary_vars) if not isinstance(f, (list, tuple)): f = [f] @@ -163,16 +184,19 @@ def losses(self, targets, outputs, loss_fn, inputs, model, aux=None): ) ) - bcs_start = np.cumsum([0] + self.num_bcs) - bcs_start = list(map(int, bcs_start)) - error_f = [fi[bcs_start[-1] :] for fi in f] + + error_f = f losses = [ loss_fn[i](bkd.zeros_like(error), error) for i, error in enumerate(error_f) ] for i, bc in enumerate(self.bcs): - beg, end = bcs_start[i], bcs_start[i + 1] - # The same BC points are used for training and testing. - error = bc.error(self.train_x, inputs, outputs, beg, end) + if isinstance(inputs, (list, tuple)): + beg, end = bcs_start_output[i], bcs_start_output[i + 1] + error = bc.error(self.train_x, inputs[i], outputs[beg:end,:], 0, end-beg) + else: + beg, end = bcs_start_output[i], bcs_start_output[i + 1] + # The same BC points are used for training and testing. + error = bc.error(self.train_x, inputs, outputs, beg, end) losses.append(loss_fn[len(error_f) + i](bkd.zeros_like(error), error)) return losses @@ -194,7 +218,10 @@ def train_next_batch(self, batch_size=None): if config.parallel_scaling == "strong": self.train_x_all = mpi_scatter_from_rank0(self.train_x_all) if self.pde is not None: - self.train_x = np.vstack((self.train_x, self.train_x_all)) + if self.is_SPINN and len(self.train_x) > 0: + self.train_x = self.train_x + [self.train_x_all] + else: + self.train_x = np.vstack((self.train_x, self.train_x_all)) self.train_y = self.soln(self.train_x) if self.soln else None if self.auxiliary_var_fn is not None: self.train_aux_vars = self.auxiliary_var_fn(self.train_x).astype( @@ -247,7 +274,11 @@ def replace_with_anchors(self, anchors): self.train_x_all = self.anchors self.train_x = self.bc_points() if self.pde is not None: - self.train_x = np.vstack((self.train_x, self.train_x_all)) + if self.is_SPINN and len(self.train_x) > 0: + self.train_x = self.train_x + [self.train_x_all] + else: + self.train_x = np.vstack((self.train_x, self.train_x_all)) + self.train_y = self.soln(self.train_x) if self.soln else None if self.auxiliary_var_fn is not None: self.train_aux_vars = self.auxiliary_var_fn(self.train_x).astype( @@ -258,12 +289,15 @@ def replace_with_anchors(self, anchors): def train_points(self): X = np.empty((0, self.geom.dim), dtype=config.real(np)) if self.num_domain > 0: - if self.train_distribution == "uniform": - X = self.geom.uniform_points(self.num_domain, boundary=False) + if self.is_SPINN: + X = self.geom.uniform_spinn_points(self.num_test, boundary=False) else: - X = self.geom.random_points( - self.num_domain, random=self.train_distribution - ) + if self.train_distribution == "uniform": + X = self.geom.uniform_points(self.num_domain, boundary=False) + else: + X = self.geom.random_points( + self.num_domain, random=self.train_distribution + ) if self.num_boundary > 0: if self.train_distribution == "uniform": tmp = self.geom.uniform_boundary_points(self.num_boundary) @@ -287,17 +321,25 @@ def is_not_excluded(x): def bc_points(self): x_bcs = [bc.collocation_points(self.train_x_all) for bc in self.bcs] self.num_bcs = list(map(len, x_bcs)) - self.train_x_bc = ( - np.vstack(x_bcs) - if x_bcs - else np.empty([0, self.train_x_all.shape[-1]], dtype=config.real(np)) - ) + if self.is_SPINN: + self.train_x_bc = x_bcs if x_bcs else np.empty([0, self.train_x_all.shape[-1]], dtype=config.real(np)) + else: + self.train_x_bc = ( + np.vstack(x_bcs) + if x_bcs + else np.empty([0, self.train_x_all.shape[-1]], dtype=config.real(np)) + ) return self.train_x_bc def test_points(self): # TODO: Use different BC points from self.train_x_bc - x = self.geom.uniform_points(self.num_test, boundary=False) - x = np.vstack((self.train_x_bc, x)) + if self.is_SPINN: + x = self.geom.uniform_spinn_points(self.num_test, boundary=False) + if len(self.train_x_bc) > 0: + x = self.train_x_bc + [x] + else: + x = self.geom.uniform_points(self.num_test, boundary=False) + x = np.vstack((self.train_x_bc, x)) return x diff --git a/deepxde/geometry/geometry_nd.py b/deepxde/geometry/geometry_nd.py index a011cd417..0759235f5 100644 --- a/deepxde/geometry/geometry_nd.py +++ b/deepxde/geometry/geometry_nd.py @@ -80,6 +80,35 @@ def uniform_points(self, n, boundary=True): "Warning: {} points required, but {} points sampled.".format(n, len(x)) ) return x + + def uniform_spinn_points(self, n, boundary=True): + dx = (self.volume / n) ** (1 / self.dim) + xi = [] + for i in range(self.dim): + ni = int(np.ceil(self.side_length[i] / dx)) + if boundary: + xi.append( + np.linspace( + self.xmin[i], self.xmax[i], num=ni, dtype=config.real(np) + ) + ) + else: + xi.append( + np.linspace( + self.xmin[i], + self.xmax[i], + num=ni + 1, + endpoint=False, + dtype=config.real(np), + )[1:] + ) + x = np.array(xi).T + if n != np.prod([len(x) for x in xi]): + print( + "Warning: {} points required, but {} points sampled.".format(n, np.prod([len(x) for x in xi])) + ) + return x + def random_points(self, n, random="pseudo"): x = sample(n, self.dim, random) diff --git a/deepxde/model.py b/deepxde/model.py index 844d6ec22..35f48cf69 100644 --- a/deepxde/model.py +++ b/deepxde/model.py @@ -374,7 +374,12 @@ def _compile_jax(self, lr, loss_fn, decay): # Initialize the network's parameters if self.params is None: key = jax.random.PRNGKey(config.jax_random_seed) - self.net.params = self.net.init(key, self.data.test()[0]) + X_test = ( + self.data.test()[0][-1] + if isinstance(self.data.test()[0], (list, tuple)) + else self.data.test()[0] + ) + self.net.params = self.net.init(key, X_test) external_trainable_variables_arr = [ var.value for var in self.external_trainable_variables ] @@ -384,6 +389,7 @@ def _compile_jax(self, lr, loss_fn, decay): self.opt_state = self.opt.init(self.params) @jax.jit + @utils.list_handler def outputs(params, training, inputs): return self.net.apply(params, inputs, training=training) @@ -392,9 +398,9 @@ def outputs_losses(params, training, inputs, targets, losses_fn): # TODO: Add auxiliary vars def outputs_fn(inputs): - return self.net.apply(nn_params, inputs, training=training) + return outputs(nn_params, training, inputs) - outputs_ = self.net.apply(nn_params, inputs, training=training) + outputs_ = outputs(nn_params, training, inputs) # Data losses # We use aux so that self.data.losses is a pure function. aux = [outputs_fn, ext_params] if ext_params else [outputs_fn] diff --git a/deepxde/nn/jax/__init__.py b/deepxde/nn/jax/__init__.py index 925d19b7f..256f29914 100644 --- a/deepxde/nn/jax/__init__.py +++ b/deepxde/nn/jax/__init__.py @@ -1,6 +1,7 @@ """Package for jax NN modules.""" -__all__ = ["FNN", "NN", "PFNN"] +__all__ = ["FNN", "NN", "PFNN", "SPINN"] +from .snn import SPINN from .fnn import FNN, PFNN from .nn import NN diff --git a/deepxde/nn/jax/snn.py b/deepxde/nn/jax/snn.py new file mode 100644 index 000000000..3ec67a159 --- /dev/null +++ b/deepxde/nn/jax/snn.py @@ -0,0 +1,216 @@ +from typing import Any, Callable, Sequence + +import jax +import jax.numpy as jnp +from flax import linen as nn + +from .nn import NN +from .. import activations +from .. import initializers +from ...utils import list_handler + +class SPINN(NN): + """ + This code is a direct adaptation of the original SPINN paper: + - paper : https://arxiv.org/abs/2306.15969 + - code : https://github.com/stnamjef/SPINN + """ + layer_sizes: Any + activation: Any + kernel_initializer: Any + mlp: str = 'mlp' + pos_enc: int = 0 + + params: Any = None + _input_transform: Callable = None + _output_transform: Callable = None + + def setup(self): + self.in_dim = self.layer_sizes[0] # input dimension + self.r = self.layer_sizes[-2] # rank of the approximated tensor + self.out_dim = self.layer_sizes[-1] # output dimension + self.init = initializers.get(self.kernel_initializer) + self.features = self.layer_sizes[1:-2] + + + @nn.compact + def __call__(self, inputs, training=False): + + if self._input_transform is not None: + x = self._input_transform(x) + + list_inputs = [] + for i in range(self.in_dim): + if inputs.ndim == 1: + list_inputs.append(inputs[i:i+1]) + else: + list_inputs.append(inputs[:, i:i+1]) + + if self.in_dim == 1: + raise ValueError("Input dimension must be greater than 1") + elif self.in_dim == 2: + outputs = self.SPINN2d(list_inputs) + elif self.in_dim == 3: + outputs = self.SPINN3d(list_inputs) + elif self.in_dim == 4: + outputs = self.SPINN4d(list_inputs) + else: + outputs = self.SPINNnd(list_inputs) + + if self._output_transform is not None: + outputs = self._output_transform(inputs, outputs) + + return outputs + + def SPINN2d(self, inputs): + # inputs = [x, y] + flat_inputs = inputs[0].ndim == 1 + if flat_inputs: + inputs = [inputs_elem.reshape(-1, 1) for inputs_elem in inputs] + outputs, pred = [], [] + if self.mlp == 'mlp': + for X in inputs: + for fs in self.features: + X = nn.Dense(fs, kernel_init=self.init)(X) + X = nn.activation.tanh(X) + X = nn.Dense(self.r*self.out_dim, kernel_init=self.init)(X) + outputs += [X] + else: + for X in inputs: + U = nn.activation.tanh(nn.Dense(self.features[0], kernel_init=self.init)(X)) + V = nn.activation.tanh(nn.Dense(self.features[0], kernel_init=self.init)(X)) + H = nn.activation.tanh(nn.Dense(self.features[0], kernel_init=self.init)(X)) + for fs in self.features: + Z = nn.Dense(fs, kernel_init=self.init)(H) + Z = nn.activation.tanh(Z) + H = (jnp.ones_like(Z)-Z)*U + Z*V + H = nn.Dense(self.r*self.out_dim, kernel_init=self.init)(H) + outputs += [H] + + for i in range(self.out_dim): + pred += [ + jnp.dot( + outputs[0][:,self.r * i : self.r * (i + 1)], + outputs[-1][:,self.r * i : self.r * (i + 1)].T, + ).reshape(-1) + ] + + if len(pred) == 1: + # 1-dimensional output + return pred[0].reshape(-1) if flat_inputs else pred[0].reshape(-1, 1) + else: + return jnp.stack(pred, axis=-1).reshape(-1) if flat_inputs else jnp.stack(pred, axis=-1) + + def SPINN3d(self, inputs): + ''' + inputs: input factorized coordinates + outputs: feature output of each body network + xy: intermediate tensor for feature merge btw. x and y axis + pred: final model prediction (e.g. for 2d output, pred=[u, v]) + ''' + [x, y, z] = inputs + if self.pos_enc != 0: + # positional encoding only to spatial coordinates + freq = jnp.expand_dims(jnp.arange(1, self.pos_enc+1, 1), 0) + y = jnp.concatenate((jnp.ones((y.shape[0], 1)), jnp.sin(y@freq), jnp.cos(y@freq)), 1) + z = jnp.concatenate((jnp.ones((z.shape[0], 1)), jnp.sin(z@freq), jnp.cos(z@freq)), 1) + + # causal PINN version (also on time axis) + # freq_x = jnp.expand_dims(jnp.power(10.0, jnp.arange(0, 3)), 0) + # x = x@freq_x + + inputs, outputs, xy, pred = [x, y, z], [], [], [] + init = nn.initializers.glorot_normal() + + if self.mlp == 'mlp': + for X in inputs: + for fs in self.features: + X = nn.Dense(fs, kernel_init=self.init)(X) + X = nn.activation.tanh(X) + X = nn.Dense(self.r*self.out_dim, kernel_init=self.init)(X) + outputs += [jnp.transpose(X, (1, 0))] + + elif self.mlp == 'modified_mlp': + for X in inputs: + U = nn.activation.tanh(nn.Dense(self.features[0], kernel_init=self.init)(X)) + V = nn.activation.tanh(nn.Dense(self.features[0], kernel_init=self.init)(X)) + H = nn.activation.tanh(nn.Dense(self.features[0], kernel_init=self.init)(X)) + for fs in self.features: + Z = nn.Dense(fs, kernel_init=self.init)(H) + Z = nn.activation.tanh(Z) + H = (jnp.ones_like(Z)-Z)*U + Z*V + H = nn.Dense(self.r*self.out_dim, kernel_init=self.init)(H) + outputs += [jnp.transpose(H, (1, 0))] + + for i in range(self.out_dim): + xy += [jnp.einsum('fx, fy->fxy', outputs[0][self.r*i:self.r*(i+1)], outputs[1][self.r*i:self.r*(i+1)])] + pred += [jnp.einsum('fxy, fz->xyz', xy[i], outputs[-1][self.r*i:self.r*(i+1)]).ravel()] + + if len(pred) == 1: + # 1-dimensional output + return pred[0] + else: + # n-dimensional output + return jnp.stack(pred, axis=1) + + def SPINN4d(self, inputs): + outputs, tx, txy, pred = [], [], [], [] + # inputs = [t, x, y, z] + init = nn.initializers.glorot_normal() + for X in inputs: + for fs in self.features: + X = nn.Dense(fs, kernel_init=self.init)(X) + X = nn.activation.tanh(X) + X = nn.Dense(self.r*self.out_dim, kernel_init=self.init)(X) + outputs += [jnp.transpose(X, (1, 0))] + + for i in range(self.out_dim): + tx += [jnp.einsum('ft, fx->ftx', + outputs[0][self.r*i:self.r*(i+1)], + outputs[1][self.r*i:self.r*(i+1)])] + + txy += [jnp.einsum('ftx, fy->ftxy', + tx[i], + outputs[2][self.r*i:self.r*(i+1)])] + + pred += [jnp.einsum('ftxy, fz->txyz', + txy[i], + outputs[3][self.r*i:self.r*(i+1)]).ravel()] + + + if len(pred) == 1: + # 1-dimensional output + return pred[0] + else: + # n-dimensional output + return jnp.stack(pred, axis=1) + + def SPINNnd(self, inputs): + # inputs = [t, *x] + dim = len(inputs) + # inputs, outputs, tx, txy, pred = [t, x, y, z], [], [], [], [] + # inputs, outputs = [t, x, y, z], [] + outputs = [] + init = nn.initializers.glorot_normal() + for X in inputs: + for fs in self.features[:-1]: + X = nn.Dense(fs, kernel_init=self.init)(X) + X = nn.activation.tanh(X) + X = nn.Dense(self.r, kernel_init=self.init)(X) + outputs += [jnp.transpose(X, (1, 0))] + + # einsum(a,b->c) + a = 'za' + b = 'zb' + c = 'zab' + pred = jnp.einsum(f'{a}, {b}->{c}', outputs[0], outputs[1]) + for i in range(dim-2): + a = c + b = f'z{chr(97+i+2)}' + c = c+chr(97+i+2) + if i == dim-3: + c = c[1:] + pred = jnp.einsum(f'{a}, {b}->{c}', pred, outputs[i+2]).ravel() + # pred = jnp.einsum('fab, fc->fabc', pred, outputs[i+2]) + + return pred \ No newline at end of file diff --git a/deepxde/utils/internal.py b/deepxde/utils/internal.py index c2b162f11..1a117da7b 100644 --- a/deepxde/utils/internal.py +++ b/deepxde/utils/internal.py @@ -231,3 +231,15 @@ def mpi_scatter_from_rank0(array, drop_last=True): ] # We truncate array size to be a multiple of num_split to prevent a MPI error. config.comm.Scatter(array, array_split, root=0) return array_split + + +def list_handler(func): + @wraps(func) + def wrapper(*args, **kwargs): + inputs = args[-1] + if isinstance(inputs, (list, tuple)): + results = [func(*args[:-1], input_item, **kwargs) for input_item in inputs] + return bkd.concat(results, axis=0) + return func(*args) + + return wrapper