diff --git a/benchmark/supervised/train.py b/benchmark/supervised/train.py index e3bca3db..8991b127 100644 --- a/benchmark/supervised/train.py +++ b/benchmark/supervised/train.py @@ -52,6 +52,188 @@ def make_model(exp: Experiment) -> tf.keras.Model: return model +def run(config): + if config.get("tfds_data_dir", None): + os.environ["TFDS_DATA_DIR"] = config["tfds_data_dir"] + + agg_results = {} + for dataset_name, dconf in config["datasets"].items(): + if "train_val_splits" not in dconf: + dconf["train_val_splits"] = { + "n_splits": 1, + "val_class_pctg": 0.05, + "max_val_examples": 10000, + } + + for architecture_name, aconf in config["architectures"].items(): + for embedding_size in aconf.get("embedding_sizes", [128]): + aconf["embedding"] = embedding_size + for loss_name, lconf in config["losses"].items(): + for opt_name, oconf in config["optimizer"].items(): + for training_name, tconf in config["training"].items(): + version = config["version"] + pconf = config["preprocess"] + aug_conf = config["augmentations"] + + # Load the raw dataset + cprint(f"\n|-loading and preprocessing {dataset_name}\n", "blue") + preproc_fns = make_augmentations(pconf) + x, y = datasets.load_tf_dataset(dataset_name, dconf, preproc_fns) + + for fold in range(dconf["train_val_splits"]["n_splits"]): + gc.collect() + tf.keras.backend.clear_session() + + tf.random.set_seed(config["random_seed"]) + + headers = [ + "dataset_name", + "architecture_name", + "loss_name", + "opt_name", + "training_name", + ] + row = [ + [ + f"{dataset_name}", + f"{architecture_name}-{aconf['embedding']}", + f"{loss_name}", + f"{opt_name}", + f"{training_name}", + ] + ] + print("\n") + cprint(tabulate(row, headers=headers), "yellow") + + ds_splits = datasets.create_splits(x, y, dconf, fold) + aug_fns = make_augmentations(aug_conf["train"]) + cprint("\n|-building train dataset\n", "blue") + train_ds = datasets.make_sampler( + ds_splits["train"][0], ds_splits["train"][1], tconf, aug_fns + ) + cprint("\n|-building val dataset\n", "blue") + val_ds = datasets.make_sampler( + ds_splits["val"][0], ds_splits["val"][1], tconf, aug_fns + ) + + # Build model + model = build_model(aconf, lconf, oconf) + + # Make result path + stub = utils.make_stub( + version, + dataset_name, + architecture_name, + aconf["embedding"], + loss_name, + opt_name, + fold, + ) + utils.clean_dir(stub) + + # Training params + callbacks = [ + metrics.make_eval_callback( + val_ds, + dconf["eval_callback"]["max_num_queries"], + dconf["eval_callback"]["max_num_targets"], + ), + ModelCheckpoint( + stub, + monitor="val_loss", + save_best_only=True, + ), + ] + + if "steps_per_epoch" in tconf: + steps_per_epoch = tconf["steps_per_epoch"] + else: + batch_size = train_ds.classes_per_batch * train_ds.examples_per_class_per_batch + steps_per_epoch = train_ds.num_examples // batch_size + + if "validation_steps" in tconf: + validation_steps = tconf["validation_steps"] + else: + batch_size = val_ds.classes_per_batch * val_ds.examples_per_class_per_batch + validation_steps = val_ds.num_examples // batch_size + + if "epochs" in tconf: + epochs = tconf["epochs"] + else: + epochs = 1000 + early_stopping = EarlyStopping( + monitor="val_loss", + patience=5, + verbose=0, + mode="auto", + restore_best_weights=True, + ) + callbacks.append(early_stopping) + + t_msg = [ + "\n|-Training", + f"| - Fold: {fold}", + f"| - Num train examples: {train_ds.num_examples}", + f"| - Num val examples: {val_ds.num_examples}", + f"| - Steps per epoch: {steps_per_epoch}", + f"| - Epochs: {epochs}", + f"| - Validation steps: {validation_steps}", + "| - Eval callback", + f"| -- Num queries: {len(callbacks[0].queries_known)}", + f"| -- Num targets: {len(callbacks[0].targets)}", + ] + cprint("\n".join(t_msg) + "\n", "green") + history = model.fit( + train_ds, + steps_per_epoch=steps_per_epoch, + epochs=epochs, + callbacks=callbacks, + validation_data=val_ds, + validation_steps=validation_steps, + ) + + # Evaluation + test_aug_fns = make_augmentations(aug_conf["test"]) + cprint("\n|-building eval dataset\n", "blue") + test_x, test_y, class_counts = datasets.make_eval_data( + ds_splits["test"][0], ds_splits["test"][1], test_aug_fns + ) + + print("Make Metrics") + eval_metrics = metrics.make_eval_metrics(dconf, config["evaluation"], class_counts) + + try: + model.reset_index() + except AttributeError: + model.create_index() + print("Add Examples to Index") + model.index(test_x, test_y) + + e_msg = [ + "\n|-Evaluate Retriveal Metrics", + f"| - Fold: {fold}", + f"| - Num eval examples: {len(test_x)}", + ] + cprint("\n".join(e_msg) + "\n", "green") + eval_results = model.evaluate_retrieval( + test_x, + test_y, + retrieval_metrics=eval_metrics, + ) + + agg_results[os.path.basename(stub)] = eval_results + + # Save history + with open(os.path.join(stub, "history.json"), "w") as o: + o.write(json.dumps(history.history, cls=utils.NpEncoder)) + + # Save eval metrics + with open(os.path.join(stub, "eval_metrics.json"), "w") as o: + o.write(json.dumps(eval_results, cls=utils.NpEncoder)) + + with open(os.path.join(os.path.dirname(stub), "all_eval_metrics.json"), "w") as o: + o.write(json.dumps(agg_results, cls=utils.NpEncoder)) +======= def run(cfg: Mapping[str, Any], filter_pattern: str) -> None: tracemalloc.start() snapshots = [] diff --git a/examples/ArcFace Loss Sample Notebook.ipynb b/examples/ArcFace Loss Sample Notebook.ipynb new file mode 100644 index 00000000..9eb1c21d --- /dev/null +++ b/examples/ArcFace Loss Sample Notebook.ipynb @@ -0,0 +1,833 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "28956aa1", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "Copyright 2022 The TensorFlow Similarity Authors." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "24eda1a6", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "id": "7ca9d025", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "# TensorFlow Similarity ArcFace Loss Example" + ] + }, + { + "cell_type": "markdown", + "id": "d072628f", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "A Total Angular Margin Loss (ArcFace) calculates the geodetic distance in the hypersphere instead of the euclidean distance to improve the discriminatory strength of the facial recognition model and stabilize the training process. Rails are used to measure all distances in geodetic space. The geodetic trace is the path taken between two places. It specifies the geodetic distance, which is the shortest distance between two places.\n", + "\n", + "ArcFace loss determines the angle between the current feature and the target weight using the arc-cosine function since the dot product between the DCNN feature and the last fully connected layer after feature and weight normalization matches the cosine distance. The target logit is then returned by multiplying the goal angle by an additional angular margin and using the cosine function. After that, we continue as before and rescale all logits to a certain feature norm, just like with softmax loss." + ] + }, + { + "cell_type": "markdown", + "id": "808ac087", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "### Notebook goal\n", + "\n", + "This notebook demonstrates how to use ArcFaceLoss implementation of TensorFlow Similarity with standalone usage and to train a `SimilarityModel()` on a fraction of the MNIST classes.\n", + "\n", + "You are going to learn about the main features offered by the `ArcFaceLoss()` and will:\n", + "\n", + " 1. Standalone usage of ArcFaceLoss\n", + "\n", + " 2. Usage with `model.compile()`\n", + "\n", + " 3. 3D-Visualization of ArcFaceLoss \n", + "\n", + "### Things to try \n", + "\n", + "Along the way you can try the following things to improve the model performance:\n", + "- Adding more \"seen\" classes at training time.\n", + "- Use a larger embedding by increasing the size of the output.\n", + "- Add data augmentation pre-processing layers to the model.\n", + "- Include more examples in the index to give the models more points to choose from.\n", + "- Try a more challenging dataset, such as Fashion MNIST." + ] + }, + { + "cell_type": "markdown", + "id": "078c53c0", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "### Notebook goal\n", + "\n", + "This notebook demonstrates how to use ArcFaceLoss implementation of TensorFlow Similarity with standalone usage and to train a `SimilarityModel()` on a fraction of the MNIST classes.\n", + "\n", + "You are going to learn about the main features offered by the `ArcFaceLoss()` and will:\n", + "\n", + " 1. Standalone usage of ArcFaceLoss\n", + "\n", + " 2. Usage with `model.compile()`\n", + "\n", + " 3. 3D-Visualization of ArcFaceLoss \n", + "\n", + "### Things to try \n", + "\n", + "Along the way you can try the following things to improve the model performance:\n", + "- Adding more \"seen\" classes at training time.\n", + "- Use a larger embedding by increasing the size of the output.\n", + "- Add data augmentation pre-processing layers to the model.\n", + "- Include more examples in the index to give the models more points to choose from.\n", + "- Try a more challenging dataset, such as Fashion MNIST." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "8fd63f16", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "import gc\n", + "import os\n", + "\n", + "import numpy as np\n", + "from matplotlib import pyplot as plt\n", + "from tabulate import tabulate\n", + "from mpl_toolkits.mplot3d import Axes3D\n", + "\n", + "# INFO messages are not printed.\n", + "# This must be run before loading other modules.\n", + "os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"1\"" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "80af5fc0", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "import tensorflow as tf" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "8ba8caf7", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Your CPU supports instructions that this binary was not compiled to use: SSE3 SSE4.1 SSE4.2 AVX AVX2\n", + "For maximum performance, you can install NMSLIB from sources \n", + "pip install --no-binary :all: nmslib\n" + ] + } + ], + "source": [ + "# install TF similarity if needed\n", + "try:\n", + " import tensorflow_similarity as tfsim # main package\n", + "except ModuleNotFoundError:\n", + " !pip install tensorflow_similarity\n", + " import tensorflow_similarity as tfsim" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "2484bd72", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "tfsim.utils.tf_cap_memory()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "3fe0344e", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "# Clear out any old model state.\n", + "gc.collect()\n", + "tf.keras.backend.clear_session()" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "99d9bef9", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TensorFlow: 2.8.0\n", + "TensorFlow Similarity 0.17.0.dev10\n" + ] + } + ], + "source": [ + "print(\"TensorFlow:\", tf.__version__)\n", + "print(\"TensorFlow Similarity\", tfsim.__version__)" + ] + }, + { + "cell_type": "markdown", + "id": "1d534ad3", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "# Standalone Usage of ArcFaceLoss\n", + "\n", + "ArcFace loss alone can be used as follows when it is desired to calculate the additive angular margin loss of the existing data set." + ] + }, + { + "cell_type": "markdown", + "id": "68d526da", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "### Initialize Loss function as ArcFaceLoss" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "bebf6ef0", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "loss_fn = tfsim.losses.ArcFaceLoss(num_classes=8, embedding_size=10)" + ] + }, + { + "cell_type": "markdown", + "id": "d2ccfd7d", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "### Create own simple random dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "1d1ec43a", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "labels = tf.Variable([0, 1, 2, 3, 4, 5, 6, 7])\n", + "embeddings = tf.Variable(tf.random.uniform(shape=[8, 10]))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "73d0c1c6", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " \n" + ] + } + ], + "source": [ + "print(\"\", embeddings)" + ] + }, + { + "cell_type": "markdown", + "id": "d65b3085", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "### Calculate loss" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "cdf7c30c", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "loss = loss_fn(labels, embeddings)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "16745b7d", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "loss : tf.Tensor(48.764076, shape=(), dtype=float32)\n" + ] + } + ], + "source": [ + "print(\"loss : \" , loss)" + ] + }, + { + "cell_type": "markdown", + "id": "11ef5236", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "## Data preparation\n", + "\n", + "We are going to load the MNIST dataset to showcase how the model is able to find similar examples from classes unseen during training. The model's ability to generalize the matching to unseen classes, without retraining, is one of the main reason you would want to use metric learning.\n", + "\n", + "\n", + "**WARNING**: Tensorflow similarity expects `y_train` to be an IntTensor containing the class ids for each example instead of the standard categorical encoding traditionally used for multi-class classification." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "97152229", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()" + ] + }, + { + "cell_type": "markdown", + "id": "08b766d8", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "## Model setup" + ] + }, + { + "cell_type": "markdown", + "id": "3eac2da7", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "### Model definition\n", + "\n", + "`SimilarityModel()` models extend `tensorflow.keras.model.Model` with additional features and functionality that allow you to index and search for similar looking examples.\n", + "\n", + "As visible in the model definition below, similarity models output a 64 dimensional float embedding using the `MetricEmbedding()` layers. This layer is a Dense layer with L2 normalization. Thanks to the loss, the model learns to minimize the distance between similar examples and maximize the distance between dissimilar examples. As a result, the distance between examples in the embedding space is meaningful; the smaller the distance the more similar the examples are. \n", + "\n", + "Being able to use a distance as a meaningful proxy for how similar two examples are, is what enables the fast ANN (aproximate nearest neighbor) search. Using a sub-linear ANN search instead of a standard quadratic NN search is what allows deep similarity search to scale to millions of items. The built in memory index used in this notebook scales to a million indexed examples very easily... if you have enough RAM :)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "a003c971", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "def get_model():\n", + " inputs = tf.keras.layers.Input(shape=(28, 28, 1))\n", + " x = tf.keras.layers.experimental.preprocessing.Rescaling(1 / 255)(inputs)\n", + " x = tf.keras.layers.Conv2D(32, 3, activation=\"relu\")(x)\n", + " x = tf.keras.layers.Conv2D(32, 3, activation=\"relu\")(x)\n", + " x = tf.keras.layers.MaxPool2D()(x)\n", + " x = tf.keras.layers.Conv2D(64, 3, activation=\"relu\")(x)\n", + " x = tf.keras.layers.Conv2D(64, 3, activation=\"relu\")(x)\n", + " x = tf.keras.layers.Flatten()(x)\n", + " # smaller embeddings will have faster lookup times while a larger embedding will improve the accuracy up to a point.\n", + " outputs = tfsim.layers.MetricEmbedding(64)(x)\n", + " return tfsim.models.SimilarityModel(inputs, outputs)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "a2177b12", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Model: \"similarity_model\"\n", + "_________________________________________________________________\n", + " Layer (type) Output Shape Param # \n", + "=================================================================\n", + " input_1 (InputLayer) [(None, 28, 28, 1)] 0 \n", + " \n", + " rescaling (Rescaling) (None, 28, 28, 1) 0 \n", + " \n", + " conv2d (Conv2D) (None, 26, 26, 32) 320 \n", + " \n", + " conv2d_1 (Conv2D) (None, 24, 24, 32) 9248 \n", + " \n", + " max_pooling2d (MaxPooling2D (None, 12, 12, 32) 0 \n", + " ) \n", + " \n", + " conv2d_2 (Conv2D) (None, 10, 10, 64) 18496 \n", + " \n", + " conv2d_3 (Conv2D) (None, 8, 8, 64) 36928 \n", + " \n", + " flatten (Flatten) (None, 4096) 0 \n", + " \n", + " metric_embedding (MetricEmb (None, 64) 262208 \n", + " edding) \n", + " \n", + "=================================================================\n", + "Total params: 327,200\n", + "Trainable params: 327,200\n", + "Non-trainable params: 0\n", + "_________________________________________________________________\n" + ] + } + ], + "source": [ + "model = get_model()\n", + "model.summary()" + ] + }, + { + "cell_type": "markdown", + "id": "defb3961", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "### ArcFace Loss definition\n", + "\n", + "Overall what makes Metric losses different from tradional losses is that:\n", + "- **They expect different inputs.** Instead of having the prediction equal the true values, they expect embeddings as `y_preds` and the id (as an int32) of the class as `y_true`. \n", + "- **They require a distance.** You need to specify which `distance` function to use to compute the distance between embeddings. `cosine` is usually a great starting point and the default.\n", + "\n", + "ArcFace Loss takes inputs as number of classes which labels includes, and embedding size which we define in model `MetricEmbedding()` layers." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "13b0d745", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "distance = \"cosine\" " + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "c22d10cc", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "num_classes = np.unique(y_train).size\n", + "embedding_size = model.get_layer('metric_embedding').output.shape[1]" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "d5b8e426", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "loss = tfsim.losses.ArcFaceLoss(num_classes=num_classes, embedding_size=embedding_size, name=\"ArcFaceLoss\")" + ] + }, + { + "cell_type": "markdown", + "id": "b6eaf9c8", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "### Compilation\n", + "\n", + "Tensorflow similarity use an extended `compile()` method that allows you to optionally specify `distance_metrics` (metrics that are computed over the distance between the embeddings), and the distance to use for the indexer.\n", + "\n", + "By default the `compile()` method tries to infer what type of distance you are using by looking at the first loss specified. If you use multiple losses, and the distance loss is not the first one, then you need to specify the distance function used as `distance=` parameter in the compile function." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "673f986f", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "LR = 0.0005 # @param {type:\"number\"}\n", + "model.compile(optimizer=tf.keras.optimizers.SGD(LR), loss=loss, distance=distance)" + ] + }, + { + "cell_type": "markdown", + "id": "15961601", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "## Training\n", + "\n", + "Similarity models are trained like normal models. " + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "147a6863", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1/10\n", + "1875/1875 [==============================] - 11s 4ms/step - loss: 5.2161 - val_loss: 1.8907\n", + "Epoch 2/10\n", + "1875/1875 [==============================] - 7s 4ms/step - loss: 1.8353 - val_loss: 1.6826\n", + "Epoch 3/10\n", + "1875/1875 [==============================] - 7s 4ms/step - loss: 1.3566 - val_loss: 1.1404\n", + "Epoch 4/10\n", + "1875/1875 [==============================] - 7s 4ms/step - loss: 1.1160 - val_loss: 1.0936\n", + "Epoch 5/10\n", + "1875/1875 [==============================] - 7s 4ms/step - loss: 0.9555 - val_loss: 1.0854\n", + "Epoch 6/10\n", + "1875/1875 [==============================] - 7s 4ms/step - loss: 0.8343 - val_loss: 1.0062\n", + "Epoch 7/10\n", + "1875/1875 [==============================] - 7s 4ms/step - loss: 0.7546 - val_loss: 0.9062\n", + "Epoch 8/10\n", + "1875/1875 [==============================] - 7s 4ms/step - loss: 0.6776 - val_loss: 0.8000\n", + "Epoch 9/10\n", + "1875/1875 [==============================] - 7s 4ms/step - loss: 0.6194 - val_loss: 0.8160\n", + "Epoch 10/10\n", + "1875/1875 [==============================] - 7s 4ms/step - loss: 0.5676 - val_loss: 0.7515\n" + ] + } + ], + "source": [ + "EPOCHS = 10 # @param {type:\"integer\"}\n", + "history = model.fit(x_train, y_train, epochs=EPOCHS, validation_data=(x_test, y_test))" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "88e1ee4d", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(history.history[\"loss\"])\n", + "plt.plot(history.history[\"val_loss\"])\n", + "plt.legend([\"loss\", \"val_loss\"])\n", + "plt.title(f\"Loss: {loss.name}\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "id": "5ad4ba20", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "## Prediction\n", + "\n", + "Let's predict some features and visualize them." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "a1936264", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "313/313 [==============================] - 1s 2ms/step\n" + ] + } + ], + "source": [ + "embedded_features = model.predict(x_test, verbose=1)" + ] + }, + { + "cell_type": "markdown", + "id": "7c0df63b", + "metadata": { + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "### 3D-Visualization of ArcFace Loss" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "5aac5d98", + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig = plt.figure()\n", + "ax = Axes3D(fig, auto_add_to_figure=False)\n", + "fig.add_axes(ax)\n", + "for c in range(len(np.unique(y_test))):\n", + " ax.plot(\n", + " embedded_features[y_test==c, 0], \n", + " embedded_features[y_test==c, 1], \n", + " embedded_features[y_test==c, 2], \n", + " '.', \n", + " alpha=0.1,\n", + " )\n", + "plt.title('ArcFace')\n", + "plt.show()" + ] + } + ], + "metadata": { + "environment": { + "kernel": "python3", + "name": "tf2-gpu.2-8.m91", + "type": "gcloud", + "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-8:m91" + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tensorflow_similarity/losses/__init__.py b/tensorflow_similarity/losses/__init__.py index e89a08b2..ebfe2e39 100644 --- a/tensorflow_similarity/losses/__init__.py +++ b/tensorflow_similarity/losses/__init__.py @@ -15,6 +15,7 @@ """ Contrastive learning specialized losses. """ +from .arcface_loss import ArcFaceLoss # noqa from .barlow import Barlow # noqa from .circle_loss import CircleLoss # noqa from .metric_loss import MetricLoss # noqa diff --git a/tensorflow_similarity/losses/arcface_loss.py b/tensorflow_similarity/losses/arcface_loss.py new file mode 100644 index 00000000..2362bba3 --- /dev/null +++ b/tensorflow_similarity/losses/arcface_loss.py @@ -0,0 +1,108 @@ +# Copyright 2022 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""ArcFace losses base class. +ArcFace: Additive Angular Margin Loss for Deep Face +Recognition. [online] arXiv.org. Available at: +. +""" + +from typing import Any, Callable, Dict, Optional + +import tensorflow as tf + +from tensorflow_similarity.types import FloatTensor + + +@tf.keras.utils.register_keras_serializable(package="Similarity") +class ArcFaceLoss(tf.keras.losses.Loss): + """Implement of ArcFace: Additive Angular Margin Loss: + Step 1: Create a trainable kernel matrix with the shape of [embedding_size, num_classes]. + Step 2: Normalize the kernel and prediction vectors. + Step 3: Calculate the cosine similarity between the normalized prediction vector and the kernel. + Step 4: Create a one-hot vector include the margin value for the ground truth class. + Step 5: Add margin_hot to the cosine similarity and multiply it by scale. + Step 6: Calculate the cross-entropy loss. + ArcFace: Additive Angular Margin Loss for Deep Face + Recognition. [online] arXiv.org. Available at: + . + Standalone usage: + >>> loss_fn = tfsim.losses.ArcFaceLoss(num_classes=2, embedding_size=3) + >>> labels = tf.Variable([1, 0]) + >>> embeddings = tf.Variable([[0.2, 0.3, 0.1], [0.4, 0.5, 0.5]]) + >>> loss = loss_fn(labels, embeddings) + Args: + num_classes: Number of classes. + embedding_size: The size of the embedding vectors. + margin: The margin value. + scale: s in the paper, feature scale + name: Optional name for the operation. + reduction: Type of loss reduction to apply to the loss. + """ + + def __init__( + self, + num_classes: int, + embedding_size: int, + margin: float = 0.50, # margin in radians + scale: float = 64.0, # feature scale + name: Optional[str] = None, + reduction: Callable = tf.keras.losses.Reduction.AUTO, + **kwargs + ): + + super().__init__(reduction=reduction, name=name, **kwargs) + + self.num_classes = num_classes + self.embedding_size = embedding_size + self.margin = margin + self.scale = scale + self.name = name + self.kernel = tf.Variable(tf.random.normal([embedding_size, num_classes]), trainable=True) + + def call(self, y_true: FloatTensor, y_pred: FloatTensor) -> FloatTensor: + + y_pred_norm = tf.math.l2_normalize(y_pred, axis=1) + kernel_norm = tf.math.l2_normalize(self.kernel, axis=0) + + cos_theta = tf.matmul(y_pred_norm, kernel_norm) + cos_theta = tf.clip_by_value(cos_theta, -1.0, 1.0) + + m_hot = tf.one_hot(y_true, self.num_classes, on_value=self.margin, axis=1) + m_hot = tf.reshape(m_hot, [-1, self.num_classes]) + + cos_theta = tf.acos(cos_theta) + cos_theta += m_hot + cos_theta = tf.math.cos(cos_theta) + cos_theta = tf.math.multiply(cos_theta, self.scale) + + cce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=self.reduction) + loss: FloatTensor = cce(y_true, cos_theta) + + return loss + + def get_config(self) -> Dict[str, Any]: + """Contains the loss configuration. + Returns: + The configuration of the ArcFace loss. + """ + config = { + "num_classes": self.num_classes, + "embedding_size": self.embedding_size, + "margin": self.margin, + "scale": self.scale, + "name": self.name, + } + base_config = super().get_config() + return {**base_config, **config} diff --git a/tests/test_losses.py b/tests/test_losses.py index 022aecba..ec28b776 100644 --- a/tests/test_losses.py +++ b/tests/test_losses.py @@ -2,6 +2,8 @@ import tensorflow as tf from tensorflow_similarity.losses import ( + + ArcFaceLoss, MultiNegativesRankLoss, MultiSimilarityLoss, PNLoss, @@ -23,7 +25,6 @@ def test_triplet_loss_serialization(): def triplet_hard_loss_np(labels, embedding, margin, dist_func, soft=False): - num_data = embedding.shape[0] # Reshape labels to compute adjacency matrix. labels_reshaped = np.reshape(labels.astype(np.float32), (labels.shape[0], 1)) @@ -248,6 +249,41 @@ def test_xbm_loss(): tf.assert_equal(loss_warm._y_true_memory, labels2) + +# arcface loss +""" +ArcFaceLoss + ArcFace: Additive Angular Margin Loss for Deep Face + Recognition. [online] arXiv.org. Available at: + . +""" + + +def test_arcface_loss_serialization(): + n_classes = 10 + embed_size = 16 + loss = ArcFaceLoss(num_classes=n_classes, embedding_size=embed_size) + config = loss.get_config() + loss2 = ArcFaceLoss.from_config(config) + assert loss.name == loss2.name + assert loss.margin == loss2.margin + assert loss.scale == loss2.scale + assert loss.num_classes == loss2.num_classes + assert loss.embedding_size == loss2.embedding_size + + +def test_arcface_loss(): + tf.random.set_seed(128) + loss_fn = ArcFaceLoss(num_classes=4, embedding_size=5) + labels = tf.Variable([0, 1, 2, 3]) + embeddings = tf.Variable(tf.random.uniform(shape=[4, 5])) + print(embeddings) + + loss = loss_fn(labels, embeddings) + print(loss) + + assert 60.4 < loss.numpy() < 60.5 + # [multiple negatives ranking loss] def test_multineg_rank_loss_serialization(): loss = MultiNegativesRankLoss(distance="inner_product") @@ -255,3 +291,4 @@ def test_multineg_rank_loss_serialization(): loss2 = MultiNegativesRankLoss.from_config(config) assert loss.name == loss2.name assert loss.distance == loss2.distance +