-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathperm_mnist_data_generation.py
More file actions
94 lines (81 loc) · 3.6 KB
/
perm_mnist_data_generation.py
File metadata and controls
94 lines (81 loc) · 3.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os
from argparse import ArgumentParser, Namespace
from typing import Dict, List
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import Subset
from tqdm import tqdm
from utils import set_seed
def gen_data(img_size: int,
downsample_ratio: float,
download: bool) -> Dict[str, torch.Tensor]:
'''
param img_size: image size
param downsample_ratio: ratio of training and test data to use
return: dictionary containing training and test data
'''
transform = transforms.Compose([
# Resize to img_size x img_size
transforms.Resize((img_size, img_size)),
transforms.ToTensor() # Convert to tensor
])
train_dataset = datasets.MNIST(root="data",
train=True,
download=download,
transform=transform)
test_dataset = datasets.MNIST(root="data",
download=download,
train=False,
transform=transform)
train_samples = int(len(train_dataset) * downsample_ratio)
train_indices = torch.randperm(len(train_dataset))[:train_samples]
train_dataset = Subset(train_dataset, train_indices)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=int(
60_000*downsample_ratio),
shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=10_000,
shuffle=False)
for images, labels in train_loader:
x_train = images.flatten(start_dim=1)
y_train = labels
for images_test, labels_test in test_loader:
x_test = images_test.flatten(start_dim=1)
y_test = labels_test
data = dict(x_train=x_train, y_train=y_train, x_test=x_test, y_test=y_test)
return data
def gen_permutation(img_size: int, num_tasks: int) -> List[torch.Tensor]:
'''
param img_size: image size
param num_tasks: number of tasks
return: list of permutations
'''
perms = []
for _ in range(num_tasks):
perms.append(torch.randperm(img_size**2))
return torch.stack(perms)
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument('--img_size', type=int,
help='output image size', default=7)
parser.add_argument('--downsample', type=float,
help='downsampling proportion of the training set', default=1.0)
parser.add_argument('-d', '--download', action='store_true',
help='whether to download the dataset')
parser.add_argument('--num_tasks', type=int,
help='number of tasks', default=50_000)
parser.add_argument('--seed', type=int, nargs='+',
help='random seed', default=list(range(10)))
args: Namespace = parser.parse_args()
print('Generating permuted MNIST data...')
for seed in tqdm(args.seed):
set_seed(seed)
data = gen_data(img_size=args.img_size,
downsample_ratio=args.downsample,
download=args.download)
perm = gen_permutation(img_size=args.img_size,
num_tasks=args.num_tasks)
path = os.path.join('.', 'data', 'MNIST', f'{seed}.pt')
torch.save({'data': data, 'perm': perm}, path)