diff --git a/.gitignore b/.gitignore index 9652162f41..1004e9532b 100644 --- a/.gitignore +++ b/.gitignore @@ -7,5 +7,8 @@ *.hdf5 .DS_Store */datasets +*.idea +*.ipynb +*.zip __pycache__ diff --git a/srgan/data_loader.py b/srgan/data_loader.py index 7ec4f141ab..383f16b14c 100644 --- a/srgan/data_loader.py +++ b/srgan/data_loader.py @@ -1,6 +1,7 @@ import scipy from glob import glob import numpy as np +from PIL import Image import matplotlib.pyplot as plt class DataLoader(): @@ -18,13 +19,13 @@ def load_data(self, batch_size=1, is_testing=False): imgs_hr = [] imgs_lr = [] for img_path in batch_images: - img = self.imread(img_path) + img = Image.open(img_path) h, w = self.img_res low_h, low_w = int(h / 4), int(w / 4) - img_hr = scipy.misc.imresize(img, self.img_res) - img_lr = scipy.misc.imresize(img, (low_h, low_w)) + img_hr = np.array(img.resize(self.img_res)) + img_lr = np.array(img.resize((low_h, low_w))) # If training => do random flip if not is_testing and np.random.random() < 0.5: @@ -38,7 +39,3 @@ def load_data(self, batch_size=1, is_testing=False): imgs_lr = np.array(imgs_lr) / 127.5 - 1. return imgs_hr, imgs_lr - - - def imread(self, path): - return scipy.misc.imread(path, mode='RGB').astype(np.float) diff --git a/srgan/srgan.py b/srgan/srgan.py index a372b3edea..ec143f466e 100644 --- a/srgan/srgan.py +++ b/srgan/srgan.py @@ -1,7 +1,11 @@ """ Super-resolution of CelebA using Generative Adversarial Networks. -The dataset can be downloaded from: https://www.dropbox.com/sh/8oqt9vytwxb3s4r/AADIKlz8PR9zr6Y20qbkunrba/Img/img_align_celeba.zip?dl=0 +The dataset can be downloaded from: http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html +Instructions to download dataset: +1. Click on the above link, followed by the Google Drive link +2. Click on the Img folder and save the img_align_celeba.zip +3. Create a folder 'datasets/' and unizp the above file to this dataset folder. Instrustion on running the script: 1. Download the dataset from the provided link @@ -103,17 +107,12 @@ def build_vgg(self): Builds a pre-trained VGG19 model that outputs image features extracted at the third block of the model """ - vgg = VGG19(weights="imagenet") + vgg = VGG19(weights="imagenet", input_shape=self.hr_shape, include_top=False) # Set outputs to outputs of last conv. layer in block 3 # See architecture at: https://github.com/keras-team/keras/blob/master/keras/applications/vgg19.py - vgg.outputs = [vgg.layers[9].output] + outputs = vgg.layers[9].output - img = Input(shape=self.hr_shape) - - # Extract image features - img_features = vgg(img) - - return Model(img, img_features) + return Model(vgg.input, outputs) def build_generator(self):