Skip to content

Commit 13b01cb

Browse files
committed
add AdamW and Optimizer class test for closure
1 parent f5d0741 commit 13b01cb

1 file changed

Lines changed: 44 additions & 22 deletions

File tree

test/legacy_test/test_optimizer.py

Lines changed: 44 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -186,35 +186,57 @@ def test_step_without_closure(self):
186186
paddle.disable_static()
187187
x = paddle.arange(26, dtype="float32").reshape([2, 13])
188188
linear = paddle.nn.Linear(13, 5)
189-
adam = paddle.optimizer.Adam(
190-
learning_rate=0.01,
191-
parameters=linear.parameters(),
192-
)
193-
adam.zero_grad()
194-
output = linear(x)
195-
loss = paddle.mean(output)
196-
loss.backward()
197-
adam.step()
189+
optimizers = [
190+
paddle.optimizer.Adam(
191+
learning_rate=0.01,
192+
parameters=linear.parameters(),
193+
),
194+
paddle.optimizer.AdamW(
195+
learning_rate=0.01,
196+
parameters=linear.parameters(),
197+
),
198+
paddle.optimizer.ASGD(
199+
learning_rate=0.01,
200+
parameters=linear.parameters(),
201+
),
202+
]
203+
for optimizer in optimizers:
204+
optimizer.zero_grad()
205+
output = linear(x)
206+
loss = paddle.mean(output)
207+
loss.backward()
208+
optimizer.step()
198209

199210
def test_step_with_closure(self):
200211
paddle.seed(100)
201212
numpy.random.seed(100)
202213
paddle.disable_static()
203214
x = paddle.arange(26, dtype="float32").reshape([2, 13])
204215
linear = paddle.nn.Linear(13, 5)
205-
adam = paddle.optimizer.Adam(
206-
learning_rate=0.01,
207-
parameters=linear.parameters(),
208-
)
209-
210-
def closure():
211-
adam.zero_grad()
212-
output = linear(x)
213-
loss = paddle.mean(output)
214-
loss.backward()
215-
return loss
216-
217-
loss = adam.step(closure)
216+
optimizers = [
217+
paddle.optimizer.Adam(
218+
learning_rate=0.01,
219+
parameters=linear.parameters(),
220+
),
221+
paddle.optimizer.AdamW(
222+
learning_rate=0.01,
223+
parameters=linear.parameters(),
224+
),
225+
paddle.optimizer.ASGD(
226+
learning_rate=0.01,
227+
parameters=linear.parameters(),
228+
),
229+
]
230+
for optimizer in optimizers:
231+
232+
def closure():
233+
optimizer.zero_grad()
234+
output = linear(x)
235+
loss = paddle.mean(output)
236+
loss.backward()
237+
return loss
238+
239+
loss = optimizer.step(closure)
218240

219241

220242
if __name__ == '__main__':

0 commit comments

Comments
 (0)