Skip to content

Commit 802bb88

Browse files
committed
Add DistributedEmbedding example for TPU on TensorFlow.
1 parent 7894001 commit 802bb88

File tree

4 files changed

+1320
-0
lines changed

4 files changed

+1320
-0
lines changed
Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
"""
2+
Title: DistributedEmbedding using TPU SparseCore and TensorFlow
3+
Author: [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/)
4+
Date created: 2025/09/02
5+
Last modified: 2025/09/02
6+
Description: Rank movies using a two tower model with embeddings on SparseCore.
7+
Accelerator: TPU
8+
"""
9+
10+
"""
11+
## Introduction
12+
13+
In the [basic ranking](/keras_rs/examples/basic_ranking/) tutorial, we showed
14+
how to build a ranking model for the MovieLens dataset to suggest movies to
15+
users.
16+
17+
This tutorial implements the same model trained on the same dataset but with the
18+
use of `keras_rs.layers.DistributedEmbedding`, which makes use of SparseCore on
19+
TPU. This is the TensorFlow version of the tutorial. It needs to be run on TPU
20+
v5p or v6e.
21+
22+
Let's begin by installing the necessary libraries. Note that we need
23+
`tensorflow-tpu` version 2.19. We'll also install `keras-rs`.
24+
"""
25+
26+
"""shell
27+
pip install -U -q tensorflow-tpu==2.19.1
28+
pip install -U -q keras-rs
29+
"""
30+
31+
"""
32+
We're using the PJRT version of the runtime for TensorFlow. We're also enabling
33+
the MLIR bridge. This requires setting a few flags before importing tensorflow.
34+
"""
35+
36+
import os
37+
import libtpu
38+
39+
os.environ["PJRT_DEVICE"] = "TPU"
40+
os.environ["NEXT_PLUGGABLE_DEVICE_USE_C_API"] = "true"
41+
os.environ["TF_PLUGGABLE_DEVICE_LIBRARY_PATH"] = libtpu.get_library_path()
42+
os.environ["TF_XLA_FLAGS"] = (
43+
"--tf_mlir_enable_mlir_bridge=true "
44+
"--tf_mlir_enable_convert_control_to_data_outputs_pass=true "
45+
"--tf_mlir_enable_merge_control_flow_pass=true"
46+
)
47+
48+
import tensorflow as tf
49+
50+
"""
51+
We're now ready to import `keras` and `keras-rs`. But we need to set the
52+
backend to TensorFlow.
53+
"""
54+
55+
os.environ["KERAS_BACKEND"] = "tensorflow"
56+
57+
import keras
58+
import keras_rs
59+
import tensorflow_datasets as tfds
60+
61+
"""
62+
## Creating a TPUStrategy
63+
64+
To run TensorFlow on TPU, you need to use a `tf.distribute.TPUStrategy` to
65+
handle the distribution of the model.
66+
67+
Note that the core of the model is replicated across TPU instances. Only the
68+
embedding tables handled by `DistributedEmbedding` are sharded across the
69+
SparseCore chips of all the available TPUs.
70+
"""
71+
72+
resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu="local")
73+
topology = tf.tpu.experimental.initialize_tpu_system(resolver)
74+
tpu_metadata = resolver.get_tpu_system_metadata()
75+
76+
device_assignment = tf.tpu.experimental.DeviceAssignment.build(
77+
topology, num_replicas=tpu_metadata.num_cores
78+
)
79+
strategy = tf.distribute.TPUStrategy(
80+
resolver, experimental_device_assignment=device_assignment
81+
)
82+
83+
"""
84+
## Dataset distribution
85+
86+
While the model is replicated and the embedding tables are sharded across
87+
SparseCores, the dataset is distributed by sharding each batch across the TPUs.
88+
We need to make sure the batch size is a multiple of the number of TPUs.
89+
"""
90+
91+
PER_REPLICA_BATCH_SIZE = 256
92+
BATCH_SIZE = PER_REPLICA_BATCH_SIZE * strategy.num_replicas_in_sync
93+
94+
"""
95+
## Preparing the dataset
96+
97+
We're going to use the same Movielens data. The ratings are the objectives we
98+
are trying to predict.
99+
"""
100+
101+
# Ratings data.
102+
ratings = tfds.load("movielens/100k-ratings", split="train")
103+
# Features of all the available movies.
104+
movies = tfds.load("movielens/100k-movies", split="train")
105+
106+
"""
107+
We need to know the number of users as we're using the user ID directly as an
108+
index in the user embedding table.
109+
"""
110+
111+
users_count = int(
112+
ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))
113+
.reduce(tf.constant(0, tf.int32), tf.maximum)
114+
.numpy()
115+
)
116+
117+
"""
118+
We also need do know the number of movies as we're using the movie ID directly
119+
as an index in the movie embedding table.
120+
"""
121+
122+
movies_count = int(movies.cardinality().numpy())
123+
124+
"""
125+
The inputs to the model are the user IDs and movie IDs and the labels are the
126+
ratings.
127+
"""
128+
129+
130+
def preprocess_rating(x):
131+
return (
132+
# Inputs are user IDs and movie IDs
133+
{
134+
"user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32),
135+
"movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32),
136+
},
137+
# Labels are ratings between 0 and 1.
138+
(x["user_rating"] - 1.0) / 4.0,
139+
)
140+
141+
142+
"""
143+
We'll split the data by putting 80% of the ratings in the train set, and 20% in
144+
the test set.
145+
"""
146+
147+
shuffled_ratings = ratings.map(preprocess_rating).shuffle(
148+
100_000, seed=42, reshuffle_each_iteration=False
149+
)
150+
train_ratings = (
151+
shuffled_ratings.take(80_000).batch(BATCH_SIZE, drop_remainder=True).cache()
152+
)
153+
test_ratings = (
154+
shuffled_ratings.skip(80_000)
155+
.take(20_000)
156+
.batch(BATCH_SIZE, drop_remainder=True)
157+
.cache()
158+
)
159+
160+
"""
161+
## Configuring DistributedEmbedding
162+
163+
The `keras_rs.layers.DistributedEmbedding` handles multiple features and
164+
multiple embedding tables. This is to enable the sharing of tables between
165+
features and allow some optimizations that come from combining multiple
166+
embedding lookups into a single invocation. In this section, we'll describe
167+
how to configure these.
168+
169+
### Configuring tables
170+
171+
Tables are configured using `keras_rs.layers.TableConfig`, which has:
172+
173+
- A name.
174+
- A vocabulary size (input size).
175+
- an embedding dimension (output size).
176+
- A combiner to specify how to reduce multiple embeddings into a single one in
177+
the case when we embed a sequence. Note that this doesn't apply to our example
178+
because we're getting a single embedding for each user and each movie.
179+
- A placement to tell whether to put the table on the SparseCore chips or not.
180+
In this case, we want the `"sparsecore"` placement.
181+
- An optimizer to specify how to apply gradients when training. Each table has
182+
its own optimizer and the one passed to `model.compile()` is not used for the
183+
embedding tables.
184+
185+
### Configuring features
186+
187+
Features are configured using `keras_rs.layers.FeatureConfig`, which has:
188+
189+
- A name.
190+
- A table, the embedding table to use.
191+
- An input shape (per replica).
192+
- An output shape (per replica).
193+
194+
We can organize features in any structure we want, which can be nested. A dict
195+
is often a good choice to have names for the inputs and outputs.
196+
"""
197+
198+
EMBEDDING_DIMENSION = 32
199+
200+
movie_table = keras_rs.layers.TableConfig(
201+
name="movie_table",
202+
vocabulary_size=movies_count + 1, # +1 for movie ID 0, which is not used
203+
embedding_dim=EMBEDDING_DIMENSION,
204+
optimizer="adam",
205+
placement="sparsecore",
206+
)
207+
user_table = keras_rs.layers.TableConfig(
208+
name="user_table",
209+
vocabulary_size=users_count + 1, # +1 for user ID 0, which is not used
210+
embedding_dim=EMBEDDING_DIMENSION,
211+
optimizer="adam",
212+
placement="sparsecore",
213+
)
214+
215+
FEATURE_CONFIGS = {
216+
"movie_id": keras_rs.layers.FeatureConfig(
217+
name="movie",
218+
table=movie_table,
219+
input_shape=(BATCH_SIZE,),
220+
output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION),
221+
),
222+
"user_id": keras_rs.layers.FeatureConfig(
223+
name="user",
224+
table=user_table,
225+
input_shape=(BATCH_SIZE,),
226+
output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION),
227+
),
228+
}
229+
230+
"""
231+
## Defining the Model
232+
233+
We're now ready to create a `DistributedEmbedding` inside a model. Once we have
234+
the configuration, we simply pass it the constructor of `DistributedEmbedding`.
235+
Then, within the model `call` method, `DistributedEmbedding` is the first layer
236+
we call.
237+
238+
The ouputs have the exact same structure as the inputs. In our example, we
239+
concatenate the embeddings we got as outputs and run them through a tower of
240+
dense layers.
241+
"""
242+
243+
244+
class EmbeddingModel(keras.Model):
245+
"""Create the model with the embedding configuration.
246+
247+
Args:
248+
feature_configs: the configuration for `DistributedEmbedding`.
249+
"""
250+
251+
def __init__(self, feature_configs):
252+
super().__init__()
253+
254+
self.embedding_layer = keras_rs.layers.DistributedEmbedding(
255+
feature_configs=feature_configs
256+
)
257+
self.ratings = keras.Sequential(
258+
[
259+
# Learn multiple dense layers.
260+
keras.layers.Dense(256, activation="relu"),
261+
keras.layers.Dense(64, activation="relu"),
262+
# Make rating predictions in the final layer.
263+
keras.layers.Dense(1),
264+
]
265+
)
266+
267+
def call(self, features):
268+
# Embedding lookup. Outputs have the same structure as the inputs.
269+
embedding = self.embedding_layer(features)
270+
return self.ratings(
271+
keras.ops.concatenate(
272+
[embedding["user_id"], embedding["movie_id"]],
273+
axis=1,
274+
)
275+
)
276+
277+
278+
"""
279+
Let's now instantiate the model. We then use `model.compile()` to configure the
280+
loss, metrics and optimizer. Again, this Adagrad optimizer will only apply to
281+
the dense layers and not the embedding tables.
282+
"""
283+
284+
with strategy.scope():
285+
model = EmbeddingModel(FEATURE_CONFIGS)
286+
287+
model.compile(
288+
loss=keras.losses.MeanSquaredError(),
289+
metrics=[keras.metrics.RootMeanSquaredError()],
290+
optimizer="adagrad",
291+
)
292+
293+
"""
294+
## Fitting and evaluating
295+
296+
We can use the standard Keras `model.fit()` to train the model. Keras will
297+
automatically use the `TPUStrategy` to distribute the model and the data.
298+
"""
299+
300+
with strategy.scope():
301+
model.fit(train_ratings, epochs=5)
302+
303+
"""
304+
Same for `model.evaluate()`.
305+
"""
306+
307+
with strategy.scope():
308+
model.evaluate(test_ratings, return_dict=True)
309+
310+
"""
311+
That's it.
312+
313+
This example shows that after setting up the `TPUStrategy` and configuring the
314+
`DistributedEmbedding`, you can use the standard Keras workflows.
315+
"""

0 commit comments

Comments
 (0)