Skip to content

Commit 8a55f0a

Browse files
Migrating Train a Vision Transformer on small datasets example to Keras 3 (#1991)
* migrate vit small dataset example to keras3 * requested changes added and backend agnostic done * other generated files are added too
1 parent ede75bc commit 8a55f0a

File tree

3 files changed

+121
-133
lines changed

3 files changed

+121
-133
lines changed

examples/vision/ipynb/vit_small_ds.ipynb

+44-51
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"\n",
1111
"**Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498)<br>\n",
1212
"**Date created:** 2022/01/07<br>\n",
13-
"**Last modified:** 2022/01/10<br>\n",
13+
"**Last modified:** 2024/11/27<br>\n",
1414
"**Description:** Training a ViT from scratch on smaller datasets with shifted patch tokenization and locality self-attention."
1515
]
1616
},
@@ -47,13 +47,7 @@
4747
"example is inspired from\n",
4848
"[Image classification with Vision Transformer](https://keras.io/examples/vision/image_classification_with_vision_transformer/).\n",
4949
"\n",
50-
"_Note_: This example requires TensorFlow 2.6 or higher, as well as\n",
51-
"[TensorFlow Addons](https://www.tensorflow.org/addons), which can be\n",
52-
"installed using the following command:\n",
53-
"\n",
54-
"```python\n",
55-
"pip install -qq -U tensorflow-addons\n",
56-
"```"
50+
"_Note_: This example requires TensorFlow 2.6 or higher."
5751
]
5852
},
5953
{
@@ -75,11 +69,11 @@
7569
"source": [
7670
"import math\n",
7771
"import numpy as np\n",
72+
"import keras\n",
73+
"from keras import ops\n",
74+
"from keras import layers\n",
7875
"import tensorflow as tf\n",
79-
"from tensorflow import keras\n",
80-
"import tensorflow_addons as tfa\n",
8176
"import matplotlib.pyplot as plt\n",
82-
"from tensorflow.keras import layers\n",
8377
"\n",
8478
"# Setting seed for reproducibiltiy\n",
8579
"SEED = 42\n",
@@ -279,17 +273,17 @@
279273
" shift_width = self.half_patch\n",
280274
"\n",
281275
" # Crop the shifted images and pad them\n",
282-
" crop = tf.image.crop_to_bounding_box(\n",
276+
" crop = ops.image.crop_images(\n",
283277
" images,\n",
284-
" offset_height=crop_height,\n",
285-
" offset_width=crop_width,\n",
278+
" top_cropping=crop_height,\n",
279+
" left_cropping=crop_width,\n",
286280
" target_height=self.image_size - self.half_patch,\n",
287281
" target_width=self.image_size - self.half_patch,\n",
288282
" )\n",
289-
" shift_pad = tf.image.pad_to_bounding_box(\n",
283+
" shift_pad = ops.image.pad_images(\n",
290284
" crop,\n",
291-
" offset_height=shift_height,\n",
292-
" offset_width=shift_width,\n",
285+
" top_padding=shift_height,\n",
286+
" left_padding=shift_width,\n",
293287
" target_height=self.image_size,\n",
294288
" target_width=self.image_size,\n",
295289
" )\n",
@@ -298,7 +292,7 @@
298292
" def call(self, images):\n",
299293
" if not self.vanilla:\n",
300294
" # Concat the shifted images with the original image\n",
301-
" images = tf.concat(\n",
295+
" images = ops.concatenate(\n",
302296
" [\n",
303297
" images,\n",
304298
" self.crop_shift_pad(images, mode=\"left-up\"),\n",
@@ -309,11 +303,11 @@
309303
" axis=-1,\n",
310304
" )\n",
311305
" # Patchify the images and flatten it\n",
312-
" patches = tf.image.extract_patches(\n",
306+
" patches = ops.image.extract_patches(\n",
313307
" images=images,\n",
314-
" sizes=[1, self.patch_size, self.patch_size, 1],\n",
308+
" size=(self.patch_size, self.patch_size),\n",
315309
" strides=[1, self.patch_size, self.patch_size, 1],\n",
316-
" rates=[1, 1, 1, 1],\n",
310+
" dilation_rate=1,\n",
317311
" padding=\"VALID\",\n",
318312
" )\n",
319313
" flat_patches = self.flatten_patches(patches)\n",
@@ -324,8 +318,7 @@
324318
" else:\n",
325319
" # Linearly project the flat patches\n",
326320
" tokens = self.projection(flat_patches)\n",
327-
" return (tokens, patches)\n",
328-
""
321+
" return (tokens, patches)\n"
329322
]
330323
},
331324
{
@@ -348,8 +341,9 @@
348341
"# Get a random image from the training dataset\n",
349342
"# and resize the image\n",
350343
"image = x_train[np.random.choice(range(x_train.shape[0]))]\n",
351-
"resized_image = tf.image.resize(\n",
352-
" tf.convert_to_tensor([image]), size=(IMAGE_SIZE, IMAGE_SIZE)\n",
344+
"resized_image = ops.cast(\n",
345+
" ops.image.resize(ops.convert_to_tensor([image]), size=(IMAGE_SIZE, IMAGE_SIZE)),\n",
346+
" dtype=\"float32\",\n",
353347
")\n",
354348
"\n",
355349
"# Vanilla patch maker: This takes an image and divides into\n",
@@ -363,7 +357,7 @@
363357
" for col in range(n):\n",
364358
" plt.subplot(n, n, count)\n",
365359
" count = count + 1\n",
366-
" image = tf.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 3))\n",
360+
" image = ops.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 3))\n",
367361
" plt.imshow(image)\n",
368362
" plt.axis(\"off\")\n",
369363
"plt.show()\n",
@@ -382,7 +376,7 @@
382376
" for col in range(n):\n",
383377
" plt.subplot(n, n, count)\n",
384378
" count = count + 1\n",
385-
" image = tf.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 5 * 3))\n",
379+
" image = ops.reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 5 * 3))\n",
386380
" plt.imshow(image[..., 3 * index : 3 * index + 3])\n",
387381
" plt.axis(\"off\")\n",
388382
" plt.show()"
@@ -418,13 +412,12 @@
418412
" self.position_embedding = layers.Embedding(\n",
419413
" input_dim=num_patches, output_dim=projection_dim\n",
420414
" )\n",
421-
" self.positions = tf.range(start=0, limit=self.num_patches, delta=1)\n",
415+
" self.positions = ops.arange(start=0, stop=self.num_patches, step=1)\n",
422416
"\n",
423417
" def call(self, encoded_patches):\n",
424418
" encoded_positions = self.position_embedding(self.positions)\n",
425419
" encoded_patches = encoded_patches + encoded_positions\n",
426-
" return encoded_patches\n",
427-
""
420+
" return encoded_patches\n"
428421
]
429422
},
430423
{
@@ -479,25 +472,24 @@
479472
"outputs": [],
480473
"source": [
481474
"\n",
482-
"class MultiHeadAttentionLSA(tf.keras.layers.MultiHeadAttention):\n",
475+
"class MultiHeadAttentionLSA(layers.MultiHeadAttention):\n",
483476
" def __init__(self, **kwargs):\n",
484477
" super().__init__(**kwargs)\n",
485478
" # The trainable temperature term. The initial value is\n",
486479
" # the square root of the key dimension.\n",
487-
" self.tau = tf.Variable(math.sqrt(float(self._key_dim)), trainable=True)\n",
480+
" self.tau = keras.Variable(math.sqrt(float(self._key_dim)), trainable=True)\n",
488481
"\n",
489482
" def _compute_attention(self, query, key, value, attention_mask=None, training=None):\n",
490-
" query = tf.multiply(query, 1.0 / self.tau)\n",
491-
" attention_scores = tf.einsum(self._dot_product_equation, key, query)\n",
483+
" query = ops.multiply(query, 1.0 / self.tau)\n",
484+
" attention_scores = ops.einsum(self._dot_product_equation, key, query)\n",
492485
" attention_scores = self._masked_softmax(attention_scores, attention_mask)\n",
493486
" attention_scores_dropout = self._dropout_layer(\n",
494487
" attention_scores, training=training\n",
495488
" )\n",
496-
" attention_output = tf.einsum(\n",
489+
" attention_output = ops.einsum(\n",
497490
" self._combine_equation, attention_scores_dropout, value\n",
498491
" )\n",
499-
" return attention_output, attention_scores\n",
500-
""
492+
" return attention_output, attention_scores\n"
501493
]
502494
},
503495
{
@@ -520,14 +512,14 @@
520512
"\n",
521513
"def mlp(x, hidden_units, dropout_rate):\n",
522514
" for units in hidden_units:\n",
523-
" x = layers.Dense(units, activation=tf.nn.gelu)(x)\n",
515+
" x = layers.Dense(units, activation=\"gelu\")(x)\n",
524516
" x = layers.Dropout(dropout_rate)(x)\n",
525517
" return x\n",
526518
"\n",
527519
"\n",
528520
"# Build the diagonal attention mask\n",
529-
"diag_attn_mask = 1 - tf.eye(NUM_PATCHES)\n",
530-
"diag_attn_mask = tf.cast([diag_attn_mask], dtype=tf.int8)"
521+
"diag_attn_mask = 1 - ops.eye(NUM_PATCHES)\n",
522+
"diag_attn_mask = ops.cast([diag_attn_mask], dtype=\"int8\")"
531523
]
532524
},
533525
{
@@ -589,8 +581,7 @@
589581
" logits = layers.Dense(NUM_CLASSES)(features)\n",
590582
" # Create the Keras model.\n",
591583
" model = keras.Model(inputs=inputs, outputs=logits)\n",
592-
" return model\n",
593-
""
584+
" return model\n"
594585
]
595586
},
596587
{
@@ -622,15 +613,15 @@
622613
" self.total_steps = total_steps\n",
623614
" self.warmup_learning_rate = warmup_learning_rate\n",
624615
" self.warmup_steps = warmup_steps\n",
625-
" self.pi = tf.constant(np.pi)\n",
616+
" self.pi = ops.array(np.pi)\n",
626617
"\n",
627618
" def __call__(self, step):\n",
628619
" if self.total_steps < self.warmup_steps:\n",
629620
" raise ValueError(\"Total_steps must be larger or equal to warmup_steps.\")\n",
630621
"\n",
631-
" cos_annealed_lr = tf.cos(\n",
622+
" cos_annealed_lr = ops.cos(\n",
632623
" self.pi\n",
633-
" * (tf.cast(step, tf.float32) - self.warmup_steps)\n",
624+
" * (ops.cast(step, dtype=\"float32\") - self.warmup_steps)\n",
634625
" / float(self.total_steps - self.warmup_steps)\n",
635626
" )\n",
636627
" learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr)\n",
@@ -644,11 +635,13 @@
644635
" slope = (\n",
645636
" self.learning_rate_base - self.warmup_learning_rate\n",
646637
" ) / self.warmup_steps\n",
647-
" warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate\n",
648-
" learning_rate = tf.where(\n",
638+
" warmup_rate = (\n",
639+
" slope * ops.cast(step, dtype=\"float32\") + self.warmup_learning_rate\n",
640+
" )\n",
641+
" learning_rate = ops.where(\n",
649642
" step < self.warmup_steps, warmup_rate, learning_rate\n",
650643
" )\n",
651-
" return tf.where(\n",
644+
" return ops.where(\n",
652645
" step > self.total_steps, 0.0, learning_rate, name=\"learning_rate\"\n",
653646
" )\n",
654647
"\n",
@@ -664,7 +657,7 @@
664657
" warmup_steps=warmup_steps,\n",
665658
" )\n",
666659
"\n",
667-
" optimizer = tfa.optimizers.AdamW(\n",
660+
" optimizer = keras.optimizers.AdamW(\n",
668661
" learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY\n",
669662
" )\n",
670663
"\n",
@@ -720,7 +713,7 @@
720713
"I would like to thank [Jarvislabs.ai](https://jarvislabs.ai/) for\n",
721714
"generously helping with GPU credits.\n",
722715
"\n",
723-
"You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/vit_small_ds_v2) ",
716+
"You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/vit_small_ds_v2) \n",
724717
"and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/vit-small-ds)."
725718
]
726719
}
@@ -754,4 +747,4 @@
754747
},
755748
"nbformat": 4,
756749
"nbformat_minor": 0
757-
}
750+
}

0 commit comments

Comments
 (0)