File tree Expand file tree Collapse file tree 2 files changed +20
-4
lines changed Expand file tree Collapse file tree 2 files changed +20
-4
lines changed Original file line number Diff line number Diff line change 11from .. import backend as bkd
22from .. import config
3+ from .. import utils
34from ..backend import backend_name , tf
45
56
@@ -26,10 +27,12 @@ def layer_wise_locally_adaptive(activation, n=1):
2627 <https://doi.org/10.1098/rspa.2020.0334>`_.
2728 """
2829 # TODO: other backends
29- if backend_name != "tensorflow.compat.v1" :
30- raise NotImplementedError ("Only tensorflow.compat.v1 backend supports L-LAAF." )
31- a = tf .Variable (1 / n , dtype = config .real (tf ))
32- return lambda x : activation (n * a * x )
30+ if backend_name == "tensorflow.compat.v1" :
31+ a = tf .Variable (1 / n , dtype = config .real (tf ))
32+ return lambda x : activation (n * a * x )
33+ if backend_name == "pytorch" :
34+ return utils .LLAAF (activation , n )
35+ raise NotImplementedError (f"L-LAAF is not implemented for { backend_name } backend." )
3336
3437
3538def get (identifier ):
Original file line number Diff line number Diff line change 11"""Utilities of pytorch."""
2+
3+ from ..backend import torch
4+
5+
6+ class LLAAF (torch .nn .Module ):
7+ def __init__ (self , activation , n ):
8+ super ().__init__ ()
9+ self .activation = activation
10+ self .n = n
11+ self .a = torch .nn .Parameter (torch .tensor (1.0 / n ))
12+
13+ def forward (self , x ):
14+ return self .activation (self .n * self .a * x )
You can’t perform that action at this time.
0 commit comments