|
| 1 | +""" |
| 2 | +Title: DistributedEmbedding using TPU SparseCore and JAX |
| 3 | +Author: [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/) |
| 4 | +Date created: 2025/06/03 |
| 5 | +Last modified: 2025/06/03 |
| 6 | +Description: Rank movies using a two tower model with embeddings on SparseCore. |
| 7 | +Accelerator: TPU |
| 8 | +""" |
| 9 | + |
| 10 | +""" |
| 11 | +## Introduction |
| 12 | +
|
| 13 | +In the [basic ranking](/keras_rs/examples/basic_ranking/) tutorial, we showed |
| 14 | +how to build a ranking model for the MovieLens dataset to suggest movies to |
| 15 | +users. |
| 16 | +
|
| 17 | +This tutorial implements the same model trained on the same dataset but with the |
| 18 | +use of `keras_rs.layers.DistributedEmbedding`, which makes use of SparseCore on |
| 19 | +TPU. This is the JAX version of the tutorial. It needs to be run on TPU v5p or |
| 20 | +v6e. |
| 21 | +
|
| 22 | +Let's begin by choosing JAX as the backend and importing all the necessary |
| 23 | +libraries. |
| 24 | +""" |
| 25 | + |
| 26 | +import os |
| 27 | + |
| 28 | +os.environ["KERAS_BACKEND"] = "jax" |
| 29 | + |
| 30 | +import jax |
| 31 | +import keras |
| 32 | +import keras_rs |
| 33 | +import tensorflow as tf # Needed for the dataset |
| 34 | +import tensorflow_datasets as tfds |
| 35 | + |
| 36 | +""" |
| 37 | +## Dataset distribution |
| 38 | +
|
| 39 | +While the model is replicated and the embedding tables are sharded across |
| 40 | +SparseCores, the dataset is distributed by sharding each batch across the TPUs. |
| 41 | +We need to make sure the batch size is a multiple of the number of TPUs. |
| 42 | +""" |
| 43 | + |
| 44 | +PER_REPLICA_BATCH_SIZE = 256 |
| 45 | +BATCH_SIZE = PER_REPLICA_BATCH_SIZE * jax.local_device_count("tpu") |
| 46 | + |
| 47 | +distribution = keras.distribution.DataParallel(devices=jax.devices("tpu")) |
| 48 | +keras.distribution.set_distribution(distribution) |
| 49 | + |
| 50 | +""" |
| 51 | +## Preparing the dataset |
| 52 | +
|
| 53 | +We're going to use the same Movielens data. The ratings are the objectives we |
| 54 | +are trying to predict. |
| 55 | +""" |
| 56 | + |
| 57 | +# Ratings data. |
| 58 | +ratings = tfds.load("movielens/100k-ratings", split="train") |
| 59 | +# Features of all the available movies. |
| 60 | +movies = tfds.load("movielens/100k-movies", split="train") |
| 61 | + |
| 62 | +""" |
| 63 | +We need to know the number of users as we're using the user ID directly as an |
| 64 | +index in the user embedding table. |
| 65 | +""" |
| 66 | + |
| 67 | +users_count = ( |
| 68 | + ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32)) |
| 69 | + .reduce(tf.constant(0, tf.int32), tf.maximum) |
| 70 | + .numpy() |
| 71 | +) |
| 72 | + |
| 73 | +""" |
| 74 | +We also need do know the number of movies as we're using the movie ID directly |
| 75 | +as an index in the movie embedding table. |
| 76 | +""" |
| 77 | + |
| 78 | +movies_count = movies.cardinality().numpy() |
| 79 | + |
| 80 | +""" |
| 81 | +The inputs to the model are the user IDs and movie IDs and the labels are the |
| 82 | +ratings. |
| 83 | +""" |
| 84 | + |
| 85 | + |
| 86 | +def preprocess_rating(x): |
| 87 | + return ( |
| 88 | + # Inputs are user IDs and movie IDs |
| 89 | + { |
| 90 | + "user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32), |
| 91 | + "movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32), |
| 92 | + }, |
| 93 | + # Labels are ratings between 0 and 1. |
| 94 | + (x["user_rating"] - 1.0) / 4.0, |
| 95 | + ) |
| 96 | + |
| 97 | + |
| 98 | +""" |
| 99 | +We'll split the data by putting 80% of the ratings in the train set, and 20% in |
| 100 | +the test set. |
| 101 | +""" |
| 102 | + |
| 103 | +shuffled_ratings = ratings.map(preprocess_rating).shuffle( |
| 104 | + 100_000, seed=42, reshuffle_each_iteration=False |
| 105 | +) |
| 106 | +train_ratings = ( |
| 107 | + shuffled_ratings.take(80_000).batch(BATCH_SIZE, drop_remainder=True).cache() |
| 108 | +) |
| 109 | +test_ratings = ( |
| 110 | + shuffled_ratings.skip(80_000) |
| 111 | + .take(20_000) |
| 112 | + .batch(BATCH_SIZE, drop_remainder=True) |
| 113 | + .cache() |
| 114 | +) |
| 115 | + |
| 116 | +""" |
| 117 | +## Configuring DistributedEmbedding |
| 118 | +
|
| 119 | +The `keras_rs.layers.DistributedEmbedding` handles multiple features and |
| 120 | +multiple embedding tables. This is to enable the sharing of tables between |
| 121 | +features and allow some optimizations that come from combining multiple |
| 122 | +embedding lookups into a single invocation. In this section, we'll describe |
| 123 | +how to configure these. |
| 124 | +
|
| 125 | +### Configuring tables |
| 126 | +
|
| 127 | +Tables are configured using `keras_rs.layers.TableConfig`, which has: |
| 128 | +
|
| 129 | +- A name. |
| 130 | +- A vocabulary size (input size). |
| 131 | +- an embedding dimension (output size). |
| 132 | +- A combiner to specify how to reduce multiple embeddings into a single one in |
| 133 | + the case when we embed a sequence. Note that this doesn't apply to our example |
| 134 | + because we're getting a single embedding for each user and each movie. |
| 135 | +- A placement to tell whether to put the table on the SparseCore chips or not. |
| 136 | + In this case, we want the `"sparsecore"` placement. |
| 137 | +- An optimizer to specify how to apply gradients when training. Each table has |
| 138 | + its own optimizer and the one passed to `model.compile()` is not used for the |
| 139 | + embedding tables. |
| 140 | +
|
| 141 | +### Configuring features |
| 142 | +
|
| 143 | +Features are configured using `keras_rs.layers.FeatureConfig`, which has: |
| 144 | +
|
| 145 | +- A name. |
| 146 | +- A table, the embedding table to use. |
| 147 | +- An input shape (per replica). |
| 148 | +- An output shape (per replica). |
| 149 | +
|
| 150 | +We can organize features in any structure we want, which can be nested. A dict |
| 151 | +is often a good choice to have names for the inputs and outputs. |
| 152 | +""" |
| 153 | + |
| 154 | +EMBEDDING_DIMENSION = 32 |
| 155 | + |
| 156 | +movie_table = keras_rs.layers.TableConfig( |
| 157 | + name="movie_table", |
| 158 | + vocabulary_size=movies_count + 1, # +1 for movie ID 0, which is not used |
| 159 | + embedding_dim=EMBEDDING_DIMENSION, |
| 160 | + optimizer="adam", |
| 161 | + placement="sparsecore", |
| 162 | +) |
| 163 | +user_table = keras_rs.layers.TableConfig( |
| 164 | + name="user_table", |
| 165 | + vocabulary_size=users_count + 1, # +1 for user ID 0, which is not used |
| 166 | + embedding_dim=EMBEDDING_DIMENSION, |
| 167 | + optimizer="adam", |
| 168 | + placement="sparsecore", |
| 169 | +) |
| 170 | + |
| 171 | +FEATURE_CONFIGS = { |
| 172 | + "movie_id": keras_rs.layers.FeatureConfig( |
| 173 | + name="movie", |
| 174 | + table=movie_table, |
| 175 | + input_shape=(BATCH_SIZE,), |
| 176 | + output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION), |
| 177 | + ), |
| 178 | + "user_id": keras_rs.layers.FeatureConfig( |
| 179 | + name="user", |
| 180 | + table=user_table, |
| 181 | + input_shape=(BATCH_SIZE,), |
| 182 | + output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION), |
| 183 | + ), |
| 184 | +} |
| 185 | + |
| 186 | +""" |
| 187 | +## Defining the Model |
| 188 | +
|
| 189 | +We're now ready to create a `DistributedEmbedding` inside a model. Once we have |
| 190 | +the configuration, we simply pass it the constructor of `DistributedEmbedding`. |
| 191 | +Then, within the model `call` method, `DistributedEmbedding` is the first layer |
| 192 | +we call. |
| 193 | +
|
| 194 | +The ouputs have the exact same structure as the inputs. In our example, we |
| 195 | +concatenate the embeddings we got as outputs and run them through a tower of |
| 196 | +dense layers. |
| 197 | +""" |
| 198 | + |
| 199 | + |
| 200 | +class EmbeddingModel(keras.Model): |
| 201 | + """Create the model with the embedding configuration. |
| 202 | +
|
| 203 | + Args: |
| 204 | + feature_configs: the configuration for `DistributedEmbedding`. |
| 205 | + """ |
| 206 | + |
| 207 | + def __init__(self, feature_configs): |
| 208 | + super().__init__() |
| 209 | + |
| 210 | + self.embedding_layer = keras_rs.layers.DistributedEmbedding( |
| 211 | + feature_configs=feature_configs |
| 212 | + ) |
| 213 | + self.ratings = keras.Sequential( |
| 214 | + [ |
| 215 | + # Learn multiple dense layers. |
| 216 | + keras.layers.Dense(256, activation="relu"), |
| 217 | + keras.layers.Dense(64, activation="relu"), |
| 218 | + # Make rating predictions in the final layer. |
| 219 | + keras.layers.Dense(1), |
| 220 | + ] |
| 221 | + ) |
| 222 | + |
| 223 | + def call(self, preprocessed_features): |
| 224 | + # Embedding lookup. Outputs have the same structure as the inputs. |
| 225 | + embedding = self.embedding_layer(preprocessed_features) |
| 226 | + return self.ratings( |
| 227 | + keras.ops.concatenate( |
| 228 | + [embedding["user_id"], embedding["movie_id"]], |
| 229 | + axis=1, |
| 230 | + ) |
| 231 | + ) |
| 232 | + |
| 233 | + |
| 234 | +""" |
| 235 | +Let's now instantiate the model. We then use `model.compile()` to configure the |
| 236 | +loss, metrics and optimizer. Again, this Adagrad optimizer will only apply to |
| 237 | +the dense layers and not the embedding tables. |
| 238 | +""" |
| 239 | + |
| 240 | +model = EmbeddingModel(FEATURE_CONFIGS) |
| 241 | + |
| 242 | +model.compile( |
| 243 | + loss=keras.losses.MeanSquaredError(), |
| 244 | + metrics=[keras.metrics.RootMeanSquaredError()], |
| 245 | + optimizer="adagrad", |
| 246 | +) |
| 247 | + |
| 248 | +""" |
| 249 | +With the JAX backend, we need to preprocess the inputs to convert them to a |
| 250 | +hardware-dependent format required for use with SparseCores. We'll do this by |
| 251 | +wrapping the datasets into generator functions. |
| 252 | +""" |
| 253 | + |
| 254 | + |
| 255 | +def train_dataset_generator(): |
| 256 | + for inputs, labels in iter(train_ratings): |
| 257 | + yield model.embedding_layer.preprocess(inputs, training=True), labels |
| 258 | + |
| 259 | + |
| 260 | +def test_dataset_generator(): |
| 261 | + for inputs, labels in iter(test_ratings): |
| 262 | + yield model.embedding_layer.preprocess(inputs, training=False), labels |
| 263 | + |
| 264 | + |
| 265 | +""" |
| 266 | +## Fitting and evaluating |
| 267 | +
|
| 268 | +We can use the standard Keras `model.fit()` to train the model. Keras will |
| 269 | +automatically use the `TPUStrategy` to distribute the model and the data. |
| 270 | +""" |
| 271 | + |
| 272 | +model.fit(train_dataset_generator(), epochs=5) |
| 273 | + |
| 274 | +""" |
| 275 | +Same for `model.evaluate()`. |
| 276 | +""" |
| 277 | + |
| 278 | +model.evaluate(test_dataset_generator(), return_dict=True) |
| 279 | + |
| 280 | +""" |
| 281 | +That's it. |
| 282 | +
|
| 283 | +This example shows that after configuring the `DistributedEmbedding` and setting |
| 284 | +up the required preprocessing, you can use the standard Keras workflows. |
| 285 | +""" |
0 commit comments