Skip to content

Commit 30a957e

Browse files
committed
merge
1 parent 625332d commit 30a957e

5 files changed

Lines changed: 48 additions & 128 deletions

File tree

6.71 KB
Binary file not shown.

cli.py

Lines changed: 17 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
import click
2-
#import genblack
3-
import genwhite
42
import os
3+
import generateface
54
import tensorflow as tf
65

76
@click.command()
87
@click.argument('age')
98
@click.argument('region')
10-
@click.argument('sex')
11-
def main(age,region,sex):
9+
@click.argument('gender')
10+
def main(age,region,gender):
1211
data_dir='/images'
1312
# Image configuration
1413
IMAGE_HEIGHT = 28
@@ -18,44 +17,21 @@ def main(age,region,sex):
1817
z_dim = 100
1918
learning_rate = 0.0002
2019
beta1 = 0.5
21-
epochs = 20
22-
if region=='black':
23-
if sex == 'female':
24-
data_files = genwhite.glob(os.path.join(data_dir, 'black/female/*.jpg'))
25-
data_files.extend(genwhite.glob('*.png'))
26-
shape = len(data_files), IMAGE_WIDTH, IMAGE_HEIGHT, 3
27-
with tf.Session() as sess:
28-
sess.run(tf.global_variables_initializer())
29-
with tf.Graph().as_default():
30-
genwhite.train(epochs, batch_size, z_dim, learning_rate, beta1,shape)
31-
elif sex == 'male':
32-
data_files = genwhite.glob(os.path.join(data_dir, 'black/male/*.jpg'))
33-
data_files.extend(genwhite.glob('*.png'))
34-
shape = len(data_files), IMAGE_WIDTH, IMAGE_HEIGHT, 3
35-
with tf.Session() as sess:
36-
sess.run(tf.global_variables_initializer())
37-
with tf.Graph().as_default():
38-
genwhite.train(epochs, batch_size, z_dim, learning_rate, beta1,shape)
39-
40-
elif region=='white':
41-
if sex == 'female':
42-
data_files = genwhite.glob(os.path.join(data_dir, 'white/female/*.jpg'))
43-
data_files.extend(genwhite.glob('*.png'))
44-
shape = len(data_files), IMAGE_WIDTH, IMAGE_HEIGHT, 3
45-
with tf.Session() as sess:
46-
sess.run(tf.global_variables_initializer())
47-
with tf.Graph().as_default():
48-
genwhite.train(epochs, batch_size, z_dim, learning_rate, beta1,shape)
49-
elif sex == 'male':
50-
data_files = genwhite.glob(os.path.join(data_dir, 'white/male/*.jpg'))
51-
data_files.extend(genwhite.glob('*.png'))
52-
shape = len(data_files), IMAGE_WIDTH, IMAGE_HEIGHT, 3
53-
with tf.Session() as sess:
54-
sess.run(tf.global_variables_initializer())
55-
with tf.Graph().as_default():
56-
genwhite.train(epochs, batch_size, z_dim, learning_rate, beta1,shape)
57-
else:
20+
epochs = 5
21+
genders=['male','female']
22+
regions=['black','white']
23+
if (region in regions) and (gender in genders):
24+
data_files = generateface.glob(os.path.join(data_dir, '%s/%s/*.*' %(region,gender)))
25+
shape = len(data_files), IMAGE_WIDTH, IMAGE_HEIGHT, 3
26+
with tf.Session() as sess:
27+
sess.run(tf.global_variables_initializer())
28+
with tf.Graph().as_default():
29+
generateface.train(epochs, batch_size, z_dim, learning_rate, beta1,shape)
30+
else :
5831
click.echo("Enter either black or white as region")
32+
click.echo('Enter either male or female as gender')
33+
click.echo('E.g 24 black female')
34+
5935

6036
if __name__=='__main__':
6137
main()

gen.zip

3.59 KB
Binary file not shown.

genwhite.py renamed to generateface.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,16 @@
55
from PIL import Image
66
import numpy as np
77
import tensorflow as tf
8-
import helper #this file is in the folder...helper.py
8+
import math
99
# import dataset
1010
data_dir='/images'
1111

1212
# Image configuration
1313
IMAGE_HEIGHT = 28
1414
IMAGE_WIDTH = 28
1515
data_files = glob('images/**/*.*', recursive=True)
16-
shape = len(data_files), IMAGE_WIDTH, IMAGE_HEIGHT, 3
1716

18-
# print(len(data_files))
17+
shape = len(data_files), IMAGE_WIDTH, IMAGE_HEIGHT, 3
1918

2019
def get_image(image_path, width, height, mode):
2120
"""
@@ -187,6 +186,31 @@ def model_opt(d_loss, g_loss, learning_rate, beta1):
187186

188187
return d_train_opt, g_train_opt
189188

189+
def images_square_grid(images, mode='RGB'):
190+
"""
191+
Helper function to save images as a square grid (visualization)
192+
"""
193+
# Get maximum size for square grid of images
194+
save_size = math.floor(np.sqrt(images.shape[0]))
195+
# Scale to 0-255
196+
images = (((images - images.min()) * 255) / (images.max() - images.min())).astype(np.uint8)
197+
# Put images in a square arrangement
198+
try:
199+
images_in_square = np.reshape(
200+
images[:save_size*save_size],
201+
(save_size, save_size, images.shape[1], images.shape[2], images.shape[3]))
202+
# Combine images to grid image
203+
new_im = Image.new(mode, (images.shape[1] * save_size, images.shape[2] * save_size))
204+
for col_i, col_images in enumerate(images_in_square):
205+
for image_i, image in enumerate(col_images):
206+
im = Image.fromarray(image, mode)
207+
new_im.paste(im, (col_i * images.shape[1], image_i * images.shape[2]))
208+
209+
return new_im
210+
except:
211+
print ('the shape of your images are '+ str(images.shape))
212+
print('check image dimensions')
213+
190214
def show_generator_output(sess, n_images, input_z, out_channel_dim):
191215
"""
192216
Show example output for the generator
@@ -197,8 +221,10 @@ def show_generator_output(sess, n_images, input_z, out_channel_dim):
197221
samples = sess.run(
198222
generator(input_z, out_channel_dim, False),
199223
feed_dict={input_z: example_z})
200-
pyplot.imshow(helper.images_square_grid(samples))
224+
pyplot.imshow(images_square_grid(samples))
201225
pyplot.show()
226+
227+
202228

203229

204230
def train(epoch_count, batch_size, z_dim, learning_rate, beta1, data_shape):
@@ -225,7 +251,7 @@ def train(epoch_count, batch_size, z_dim, learning_rate, beta1, data_shape):
225251
_ = sess.run(d_opt, feed_dict={input_real: batch_images, input_z: batch_z})
226252
_ = sess.run(g_opt, feed_dict={input_real: batch_images, input_z: batch_z})
227253

228-
if steps % 20 == 0:
254+
if steps % 400 == 0:
229255
# At the end of every 10 epochs, get the losses and print them out
230256
train_loss_d = d_loss.eval({input_z: batch_z, input_real: batch_images})
231257
train_loss_g = g_loss.eval({input_z: batch_z})

helper.py

Lines changed: 0 additions & 82 deletions
This file was deleted.

0 commit comments

Comments
 (0)