-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
58 lines (44 loc) · 1.43 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import numpy as np
import torch
import os
from torch.utils.data.dataset import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.autograd import Variable
from PIL import Image
from os import path
from config import *
class LocalDataset(Dataset):
def __init__(self, base_path, txt_list, transform=None):
self.base_path=base_path
self.images = np.loadtxt(txt_list,dtype=str,delimiter=',')
self.transform = transform
def __getitem__(self, index):
f,x,y,u,v,c = self.images[index]
im = Image.open(path.join(self.base_path, f))
if self.transform is not None:
im = self.transform(im)
if REGRESSION:
label = torch.tensor([
float(x),
float(y),
float(u),
float(v)
])
else:
label = int(c)
return { 'image' : im, 'label': label, 'img_name': f }
def __len__(self):
return len(self.images)
# Algorithms to calculate mean and standard_deviation
#dataset = LocalDataset("images", "training_list.csv", transform=transforms.ToTensor())
# Mean
#m = torch.zeros(3)
#for sample in dataset:
# m += sample['image'].sum(1).sum(1)
#m /= len(dataset)*256*144
# Standard Deviation
#s = torch.zeros(3)
#for sample in dataset:
# s+=((sample['image']-m.view(3,1,1))**2).sum(1).sum(1)
#s=torch.sqrt(s/(len(dataset)*256*144))