Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions deepxde/nn/activations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .. import backend as bkd
from .. import config
from .. import utils
from ..backend import backend_name, tf


Expand All @@ -26,10 +27,12 @@ def layer_wise_locally_adaptive(activation, n=1):
<https://doi.org/10.1098/rspa.2020.0334>`_.
"""
# TODO: other backends
if backend_name != "tensorflow.compat.v1":
raise NotImplementedError("Only tensorflow.compat.v1 backend supports L-LAAF.")
a = tf.Variable(1 / n, dtype=config.real(tf))
return lambda x: activation(n * a * x)
if backend_name == "tensorflow.compat.v1":
a = tf.Variable(1 / n, dtype=config.real(tf))
return lambda x: activation(n * a * x)
if backend_name == "pytorch":
return utils.LLAAF(activation, n)
raise NotImplementedError(f"L-LAAF is not implemented for {backend_name} backend.")


def get(identifier):
Expand Down
39 changes: 39 additions & 0 deletions deepxde/utils/pytorch.py
Original file line number Diff line number Diff line change
@@ -1 +1,40 @@
"""Utilities of pytorch."""

import torch


class LLAAF(torch.nn.Module):
"""Pytorch implementation of layer-wise locally adaptive
activation functions (L-LAAF).
Examples:
To define a L-LAAF ReLU with the scaling factor ``n = 10``:
.. code-block:: python
n = 10
llaaf = LLAAF(torch.relu, n)
References:
`A. D. Jagtap, K. Kawaguchi, & G. E. Karniadakis. Locally adaptive activation
functions with slope recovery for deep and physics-informed neural networks.
Proceedings of the Royal Society A, 476(2239), 20200334, 2020
<https://doi.org/10.1098/rspa.2020.0334>`_.
"""

def __init__(self, activation, n):
"""
Initialize the L-LAAF module.
Args:
activation: The activation function to use.
n: The scaling factor.
"""
super().__init__()
self.activation = activation
self.n = n
self.a = torch.nn.Parameter(torch.tensor(1.0 / n))

def forward(self, x):
return self.activation(self.n * self.a * x)