Skip to content

Commit 82e5203

Browse files
committed
Backend Pytorch: L-LAAF implementation
1 parent b79d2fd commit 82e5203

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

deepxde/nn/activations.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .. import backend as bkd
22
from .. import config
3+
from .. import utils
34
from ..backend import backend_name, tf
45

56

@@ -26,10 +27,12 @@ def layer_wise_locally_adaptive(activation, n=1):
2627
<https://doi.org/10.1098/rspa.2020.0334>`_.
2728
"""
2829
# TODO: other backends
29-
if backend_name != "tensorflow.compat.v1":
30-
raise NotImplementedError("Only tensorflow.compat.v1 backend supports L-LAAF.")
31-
a = tf.Variable(1 / n, dtype=config.real(tf))
32-
return lambda x: activation(n * a * x)
30+
if backend_name == "tensorflow.compat.v1":
31+
a = tf.Variable(1 / n, dtype=config.real(tf))
32+
return lambda x: activation(n * a * x)
33+
if backend_name == "pytorch":
34+
return utils.LLAAF(activation, n)
35+
raise NotImplementedError(f"L-LAAF is not implemented for {backend_name} backend.")
3336

3437

3538
def get(identifier):

deepxde/utils/pytorch.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,14 @@
11
"""Utilities of pytorch."""
2+
3+
from ..backend import torch
4+
5+
6+
class LLAAF(torch.nn.Module):
7+
def __init__(self, activation, n):
8+
super().__init__()
9+
self.activation = activation
10+
self.n = n
11+
self.a = torch.nn.Parameter(torch.tensor(1.0 / n))
12+
13+
def forward(self, x):
14+
return self.activation(self.n * self.a * x)

0 commit comments

Comments
 (0)