diff --git a/README.md b/README.md index d75be7f..77baa3c 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,9 @@ +--- +noteId: "708531c0f33b11ec8e08f3de8bf47f07" +tags: [] + +--- + # SimCLR PyTorch implementation of SimCLR: A Simple Framework for Contrastive Learning of Visual Representations by T. Chen et al. Including support for: diff --git a/config/config.yaml b/config/config.yaml index 2c79b24..2c09ac9 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -4,15 +4,15 @@ gpus: 1 # I recommend always assigning 1 GPU to 1 node nr: 0 # machine nr. in node (0 -- nodes - 1) dataparallel: 0 # Use DataParallel instead of DistributedDataParallel workers: 8 -dataset_dir: "./datasets" +dataset_dir: "/Users/gaurav/Desktop/thesis-work/Datasets/T-1/train-faces/all/train" # train options seed: 42 # sacred handles automatic seeding when passed in the config batch_size: 128 -image_size: 224 +image_size: [108, 124] start_epoch: 0 epochs: 100 -dataset: "CIFAR10" # STL10 +dataset: "TFIW" #"CIFAR10" # STL10 pretrain: True # model options diff --git a/linear_evaluation.py b/linear_evaluation.py index a98038f..2ee9a14 100644 --- a/linear_evaluation.py +++ b/linear_evaluation.py @@ -8,6 +8,7 @@ from simclr import SimCLR from simclr.modules import LogisticRegression, get_resnet from simclr.modules.transformations import TransformsSimCLR +from simclr.modules.tfiwDataset import TFIWDataset from utils import yaml_config_hook @@ -138,12 +139,26 @@ def test(args, loader, simclr_model, model, criterion, optimizer): download=True, transform=TransformsSimCLR(size=args.image_size).test_transform, ) + test_dataset = torchvision.datasets.CIFAR10( args.dataset_dir, train=False, download=True, transform=TransformsSimCLR(size=args.image_size).test_transform, ) + + elif args.dataset == "TFIW": + train_dataset = TFIWDataset( + args.dataset_dir, #enter /Users/gaurav/Desktop/thesis-work/Datasets/T-1/train-faces/train + transform = TransformsSimCLR(size=args.image_size).test_transform, + ) + + test_dataset = TFIWDataset( + #args.dataset_dir, + "/Users/gaurav/Desktop/thesis-work/Datasets/T-1/train-faces/val", + transform = TransformsSimCLR(size=args.image_size).test_transform, + ) + else: raise NotImplementedError @@ -174,7 +189,8 @@ def test(args, loader, simclr_model, model, criterion, optimizer): simclr_model.eval() ## Logistic Regression - n_classes = 10 # CIFAR-10 / STL-10 + #n_classes = 10 # CIFAR-10 / STL-10 + n_classes = 571 #TFIW has 571 families in the training dataset model = LogisticRegression(simclr_model.n_features, n_classes) model = model.to(args.device) diff --git a/main.py b/main.py index fec4778..040e4af 100644 --- a/main.py +++ b/main.py @@ -19,16 +19,23 @@ from simclr.modules.transformations import TransformsSimCLR from simclr.modules.sync_batchnorm import convert_model +from simclr.modules.tfiwDataset import TFIWDataset +from net import LResNet50E_IR, LResNet + from model import load_optimizer, save_model from utils import yaml_config_hook def train(args, train_loader, model, criterion, optimizer, writer): loss_epoch = 0 + print(enumerate(train_loader)) for step, ((x_i, x_j), _) in enumerate(train_loader): + #for step, (x_i, x_j) in enumerate(train_loader): + #print(x_i) + #print(x_j) optimizer.zero_grad() - x_i = x_i.cuda(non_blocking=True) - x_j = x_j.cuda(non_blocking=True) + #x_i = x_i.cuda(non_blocking=True) + #x_j = x_j.cuda(non_blocking=True) # positive pair, with encoding h_i, h_j, z_i, z_j = model(x_i, x_j) @@ -76,6 +83,13 @@ def main(gpu, args): download=True, transform=TransformsSimCLR(size=args.image_size), ) + + elif args.dataset == "TFIW": + train_dataset = TFIWDataset( + args.dataset_dir, #enter /Users/gaurav/Desktop/thesis-work/Datasets/T-1/train-faces/all/train + transform = TransformsSimCLR(size=args.image_size), + ) + else: raise NotImplementedError @@ -96,8 +110,10 @@ def main(gpu, args): ) # initialize ResNet - encoder = get_resnet(args.resnet, pretrained=False) - n_features = encoder.fc.in_features # get dimensions of fc layer + #encoder = get_resnet(args.resnet, pretrained=False) + encoder = LResNet50E_IR(is_gray=False) + n_features = 256 #encoder.fc.in_features # get dimensions of fc layer + print(encoder) # initialize model model = SimCLR(encoder, args.projection_dim, n_features) diff --git a/main_pl.py b/main_pl.py index b7da1cc..66c1d37 100644 --- a/main_pl.py +++ b/main_pl.py @@ -11,6 +11,7 @@ from simclr.modules import NT_Xent, get_resnet from simclr.modules.transformations import TransformsSimCLR from simclr.modules.sync_batchnorm import convert_model +from simclr.modules.tfiwDataset import TFIWDataset from utils import yaml_config_hook @@ -76,7 +77,7 @@ def configure_optimizers(self): parser = argparse.ArgumentParser(description="SimCLR") - config = yaml_config_hook("./config/config.yaml") + config = yaml_config_hook("/Users/gaurav/Desktop/thesis-work/contrastive/SimCLR-Faces/config/config.yaml") for k, v in config.items(): parser.add_argument(f"--{k}", default=v, type=type(v)) @@ -95,6 +96,13 @@ def configure_optimizers(self): download=True, transform=TransformsSimCLR(size=args.image_size), ) + + elif args.dataset == "TFIW": + train_dataset = TFIWDataset( + args.dataset_dir, #enter /Users/gaurav/Desktop/thesis-work/Datasets/T-1/train-faces/train + transform = TransformsSimCLR(size=args.image_size), + ) + else: raise NotImplementedError diff --git a/net.py b/net.py new file mode 100644 index 0000000..9cc8f45 --- /dev/null +++ b/net.py @@ -0,0 +1,182 @@ +import torch +import torch.nn as nn + + + +# -------------------------------------- sphere network Begin -------------------------------------- +class Block(nn.Module): + def __init__(self, planes): + super(Block, self).__init__() + self.conv1 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.prelu1 = nn.PReLU(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.prelu2 = nn.PReLU(planes) + + def forward(self, x): + return x + self.prelu2(self.conv2(self.prelu1(self.conv1(x)))) + + +class sphere(nn.Module): + def __init__(self, type=20, is_gray=False): + super(sphere, self).__init__() + block = Block + if type is 20: + layers = [1, 2, 4, 1] + elif type is 64: + layers = [3, 7, 16, 3] + else: + raise ValueError('sphere' + str(type) + " IS NOT SUPPORTED! (sphere20 or sphere64)") + filter_list = [3, 64, 128, 256, 512] + if is_gray: + filter_list[0] = 1 + + self.layer1 = self._make_layer(block, filter_list[0], filter_list[1], layers[0], stride=2) + self.layer2 = self._make_layer(block, filter_list[1], filter_list[2], layers[1], stride=2) + self.layer3 = self._make_layer(block, filter_list[2], filter_list[3], layers[2], stride=2) + self.layer4 = self._make_layer(block, filter_list[3], filter_list[4], layers[3], stride=2) + self.fc = nn.Linear(512 * 7 * 6, 512) + + # Weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + if m.bias is not None: + nn.init.xavier_uniform_(m.weight) + nn.init.constant_(m.bias, 0.0) + else: + nn.init.normal_(m.weight, 0, 0.01) + + + def _make_layer(self, block, inplanes, planes, blocks, stride): + layers = [] + layers.append(nn.Conv2d(inplanes, planes, 3, stride, 1)) + layers.append(nn.PReLU(planes)) + for i in range(blocks): + layers.append(block(planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + def save(self, file_path): + with open(file_path, 'wb') as f: + torch.save(self.state_dict(), f) + + +# -------------------------------------- sphere network END -------------------------------------- + +# ---------------------------------- LResNet50E-IR network Begin ---------------------------------- + +class BlockIR(nn.Module): + def __init__(self, inplanes, planes, stride, dim_match): + super(BlockIR, self).__init__() + self.bn1 = nn.BatchNorm2d(inplanes) + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + self.prelu1 = nn.PReLU(planes) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(planes) + + if dim_match: + self.downsample = None + else: + self.downsample = nn.Sequential( + nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes), + ) + + def forward(self, x): + residual = x + + out = self.bn1(x) + out = self.conv1(out) + out = self.bn2(out) + out = self.prelu1(out) + out = self.conv2(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + + return out + + +class LResNet(nn.Module): + + def __init__(self, block, layers, filter_list, is_gray=False): + self.inplanes = 64 + super(LResNet, self).__init__() + # input is (mini-batch,3 or 1,112,96) + # use (conv3x3, stride=1, padding=1) instead of (conv7x7, stride=2, padding=3) + if is_gray: + self.conv1 = nn.Conv2d(1, filter_list[0], kernel_size=3, stride=1, padding=1, bias=False) # gray + else: + self.conv1 = nn.Conv2d(3, filter_list[0], kernel_size=3, stride=1, padding=1, bias=False) + self.bn1 = nn.BatchNorm2d(filter_list[0]) + self.prelu1 = nn.PReLU(filter_list[0]) + self.layer1 = self._make_layer(block, filter_list[0], filter_list[1], layers[0], stride=2) + self.layer2 = self._make_layer(block, filter_list[1], filter_list[2], layers[1], stride=2) + self.layer3 = self._make_layer(block, filter_list[2], filter_list[3], layers[2], stride=2) + self.layer4 = self._make_layer(block, filter_list[3], filter_list[4], layers[3], stride=2) + self.fc = nn.Sequential( + nn.BatchNorm1d(filter_list[4] * 7 * 6), + nn.Dropout(p=0.4), + nn.Linear(filter_list[4] * 7 * 6, 512), + nn.BatchNorm1d(512), # fix gamma ??? + ) + + # Weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0.0) + elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d): + nn.init.constant_(m.weight,1) + nn.init.constant_(m.bias,0) + + + def _make_layer(self, block, inplanes, planes, blocks, stride): + layers = [] + layers.append(block(inplanes, planes, stride, False)) + for i in range(1, blocks): + layers.append(block(planes, planes, stride=1, dim_match=True)) + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.prelu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + + x = x.view(x.size(0), -1) + x = self.fc(x) + + return x + + def save(self, file_path): + with open(file_path, 'wb') as f: + torch.save(self.state_dict(), f) + + +def LResNet50E_IR(is_gray=False): + print("Using LResNet50E from ArcFace") + filter_list = [64, 64, 128, 256, 512] + layers = [3, 4, 14, 3] + return LResNet(BlockIR, layers, filter_list, is_gray) +# ---------------------------------- LResNet50E-IR network End ---------------------------------- diff --git a/simclr/modules/__init__.py b/simclr/modules/__init__.py index 56b4ca9..73ecee4 100644 --- a/simclr/modules/__init__.py +++ b/simclr/modules/__init__.py @@ -3,3 +3,4 @@ from .lars import LARS from .resnet import get_resnet from .gather import GatherLayer +#from simclr.modules.tfiwDataset import TFIWDataset diff --git a/simclr/modules/tfiwDataset.py b/simclr/modules/tfiwDataset.py new file mode 100644 index 0000000..1f0dcbc --- /dev/null +++ b/simclr/modules/tfiwDataset.py @@ -0,0 +1,50 @@ +from types import NoneType +from torch.utils.data import Dataset +from PIL import Image +import os +import torch +from torchvision import transforms +import pandas as pd + + +class TFIWDataset(Dataset): + def __init__(self, img_dir = os.getcwd(), transform = None): + self.img_dir = img_dir + self.transform = transform + + self.img_names = os.listdir(img_dir) + + file_names = [] + labels = [] + for i in self.img_names: + #print(i[:-3]) + if(i[-3:]=='jpg'): + file_names.extend([i]) #to remove unwanted files names from the img_names like .DS_Store etc. + labels.extend([int(i[1:5])]) + self.labels = labels + self.img_names = file_names + + img_names_csv = pd.DataFrame(data= [file_names, self.labels]); + #img_names_csv['Labels'] = self.labels + img_names_csv.T.to_csv("/Users/gaurav/Desktop/data.csv") + #print(self.img_names[0:5]) + #print(self.labels[0:5]) + + def __getitem__(self, idx): + image = Image.open(os.path.join(self.img_dir, self.img_names[idx])) + #image = torch.tensor(image) + if type(image)!=NoneType: #Some images were throwing empty tensors, hence did this. + if self.transform is not None: + image = self.transform(image) + try: + #print(idx, self.labels[idx], self.img_names[idx]) + return image, self.labels[idx] + except IndexError: + print(f"Index is not present for index number {idx}") + + def __len__(self): + return len(self.img_names) + +#tfiw = TFIWDataset(img_dir='/Users/gaurav/Desktop/thesis-work/Datasets/T-1/train-faces/all') +#example = tfiw[7] +#print(example) \ No newline at end of file