Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"\n",
"**Author:** [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)<br>\n",
"**Date created:** 2020/05/10<br>\n",
"**Last modified:** 2021/02/15<br>\n",
"**Last modified:** 2026/02/25<br>\n",
"**Description:** Implement a Switch Transformer for text classification."
]
},
Expand Down Expand Up @@ -50,7 +50,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand All @@ -72,7 +72,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand All @@ -98,7 +98,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
Expand Down Expand Up @@ -131,12 +131,13 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"\n",
"class TokenAndPositionEmbedding(layers.Layer):\n",
" def __init__(self, maxlen, vocab_size, embed_dim):\n",
" super().__init__()\n",
Expand All @@ -148,7 +149,8 @@
" positions = ops.arange(start=0, stop=maxlen, step=1)\n",
" positions = self.pos_emb(positions)\n",
" x = self.token_emb(x)\n",
" return x + positions"
" return x + positions\n",
""
]
},
{
Expand All @@ -164,16 +166,18 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"\n",
"def create_feedforward_network(ff_dim, embed_dim, name=None):\n",
" return keras.Sequential(\n",
" [layers.Dense(ff_dim, activation=\"relu\"), layers.Dense(embed_dim)], name=name\n",
" )"
" )\n",
""
]
},
{
Expand All @@ -189,12 +193,13 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"\n",
"def load_balanced_loss(router_probs, expert_mask):\n",
" # router_probs [tokens_per_batch, num_experts] is the probability assigned for\n",
" # each expert per token. expert_mask [tokens_per_batch, num_experts] contains\n",
Expand All @@ -211,7 +216,8 @@
" # num_expert elements. The two vectors will be pushed towards uniform allocation\n",
" # when the dot product is minimized.\n",
" loss = ops.mean(density_proxy * density) * ops.cast((num_experts**2), \"float32\")\n",
" return loss"
" return loss\n",
""
]
},
{
Expand All @@ -225,12 +231,13 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"\n",
"class Router(layers.Layer):\n",
" def __init__(self, num_experts, expert_capacity):\n",
" self.num_experts = num_experts\n",
Expand Down Expand Up @@ -281,11 +288,14 @@
" * ops.squeeze(ops.one_hot(expert_index, self.num_experts), 1),\n",
" -1,\n",
" ) * ops.squeeze(ops.one_hot(position_in_expert, self.expert_capacity), 1)\n",
"\n",
" # Create binary dispatch_tensor [tokens_per_batch, num_experts, expert_capacity]\n",
" # that is 1 if the token gets routed to the corresponding expert.\n",
" dispatch_tensor = ops.cast(combined_tensor, \"float32\")\n",
" # cast to float32 so it can be used in the einsum product in the Switch layer.\n",
" dispatch_tensor = ops.cast(combined_tensor, dtype=\"float32\")\n",
Comment on lines +294 to +295
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Explicitly naming the dtype argument in ops.cast improves readability and makes the function call clearer. The added comment also provides valuable context for why the casting is performed.

# cast to float32 so it can be used in the einsum product in the Switch layer.
dispatch_tensor = ops.cast(combined_tensor, dtype="float32")

"\n",
" return dispatch_tensor, combined_tensor"
" return dispatch_tensor, combined_tensor\n",
""
]
},
{
Expand All @@ -299,12 +309,13 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"\n",
"class Switch(layers.Layer):\n",
" def __init__(\n",
" self, num_experts, embed_dim, ff_dim, num_tokens_per_batch, capacity_factor=1\n",
Expand All @@ -325,10 +336,11 @@
"\n",
" # inputs shape: [num_tokens_per_batch, embed_dim]\n",
" inputs = ops.reshape(inputs, [num_tokens_per_batch, self.embed_dim])\n",
" # dispatch_tensor shape: [expert_capacity, num_experts, tokens_per_batch]\n",
" # dispatch_tensor shape: [tokens_per_batch, num_experts, expert_capacity]\n",
" # combine_tensor shape: [tokens_per_batch, num_experts, expert_capacity]\n",
" dispatch_tensor, combine_tensor = self.router(inputs)\n",
" # expert_inputs shape: [num_experts, expert_capacity, embed_dim]\n",
" # \"ab\" = [tokens, dim], \"acd\" = [tokens, experts, capacity] -> \"cdb\" = [experts, capacity, dim]\n",
Comment on lines +339 to +343
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The correction of the dispatch_tensor shape in the comment is a good improvement for documentation accuracy. Additionally, the new comment explaining the einsum operation clarifies the tensor transformations, which is very helpful for understanding the logic.

# dispatch_tensor shape: [tokens_per_batch, num_experts, expert_capacity]
# combine_tensor shape: [tokens_per_batch, num_experts, expert_capacity]
dispatch_tensor, combine_tensor = self.router(inputs)
# expert_inputs shape: [num_experts, expert_capacity, embed_dim]
# "ab" = [tokens, dim], "acd" = [tokens, experts, capacity] -> "cdb" = [experts, capacity, dim]

" expert_inputs = ops.einsum(\"ab,acd->cdb\", inputs, dispatch_tensor)\n",
" expert_inputs = ops.reshape(\n",
" expert_inputs, [self.num_experts, self.expert_capacity, self.embed_dim]\n",
Expand All @@ -350,7 +362,8 @@
" expert_outputs_combined,\n",
" [batch_size, num_tokens_per_example, self.embed_dim],\n",
" )\n",
" return outputs"
" return outputs\n",
""
]
},
{
Expand All @@ -364,12 +377,13 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"\n",
"class TransformerBlock(layers.Layer):\n",
" def __init__(self, embed_dim, num_heads, ffn, dropout_rate=0.1):\n",
" super().__init__()\n",
Expand All @@ -388,7 +402,8 @@
" out1 = self.layernorm1(inputs + attn_output)\n",
" ffn_output = self.ffn(out1)\n",
" ffn_output = self.dropout2(ffn_output, training=training)\n",
" return self.layernorm2(out1 + ffn_output)"
" return self.layernorm2(out1 + ffn_output)\n",
""
]
},
{
Expand All @@ -406,12 +421,13 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"\n",
"def create_classifier():\n",
" switch = Switch(num_experts, embed_dim, ff_dim, num_tokens_per_batch)\n",
" transformer_block = TransformerBlock(embed_dim // num_heads, num_heads, switch)\n",
Expand All @@ -429,7 +445,8 @@
" outputs = layers.Dense(2, activation=\"softmax\")(x)\n",
"\n",
" classifier = keras.Model(inputs=inputs, outputs=outputs)\n",
" return classifier"
" return classifier\n",
""
]
},
{
Expand All @@ -443,12 +460,13 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 0,
"metadata": {
"colab_type": "code"
},
"outputs": [],
"source": [
"\n",
"def run_experiment(classifier):\n",
" classifier.compile(\n",
" optimizer=keras.optimizers.Adam(learning_rate),\n",
Expand All @@ -466,7 +484,8 @@
"\n",
"\n",
"classifier = create_classifier()\n",
"run_experiment(classifier)"
"run_experiment(classifier)\n",
""
]
},
{
Expand Down
32 changes: 20 additions & 12 deletions examples/nlp/md/text_classification_with_switch_transformer.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

**Author:** [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)<br>
**Date created:** 2020/05/10<br>
**Last modified:** 2021/02/15<br>
**Last modified:** 2026/02/25<br>
**Description:** Implement a Switch Transformer for text classification.


Expand Down Expand Up @@ -57,9 +57,9 @@ x_val = keras.utils.pad_sequences(x_val, maxlen=num_tokens_per_example)
```
25000 Training sequences
25000 Validation sequences

```
</div>

---
## Define hyperparameters

Expand All @@ -82,9 +82,9 @@ print(f"Number of tokens per batch: {num_tokens_per_batch}")
<div class="k-default-codeblock">
```
Number of tokens per batch: 10000

```
</div>

---
## Implement token & position embedding layer

Expand Down Expand Up @@ -206,9 +206,11 @@ class Router(layers.Layer):
* ops.squeeze(ops.one_hot(expert_index, self.num_experts), 1),
-1,
) * ops.squeeze(ops.one_hot(position_in_expert, self.expert_capacity), 1)

