|
2 | 2 |
|
3 | 3 |
|
4 | 4 | def get(identifier): |
| 5 | + """Retrieves a TensorFlow regularizer instance based on the given identifier. |
| 6 | +
|
| 7 | + Args: |
| 8 | + identifier (list/tuple): Specifies the type and factor(s) of the regularizer. |
| 9 | + The first element should be one of "l1", "l2", or "l1l2" ("l1+l2"). |
| 10 | + For "l1" and "l2", a single regularization factor is expected. |
| 11 | + For "l1l2", provide both "l1" and "l2" factors. |
| 12 | + """ |
| 13 | + |
5 | 14 | # TODO: other backends |
6 | | - if identifier is None: |
| 15 | + if identifier is None or not identifier: |
7 | 16 | return None |
8 | | - name, scales = identifier[0].lower(), identifier[1:] |
9 | | - return ( |
10 | | - tf.keras.regularizers.L1(l1=scales[0]) |
11 | | - if name == "l1" |
12 | | - else tf.keras.regularizers.L2(l2=scales[0]) |
13 | | - if name == "l2" |
14 | | - else tf.keras.regularizers.L1L2(l1=scales[0], l2=scales[1]) |
15 | | - if name in ("l1+l2", "l1l2") |
16 | | - else None |
17 | | - ) |
| 17 | + if not isinstance(identifier, (list, tuple)): |
| 18 | + raise ValueError("Identifier must be a list or a tuple.") |
| 19 | + |
| 20 | + name = identifier[0].lower() |
| 21 | + factor = identifier[1:] |
| 22 | + if not factor: |
| 23 | + raise ValueError("Regularization factor must be provided.") |
| 24 | + |
| 25 | + if name == "l1": |
| 26 | + return tf.keras.regularizers.L1(l1=factor[0]) |
| 27 | + if name == "l2": |
| 28 | + return tf.keras.regularizers.L2(l2=factor[0]) |
| 29 | + if name in ("l1l2", "l1+l2"): |
| 30 | + if len(factor) < 2: |
| 31 | + raise ValueError("L1L2 regularizer requires both L1/L2 penalties.") |
| 32 | + return tf.keras.regularizers.L1L2(l1=factor[0], l2=factor[1]) |
| 33 | + raise ValueError(f"Unknown regularizer name: {name}") |
0 commit comments