diff --git a/picasso/examples/keras-vgg16/config.py b/picasso/examples/keras-vgg16/config.py index 7ab3fa5..e6590c2 100644 --- a/picasso/examples/keras-vgg16/config.py +++ b/picasso/examples/keras-vgg16/config.py @@ -1,12 +1,13 @@ +# Note: By default, Flask doesn't know that this file exists. If you want +# Flask to load the settings you specify here, you must set the environment +# variable `PICASSO_SETTINGS` to point to this file. E.g.: +# +# export PICASSO_SETTINGS=/path/to/examples/keras-vgg16/config.py +# import os base_dir = os.path.dirname(os.path.abspath(__file__)) -BACKEND_ML = 'keras' -BACKEND_PREPROCESSOR_NAME = 'preprocess' -BACKEND_PREPROCESSOR_PATH = os.path.join(base_dir, 'util.py') -BACKEND_POSTPROCESSOR_NAME = 'postprocess' -BACKEND_POSTPROCESSOR_PATH = os.path.join(base_dir, 'util.py') -BACKEND_PROB_DECODER_NAME = 'prob_decode' -BACKEND_PROB_DECODER_PATH = os.path.join(base_dir, 'util.py') +MODEL_CLS_PATH = os.path.join(base_dir, 'model.py') +MODEL_CLS_NAME = 'KerasVGG16Model' DATA_DIR = os.path.join(base_dir, 'data-volume') diff --git a/picasso/examples/keras-vgg16/model.py b/picasso/examples/keras-vgg16/model.py new file mode 100644 index 0000000..f7556a3 --- /dev/null +++ b/picasso/examples/keras-vgg16/model.py @@ -0,0 +1,44 @@ +from keras.applications.imagenet_utils import (decode_predictions, + preprocess_input) +import keras.applications.imagenet_utils +import numpy as np +from PIL import Image + +from picasso.ml_frameworks.keras.model import KerasModel + + +VGG16_DIM = (224, 224, 3) + + +class KerasVGG16Model(KerasModel): + + def preprocess(self, targets): + image_arrays = [] + for target in targets: + im = target.resize(VGG16_DIM[:2], Image.ANTIALIAS) + im = im.convert('RGB') + arr = np.array(im).astype('float32') + image_arrays.append(arr) + + all_targets = np.array(image_arrays) + return preprocess_input(all_targets) + + def decode_prob(self, probability_array): + r = decode_predictions(probability_array, top=self.top_probs) + results = [ + [{'code': entry[0], + 'name': entry[1], + 'prob': '{:.3f}'.format(entry[2])} + for entry in row] + for row in r + ] + classes = keras.applications.imagenet_utils.CLASS_INDEX + class_keys = list(classes.keys()) + class_values = list(classes.values()) + + for result in results: + for entry in result: + entry['index'] = int( + class_keys[class_values.index([entry['code'], + entry['name']])]) + return results diff --git a/picasso/examples/keras-vgg16/util.py b/picasso/examples/keras-vgg16/util.py deleted file mode 100644 index e56eaa0..0000000 --- a/picasso/examples/keras-vgg16/util.py +++ /dev/null @@ -1,53 +0,0 @@ -from keras.applications.imagenet_utils import (decode_predictions, - preprocess_input) -import keras.applications.imagenet_utils -from PIL import Image -import numpy as np - -VGG16_DIM = (224, 224, 3) - - -def preprocess(targets): - image_arrays = [] - for target in targets: - im = target.resize(VGG16_DIM[:2], Image.ANTIALIAS) - im = im.convert('RGB') - arr = np.array(im).astype('float32') - image_arrays.append(arr) - - all_targets = np.array(image_arrays) - return preprocess_input(all_targets) - - -def postprocess(output_arr): - images = [] - for row in output_arr: - im_array = row.reshape(VGG16_DIM[:2]) - images.append(im_array) - - return images - - -def prob_decode(probability_array, top=5): - r = decode_predictions(probability_array, top=top) - results = [ - [{'code': entry[0], - 'name': entry[1], - 'prob': '{:.3f}'.format(entry[2])} - for entry in row] - for row in r - ] - classes = keras.applications.imagenet_utils.CLASS_INDEX - class_keys = list(classes.keys()) - class_values = list(classes.values()) - - for result in results: - for entry in result: - entry.update( - {'index': - int( - class_keys[class_values.index([entry['code'], - entry['name']])] - )} - ) - return results diff --git a/picasso/examples/keras/config.py b/picasso/examples/keras/config.py index 0ecdf76..0d8d46e 100644 --- a/picasso/examples/keras/config.py +++ b/picasso/examples/keras/config.py @@ -1,19 +1,13 @@ -# Note: this settings file duplicates the default settings in the top-level -# file `settings.py`. If you want to modify settings here, you must export the -# path to this file: +# Note: By default, Flask doesn't know that this file exists. If you want +# Flask to load the settings you specify here, you must set the environment +# variable `PICASSO_SETTINGS` to point to this file. E.g.: # -# export PICASSO_SETTINGS=/path/to/picasso/picasso/examples/keras/config.py +# export PICASSO_SETTINGS=/path/to/examples/keras/config.py # -# otherwise, these settings will not be loaded. import os base_dir = os.path.dirname(os.path.abspath(__file__)) -BACKEND_ML = 'keras' -BACKEND_PREPROCESSOR_NAME = 'preprocess' -BACKEND_PREPROCESSOR_PATH = os.path.join(base_dir, 'util.py') -BACKEND_POSTPROCESSOR_NAME = 'postprocess' -BACKEND_POSTPROCESSOR_PATH = os.path.join(base_dir, 'util.py') -BACKEND_PROB_DECODER_NAME = 'prob_decode' -BACKEND_PROB_DECODER_PATH = os.path.join(base_dir, 'util.py') +MODEL_CLS_PATH = os.path.join(base_dir, 'model.py') +MODEL_CLS_NAME = 'KerasMNISTModel' DATA_DIR = os.path.join(base_dir, 'data-volume') diff --git a/picasso/examples/keras/model.py b/picasso/examples/keras/model.py new file mode 100644 index 0000000..c46bdae --- /dev/null +++ b/picasso/examples/keras/model.py @@ -0,0 +1,30 @@ +import numpy as np +from PIL import Image + +from picasso.ml_frameworks.keras.model import KerasModel + + +MNIST_DIM = (28, 28) + + +class KerasMNISTModel(KerasModel): + + def preprocess(self, raw_inputs): + """Convert images into the format required by our model. + + Our model requires that inputs be grayscale (mode 'L'), be resized to + `MNIST_DIM`, and be represented as float32 numpy arrays in range + [0, 1]. + + """ + image_arrays = [] + for raw_im in raw_inputs: + im = raw_im.convert('L') + im = im.resize(MNIST_DIM, Image.ANTIALIAS) + arr = np.array(im) + image_arrays.append(arr) + + inputs = np.array(image_arrays) + return inputs.reshape(len(inputs), + MNIST_DIM[0], + MNIST_DIM[1], 1).astype('float32') / 255 diff --git a/picasso/examples/keras/util.py b/picasso/examples/keras/util.py deleted file mode 100644 index 6f6202e..0000000 --- a/picasso/examples/keras/util.py +++ /dev/null @@ -1,90 +0,0 @@ -from PIL import Image -from operator import itemgetter -import numpy as np - -MNIST_DIM = (28, 28) - - -def preprocess(targets): - """Turn images into computation inputs - - Converts an iterable of PIL Images into a suitably-sized numpy array which - can be used as an input to the evaluation portion of the Keras/tensorflow - graph. - - Args: - targets (list of Images): a list of PIL Image objects - - Returns: - array (float32) - - """ - image_arrays = [] - for target in targets: - im = target.convert('L') - im = im.resize(MNIST_DIM, Image.ANTIALIAS) - arr = np.array(im) - image_arrays.append(arr) - - all_targets = np.array(image_arrays) - return all_targets.reshape(len(all_targets), - MNIST_DIM[0], - MNIST_DIM[1], 1).astype('float32') / 255 - - -def postprocess(output_arr): - """Reshape arrays to original image dimensions - - Typically used for outputs or computations on intermediate layers which - make sense to represent as an image in the original dimension of the input - images (see ``SaliencyMaps``). - - Args: - output_arr (array of float32): Array of leading dimension n containing - n arrays to be reshaped - - Returns: - reshaped array - - """ - images = [] - for row in output_arr: - im_array = row.reshape(MNIST_DIM) - images.append(im_array) - - return images - - -def prob_decode(probability_array, top=5): - """Provide class information from output probabilities - - Gives the visualization additional context for the computed class - probabilities. - - Args: - probability_array (array): class probabilities - top (int): number of class entries to return. Useful for limiting - output in models with many classes. Defaults to 5. - - Returns: - result list of dict in the format [{'index': class_index, 'name': - class_name, 'prob': class_probability}, ...] - - """ - results = [] - for row in probability_array: - entries = [] - for i, prob in enumerate(row): - entries.append({'index': i, - 'name': str(i), - 'prob': prob}) - - entries = sorted(entries, - key=itemgetter('prob'), - reverse=True)[:top] - - for entry in entries: - entry['prob'] = '{:.3f}'.format(entry['prob']) - results.append(entries) - - return results diff --git a/picasso/examples/tensorflow/config.py b/picasso/examples/tensorflow/config.py index 60de7b0..1e8ef20 100644 --- a/picasso/examples/tensorflow/config.py +++ b/picasso/examples/tensorflow/config.py @@ -1,14 +1,13 @@ +# Note: By default, Flask doesn't know that this file exists. If you want +# Flask to load the settings you specify here, you must set the environment +# variable `PICASSO_SETTINGS` to point to this file. E.g.: +# +# export PICASSO_SETTINGS=/path/to/examples/tensorflow/config.py +# import os base_dir = os.path.dirname(os.path.abspath(__file__)) -BACKEND_ML = 'tensorflow' -BACKEND_PREPROCESSOR_NAME = 'preprocess' -BACKEND_PREPROCESSOR_PATH = os.path.join(base_dir, 'util.py') -BACKEND_POSTPROCESSOR_NAME = 'postprocess' -BACKEND_POSTPROCESSOR_PATH = os.path.join(base_dir, 'util.py') -BACKEND_PROB_DECODER_NAME = 'prob_decode' -BACKEND_PROB_DECODER_PATH = os.path.join(base_dir, 'util.py') -BACKEND_TF_PREDICT_VAR = 'Softmax:0' -BACKEND_TF_INPUT_VAR = 'convolution2d_input_1:0' +MODEL_CLS_PATH = os.path.join(base_dir, 'model.py') +MODEL_CLS_NAME = 'TensorflowMNISTModel' DATA_DIR = os.path.join(base_dir, 'data-volume') diff --git a/picasso/examples/tensorflow/model.py b/picasso/examples/tensorflow/model.py new file mode 100644 index 0000000..260444f --- /dev/null +++ b/picasso/examples/tensorflow/model.py @@ -0,0 +1,34 @@ +import numpy as np +from PIL import Image + +from picasso.ml_frameworks.tensorflow.model import TFModel + + +MNIST_DIM = (28, 28) + + +class TensorflowMNISTModel(TFModel): + + TF_INPUT_VAR = 'convolution2d_input_1:0' + + TF_PREDICT_VAR = 'Softmax:0' + + def preprocess(self, raw_inputs): + """Convert images into the format required by our model. + + Our model requires that inputs be grayscale (mode 'L'), be resized to + `MNIST_DIM`, and be represented as float32 numpy arrays in range + [0, 1]. + + """ + image_arrays = [] + for raw_im in raw_inputs: + im = raw_im.convert('L') + im = im.resize(MNIST_DIM, Image.ANTIALIAS) + arr = np.array(im) + image_arrays.append(arr) + + inputs = np.array(image_arrays) + return inputs.reshape(len(inputs), + MNIST_DIM[0], + MNIST_DIM[1], 1).astype('float32') / 255 diff --git a/picasso/examples/tensorflow/util.py b/picasso/examples/tensorflow/util.py deleted file mode 100644 index 1820184..0000000 --- a/picasso/examples/tensorflow/util.py +++ /dev/null @@ -1,48 +0,0 @@ -from PIL import Image -from operator import itemgetter -import numpy as np - -MNIST_DIM = (28, 28) - - -def preprocess(targets): - image_arrays = [] - for target in targets: - im = target.convert('L') - im = im.resize(MNIST_DIM, Image.ANTIALIAS) - arr = np.array(im) - image_arrays.append(arr) - - all_targets = np.array(image_arrays) - return all_targets.reshape(len(all_targets), - MNIST_DIM[0], - MNIST_DIM[1], 1).astype('float32') / 255 - - -def postprocess(output_arr): - images = [] - for row in output_arr: - im_array = row.reshape(MNIST_DIM) - images.append(im_array) - - return images - - -def prob_decode(probability_array, top=5): - results = [] - for row in probability_array: - entries = [] - for i, prob in enumerate(row): - entries.append({'index': i, - 'name': str(i), - 'prob': prob}) - - entries = sorted(entries, - key=itemgetter('prob'), - reverse=True)[:top] - - for entry in entries: - entry['prob'] = '{:.3f}'.format(entry['prob']) - results.append(entries) - - return results diff --git a/picasso/ml_frameworks/keras/model.py b/picasso/ml_frameworks/keras/model.py index 5b6715c..09c24ef 100644 --- a/picasso/ml_frameworks/keras/model.py +++ b/picasso/ml_frameworks/keras/model.py @@ -6,27 +6,19 @@ import keras.backend as K from keras.models import model_from_json -from picasso.ml_frameworks.tensorflow.model import TFModel +from picasso.ml_frameworks.model import BaseModel -class KerasModel(TFModel): - """Implements model loading functions for Keras +class KerasModel(BaseModel): + """Implements model loading functions for Keras. - Using this Keras module will require the h5py library, - which is not included with Keras - - Attributes: - sess (Tensorflow :obj:`Session`): underlying Tensorflow session of - the Keras model. - tf_predict_var (:obj:`Tensor`): tensorflow tensor which represents - the class probabilities - tf_input_var (:obj:`Tensor`): tensorflow tensor which represents - the inputs + Using this Keras module will require the h5py library, which is not + included with Keras. """ - def load(self, data_dir='./'): - """Load graph and weight data + def _load(self, data_dir): + """Load graph and weight data. Args: data_dir (:obj:`str`): location of Keras checkpoint (`.hdf5`) files @@ -34,21 +26,17 @@ def load(self, data_dir='./'): is to take the latest of each, by OS timestamp. """ - # find newest ckpt and graph files try: latest_ckpt = max(glob.iglob( - os.path.join(data_dir, '*.h*5')), - key=os.path.getctime) - - self.latest_ckpt_name = os.path.basename(latest_ckpt) - self.latest_ckpt_time = str(datetime.fromtimestamp( - os.path.getmtime(latest_ckpt)) - ) - + os.path.join(data_dir, '*.h*5')), key=os.path.getctime) + self._latest_ckpt_name = os.path.basename(latest_ckpt) + self._latest_ckpt_time = str( + datetime.fromtimestamp(os.path.getmtime(latest_ckpt))) except ValueError: raise FileNotFoundError('No checkpoint (.hdf5 or .h5) files ' 'available at {}'.format(data_dir)) + try: latest_json = max(glob.iglob(os.path.join(data_dir, '*.json')), key=os.path.getctime) @@ -60,13 +48,21 @@ def load(self, data_dir='./'): K.set_learning_phase(0) with open(latest_json, 'r') as f: model_json = json.loads(f.read()) - self.model = model_from_json(model_json) + self._model = model_from_json(model_json) + + self._model.load_weights(latest_ckpt) + self._sess = K.get_session() - self.model.load_weights(latest_ckpt) - self.sess = K.get_session() + self._tf_predict_var = self._model.outputs[0] + self._tf_input_var = self._model.inputs[0] - self.tf_predict_var = self.model.outputs[0] - self.tf_input_var = self.model.inputs[0] + @property + def description(self): + return "%s loaded from %s (name: %s, timestamp: %s)" % ( + type(self).__name__, + self._data_dir, + self._latest_ckpt_name, + self._latest_ckpt_time) - def _predict(self, input_array): - return self.model.predict(input_array) + def predict(self, input_array): + return self._model.predict(input_array) diff --git a/picasso/ml_frameworks/model.py b/picasso/ml_frameworks/model.py index 4d687b1..d3eb7b1 100644 --- a/picasso/ml_frameworks/model.py +++ b/picasso/ml_frameworks/model.py @@ -1,237 +1,189 @@ -import importlib.util -import warnings -from importlib import import_module +import importlib from operator import itemgetter +import warnings -ML_LIBRARIES = { - 'tensorflow': - 'picasso.ml_frameworks.tensorflow.model.TFModel', - 'keras': - 'picasso.ml_frameworks.keras.model.KerasModel' -} +def get_model(model_cls_path, model_cls_name, data_dir, **kwargs): + """Get an instance of the described model. -class Model: - """Model class interface. + Args: + model_cls_path: Path to the module in which the model class + is defined. + model_cls_name: Name of the model class. + data_dir: Directory containing the graph and weights. + kwargs: Arbitrary keyword arguments passed to the model's + constructor. - All ML frameworks should derive from this class for the purposes of - the visualization. This class loads saved files generated by various - ML frameworks and allows us to extract the graph topology, weights, etc. + Returns: + An instance of :class:`.ml_frameworks.model.BaseModel` or subclass + + """ + spec = importlib.util.spec_from_file_location('active_model', + model_cls_path) + model_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(model_module) + model_cls = getattr(model_module, model_cls_name) + model = model_cls(data_dir, **kwargs) + if not isinstance(model, BaseModel): + warnings.warn("Loaded model '%s' at '%s' is not an instance of %r" + % (model_cls_name, model_cls_path, BaseModel)) + return model + + +class BaseModel: + """Interface encapsulating a trained NN model usable for prediction. + + This interface defines: + + - How to load the model's topology and parameters from disk + - How to preprocess a batch of examples for the model + - How to perform prediction using the model + - Etc """ def __init__(self, - preprocessor_name='preprocess', - preprocessor_path=None, - postprocessor_name='postprocess', - postprocessor_path=None, - prob_decoder_name='prob_decode', - prob_decoder_path=None, + data_dir, top_probs=5, **kwargs): - """Attempt to load utilities - - The class constructor attempts to import a preprocessor, postprocessor, - and probability decoder if a path is supplied. + """Create a new instance of this model. + + `BaseModel` is an interface and should only be instantiated via a + subclass. Args: - preprocessor_name (str, optional): the name of the preprocessing - function. Defaults to 'preprocess'. - preprocessor_path (str, optional): the absolute path to the file - containing the function named above. If `None`, then do not - try to load a preprocessor. Defaults to `None`. - postprocessor_name (str, optional): the name of the postprocessing - function. Defaults to 'postprocess'. - postprocessor_path (str, optional): the absolute path to the file - containing the function named above. If `None`, then do not - try to load a postprocessor. Defaults to `None`. - prob_decoder_name (str, optional): the name of the postprocessing - function. Defaults to 'prob_decode'. - prob_decoder_path (str, optional): the absolute path to the file - containing the function named above. If `None`, then do not - try to load a prob_decoder. Defaults to `None`. top_probs (int): Number of classes to display per result. For instance, VGG16 has 1000 classes, we don't want to display a visualization for every single possibility. Defaults to 5. - **kwargs: Arbitrary keyword arguments, useful for passing specific + kwargs: Arbitrary keyword arguments, useful for passing specific settings to derived classes. - Example: - If you define a function called "preprocess" at "/path/to/util.py", - then try:: - - preprocessor_name='preprocess', - preprocessor_path='/path/to/util.py' - """ - self.latest_ckpt_name = None - self.latest_ckpt_time = None + self._data_dir = data_dir + self._load(data_dir) + self.top_probs = top_probs - self.preprocessor_name = preprocessor_name - self.preprocessor_path = preprocessor_path - self.postprocessor_name = postprocessor_name - self.postprocessor_path = postprocessor_path - self.prob_decoder_name = prob_decoder_name - self.prob_decoder_path = prob_decoder_path - - for util in ('preprocessor', 'postprocessor', 'prob_decoder'): - if getattr(self, '{}_path'.format(util)): - spec = importlib.util.\ - spec_from_file_location( - getattr(self, '{}_name'.format(util)), - getattr(self, '{}_path'.format(util))) - setattr(self, util, importlib.util.module_from_spec(spec)) - spec.loader.exec_module(getattr(self, util)) - - if kwargs: - for key, value in kwargs.items(): - setattr(self, key, value) - - def load(self, data_dir, **kwargs): - """Load the model in the desired framework - - Given a directory where model data (weights and graph - structure), should be able to restore the model locally to the point - where it can be evaluated. + def _load(self, data_dir): + """Load the model's graph and parameters from disk, and restore the + model so that it can be run for inference. Args: - data_dir (:obj:`str`): full path to directory containing - weight and graph data - **kwargs: Arbitrary keyword arguments, useful for passing specific - settings to derived classes. + data_dir (:obj:`str`): Full path to directory containing + graph and weight data. """ raise NotImplementedError - def _predict(self, targets): - """Evaluate new examples and return class probablilites + @property + def description(self): + """A description of the loaded model. - Given an iterable of examples or numpy array where the first - dimension is the number of example, return a n_examples x - n_classes array of class predictions + This description is rendered to the user in the UI. - Args: - targets: iterable of arrays suitable for input into graph + """ + return "%s loaded from %s" % (type(self).__name__, self._data_dir) - Returns: - array of class probabilities + @property + def sess(self): + """A Tensorflow session that can be used to evaluate tensors in the + model. + + (:obj:`tf.Session`) """ - raise NotImplementedError + return self._sess - def predict(self, raw_targets): - """Predict from raw data + @property + def tf_input_var(self): + """The Tensorflow tensor that represents the model's inputs. - Takes an iterable of data in its raw format. Passes to the - preprocessor and then the child class _predict. + (:obj:`tf.Tensor`) - Args: - raw_targets (:obj:`list` of :obj:`PIL.Image`): the images - to be processed + """ + return self._tf_input_var - Returns: - array of class probabilities + @property + def tf_predict_var(self): + """The Tensorflow tensor that represents the model's predicted class + probabilities. + + (:obj:`tf.Tensor`) """ - return self._predict(self.preprocess(raw_targets)) + return self._tf_predict_var + + def preprocess(self, raw_inputs): + """Preprocess raw inputs into the format required by the model. - def preprocess(self, raw_targets): - """Preprocess raw input for evaluation by model + E.g, the raw image may need to converted to a numpy array of the + appropriate dimension. - Usually, input will need some preprocessing before submission - to a computation graph. For instance, the raw image may need - to converted to a numpy array of appropriate dimension + By default, we perform no preprocessing. Args: - raw_targets (:obj:`list` of :obj:`PIL.Image`): the images - to be processed + raw_inputs (:obj:`list` of :obj:`PIL.Image`): List of raw + input images of any mode and shape. Returns: - iterable of arrays of the correct shape for input into graph + array (float32): Images ready to be fed into the model. """ - try: - return getattr(self.preprocessor, - self.preprocessor_name)(raw_targets) - except AttributeError: - warnings.warn('Evaluating without preprocessor') - return raw_targets + return raw_inputs - def postprocess(self, output_arr): - """Postprocess prediction results back into images + def predict(self, inputs): + """Given preprocessed inputs, generate class probabilities by using the + model to perform inference. - Sometimes it's useful to display an intermediate computation - as image. This is model-dependent. + Given an iterable of examples or numpy array where the first + dimension is the number of example, return a n_examples x + n_classes array of class predictions Args: - output_arr (iterable of arrays): any array with the - same total number of entries an input array + inputs: Iterable of examples (e.g., a numpy array whose first + dimension is the batch size). Returns: - iterable of arrays in original image shape + Class probabilities for each input example, as a numpy array of + shape (num_examples, num_classes). """ - - try: - return getattr(self.postprocessor, - self.postprocessor_name)(output_arr) - except AttributeError: - warnings.warn('Evaluating without postprocessor') - return output_arr + raise NotImplementedError def decode_prob(self, output_arr): - """Label class probabilites with class names + """Given predicted class probabilites for a set of examples, annotate + each logit with a class name. + + By default, we name each class using its index in the logits array. Args: - output_arr (array): class probabilities + output_arr (array): Class probabilities as output by + `self.predict`, i.e., a numpy array of shape (num_examples, + num_classes). Returns: - result list of dict in the format [{'index': class_index, 'name': - class_name, 'prob': class_probability}, ...] + Annotated class probabilities for each input example, as a list of + dicts where each dict is formatted as: + { + 'index': class_index, + 'name': class_name, + 'prob': class_probability + } """ - - try: - return getattr(self.prob_decoder, - self.prob_decoder_name)(output_arr, - top=self.top_probs) - except AttributeError: - warnings.warn('Evaluating without class decoder') - results = [] - for row in output_arr: - entries = [] - for i, prob in enumerate(row): - entries.append({'index': i, - 'name': str(i), - 'prob': prob}) - - entries = sorted(entries, - key=itemgetter('prob'), - reverse=True)[:self.top_probs] - - for entry in entries: - entry['prob'] = '{:.3f}'.format(entry['prob']) - results.append(entries) - return results - - -def generate_model(backend_ml, **kwargs): - """Create a new instance of ML backend - - Args: - backend_ml (:obj:`str`): name of the backend to use - **kwargs: Arbitrary keyword arguments - - Returns: - An instance of :class:`.ml_frameworks.model.Model` - - """ - module_name, _, class_name = \ - ML_LIBRARIES[backend_ml].rpartition('.') - - cls = getattr(import_module(module_name), class_name) - - kwargs = {k.partition('_')[-1]: - v for (k, v) in kwargs.items()} - return cls(**kwargs) + results = [] + for row in output_arr: + entries = [] + for i, prob in enumerate(row): + entries.append({'index': i, + 'name': str(i), + 'prob': prob}) + + entries = sorted(entries, + key=itemgetter('prob'), + reverse=True)[:self.top_probs] + + for entry in entries: + entry['prob'] = '{:.3f}'.format(entry['prob']) + results.append(entries) + return results diff --git a/picasso/ml_frameworks/tensorflow/model.py b/picasso/ml_frameworks/tensorflow/model.py index 7f11ffd..1566466 100644 --- a/picasso/ml_frameworks/tensorflow/model.py +++ b/picasso/ml_frameworks/tensorflow/model.py @@ -7,34 +7,43 @@ from picasso.ml_frameworks.model import Model -class TFModel(Model): - """Implements model loading functions for tensorflow""" +class TFModel(BaseModel): + """Implements model loading functions for Tensorflow. + + """ - def load(self, data_dir='./'): + # Name of the tensor corresponding to the model's inputs. You must define + # this if you are loading the model from a checkpoint. + TF_INPUT_VAR = None + + # Name of the tensor corresponding to the model's inputs. You must define + # this if you are loading the model from a checkpoint. + TF_PREDICT_VAR = None + + def _load(self, data_dir): """Load graph and weight data Args: - data_dir (:obj:`str`): location of tensorflow checkpoint - data. We'll need the .meta file to reconstruct - the graph and the data (checkpoint) files to - fill in the weights of the model. The default - behavior is take the latest files, by OS timestamp. + data_dir (:obj:`str`): location of tensorflow checkpoint data. + We'll need the .meta file to reconstruct the graph and the data + (checkpoint) files to fill in the weights of the model. The + default behavior is take the latest files, by OS timestamp. """ + self._sess = tf.Session() + self._sess.as_default() - self.sess = tf.Session() - self.sess.as_default() # find newest ckpt and meta files try: latest_ckpt_fn = max(glob.iglob(os.path.join(data_dir, '*.ckpt*')), key=os.path.getctime) - self.latest_ckpt_time = str(datetime.fromtimestamp( - os.path.getmtime(latest_ckpt_fn) - )) - latest_ckpt = latest_ckpt_fn[:latest_ckpt_fn.rfind('.ckpt') + 5] + self._latest_ckpt_time = str( + datetime.fromtimestamp(os.path.getmtime(latest_ckpt_fn))) except ValueError: raise FileNotFoundError('No checkpoint (.ckpt) files ' 'available at {}'.format(data_dir)) + latest_ckpt = latest_ckpt_fn[:latest_ckpt_fn.rfind('.ckpt') + 5] + try: latest_meta = max(glob.iglob(os.path.join(data_dir, '*.meta')), key=os.path.getctime) @@ -42,15 +51,21 @@ def load(self, data_dir='./'): raise FileNotFoundError('No graph (.meta) files ' 'available at {}'.format(data_dir)) - with self.sess.as_default() as sess: - self.saver = tf.train.import_meta_graph(latest_meta) - self.saver.restore(sess, latest_ckpt) + self._saver = tf.train.import_meta_graph(latest_meta) + self._saver.restore(self._sess, latest_ckpt) + + self._tf_input_var = self.sess.graph.get_tensor_by_name( + self.TF_INPUT_VAR) + self._tf_predict_var = self._sess.graph.get_tensor_by_name( + self.TF_PREDICT_VAR) - self.tf_predict_var = \ - self.sess.graph.get_tensor_by_name(self.tf_predict_var) - self.tf_input_var = \ - self.sess.graph.get_tensor_by_name(self.tf_input_var) + @property + def description(self): + return "%s loaded from %s (timestamp: %s)" % ( + type(self).__name__, + self._data_dir, + self._latest_ckpt_time) - def _predict(self, input_array): + def predict(self, input_array): return self.sess.run(self.tf_predict_var, {self.tf_input_var: input_array}) diff --git a/picasso/picasso.py b/picasso/picasso.py index 10d141a..f917c79 100644 --- a/picasso/picasso.py +++ b/picasso/picasso.py @@ -44,7 +44,7 @@ ) from picasso import app -from picasso.ml_frameworks.model import generate_model +from picasso.ml_frameworks.model import get_model from picasso.visualizations import BaseVisualization from picasso.visualizations import * @@ -76,20 +76,27 @@ # safest way. Would be much better to connect to a # persistent tensorflow session running in another process or # machine. -ml_backend = \ - generate_model( - **{k.lower(): v for (k, v) - in app.config.items() - if k.startswith('BACKEND')} - ) -ml_backend.load(app.config['DATA_DIR']) +model = get_model(app.config['MODEL_CLS_PATH'], + app.config['MODEL_CLS_NAME'], + app.config['DATA_DIR']) -def get_visualizations(): - """Get visualization classes in context +def get_model(): + """Get the NN model that's being analyzed from the request context. Put + the model in the request context if it is not yet there. + + Returns: + instance of :class:`.ml_frameworks.model.Model` or derived + class + """ + if not hasattr(g, 'model'): + g.model = model + return g.model - Puts the available visualizations in the request context - and returns them. + +def get_visualizations(): + """Get the available visualizations from the request context. Put the + visualizations in the request context if they are not yet there. Returns: :obj:`list` of instances of :class:`.BaseVisualization` or @@ -99,26 +106,12 @@ def get_visualizations(): if not hasattr(g, 'visualizations'): g.visualizations = {} for VisClass in VISUALIZATON_CLASSES: - vis = VisClass(get_ml_backend()) + vis = VisClass(get_model()) g.visualizations[vis.__class__.__name__] = vis return g.visualizations -def get_ml_backend(): - """Get machine learning backend in context - - Puts the backend in the request context and returns it. - - Returns: - instance of :class:`.ml_frameworks.model.Model` or derived - class - """ - if not hasattr(g, 'ml_backend'): - g.ml_backend = ml_backend - return g.ml_backend - - def get_app_state(): """Get current status of application in context @@ -127,12 +120,10 @@ def get_app_state(): """ if not hasattr(g, 'app_state'): - model = get_ml_backend() + model = get_model() g.app_state = { 'app_title': APP_TITLE, - 'backend': type(model).__name__, - 'latest_ckpt_name': model.latest_ckpt_name, - 'latest_ckpt_time': model.latest_ckpt_time + 'model_description': model.description, } return g.app_state @@ -150,14 +141,14 @@ def landing(): if request.method == 'POST': session['vis_name'] = request.form.get('choice') vis = get_visualizations()[session['vis_name']] - if hasattr(vis, 'settings'): + if vis.AVAILABLE_SETTINGS: return visualization_settings() return select_files() # otherwise, on GET request visualizations = get_visualizations() vis_desc = [{'name': vis, - 'description': visualizations[vis].description} + 'description': visualizations[vis].DESCRIPTION} for vis in visualizations] session.clear() return render_template('select_visualization.html', @@ -171,7 +162,7 @@ def landing(): def visualization_settings(): """Visualization settings page - Will only render if the visualization object has a `settings` + Will only render if the visualization object has a non-null `settings` attribute. """ @@ -180,7 +171,7 @@ def visualization_settings(): return render_template('settings.html', app_state=get_app_state(), current_vis=session['vis_name'], - settings=vis.settings) + settings=vis.AVAILABLE_SETTINGS) @app.route('/select_files', methods=['GET', 'POST']) @@ -217,10 +208,9 @@ def select_files(): start_time = time.time() session['img_output_dir'] = mkdtemp() - output = \ - vis.make_visualization(inputs, - output_dir=session['img_output_dir'], - settings=session['settings']) + output = vis.make_visualization(inputs, + output_dir=session['img_output_dir'], + settings=session['settings']) duration = '{:.2f}'.format(time.time() - start_time, 2) for i, file_obj in enumerate(request.files.getlist('file[]')): @@ -233,8 +223,8 @@ def select_files(): entry['data'].save(path, 'PNG') kwargs = {} - if hasattr(vis, 'reference_link'): - kwargs.update({'reference_link': vis.reference_link}) + if vis.REFERENCE_LINK: + kwargs['reference_link'] = vis.REFERENCE_LINK return render_template('{}.html'.format(session['vis_name']), inputs=inputs, diff --git a/picasso/settings.py b/picasso/settings.py index d469651..d772174 100644 --- a/picasso/settings.py +++ b/picasso/settings.py @@ -4,41 +4,22 @@ class Default: - """Default configuration settings + """Default configuration settings. - The app will use these settings if none are specified. That is, - if no configuration file is specified by PICASSO_SETTINGS - or any individual setting is specified by environment variable. - These are, in effect, "settings of last resort." + These settings are overridden by any settings defined in the Python module + referred to by the environment variable `PICASSO_SETTINGS`. If + `PICASSO_SETTINGS` is not set, or if any particular parameter value is + not set in the indicated module, then the app uses these settings. - The paths will automatically be generated based on the location of - the source. """ + # :obj:`str`: filepath of the module containing the model to run + MODEL_CLS_PATH = os.path.join( + base_dir, 'examples', 'keras', 'model.py') - #: :obj:`str`: which backend to use - BACKEND_ML = 'keras' + # :obj:`str`: name of model class + MODEL_CLS_NAME = 'KerasMNISTModel' - #: :obj:`str`: name of the preprocess function - BACKEND_PREPROCESSOR_NAME = 'preprocess' - - #: :obj:`str`: filepath of the preprocess function - BACKEND_PREPROCESSOR_PATH = os.path.join( - base_dir, 'examples', 'keras', 'util.py') - - #: :obj:`str`: name of the postprocess function - BACKEND_POSTPROCESSOR_NAME = 'postprocess' - - #: :obj:`str`: filepath of the postprocess function - BACKEND_POSTPROCESSOR_PATH = os.path.join( - base_dir, 'examples', 'keras', 'util.py') - - #: :obj:`str`: name of the probability decoder function - BACKEND_PROB_DECODER_NAME = 'prob_decode' - - #: :obj:`str`: filepath of the probability decoder function - BACKEND_PROB_DECODER_PATH = os.path.join( - base_dir, 'examples', 'keras', 'util.py') - - #: :obj:`str`: path to directory containing weights and graph + # :obj:`str`: path to directory containing weights and graph DATA_DIR = os.path.join( base_dir, 'examples', 'keras', 'data-volume') + diff --git a/picasso/templates/layout.html b/picasso/templates/layout.html index cd30161..6e4a271 100644 --- a/picasso/templates/layout.html +++ b/picasso/templates/layout.html @@ -10,22 +10,26 @@ +

