Skip to content

Commit d3ae98f

Browse files
committed
不需要每个batch设置train()
1 parent 912793e commit d3ae98f

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

examples/tutorials_mnist.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class MyEvaluator(Evaluator):
6060
# 重构评价函数
6161
def evaluate(self):
6262
total, hit = 1e-5, 0
63-
for X, y in tqdm(test_dataloader, desc='Evaluating'):
63+
for X, y in tqdm(test_dataloader, desc='Evaluating', ncols=80):
6464
pred_y = model.predict(X).argmax(dim=-1)
6565
hit += pred_y.eq(y).sum().item()
6666
total += y.shape[0]

torch4keras/trainer.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,9 @@ def fit(self, train_dataloader, steps_per_epoch=None, epochs=1, callbacks=None,
361361
self.callbacks.on_batch_begin(self.global_step, self.local_step, logs)
362362

363363
# forward和backward
364-
self.unwrap_model().train() # 设置为train模式
364+
if not self.unwrap_model().training:
365+
self.unwrap_model().train() # 设置为train模式
366+
365367
tr_loss, tr_loss_detail = 0, {}
366368
for _ in range(self.grad_accumulation_steps):
367369
train_X, train_y = self._prepare_nextbatch() # 获取下一个batch的训练数据

0 commit comments

Comments
 (0)