Skip to content

Commit ec4bdd3

Browse files
authored
Backend paddle: Refactor and add regularizer (#1894)
1 parent 3544fdf commit ec4bdd3

File tree

5 files changed

+77
-6
lines changed

5 files changed

+77
-6
lines changed

deepxde/backend/backend.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,3 +502,42 @@ def sparse_dense_matmul(x, y):
502502
Returns:
503503
Tensor: The multiplication result.
504504
"""
505+
506+
507+
###############################################################################
508+
# Regularization
509+
510+
511+
def l1_regularization(l1):
512+
"""A regularizer that applies a L1 regularization penalty or L1 weight decay.
513+
514+
Warning:
515+
The implementation may vary across different backends.
516+
517+
Args:
518+
l1 (float): L1 regularization factor.
519+
"""
520+
521+
522+
def l2_regularization(l2):
523+
"""A regularizer that applies a L2 regularization penalty or L2 weight decay.
524+
525+
Warning:
526+
The implementation may vary across different backends.
527+
528+
Args:
529+
l2 (float): L2 regularization factor.
530+
"""
531+
532+
533+
def l1_l2_regularization(l1, l2):
534+
"""A regularizer that applies both L1 and L2 regularization penalties or
535+
L1 and L2 weight decay.
536+
537+
Warning:
538+
The implementation may vary across different backends.
539+
540+
Args:
541+
l1 (float): L1 regularization factor.
542+
l2 (float): L2 regularization factor.
543+
"""

deepxde/backend/paddle/tensor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,3 +229,11 @@ def matmul(x, y):
229229

230230
def sparse_dense_matmul(x, y):
231231
return paddle.sparse.matmul(x, y)
232+
233+
234+
def l1_regularization(l1):
235+
return paddle.regularizer.L1Decay(coeff=l1)
236+
237+
238+
def l2_regularization(l2):
239+
return paddle.regularizer.L2Decay(coeff=l2)

deepxde/backend/tensorflow/tensor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,3 +210,15 @@ def zeros_like(input_tensor):
210210

211211
def matmul(x, y):
212212
return tf.linalg.matmul(x, y)
213+
214+
215+
def l1_regularization(l1):
216+
return tf.keras.regularizers.L1(l1=l1)
217+
218+
219+
def l2_regularization(l2):
220+
return tf.keras.regularizers.L2(l2=l2)
221+
222+
223+
def l1_l2_regularization(l1, l2):
224+
return tf.keras.regularizers.L1L2(l1=l1, l2=l2)

deepxde/backend/tensorflow_compat_v1/tensor.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,3 +245,15 @@ def matmul(x, y):
245245

246246
def sparse_dense_matmul(x, y):
247247
return tf.sparse.sparse_dense_matmul(x, y)
248+
249+
250+
def l1_regularization(l1):
251+
return tf.keras.regularizers.L1(l1=l1)
252+
253+
254+
def l2_regularization(l2):
255+
return tf.keras.regularizers.L2(l2=l2)
256+
257+
258+
def l1_l2_regularization(l1, l2):
259+
return tf.keras.regularizers.L1L2(l1=l1, l2=l2)

deepxde/nn/regularizers.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from ..backend import tf
1+
from .. import backend as bkd
22

33

44
def get(identifier):
5-
"""Retrieves a TensorFlow regularizer instance based on the given identifier.
5+
"""Retrieves a regularizer instance based on the given identifier.
66
77
Args:
88
identifier (list/tuple): Specifies the type and factor(s) of the regularizer.
@@ -11,7 +11,6 @@ def get(identifier):
1111
For "l1l2", provide both "l1" and "l2" factors.
1212
"""
1313

14-
# TODO: other backends
1514
if identifier is None or not identifier:
1615
return None
1716
if not isinstance(identifier, (list, tuple)):
@@ -23,11 +22,12 @@ def get(identifier):
2322
raise ValueError("Regularization factor must be provided.")
2423

2524
if name == "l1":
26-
return tf.keras.regularizers.L1(l1=factor[0])
25+
return bkd.l1_regularization(factor[0])
2726
if name == "l2":
28-
return tf.keras.regularizers.L2(l2=factor[0])
27+
return bkd.l2_regularization(factor[0])
2928
if name in ("l1l2", "l1+l2"):
29+
# TODO: only supported by 'tensorflow.compat.v1' and 'tensorflow' now.
3030
if len(factor) < 2:
3131
raise ValueError("L1L2 regularizer requires both L1/L2 penalties.")
32-
return tf.keras.regularizers.L1L2(l1=factor[0], l2=factor[1])
32+
return bkd.l1_l2_regularization(factor[0], factor[1])
3333
raise ValueError(f"Unknown regularizer name: {name}")

0 commit comments

Comments
 (0)