1010 " \n " ,
1111 " **Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha), [Shivalika Singh](https://www.linkedin.com/in/shivalika-singh/)<br>\n " ,
1212 " **Date created:** 2022/02/24<br>\n " ,
13- " **Last modified:** 2022/10/15 <br>\n " ,
13+ " **Last modified:** 2024/12/06 <br>\n " ,
1414 " **Description:** A minimal implementation of ShiftViT."
1515 ]
1616 },
3939 " In this example, we minimally implement the paper with close alignement to the author's\n " ,
4040 " [official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py).\n " ,
4141 " \n " ,
42- " This example requires TensorFlow 2.9 or higher, as well as TensorFlow Addons, which can\n " ,
43- " be installed using the following command:"
44- ]
45- },
46- {
47- "cell_type" : " code" ,
48- "execution_count" : 0 ,
49- "metadata" : {
50- "colab_type" : " code"
51- },
52- "outputs" : [],
53- "source" : [
54- " !pip install -qq -U tensorflow-addons"
42+ " This example requires TensorFlow 2.9 or higher."
5543 ]
5644 },
5745 {
7462 " import numpy as np\n " ,
7563 " import matplotlib.pyplot as plt\n " ,
7664 " \n " ,
65+ " import keras\n " ,
66+ " from keras import ops\n " ,
67+ " from keras import layers\n " ,
7768 " import tensorflow as tf\n " ,
78- " from tensorflow import keras\n " ,
79- " from tensorflow.keras import layers\n " ,
80- " import tensorflow_addons as tfa\n " ,
8169 " \n " ,
8270 " import pathlib\n " ,
8371 " import glob\n " ,
237225 " layers.Rescaling(1 / 255.0),\n " ,
238226 " ]\n " ,
239227 " )\n " ,
240- " return data_augmentation\n " ,
241- " "
228+ " return data_augmentation\n "
242229 ]
243230 },
244231 {
341328 " [\n " ,
342329 " layers.Dense(\n " ,
343330 " units=initial_filters,\n " ,
344- " activation=tf.nn. gelu,\n " ,
331+ " activation=\" gelu\" ,\n " ,
345332 " ),\n " ,
346333 " layers.Dropout(rate=self.mlp_dropout_rate),\n " ,
347334 " layers.Dense(units=input_channels),\n " ,
351338 " \n " ,
352339 " def call(self, x):\n " ,
353340 " x = self.mlp(x)\n " ,
354- " return x\n " ,
355- " "
341+ " return x\n "
356342 ]
357343 },
358344 {
389375 " def __init__(self, drop_path_prob, **kwargs):\n " ,
390376 " super().__init__(**kwargs)\n " ,
391377 " self.drop_path_prob = drop_path_prob\n " ,
378+ " self.seed_generator = keras.random.SeedGenerator(1337)\n " ,
392379 " \n " ,
393380 " def call(self, x, training=False):\n " ,
394381 " if training:\n " ,
395382 " keep_prob = 1 - self.drop_path_prob\n " ,
396- " shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)\n " ,
397- " random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)\n " ,
398- " random_tensor = tf.floor(random_tensor)\n " ,
383+ " shape = (ops.shape(x)[0],) + (1,) * (len(ops.shape(x)) - 1)\n " ,
384+ " random_tensor = keep_prob + keras.random.uniform(\n " ,
385+ " shape, 0, 1, seed=self.seed_generator\n " ,
386+ " )\n " ,
387+ " random_tensor = ops.floor(random_tensor)\n " ,
399388 " return (x / keep_prob) * random_tensor\n " ,
400- " return x\n " ,
401- " "
389+ " return x\n "
402390 ]
403391 },
404392 {
523511 " offset_width = 0\n " ,
524512 " target_height = self.shift_pixel\n " ,
525513 " target_width = 0\n " ,
526- " crop = tf .image.crop_to_bounding_box (\n " ,
514+ " crop = ops .image.crop_images (\n " ,
527515 " x,\n " ,
528- " offset_height =offset_height,\n " ,
529- " offset_width =offset_width,\n " ,
516+ " top_cropping =offset_height,\n " ,
517+ " left_cropping =offset_width,\n " ,
530518 " target_height=self.H - target_height,\n " ,
531519 " target_width=self.W - target_width,\n " ,
532520 " )\n " ,
533- " shift_pad = tf .image.pad_to_bounding_box (\n " ,
521+ " shift_pad = ops .image.pad_images (\n " ,
534522 " crop,\n " ,
535- " offset_height =offset_height,\n " ,
536- " offset_width =offset_width,\n " ,
523+ " top_padding =offset_height,\n " ,
524+ " left_padding =offset_width,\n " ,
537525 " target_height=self.H,\n " ,
538526 " target_width=self.W,\n " ,
539527 " )\n " ,
540528 " return shift_pad\n " ,
541529 " \n " ,
542530 " def call(self, x, training=False):\n " ,
543531 " # Split the feature maps\n " ,
544- " x_splits = tf .split(x, num_or_size_splits =self.C // self.num_div, axis=-1)\n " ,
532+ " x_splits = ops .split(x, indices_or_sections =self.C // self.num_div, axis=-1)\n " ,
545533 " \n " ,
546534 " # Shift the feature maps\n " ,
547535 " x_splits[0] = self.get_shift_pad(x_splits[0], mode=\" left\" )\n " ,
550538 " x_splits[3] = self.get_shift_pad(x_splits[3], mode=\" down\" )\n " ,
551539 " \n " ,
552540 " # Concatenate the shifted and unshifted feature maps\n " ,
553- " x = tf.concat (x_splits, axis=-1)\n " ,
541+ " x = ops.concatenate (x_splits, axis=-1)\n " ,
554542 " \n " ,
555543 " # Add the residual connection\n " ,
556544 " shortcut = x\n " ,
557545 " x = shortcut + self.drop_path(self.mlp(self.layer_norm(x)), training=training)\n " ,
558- " return x\n " ,
559- " "
546+ " return x\n "
560547 ]
561548 },
562549 {
622609 " # Apply the patch merging algorithm on the feature maps\n " ,
623610 " x = self.layer_norm(x)\n " ,
624611 " x = self.reduction(x)\n " ,
625- " return x\n " ,
626- " "
612+ " return x\n "
627613 ]
628614 },
629615 {
737723 " \" mlp_expand_ratio\" : self.mlp_expand_ratio,\n " ,
738724 " }\n " ,
739725 " )\n " ,
740- " return config\n " ,
741- " "
726+ " return config\n "
742727 ]
743728 },
744729 {
903888 " x = stage(x, training=False)\n " ,
904889 " x = self.global_avg_pool(x)\n " ,
905890 " logits = self.classifier(x)\n " ,
906- " return logits\n " ,
907- " "
891+ " return logits\n "
908892 ]
909893 },
910894 {
979963 " self.lr_max = lr_max\n " ,
980964 " self.warmup_steps = warmup_steps\n " ,
981965 " self.total_steps = total_steps\n " ,
982- " self.pi = tf.constant (np.pi)\n " ,
966+ " self.pi = ops.array (np.pi)\n " ,
983967 " \n " ,
984968 " def __call__(self, step):\n " ,
985969 " # Check whether the total number of steps is larger than the warmup\n " ,
993977 " # `cos_annealed_lr` is a graph that increases to 1 from the initial\n " ,
994978 " # step to the warmup step. After that this graph decays to -1 at the\n " ,
995979 " # final step mark.\n " ,
996- " cos_annealed_lr = tf .cos(\n " ,
980+ " cos_annealed_lr = ops .cos(\n " ,
997981 " self.pi\n " ,
998- " * (tf .cast(step, tf. float32) - self.warmup_steps)\n " ,
999- " / tf .cast(self.total_steps - self.warmup_steps, tf. float32)\n " ,
982+ " * (ops .cast(step, dtype= \" float32\" ) - self.warmup_steps)\n " ,
983+ " / ops .cast(self.total_steps - self.warmup_steps, dtype= \" float32\" )\n " ,
1000984 " )\n " ,
1001985 " \n " ,
1002986 " # Shift the mean of the `cos_annealed_lr` graph to 1. Now the grpah goes\n " ,
10211005 " \n " ,
10221006 " # With the formula for a straight line (y = mx+c) build the warmup\n " ,
10231007 " # schedule\n " ,
1024- " warmup_rate = slope * tf .cast(step, tf. float32) + self.lr_start\n " ,
1008+ " warmup_rate = slope * ops .cast(step, dtype= \" float32\" ) + self.lr_start\n " ,
10251009 " \n " ,
10261010 " # When the current step is lesser that warmup steps, get the line\n " ,
10271011 " # graph. When the current step is greater than the warmup steps, get\n " ,
10281012 " # the scaled cos graph.\n " ,
1029- " learning_rate = tf .where(\n " ,
1013+ " learning_rate = ops .where(\n " ,
10301014 " step < self.warmup_steps, warmup_rate, learning_rate\n " ,
10311015 " )\n " ,
10321016 " \n " ,
10331017 " # When the current step is more that the total steps, return 0 else return\n " ,
10341018 " # the calculated graph.\n " ,
1035- " return tf.where(\n " ,
1036- " step > self.total_steps, 0.0, learning_rate, name=\" learning_rate\"\n " ,
1037- " )\n " ,
1019+ " return ops.where(step > self.total_steps, 0.0, learning_rate)\n " ,
10381020 " \n " ,
10391021 " def get_config(self):\n " ,
10401022 " config = {\n " ,
10431025 " \" total_steps\" : self.total_steps,\n " ,
10441026 " \" warmup_steps\" : self.warmup_steps,\n " ,
10451027 " }\n " ,
1046- " return config\n " ,
1047- " "
1028+ " return config\n "
10481029 ]
10491030 },
10501031 {
10851066 " )\n " ,
10861067 " \n " ,
10871068 " # Get the optimizer.\n " ,
1088- " optimizer = tfa .optimizers.AdamW(\n " ,
1069+ " optimizer = keras .optimizers.AdamW(\n " ,
10891070 " learning_rate=scheduled_lrs, weight_decay=config.weight_decay\n " ,
10901071 " )\n " ,
10911072 " \n " ,
11421123 },
11431124 "outputs" : [],
11441125 "source" : [
1145- " model.save (\" ShiftViT\" )"
1126+ " model.export (\" ShiftViT\" )"
11461127 ]
11471128 },
11481129 {
11921173 },
11931174 "outputs" : [],
11941175 "source" : [
1195- " # Custom objects are not included when the model is saved.\n " ,
1196- " # At loading time, these objects need to be passed for reconstruction of the model\n " ,
1197- " saved_model = tf.keras.models.load_model(\n " ,
1198- " \" ShiftViT\" ,\n " ,
1199- " custom_objects={\" WarmUpCosine\" : WarmUpCosine, \" AdamW\" : tfa.optimizers.AdamW},\n " ,
1200- " )"
1176+ " # Using TFSMLayer to reload the TF SavedModel as a Keras layer.\n " ,
1177+ " # This is not limited to SavedModels that originate from Keras – it will work with any SavedModel, e.g. TF-Hub models.\n " ,
1178+ " saved_model = keras.layers.TFSMLayer(\" ShiftViT\" , call_endpoint=\" serving_default\" )"
12011179 ]
12021180 },
12031181 {
12261204 " img = tf.io.decode_jpeg(img, channels=3)\n " ,
12271205 " \n " ,
12281206 " # resize image to match input size accepted by model\n " ,
1229- " # use `method ` as `nearest` to preserve dtype of input passed to `resize()`\n " ,
1230- " img = tf .image.resize(\n " ,
1231- " img, [config.input_shape[0], config.input_shape[1]], method =\" nearest\"\n " ,
1207+ " # use `interpolation ` as `nearest` to preserve dtype of input passed to `resize()`\n " ,
1208+ " img = ops .image.resize(\n " ,
1209+ " img, [config.input_shape[0], config.input_shape[1]], interpolation =\" nearest\"\n " ,
12321210 " )\n " ,
12331211 " return img\n " ,
12341212 " \n " ,
12501228 " \n " ,
12511229 " def predict(predict_ds):\n " ,
12521230 " # ShiftViT model returns logits (non-normalized predictions)\n " ,
1253- " logits = saved_model.predict(predict_ds)\n " ,
1231+ " model = keras.Sequential([saved_model])\n " ,
1232+ " output_dict = model.predict(predict_ds)\n " ,
1233+ " logits = list(output_dict.values())[0]\n " ,
12541234 " \n " ,
12551235 " # normalize predictions by calling softmax()\n " ,
1256- " probabilities = tf.nn .softmax(logits)\n " ,
1236+ " probabilities = ops .softmax(logits)\n " ,
12571237 " return probabilities\n " ,
12581238 " \n " ,
12591239 " \n " ,
12701250 " config.label_map[label]: np.round((probabilities[label]) * 100, 2)\n " ,
12711251 " for label in labels\n " ,
12721252 " }\n " ,
1273- " return confidences\n " ,
1274- " "
1253+ " return confidences\n "
12751254 ]
12761255 },
12771256 {
13981377 },
13991378 "nbformat" : 4 ,
14001379 "nbformat_minor" : 0
1401- }
1380+ }
0 commit comments