Skip to content

Commit 2893ac8

Browse files
Merge branch 'keras-team:master' into master
2 parents 43793bc + 0f66e3c commit 2893ac8

25 files changed

+3837
-25
lines changed

examples/keras_rs/basic_ranking.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
starting at 1 and with no gap. Normally, you would need to create a lookup table
6565
to map user IDs to integers from 0 to N-1. But as a simplication, we'll use the
6666
user id directly as an index in our model, in particular to lookup the user
67-
embedding from the user embedding table. So we need do know the number of users.
67+
embedding from the user embedding table. So we need to know the number of users.
6868
"""
6969

7070
users_count = (
@@ -78,7 +78,7 @@
7878
starting at 1 and with no gap. Normally, you would need to create a lookup table
7979
to map movie IDs to integers from 0 to N-1. But as a simplication, we'll use the
8080
movie id directly as an index in our model, in particular to lookup the movie
81-
embedding from the movie embedding table. So we need do know the number of
81+
embedding from the movie embedding table. So we need to know the number of
8282
movies.
8383
"""
8484

examples/keras_rs/deep_recommender.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@
1717
usually not be immediately usable in a model.
1818
1919
For example:
20+
2021
- User and item IDs may be strings (titles, usernames) or large, non-contiguous
2122
integers (database IDs).
2223
- Item descriptions could be raw text.
2324
- Interaction timestamps could be raw Unix timestamps.
2425
2526
These need to be appropriately transformed in order to be useful in building
2627
models:
28+
2729
- User and item IDs have to be translated into embedding vectors,
2830
high-dimensional numerical representations that are adjusted during training
2931
to help the model predict its objective better.
Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
"""
2+
Title: DistributedEmbedding using TPU SparseCore and JAX
3+
Author: [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/)
4+
Date created: 2025/06/03
5+
Last modified: 2025/06/03
6+
Description: Rank movies using a two tower model with embeddings on SparseCore.
7+
Accelerator: TPU
8+
"""
9+
10+
"""
11+
## Introduction
12+
13+
In the [basic ranking](/keras_rs/examples/basic_ranking/) tutorial, we showed
14+
how to build a ranking model for the MovieLens dataset to suggest movies to
15+
users.
16+
17+
This tutorial implements the same model trained on the same dataset but with the
18+
use of `keras_rs.layers.DistributedEmbedding`, which makes use of SparseCore on
19+
TPU. This is the JAX version of the tutorial. It needs to be run on TPU v5p or
20+
v6e.
21+
22+
Let's begin by choosing JAX as the backend and importing all the necessary
23+
libraries.
24+
"""
25+
26+
import os
27+
28+
os.environ["KERAS_BACKEND"] = "jax"
29+
30+
import jax
31+
import keras
32+
import keras_rs
33+
import tensorflow as tf # Needed for the dataset
34+
import tensorflow_datasets as tfds
35+
36+
"""
37+
## Dataset distribution
38+
39+
While the model is replicated and the embedding tables are sharded across
40+
SparseCores, the dataset is distributed by sharding each batch across the TPUs.
41+
We need to make sure the batch size is a multiple of the number of TPUs.
42+
"""
43+
44+
PER_REPLICA_BATCH_SIZE = 256
45+
BATCH_SIZE = PER_REPLICA_BATCH_SIZE * jax.local_device_count("tpu")
46+
47+
distribution = keras.distribution.DataParallel(devices=jax.devices("tpu"))
48+
keras.distribution.set_distribution(distribution)
49+
50+
"""
51+
## Preparing the dataset
52+
53+
We're going to use the same Movielens data. The ratings are the objectives we
54+
are trying to predict.
55+
"""
56+
57+
# Ratings data.
58+
ratings = tfds.load("movielens/100k-ratings", split="train")
59+
# Features of all the available movies.
60+
movies = tfds.load("movielens/100k-movies", split="train")
61+
62+
"""
63+
We need to know the number of users as we're using the user ID directly as an
64+
index in the user embedding table.
65+
"""
66+
67+
users_count = (
68+
ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))
69+
.reduce(tf.constant(0, tf.int32), tf.maximum)
70+
.numpy()
71+
)
72+
73+
"""
74+
We also need do know the number of movies as we're using the movie ID directly
75+
as an index in the movie embedding table.
76+
"""
77+
78+
movies_count = movies.cardinality().numpy()
79+
80+
"""
81+
The inputs to the model are the user IDs and movie IDs and the labels are the
82+
ratings.
83+
"""
84+
85+
86+
def preprocess_rating(x):
87+
return (
88+
# Inputs are user IDs and movie IDs
89+
{
90+
"user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32),
91+
"movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32),
92+
},
93+
# Labels are ratings between 0 and 1.
94+
(x["user_rating"] - 1.0) / 4.0,
95+
)
96+
97+
98+
"""
99+
We'll split the data by putting 80% of the ratings in the train set, and 20% in
100+
the test set.
101+
"""
102+
103+
shuffled_ratings = ratings.map(preprocess_rating).shuffle(
104+
100_000, seed=42, reshuffle_each_iteration=False
105+
)
106+
train_ratings = (
107+
shuffled_ratings.take(80_000).batch(BATCH_SIZE, drop_remainder=True).cache()
108+
)
109+
test_ratings = (
110+
shuffled_ratings.skip(80_000)
111+
.take(20_000)
112+
.batch(BATCH_SIZE, drop_remainder=True)
113+
.cache()
114+
)
115+
116+
"""
117+
## Configuring DistributedEmbedding
118+
119+
The `keras_rs.layers.DistributedEmbedding` handles multiple features and
120+
multiple embedding tables. This is to enable the sharing of tables between
121+
features and allow some optimizations that come from combining multiple
122+
embedding lookups into a single invocation. In this section, we'll describe
123+
how to configure these.
124+
125+
### Configuring tables
126+
127+
Tables are configured using `keras_rs.layers.TableConfig`, which has:
128+
129+
- A name.
130+
- A vocabulary size (input size).
131+
- an embedding dimension (output size).
132+
- A combiner to specify how to reduce multiple embeddings into a single one in
133+
the case when we embed a sequence. Note that this doesn't apply to our example
134+
because we're getting a single embedding for each user and each movie.
135+
- A placement to tell whether to put the table on the SparseCore chips or not.
136+
In this case, we want the `"sparsecore"` placement.
137+
- An optimizer to specify how to apply gradients when training. Each table has
138+
its own optimizer and the one passed to `model.compile()` is not used for the
139+
embedding tables.
140+
141+
### Configuring features
142+
143+
Features are configured using `keras_rs.layers.FeatureConfig`, which has:
144+
145+
- A name.
146+
- A table, the embedding table to use.
147+
- An input shape (per replica).
148+
- An output shape (per replica).
149+
150+
We can organize features in any structure we want, which can be nested. A dict
151+
is often a good choice to have names for the inputs and outputs.
152+
"""
153+
154+
EMBEDDING_DIMENSION = 32
155+
156+
movie_table = keras_rs.layers.TableConfig(
157+
name="movie_table",
158+
vocabulary_size=movies_count + 1, # +1 for movie ID 0, which is not used
159+
embedding_dim=EMBEDDING_DIMENSION,
160+
optimizer="adam",
161+
placement="sparsecore",
162+
)
163+
user_table = keras_rs.layers.TableConfig(
164+
name="user_table",
165+
vocabulary_size=users_count + 1, # +1 for user ID 0, which is not used
166+
embedding_dim=EMBEDDING_DIMENSION,
167+
optimizer="adam",
168+
placement="sparsecore",
169+
)
170+
171+
FEATURE_CONFIGS = {
172+
"movie_id": keras_rs.layers.FeatureConfig(
173+
name="movie",
174+
table=movie_table,
175+
input_shape=(BATCH_SIZE,),
176+
output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION),
177+
),
178+
"user_id": keras_rs.layers.FeatureConfig(
179+
name="user",
180+
table=user_table,
181+
input_shape=(BATCH_SIZE,),
182+
output_shape=(BATCH_SIZE, EMBEDDING_DIMENSION),
183+
),
184+
}
185+
186+
"""
187+
## Defining the Model
188+
189+
We're now ready to create a `DistributedEmbedding` inside a model. Once we have
190+
the configuration, we simply pass it the constructor of `DistributedEmbedding`.
191+
Then, within the model `call` method, `DistributedEmbedding` is the first layer
192+
we call.
193+
194+
The ouputs have the exact same structure as the inputs. In our example, we
195+
concatenate the embeddings we got as outputs and run them through a tower of
196+
dense layers.
197+
"""
198+
199+
200+
class EmbeddingModel(keras.Model):
201+
"""Create the model with the embedding configuration.
202+
203+
Args:
204+
feature_configs: the configuration for `DistributedEmbedding`.
205+
"""
206+
207+
def __init__(self, feature_configs):
208+
super().__init__()
209+
210+
self.embedding_layer = keras_rs.layers.DistributedEmbedding(
211+
feature_configs=feature_configs
212+
)
213+
self.ratings = keras.Sequential(
214+
[
215+
# Learn multiple dense layers.
216+
keras.layers.Dense(256, activation="relu"),
217+
keras.layers.Dense(64, activation="relu"),
218+
# Make rating predictions in the final layer.
219+
keras.layers.Dense(1),
220+
]
221+
)
222+
223+
def call(self, preprocessed_features):
224+
# Embedding lookup. Outputs have the same structure as the inputs.
225+
embedding = self.embedding_layer(preprocessed_features)
226+
return self.ratings(
227+
keras.ops.concatenate(
228+
[embedding["user_id"], embedding["movie_id"]],
229+
axis=1,
230+
)
231+
)
232+
233+
234+
"""
235+
Let's now instantiate the model. We then use `model.compile()` to configure the
236+
loss, metrics and optimizer. Again, this Adagrad optimizer will only apply to
237+
the dense layers and not the embedding tables.
238+
"""
239+
240+
model = EmbeddingModel(FEATURE_CONFIGS)
241+
242+
model.compile(
243+
loss=keras.losses.MeanSquaredError(),
244+
metrics=[keras.metrics.RootMeanSquaredError()],
245+
optimizer="adagrad",
246+
)
247+
248+
"""
249+
With the JAX backend, we need to preprocess the inputs to convert them to a
250+
hardware-dependent format required for use with SparseCores. We'll do this by
251+
wrapping the datasets into generator functions.
252+
"""
253+
254+
255+
def train_dataset_generator():
256+
for inputs, labels in iter(train_ratings):
257+
yield model.embedding_layer.preprocess(inputs, training=True), labels
258+
259+
260+
def test_dataset_generator():
261+
for inputs, labels in iter(test_ratings):
262+
yield model.embedding_layer.preprocess(inputs, training=False), labels
263+
264+
265+
"""
266+
## Fitting and evaluating
267+
268+
We can use the standard Keras `model.fit()` to train the model. Keras will
269+
automatically use the `TPUStrategy` to distribute the model and the data.
270+
"""
271+
272+
model.fit(train_dataset_generator(), epochs=5)
273+
274+
"""
275+
Same for `model.evaluate()`.
276+
"""
277+
278+
model.evaluate(test_dataset_generator(), return_dict=True)
279+
280+
"""
281+
That's it.
282+
283+
This example shows that after configuring the `DistributedEmbedding` and setting
284+
up the required preprocessing, you can use the standard Keras workflows.
285+
"""

