Skip to content

Commit da53c3b

Browse files
committed
Backend paddle: add regularizer
1 parent 8275aeb commit da53c3b

File tree

4 files changed

+37
-9
lines changed

4 files changed

+37
-9
lines changed

deepxde/backend/backend.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,3 +502,12 @@ def sparse_dense_matmul(x, y):
502502
Returns:
503503
Tensor: The multiplication result.
504504
"""
505+
506+
def l1_decay(x):
507+
"""Implement the L1 weight decay regularization."""
508+
509+
def l2_decay(x):
510+
"""Implement the L2 weight decay regularization."""
511+
512+
def l1_l2_decay(x,y):
513+
"""Implement the L1 and L2 weight decay regularization."""

deepxde/backend/paddle/tensor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,3 +229,9 @@ def matmul(x, y):
229229

230230
def sparse_dense_matmul(x, y):
231231
return paddle.sparse.matmul(x, y)
232+
233+
def l1_decay(x):
234+
return paddle.regularizer.L1Decay(coeff=x)
235+
236+
def l2_decay(x):
237+
return paddle.regularizer.L2Decay(coeff=x)

deepxde/backend/tensorflow_compat_v1/tensor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,12 @@ def matmul(x, y):
245245

246246
def sparse_dense_matmul(x, y):
247247
return tf.sparse.sparse_dense_matmul(x, y)
248+
249+
def l1_decay(x):
250+
return tf.keras.regularizers.L1(l1=x)
251+
252+
def l2_decay(x):
253+
return tf.keras.regularizers.L2(l2=x)
254+
255+
def l1_l2_decay(x,y):
256+
return tf.keras.regularizers.L1L2(l1=x, l2=y)

deepxde/nn/regularizers.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from ..backend import tf
1+
from .. import backend as bkd
2+
from ..backend import backend_name
23

34

45
def get(identifier):
@@ -22,12 +23,15 @@ def get(identifier):
2223
if not factor:
2324
raise ValueError("Regularization factor must be provided.")
2425

25-
if name == "l1":
26-
return tf.keras.regularizers.L1(l1=factor[0])
27-
if name == "l2":
28-
return tf.keras.regularizers.L2(l2=factor[0])
29-
if name in ("l1l2", "l1+l2"):
30-
if len(factor) < 2:
31-
raise ValueError("L1L2 regularizer requires both L1/L2 penalties.")
32-
return tf.keras.regularizers.L1L2(l1=factor[0], l2=factor[1])
26+
try:
27+
if name == "l1":
28+
return bkd.l1_decay(factor[0])
29+
if name == "l2":
30+
return bkd.l2_decay(factor[0])
31+
if name in ("l1l2", "l1+l2"):
32+
# TODO: only supported by 'tensorflow.compat.v1' now.
33+
if len(factor) < 2:
34+
return bkd.l1_l2_decay(factor[0], factor[1])
35+
except Exception:
36+
print(f"{name} regularization to be implemented for backend {backend_name} now.")
3337
raise ValueError(f"Unknown regularizer name: {name}")

0 commit comments

Comments
 (0)