diff --git a/deepxde/experimental/__init__.py b/deepxde/experimental/__init__.py new file mode 100644 index 000000000..388086d03 --- /dev/null +++ b/deepxde/experimental/__init__.py @@ -0,0 +1,21 @@ +__all__ = [ + "callbacks", + "geometry", + "grad", + "icbc", + "metrics", + "nn", + "problem", + "utils", + "Trainer", +] + +from . import callbacks +from . import geometry +from . import grad +from . import icbc +from . import metrics +from . import nn +from . import problem +from . import utils +from ._trainer import Trainer diff --git a/deepxde/experimental/_trainer.py b/deepxde/experimental/_trainer.py new file mode 100644 index 000000000..69aaeb1ee --- /dev/null +++ b/deepxde/experimental/_trainer.py @@ -0,0 +1,534 @@ +import time +from typing import Union, Sequence, Callable, Optional + +import brainstate as bst +import brainunit as u +import jax.numpy as jnp +import jax.tree +import numpy as np + +from deepxde.model import LossHistory, TrainState as TrainStateBase +from deepxde.utils.internal import timing +from . import metrics as metrics_module +from .callbacks import CallbackList, Callback +from .problem.base import Problem +from .utils.display import training_display +from .utils.external import saveplot + +__all__ = [ + "Trainer", + "TrainState", + "LossHistory", +] + + +class Trainer: + """ + A ``Trainer`` trains a neural network on a ``Problem``. + + Args: + problem: ``experimental.problem.Problem`` instance. + external_trainable_variables: A trainable ``brainstate.ParamState`` object or a list + of trainable ``brainstate.ParamState`` objects. The unknown parameters in the + physics systems that need to be recovered. + """ + + __module__ = "deepxde.experimental" + optimizer: bst.optim.Optimizer # optimizer + problem: Problem # problem + params: bst.util.FlattedDict # trainable variables + + def __init__( + self, + problem: Problem, + external_trainable_variables: Union[ + bst.ParamState, Sequence[bst.ParamState] + ] = None, + batch_size: Optional[int] = None, + ): + """ + Initialize the Trainer. + + Args: + problem (Problem): The problem instance to be solved. + external_trainable_variables (Union[bst.ParamState, Sequence[bst.ParamState]], optional): + External trainable variables to be included in the optimization process. + Can be a single ParamState or a sequence of ParamStates. Defaults to None. + batch_size (Optional[int], optional): The batch size to be used during training. + If None, the entire dataset will be used. Defaults to None. + + Raises: + ValueError: If the problem does not define an approximator. + AssertionError: If the problem is not a Problem instance or if external_trainable_variables + are not ParamState instances. + + Returns: + None + """ + # the problem + self.problem = problem + assert isinstance(self.problem, Problem), "problem must be a Problem instance." + + # the approximator + if self.problem.approximator is None: + raise ValueError("Problem must define an approximator before training.") + + # parameters and external trainable variables + params = bst.graph.states(self.problem.approximator, bst.ParamState) + if external_trainable_variables is None: + external_trainable_variables = [] + else: + if not isinstance(external_trainable_variables, list): + external_trainable_variables = [external_trainable_variables] + for i, var in enumerate(external_trainable_variables): + assert isinstance(var, bst.ParamState), ( + "external_trainable_variables must be a " "list of ParamState instance." + ) + params[("external_trainable_variable", i)] = var + self.params = params + + # other useful parameters + self.metrics = None + self.batch_size = batch_size + + # training state + self.train_state = TrainState() + self.loss_history = LossHistory() + self.stop_training = False + + @timing + def compile( + self, + optimizer: bst.optim.Optimizer, + metrics: Union[str, Sequence[str]] = None, + measture_train_step_compile_time: bool = False, + ): + """ + Configures the trainer for training. + + Args: + optimizer: String name of an optimizer, or an optimizer class instance. + metrics: List of metrics to be evaluated by the trainer during training. + """ + print("Compiling trainer...") + + # optimizer + assert isinstance( + optimizer, bst.optim.Optimizer + ), "optimizer must be an Optimizer instance." + self.optimizer = optimizer + self.optimizer.register_trainable_weights(self.params) + + # metrics may use trainer variables such as self.net, + # and thus are instantiated after compile. + metrics = metrics or [] + self.metrics = [metrics_module.get(m) for m in metrics] + + def fn_outputs(training: bool, inputs): + with bst.environ.context(fit=training): + inputs = jax.tree.map( + lambda x: u.math.asarray(x), inputs, is_leaf=u.math.is_quantity + ) + return self.problem.approximator(inputs) + + def fn_outputs_losses(training, inputs, targets, **kwargs): + with bst.environ.context(fit=training): + # inputs + inputs = jax.tree.map( + lambda x: u.math.asarray(x), inputs, is_leaf=u.math.is_quantity + ) + + # outputs + outputs = self.problem.approximator(inputs) + + # targets + if targets is not None: + targets = jax.tree.map( + lambda x: u.math.asarray(x), targets, is_leaf=u.math.is_quantity + ) + + # compute losses + if training: + losses = self.problem.losses_train( + inputs, outputs, targets, **kwargs + ) + else: + losses = self.problem.losses_test( + inputs, outputs, targets, **kwargs + ) + return outputs, losses + + def fn_outputs_losses_train(inputs, targets, **aux): + return fn_outputs_losses(True, inputs, targets, **aux) + + def fn_outputs_losses_test(inputs, targets, **aux): + return fn_outputs_losses(False, inputs, targets, **aux) + + def fn_train_step(inputs, targets, **aux): + def _loss_fun(): + losses = fn_outputs_losses_train(inputs, targets, **aux)[1] + return u.math.sum( + u.math.asarray([loss.sum() for loss in jax.tree.leaves(losses)]) + ) + + grads = bst.augment.grad(_loss_fun, grad_states=self.params)() + self.optimizer.update(grads) + + # Callables + self.fn_outputs = bst.compile.jit(fn_outputs, static_argnums=0) + self.fn_outputs_losses_train = bst.compile.jit(fn_outputs_losses_train) + self.fn_outputs_losses_test = bst.compile.jit(fn_outputs_losses_test) + self.fn_train_step = bst.compile.jit(fn_train_step) + + if measture_train_step_compile_time: + t0 = time.time() + self._compile_training_step(self.batch_size) + t1 = time.time() + return self, t1 - t0 + + return self + + @timing + def train( + self, + iterations: int, + batch_size: int = None, + display_every: int = 1000, + disregard_previous_best: bool = False, + callbacks: Union[Callback, Sequence[Callback]] = None, + model_restore_path: str = None, + model_save_path: str = None, + measture_train_step_time: bool = False, + ): + """ + Trains the trainer. + + Args: + iterations (Integer): Number of iterations to train the trainer, i.e., number + of times the network weights are updated. + batch_size: Integer, tuple, or ``None``. + + - If you solve PDEs via ``experimental.problem.PDE`` or ``experimental.problem.TimePDE``, do not use `batch_size`, + and instead use `experimental.callbacks.PDEPointResampler + `_, + see an `example `_. + - For DeepONet in the format of Cartesian product, if `batch_size` is an Integer, + then it is the batch size for the branch input; + if you want to also use mini-batch for the trunk net input, + set `batch_size` as a tuple, where the fist number is the batch size for the branch net input + and the second number is the batch size for the trunk net input. + display_every (Integer): Print the loss and metrics every this steps. + disregard_previous_best: If ``True``, disregard the previous saved best + trainer. + callbacks: List of ``experimental.callbacks.Callback`` instances. List of callbacks + to apply during training. + model_restore_path (String): Path where parameters were previously saved. + model_save_path (String): Prefix of filenames created for the checkpoint. + """ + + if measture_train_step_time: + t0 = time.time() + + if self.metrics is None: + raise ValueError("Compile the trainer before training.") + + # callbacks + callbacks = CallbackList( + callbacks=[callbacks] if isinstance(callbacks, Callback) else callbacks + ) + callbacks.set_model(self) + + # disregard previous best + if disregard_previous_best: + self.train_state.disregard_best() + + # restore + if model_restore_path is not None: + self.restore(model_restore_path, verbose=1) + + print("Training trainer...\n") + self.stop_training = False + + # testing + self.train_state.set_data_train(*self.problem.train_next_batch(batch_size)) + self.train_state.set_data_test(*self.problem.test()) + self._test() + + # training + callbacks.on_train_begin() + self._train(iterations, display_every, batch_size, callbacks) + callbacks.on_train_end() + + # summary + print("") + training_display.summary(self.train_state) + if model_save_path is not None: + self.save(model_save_path, verbose=1) + + if measture_train_step_time: + t1 = time.time() + return self, t1 - t0 + return self + + def _compile_training_step(self, batch_size=None): + # get data + self.train_state.set_data_train(*self.problem.train_next_batch(batch_size)) + + # train one batch + self.fn_train_step.compile( + self.train_state.X_train, + self.train_state.y_train, + **self.train_state.Aux_train, + ) + + def _train(self, iterations, display_every, batch_size, callbacks): + for i in range(iterations): + callbacks.on_epoch_begin() + callbacks.on_batch_begin() + + # get data + self.train_state.set_data_train(*self.problem.train_next_batch(batch_size)) + + # train one batch + self.fn_train_step( + self.train_state.X_train, + self.train_state.y_train, + **self.train_state.Aux_train, + ) + + self.train_state.epoch += 1 + self.train_state.step += 1 + if self.train_state.step % display_every == 0 or i + 1 == iterations: + self._test() + + callbacks.on_batch_end() + callbacks.on_epoch_end() + + if self.stop_training: + break + + def _test(self): + # evaluate the training data + ( + self.train_state.y_pred_train, + self.train_state.loss_train, + ) = self.fn_outputs_losses_train( + self.train_state.X_train, + self.train_state.y_train, + **self.train_state.Aux_train, + ) + + # evaluate the test data + (self.train_state.y_pred_test, self.train_state.loss_test) = ( + self.fn_outputs_losses_test( + self.train_state.X_test, + self.train_state.y_test, + **self.train_state.Aux_test, + ) + ) + + # metrics + if isinstance(self.train_state.y_test, (list, tuple)): + self.train_state.metrics_test = [ + m(self.train_state.y_test[i], self.train_state.y_pred_test[i]) + for m in self.metrics + for i in range(len(self.train_state.y_test)) + ] + else: + self.train_state.metrics_test = [ + m(self.train_state.y_test, self.train_state.y_pred_test) + for m in self.metrics + ] + + # history + self.train_state.update_best() + self.loss_history.append( + self.train_state.step, + self.train_state.loss_train, + self.train_state.loss_test, + self.train_state.metrics_test, + ) + + # check NaN + if ( + jnp.isnan(jnp.asarray(jax.tree.leaves(self.train_state.loss_train))).any() + or jnp.isnan(jnp.asarray(jax.tree.leaves(self.train_state.loss_test))).any() + ): + self.stop_training = True + + # display + training_display(self.train_state) + + def predict( + self, + xs, + operator: Optional[Callable] = None, + callbacks: Union[Callback, Sequence[Callback]] = None, + ): + """Generates predictions for the input samples. If `operator` is ``None``, + returns the network output, otherwise returns the output of the `operator`. + + Args: + xs: The network inputs. A Numpy array or a tuple of Numpy arrays. + operator: A function takes arguments (`neural_net`, `inputs`) and outputs a tensor. `inputs` and + `outputs` are the network input and output tensors, respectively. `operator` is typically + chosen as the PDE (used to define `experimental.problem.PDE`) to predict the PDE residual. + callbacks: List of ``experimental.callbacks.Callback`` instances. List of callbacks + to apply during prediction. + """ + xs = jax.tree.map( + lambda x: u.math.asarray(x, dtype=bst.environ.dftype()), + xs, + is_leaf=u.math.is_quantity, + ) + callbacks = CallbackList( + callbacks=[callbacks] if isinstance(callbacks, Callback) else callbacks + ) + callbacks.set_model(self) + callbacks.on_predict_begin() + ys = self.fn_outputs(False, xs) + if operator is not None: + ys = operator(xs, ys) + callbacks.on_predict_end() + return ys + + def save(self, save_path, verbose: int = 0): + """Saves all variables to a disk file. + + Args: + save_path (string): Prefix of filenames to save the trainer file. + verbose (int): Verbosity mode, 0 or 1. + + Returns: + string: Path where trainer is saved. + """ + import braintools + + # save path + save_path = f"{save_path}-{self.train_state.epoch}.msgpack" + + # avoid the duplicate ParamState save + model = bst.graph.Dict(params=self.params, optimizer=self.optimizer) + + checkpoint = bst.graph.states(model).to_nest() + braintools.file.msgpack_save(save_path, checkpoint) + + if verbose > 0: + print( + "Epoch {}: saving trainer to {} ...\n".format( + self.train_state.epoch, save_path + ) + ) + return save_path + + def restore(self, save_path, verbose: int = 0): + """Restore all variables from a disk file. + + Args: + save_path (string): Path where trainer was previously saved. + verbose (int): Verbosity mode, 0 or 1. + """ + import braintools + + if verbose > 0: + print("Restoring trainer from {} ...\n".format(save_path)) + + data = bst.graph.Dict(params=self.params, optimizer=self.optimizer) + + checkpoint = bst.graph.states(data).to_nest() + braintools.file.msgpack_load(save_path, target=checkpoint) + + def saveplot( + self, + issave: bool = True, + isplot: bool = True, + loss_fname: str = "loss.dat", + train_fname: str = "train.dat", + test_fname: str = "test.dat", + output_dir: str = None, + ): + """ + Saves and plots the loss and metrics. + + Args: + issave: If ``True``, save the loss and metrics to files. + isplot: If ``True``, plot the loss and metrics. + loss_fname: Filename to save the loss. + train_fname: Filename to save the training metrics. + test_fname: Filename to save the test metrics. + output_dir: Directory to save the files. + """ + saveplot( + self.loss_history, + self.train_state, + issave=issave, + isplot=isplot, + loss_fname=loss_fname, + train_fname=train_fname, + test_fname=test_fname, + output_dir=output_dir, + ) + + +class TrainState(TrainStateBase): + __module__ = "deepxde.experimental" + + def __init__(self): + self.epoch = 0 + self.step = 0 + + # Current data + self.X_train = None + self.y_train = None + self.Aux_train = dict() + self.X_test = None + self.y_test = None + self.Aux_test = dict() + + # Results of current step + # Train results + self.loss_train = None + self.y_pred_train = None + # Test results + self.loss_test = None + self.y_pred_test = None + self.y_std_test = None + self.metrics_test = None + + # The best results correspond to the min train loss + self.best_step = 0 + self.best_loss_train = np.inf + self.best_loss_test = np.inf + self.best_y = None + self.best_ystd = None + self.best_metrics = None + + def set_data_train(self, X_train, y_train, *args): + self.X_train = X_train + self.y_train = y_train + if len(args) > 0: + assert len(args) == 1, "Auxiliary training data must be a single argument." + assert isinstance( + args[0], dict + ), "Auxiliary training data must be a dictionary." + self.Aux_train = args[0] + + def set_data_test(self, X_test, y_test, *args): + self.X_test = X_test + self.y_test = y_test + if len(args) > 0: + assert len(args) == 1, "Auxiliary test data must be a single argument." + assert isinstance( + args[0], dict + ), "Auxiliary test data must be a dictionary." + self.Aux_test = args[0] + + def update_best(self): + current_loss_train = jnp.sum(jnp.asarray(jax.tree.leaves(self.loss_train))) + if self.best_loss_train > current_loss_train: + self.best_step = self.step + self.best_loss_train = current_loss_train + self.best_loss_test = jnp.sum(jnp.asarray(jax.tree.leaves(self.loss_test))) + self.best_y = self.y_pred_test + self.best_ystd = self.y_std_test + self.best_metrics = self.metrics_test diff --git a/deepxde/experimental/callbacks.py b/deepxde/experimental/callbacks.py new file mode 100644 index 000000000..d380ee4d5 --- /dev/null +++ b/deepxde/experimental/callbacks.py @@ -0,0 +1,171 @@ +import sys + +import brainstate as bst +import brainunit as u +import jax.tree +import numpy as np + +from deepxde.callbacks import ( + Callback, + CallbackList, + ModelCheckpoint, + Timer, + MovieDumper, + PDEPointResampler, + EarlyStopping as EarlyStoppingCallback, + DropoutUncertainty as DropoutUncertaintyCallback, + OperatorPredictor as OperatorPredictorCallback, +) +from deepxde.utils.internal import list_to_str + +__all__ = [ + "Callback", + "CallbackList", + "ModelCheckpoint", + "EarlyStopping", + "Timer", + "DropoutUncertainty", + "VariableValue", + "OperatorPredictor", + "MovieDumper", + "PDEPointResampler", +] + + +class EarlyStopping(EarlyStoppingCallback): + """Stop training when a monitored quantity (training or testing loss) has stopped improving. + Only checked at validation step according to ``display_every`` in ``Trainer.train``. + + Args: + min_delta: Minimum change in the monitored quantity + to qualify as an improvement, i.e. an absolute + change of less than min_delta, will count as no + improvement. + patience: Number of epochs with no improvement + after which training will be stopped. + baseline: Baseline value for the monitored quantity to reach. + Training will stop if the trainer doesn't show improvement + over the baseline. + monitor: The loss function that is monitored. Either 'loss_train' or 'loss_test' + start_from_epoch: Number of epochs to wait before starting + to monitor improvement. This allows for a warm-up period in which + no improvement is expected and thus training will not be stopped. + """ + + def get_monitor_value(self): + if self.monitor == "loss_train": + result = np.sum(jax.tree.leaves(self.model.train_state.loss_train)) + elif self.monitor == "loss_test": + result = np.sum(jax.tree.leaves(self.model.train_state.loss_test)) + else: + raise ValueError("The specified monitor function is incorrect.") + + return result + + +class DropoutUncertainty(DropoutUncertaintyCallback): + """Uncertainty estimation via MC dropout. + + References: + `Y. Gal, & Z. Ghahramani. Dropout as a Bayesian approximation: Representing + trainer uncertainty in deep learning. International Conference on Machine + Learning, 2016 `_. + + Warning: + This cannot be used together with other techniques that have different behaviors + during training and testing, such as batch normalization. + """ + + def on_epoch_end(self): + self.epochs_since_last += 1 + if self.epochs_since_last >= self.period: + self.epochs_since_last = 0 + y_preds = [] + for _ in range(1000): + y_pred_test_one = self.model.fn_outputs( + True, self.model.train_state.X_test + ) + y_preds.append(y_pred_test_one) + y_preds = jax.tree.map( + lambda *x: u.math.stack(x, axis=0), *y_preds, is_leaf=u.math.is_quantity + ) + self.model.train_state.y_std_test = jax.tree.map( + lambda x: u.math.std(x, axis=0), y_preds, is_leaf=u.math.is_quantity + ) + + +class VariableValue(Callback): + """Get the variable values. + + Args: + var_list: A `TensorFlow Variable `_ + or a list of TensorFlow Variable. + period (int): Interval (number of epochs) between checking values. + filename (string): Output the values to the file `filename`. + The file is kept open to allow instances to be re-used. + If ``None``, output to the screen. + precision (int): The precision of variables to display. + """ + + def __init__(self, var_list, period=1, filename=None, precision=2): + super().__init__() + self.var_list = var_list if isinstance(var_list, (tuple, list)) else [var_list] + for v in self.var_list: + if not isinstance(v, bst.State): + raise ValueError("The variable must be a brainstate.State object.") + + self.period = period + self.precision = precision + + self.file = sys.stdout if filename is None else open(filename, "w", buffering=1) + self.value = None + self.epochs_since_last = 0 + + def on_train_begin(self): + self.value = [var.value for var in self.var_list] + + print( + self.model.train_state.epoch, + list_to_str(self.value, precision=self.precision), + file=self.file, + ) + self.file.flush() + + def on_epoch_end(self): + self.epochs_since_last += 1 + if self.epochs_since_last >= self.period: + self.epochs_since_last = 0 + self.on_train_begin() + + def on_train_end(self): + if not self.epochs_since_last == 0: + self.on_train_begin() + + def get_value(self): + """Return the variable values.""" + return self.value + + +class OperatorPredictor(OperatorPredictorCallback): + """ + Generates operator values for the input samples. + + Args: + x: The input data. + op: The operator with inputs (x, y). + period (int): Interval (number of epochs) between checking values. + filename (string): Output the values to the file `filename`. + The file is kept open to allow instances to be re-used. + If ``None``, output to the screen. + precision (int): The precision of variables to display. + """ + + def on_predict_end(self): + self.value = self._eval() + # self.value = jax.tree.map(np.asarray, self._eval()) + + @bst.compile.jit(static_argnums=0) + def _eval(self): + with bst.environ.context(fit=False): + outputs = self.model.problem.approximator(self.x) + return self.op(self.x, outputs) diff --git a/deepxde/experimental/geometry/__init__.py b/deepxde/experimental/geometry/__init__.py new file mode 100644 index 000000000..99b6ce842 --- /dev/null +++ b/deepxde/experimental/geometry/__init__.py @@ -0,0 +1,25 @@ +__all__ = [ + "DictPointGeometry", + "Cuboid", + "Disk", + "Ellipse", + "GeometryXTime", + "Hypercube", + "Hypersphere", + "Interval", + "PointCloud", + "Polygon", + "Rectangle", + "Sphere", + "StarShaped", + "TimeDomain", + "Triangle", +] + +from .base import DictPointGeometry +from .geometry_1d import Interval +from .geometry_2d import Disk, Ellipse, Polygon, Rectangle, StarShaped, Triangle +from .geometry_3d import Cuboid, Sphere +from .geometry_nd import Hypercube, Hypersphere +from .pointcloud import PointCloud +from .timedomain import TimeDomain, GeometryXTime diff --git a/deepxde/experimental/geometry/base.py b/deepxde/experimental/geometry/base.py new file mode 100644 index 000000000..a4773c5a6 --- /dev/null +++ b/deepxde/experimental/geometry/base.py @@ -0,0 +1,491 @@ +from typing import Dict, Union + +import brainstate as bst +import brainunit as u +import jax.numpy as jnp +import numpy as np + +from deepxde.geometry.geometry import Geometry +from deepxde.experimental import utils + +__all__ = [ + "GeometryExperimental", + "DictPointGeometry", +] + + +class GeometryExperimental(Geometry): + """ + A base class for geometries in the PINNx (Physics-Informed Neural Networks Extended) framework. + + This class extends the functionality of the base Geometry class to provide additional + features specific to the PINNx framework. It serves as a foundation for creating + more specialized geometry classes that can work with dictionary-based point representations + and unit-aware computations. + + Attributes: + Inherits all attributes from the Geometry base class. + + Methods: + to_dict_point(*names, **kw_names): + Converts the geometry to a dictionary-based point representation. + + Note: + This class is designed to be subclassed for specific geometry implementations + in the PINNx framework. It provides a bridge between the standard Geometry + representations and the more flexible, unit-aware representations used in PINNx. + + Example: + class CustomGeometry(GeometryExperimental): + def __init__(self, dim, bbox, diam): + super().__init__(dim, bbox, diam) + # Additional initialization specific to CustomGeometry + + # Implement other required methods + + # Usage + custom_geom = CustomGeometry(dim=2, bbox=[0, 1, 0, 1], diam=1.414) + dict_geom = custom_geom.to_dict_point('x', 'y', z=u.meter) + """ + + def to_dict_point(self, *names, **kw_names): + """ + Convert the geometry to a dictionary geometry. + + This method creates a DictPointGeometry object, which represents the geometry + using named coordinates and their associated units. + + Args: + *names (str): Variable length argument list of coordinate names. + These are assumed to be unitless. + **kw_names (dict): Arbitrary keyword arguments where keys are coordinate names + and values are their corresponding units. + + Returns: + DictPointGeometry: A new geometry object that represents the current geometry + using a dictionary-based structure with named coordinates + and units. + + Raises: + ValueError: If the number of provided names doesn't match the dimension of the geometry. + + Note: + If a coordinate is specified in both *names and **kw_names, the unit from **kw_names will be used. + """ + return DictPointGeometry(self, *names, **kw_names) + + +def quantity_to_array( + quantity: Union[np.ndarray, jnp.ndarray, u.Quantity], unit: u.Unit +): + """ + Convert a quantity to an array with specified units. + + This function takes a quantity (which can be a numpy array, JAX array, or a Quantity object) + and converts it to an array with the specified units. If the input is already a Quantity, + it is converted to the specified unit and its magnitude is returned. If the input is an array, + it is returned as-is, but only if the specified unit is unitless. + + Parameters: + ----------- + quantity : Union[np.ndarray, jnp.ndarray, u.Quantity] + The input quantity to be converted. Can be a numpy array, JAX array, or a Quantity object. + unit : u.Unit + The target unit for conversion. If the input is not a Quantity, this must be unitless. + + Returns: + -------- + Union[np.ndarray, jnp.ndarray] + The magnitude of the quantity in the specified units, returned as an array. + + Raises: + ------- + AssertionError + If the input is not a Quantity and the specified unit is not unitless. + """ + if isinstance(quantity, u.Quantity): + return quantity.to(unit).magnitude + else: + assert unit.is_unitless, "The unit should be unitless." + return quantity + + +def array_to_quantity(array: Union[np.ndarray, jnp.ndarray], unit: u.Unit): + """ + Convert an array to a Quantity object with specified units. + + This function takes an array (either numpy or JAX) and a unit, and returns + a Quantity object representing the array with the given unit. + + Parameters: + ----------- + array : Union[np.ndarray, jnp.ndarray] + The input array to be converted to a Quantity. Can be either a numpy array + or a JAX array. + unit : u.Unit + The unit to be associated with the array values. + + Returns: + -------- + u.Quantity + A Quantity object representing the input array with the specified unit. + The returned object may be a decimal representation if appropriate. + """ + return u.math.maybe_decimal(u.Quantity(array, unit=unit)) + + +class DictPointGeometry(GeometryExperimental): + """ + A class that converts a standard Geometry object to a dictionary-based geometry representation. + + This class extends GeometryExperimental to provide a more flexible, named coordinate system + with unit awareness. It wraps an existing Geometry object and allows access to its + methods while providing additional functionality for working with named coordinates. + + Attributes: + geom (Geometry): The original geometry object being wrapped. + name2unit (dict): A dictionary mapping coordinate names to their corresponding units. + + Methods: + arr_to_dict(x): Convert an array to a dictionary of named quantities. + dict_to_arr(x): Convert a dictionary of named quantities to an array. + inside(x): Check if points are inside the geometry. + on_initial(x): Check if points are on the initial boundary. + on_boundary(x): Check if points are on the boundary of the geometry. + distance2boundary(x, dirn): Calculate the distance to the boundary in a specific direction. + mindist2boundary(x): Calculate the minimum distance to the boundary. + boundary_constraint_factor(x, **kw): Calculate the boundary constraint factor. + boundary_normal(x): Calculate the boundary normal vectors. + uniform_points(n, boundary): Generate uniformly distributed points in the geometry. + random_points(n, random): Generate random points in the geometry. + uniform_boundary_points(n): Generate uniformly distributed points on the boundary. + random_boundary_points(n, random): Generate random points on the boundary. + periodic_point(x, component): Find the periodic point for a given point and component. + background_points(x, dirn, dist2npt, shift): Generate background points. + random_initial_points(n, random): Generate random initial points. + uniform_initial_points(n): Generate uniformly distributed initial points. + + The class provides a bridge between array-based and dictionary-based representations + of geometric points, allowing for more intuitive and unit-aware manipulations of + geometric data in the context of physics-informed neural networks. + + Example: + geom = Geometry(dim=2, bbox=[0, 1, 0, 1], diam=1.414) + dict_geom = DictPointGeometry(geom, 'x', y=u.meter) + points = dict_geom.uniform_points(100) + # points will be a dictionary with keys 'x' (unitless) and 'y' (in meters) + """ + + def __init__(self, geom: Geometry, *names, **kw_names): + """ + Initialize a DictPointGeometry object. + + Args: + geom (Geometry): The base geometry object to be converted. + *names (str): Variable length argument list of coordinate names (assumed to be unitless). + **kw_names (dict): Arbitrary keyword arguments where keys are coordinate names + and values are their corresponding units. + + Raises: + ValueError: If the number of provided names doesn't match the dimension of the geometry. + """ + super().__init__(geom.dim, geom.bbox, geom.diam) + + self.geom = geom + for name in names: + assert isinstance(name, str), "The name should be a string." + kw_names = { + key: u.UNITLESS if unit is None else unit for key, unit in kw_names.items() + } + for key, unit in kw_names.items(): + assert isinstance(key, str), "The name should be a string." + assert isinstance(unit, u.Unit), "The unit should be a brainunit.Unit." + self.name2unit = {name: u.UNITLESS for name in names} + self.name2unit.update(kw_names) + if len(self.name2unit) != geom.dim: + raise ValueError( + "The number of names should match the dimension of the geometry. " + "But got {} names and {} dimensions.".format( + len(self.name2unit), geom.dim + ) + ) + + def arr_to_dict(self, x: bst.typing.ArrayLike) -> Dict[str, bst.typing.ArrayLike]: + """ + Convert an array to a dictionary of named quantities. + + Args: + x (ArrayLike): The input array to be converted. + + Returns: + Dict[str, ArrayLike]: A dictionary where keys are coordinate names and values are quantities. + """ + return { + name: array_to_quantity(x[..., i], unit) + for i, (name, unit) in enumerate(self.name2unit.items()) + } + + def dict_to_arr(self, x: Dict[str, bst.typing.ArrayLike]) -> bst.typing.ArrayLike: + """ + Convert a dictionary of named quantities to an array. + + Args: + x (Dict[str, ArrayLike]): The input dictionary to be converted. + + Returns: + ArrayLike: The resulting array. + """ + arrs = [ + quantity_to_array(x[name], unit) for name, unit in self.name2unit.items() + ] + mod = utils.smart_numpy(arrs[0]) + return mod.stack(arrs, axis=-1) + + def inside( + self, x: Union[np.ndarray, jnp.ndarray, u.Quantity, Dict] + ) -> np.ndarray[bool]: + """ + Check if points are inside the geometry. + + Args: + x (Union[np.ndarray, jnp.ndarray, u.Quantity, Dict]): The points to check. + + Returns: + np.ndarray[bool]: Boolean array indicating whether each point is inside the geometry. + """ + if isinstance(x, dict): + x = self.dict_to_arr(x) + return self.geom.inside(x) + + def on_initial( + self, x: Union[np.ndarray, jnp.ndarray, u.Quantity, Dict] + ) -> np.ndarray: + """ + Check if points are on the initial boundary. + + Args: + x (Union[np.ndarray, jnp.ndarray, u.Quantity, Dict]): The points to check. + + Returns: + np.ndarray: Array indicating whether each point is on the initial boundary. + """ + if isinstance(x, dict): + x = self.dict_to_arr(x) + return self.geom.on_initial(x) + + def on_boundary( + self, x: Union[np.ndarray, jnp.ndarray, u.Quantity, Dict] + ) -> np.ndarray[bool]: + """ + Check if points are on the boundary of the geometry. + + Args: + x (Union[np.ndarray, jnp.ndarray, u.Quantity, Dict]): The points to check. + + Returns: + np.ndarray[bool]: Boolean array indicating whether each point is on the boundary. + """ + if isinstance(x, dict): + x = self.dict_to_arr(x) + return self.geom.on_boundary(x) + + def distance2boundary( + self, x: Union[np.ndarray, jnp.ndarray, u.Quantity, Dict], dirn: int + ) -> np.ndarray: + """ + Calculate the distance to the boundary in a specific direction. + + Args: + x (Union[np.ndarray, jnp.ndarray, u.Quantity, Dict]): The points to calculate from. + dirn (int): The direction to calculate the distance. + + Returns: + np.ndarray: Array of distances to the boundary. + """ + if isinstance(x, dict): + x = self.dict_to_arr(x) + return self.geom.distance2boundary(x, dirn) + + def mindist2boundary( + self, x: Union[np.ndarray, jnp.ndarray, u.Quantity, Dict] + ) -> np.ndarray: + """ + Calculate the minimum distance to the boundary. + + Args: + x (Union[np.ndarray, jnp.ndarray, u.Quantity, Dict]): The points to calculate from. + + Returns: + np.ndarray: Array of minimum distances to the boundary. + """ + if isinstance(x, dict): + x = self.dict_to_arr(x) + return self.geom.mindist2boundary(x) + + def boundary_constraint_factor( + self, x: Union[np.ndarray, jnp.ndarray, u.Quantity, Dict], **kw + ) -> np.ndarray: + """ + Calculate the boundary constraint factor. + + Args: + x (Union[np.ndarray, jnp.ndarray, u.Quantity, Dict]): The points to calculate for. + **kw: Additional keyword arguments. + + Returns: + np.ndarray: Array of boundary constraint factors. + """ + if isinstance(x, dict): + x = self.dict_to_arr(x) + return self.geom.boundary_constraint_factor(x, **kw) + + def boundary_normal( + self, x: Union[np.ndarray, jnp.ndarray, u.Quantity, Dict] + ) -> Dict[str, bst.typing.ArrayLike]: + """ + Calculate the boundary normal vectors. + + Args: + x (Union[np.ndarray, jnp.ndarray, u.Quantity, Dict]): The points to calculate for. + + Returns: + Dict[str, ArrayLike]: Dictionary of boundary normal vectors. + """ + if isinstance(x, dict): + x = self.dict_to_arr(x) + normal = self.geom.boundary_normal(x) + return self.arr_to_dict(normal) + + def uniform_points( + self, n, boundary: bool = True + ) -> Dict[str, bst.typing.ArrayLike]: + """ + Generate uniformly distributed points in the geometry. + + Args: + n (int): Number of points to generate. + boundary (bool, optional): Whether to include boundary points. Defaults to True. + + Returns: + Dict[str, ArrayLike]: Dictionary of generated points. + """ + points = self.geom.uniform_points(n, boundary=boundary) + return self.arr_to_dict(points) + + def random_points(self, n, random="pseudo") -> Dict[str, bst.typing.ArrayLike]: + """ + Generate random points in the geometry. + + Args: + n (int): Number of points to generate. + random (str, optional): Type of random number generation. Defaults to "pseudo". + + Returns: + Dict[str, ArrayLike]: Dictionary of generated points. + """ + points = self.geom.random_points(n, random=random) + return self.arr_to_dict(points) + + def uniform_boundary_points(self, n) -> Dict[str, bst.typing.ArrayLike]: + """ + Generate uniformly distributed points on the boundary. + + Args: + n (int): Number of points to generate. + + Returns: + Dict[str, ArrayLike]: Dictionary of generated boundary points. + """ + points = self.geom.uniform_boundary_points(n) + return self.arr_to_dict(points) + + def random_boundary_points( + self, n, random: str = "pseudo" + ) -> Dict[str, bst.typing.ArrayLike]: + """ + Generate random points on the boundary. + + Args: + n (int): Number of points to generate. + random (str, optional): Type of random number generation. Defaults to "pseudo". + + Returns: + Dict[str, ArrayLike]: Dictionary of generated boundary points. + """ + points = self.geom.random_boundary_points(n, random=random) + return self.arr_to_dict(points) + + def periodic_point( + self, x, component: Union[str, int] + ) -> Dict[str, bst.typing.ArrayLike]: + """ + Find the periodic point for a given point and component. + + Args: + x (Union[np.ndarray, jnp.ndarray, u.Quantity, Dict]): The point to find the periodic point for. + component (Union[str, int]): The component to consider for periodicity. + + Returns: + Dict[str, ArrayLike]: Dictionary of the periodic point. + + Raises: + AssertionError: If the component is not an integer or a string. + """ + if isinstance(x, dict): + x = self.dict_to_arr(x) + if isinstance(component, str): + component = list(self.name2unit.keys()).index(component) + assert isinstance( + component, int + ), f"The component should be an integer or a string. But got {component}." + x = self.geom.periodic_point(x, component) + return self.arr_to_dict(x) + + def background_points( + self, x, dirn, dist2npt, shift + ) -> Dict[str, bst.typing.ArrayLike]: + """ + Generate background points. + + Args: + x (Union[np.ndarray, jnp.ndarray, u.Quantity, Dict]): The reference points. + dirn: The direction for generating background points. + dist2npt: The distance to number of points mapping. + shift: The shift to apply. + + Returns: + Dict[str, ArrayLike]: Dictionary of generated background points. + """ + if isinstance(x, dict): + x = self.dict_to_arr(x) + points = self.geom.background_points(x, dirn, dist2npt, shift) + return self.arr_to_dict(points) + + def random_initial_points( + self, n: int, random: str = "pseudo" + ) -> Dict[str, bst.typing.ArrayLike]: + """ + Generate random initial points. + + Args: + n (int): Number of points to generate. + random (str, optional): Type of random number generation. Defaults to "pseudo". + + Returns: + Dict[str, ArrayLike]: Dictionary of generated initial points. + """ + points = self.geom.random_initial_points(n, random=random) + return self.arr_to_dict(points) + + def uniform_initial_points(self, n: int) -> Dict[str, bst.typing.ArrayLike]: + """ + Generate uniformly distributed initial points. + + Args: + n (int): Number of points to generate. + + Returns: + Dict[str, ArrayLike]: Dictionary of generated initial points. + """ + points = self.geom.uniform_initial_points(n) + return self.arr_to_dict(points) diff --git a/deepxde/experimental/geometry/geometry_1d.py b/deepxde/experimental/geometry/geometry_1d.py new file mode 100644 index 000000000..aeb1bbf8e --- /dev/null +++ b/deepxde/experimental/geometry/geometry_1d.py @@ -0,0 +1,330 @@ +from typing import Literal, Union + +import brainstate as bst +import jax.numpy as jnp + +from deepxde.geometry.sampler import sample +from deepxde.experimental import utils +from .base import GeometryExperimental as Geometry + + +class Interval(Geometry): + """ + Represents a 1D interval geometry. + + This class defines an interval [l, r] and provides various methods for + working with points within and on the boundary of this interval. + """ + + def __init__(self, l, r): + """ + Initialize the Interval object. + + Args: + l (float): The left endpoint of the interval. + r (float): The right endpoint of the interval. + """ + super().__init__( + 1, + ( + jnp.array([l], dtype=bst.environ.dftype()), + jnp.array([r], dtype=bst.environ.dftype()), + ), + r - l, + ) + self.l, self.r = l, r + + def inside(self, x): + """ + Check if points are inside the interval. + + Args: + x (array-like): The points to check. + + Returns: + array: Boolean array indicating whether each point is inside the interval. + """ + mod = utils.smart_numpy(x) + return mod.logical_and(self.l <= x, x <= self.r).flatten() + + def on_boundary(self, x): + """ + Check if points are on the boundary of the interval. + + Args: + x (array-like): The points to check. + + Returns: + array: Boolean array indicating whether each point is on the boundary. + """ + mod = utils.smart_numpy(x) + return mod.any( + mod.isclose(x, jnp.array([self.l, self.r], dtype=bst.environ.dftype())), + axis=-1, + ) + + def distance2boundary(self, x, dirn): + """ + Compute the distance from points to the boundary in a specified direction. + + Args: + x (array-like): The points to compute distance for. + dirn (int): Direction indicator (-1 for left, 1 for right). + + Returns: + array: Distances to the boundary. + """ + return x - self.l if dirn < 0 else self.r - x + + def mindist2boundary(self, x): + """ + Compute the minimum distance from points to the boundary. + + Args: + x (array-like): The points to compute distance for. + + Returns: + float: Minimum distance to the boundary. + """ + mod = utils.smart_numpy(x) + return min(mod.amin(x - self.l), mod.amin(self.r - x)) + + def boundary_constraint_factor( + self, + x, + smoothness: Literal["C0", "C0+", "Cinf"] = "C0+", + where: Union[None, Literal["left", "right"]] = None, + ): + """Compute the hard constraint factor at x for the boundary. + + This function is used for the hard-constraint methods in Physics-Informed Neural Networks (PINNs). + The hard constraint factor satisfies the following properties: + + - The function is zero on the boundary and positive elsewhere. + - The function is at least continuous. + + In the ansatz `boundary_constraint_factor(x) * NN(x) + boundary_condition(x)`, when `x` is on the boundary, + `boundary_constraint_factor(x)` will be zero, making the ansatz be the boundary condition, which in + turn makes the boundary condition a "hard constraint". + + Args: + x: A 2D array of shape (n, dim), where `n` is the number of points and + `dim` is the dimension of the geometry. Note that `x` should be a tensor type + of backend (e.g., `tf.Tensor` or `torch.Tensor`), not a numpy array. + smoothness (string, optional): A string to specify the smoothness of the distance function, + e.g., "C0", "C0+", "Cinf". "C0" is the least smooth, "Cinf" is the most smooth. + Default is "C0+". + + - C0 + The distance function is continuous but may not be non-differentiable. + But the set of non-differentiable points should have measure zero, + which makes the probability of the collocation point falling in this set be zero. + + - C0+ + The distance function is continuous and differentiable almost everywhere. The + non-differentiable points can only appear on boundaries. If the points in `x` are + all inside or outside the geometry, the distance function is smooth. + + - Cinf + The distance function is continuous and differentiable at any order on any + points. This option may result in a polynomial of HIGH order. + + where (string, optional): A string to specify which part of the boundary to compute the distance, + e.g., "left", "right". If `None`, compute the distance to the whole boundary. Default is `None`. + + Returns: + A tensor of a type determined by the backend, which will have a shape of (n, 1). + Each element in the tensor corresponds to the computed distance value for the respective point in `x`. + """ + + if where not in [None, "left"]: + raise ValueError("where must be None or left") + if smoothness not in ["C0", "C0+", "Cinf"]: + raise ValueError("smoothness must be one of C0, C0+, Cinf") + + # To convert self.l and self.r to tensor, + # and avoid repeated conversion in the loop + if not hasattr(self, "self.l_tensor"): + self.l_tensor = jnp.asarray(self.l) + self.r_tensor = jnp.asarray(self.r) + + dist_l = dist_r = None + if where != "right": + dist_l = jnp.abs((x - self.l_tensor) / (self.r_tensor - self.l_tensor) * 2) + if where != "left": + dist_r = jnp.abs((x - self.r_tensor) / (self.r_tensor - self.l_tensor) * 2) + + if where is None: + if smoothness == "C0": + return jnp.minimum(dist_l, dist_r) + if smoothness == "C0+": + return dist_l * dist_r + return jnp.square(dist_l * dist_r) + if where == "left": + if smoothness == "Cinf": + dist_l = jnp.square(dist_l) + return dist_l + if smoothness == "Cinf": + dist_r = jnp.square(dist_r) + return dist_r + + def boundary_normal(self, x): + """ + Compute the normal vector at boundary points. + + Args: + x (array-like): The points to compute normal vectors for. + + Returns: + array: Normal vectors at the given points. + """ + mod = utils.smart_numpy(x) + return -mod.isclose(x, self.l).astype(bst.environ.dftype()) + mod.isclose( + x, self.r + ) + + def uniform_points(self, n, boundary=True): + """ + Generate uniformly distributed points in the interval. + + Args: + n (int): Number of points to generate. + boundary (bool): Whether to include boundary points. + + Returns: + array: Uniformly distributed points. + """ + if boundary: + return jnp.linspace(self.l, self.r, num=n, dtype=bst.environ.dftype())[ + :, None + ] + return jnp.linspace( + self.l, self.r, num=n + 1, endpoint=False, dtype=bst.environ.dftype() + )[1:, None] + + def log_uniform_points(self, n, boundary=True): + """ + Generate logarithmically uniformly distributed points in the interval. + + Args: + n (int): Number of points to generate. + boundary (bool): Whether to include boundary points. + + Returns: + array: Logarithmically uniformly distributed points. + """ + eps = 0 if self.l > 0 else jnp.finfo(bst.environ.dftype()).eps + l = jnp.log(self.l + eps) + r = jnp.log(self.r + eps) + if boundary: + x = jnp.linspace(l, r, num=n, dtype=bst.environ.dftype())[:, None] + else: + x = jnp.linspace( + l, r, num=n + 1, endpoint=False, dtype=bst.environ.dftype() + )[1:, None] + return jnp.exp(x) - eps + + def random_points(self, n, random="pseudo"): + """ + Generate random points in the interval. + + Args: + n (int): Number of points to generate. + random (str): Type of random number generation ("pseudo" or other). + + Returns: + array: Randomly distributed points. + """ + x = sample(n, 1, random) + return (self.diam * x + self.l).astype(bst.environ.dftype()) + + def uniform_boundary_points(self, n): + """ + Generate uniformly distributed points on the boundary. + + Args: + n (int): Number of points to generate. + + Returns: + array: Uniformly distributed boundary points. + """ + if n == 1: + return jnp.array([[self.l]]).astype(bst.environ.dftype()) + xl = jnp.full((n // 2, 1), self.l).astype(bst.environ.dftype()) + xr = jnp.full((n - n // 2, 1), self.r).astype(bst.environ.dftype()) + return jnp.vstack((xl, xr)) + + def random_boundary_points(self, n, random="pseudo"): + """ + Generate random points on the boundary. + + Args: + n (int): Number of points to generate. + random (str): Type of random number generation ("pseudo" or other). + + Returns: + array: Randomly distributed boundary points. + """ + if n == 2: + return jnp.array([[self.l], [self.r]]).astype(bst.environ.dftype()) + return bst.random.choice([self.l, self.r], n)[:, None].astype( + bst.environ.dftype() + ) + + def periodic_point(self, x, component=0): + """ + Map points to their periodic equivalents within the interval. + + Args: + x (array-like): Points to map. + component (int): Component to apply periodicity to (unused in 1D). + + Returns: + array: Mapped periodic points. + """ + tmp = jnp.copy(x) + tmp[utils.isclose(x, self.l)] = self.r + tmp[utils.isclose(x, self.r)] = self.l + return tmp + + def background_points(self, x, dirn, dist2npt, shift): + """ + Generate background points based on a given point and direction. + + Args: + x (array-like): Reference point. + dirn (int): Direction (-1 for left, 1 for right, 0 for both). + dist2npt (callable): Function to convert distance to number of points. + shift (int): Number of points to shift. + + Returns: + array: Generated background points. + """ + + def background_points_left(): + dx = x[0] - self.l + n = max(dist2npt(dx), 1) + h = dx / n + pts = ( + x[0] - jnp.arange(-shift, n - shift + 1, dtype=bst.environ.dftype()) * h + ) + return pts[:, None] + + def background_points_right(): + dx = self.r - x[0] + n = max(dist2npt(dx), 1) + h = dx / n + pts = ( + x[0] + jnp.arange(-shift, n - shift + 1, dtype=bst.environ.dftype()) * h + ) + return pts[:, None] + + return ( + background_points_left() + if dirn < 0 + else ( + background_points_right() + if dirn > 0 + else jnp.vstack((background_points_left(), background_points_right())) + ) + ) diff --git a/deepxde/experimental/geometry/geometry_2d.py b/deepxde/experimental/geometry/geometry_2d.py new file mode 100644 index 000000000..cc15cd66f --- /dev/null +++ b/deepxde/experimental/geometry/geometry_2d.py @@ -0,0 +1,1123 @@ +__all__ = ["Disk", "Ellipse", "Polygon", "Rectangle", "StarShaped", "Triangle"] + +from typing import Union, Literal + +import brainstate as bst +import jax.numpy as jnp +from scipy import spatial + +from deepxde.geometry.sampler import sample +from deepxde.experimental import utils +from deepxde.utils.internal import vectorize +from .base import GeometryExperimental as Geometry +from .geometry_nd import Hypercube, Hypersphere +from ..utils import isclose + + +class Disk(Hypersphere): + """ + A class representing a disk in 2D space, inheriting from Hypersphere. + """ + + def inside(self, x): + """ + Check if points are inside the disk. + + Args: + x (array-like): The coordinates of points to check. + + Returns: + array-like: Boolean array indicating whether each point is inside the disk. + """ + mod = utils.smart_numpy(x) + return mod.linalg.norm(x - self.center, axis=-1) <= self.radius + + def on_boundary(self, x): + """ + Check if points are on the boundary of the disk. + + Args: + x (array-like): The coordinates of points to check. + + Returns: + array-like: Boolean array indicating whether each point is on the disk's boundary. + """ + mod = utils.smart_numpy(x) + return mod.isclose(mod.linalg.norm(x - self.center, axis=-1), self.radius) + + def distance2boundary_unitdirn(self, x, dirn): + """ + Calculate the distance from points to the disk boundary in a given unit direction. + + Args: + x (array-like): The coordinates of points. + dirn (array-like): The unit direction vector. + + Returns: + array-like: Distances from points to the disk boundary in the given direction. + """ + mod = utils.smart_numpy(x) + xc = x - self.center + ad = jnp.dot(xc, dirn) + return (-ad + (ad**2 - mod.sum(xc * xc, axis=-1) + self._r2) ** 0.5).astype( + bst.environ.dftype() + ) + + def distance2boundary(self, x, dirn): + """ + Calculate the distance from points to the disk boundary in a given direction. + + Args: + x (array-like): The coordinates of points. + dirn (array-like): The direction vector (not necessarily unit). + + Returns: + array-like: Distances from points to the disk boundary in the given direction. + """ + mod = utils.smart_numpy(x) + return self.distance2boundary_unitdirn(x, dirn / mod.linalg.norm(dirn)) + + def mindist2boundary(self, x): + """ + Calculate the minimum distance from points to the disk boundary. + + Args: + x (array-like): The coordinates of points. + + Returns: + array-like: Minimum distances from points to the disk boundary. + """ + mod = utils.smart_numpy(x) + return mod.amin(self.radius - mod.linalg.norm(x - self.center, axis=1)) + + def boundary_normal(self, x): + """ + Calculate the unit normal vector to the disk boundary at given points. + + Args: + x (array-like): The coordinates of points on the disk boundary. + + Returns: + array-like: Unit normal vectors to the disk boundary at the given points. + """ + mod = utils.smart_numpy(x) + _n = x - self.center + l = mod.linalg.norm(_n, axis=-1, keepdims=True) + _n = _n / l * mod.isclose(l, self.radius) + return _n + + def random_points(self, n, random="pseudo"): + """ + Generate random points inside the disk. + + Args: + n (int): Number of points to generate. + random (str, optional): Method for generating random numbers. Defaults to "pseudo". + + Returns: + array-like: Coordinates of randomly generated points inside the disk. + """ + rng = sample(n, 2, random) + r, theta = rng[:, 0], 2 * jnp.pi * rng[:, 1] + x, y = jnp.cos(theta), jnp.sin(theta) + return self.radius * (jnp.sqrt(r) * jnp.vstack((x, y))).T + self.center + + def uniform_boundary_points(self, n): + """ + Generate uniformly distributed points on the disk boundary. + + Args: + n (int): Number of points to generate. + + Returns: + array-like: Coordinates of uniformly distributed points on the disk boundary. + """ + theta = jnp.linspace(0, 2 * jnp.pi, num=n, endpoint=False) + X = jnp.vstack((jnp.cos(theta), jnp.sin(theta))).T + return self.radius * X + self.center + + def random_boundary_points(self, n, random="pseudo"): + """ + Generate random points on the disk boundary. + + Args: + n (int): Number of points to generate. + random (str, optional): Method for generating random numbers. Defaults to "pseudo". + + Returns: + array-like: Coordinates of randomly generated points on the disk boundary. + """ + u = sample(n, 1, random) + theta = 2 * jnp.pi * u + X = jnp.hstack((jnp.cos(theta), jnp.sin(theta))) + return self.radius * X + self.center + + def background_points(self, x, dirn, dist2npt, shift): + """ + Generate background points along a line passing through given points. + + Args: + x (array-like): The coordinates of points. + dirn (array-like): The direction vector. + dist2npt (callable): Function to determine number of points based on distance. + shift (float): Shift factor for point generation. + + Returns: + array-like: Coordinates of generated background points. + """ + dirn = dirn / jnp.linalg.norm(dirn) + dx = self.distance2boundary_unitdirn(x, -dirn) + n = max(dist2npt(dx), 1) + h = dx / n + pts = ( + x + - jnp.arange(-shift, n - shift + 1, dtype=bst.environ.dftype())[:, None] + * h + * dirn + ) + return pts + + +class Ellipse(Geometry): + """ + A class representing an ellipse in 2D space. + + This class inherits from the Geometry class and provides methods for working with ellipses, + including generating points on the boundary and inside the ellipse, and computing distances. + + Args: + center (array-like): The coordinates of the center of the ellipse. + semimajor (float): The length of the semimajor axis of the ellipse. + semiminor (float): The length of the semiminor axis of the ellipse. + angle (float, optional): The rotation angle of the ellipse in radians. Defaults to 0. + A positive angle rotates the ellipse clockwise about the center, + while a negative angle rotates it counterclockwise. + """ + + def __init__(self, center, semimajor, semiminor, angle=0): + self.center = jnp.array(center, dtype=bst.environ.dftype()) + self.semimajor = semimajor + self.semiminor = semiminor + self.angle = angle + self.c = (semimajor**2 - semiminor**2) ** 0.5 + + self.focus1 = jnp.array( + [ + center[0] - self.c * jnp.cos(angle), + center[1] + self.c * jnp.sin(angle), + ], + dtype=bst.environ.dftype(), + ) + self.focus2 = jnp.array( + [ + center[0] + self.c * jnp.cos(angle), + center[1] - self.c * jnp.sin(angle), + ], + dtype=bst.environ.dftype(), + ) + self.rotation_mat = jnp.array( + [[jnp.cos(-angle), -jnp.sin(-angle)], [jnp.sin(-angle), jnp.cos(-angle)]] + ) + ( + self.theta_from_arc_length, + self.total_arc, + ) = self._theta_from_arc_length_constructor() + super().__init__( + 2, (self.center - semimajor, self.center + semiminor), 2 * self.c + ) + + def on_boundary(self, x): + """ + Check if points are on the boundary of the ellipse. + + Args: + x (array-like): The coordinates of points to check. + + Returns: + array-like: Boolean array indicating whether each point is on the ellipse's boundary. + """ + d1 = jnp.linalg.norm(x - self.focus1, axis=-1) + d2 = jnp.linalg.norm(x - self.focus2, axis=-1) + return isclose(d1 + d2, 2 * self.semimajor) + + def inside(self, x): + """ + Check if points are inside the ellipse. + + Args: + x (array-like): The coordinates of points to check. + + Returns: + array-like: Boolean array indicating whether each point is inside the ellipse. + """ + d1 = jnp.linalg.norm(x - self.focus1, axis=-1) + d2 = jnp.linalg.norm(x - self.focus2, axis=-1) + return d1 + d2 <= 2 * self.semimajor + + def _ellipse_arc(self): + """ + Calculate the cumulative arc length of the ellipse. + + Returns: + tuple: A tuple containing: + - theta (array-like): Angle values. + - cumulative_distance (array-like): Cumulative distance at each theta. + - c (float): Total arc length of the ellipse. + """ + # Divide the interval [0 , theta] into n steps at regular angles + theta = jnp.linspace(0, 2 * jnp.pi, 10000) + coords = jnp.array( + [self.semimajor * jnp.cos(theta), self.semiminor * jnp.sin(theta)] + ) + # Compute vector distance between each successive point + coords_diffs = jnp.diff(coords) + # Compute the full arc + delta_r = jnp.linalg.norm(coords_diffs, axis=0) + cumulative_distance = jnp.concatenate(([0], jnp.cumsum(delta_r))) + c = jnp.sum(delta_r) + return theta, cumulative_distance, c + + def _theta_from_arc_length_constructor(self): + """ + Construct a function that returns the angle associated with a given cumulative arc length. + + Returns: + tuple: A tuple containing: + - f (callable): A function that takes an arc length and returns the corresponding angle. + - total_arc (float): The total arc length of the ellipse. + """ + theta, cumulative_distance, total_arc = self._ellipse_arc() + + # Construct the inverse arc length function + def f(s): + return jnp.interp(s, cumulative_distance, theta) + + return f, total_arc + + def random_points(self, n, random="pseudo"): + """ + Generate random points inside the ellipse. + + Args: + n (int): Number of points to generate. + random (str, optional): Method for generating random numbers. Defaults to "pseudo". + + Returns: + array-like: Coordinates of randomly generated points inside the ellipse. + """ + # http://mathworld.wolfram.com/DiskPointPicking.html + rng = sample(n, 2, random) + r, theta = rng[:, 0], 2 * jnp.pi * rng[:, 1] + x, y = self.semimajor * jnp.cos(theta), self.semiminor * jnp.sin(theta) + X = jnp.sqrt(r) * jnp.vstack((x, y)) + return jnp.matmul(self.rotation_mat, X).T + self.center + + def uniform_boundary_points(self, n): + """ + Generate uniformly distributed points on the ellipse boundary. + + Args: + n (int): Number of points to generate. + + Returns: + array-like: Coordinates of uniformly distributed points on the ellipse boundary. + """ + # https://codereview.stackexchange.com/questions/243590/generate-random-points-on-perimeter-of-ellipse + u = jnp.linspace(0, 1, num=n, endpoint=False).reshape((-1, 1)) + theta = self.theta_from_arc_length(u * self.total_arc) + X = jnp.hstack( + (self.semimajor * jnp.cos(theta), self.semiminor * jnp.sin(theta)) + ) + return jnp.matmul(self.rotation_mat, X.T).T + self.center + + def random_boundary_points(self, n, random="pseudo"): + """ + Generate random points on the ellipse boundary. + + Args: + n (int): Number of points to generate. + random (str, optional): Method for generating random numbers. Defaults to "pseudo". + + Returns: + array-like: Coordinates of randomly generated points on the ellipse boundary. + """ + u = sample(n, 1, random) + theta = self.theta_from_arc_length(u * self.total_arc) + X = jnp.hstack( + (self.semimajor * jnp.cos(theta), self.semiminor * jnp.sin(theta)) + ) + return jnp.matmul(self.rotation_mat, X.T).T + self.center + + def boundary_constraint_factor( + self, x, smoothness: Literal["C0", "C0+", "Cinf"] = "C0+" + ): + """ + Compute the boundary constraint factor for given points. + + This function calculates a factor that represents how close points are to the ellipse boundary. + The factor is zero on the boundary and positive elsewhere. + + Args: + x (array-like): The coordinates of points to evaluate. + smoothness (str, optional): The smoothness of the constraint factor. + Must be one of "C0", "C0+", or "Cinf". Defaults to "C0+". + + Returns: + array-like: The computed boundary constraint factors for the input points. + + Raises: + ValueError: If smoothness is not one of the allowed values. + """ + if smoothness not in ["C0", "C0+", "Cinf"]: + raise ValueError("`smoothness` must be one of C0, C0+, Cinf") + + if not hasattr(self, "self.focus1_tensor"): + self.focus1_tensor = jnp.asarray(self.focus1) + self.focus2_tensor = jnp.asarray(self.focus2) + + d1 = jnp.linalg.norm(x - self.focus1_tensor, axis=-1, keepdims=True) + d2 = jnp.linalg.norm(x - self.focus2_tensor, axis=-1, keepdims=True) + dist = d1 + d2 - 2 * self.semimajor + + if smoothness == "Cinf": + dist = jnp.square(dist) + else: + dist = jnp.abs(dist) + + return dist + + +class Rectangle(Hypercube): + """ + A class representing a rectangle in 2D space. + + This class inherits from the Hypercube class and provides methods for working with rectangles, + including generating points on the boundary and inside the rectangle, and computing distances. + + Args: + xmin: Coordinate of bottom left corner. + xmax: Coordinate of top right corner. + """ + + def __init__(self, xmin, xmax): + """ + Initialize a Rectangle object. + + Args: + xmin (array-like): Coordinate of the bottom left corner. + xmax (array-like): Coordinate of the top right corner. + """ + super().__init__(xmin, xmax) + self.perimeter = 2 * jnp.sum(self.xmax - self.xmin) + self.area = jnp.prod(self.xmax - self.xmin) + + def uniform_boundary_points(self, n): + """ + Generate uniformly distributed points on the rectangle boundary. + + Args: + n (int): Number of points to generate. + + Returns: + array-like: Coordinates of uniformly distributed points on the rectangle boundary. + """ + nx, ny = jnp.ceil(n / self.perimeter * (self.xmax - self.xmin)).astype(int) + xbot = jnp.hstack( + ( + jnp.linspace(self.xmin[0], self.xmax[0], num=nx, endpoint=False)[ + :, None + ], + jnp.full([nx, 1], self.xmin[1]), + ) + ) + yrig = jnp.hstack( + ( + jnp.full([ny, 1], self.xmax[0]), + jnp.linspace(self.xmin[1], self.xmax[1], num=ny, endpoint=False)[ + :, None + ], + ) + ) + xtop = jnp.hstack( + ( + jnp.linspace(self.xmin[0], self.xmax[0], num=nx + 1)[1:, None], + jnp.full([nx, 1], self.xmax[1]), + ) + ) + ylef = jnp.hstack( + ( + jnp.full([ny, 1], self.xmin[0]), + jnp.linspace(self.xmin[1], self.xmax[1], num=ny + 1)[1:, None], + ) + ) + x = jnp.vstack((xbot, yrig, xtop, ylef)) + if n != len(x): + print( + "Warning: {} points required, but {} points sampled.".format(n, len(x)) + ) + return x + + def random_boundary_points(self, n, random="pseudo"): + """ + Generate random points on the rectangle boundary. + + Args: + n (int): Number of points to generate. + random (str, optional): Method for generating random numbers. Defaults to "pseudo". + + Returns: + array-like: Coordinates of randomly generated points on the rectangle boundary. + """ + l1 = self.xmax[0] - self.xmin[0] + l2 = l1 + self.xmax[1] - self.xmin[1] + l3 = l2 + l1 + u = jnp.ravel(sample(n + 2, 1, random)) + # Remove the possible points very close to the corners + u = u[jnp.logical_not(isclose(u, l1 / self.perimeter))] + u = u[jnp.logical_not(isclose(u, l3 / self.perimeter))] + u = u[:n] + + u *= self.perimeter + x = [] + for l in u: + if l < l1: + x.append([self.xmin[0] + l, self.xmin[1]]) + elif l < l2: + x.append([self.xmax[0], self.xmin[1] + l - l1]) + elif l < l3: + x.append([self.xmax[0] - l + l2, self.xmax[1]]) + else: + x.append([self.xmin[0], self.xmax[1] - l + l3]) + return jnp.vstack(x) + + @staticmethod + def is_valid(vertices): + """ + Check if the geometry is a valid Rectangle. + + Args: + vertices (array-like): An array of 4 vertices defining the rectangle. + + Returns: + bool: True if the geometry is a valid rectangle, False otherwise. + """ + return ( + len(vertices) == 4 + and isclose(jnp.prod(vertices[1] - vertices[0]), 0) + and isclose(jnp.prod(vertices[2] - vertices[1]), 0) + and isclose(jnp.prod(vertices[3] - vertices[2]), 0) + and isclose(jnp.prod(vertices[0] - vertices[3]), 0) + ) + + +class StarShaped(Geometry): + """Star-shaped 2d domain, i.e., a geometry whose boundary is parametrized in polar coordinates as: + + $$ + r(theta) := r_0 + sum_{i = 1}^N [a_i cos( i theta) + b_i sin(i theta) ], theta in [0,2 pi]. + $$ + + For more details, refer to: + `Hiptmair et al. Large deformation shape uncertainty quantification in acoustic + scattering. Adv Comp Math, 2018. + `_ + + Args: + center: Center of the domain. + radius: 0th-order term of the parametrization (r_0). + coeffs_cos: i-th order coefficients for the i-th cos term (a_i). + coeffs_sin: i-th order coefficients for the i-th sin term (b_i). + """ + + def __init__(self, center, radius, coeffs_cos, coeffs_sin): + self.center = jnp.array(center, dtype=bst.environ.dftype()) + self.radius = radius + self.coeffs_cos = coeffs_cos + self.coeffs_sin = coeffs_sin + max_radius = radius + jnp.sum(coeffs_cos) + jnp.sum(coeffs_sin) + super().__init__( + 2, + (self.center - max_radius, self.center + max_radius), + 2 * max_radius, + ) + + def _r_theta(self, theta): + """Define the parametrization r(theta) at angles theta.""" + result = self.radius * jnp.ones(theta.shape) + for i, (coeff_cos, coeff_sin) in enumerate( + zip(self.coeffs_cos, self.coeffs_sin), start=1 + ): + result += coeff_cos * jnp.cos(i * theta) + coeff_sin * jnp.sin(i * theta) + return result + + def _dr_theta(self, theta): + """Evalutate the polar derivative r'(theta) at angles theta""" + result = jnp.zeros(theta.shape) + for i, (coeff_cos, coeff_sin) in enumerate( + zip(self.coeffs_cos, self.coeffs_sin), start=1 + ): + result += -coeff_cos * i * jnp.sin(i * theta) + coeff_sin * i * jnp.cos( + i * theta + ) + return result + + def inside(self, x): + r, theta = polar(x - self.center) + r_theta = self._r_theta(theta) + return r_theta >= r + + def on_boundary(self, x): + r, theta = polar(x - self.center) + r_theta = self._r_theta(theta) + return isclose(jnp.linalg.norm(r_theta - r), 0) + + def boundary_normal(self, x): + _, theta = polar(x - self.center) + dr_theta = self._dr_theta(theta) + r_theta = self._r_theta(theta) + + dxt = jnp.vstack( + ( + dr_theta * jnp.cos(theta) - r_theta * jnp.sin(theta), + dr_theta * jnp.sin(theta) + r_theta * jnp.cos(theta), + ) + ).T + norm = jnp.linalg.norm(dxt, axis=-1, keepdims=True) + dxt /= norm + return jnp.array([dxt[:, 1], -dxt[:, 0]]).T + + def random_points(self, n, random="pseudo"): + x = jnp.empty((0, 2), dtype=bst.environ.dftype()) + vbbox = self.bbox[1] - self.bbox[0] + while len(x) < n: + x_new = sample(n, 2, sampler="pseudo") * vbbox + self.bbox[0] + x = jnp.vstack((x, x_new[self.inside(x_new)])) + return x[:n] + + def uniform_boundary_points(self, n): + theta = jnp.linspace(0, 2 * jnp.pi, num=n, endpoint=False) + r_theta = self._r_theta(theta) + X = jnp.vstack((r_theta * jnp.cos(theta), r_theta * jnp.sin(theta))).T + return X + self.center + + def random_boundary_points(self, n, random="pseudo"): + u = sample(n, 1, random) + theta = 2 * jnp.pi * u + r_theta = self._r_theta(theta) + X = jnp.hstack((r_theta * jnp.cos(theta), r_theta * jnp.sin(theta))) + return X + self.center + + +class Triangle(Geometry): + """Triangle. + + The order of vertices can be in a clockwise or counterclockwise direction. The + vertices will be re-ordered in counterclockwise (right hand rule). + """ + + def __init__(self, x1, x2, x3): + self.area = polygon_signed_area([x1, x2, x3]) + # Clockwise + if self.area < 0: + self.area = -self.area + x2, x3 = x3, x2 + + self.x1 = jnp.array(x1, dtype=bst.environ.dftype()) + self.x2 = jnp.array(x2, dtype=bst.environ.dftype()) + self.x3 = jnp.array(x3, dtype=bst.environ.dftype()) + + self.v12 = self.x2 - self.x1 + self.v23 = self.x3 - self.x2 + self.v31 = self.x1 - self.x3 + self.l12 = jnp.linalg.norm(self.v12) + self.l23 = jnp.linalg.norm(self.v23) + self.l31 = jnp.linalg.norm(self.v31) + self.n12 = self.v12 / self.l12 + self.n23 = self.v23 / self.l23 + self.n31 = self.v31 / self.l31 + self.n12_normal = clockwise_rotation_90(self.n12) + self.n23_normal = clockwise_rotation_90(self.n23) + self.n31_normal = clockwise_rotation_90(self.n31) + self.perimeter = self.l12 + self.l23 + self.l31 + + super().__init__( + 2, + ( + jnp.minimum(x1, jnp.minimum(x2, x3)), + jnp.maximum(x1, jnp.maximum(x2, x3)), + ), + self.l12 + * self.l23 + * self.l31 + / ( + self.perimeter + * (self.l12 + self.l23 - self.l31) + * (self.l23 + self.l31 - self.l12) + * (self.l31 + self.l12 - self.l23) + ) + ** 0.5, + ) + + def inside(self, x): + # https://stackoverflow.com/a/2049593/12679294 + _sign = jnp.hstack( + [ + jnp.cross(self.v12, x - self.x1)[:, jnp.newaxis], + jnp.cross(self.v23, x - self.x2)[:, jnp.newaxis], + jnp.cross(self.v31, x - self.x3)[:, jnp.newaxis], + ] + ) + return ~jnp.logical_and( + jnp.any(_sign > 0, axis=-1), jnp.any(_sign < 0, axis=-1) + ) + + def on_boundary(self, x): + l1 = jnp.linalg.norm(x - self.x1, axis=-1) + l2 = jnp.linalg.norm(x - self.x2, axis=-1) + l3 = jnp.linalg.norm(x - self.x3, axis=-1) + return jnp.any( + isclose( + [l1 + l2 - self.l12, l2 + l3 - self.l23, l3 + l1 - self.l31], + 0, + ), + axis=0, + ) + + def boundary_normal(self, x): + l1 = jnp.linalg.norm(x - self.x1, axis=-1, keepdims=True) + l2 = jnp.linalg.norm(x - self.x2, axis=-1, keepdims=True) + l3 = jnp.linalg.norm(x - self.x3, axis=-1, keepdims=True) + on12 = isclose(l1 + l2, self.l12) + on23 = isclose(l2 + l3, self.l23) + on31 = isclose(l3 + l1, self.l31) + # Check points on the vertexes + if jnp.any(jnp.count_nonzero(jnp.hstack([on12, on23, on31]), axis=-1) > 1): + raise ValueError( + "{}.boundary_normal do not accept points on the vertexes.".format( + self.__class__.__name__ + ) + ) + return self.n12_normal * on12 + self.n23_normal * on23 + self.n31_normal * on31 + + def random_points(self, n, random="pseudo"): + # There are two methods for triangle point picking. + # Method 1 (used here): + # - https://math.stackexchange.com/questions/18686/uniform-random-point-in-triangle + # Method 2: + # - http://mathworld.wolfram.com/TrianglePointPicking.html + # - https://hbfs.wordpress.com/2010/10/05/random-points-in-a-triangle-generating-random-sequences-ii/ + # - https://stackoverflow.com/questions/19654251/random-point-inside-triangle-inside-java + sqrt_r1 = jnp.sqrt(bst.random.rand(n, 1)) + r2 = bst.random.rand(n, 1) + return ( + (1 - sqrt_r1) * self.x1 + + sqrt_r1 * (1 - r2) * self.x2 + + r2 * sqrt_r1 * self.x3 + ) + + def uniform_boundary_points(self, n): + density = n / self.perimeter + x12 = ( + jnp.linspace(0, 1, num=int(jnp.ceil(density * self.l12)), endpoint=False)[ + :, None + ] + * self.v12 + + self.x1 + ) + x23 = ( + jnp.linspace(0, 1, num=int(jnp.ceil(density * self.l23)), endpoint=False)[ + :, None + ] + * self.v23 + + self.x2 + ) + x31 = ( + jnp.linspace(0, 1, num=int(jnp.ceil(density * self.l31)), endpoint=False)[ + :, None + ] + * self.v31 + + self.x3 + ) + x = jnp.vstack((x12, x23, x31)) + if n != len(x): + print( + "Warning: {} points required, but {} points sampled.".format(n, len(x)) + ) + return x + + def random_boundary_points(self, n, random="pseudo"): + u = jnp.ravel(sample(n + 2, 1, random)) + # Remove the possible points very close to the corners + u = u[jnp.logical_not(isclose(u, self.l12 / self.perimeter))] + u = u[jnp.logical_not(isclose(u, (self.l12 + self.l23) / self.perimeter))] + u = u[:n] + + u *= self.perimeter + x = [] + for l in u: + if l < self.l12: + x.append(l * self.n12 + self.x1) + elif l < self.l12 + self.l23: + x.append((l - self.l12) * self.n23 + self.x2) + else: + x.append((l - self.l12 - self.l23) * self.n31 + self.x3) + return jnp.vstack(x) + + def boundary_constraint_factor( + self, + x, + smoothness: Literal["C0", "C0+", "Cinf"] = "C0+", + where: Union[None, Literal["x1-x2", "x1-x3", "x2-x3"]] = None, + ): + """Compute the hard constraint factor at x for the boundary. + + This function is used for the hard-constraint methods in Physics-Informed Neural Networks (PINNs). + The hard constraint factor satisfies the following properties: + + - The function is zero on the boundary and positive elsewhere. + - The function is at least continuous. + + In the ansatz `boundary_constraint_factor(x) * NN(x) + boundary_condition(x)`, when `x` is on the boundary, + `boundary_constraint_factor(x)` will be zero, making the ansatz be the boundary condition, which in + turn makes the boundary condition a "hard constraint". + + Args: + x: A 2D array of shape (n, dim), where `n` is the number of points and + `dim` is the dimension of the geometry. Note that `x` should be a tensor type + of backend (e.g., `tf.Tensor` or `torch.Tensor`), not a numpy array. + smoothness (string, optional): A string to specify the smoothness of the distance function, + e.g., "C0", "C0+", "Cinf". "C0" is the least smooth, "Cinf" is the most smooth. + Default is "C0+". + + - C0 + The distance function is continuous but may not be non-differentiable. + But the set of non-differentiable points should have measure zero, + which makes the probability of the collocation point falling in this set be zero. + + - C0+ + The distance function is continuous and differentiable almost everywhere. The + non-differentiable points can only appear on boundaries. If the points in `x` are + all inside or outside the geometry, the distance function is smooth. + + - Cinf + The distance function is continuous and differentiable at any order on any + points. This option may result in a polynomial of HIGH order. + + where (string, optional): A string to specify which part of the boundary to compute the distance. + If `None`, compute the distance to the whole boundary. + "x1-x2" indicates the line segment with vertices x1 and x2 (after reordered). Default is `None`. + + Returns: + A tensor of a type determined by the backend, which will have a shape of (n, 1). + Each element in the tensor corresponds to the computed distance value for the respective point in `x`. + """ + + if where not in [None, "x1-x2", "x1-x3", "x2-x3"]: + raise ValueError("where must be one of None, x1-x2, x1-x3, x2-x3") + if smoothness not in ["C0", "C0+", "Cinf"]: + raise ValueError("smoothness must be one of C0, C0+, Cinf") + + if not hasattr(self, "self.x1_tensor"): + self.x1_tensor = jnp.asarray(self.x1) + self.x2_tensor = jnp.asarray(self.x2) + self.x3_tensor = jnp.asarray(self.x3) + + diff_x1_x2 = diff_x1_x3 = diff_x2_x3 = None + if where not in ["x1-x3", "x2-x3"]: + diff_x1_x2 = ( + jnp.linalg.norm(x - self.x1_tensor, axis=-1, keepdims=True) + + jnp.linalg.norm(x - self.x2_tensor, axis=-1, keepdims=True) + - self.l12 + ) + if where not in ["x1-x2", "x2-x3"]: + diff_x1_x3 = ( + jnp.linalg.norm(x - self.x1_tensor, axis=-1, keepdims=True) + + jnp.linalg.norm(x - self.x3_tensor, axis=-1, keepdims=True) + - self.l31 + ) + if where not in ["x1-x2", "x1-x3"]: + diff_x2_x3 = ( + jnp.linalg.norm(x - self.x2_tensor, axis=-1, keepdims=True) + + jnp.linalg.norm(x - self.x3_tensor, axis=-1, keepdims=True) + - self.l23 + ) + + if where is None: + if smoothness == "C0": + return jnp.minimum(jnp.minimum(diff_x1_x2, diff_x1_x3), diff_x2_x3) + return diff_x1_x2 * diff_x1_x3 * diff_x2_x3 + if where == "x1-x2": + return diff_x1_x2 + if where == "x1-x3": + return diff_x1_x3 + return diff_x2_x3 + + +class Polygon(Geometry): + """ + Represents a simple polygon geometry. + + This class creates a polygon object from a set of vertices. The vertices can be provided + in either clockwise or counterclockwise order, and will be reordered to counterclockwise + (right-hand rule) if necessary. + + Args: + vertices (list or array-like): A sequence of (x, y) coordinates defining the vertices + of the polygon. The order can be clockwise or counterclockwise. + + Raises: + ValueError: If the polygon is a triangle (use Triangle class instead) or + if the polygon is a rectangle (use Rectangle class instead). + + Attributes: + vertices (jnp.ndarray): Array of vertex coordinates. + area (float): Signed area of the polygon. + diagonals (jnp.ndarray): Square matrix of distances between vertices. + nvertices (int): Number of vertices in the polygon. + perimeter (float): Perimeter of the polygon. + bbox (jnp.ndarray): Bounding box of the polygon. + segments (jnp.ndarray): Vectors representing the edges of the polygon. + normal (jnp.ndarray): Normal vectors for each edge of the polygon. + + Note: + This class inherits from the Geometry base class and implements several + methods for working with polygons, including checking if points are inside + or on the boundary, and generating random points within or on the boundary + of the polygon. + """ + + def __init__(self, vertices): + self.vertices = jnp.array(vertices, dtype=bst.environ.dftype()) + if len(vertices) == 3: + raise ValueError("The polygon is a triangle. Use Triangle instead.") + if Rectangle.is_valid(self.vertices): + raise ValueError("The polygon is a rectangle. Use Rectangle instead.") + + self.area = polygon_signed_area(self.vertices) + # Clockwise + if self.area < 0: + self.area = -self.area + self.vertices = jnp.flipud(self.vertices) + + self.diagonals = spatial.distance.squareform( + spatial.distance.pdist(self.vertices) + ) + super().__init__( + 2, + (jnp.amin(self.vertices, axis=0), jnp.amax(self.vertices, axis=0)), + jnp.max(self.diagonals), + ) + self.nvertices = len(self.vertices) + self.perimeter = jnp.sum( + jnp.asarray( + [self.diagonals[i, i + 1] for i in range(-1, self.nvertices - 1)] + ) + ) + self.bbox = jnp.array( + [jnp.min(self.vertices, axis=0), jnp.max(self.vertices, axis=0)] + ) + + self.segments = self.vertices[1:] - self.vertices[:-1] + self.segments = jnp.vstack( + (self.vertices[0] - self.vertices[-1], self.segments) + ) + self.normal = clockwise_rotation_90(self.segments.T).T + self.normal = self.normal / jnp.linalg.norm(self.normal, axis=1).reshape(-1, 1) + + def inside(self, x): + def wn_PnPoly(P, V): + """Winding number algorithm. + + https://en.wikipedia.org/wiki/Point_in_polygon + http://geomalgorithms.com/a03-_inclusion.html + + Args: + P: A point. + V: Vertex points of a polygon. + + Returns: + wn: Winding number (=0 only if P is outside polygon). + """ + wn = jnp.zeros(len(P)) # Winding number counter + + # Repeat the first vertex at end + # Loop through all edges of the polygon + for i in range(-1, self.nvertices - 1): # Edge from V[i] to V[i+1] + tmp = jnp.all( + jnp.hstack( + [ + V[i, 1] <= P[:, 1:2], # Start y <= P[1] + V[i + 1, 1] > P[:, 1:2], # An upward crossing + is_left(V[i], V[i + 1], P) > 0, # P left of edge + ] + ), + axis=-1, + ) + wn = wn.at[tmp].add(1) # Have a valid up intersect + tmp = jnp.all( + jnp.hstack( + [ + V[i, 1] > P[:, 1:2], # Start y > P[1] + V[i + 1, 1] <= P[:, 1:2], # A downward crossing + is_left(V[i], V[i + 1], P) < 0, # P right of edge + ] + ), + axis=-1, + ) + wn = wn.at[tmp].add(-1) # Have a valid down intersect + return wn + + return wn_PnPoly(x, self.vertices) != 0 + + def on_boundary(self, x): + _on = jnp.zeros(shape=len(x), dtype=int) + for i in range(-1, self.nvertices - 1): + l1 = jnp.linalg.norm(self.vertices[i] - x, axis=-1) + l2 = jnp.linalg.norm(self.vertices[i + 1] - x, axis=-1) + _on = _on.at[isclose(l1 + l2, self.diagonals[i, i + 1])].add(1) + return _on > 0 + + @vectorize(excluded=[0], signature="(n)->(n)") + def boundary_normal(self, x): + for i in range(self.nvertices): + if is_on_line_segment(self.vertices[i - 1], self.vertices[i], x): + return self.normal[i] + return jnp.array([0, 0]) + + def random_points(self, n, random="pseudo"): + x = jnp.empty((0, 2), dtype=bst.environ.dftype()) + vbbox = self.bbox[1] - self.bbox[0] + while len(x) < n: + x_new = sample(n, 2, sampler="pseudo") * vbbox + self.bbox[0] + x = jnp.vstack((x, x_new[self.inside(x_new)])) + return x[:n] + + def uniform_boundary_points(self, n): + density = n / self.perimeter + x = [] + for i in range(-1, self.nvertices - 1): + x.append( + jnp.linspace( + 0, + 1, + num=int(jnp.ceil(density * self.diagonals[i, i + 1])), + endpoint=False, + )[:, None] + * (self.vertices[i + 1] - self.vertices[i]) + + self.vertices[i] + ) + x = jnp.vstack(x) + if n != len(x): + print( + "Warning: {} points required, but {} points sampled.".format(n, len(x)) + ) + return x + + def random_boundary_points(self, n, random="pseudo"): + u = jnp.ravel(sample(n + self.nvertices, 1, random)) + # Remove the possible points very close to the corners + l = 0 + for i in range(0, self.nvertices - 1): + l += self.diagonals[i, i + 1] + u = u[jnp.logical_not(isclose(u, l / self.perimeter))] + u = u[:n] + u *= self.perimeter + u.sort() + + x = [] + i = -1 + l0 = 0 + l1 = l0 + self.diagonals[i, i + 1] + v = (self.vertices[i + 1] - self.vertices[i]) / self.diagonals[i, i + 1] + for l in u: + if l > l1: + i += 1 + l0, l1 = l1, l1 + self.diagonals[i, i + 1] + v = (self.vertices[i + 1] - self.vertices[i]) / self.diagonals[i, i + 1] + x.append((l - l0) * v + self.vertices[i]) + return jnp.vstack(x) + + +def polygon_signed_area(vertices): + """The (signed) area of a simple polygon. + + If the vertices are in the counterclockwise direction, then the area is positive; if + they are in the clockwise direction, the area is negative. + + Shoelace formula: https://en.wikipedia.org/wiki/Shoelace_formula + """ + x, y = zip(*vertices) + x = jnp.array(list(x) + [x[0]]) + y = jnp.array(list(y) + [y[0]]) + return 0.5 * (jnp.sum(x[:-1] * y[1:]) - jnp.sum(x[1:] * y[:-1])) + + +def clockwise_rotation_90(v): + """Rotate a vector of 90 degrees clockwise about the origin.""" + return jnp.array([v[1], -v[0]]) + + +def is_left(P0, P1, P2): + """Test if a point is Left|On|Right of an infinite line. + + See: the January 2001 Algorithm "Area of 2D and 3D Triangles and Polygons". + + Args: + P0: One point in the line. + P1: One point in the line. + P2: A array of point to be tested. + + Returns: + >0 if P2 left of the line through P0 and P1, =0 if P2 on the line, <0 if P2 + right of the line. + """ + return jnp.cross(P1 - P0, P2 - P0, axis=-1).reshape((-1, 1)) + + +def is_rectangle(vertices): + """Check if the geometry is a rectangle. + + https://stackoverflow.com/questions/2303278/find-if-4-points-on-a-plane-form-a-rectangle/2304031 + + 1. Find the center of mass of corner points: cx=(x1+x2+x3+x4)/4, cy=(y1+y2+y3+y4)/4 + 2. Test if square of distances from center of mass to all 4 corners are equal + """ + if len(vertices) != 4: + return False + + c = jnp.mean(vertices, axis=0) + d = jnp.sum((vertices - c) ** 2, axis=1) + return jnp.allclose(d, jnp.full(4, d[0])) + + +def is_on_line_segment(P0, P1, P2): + """Test if a point is between two other points on a line segment. + + Args: + P0: One point in the line. + P1: One point in the line. + P2: The point to be tested. + + References: + https://stackoverflow.com/questions/328107 + """ + v01 = P1 - P0 + v02 = P2 - P0 + v12 = P2 - P1 + return ( + # check that P2 is almost on the line P0 P1 + isclose(jnp.cross(v01, v02) / jnp.linalg.norm(v01), 0) + # check that projection of P2 to line is between P0 and P1 + and v01 @ v02 >= 0 + and v01 @ v12 <= 0 + ) + # Not between P0 and P1, but close to P0 or P1 + # or isclose(np.linalg.norm(v02), 0) # check whether P2 is close to P0 + # or isclose(np.linalg.norm(v12), 0) # check whether P2 is close to P1 + + +def polar(x): + """Get the polar coordinated for a 2d vector in cartesian coordinates.""" + r = jnp.sqrt(x[:, 0] ** 2 + x[:, 1] ** 2) + theta = jnp.arctan2(x[:, 1], x[:, 0]) + return r, theta diff --git a/deepxde/experimental/geometry/geometry_3d.py b/deepxde/experimental/geometry/geometry_3d.py new file mode 100644 index 000000000..6d7f33954 --- /dev/null +++ b/deepxde/experimental/geometry/geometry_3d.py @@ -0,0 +1,222 @@ +import itertools +from typing import Union, Literal + +import brainstate as bst +import jax.numpy as jnp + +from .geometry_2d import Rectangle +from .geometry_nd import Hypercube, Hypersphere + + +class Cuboid(Hypercube): + """ + A class representing a 3D cuboid, inheriting from Hypercube. + + Args: + xmin: Coordinate of bottom left corner. + xmax: Coordinate of top right corner. + """ + + def __init__(self, xmin, xmax): + """ + Initialize the Cuboid object. + + Args: + xmin: Coordinate of bottom left corner. + xmax: Coordinate of top right corner. + """ + super().__init__(xmin, xmax) + dx = self.xmax - self.xmin + self.area = 2 * jnp.sum(dx * jnp.roll(dx, 2)) + + def random_boundary_points(self, n, random="pseudo"): + """ + Generate random points on the boundary of the cuboid. + + Args: + n (int): The number of points to generate. + random (str, optional): The type of random number generation. Defaults to "pseudo". + + Returns: + jnp.ndarray: An array of shape (n, 3) containing the generated boundary points. + """ + pts = [] + density = n / self.area + rect = Rectangle(self.xmin[:-1], self.xmax[:-1]) + for z in [self.xmin[-1], self.xmax[-1]]: + u = rect.random_points(int(jnp.ceil(density * rect.area)), random=random) + pts.append(jnp.hstack((u, jnp.full((len(u), 1), z)))) + rect = Rectangle(self.xmin[::2], self.xmax[::2]) + for y in [self.xmin[1], self.xmax[1]]: + u = rect.random_points(int(jnp.ceil(density * rect.area)), random=random) + pts.append(jnp.hstack((u[:, 0:1], jnp.full((len(u), 1), y), u[:, 1:]))) + rect = Rectangle(self.xmin[1:], self.xmax[1:]) + for x in [self.xmin[0], self.xmax[0]]: + u = rect.random_points(int(jnp.ceil(density * rect.area)), random=random) + pts.append(jnp.hstack((jnp.full((len(u), 1), x), u))) + pts = jnp.vstack(pts) + if len(pts) > n: + return pts[bst.random.choice(len(pts), size=n, replace=False)] + return pts + + def uniform_boundary_points(self, n): + """ + Generate uniformly distributed points on the boundary of the cuboid. + + Args: + n (int): The target number of points to generate. + + Returns: + jnp.ndarray: An array of shape (m, 3) containing the generated boundary points, + where m may not exactly equal n. + """ + h = (self.area / n) ** 0.5 + nx, ny, nz = jnp.ceil((self.xmax - self.xmin) / h).astype(int) + 1 + x = jnp.linspace(self.xmin[0], self.xmax[0], num=nx) + y = jnp.linspace(self.xmin[1], self.xmax[1], num=ny) + z = jnp.linspace(self.xmin[2], self.xmax[2], num=nz) + + pts = [] + for v in [self.xmin[-1], self.xmax[-1]]: + u = list(itertools.product(x, y)) + pts.append(jnp.hstack((u, jnp.full((len(u), 1), v)))) + if nz > 2: + for v in [self.xmin[1], self.xmax[1]]: + u = jnp.array(list(itertools.product(x, z[1:-1]))) + pts.append(jnp.hstack((u[:, 0:1], jnp.full((len(u), 1), v), u[:, 1:]))) + if ny > 2 and nz > 2: + for v in [self.xmin[0], self.xmax[0]]: + u = list(itertools.product(y[1:-1], z[1:-1])) + pts.append(jnp.hstack((jnp.full((len(u), 1), v), u))) + pts = jnp.vstack(pts) + if n != len(pts): + print( + "Warning: {} points required, but {} points sampled.".format( + n, len(pts) + ) + ) + return pts + + def boundary_constraint_factor( + self, + x, + smoothness: Literal["C0", "C0+", "Cinf"] = "C0+", + where: Union[ + None, Literal["back", "front", "left", "right", "bottom", "top"] + ] = None, + inside: bool = True, + ): + """ + Compute the hard constraint factor at x for the boundary. + + This function is used for the hard-constraint methods in Physics-Informed Neural Networks (PINNs). + The hard constraint factor satisfies the following properties: + + - The function is zero on the boundary and positive elsewhere. + - The function is at least continuous. + + In the ansatz `boundary_constraint_factor(x) * NN(x) + boundary_condition(x)`, when `x` is on the boundary, + `boundary_constraint_factor(x)` will be zero, making the ansatz be the boundary condition, which in + turn makes the boundary condition a "hard constraint". + + Args: + x: A 2D array of shape (n, dim), where `n` is the number of points and + `dim` is the dimension of the geometry. Note that `x` should be a tensor type + of backend (e.g., `tf.Tensor` or `torch.Tensor`), not a numpy array. + smoothness (string, optional): A string to specify the smoothness of the distance function, + e.g., "C0", "C0+", "Cinf". "C0" is the least smooth, "Cinf" is the most smooth. + Default is "C0+". + + - C0 + The distance function is continuous but may not be non-differentiable. + But the set of non-differentiable points should have measure zero, + which makes the probability of the collocation point falling in this set be zero. + + - C0+ + The distance function is continuous and differentiable almost everywhere. The + non-differentiable points can only appear on boundaries. If the points in `x` are + all inside or outside the geometry, the distance function is smooth. + + - Cinf + The distance function is continuous and differentiable at any order on any + points. This option may result in a polynomial of HIGH order. + + where (string, optional): A string to specify which part of the boundary to compute the distance. + "back": x[0] = xmin[0], "front": x[0] = xmax[0], "left": x[1] = xmin[1], + "right": x[1] = xmax[1], "bottom": x[2] = xmin[2], "top": x[2] = xmax[2]. + If `None`, compute the distance to the whole boundary. Default is `None`. + inside (bool, optional): The `x` is either inside or outside the geometry. + The cases where there are both points inside and points + outside the geometry are NOT allowed. NOTE: currently only support `inside=True`. + + Returns: + A tensor of a type determined by the backend, which will have a shape of (n, 1). + Each element in the tensor corresponds to the computed distance value for the respective point in `x`. + """ + if where not in [None, "back", "front", "left", "right", "bottom", "top"]: + raise ValueError( + "where must be one of None, back, front, left, right, bottom, top" + ) + if smoothness not in ["C0", "C0+", "Cinf"]: + raise ValueError("smoothness must be one of C0, C0+, Cinf") + if self.dim != 3: + raise ValueError("self.dim must be 3") + if not inside: + raise ValueError("inside=False is not supported for Cuboid") + + if not hasattr(self, "self.xmin_tensor"): + self.xmin_tensor = jnp.asarray(self.xmin) + self.xmax_tensor = jnp.asarray(self.xmax) + + dist_l = dist_r = None + if where not in ["front", "right", "top"]: + dist_l = jnp.abs( + (x - self.xmin_tensor) / (self.xmax_tensor - self.xmin_tensor) * 2 + ) + if where not in ["back", "left", "bottom"]: + dist_r = jnp.abs( + (x - self.xmax_tensor) / (self.xmax_tensor - self.xmin_tensor) * 2 + ) + + if where == "back": + return dist_l[:, 0:1] + if where == "front": + return dist_r[:, 0:1] + if where == "left": + return dist_l[:, 1:2] + if where == "right": + return dist_r[:, 1:2] + if where == "bottom": + return dist_l[:, 2:] + if where == "top": + return dist_r[:, 2:] + + if smoothness == "C0": + dist_l = jnp.min(dist_l, axis=-1, keepdims=True) + dist_r = jnp.min(dist_r, axis=-1, keepdims=True) + return jnp.minimum(dist_l, dist_r) + dist_l = jnp.prod(dist_l, axis=-1, keepdims=True) + dist_r = jnp.prod(dist_r, axis=-1, keepdims=True) + return dist_l * dist_r + + +class Sphere(Hypersphere): + """ + A class representing a 3D sphere, inheriting from Hypersphere. + + This class provides functionality for creating and manipulating a 3D sphere + in geometric computations and simulations. + + Args: + center (array-like): The coordinates of the center of the sphere. + Should be a sequence of 3 numbers representing x, y, and z coordinates. + radius (float): The radius of the sphere. + Must be a positive number. + + Attributes: + center (array-like): The center coordinates of the sphere. + radius (float): The radius of the sphere. + + Note: + This class inherits additional methods and attributes from the Hypersphere class. + """ diff --git a/deepxde/experimental/geometry/geometry_nd.py b/deepxde/experimental/geometry/geometry_nd.py new file mode 100644 index 000000000..279efdfda --- /dev/null +++ b/deepxde/experimental/geometry/geometry_nd.py @@ -0,0 +1,518 @@ +import itertools +from typing import Literal + +import brainstate as bst +import jax +import jax.numpy as jnp +from scipy import stats +from sklearn import preprocessing + +from deepxde.geometry.sampler import sample +from deepxde.experimental import utils +from .base import GeometryExperimental as Geometry +from ..utils import isclose + + +class Hypercube(Geometry): + """ + Represents a hypercube geometry in N-dimensional space. + + This class defines a hypercube with specified minimum and maximum coordinates + for each dimension. + """ + + def __init__(self, xmin, xmax): + """ + Initialize a Hypercube object. + + Args: + xmin (array-like): Minimum coordinates for each dimension. + xmax (array-like): Maximum coordinates for each dimension. + + Raises: + ValueError: If dimensions of xmin and xmax do not match or if xmin >= xmax. + """ + if len(xmin) != len(xmax): + raise ValueError("Dimensions of xmin and xmax do not match.") + + self.xmin = jnp.array(xmin, dtype=bst.environ.dftype()) + self.xmax = jnp.array(xmax, dtype=bst.environ.dftype()) + if jnp.any(self.xmin >= self.xmax): + raise ValueError("xmin >= xmax") + + self.side_length = self.xmax - self.xmin + super().__init__( + len(xmin), (self.xmin, self.xmax), jnp.linalg.norm(self.side_length) + ) + self.volume = jnp.prod(self.side_length) + + def inside(self, x): + """ + Check if points are inside the hypercube. + + Args: + x (array-like): Points to check. + + Returns: + array-like: Boolean array indicating whether each point is inside the hypercube. + """ + mod = utils.smart_numpy(x) + return mod.logical_and( + mod.all(x >= self.xmin, axis=-1), mod.all(x <= self.xmax, axis=-1) + ) + + def on_boundary(self, x): + """ + Check if points are on the boundary of the hypercube. + + Args: + x (array-like): Points to check. + + Returns: + array-like: Boolean array indicating whether each point is on the boundary. + """ + mod = utils.smart_numpy(x) + if x.ndim == 0: + _on_boundary = mod.logical_or( + mod.isclose(x, self.xmin), mod.isclose(x, self.xmax) + ) + else: + _on_boundary = mod.logical_or( + mod.any(mod.isclose(x, self.xmin), axis=-1), + mod.any(mod.isclose(x, self.xmax), axis=-1), + ) + return mod.logical_and(self.inside(x), _on_boundary) + + def boundary_normal(self, x): + """ + Compute the normal vectors at boundary points. + + Args: + x (array-like): Points on the boundary. + + Returns: + array-like: Normal vectors at the given points. + """ + mod = utils.smart_numpy(x) + _n = -mod.isclose(x, self.xmin).astype(bst.environ.dftype()) + mod.isclose( + x, self.xmax + ) + # For vertices, the normal is averaged for all directions + idx = mod.count_nonzero(_n, axis=-1) > 1 + _n = jax.vmap( + lambda idx_, n_: jax.numpy.where( + idx_, n_ / mod.linalg.norm(n_, keepdims=True), n_ + ) + )(idx, _n) + return mod.asarray(_n) + + def uniform_points(self, n, boundary=True): + """ + Generate uniformly distributed points in the hypercube. + + Args: + n (int): Number of points to generate. + boundary (bool): Whether to include boundary points. + + Returns: + array-like: Uniformly distributed points. + """ + dx = (self.volume / n) ** (1 / self.dim) + xi = [] + for i in range(self.dim): + ni = int(jnp.ceil(self.side_length[i] / dx)) + if boundary: + xi.append( + jnp.linspace( + self.xmin[i], self.xmax[i], num=ni, dtype=bst.environ.dftype() + ) + ) + else: + xi.append( + jnp.linspace( + self.xmin[i], + self.xmax[i], + num=ni + 1, + endpoint=False, + dtype=bst.environ.dftype(), + )[1:] + ) + x = jnp.array(list(itertools.product(*xi))) + if n != len(x): + print( + "Warning: {} points required, but {} points sampled.".format(n, len(x)) + ) + return x + + def uniform_boundary_points(self, n): + """ + Generate uniformly distributed points on the boundary of the hypercube. + + Args: + n (int): Number of points to generate. + + Returns: + array-like: Uniformly distributed boundary points. + """ + points_per_face = max(1, n // (2 * self.dim)) + points = [] + for d in range(self.dim): + for boundary in [self.xmin[d], self.xmax[d]]: + xi = [] + for i in range(self.dim): + if i == d: + xi.append(jnp.array([boundary], dtype=bst.environ.dftype())) + else: + ni = int(jnp.ceil(points_per_face ** (1 / (self.dim - 1)))) + xi.append( + jnp.linspace( + self.xmin[i], + self.xmax[i], + num=ni + 1, + endpoint=False, + dtype=bst.environ.dftype(), + )[1:] + ) + face_points = jnp.array(list(itertools.product(*xi))) + points.append(face_points) + points = jnp.vstack(points) + if n != len(points): + print( + "Warning: {} points required, but {} points sampled.".format( + n, len(points) + ) + ) + return points + + def random_points(self, n, random="pseudo"): + """ + Generate random points inside the hypercube. + + Args: + n (int): Number of points to generate. + random (str): Type of random number generation ("pseudo" or other). + + Returns: + array-like: Randomly generated points. + """ + x = sample(n, self.dim, random) + return (self.xmax - self.xmin) * x + self.xmin + + def random_boundary_points(self, n, random="pseudo"): + """ + Generate random points on the boundary of the hypercube. + + Args: + n (int): Number of points to generate. + random (str): Type of random number generation ("pseudo" or other). + + Returns: + array-like: Randomly generated boundary points. + """ + x = sample(n, self.dim, random) + # Randomly pick a dimension + rand_dim = bst.random.randint(self.dim, size=n) + # Replace value of the randomly picked dimension with the nearest boundary value (0 or 1) + x[jnp.arange(n), rand_dim] = jnp.round(x[jnp.arange(n), rand_dim]) + return (self.xmax - self.xmin) * x + self.xmin + + def periodic_point(self, x, component): + """ + Map points to their periodic counterparts on the opposite face of the hypercube. + + Args: + x (array-like): Points to map. + component (int): The dimension along which to apply periodicity. + + Returns: + array-like: Mapped periodic points. + """ + y = jnp.copy(x) + _on_xmin = isclose(y[:, component], self.xmin[component]) + _on_xmax = isclose(y[:, component], self.xmax[component]) + y[:, component][_on_xmin] = self.xmax[component] + y[:, component][_on_xmax] = self.xmin[component] + return y + + def boundary_constraint_factor( + self, + x, + smoothness: Literal["C0", "C0+", "Cinf"] = "C0", + where: None = None, + inside: bool = True, + ): + """ + Compute the hard constraint factor at x for the boundary. + + This function is used for the hard-constraint methods in Physics-Informed Neural Networks (PINNs). + The hard constraint factor satisfies the following properties: + + - The function is zero on the boundary and positive elsewhere. + - The function is at least continuous. + + In the ansatz `boundary_constraint_factor(x) * NN(x) + boundary_condition(x)`, when `x` is on the boundary, + `boundary_constraint_factor(x)` will be zero, making the ansatz be the boundary condition, which in + turn makes the boundary condition a "hard constraint". + + Args: + x: A 2D array of shape (n, dim), where `n` is the number of points and + `dim` is the dimension of the geometry. Note that `x` should be a tensor type + of backend (e.g., `tf.Tensor` or `torch.Tensor`), not a numpy array. + smoothness (string, optional): A string to specify the smoothness of the distance function, + e.g., "C0", "C0+", "Cinf". "C0" is the least smooth, "Cinf" is the most smooth. + Default is "C0". + + - C0 + The distance function is continuous but may not be non-differentiable. + But the set of non-differentiable points should have measure zero, + which makes the probability of the collocation point falling in this set be zero. + + - C0+ + The distance function is continuous and differentiable almost everywhere. The + non-differentiable points can only appear on boundaries. If the points in `x` are + all inside or outside the geometry, the distance function is smooth. + + - Cinf + The distance function is continuous and differentiable at any order on any + points. This option may result in a polynomial of HIGH order. + + - WARNING + In current implementation, + numerical underflow may happen for high dimensionalities + when `smoothness="C0+"` or `smoothness="Cinf"`. + + where (string, optional): This option is currently not supported for Hypercube. + inside (bool, optional): The `x` is either inside or outside the geometry. + The cases where there are both points inside and points + outside the geometry are NOT allowed. NOTE: currently only support `inside=True`. + + Returns: + A tensor of a type determined by the backend, which will have a shape of (n, 1). + Each element in the tensor corresponds to the computed distance value for the respective point in `x`. + """ + if where is not None: + raise ValueError("where is currently not supported for Hypercube") + if smoothness not in ["C0", "C0+", "Cinf"]: + raise ValueError("smoothness must be one of C0, C0+, Cinf") + if not inside: + raise ValueError("inside=False is not supported for Hypercube") + + if not hasattr(self, "self.xmin_tensor"): + self.xmin_tensor = jnp.asarray(self.xmin) + self.xmax_tensor = jnp.asarray(self.xmax) + + dist_l = jnp.abs( + (x - self.xmin_tensor) / (self.xmax_tensor - self.xmin_tensor) * 2 + ) + dist_r = jnp.abs( + (x - self.xmax_tensor) / (self.xmax_tensor - self.xmin_tensor) * 2 + ) + if smoothness == "C0": + dist_l = jnp.min(dist_l, axis=-1, keepdims=True) + dist_r = jnp.min(dist_r, axis=-1, keepdims=True) + return jnp.minimum(dist_l, dist_r) + # TODO: fix potential numerical underflow + dist_l = jnp.prod(dist_l, axis=-1, keepdims=True) + dist_r = jnp.prod(dist_r, dim=-1, keepdims=True) + return dist_l * dist_r + + +class Hypersphere(Geometry): + """ + Represents a hypersphere geometry in N-dimensional space. + + This class defines a hypersphere with a specified center and radius. + """ + + def __init__(self, center, radius): + """ + Initialize a Hypersphere object. + + Args: + center (array-like): Coordinates of the center of the hypersphere. + radius (float): Radius of the hypersphere. + """ + self.center = jnp.array(center, dtype=bst.environ.dftype()) + self.radius = radius + super().__init__( + len(center), (self.center - radius, self.center + radius), 2 * radius + ) + + self._r2 = radius**2 + + def inside(self, x): + """ + Check if points are inside the hypersphere. + + Args: + x (array-like): Points to check. + + Returns: + array-like: Boolean array indicating whether each point is inside the hypersphere. + """ + return jnp.linalg.norm(x - self.center, axis=-1) <= self.radius + + def on_boundary(self, x): + """ + Check if points are on the boundary of the hypersphere. + + Args: + x (array-like): Points to check. + + Returns: + array-like: Boolean array indicating whether each point is on the boundary. + """ + return isclose(jnp.linalg.norm(x - self.center, axis=-1), self.radius) + + def distance2boundary_unitdirn(self, x, dirn): + """ + Compute the distance from points to the boundary along a unit direction. + + Args: + x (array-like): Points to compute distance from. + dirn (array-like): Unit direction vector. + + Returns: + array-like: Distances from points to the boundary along the given direction. + """ + xc = x - self.center + ad = jnp.dot(xc, dirn) + return (-ad + (ad**2 - jnp.sum(xc * xc, axis=-1) + self._r2) ** 0.5).astype( + bst.environ.dftype() + ) + + def distance2boundary(self, x, dirn): + """ + Compute the distance from points to the boundary along a given direction. + + Args: + x (array-like): Points to compute distance from. + dirn (array-like): Direction vector (will be normalized). + + Returns: + array-like: Distances from points to the boundary along the given direction. + """ + return self.distance2boundary_unitdirn(x, dirn / jnp.linalg.norm(dirn)) + + def mindist2boundary(self, x): + """ + Compute the minimum distance from points to the boundary. + + Args: + x (array-like): Points to compute distance from. + + Returns: + array-like: Minimum distances from points to the boundary. + """ + return jnp.amin(self.radius - jnp.linalg.norm(x - self.center, axis=-1)) + + def boundary_constraint_factor( + self, x, smoothness: Literal["C0", "C0+", "Cinf"] = "C0+" + ): + """ + Compute the boundary constraint factor for given points. + + Args: + x (array-like): Points to compute the factor for. + smoothness (str): Smoothness of the constraint factor. Options are "C0", "C0+", or "Cinf". + + Returns: + array-like: Boundary constraint factors for the given points. + """ + if smoothness not in ["C0", "C0+", "Cinf"]: + raise ValueError("smoothness must be one of C0, C0+, Cinf") + + if not hasattr(self, "self.center_tensor"): + self.center_tensor = jnp.asarray(self.center) + self.radius_tensor = jnp.asarray(self.radius) + + dist = ( + jnp.linalg.norm(x - self.center_tensor, axis=-1, keepdims=True) + - self.radius + ) + if smoothness == "Cinf": + dist = jnp.square(dist) + else: + dist = jnp.abs(dist) + return dist + + def boundary_normal(self, x): + """ + Compute the normal vectors at boundary points. + + Args: + x (array-like): Points on the boundary. + + Returns: + array-like: Normal vectors at the given points. + """ + _n = x - self.center + l = jnp.linalg.norm(_n, axis=-1, keepdims=True) + _n = _n / l * isclose(l, self.radius) + return _n + + def random_points(self, n, random="pseudo"): + """ + Generate random points inside the hypersphere. + + Args: + n (int): Number of points to generate. + random (str): Type of random number generation ("pseudo" or other). + + Returns: + array-like: Randomly generated points. + """ + if random == "pseudo": + U = bst.random.rand(n, 1).astype(bst.environ.dftype()) + X = bst.random.normal(size=(n, self.dim)).astype(bst.environ.dftype()) + else: + rng = sample(n, self.dim + 1, random) + U, X = rng[:, 0:1], rng[:, 1:] + X = stats.norm.ppf(X).astype(bst.environ.dftype()) + X = preprocessing.normalize(X) + X = U ** (1 / self.dim) * X + return self.radius * X + self.center + + def random_boundary_points(self, n, random="pseudo"): + """ + Generate random points on the boundary of the hypersphere. + + Args: + n (int): Number of points to generate. + random (str): Type of random number generation ("pseudo" or other). + + Returns: + array-like: Randomly generated boundary points. + """ + if random == "pseudo": + X = bst.random.normal(size=(n, self.dim)).astype(bst.environ.dftype()) + else: + U = sample(n, self.dim, random) + X = stats.norm.ppf(U).astype(bst.environ.dftype()) + X = preprocessing.normalize(X) + return self.radius * X + self.center + + def background_points(self, x, dirn, dist2npt, shift): + """ + Generate background points along a direction from given points. + + Args: + x (array-like): Starting points. + dirn (array-like): Direction vector. + dist2npt (callable): Function to determine number of points based on distance. + shift (float): Shift factor for point generation. + + Returns: + array-like: Generated background points. + """ + dirn = dirn / jnp.linalg.norm(dirn) + dx = self.distance2boundary_unitdirn(x, -dirn) + n = max(dist2npt(dx), 1) + h = dx / n + pts = ( + x + - jnp.arange(-shift, n - shift + 1, dtype=bst.environ.dftype())[:, None] + * h + * dirn + ) + return pts diff --git a/deepxde/experimental/geometry/pointcloud.py b/deepxde/experimental/geometry/pointcloud.py new file mode 100644 index 000000000..836c90f6f --- /dev/null +++ b/deepxde/experimental/geometry/pointcloud.py @@ -0,0 +1,150 @@ +import brainstate as bst +import numpy as np + +from deepxde.data.sampler import BatchSampler +from .base import GeometryExperimental as Geometry +from ..utils import isclose + + +class PointCloud(Geometry): + """A geometry represented by a point cloud, i.e., a set of points in space. + + Args: + points: A 2-D NumPy array. If `boundary_points` is not provided, `points` can + include points both inside the geometry or on the boundary; if `boundary_points` + is provided, `points` includes only points inside the geometry. + boundary_points: A 2-D NumPy array representing points on the boundary. + boundary_normals: A 2-D NumPy array representing normal vectors at boundary points. + """ + + def __init__(self, points, boundary_points=None, boundary_normals=None): + self.points = np.asarray(points, dtype=bst.environ.dftype()) + self.num_points = len(points) + self.boundary_points = None + self.boundary_normals = None + all_points = self.points + if boundary_points is not None: + self.boundary_points = np.asarray( + boundary_points, dtype=bst.environ.dftype() + ) + self.num_boundary_points = len(boundary_points) + all_points = np.vstack((self.points, self.boundary_points)) + self.boundary_sampler = BatchSampler(self.num_boundary_points, shuffle=True) + if boundary_normals is not None: + if len(boundary_normals) != len(boundary_points): + raise ValueError( + "the shape of boundary_normals should be the same as boundary_points" + ) + self.boundary_normals = np.asarray( + boundary_normals, dtype=bst.environ.dftype() + ) + super().__init__( + len(points[0]), + (np.amin(all_points, axis=0), np.amax(all_points, axis=0)), + np.inf, + ) + self.sampler = BatchSampler(self.num_points, shuffle=True) + + def inside(self, x): + """ + Check if points are inside the geometry. + + Args: + x (numpy.ndarray): A 2-D array of points to check. + + Returns: + numpy.ndarray: A boolean array indicating whether each point is inside the geometry. + """ + return ( + isclose((x[:, None, :] - self.points[None, :, :]), 0) + .all(axis=2) + .any(axis=1) + ) + + def on_boundary(self, x): + """ + Check if points are on the boundary of the geometry. + + Args: + x (numpy.ndarray): A 2-D array of points to check. + + Returns: + numpy.ndarray: A boolean array indicating whether each point is on the boundary. + + Raises: + ValueError: If boundary_points is not defined. + """ + if self.boundary_points is None: + raise ValueError("boundary_points must be defined to test on_boundary") + return ( + isclose( + (x[:, None, :] - self.boundary_points[None, :, :]), + 0, + ) + .all(axis=2) + .any(axis=1) + ) + + def boundary_normal(self, x): + """ + Get the normal vectors for points on the boundary. + + Args: + x (numpy.ndarray): A 2-D array of points on the boundary. + + Returns: + numpy.ndarray: A 2-D array of normal vectors corresponding to the input points. + + Raises: + ValueError: If boundary_normals is not defined. + """ + if self.boundary_normals is None: + raise ValueError("boundary_normals must be defined for boundary_normal") + boundary_point_matches = isclose( + (self.boundary_points[:, None, :] - x[None, :, :]), 0 + ).all(axis=2) + normals_idx = np.where(boundary_point_matches)[0] + return self.boundary_normals[normals_idx, :] + + def random_points(self, n, random="pseudo"): + """ + Generate random points within the geometry. + + Args: + n (int): Number of random points to generate. + random (str, optional): Type of random number generation. Defaults to "pseudo". + + Returns: + numpy.ndarray: A 2-D array of randomly generated points. + """ + if n <= self.num_points: + indices = self.sampler.get_next(n) + return self.points[indices] + + x = np.tile(self.points, (n // self.num_points, 1)) + indices = self.sampler.get_next(n % self.num_points) + return np.vstack((x, self.points[indices])) + + def random_boundary_points(self, n, random="pseudo"): + """ + Generate random points on the boundary of the geometry. + + Args: + n (int): Number of random boundary points to generate. + random (str, optional): Type of random number generation. Defaults to "pseudo". + + Returns: + numpy.ndarray: A 2-D array of randomly generated boundary points. + + Raises: + ValueError: If boundary_points is not defined. + """ + if self.boundary_points is None: + raise ValueError("boundary_points must be defined to test on_boundary") + if n <= self.num_boundary_points: + indices = self.boundary_sampler.get_next(n) + return self.boundary_points[indices] + + x = np.tile(self.boundary_points, (n // self.num_boundary_points, 1)) + indices = self.boundary_sampler.get_next(n % self.num_boundary_points) + return np.vstack((x, self.boundary_points[indices])) diff --git a/deepxde/experimental/geometry/timedomain.py b/deepxde/experimental/geometry/timedomain.py new file mode 100644 index 000000000..99b52635c --- /dev/null +++ b/deepxde/experimental/geometry/timedomain.py @@ -0,0 +1,292 @@ +import itertools + +import brainstate as bst +import jax.numpy as jnp + +from .base import GeometryExperimental +from .geometry_1d import Interval +from .geometry_2d import Rectangle +from .geometry_3d import Cuboid +from .geometry_nd import Hypercube +from ..utils import isclose + + +class TimeDomain(Interval): + """ + Represents a time domain interval. + + This class extends the Interval class to specifically handle time domains. + It provides functionality to check if a given time point is at the initial time. + + Attributes: + t0 (jnp.ndarray): The start time of the domain. + t1 (jnp.ndarray): The end time of the domain. + """ + + def __init__(self, t0, t1): + """ + Initialize the TimeDomain. + + Parameters: + t0 (float or jnp.ndarray): The start time of the domain. + t1 (float or jnp.ndarray): The end time of the domain. + """ + super().__init__(t0, t1) + self.t0 = jnp.asarray(t0, dtype=bst.environ.dftype()) + self.t1 = jnp.asarray(t1, dtype=bst.environ.dftype()) + + def on_initial(self, t): + """ + Check if the given time point is at the initial time (t0). + + Parameters: + t (jnp.ndarray): The time point(s) to check. + + Returns: + jnp.ndarray: A boolean array indicating whether each time point is at the initial time. + """ + return isclose(t, self.t0).flatten() + + +class GeometryXTime(GeometryExperimental): + """ + Represents a geometry combined with a time domain for spatio-temporal problems. + + This class extends GeometryExperimental to handle both spatial and temporal dimensions. + """ + + def __init__(self, geometry, timedomain): + """ + Initialize the GeometryXTime object. + + Parameters: + geometry (GeometryExperimental): The spatial geometry. + timedomain (TimeDomain): The time domain. + """ + self.geometry = geometry + self.timedomain = timedomain + super().__init__( + geometry.dim + timedomain.dim, + geometry.bbox + timedomain.bbox, + min(geometry.diam, timedomain.diam), + ) + + def inside(self, x): + """ + Check if points are inside the spatio-temporal domain. + + Parameters: + x (jnp.ndarray): Array of points to check. + + Returns: + jnp.ndarray: Boolean array indicating whether each point is inside the domain. + """ + return jnp.logical_and( + self.geometry.inside(x[:, :-1]), self.timedomain.inside(x[:, -1:]) + ) + + def on_boundary(self, x): + """ + Check if points are on the spatial boundary of the domain. + + Parameters: + x (jnp.ndarray): Array of points to check. + + Returns: + jnp.ndarray: Boolean array indicating whether each point is on the boundary. + """ + return self.geometry.on_boundary(x[:, :-1]) + + def on_initial(self, x): + """ + Check if points are at the initial time of the domain. + + Parameters: + x (jnp.ndarray): Array of points to check. + + Returns: + jnp.ndarray: Boolean array indicating whether each point is at the initial time. + """ + return self.timedomain.on_initial(x[:, -1:]) + + def boundary_normal(self, x): + """ + Compute the boundary normal vectors for given points. + + Parameters: + x (jnp.ndarray): Array of points on the boundary. + + Returns: + jnp.ndarray: Array of boundary normal vectors. + """ + _n = self.geometry.boundary_normal(x[:, :-1]) + return jnp.hstack([_n, jnp.zeros((len(_n), 1))]) + + def uniform_points(self, n, boundary=True): + """ + Generate uniform points in the spatio-temporal domain. + + Parameters: + n (int): Number of points to generate. + boundary (bool): Whether to include boundary points. + + Returns: + jnp.ndarray: Array of uniformly distributed points. + """ + nx = int( + jnp.ceil( + ( + n + * jnp.prod(self.geometry.bbox[1] - self.geometry.bbox[0]) + / self.timedomain.diam + ) + ** 0.5 + ) + ) + nt = int(jnp.ceil(n / nx)) + x = self.geometry.uniform_points(nx, boundary=boundary) + nx = len(x) + if boundary: + t = self.timedomain.uniform_points(nt, boundary=True) + else: + t = jnp.linspace( + self.timedomain.t1, + self.timedomain.t0, + num=nt, + endpoint=False, + dtype=bst.environ.dftype(), + )[:, None] + xt = [] + for ti in t: + xt.append(jnp.hstack((x, jnp.full([nx, 1], ti[0])))) + xt = jnp.vstack(xt) + if n != len(xt): + print( + "Warning: {} points required, but {} points sampled.".format(n, len(xt)) + ) + return xt + + def random_points(self, n, random="pseudo"): + """ + Generate random points in the spatio-temporal domain. + + Parameters: + n (int): Number of points to generate. + random (str): Type of random number generation ("pseudo" or "sobol"). + """ + if isinstance(self.geometry, (Cuboid, Hypercube)): + geom = Hypercube( + jnp.append(self.geometry.xmin, self.timedomain.t0), + jnp.append(self.geometry.xmax, self.timedomain.t1), + ) + return geom.random_points(n, random=random) + + x = self.geometry.random_points(n, random=random) + t = self.timedomain.random_points(n, random=random) + t = bst.random.permutation(t) + return jnp.hstack((x, t)) + + def uniform_boundary_points(self, n): + """ + Generate uniform points on the boundary of the spatio-temporal domain. + + Parameters: + n (int): Number of boundary points to generate. + + Returns: + jnp.ndarray: Array of uniformly distributed boundary points. + """ + if self.geometry.dim == 1: + nx = 2 + else: + s = 2 * sum( + map( + lambda l: l[0] * l[1], + itertools.combinations( + self.geometry.bbox[1] - self.geometry.bbox[0], 2 + ), + ) + ) + nx = int((n * s / self.timedomain.diam) ** 0.5) + nt = int(jnp.ceil(n / nx)) + x = self.geometry.uniform_boundary_points(nx) + nx = len(x) + t = jnp.linspace( + self.timedomain.t1, + self.timedomain.t0, + num=nt, + endpoint=False, + dtype=bst.environ.dftype(), + ) + xt = [] + for ti in t: + xt.append(jnp.hstack((x, jnp.full([nx, 1], ti)))) + xt = jnp.vstack(xt) + if n != len(xt): + print( + "Warning: {} points required, but {} points sampled.".format(n, len(xt)) + ) + return xt + + def random_boundary_points(self, n, random="pseudo"): + """ + Generate random points on the boundary of the spatio-temporal domain. + + Parameters: + n (int): Number of boundary points to generate. + random (str): Type of random number generation ("pseudo" or "sobol"). + + Returns: + jnp.ndarray: Array of randomly distributed boundary points. + """ + x = self.geometry.random_boundary_points(n, random=random) + t = self.timedomain.random_points(n, random=random) + t = bst.random.permutation(t) + return jnp.hstack((x, t)) + + def uniform_initial_points(self, n): + """ + Generate uniform points at the initial time of the spatio-temporal domain. + + Parameters: + n (int): Number of initial points to generate. + + Returns: + jnp.ndarray: Array of uniformly distributed initial points. + """ + x = self.geometry.uniform_points(n, True) + t = self.timedomain.t0 + if n != len(x): + print( + "Warning: {} points required, but {} points sampled.".format(n, len(x)) + ) + return jnp.hstack((x, jnp.full([len(x), 1], t, dtype=bst.environ.dftype()))) + + def random_initial_points(self, n, random="pseudo"): + """ + Generate random points at the initial time of the spatio-temporal domain. + + Parameters: + n (int): Number of initial points to generate. + random (str): Type of random number generation ("pseudo" or "sobol"). + + Returns: + jnp.ndarray: Array of randomly distributed initial points. + """ + x = self.geometry.random_points(n, random=random) + t = self.timedomain.t0 + return jnp.hstack((x, jnp.full([n, 1], t, dtype=bst.environ.dftype()))) + + def periodic_point(self, x, component): + """ + Map points to their periodic counterparts in the spatial domain. + + Parameters: + x (jnp.ndarray): Array of points to map. + component (int): The spatial component for which to apply periodicity. + + Returns: + jnp.ndarray: Array of mapped periodic points. + """ + xp = self.geometry.periodic_point(x[:, :-1], component) + return jnp.hstack([xp, x[:, -1:]]) diff --git a/deepxde/experimental/grad.py b/deepxde/experimental/grad.py new file mode 100644 index 000000000..0a0299885 --- /dev/null +++ b/deepxde/experimental/grad.py @@ -0,0 +1,482 @@ +from __future__ import annotations + +from functools import wraps +from typing import Dict, Callable, Sequence, Union, Optional, Tuple, Any, Iterator + +import brainstate as bst +import brainunit as u + +TransformFn = Callable + +__all__ = [ + "jacobian", + "hessian", + "gradient", +] + + +class GradientTransform(bst.util.PrettyRepr): + """ + A class for transforming gradient computations. + + This class wraps a target function and applies a gradient transformation to it. + It handles auxiliary data and state management during the transformation process. + + Attributes: + target (Callable): The target function to be transformed. + _transform (Callable): The transformed function. + _return_value (bool): Flag to determine if the original function value should be returned. + _has_aux (bool): Flag to indicate if the target function returns auxiliary data. + _states_to_be_written (Tuple[bst.State, ...]): States that need to be updated after computation. + """ + + def __init__( + self, + target: Callable, + transform: TransformFn, + return_value: bool = False, + has_aux: bool = False, + transform_params: Optional[Dict[str, Any]] = None, + ): + """ + Initialize the GradientTransform. + + Args: + target (Callable): The target function to be transformed. + transform (TransformFn): The transformation function to be applied. + return_value (bool, optional): If True, return the original function value along with the gradient. Defaults to False. + has_aux (bool, optional): If True, the target function returns auxiliary data. Defaults to False. + transform_params (Optional[Dict[str, Any]], optional): Additional parameters for the transformation. Defaults to None. + """ + self._return_value = return_value + self._has_aux = has_aux + + # target + self.target = target + + # transform + self._states_to_be_written: Tuple[bst.State, ...] = None + _grad_setting = dict() if transform_params is None else transform_params + if self._has_aux: + self._transform = transform( + self._fun_with_aux, has_aux=True, **_grad_setting + ) + else: + self._transform = transform( + self._fun_without_aux, has_aux=True, **_grad_setting + ) + + def __pretty_repr__( + self, + ) -> Iterator[Union[bst.util.PrettyType, bst.util.PrettyAttr]]: + """ + Generate a pretty representation of the GradientTransform instance. + + Returns: + Iterator[Union[bst.util.PrettyType, bst.util.PrettyAttr]]: An iterator of pretty-formatted attributes. + """ + yield bst.util.PrettyType(self.__class__.__name__) + yield bst.util.PrettyAttr("target", self.target) + yield bst.util.PrettyAttr("return_value", self._return_value) + yield bst.util.PrettyAttr("has_aux", self._has_aux) + yield bst.util.PrettyAttr("transform", self._transform) + + def _call_target(self, *args, **kwargs): + """ + Call the target function and collect states to be written. + + Args: + *args: Positional arguments for the target function. + **kwargs: Keyword arguments for the target function. + + Returns: + Any: The output of the target function. + """ + if self._states_to_be_written is None: + with bst.StateTraceStack() as stack: + output = self.target(*args, **kwargs) + self._states_to_be_written = [st for st in stack.get_write_states()] + else: + output = self.target(*args, **kwargs) + return output + + def _fun_with_aux(self, *args, **kwargs): + """ + Wrapper for target function when it returns auxiliary data. + + Args: + *args: Positional arguments for the target function. + **kwargs: Keyword arguments for the target function. + + Returns: + Tuple: A tuple containing the main output and auxiliary data. + """ + outs = self._call_target(*args, **kwargs) + assert ( + self._states_to_be_written is not None + ), "The states to be written should be collected." + return outs[0], (outs, [v.value for v in self._states_to_be_written]) + + def _fun_without_aux(self, *args, **kwargs): + """ + Wrapper for target function when it doesn't return auxiliary data. + + Args: + *args: Positional arguments for the target function. + **kwargs: Keyword arguments for the target function. + + Returns: + Tuple: A tuple containing the output and related data. + """ + out = self._call_target(*args, **kwargs) + assert ( + self._states_to_be_written is not None + ), "The states to be written should be collected." + return out, (out, [v.value for v in self._states_to_be_written]) + + def _return(self, rets): + """ + Process and return the results of the transformation. + + Args: + rets: The results from the transformation. + + Returns: + Tuple: Processed results based on the configuration of return_value and has_aux. + """ + grads, (outputs, new_dyn_vals) = rets + for i, val in enumerate(new_dyn_vals): + self._states_to_be_written[i].value = val + + if self._return_value: + if self._has_aux: + return grads, outputs[0], outputs[1] + else: + return grads, outputs + else: + if self._has_aux: + return grads, outputs[1] + else: + return grads + + def __call__(self, *args, **kwargs): + """ + Call the transformed function and process its results. + + Args: + *args: Positional arguments for the transformed function. + **kwargs: Keyword arguments for the transformed function. + + Returns: + Any: The processed results of the transformation. + """ + rets = self._transform(*args, **kwargs) + return self._return(rets) + + +def _raw_jacrev( + fun: Callable, + has_aux: bool = False, + y: str | Sequence[str] | None = None, + x: str | Sequence[str] | None = None, +) -> Callable: + # process only for y + if isinstance(y, str): + y = [y] + if y is not None: + fun = _format_y(fun, y, has_aux=has_aux) + + # process only for x + if isinstance(x, str): + x = [x] + + def transform(inputs): + if x is not None: + fun2, inputs = _format_x(fun, x, inputs) + return u.autograd.jacrev(fun2, has_aux=has_aux)(inputs) + else: + return u.autograd.jacrev(fun, has_aux=has_aux)(inputs) + + return transform + + +def _raw_jacfwd( + fun: Callable, + has_aux: bool = False, + y: str | Sequence[str] | None = None, + x: str | Sequence[str] | None = None, +) -> Callable: + # process only for y + if isinstance(y, str): + y = [y] + if y is not None: + fun = _format_y(fun, y, has_aux=has_aux) + + # process only for x + if isinstance(x, str): + x = [x] + + def transform(inputs): + if x is not None: + fun2, inputs = _format_x(fun, x, inputs) + return u.autograd.jacfwd(fun2, has_aux=has_aux)(inputs) + else: + return u.autograd.jacfwd(fun, has_aux=has_aux)(inputs) + + return transform + + +def _raw_hessian( + fun: Callable, + has_aux: bool = False, + y: str | Sequence[str] | None = None, + xi: str | Sequence[str] | None = None, + xj: str | Sequence[str] | None = None, +) -> Callable: + r""" + Physical unit-aware version of `jax.hessian `_, + computing Hessian of ``fun`` as a dense array. + + H[y][xi][xj] = d^2y / dxi dxj + + Args: + fun: Function whose Hessian is to be computed. Its arguments at positions + specified by ``argnums`` should be arrays, scalars, or standard Python + containers thereof. It should return arrays, scalars, or standard Python + containers thereof. + has_aux: Optional, bool. Indicates whether ``fun`` returns a pair where the + first element is considered the output of the mathematical function to be + differentiated and the second element is auxiliary data. Default False. + + Returns: + A function with the same arguments as ``fun``, that evaluates the Hessian of + ``fun``. + """ + + inner = _raw_jacrev(fun, has_aux=has_aux, y=y, x=xi) + + # process only for xj + if isinstance(xj, str): + xj = [xj] + + def transform(inputs): + if xj is not None: + fun2, inputs = _format_x(inner, xj, inputs) + return u.autograd.jacfwd(fun2, has_aux=has_aux)(inputs) + else: + return u.autograd.jacfwd(inner, has_aux=has_aux)(inputs) + + return transform + + +def _format_x(fn, x_keys, xs): + assert isinstance(xs, dict), "xs must be a dictionary." + assert isinstance(x_keys, (tuple, list)), "x must be a tuple or list." + assert all( + isinstance(key, str) for key in x_keys + ), "x_keys must be a tuple or list of strings." + others = {key: xs[key] for key in xs if key not in x_keys} + xs = {key: xs[key] for key in x_keys} + + @wraps(fn) + def fn_new(inputs): + return fn({**inputs, **others}) + + return fn_new, xs + + +def _format_y(fn, y, has_aux: bool): + assert isinstance(y, (tuple, list)), "y must be a tuple or list." + assert all( + isinstance(key, str) for key in y + ), "y must be a tuple or list of strings." + + @wraps(fn) + def fn_new(inputs): + if has_aux: + outs, _aux = fn(inputs) + return {key: outs[key] for key in y}, _aux + else: + outs = fn(inputs) + return {key: outs[key] for key in y} + + return fn_new + + +def jacobian( + fn: Callable, + xs: Dict, + y: str | Sequence[str] | None = None, + x: str | Sequence[str] | None = None, + mode: str = "backward", + vmap: bool = True, +): + """ + Compute the Jacobian matrix of a function. + + This function calculates the Jacobian matrix J as J[i, j] = dy_i / dx_j, + where i = 0, ..., dim_y - 1 and j = 0, ..., dim_x - 1. + + Args: + fn (Callable): The function to compute the Jacobian for. + xs (Dict): A dictionary containing the input values for the function. + y (str | Sequence[str] | None, optional): Specifies the output variable(s) for which + to compute the Jacobian. If None, computes for all outputs. Defaults to None. + x (str | Sequence[str] | None, optional): Specifies the input variable(s) with respect + to which the Jacobian is computed. If None, computes for all inputs. Defaults to None. + mode (str, optional): The mode of gradient computation. Either 'backward' or 'forward'. + Defaults to 'backward'. + vmap (bool, optional): Whether to use vectorized mapping. Defaults to True. + + Returns: + The Jacobian matrix. Depending on the inputs, it can be: + - The full Jacobian matrix if both x and y are None or specify all variables. + - A row vector J[i, :] if y specifies a single output and x is None. + - A column vector J[:, j] if x specifies a single input and y is None. + - A scalar J[i, j] if both x and y specify single variables. + + Raises: + ValueError: If an invalid mode is specified. + + Note: + The function uses automatic differentiation techniques to compute the Jacobian. + The 'backward' mode is generally more efficient for functions with more outputs than inputs, + while 'forward' mode is more efficient for functions with more inputs than outputs. + """ + # assert isinstance(xs, dict), 'xs must be a dictionary.' + assert isinstance(mode, str), "mode must be a string." + assert mode in ["backward", "forward"], "mode must be either backward or forward." + + # process only for x + if isinstance(x, str): + x = [x] + + # process only for y + if isinstance(y, str): + y = [y] + + # compute the Jacobian + if mode == "backward": + transform = GradientTransform( + fn, _raw_jacrev, transform_params={"y": y, "x": x} + ) + elif mode == "forward": + transform = GradientTransform( + fn, _raw_jacfwd, transform_params={"y": y, "x": x} + ) + else: + raise ValueError("Invalid mode. Choose between backward and forward.") + if vmap: + return bst.augment.vmap(transform)(xs) + else: + return transform(xs) + + +def hessian( + fn: Callable, + xs: Dict, + y: str | Sequence[str] | None = None, + xi: str | Sequence[str] | None = None, + xj: str | Sequence[str] | None = None, + vmap: bool = True, +): + """ + Compute the Hessian matrix of a function. + + This function calculates the Hessian matrix H as H[i, j] = d^2y / dx_i dx_j, + where i, j = 0, ..., dim_x - 1. + + Args: + fn (Callable): The function for which to compute the Hessian. + xs (Dict): A dictionary containing the input values for the function. + y (str | Sequence[str] | None, optional): Specifies the output variable(s) for which + to compute the Hessian. If None, computes for all outputs. Defaults to None. + xi (str | Sequence[str] | None, optional): Specifies the input variable(s) for the i-th + dimension of the Hessian. If None, computes for all inputs in this dimension. + Defaults to None. + xj (str | Sequence[str] | None, optional): Specifies the input variable(s) for the j-th + dimension of the Hessian. If None, computes for all inputs in this dimension. + Defaults to None. + vmap (bool, optional): Whether to use vectorized mapping. Defaults to True. + + Returns: + The Hessian matrix or a part of it, depending on the specified xi and xj: + - If both xi and xj are None, returns the full Hessian matrix. + - If xi is specified and xj is None, returns the i-th row of the Hessian, H[i, :]. + - If xj is specified and xi is None, returns the j-th column of the Hessian, H[:, j]. + - If both xi and xj are specified, returns the specific element H[i, j]. + + Note: + xi and xj cannot both be None unless the Hessian has only one element. + """ + # assert isinstance(xs, dict), 'xs must be a dictionary.' + transform = GradientTransform( + fn, _raw_hessian, transform_params={"y": y, "xi": xi, "xj": xj} + ) + if vmap: + return bst.augment.vmap(transform)(xs) + else: + return transform(xs) + + +def gradient( + fn: Callable, + xs: Dict, + y: str | Sequence[str] | None = None, + *xi: str | Sequence[str] | None, + order: int = 1, +): + """ + Compute the gradient of a function with respect to specified variables. + + This function calculates the gradient dy/dx of a function y = f(x) with respect to x. + It supports computing higher-order gradients by specifying the 'order' parameter. + + Args: + fn (Callable): The function for which to compute the gradient. + xs (Dict): A dictionary containing the input values for the function. + y (str | Sequence[str] | None, optional): Specifies the output variable(s) to differentiate. + If None, computes for all outputs. Defaults to None. + *xi (str | Sequence[str] | None): Variable-length argument specifying the input variable(s) + to differentiate with respect to. The number of xi arguments should match the 'order' parameter. + order (int, optional): The order of the gradient to compute. Default is 1 (first derivative). + + Returns: + The computed gradient. The structure and dimensions of the output depend on the inputs: + - For first-order gradients (order=1), returns dy/dx. + - For higher-order gradients, returns the corresponding higher-order derivative. + + Raises: + AssertionError: If 'order' is not a positive integer or if the number of 'xi' arguments + doesn't match the specified 'order'. + + Note: + The function uses a combination of reverse-mode (for the first derivative) and + forward-mode (for higher-order derivatives) automatic differentiation. + """ + assert isinstance(order, int), "order must be an integer." + assert order > 0, "order must be positive." + + # process only for y + if isinstance(y, str): + y = [y] + if y is not None: + fn = _format_y(fn, y, has_aux=False) + + # process xi + if len(xi) > 0: + assert len(xi) == order, "The number of xi must be equal to order." + xi = list(xi) + for i in range(order): + if isinstance(xi[i], str): + xi[i] = [xi[i]] + else: + xi = [None] * order + + # compute the gradient + for i, x in enumerate(xi): + if i == 0: + fn = _raw_jacrev(fn, y=y, x=x) + else: + fn = _raw_jacfwd(fn, y=None, x=x) + return bst.augment.vmap(fn)(xs) diff --git a/deepxde/experimental/icbc/__init__.py b/deepxde/experimental/icbc/__init__.py new file mode 100644 index 000000000..0309af16f --- /dev/null +++ b/deepxde/experimental/icbc/__init__.py @@ -0,0 +1,29 @@ +"""Initial conditions and boundary conditions.""" + +__all__ = [ + "ICBC", + "BC", + "DirichletBC", + "Interface2DBC", + "NeumannBC", + "RobinBC", + "PeriodicBC", + "OperatorBC", + "PointSetBC", + "PointSetOperatorBC", + "IC", +] + +from .base import ICBC +from .boundary_conditions import ( + BC, + DirichletBC, + Interface2DBC, + NeumannBC, + RobinBC, + PeriodicBC, + OperatorBC, + PointSetBC, + PointSetOperatorBC, +) +from .initial_conditions import IC diff --git a/deepxde/experimental/icbc/base.py b/deepxde/experimental/icbc/base.py new file mode 100644 index 000000000..fab37345f --- /dev/null +++ b/deepxde/experimental/icbc/base.py @@ -0,0 +1,109 @@ +import abc +from typing import Optional, Dict + +import brainstate as bst + +from deepxde.experimental.geometry.base import GeometryExperimental + + +class ICBC(abc.ABC): + """ + Base class for initial and boundary conditions. + """ + + # A ``experimental.geometry.Geometry`` instance. + geometry: Optional[GeometryExperimental] + problem: Optional["Problem"] + + def apply_geometry(self, geom: GeometryExperimental): + """ + Applies a geometry to the ICBC instance. + + Parameters: + ----------- + geom : GeometryExperimental + The geometry to be applied to the ICBC instance. + + Raises: + ------- + AssertionError + If the provided geometry is not an instance of AbstractGeometry. + """ + assert isinstance( + geom, GeometryExperimental + ), "geometry must be an instance of AbstractGeometry." + self.geometry = geom + + def apply_problem(self, problem: "Problem"): + """ + Applies a problem to the ICBC instance. + + Parameters: + ----------- + problem : Problem + The problem to be applied to the ICBC instance. + + Raises: + ------- + AssertionError + If the provided problem is not an instance of Problem. + """ + from deepxde.experimental.problem.base import Problem + + assert isinstance(problem, Problem), "problem must be an instance of Problem." + self.problem = problem + + @abc.abstractmethod + def filter(self, X): + """ + Filters the input data. + + Parameters: + ----------- + X : array-like + The input data to be filtered. + + Returns: + -------- + array-like + The filtered input data. + """ + pass + + @abc.abstractmethod + def collocation_points(self, X): + """ + Returns the collocation points. + + Parameters: + ----------- + X : array-like + The input data for which to compute collocation points. + + Returns: + -------- + array-like + The computed collocation points. + """ + pass + + @abc.abstractmethod + def error(self, inputs, outputs, **kwargs) -> Dict[str, bst.typing.ArrayLike]: + """ + Returns the loss for each component at the initial or boundary conditions. + + Parameters: + ----------- + inputs : array-like + The input data. + outputs : array-like + The output data. + **kwargs : dict + Additional keyword arguments. + + Returns: + -------- + Dict[str, bst.typing.ArrayLike] + A dictionary containing the loss for each component at the initial or boundary conditions. + """ + pass diff --git a/deepxde/experimental/icbc/boundary_conditions.py b/deepxde/experimental/icbc/boundary_conditions.py new file mode 100644 index 000000000..442fa783e --- /dev/null +++ b/deepxde/experimental/icbc/boundary_conditions.py @@ -0,0 +1,655 @@ +from __future__ import annotations + +from typing import Callable, Dict + +import brainstate as bst +import brainunit as u +import jax +import numpy as np + +from deepxde.data.sampler import BatchSampler +from deepxde.experimental import utils +from deepxde.experimental.nn.model import Model +from .base import ICBC + +__all__ = [ + "BC", + "DirichletBC", + "Interface2DBC", + "NeumannBC", + "OperatorBC", + "PeriodicBC", + "PointSetBC", + "PointSetOperatorBC", + "RobinBC", +] + +X = Dict[str, bst.typing.ArrayLike] +Y = Dict[str, bst.typing.ArrayLike] +F = Dict[str, bst.typing.ArrayLike] +Boundary = Dict[str, bst.typing.ArrayLike] + + +class BC(ICBC): + """ + Boundary condition base class. + + This class serves as the foundation for implementing various boundary conditions in the DeepXDE framework. + It provides methods for filtering collocation points, computing normal derivatives, and handling boundary-related operations. + + Args: + on_boundary (Callable[[X, np.array], np.array]): A function that takes two arguments: + - x: The input points. + - on: A boolean array indicating whether each point is on the boundary. + The function should return a boolean array indicating which points satisfy the boundary condition. + + Attributes: + on_boundary (Callable): A vectorized version of the input `on_boundary` function. + """ + + def __init__( + self, + on_boundary: Callable[[X, np.array], np.array], + ): + self.on_boundary = lambda x, on: jax.vmap(on_boundary)(x, on) + + @utils.check_not_none("geometry") + def filter(self, X): + """ + Filter the collocation points for boundary conditions. + + This method applies the boundary condition filter to the given collocation points. + + Args: + X (Dict[str, bst.typing.ArrayLike]): A dictionary of collocation points. + + Returns: + Dict[str, bst.typing.ArrayLike]: A dictionary of filtered collocation points that satisfy the boundary condition. + """ + positions = self.on_boundary(X, self.geometry.on_boundary(X)) + return jax.tree.map(lambda x: x[positions], X) + + def collocation_points(self, X): + """ + Return the collocation points for boundary conditions. + + This method filters the input collocation points to return only those that satisfy the boundary condition. + + Args: + X (Dict[str, bst.typing.ArrayLike]): A dictionary of collocation points. + + Returns: + Dict[str, bst.typing.ArrayLike]: A dictionary of collocation points that satisfy the boundary condition. + """ + return self.filter(X) + + def normal_derivative(self, inputs) -> Dict[str, bst.typing.ArrayLike]: + """ + Compute the normal derivative of the output. + + This method calculates the normal derivative of the output with respect to the input at the boundary. + + Args: + inputs (Dict[str, bst.typing.ArrayLike]): A dictionary of input points. + + Returns: + Dict[str, bst.typing.ArrayLike]: A dictionary containing the normal derivatives of the output + with respect to each input variable. + + Raises: + AssertionError: If the problem approximator is not an instance of the Model class, + or if the boundary normal or jacobian are not dictionaries. + """ + # first order derivative + assert isinstance(self.problem.approximator, Model), ( + "Normal derivative is only supported " "for Sequential approximator." + ) + dydx = self.problem.approximator.jacobian(inputs) + + # boundary normal + n = self.geometry.boundary_normal(inputs) + + assert isinstance(n, dict), "Boundary normal should be a dictionary." + assert isinstance(dydx, dict), "dydx should be a dictionary." + norms = dict() + for y in dydx: + norm = None + for x in dydx[y]: + if norm is None: + norm = dydx[y][x] * n[x] + else: + norm += dydx[y][x] * n[x] + norms[y] = norm + return norms + + +class DirichletBC(BC): + """ + Dirichlet boundary conditions: ``y(x) = func(x)``. + + This class implements Dirichlet boundary conditions, where the solution is specified + on the boundary of the domain. + + Args: + func (Callable[[X, ...], F] | Callable[[X], F] | F): A function that takes an array of points + and returns an array of values, or a constant value to be applied at all boundary points. + on_boundary (Callable[[X, np.array], np.array], optional): A function that takes two arguments: + x (the input points) and on (a boolean array indicating whether each point is on the boundary). + It should return a boolean array indicating which points satisfy the boundary condition. + Defaults to a function that returns the input 'on' array. + + """ + + def __init__( + self, + func: Callable[[X, ...], F] | Callable[[X], F] | F, + on_boundary: Callable[[X, np.array], np.array] = lambda x, on: on, + ): + super().__init__(on_boundary) + self.func = func if callable(func) else lambda x: func + + def error(self, bc_inputs, bc_outputs, **kwargs): + """ + Calculate the error between the predicted and true values at the boundary. + + Args: + bc_inputs (Dict[str, bst.typing.ArrayLike]): Input points on the boundary. + bc_outputs (Dict[str, bst.typing.ArrayLike]): Predicted output values at the boundary points. + **kwargs: Additional keyword arguments to be passed to self.func. + + Returns: + Dict[str, bst.typing.ArrayLike]: A dictionary containing the errors for each output component. + The keys are the component names, and the values are the differences between + the predicted and true values at the boundary points. + """ + values = self.func(bc_inputs, **kwargs) + errors = dict() + for component in values.keys(): + errors[component] = bc_outputs[component] - values[component] + return errors + + +class NeumannBC(BC): + """ + Neumann boundary conditions: ``dy/dn(x) = func(x)``. + + Args: + func: A function that takes an array of points and returns an array of values. + on_boundary: (x, Geometry.on_boundary(x)) -> True/False. + """ + + def __init__( + self, + func: Callable[[X, ...], F] | Callable[[X], F], + on_boundary: Callable[[X, np.array], np.array] = lambda x, on: on, + ): + super().__init__(on_boundary) + self.func = func + + def error(self, bc_inputs, bc_outputs, **kwargs): + """ + Calculate the error for Neumann boundary conditions. + + This method computes the difference between the normal derivative of the solution + and the specified function values at the boundary points. + + Args: + bc_inputs (Dict[str, bst.typing.ArrayLike]): Input points on the boundary. + bc_outputs (Dict[str, bst.typing.ArrayLike]): Predicted output values at the boundary points. + **kwargs: Additional keyword arguments to be passed to self.func. + + Returns: + Dict[str, bst.typing.ArrayLike]: A dictionary containing the errors for each output component. + The keys are the component names, and the values are the differences between + the normal derivatives and the specified function values at the boundary points. + """ + values = self.func(bc_inputs, **kwargs) + normals = self.normal_derivative(bc_inputs) + return { + component: normals[component] - values[component] + for component in values.keys() + } + + +class RobinBC(BC): + """ + Robin boundary conditions: dy/dn(x) = func(x, y). + + This class implements Robin boundary conditions, which are a combination of + Dirichlet and Neumann boundary conditions. + + Attributes: + func (Callable): The function defining the Robin boundary condition. + """ + + def __init__( + self, + func: Callable[[X, Y, ...], F] | Callable[[X, Y], F], + on_boundary: Callable[[Dict, np.array], np.array] = lambda x, on: on, + ): + """ + Initialize the RobinBC class. + + Args: + func (Callable[[X, Y, ...], F] | Callable[[X, Y], F]): A function that takes + input points (X) and output values (Y) and returns the right-hand side + of the Robin boundary condition equation. + on_boundary (Callable[[Dict, np.array], np.array], optional): A function that + determines which points are on the boundary. Defaults to a function that + returns the input 'on' array. + """ + super().__init__(on_boundary) + self.func = func + + def error(self, bc_inputs, bc_outputs, **kwargs): + """ + Calculate the error for the Robin boundary condition. + + This method computes the difference between the normal derivative of the solution + and the specified function values at the boundary points. + + Args: + bc_inputs (Dict[str, bst.typing.ArrayLike]): Input points on the boundary. + bc_outputs (Dict[str, bst.typing.ArrayLike]): Predicted output values at the boundary points. + **kwargs: Additional keyword arguments to be passed to self.func. + + Returns: + Dict[str, bst.typing.ArrayLike]: A dictionary containing the errors for each output component. + The keys are the component names, and the values are the differences between + the normal derivatives and the specified function values at the boundary points. + """ + values = self.func(bc_inputs, bc_outputs, **kwargs) + normals = self.normal_derivative(bc_inputs) + return { + component: normals[component] - values[component] + for component in values.keys() + } + + +class PeriodicBC(BC): + """ + Implements periodic boundary conditions for a specified component of the solution. + + This class enforces periodicity by ensuring that the values (or derivatives) of the solution + at corresponding points on opposite boundaries are equal. + + Args: + component_y (str): The name of the output component to which the periodic condition is applied. + component_x (str): The name of the input component along which the periodicity is enforced. + on_boundary (Callable[[X, np.array], np.array], optional): A function that takes two arguments: + x (the input points) and on (a boolean array indicating whether each point is on the boundary). + It should return a boolean array indicating which points satisfy the boundary condition. + Defaults to a function that returns the input 'on' array. + derivative_order (int, optional): The order of the derivative for which periodicity is enforced. + Can be 0 (for function values) or 1 (for first derivatives). Defaults to 0. + + Raises: + NotImplementedError: If derivative_order is greater than 1. + """ + + def __init__( + self, + component_y: str, + component_x: str, + on_boundary: Callable[[X, np.array], np.array] = lambda x, on: on, + derivative_order: int = 0, + ): + super().__init__(on_boundary) + self.component_y = component_y + self.component_x = component_x + self.derivative_order = derivative_order + if derivative_order > 1: + raise NotImplementedError( + "PeriodicBC only supports derivative_order 0 or 1." + ) + + @utils.check_not_none("geometry") + def collocation_points(self, X): + """ + Generates collocation points for enforcing periodic boundary conditions. + + This method filters the input points, identifies the periodic points, and concatenates + them to create pairs of points for enforcing periodicity. + + Args: + X (Dict[str, bst.typing.ArrayLike]): A dictionary of input points. + + Returns: + Dict[str, bst.typing.ArrayLike]: A dictionary of collocation points, where each entry + is the concatenation of points on one boundary and their periodic counterparts. + """ + X1 = self.filter(X) + X2 = self.geometry.periodic_point(X1, self.component_x) + return jax.tree.map( + lambda x1, x2: utils.smart_numpy(x1).concatenate((x1, x2), axis=-1), + X1, + X2, + is_leaf=u.math.is_quantity, + ) + + def error(self, bc_inputs, bc_outputs, **kwargs): + """ + Calculates the error for periodic boundary conditions. + + This method computes the difference between the values (or derivatives) of the solution + at corresponding points on opposite boundaries. + + Args: + bc_inputs (Dict[str, bst.typing.ArrayLike]): Input points on the boundary. + bc_outputs (Dict[str, bst.typing.ArrayLike]): Predicted output values at the boundary points. + **kwargs: Additional keyword arguments (unused in this method). + + Returns: + Dict[str, Dict[str, bst.typing.ArrayLike]]: A nested dictionary containing the errors. + The outer key is the output component name, and the inner key is the input component name. + The value is the difference between the left and right boundary values or derivatives. + """ + n_batch = bc_inputs[self.component_x].shape[0] + mid = n_batch // 2 + if self.derivative_order == 0: + yleft = bc_outputs[self.component_y][:mid] + yright = bc_outputs[self.component_y][mid:] + else: + dydx = self.problem.approximator.jacobian( + bc_outputs, y=self.component_y, x=self.component_x + ) + dydx = dydx[self.component_y][self.component_x] + yleft = dydx[:mid] + yright = dydx[mid:] + return {self.component_y: {self.component_x: yleft - yright}} + + +class OperatorBC(BC): + """ + General operator boundary conditions: func(inputs, outputs) = 0. + + Args: + func: A function takes arguments (`inputs`, `outputs`) + and outputs a tensor of size `N x 1`, where `N` is the length of `inputs`. + `inputs` and `outputs` are the network input and output tensors, + respectively; `X` are the NumPy array of the `inputs`. + on_boundary: (x, Geometry.on_boundary(x)) -> True/False. + + Warning: + If you use `X` in `func`, then do not set ``num_test`` when you define + ``experimental.problem.PDE`` or ``experimental.problem.TimePDE``, otherwise DeepXDE would throw an + error. In this case, the training points will be used for testing, and this will + not affect the network training and training loss. This is a bug of DeepXDE, + which cannot be fixed in an easy way for all backends. + """ + + def __init__( + self, + func: Callable[[X, Y, ...], F] | Callable[[X, Y], F], + on_boundary: Callable[[X, np.array], np.array] = lambda x, on: on, + ): + super().__init__(on_boundary) + self.func = func + + def error(self, bc_inputs, bc_outputs, **kwargs): + """ + Calculate the error for the operator boundary condition. + + This method applies the operator function to the boundary inputs and outputs + to compute the error of the boundary condition. + + Args: + bc_inputs (Dict[str, bst.typing.ArrayLike]): A dictionary of input values at the boundary points. + bc_outputs (Dict[str, bst.typing.ArrayLike]): A dictionary of output values at the boundary points. + **kwargs: Additional keyword arguments to be passed to the operator function. + + Returns: + Dict[str, bst.typing.ArrayLike]: A dictionary containing the computed error values + for each component of the boundary condition. + """ + return self.func(bc_inputs, bc_outputs, **kwargs) + + +class PointSetBC(BC): + """ + Dirichlet boundary condition for a set of points. + + Compare the output (that associates with `points`) with `values` (target data). + If more than one component is provided via a list, the resulting loss will + be the addative loss of the provided componets. + + Args: + points (Dict[str, bst.typing.ArrayLike]): A dictionary of arrays representing points + where the corresponding target values are known and used for training. + values (Dict[str, bst.typing.ArrayLike]): A dictionary of scalars or 2D-arrays + representing the exact solution of the problem at the given points. + batch_size (int, optional): The number of points per minibatch, or None to return all points. + This is only supported for the backend PyTorch and PaddlePaddle. Defaults to None. + shuffle (bool, optional): Whether to randomize the order on each pass through the data + when batching. Defaults to True. + + Note: + If you want to use batch size here, you should also set callback + 'experimental.callbacks.PDEPointResampler(bc_points=True)' in training. + """ + + def __init__( + self, + points: Dict[str, bst.typing.ArrayLike], + values: Dict[str, bst.typing.ArrayLike], + batch_size: int = None, + shuffle: bool = True, + ): + super().__init__(lambda x, on: on) + + self.points = points + self.values = values + self.batch_size = batch_size + + if batch_size is not None: # batch iterator and state + self.batch_sampler = BatchSampler(len(self), shuffle=shuffle) + self.batch_indices = None + + def __len__(self): + """ + Get the number of points in the PointSetBC. + + Returns: + int: The number of points in the first value of the points dictionary. + """ + v = tuple(self.points.values())[0] + return v.shape[0] + + def collocation_points(self, X): + """ + Get the collocation points for the boundary condition. + + If batch_size is set, returns a batch of points. Otherwise, returns all points. + + Args: + X: Unused in this method, kept for compatibility with parent class. + + Returns: + Dict[str, bst.typing.ArrayLike]: A dictionary of collocation points, + either a batch or all points depending on the batch_size setting. + """ + if self.batch_size is not None: + self.batch_indices = self.batch_sampler.get_next(self.batch_size) + return jax.tree.map(lambda x: x[self.batch_indices], self.points) + return self.points + + def error(self, bc_inputs, bc_outputs, **kwargs): + """ + Calculate the error between the predicted and true values at the boundary points. + + Args: + bc_inputs: Unused in this method, kept for compatibility with parent class. + bc_outputs (Dict[str, bst.typing.ArrayLike]): A dictionary of predicted output values + at the boundary points. + **kwargs: Additional keyword arguments (unused in this method). + + Returns: + Dict[str, bst.typing.ArrayLike]: A dictionary containing the errors for each output component. + The keys are the component names, and the values are the differences between + the predicted and true values at the boundary points. + """ + if self.batch_size is not None: + return { + k: bc_outputs[k] - self.values[k][self.batch_indices] + for k in self.values.keys() + } + else: + return {k: bc_outputs[k] - self.values[k] for k in self.values.keys()} + + +class PointSetOperatorBC(BC): + """ + General operator boundary conditions for a set of points. + + Compare the function output, func, (that associates with `points`) + with `values` (target data). + + Args: + points: An array of points where the corresponding target values are + known and used for training. + values: An array of values which output of function should fulfill. + func: A function takes arguments (`inputs`, `outputs`,) + and outputs a tensor of size `N x 1`, where `N` is the length of + `inputs`. `inputs` and `outputs` are the network input and output + tensors, respectively; `X` are the NumPy array of the `inputs`. + """ + + def __init__( + self, + points: Dict[str, bst.typing.ArrayLike], + values: Dict[str, bst.typing.ArrayLike], + func: Callable[[X, Y], F], + ): + super().__init__(lambda x, on: on) + self.points = points + self.values = values + self.func = func + + def collocation_points(self, X): + """ + Return the collocation points for the boundary condition. + + Args: + X: Unused input parameter, kept for compatibility with parent class. + + Returns: + Dict[str, bst.typing.ArrayLike]: The points where the boundary condition is applied. + """ + return self.points + + def error(self, bc_inputs, bc_outputs, **kwargs): + """ + Calculate the error for the operator boundary condition. + + This method applies the operator function to the boundary inputs and outputs, + then computes the difference between the function output and the target values. + + Args: + bc_inputs (Dict[str, bst.typing.ArrayLike]): Input values at the boundary points. + bc_outputs (Dict[str, bst.typing.ArrayLike]): Output values at the boundary points. + **kwargs: Additional keyword arguments to be passed to the operator function. + + Returns: + Dict[str, bst.typing.ArrayLike]: A dictionary containing the computed error values + for each component of the boundary condition. + """ + outs = self.func(bc_inputs, bc_outputs) + return { + component: outs[component] - self.values[component] + for component in outs.keys() + } + + +class Interface2DBC(BC): + """2D interface boundary condition. + + This BC applies to the case with the following conditions: + (1) the network output has two elements, i.e., output = [y1, y2], + (2) the 2D geometry is ``experimental.geometry.Rectangle`` or ``experimental.geometry.Polygon``, which has two edges of the same length, + (3) uniform boundary points are used, i.e., in ``experimental.problem.PDE`` or ``experimental.problem.TimePDE``, ``train_distribution="uniform"``. + For a pair of points on the two edges, compute for the point on the first edge + and for the point on the second edge in the n/t direction ('n' for normal or 't' for tangent). + Here, is the dot product between vectors v1 and v2; + and d1 and d2 are the n/t vectors of the first and second edges, respectively. + In the normal case, d1 and d2 are the outward normal vectors; + and in the tangent case, d1 and d2 are the outward normal vectors rotated 90 degrees clockwise. + The points on the two edges are paired as follows: the boundary points on one edge are sampled clockwise, + and the points on the other edge are sampled counterclockwise. Then, compare the sum with 'values', + i.e., the error is calculated as + - values, + where 'values' is the argument `func` evaluated on the first edge. + + Args: + func: the target discontinuity between edges, evaluated on the first edge, + e.g., ``func=lambda x: 0`` means no discontinuity is wanted. + on_boundary1: First edge func. (x, Geometry.on_boundary(x)) -> True/False. + on_boundary2: Second edge func. (x, Geometry.on_boundary(x)) -> True/False. + direction (string): "normal" or "tangent". + """ + + def __init__( + self, + func: Callable[[X, ...], F] | Callable[[X], F], + on_boundary1: Callable[[X, np.array], np.array] = lambda x, on: on, + on_boundary2: Callable[[X, np.array], np.array] = lambda x, on: on, + direction: str = "normal", + ): + super().__init__(lambda x, on: on) + + self.func = utils.return_tensor(func) + self.on_boundary1 = lambda x, on: np.array( + [on_boundary1(x[i], on[i]) for i in range(len(x))] + ) + self.on_boundary2 = lambda x, on: np.array( + [on_boundary2(x[i], on[i]) for i in range(len(x))] + ) + self.direction = direction + + @utils.check_not_none("geometry") + def collocation_points(self, X): + on_boundary = self.geometry.on_boundary(X) + X1 = X[self.on_boundary1(X, on_boundary)] + X2 = X[self.on_boundary2(X, on_boundary)] + # Flip order of X2 when experimental.geometry.Polygon is used + if self.geometry.__class__.__name__ == "Polygon": + X2 = np.flip(X2, axis=0) + return np.vstack((X1, X2)) + + @utils.check_not_none("geometry") + def error(self, bc_inputs, bc_outputs, **kwargs): + mid = bc_inputs.shape[0] // 2 + if bc_inputs.shape[0] % 2 != 0: + raise RuntimeError( + "There is a different number of points on each edge,\n " + "this is likely because the chosen edges do not have the same length." + ) + aux_var = None + values = self.func(bc_inputs[:mid], **kwargs) + if np.ndim(values) == 2 and np.shape(values)[1] != 1: + raise RuntimeError("BC function should return an array of shape N by 1") + left_n = self.geometry.boundary_normal(bc_inputs[:mid]) + right_n = self.geometry.boundary_normal(bc_inputs[:mid]) + if self.direction == "normal": + left_side = bc_outputs[:mid, :] + right_side = bc_outputs[mid:, :] + left_values = u.math.sum(left_side * left_n, 1, keepdims=True) + right_values = u.math.sum(right_side * right_n, 1, keepdims=True) + + elif self.direction == "tangent": + # Tangent vector is [n[1],-n[0]] on edge 1 + left_side1 = bc_outputs[:mid, 0:1] + left_side2 = bc_outputs[:mid, 1:2] + right_side1 = bc_outputs[mid:, 0:1] + right_side2 = bc_outputs[mid:, 1:2] + left_values_1 = u.math.sum(left_side1 * left_n[:, 1:2], 1, keepdims=True) + left_values_2 = u.math.sum(-left_side2 * left_n[:, 0:1], 1, keepdims=True) + left_values = left_values_1 + left_values_2 + right_values_1 = u.math.sum(right_side1 * right_n[:, 1:2], 1, keepdims=True) + right_values_2 = u.math.sum( + -right_side2 * right_n[:, 0:1], 1, keepdims=True + ) + right_values = right_values_1 + right_values_2 + + else: + raise ValueError("Invalid direction, must be 'normal' or 'tangent'.") + + return left_values + right_values - values diff --git a/deepxde/experimental/icbc/initial_conditions.py b/deepxde/experimental/icbc/initial_conditions.py new file mode 100644 index 000000000..e9c608e30 --- /dev/null +++ b/deepxde/experimental/icbc/initial_conditions.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from typing import Callable, Dict + +import brainstate as bst +import jax +import numpy as np + +from .base import ICBC + +__all__ = ["IC"] + + +class IC(ICBC): + """ + Represents the Initial Conditions (IC) for a differential equation. + + This class defines and handles the initial conditions of the form: + y([x, t0]) = func([x, t0]), where func is a user-defined function. + + Args: + func (Callable[[Dict, ...], Dict] | Callable[[Dict], Dict]): A function that returns the initial conditions. + This function should take a dictionary of collocation points and + return a dictionary of initial conditions. For example: + import brainunit as u + def func(x): + return {'y': -u.math.sin(np.pi * x['x'] / u.meter) * u.meter / u.second} + on_initial (Callable[[Dict, np.array], np.array], optional): A filter function for initial conditions. + This function should take a dictionary of collocation points and + return a boolean array indicating whether the points are initial conditions. + Defaults to lambda x, on: on. For example: + def on_initial(x, on): + return on + """ + + def __init__( + self, + func: Callable[[Dict, ...], Dict] | Callable[[Dict], Dict], + on_initial: Callable[[Dict, np.array], np.array] = lambda x, on: on, + ): + self.func = func + self.on_initial = lambda x, on: jax.vmap(on_initial)(x, on) + + def filter(self, X): + """ + Filters the collocation points for initial conditions. + + Args: + X (Dict): A dictionary of collocation points. + + Returns: + Dict: Filtered collocation points that satisfy the initial conditions. + """ + # the "geometry" should be "TimeDomain" or "GeometryXTime" + positions = self.on_initial(X, self.geometry.on_initial(X)) + return jax.tree.map(lambda x: x[positions], X) + + def collocation_points(self, X): + """ + Returns the collocation points for initial conditions. + + Args: + X (Dict): A dictionary of collocation points. + + Returns: + Dict: Collocation points that satisfy the initial conditions. + """ + return self.filter(X) + + def error(self, inputs, outputs, **kwargs) -> Dict[str, bst.typing.ArrayLike]: + """ + Calculates the error for initial conditions. + + This method compares the initial conditions with the outputs to compute the error. + + Args: + inputs (Dict): A dictionary of collocation points. + outputs (Dict): A dictionary of collocation values. + **kwargs: Additional keyword arguments to be passed to the func method. + + Returns: + Dict[str, bst.typing.ArrayLike]: A dictionary containing the errors for each variable. + The keys correspond to the variable names, and the values are the computed errors. + """ + values = self.func(inputs, **kwargs) + errors = dict() + for key, value in values.items(): + errors[key] = outputs[key] - value + return errors diff --git a/deepxde/experimental/metrics.py b/deepxde/experimental/metrics.py new file mode 100644 index 000000000..1df36a19c --- /dev/null +++ b/deepxde/experimental/metrics.py @@ -0,0 +1,307 @@ +import brainunit as u +import jax + +__all__ = [ + "accuracy", + "l2_relative_error", + "nanl2_relative_error", + "mean_l2_relative_error", + "mean_squared_error", + "mean_absolute_percentage_error", + "max_absolute_percentage_error", + "absolute_percentage_error_std", +] + + +def _accuracy(y_true, y_pred): + return u.math.mean( + u.math.equal(u.math.argmax(y_pred, axis=-1), u.math.argmax(y_true, axis=-1)) + ) + + +def accuracy(y_true, y_pred): + """ + Computes accuracy across nested structures of labels and predictions. + + This function calculates the accuracy by comparing the predicted labels + with the true labels. It can handle nested structures of data. + + Parameters: + ----------- + y_true : array_like or nested structure + The true labels or ground truth values. Can be a single array or a + nested structure of arrays. + y_pred : array_like or nested structure + The predicted labels or values. Should have the same structure as y_true. + + Returns: + -------- + float or nested structure + The computed accuracy. If the input is a nested structure, the output + will have the same structure with accuracy values for each leaf node. + """ + return jax.tree_util.tree_map(_accuracy, y_true, y_pred, is_leaf=u.math.is_quantity) + + +def _l2_relative_error(y_true, y_pred): + return u.linalg.norm(y_true - y_pred) / u.linalg.norm(y_true) + + +def l2_relative_error(y_true, y_pred): + """ + Computes L2 relative error across nested structures of labels and predictions. + + This function calculates the L2 relative error between true values and predicted values. + It can handle nested structures of data by applying the calculation to each leaf node. + + Parameters: + ----------- + y_true : array_like or nested structure + The true values or ground truth. Can be a single array or a nested structure of arrays. + y_pred : array_like or nested structure + The predicted values. Should have the same structure as y_true. + + Returns: + -------- + float or nested structure + The computed L2 relative error. If the input is a nested structure, the output + will have the same structure with L2 relative error values for each leaf node. + """ + return jax.tree_util.tree_map( + _l2_relative_error, y_true, y_pred, is_leaf=u.math.is_quantity + ) + + +def _nanl2_relative_error(y_true, y_pred): + err = y_true - y_pred + err = u.math.nan_to_num(err) + y_true = u.math.nan_to_num(y_true) + return u.linalg.norm(err) / u.linalg.norm(y_true) + + +def nanl2_relative_error(y_true, y_pred): + """ + Computes L2 relative error across nested structures of labels and predictions, + handling NaN values. + + This function calculates the L2 relative error between true values and predicted values, + treating NaN values as zeros. It can handle nested structures of data by applying + the calculation to each leaf node. + + Parameters: + ----------- + y_true : array_like or nested structure + The true values or ground truth. Can be a single array or a nested structure of arrays. + May contain NaN values. + y_pred : array_like or nested structure + The predicted values. Should have the same structure as y_true. + May contain NaN values. + + Returns: + -------- + float or nested structure + The computed L2 relative error with NaN handling. If the input is a nested structure, + the output will have the same structure with L2 relative error values for each leaf node. + """ + return jax.tree_util.tree_map( + _nanl2_relative_error, y_true, y_pred, is_leaf=u.math.is_quantity + ) + + +def _mean_l2_relative_error(y_true, y_pred): + return u.math.mean( + u.linalg.norm(y_true - y_pred, axis=1) / u.linalg.norm(y_true, axis=1) + ) + + +def mean_l2_relative_error(y_true, y_pred): + """ + Computes mean L2 relative error across nested structures of labels and predictions. + + This function calculates the mean L2 relative error between true values and predicted values. + It can handle nested structures of data by applying the calculation to each leaf node. + + Parameters: + ----------- + y_true : array_like or nested structure + The true values or ground truth. Can be a single array or a nested structure of arrays. + y_pred : array_like or nested structure + The predicted values. Should have the same structure as y_true. + + Returns: + -------- + float or nested structure + The computed mean L2 relative error. If the input is a nested structure, the output + will have the same structure with mean L2 relative error values for each leaf node. + """ + return jax.tree_util.tree_map( + _mean_l2_relative_error, y_true, y_pred, is_leaf=u.math.is_quantity + ) + + +def _absolute_percentage_error(y_true, y_pred): + return 100 * u.math.abs((y_true - y_pred) / u.math.abs(y_true)) + + +def mean_absolute_percentage_error(y_true, y_pred): + """ + Computes mean absolute percentage error across nested structures of labels and predictions. + + This function calculates the mean absolute percentage error between true values and predicted values. + It can handle nested structures of data by applying the calculation to each leaf node. + + Parameters: + ----------- + y_true : array_like or nested structure + The true values or ground truth. Can be a single array or a nested structure of arrays. + y_pred : array_like or nested structure + The predicted values. Should have the same structure as y_true. + + Returns: + -------- + float or nested structure + The computed mean absolute percentage error. If the input is a nested structure, the output + will have the same structure with mean absolute percentage error values for each leaf node. + """ + return jax.tree_util.tree_map( + lambda x, y: _absolute_percentage_error(x, y).mean(), + y_true, + y_pred, + is_leaf=u.math.is_quantity, + ) + + +def max_absolute_percentage_error(y_true, y_pred): + """ + Computes maximum absolute percentage error across nested structures of labels and predictions. + + This function calculates the maximum absolute percentage error between true values and predicted values. + It can handle nested structures of data by applying the calculation to each leaf node. + + Parameters: + ----------- + y_true : array_like or nested structure + The true values or ground truth. Can be a single array or a nested structure of arrays. + y_pred : array_like or nested structure + The predicted values. Should have the same structure as y_true. + + Returns: + -------- + float or nested structure + The computed maximum absolute percentage error. If the input is a nested structure, the output + will have the same structure with maximum absolute percentage error values for each leaf node. + """ + return jax.tree_util.tree_map( + lambda x, y: _absolute_percentage_error(x, y).max(), + y_true, + y_pred, + is_leaf=u.math.is_quantity, + ) + + +def absolute_percentage_error_std(y_true, y_pred): + """ + Computes standard deviation of absolute percentage error across nested structures of labels and predictions. + + This function calculates the standard deviation of the absolute percentage error between true values + and predicted values. It can handle nested structures of data by applying the calculation to each leaf node. + + Parameters: + ----------- + y_true : array_like or nested structure + The true values or ground truth. Can be a single array or a nested structure of arrays. + y_pred : array_like or nested structure + The predicted values. Should have the same structure as y_true. + + Returns: + -------- + float or nested structure + The computed standard deviation of absolute percentage error. If the input is a nested structure, + the output will have the same structure with standard deviation values for each leaf node. + """ + return jax.tree_util.tree_map( + lambda x, y: _absolute_percentage_error(x, y).std(), + y_true, + y_pred, + is_leaf=u.math.is_quantity, + ) + + +def _mean_squared_error(y_true, y_pred): + return u.math.mean(u.math.square(y_true - y_pred)) + + +def mean_squared_error(y_true, y_pred): + """ + Computes mean squared error across nested structures of labels and predictions. + + This function calculates the mean squared error between true values and predicted values. + It can handle nested structures of data by applying the calculation to each leaf node. + + Parameters: + ----------- + y_true : array_like or nested structure + The true values or ground truth. Can be a single array or a nested structure of arrays. + y_pred : array_like or nested structure + The predicted values. Should have the same structure as y_true. + + Returns: + -------- + float or nested structure + The computed mean squared error. If the input is a nested structure, the output + will have the same structure with mean squared error values for each leaf node. + """ + return jax.tree_util.tree_map( + _mean_squared_error, y_true, y_pred, is_leaf=u.math.is_quantity + ) + + +def get(identifier): + """ + Retrieves a metric function based on the provided identifier. + + This function maps string identifiers to their corresponding metric functions + or returns the function if a callable is provided directly. + + Parameters: + ----------- + identifier : str or callable + A string identifier for a predefined metric function or a callable metric function. + Accepted string identifiers include: + - "accuracy" + - "l2 relative error" + - "nanl2 relative error" + - "mean l2 relative error" + - "mean squared error" (also "MSE" or "mse") + - "MAPE" + - "max APE" + - "APE SD" + + Returns: + -------- + callable + The metric function corresponding to the provided identifier. + + Raises: + ------- + ValueError + If the provided identifier is neither a recognized string nor a callable. + """ + metric_identifier = { + "accuracy": accuracy, + "l2 relative error": l2_relative_error, + "nanl2 relative error": nanl2_relative_error, + "mean l2 relative error": mean_l2_relative_error, + "mean squared error": mean_squared_error, + "MSE": mean_squared_error, + "mse": mean_squared_error, + "MAPE": mean_absolute_percentage_error, + "max APE": max_absolute_percentage_error, + "APE SD": absolute_percentage_error_std, + } + + if isinstance(identifier, str): + return metric_identifier[identifier] + if callable(identifier): + return identifier + raise ValueError("Could not interpret metric function identifier:", identifier) diff --git a/deepxde/experimental/nn/__init__.py b/deepxde/experimental/nn/__init__.py new file mode 100644 index 000000000..e47881f7d --- /dev/null +++ b/deepxde/experimental/nn/__init__.py @@ -0,0 +1,34 @@ +"""The ``experimental.nn`` package contains framework-specific implementations for different +neural networks. + +Users can directly import ``experimental.nn.`` (e.g., ``experimental.nn.FNN``), and +the package will dispatch the network name to the actual implementation according to the +backend framework currently in use. + +Note that there are coverage differences among frameworks. If you encounter an +``AttributeError: module 'experimental.nn.XXX' has no attribute 'XXX'`` or ``ImportError: +cannot import name 'XXX' from 'experimental.nn.XXX'`` error, that means the network is not +available to the current backend. If you wish a module to appear in DeepXDE, please +create an issue. If you want to contribute a NN module, please create a pull request. +""" + +__all__ = [ + "DictToArray", + "ArrayToDict", + "Model", + "NN", + "FNN", + "DeepONet", + "DeepONetCartesianProd", + "MIONetCartesianProd", + "PFNN", + "PODDeepONet", + "PODMIONet", +] + +from .base import NN +from .convert import DictToArray, ArrayToDict +from .deeponet import DeepONet, DeepONetCartesianProd, PODDeepONet +from .fnn import FNN, PFNN +from .mionet import MIONetCartesianProd, PODMIONet +from .model import Model diff --git a/deepxde/experimental/nn/base.py b/deepxde/experimental/nn/base.py new file mode 100644 index 000000000..566dc1884 --- /dev/null +++ b/deepxde/experimental/nn/base.py @@ -0,0 +1,90 @@ +from typing import Optional, Callable + +import brainstate as bst +import jax.tree + + +class NN(bst.nn.Module): + """Base class for all neural network modules.""" + + def __init__( + self, + input_transform: Optional[Callable] = None, + output_transform: Optional[Callable] = None, + ): + """ + Initialize the NN class. + + Parameters: + ----------- + input_transform : Optional[Callable], default=None + A callable that transforms the input before it's passed to the network. + output_transform : Optional[Callable], default=None + A callable that transforms the output after it's produced by the network. + + Returns: + -------- + None + """ + super().__init__() + self.regularization = None + self._input_transform = input_transform + self._output_transform = output_transform + + def apply_feature_transform(self, transform): + """ + Compute the features by applying a transform to the network inputs. + + This method sets the input transform function, which is applied before + the input is passed to the network, i.e., ``features = transform(inputs)``. + Then, ``outputs = network(features)``. + + Parameters: + ----------- + transform : Callable + The transform function to be applied to the inputs. + + Returns: + -------- + None + """ + self._input_transform = transform + + def apply_output_transform(self, transform): + """ + Apply a transform to the network outputs. + + This method sets the output transform function, which is applied after + the network produces its output, i.e., ``outputs = transform(inputs, outputs)``. + + Parameters: + ----------- + transform : Callable + The transform function to be applied to the outputs. + + Returns: + -------- + None + """ + self._output_transform = transform + + def num_trainable_parameters(self): + """ + Evaluate the number of trainable parameters for the NN. + + This method calculates the total number of trainable parameters in the neural network + by iterating through all parameters in the network's state. + + Parameters: + ----------- + None + + Returns: + -------- + int + The total number of trainable parameters in the neural network. + """ + n_param = 0 + for key, val in self.states(bst.ParamState).items(): + n_param += [v.size for v in jax.tree_leaves(val)] + return n_param diff --git a/deepxde/experimental/nn/convert.py b/deepxde/experimental/nn/convert.py new file mode 100644 index 000000000..41a7e3154 --- /dev/null +++ b/deepxde/experimental/nn/convert.py @@ -0,0 +1,195 @@ +from typing import Dict + +import brainstate as bst +import brainunit as u + +__all__ = [ + "DictToArray", + "ArrayToDict", +] + + +def dict_to_array(d: Dict[str, bst.typing.ArrayLike], axis: int = 1): + """ + Convert a dictionary of array-like values to a single concatenated array. + + This function takes a dictionary where each value is an array-like object, + and concatenates all these arrays along the specified axis to create a + single output array. + + Args: + d (Dict[str, bst.typing.ArrayLike]): A dictionary where keys are strings + and values are array-like objects (e.g., numpy arrays, lists, etc.). + axis (int, optional): The axis along which the arrays should be concatenated. + Default is 1. + + Returns: + ndarray: A single array containing all the input arrays concatenated + along the specified axis. The order of concatenation is determined + by the order of the keys in the input dictionary. + + Example: + >>> d = {'a': [1, 2, 3], 'b': [4, 5, 6]} + >>> dict_to_array(d) + array([[1, 4], + [2, 5], + [3, 6]]) + """ + keys = tuple(d.keys()) + return u.math.stack([d[key] for key in keys], axis=axis) + + +class DictToArray(bst.nn.Module): + """ + DictToArray layer, scaling the input data according to the given units, and merging them into an array. + + This layer takes a dictionary of array-like inputs, scales them according to specified units, + and concatenates them into a single array along a specified axis. + + Args: + axis (int, optional): The axis along which to concatenate the input arrays. Defaults to -1. + **units: Keyword arguments specifying the units for each input. Each unit should be an + instance of ``brainunit.Unit`` or None. + + Attributes: + axis (int): The axis along which concatenation is performed. + units (dict): A dictionary mapping input keys to their corresponding units. + in_size (int): The number of input elements (length of units dictionary). + out_size (int): The number of output elements (same as in_size). + """ + + def __init__(self, axis: int = -1, **units): + super().__init__() + + # axis + assert isinstance( + axis, int + ), f"DictToArray axis must be an integer. Please check the input values." + self.axis = axis + + # unit scale + self.units = units + for val in units.values(): + assert isinstance(val, u.Unit) or val is None, ( + f"DictToArray values must be a unit or None. " + "Please check the input values." + ) + + self.in_size = len(units) + self.out_size = len(units) + + def update(self, x: Dict[str, bst.typing.ArrayLike]): + """ + Scales the input dictionary values according to their units and concatenates them into an array. + + Args: + x (Dict[str, bst.typing.ArrayLike]): A dictionary of input arrays to be scaled and concatenated. + The keys should match those specified in the units dictionary during initialization. + + Returns: + ndarray: A single array containing all the scaled input arrays concatenated along the specified axis. + + Raises: + AssertionError: If the input dictionary keys don't match the units dictionary keys, + or if the input values are not of the expected type (Quantity or dimensionless). + """ + assert set(x.keys()) == set(self.units.keys()), ( + f"DictToArray keys mismatch. " + f"{set(x.keys())} != {set(self.units.keys())}." + ) + + # scale the input + x_dict = dict() + for key in self.units.keys(): + val = x[key] + if isinstance(self.units[key], u.Unit): + assert ( + isinstance(val, u.Quantity) + or self.units[key].dim == u.DIMENSIONLESS + ), ( + f"DictToArray values must be a quantity. " + "Please check the input values." + ) + x_dict[key] = ( + val.to_decimal(self.units[key]) + if isinstance(val, u.Quantity) + else val + ) + else: + x_dict[key] = u.maybe_decimal(val) + + # convert to array + arr = dict_to_array(x_dict, axis=self.axis) + return arr + + +class ArrayToDict(bst.nn.Module): + """ + Output layer, splitting the output data into a dict and assign the corresponding units. + + This class takes an input array and splits it into a dictionary, where each key-value pair + represents a specific output with its corresponding unit. + + Args: + axis (int, optional): The axis along which to split the output data. Defaults to -1. + **units: Keyword arguments specifying the units for each output. Each unit should be an + instance of ``brainunit.Unit`` or None. + + Attributes: + axis (int): The axis along which splitting is performed. + units (dict): A dictionary mapping output keys to their corresponding units. + in_size (int): The number of input elements (length of units dictionary). + out_size (int): The number of output elements (same as in_size). + """ + + def __init__(self, axis: int = -1, **units): + super().__init__() + + assert isinstance(axis, int), f"Output axis must be an integer. " + self.axis = axis + self.units = units + for val in units.values(): + assert isinstance(val, u.Unit) or val is None, ( + f"Input values must be a unit or None. " + "Please check the input values." + ) + self.in_size = len(units) + self.out_size = len(units) + + def update(self, arr: bst.typing.ArrayLike) -> Dict[str, bst.typing.ArrayLike]: + """ + Splits the input array into a dictionary and assigns the corresponding units. + + This method takes an input array, splits it along the specified axis, and creates + a dictionary where each key-value pair represents a specific output with its + corresponding unit. + + Args: + arr (bst.typing.ArrayLike): The input array to be split and converted into a dictionary. + + Returns: + Dict[str, bst.typing.ArrayLike]: A dictionary where keys are the output names and + values are the corresponding split arrays, potentially with units applied. + + Raises: + AssertionError: If the shape of the input array along the specified axis doesn't + match the number of units provided during initialization. + """ + assert arr.shape[self.axis] == len(self.units), ( + f"The number of columns of x must be " + f"equal to the number of units. " + f"Got {arr.shape[self.axis]} != {len(self.units)}. " + "Please check the input values." + ) + shape = list(arr.shape) + shape.pop(self.axis) + xs = u.math.split(arr, len(self.units), axis=self.axis) + + keys = tuple(self.units.keys()) + units = tuple(self.units.values()) + res = dict() + for key, unit, x in zip(keys, units, xs): + res[key] = u.math.squeeze(x, axis=self.axis) + if unit is not None: + res[key] *= unit + return res diff --git a/deepxde/experimental/nn/deeponet.py b/deepxde/experimental/nn/deeponet.py new file mode 100644 index 000000000..75b3087eb --- /dev/null +++ b/deepxde/experimental/nn/deeponet.py @@ -0,0 +1,347 @@ +from typing import Union, Callable, Sequence, Dict, Optional + +import brainstate as bst +import brainunit as u + +from deepxde.nn.deeponet_strategy import ( + DeepONetStrategy, + SingleOutputStrategy, + IndependentStrategy, + SplitBothStrategy, + SplitBranchStrategy, + SplitTrunkStrategy, +) +from deepxde.experimental.utils import get_activation +from .base import NN +from .fnn import FNN + +strategies = { + None: SingleOutputStrategy, + "independent": IndependentStrategy, + "split_both": SplitBothStrategy, + "split_branch": SplitBranchStrategy, + "split_trunk": SplitTrunkStrategy, +} + +__all__ = ["DeepONet", "DeepONetCartesianProd", "PODDeepONet"] + + +class DeepONet(NN): + """ + Deep operator network. + + `Lu et al. Learning nonlinear operators via DeepONet based on the universal + approximation theorem of operators. Nat Mach Intell, 2021. + `_ + + Args: + layer_sizes_branch: A list of integers as the width of a fully connected network, + or `(dim, f)` where `dim` is the input dimension and `f` is a network + function. The width of the last layer in the branch and trunk net + should be the same for all strategies except "split_branch" and "split_trunk". + layer_sizes_trunk (list): A list of integers as the width of a fully connected + network. + activation: If `activation` is a ``string``, then the same activation is used in + both trunk and branch nets. If `activation` is a ``dict``, then the trunk + net uses the activation `activation["trunk"]`, and the branch net uses + `activation["branch"]`. + num_outputs (integer): Number of outputs. In case of multiple outputs, i.e., `num_outputs` > 1, + `multi_output_strategy` below should be set. + multi_output_strategy (str or None): ``None``, "independent", "split_both", "split_branch" or + "split_trunk". It makes sense to set in case of multiple outputs. + + - None + Classical implementation of DeepONet with a single output. + Cannot be used with `num_outputs` > 1. + + - independent + Use `num_outputs` independent DeepONets, and each DeepONet outputs only + one function. + + - split_both + Split the outputs of both the branch net and the trunk net into `num_outputs` + groups, and then the kth group outputs the kth solution. + + - split_branch + Split the branch net and share the trunk net. The width of the last layer + in the branch net should be equal to the one in the trunk net multiplied + by the number of outputs. + + - split_trunk + Split the trunk net and share the branch net. The width of the last layer + in the trunk net should be equal to the one in the branch net multiplied + by the number of outputs. + """ + + def __init__( + self, + layer_sizes_branch: Sequence[int], + layer_sizes_trunk: Sequence[int], + activation: Union[str, Callable, Dict[str, str], Dict[str, Callable]], + kernel_initializer: bst.init.Initializer = bst.init.KaimingUniform(), + num_outputs: int = 1, + multi_output_strategy=None, + input_transform: Optional[Callable] = None, + output_transform: Optional[Callable] = None, + ): + super().__init__( + input_transform=input_transform, output_transform=output_transform + ) + + # activation function + if isinstance(activation, dict): + self.activation_branch = get_activation(activation["branch"]) + self.activation_trunk = get_activation(activation["trunk"]) + else: + self.activation_branch = self.activation_trunk = get_activation(activation) + + # initialize kernel + self.kernel_initializer = kernel_initializer + + self.num_outputs = num_outputs + if self.num_outputs == 1: + if multi_output_strategy is not None: + raise ValueError( + "num_outputs is set to 1, but multi_output_strategy is not None." + ) + elif multi_output_strategy is None: + multi_output_strategy = "independent" + print( + f"Warning: There are {num_outputs} outputs, but no multi_output_strategy selected. " + 'Use "independent" as the multi_output_strategy.' + ) + self.multi_output_strategy: DeepONetStrategy = strategies[ + multi_output_strategy + ](self) + + self.branch, self.trunk = self.multi_output_strategy.build( + layer_sizes_branch, layer_sizes_trunk + ) + self.b = bst.ParamState([0.0 for _ in range(self.num_outputs)]) + + def build_branch_net(self, layer_sizes_branch) -> FNN: + # User-defined network + if callable(layer_sizes_branch[1]): + return layer_sizes_branch[1] + # Fully connected network + return FNN(layer_sizes_branch, self.activation_branch, self.kernel_initializer) + + def build_trunk_net(self, layer_sizes_trunk) -> FNN: + return FNN(layer_sizes_trunk, self.activation_trunk, self.kernel_initializer) + + def merge_branch_trunk(self, x_func, x_loc, index): + y = u.math.sum(x_func * x_loc, axis=-1, keepdims=True) + y += self.b.value[index] + return y + + @staticmethod + def concatenate_outputs(ys): + return u.math.concatenate(ys, axis=1) + + def update(self, inputs): + x_func = inputs[0] + x_loc = inputs[1] + # Trunk net input transform + if self._input_transform is not None: + x_loc = self._input_transform(x_loc) + x = self.multi_output_strategy.call(x_func, x_loc) + if self._output_transform is not None: + x = self._output_transform(inputs, x) + return x + + +class DeepONetCartesianProd(NN): + """ + Deep operator network for dataset in the format of Cartesian product. + + Args: + layer_sizes_branch: A list of integers as the width of a fully connected network, + or `(dim, f)` where `dim` is the input dimension and `f` is a network + function. The width of the last layer in the branch and trunk net + should be the same for all strategies except "split_branch" and "split_trunk". + layer_sizes_trunk (list): A list of integers as the width of a fully connected + network. + activation: If `activation` is a ``string``, then the same activation is used in + both trunk and branch nets. If `activation` is a ``dict``, then the trunk + net uses the activation `activation["trunk"]`, and the branch net uses + `activation["branch"]`. + num_outputs (integer): Number of outputs. In case of multiple outputs, i.e., `num_outputs` > 1, + `multi_output_strategy` below should be set. + multi_output_strategy (str or None): ``None``, "independent", "split_both", "split_branch" or + "split_trunk". It makes sense to set in case of multiple outputs. + + - None + Classical implementation of DeepONet with a single output. + Cannot be used with `num_outputs` > 1. + + - independent + Use `num_outputs` independent DeepONets, and each DeepONet outputs only + one function. + + - split_both + Split the outputs of both the branch net and the trunk net into `num_outputs` + groups, and then the kth group outputs the kth solution. + + - split_branch + Split the branch net and share the trunk net. The width of the last layer + in the branch net should be equal to the one in the trunk net multiplied + by the number of outputs. + + - split_trunk + Split the trunk net and share the branch net. The width of the last layer + in the trunk net should be equal to the one in the branch net multiplied + by the number of outputs. + """ + + def __init__( + self, + layer_sizes_branch: Sequence[int], + layer_sizes_trunk: Sequence[int], + activation: Union[str, Callable, Dict[str, str], Dict[str, Callable]], + kernel_initializer: bst.init.Initializer = bst.init.KaimingUniform(), + num_outputs: int = 1, + multi_output_strategy=None, + input_transform: Optional[Callable] = None, + output_transform: Optional[Callable] = None, + ): + super().__init__( + input_transform=input_transform, output_transform=output_transform + ) + if isinstance(activation, dict): + self.activation_branch = activation["branch"] + self.activation_trunk = get_activation(activation["trunk"]) + else: + self.activation_branch = self.activation_trunk = get_activation(activation) + self.kernel_initializer = kernel_initializer + + self.num_outputs = num_outputs + if self.num_outputs == 1: + if multi_output_strategy is not None: + raise ValueError( + "num_outputs is set to 1, but multi_output_strategy is not None." + ) + elif multi_output_strategy is None: + multi_output_strategy = "independent" + print( + f"Warning: There are {num_outputs} outputs, but no multi_output_strategy selected. " + 'Use "independent" as the multi_output_strategy.' + ) + self.multi_output_strategy = strategies[multi_output_strategy](self) + + self.branch, self.trunk = self.multi_output_strategy.build( + layer_sizes_branch, layer_sizes_trunk + ) + self.b = bst.ParamState([0.0 for _ in range(self.num_outputs)]) + + def build_branch_net(self, layer_sizes_branch): + # User-defined network + if callable(layer_sizes_branch[1]): + return layer_sizes_branch[1] + # Fully connected network + return FNN(layer_sizes_branch, self.activation_branch, self.kernel_initializer) + + def build_trunk_net(self, layer_sizes_trunk): + return FNN(layer_sizes_trunk, self.activation_trunk, self.kernel_initializer) + + def merge_branch_trunk(self, x_func, x_loc, index): + y = u.math.einsum("bi,ni->bn", x_func, x_loc) + y += self.b.value[index] + return y + + @staticmethod + def concatenate_outputs(ys): + return u.math.stack(ys, axis=2) + + def update(self, inputs): + x_func = inputs[0] + x_loc = inputs[1] + # Trunk net input transform + if self._input_transform is not None: + x_loc = self._input_transform(x_loc) + x = self.multi_output_strategy.call(x_func, x_loc) + if self._output_transform is not None: + x = self._output_transform(inputs, x) + return x if x.ndim == 3 else x[..., None] + + +class PODDeepONet(NN): + """ + Deep operator network with proper orthogonal decomposition (POD) for dataset in + the format of Cartesian product. + + Args: + pod_basis: POD basis used in the trunk net. + layer_sizes_branch: A list of integers as the width of a fully connected network, + or `(dim, f)` where `dim` is the input dimension and `f` is a network + function. The width of the last layer in the branch and trunk net should be + equal. + activation: If `activation` is a ``string``, then the same activation is used in + both trunk and branch nets. If `activation` is a ``dict``, then the trunk + net uses the activation `activation["trunk"]`, and the branch net uses + `activation["branch"]`. + layer_sizes_trunk (list): A list of integers as the width of a fully connected + network. If ``None``, then only use POD basis as the trunk net. + + References: + `L. Lu, X. Meng, S. Cai, Z. Mao, S. Goswami, Z. Zhang, & G. E. Karniadakis. A + comprehensive and fair comparison of two neural operators (with practical + extensions) based on FAIR data. arXiv preprint arXiv:2111.05512, 2021 + `_. + """ + + def __init__( + self, + pod_basis, + layer_sizes_branch: Sequence[int], + activation: Union[str, Callable, Dict[str, str], Dict[str, Callable]], + kernel_initializer: bst.init.Initializer = bst.init.KaimingUniform(), + layer_sizes_trunk: Sequence[int] = None, + regularization=None, + input_transform: Optional[Callable] = None, + output_transform: Optional[Callable] = None, + ): + super().__init__( + input_transform=input_transform, output_transform=output_transform + ) + self.regularization = regularization # TODO: currently unused + self.pod_basis = pod_basis + if isinstance(activation, dict): + activation_branch = activation["branch"] + self.activation_trunk = get_activation(activation["trunk"]) + else: + activation_branch = self.activation_trunk = get_activation(activation) + + if callable(layer_sizes_branch[1]): + # User-defined network + self.branch = layer_sizes_branch[1] + else: + # Fully connected network + self.branch = FNN(layer_sizes_branch, activation_branch, kernel_initializer) + + self.trunk = None + if layer_sizes_trunk is not None: + self.trunk = FNN( + layer_sizes_trunk, self.activation_trunk, kernel_initializer + ) + self.b = bst.ParamState(0.0) + + def forward(self, inputs): + x_func = inputs[0] + x_loc = inputs[1] + + # Branch net to encode the input function + x_func = self.branch(x_func) + # Trunk net to encode the domain of the output function + if self.trunk is None: + # POD only + x = u.math.einsum("bi,ni->bn", x_func, self.pod_basis) + else: + x_loc = self.activation_trunk(self.trunk(x_loc)) + x = u.math.einsum( + "bi,ni->bn", x_func, u.math.concatenate((self.pod_basis, x_loc), axis=1) + ) + x += self.b.value + + if self._output_transform is not None: + x = self._output_transform(inputs, x) + return x diff --git a/deepxde/experimental/nn/fnn.py b/deepxde/experimental/nn/fnn.py new file mode 100644 index 000000000..70bfb8e7f --- /dev/null +++ b/deepxde/experimental/nn/fnn.py @@ -0,0 +1,249 @@ +from typing import Union, Callable, Sequence, Optional + +import brainstate as bst +import brainunit as u + +from deepxde.experimental.utils import get_activation +from .base import NN + + +class FNN(NN): + """ + Fully-connected neural network. + + This class implements a fully-connected neural network with customizable layer sizes, + activation functions, and optional input/output transformations. + + Args: + layer_sizes (Sequence[int]): A sequence of integers defining the number of neurons + in each layer, including input and output layers. + activation (Union[str, Callable, Sequence[str], Sequence[Callable]]): Activation + function(s) to use. Can be a single string/callable for all layers, or a + sequence of strings/callables for each layer. + kernel_initializer (bst.init.Initializer, optional): Initializer for the layer weights. + Defaults to bst.init.KaimingUniform(). + input_transform (Optional[Callable], optional): A function to transform the input + before passing it through the network. Defaults to None. + output_transform (Optional[Callable], optional): A function to transform the output + of the network. Defaults to None. + + Raises: + ValueError: If the number of activation functions doesn't match the number of layers + when a sequence of activations is provided. + """ + + def __init__( + self, + layer_sizes: Sequence[int], + activation: Union[str, Callable, Sequence[str], Sequence[Callable]], + kernel_initializer: bst.init.Initializer = bst.init.KaimingUniform(), + input_transform: Optional[Callable] = None, + output_transform: Optional[Callable] = None, + ): + super().__init__( + input_transform=input_transform, output_transform=output_transform + ) + + # activations + if isinstance(activation, (list, tuple)): + if not (len(layer_sizes) - 1) == len(activation): + raise ValueError( + "Total number of activation functions do not match with " + "sum of hidden layers and output layer!" + ) + self.activation = list(map(get_activation, activation)) + else: + self.activation = get_activation(activation) + + # layers + self.layers = [] + for i in range(1, len(layer_sizes)): + self.layers.append( + bst.nn.Linear( + layer_sizes[i - 1], layer_sizes[i], w_init=kernel_initializer + ) + ) + + # output transform + if output_transform is not None: + self.apply_output_transform(output_transform) + + def update(self, inputs): + """ + Perform a forward pass through the neural network. + + This method applies the input transformation (if any), passes the input through + all layers of the network applying activations, and then applies the output + transformation (if any). + + Args: + inputs: The input data to be passed through the network. + + Returns: + The output of the neural network after processing the inputs. + """ + x = inputs + if self._input_transform is not None: + x = self._input_transform(x) + for j, linear in enumerate(self.layers[:-1]): + x = ( + self.activation[j](linear(x)) + if isinstance(self.activation, list) + else self.activation(linear(x)) + ) + x = self.layers[-1](x) + if self._output_transform is not None: + x = self._output_transform(inputs, x) + return x + + +class PFNN(NN): + """ + Parallel fully-connected network that uses independent sub-networks for each + network output. + + This class implements a parallel fully-connected neural network where each output + can have its own independent sub-network. This allows for more flexibility in + network architecture, especially when different outputs require different levels + of complexity. + + Args: + layer_sizes (Sequence[int]): A nested list that defines the architecture of the neural network + (how the layers are connected). If `layer_sizes[i]` is an int, it represents + one layer shared by all the outputs; if `layer_sizes[i]` is a list, it + represents `len(layer_sizes[i])` sub-layers, each of which is exclusively + used by one output. Note that `len(layer_sizes[i])` should equal the number + of outputs. Every number specifies the number of neurons in that layer. + activation (Union[str, Callable, Sequence[str], Sequence[Callable]]): Activation + function(s) to use. Can be a single string/callable for all layers, or a + sequence of strings/callables for each layer. + kernel_initializer (bst.init.Initializer, optional): Initializer for the layer weights. + Defaults to bst.init.KaimingUniform(). + input_transform (Optional[Callable], optional): A function to transform the input + before passing it through the network. Defaults to None. + output_transform (Optional[Callable], optional): A function to transform the output + of the network. Defaults to None. + + Raises: + ValueError: If the layer sizes are not properly specified or if the number of + sub-layers doesn't match the number of outputs. + """ + + def __init__( + self, + layer_sizes: Sequence[int], + activation: Union[str, Callable, Sequence[str], Sequence[Callable]], + kernel_initializer: bst.init.Initializer = bst.init.KaimingUniform(), + input_transform: Optional[Callable] = None, + output_transform: Optional[Callable] = None, + ): + super().__init__( + input_transform=input_transform, output_transform=output_transform + ) + self.activation = get_activation(activation) + + if len(layer_sizes) <= 1: + raise ValueError("must specify input and output sizes") + if not isinstance(layer_sizes[0], int): + raise ValueError("input size must be integer") + if not isinstance(layer_sizes[-1], int): + raise ValueError("output size must be integer") + + n_output = layer_sizes[-1] + + self.layers = [] + for i in range(1, len(layer_sizes) - 1): + prev_layer_size = layer_sizes[i - 1] + curr_layer_size = layer_sizes[i] + if isinstance(curr_layer_size, (list, tuple)): + if len(curr_layer_size) != n_output: + raise ValueError( + "number of sub-layers should equal number of network outputs" + ) + if isinstance(prev_layer_size, (list, tuple)): + # e.g. [8, 8, 8] -> [16, 16, 16] + self.layers.append( + [ + bst.nn.Linear( + prev_layer_size[j], + curr_layer_size[j], + w_init=kernel_initializer, + ) + for j in range(n_output) + ] + ) + else: # e.g. 64 -> [8, 8, 8] + self.layers.append( + [ + bst.nn.Linear( + prev_layer_size, + curr_layer_size[j], + w_init=kernel_initializer, + ) + for j in range(n_output) + ] + ) + else: # e.g. 64 -> 64 + if not isinstance(prev_layer_size, int): + raise ValueError( + "cannot rejoin parallel subnetworks after splitting" + ) + self.layers.append( + bst.nn.Linear( + prev_layer_size, curr_layer_size, w_init=kernel_initializer + ) + ) + + # output layers + if isinstance(layer_sizes[-2], (list, tuple)): # e.g. [3, 3, 3] -> 3 + self.layers.append( + [ + bst.nn.Linear(layer_sizes[-2][j], 1, w_init=kernel_initializer) + for j in range(n_output) + ] + ) + else: + self.layers.append( + bst.nn.Linear(layer_sizes[-2], n_output, w_init=kernel_initializer) + ) + + def update(self, inputs): + """ + Perform a forward pass through the parallel fully-connected neural network. + + This method applies the input transformation (if any), passes the input through + all layers of the network applying activations, and then applies the output + transformation (if any). It handles both shared layers and parallel sub-networks. + + Args: + inputs: The input data to be passed through the network. + + Returns: + The output of the neural network after processing the inputs. The shape of the + output depends on the network architecture defined in the constructor. + """ + + x = inputs + if self._input_transform is not None: + x = self._input_transform(x) + + for layer in self.layers[:-1]: + if isinstance(layer, list): + if isinstance(x, list): + x = [self.activation(f(x_)) for f, x_ in zip(layer, x)] + else: + x = [self.activation(f(x)) for f in layer] + else: + x = self.activation(layer(x)) + + # output layers + if isinstance(x, list): + x = u.math.concatenate( + [f(x_) for f, x_ in zip(self.layers[-1], x)], axis=-1 + ) + else: + x = self.layers[-1](x) + + if self._output_transform is not None: + x = self._output_transform(inputs, x) + return x diff --git a/deepxde/experimental/nn/mionet.py b/deepxde/experimental/nn/mionet.py new file mode 100644 index 000000000..59623ce57 --- /dev/null +++ b/deepxde/experimental/nn/mionet.py @@ -0,0 +1,269 @@ +from typing import Optional, Callable + +import brainstate as bst +import brainunit as u + +from deepxde.experimental.utils import get_activation +from .base import NN +from .fnn import FNN + + +class MIONetCartesianProd(NN): + """ + MIONet with two input functions for Cartesian product format. + """ + + def __init__( + self, + layer_sizes_branch1, + layer_sizes_branch2, + layer_sizes_trunk, + activation, + kernel_initializer, + regularization=None, + trunk_last_activation=False, + merge_operation="mul", + layer_sizes_merger=None, + output_merge_operation="mul", + layer_sizes_output_merger=None, + input_transform: Optional[Callable] = None, + output_transform: Optional[Callable] = None, + ): + super().__init__( + input_transform=input_transform, output_transform=output_transform + ) + + if isinstance(activation, dict): + self.activation_branch1 = get_activation(activation["branch1"]) + self.activation_branch2 = get_activation(activation["branch2"]) + self.activation_trunk = get_activation(activation["trunk"]) + else: + self.activation_branch1 = self.activation_branch2 = ( + self.activation_trunk + ) = get_activation(activation) + if callable(layer_sizes_branch1[1]): + # User-defined network + self.branch1 = layer_sizes_branch1[1] + else: + # Fully connected network + self.branch1 = FNN( + layer_sizes_branch1, self.activation_branch1, kernel_initializer + ) + if callable(layer_sizes_branch2[1]): + # User-defined network + self.branch2 = layer_sizes_branch2[1] + else: + # Fully connected network + self.branch2 = FNN( + layer_sizes_branch2, self.activation_branch2, kernel_initializer + ) + if layer_sizes_merger is not None: + self.activation_merger = get_activation(activation["merger"]) + if callable(layer_sizes_merger[1]): + # User-defined network + self.merger = layer_sizes_merger[1] + else: + # Fully connected network + self.merger = FNN( + layer_sizes_merger, self.activation_merger, kernel_initializer + ) + else: + self.merger = None + if layer_sizes_output_merger is not None: + self.activation_output_merger = get_activation(activation["output merger"]) + if callable(layer_sizes_output_merger[1]): + # User-defined network + self.output_merger = layer_sizes_output_merger[1] + else: + # Fully connected network + self.output_merger = FNN( + layer_sizes_output_merger, + self.activation_output_merger, + kernel_initializer, + ) + else: + self.output_merger = None + self.trunk = FNN(layer_sizes_trunk, self.activation_trunk, kernel_initializer) + self.b = bst.ParamState(0.0) + self.regularizer = regularization + self.trunk_last_activation = trunk_last_activation + self.merge_operation = merge_operation + self.output_merge_operation = output_merge_operation + + def update(self, inputs): + x_func1 = inputs[0] + x_func2 = inputs[1] + x_loc = inputs[2] + # Branch net to encode the input function + y_func1 = self.branch1(x_func1) + y_func2 = self.branch2(x_func2) + if self.merge_operation == "cat": + x_merger = u.math.concatenate((y_func1, y_func2), axis=-1) + else: + if y_func1.shape[-1] != y_func2.shape[-1]: + raise AssertionError( + "Output sizes of branch1 net and branch2 net do not match." + ) + if self.merge_operation == "add": + x_merger = y_func1 + y_func2 + elif self.merge_operation == "mul": + x_merger = u.math.multiply(y_func1, y_func2) + else: + raise NotImplementedError( + f"{self.merge_operation} operation to be implemented" + ) + # Optional merger net + if self.merger is not None: + y_func = self.merger(x_merger) + else: + y_func = x_merger + # Trunk net to encode the domain of the output function + if self._input_transform is not None: + x_loc = self._input_transform(x_loc) + y_loc = self.trunk(x_loc) + if self.trunk_last_activation: + y_loc = self.activation_trunk(y_loc) + # Dot product + if y_func.shape[-1] != y_loc.shape[-1]: + raise AssertionError( + "Output sizes of merger net and trunk net do not match." + ) + # output merger net + if self.output_merger is None: + y = u.math.einsum("ip,jp->ij", y_func, y_loc) + else: + y_func = y_func[:, None, :] + y_loc = y_loc[None, :] + if self.output_merge_operation == "mul": + y = u.math.multiply(y_func, y_loc) + elif self.output_merge_operation == "add": + y = y_func + y_loc + elif self.output_merge_operation == "cat": + y_func = y_func.repeat(1, y_loc.shape[1], 1) + y_loc = y_loc.repeat(y_func.shape[0], 1, 1) + y = u.math.concatenate((y_func, y_loc), axis=2) + shape0 = y.shape[0] + shape1 = y.shape[1] + y = y.reshape(shape0 * shape1, -1) + y = self.output_merger(y) + y = y.reshape(shape0, shape1) + # Add bias + y += self.b + if self._output_transform is not None: + y = self._output_transform(inputs, y) + return y + + +class PODMIONet(NN): + """MIONet with two input functions and proper orthogonal decomposition (POD) + for Cartesian product format.""" + + def __init__( + self, + pod_basis, + layer_sizes_branch1, + layer_sizes_branch2, + activation, + kernel_initializer, + layer_sizes_trunk=None, + regularization=None, + trunk_last_activation=False, + merge_operation="mul", + layer_sizes_merger=None, + input_transform: Optional[Callable] = None, + output_transform: Optional[Callable] = None, + ): + super().__init__( + input_transform=input_transform, output_transform=output_transform + ) + + if isinstance(activation, dict): + self.activation_branch1 = get_activation(activation["branch1"]) + self.activation_branch2 = get_activation(activation["branch2"]) + self.activation_trunk = get_activation(activation["trunk"]) + self.activation_merger = get_activation(activation["merger"]) + else: + self.activation_branch1 = self.activation_branch2 = ( + self.activation_trunk + ) = get_activation(activation) + self.pod_basis = pod_basis + if callable(layer_sizes_branch1[1]): + # User-defined network + self.branch1 = layer_sizes_branch1[1] + else: + # Fully connected network + self.branch1 = FNN( + layer_sizes_branch1, self.activation_branch1, kernel_initializer + ) + if callable(layer_sizes_branch2[1]): + # User-defined network + self.branch2 = layer_sizes_branch2[1] + else: + # Fully connected network + self.branch2 = FNN( + layer_sizes_branch2, self.activation_branch2, kernel_initializer + ) + if layer_sizes_merger is not None: + if callable(layer_sizes_merger[1]): + # User-defined network + self.merger = layer_sizes_merger[1] + else: + # Fully connected network + self.merger = FNN( + layer_sizes_merger, self.activation_merger, kernel_initializer + ) + else: + self.merger = None + self.trunk = None + if layer_sizes_trunk is not None: + self.trunk = FNN( + layer_sizes_trunk, self.activation_trunk, kernel_initializer + ) + self.b = bst.ParamState(0.0) + self.regularizer = regularization + self.trunk_last_activation = trunk_last_activation + self.merge_operation = merge_operation + + def update(self, inputs): + x_func1 = inputs[0] + x_func2 = inputs[1] + x_loc = inputs[2] + # Branch net to encode the input function + y_func1 = self.branch1(x_func1) + y_func2 = self.branch2(x_func2) + # connect two branch outputs + if self.merge_operation == "cat": + x_merger = u.math.concatenate((y_func1, y_func2), 1) + else: + if y_func1.shape[-1] != y_func2.shape[-1]: + raise AssertionError( + "Output sizes of branch1 net and branch2 net do not match." + ) + if self.merge_operation == "add": + x_merger = y_func1 + y_func2 + elif self.merge_operation == "mul": + x_merger = u.math.multiply(y_func1, y_func2) + else: + raise NotImplementedError( + f"{self.merge_operation} operation to be implemented" + ) + # Optional merger net + if self.merger is not None: + y_func = self.merger(x_merger) + else: + y_func = x_merger + # Dot product + if self.trunk is None: + # POD only + y = u.math.einsum("bi,ni->bn", y_func, self.pod_basis) + else: + y_loc = self.trunk(x_loc) + if self.trunk_last_activation: + y_loc = self.activation_trunk(y_loc) + y = u.math.einsum( + "bi,ni->bn", y_func, u.math.concatenate((self.pod_basis, y_loc), axis=1) + ) + y += self.b + if self._output_transform is not None: + y = self._output_transform(inputs, y) + return y diff --git a/deepxde/experimental/nn/model.py b/deepxde/experimental/nn/model.py new file mode 100644 index 000000000..336244e71 --- /dev/null +++ b/deepxde/experimental/nn/model.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +from typing import Dict, Sequence + +import brainstate as bst + +from deepxde.experimental.grad import jacobian, hessian, gradient +from .convert import DictToArray, ArrayToDict + +__all__ = [ + "Model", +] + + +class Model(bst.nn.Module): + """ + A neural network approximator. + + Args: + input: The input check. + approx: The neural network model. + output: The output unit. + + """ + + def __init__( + self, + input: DictToArray, + approx: bst.nn.Module, + output: ArrayToDict, + *args, + ): + """ + Initialize the Model. + + Args: + input (DictToArray): The input converter that transforms dictionary inputs to arrays. + approx (bst.nn.Module): The neural network model used for approximation. + output (ArrayToDict): The output converter that transforms array outputs to dictionaries. + *args: Additional arguments (not used). + + Raises: + AssertionError: If input is not an instance of DictToArray, approx is not an instance of bst.nn.Module, + or output is not an instance of ArrayToDict. + """ + super().__init__() + + assert isinstance( + input, DictToArray + ), "input must be an instance of DictToArray." + self.input = input + + assert isinstance( + approx, bst.nn.Module + ), "approx must be an instance of nn.Module." + self.approx = approx + + assert isinstance(output, ArrayToDict), "output must be an instance of Output." + self.output = output + + @bst.compile.jit(static_argnums=(0,)) + def update(self, x): + """ + Update the model by passing input through the neural network. + + Args: + x: The input data to be processed. + + Returns: + The output of the neural network after passing through input conversion, + approximation, and output conversion stages. + """ + return self.output(self.approx(self.input(x))) + + def jacobian( + self, + inputs: Dict[str, bst.typing.ArrayLike], + y: str | Sequence[str] | None = None, + x: str | Sequence[str] | None = None, + ): + """ + Compute the Jacobian of the approximation neural networks. + + Args: + inputs: The input data. + y: The output variables. + x: The input variables. + + Returns: + The Jacobian of the approximation neural networks. + """ + return jacobian(self, inputs, y=y, x=x) + + def hessian( + self, + inputs: Dict[str, bst.typing.ArrayLike], + y: str | Sequence[str] | None = None, + xi: str | Sequence[str] | None = None, + xj: str | Sequence[str] | None = None, + ): + """ + Compute the Hessian of the approximator. + + Compute: `H[y][xi][xj] = d^2y / dxi dxj = d^2y / dxj dxi` + + Args: + inputs: The input data. + y: The output variables. + xi: The first input variables. + xj: The second input variables. + + Returns: + The Hessian of the approximator. + """ + return hessian(self, inputs, y=y, xi=xi, xj=xj) + + def gradient( + self, + inputs: Dict[str, bst.typing.ArrayLike], + order: int, + y: str | Sequence[str] | None = None, + *xi: str | Sequence[str] | None, + ): + """ + Compute the gradient of the approximator. + + Args: + inputs: The input data. + order: The order of the gradient. + y: The output variables. + xi: The input variables. + + Returns: + The gradient of the approximator. + """ + assert ( + isinstance(order, int) and order >= 1 + ), "order must be an integer greater than or equal to 1." + return gradient(self, inputs, y, *xi, order=order) diff --git a/deepxde/experimental/problem/__init__.py b/deepxde/experimental/problem/__init__.py new file mode 100644 index 000000000..735fe9e73 --- /dev/null +++ b/deepxde/experimental/problem/__init__.py @@ -0,0 +1,25 @@ +__all__ = [ + "Problem", + "DataSet", + "Function", + "QuadrupleDataset", + "TripleDataset", + "TripleCartesianProd", + "IDE", + "PDE", + "TimePDE", + "FPDE", + "TimeFPDE", + "PDEOperator", + "PDEOperatorCartesianProd", +] + +from .base import Problem +from .dataset_function import Function +from .dataset_general import DataSet +from .dataset_quadruple import QuadrupleDataset +from .dataset_triple import TripleDataset, TripleCartesianProd +from .fpde import FPDE, TimeFPDE +from .ide import IDE +from .pde import PDE, TimePDE +from .pde_operator import PDEOperator, PDEOperatorCartesianProd diff --git a/deepxde/experimental/problem/base.py b/deepxde/experimental/problem/base.py new file mode 100644 index 000000000..aaa5552c8 --- /dev/null +++ b/deepxde/experimental/problem/base.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +import abc +from typing import Callable, Sequence, Any, Tuple + +import brainstate as bst +import jax + +from deepxde.experimental.utils.losses import get_loss + +Inputs = Any +Targets = Any +Auxiliary = Any +Outputs = Any +LOSS = jax.typing.ArrayLike + +__all__ = [ + "Problem", +] + + +class Problem(abc.ABC): + """ + Base Problem Class. + + A problem is defined by the approximator and the loss function. + + Attributes: + approximator: The approximator. + loss_fn: The loss function. + loss_weights: A list specifying scalar coefficients (Python floats) to + weight the loss contributions. The loss value that will be minimized by + the trainer will then be the weighted sum of all individual losses, + weighted by the `loss_weights` coefficients. + """ + + approximator: bst.nn.Module + loss_fn: Callable | Sequence[Callable] + + def __init__( + self, + approximator: bst.nn.Module = None, + loss_fn: str | Callable[[Inputs, Outputs], LOSS] = "MSE", + loss_weights: Sequence[float] = None, + ): + """ + Initialize the problem. + + Args: + approximator (bst.nn.Module, optional): The approximator. Defaults to None. + loss_fn (str | Callable[[Inputs, Outputs], LOSS], optional): The loss function. + If the same loss is used for all errors, then `loss` is a String name of a loss function + or a loss function. If different errors use different losses, then `loss` is a list + whose size is equal to the number of errors. Defaults to 'MSE'. + loss_weights (Sequence[float], optional): A list specifying scalar coefficients (Python floats) to + weight the loss contributions. The loss value that will be minimized by + the trainer will then be the weighted sum of all individual losses, + weighted by the `loss_weights` coefficients. Defaults to None. + """ + # approximator + if approximator is not None: + self.define_approximator(approximator) + else: + self.approximator = None + + # loss function + self.loss_fn = get_loss(loss_fn) + + # loss weights + if loss_weights is not None: + assert isinstance( + loss_weights, (list, tuple) + ), "loss_weights must be a list or tuple." + self.loss_weights = loss_weights + + def define_approximator( + self, + approximator: bst.nn.Module, + ) -> Problem: + """ + Define the approximator for the problem. + + Args: + approximator (bst.nn.Module): The approximator to be used in the problem. + + Returns: + Problem: The current Problem instance with the defined approximator. + + Raises: + AssertionError: If the approximator is not an instance of bst.nn.Module. + """ + assert isinstance( + approximator, bst.nn.Module + ), "approximator must be an instance of bst.nn.Module." + self.approximator = approximator + return self + + def losses(self, inputs, outputs, targets, **kwargs): + """ + Calculate and return a list of losses (constraints) for the problem. + + Args: + inputs: The input data. + outputs: The output data. + targets: The target data. + **kwargs: Additional keyword arguments. + + Returns: + A list of calculated losses. + + Raises: + NotImplementedError: This method should be implemented by subclasses. + """ + raise NotImplementedError("Problem.losses is not implemented.") + + def losses_train(self, inputs, outputs, targets, **kwargs): + """ + Calculate and return a list of losses for the training dataset. + + This method sets the environment context to training mode before calculating losses. + + Args: + inputs: The input data for training. + outputs: The output data for training. + targets: The target data for training. + **kwargs: Additional keyword arguments. + + Returns: + A list of calculated losses for the training dataset. + """ + with bst.environ.context(fit=True): + return self.losses(inputs, outputs, targets, **kwargs) + + def losses_test(self, inputs, outputs, targets, **kwargs): + """ + Calculate and return a list of losses for the test dataset. + + This method sets the environment context to testing mode before calculating losses. + + Args: + inputs: The input data for testing. + outputs: The output data for testing. + targets: The target data for testing. + **kwargs: Additional keyword arguments. + + Returns: + A list of calculated losses for the test dataset. + """ + with bst.environ.context(fit=False): + return self.losses(inputs, outputs, targets, **kwargs) + + @abc.abstractmethod + def train_next_batch( + self, batch_size=None + ) -> Tuple[Inputs, Targets] | Tuple[Inputs, Targets, Auxiliary]: + """ + Generate and return the next batch of training data. + + This method should be implemented by subclasses to provide the next batch of training data. + + Args: + batch_size (int, optional): The size of the batch to be returned. Defaults to None. + + Returns: + Tuple[Inputs, Targets] | Tuple[Inputs, Targets, Auxiliary]: A tuple containing the inputs and targets + for the next training batch. May also include auxiliary data if applicable. + """ + + @abc.abstractmethod + def test(self) -> Tuple[Inputs, Targets] | Tuple[Inputs, Targets, Auxiliary]: + """ + Generate and return the test dataset. + + This method should be implemented by subclasses to provide the test dataset. + + Returns: + Tuple[Inputs, Targets] | Tuple[Inputs, Targets, Auxiliary]: A tuple containing the inputs and targets + for the test dataset. May also include auxiliary data if applicable. + """ diff --git a/deepxde/experimental/problem/dataset_function.py b/deepxde/experimental/problem/dataset_function.py new file mode 100644 index 000000000..6e85d3048 --- /dev/null +++ b/deepxde/experimental/problem/dataset_function.py @@ -0,0 +1,109 @@ +from typing import Callable, Sequence + +import brainstate as bst + +from deepxde.experimental.geometry.base import GeometryExperimental +from deepxde.utils.internal import run_if_any_none +from .base import Problem + +__all__ = [ + "Function", +] + + +class Function(Problem): + """ + Approximate a function via a network. + + Args: + geometry (GeometryExperimental): The domain of the function. Instance of ``Geometry``. + function (Callable): The function to be approximated. A callable function takes a NumPy array as the input and returns the + a NumPy array of corresponding function values. + num_train (int): The number of training points sampled inside the domain. + num_test (int): The number of points for testing. + train_distribution (str, optional): The distribution to sample training points. One of the following: "uniform" + (equispaced grid), "pseudo" (pseudorandom), "LHS" (Latin hypercube sampling), "Halton" (Halton sequence), + "Hammersley" (Hammersley sequence), or "Sobol" (Sobol sequence). Defaults to "uniform". + online (bool, optional): If ``True``, resample the pseudorandom training points every training step, otherwise, use the + same training points. Defaults to False. + approximator (bst.nn.Module, optional): The neural network module to use as an approximator. Defaults to None. + loss_fn (str, optional): The loss function to use. Defaults to 'MSE'. + loss_weights (Sequence[float], optional): The weights for different loss components. Defaults to None. + """ + + def __init__( + self, + geometry: GeometryExperimental, + function: Callable, + num_train: int, + num_test: int, + train_distribution: str = "uniform", + online: bool = False, + approximator: bst.nn.Module = None, + loss_fn: str = "MSE", + loss_weights: Sequence[float] = None, + ): + super().__init__( + approximator=approximator, loss_fn=loss_fn, loss_weights=loss_weights + ) + + self.geom = geometry + self.func = function + self.num_train = num_train + self.num_test = num_test + self.dist_train = train_distribution + self.online = online + + if online and train_distribution != "pseudo": + print("Warning: Online learning should use pseudorandom sampling.") + self.dist_train = "pseudo" + + self.train_x, self.train_y = None, None + self.test_x, self.test_y = None, None + + def losses(self, inputs, outputs, targets, **kwargs): + """ + Compute the loss between the predicted outputs and the target values. + + Args: + inputs: The input data. + outputs: The predicted output from the model. + targets: The target values. + **kwargs: Additional keyword arguments. + + Returns: + The computed loss value. + """ + return self.loss_fn(targets, outputs) + + def train_next_batch(self, batch_size=None): + """ + Generate the next batch of training data. + + Args: + batch_size (int, optional): The size of the batch to generate. Defaults to None. + + Returns: + tuple: A tuple containing the input features (train_x) and target values (train_y) for training. + """ + if self.train_x is None or self.online: + if self.dist_train == "uniform": + self.train_x = self.geom.uniform_points(self.num_train, boundary=True) + else: + self.train_x = self.geom.random_points( + self.num_train, random=self.dist_train + ) + self.train_y = self.func(self.train_x) + return self.train_x, self.train_y + + @run_if_any_none("test_x", "test_y") + def test(self): + """ + Generate test data points and their corresponding function values. + + Returns: + tuple: A tuple containing the test input features (test_x) and their corresponding function values (test_y). + """ + self.test_x = self.geom.uniform_points(self.num_test, boundary=True) + self.test_y = self.func(self.test_x) + return self.test_x, self.test_y diff --git a/deepxde/experimental/problem/dataset_general.py b/deepxde/experimental/problem/dataset_general.py new file mode 100644 index 000000000..e29a1aa53 --- /dev/null +++ b/deepxde/experimental/problem/dataset_general.py @@ -0,0 +1,104 @@ +from typing import Sequence, Dict + +import brainstate as bst +import jax +import numpy as np + +from deepxde.experimental import utils +from .base import Problem + +__all__ = ["DataSet"] + + +class DataSet(Problem): + """ + Fitting Problem set for handling dataset-based machine learning problems. + + This class extends the Problem class to handle dataset-based machine learning tasks, + including data preprocessing, loss calculation, and batch generation for training. + + Args: + X_train (Dict[str, bst.typing.ArrayLike]): Dictionary of training input data. + y_train (Dict[str, bst.typing.ArrayLike]): Dictionary of training output data. + X_test (Dict[str, bst.typing.ArrayLike]): Dictionary of testing input data. + y_test (Dict[str, bst.typing.ArrayLike]): Dictionary of testing output data. + standardize (bool, optional): Whether to standardize input data. Defaults to False. + approximator (bst.nn.Module, optional): The neural network module to use. Defaults to None. + loss_fn (str, optional): The loss function to use. Defaults to 'MSE'. + loss_weights (Sequence[float], optional): Weights for different loss components. Defaults to None. + + Attributes: + train_x (Dict[str, bst.typing.ArrayLike]): Processed training input data. + train_y (Dict[str, bst.typing.ArrayLike]): Processed training output data. + test_x (Dict[str, bst.typing.ArrayLike]): Processed testing input data. + test_y (Dict[str, bst.typing.ArrayLike]): Processed testing output data. + scaler_x (object): Scaler used for standardization, if applied. + """ + + def __init__( + self, + X_train: Dict[str, bst.typing.ArrayLike], + y_train: Dict[str, bst.typing.ArrayLike], + X_test: Dict[str, bst.typing.ArrayLike], + y_test: Dict[str, bst.typing.ArrayLike], + standardize: bool = False, + approximator: bst.nn.Module = None, + loss_fn: str = "MSE", + loss_weights: Sequence[float] = None, + ): + super().__init__( + approximator=approximator, loss_fn=loss_fn, loss_weights=loss_weights + ) + + self.train_x = X_train + self.train_y = y_train + self.test_x = X_test + self.test_y = y_test + self.scaler_x = None + if standardize: + r = jax.tree.map( + lambda train, test: utils.standardize(train, test), + self.train_x, + self.test_x, + ) + self.train_x = dict() + self.test_x = dict() + for key, val in r.items(): + self.train_x[key] = val[0] + self.test_x[key] = val[1] + + def losses(self, inputs, outputs, targets, **kwargs): + """ + Calculate the loss between the model outputs and the target values. + + Args: + inputs: The input data (not used in this method). + outputs: The model's output predictions. + targets: The true target values. + **kwargs: Additional keyword arguments. + + Returns: + The calculated loss value. + """ + return self.loss_fn(targets, outputs) + + def train_next_batch(self, batch_size=None): + """ + Get the next batch of training data. + + Args: + batch_size (int, optional): The size of the batch to return. If None, returns all training data. + + Returns: + tuple: A tuple containing the batch of training inputs (self.train_x) and outputs (self.train_y). + """ + return self.train_x, self.train_y + + def test(self): + """ + Get the test dataset. + + Returns: + tuple: A tuple containing the test inputs (self.test_x) and outputs (self.test_y). + """ + return self.test_x, self.test_y diff --git a/deepxde/experimental/problem/dataset_quadruple.py b/deepxde/experimental/problem/dataset_quadruple.py new file mode 100644 index 000000000..efcf5eee6 --- /dev/null +++ b/deepxde/experimental/problem/dataset_quadruple.py @@ -0,0 +1,92 @@ +from typing import Sequence + +import brainstate as bst + +from deepxde.data.sampler import BatchSampler +from .base import Problem + +__all__ = [ + "QuadrupleDataset", +] + + +class QuadrupleDataset(Problem): + """ + Dataset with each data point as a quadruple. + + The couple of the first three elements are the input, and the fourth element is the + output. This dataset can be used with the network ``MIONet`` for operator + learning. + + Args: + X_train (tuple): A tuple of three NumPy arrays representing the input training data. + y_train (numpy.ndarray): A NumPy array representing the output training data. + X_test (tuple): A tuple of three NumPy arrays representing the input testing data. + y_test (numpy.ndarray): A NumPy array representing the output testing data. + approximator (bst.nn.Module, optional): The neural network module used for approximation. Defaults to None. + loss_fn (str, optional): The loss function to be used. Defaults to 'MSE'. + loss_weights (Sequence[float], optional): Weights for the loss function. Defaults to None. + """ + + def __init__( + self, + X_train, + y_train, + X_test, + y_test, + approximator: bst.nn.Module = None, + loss_fn: str = "MSE", + loss_weights: Sequence[float] = None, + ): + super().__init__( + approximator=approximator, loss_fn=loss_fn, loss_weights=loss_weights + ) + self.train_x = X_train + self.train_y = y_train + self.test_x = X_test + self.test_y = y_test + + self.train_sampler = BatchSampler(len(self.train_y), shuffle=True) + + def losses(self, inputs, outputs, targets, **kwargs): + """ + Calculate the loss between the predicted outputs and the target values. + + Args: + inputs: The input data (not used in this method). + outputs: The predicted output values. + targets: The target output values. + **kwargs: Additional keyword arguments. + + Returns: + The calculated loss value. + """ + return self.loss_fn(targets, outputs) + + def train_next_batch(self, batch_size=None): + """ + Get the next batch of training data. + + Args: + batch_size (int, optional): The size of the batch to return. If None, returns all training data. + + Returns: + tuple: A tuple containing the input data (as a tuple of arrays) and the corresponding output data. + """ + if batch_size is None: + return self.train_x, self.train_y + indices = self.train_sampler.get_next(batch_size) + return ( + (self.train_x[0][indices], self.train_x[1][indices]), + self.train_x[2][indices], + self.train_y[indices], + ) + + def test(self): + """ + Get the testing data. + + Returns: + tuple: A tuple containing the input testing data and the corresponding output testing data. + """ + return self.test_x, self.test_y diff --git a/deepxde/experimental/problem/dataset_triple.py b/deepxde/experimental/problem/dataset_triple.py new file mode 100644 index 000000000..eb9f7c3c9 --- /dev/null +++ b/deepxde/experimental/problem/dataset_triple.py @@ -0,0 +1,214 @@ +from typing import Sequence + +import brainstate as bst + +from deepxde.data.sampler import BatchSampler +from .base import Problem + +__all__ = ["TripleDataset", "TripleCartesianProd"] + + +class TripleDataset(Problem): + """ + Dataset with each data point as a triple. + + The couple of the first two elements are the input, and the third element is the + output. This dataset can be used with the network ``DeepONet`` for operator + learning. + + Args: + X_train (tuple): A tuple of two NumPy arrays representing the input training data. + y_train (numpy.ndarray): A NumPy array representing the output training data. + X_test (tuple): A tuple of two NumPy arrays representing the input testing data. + y_test (numpy.ndarray): A NumPy array representing the output testing data. + approximator (bst.nn.Module, optional): The neural network module used for approximation. Defaults to None. + loss_fn (str, optional): The loss function to be used. Defaults to 'MSE'. + loss_weights (Sequence[float], optional): Weights for the loss function. Defaults to None. + + References: + `L. Lu, P. Jin, G. Pang, Z. Zhang, & G. E. Karniadakis. Learning nonlinear + operators via DeepONet based on the universal approximation theorem of + operators. Nature Machine Intelligence, 3, 218--229, 2021 + `_. + """ + + def __init__( + self, + X_train, + y_train, + X_test, + y_test, + approximator: bst.nn.Module = None, + loss_fn: str = "MSE", + loss_weights: Sequence[float] = None, + ): + super().__init__( + approximator=approximator, loss_fn=loss_fn, loss_weights=loss_weights + ) + self.train_x = X_train + self.train_y = y_train + self.test_x = X_test + self.test_y = y_test + + self.train_sampler = BatchSampler(len(self.train_y), shuffle=True) + + def losses(self, inputs, outputs, targets, **kwargs): + """ + Compute the loss between the model outputs and the targets. + + Args: + inputs: The input data (not used in this method). + outputs: The model outputs. + targets: The target values. + **kwargs: Additional keyword arguments. + + Returns: + The computed loss value. + """ + return self.loss_fn(targets, outputs) + + def train_next_batch(self, batch_size=None): + """ + Get the next batch of training data. + + Args: + batch_size (int, optional): The size of the batch to return. If None, returns all training data. + + Returns: + tuple: A tuple containing two elements: + - A tuple of two arrays representing the input training data for the batch. + - An array representing the output training data for the batch. + """ + if batch_size is None: + return self.train_x, self.train_y + indices = self.train_sampler.get_next(batch_size) + return ( + (self.train_x[0][indices], self.train_x[1][indices]), + self.train_y[indices], + ) + + def test(self): + """ + Get the testing data. + + Returns: + tuple: A tuple containing two elements: + - The input testing data. + - The output testing data. + """ + return self.test_x, self.test_y + + +class TripleCartesianProd(Problem): + """ + Dataset with each data point as a triple. The ordered pair of the first two + elements are created from a Cartesian product of the first two lists. If we compute + the Cartesian product of the first two arrays, then we have a ``TripleDataset`` dataset. + + This dataset can be used with the network ``DeepONetCartesianProd`` for operator + learning. + + Args: + X_train: A tuple of two NumPy arrays. The first element has the shape (`N1`, + `dim1`), and the second element has the shape (`N2`, `dim2`). + y_train: A NumPy array of shape (`N1`, `N2`). + """ + + def __init__( + self, + X_train, + y_train, + X_test, + y_test, + approximator: bst.nn.Module = None, + loss_fn: str = "MSE", + loss_weights: Sequence[float] = None, + ): + """ + Initialize the TripleCartesianProd dataset. + + Args: + X_train (tuple): A tuple of two NumPy arrays for training input data. + y_train (numpy.ndarray): A NumPy array for training output data. + X_test (tuple): A tuple of two NumPy arrays for testing input data. + y_test (numpy.ndarray): A NumPy array for testing output data. + approximator (bst.nn.Module, optional): The neural network module used for approximation. Defaults to None. + loss_fn (str, optional): The loss function to be used. Defaults to 'MSE'. + loss_weights (Sequence[float], optional): Weights for the loss function. Defaults to None. + + Raises: + ValueError: If the training or testing dataset does not have the format of Cartesian product. + """ + super().__init__( + approximator=approximator, loss_fn=loss_fn, loss_weights=loss_weights + ) + + if len(X_train[0]) != y_train.shape[0] or len(X_train[1]) != y_train.shape[1]: + raise ValueError( + "The training dataset does not have the format of Cartesian product." + ) + if len(X_test[0]) != y_test.shape[0] or len(X_test[1]) != y_test.shape[1]: + raise ValueError( + "The testing dataset does not have the format of Cartesian product." + ) + self.train_x, self.train_y = X_train, y_train + self.test_x, self.test_y = X_test, y_test + + self.branch_sampler = BatchSampler(len(X_train[0]), shuffle=True) + self.trunk_sampler = BatchSampler(len(X_train[1]), shuffle=True) + + def losses(self, inputs, outputs, targets, **kwargs): + """ + Compute the loss between the model outputs and the targets. + + Args: + inputs: The input data (not used in this method). + outputs: The model outputs. + targets: The target values. + **kwargs: Additional keyword arguments. + + Returns: + The computed loss value. + """ + return self.loss_fn(targets, outputs) + + def train_next_batch(self, batch_size=None): + """ + Get the next batch of training data. + + Args: + batch_size (int, tuple, or list, optional): The size of the batch to return. + If None, returns all training data. + If int, returns a batch with the specified size for branch data and all trunk data. + If tuple or list, returns a batch with specified sizes for both branch and trunk data. + + Returns: + tuple: A tuple containing two elements: + - A tuple of two arrays representing the input training data for the batch. + - An array representing the output training data for the batch. + """ + if batch_size is None: + return self.train_x, self.train_y + if not isinstance(batch_size, (tuple, list)): + indices = self.branch_sampler.get_next(batch_size) + return (self.train_x[0][indices], self.train_x[1]), self.train_y[indices] + indices_branch = self.branch_sampler.get_next(batch_size[0]) + indices_trunk = self.trunk_sampler.get_next(batch_size[1]) + return ( + ( + self.train_x[0][indices_branch], + self.train_x[1][indices_trunk], + ), + self.train_y[indices_branch, indices_trunk], + ) + + def test(self): + """ + Get the testing data. + + Returns: + tuple: A tuple containing two elements: + - The input testing data. + - The output testing data. + """ + return self.test_x, self.test_y diff --git a/deepxde/experimental/problem/fpde.py b/deepxde/experimental/problem/fpde.py new file mode 100644 index 000000000..f8e32d2b5 --- /dev/null +++ b/deepxde/experimental/problem/fpde.py @@ -0,0 +1,718 @@ +from __future__ import annotations + +import warnings +from typing import Callable, Sequence, Optional, Dict, Any + +import brainstate as bst +import brainunit as u +import jax +import numpy as np + +from deepxde.data.fpde import ( + Scheme, + Fractional as FractionalBase, + FractionalTime as FractionalTimeBase, +) +from deepxde.experimental.geometry import GeometryXTime, DictPointGeometry +from deepxde.experimental.icbc.base import ICBC +from deepxde.experimental.utils import array_ops +from deepxde.utils.internal import run_if_all_none +from .pde import PDE + +__all__ = ["FPDE", "TimeFPDE"] + +X = Dict[str, bst.typing.ArrayLike] +Y = Dict[str, bst.typing.ArrayLike] +InitMat = bst.typing.ArrayLike + + +class FPDE(PDE): + r""" + Fractional PDE solver. + + This class implements a solver for Fractional Partial Differential Equations (FPDEs) using the Physics-Informed Neural Network (PINN) approach. + + D-dimensional fractional Laplacian of order alpha/2 (1 < alpha < 2) is defined as: + (-Delta)^(alpha/2) u(x) = C(alpha, D) \int_{||theta||=1} D_theta^alpha u(x) d theta, + where C(alpha, D) = gamma((1-alpha)/2) * gamma((D+alpha)/2) / (2 pi^((D+1)/2)), + D_theta^alpha is the Riemann-Liouville directional fractional derivative, + and theta is the differentiation direction vector. + The solution u(x) is assumed to be identically zero in the boundary and exterior of the domain. + When D = 1, C(alpha, D) = 1 / (2 cos(alpha * pi / 2)). + + This solver does not consider C(alpha, D) in the fractional Laplacian, + and only discretizes \int_{||theta||=1} D_theta^alpha u(x) d theta. + D_theta^alpha is approximated by Grunwald-Letnikov formula. + + Parameters: + ----------- + geometry : DictPointGeometry + The geometry of the problem domain. + pde : Callable[[X, Y, InitMat], Any] + The PDE to be solved. + alpha : float | bst.State[float] + The order of the fractional derivative. + constraints : ICBC | Sequence[ICBC] + The initial and boundary conditions. + resolution : Sequence[int] + The resolution for discretization. + approximator : Optional[bst.nn.Module], default=None + The neural network approximator. + meshtype : str, default="dynamic" + The type of mesh to use ("static" or "dynamic"). + num_domain : int, default=0 + The number of domain points. + num_boundary : int, default=0 + The number of boundary points. + train_distribution : str, default="Hammersley" + The distribution method for training points. + anchors : Any, default=None + Anchor points for the domain. + solution : Callable[[Dict], Dict], default=None + The analytical solution of the PDE, if available. + num_test : int, default=None + The number of test points. + loss_fn : str | Callable, default='MSE' + The loss function to use. + loss_weights : Sequence[float], default=None + The weights for different components of the loss. + + References: + ----------- + G. Pang, L. Lu, & G. E. Karniadakis. fPINNs: Fractional physics-informed neural + networks. SIAM Journal on Scientific Computing, 41(4), A2603--A2626, 2019 + . + """ + + def __init__( + self, + geometry: DictPointGeometry, + pde: Callable[[X, Y, InitMat], Any], + alpha: float | bst.State[float], + constraints: ICBC | Sequence[ICBC], + resolution: Sequence[int], + approximator: Optional[bst.nn.Module] = None, + meshtype: str = "dynamic", + num_domain: int = 0, + num_boundary: int = 0, + train_distribution: str = "Hammersley", + anchors=None, + solution: Callable[[Dict], Dict] = None, + num_test: int = None, + loss_fn: str | Callable = "MSE", + loss_weights: Sequence[float] = None, + ): + self.alpha = alpha + self.disc = Scheme(meshtype, resolution) + self.frac_train, self.frac_test = None, None + self.int_mat_train = None + + super().__init__( + geometry, + pde, + constraints, + approximator=approximator, + num_domain=num_domain, + num_boundary=num_boundary, + train_distribution=train_distribution, + anchors=anchors, + solution=solution, + num_test=num_test, + loss_fn=loss_fn, + loss_weights=loss_weights, + ) + + def call_pde_errors(self, inputs, outputs, **kwargs): + bcs_start = np.cumsum([0] + self.num_bcs) + + # # PDE inputs and outputs + # pde_inputs = jax.tree.map(lambda x: x[bcs_start[-1]:], inputs) + # pde_outputs = jax.tree.map(lambda x: x[bcs_start[-1]:], outputs) + + # do not cache int_mat when alpha is a learnable parameter + fit = bst.environ.get("fit") + + if fit: + if isinstance(self.alpha, bst.State): + int_mat = self.get_int_matrix(True) + else: + if self.int_mat_train is not None: + # use cached int_mat + int_mat = self.int_mat_train + else: + # initialize self.int_mat_train with int_mat + int_mat = self.get_int_matrix(True) + self.int_mat_train = int_mat + else: + int_mat = self.get_int_matrix(False) + + # computing PDE losses + # pde_errors = self.pde(pde_inputs, pde_outputs, int_mat, **kwargs) + # return pde_errors + pde_errors = self.pde(inputs, outputs, int_mat, **kwargs) + return jax.tree.map(lambda x: x[bcs_start[-1] :], pde_errors) + + def call_bc_errors(self, loss_fns, loss_weights, inputs, outputs, **kwargs): + return super().call_bc_errors(loss_fns, loss_weights, inputs, outputs, **kwargs) + # fit = bst.environ.get('fit') + # if fit: + # return super().call_bc_errors(loss_fns, loss_weights, inputs, outputs, **kwargs) + # else: + # return [u.math.zeros((), dtype=bst.environ.dftype()) for _ in self.constraints] + + @run_if_all_none("train_x", "train_y") + def train_next_batch(self, batch_size=None): + alpha = self.alpha.value if isinstance(self.alpha, bst.State) else self.alpha + + # do not cache train data when alpha is a learnable parameter + if self.disc.meshtype == "static": + if self.geometry.geom.idstr != "Interval": + raise ValueError("Only Interval supports static mesh.") + + self.frac_train = Fractional(alpha, self.geometry.geom, self.disc, None) + X = self.frac_train.get_x() + X = self.geometry.arr_to_dict(u.math.roll(X, -1)) + + # FPDE is only applied to the domain points. + # Boundary points are auxiliary points, and appended in the end. + self.train_x_all = X + if self.anchors is not None: + self.train_x_all = jax.tree.map( + lambda x, y: u.math.concatenate((x, y), axis=-1), + self.anchors, + self.train_x_all, + ) + x_bc = self.bc_points() + + elif self.disc.meshtype == "dynamic": + self.train_x_all = self.train_points() + x_bc = self.bc_points() + + # FPDE is only applied to the domain points. + train_x_all = self.geometry.dict_to_arr(self.train_x_all) + x_f = train_x_all[~self.geometry.on_boundary(self.train_x_all)] + self.frac_train = Fractional(alpha, self.geometry.geom, self.disc, x_f) + X = self.geometry.arr_to_dict(self.frac_train.get_x()) + + else: + raise ValueError("Unknown meshtype %s" % self.disc.meshtype) + + self.train_x = jax.tree.map( + lambda x, y: u.math.concatenate((x, y), axis=-1), + x_bc, + X, + is_leaf=u.math.is_quantity, + ) + self.train_y = self.solution(self.train_x) if self.solution else None + return self.train_x, self.train_y + + @run_if_all_none("test_x", "test_y") + def test(self): + # do not cache test data when alpha is a learnable parameter + if self.disc.meshtype == "static" and self.num_test is not None: + raise ValueError("Cannot use test points in static mesh.") + + if self.num_test is None: + # assign the training points to the testing points + num_bc = sum(self.num_bcs) + self.test_x = jax.tree_map(lambda x: x[num_bc:], self.train_x) + self.frac_test = self.frac_train + else: + alpha = ( + self.alpha.value if isinstance(self.alpha, bst.State) else self.alpha + ) + + # Generate `self.test_x`, resampling the test points + self.test_x = self.test_points() + not_boundary = ~self.geometry.on_boundary(self.test_x) + x_f = self.geometry.dict_to_arr(self.test_x)[not_boundary] + self.frac_test = Fractional(alpha, self.geometry.geom, self.disc, x_f) + self.test_x = self.geometry.arr_to_dict(self.frac_test.get_x()) + + self.test_y = self.solution(self.test_x) if self.solution else None + return self.test_x, self.test_y + + def test_points(self): + return self.geometry.uniform_points(self.num_test, True) + + def get_int_matrix(self, training): + if training: + int_mat = self.frac_train.get_matrix(sparse=True) + num_bc = sum(self.num_bcs) + else: + int_mat = self.frac_test.get_matrix(sparse=True) + num_bc = 0 + + if self.disc.meshtype == "static": + int_mat = np.roll(int_mat, -1, 1) + int_mat = int_mat[1:-1] + + int_mat = array_ops.zero_padding(int_mat, ((num_bc, 0), (num_bc, 0))) + return int_mat + + +class TimeFPDE(FPDE): + r"""Time-dependent fractional PDE solver. + + D-dimensional fractional Laplacian of order alpha/2 (1 < alpha < 2) is defined as: + (-Delta)^(alpha/2) u(x) = C(alpha, D) \int_{||theta||=1} D_theta^alpha u(x) d theta, + where C(alpha, D) = gamma((1-alpha)/2) * gamma((D+alpha)/2) / (2 pi^((D+1)/2)), + D_theta^alpha is the Riemann-Liouville directional fractional derivative, + and theta is the differentiation direction vector. + The solution u(x) is assumed to be identically zero in the boundary and exterior of the domain. + When D = 1, C(alpha, D) = 1 / (2 cos(alpha * pi / 2)). + + This solver does not consider C(alpha, D) in the fractional Laplacian, + and only discretizes \int_{||theta||=1} D_theta^alpha u(x) d theta. + D_theta^alpha is approximated by Grunwald-Letnikov formula. + + References: + `G. Pang, L. Lu, & G. E. Karniadakis. fPINNs: Fractional physics-informed neural + networks. SIAM Journal on Scientific Computing, 41(4), A2603--A2626, 2019 + `_. + """ + + def __init__( + self, + geometry: DictPointGeometry, + pde: Callable[[X, Y, InitMat], Any], + alpha: float | bst.State[float], + constraints: ICBC | Sequence[ICBC], + resolution: Sequence[int], + approximator: Optional[bst.nn.Module] = None, + meshtype: str = "dynamic", + num_domain: int = 0, + num_boundary: int = 0, + num_initial: int = 0, + train_distribution: str = "Hammersley", + anchors=None, + solution=None, + num_test: int = None, + loss_fn: str | Callable = "MSE", + loss_weights: Sequence[float] = None, + ): + self.num_initial = num_initial + assert isinstance( + geometry, DictPointGeometry + ), f"DictPointGeometry is required. But got {geometry}" + super().__init__( + geometry, + pde, + alpha, + constraints, + resolution, + approximator=approximator, + meshtype=meshtype, + num_domain=num_domain, + num_boundary=num_boundary, + train_distribution=train_distribution, + anchors=anchors, + solution=solution, + num_test=num_test, + loss_fn=loss_fn, + loss_weights=loss_weights, + ) + + @run_if_all_none("train_x", "train_y") + def train_next_batch(self, batch_size=None): + assert isinstance( + self.geometry.geom, GeometryXTime + ), "GeometryXTime is required." + geometry = self.geometry.geom + alpha = self.alpha.value if isinstance(self.alpha, bst.State) else self.alpha + + if self.disc.meshtype == "static": + if geometry.geometry.idstr != "Interval": + raise ValueError("Only Interval supports static mesh.") + + nt = int(round(self.num_domain / (self.disc.resolution[0] - 2))) + 1 + self.frac_train = FractionalTime( + alpha, + geometry.geometry, + geometry.timedomain.t0, + geometry.timedomain.t1, + self.disc, + nt, + None, + ) + X = self.geometry.arr_to_dict(self.frac_train.get_x()) + self.train_x_all = X + if self.anchors is not None: + self.train_x_all = jax.tree.map( + lambda x, y: u.math.concatenate((x, y), axis=-1), + self.anchors, + self.train_x_all, + ) + x_bc = self.bc_points() + + # Remove the initial and boundary points at the beginning of X, + # which are not considered in the integral matrix. + n_start = self.disc.resolution[0] + 2 * nt - 2 + X = jax.tree.map(lambda x: x[n_start:], X) + + elif self.disc.meshtype == "dynamic": + self.train_x_all = self.train_points() + train_x_all = self.geometry.dict_to_arr(self.train_x_all) + x_bc = self.bc_points() + + # FPDE is only applied to the non-boundary points. + x_f = train_x_all[~geometry.on_boundary(train_x_all)] + self.frac_train = FractionalTime( + alpha, + geometry.geometry, + geometry.timedomain.t0, + geometry.timedomain.t1, + self.disc, + None, + x_f, + ) + X = self.geometry.arr_to_dict(self.frac_train.get_x()) + + else: + raise ValueError("Unknown meshtype %s" % self.disc.meshtype) + + self.train_x = jax.tree.map( + lambda x, y: u.math.concatenate((x, y), axis=-1), + x_bc, + X, + is_leaf=u.math.is_quantity, + ) + self.train_y = self.solution(self.train_x) if self.solution else None + return self.train_x, self.train_y + + @run_if_all_none("test_x", "test_y") + def test(self): + alpha = self.alpha.value if isinstance(self.alpha, bst.State) else self.alpha + assert isinstance( + self.geometry.geom, GeometryXTime + ), "GeometryXTime is required." + geometry = self.geometry.geom + if self.disc.meshtype == "static" and self.num_test is not None: + raise ValueError("Cannot use test points in static mesh.") + + if self.num_test is None: + n_bc = sum(self.num_bcs) + self.test_x = jax.tree.map(lambda x: x[n_bc:], self.train_x) + self.frac_test = self.frac_train + + else: + self.test_x = self.test_points() + test_x = self.geometry.dict_to_arr(self.test_x) + x_f = test_x[~geometry.on_boundary(test_x)] + self.frac_test = FractionalTime( + alpha, + geometry.geometry, + geometry.timedomain.t0, + geometry.timedomain.t1, + self.disc, + None, + x_f, + ) + self.test_x = self.geometry.arr_to_dict(self.frac_test.get_x()) + self.test_y = self.solution(self.test_x) if self.solution else None + return self.test_x, self.test_y + + def train_points(self): + X = super().train_points() + if self.num_initial > 0: + if self.train_distribution == "uniform": + tmp = self.geometry.uniform_initial_points(self.num_initial) + else: + tmp = self.geometry.random_initial_points( + self.num_initial, random=self.train_distribution + ) + X = jax.tree.map( + lambda x, y: u.math.concatenate((x, y), axis=-1), + tmp, + X, + is_leaf=u.math.is_quantity, + ) + return X + + def get_int_matrix(self, training): + if training: + int_mat = self.frac_train.get_matrix(sparse=True) + num_bc = sum(self.num_bcs) + else: + int_mat = self.frac_test.get_matrix(sparse=True) + num_bc = 0 + + int_mat = array_ops.zero_padding(int_mat, ((num_bc, 0), (num_bc, 0))) + return int_mat + + +class Fractional(FractionalBase): + """Fractional derivative. + + Args: + x0: If ``disc.meshtype = static``, then x0 should be None; + if ``disc.meshtype = 'dynamic'``, then x0 are non-boundary points. + """ + + def _check_dynamic_stepsize(self): + h = 1 / self.disc.resolution[-1] + min_h = self.geom.mindist2boundary(self.x0) + if min_h < h: + warnings.warn( + "Warning: mesh step size %f is larger than the boundary distance %f." + % (h, min_h), + UserWarning, + ) + + def _init_weights(self): + """If ``disc.meshtype = 'static'``, then n is number of points; + if ``disc.meshtype = 'dynamic'``, then n is resolution lambda. + """ + n = ( + self.disc.resolution[0] + if self.disc.meshtype == "static" + else self.dynamic_dist2npts(self.geom.diam) + 1 + ) + w = [1.0] + for j in range(1, n): + w.append(w[-1] * (j - 1 - self.alpha) / j) + return np.asarray(w) + + def get_x_dynamic(self): + if np.any(self.geom.on_boundary(self.x0)): + raise ValueError("x0 contains boundary points.") + if self.geom.dim == 1: + dirns, dirn_w = [-1, 1], [1, 1] + elif self.geom.dim == 2: + gauss_x, gauss_w = np.polynomial.legendre.leggauss(self.disc.resolution[0]) + gauss_x, gauss_w = gauss_x.astype(bst.environ.dftype()), gauss_w.astype( + bst.environ.dftype() + ) + thetas = np.pi * gauss_x + np.pi + dirns = np.vstack((np.cos(thetas), np.sin(thetas))).T + dirn_w = np.pi * gauss_w + elif self.geom.dim == 3: + gauss_x, gauss_w = np.polynomial.legendre.leggauss( + max(self.disc.resolution[:2]) + ) + gauss_x, gauss_w = gauss_x.astype(bst.environ.dftype()), gauss_w.astype( + bst.environ.dftype() + ) + thetas = (np.pi * gauss_x[: self.disc.resolution[0]] + np.pi) / 2 + phis = np.pi * gauss_x[: self.disc.resolution[1]] + np.pi + dirns, dirn_w = [], [] + for i in range(self.disc.resolution[0]): + for j in range(self.disc.resolution[1]): + dirns.append( + [ + np.sin(thetas[i]) * np.cos(phis[j]), + np.sin(thetas[i]) * np.sin(phis[j]), + np.cos(thetas[i]), + ] + ) + dirn_w.append(gauss_w[i] * gauss_w[j] * np.sin(thetas[i])) + dirn_w = np.pi**2 / 2 * np.array(dirn_w) + x, self.w = [], [] + for x0i in self.x0: + xi = list( + map( + lambda dirn: self.geom.background_points( + x0i, dirn, self.dynamic_dist2npts, 0 + ), + dirns, + ) + ) + wi = list( + map( + lambda i: dirn_w[i] + * np.linalg.norm(xi[i][1] - xi[i][0]) ** (-self.alpha) + * self.get_weight(len(xi[i]) - 1), + range(len(dirns)), + ) + ) + # first order + xi, wi = zip(*map(self.modify_first_order, xi, wi)) + # second order + # xi, wi = zip(*map(self.modify_second_order, xi, wi)) + # third order + # xi, wi = zip(*map(self.modify_third_order, xi, wi)) + x.append(np.vstack(xi)) + self.w.append(array_ops.hstack(wi)) + self.xindex_start = np.hstack(([0], np.cumsum(list(map(len, x))))) + len( + self.x0 + ) + return np.vstack([self.x0] + x) + + def modify_first_order(self, x, w): + x = np.vstack(([2 * x[0] - x[1]], x[:-1])) + if not self.geom.inside(x[0:1])[0]: + return x[1:], w[1:] + return x, w + + def modify_second_order(self, x=None, w=None): + w0 = np.hstack(([bst.environ.dftype()(0)], w)) + w1 = np.hstack((w, [bst.environ.dftype()(0)])) + beta = 1 - self.alpha / 2 + w = beta * w0 + (1 - beta) * w1 + if x is None: + return w + x = np.vstack(([2 * x[0] - x[1]], x)) + if not self.geom.inside(x[0:1])[0]: + return x[1:], w[1:] + return x, w + + def modify_third_order(self, x=None, w=None): + w0 = np.hstack(([bst.environ.dftype()(0)], w)) + w1 = np.hstack((w, [bst.environ.dftype()(0)])) + w2 = np.hstack(([bst.environ.dftype()(0)] * 2, w[:-1])) + beta = 1 - self.alpha / 2 + w = ( + (-6 * beta**2 + 11 * beta + 1) / 6 * w0 + + (11 - 6 * beta) * (1 - beta) / 12 * w1 + + (6 * beta + 1) * (beta - 1) / 12 * w2 + ) + if x is None: + return w + x = np.vstack(([2 * x[0] - x[1]], x)) + if not self.geom.inside(x[0:1])[0]: + return x[1:], w[1:] + return x, w + + def get_matrix_static(self): + if not isinstance(self.alpha, (np.ndarray, jax.Array)): + int_mat = np.zeros( + (self.disc.resolution[0], self.disc.resolution[0]), + dtype=bst.environ.dftype(), + ) + h = self.geom.diam / (self.disc.resolution[0] - 1) + for i in range(1, self.disc.resolution[0] - 1): + # first order + int_mat[i, 1 : i + 2] = np.flipud(self.get_weight(i)) + int_mat[i, i - 1 : -1] += self.get_weight( + self.disc.resolution[0] - 1 - i + ) + # second order + # int_mat[i, 0:i+2] = np.flipud(self.modify_second_order(w=self.get_weight(i))) + # int_mat[i, i-1:] += self.modify_second_order(w=self.get_weight(self.disc.resolution[0]-1-i)) + # third order + # int_mat[i, 0:i+2] = np.flipud(self.modify_third_order(w=self.get_weight(i))) + # int_mat[i, i-1:] += self.modify_third_order(w=self.get_weight(self.disc.resolution[0]-1-i)) + return h ** (-self.alpha) * int_mat + int_mat = np.zeros((1, self.disc.resolution[0]), dtype=bst.environ.dftype()) + for i in range(1, self.disc.resolution[0] - 1): + # shifted + row = np.concatenate( + [ + np.zeros(1, dtype=bst.environ.dftype()), + np.flip(self.get_weight(i), (0,)), + np.zeros( + self.disc.resolution[0] - i - 2, dtype=bst.environ.dftype() + ), + ], + 0, + ) + row += np.concatenate( + [ + np.zeros(i - 1, dtype=bst.environ.dftype()), + self.get_weight(self.disc.resolution[0] - 1 - i), + np.zeros(1, dtype=bst.environ.dftype()), + ], + 0, + ) + row = np.expand_dims(row, 0) + int_mat = np.concatenate([int_mat, row], 0) + int_mat = np.concatenate( + [ + int_mat, + np.zeros([1, self.disc.resolution[0]], dtype=bst.environ.dftype()), + ], + 0, + ) + h = self.geom.diam / (self.disc.resolution[0] - 1) + return h ** (-self.alpha) * int_mat + + def get_matrix_dynamic(self, sparse): + if self.x is None: + raise AssertionError("No dynamic points") + + if sparse: + print("Generating sparse fractional matrix...") + dense_shape = (self.x0.shape[0], self.x.shape[0]) + indices, values = [], [] + beg = self.x0.shape[0] + for i in range(self.x0.shape[0]): + for _ in range(self.w[i].shape[0]): + indices.append([i, beg]) + beg += 1 + values = array_ops.hstack((values, self.w[i])) + return indices, values, dense_shape + + print("Generating dense fractional matrix...") + int_mat = np.zeros( + (self.x0.shape[0], self.x.shape[0]), dtype=bst.environ.dftype() + ) + beg = self.x0.shape[0] + for i in range(self.x0.shape[0]): + int_mat[i, beg : beg + self.w[i].size] = self.w[i] + beg += self.w[i].size + return int_mat + + +class FractionalTime(FractionalTimeBase): + """Fractional derivative with time. + + Args: + nt: If ``disc.meshtype = static``, then nt is the number of t points; + if ``disc.meshtype = 'dynamic'``, then nt is None. + x0: If ``disc.meshtype = static``, then x0 should be None; + if ``disc.meshtype = 'dynamic'``, then x0 are non-boundary points. + + Attributes: + nx: If ``disc.meshtype = static``, then nx is the number of x points; + if ``disc.meshtype = dynamic``, then nx is the resolution lambda. + """ + + def get_x_static(self): + # Points are ordered as initial --> boundary --> inside + x = self.geom.uniform_points(self.disc.resolution[0], True) + x = np.roll(x, 1)[:, 0] + dt = (self.tmax - self.tmin) / (self.nt - 1) + d = np.empty( + (self.disc.resolution[0] * self.nt, self.geom.dim + 1), dtype=x.dtype + ) + d[0 : self.disc.resolution[0], 0] = x + d[0 : self.disc.resolution[0], 1] = self.tmin + beg = self.disc.resolution[0] + for i in range(1, self.nt): + d[beg : beg + 2, 0] = x[:2] + d[beg : beg + 2, 1] = self.tmin + i * dt + beg += 2 + for i in range(1, self.nt): + d[beg : beg + self.disc.resolution[0] - 2, 0] = x[2:] + d[beg : beg + self.disc.resolution[0] - 2, 1] = self.tmin + i * dt + beg += self.disc.resolution[0] - 2 + return d + + def get_x_dynamic(self): + self.fracx = Fractional(self.alpha, self.geom, self.disc, self.x0[:, :-1]) + xx = self.fracx.get_x() + x = np.empty((len(xx), self.geom.dim + 1), dtype=xx.dtype) + x[: len(self.x0)] = self.x0 + beg = len(self.x0) + for i in range(len(self.x0)): + tmp = xx[self.fracx.xindex_start[i] : self.fracx.xindex_start[i + 1]] + x[beg : beg + len(tmp), :1] = tmp + x[beg : beg + len(tmp), -1] = self.x0[i, -1] + beg += len(tmp) + return x + + def get_matrix_static(self): + # Only consider the inside points + print("Warning: assume zero boundary condition.") + n = (self.disc.resolution[0] - 2) * (self.nt - 1) + int_mat = np.zeros((n, n), dtype=bst.environ.dftype()) + self.fracx = Fractional(self.alpha, self.geom, self.disc, None) + int_mat_one = self.fracx.get_matrix() + beg = 0 + for _ in range(self.nt - 1): + int_mat[ + beg : beg + self.disc.resolution[0] - 2, + beg : beg + self.disc.resolution[0] - 2, + ] = int_mat_one[1:-1, 1:-1] + beg += self.disc.resolution[0] - 2 + return int_mat diff --git a/deepxde/experimental/problem/ide.py b/deepxde/experimental/problem/ide.py new file mode 100644 index 000000000..57a2acf04 --- /dev/null +++ b/deepxde/experimental/problem/ide.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +from typing import Callable, Sequence, Union, Optional, Dict, Any + +import brainstate as bst +import brainunit as u +import jax +import numpy as np + +from deepxde.experimental.geometry import DictPointGeometry +from deepxde.experimental.icbc.base import ICBC +from deepxde.utils.internal import run_if_all_none +from .pde import PDE + +__all__ = [ + "IDE", +] + +X = Dict[str, bst.typing.ArrayLike] +Y = Dict[str, bst.typing.ArrayLike] +InitMat = Any + + +class IDE(PDE): + """ + Integro-Differential Equation (IDE) solver class. + + This class extends the PDE solver to handle Integro-Differential Equations. + It specifically focuses on solving 1D problems with integral terms of the form: + int_0^x K(x, t) y(t) dt, where K is the kernel function. + + The IDE solver uses a Physics-Informed Neural Network (PINN) approach, + combining neural networks with numerical integration techniques to + approximate solutions to IDEs. + + Attributes: + kernel (Callable): The kernel function K(x, t) used in the integral term. + quad_deg (int): The degree of quadrature used for numerical integration. + quad_x (np.ndarray): Quadrature points for Gauss-Legendre quadrature. + quad_w (np.ndarray): Quadrature weights for Gauss-Legendre quadrature. + + Inherits from: + PDE: Base class for partial differential equation solvers. + + Note: + - This implementation currently supports only 1D problems. + - The solver uses Gauss-Legendre quadrature for numerical integration. + - The neural network approximator and other PDE-related functionalities + are inherited from the parent PDE class. + """ + + def __init__( + self, + geometry: DictPointGeometry, + ide: Callable[[X, Y, InitMat], Any], + constraints: Union[ICBC, Sequence[ICBC]], + quad_deg: int, + approximator: Optional[bst.nn.Module] = None, + kernel: Callable = None, + num_domain: int = 0, + num_boundary: int = 0, + train_distribution: str = "Hammersley", + anchors=None, + solution=None, + num_test: int = None, + loss_fn: str | Callable = "MSE", + loss_weights: Sequence[float] = None, + ): + """ + Initialize the IDE (Integro-Differential Equation) solver. + + Args: + geometry (DictPointGeometry): The geometry of the problem domain. + ide (Callable[[X, Y, InitMat], Any]): The IDE function to be solved. + constraints (Union[ICBC, Sequence[ICBC]]): Initial and boundary conditions. + quad_deg (int): The degree of quadrature for numerical integration. + approximator (Optional[bst.nn.Module], optional): The neural network approximator. Defaults to None. + kernel (Callable, optional): The kernel function for the integral term. Defaults to None. + num_domain (int, optional): Number of domain points. Defaults to 0. + num_boundary (int, optional): Number of boundary points. Defaults to 0. + train_distribution (str, optional): Distribution method for training points. Defaults to "Hammersley". + anchors (optional): Anchor points for the geometry. Defaults to None. + solution (optional): The analytical solution if available. Defaults to None. + num_test (int, optional): Number of test points. Defaults to None. + loss_fn (str | Callable, optional): Loss function to be used. Defaults to 'MSE'. + loss_weights (Sequence[float], optional): Weights for different components of the loss. Defaults to None. + + Returns: + None + """ + self.kernel = kernel or (lambda x, *args: np.ones((len(x), 1))) + self.quad_deg = quad_deg + self.quad_x, self.quad_w = np.polynomial.legendre.leggauss(quad_deg) + self.quad_x = self.quad_x.astype(bst.environ.dftype()) + self.quad_w = self.quad_w.astype(bst.environ.dftype()) + + super().__init__( + geometry, + ide, + constraints, + approximator=approximator, + num_domain=num_domain, + num_boundary=num_boundary, + train_distribution=train_distribution, + anchors=anchors, + solution=solution, + num_test=num_test, + loss_fn=loss_fn, + loss_weights=loss_weights, + ) + + def call_pde_errors(self, inputs, outputs, **kwargs): + bcs_start = np.cumsum([0] + self.num_bcs) + fit = bst.environ.get("fit") + int_mat = self.get_int_matrix(fit) + pde_errors = self.pde(inputs, outputs, int_mat, **kwargs) + return jax.tree.map(lambda x: x[bcs_start[-1] :], pde_errors) + + @run_if_all_none("train_x", "train_y") + def train_next_batch(self, batch_size=None): + self.train_x_all = self.train_points() + x_bc = self.bc_points() + x_quad = self.quad_points(self.train_x_all) + self.train_x = jax.tree.map( + lambda x, y, z: u.math.concatenate((x, y, z), axis=0), + x_bc, + self.train_x_all, + x_quad, + is_leaf=u.math.is_quantity, + ) + self.train_y = self.solution(self.train_x) if self.solution else None + return self.train_x, self.train_y + + @run_if_all_none("test_x", "test_y") + def test(self): + if self.num_test is None: + self.test_x = self.train_x_all + else: + self.test_x = self.test_points() + x_quad = self.quad_points(self.test_x) + self.test_x = jax.tree.map( + lambda x, y: u.math.concatenate((x, y), axis=0), + self.test_x, + x_quad, + is_leaf=u.math.is_quantity, + ) + self.test_y = self.solution(self.test_x) if self.solution else None + return self.test_x, self.test_y + + def test_points(self): + return self.geometry.uniform_points(self.num_test, True) + + def quad_points(self, X): + fn = lambda xs: (jax.vmap(lambda x: (self.quad_x + 1) * x / 2)(xs)).flatten() + return jax.tree.map(fn, X, is_leaf=u.math.is_quantity) + + def get_int_matrix(self, training): + def get_quad_weights(x): + return self.quad_w * x / 2 + + with jax.ensure_compile_time_eval(): + if training: + num_bc = sum(self.num_bcs) + X = self.train_x + else: + num_bc = 0 + X = self.test_x + + X = np.asarray(self.geometry.dict_to_arr(X)) + if training or self.num_test is None: + num_f = tuple(self.train_x_all.values())[0].shape[0] + else: + num_f = self.num_test + + int_mat = np.zeros((num_bc + num_f, X.size), dtype=bst.environ.dftype()) + for i in range(num_f): + x = X[i + num_bc, 0] + beg = num_f + num_bc + self.quad_deg * i + end = beg + self.quad_deg + K = np.ravel(self.kernel(np.full((self.quad_deg, 1), x), X[beg:end])) + int_mat[i + num_bc, beg:end] = get_quad_weights(x) * K + return int_mat diff --git a/deepxde/experimental/problem/pde.py b/deepxde/experimental/problem/pde.py new file mode 100644 index 000000000..2ba970595 --- /dev/null +++ b/deepxde/experimental/problem/pde.py @@ -0,0 +1,562 @@ +from __future__ import annotations + +from typing import Callable, Sequence, Union, Optional, Dict, List + +import brainstate as bst +import brainunit as u +import jax.tree +import numpy as np + +from deepxde.experimental import utils +from deepxde.experimental.geometry import GeometryXTime, DictPointGeometry +from deepxde.utils.internal import run_if_all_none +from .base import Problem +from ..icbc.base import ICBC + +__all__ = ["PDE", "TimePDE"] + + +class PDE(Problem): + """ODE or time-independent PDE solver. + + Args: + geometry: Instance of ``Geometry``. + constraints: A boundary condition or a list of boundary conditions. Use ``[]`` if no + boundary condition. + approximator: A neural network trainer for approximating the solution. + num_domain (int): The number of training points sampled inside the domain. + num_boundary (int): The number of training points sampled on the boundary. + train_distribution (string): The distribution to sample training points. One of + the following: "uniform" (equispaced grid), "pseudo" (pseudorandom), "LHS" + (Latin hypercube sampling), "Halton" (Halton sequence), "Hammersley" + (Hammersley sequence), or "Sobol" (Sobol sequence). + anchors: A Numpy array of training points, in addition to the `num_domain` and + `num_boundary` sampled points. + exclusions: A Numpy array of points to be excluded for training. + solution: The reference solution. + num_test: The number of points sampled inside the domain for testing PDE loss. + The testing points for BCs/ICs are the same set of points used for training. + If ``None``, then the training points will be used for testing. + + Warning: + The testing points include points inside the domain and points on the boundary, + and they may not have the same density, and thus the entire testing points may + not be uniformly distributed. As a result, if you have a reference solution + (`solution`) and would like to compute a metric such as + + .. code-block:: python + + Trainer.compile(metrics=["l2 relative error"]) + + then the metric may not be very accurate. To better compute a metric, you can + sample the points manually, and then use ``Trainer.predict()`` to predict the + solution on these points and compute the metric: + + .. code-block:: python + + x = geometry.uniform_points(num, boundary=True) + y_true = ... + y_pred = trainer.predict(x) + error= experimental.metrics.l2_relative_error(y_true, y_pred) + + Attributes: + train_x_all: A Numpy array of points for PDE training. `train_x_all` is + unordered, and does not have duplication. If there is PDE, then + `train_x_all` is used as the training points of PDE. + train_x_bc: A Numpy array of the training points for BCs. `train_x_bc` is + constructed from `train_x_all` at the first step of training, by default it + won't be updated when `train_x_all` changes. To update `train_x_bc`, set it + to `None` and call `bc_points`, and then update the loss function by + ``trainer.compile()``. + num_bcs (list): `num_bcs[i]` is the number of points for `constraints[i]`. + train_x: A Numpy array of the points fed into the network for training. + `train_x` is ordered from BC points (`train_x_bc`) to PDE points + (`train_x_all`), and may have duplicate points. + test_x: A Numpy array of the points fed into the network for testing, ordered + from BCs to PDE. The BC points are exactly the same points in `train_x_bc`. + """ + + def __init__( + self, + geometry: DictPointGeometry, + pde: Callable, + constraints: Union[ICBC, Sequence[ICBC]], + approximator: Optional[bst.nn.Module] = None, + solution: Callable[[bst.typing.PyTree], bst.typing.PyTree] = None, + loss_fn: str | Callable = "MSE", + num_domain: int = 0, + num_boundary: int = 0, + num_test: int = None, + train_distribution: str = "Hammersley", + anchors: Optional[bst.typing.ArrayLike] = None, + exclusions=None, + loss_weights: Sequence[float] = None, + ): + super().__init__( + approximator=approximator, loss_fn=loss_fn, loss_weights=loss_weights + ) + + assert isinstance( + geometry, DictPointGeometry + ), f"Expected DictPointGeometry, got {type(geometry)}" + # geometry is a Geometry object + self.geometry = geometry + + # PDE function + self._pde = pde + if pde is not None: + assert callable(pde), f"Expected callable, got {type(pde)}" + + # initial and boundary conditions + self.constraints = ( + constraints if isinstance(constraints, (list, tuple)) else [constraints] + ) + for bc in self.constraints: + assert isinstance(bc, ICBC), f"Expected ICBC, got {type(bc)}" + bc.apply_geometry(self.geometry) + bc.apply_problem(self) + + # anchors + self.anchors = ( + None + if anchors is None + else jax.tree.map(lambda x: x.astype(bst.environ.dftype()), anchors) + ) + + # solution + if solution is not None: + assert callable(solution), f"Expected callable, got {type(solution)}" + self.solution = solution + + # exclusions + self.exclusions = exclusions + + # others + self.num_domain = num_domain + self.num_boundary = num_boundary + self.num_test = num_test + self.train_distribution = train_distribution + + # training data + self.train_x_all: Dict[str, bst.typing.ArrayLike] = None + self.train_x_bc: Dict[str, bst.typing.ArrayLike] = None + self.num_bcs: List[int] = None + + # these include both BC and PDE points + self.train_x: Dict[str, bst.typing.ArrayLike] = None + self.train_y: Dict[str, bst.typing.ArrayLike] = None + self.test_x: Dict[str, bst.typing.ArrayLike] = None + self.test_y: Dict[str, bst.typing.ArrayLike] = None + + # generate training data and testing data + self.train_next_batch() + self.test() + + def pde(self, *args, **kwargs): + """ + Compute the PDE residual. + """ + if self._pde is not None: + return self._pde(*args, **kwargs) + else: + raise NotImplementedError("PDE is not defined.") + + def call_pde_errors(self, inputs, outputs, **kwargs): + bcs_start = np.cumsum([0] + self.num_bcs) + + # PDE inputs and outputs, computing PDE losses + pde_inputs = jax.tree.map(lambda x: x[bcs_start[-1] :], inputs) + pde_outputs = jax.tree.map(lambda x: x[bcs_start[-1] :], outputs) + pde_kwargs = jax.tree.map(lambda x: x[bcs_start[-1] :], kwargs) + + # error + pde_errors = self.pde(pde_inputs, pde_outputs, **pde_kwargs) + return pde_errors + + def call_bc_errors(self, loss_fns, loss_weights, inputs, outputs, **kwargs): + bcs_start = np.cumsum([0] + self.num_bcs) + losses = [] + for i, bc in enumerate(self.constraints): + # ICBC inputs and outputs, computing ICBC losses + beg, end = bcs_start[i], bcs_start[i + 1] + icbc_inputs = jax.tree.map(lambda x: x[beg:end], inputs) + icbc_outputs = jax.tree.map(lambda x: x[beg:end], outputs) + icbc_kwargs = jax.tree.map(lambda x: x[beg:end], kwargs) + + # error + error: Dict = bc.error(icbc_inputs, icbc_outputs, **icbc_kwargs) + + # loss and weights + f_loss = loss_fns[i] + if loss_weights is not None: + w = loss_weights[i] + bc_loss = jax.tree.map( + lambda err: f_loss(u.math.zeros_like(err), err) * w, error + ) + else: + bc_loss = jax.tree.map( + lambda err: f_loss(u.math.zeros_like(err), err), error + ) + + # append to losses + losses.append({f"ibc{i}": bc_loss}) + return losses + + @utils.check_not_none("num_bcs") + def losses(self, inputs, outputs, targets, **kwargs): + # PDE inputs and outputs, computing PDE losses + pde_errors = self.call_pde_errors(inputs, outputs, **kwargs) + if not isinstance(pde_errors, (list, tuple)): + pde_errors = [pde_errors] + + # loss functions + if not isinstance(self.loss_fn, (list, tuple)): + loss_fn = [self.loss_fn] * (len(pde_errors) + len(self.constraints)) + else: + loss_fn = self.loss_fn + if len(loss_fn) != len(pde_errors) + len(self.constraints): + raise ValueError( + f"There are {len(pde_errors) + len(self.constraints)} errors, " + f"but only {len(loss_fn)} losses." + ) + + # PDE loss + losses = [ + loss_fn[i](u.math.zeros_like(error), error) + for i, error in enumerate(pde_errors) + ] + if self.loss_weights is not None: + n_loss = len(losses) + len(self.constraints) + if len(self.loss_weights) != len(losses) + len(self.constraints): + raise ValueError( + f"Expected {n_loss} weights, got {len(self.loss_weights)}. " + f"There are {len(losses)} PDE losses and {len(self.constraints)} IC+BC losses." + ) + del n_loss + losses = [ + w * loss for w, loss in zip(self.loss_weights[: len(losses)], losses) + ] + + # loss of boundary or initial conditions + bc_errors = self.call_bc_errors( + loss_fn[len(pde_errors) :], + ( + self.loss_weights[len(pde_errors) :] + if self.loss_weights is not None + else None + ), + inputs, + outputs, + **kwargs, + ) + losses.extend(bc_errors) + return losses + + @run_if_all_none("train_x", "train_y") + def train_next_batch(self, batch_size=None): + # Generate `self.train_x_all` + self.train_points() + + # Generate `self.num_bcs` and `self.train_x_bc` + self.bc_points() + + if self.pde is not None: + # include data in boundary, initial conditions, and PDE + if len(self.train_x_bc): + self.train_x = jax.tree.map( + lambda x, y: u.math.concatenate((x, y), axis=0), + self.train_x_bc, + self.train_x_all, + ) + else: + self.train_x = self.train_x_all + + else: + # only include data in boundary or initial conditions + self.train_x = self.train_x_bc + + self.train_y = ( + self.solution(self.train_x) if self.solution is not None else None + ) + return self.train_x, self.train_y + + @run_if_all_none("test_x", "test_y") + def test(self): + if self.num_test is None: + # assign the training points to the testing points + self.test_x = self.train_x + else: + # Generate `self.test_x`, resampling the test points + self.test_x = self.test_points() + + # solution on the test points + self.test_y = self.solution(self.test_x) if self.solution is not None else None + return self.test_x, self.test_y + + def resample_train_points(self, pde_points=True, bc_points=True): + """Resample the training points for PDE and/or BC.""" + if pde_points: + self.train_x_all = None + if bc_points: + self.train_x_bc = None + self.train_x, self.train_y = None, None + self.train_next_batch() + + def add_anchors(self, anchors: bst.typing.PyTree): + """ + Add new points for training PDE losses. + + The BC points will not be updated. + """ + anchors = jax.tree.map(lambda x: x.astype(bst.environ.dftype()), anchors) + if self.anchors is None: + self.anchors = anchors + else: + self.anchors = jax.tree.map( + lambda x, y: u.math.concatenate((x, y), axis=-1), self.anchors, anchors + ) + + # include anchors in the training points + self.train_x_all = jax.tree.map( + lambda x, y: u.math.concatenate((x, y), axis=-1), anchors, self.train_x_all + ) + + if self.pde is not None: + # include data in boundary, initial conditions, and PDE + self.train_x = jax.tree.map( + lambda x, y: u.math.concatenate((x, y), axis=-1), + self.bc_points(), + self.train_x_all, + ) + + else: + # only include data in boundary or initial conditions + self.train_x = self.bc_points() + + # solution on the training points + self.train_y = ( + self.solution(self.train_x) if self.solution is not None else None + ) + + def replace_with_anchors(self, anchors): + """Replace the current PDE training points with anchors. + + The BC points will not be changed. + """ + self.anchors = jax.tree.map(lambda x: x.astype(bst.environ.dftype()), anchors) + self.train_x_all = self.anchors + + if self.pde is not None: + # include data in boundary, initial conditions, and PDE + self.train_x = jax.tree.map( + lambda x, y: u.math.concatenate((x, y), axis=-1), + self.bc_points(), + self.train_x_all, + ) + else: + # only include data in boundary or initial conditions + self.train_x = self.bc_points() + + # solution on the training points + self.train_y = ( + self.solution(self.train_x) if self.solution is not None else None + ) + + @run_if_all_none("train_x_all") + def train_points(self): + X = None + + # sampling points in the domain + if self.num_domain > 0: + if self.train_distribution == "uniform": + X = self.geometry.uniform_points(self.num_domain, boundary=False) + else: + X = self.geometry.random_points( + self.num_domain, random=self.train_distribution + ) + + # sampling points on the boundary + if self.num_boundary > 0: + if self.train_distribution == "uniform": + tmp = self.geometry.uniform_boundary_points(self.num_boundary) + else: + tmp = self.geometry.random_boundary_points( + self.num_boundary, random=self.train_distribution + ) + X = ( + tmp + if X is None + else jax.tree.map( + lambda x, y: u.math.concatenate((x, y), axis=0), X, tmp + ) + ) + + # add anchors + if self.anchors is not None: + X = ( + self.anchors + if X is None + else jax.tree.map( + lambda x, y: u.math.concatenate((x, y), axis=0), self.anchors, X + ) + ) + + # exclude points + if self.exclusions is not None: + raise NotImplementedError + + # TODO: Check if this is correct + def is_not_excluded(x): + return not np.any([np.allclose(x, y) for y in self.exclusions]) + + X = np.array(list(filter(is_not_excluded, X))) + + # save the training points + self.train_x_all = X + return X + + @run_if_all_none("train_x_bc") + def bc_points(self): + """ + Generate boundary condition points. + + Returns: + np.ndarray: The boundary condition points. + """ + x_bcs = [bc.collocation_points(self.train_x_all) for bc in self.constraints] + # self.num_bcs = list([len(x[self.geometry.names[0]]) for x in x_bcs]) + self.num_bcs = list([len(tuple(x.values())[0]) for x in x_bcs]) + if len(self.num_bcs): + self.train_x_bc = jax.tree.map( + lambda *x: u.math.concatenate(x, axis=0), *x_bcs + ) + else: + self.train_x_bc = dict() + return self.train_x_bc + + def test_points(self): + # different points from self.train_x_all + x = self.geometry.uniform_points(self.num_test, boundary=False) + + # # different BC points from self.train_x_bc + # x_bcs = [bc.collocation_points(x) for bc in self.constraints] + # x_bcs = jax.tree.map(lambda *x: u.math.vstack(x), *x_bcs) + + # reuse the same BC points + if len(self.num_bcs): + x_bcs = self.train_x_bc + x = jax.tree.map( + lambda x_, y_: u.math.concatenate((x_, y_), axis=0), x_bcs, x + ) + return x + + +class TimePDE(PDE): + """Time-dependent PDE solver. + + This class extends the PDE solver to handle time-dependent partial differential equations. + It provides functionality to generate training points for both spatial and temporal domains, + including initial condition points. + + Args: + geometry (DictPointGeometry): The geometry of the problem domain, including both spatial and temporal dimensions. + pde (Callable): The partial differential equation to be solved. + constraints (Union[ICBC, Sequence[ICBC]]): Initial and boundary conditions for the PDE. + approximator (Optional[bst.nn.Module]): The neural network used to approximate the solution. Defaults to None. + num_domain (int): Number of training points in the domain. Defaults to 0. + num_boundary (int): Number of training points on the boundary. Defaults to 0. + num_initial (int): Number of training points for the initial condition. Defaults to 0. + train_distribution (str): Method for distributing training points. Options include "uniform" and "Hammersley". Defaults to "Hammersley". + anchors (Optional): Specific points to include in the training set. Defaults to None. + exclusions (Optional): Points to exclude from the training set. Defaults to None. + solution (Optional): The analytical solution to the PDE, if known. Defaults to None. + num_test (Optional[int]): Number of test points. If None, training points are used for testing. Defaults to None. + loss_fn (Union[str, Callable]): Loss function for training. Can be a string identifier or a callable. Defaults to 'MSE'. + loss_weights (Optional[Sequence[float]]): Weights for different components of the loss function. Defaults to None. + + Attributes: + num_initial (int): Number of initial condition points. + geometry (GeometryXTime): The geometry of the problem, including time. + + Methods: + train_points(): Generates training points for the time-dependent PDE, including initial condition points. + + Note: + This class is specifically designed for time-dependent PDEs and extends the functionality + of the base PDE class to handle the temporal aspect of the problem. + """ + + def __init__( + self, + geometry: DictPointGeometry, + pde: Callable, + constraints: Union[ICBC, Sequence[ICBC]], + approximator: Optional[bst.nn.Module] = None, + num_domain: int = 0, + num_boundary: int = 0, + num_initial: int = 0, + train_distribution: str = "Hammersley", + anchors=None, + exclusions=None, + solution=None, + num_test: int = None, + loss_fn: str | Callable = "MSE", + loss_weights: Sequence[float] = None, + ): + self.num_initial = num_initial + super().__init__( + geometry, + pde, + constraints, + num_domain=num_domain, + num_boundary=num_boundary, + train_distribution=train_distribution, + anchors=anchors, + exclusions=exclusions, + solution=solution, + num_test=num_test, + approximator=approximator, + loss_fn=loss_fn, + loss_weights=loss_weights, + ) + + @run_if_all_none("train_x_all") + def train_points(self): + """ + Generate training points for the time-dependent PDE solver. + + This method extends the base PDE class's train_points method by adding + initial condition points for time-dependent problems. + + Returns: + X (Dict[str, bst.typing.ArrayLike]): A dictionary containing the generated training points. + The keys are the names of the spatial dimensions and time, and the values are + the corresponding coordinates. + + Note: + - The method uses the geometry attribute (of type GeometryXTime) to generate points. + - If num_initial > 0, it adds initial condition points to the training set. + - The distribution of initial points can be either uniform or based on the specified + train_distribution. + - If exclusions are specified, the method filters out excluded points. + """ + self.geometry: GeometryXTime + + X = super().train_points() + + if self.num_initial > 0: + if self.train_distribution == "uniform": + tmp = self.geometry.uniform_initial_points(self.num_initial) + else: + tmp = self.geometry.random_initial_points( + self.num_initial, random=self.train_distribution + ) + if self.exclusions is not None: + + def is_not_excluded(x): + return not np.any([np.allclose(x, y) for y in self.exclusions]) + + tmp = np.array(list(filter(is_not_excluded, tmp))) + X = jax.tree.map(lambda x, y: u.math.concatenate((x, y), axis=0), X, tmp) + self.train_x_all = X + return X diff --git a/deepxde/experimental/problem/pde_operator.py b/deepxde/experimental/problem/pde_operator.py new file mode 100644 index 000000000..8860d50ba --- /dev/null +++ b/deepxde/experimental/problem/pde_operator.py @@ -0,0 +1,388 @@ +from __future__ import annotations + +from typing import Callable, Sequence, Union, Optional, Any, Dict + +import brainstate as bst +import brainunit as u +import jax +import numpy as np + +from deepxde.data.function_spaces import FunctionSpace +from deepxde.data.sampler import BatchSampler +from deepxde.experimental.geometry import DictPointGeometry +from deepxde.experimental.icbc.base import ICBC +from deepxde.utils.internal import run_if_all_none +from .pde import TimePDE + +__all__ = [ + "PDEOperator", + "PDEOperatorCartesianProd", +] + +Inputs = Any +Outputs = Any +Auxiliary = Any +Residual = Any + + +class PDEOperator(TimePDE): + """ + PDE solution operator. + + Args: + function_space: Instance of ``experimental.fnspace.FunctionSpace``. + evaluation_points: A NumPy array of shape (n_points, dim). Discretize the input + function sampled from `function_space` using point-wise evaluations at a set + of points as the input of the branch net. + num_function (int): The number of functions for training. + function_variables: ``None`` or a list of integers. The functions in the + `function_space` may not have the same domain as the PDE. For example, the + PDE is defined on a spatio-temporal domain (`x`, `t`), but the function is + IC, which is only a function of `x`. In this case, we need to specify the + variables of the function by `function_variables=[0]`, where `0` indicates + the first variable `x`. If ``None``, then we assume the domains of the + function and the PDE are the same. + num_fn_test: The number of functions for testing PDE loss. The testing functions + for BCs/ICs are the same functions used for training. If ``None``, then the + training functions will be used for testing. + """ + + def __init__( + self, + geometry: DictPointGeometry, + pde: Callable[[Inputs, Outputs, Auxiliary], Residual], + constraints: Union[ICBC, Sequence[ICBC]], + function_space: FunctionSpace, + evaluation_points, + num_function: int, + function_variables: Optional[Sequence[int]] = None, + num_test: int = None, + approximator: Optional[bst.nn.Module] = None, + solution: Callable[[bst.typing.PyTree], bst.typing.PyTree] = None, + num_domain: int = 0, # for space PDE + num_boundary: int = 0, # for space PDE + num_initial: int = 0, # for time PDE + num_fn_test: int = None, + train_distribution: str = "Hammersley", + anchors: Optional[bst.typing.ArrayLike] = None, + exclusions=None, + loss_fn: str | Callable = "MSE", + loss_weights: Sequence[float] = None, + ): + + assert isinstance(function_space, FunctionSpace), ( + f"Expected `function_space` to be an instance of `FunctionSpace`, " + f"but got {type(function_space)}." + ) + self.fn_space = function_space + self.eval_pts = evaluation_points + self.func_vars = ( + function_variables + if function_variables is not None + else list(range(geometry.dim)) + ) + + self.num_fn = num_function + self.num_fn_test = num_fn_test + + self.fn_train_bc = None + self.fn_train_x = None + self.fn_train_y = None + self.fn_train_aux_vars = None + self.fn_test_x = None + self.fn_test_y = None + self.fn_test_aux_vars = None + + super().__init__( + geometry=geometry, + pde=pde, + constraints=constraints, + approximator=approximator, + loss_fn=loss_fn, + loss_weights=loss_weights, + num_initial=num_initial, + num_domain=num_domain, + num_boundary=num_boundary, + train_distribution=train_distribution, + anchors=anchors, + exclusions=exclusions, + solution=solution, + num_test=num_test, + ) + + def call_pde_errors(self, inputs, outputs, **kwargs): + num_bcs = self.num_bcs + self.num_bcs = self.num_fn_bcs + losses = super().call_pde_errors(inputs, outputs, **kwargs) + self.num_bcs = num_bcs + return losses + + def call_bc_errors(self, loss_fns, loss_weights, inputs, outputs, **kwargs): + num_bcs = self.num_bcs + self.num_bcs = self.num_fn_bcs + losses = super().call_bc_errors( + loss_fns, loss_weights, inputs, outputs, **kwargs + ) + self.num_bcs = num_bcs + return losses + + @run_if_all_none("fn_train_x", "fn_train_y", "fn_train_aux_vars") + def train_next_batch(self, batch_size=None): + super().train_next_batch(batch_size) + + self.num_fn_bcs = [n * self.num_fn for n in self.num_bcs] + func_feats = self.fn_space.random(self.num_fn) + func_vals = self.fn_space.eval_batch(func_feats, self.eval_pts) + v, x, vx = self.bc_inputs(func_feats, func_vals) + + if self._pde is not None: + v_pde, x_pde, vx_pde = self.gen_inputs( + func_feats, func_vals, self.geometry.dict_to_arr(self.train_x_all) + ) + v = np.vstack((v, v_pde)) + x = np.vstack((x, x_pde)) + vx = np.vstack((vx, vx_pde)) + self.fn_train_x = (v, x) + self.fn_train_aux_vars = {"aux": vx} + return self.fn_train_x, self.fn_train_x, self.fn_train_aux_vars + + @run_if_all_none("fn_test_x", "fn_test_y", "fn_test_aux_vars") + def test(self): + super().test() + + if self.num_fn_test is None: + self.fn_test_x = self.fn_train_x + self.fn_test_aux_vars = self.fn_train_aux_vars + + else: + func_feats = self.fn_space.random(self.num_fn_test) + func_vals = self.fn_space.eval_batch(func_feats, self.eval_pts) + # TODO: Use different BC data from self.fn_train_x + v, x, vx = self.train_bc + if self._pde is not None: + test_x = self.geometry.dict_to_arr(self.test_x) + v_pde, x_pde, vx_pde = self.gen_inputs( + func_feats, func_vals, test_x[sum(self.num_bcs) :] + ) + v = np.vstack((v, v_pde)) + x = np.vstack((x, x_pde)) + vx = np.vstack((vx, vx_pde)) + self.fn_test_x = (v, x) + self.fn_test_aux_vars = {"aux": vx} + return self.fn_test_x, self.fn_test_y, self.fn_test_aux_vars + + def gen_inputs(self, func_feats, func_vals, points): + # Format: + # v1, x_1 + # ... + # v1, x_N1 + # v2, x_1 + # ... + # v2, x_N1 + v = np.repeat(func_vals, len(points), axis=0) + x = np.tile(points, (len(func_feats), 1)) + vx = self.fn_space.eval_batch(func_feats, points[:, self.func_vars]).reshape( + -1, 1 + ) + return v, x, vx + + def bc_inputs(self, func_feats, func_vals): + if len(self.constraints) == 0: + self.train_bc = ( + np.empty((0, len(self.eval_pts)), dtype=bst.environ.dftype()), + np.empty((0, self.geometry.dim), dtype=bst.environ.dftype()), + np.empty((0, 1), dtype=bst.environ.dftype()), + ) + return self.train_bc + + v, x, vx = [], [], [] + bcs_start = np.cumsum([0] + self.num_bcs) + train_x_bc = self.geometry.dict_to_arr(self.train_x_bc) + for i, _ in enumerate(self.num_bcs): + beg, end = bcs_start[i], bcs_start[i + 1] + vi, xi, vxi = self.gen_inputs(func_feats, func_vals, train_x_bc[beg:end]) + v.append(vi) + x.append(xi) + vx.append(vxi) + self.train_bc = (np.vstack(v), np.vstack(x), np.vstack(vx)) + return self.train_bc + + def resample_train_points(self, pde_points=True, bc_points=True): + """ + Resample the training points for the operator. + """ + super().resample_train_points(pde_points=pde_points, bc_points=bc_points) + + self.fn_train_x, self.fn_train_x, self.fn_train_aux_vars = None, None, None + self.train_next_batch() + + +class PDEOperatorCartesianProd(TimePDE): + """ + PDE solution operator with problem in the format of Cartesian product. + + Args: + pde: Instance of ``experimental.problem.PDE`` or ``experimental.problem.TimePDE``. + function_space: Instance of ``experimental.problem.FunctionSpace``. + evaluation_points: A NumPy array of shape (n_points, dim). Discretize the input + function sampled from `function_space` using pointwise evaluations at a set + of points as the input of the branch net. + num_function (int): The number of functions for training. + function_variables: ``None`` or a list of integers. The functions in the + `function_space` may not have the same domain as the PDE. For example, the + PDE is defined on a spatio-temporal domain (`x`, `t`), but the function is + IC, which is only a function of `x`. In this case, we need to specify the + variables of the function by `function_variables=[0]`, where `0` indicates + the first variable `x`. If ``None``, then we assume the domains of the + function and the PDE are the same. + num_test: The number of functions for testing PDE loss. The testing functions + for BCs/ICs are the same functions used for training. If ``None``, then the + training functions will be used for testing. + batch_size: Integer or ``None``. + + Attributes: + train_x: A tuple of two Numpy arrays (v, x) fed into PIDeepONet for training. v + is the function input to the branch net and has the shape (`N1`, `dim1`); x + is the point input to the trunk net and has the shape (`N2`, `dim2`). + """ + + def __init__( + self, + geometry: DictPointGeometry, + pde: Callable[[Inputs, Outputs, Auxiliary], Residual], + constraints: Union[ICBC, Sequence[ICBC]], + function_space: FunctionSpace, + evaluation_points, + num_function: int, + function_variables: Optional[Sequence[int]] = None, + num_test: int = None, + approximator: Optional[bst.nn.Module] = None, + solution: Callable[[bst.typing.PyTree], bst.typing.PyTree] = None, + num_domain: int = 0, # for space PDE + num_boundary: int = 0, # for space PDE + num_initial: int = 0, # for time PDE + num_fn_test: int = None, # for function space + train_distribution: str = "Hammersley", + anchors: Optional[bst.typing.ArrayLike] = None, + exclusions=None, + loss_fn: str | Callable = "MSE", + loss_weights: Sequence[float] = None, + batch_size: int = None, + ): + + assert isinstance(function_space, FunctionSpace), ( + f"Expected `function_space` to be an instance of `FunctionSpace`, " + f"but got {type(function_space)}." + ) + self.fn_space = function_space + self.eval_pts = evaluation_points + self.func_vars = ( + function_variables + if function_variables is not None + else list(range(geometry.dim)) + ) + self.num_fn = num_function + self.num_fn_test = num_fn_test + + self.train_sampler = BatchSampler(self.num_fn, shuffle=True) + self.batch_size = batch_size + + self.fn_train_bc = None + self.fn_train_x = None + self.fn_train_y = None + self.fn_train_aux_vars = None + self.fn_test_x = None + self.fn_test_y = None + self.fn_test_aux_vars = None + + super().__init__( + geometry=geometry, + pde=pde, + constraints=constraints, + approximator=approximator, + loss_fn=loss_fn, + loss_weights=loss_weights, + num_initial=num_initial, + num_domain=num_domain, + num_boundary=num_boundary, + train_distribution=train_distribution, + anchors=anchors, + exclusions=exclusions, + solution=solution, + num_test=num_test, + ) + + def call_pde_errors(self, inputs, outputs, **kwargs): + bcs_start = np.cumsum([0] + self.num_bcs) + + # PDE inputs and outputs, computing PDE losses + pde_inputs = (inputs[0], jax.tree.map(lambda x: x[bcs_start[-1] :], inputs[1])) + pde_outputs = jax.tree.map(lambda x: x[:, bcs_start[-1] :], outputs) + pde_kwargs = jax.tree.map(lambda x: x[:, bcs_start[-1] :], kwargs) + + # error + pde_errors = self.pde(pde_inputs, pde_outputs, **pde_kwargs) + return pde_errors + + def call_bc_errors(self, loss_fns, loss_weights, inputs, outputs, **kwargs): + bcs_start = np.cumsum([0] + self.num_bcs) + losses = [] + for i, bc in enumerate(self.constraints): + # ICBC inputs and outputs, computing ICBC losses + beg, end = bcs_start[i], bcs_start[i + 1] + icbc_inputs = (inputs[0], jax.tree.map(lambda x: x[beg:end], inputs[1])) + icbc_outputs = jax.tree.map(lambda x: x[:, beg:end], outputs) + icbc_kwargs = jax.tree.map(lambda x: x[:, beg:end], kwargs) + + # error + error: Dict = bc.error(icbc_inputs, icbc_outputs, **icbc_kwargs) + + # loss and weights + f_loss = loss_fns[i] + if loss_weights is not None: + w = loss_weights[i] + bc_loss = jax.tree.map( + lambda err: f_loss(u.math.zeros_like(err), err) * w, error + ) + else: + bc_loss = jax.tree.map( + lambda err: f_loss(u.math.zeros_like(err), err), error + ) + + # append to losses + losses.append({f"ibc{i}": bc_loss}) + return losses + + def train_next_batch(self, batch_size=None): + super().train_next_batch(batch_size) + + if self.fn_train_x is None: + train_x = self.geometry.dict_to_arr(self.train_x) + func_feats = self.fn_space.random(self.num_fn) + func_vals = self.fn_space.eval_batch(func_feats, self.eval_pts) + vx = self.fn_space.eval_batch(func_feats, train_x[:, self.func_vars]) + self.fn_train_x = (func_vals, train_x) + self.fn_train_aux_vars = {"aux": vx} + + if self.batch_size is None: + return self.fn_train_x, self.train_y, self.fn_train_aux_vars + + indices = self.train_sampler.get_next(self.batch_size) + train_x = (self.fn_train_x[0][indices], self.fn_train_x[1]) + return train_x, self.train_y, {"aux": self.fn_train_aux_vars["aux"][indices]} + + @run_if_all_none("fn_test_x", "test_y", "fn_test_aux_vars") + def test(self): + super().test() + + if self.num_fn_test is None: + self.fn_test_x = self.fn_train_x + self.fn_test_aux_vars = self.fn_train_aux_vars + else: + test_x = self.geometry.dict_to_arr(self.test_x) + func_feats = self.fn_space.random(self.num_fn_test) + func_vals = self.fn_space.eval_batch(func_feats, self.eval_pts) + vx = self.fn_space.eval_batch(func_feats, test_x[:, self.func_vars]) + self.fn_test_x = (func_vals, test_x) + self.fn_test_aux_vars = {"aux": vx} + return self.fn_test_x, self.test_y, {"aux": self.fn_test_aux_vars} diff --git a/deepxde/experimental/utils/__init__.py b/deepxde/experimental/utils/__init__.py new file mode 100644 index 000000000..fa5ce829e --- /dev/null +++ b/deepxde/experimental/utils/__init__.py @@ -0,0 +1,6 @@ +"""Internal utilities.""" + +from . import array_ops +from ._convert import * +from .external import * +from .internal import * diff --git a/deepxde/experimental/utils/_convert.py b/deepxde/experimental/utils/_convert.py new file mode 100644 index 000000000..d3ed6240c --- /dev/null +++ b/deepxde/experimental/utils/_convert.py @@ -0,0 +1,52 @@ +from typing import Sequence, Dict + +import brainstate as bst +import brainunit as u +import numpy as np + +__all__ = [ + "array_to_dict", + "dict_to_array", +] + + +def array_to_dict( + x: bst.typing.ArrayLike, names: Sequence[str], keep_dim: bool = False +): + """ + Convert args to a dictionary. + + """ + if x.shape[-1] != len(names): + raise ValueError( + "The number of columns of x must be equal to the number of names." + ) + + if keep_dim: + return {key: x[..., i : i + 1] for i, key in enumerate(names)} + else: + return {key: x[..., i] for i, key in enumerate(names)} + + +def dict_to_array(d: Dict[str, bst.typing.ArrayLike], keep_dim: bool = False): + """ + Convert a dictionary to an array. + + Args: + d (dict): The dictionary. + keep_dim (bool): Whether to keep the dimension. + + Returns: + ndarray: The array. + """ + keys = tuple(d.keys()) + if isinstance(d[keys[0]], np.ndarray): + if keep_dim: + return np.concatenate([d[key] for key in keys], axis=-1) + else: + return np.stack([d[key] for key in keys], axis=-1) + else: + if keep_dim: + return u.math.concatenate([d[key] for key in keys], axis=-1) + else: + return u.math.stack([d[key] for key in keys], axis=-1) diff --git a/deepxde/experimental/utils/array_ops.py b/deepxde/experimental/utils/array_ops.py new file mode 100644 index 000000000..694f2edc0 --- /dev/null +++ b/deepxde/experimental/utils/array_ops.py @@ -0,0 +1,44 @@ +from typing import Sequence + +import brainstate as bst +import brainunit as u +import jax +import numpy as np + + +def is_tensor(obj): + return isinstance(obj, (jax.Array, u.Quantity, np.ndarray)) + + +def istensorlist(values): + return any(map(is_tensor, values)) + + +def convert_to_array(value: Sequence): + """Convert a list of numpy arrays or tensors to a numpy array or a tensor.""" + if istensorlist(value): + return np.stack(value, axis=0) + return np.array(value, dtype=bst.environ.dftype()) + + +def hstack(tup): + if not is_tensor(tup[0]) and isinstance(tup[0], list) and tup[0] == []: + tup = list(tup) + if istensorlist(tup[1:]): + tup[0] = np.asarray([], dtype=bst.environ.dftype()) + else: + tup[0] = np.array([], dtype=bst.environ.dftype()) + return np.concatenate(tup, 0) if is_tensor(tup[0]) else np.hstack(tup) + + +def zero_padding(array, pad_width): + # SparseTensor + if isinstance(array, (list, tuple)) and len(array) == 3: + indices, values, dense_shape = array + indices = [(i + pad_width[0][0], j + pad_width[1][0]) for i, j in indices] + dense_shape = ( + dense_shape[0] + sum(pad_width[0]), + dense_shape[1] + sum(pad_width[1]), + ) + return indices, values, dense_shape + return np.pad(array, pad_width) diff --git a/deepxde/experimental/utils/display.py b/deepxde/experimental/utils/display.py new file mode 100644 index 000000000..6b537986e --- /dev/null +++ b/deepxde/experimental/utils/display.py @@ -0,0 +1,115 @@ +import sys +from pprint import pformat + +import brainunit as u +import jax.tree + +from deepxde.experimental.utils import tree_repr + + +class TrainingDisplay: + """ + Display training progress. + """ + + def __init__(self): + self.len_train = None + self.len_test = None + self.len_metric = None + self.is_header_print = False + + def print_one(self, s1, s2, s3, s4): + s1 = s1.split("\n") + s2 = s2.split("\n") + s3 = s3.split("\n") + s4 = s4.split("\n") + + lines = [] + for i in range(max([len(s1), len(s2), len(s3), len(s4)])): + s1_ = s1[i] if i < len(s1) else "" + s2_ = s2[i] if i < len(s2) else "" + s3_ = s3[i] if i < len(s3) else "" + s4_ = s4[i] if i < len(s4) else "" + lines.append( + "{:{l1}s}{:{l2}s}{:{l3}s}{:{l4}s}".format( + s1_, + s2_, + s3_, + s4_, + l1=10, + l2=self.len_train, + l3=self.len_test, + l4=self.len_metric, + ) + ) + + print("\n".join(lines)) + sys.stdout.flush() + + def header(self): + self.print_one("Step", "Train loss", "Test loss", "Test metric") + self.is_header_print = True + + def __call__(self, train_state): + train_loss_repr = pformat(train_state.loss_train, width=40) + test_loss_repr = pformat(train_state.loss_test, width=40) + test_metrics_repr = pformat(train_state.metrics_test, width=40) + + if not self.is_header_print: + train_loss_repr_max = max( + [len(line) for line in train_loss_repr.split("\n") if line] + ) + test_loss_repr_max = max( + [len(line) for line in test_loss_repr.split("\n") if line] + ) + test_metrics_repr_max = max( + [len(line) for line in test_metrics_repr.split("\n") if line] + ) + self.len_train = train_loss_repr_max + 10 + self.len_test = test_loss_repr_max + 10 + self.len_metric = test_metrics_repr_max + 10 + self.header() + + self.print_one( + str(train_state.step), + train_loss_repr, + test_loss_repr, + test_metrics_repr, + ) + + def summary(self, train_state): + print("Best trainer at step {}:".format(train_state.best_step)) + print(" train loss: {}".format(train_state.best_loss_train)) + print(" test loss: {}".format(train_state.best_loss_test)) + print(" test metric: {}".format(tree_repr(train_state.best_metrics))) + if train_state.best_ystd is not None: + print(" Uncertainty:") + print( + " l2: {}".format( + jax.tree.map(lambda x: u.linalg.norm(x), train_state.best_ystd) + ) + ) + print( + " l_infinity: {}".format( + jax.tree_map( + lambda x: u.linalg.norm(x, ord=u.math.inf), + train_state.best_ystd, + is_leaf=u.math.is_quantity, + ) + ) + ) + if len(train_state.best_ystd) == 1: + index = u.math.argmax(tuple(train_state.best_ystd.values())[0]) + print( + " max uncertainty location:", + jax.tree_map( + lambda test: test[index], + train_state.X_test, + is_leaf=u.math.is_quantity, + ), + ) + print("") + self.is_header_print = False + + +training_display = TrainingDisplay() diff --git a/deepxde/experimental/utils/external.py b/deepxde/experimental/utils/external.py new file mode 100644 index 000000000..3dc2b0c26 --- /dev/null +++ b/deepxde/experimental/utils/external.py @@ -0,0 +1,362 @@ +"""External utilities.""" + +import csv +import os +from multiprocessing import Pool + +import braintools +import brainunit as u +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +from mpl_toolkits.mplot3d import Axes3D +from sklearn import preprocessing + + +def apply(func, args=None, kwds=None): + """Launch a new process to call the function. + + This can be used to clear Tensorflow GPU memory after trainer execution: + https://stackoverflow.com/questions/39758094/clearing-tensorflow-gpu-memory-after-model-execution + """ + with Pool(1) as p: + if args is None and kwds is None: + r = p.apply(func) + elif kwds is None: + r = p.apply(func, args=args) + elif args is None: + r = p.apply(func, kwds=kwds) + else: + r = p.apply(func, args=args, kwds=kwds) + return r + + +def standardize(X_train, X_test): + """Standardize features by removing the mean and scaling to unit variance. + + The mean and std are computed from the training data `X_train` using + `sklearn.preprocessing.StandardScaler `_, + and then applied to the testing data `X_test`. + + Args: + X_train: A NumPy array of shape (n_samples, n_features). The data used to + compute the mean and standard deviation used for later scaling along the + features axis. + X_test: A NumPy array. + + Returns: + scaler: Instance of ``sklearn.preprocessing.StandardScaler``. + X_train: Transformed training data. + X_test: Transformed testing data. + """ + + train_exp_dim = False + if u.math.ndim(X_train) == 1: + train_exp_dim = True + X_train = X_train.reshape(-1, 1) + test_exp_dim = False + if u.math.ndim(X_test) == 1: + test_exp_dim = True + X_test = X_test.reshape(-1, 1) + + scaler = preprocessing.StandardScaler(with_mean=True, with_std=True) + X_train = scaler.fit_transform(X_train) + X_test = scaler.transform(X_test) + if train_exp_dim: + X_train = X_train.flatten() + if test_exp_dim: + X_test = X_test.flatten() + return X_train, X_test + + +def saveplot( + loss_history, + train_state, + issave=True, + isplot=True, + loss_fname="loss.dat", + train_fname="train.dat", + test_fname="test.dat", + output_dir=None, +): + """Save/plot the loss history and best trained result. + + This function is used to quickly check your results. To better investigate your + result, use ``save_loss_history()`` and ``save_best_state()``. + + Args: + loss_history: ``LossHistory`` instance. The first variable returned from + ``Trainer.train()``. + train_state: ``TrainState`` instance. The second variable returned from + ``Trainer.train()``. + issave (bool): Set ``True`` (default) to save the loss, training points, + and testing points. + isplot (bool): Set ``True`` (default) to plot loss, metric, and the predicted + solution. + loss_fname (string): Name of the file to save the loss in. + train_fname (string): Name of the file to save the training points in. + test_fname (string): Name of the file to save the testing points in. + output_dir (string): If ``None``, use the current working directory. + """ + if output_dir is None: + output_dir = os.getcwd() + if not os.path.exists(output_dir): + print(f"Warning: Directory {output_dir} doesn't exist. Creating it.") + os.mkdir(output_dir) + + if issave: + loss_fname = os.path.join(output_dir, loss_fname) + train_fname = os.path.join(output_dir, train_fname) + test_fname = os.path.join(output_dir, test_fname) + save_loss_history(loss_history, loss_fname) + save_best_state(train_state, train_fname, test_fname) + + if isplot: + plot_loss_history(loss_history) + plot_best_state(train_state) + plt.show() + + +def plot_loss_history(loss_history, fname=None): + """Plot the training and testing loss history. + + Note: + You need to call ``plt.show()`` to show the figure. + + Args: + loss_history: ``LossHistory`` instance. The first variable returned from + ``Trainer.train()``. + fname (string): If `fname` is a string (e.g., 'loss_history.png'), then save the + figure to the file of the file name `fname`. + """ + # np.sum(loss_history.loss_train, axis=1) is error-prone for arrays of varying lengths. + # Handle irregular array sizes. + loss_train = jnp.array( + [ + jnp.sum(jnp.asarray(jax.tree.leaves(loss))) + for loss in loss_history.loss_train + ] + ) + loss_test = jnp.array( + [jnp.sum(jnp.asarray(jax.tree.leaves(loss))) for loss in loss_history.loss_test] + ) + + plt.figure() + plt.semilogy(loss_history.steps, loss_train, label="Train loss") + plt.semilogy(loss_history.steps, loss_test, label="Test loss") + metric_tests = jax.tree.map( + lambda *a: u.math.asarray(a), *loss_history.metrics_test + ) + + for i in range(len(loss_history.metrics_test[0])): + if isinstance(metric_tests[i], dict): + for k, v in metric_tests[i].items(): + plt.semilogy(loss_history.steps, v, label=f"Test metric {k}") + else: + plt.semilogy(loss_history.steps, metric_tests[i], label=f"Test metric {i}") + plt.xlabel("# Steps") + plt.legend() + + if isinstance(fname, str): + plt.savefig(fname) + + +def save_loss_history(loss_history, fname): + """Save the training and testing loss history to a file.""" + print("Saving loss history to {} ...".format(fname)) + + train_losses = jax.tree.map(lambda *a: u.math.asarray(a), *loss_history.loss_train) + braintools.file.msgpack_save(fname, train_losses) + + +def _pack_data(train_state): + def merge_values(values): + if values is None: + return None + return jnp.hstack(values) if isinstance(values, (list, tuple)) else values + + # y_train = merge_values(train_state.y_train) + # y_test = merge_values(train_state.y_test) + # best_y = merge_values(train_state.best_y) + # best_ystd = merge_values(train_state.best_ystd) + y_train = train_state.y_train + y_test = train_state.y_test + best_y = train_state.best_y + best_ystd = train_state.best_ystd + return y_train, y_test, best_y, best_ystd + + +def plot_best_state(train_state): + """Plot the best result of the smallest training loss. + + This function only works for 1D and 2D problems. For other problems and to better + customize the figure, use ``save_best_state()``. + + Note: + You need to call ``plt.show()`` to show the figure. + + Args: + train_state: ``TrainState`` instance. The second variable returned from + ``Trainer.train()``. + """ + if isinstance(train_state.X_train, (list, tuple)): + print( + "Error: The network has multiple inputs, and plotting such result hasn't been implemented." + ) + return + + y_train, y_test, best_y, best_ystd = _pack_data(train_state) + xkeys = tuple(train_state.X_test.keys()) + + # Regression plot + # 1D + if len(train_state.X_test) == 1: + idx = u.math.argsort(train_state.X_test[xkeys[0]]) + X = train_state.X_test[xkeys[0]][idx] + plt.figure() + for ykey in best_y: + if y_train is not None: + plt.plot( + train_state.X_train[xkeys[0]], y_train[ykey], "ok", label="Train" + ) + if y_test is not None: + plt.plot(X, y_test[ykey], "-k", label="True") + y_val, y_unit = u.split_mantissa_unit(best_y[ykey]) + plt.plot( + X, + y_val, + "--r", + label=( + f"{ykey} Prediction" + if y_unit.is_unitless + else f"{ykey} Prediction [{y_unit}]" + ), + ) + if best_ystd is not None: + ystd_val = u.get_magnitude(u.Quantity(best_ystd[ykey], unit=y_unit)) + plt.plot(X, y_val + 1.96 * ystd_val, "-b", label="95% CI") + plt.plot(X, y_val - 1.96 * ystd_val, "-b") + plt.xlabel("x") + plt.ylabel("y") + plt.legend() + + # 2D + elif len(train_state.X_test) == 2: + for ykey in best_y: + plt.figure() + ax = plt.axes(projection=Axes3D.name) + ax.plot3D( + u.get_magnitude(train_state.X_test[xkeys[0]]), + u.get_magnitude(train_state.X_test[xkeys[1]]), + u.get_magnitude(best_y[ykey]), + ".", + ) + unit = u.get_unit(train_state.X_test[xkeys[0]]) + if unit.is_unitless: + ax.set_xlabel(f"{xkeys[0]}") + else: + ax.set_xlabel(f"{xkeys[0]} [{unit}]") + unit = u.get_unit(train_state.X_test[xkeys[1]]) + if unit.is_unitless: + ax.set_ylabel(f"{xkeys[1]}") + else: + ax.set_ylabel(f"{xkeys[1]} [{unit}]") + unit = u.get_unit(best_y[ykey]) + if unit.is_unitless: + ax.set_zlabel(f"{ykey}") + else: + ax.set_zlabel(f"{ykey} [{unit}]") + + # Residual plot + # Not necessary to plot + # if y_test is not None: + # plt.figure() + # residual = y_test[:, 0] - best_y[:, 0] + # plt.plot(best_y[:, 0], residual, "o", zorder=1) + # plt.hlines(0, plt.xlim()[0], plt.xlim()[1], linestyles="dashed", zorder=2) + # plt.xlabel("Predicted") + # plt.ylabel("Residual = Observed - Predicted") + # plt.tight_layout() + + # Uncertainty plot + # Not necessary to plot + # if best_ystd is not None: + # plt.figure() + # for i in range(y_dim): + # plt.plot(train_state.X_test[:, 0], best_ystd[:, i], "-b") + # plt.plot( + # train_state.X_train[:, 0], + # np.interp( + # train_state.X_train[:, 0], train_state.X_test[:, 0], best_ystd[:, i] + # ), + # "ok", + # ) + # plt.xlabel("x") + # plt.ylabel("std(y)") + + +def save_best_state(train_state, fname_train, fname_test): + """Save the best result of the smallest training loss to a file.""" + if isinstance(train_state.X_train, (list, tuple)): + print( + "Error: The network has multiple inputs, and saving such result han't been implemented." + ) + return + + print("Saving training data to {} ...".format(fname_train)) + y_train, y_test, best_y, best_ystd = _pack_data(train_state) + if y_train is None: + data = {"X_train": train_state.X_train} + else: + data = {"X_train": train_state.X_train, "y_train": y_train} + braintools.file.msgpack_save(fname_train, data) + + print("Saving test data to {} ...".format(fname_test)) + if y_test is None: + data = {"X_test": train_state.X_test, "best_y": best_y} + if best_ystd is not None: + data["best_ystd"] = best_ystd + braintools.file.msgpack_save(fname_test, data) + else: + data = {"X_test": train_state.X_test, "best_y": best_y, "y_test": y_test} + if best_ystd is not None: + data["best_ystd"] = best_ystd + braintools.file.msgpack_save(fname_test, data) + + +def isclose(a, b): + """A modified version of `np.isclose` for DeepXDE. + + This function changes the value of `atol` due to the dtype of `a` and `b`. + If the dtype is float16, `atol` is `1e-4`. + If it is float32, `atol` is `1e-6`. + Otherwise (for float64), the default is `1e-8`. + If you want to manually set `atol` for some reason, use `np.isclose` instead. + + Args: + a, b (array like): DictToArray arrays to compare. + """ + pack = smart_numpy(a) + a_dtype = a.dtype + a_unit = u.get_unit(a) + if a_dtype == jnp.float32: + atol = u.maybe_decimal(u.Quantity(1e-6, unit=a_unit)) + elif a_dtype == jnp.float16: + atol = u.maybe_decimal(u.Quantity(1e-4, unit=a_unit)) + else: + atol = u.maybe_decimal(u.Quantity(1e-8, unit=a_unit)) + return pack.isclose(a, b, atol=atol) + + +def smart_numpy(x): + if isinstance(x, jnp.ndarray): + return jnp + elif isinstance(x, jax.Array): + return jax.numpy + elif isinstance(x, u.Quantity): + return u.math + elif isinstance(x, np.ndarray): + return np + else: + raise TypeError(f"Unknown type {type(x)}.") diff --git a/deepxde/experimental/utils/internal.py b/deepxde/experimental/utils/internal.py new file mode 100644 index 000000000..8fbf63ecc --- /dev/null +++ b/deepxde/experimental/utils/internal.py @@ -0,0 +1,50 @@ +"""Internal utilities.""" + +from functools import wraps +from typing import Callable, Union + +import brainstate as bst +import brainunit as u +import numpy as np + + +def check_not_none(*attr): + def decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + is_none = [] + for a in attr: + if not hasattr(self, a): + raise ValueError(f"{a} must be an attribute of the class.") + is_none.append(getattr(self, a) is None) + if any(is_none): + raise ValueError(f"{attr} must not be None.") + return func(self, *args, **kwargs) + + return wrapper + + return decorator + + +def return_tensor(func): + """Convert the output to a Tensor.""" + + @wraps(func) + def wrapper(*args, **kwargs): + return u.math.asarray(func(*args, **kwargs), dtype=bst.environ.dftype()) + + return wrapper + + +def tree_repr(tree, precision: int = 2): + with np.printoptions(precision=precision, suppress=True, threshold=5): + return repr(tree) + # return repr(jax.tree.map(lambda x: repr(x), tree, is_leaf=u.math.is_quantity)) + + +def get_activation(activation: Union[str, Callable]): + """Get the activation function.""" + if isinstance(activation, str): + return getattr(bst.functional, activation) + else: + return activation diff --git a/deepxde/experimental/utils/losses.py b/deepxde/experimental/utils/losses.py new file mode 100644 index 000000000..65f4258a5 --- /dev/null +++ b/deepxde/experimental/utils/losses.py @@ -0,0 +1,74 @@ +import braintools +import brainunit as u +import jax + + +def mean_absolute_error(y_true, y_pred): + return jax.tree.map( + lambda x, y: braintools.metric.absolute_error(x, y).mean(), + y_true, + y_pred, + is_leaf=u.math.is_quantity, + ) + + +def mean_squared_error(y_true, y_pred): + return jax.tree.map( + lambda x, y: braintools.metric.squared_error(x, y).mean(), + y_true, + y_pred, + is_leaf=u.math.is_quantity, + ) + + +def mean_l2_relative_error(y_true, y_pred): + return jax.tree.map( + lambda x, y: braintools.metric.l2_norm(x, y).mean(), + y_true, + y_pred, + is_leaf=u.math.is_quantity, + ) + + +def softmax_cross_entropy(y_true, y_pred): + return jax.tree.map( + lambda x, y: braintools.metric.softmax_cross_entropy(x, y).mean(), + y_true, + y_pred, + is_leaf=u.math.is_quantity, + ) + + +LOSS_DICT = { + # mean absolute error + "mean absolute error": mean_absolute_error, + "MAE": mean_absolute_error, + "mae": mean_absolute_error, + # mean squared error + "mean squared error": mean_squared_error, + "MSE": mean_squared_error, + "mse": mean_squared_error, + # mean l2 relative error + "mean l2 relative error": mean_l2_relative_error, + # softmax cross entropy + "softmax cross entropy": softmax_cross_entropy, +} + + +def get_loss(identifier): + """Retrieves a loss function. + + Args: + identifier: A loss identifier. String name of a loss function, or a loss function. + + Returns: + A loss function. + """ + if isinstance(identifier, (list, tuple)): + return list(map(get_loss, identifier)) + + if isinstance(identifier, str): + return LOSS_DICT[identifier] + if callable(identifier): + return identifier + raise ValueError("Could not interpret loss function identifier:", identifier)