File tree Expand file tree Collapse file tree 4 files changed +37
-9
lines changed Expand file tree Collapse file tree 4 files changed +37
-9
lines changed Original file line number Diff line number Diff line change @@ -502,3 +502,12 @@ def sparse_dense_matmul(x, y):
502502 Returns:
503503 Tensor: The multiplication result.
504504 """
505+
506+ def l1_decay (x ):
507+ """Implement the L1 weight decay regularization."""
508+
509+ def l2_decay (x ):
510+ """Implement the L2 weight decay regularization."""
511+
512+ def l1_l2_decay (x ,y ):
513+ """Implement the L1 and L2 weight decay regularization."""
Original file line number Diff line number Diff line change @@ -229,3 +229,9 @@ def matmul(x, y):
229229
230230def sparse_dense_matmul (x , y ):
231231 return paddle .sparse .matmul (x , y )
232+
233+ def l1_decay (x ):
234+ return paddle .regularizer .L1Decay (coeff = x )
235+
236+ def l2_decay (x ):
237+ return paddle .regularizer .L2Decay (coeff = x )
Original file line number Diff line number Diff line change @@ -245,3 +245,12 @@ def matmul(x, y):
245245
246246def sparse_dense_matmul (x , y ):
247247 return tf .sparse .sparse_dense_matmul (x , y )
248+
249+ def l1_decay (x ):
250+ return tf .keras .regularizers .L1 (l1 = x )
251+
252+ def l2_decay (x ):
253+ return tf .keras .regularizers .L2 (l2 = x )
254+
255+ def l1_l2_decay (x ,y ):
256+ return tf .keras .regularizers .L1L2 (l1 = x , l2 = y )
Original file line number Diff line number Diff line change 1- from ..backend import tf
1+ from .. import backend as bkd
2+ from ..backend import backend_name
23
34
45def get (identifier ):
@@ -22,12 +23,15 @@ def get(identifier):
2223 if not factor :
2324 raise ValueError ("Regularization factor must be provided." )
2425
25- if name == "l1" :
26- return tf .keras .regularizers .L1 (l1 = factor [0 ])
27- if name == "l2" :
28- return tf .keras .regularizers .L2 (l2 = factor [0 ])
29- if name in ("l1l2" , "l1+l2" ):
30- if len (factor ) < 2 :
31- raise ValueError ("L1L2 regularizer requires both L1/L2 penalties." )
32- return tf .keras .regularizers .L1L2 (l1 = factor [0 ], l2 = factor [1 ])
26+ try :
27+ if name == "l1" :
28+ return bkd .l1_decay (factor [0 ])
29+ if name == "l2" :
30+ return bkd .l2_decay (factor [0 ])
31+ if name in ("l1l2" , "l1+l2" ):
32+ # TODO: only supported by 'tensorflow.compat.v1' now.
33+ if len (factor ) < 2 :
34+ return bkd .l1_l2_decay (factor [0 ], factor [1 ])
35+ except Exception :
36+ print (f"{ name } regularization to be implemented for backend { backend_name } now." )
3337 raise ValueError (f"Unknown regularizer name: { name } " )
You can’t perform that action at this time.
0 commit comments