10
10
" \n " ,
11
11
" **Author:** [Aritra Roy Gosthipaty](https://twitter.com/ariG23498)<br>\n " ,
12
12
" **Date created:** 2022/01/07<br>\n " ,
13
- " **Last modified:** 2022/01/10 <br>\n " ,
13
+ " **Last modified:** 2024/11/27 <br>\n " ,
14
14
" **Description:** Training a ViT from scratch on smaller datasets with shifted patch tokenization and locality self-attention."
15
15
]
16
16
},
47
47
" example is inspired from\n " ,
48
48
" [Image classification with Vision Transformer](https://keras.io/examples/vision/image_classification_with_vision_transformer/).\n " ,
49
49
" \n " ,
50
- " _Note_: This example requires TensorFlow 2.6 or higher, as well as\n " ,
51
- " [TensorFlow Addons](https://www.tensorflow.org/addons), which can be\n " ,
52
- " installed using the following command:\n " ,
53
- " \n " ,
54
- " ```python\n " ,
55
- " pip install -qq -U tensorflow-addons\n " ,
56
- " ```"
50
+ " _Note_: This example requires TensorFlow 2.6 or higher."
57
51
]
58
52
},
59
53
{
75
69
"source" : [
76
70
" import math\n " ,
77
71
" import numpy as np\n " ,
72
+ " import keras\n " ,
73
+ " from keras import ops\n " ,
74
+ " from keras import layers\n " ,
78
75
" import tensorflow as tf\n " ,
79
- " from tensorflow import keras\n " ,
80
- " import tensorflow_addons as tfa\n " ,
81
76
" import matplotlib.pyplot as plt\n " ,
82
- " from tensorflow.keras import layers\n " ,
83
77
" \n " ,
84
78
" # Setting seed for reproducibiltiy\n " ,
85
79
" SEED = 42\n " ,
279
273
" shift_width = self.half_patch\n " ,
280
274
" \n " ,
281
275
" # Crop the shifted images and pad them\n " ,
282
- " crop = tf .image.crop_to_bounding_box (\n " ,
276
+ " crop = ops .image.crop_images (\n " ,
283
277
" images,\n " ,
284
- " offset_height =crop_height,\n " ,
285
- " offset_width =crop_width,\n " ,
278
+ " top_cropping =crop_height,\n " ,
279
+ " left_cropping =crop_width,\n " ,
286
280
" target_height=self.image_size - self.half_patch,\n " ,
287
281
" target_width=self.image_size - self.half_patch,\n " ,
288
282
" )\n " ,
289
- " shift_pad = tf .image.pad_to_bounding_box (\n " ,
283
+ " shift_pad = ops .image.pad_images (\n " ,
290
284
" crop,\n " ,
291
- " offset_height =shift_height,\n " ,
292
- " offset_width =shift_width,\n " ,
285
+ " top_padding =shift_height,\n " ,
286
+ " left_padding =shift_width,\n " ,
293
287
" target_height=self.image_size,\n " ,
294
288
" target_width=self.image_size,\n " ,
295
289
" )\n " ,
298
292
" def call(self, images):\n " ,
299
293
" if not self.vanilla:\n " ,
300
294
" # Concat the shifted images with the original image\n " ,
301
- " images = tf.concat (\n " ,
295
+ " images = ops.concatenate (\n " ,
302
296
" [\n " ,
303
297
" images,\n " ,
304
298
" self.crop_shift_pad(images, mode=\" left-up\" ),\n " ,
309
303
" axis=-1,\n " ,
310
304
" )\n " ,
311
305
" # Patchify the images and flatten it\n " ,
312
- " patches = tf .image.extract_patches(\n " ,
306
+ " patches = ops .image.extract_patches(\n " ,
313
307
" images=images,\n " ,
314
- " sizes=[1, self.patch_size, self.patch_size, 1] ,\n " ,
308
+ " size=( self.patch_size, self.patch_size) ,\n " ,
315
309
" strides=[1, self.patch_size, self.patch_size, 1],\n " ,
316
- " rates=[1, 1, 1, 1] ,\n " ,
310
+ " dilation_rate=1 ,\n " ,
317
311
" padding=\" VALID\" ,\n " ,
318
312
" )\n " ,
319
313
" flat_patches = self.flatten_patches(patches)\n " ,
324
318
" else:\n " ,
325
319
" # Linearly project the flat patches\n " ,
326
320
" tokens = self.projection(flat_patches)\n " ,
327
- " return (tokens, patches)\n " ,
328
- " "
321
+ " return (tokens, patches)\n "
329
322
]
330
323
},
331
324
{
348
341
" # Get a random image from the training dataset\n " ,
349
342
" # and resize the image\n " ,
350
343
" image = x_train[np.random.choice(range(x_train.shape[0]))]\n " ,
351
- " resized_image = tf.image.resize(\n " ,
352
- " tf.convert_to_tensor([image]), size=(IMAGE_SIZE, IMAGE_SIZE)\n " ,
344
+ " resized_image = ops.cast(\n " ,
345
+ " ops.image.resize(ops.convert_to_tensor([image]), size=(IMAGE_SIZE, IMAGE_SIZE)),\n " ,
346
+ " dtype=\" float32\" ,\n " ,
353
347
" )\n " ,
354
348
" \n " ,
355
349
" # Vanilla patch maker: This takes an image and divides into\n " ,
363
357
" for col in range(n):\n " ,
364
358
" plt.subplot(n, n, count)\n " ,
365
359
" count = count + 1\n " ,
366
- " image = tf .reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 3))\n " ,
360
+ " image = ops .reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 3))\n " ,
367
361
" plt.imshow(image)\n " ,
368
362
" plt.axis(\" off\" )\n " ,
369
363
" plt.show()\n " ,
382
376
" for col in range(n):\n " ,
383
377
" plt.subplot(n, n, count)\n " ,
384
378
" count = count + 1\n " ,
385
- " image = tf .reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 5 * 3))\n " ,
379
+ " image = ops .reshape(patch[row][col], (PATCH_SIZE, PATCH_SIZE, 5 * 3))\n " ,
386
380
" plt.imshow(image[..., 3 * index : 3 * index + 3])\n " ,
387
381
" plt.axis(\" off\" )\n " ,
388
382
" plt.show()"
418
412
" self.position_embedding = layers.Embedding(\n " ,
419
413
" input_dim=num_patches, output_dim=projection_dim\n " ,
420
414
" )\n " ,
421
- " self.positions = tf.range (start=0, limit =self.num_patches, delta =1)\n " ,
415
+ " self.positions = ops.arange (start=0, stop =self.num_patches, step =1)\n " ,
422
416
" \n " ,
423
417
" def call(self, encoded_patches):\n " ,
424
418
" encoded_positions = self.position_embedding(self.positions)\n " ,
425
419
" encoded_patches = encoded_patches + encoded_positions\n " ,
426
- " return encoded_patches\n " ,
427
- " "
420
+ " return encoded_patches\n "
428
421
]
429
422
},
430
423
{
479
472
"outputs" : [],
480
473
"source" : [
481
474
" \n " ,
482
- " class MultiHeadAttentionLSA(tf.keras. layers.MultiHeadAttention):\n " ,
475
+ " class MultiHeadAttentionLSA(layers.MultiHeadAttention):\n " ,
483
476
" def __init__(self, **kwargs):\n " ,
484
477
" super().__init__(**kwargs)\n " ,
485
478
" # The trainable temperature term. The initial value is\n " ,
486
479
" # the square root of the key dimension.\n " ,
487
- " self.tau = tf .Variable(math.sqrt(float(self._key_dim)), trainable=True)\n " ,
480
+ " self.tau = keras .Variable(math.sqrt(float(self._key_dim)), trainable=True)\n " ,
488
481
" \n " ,
489
482
" def _compute_attention(self, query, key, value, attention_mask=None, training=None):\n " ,
490
- " query = tf .multiply(query, 1.0 / self.tau)\n " ,
491
- " attention_scores = tf .einsum(self._dot_product_equation, key, query)\n " ,
483
+ " query = ops .multiply(query, 1.0 / self.tau)\n " ,
484
+ " attention_scores = ops .einsum(self._dot_product_equation, key, query)\n " ,
492
485
" attention_scores = self._masked_softmax(attention_scores, attention_mask)\n " ,
493
486
" attention_scores_dropout = self._dropout_layer(\n " ,
494
487
" attention_scores, training=training\n " ,
495
488
" )\n " ,
496
- " attention_output = tf .einsum(\n " ,
489
+ " attention_output = ops .einsum(\n " ,
497
490
" self._combine_equation, attention_scores_dropout, value\n " ,
498
491
" )\n " ,
499
- " return attention_output, attention_scores\n " ,
500
- " "
492
+ " return attention_output, attention_scores\n "
501
493
]
502
494
},
503
495
{
520
512
" \n " ,
521
513
" def mlp(x, hidden_units, dropout_rate):\n " ,
522
514
" for units in hidden_units:\n " ,
523
- " x = layers.Dense(units, activation=tf.nn. gelu)(x)\n " ,
515
+ " x = layers.Dense(units, activation=\" gelu\" )(x)\n " ,
524
516
" x = layers.Dropout(dropout_rate)(x)\n " ,
525
517
" return x\n " ,
526
518
" \n " ,
527
519
" \n " ,
528
520
" # Build the diagonal attention mask\n " ,
529
- " diag_attn_mask = 1 - tf .eye(NUM_PATCHES)\n " ,
530
- " diag_attn_mask = tf .cast([diag_attn_mask], dtype=tf. int8)"
521
+ " diag_attn_mask = 1 - ops .eye(NUM_PATCHES)\n " ,
522
+ " diag_attn_mask = ops .cast([diag_attn_mask], dtype=\" int8\" )"
531
523
]
532
524
},
533
525
{
589
581
" logits = layers.Dense(NUM_CLASSES)(features)\n " ,
590
582
" # Create the Keras model.\n " ,
591
583
" model = keras.Model(inputs=inputs, outputs=logits)\n " ,
592
- " return model\n " ,
593
- " "
584
+ " return model\n "
594
585
]
595
586
},
596
587
{
622
613
" self.total_steps = total_steps\n " ,
623
614
" self.warmup_learning_rate = warmup_learning_rate\n " ,
624
615
" self.warmup_steps = warmup_steps\n " ,
625
- " self.pi = tf.constant (np.pi)\n " ,
616
+ " self.pi = ops.array (np.pi)\n " ,
626
617
" \n " ,
627
618
" def __call__(self, step):\n " ,
628
619
" if self.total_steps < self.warmup_steps:\n " ,
629
620
" raise ValueError(\" Total_steps must be larger or equal to warmup_steps.\" )\n " ,
630
621
" \n " ,
631
- " cos_annealed_lr = tf .cos(\n " ,
622
+ " cos_annealed_lr = ops .cos(\n " ,
632
623
" self.pi\n " ,
633
- " * (tf .cast(step, tf. float32) - self.warmup_steps)\n " ,
624
+ " * (ops .cast(step, dtype= \" float32\" ) - self.warmup_steps)\n " ,
634
625
" / float(self.total_steps - self.warmup_steps)\n " ,
635
626
" )\n " ,
636
627
" learning_rate = 0.5 * self.learning_rate_base * (1 + cos_annealed_lr)\n " ,
644
635
" slope = (\n " ,
645
636
" self.learning_rate_base - self.warmup_learning_rate\n " ,
646
637
" ) / self.warmup_steps\n " ,
647
- " warmup_rate = slope * tf.cast(step, tf.float32) + self.warmup_learning_rate\n " ,
648
- " learning_rate = tf.where(\n " ,
638
+ " warmup_rate = (\n " ,
639
+ " slope * ops.cast(step, dtype=\" float32\" ) + self.warmup_learning_rate\n " ,
640
+ " )\n " ,
641
+ " learning_rate = ops.where(\n " ,
649
642
" step < self.warmup_steps, warmup_rate, learning_rate\n " ,
650
643
" )\n " ,
651
- " return tf .where(\n " ,
644
+ " return ops .where(\n " ,
652
645
" step > self.total_steps, 0.0, learning_rate, name=\" learning_rate\"\n " ,
653
646
" )\n " ,
654
647
" \n " ,
664
657
" warmup_steps=warmup_steps,\n " ,
665
658
" )\n " ,
666
659
" \n " ,
667
- " optimizer = tfa .optimizers.AdamW(\n " ,
660
+ " optimizer = keras .optimizers.AdamW(\n " ,
668
661
" learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY\n " ,
669
662
" )\n " ,
670
663
" \n " ,
720
713
" I would like to thank [Jarvislabs.ai](https://jarvislabs.ai/) for\n " ,
721
714
" generously helping with GPU credits.\n " ,
722
715
" \n " ,
723
- " You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/vit_small_ds_v2) " ,
716
+ " You can use the trained model hosted on [Hugging Face Hub](https://huggingface.co/keras-io/vit_small_ds_v2) \n " ,
724
717
" and try the demo on [Hugging Face Spaces](https://huggingface.co/spaces/keras-io/vit-small-ds)."
725
718
]
726
719
}
754
747
},
755
748
"nbformat" : 4 ,
756
749
"nbformat_minor" : 0
757
- }
750
+ }
0 commit comments