@@ -14,27 +14,25 @@ def get(identifier):
1414 """Retrieves a TensorFlow regularizer instance based on the given identifier.
1515
1616 Args:
17- identifier (str or list/tuple): Specifies the type of regularizer and
18- the optional scale values. If a string, it should be one of "l1",
19- "l2", "orthogonal", or "l1l2" ("l1+l2"); default scale values will
20- be used. If a list or tuple, the first element should be one of
21- the above strings, followed by scale values. For "l1", "l2", or
22- "orthogonal", you can provide a single scale value. For "l1l2" ("l1+l2"),
23- you can provide both "l1" and "l2" scale values.
17+ identifier (list/tuple): Specifies the type of regularizer and
18+ regularization factor. The first element should be one of "l1", "l2",
19+ "orthogonal", or "l1l2" ("l1+l2"). For "l1", "l2", or "orthogonal",
20+ you can provide a single factor value. For "l1l2" ("l1+l2"),
21+ both "l1" and "l2" factors are required.
2422 """
2523
2624 # TODO: other backends
2725 if identifier is None :
2826 return None
2927
30- if isinstance (identifier , str ):
31- name = identifier .lower ()
32- scales = []
33- elif isinstance (identifier , (list , tuple )) and identifier :
28+ if isinstance (identifier , (list , tuple )) and identifier :
3429 name = identifier [0 ].lower ()
35- scales = identifier [1 :]
30+ factor = identifier [1 :]
3631 else :
37- raise ValueError ("Identifier must be a string or a non-empty list or tuple." )
32+ raise ValueError ("Identifier must be a non-empty list or tuple." )
33+
34+ if not factor :
35+ raise ValueError ("Regularization factor must be provided." )
3836
3937 regularizer_class = REGULARIZER_DICT .get (name )
4038 if not regularizer_class :
@@ -46,16 +44,15 @@ def get(identifier):
4644 raise ValueError (f"Unknown regularizer name: { name } " )
4745
4846 regularizer_kwargs = {}
49- if scales :
50- if name == "l1" :
51- regularizer_kwargs ["l1" ] = scales [0 ]
52- elif name == "l2" :
53- regularizer_kwargs ["l2" ] = scales [0 ]
54- elif name == "orthogonal" :
55- regularizer_kwargs ["factor" ] = scales [0 ]
56- elif name in ("l1l2" , "l1+l2" ):
57- if len (scales ) < 2 :
58- raise ValueError ("L1L2 regularizer requires both L1/L2 penalties." )
59- regularizer_kwargs ["l1" ] = scales [0 ]
60- regularizer_kwargs ["l2" ] = scales [1 ]
47+ if name == "l1" :
48+ regularizer_kwargs ["l1" ] = factor [0 ]
49+ elif name == "l2" :
50+ regularizer_kwargs ["l2" ] = factor [0 ]
51+ elif name == "orthogonal" :
52+ regularizer_kwargs ["factor" ] = factor [0 ]
53+ elif name in ("l1l2" , "l1+l2" ):
54+ if len (factor ) < 2 :
55+ raise ValueError ("L1L2 regularizer requires both L1/L2 penalties." )
56+ regularizer_kwargs ["l1" ] = factor [0 ]
57+ regularizer_kwargs ["l2" ] = factor [1 ]
6158 return regularizer_class (** regularizer_kwargs )
0 commit comments