diff --git a/wgan/README.md b/wgan/README.md new file mode 100644 index 0000000000..e588803975 --- /dev/null +++ b/wgan/README.md @@ -0,0 +1,9 @@ +### For custom dataset +- Put the image dataset inside `Keras-GAN/wgan/dataset/` folder +- Update the `self.img_rows`, `self.img_cols`, `self.channels` value. +- Update the following lines inside `build_generator()` function. +```python +model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim)) +model.add(Reshape((7, 7, 128))) +``` +- Change the `(X_train, _), (_, _) = mnist.load_data()` with `X_train = load_image()` \ No newline at end of file diff --git a/wgan/wgan.py b/wgan/wgan.py index 2ed96ead43..a9d1be8ec0 100644 --- a/wgan/wgan.py +++ b/wgan/wgan.py @@ -1,9 +1,8 @@ from __future__ import print_function, division -from keras.datasets import mnist from keras.layers import Input, Dense, Reshape, Flatten, Dropout from keras.layers import BatchNormalization, Activation, ZeroPadding2D -from keras.layers.advanced_activations import LeakyReLU +from keras.layers import LeakyReLU from keras.layers.convolutional import UpSampling2D, Conv2D from keras.models import Sequential, Model from keras.optimizers import RMSprop @@ -16,11 +15,33 @@ import numpy as np +# For mnist dataset +from keras.datasets import mnist + +# For custom dataset +from pathlib import Path +import cv2 +import glob + +''' +For custom dataset, Put the image dataset inside Keras-GAN/wgan/dataset/ folder. This load_image() function will load the images +''' +def load_image(dirName='dataset'): + path = str(Path().absolute()) + "/" + dirName + "/*.jpg" + + images = [] + for file in glob.glob(path): + img = cv2.imread(file) + img = cv2.resize(img, (228, 228)) + images.append(img) + + return np.array(images) + class WGAN(): def __init__(self): self.img_rows = 28 self.img_cols = 28 - self.channels = 1 + self.channels = 1 # For RGB image, channel = 3 self.img_shape = (self.img_rows, self.img_cols, self.channels) self.latent_dim = 100