Skip to content

Commit 5efa231

Browse files
authored
[misc] Fix training function with distillation (#396)
1 parent c9d0ac5 commit 5efa231

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

tinynn/util/cifar10.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def _calc_loss(label, label_teacher):
237237

238238
avg_data_time.update(time.time() - batch_end)
239239
image = image.to(device=context.device)
240+
context.optimizer.zero_grad()
240241

241242
if context.grad_scaler:
242243
with autocast():

0 commit comments

Comments
 (0)