Skip to content

Commit 63c4fd8

Browse files
authored
Fix data split strategy in GRU4Rec (#2105)
1 parent 7567602 commit 63c4fd8

File tree

3 files changed

+126
-90
lines changed

3 files changed

+126
-90
lines changed

examples/keras_rs/ipynb/sequential_retrieval.ipynb

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,6 @@
9696
"# Data processing args\n",
9797
"MAX_CONTEXT_LENGTH = 10\n",
9898
"MIN_SEQUENCE_LENGTH = 3\n",
99-
"TRAIN_DATA_FRACTION = 0.9\n",
10099
"\n",
101100
"RATINGS_DATA_COLUMNS = [\"UserID\", \"MovieID\", \"Rating\", \"Timestamp\"]\n",
102101
"MOVIES_DATA_COLUMNS = [\"MovieID\", \"Title\", \"Genres\"]\n",
@@ -251,7 +250,13 @@
251250
" will be the last token.\n",
252251
"3. Remove all user sequences with less than `MIN_SEQUENCE_LENGTH`\n",
253252
" movies.\n",
254-
"4. Pad all sequences to `MAX_CONTEXT_LENGTH`."
253+
"4. Pad all sequences to `MAX_CONTEXT_LENGTH`.\n",
254+
"\n",
255+
"An important point to note is how we form the train-test splits. We do not\n",
256+
"form the entire dataset of sequences and then split it into train and test.\n",
257+
"Instead, for every user, we take the last sequence to be part of the test set,\n",
258+
"and all other sequences to be part of the train set. This is to prevent data\n",
259+
"leakage."
255260
]
256261
},
257262
{
@@ -269,7 +274,8 @@
269274
" def generate_examples_from_user_sequence(sequence):\n",
270275
" \"\"\"Generates examples for a single user sequence.\"\"\"\n",
271276
"\n",
272-
" examples = []\n",
277+
" train_examples = []\n",
278+
" test_examples = []\n",
273279
" for label_idx in range(1, len(sequence)):\n",
274280
" start_idx = max(0, label_idx - MAX_CONTEXT_LENGTH)\n",
275281
" context = sequence[start_idx:label_idx]\n",
@@ -287,24 +293,32 @@
287293
" label_movie_id = int(sequence[label_idx][\"movie_id\"])\n",
288294
" context_movie_id = [int(movie[\"movie_id\"]) for movie in context]\n",
289295
"\n",
290-
" examples.append(\n",
291-
" {\n",
292-
" \"context_movie_id\": context_movie_id,\n",
293-
" \"label_movie_id\": label_movie_id,\n",
294-
" },\n",
295-
" )\n",
296-
" return examples\n",
296+
" example = {\n",
297+
" \"context_movie_id\": context_movie_id,\n",
298+
" \"label_movie_id\": label_movie_id,\n",
299+
" }\n",
300+
"\n",
301+
" if label_idx == len(sequence) - 1:\n",
302+
" test_examples.append(example)\n",
303+
" else:\n",
304+
" train_examples.append(example)\n",
297305
"\n",
298-
" all_examples = []\n",
306+
" return train_examples, test_examples\n",
307+
"\n",
308+
" all_train_examples = []\n",
309+
" all_test_examples = []\n",
299310
" for sequence in sequences.values():\n",
300311
" if len(sequence) < MIN_SEQUENCE_LENGTH:\n",
301312
" continue\n",
302313
"\n",
303-
" user_examples = generate_examples_from_user_sequence(sequence)\n",
314+
" user_train_examples, user_test_example = generate_examples_from_user_sequence(\n",
315+
" sequence\n",
316+
" )\n",
304317
"\n",
305-
" all_examples.extend(user_examples)\n",
318+
" all_train_examples.extend(user_train_examples)\n",
319+
" all_test_examples.extend(user_test_example)\n",
306320
"\n",
307-
" return all_examples\n",
321+
" return all_train_examples, all_test_examples\n",
308322
""
309323
]
310324
},
@@ -328,13 +342,7 @@
328342
"outputs": [],
329343
"source": [
330344
"sequences = get_movie_sequence_per_user(ratings_df)\n",
331-
"examples = generate_examples_from_user_sequences(sequences)\n",
332-
"\n",
333-
"# Train-test split.\n",
334-
"random.shuffle(examples)\n",
335-
"split_index = int(TRAIN_DATA_FRACTION * len(examples))\n",
336-
"train_examples = examples[:split_index]\n",
337-
"test_examples = examples[split_index:]\n",
345+
"train_examples, test_examples = generate_examples_from_user_sequences(sequences)\n",
338346
"\n",
339347
"\n",
340348
"def list_of_dicts_to_dict_of_lists(list_of_dicts):\n",

examples/keras_rs/md/sequential_retrieval.md

Lines changed: 69 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ MOVIES_FILE_NAME = "movies.dat"
6161
# Data processing args
6262
MAX_CONTEXT_LENGTH = 10
6363
MIN_SEQUENCE_LENGTH = 3
64-
TRAIN_DATA_FRACTION = 0.9
6564

6665
RATINGS_DATA_COLUMNS = ["UserID", "MovieID", "Rating", "Timestamp"]
6766
MOVIES_DATA_COLUMNS = ["MovieID", "Title", "Genres"]
@@ -149,11 +148,13 @@ movies_count = movies_df["MovieID"].max()
149148
<div class="k-default-codeblock">
150149
```
151150
Downloading data from https://files.grouplens.org/datasets/movielens/ml-1m.zip
152-
5917549/5917549 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
153151
154-
/var/tmp/ipykernel_688439/1372663084.py:26: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'.
152+
5917549/5917549 ━━━━━━━━━━━━━━━━━━━━ 2s 0us/step
153+
154+
<ipython-input-4-6fc962858754>:26: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'.
155155
ratings_df = pd.read_csv(
156-
/var/tmp/ipykernel_688439/1372663084.py:38: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'.
156+
157+
<ipython-input-4-6fc962858754>:38: ParserWarning: Falling back to the 'python' engine because the 'c' engine does not support regex separators (separators > 1 char and different from '\s+' are interpreted as regex); you can avoid this warning by specifying engine='python'.
157158
movies_df = pd.read_csv(
158159
```
159160
</div>
@@ -199,6 +200,12 @@ with training the model:
199200
movies.
200201
4. Pad all sequences to `MAX_CONTEXT_LENGTH`.
201202

203+
An important point to note is how we form the train-test splits. We do not
204+
form the entire dataset of sequences and then split it into train and test.
205+
Instead, for every user, we take the last sequence to be part of the test set,
206+
and all other sequences to be part of the train set. This is to prevent data
207+
leakage.
208+
202209

203210
```python
204211

@@ -208,7 +215,8 @@ def generate_examples_from_user_sequences(sequences):
208215
def generate_examples_from_user_sequence(sequence):
209216
"""Generates examples for a single user sequence."""
210217

211-
examples = []
218+
train_examples = []
219+
test_examples = []
212220
for label_idx in range(1, len(sequence)):
213221
start_idx = max(0, label_idx - MAX_CONTEXT_LENGTH)
214222
context = sequence[start_idx:label_idx]
@@ -226,24 +234,32 @@ def generate_examples_from_user_sequences(sequences):
226234
label_movie_id = int(sequence[label_idx]["movie_id"])
227235
context_movie_id = [int(movie["movie_id"]) for movie in context]
228236

229-
examples.append(
230-
{
231-
"context_movie_id": context_movie_id,
232-
"label_movie_id": label_movie_id,
233-
},
234-
)
235-
return examples
237+
example = {
238+
"context_movie_id": context_movie_id,
239+
"label_movie_id": label_movie_id,
240+
}
241+
242+
if label_idx == len(sequence) - 1:
243+
test_examples.append(example)
244+
else:
245+
train_examples.append(example)
236246

237-
all_examples = []
247+
return train_examples, test_examples
248+
249+
all_train_examples = []
250+
all_test_examples = []
238251
for sequence in sequences.values():
239252
if len(sequence) < MIN_SEQUENCE_LENGTH:
240253
continue
241254

242-
user_examples = generate_examples_from_user_sequence(sequence)
255+
user_train_examples, user_test_example = generate_examples_from_user_sequence(
256+
sequence
257+
)
243258

244-
all_examples.extend(user_examples)
259+
all_train_examples.extend(user_train_examples)
260+
all_test_examples.extend(user_test_example)
245261

246-
return all_examples
262+
return all_train_examples, all_test_examples
247263

248264
```
249265

@@ -254,13 +270,7 @@ to a `tf.data.Dataset` object.
254270

255271
```python
256272
sequences = get_movie_sequence_per_user(ratings_df)
257-
examples = generate_examples_from_user_sequences(sequences)
258-
259-
# Train-test split.
260-
random.shuffle(examples)
261-
split_index = int(TRAIN_DATA_FRACTION * len(examples))
262-
train_examples = examples[:split_index]
263-
test_examples = examples[split_index:]
273+
train_examples, test_examples = generate_examples_from_user_sequences(sequences)
264274

265275

266276
def list_of_dicts_to_dict_of_lists(list_of_dicts):
@@ -305,13 +315,13 @@ for sample in train_ds.take(1):
305315
<div class="k-default-codeblock">
306316
```
307317
(<tf.Tensor: shape=(4096, 10), dtype=int32, numpy=
308-
array([[3512, 3617, 3623, ..., 3007, 2858, 1617],
309-
[1952, 1206, 1233, ..., 1394, 3683, 593],
310-
[2114, 1274, 2407, ..., 2100, 1257, 2001],
318+
array([[3186, 0, 0, ..., 0, 0, 0],
319+
[3186, 1270, 0, ..., 0, 0, 0],
320+
[3186, 1270, 1721, ..., 0, 0, 0],
311321
...,
312-
[ 81, 2210, 1343, ..., 1625, 1748, 1407],
313-
[ 494, 832, 543, ..., 23, 432, 1682],
314-
[2421, 0, 0, ..., 0, 0, 0]], dtype=int32)>, <tf.Tensor: shape=(4096,), dtype=int32, numpy=array([3265, 1203, 2003, ..., 3044, 367, 110], dtype=int32)>)
322+
[2194, 1291, 2159, ..., 300, 2076, 866],
323+
[1291, 2159, 1012, ..., 2076, 866, 2206],
324+
[2159, 1012, 1092, ..., 866, 2206, 377]], dtype=int32)>, <tf.Tensor: shape=(4096,), dtype=int32, numpy=array([1270, 1721, 1022, ..., 2206, 377, 1357], dtype=int32)>)
315325
```
316326
</div>
317327

@@ -432,17 +442,26 @@ model.fit(
432442
<div class="k-default-codeblock">
433443
```
434444
Epoch 1/5
435-
207/207 ━━━━━━━━━━━━━━━━━━━━ 6s 24ms/step - loss: 7.9460 - val_loss: 6.4827
445+
446+
228/228 ━━━━━━━━━━━━━━━━━━━━ 7s 24ms/step - loss: 7.9319 - val_loss: 6.8823
447+
436448
Epoch 2/5
437-
207/207 ━━━━━━━━━━━━━━━━━━━━ 2s 3ms/step - loss: 7.0764 - val_loss: 6.1424
449+
450+
228/228 ━━━━━━━━━━━━━━━━━━━━ 2s 6ms/step - loss: 7.0997 - val_loss: 6.5517
451+
438452
Epoch 3/5
439-
207/207 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 6.7779 - val_loss: 6.0004
453+
454+
228/228 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - loss: 6.8198 - val_loss: 6.4342
455+
440456
Epoch 4/5
441-
207/207 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 6.6406 - val_loss: 5.9258
457+
458+
228/228 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - loss: 6.6873 - val_loss: 6.3748
459+
442460
Epoch 5/5
443-
207/207 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 6.5584 - val_loss: 5.8826
444461
445-
<keras.src.callbacks.history.History at 0x7fd1506dc670>
462+
228/228 ━━━━━━━━━━━━━━━━━━━━ 1s 5ms/step - loss: 6.6105 - val_loss: 6.3444
463+
464+
<keras.src.callbacks.history.History at 0x795792c69b90>
446465
```
447466
</div>
448467

@@ -487,22 +506,23 @@ for movie_id in predictions[0]:
487506
<div class="k-default-codeblock">
488507
```
489508
==> Movies the user has watched:
490-
Rob Roy (1995), Legends of the Fall (1994), French Kiss (1995), Terminator 2: Judgment Day (1991), Nikita (La Femme Nikita) (1990), Professional, The (a.k.a. Leon: The Professional) (1994), Seven (Se7en) (1995), Fugitive, The (1993), Enemy of the State (1998), Reservoir Dogs (1992)
491-
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 225ms/step
509+
Beauty and the Beast (1991), Tarzan (1999), Close Shave, A (1995), Aladdin (1992), Toy Story (1995), Bug's Life, A (1998), Antz (1998), Hunchback of Notre Dame, The (1996), Hercules (1997), Mulan (1998)
510+
511+
1/1 ━━━━━━━━━━━━━━━━━━━━ 0s 272ms/step
492512
493513
==> Recommended movies for the above sequence:
494-
Red Rock West (1992)
495-
Casino (1995)
496-
Cape Fear (1991)
497-
Simple Plan, A (1998)
498-
Seven (Se7en) (1995)
499-
Hard 8 (a.k.a. Sydney, a.k.a. Hard Eight) (1996)
500-
Primal Fear (1996)
501-
Heat (1995)
502-
Scream (1996)
503-
Zero Effect (1998)
504-
505-
/opt/conda/envs/keras-jax/lib/python3.10/site-packages/keras/src/trainers/epoch_iterator.py:151: UserWarning: Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches. You may need to use the `.repeat()` function when building your dataset.
514+
Hunchback of Notre Dame, The (1996)
515+
Anastasia (1997)
516+
Beavis and Butt-head Do America (1996)
517+
Hercules (1997)
518+
Pocahontas (1995)
519+
Thumbelina (1994)
520+
James and the Giant Peach (1996)
521+
We're Back! A Dinosaur's Story (1993)
522+
Rescuers Down Under, The (1990)
523+
Prince of Egypt, The (1998)
524+
525+
/usr/local/lib/python3.11/dist-packages/keras/src/trainers/epoch_iterator.py:151: UserWarning: Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches. You may need to use the `.repeat()` function when building your dataset.
506526
self._interrupted_warning()
507527
```
508528
</div>

examples/keras_rs/sequential_retrieval.py

Lines changed: 28 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@
5555
# Data processing args
5656
MAX_CONTEXT_LENGTH = 10
5757
MIN_SEQUENCE_LENGTH = 3
58-
TRAIN_DATA_FRACTION = 0.9
5958

6059
RATINGS_DATA_COLUMNS = ["UserID", "MovieID", "Rating", "Timestamp"]
6160
MOVIES_DATA_COLUMNS = ["MovieID", "Title", "Genres"]
@@ -177,6 +176,12 @@ def get_movie_sequence_per_user(ratings_df):
177176
3. Remove all user sequences with less than `MIN_SEQUENCE_LENGTH`
178177
movies.
179178
4. Pad all sequences to `MAX_CONTEXT_LENGTH`.
179+
180+
An important point to note is how we form the train-test splits. We do not
181+
form the entire dataset of sequences and then split it into train and test.
182+
Instead, for every user, we take the last sequence to be part of the test set,
183+
and all other sequences to be part of the train set. This is to prevent data
184+
leakage.
180185
"""
181186

182187

@@ -186,7 +191,8 @@ def generate_examples_from_user_sequences(sequences):
186191
def generate_examples_from_user_sequence(sequence):
187192
"""Generates examples for a single user sequence."""
188193

189-
examples = []
194+
train_examples = []
195+
test_examples = []
190196
for label_idx in range(1, len(sequence)):
191197
start_idx = max(0, label_idx - MAX_CONTEXT_LENGTH)
192198
context = sequence[start_idx:label_idx]
@@ -204,24 +210,32 @@ def generate_examples_from_user_sequence(sequence):
204210
label_movie_id = int(sequence[label_idx]["movie_id"])
205211
context_movie_id = [int(movie["movie_id"]) for movie in context]
206212

207-
examples.append(
208-
{
209-
"context_movie_id": context_movie_id,
210-
"label_movie_id": label_movie_id,
211-
},
212-
)
213-
return examples
213+
example = {
214+
"context_movie_id": context_movie_id,
215+
"label_movie_id": label_movie_id,
216+
}
217+
218+
if label_idx == len(sequence) - 1:
219+
test_examples.append(example)
220+
else:
221+
train_examples.append(example)
214222

215-
all_examples = []
223+
return train_examples, test_examples
224+
225+
all_train_examples = []
226+
all_test_examples = []
216227
for sequence in sequences.values():
217228
if len(sequence) < MIN_SEQUENCE_LENGTH:
218229
continue
219230

220-
user_examples = generate_examples_from_user_sequence(sequence)
231+
user_train_examples, user_test_example = generate_examples_from_user_sequence(
232+
sequence
233+
)
221234

222-
all_examples.extend(user_examples)
235+
all_train_examples.extend(user_train_examples)
236+
all_test_examples.extend(user_test_example)
223237

224-
return all_examples
238+
return all_train_examples, all_test_examples
225239

226240

227241
"""
@@ -230,13 +244,7 @@ def generate_examples_from_user_sequence(sequence):
230244
to a `tf.data.Dataset` object.
231245
"""
232246
sequences = get_movie_sequence_per_user(ratings_df)
233-
examples = generate_examples_from_user_sequences(sequences)
234-
235-
# Train-test split.
236-
random.shuffle(examples)
237-
split_index = int(TRAIN_DATA_FRACTION * len(examples))
238-
train_examples = examples[:split_index]
239-
test_examples = examples[split_index:]
247+
train_examples, test_examples = generate_examples_from_user_sequences(sequences)
240248

241249

242250
def list_of_dicts_to_dict_of_lists(list_of_dicts):

0 commit comments

Comments
 (0)