# Create binary dispatch_tensor [tokens_per_batch, num_experts, expert_capacity]
# that is 1 if the token gets routed to the corresponding expert.
dispatch_tensor = ops.cast(combined_tensor, "float32")
# cast to float32 so it can be used in the einsum product in the Switch layer.
dispatch_tensor = ops.cast(combined_tensor, dtype="float32")

return dispatch_tensor, combined_tensor

Expand Down Expand Up @@ -239,10 +241,11 @@ class Switch(layers.Layer):

# inputs shape: [num_tokens_per_batch, embed_dim]
inputs = ops.reshape(inputs, [num_tokens_per_batch, self.embed_dim])
# dispatch_tensor shape: [expert_capacity, num_experts, tokens_per_batch]
# combine_tensor shape: [tokens_per_batch, num_experts, expert_capacity]
dispatch_tensor, combine_tensor = self.router(inputs)
# expert_inputs shape: [num_experts, expert_capacity, embed_dim]
# dispatch_tensor shape: [tokens_per_batch, num_experts, expert_capacity]
# combine_tensor shape: [tokens_per_batch, num_experts, expert_capacity]
dispatch_tensor, combine_tensor = self.router(inputs)
# expert_inputs shape: [num_experts, expert_capacity, embed_dim]
# "ab" = [tokens, dim], "acd" = [tokens, experts, capacity] -> "cdb" = [experts, capacity, dim]
expert_inputs = ops.einsum("ab,acd->cdb", inputs, dispatch_tensor)
expert_inputs = ops.reshape(
expert_inputs, [self.num_experts, self.expert_capacity, self.embed_dim]
Expand Down Expand Up @@ -357,16 +360,21 @@ run_experiment(classifier)
<div class="k-default-codeblock">
```
Epoch 1/3
500/500 ━━━━━━━━━━━━━━━━━━━━ 251s 485ms/step - accuracy: 0.7121 - loss: 1.5394 - val_accuracy: 0.8748 - val_loss: 1.2891

500/500 ━━━━━━━━━━━━━━━━━━━━ 237s 470ms/step - accuracy: 0.7964 - loss: 1.4334 - val_accuracy: 0.8550 - val_loss: 1.3459

Epoch 2/3
500/500 ━━━━━━━━━━━━━━━━━━━━ 240s 480ms/step - accuracy: 0.9243 - loss: 1.2063 - val_accuracy: 0.8752 - val_loss: 1.3090

500/500 ━━━━━━━━━━━━━━━━━━━━ 266s 532ms/step - accuracy: 0.9182 - loss: 1.2174 - val_accuracy: 0.8750 - val_loss: 1.3057

Epoch 3/3
500/500 ━━━━━━━━━━━━━━━━━━━━ 242s 485ms/step - accuracy: 0.9572 - loss: 1.1222 - val_accuracy: 0.8614 - val_loss: 1.3744

<keras.src.callbacks.history.History at 0x7efb79d82a90>
500/500 ━━━━━━━━━━━━━━━━━━━━ 272s 545ms/step - accuracy: 0.9519 - loss: 1.1388 - val_accuracy: 0.8637 - val_loss: 1.3765

<keras.src.callbacks.history.History at 0x176081090>
```
</div>

---
## Conclusion

Expand Down
9 changes: 6 additions & 3 deletions examples/nlp/text_classification_with_switch_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Title: Text classification with Switch Transformer
Author: [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)
Date created: 2020/05/10
Last modified: 2021/02/15
Last modified: 2026/02/25
Description: Implement a Switch Transformer for text classification.
Accelerator: GPU
"""
Expand Down Expand Up @@ -179,9 +179,11 @@ def call(self, inputs, training=False):
* ops.squeeze(ops.one_hot(expert_index, self.num_experts), 1),
-1,
) * ops.squeeze(ops.one_hot(position_in_expert, self.expert_capacity), 1)

# Create binary dispatch_tensor [tokens_per_batch, num_experts, expert_capacity]
# that is 1 if the token gets routed to the corresponding expert.
dispatch_tensor = ops.cast(combined_tensor, "float32")
# cast to float32 so it can be used in the einsum product in the Switch layer.
dispatch_tensor = ops.cast(combined_tensor, dtype="float32")

return dispatch_tensor, combined_tensor

Expand Down Expand Up @@ -211,10 +213,11 @@ def call(self, inputs):

# inputs shape: [num_tokens_per_batch, embed_dim]
inputs = ops.reshape(inputs, [num_tokens_per_batch, self.embed_dim])
# dispatch_tensor shape: [expert_capacity, num_experts, tokens_per_batch]
# dispatch_tensor shape: [tokens_per_batch, num_experts, expert_capacity]
# combine_tensor shape: [tokens_per_batch, num_experts, expert_capacity]
dispatch_tensor, combine_tensor = self.router(inputs)
# expert_inputs shape: [num_experts, expert_capacity, embed_dim]
# "ab" = [tokens, dim], "acd" = [tokens, experts, capacity] -> "cdb" = [experts, capacity, dim]
expert_inputs = ops.einsum("ab,acd->cdb", inputs, dispatch_tensor)
expert_inputs = ops.reshape(
expert_inputs, [self.num_experts, self.expert_capacity, self.embed_dim]
Expand Down