{{ app_state.app_title }} by Merantix

-
-

Current backend: {{ app_state.backend }}

- {% if app_state.latest_ckpt_name is defined %} -

Current checkpoint: {{ app_state.latest_ckpt_name }}

- {% endif %} - {% if app_state.latest_ckpt_time is defined %} -

Last updated: {{ app_state.latest_ckpt_time }}

- {% endif %} + + -
+ +
+ {% block body %}{% endblock %} + +
+ + +
diff --git a/picasso/visualizations/__init__.py b/picasso/visualizations/__init__.py index 0146dd3..6c1fd87 100644 --- a/picasso/visualizations/__init__.py +++ b/picasso/visualizations/__init__.py @@ -11,32 +11,58 @@ class BaseVisualization: - """Template for visualizations - - Attributes: - description (:obj:`str`): short description of the visualization - model (instance of :class:`.ml_frameworks.model.Model` or derived class): - backend to use - settings (:obj:`dict`): a settings dictionary. Settings defined - here will be rendered in html for the user to select. See - derived classes for examples. + """Interface encapsulating a NN visualization. + + This interface defines how a visualization is computed for a given NN + model. + """ + # (:obj:`str`): Short description of the visualization. + DESCRIPTION = None + + # (:obj:`str`): Optional link to the paper specifying the visualization. + REFERENCE_LINK = None + + # (:obj:`dict`): Optional visuzalization settings that the user can select. + # Should be a dict mapping setting names to lists of their allowed values. + AVAILABLE_SETTINGS = None + def __init__(self, model): - self.model = model + """Create a new instance of this visualization. + + `BaseVisualization` is an interface and should only be instantiated via + a subclass. + + Args: + model (:obj:`.ml_frameworks.model.BaseModel`): NN model to be + visualized. + + """ + self._model = model + + @property + def model(self): + """NN model to be visualized. + + (:obj:`.ml_frameworks.model.BaseModel`) + + """ + return self._model def make_visualization(self, inputs, output_dir, settings=None): - """Generate the visualization + """Generate the visualization. All visualizations must implement this method. Args: - inputs (iterable of :class:`PIL.Image`): images uploaded by the - user. Will have already been converted to :obj:`Image` - objects. - output_dir (:obj:`str`): a directory to store outputs (e.g. plots) + inputs (iterable of :class:`PIL.Image`): Batch of input images to + make visualizations for, as PIL :obj:`Image` objects. + output_dir (:obj:`str`): A directory to write outputs to (e.g., + plots). Returns: data needed to render the visualization. Since there is an associated HTML template, the return type is arbitrary. + """ raise NotImplementedError diff --git a/picasso/visualizations/class_probabilities.py b/picasso/visualizations/class_probabilities.py index f87d5f3..ca49cbe 100644 --- a/picasso/visualizations/class_probabilities.py +++ b/picasso/visualizations/class_probabilities.py @@ -9,10 +9,9 @@ class probabilities of the input image. """ - description = 'Predict class probabilities from new examples' + DESCRIPTION = 'Predict class probabilities from new examples' - def make_visualization(self, inputs, - output_dir, settings=None): + def make_visualization(self, inputs, output_dir, settings=None): pre_processed_arrays = self.model.preprocess([example['data'] for example in inputs]) predictions = self.model.sess.run(self.model.tf_predict_var, diff --git a/picasso/visualizations/partial_occlusion.py b/picasso/visualizations/partial_occlusion.py index 239b4d6..711454d 100644 --- a/picasso/visualizations/partial_occlusion.py +++ b/picasso/visualizations/partial_occlusion.py @@ -22,16 +22,17 @@ class PartialOcclusion(BaseVisualization): classifying on the image feature we expect. """ - settings = { + DESCRIPTION = ('Partially occlude image to determine regions ' + 'important to classification') + + REFERENCE_LINK = 'https://arxiv.org/abs/1311.2901' + + AVAILABLE_SETTINGS = { 'Window': ['0.50', '0.40', '0.30', '0.20', '0.10', '0.05'], 'Strides': ['2', '5', '10', '20', '30'], 'Occlusion': ['grey', 'black', 'white'] } - description = ('Partially occlude image to determine regions ' - 'important to classification') - reference_link = 'https://arxiv.org/abs/1311.2901' - def __init__(self, model): super(PartialOcclusion, self).__init__(model) self.predict_tensor = self.get_predict_tensor() @@ -53,11 +54,10 @@ def make_visualization(self, inputs, output_dir, settings=None): # get class predictions as in ClassProbabilities pre_processed_arrays = self.model.preprocess([example['data'] - for example in inputs]) - class_predictions = \ - self.model.sess.run(self.model.tf_predict_var, - feed_dict={self.model.tf_input_var: - pre_processed_arrays}) + for example in inputs]) + class_predictions = self.model.sess.run( + self.model.tf_predict_var, + feed_dict={self.model.tf_input_var: pre_processed_arrays}) decoded_predictions = self.model.decode_prob(class_predictions) results = [] @@ -100,8 +100,8 @@ def make_visualization(self, inputs, output_dir, settings=None): def get_predict_tensor(self): # Assume that predict is the softmax # tensor in the computation graph - return self.model.sess.graph. \ - get_tensor_by_name(self.model.tf_predict_var.name) + return self.model.sess.graph.get_tensor_by_name( + self.model.tf_predict_var.name) def update_settings(self, settings): def error_string(setting, setting_val): @@ -112,19 +112,19 @@ def error_string(setting, setting_val): vis=self.__class__.__name__) if 'Window' in settings: - if settings['Window'] in self.settings['Window']: + if settings['Window'] in self.AVAILABLE_SETTINGS['Window']: self.window = float(settings['Window']) else: raise ValueError(error_string(settings['Window'], 'Window')) if 'Strides' in settings: - if settings['Strides'] in self.settings['Strides']: + if settings['Strides'] in self.AVAILABLE_SETTINGS['Strides']: self.num_windows = int(settings['Strides']) else: raise ValueError(error_string(settings['Strides'], 'Strides')) if 'Occlusion' in settings: - if settings['Occlusion'] in self.settings['Occlusion']: + if settings['Occlusion'] in self.AVAILABLE_SETTINGS['Occlusion']: self.occlusion_method = settings['Occlusion'] else: raise ValueError(error_string(settings['Occlusion'], diff --git a/picasso/visualizations/saliency_maps.py b/picasso/visualizations/saliency_maps.py index 060df75..f3d8160 100644 --- a/picasso/visualizations/saliency_maps.py +++ b/picasso/visualizations/saliency_maps.py @@ -21,24 +21,27 @@ class SaliencyMaps(BaseVisualization): classification (as changing them would change the classification). """ - description = ('See maximal derivates against class with respect ' + DESCRIPTION = ('See maximal derivates against class with respect ' 'to input') - reference_link = 'https://arxiv.org/pdf/1312.6034' + + REFERENCE_LINK = 'https://arxiv.org/pdf/1312.6034' def __init__(self, model, logit_tensor_name=None): super(SaliencyMaps, self).__init__(model) if logit_tensor_name: - self.logit_tensor = self.model.sess.graph \ - .get_tensor_by_name(logit_tensor_name) + self.logit_tensor = self.model.sess.graph.get_tensor_by_name( + logit_tensor_name) else: self.logit_tensor = self.get_logit_tensor() + self.input_shape = self.model.tf_input_var.get_shape()[1:].as_list() + def get_gradient_wrt_class(self, class_index): - gradient_name = 'bv_{class_index}_gradient' \ - .format(class_index=class_index) + gradient_name = 'bv_{class_index}_gradient'.format( + class_index=class_index) try: - return self.model.sess.graph. \ - get_tensor_by_name('{}:0'.format(gradient_name)) + return self.model.sess.graph.get_tensor_by_name( + '{}:0'.format(gradient_name)) except KeyError: class_logit = tf.slice(self.logit_tensor, [0, class_index], @@ -61,24 +64,29 @@ def make_visualization(self, inputs, output_dir, settings=None): results = [] for i, inp in enumerate(inputs): class_gradients = [] - output_images = [] relevant_class_indices = [pred['index'] for pred in decoded_predictions[i]] - gradients_wrt_class = [self.get_gradient_wrt_class(index) for index - in relevant_class_indices] + gradients_wrt_class = [self.get_gradient_wrt_class(index) + for index in relevant_class_indices] for gradient_wrt_class in gradients_wrt_class: class_gradients.append([self.model.sess.run( gradient_wrt_class, feed_dict={self.model.tf_input_var: [arr]}) for arr in pre_processed_arrays]) - output_fns = [] - output_arrays = np.array([gradient[i] for - gradient in class_gradients]) + + output_arrays = np.array([gradient[i] + for gradient in class_gradients]) # if images are color, take the maximum channel if output_arrays.shape[-1] == 3: output_arrays = output_arrays.max(-1) + # we care about the size of the derivative, not the sign + output_arrays = np.abs(output_arrays) + + # We want each array to be represented as a 1-channel image of + # the same size as the model's input image. + output_images = output_arrays.reshape([-1] + self.input_shape[0:2]) - output_images = self.model.postprocess(np.abs(output_arrays)) + output_fns = [] for j, image in enumerate(output_images): output_fn = '{fn}-{j}-{ts}.png'.format(ts=str(time.time()), j=j, @@ -106,8 +114,7 @@ def get_logit_tensor(self): # Assume that the logits are the tensor input to the last softmax # operation in the computation graph sm = [node for node in self.model.sess.graph_def.node - if node.name == - self.model.tf_predict_var.name.split(':')[0]][-1] + if node.name == self.model.tf_predict_var.name.split(':')[0]][-1] logit_op_name = sm.input[0] - return self.model.sess.graph. \ - get_tensor_by_name('{}:0'.format(logit_op_name)) + return self.model.sess.graph.get_tensor_by_name( + '{}:0'.format(logit_op_name)) diff --git a/tests/conftest.py b/tests/conftest.py index 777ac6f..238e5f7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -27,5 +27,8 @@ def example_prob_array(): @pytest.fixture def base_model(): - from picasso.ml_frameworks.model import Model - return Model() + from picasso.ml_frameworks.model import BaseModel + class BaseModelForTest(BaseModel): + def _load(self, data_dir): + pass + return BaseModelForTest("")