11from ..backend import tf
22
3- REGULARIZER_DICT = {
4- "l1" : tf .keras .regularizers .L1 ,
5- "l2" : tf .keras .regularizers .L2 ,
6- "l1l2" : tf .keras .regularizers .L1L2 ,
7- "l1+l2" : tf .keras .regularizers .L1L2 ,
8- }
9- if hasattr (tf .keras .regularizers , "OrthogonalRegularizer" ):
10- REGULARIZER_DICT ["orthogonal" ] = tf .keras .regularizers .OrthogonalRegularizer
11-
123
134def get (identifier ):
145 """Retrieves a TensorFlow regularizer instance based on the given identifier.
@@ -24,35 +15,27 @@ def get(identifier):
2415 # TODO: other backends
2516 if identifier is None or not identifier :
2617 return None
18+ if not isinstance (identifier , (list , tuple )):
19+ raise ValueError ("Identifier must be a list or a tuple." )
2720
28- if isinstance (identifier , (list , tuple )):
29- name = identifier [0 ].lower ()
30- factor = identifier [1 :]
31- else :
32- raise ValueError ("Identifier must be a non-empty list or tuple." )
33-
21+ name = identifier [0 ].lower ()
22+ factor = identifier [1 :]
3423 if not factor :
3524 raise ValueError ("Regularization factor must be provided." )
3625
37- regularizer_class = REGULARIZER_DICT .get (name )
38- if not regularizer_class :
39- if name == "orthogonal" :
26+ if name == "l1" :
27+ return tf .keras .regularizers .L1 (l1 = factor [0 ])
28+ if name == "l2" :
29+ return tf .keras .regularizers .L2 (l2 = factor [0 ])
30+ if name == "orthogonal" :
31+ if not hasattr (tf .keras .regularizers , "OrthogonalRegularizer" ):
4032 raise ValueError (
4133 "The 'orthogonal' regularizer is not available "
4234 "in your version of TensorFlow"
4335 )
44- raise ValueError (f"Unknown regularizer name: { name } " )
45-
46- regularizer_kwargs = {}
47- if name == "l1" :
48- regularizer_kwargs ["l1" ] = factor [0 ]
49- elif name == "l2" :
50- regularizer_kwargs ["l2" ] = factor [0 ]
51- elif name == "orthogonal" :
52- regularizer_kwargs ["factor" ] = factor [0 ]
53- elif name in ("l1l2" , "l1+l2" ):
36+ return tf .keras .regularizers .OrthogonalRegularizer (factor = factor [0 ])
37+ if name in ("l1l2" , "l1+l2" ):
5438 if len (factor ) < 2 :
5539 raise ValueError ("L1L2 regularizer requires both L1/L2 penalties." )
56- regularizer_kwargs ["l1" ] = factor [0 ]
57- regularizer_kwargs ["l2" ] = factor [1 ]
58- return regularizer_class (** regularizer_kwargs )
40+ return tf .keras .regularizers .L1L2 (l1 = factor [0 ], l2 = factor [1 ])
41+ raise ValueError (f"Unknown regularizer name: { name } " )
0 commit comments