|
| 1 | +""" |
| 2 | +Title: Recommending movies: ranking |
| 3 | +Author: [Fabien Hertschuh](https://github.com/hertschuh/), [Abheesht Sharma](https://github.com/abheesht17/) |
| 4 | +Date created: 2025/04/28 |
| 5 | +Last modified: 2025/04/28 |
| 6 | +Description: Rank movies using a two tower model. |
| 7 | +Accelerator: GPU |
| 8 | +""" |
| 9 | + |
| 10 | +""" |
| 11 | +## Introduction |
| 12 | +
|
| 13 | +Recommender systems are often composed of two stages: |
| 14 | +
|
| 15 | +1. The retrieval stage is responsible for selecting an initial set of hundreds |
| 16 | + of candidates from all possible candidates. The main objective of this model |
| 17 | + is to efficiently weed out all candidates that the user is not interested in. |
| 18 | + Because the retrieval model may be dealing with millions of candidates, it |
| 19 | + has to be computationally efficient. |
| 20 | +2. The ranking stage takes the outputs of the retrieval model and fine-tunes |
| 21 | + them to select the best possible handful of recommendations. Its task is to |
| 22 | + narrow down the set of items the user may be interested in to a shortlist of |
| 23 | + likely candidates. |
| 24 | +
|
| 25 | +In this tutorial, we're going to focus on the second stage, ranking. If you are |
| 26 | +interested in the retrieval stage, have a look at our |
| 27 | +[retrieval](/keras_rs/examples/basic_retrieval/) |
| 28 | +tutorial. |
| 29 | +
|
| 30 | +In this tutorial, we're going to: |
| 31 | +
|
| 32 | +1. Get our data and split it into a training and test set. |
| 33 | +2. Implement a ranking model. |
| 34 | +3. Fit and evaluate it. |
| 35 | +4. Test running predictions with the model. |
| 36 | +
|
| 37 | +Let's begin by choosing JAX as the backend we want to run on, and import all |
| 38 | +the necessary libraries. |
| 39 | +""" |
| 40 | + |
| 41 | +import os |
| 42 | + |
| 43 | +os.environ["KERAS_BACKEND"] = "jax" # `"tensorflow"`/`"torch"` |
| 44 | + |
| 45 | +import keras |
| 46 | +import tensorflow as tf # Needed for the dataset |
| 47 | +import tensorflow_datasets as tfds |
| 48 | + |
| 49 | +""" |
| 50 | +## Preparing the dataset |
| 51 | +
|
| 52 | +We're going to use the same data as the |
| 53 | +[retrieval](/keras_rs/examples/basic_retrieval/) |
| 54 | +tutorial. The ratings are the objectives we are trying to predict. |
| 55 | +""" |
| 56 | + |
| 57 | +# Ratings data. |
| 58 | +ratings = tfds.load("movielens/100k-ratings", split="train") |
| 59 | +# Features of all the available movies. |
| 60 | +movies = tfds.load("movielens/100k-movies", split="train") |
| 61 | + |
| 62 | +""" |
| 63 | +In the Movielens dataset, user IDs are integers (represented as strings) |
| 64 | +starting at 1 and with no gap. Normally, you would need to create a lookup table |
| 65 | +to map user IDs to integers from 0 to N-1. But as a simplication, we'll use the |
| 66 | +user id directly as an index in our model, in particular to lookup the user |
| 67 | +embedding from the user embedding table. So we need do know the number of users. |
| 68 | +""" |
| 69 | + |
| 70 | +users_count = ( |
| 71 | + ratings.map(lambda x: tf.strings.to_number(x["user_id"], out_type=tf.int32)) |
| 72 | + .reduce(tf.constant(0, tf.int32), tf.maximum) |
| 73 | + .numpy() |
| 74 | +) |
| 75 | + |
| 76 | +""" |
| 77 | +In the Movielens dataset, movie IDs are integers (represented as strings) |
| 78 | +starting at 1 and with no gap. Normally, you would need to create a lookup table |
| 79 | +to map movie IDs to integers from 0 to N-1. But as a simplication, we'll use the |
| 80 | +movie id directly as an index in our model, in particular to lookup the movie |
| 81 | +embedding from the movie embedding table. So we need do know the number of |
| 82 | +movies. |
| 83 | +""" |
| 84 | + |
| 85 | +movies_count = movies.cardinality().numpy() |
| 86 | + |
| 87 | +""" |
| 88 | +The inputs to the model are the user IDs and movie IDs and the labels are the |
| 89 | +ratings. |
| 90 | +""" |
| 91 | + |
| 92 | + |
| 93 | +def preprocess_rating(x): |
| 94 | + return ( |
| 95 | + # Inputs are user IDs and movie IDs |
| 96 | + { |
| 97 | + "user_id": tf.strings.to_number(x["user_id"], out_type=tf.int32), |
| 98 | + "movie_id": tf.strings.to_number(x["movie_id"], out_type=tf.int32), |
| 99 | + }, |
| 100 | + # Labels are ratings between 0 and 1. |
| 101 | + (x["user_rating"] - 1.0) / 4.0, |
| 102 | + ) |
| 103 | + |
| 104 | + |
| 105 | +""" |
| 106 | +We'll split the data by putting 80% of the ratings in the train set, and 20% in |
| 107 | +the test set. |
| 108 | +""" |
| 109 | + |
| 110 | +shuffled_ratings = ratings.map(preprocess_rating).shuffle( |
| 111 | + 100_000, seed=42, reshuffle_each_iteration=False |
| 112 | +) |
| 113 | +train_ratings = shuffled_ratings.take(80_000).batch(1000).cache() |
| 114 | +test_ratings = shuffled_ratings.skip(80_000).take(20_000).batch(1000).cache() |
| 115 | + |
| 116 | +""" |
| 117 | +## Implementing the Model |
| 118 | +
|
| 119 | +### Architecture |
| 120 | +
|
| 121 | +Ranking models do not face the same efficiency constraints as retrieval models |
| 122 | +do, and so we have a little bit more freedom in our choice of architectures. |
| 123 | +
|
| 124 | +A model composed of multiple stacked dense layers is a relatively common |
| 125 | +architecture for ranking tasks. We can implement it as follows: |
| 126 | +""" |
| 127 | + |
| 128 | + |
| 129 | +class RankingModel(keras.Model): |
| 130 | + """Create the ranking model with the provided parameters. |
| 131 | +
|
| 132 | + Args: |
| 133 | + num_users: Number of entries in the user embedding table. |
| 134 | + num_candidates: Number of entries in the candidate embedding table. |
| 135 | + embedding_dimension: Output dimension for user and movie embedding tables. |
| 136 | + """ |
| 137 | + |
| 138 | + def __init__( |
| 139 | + self, |
| 140 | + num_users, |
| 141 | + num_candidates, |
| 142 | + embedding_dimension=32, |
| 143 | + **kwargs, |
| 144 | + ): |
| 145 | + super().__init__(**kwargs) |
| 146 | + # Embedding table for users. |
| 147 | + self.user_embedding = keras.layers.Embedding(num_users, embedding_dimension) |
| 148 | + # Embedding table for candidates. |
| 149 | + self.candidate_embedding = keras.layers.Embedding( |
| 150 | + num_candidates, embedding_dimension |
| 151 | + ) |
| 152 | + # Predictions. |
| 153 | + self.ratings = keras.Sequential( |
| 154 | + [ |
| 155 | + # Learn multiple dense layers. |
| 156 | + keras.layers.Dense(256, activation="relu"), |
| 157 | + keras.layers.Dense(64, activation="relu"), |
| 158 | + # Make rating predictions in the final layer. |
| 159 | + keras.layers.Dense(1), |
| 160 | + ] |
| 161 | + ) |
| 162 | + |
| 163 | + def call(self, inputs): |
| 164 | + user_id, movie_id = inputs["user_id"], inputs["movie_id"] |
| 165 | + user_embeddings = self.user_embedding(user_id) |
| 166 | + candidate_embeddings = self.candidate_embedding(movie_id) |
| 167 | + return self.ratings( |
| 168 | + keras.ops.concatenate([user_embeddings, candidate_embeddings], axis=1) |
| 169 | + ) |
| 170 | + |
| 171 | + |
| 172 | +""" |
| 173 | +Let's first instantiate the model. Note that we add `+ 1` to the number of users |
| 174 | +and movies to account for the fact that id zero is not used for either (IDs |
| 175 | +start at 1), but still takes a row in the embedding tables. |
| 176 | +""" |
| 177 | + |
| 178 | +model = RankingModel(users_count + 1, movies_count + 1) |
| 179 | + |
| 180 | +""" |
| 181 | +### Loss and metrics |
| 182 | +
|
| 183 | +The next component is the loss used to train our model. Keras has several losses |
| 184 | +to make this easy. In this instance, we'll make use of the `MeanSquaredError` |
| 185 | +loss in order to predict the ratings. We'll also look at the |
| 186 | +`RootMeanSquaredError` metric. |
| 187 | +""" |
| 188 | + |
| 189 | +model.compile( |
| 190 | + loss=keras.losses.MeanSquaredError(), |
| 191 | + metrics=[keras.metrics.RootMeanSquaredError()], |
| 192 | + optimizer=keras.optimizers.Adagrad(learning_rate=0.1), |
| 193 | +) |
| 194 | + |
| 195 | +""" |
| 196 | +## Fitting and evaluating |
| 197 | +
|
| 198 | +After defining the model, we can use the standard Keras `model.fit()` to train |
| 199 | +the model. |
| 200 | +""" |
| 201 | + |
| 202 | +model.fit(train_ratings, epochs=5) |
| 203 | + |
| 204 | +""" |
| 205 | +As the model trains, the loss is falling and the RMSE metric is improving. |
| 206 | +
|
| 207 | +Finally, we can evaluate our model on the test set. The lower the RMSE metric, |
| 208 | +the more accurate our model is at predicting ratings. |
| 209 | +""" |
| 210 | + |
| 211 | +model.evaluate(test_ratings, return_dict=True) |
| 212 | + |
| 213 | +""" |
| 214 | +## Testing the ranking model |
| 215 | +
|
| 216 | +So far, we have only handled movies by id. Now is the time to create a mapping |
| 217 | +keyed by movie IDs to be able to surface the titles. |
| 218 | +""" |
| 219 | + |
| 220 | +movie_id_to_movie_title = { |
| 221 | + int(x["movie_id"]): x["movie_title"] for x in movies.as_numpy_iterator() |
| 222 | +} |
| 223 | +movie_id_to_movie_title[0] = "" # Because id 0 is not in the dataset. |
| 224 | + |
| 225 | +""" |
| 226 | +Now we can test the ranking model by computing predictions for a set of movies |
| 227 | +and then rank these movies based on the predictions: |
| 228 | +""" |
| 229 | + |
| 230 | +user_id = 42 |
| 231 | +movie_ids = [204, 141, 131] |
| 232 | +predictions = model.predict( |
| 233 | + { |
| 234 | + "user_id": keras.ops.array([user_id] * len(movie_ids)), |
| 235 | + "movie_id": keras.ops.array(movie_ids), |
| 236 | + } |
| 237 | +) |
| 238 | +predictions = keras.ops.convert_to_numpy(keras.ops.squeeze(predictions, axis=1)) |
| 239 | + |
| 240 | +for movie_id, prediction in zip(movie_ids, predictions): |
| 241 | + print(f"{movie_id_to_movie_title[movie_id]}: {5.0 * prediction:,.2f}") |
0 commit comments