diff --git a/deepxde/nn/activations.py b/deepxde/nn/activations.py index 382946861..5dc977089 100644 --- a/deepxde/nn/activations.py +++ b/deepxde/nn/activations.py @@ -1,5 +1,6 @@ from .. import backend as bkd from .. import config +from .. import utils from ..backend import backend_name, tf @@ -26,10 +27,12 @@ def layer_wise_locally_adaptive(activation, n=1): `_. """ # TODO: other backends - if backend_name != "tensorflow.compat.v1": - raise NotImplementedError("Only tensorflow.compat.v1 backend supports L-LAAF.") - a = tf.Variable(1 / n, dtype=config.real(tf)) - return lambda x: activation(n * a * x) + if backend_name == "tensorflow.compat.v1": + a = tf.Variable(1 / n, dtype=config.real(tf)) + return lambda x: activation(n * a * x) + if backend_name == "pytorch": + return utils.LLAAF(activation, n) + raise NotImplementedError(f"L-LAAF is not implemented for {backend_name} backend.") def get(identifier): diff --git a/deepxde/utils/pytorch.py b/deepxde/utils/pytorch.py index 7c9c087df..194f7b42f 100644 --- a/deepxde/utils/pytorch.py +++ b/deepxde/utils/pytorch.py @@ -1 +1,37 @@ """Utilities of pytorch.""" + +import torch + + +class LLAAF(torch.nn.Module): + """Pytorch implementation of layer-wise locally adaptive + activation functions (L-LAAF). + + Args: + activation: The activation function to use. + n: The scaling factor. + + Examples: + + To define a L-LAAF ReLU with the scaling factor ``n = 10``: + + .. code-block:: python + + n = 10 + llaaf = LLAAF(torch.relu, n) + + References: + `A. D. Jagtap, K. Kawaguchi, & G. E. Karniadakis. Locally adaptive activation + functions with slope recovery for deep and physics-informed neural networks. + Proceedings of the Royal Society A, 476(2239), 20200334, 2020 + `_. + """ + + def __init__(self, activation, n): + super().__init__() + self.activation = activation + self.n = n + self.a = torch.nn.Parameter(torch.tensor(1.0 / n)) + + def forward(self, x): + return self.activation(self.n * self.a * x)