Skip to content

Commit 37acb87

Browse files
committed
Make regularization factor required
1 parent 7329fe6 commit 37acb87

File tree

1 file changed

+22
-25
lines changed

1 file changed

+22
-25
lines changed

deepxde/nn/regularizers.py

Lines changed: 22 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,27 +14,25 @@ def get(identifier):
1414
"""Retrieves a TensorFlow regularizer instance based on the given identifier.
1515
1616
Args:
17-
identifier (str or list/tuple): Specifies the type of regularizer and
18-
the optional scale values. If a string, it should be one of "l1",
19-
"l2", "orthogonal", or "l1l2" ("l1+l2"); default scale values will
20-
be used. If a list or tuple, the first element should be one of
21-
the above strings, followed by scale values. For "l1", "l2", or
22-
"orthogonal", you can provide a single scale value. For "l1l2" ("l1+l2"),
23-
you can provide both "l1" and "l2" scale values.
17+
identifier (list/tuple): Specifies the type of regularizer and
18+
regularization factor. The first element should be one of "l1", "l2",
19+
"orthogonal", or "l1l2" ("l1+l2"). For "l1", "l2", or "orthogonal",
20+
you can provide a single factor value. For "l1l2" ("l1+l2"),
21+
both "l1" and "l2" factors are required.
2422
"""
2523

2624
# TODO: other backends
2725
if identifier is None:
2826
return None
2927

30-
if isinstance(identifier, str):
31-
name = identifier.lower()
32-
scales = []
33-
elif isinstance(identifier, (list, tuple)) and identifier:
28+
if isinstance(identifier, (list, tuple)) and identifier:
3429
name = identifier[0].lower()
35-
scales = identifier[1:]
30+
factor = identifier[1:]
3631
else:
37-
raise ValueError("Identifier must be a string or a non-empty list or tuple.")
32+
raise ValueError("Identifier must be a non-empty list or tuple.")
33+
34+
if not factor:
35+
raise ValueError("Regularization factor must be provided.")
3836

3937
regularizer_class = REGULARIZER_DICT.get(name)
4038
if not regularizer_class:
@@ -46,16 +44,15 @@ def get(identifier):
4644
raise ValueError(f"Unknown regularizer name: {name}")
4745

4846
regularizer_kwargs = {}
49-
if scales:
50-
if name == "l1":
51-
regularizer_kwargs["l1"] = scales[0]
52-
elif name == "l2":
53-
regularizer_kwargs["l2"] = scales[0]
54-
elif name == "orthogonal":
55-
regularizer_kwargs["factor"] = scales[0]
56-
elif name in ("l1l2", "l1+l2"):
57-
if len(scales) < 2:
58-
raise ValueError("L1L2 regularizer requires both L1/L2 penalties.")
59-
regularizer_kwargs["l1"] = scales[0]
60-
regularizer_kwargs["l2"] = scales[1]
47+
if name == "l1":
48+
regularizer_kwargs["l1"] = factor[0]
49+
elif name == "l2":
50+
regularizer_kwargs["l2"] = factor[0]
51+
elif name == "orthogonal":
52+
regularizer_kwargs["factor"] = factor[0]
53+
elif name in ("l1l2", "l1+l2"):
54+
if len(factor) < 2:
55+
raise ValueError("L1L2 regularizer requires both L1/L2 penalties.")
56+
regularizer_kwargs["l1"] = factor[0]
57+
regularizer_kwargs["l2"] = factor[1]
6158
return regularizer_class(**regularizer_kwargs)

0 commit comments

Comments
 (0)