Skip to content

Commit bdfc1bc

Browse files
authored
Keras RS example fixes (#2094)
* Keras RS example fixes * Lower LR for Adam * Doc-string change * Format
1 parent 7a014ee commit bdfc1bc

File tree

5 files changed

+23
-25
lines changed

5 files changed

+23
-25
lines changed

examples/keras_rs/dcn.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,6 @@
122122
"batch_size": 1024,
123123
}
124124

125-
LOOKUP_LAYERS = {
126-
"int": keras.layers.IntegerLookup,
127-
"str": keras.layers.StringLookup,
128-
}
129125

130126
"""
131127
Here, we define a helper function for visualising weights of the cross layer in
@@ -186,9 +182,7 @@ def print_stats(rmse_list, num_params, model_name):
186182
if num_trials == 1:
187183
print(f"{model_name}: RMSE = {avg_rmse}; #params = {num_params}")
188184
else:
189-
print(
190-
f"{model_name}: RMSE = {avg_rmse} ± {std_rmse}; " "#params = {num_params}"
191-
)
185+
print(f"{model_name}: RMSE = {avg_rmse} ± {std_rmse}; #params = {num_params}")
192186

193187

194188
"""
@@ -275,6 +269,7 @@ def get_mixer_data(data_size=100_000):
275269
keras.layers.Dense(512, activation="relu"),
276270
keras.layers.Dense(256, activation="relu"),
277271
keras.layers.Dense(128, activation="relu"),
272+
keras.layers.Dense(1),
278273
]
279274
)
280275

examples/keras_rs/deep_recommender.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,7 @@ def __init__(
668668
self.update_candidates() # Provide an initial set of candidates
669669
self.loss_fn = keras.losses.MeanSquaredError()
670670
self.top_k_metric = keras.metrics.SparseTopKCategoricalAccuracy(
671-
k=100, from_sorted_ids=True
671+
k=retrieval_k, from_sorted_ids=True
672672
)
673673

674674
def update_candidates(self):

examples/keras_rs/listwise_ranking.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -182,16 +182,19 @@ def get_examples(
182182
}
183183
labels = []
184184
for user_id, user_list in sequences.items():
185-
sampled_list = sample_sublist_from_list(
186-
user_list,
187-
num_examples_per_list,
188-
)
189-
190-
inputs["user_id"].append(user_id)
191-
inputs["movie_id"].append(
192-
tf.convert_to_tensor([f["movie_id"] for f in sampled_list])
193-
)
194-
labels.append(tf.convert_to_tensor([f["user_rating"] for f in sampled_list]))
185+
for _ in range(num_list_per_user):
186+
sampled_list = sample_sublist_from_list(
187+
user_list,
188+
num_examples_per_list,
189+
)
190+
191+
inputs["user_id"].append(user_id)
192+
inputs["movie_id"].append(
193+
tf.convert_to_tensor([f["movie_id"] for f in sampled_list])
194+
)
195+
labels.append(
196+
tf.convert_to_tensor([f["user_rating"] for f in sampled_list])
197+
)
195198

196199
return (
197200
{"user_id": inputs["user_id"], "movie_id": inputs["movie_id"]},

examples/keras_rs/multi_task.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def __init__(
143143
)
144144

145145
# Rating model.
146-
self.rating_model = tf.keras.Sequential(
146+
self.rating_model = keras.Sequential(
147147
[
148148
keras.layers.Dense(layer_size, activation="relu")
149149
for layer_size in layer_sizes
@@ -162,7 +162,7 @@ def __init__(
162162

163163
# Top-k accuracy for retrieval
164164
self.top_k_metric = keras.metrics.SparseTopKCategoricalAccuracy(
165-
k=100, from_sorted_ids=True
165+
k=10, from_sorted_ids=True
166166
)
167167
# RMSE for ranking
168168
self.rmse_metric = keras.metrics.RootMeanSquaredError()
@@ -306,7 +306,7 @@ def compute_metrics(self, x, y, y_pred, sample_weight=None):
306306
ranking_loss_wt=1.0,
307307
retrieval_loss_wt=0.0,
308308
)
309-
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))
309+
model.compile(optimizer=keras.optimizers.Adagrad(0.1))
310310
model.fit(train_ratings, epochs=5)
311311

312312
model.evaluate(test_ratings)
@@ -318,7 +318,7 @@ def compute_metrics(self, x, y, y_pred, sample_weight=None):
318318
ranking_loss_wt=0.0,
319319
retrieval_loss_wt=1.0,
320320
)
321-
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))
321+
model.compile(optimizer=keras.optimizers.Adagrad(0.1))
322322
model.fit(train_ratings, epochs=5)
323323

324324
model.evaluate(test_ratings)
@@ -330,7 +330,7 @@ def compute_metrics(self, x, y, y_pred, sample_weight=None):
330330
ranking_loss_wt=1.0,
331331
retrieval_loss_wt=1.0,
332332
)
333-
model.compile(optimizer=tf.keras.optimizers.Adagrad(0.1))
333+
model.compile(optimizer=keras.optimizers.Adagrad(0.1))
334334
model.fit(train_ratings, epochs=5)
335335

336336
model.evaluate(test_ratings)

examples/keras_rs/sequential_retrieval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@
6262
TEST_BATCH_SIZE = 2048
6363
EMBEDDING_DIM = 32
6464
NUM_EPOCHS = 5
65-
LEARNING_RATE = 0.05
65+
LEARNING_RATE = 0.005
6666

6767
"""
6868
## Dataset
@@ -368,7 +368,7 @@ def compute_loss(self, x, y, y_pred, sample_weight, training=True):
368368
"""
369369

370370
model = SequentialRetrievalModel(
371-
movies_count=movies_count + 1, embedding_dimension=EMBEDDING_DIM
371+
movies_count=movies_count, embedding_dimension=EMBEDDING_DIM
372372
)
373373

374374
# Compile.

0 commit comments

Comments
 (0)