Skip to content

Commit 8144d30

Browse files
committed
fix sampler bugs and update dataloader
1 parent 5798ff5 commit 8144d30

File tree

3 files changed

+12
-15
lines changed

3 files changed

+12
-15
lines changed

train.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -427,8 +427,11 @@
427427

428428
UnFreeze_flag = True
429429

430+
if distributed:
431+
train_sampler.set_epoch(epoch)
430432
set_optimizer_lr(optimizer, lr_scheduler_func, epoch)
431433

432434
fit_one_epoch(model_train, model, yolo_loss, loss_history, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, UnFreeze_Epoch, Cuda, fp16, scaler, save_period, save_dir, local_rank)
433-
434-
loss_history.writer.close()
435+
436+
if local_rank == 0:
437+
loss_history.writer.close()

utils/dataloader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import cv2
22
import numpy as np
3+
import torch
34
from PIL import Image
45
from torch.utils.data.dataset import Dataset
56

@@ -160,7 +161,6 @@ def yolo_dataset_collate(batch):
160161
for img, box in batch:
161162
images.append(img)
162163
bboxes.append(box)
163-
images = np.array(images)
164+
images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
165+
bboxes = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in bboxes]
164166
return images, bboxes
165-
166-

utils/utils_fit.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,8 @@ def fit_one_epoch(model_train, model, yolo_loss, loss_history, optimizer, epoch,
2121
images, targets = batch[0], batch[1]
2222
with torch.no_grad():
2323
if cuda:
24-
images = torch.from_numpy(images).type(torch.FloatTensor).cuda()
25-
targets = [torch.from_numpy(ann).type(torch.FloatTensor).cuda() for ann in targets]
26-
else:
27-
images = torch.from_numpy(images).type(torch.FloatTensor)
28-
targets = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in targets]
24+
images = images.cuda()
25+
targets = [ann.cuda() for ann in targets]
2926
#----------------------#
3027
# 清零梯度
3128
#----------------------#
@@ -94,11 +91,8 @@ def fit_one_epoch(model_train, model, yolo_loss, loss_history, optimizer, epoch,
9491
images, targets = batch[0], batch[1]
9592
with torch.no_grad():
9693
if cuda:
97-
images = torch.from_numpy(images).type(torch.FloatTensor).cuda()
98-
targets = [torch.from_numpy(ann).type(torch.FloatTensor).cuda() for ann in targets]
99-
else:
100-
images = torch.from_numpy(images).type(torch.FloatTensor)
101-
targets = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in targets]
94+
images = images.cuda()
95+
targets = [ann.cuda() for ann in targets]
10296
#----------------------#
10397
# 清零梯度
10498
#----------------------#

0 commit comments

Comments
 (0)