Skip to content

Commit a20ad97

Browse files
authored
Backend Pytorch: Support L-LAAF (#1993)
1 parent ec0fb74 commit a20ad97

File tree

2 files changed

+43
-4
lines changed

2 files changed

+43
-4
lines changed

deepxde/nn/activations.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .. import backend as bkd
22
from .. import config
3+
from .. import utils
34
from ..backend import backend_name, tf
45

56

@@ -26,10 +27,12 @@ def layer_wise_locally_adaptive(activation, n=1):
2627
<https://doi.org/10.1098/rspa.2020.0334>`_.
2728
"""
2829
# TODO: other backends
29-
if backend_name != "tensorflow.compat.v1":
30-
raise NotImplementedError("Only tensorflow.compat.v1 backend supports L-LAAF.")
31-
a = tf.Variable(1 / n, dtype=config.real(tf))
32-
return lambda x: activation(n * a * x)
30+
if backend_name == "tensorflow.compat.v1":
31+
a = tf.Variable(1 / n, dtype=config.real(tf))
32+
return lambda x: activation(n * a * x)
33+
if backend_name == "pytorch":
34+
return utils.LLAAF(activation, n)
35+
raise NotImplementedError(f"L-LAAF is not implemented for {backend_name} backend.")
3336

3437

3538
def get(identifier):

deepxde/utils/pytorch.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,37 @@
11
"""Utilities of pytorch."""
2+
3+
import torch
4+
5+
6+
class LLAAF(torch.nn.Module):
7+
"""Pytorch implementation of layer-wise locally adaptive
8+
activation functions (L-LAAF).
9+
10+
Args:
11+
activation: The activation function to use.
12+
n: The scaling factor.
13+
14+
Examples:
15+
16+
To define a L-LAAF ReLU with the scaling factor ``n = 10``:
17+
18+
.. code-block:: python
19+
20+
n = 10
21+
llaaf = LLAAF(torch.relu, n)
22+
23+
References:
24+
`A. D. Jagtap, K. Kawaguchi, & G. E. Karniadakis. Locally adaptive activation
25+
functions with slope recovery for deep and physics-informed neural networks.
26+
Proceedings of the Royal Society A, 476(2239), 20200334, 2020
27+
<https://doi.org/10.1098/rspa.2020.0334>`_.
28+
"""
29+
30+
def __init__(self, activation, n):
31+
super().__init__()
32+
self.activation = activation
33+
self.n = n
34+
self.a = torch.nn.Parameter(torch.tensor(1.0 / n))
35+
36+
def forward(self, x):
37+
return self.activation(self.n * self.a * x)

0 commit comments

Comments
 (0)