Skip to content

Commit 3de1044

Browse files
committed
Simplify the code
1 parent 65d1208 commit 3de1044

File tree

1 file changed

+14
-31
lines changed

1 file changed

+14
-31
lines changed

deepxde/nn/regularizers.py

Lines changed: 14 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,5 @@
11
from ..backend import tf
22

3-
REGULARIZER_DICT = {
4-
"l1": tf.keras.regularizers.L1,
5-
"l2": tf.keras.regularizers.L2,
6-
"l1l2": tf.keras.regularizers.L1L2,
7-
"l1+l2": tf.keras.regularizers.L1L2,
8-
}
9-
if hasattr(tf.keras.regularizers, "OrthogonalRegularizer"):
10-
REGULARIZER_DICT["orthogonal"] = tf.keras.regularizers.OrthogonalRegularizer
11-
123

134
def get(identifier):
145
"""Retrieves a TensorFlow regularizer instance based on the given identifier.
@@ -24,35 +15,27 @@ def get(identifier):
2415
# TODO: other backends
2516
if identifier is None or not identifier:
2617
return None
18+
if not isinstance(identifier, (list, tuple)):
19+
raise ValueError("Identifier must be a list or a tuple.")
2720

28-
if isinstance(identifier, (list, tuple)):
29-
name = identifier[0].lower()
30-
factor = identifier[1:]
31-
else:
32-
raise ValueError("Identifier must be a non-empty list or tuple.")
33-
21+
name = identifier[0].lower()
22+
factor = identifier[1:]
3423
if not factor:
3524
raise ValueError("Regularization factor must be provided.")
3625

37-
regularizer_class = REGULARIZER_DICT.get(name)
38-
if not regularizer_class:
39-
if name == "orthogonal":
26+
if name == "l1":
27+
return tf.keras.regularizers.L1(l1=factor[0])
28+
if name == "l2":
29+
return tf.keras.regularizers.L2(l2=factor[0])
30+
if name == "orthogonal":
31+
if not hasattr(tf.keras.regularizers, "OrthogonalRegularizer"):
4032
raise ValueError(
4133
"The 'orthogonal' regularizer is not available "
4234
"in your version of TensorFlow"
4335
)
44-
raise ValueError(f"Unknown regularizer name: {name}")
45-
46-
regularizer_kwargs = {}
47-
if name == "l1":
48-
regularizer_kwargs["l1"] = factor[0]
49-
elif name == "l2":
50-
regularizer_kwargs["l2"] = factor[0]
51-
elif name == "orthogonal":
52-
regularizer_kwargs["factor"] = factor[0]
53-
elif name in ("l1l2", "l1+l2"):
36+
return tf.keras.regularizers.OrthogonalRegularizer(factor=factor[0])
37+
if name in ("l1l2", "l1+l2"):
5438
if len(factor) < 2:
5539
raise ValueError("L1L2 regularizer requires both L1/L2 penalties.")
56-
regularizer_kwargs["l1"] = factor[0]
57-
regularizer_kwargs["l2"] = factor[1]
58-
return regularizer_class(**regularizer_kwargs)
40+
return tf.keras.regularizers.L1L2(l1=factor[0], l2=factor[1])
41+
raise ValueError(f"Unknown regularizer name: {name}")

0 commit comments

Comments
 (0)