-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathcifar10_download.py
More file actions
63 lines (51 loc) · 1.57 KB
/
cifar10_download.py
File metadata and controls
63 lines (51 loc) · 1.57 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import argparse
import copy
import os
import socket
import time
import random
import sys
import numpy as np
from itertools import cycle
from functools import reduce
import torch
import torch.backends.cudnn as cudnn
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.multiprocessing import Process
from torch.autograd import Variable
import torch.nn as nn
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
def make_dataloader():
ii64 = np.iinfo(np.int64)
r = random.randint(0, ii64.max)
torch.manual_seed(r)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
transforms.ToTensor(),
normalize,
]))
def make_validation_dataloader():
ii64 = np.iinfo(np.int64)
r = random.randint(0, ii64.max)
torch.manual_seed(r)
torch.cuda.manual_seed(r)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
val_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(root='./data', train=False, download=True, transform=transforms.Compose([
transforms.ToTensor(),
normalize,
])),
batch_size=128, shuffle=False,
num_workers=2, pin_memory=True)
return val_loader
make_dataloader( )
val_loader = make_validation_dataloader( )