@@ -186,35 +186,57 @@ def test_step_without_closure(self):
186186 paddle .disable_static ()
187187 x = paddle .arange (26 , dtype = "float32" ).reshape ([2 , 13 ])
188188 linear = paddle .nn .Linear (13 , 5 )
189- adam = paddle .optimizer .Adam (
190- learning_rate = 0.01 ,
191- parameters = linear .parameters (),
192- )
193- adam .zero_grad ()
194- output = linear (x )
195- loss = paddle .mean (output )
196- loss .backward ()
197- adam .step ()
189+ optimizers = [
190+ paddle .optimizer .Adam (
191+ learning_rate = 0.01 ,
192+ parameters = linear .parameters (),
193+ ),
194+ paddle .optimizer .AdamW (
195+ learning_rate = 0.01 ,
196+ parameters = linear .parameters (),
197+ ),
198+ paddle .optimizer .ASGD (
199+ learning_rate = 0.01 ,
200+ parameters = linear .parameters (),
201+ ),
202+ ]
203+ for optimizer in optimizers :
204+ optimizer .zero_grad ()
205+ output = linear (x )
206+ loss = paddle .mean (output )
207+ loss .backward ()
208+ optimizer .step ()
198209
199210 def test_step_with_closure (self ):
200211 paddle .seed (100 )
201212 numpy .random .seed (100 )
202213 paddle .disable_static ()
203214 x = paddle .arange (26 , dtype = "float32" ).reshape ([2 , 13 ])
204215 linear = paddle .nn .Linear (13 , 5 )
205- adam = paddle .optimizer .Adam (
206- learning_rate = 0.01 ,
207- parameters = linear .parameters (),
208- )
209-
210- def closure ():
211- adam .zero_grad ()
212- output = linear (x )
213- loss = paddle .mean (output )
214- loss .backward ()
215- return loss
216-
217- loss = adam .step (closure )
216+ optimizers = [
217+ paddle .optimizer .Adam (
218+ learning_rate = 0.01 ,
219+ parameters = linear .parameters (),
220+ ),
221+ paddle .optimizer .AdamW (
222+ learning_rate = 0.01 ,
223+ parameters = linear .parameters (),
224+ ),
225+ paddle .optimizer .ASGD (
226+ learning_rate = 0.01 ,
227+ parameters = linear .parameters (),
228+ ),
229+ ]
230+ for optimizer in optimizers :
231+
232+ def closure ():
233+ optimizer .zero_grad ()
234+ output = linear (x )
235+ loss = paddle .mean (output )
236+ loss .backward ()
237+ return loss
238+
239+ loss = optimizer .step (closure )
218240
219241
220242if __name__ == '__main__' :
0 commit comments