@@ -11,13 +11,13 @@ def is_external_optimizer(optimizer):
1111
1212def get (params , optimizer , learning_rate = None , decay = None ):
1313 """Retrieves an Optimizer instance."""
14+ # Custom Optimizer
1415 if isinstance (optimizer , torch .optim .Optimizer ):
15- return optimizer
16-
17- if optimizer in ["L-BFGS" , "L-BFGS-B" ]:
16+ optim = optimizer
17+ elif optimizer in ["L-BFGS" , "L-BFGS-B" ]:
1818 if learning_rate is not None or decay is not None :
1919 print ("Warning: learning rate is ignored for {}" .format (optimizer ))
20- return torch .optim .LBFGS (
20+ optim = torch .optim .LBFGS (
2121 params ,
2222 lr = 1 ,
2323 max_iter = LBFGS_options ["iter_per_step" ],
@@ -27,15 +27,33 @@ def get(params, optimizer, learning_rate=None, decay=None):
2727 history_size = LBFGS_options ["maxcor" ],
2828 line_search_fn = None ,
2929 )
30-
31- if learning_rate is None :
32- raise ValueError ("No learning rate for {}." .format (optimizer ))
33-
34- if decay is not None :
35- # TODO: learning rate decay
36- raise NotImplementedError (
37- "learning rate decay to be implemented for backend pytorch."
30+ else :
31+ if learning_rate is None :
32+ raise ValueError ("No learning rate for {}." .format (optimizer ))
33+ if optimizer == "sgd" :
34+ optim = torch .optim .SGD (params , lr = learning_rate )
35+ elif optimizer == "rmsprop" :
36+ optim = torch .optim .RMSprop (params , lr = learning_rate )
37+ elif optimizer == "adam" :
38+ optim = torch .optim .Adam (params , lr = learning_rate )
39+ else :
40+ raise NotImplementedError (
41+ f"{ optimizer } to be implemented for backend pytorch."
42+ )
43+ lr_scheduler = _get_learningrate_scheduler (optim , decay )
44+ return optim , lr_scheduler
45+
46+
47+ def _get_learningrate_scheduler (optim , decay ):
48+ if decay is None :
49+ return None
50+
51+ if decay [0 ] == "step" :
52+ return torch .optim .lr_scheduler .StepLR (
53+ optim , step_size = decay [1 ], gamma = decay [2 ]
3854 )
39- if optimizer == "adam" :
40- return torch .optim .Adam (params , lr = learning_rate )
41- raise NotImplementedError (f"{ optimizer } to be implemented for backend pytorch." )
55+
56+ # TODO: More learning rate scheduler
57+ raise NotImplementedError (
58+ f"{ decay [0 ]} learning rate scheduler to be implemented for backend pytorch."
59+ )
0 commit comments