File tree Expand file tree Collapse file tree 2 files changed +43
-4
lines changed Expand file tree Collapse file tree 2 files changed +43
-4
lines changed Original file line number Diff line number Diff line change 11from .. import backend as bkd
22from .. import config
3+ from .. import utils
34from ..backend import backend_name , tf
45
56
@@ -26,10 +27,12 @@ def layer_wise_locally_adaptive(activation, n=1):
2627 <https://doi.org/10.1098/rspa.2020.0334>`_.
2728 """
2829 # TODO: other backends
29- if backend_name != "tensorflow.compat.v1" :
30- raise NotImplementedError ("Only tensorflow.compat.v1 backend supports L-LAAF." )
31- a = tf .Variable (1 / n , dtype = config .real (tf ))
32- return lambda x : activation (n * a * x )
30+ if backend_name == "tensorflow.compat.v1" :
31+ a = tf .Variable (1 / n , dtype = config .real (tf ))
32+ return lambda x : activation (n * a * x )
33+ if backend_name == "pytorch" :
34+ return utils .LLAAF (activation , n )
35+ raise NotImplementedError (f"L-LAAF is not implemented for { backend_name } backend." )
3336
3437
3538def get (identifier ):
Original file line number Diff line number Diff line change 11"""Utilities of pytorch."""
2+
3+ import torch
4+
5+
6+ class LLAAF (torch .nn .Module ):
7+ """Pytorch implementation of layer-wise locally adaptive
8+ activation functions (L-LAAF).
9+
10+ Args:
11+ activation: The activation function to use.
12+ n: The scaling factor.
13+
14+ Examples:
15+
16+ To define a L-LAAF ReLU with the scaling factor ``n = 10``:
17+
18+ .. code-block:: python
19+
20+ n = 10
21+ llaaf = LLAAF(torch.relu, n)
22+
23+ References:
24+ `A. D. Jagtap, K. Kawaguchi, & G. E. Karniadakis. Locally adaptive activation
25+ functions with slope recovery for deep and physics-informed neural networks.
26+ Proceedings of the Royal Society A, 476(2239), 20200334, 2020
27+ <https://doi.org/10.1098/rspa.2020.0334>`_.
28+ """
29+
30+ def __init__ (self , activation , n ):
31+ super ().__init__ ()
32+ self .activation = activation
33+ self .n = n
34+ self .a = torch .nn .Parameter (torch .tensor (1.0 / n ))
35+
36+ def forward (self , x ):
37+ return self .activation (self .n * self .a * x )
You can’t perform that action at this time.
0 commit comments