Skip to content

Commit 4a08dfb

Browse files
authored
A Vision Transformer without Attention example to Keras 3 (#2000)
* shiftvit migrated to keras 3 * gelu tf ops fixed * generated files are added * keras3 tag added
1 parent 990c611 commit 4a08dfb

File tree

4 files changed

+129
-158
lines changed

4 files changed

+129
-158
lines changed

examples/vision/ipynb/shiftvit.ipynb

Lines changed: 49 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"\n",
1111
"**Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498), [Ritwik Raha](https://twitter.com/ritwik_raha), [Shivalika Singh](https://www.linkedin.com/in/shivalika-singh/)<br>\n",
1212
"**Date created:** 2022/02/24<br>\n",
13-
"**Last modified:** 2022/10/15<br>\n",
13+
"**Last modified:** 2024/12/06<br>\n",
1414
"**Description:** A minimal implementation of ShiftViT."
1515
]
1616
},
@@ -39,19 +39,7 @@
3939
"In this example, we minimally implement the paper with close alignement to the author's\n",
4040
"[official implementation](https://github.com/microsoft/SPACH/blob/main/models/shiftvit.py).\n",
4141
"\n",
42-
"This example requires TensorFlow 2.9 or higher, as well as TensorFlow Addons, which can\n",
43-
"be installed using the following command:"
44-
]
45-
},
46-
{
47-
"cell_type": "code",
48-
"execution_count": 0,
49-
"metadata": {
50-
"colab_type": "code"
51-
},
52-
"outputs": [],
53-
"source": [
54-
"!pip install -qq -U tensorflow-addons"
42+
"This example requires TensorFlow 2.9 or higher."
5543
]
5644
},
5745
{
@@ -74,10 +62,10 @@
7462
"import numpy as np\n",
7563
"import matplotlib.pyplot as plt\n",
7664
"\n",
65+
"import keras\n",
66+
"from keras import ops\n",
67+
"from keras import layers\n",
7768
"import tensorflow as tf\n",
78-
"from tensorflow import keras\n",
79-
"from tensorflow.keras import layers\n",
80-
"import tensorflow_addons as tfa\n",
8169
"\n",
8270
"import pathlib\n",
8371
"import glob\n",
@@ -237,8 +225,7 @@
237225
" layers.Rescaling(1 / 255.0),\n",
238226
" ]\n",
239227
" )\n",
240-
" return data_augmentation\n",
241-
""
228+
" return data_augmentation\n"
242229
]
243230
},
244231
{
@@ -341,7 +328,7 @@
341328
" [\n",
342329
" layers.Dense(\n",
343330
" units=initial_filters,\n",
344-
" activation=tf.nn.gelu,\n",
331+
" activation=\"gelu\",\n",
345332
" ),\n",
346333
" layers.Dropout(rate=self.mlp_dropout_rate),\n",
347334
" layers.Dense(units=input_channels),\n",
@@ -351,8 +338,7 @@
351338
"\n",
352339
" def call(self, x):\n",
353340
" x = self.mlp(x)\n",
354-
" return x\n",
355-
""
341+
" return x\n"
356342
]
357343
},
358344
{
@@ -389,16 +375,18 @@
389375
" def __init__(self, drop_path_prob, **kwargs):\n",
390376
" super().__init__(**kwargs)\n",
391377
" self.drop_path_prob = drop_path_prob\n",
378+
" self.seed_generator = keras.random.SeedGenerator(1337)\n",
392379
"\n",
393380
" def call(self, x, training=False):\n",
394381
" if training:\n",
395382
" keep_prob = 1 - self.drop_path_prob\n",
396-
" shape = (tf.shape(x)[0],) + (1,) * (len(tf.shape(x)) - 1)\n",
397-
" random_tensor = keep_prob + tf.random.uniform(shape, 0, 1)\n",
398-
" random_tensor = tf.floor(random_tensor)\n",
383+
" shape = (ops.shape(x)[0],) + (1,) * (len(ops.shape(x)) - 1)\n",
384+
" random_tensor = keep_prob + keras.random.uniform(\n",
385+
" shape, 0, 1, seed=self.seed_generator\n",
386+
" )\n",
387+
" random_tensor = ops.floor(random_tensor)\n",
399388
" return (x / keep_prob) * random_tensor\n",
400-
" return x\n",
401-
""
389+
" return x\n"
402390
]
403391
},
404392
{
@@ -523,25 +511,25 @@
523511
" offset_width = 0\n",
524512
" target_height = self.shift_pixel\n",
525513
" target_width = 0\n",
526-
" crop = tf.image.crop_to_bounding_box(\n",
514+
" crop = ops.image.crop_images(\n",
527515
" x,\n",
528-
" offset_height=offset_height,\n",
529-
" offset_width=offset_width,\n",
516+
" top_cropping=offset_height,\n",
517+
" left_cropping=offset_width,\n",
530518
" target_height=self.H - target_height,\n",
531519
" target_width=self.W - target_width,\n",
532520
" )\n",
533-
" shift_pad = tf.image.pad_to_bounding_box(\n",
521+
" shift_pad = ops.image.pad_images(\n",
534522
" crop,\n",
535-
" offset_height=offset_height,\n",
536-
" offset_width=offset_width,\n",
523+
" top_padding=offset_height,\n",
524+
" left_padding=offset_width,\n",
537525
" target_height=self.H,\n",
538526
" target_width=self.W,\n",
539527
" )\n",
540528
" return shift_pad\n",
541529
"\n",
542530
" def call(self, x, training=False):\n",
543531
" # Split the feature maps\n",
544-
" x_splits = tf.split(x, num_or_size_splits=self.C // self.num_div, axis=-1)\n",
532+
" x_splits = ops.split(x, indices_or_sections=self.C // self.num_div, axis=-1)\n",
545533
"\n",
546534
" # Shift the feature maps\n",
547535
" x_splits[0] = self.get_shift_pad(x_splits[0], mode=\"left\")\n",
@@ -550,13 +538,12 @@
550538
" x_splits[3] = self.get_shift_pad(x_splits[3], mode=\"down\")\n",
551539
"\n",
552540
" # Concatenate the shifted and unshifted feature maps\n",
553-
" x = tf.concat(x_splits, axis=-1)\n",
541+
" x = ops.concatenate(x_splits, axis=-1)\n",
554542
"\n",
555543
" # Add the residual connection\n",
556544
" shortcut = x\n",
557545
" x = shortcut + self.drop_path(self.mlp(self.layer_norm(x)), training=training)\n",
558-
" return x\n",
559-
""
546+
" return x\n"
560547
]
561548
},
562549
{
@@ -622,8 +609,7 @@
622609
" # Apply the patch merging algorithm on the feature maps\n",
623610
" x = self.layer_norm(x)\n",
624611
" x = self.reduction(x)\n",
625-
" return x\n",
626-
""
612+
" return x\n"
627613
]
628614
},
629615
{
@@ -737,8 +723,7 @@
737723
" \"mlp_expand_ratio\": self.mlp_expand_ratio,\n",
738724
" }\n",
739725
" )\n",
740-
" return config\n",
741-
""
726+
" return config\n"
742727
]
743728
},
744729
{
@@ -903,8 +888,7 @@
903888
" x = stage(x, training=False)\n",
904889
" x = self.global_avg_pool(x)\n",
905890
" logits = self.classifier(x)\n",
906-
" return logits\n",
907-
""
891+
" return logits\n"
908892
]
909893
},
910894
{
@@ -979,7 +963,7 @@
979963
" self.lr_max = lr_max\n",
980964
" self.warmup_steps = warmup_steps\n",
981965
" self.total_steps = total_steps\n",
982-
" self.pi = tf.constant(np.pi)\n",
966+
" self.pi = ops.array(np.pi)\n",
983967
"\n",
984968
" def __call__(self, step):\n",
985969
" # Check whether the total number of steps is larger than the warmup\n",
@@ -993,10 +977,10 @@
993977
" # `cos_annealed_lr` is a graph that increases to 1 from the initial\n",
994978
" # step to the warmup step. After that this graph decays to -1 at the\n",
995979
" # final step mark.\n",
996-
" cos_annealed_lr = tf.cos(\n",
980+
" cos_annealed_lr = ops.cos(\n",
997981
" self.pi\n",
998-
" * (tf.cast(step, tf.float32) - self.warmup_steps)\n",
999-
" / tf.cast(self.total_steps - self.warmup_steps, tf.float32)\n",
982+
" * (ops.cast(step, dtype=\"float32\") - self.warmup_steps)\n",
983+
" / ops.cast(self.total_steps - self.warmup_steps, dtype=\"float32\")\n",
1000984
" )\n",
1001985
"\n",
1002986
" # Shift the mean of the `cos_annealed_lr` graph to 1. Now the grpah goes\n",
@@ -1021,20 +1005,18 @@
10211005
"\n",
10221006
" # With the formula for a straight line (y = mx+c) build the warmup\n",
10231007
" # schedule\n",
1024-
" warmup_rate = slope * tf.cast(step, tf.float32) + self.lr_start\n",
1008+
" warmup_rate = slope * ops.cast(step, dtype=\"float32\") + self.lr_start\n",
10251009
"\n",
10261010
" # When the current step is lesser that warmup steps, get the line\n",
10271011
" # graph. When the current step is greater than the warmup steps, get\n",
10281012
" # the scaled cos graph.\n",
1029-
" learning_rate = tf.where(\n",
1013+
" learning_rate = ops.where(\n",
10301014
" step < self.warmup_steps, warmup_rate, learning_rate\n",
10311015
" )\n",
10321016
"\n",
10331017
" # When the current step is more that the total steps, return 0 else return\n",
10341018
" # the calculated graph.\n",
1035-
" return tf.where(\n",
1036-
" step > self.total_steps, 0.0, learning_rate, name=\"learning_rate\"\n",
1037-
" )\n",
1019+
" return ops.where(step > self.total_steps, 0.0, learning_rate)\n",
10381020
"\n",
10391021
" def get_config(self):\n",
10401022
" config = {\n",
@@ -1043,8 +1025,7 @@
10431025
" \"total_steps\": self.total_steps,\n",
10441026
" \"warmup_steps\": self.warmup_steps,\n",
10451027
" }\n",
1046-
" return config\n",
1047-
""
1028+
" return config\n"
10481029
]
10491030
},
10501031
{
@@ -1085,7 +1066,7 @@
10851066
")\n",
10861067
"\n",
10871068
"# Get the optimizer.\n",
1088-
"optimizer = tfa.optimizers.AdamW(\n",
1069+
"optimizer = keras.optimizers.AdamW(\n",
10891070
" learning_rate=scheduled_lrs, weight_decay=config.weight_decay\n",
10901071
")\n",
10911072
"\n",
@@ -1142,7 +1123,7 @@
11421123
},
11431124
"outputs": [],
11441125
"source": [
1145-
"model.save(\"ShiftViT\")"
1126+
"model.export(\"ShiftViT\")"
11461127
]
11471128
},
11481129
{
@@ -1192,12 +1173,9 @@
11921173
},
11931174
"outputs": [],
11941175
"source": [
1195-
"# Custom objects are not included when the model is saved.\n",
1196-
"# At loading time, these objects need to be passed for reconstruction of the model\n",
1197-
"saved_model = tf.keras.models.load_model(\n",
1198-
" \"ShiftViT\",\n",
1199-
" custom_objects={\"WarmUpCosine\": WarmUpCosine, \"AdamW\": tfa.optimizers.AdamW},\n",
1200-
")"
1176+
"# Using TFSMLayer to reload the TF SavedModel as a Keras layer.\n",
1177+
"# This is not limited to SavedModels that originate from Keras – it will work with any SavedModel, e.g. TF-Hub models.\n",
1178+
"saved_model = keras.layers.TFSMLayer(\"ShiftViT\", call_endpoint=\"serving_default\")"
12011179
]
12021180
},
12031181
{
@@ -1226,9 +1204,9 @@
12261204
" img = tf.io.decode_jpeg(img, channels=3)\n",
12271205
"\n",
12281206
" # resize image to match input size accepted by model\n",
1229-
" # use `method` as `nearest` to preserve dtype of input passed to `resize()`\n",
1230-
" img = tf.image.resize(\n",
1231-
" img, [config.input_shape[0], config.input_shape[1]], method=\"nearest\"\n",
1207+
" # use `interpolation` as `nearest` to preserve dtype of input passed to `resize()`\n",
1208+
" img = ops.image.resize(\n",
1209+
" img, [config.input_shape[0], config.input_shape[1]], interpolation=\"nearest\"\n",
12321210
" )\n",
12331211
" return img\n",
12341212
"\n",
@@ -1250,10 +1228,12 @@
12501228
"\n",
12511229
"def predict(predict_ds):\n",
12521230
" # ShiftViT model returns logits (non-normalized predictions)\n",
1253-
" logits = saved_model.predict(predict_ds)\n",
1231+
" model = keras.Sequential([saved_model])\n",
1232+
" output_dict = model.predict(predict_ds)\n",
1233+
" logits = list(output_dict.values())[0]\n",
12541234
"\n",
12551235
" # normalize predictions by calling softmax()\n",
1256-
" probabilities = tf.nn.softmax(logits)\n",
1236+
" probabilities = ops.softmax(logits)\n",
12571237
" return probabilities\n",
12581238
"\n",
12591239
"\n",
@@ -1270,8 +1250,7 @@
12701250
" config.label_map[label]: np.round((probabilities[label]) * 100, 2)\n",
12711251
" for label in labels\n",
12721252
" }\n",
1273-
" return confidences\n",
1274-
""
1253+
" return confidences\n"
12751254
]
12761255
},
12771256
{
@@ -1398,4 +1377,4 @@
13981377
},
13991378
"nbformat": 4,
14001379
"nbformat_minor": 0
1401-
}
1380+
}

0 commit comments

Comments
 (0)