Skip to content

Commit 7a014ee

Browse files
authored
Add KerasRS API documentation and examples (#2083)
* Add KerasRS docs and examples * Small changes in examples * Add all examples * Fix accelerator for DP training * Add imgur images, fix links * Fixes * Add README * Link fixes * Add generated examples * Set backend * Generate all examples * Fix remaining examples * Fix README and dp example * Add logo * Fixes * Fix doc-strings * Make image smaller * Make image smaller * Small home page edits * Fix DCN links * Add correct images to DCN * Add correct images to DCN * Add correct images to DCN * Make titles shorter * Generate all examples * Fixes * Fixes (2) * Add temp path to requirements.txt
1 parent 538dee5 commit 7a014ee

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+79734
-6
lines changed

examples/keras_rs/basic_ranking.py

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
"""
2+
Title: Recommending movies: ranking
3+
Author: [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/)
4+
Date created: 2025/04/28
5+
Last modified: 2025/04/28
6+
Description: Rank movies using a two tower model.
7+
Accelerator: GPU
8+
"""
9+
10+
"""
11+
## Introduction
12+
13+
Recommender systems are often composed of two stages:
14+
15+
1. The retrieval stage is responsible for selecting an initial set of hundreds
16+
of candidates from all possible candidates. The main objective of this model
17+
is to efficiently weed out all candidates that the user is not interested in.
18+
Because the retrieval model may be dealing with millions of candidates, it
19+
has to be computationally efficient.
20+
2. The ranking stage takes the outputs of the retrieval model and fine-tunes
21+
them to select the best possible handful of recommendations. Its task is to
22+
narrow down the set of items the user may be interested in to a shortlist of
23+
likely candidates.
24+
25+
In this tutorial, we're going to focus on the second stage, ranking. If you are
26+
interested in the retrieval stage, have a look at our
27+
[retrieval](/keras_rs/examples/basic_retrieval/)
28+
tutorial.
29+
30+
In this tutorial, we're going to:
31+
32+
1. Get our data and split it into a training and test set.
33+
2. Implement a ranking model.
34+
3. Fit and evaluate it.
35+
4. Test running predictions with the model.
36+
37+
Let's begin by choosing JAX as the backend we want to run on, and import all
38+
the necessary libraries.
39+
"""
40+
41+
import os
42+
43+
os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"`
44+
45+
import keras
46+
import tensorflow as tf # Needed for the dataset
47+
import tensorflow_datasets as tfds
48+
49+
"""
50+
## Preparing the dataset
51+
52+
We're going to use the same data as the
53+
[retrieval](/keras_rs/examples/basic_retrieval/)
54+
tutorial. The ratings are the objectives we are trying to predict.
55+
"""
56+
57+
# Ratings data.
58+
ratings = tfds.load("movielens/100k-ratings", split="train")
59+
# Features of all the available movies.
60+
movies = tfds.load("movielens/100k-movies", split="train")
61+
62+
"""
63+
In the Movielens dataset, user IDs are integers (represented as strings)
64+
starting at 1 and with no gap. Normally, you would need to create a lookup table
65+
to map user IDs to integers from 0 to N-1. But as a simplication, we'll use the
66+
user id directly as an index in our model, in particular to lookup the user
67+
embedding from the user embedding table. So we need do know the number of users.
68+
"""
69+
70+
users_count = (
71+
ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32))
72+
.reduce(tf.constant(0, tf.int32), tf.maximum)
73+
.numpy()
74+
)
75+
76+
"""
77+
In the Movielens dataset, movie IDs are integers (represented as strings)
78+
starting at 1 and with no gap. Normally, you would need to create a lookup table
79+
to map movie IDs to integers from 0 to N-1. But as a simplication, we'll use the
80+
movie id directly as an index in our model, in particular to lookup the movie
81+
embedding from the movie embedding table. So we need do know the number of
82+
movies.
83+
"""
84+
85+
movies_count = movies.cardinality().numpy()
86+
87+
"""
88+
The inputs to the model are the user IDs and movie IDs and the labels are the
89+
ratings.
90+
"""
91+
92+
93+
def preprocess_rating(x):
94+
return (
95+
# Inputs are user IDs and movie IDs
96+
{
97+
"user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32),
98+
"movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32),
99+
},
100+
# Labels are ratings between 0 and 1.
101+
(x["user_rating"] - 1.0) / 4.0,
102+
)
103+
104+
105+
"""
106+
We'll split the data by putting 80% of the ratings in the train set, and 20% in
107+
the test set.
108+
"""
109+
110+
shuffled_ratings = ratings.map(preprocess_rating).shuffle(
111+
100_000, seed=42, reshuffle_each_iteration=False
112+
)
113+
train_ratings = shuffled_ratings.take(80_000).batch(1000).cache()
114+
test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache()
115+
116+
"""
117+
## Implementing the Model
118+
119+
### Architecture
120+
121+
Ranking models do not face the same efficiency constraints as retrieval models
122+
do, and so we have a little bit more freedom in our choice of architectures.
123+
124+
A model composed of multiple stacked dense layers is a relatively common
125+
architecture for ranking tasks. We can implement it as follows:
126+
"""
127+
128+
129+
class RankingModel(keras.Model):
130+
"""Create the ranking model with the provided parameters.
131+
132+
Args:
133+
num_users: Number of entries in the user embedding table.
134+
num_candidates: Number of entries in the candidate embedding table.
135+
embedding_dimension: Output dimension for user and movie embedding tables.
136+
"""
137+
138+
def __init__(
139+
self,
140+
num_users,
141+
num_candidates,
142+
embedding_dimension=32,
143+
**kwargs,
144+
):
145+
super().__init__(**kwargs)
146+
# Embedding table for users.
147+
self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension)
148+
# Embedding table for candidates.
149+
self.candidate_embedding = keras.layers.Embedding(
150+
num_candidates, embedding_dimension
151+
)
152+
# Predictions.
153+
self.ratings = keras.Sequential(
154+
[
155+
# Learn multiple dense layers.
156+
keras.layers.Dense(256, activation="relu"),
157+
keras.layers.Dense(64, activation="relu"),
158+
# Make rating predictions in the final layer.
159+
keras.layers.Dense(1),
160+
]
161+
)
162+
163+
def call(self, inputs):
164+
user_id, movie_id = inputs["user_id"], inputs["movie_id"]
165+
user_embeddings = self.user_embedding(user_id)
166+
candidate_embeddings = self.candidate_embedding(movie_id)
167+
return self.ratings(
168+
keras.ops.concatenate([user_embeddings, candidate_embeddings], axis=1)
169+
)
170+
171+
172+
"""
173+
Let's first instantiate the model. Note that we add `+ 1` to the number of users
174+
and movies to account for the fact that id zero is not used for either (IDs
175+
start at 1), but still takes a row in the embedding tables.
176+
"""
177+
178+
model = RankingModel(users_count + 1, movies_count + 1)
179+
180+
"""
181+
### Loss and metrics
182+
183+
The next component is the loss used to train our model. Keras has several losses
184+
to make this easy. In this instance, we'll make use of the `MeanSquaredError`
185+
loss in order to predict the ratings. We'll also look at the
186+
`RootMeanSquaredError` metric.
187+
"""
188+
189+
model.compile(
190+
loss=keras.losses.MeanSquaredError(),
191+
metrics=[keras.metrics.RootMeanSquaredError()],
192+
optimizer=keras.optimizers.Adagrad(learning_rate=0.1),
193+
)
194+
195+
"""
196+
## Fitting and evaluating
197+
198+
After defining the model, we can use the standard Keras `model.fit()` to train
199+
the model.
200+
"""
201+
202+
model.fit(train_ratings, epochs=5)
203+
204+
"""
205+
As the model trains, the loss is falling and the RMSE metric is improving.
206+
207+
Finally, we can evaluate our model on the test set. The lower the RMSE metric,
208+
the more accurate our model is at predicting ratings.
209+
"""
210+
211+
model.evaluate(test_ratings, return_dict=True)
212+
213+
"""
214+
## Testing the ranking model
215+
216+
So far, we have only handled movies by id. Now is the time to create a mapping
217+
keyed by movie IDs to be able to surface the titles.
218+
"""
219+
220+
movie_id_to_movie_title = {
221+
int(x["movie_id"]): x["movie_title"] for x in movies.as_numpy_iterator()
222+
}
223+
movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset.
224+
225+
"""
226+
Now we can test the ranking model by computing predictions for a set of movies
227+
and then rank these movies based on the predictions:
228+
"""
229+
230+
user_id = 42
231+
movie_ids = [204, 141, 131]
232+
predictions = model.predict(
233+
{
234+
"user_id": keras.ops.array([user_id] * len(movie_ids)),
235+
"movie_id": keras.ops.array(movie_ids),
236+
}
237+
)
238+
predictions = keras.ops.convert_to_numpy(keras.ops.squeeze(predictions, axis=1))
239+
240+
for movie_id, prediction in zip(movie_ids, predictions):
241+
print(f"{movie_id_to_movie_title[movie_id]}: {5.0 * prediction:,.2f}")

0 commit comments

Comments
 (0)