|
1 | 1 | from ..backend import tf |
2 | 2 |
|
| 3 | +REGULARIZER_DICT = { |
| 4 | + "l1": tf.keras.regularizers.L1, |
| 5 | + "l2": tf.keras.regularizers.L2, |
| 6 | + "l1l2": tf.keras.regularizers.L1L2, |
| 7 | + "l1+l2": tf.keras.regularizers.L1L2, |
| 8 | +} |
| 9 | +if hasattr(tf.keras.regularizers, "OrthogonalRegularizer"): |
| 10 | + REGULARIZER_DICT["orthogonal"] = tf.keras.regularizers.OrthogonalRegularizer |
| 11 | + |
3 | 12 |
|
4 | 13 | def get(identifier): |
| 14 | + """Retrieves a TensorFlow regularizer instance based on the given identifier. |
| 15 | +
|
| 16 | + Args: |
| 17 | + identifier (str or list/tuple): Specifies the type of regularizer and |
| 18 | + the optional scale values. If a string, it should be one of "l1", |
| 19 | + "l2", "orthogonal", or "l1l2" ("l1+l2"); default scale values will |
| 20 | + be used. If a list or tuple, the first element should be one of |
| 21 | + the above strings, followed by scale values. For "l1", "l2", or |
| 22 | + "orthogonal", you can provide a single scale value. For "l1l2" ("l1+l2"), |
| 23 | + you can provide both "l1" and "l2" scale values. |
| 24 | + """ |
| 25 | + |
5 | 26 | # TODO: other backends |
6 | 27 | if identifier is None: |
7 | 28 | return None |
8 | | - name, scales = identifier[0].lower(), identifier[1:] |
9 | | - return ( |
10 | | - tf.keras.regularizers.L1(l1=scales[0]) |
11 | | - if name == "l1" |
12 | | - else tf.keras.regularizers.L2(l2=scales[0]) |
13 | | - if name == "l2" |
14 | | - else tf.keras.regularizers.L1L2(l1=scales[0], l2=scales[1]) |
15 | | - if name in ("l1+l2", "l1l2") |
16 | | - else None |
17 | | - ) |
| 29 | + |
| 30 | + if isinstance(identifier, str): |
| 31 | + name = identifier.lower() |
| 32 | + scales = [] |
| 33 | + elif isinstance(identifier, (list, tuple)) and identifier: |
| 34 | + name = identifier[0].lower() |
| 35 | + scales = identifier[1:] |
| 36 | + else: |
| 37 | + raise ValueError("Identifier must be a string or a non-empty list or tuple.") |
| 38 | + |
| 39 | + regularizer_class = REGULARIZER_DICT.get(name) |
| 40 | + if not regularizer_class: |
| 41 | + if name == "orthogonal": |
| 42 | + raise ValueError( |
| 43 | + "The 'orthogonal' regularizer is not available " |
| 44 | + "in your version of TensorFlow" |
| 45 | + ) |
| 46 | + raise ValueError(f"Unknown regularizer name: {name}") |
| 47 | + |
| 48 | + regularizer_kwargs = {} |
| 49 | + if scales: |
| 50 | + if name == "l1": |
| 51 | + regularizer_kwargs["l1"] = scales[0] |
| 52 | + elif name == "l2": |
| 53 | + regularizer_kwargs["l2"] = scales[0] |
| 54 | + elif name == "orthogonal": |
| 55 | + regularizer_kwargs["factor"] = scales[0] |
| 56 | + elif name in ("l1l2", "l1+l2"): |
| 57 | + if len(scales) < 2: |
| 58 | + raise ValueError("L1L2 regularizer requires both L1/L2 penalties.") |
| 59 | + regularizer_kwargs["l1"] = scales[0] |
| 60 | + regularizer_kwargs["l2"] = scales[1] |
| 61 | + return regularizer_class(**regularizer_kwargs) |
0 commit comments