Skip to content

Commit c727805

Browse files
authored
#10-add dataloaders (#12)
1 parent 4b9caa6 commit c727805

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

data/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from data.dataloader import get_mnist_loader, get_cifar10_loader

data/dataloader.py

+70
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from typing import Union, Tuple
2+
3+
from torch.utils.data import random_split, DataLoader
4+
from torchvision.datasets import MNIST, CIFAR10
5+
from torchvision.transforms import ToTensor, Normalize, Compose
6+
7+
8+
def get_mnist_loader(
9+
train: bool,
10+
batch_size: int,
11+
normalize: bool = True,
12+
data_path: str = "./data",
13+
shuffle: bool = True,
14+
num_workers: int = 1
15+
) -> Union[DataLoader, Tuple[DataLoader, DataLoader]]:
16+
transforms = [
17+
ToTensor()
18+
]
19+
20+
if normalize:
21+
transforms.append(
22+
Normalize((0.1307,), (0.3081,))
23+
)
24+
25+
transform = Compose(transforms)
26+
27+
if train:
28+
dataset = MNIST(data_path, train=True, transform=transform, download=True)
29+
train_dataset, val_dataset = random_split(dataset, [50_000, 10_000])
30+
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=shuffle,
31+
num_workers=num_workers)
32+
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
33+
return train_loader, val_loader
34+
35+
else:
36+
test_dataset = MNIST(data_path, train=False, transform=transform, download=True)
37+
return DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
38+
39+
40+
def get_cifar10_loader(
41+
train: bool,
42+
batch_size: int,
43+
normalize: bool = True,
44+
data_path: str = "./data",
45+
shuffle: bool = True,
46+
num_workers: int = 1
47+
) -> Union[DataLoader, Tuple[DataLoader, DataLoader]]:
48+
transforms = [
49+
ToTensor()
50+
]
51+
52+
if normalize:
53+
transforms.append(
54+
Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
55+
)
56+
57+
transform = Compose(transforms)
58+
59+
if train:
60+
dataset = CIFAR10(data_path, train=True, transform=transform, download=True)
61+
train_dataset, val_dataset = random_split(dataset, [42_000, 8_000])
62+
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=shuffle,
63+
num_workers=num_workers)
64+
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
65+
return train_loader, val_loader
66+
67+
else:
68+
test_dataset = CIFAR10(data_path, train=False, transform=transform, download=True)
69+
return DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers
70+
)

0 commit comments

Comments
 (0)