Skip to content

Commit 7329fe6

Browse files
committed
Modify Tensorflow regularizers
1 parent ba8e824 commit 7329fe6

File tree

1 file changed

+54
-10
lines changed

1 file changed

+54
-10
lines changed

deepxde/nn/regularizers.py

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,61 @@
11
from ..backend import tf
22

3+
REGULARIZER_DICT = {
4+
"l1": tf.keras.regularizers.L1,
5+
"l2": tf.keras.regularizers.L2,
6+
"l1l2": tf.keras.regularizers.L1L2,
7+
"l1+l2": tf.keras.regularizers.L1L2,
8+
}
9+
if hasattr(tf.keras.regularizers, "OrthogonalRegularizer"):
10+
REGULARIZER_DICT["orthogonal"] = tf.keras.regularizers.OrthogonalRegularizer
11+
312

413
def get(identifier):
14+
"""Retrieves a TensorFlow regularizer instance based on the given identifier.
15+
16+
Args:
17+
identifier (str or list/tuple): Specifies the type of regularizer and
18+
the optional scale values. If a string, it should be one of "l1",
19+
"l2", "orthogonal", or "l1l2" ("l1+l2"); default scale values will
20+
be used. If a list or tuple, the first element should be one of
21+
the above strings, followed by scale values. For "l1", "l2", or
22+
"orthogonal", you can provide a single scale value. For "l1l2" ("l1+l2"),
23+
you can provide both "l1" and "l2" scale values.
24+
"""
25+
526
# TODO: other backends
627
if identifier is None:
728
return None
8-
name, scales = identifier[0].lower(), identifier[1:]
9-
return (
10-
tf.keras.regularizers.L1(l1=scales[0])
11-
if name == "l1"
12-
else tf.keras.regularizers.L2(l2=scales[0])
13-
if name == "l2"
14-
else tf.keras.regularizers.L1L2(l1=scales[0], l2=scales[1])
15-
if name in ("l1+l2", "l1l2")
16-
else None
17-
)
29+
30+
if isinstance(identifier, str):
31+
name = identifier.lower()
32+
scales = []
33+
elif isinstance(identifier, (list, tuple)) and identifier:
34+
name = identifier[0].lower()
35+
scales = identifier[1:]
36+
else:
37+
raise ValueError("Identifier must be a string or a non-empty list or tuple.")
38+
39+
regularizer_class = REGULARIZER_DICT.get(name)
40+
if not regularizer_class:
41+
if name == "orthogonal":
42+
raise ValueError(
43+
"The 'orthogonal' regularizer is not available "
44+
"in your version of TensorFlow"
45+
)
46+
raise ValueError(f"Unknown regularizer name: {name}")
47+
48+
regularizer_kwargs = {}
49+
if scales:
50+
if name == "l1":
51+
regularizer_kwargs["l1"] = scales[0]
52+
elif name == "l2":
53+
regularizer_kwargs["l2"] = scales[0]
54+
elif name == "orthogonal":
55+
regularizer_kwargs["factor"] = scales[0]
56+
elif name in ("l1l2", "l1+l2"):
57+
if len(scales) < 2:
58+
raise ValueError("L1L2 regularizer requires both L1/L2 penalties.")
59+
regularizer_kwargs["l1"] = scales[0]
60+
regularizer_kwargs["l2"] = scales[1]
61+
return regularizer_class(**regularizer_kwargs)

0 commit comments

Comments
 (0)