1- from ..backend import tf
1+ from .. import backend as bkd
22
33
44def get (identifier ):
5- """Retrieves a TensorFlow regularizer instance based on the given identifier.
5+ """Retrieves a regularizer instance based on the given identifier.
66
77 Args:
88 identifier (list/tuple): Specifies the type and factor(s) of the regularizer.
@@ -11,7 +11,6 @@ def get(identifier):
1111 For "l1l2", provide both "l1" and "l2" factors.
1212 """
1313
14- # TODO: other backends
1514 if identifier is None or not identifier :
1615 return None
1716 if not isinstance (identifier , (list , tuple )):
@@ -23,11 +22,12 @@ def get(identifier):
2322 raise ValueError ("Regularization factor must be provided." )
2423
2524 if name == "l1" :
26- return tf . keras . regularizers . L1 ( l1 = factor [0 ])
25+ return bkd . l1_regularization ( factor [0 ])
2726 if name == "l2" :
28- return tf . keras . regularizers . L2 ( l2 = factor [0 ])
27+ return bkd . l2_regularization ( factor [0 ])
2928 if name in ("l1l2" , "l1+l2" ):
29+ # TODO: only supported by 'tensorflow.compat.v1' and 'tensorflow' now.
3030 if len (factor ) < 2 :
3131 raise ValueError ("L1L2 regularizer requires both L1/L2 penalties." )
32- return tf . keras . regularizers . L1L2 ( l1 = factor [0 ], l2 = factor [1 ])
32+ return bkd . l1_l2_regularization ( factor [0 ], factor [1 ])
3333 raise ValueError (f"Unknown regularizer name: { name } " )
0 commit comments