@@ -19,14 +19,16 @@ def is_external_optimizer(optimizer):
1919 return optimizer in ["L-BFGS" , "L-BFGS-B" ]
2020
2121
22- def get (params , optimizer , learning_rate = None , decay = None ):
22+ def get (params , optimizer , learning_rate = None , decay = None , weight_decay = None ):
2323 """Retrieves an Optimizer instance."""
2424 if isinstance (optimizer , paddle .optimizer .Optimizer ):
2525 return optimizer
2626
2727 if optimizer in ["L-BFGS" , "L-BFGS-B" ]:
2828 if learning_rate is not None or decay is not None :
2929 print ("Warning: learning rate is ignored for {}" .format (optimizer ))
30+ if weight_decay is not None :
31+ raise ValueError ("L-BFGS optimizer doesn't support weight_decay" )
3032 optim = paddle .optimizer .LBFGS (
3133 learning_rate = 1 ,
3234 max_iter = LBFGS_options ["iter_per_step" ],
@@ -46,5 +48,28 @@ def get(params, optimizer, learning_rate=None, decay=None):
4648 learning_rate = _get_lr_scheduler (learning_rate , decay )
4749
4850 if optimizer == "adam" :
49- return paddle .optimizer .Adam (learning_rate = learning_rate , parameters = params )
51+ return paddle .optimizer .Adam (
52+ learning_rate = learning_rate , parameters = params , weight_decay = weight_decay
53+ )
54+ if optimizer == "sgd" :
55+ return paddle .optimizer .SGD (
56+ learning_rate = learning_rate , parameters = params , weight_decay = weight_decay
57+ )
58+ if optimizer == "rmsprop" :
59+ return paddle .optimizer .RMSProp (
60+ learning_rate = learning_rate ,
61+ parameters = params ,
62+ weight_decay = weight_decay ,
63+ )
64+ if optimizer == "adamw" :
65+ if (
66+ not isinstance (weight_decay , paddle .regularizer .L2Decay )
67+ or weight_decay ._coeff == 0
68+ ):
69+ raise ValueError ("AdamW optimizer requires non-zero L2 regularizer" )
70+ return paddle .optimizer .AdamW (
71+ learning_rate = learning_rate ,
72+ parameters = params ,
73+ weight_decay = weight_decay ._coeff ,
74+ )
5075 raise NotImplementedError (f"{ optimizer } to be implemented for backend Paddle." )
0 commit comments