diff --git a/examples/keras_rs/distributed_embedding_jax.py b/examples/keras_rs/distributed_embedding_jax.py
index e0fa3f0425..b12b8fe053 100644
--- a/examples/keras_rs/distributed_embedding_jax.py
+++ b/examples/keras_rs/distributed_embedding_jax.py
@@ -1,8 +1,8 @@
"""
Title: DistributedEmbedding using TPU SparseCore and JAX
-Author: [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/)
+Author: [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/), [C. Antonio Sánchez](https://github.com/cantonios/)
Date created: 2025/06/03
-Last modified: 2025/06/03
+Last modified: 2025/09/02
Description: Rank movies using a two tower model with embeddings on SparseCore.
Accelerator: TPU
"""
@@ -56,7 +56,7 @@
"""
## Preparing the dataset
-We're going to use the same Movielens data. The ratings are the objectives we
+We're going to use the same MovieLens data. The ratings are the objectives we
are trying to predict.
"""
@@ -150,8 +150,8 @@ def preprocess_rating(x):
- A name.
- A table, the embedding table to use.
-- An input shape (per replica).
-- An output shape (per replica).
+- An input shape (batch size is for all TPUs).
+- An output shape (batch size is for all TPUs).
We can organize features in any structure we want, which can be nested. A dict
is often a good choice to have names for the inputs and outputs.
diff --git a/examples/keras_rs/distributed_embedding_tf.py b/examples/keras_rs/distributed_embedding_tf.py
new file mode 100644
index 0000000000..fc6a8ea099
--- /dev/null
+++ b/examples/keras_rs/distributed_embedding_tf.py
@@ -0,0 +1,319 @@
+"""
+Title: DistributedEmbedding using TPU SparseCore and TensorFlow
+Author: [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/)
+Date created: 2025/09/02
+Last modified: 2025/09/02
+Description: Rank movies using a two tower model with embeddings on SparseCore.
+Accelerator: TPU
+"""
+
+"""
+## Introduction
+
+In the [basic ranking](/keras_rs/examples/basic_ranking/) tutorial, we showed
+how to build a ranking model for the MovieLens dataset to suggest movies to
+users.
+
+This tutorial implements the same model trained on the same dataset but with the
+use of `keras_rs.layers.DistributedEmbedding`, which makes use of SparseCore on
+TPU. This is the TensorFlow version of the tutorial. It needs to be run on TPU
+v5p or v6e.
+
+Let's begin by installing the necessary libraries. Note that we need
+`tensorflow-tpu` version 2.19. We'll also install `keras-rs`.
+"""
+
+"""shell
+pip install -U -q tensorflow-tpu==2.19.1
+pip install -q keras-rs
+"""
+
+"""
+We're using the PJRT version of the runtime for TensorFlow. We're also enabling
+the MLIR bridge. This requires setting a few flags before importing TensorFlow.
+"""
+
+import os
+import libtpu
+
+os.environ["PJRT_DEVICE"] = "TPU"
+os.environ["NEXT_PLUGGABLE_DEVICE_USE_C_API"] = "true"
+os.environ["TF_PLUGGABLE_DEVICE_LIBRARY_PATH"] = libtpu.get_library_path()
+os.environ["TF_XLA_FLAGS"] = (
+ "--tf_mlir_enable_mlir_bridge=true "
+ "--tf_mlir_enable_convert_control_to_data_outputs_pass=true "
+ "--tf_mlir_enable_merge_control_flow_pass=true"
+)
+
+import tensorflow as tf
+
+"""
+We now set the Keras backend to TensorFlow and import the necessary libraries.
+"""
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import keras
+import keras_rs
+import tensorflow_datasets as tfds
+
+"""
+## Creating a `TPUStrategy`
+
+To run TensorFlow on TPU, you need to use a
+[`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy)
+to handle the distribution of the model.
+
+The core of the model is replicated across TPU instances, which is done by the
+`TPUStrategy`. Note that on GPU you would use
+[`tf.distribute.MirroredStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy)
+instead, but this strategy is not for TPU.
+
+Only the embedding tables handled by `DistributedEmbedding` are sharded across
+the SparseCore chips of all the available TPUs.
+"""
+
+resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
+topology = tf.tpu.experimental.initialize_tpu_system(resolver)
+tpu_metadata = resolver.get_tpu_system_metadata()
+
+device_assignment = tf.tpu.experimental.DeviceAssignment.build(
+ topology, num_replicas=tpu_metadata.num_cores
+)
+strategy = tf.distribute.TPUStrategy(
+ resolver, experimental_device_assignment=device_assignment
+)
+
+"""
+## Dataset distribution
+
+While the model is replicated and the embedding tables are sharded across
+SparseCores, the dataset is distributed by sharding each batch across the TPUs.
+We need to make sure the batch size is a multiple of the number of TPUs.
+"""
+
+PER_REPLICA_BATCH_SIZE = 256
+BATCH_SIZE = PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync
+
+"""
+## Preparing the dataset
+
+We're going to use the same MovieLens data. The ratings are the objectives we
+are trying to predict.
+"""
+
+# Ratings data.
+ratings = tfds.load("movielens/100k-ratings", split="train")
+# Features of all the available movies.
+movies = tfds.load("movielens/100k-movies", split="train")
+
+"""
+We need to know the number of users as we're using the user ID directly as an
+index in the user embedding table.
+"""
+
+users_count = int(
+ ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))
+ .reduce(tf.constant(0, tf.int32), tf.maximum)
+ .numpy()
+)
+
+"""
+We also need do know the number of movies as we're using the movie ID directly
+as an index in the movie embedding table.
+"""
+
+movies_count = int(movies.cardinality().numpy())
+
+"""
+The inputs to the model are the user IDs and movie IDs and the labels are the
+ratings.
+"""
+
+
+def preprocess_rating(x):
+ return (
+ # Inputs are user IDs and movie IDs
+ {
+ "user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32),
+ "movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32),
+ },
+ # Labels are ratings between 0 and 1.
+ (x["user_rating"] - 1.0) / 4.0,
+ )
+
+
+"""
+We'll split the data by putting 80% of the ratings in the train set, and 20% in
+the test set.
+"""
+
+shuffled_ratings = ratings.map(preprocess_rating).shuffle(
+ 100_000, seed=42, reshuffle_each_iteration=False
+)
+train_ratings = (
+ shuffled_ratings.take(80_000).batch(BATCH_SIZE, drop_remainder=True).cache()
+)
+test_ratings = (
+ shuffled_ratings.skip(80_000)
+ .take(20_000)
+ .batch(BATCH_SIZE, drop_remainder=True)
+ .cache()
+)
+
+"""
+## Configuring DistributedEmbedding
+
+The `keras_rs.layers.DistributedEmbedding` handles multiple features and
+multiple embedding tables. This is to enable the sharing of tables between
+features and allow some optimizations that come from combining multiple
+embedding lookups into a single invocation. In this section, we'll describe
+how to configure these.
+
+### Configuring tables
+
+Tables are configured using `keras_rs.layers.TableConfig`, which has:
+
+- A name.
+- A vocabulary size (input size).
+- an embedding dimension (output size).
+- A combiner to specify how to reduce multiple embeddings into a single one in
+ the case when we embed a sequence. Note that this doesn't apply to our example
+ because we're getting a single embedding for each user and each movie.
+- A placement to tell whether to put the table on the SparseCore chips or not.
+ In this case, we want the `"sparsecore"` placement.
+- An optimizer to specify how to apply gradients when training. Each table has
+ its own optimizer and the one passed to `model.compile()` is not used for the
+ embedding tables.
+
+### Configuring features
+
+Features are configured using `keras_rs.layers.FeatureConfig`, which has:
+
+- A name.
+- A table, the embedding table to use.
+- An input shape (batch size is for all TPUs).
+- An output shape (batch size is for all TPUs).
+
+We can organize features in any structure we want, which can be nested. A dict
+is often a good choice to have names for the inputs and outputs.
+"""
+
+EMBEDDING_DIMENSION = 32
+
+movie_table = keras_rs.layers.TableConfig(
+ name="movie_table",
+ vocabulary_size=movies_count + 1, # +1 for movie ID 0, which is not used
+ embedding_dim=EMBEDDING_DIMENSION,
+ optimizer="adam",
+ placement="sparsecore",
+)
+user_table = keras_rs.layers.TableConfig(
+ name="user_table",
+ vocabulary_size=users_count + 1, # +1 for user ID 0, which is not used
+ embedding_dim=EMBEDDING_DIMENSION,
+ optimizer="adam",
+ placement="sparsecore",
+)
+
+FEATURE_CONFIGS = {
+ "movie_id": keras_rs.layers.FeatureConfig(
+ name="movie",
+ table=movie_table,
+ input_shape=(BATCH_SIZE,),
+ output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION),
+ ),
+ "user_id": keras_rs.layers.FeatureConfig(
+ name="user",
+ table=user_table,
+ input_shape=(BATCH_SIZE,),
+ output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION),
+ ),
+}
+
+"""
+## Defining the Model
+
+We're now ready to create a `DistributedEmbedding` inside a model. Once we have
+the configuration, we simply pass it the constructor of `DistributedEmbedding`.
+Then, within the model `call` method, `DistributedEmbedding` is the first layer
+we call.
+
+The ouputs have the exact same structure as the inputs. In our example, we
+concatenate the embeddings we got as outputs and run them through a tower of
+dense layers.
+"""
+
+
+class EmbeddingModel(keras.Model):
+ """Create the model with the embedding configuration.
+
+ Args:
+ feature_configs: the configuration for `DistributedEmbedding`.
+ """
+
+ def __init__(self, feature_configs):
+ super().__init__()
+
+ self.embedding_layer = keras_rs.layers.DistributedEmbedding(
+ feature_configs=feature_configs
+ )
+ self.ratings = keras.Sequential(
+ [
+ # Learn multiple dense layers.
+ keras.layers.Dense(256, activation="relu"),
+ keras.layers.Dense(64, activation="relu"),
+ # Make rating predictions in the final layer.
+ keras.layers.Dense(1),
+ ]
+ )
+
+ def call(self, features):
+ # Embedding lookup. Outputs have the same structure as the inputs.
+ embedding = self.embedding_layer(features)
+ return self.ratings(
+ keras.ops.concatenate(
+ [embedding["user_id"], embedding["movie_id"]],
+ axis=1,
+ )
+ )
+
+
+"""
+Let's now instantiate the model. We then use `model.compile()` to configure the
+loss, metrics and optimizer. Again, this Adagrad optimizer will only apply to
+the dense layers and not the embedding tables.
+"""
+
+with strategy.scope():
+ model = EmbeddingModel(FEATURE_CONFIGS)
+
+ model.compile(
+ loss=keras.losses.MeanSquaredError(),
+ metrics=[keras.metrics.RootMeanSquaredError()],
+ optimizer="adagrad",
+ )
+
+"""
+## Fitting and evaluating
+
+We can use the standard Keras `model.fit()` to train the model. Keras will
+automatically use the `TPUStrategy` to distribute the model and the data.
+"""
+
+with strategy.scope():
+ model.fit(train_ratings, epochs=5)
+
+"""
+Same for `model.evaluate()`.
+"""
+
+with strategy.scope():
+ model.evaluate(test_ratings, return_dict=True)
+
+"""
+That's it.
+
+This example shows that after setting up the `TPUStrategy` and configuring the
+`DistributedEmbedding`, you can use the standard Keras workflows.
+"""
diff --git a/examples/keras_rs/ipynb/distributed_embedding_jax.ipynb b/examples/keras_rs/ipynb/distributed_embedding_jax.ipynb
index 67e3f96b2d..1e17903b47 100644
--- a/examples/keras_rs/ipynb/distributed_embedding_jax.ipynb
+++ b/examples/keras_rs/ipynb/distributed_embedding_jax.ipynb
@@ -8,9 +8,9 @@
"source": [
"# DistributedEmbedding using TPU SparseCore and JAX\n",
"\n",
- "**Author:** [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/)
\n",
+ "**Author:** [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/), [C. Antonio Sánchez](https://github.com/cantonios/)
\n",
"**Date created:** 2025/06/03
\n",
- "**Last modified:** 2025/06/03
\n",
+ "**Last modified:** 2025/09/02
\n",
"**Description:** Rank movies using a two tower model with embeddings on SparseCore."
]
},
@@ -103,7 +103,7 @@
"source": [
"## Preparing the dataset\n",
"\n",
- "We're going to use the same Movielens data. The ratings are the objectives we\n",
+ "We're going to use the same MovieLens data. The ratings are the objectives we\n",
"are trying to predict."
]
},
@@ -267,8 +267,8 @@
"\n",
"- A name.\n",
"- A table, the embedding table to use.\n",
- "- An input shape (per replica).\n",
- "- An output shape (per replica).\n",
+ "- An input shape (batch size is for all TPUs).\n",
+ "- An output shape (batch size is for all TPUs).\n",
"\n",
"We can organize features in any structure we want, which can be nested. A dict\n",
"is often a good choice to have names for the inputs and outputs."
diff --git a/examples/keras_rs/ipynb/distributed_embedding_tf.ipynb b/examples/keras_rs/ipynb/distributed_embedding_tf.ipynb
new file mode 100644
index 0000000000..90cc6564f5
--- /dev/null
+++ b/examples/keras_rs/ipynb/distributed_embedding_tf.ipynb
@@ -0,0 +1,571 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "# DistributedEmbedding using TPU SparseCore and TensorFlow\n",
+ "\n",
+ "**Author:** [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/)
\n",
+ "**Date created:** 2025/09/02
\n",
+ "**Last modified:** 2025/09/02
\n",
+ "**Description:** Rank movies using a two tower model with embeddings on SparseCore."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Introduction\n",
+ "\n",
+ "In the [basic ranking](/keras_rs/examples/basic_ranking/) tutorial, we showed\n",
+ "how to build a ranking model for the MovieLens dataset to suggest movies to\n",
+ "users.\n",
+ "\n",
+ "This tutorial implements the same model trained on the same dataset but with the\n",
+ "use of `keras_rs.layers.DistributedEmbedding`, which makes use of SparseCore on\n",
+ "TPU. This is the TensorFlow version of the tutorial. It needs to be run on TPU\n",
+ "v5p or v6e.\n",
+ "\n",
+ "Let's begin by installing the necessary libraries. Note that we need\n",
+ "`tensorflow-tpu` version 2.19. We'll also install `keras-rs`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install -U -q tensorflow-tpu==2.19.1\n",
+ "!pip install -q keras-rs"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "We're using the PJRT version of the runtime for TensorFlow. We're also enabling\n",
+ "the MLIR bridge. This requires setting a few flags before importing TensorFlow."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import libtpu\n",
+ "\n",
+ "os.environ[\"PJRT_DEVICE\"] = \"TPU\"\n",
+ "os.environ[\"NEXT_PLUGGABLE_DEVICE_USE_C_API\"] = \"true\"\n",
+ "os.environ[\"TF_PLUGGABLE_DEVICE_LIBRARY_PATH\"] = libtpu.get_library_path()\n",
+ "os.environ[\"TF_XLA_FLAGS\"] = (\n",
+ " \"--tf_mlir_enable_mlir_bridge=true \"\n",
+ " \"--tf_mlir_enable_convert_control_to_data_outputs_pass=true \"\n",
+ " \"--tf_mlir_enable_merge_control_flow_pass=true\"\n",
+ ")\n",
+ "\n",
+ "import tensorflow as tf"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "We now set the Keras backend to TensorFlow and import the necessary libraries."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
+ "\n",
+ "import keras\n",
+ "import keras_rs\n",
+ "import tensorflow_datasets as tfds"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Creating a `TPUStrategy`\n",
+ "\n",
+ "To run TensorFlow on TPU, you need to use a\n",
+ "[`tf.distribute.TPUStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/TPUStrategy)\n",
+ "to handle the distribution of the model.\n",
+ "\n",
+ "The core of the model is replicated across TPU instances, which is done by the\n",
+ "`TPUStrategy`. Note that on GPU you would use\n",
+ "[`tf.distribute.MirroredStrategy`](https://www.tensorflow.org/api_docs/python/tf/distribute/MirroredStrategy)\n",
+ "instead, but this strategy is not for TPU.\n",
+ "\n",
+ "Only the embedding tables handled by `DistributedEmbedding` are sharded across\n",
+ "the SparseCore chips of all the available TPUs."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=\"local\")\n",
+ "topology = tf.tpu.experimental.initialize_tpu_system(resolver)\n",
+ "tpu_metadata = resolver.get_tpu_system_metadata()\n",
+ "\n",
+ "device_assignment = tf.tpu.experimental.DeviceAssignment.build(\n",
+ " topology, num_replicas=tpu_metadata.num_cores\n",
+ ")\n",
+ "strategy = tf.distribute.TPUStrategy(\n",
+ " resolver, experimental_device_assignment=device_assignment\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Dataset distribution\n",
+ "\n",
+ "While the model is replicated and the embedding tables are sharded across\n",
+ "SparseCores, the dataset is distributed by sharding each batch across the TPUs.\n",
+ "We need to make sure the batch size is a multiple of the number of TPUs."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "PER_REPLICA_BATCH_SIZE = 256\n",
+ "BATCH_SIZE = PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Preparing the dataset\n",
+ "\n",
+ "We're going to use the same MovieLens data. The ratings are the objectives we\n",
+ "are trying to predict."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "# Ratings data.\n",
+ "ratings = tfds.load(\"movielens/100k-ratings\", split=\"train\")\n",
+ "# Features of all the available movies.\n",
+ "movies = tfds.load(\"movielens/100k-movies\", split=\"train\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "We need to know the number of users as we're using the user ID directly as an\n",
+ "index in the user embedding table."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "users_count = int(\n",
+ " ratings.map(lambda x: tf.strings.to_number(x[\"user_id\"], out_type=tf.int32))\n",
+ " .reduce(tf.constant(0, tf.int32), tf.maximum)\n",
+ " .numpy()\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "We also need do know the number of movies as we're using the movie ID directly\n",
+ "as an index in the movie embedding table."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "movies_count = int(movies.cardinality().numpy())"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "The inputs to the model are the user IDs and movie IDs and the labels are the\n",
+ "ratings."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "def preprocess_rating(x):\n",
+ " return (\n",
+ " # Inputs are user IDs and movie IDs\n",
+ " {\n",
+ " \"user_id\": tf.strings.to_number(x[\"user_id\"], out_type=tf.int32),\n",
+ " \"movie_id\": tf.strings.to_number(x[\"movie_id\"], out_type=tf.int32),\n",
+ " },\n",
+ " # Labels are ratings between 0 and 1.\n",
+ " (x[\"user_rating\"] - 1.0) / 4.0,\n",
+ " )\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "We'll split the data by putting 80% of the ratings in the train set, and 20% in\n",
+ "the test set."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "shuffled_ratings = ratings.map(preprocess_rating).shuffle(\n",
+ " 100_000, seed=42, reshuffle_each_iteration=False\n",
+ ")\n",
+ "train_ratings = (\n",
+ " shuffled_ratings.take(80_000).batch(BATCH_SIZE, drop_remainder=True).cache()\n",
+ ")\n",
+ "test_ratings = (\n",
+ " shuffled_ratings.skip(80_000)\n",
+ " .take(20_000)\n",
+ " .batch(BATCH_SIZE, drop_remainder=True)\n",
+ " .cache()\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Configuring DistributedEmbedding\n",
+ "\n",
+ "The `keras_rs.layers.DistributedEmbedding` handles multiple features and\n",
+ "multiple embedding tables. This is to enable the sharing of tables between\n",
+ "features and allow some optimizations that come from combining multiple\n",
+ "embedding lookups into a single invocation. In this section, we'll describe\n",
+ "how to configure these.\n",
+ "\n",
+ "### Configuring tables\n",
+ "\n",
+ "Tables are configured using `keras_rs.layers.TableConfig`, which has:\n",
+ "\n",
+ "- A name.\n",
+ "- A vocabulary size (input size).\n",
+ "- an embedding dimension (output size).\n",
+ "- A combiner to specify how to reduce multiple embeddings into a single one in\n",
+ " the case when we embed a sequence. Note that this doesn't apply to our example\n",
+ " because we're getting a single embedding for each user and each movie.\n",
+ "- A placement to tell whether to put the table on the SparseCore chips or not.\n",
+ " In this case, we want the `\"sparsecore\"` placement.\n",
+ "- An optimizer to specify how to apply gradients when training. Each table has\n",
+ " its own optimizer and the one passed to `model.compile()` is not used for the\n",
+ " embedding tables.\n",
+ "\n",
+ "### Configuring features\n",
+ "\n",
+ "Features are configured using `keras_rs.layers.FeatureConfig`, which has:\n",
+ "\n",
+ "- A name.\n",
+ "- A table, the embedding table to use.\n",
+ "- An input shape (batch size is for all TPUs).\n",
+ "- An output shape (batch size is for all TPUs).\n",
+ "\n",
+ "We can organize features in any structure we want, which can be nested. A dict\n",
+ "is often a good choice to have names for the inputs and outputs."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "EMBEDDING_DIMENSION = 32\n",
+ "\n",
+ "movie_table = keras_rs.layers.TableConfig(\n",
+ " name=\"movie_table\",\n",
+ " vocabulary_size=movies_count + 1, # +1 for movie ID 0, which is not used\n",
+ " embedding_dim=EMBEDDING_DIMENSION,\n",
+ " optimizer=\"adam\",\n",
+ " placement=\"sparsecore\",\n",
+ ")\n",
+ "user_table = keras_rs.layers.TableConfig(\n",
+ " name=\"user_table\",\n",
+ " vocabulary_size=users_count + 1, # +1 for user ID 0, which is not used\n",
+ " embedding_dim=EMBEDDING_DIMENSION,\n",
+ " optimizer=\"adam\",\n",
+ " placement=\"sparsecore\",\n",
+ ")\n",
+ "\n",
+ "FEATURE_CONFIGS = {\n",
+ " \"movie_id\": keras_rs.layers.FeatureConfig(\n",
+ " name=\"movie\",\n",
+ " table=movie_table,\n",
+ " input_shape=(BATCH_SIZE,),\n",
+ " output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION),\n",
+ " ),\n",
+ " \"user_id\": keras_rs.layers.FeatureConfig(\n",
+ " name=\"user\",\n",
+ " table=user_table,\n",
+ " input_shape=(BATCH_SIZE,),\n",
+ " output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION),\n",
+ " ),\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Defining the Model\n",
+ "\n",
+ "We're now ready to create a `DistributedEmbedding` inside a model. Once we have\n",
+ "the configuration, we simply pass it the constructor of `DistributedEmbedding`.\n",
+ "Then, within the model `call` method, `DistributedEmbedding` is the first layer\n",
+ "we call.\n",
+ "\n",
+ "The ouputs have the exact same structure as the inputs. In our example, we\n",
+ "concatenate the embeddings we got as outputs and run them through a tower of\n",
+ "dense layers."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "class EmbeddingModel(keras.Model):\n",
+ " \"\"\"Create the model with the embedding configuration.\n",
+ "\n",
+ " Args:\n",
+ " feature_configs: the configuration for `DistributedEmbedding`.\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(self, feature_configs):\n",
+ " super().__init__()\n",
+ "\n",
+ " self.embedding_layer = keras_rs.layers.DistributedEmbedding(\n",
+ " feature_configs=feature_configs\n",
+ " )\n",
+ " self.ratings = keras.Sequential(\n",
+ " [\n",
+ " # Learn multiple dense layers.\n",
+ " keras.layers.Dense(256, activation=\"relu\"),\n",
+ " keras.layers.Dense(64, activation=\"relu\"),\n",
+ " # Make rating predictions in the final layer.\n",
+ " keras.layers.Dense(1),\n",
+ " ]\n",
+ " )\n",
+ "\n",
+ " def call(self, features):\n",
+ " # Embedding lookup. Outputs have the same structure as the inputs.\n",
+ " embedding = self.embedding_layer(features)\n",
+ " return self.ratings(\n",
+ " keras.ops.concatenate(\n",
+ " [embedding[\"user_id\"], embedding[\"movie_id\"]],\n",
+ " axis=1,\n",
+ " )\n",
+ " )\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "Let's now instantiate the model. We then use `model.compile()` to configure the\n",
+ "loss, metrics and optimizer. Again, this Adagrad optimizer will only apply to\n",
+ "the dense layers and not the embedding tables."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "with strategy.scope():\n",
+ " model = EmbeddingModel(FEATURE_CONFIGS)\n",
+ "\n",
+ " model.compile(\n",
+ " loss=keras.losses.MeanSquaredError(),\n",
+ " metrics=[keras.metrics.RootMeanSquaredError()],\n",
+ " optimizer=\"adagrad\",\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Fitting and evaluating\n",
+ "\n",
+ "We can use the standard Keras `model.fit()` to train the model. Keras will\n",
+ "automatically use the `TPUStrategy` to distribute the model and the data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "with strategy.scope():\n",
+ " model.fit(train_ratings, epochs=5)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "Same for `model.evaluate()`."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "with strategy.scope():\n",
+ " model.evaluate(test_ratings, return_dict=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "That's it.\n",
+ "\n",
+ "This example shows that after setting up the `TPUStrategy` and configuring the\n",
+ "`DistributedEmbedding`, you can use the standard Keras workflows."
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "TPU",
+ "colab": {
+ "collapsed_sections": [],
+ "name": "distributed_embedding_tf",
+ "private_outputs": false,
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file
diff --git a/examples/keras_rs/md/distributed_embedding_jax.md b/examples/keras_rs/md/distributed_embedding_jax.md
index 6b16cff59b..390b02fa8d 100644
--- a/examples/keras_rs/md/distributed_embedding_jax.md
+++ b/examples/keras_rs/md/distributed_embedding_jax.md
@@ -1,8 +1,8 @@
# DistributedEmbedding using TPU SparseCore and JAX
-**Author:** [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/)
+**Author:** [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/), [C. Antonio Sánchez](https://github.com/cantonios/)
**Date created:** 2025/06/03
-**Last modified:** 2025/06/03
+**Last modified:** 2025/09/02
**Description:** Rank movies using a two tower model with embeddings on SparseCore.
@@ -63,7 +63,7 @@ keras.distribution.set_distribution(distribution)
---
## Preparing the dataset
-We're going to use the same Movielens data. The ratings are the objectives we
+We're going to use the same MovieLens data. The ratings are the objectives we
are trying to predict.
@@ -163,8 +163,8 @@ Features are configured using `keras_rs.layers.FeatureConfig`, which has:
- A name.
- A table, the embedding table to use.
-- An input shape (per replica).
-- An output shape (per replica).
+- An input shape (batch size is for all TPUs).
+- An output shape (batch size is for all TPUs).
We can organize features in any structure we want, which can be nested. A dict
is often a good choice to have names for the inputs and outputs.
diff --git a/examples/keras_rs/md/distributed_embedding_tf.md b/examples/keras_rs/md/distributed_embedding_tf.md
new file mode 100644
index 0000000000..c0ad1c258e
--- /dev/null
+++ b/examples/keras_rs/md/distributed_embedding_tf.md
@@ -0,0 +1,438 @@
+# DistributedEmbedding using TPU SparseCore and TensorFlow
+
+**Author:** [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/)
+**Date created:** 2025/09/02
+**Last modified:** 2025/09/02
+**Description:** Rank movies using a two tower model with embeddings on SparseCore.
+
+
+ [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_rs/ipynb/distributed_embedding_tf.ipynb) •
[**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/keras_rs/distributed_embedding_tf.py)
+
+
+
+---
+## Introduction
+
+In the [basic ranking](/keras_rs/examples/basic_ranking/) tutorial, we showed
+how to build a ranking model for the MovieLens dataset to suggest movies to
+users.
+
+This tutorial implements the same model trained on the same dataset but with the
+use of `keras_rs.layers.DistributedEmbedding`, which makes use of SparseCore on
+TPU. This is the TensorFlow version of the tutorial. It needs to be run on TPU
+v5p or v6e.
+
+Let's begin by installing the necessary libraries. Note that we need
+`tensorflow-tpu` version 2.19. We'll also install `keras-rs`.
+
+
+```python
+!pip install -U -q tensorflow-tpu==2.19.1
+!pip install -q keras-rs
+```
+
+We're using the PJRT version of the runtime for TensorFlow. We're also enabling
+the MLIR bridge. This requires setting a few flags before importing TensorFlow.
+
+
+```python
+import os
+import libtpu
+
+os.environ["PJRT_DEVICE"] = "TPU"
+os.environ["NEXT_PLUGGABLE_DEVICE_USE_C_API"] = "true"
+os.environ["TF_PLUGGABLE_DEVICE_LIBRARY_PATH"] = libtpu.get_library_path()
+os.environ["TF_XLA_FLAGS"] = (
+ "--tf_mlir_enable_mlir_bridge=true "
+ "--tf_mlir_enable_convert_control_to_data_outputs_pass=true "
+ "--tf_mlir_enable_merge_control_flow_pass=true"
+)
+
+import tensorflow as tf
+```
+
+