diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9a7a43f --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +weights/* +*.pyc +data +weights_1 diff --git a/cam.py b/cam.py index 29982c4..4c243fb 100644 --- a/cam.py +++ b/cam.py @@ -1,27 +1,68 @@ from keras.models import * from keras.callbacks import * +from keras.preprocessing import image +from keras.preprocessing.image import ImageDataGenerator import keras.backend as K from model import * from data import * import cv2 import argparse +BATCH_SIZE = 32 +NB_EPOCHS = 40 +IMAGE_SIZE = (128, 128) + +def get_batches( + dirname, + gen=image.ImageDataGenerator(), + shuffle=False, + save_to_dir=None, + batch_size=32, + class_mode='categorical', + target_size=IMAGE_SIZE): + return gen.flow_from_directory( + dirname, + save_to_dir=save_to_dir, + target_size=target_size, + class_mode=class_mode, + shuffle=shuffle, + batch_size=batch_size) + def train(dataset_path): - model = get_model() - X, y = load_inria_person(dataset_path) - print "Training.." - checkpoint_path="weights.{epoch:02d}-{val_loss:.2f}.hdf5" + gen = ImageDataGenerator( + rotation_range=15, + rescale=1./255, + shear_range=0.1, + zoom_range=0.1, + horizontal_flip=True) + train_generator = get_batches(dataset_path+"/train", gen=gen, shuffle=True, batch_size=BATCH_SIZE) + # Don't shuffle or Augment validation set + valid_generator = get_batches(dataset_path+"/valid", shuffle=False, batch_size=BATCH_SIZE) + + x_train = train_generator.classes + x_valid = valid_generator.classes + y_train = to_categorical(x_train) + nb_classes = len(y_train[0]) + model = get_model(nb_classes) + nb_train_samples = len(x_train) + nb_valid_samples = len(x_valid) + checkpoint_path="weights/weights.{epoch:02d}-{val_loss:.2f}.hdf5" checkpoint = ModelCheckpoint(checkpoint_path, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto') - model.fit(X, y, nb_epoch=40, batch_size=32, validation_split=0.2, verbose=1, callbacks=[checkpoint]) + model.fit_generator( + train_generator, + nb_train_samples, + NB_EPOCHS, + validation_data=valid_generator, + nb_val_samples=nb_valid_samples, + callbacks=[checkpoint]) def visualize_class_activation_map(model_path, img_path, output_path): model = load_model(model_path) - original_img = cv2.imread(img_path, 1) + original_img = cv2.resize(cv2.imread(img_path, 1), IMAGE_SIZE) width, height, _ = original_img.shape #Reshape to the network input shape (3, w, h). img = np.array([np.transpose(np.float32(original_img), (2, 0, 1))]) - #Get the 512 input weights to the softmax. class_weights = model.layers[-1].get_weights()[0] final_conv_layer = get_output_layer(model, "conv5_3") @@ -31,13 +72,17 @@ def visualize_class_activation_map(model_path, img_path, output_path): #Create the class activation map. cam = np.zeros(dtype = np.float32, shape = conv_outputs.shape[1:3]) - for i, w in enumerate(class_weights[:, 1]): + + class_index = predictions.argmax() + print(class_index) + for i, w in enumerate(class_weights[:, class_index]): cam += w * conv_outputs[i, :, :] print "predictions", predictions cam /= np.max(cam) cam = cv2.resize(cam, (height, width)) heatmap = cv2.applyColorMap(np.uint8(255*cam), cv2.COLORMAP_JET) heatmap[np.where(cam < 0.2)] = 0 + print(heatmap.shape) img = heatmap*0.5 + original_img cv2.imwrite(output_path, img) diff --git a/data.py b/data.py deleted file mode 100644 index f8b79d6..0000000 --- a/data.py +++ /dev/null @@ -1,18 +0,0 @@ -import cv2 -import glob -import os -import numpy as np -from keras.utils.np_utils import to_categorical - -def load_inria_person(path): - pos_path = os.path.join(path, "pos") - neg_path = os.path.join(path, "/neg") - pos_images = [cv2.resize(cv2.imread(x), (64, 128)) for x in glob.glob(pos_path + "/*.png")] - pos_images = [np.transpose(img, (2, 0, 1)) for img in pos_images] - neg_images = [cv2.resize(cv2.imread(x), (64, 128)) for x in glob.glob(neg_path + "/*.png")] - neg_images = [np.transpose(img, (2, 0, 1)) for img in neg_images] - y = [1] * len(pos_images) + [0] * len(neg_images) - y = to_categorical(y, 2) - X = np.float32(pos_images + neg_images) - - return X, y diff --git a/model.py b/model.py index 42df8d4..f11c904 100644 --- a/model.py +++ b/model.py @@ -49,14 +49,14 @@ def VGG16_convolutions(): model.add(Convolution2D(512, 3, 3, activation='relu', name='conv5_3')) return model -def get_model(): +def get_model(nb_classes): model = VGG16_convolutions() model = load_model_weights(model, "vgg16_weights.h5") model.add(Lambda(global_average_pooling, output_shape=global_average_pooling_shape)) - model.add(Dense(2, activation = 'softmax', init='uniform')) + model.add(Dense(nb_classes, activation = 'softmax', init='uniform')) sgd = SGD(lr=0.01, decay=1e-6, momentum=0.5, nesterov=True) model.compile(loss = 'categorical_crossentropy', optimizer = sgd, metrics=['accuracy']) return model @@ -80,4 +80,4 @@ def get_output_layer(model, layer_name): # get the symbolic outputs of each "key" layer (we gave them unique names). layer_dict = dict([(layer.name, layer) for layer in model.layers]) layer = layer_dict[layer_name] - return layer \ No newline at end of file + return layer