Skip to content

Commit 4c8b2d2

Browse files
authored
Modify Tensorflow regularizers (#1864)
1 parent 40cd7e5 commit 4c8b2d2

File tree

1 file changed

+27
-11
lines changed

1 file changed

+27
-11
lines changed

deepxde/nn/regularizers.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,32 @@
22

33

44
def get(identifier):
5+
"""Retrieves a TensorFlow regularizer instance based on the given identifier.
6+
7+
Args:
8+
identifier (list/tuple): Specifies the type and factor(s) of the regularizer.
9+
The first element should be one of "l1", "l2", or "l1l2" ("l1+l2").
10+
For "l1" and "l2", a single regularization factor is expected.
11+
For "l1l2", provide both "l1" and "l2" factors.
12+
"""
13+
514
# TODO: other backends
6-
if identifier is None:
15+
if identifier is None or not identifier:
716
return None
8-
name, scales = identifier[0].lower(), identifier[1:]
9-
return (
10-
tf.keras.regularizers.L1(l1=scales[0])
11-
if name == "l1"
12-
else tf.keras.regularizers.L2(l2=scales[0])
13-
if name == "l2"
14-
else tf.keras.regularizers.L1L2(l1=scales[0], l2=scales[1])
15-
if name in ("l1+l2", "l1l2")
16-
else None
17-
)
17+
if not isinstance(identifier, (list, tuple)):
18+
raise ValueError("Identifier must be a list or a tuple.")
19+
20+
name = identifier[0].lower()
21+
factor = identifier[1:]
22+
if not factor:
23+
raise ValueError("Regularization factor must be provided.")
24+
25+
if name == "l1":
26+
return tf.keras.regularizers.L1(l1=factor[0])
27+
if name == "l2":
28+
return tf.keras.regularizers.L2(l2=factor[0])
29+
if name in ("l1l2", "l1+l2"):
30+
if len(factor) < 2:
31+
raise ValueError("L1L2 regularizer requires both L1/L2 penalties.")
32+
return tf.keras.regularizers.L1L2(l1=factor[0], l2=factor[1])
33+
raise ValueError(f"Unknown regularizer name: {name}")

0 commit comments

Comments
 (0)