@@ -19,12 +19,14 @@ def is_external_optimizer(optimizer):
1919 return optimizer in ["L-BFGS" , "L-BFGS-B" ]
2020
2121
22- def get (params , optimizer , learning_rate = None , decay = None ):
22+ def get (params , optimizer , learning_rate = None , decay = None , weight_decay = None ):
2323 """Retrieves an Optimizer instance."""
2424 if isinstance (optimizer , paddle .optimizer .Optimizer ):
2525 return optimizer
2626
2727 if optimizer in ["L-BFGS" , "L-BFGS-B" ]:
28+ if weight_decay is not None :
29+ raise ValueError ("L-BFGS optimizer doesn't support weight_decay" )
2830 if learning_rate is not None or decay is not None :
2931 print ("Warning: learning rate is ignored for {}" .format (optimizer ))
3032 optim = paddle .optimizer .LBFGS (
@@ -46,5 +48,17 @@ def get(params, optimizer, learning_rate=None, decay=None):
4648 learning_rate = _get_lr_scheduler (learning_rate , decay )
4749
4850 if optimizer == "adam" :
49- return paddle .optimizer .Adam (learning_rate = learning_rate , parameters = params )
51+ return paddle .optimizer .Adam (learning_rate = learning_rate , parameters = params , weight_decay = weight_decay )
52+ elif optimizer == "sgd" :
53+ return paddle .optimizer .SGD (learning_rate = learning_rate , parameters = params , weight_decay = weight_decay )
54+ elif optimizer == "rmsprop" :
55+ return paddle .optimizer .RMSProp (
56+ learning_rate = learning_rate , parameters = params , weight_decay = weight_decay ,
57+ )
58+ elif optimizer == "adamw" :
59+ if weight_decay [0 ] == 0 :
60+ raise ValueError ("AdamW optimizer requires non-zero weight decay" )
61+ return paddle .optimizer .AdamW (
62+ learning_rate = learning_rate , parameters = params , weight_decay = weight_decay [0 ],
63+ )
5064 raise NotImplementedError (f"{ optimizer } to be implemented for backend Paddle." )
0 commit comments