examples/keras_rs/ipynb/basic_ranking.ipynb

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@
104104
"starting at 1 and with no gap. Normally, you would need to create a lookup table\n",
105105
"to map user IDs to integers from 0 to N-1. But as a simplication, we'll use the\n",
106106
"user id directly as an index in our model, in particular to lookup the user\n",
107-
"embedding from the user embedding table. So we need do know the number of users."
107+
"embedding from the user embedding table. So we need to know the number of users."
108108
]
109109
},
110110
{
@@ -132,7 +132,7 @@
132132
"starting at 1 and with no gap. Normally, you would need to create a lookup table\n",
133133
"to map movie IDs to integers from 0 to N-1. But as a simplication, we'll use the\n",
134134
"movie id directly as an index in our model, in particular to lookup the movie\n",
135-
"embedding from the movie embedding table. So we need do know the number of\n",
135+
"embedding from the movie embedding table. So we need to know the number of\n",
136136
"movies."
137137
]
138138
},
@@ -459,4 +459,4 @@
459459
},
460460
"nbformat": 4,
461461
"nbformat_minor": 0
462-
}
462+
}

examples/keras_rs/ipynb/deep_recommender.ipynb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,15 @@
2929
"usually not be immediately usable in a model.\n",
3030
"\n",
3131
"For example:\n",
32+
"\n",
3233
"- User and item IDs may be strings (titles, usernames) or large, non-contiguous\n",
3334
" integers (database IDs).\n",
3435
"- Item descriptions could be raw text.\n",
3536
"- Interaction timestamps could be raw Unix timestamps.\n",
3637
"\n",
3738
"These need to be appropriately transformed in order to be useful in building\n",
3839
"models:\n",
40+
"\n",
3941
"- User and item IDs have to be translated into embedding vectors,\n",
4042
" high-dimensional numerical representations that are adjusted during training\n",
4143
" to help the model predict its objective better.\n",
@@ -1351,4 +1353,4 @@
13511353
},
13521354
"nbformat": 4,
13531355
"nbformat_minor": 0
1354-
}
1356+
}

0 commit comments

Comments
 (0)