Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions examples/vision/focal_modulation_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,14 +78,26 @@
import os

# Set backend before importing keras
os.environ["KERAS_BACKEND"] = "tensorflow" # Or "torch" or "tensorflow"
os.environ["KERAS_BACKEND"] = "tensorflow" # or "torch" or "jax"
# Suppress TensorFlow C++ logging (XLA messages)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import numpy as np
import keras
from keras import layers
from keras import ops
from matplotlib import pyplot as plt
from random import randint
import warnings
import logging

# Suppress warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", message=".*build.*method.*")
warnings.filterwarnings("ignore", message=".*tf.function.*retracing.*")

# Suppress TensorFlow logging
logging.getLogger("tensorflow").setLevel(logging.ERROR)

# Set seed for reproducibility using Keras 3 utility.
keras.utils.set_random_seed(42)
Expand Down Expand Up @@ -837,7 +849,7 @@ def __call__(self, step):
train_ds,
validation_data=val_ds,
epochs=EPOCHS,
callbacks=[TrainMonitor(epoch_interval=5)],
callbacks=[],
)

"""
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 14 additions & 2 deletions examples/vision/ipynb/focal_modulation_network.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -102,14 +102,26 @@
"import os\n",
"\n",
"# Set backend before importing keras\n",
"os.environ[\"KERAS_BACKEND\"] = \"tensorflow\" # Or \"torch\" or \"tensorflow\"\n",
"os.environ[\"KERAS_BACKEND\"] = \"tensorflow\" # or \"torch\" or \"jax\"\n",
"# Suppress TensorFlow C++ logging (XLA messages)\n",
"os.environ[\"TF_CPP_MIN_LOG_LEVEL\"] = \"3\"\n",
"\n",
"import numpy as np\n",
"import keras\n",
"from keras import layers\n",
"from keras import ops\n",
"from matplotlib import pyplot as plt\n",
"from random import randint\n",
"import warnings\n",
"import logging\n",
"\n",
"# Suppress warnings\n",
"warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
"warnings.filterwarnings(\"ignore\", message=\".*build.*method.*\")\n",
"warnings.filterwarnings(\"ignore\", message=\".*tf.function.*retracing.*\")\n",
"\n",
"# Suppress TensorFlow logging\n",
"logging.getLogger(\"tensorflow\").setLevel(logging.ERROR)\n",
"\n",
"# Set seed for reproducibility using Keras 3 utility.\n",
"keras.utils.set_random_seed(42)"
Expand Down Expand Up @@ -1059,7 +1071,7 @@
" train_ds,\n",
" validation_data=val_ds,\n",
" epochs=EPOCHS,\n",
" callbacks=[TrainMonitor(epoch_interval=5)],\n",
" callbacks=[],\n",
")"
]
},
Expand Down
127 changes: 37 additions & 90 deletions examples/vision/md/focal_modulation_network.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,26 @@ Keras 3 allows this model to run on JAX, PyTorch, or TensorFlow. We use keras.op
import os

# Set backend before importing keras
os.environ["KERAS_BACKEND"] = "tensorflow" # Or "torch" or "tensorflow"
os.environ["KERAS_BACKEND"] = "tensorflow" # or "torch" or "jax"
# Suppress TensorFlow C++ logging (XLA messages)
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

import numpy as np
import keras
from keras import layers
from keras import ops
from matplotlib import pyplot as plt
from random import randint
import warnings
import logging

# Suppress warnings
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", message=".*build.*method.*")
warnings.filterwarnings("ignore", message=".*tf.function.*retracing.*")

# Suppress TensorFlow logging
logging.getLogger("tensorflow").setLevel(logging.ERROR)

# Set seed for reproducibility using Keras 3 utility.
keras.utils.set_random_seed(42)
Expand Down Expand Up @@ -847,161 +859,96 @@ history = model.fit(
train_ds,
validation_data=val_ds,
epochs=EPOCHS,
callbacks=[TrainMonitor(epoch_interval=5)],
callbacks=[],
)
```

<div class="k-default-codeblock">
```
Epoch 1/20

/Users/lakshmikala/node2vec_env/lib/python3.12/site-packages/keras/src/layers/layer.py:424: UserWarning: `build()` was called on layer 'focal_modulation_block', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
warnings.warn(
/Users/lakshmikala/node2vec_env/lib/python3.12/site-packages/keras/src/layers/layer.py:424: UserWarning: `build()` was called on layer 'focal_modulation_block_1', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
warnings.warn(

/Users/lakshmikala/node2vec_env/lib/python3.12/site-packages/keras/src/layers/layer.py:424: UserWarning: `build()` was called on layer 'focal_modulation_block_2', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
warnings.warn(
/Users/lakshmikala/node2vec_env/lib/python3.12/site-packages/keras/src/layers/layer.py:424: UserWarning: `build()` was called on layer 'focal_modulation_block_3', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
warnings.warn(

/Users/lakshmikala/node2vec_env/lib/python3.12/site-packages/keras/src/layers/layer.py:424: UserWarning: `build()` was called on layer 'focal_modulation_block_4', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
warnings.warn(

/Users/lakshmikala/node2vec_env/lib/python3.12/site-packages/keras/src/layers/layer.py:424: UserWarning: `build()` was called on layer 'focal_modulation_block_5', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
warnings.warn(
/Users/lakshmikala/node2vec_env/lib/python3.12/site-packages/keras/src/layers/layer.py:424: UserWarning: `build()` was called on layer 'focal_modulation_block_6', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
warnings.warn(

/Users/lakshmikala/node2vec_env/lib/python3.12/site-packages/keras/src/layers/layer.py:424: UserWarning: `build()` was called on layer 'focal_modulation_network', however the layer does not have a `build()` method implemented and it looks like it has unbuilt state. This will cause the layer to be marked as built, despite not being actually built, which may cause failures down the line. Make sure to implement a proper `build()` method.
warnings.warn(

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1770700186.220793 2002752 service.cc:152] XLA service 0x16cf639d0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1770700186.220808 2002752 service.cc:160] StreamExecutor device (0): Host, Default Version
I0000 00:00:1770700186.251643 2002752 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
I0000 00:00:1772431619.408199 4985427 service.cc:152] XLA service 0x166d7d250 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1772431619.408213 4985427 service.cc:160] StreamExecutor device (0): Host, Default Version
I0000 00:00:1772431619.566351 4985427 device_compiler.h:188] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.

313/313 ━━━━━━━━━━━━━━━━━━━━ 312s 964ms/step - accuracy: 0.1826 - loss: 2.1990 - val_accuracy: 0.2426 - val_loss: 2.0434
313/313 ━━━━━━━━━━━━━━━━━━━━ 310s 958ms/step - accuracy: 0.1826 - loss: 2.1990 - val_accuracy: 0.2426 - val_loss: 2.0434

Epoch 2/20

313/313 ━━━━━━━━━━━━━━━━━━━━ 302s 964ms/step - accuracy: 0.2891 - loss: 1.8906 - val_accuracy: 0.3191 - val_loss: 1.8333
313/313 ━━━━━━━━━━━━━━━━━━━━ 301s 963ms/step - accuracy: 0.2891 - loss: 1.8906 - val_accuracy: 0.3191 - val_loss: 1.8333

Epoch 3/20

313/313 ━━━━━━━━━━━━━━━━━━━━ 303s 968ms/step - accuracy: 0.3669 - loss: 1.7095 - val_accuracy: 0.3869 - val_loss: 1.6693
313/313 ━━━━━━━━━━━━━━━━━━━━ 311s 994ms/step - accuracy: 0.3669 - loss: 1.7095 - val_accuracy: 0.3869 - val_loss: 1.6693

Epoch 4/20

313/313 ━━━━━━━━━━━━━━━━━━━━ 308s 984ms/step - accuracy: 0.4221 - loss: 1.5685 - val_accuracy: 0.4188 - val_loss: 1.5894
313/313 ━━━━━━━━━━━━━━━━━━━━ 305s 975ms/step - accuracy: 0.4221 - loss: 1.5685 - val_accuracy: 0.4188 - val_loss: 1.5894

Epoch 5/20

313/313 ━━━━━━━━━━━━━━━━━━━━ 0s 905ms/step - accuracy: 0.4501 - loss: 1.5031

WARNING:tensorflow:5 out of the last 5 calls to <function conv.<locals>._conv_xla at 0x3190abc40> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.

WARNING:tensorflow:6 out of the last 6 calls to <function conv.<locals>._conv_xla at 0x3190abd80> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
```
</div>

![png](/img/examples/vision/focal_modulation_network/focal_modulation_network_33_1582.png)



<div class="k-default-codeblock">
```
313/313 ━━━━━━━━━━━━━━━━━━━━ 313s 1s/step - accuracy: 0.4618 - loss: 1.4759 - val_accuracy: 0.4519 - val_loss: 1.5107
313/313 ━━━━━━━━━━━━━━━━━━━━ 301s 961ms/step - accuracy: 0.4618 - loss: 1.4759 - val_accuracy: 0.4519 - val_loss: 1.5107

Epoch 6/20

313/313 ━━━━━━━━━━━━━━━━━━━━ 316s 1s/step - accuracy: 0.4919 - loss: 1.4076 - val_accuracy: 0.4692 - val_loss: 1.4941
313/313 ━━━━━━━━━━━━━━━━━━━━ 308s 984ms/step - accuracy: 0.4943 - loss: 1.3977 - val_accuracy: 0.4743 - val_loss: 1.5007

Epoch 7/20

313/313 ━━━━━━━━━━━━━━━━━━━━ 312s 997ms/step - accuracy: 0.5189 - loss: 1.3461 - val_accuracy: 0.5032 - val_loss: 1.3940
313/313 ━━━━━━━━━━━━━━━━━━━━ 304s 971ms/step - accuracy: 0.5185 - loss: 1.3445 - val_accuracy: 0.5042 - val_loss: 1.3920

Epoch 8/20

313/313 ━━━━━━━━━━━━━━━━━━━━ 307s 981ms/step - accuracy: 0.5356 - loss: 1.3025 - val_accuracy: 0.5182 - val_loss: 1.3580
313/313 ━━━━━━━━━━━━━━━━━━━━ 308s 985ms/step - accuracy: 0.5350 - loss: 1.3001 - val_accuracy: 0.5192 - val_loss: 1.3544

Epoch 9/20

313/313 ━━━━━━━━━━━━━━━━━━━━ 299s 954ms/step - accuracy: 0.5440 - loss: 1.2654 - val_accuracy: 0.5273 - val_loss: 1.3291
313/313 ━━━━━━━━━━━━━━━━━━━━ 308s 984ms/step - accuracy: 0.5451 - loss: 1.2637 - val_accuracy: 0.5247 - val_loss: 1.3289

Epoch 10/20

313/313 ━━━━━━━━━━━━━━━━━━━━ 0s 866ms/step - accuracy: 0.5588 - loss: 1.2346
```
</div>

![png](/img/examples/vision/focal_modulation_network/focal_modulation_network_33_3158.png)



<div class="k-default-codeblock">
```
313/313 ━━━━━━━━━━━━━━━━━━━━ 301s 961ms/step - accuracy: 0.5600 - loss: 1.2305 - val_accuracy: 0.5273 - val_loss: 1.3158
313/313 ━━━━━━━━━━━━━━━━━━━━ 306s 976ms/step - accuracy: 0.5607 - loss: 1.2300 - val_accuracy: 0.5312 - val_loss: 1.3112

Epoch 11/20

313/313 ━━━━━━━━━━━━━━━━━━━━ 302s 965ms/step - accuracy: 0.5741 - loss: 1.1958 - val_accuracy: 0.5248 - val_loss: 1.3298
313/313 ━━━━━━━━━━━━━━━━━━━━ 299s 956ms/step - accuracy: 0.5730 - loss: 1.1985 - val_accuracy: 0.5354 - val_loss: 1.3262

Epoch 12/20

313/313 ━━━━━━━━━━━━━━━━━━━━ 302s 965ms/step - accuracy: 0.5836 - loss: 1.1713 - val_accuracy: 0.5500 - val_loss: 1.2602
313/313 ━━━━━━━━━━━━━━━━━━━━ 304s 970ms/step - accuracy: 0.5802 - loss: 1.1739 - val_accuracy: 0.5484 - val_loss: 1.2850

Epoch 13/20

313/313 ━━━━━━━━━━━━━━━━━━━━ 297s 947ms/step - accuracy: 0.5900 - loss: 1.1483 - val_accuracy: 0.5626 - val_loss: 1.2348
313/313 ━━━━━━━━━━━━━━━━━━━━ 312s 996ms/step - accuracy: 0.5905 - loss: 1.1494 - val_accuracy: 0.5640 - val_loss: 1.2334

Epoch 14/20

313/313 ━━━━━━━━━━━━━━━━━━━━ 304s 970ms/step - accuracy: 0.5987 - loss: 1.1270 - val_accuracy: 0.5657 - val_loss: 1.2249
313/313 ━━━━━━━━━━━━━━━━━━━━ 307s 982ms/step - accuracy: 0.5973 - loss: 1.1309 - val_accuracy: 0.5641 - val_loss: 1.2307

Epoch 15/20

313/313 ━━━━━━━━━━━━━━━━━━━━ 0s 884ms/step - accuracy: 0.6118 - loss: 1.1106
```
</div>

![png](/img/examples/vision/focal_modulation_network/focal_modulation_network_33_4734.png)



<div class="k-default-codeblock">
```
313/313 ━━━━━━━━━━━━━━━━━━━━ 308s 982ms/step - accuracy: 0.6081 - loss: 1.1134 - val_accuracy: 0.5671 - val_loss: 1.2246
313/313 ━━━━━━━━━━━━━━━━━━━━ 313s 999ms/step - accuracy: 0.6072 - loss: 1.1127 - val_accuracy: 0.5740 - val_loss: 1.2129

Epoch 16/20

313/313 ━━━━━━━━━━━━━━━━━━━━ 298s 954ms/step - accuracy: 0.6105 - loss: 1.0981 - val_accuracy: 0.5708 - val_loss: 1.2035
313/313 ━━━━━━━━━━━━━━━━━━━━ 311s 992ms/step - accuracy: 0.6120 - loss: 1.0956 - val_accuracy: 0.5734 - val_loss: 1.2059

Epoch 17/20

313/313 ━━━━━━━━━━━━━━━━━━━━ 302s 964ms/step - accuracy: 0.6144 - loss: 1.0838 - val_accuracy: 0.5770 - val_loss: 1.2002
313/313 ━━━━━━━━━━━━━━━━━━━━ 5871s 19s/step - accuracy: 0.6145 - loss: 1.0874 - val_accuracy: 0.5788 - val_loss: 1.1996

Epoch 18/20

313/313 ━━━━━━━━━━━━━━━━━━━━ 308s 984ms/step - accuracy: 0.6209 - loss: 1.0799 - val_accuracy: 0.5764 - val_loss: 1.1978
313/313 ━━━━━━━━━━━━━━━━━━━━ 300s 959ms/step - accuracy: 0.6164 - loss: 1.0800 - val_accuracy: 0.5797 - val_loss: 1.1991

Epoch 19/20

313/313 ━━━━━━━━━━━━━━━━━━━━ 315s 1s/step - accuracy: 0.6174 - loss: 1.0772 - val_accuracy: 0.5777 - val_loss: 1.1951
313/313 ━━━━━━━━━━━━━━━━━━━━ 297s 950ms/step - accuracy: 0.6212 - loss: 1.0737 - val_accuracy: 0.5828 - val_loss: 1.1924

Epoch 20/20

313/313 ━━━━━━━━━━━━━━━━━━━━ 0s 896ms/step - accuracy: 0.6249 - loss: 1.0723
```
</div>

![png](/img/examples/vision/focal_modulation_network/focal_modulation_network_33_6310.png)



<div class="k-default-codeblock">
```
313/313 ━━━━━━━━━━━━━━━━━━━━ 311s 993ms/step - accuracy: 0.6240 - loss: 1.0710 - val_accuracy: 0.5775 - val_loss: 1.1971
313/313 ━━━━━━━━━━━━━━━━━━━━ 297s 950ms/step - accuracy: 0.6204 - loss: 1.0702 - val_accuracy: 0.5822 - val_loss: 1.1950
```
</div>

Expand Down