Skip to content

Commit 98429b5

Browse files
adapting the script movielens_recommendations_transformers.py to be Backend-Agnostic (#2039)
* adapting the script movielens_recommendations_transformers.py to be Backend-Agnostic * adjusting the script to be more modular * introducing some efficiency * addressing PR comments * removed the extensive refactoring * generating the .md and .ipynb files
1 parent ece8130 commit 98429b5

File tree

3 files changed

+3803
-3764
lines changed

3 files changed

+3803
-3764
lines changed

examples/structured_data/ipynb/movielens_recommendations_transformers.ipynb

Lines changed: 118 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"\n",
1111
"**Author:** [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)<br>\n",
1212
"**Date created:** 2020/12/30<br>\n",
13-
"**Last modified:** 2025/01/03<br>\n",
13+
"**Last modified:** 2025/01/27<br>\n",
1414
"**Description:** Rating rate prediction using the Behavior Sequence Transformer (BST) model on the Movielens."
1515
]
1616
},
@@ -82,17 +82,16 @@
8282
"source": [
8383
"import os\n",
8484
"\n",
85-
"os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
85+
"os.environ[\"KERAS_BACKEND\"] = \"jax\" # or torch, or tensorflow\n",
8686
"\n",
8787
"import math\n",
8888
"from zipfile import ZipFile\n",
8989
"from urllib.request import urlretrieve\n",
90-
"\n",
91-
"import keras\n",
9290
"import numpy as np\n",
9391
"import pandas as pd\n",
94-
"import tensorflow as tf\n",
95-
"from keras import layers\n",
92+
"\n",
93+
"import keras\n",
94+
"from keras import layers, ops\n",
9695
"from keras.layers import StringLookup"
9796
]
9897
},
@@ -408,7 +407,8 @@
408407
"\n",
409408
"USER_FEATURES = [\"sex\", \"age_group\", \"occupation\"]\n",
410409
"\n",
411-
"MOVIE_FEATURES = [\"genres\"]"
410+
"MOVIE_FEATURES = [\"genres\"]\n",
411+
""
412412
]
413413
},
414414
{
@@ -417,7 +417,30 @@
417417
"colab_type": "text"
418418
},
419419
"source": [
420-
"## Create `tf.data.Dataset` for training and evaluation"
420+
"## Encode input features\n",
421+
"\n",
422+
"The `encode_input_features` function works as follows:\n",
423+
"\n",
424+
"1. Each categorical user feature is encoded using `layers.Embedding`, with embedding\n",
425+
"dimension equals to the square root of the vocabulary size of the feature.\n",
426+
"The embeddings of these features are concatenated to form a single input tensor.\n",
427+
"\n",
428+
"2. Each movie in the movie sequence and the target movie is encoded `layers.Embedding`,\n",
429+
"where the dimension size is the square root of the number of movies.\n",
430+
"\n",
431+
"3. A multi-hot genres vector for each movie is concatenated with its embedding vector,\n",
432+
"and processed using a non-linear `layers.Dense` to output a vector of the same movie\n",
433+
"embedding dimensions.\n",
434+
"\n",
435+
"4. A positional embedding is added to each movie embedding in the sequence, and then\n",
436+
"multiplied by its rating from the ratings sequence.\n",
437+
"\n",
438+
"5. The target movie embedding is concatenated to the sequence movie embeddings, producing\n",
439+
"a tensor with the shape of `[batch size, sequence length, embedding size]`, as expected\n",
440+
"by the attention layer for the transformer architecture.\n",
441+
"\n",
442+
"6. The method returns a tuple of two elements: `encoded_transformer_features` and\n",
443+
"`encoded_other_features`."
421444
]
422445
},
423446
{
@@ -428,25 +451,60 @@
428451
},
429452
"outputs": [],
430453
"source": [
454+
"# Required for tf.data.Dataset\n",
455+
"import tensorflow as tf\n",
456+
"\n",
431457
"\n",
432458
"def get_dataset_from_csv(csv_file_path, batch_size, shuffle=True):\n",
459+
"\n",
433460
" def process(features):\n",
434461
" movie_ids_string = features[\"sequence_movie_ids\"]\n",
435462
" sequence_movie_ids = tf.strings.split(movie_ids_string, \",\").to_tensor()\n",
436-
"\n",
437463
" # The last movie id in the sequence is the target movie.\n",
438464
" features[\"target_movie_id\"] = sequence_movie_ids[:, -1]\n",
439465
" features[\"sequence_movie_ids\"] = sequence_movie_ids[:, :-1]\n",
440-
"\n",
466+
" # Sequence ratings\n",
441467
" ratings_string = features[\"sequence_ratings\"]\n",
442468
" sequence_ratings = tf.strings.to_number(\n",
443469
" tf.strings.split(ratings_string, \",\"), tf.dtypes.float32\n",
444470
" ).to_tensor()\n",
445-
"\n",
446471
" # The last rating in the sequence is the target for the model to predict.\n",
447472
" target = sequence_ratings[:, -1]\n",
448473
" features[\"sequence_ratings\"] = sequence_ratings[:, :-1]\n",
449474
"\n",
475+
" def encoding_helper(feature_name):\n",
476+
"\n",
477+
" # This are target_movie_id and sequence_movie_ids and they have the same\n",
478+
" # vocabulary as movie_id.\n",
479+
" if feature_name not in CATEGORICAL_FEATURES_WITH_VOCABULARY:\n",
480+
" vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[\"movie_id\"]\n",
481+
" index_lookup = StringLookup(\n",
482+
" vocabulary=vocabulary, mask_token=None, num_oov_indices=0\n",
483+
" )\n",
484+
" # Convert the string input values into integer indices.\n",
485+
" value_index = index_lookup(features[feature_name])\n",
486+
" features[feature_name] = value_index\n",
487+
" else:\n",
488+
" # movie_id is not part of the features, hence not processed. It was mainly required\n",
489+
" # for its vocabulary above.\n",
490+
" if feature_name == \"movie_id\":\n",
491+
" pass\n",
492+
" else:\n",
493+
" vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]\n",
494+
" index_lookup = StringLookup(\n",
495+
" vocabulary=vocabulary, mask_token=None, num_oov_indices=0\n",
496+
" )\n",
497+
" # Convert the string input values into integer indices.\n",
498+
" value_index = index_lookup(features[feature_name])\n",
499+
" features[feature_name] = value_index\n",
500+
"\n",
501+
" # Encode the user features\n",
502+
" for feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:\n",
503+
" encoding_helper(feature_name)\n",
504+
" # Encoding target_movie_id and returning it as the target variable\n",
505+
" encoding_helper(\"target_movie_id\")\n",
506+
" # Encoding sequence movie_ids.\n",
507+
" encoding_helper(\"sequence_movie_ids\")\n",
450508
" return dict(features), target\n",
451509
"\n",
452510
" dataset = tf.data.experimental.make_csv_dataset(\n",
@@ -458,94 +516,14 @@
458516
" field_delim=\"|\",\n",
459517
" shuffle=shuffle,\n",
460518
" ).map(process)\n",
461-
"\n",
462519
" return dataset\n",
463-
""
464-
]
465-
},
466-
{
467-
"cell_type": "markdown",
468-
"metadata": {
469-
"colab_type": "text"
470-
},
471-
"source": [
472-
"## Create model inputs"
473-
]
474-
},
475-
{
476-
"cell_type": "code",
477-
"execution_count": 0,
478-
"metadata": {
479-
"colab_type": "code"
480-
},
481-
"outputs": [],
482-
"source": [
483-
"\n",
484-
"def create_model_inputs():\n",
485-
" return {\n",
486-
" \"user_id\": keras.Input(name=\"user_id\", shape=(1,), dtype=\"string\"),\n",
487-
" \"sequence_movie_ids\": keras.Input(\n",
488-
" name=\"sequence_movie_ids\", shape=(sequence_length - 1,), dtype=\"string\"\n",
489-
" ),\n",
490-
" \"target_movie_id\": keras.Input(\n",
491-
" name=\"target_movie_id\", shape=(1,), dtype=\"string\"\n",
492-
" ),\n",
493-
" \"sequence_ratings\": keras.Input(\n",
494-
" name=\"sequence_ratings\", shape=(sequence_length - 1,), dtype=tf.float32\n",
495-
" ),\n",
496-
" \"sex\": keras.Input(name=\"sex\", shape=(1,), dtype=\"string\"),\n",
497-
" \"age_group\": keras.Input(name=\"age_group\", shape=(1,), dtype=\"string\"),\n",
498-
" \"occupation\": keras.Input(name=\"occupation\", shape=(1,), dtype=\"string\"),\n",
499-
" }\n",
500-
""
501-
]
502-
},
503-
{
504-
"cell_type": "markdown",
505-
"metadata": {
506-
"colab_type": "text"
507-
},
508-
"source": [
509-
"## Encode input features\n",
510-
"\n",
511-
"The `encode_input_features` method works as follows:\n",
512-
"\n",
513-
"1. Each categorical user feature is encoded using `layers.Embedding`, with embedding\n",
514-
"dimension equals to the square root of the vocabulary size of the feature.\n",
515-
"The embeddings of these features are concatenated to form a single input tensor.\n",
516520
"\n",
517-
"2. Each movie in the movie sequence and the target movie is encoded `layers.Embedding`,\n",
518-
"where the dimension size is the square root of the number of movies.\n",
519-
"\n",
520-
"3. A multi-hot genres vector for each movie is concatenated with its embedding vector,\n",
521-
"and processed using a non-linear `layers.Dense` to output a vector of the same movie\n",
522-
"embedding dimensions.\n",
523-
"\n",
524-
"4. A positional embedding is added to each movie embedding in the sequence, and then\n",
525-
"multiplied by its rating from the ratings sequence.\n",
526-
"\n",
527-
"5. The target movie embedding is concatenated to the sequence movie embeddings, producing\n",
528-
"a tensor with the shape of `[batch size, sequence length, embedding size]`, as expected\n",
529-
"by the attention layer for the transformer architecture.\n",
530-
"\n",
531-
"6. The method returns a tuple of two elements: `encoded_transformer_features` and\n",
532-
"`encoded_other_features`."
533-
]
534-
},
535-
{
536-
"cell_type": "code",
537-
"execution_count": 0,
538-
"metadata": {
539-
"colab_type": "code"
540-
},
541-
"outputs": [],
542-
"source": [
543521
"\n",
544522
"def encode_input_features(\n",
545523
" inputs,\n",
546-
" include_user_id=True,\n",
547-
" include_user_features=True,\n",
548-
" include_movie_features=True,\n",
524+
" include_user_id,\n",
525+
" include_user_features,\n",
526+
" include_movie_features,\n",
549527
"):\n",
550528
" encoded_transformer_features = []\n",
551529
" encoded_other_features = []\n",
@@ -558,11 +536,7 @@
558536
"\n",
559537
" ## Encode user features\n",
560538
" for feature_name in other_feature_names:\n",
561-
" # Convert the string input values into integer indices.\n",
562539
" vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]\n",
563-
" idx = StringLookup(vocabulary=vocabulary, mask_token=None, num_oov_indices=0)(\n",
564-
" inputs[feature_name]\n",
565-
" )\n",
566540
" # Compute embedding dimensions\n",
567541
" embedding_dims = int(math.sqrt(len(vocabulary)))\n",
568542
" # Create an embedding layer with the specified dimensions.\n",
@@ -572,7 +546,7 @@
572546
" name=f\"{feature_name}_embedding\",\n",
573547
" )\n",
574548
" # Convert the index values to embedding representations.\n",
575-
" encoded_other_features.append(embedding_encoder(idx))\n",
549+
" encoded_other_features.append(embedding_encoder(inputs[feature_name]))\n",
576550
"\n",
577551
" ## Create a single embedding vector for the user features\n",
578552
" if len(encoded_other_features) > 1:\n",
@@ -585,13 +559,6 @@
585559
" ## Create a movie embedding encoder\n",
586560
" movie_vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[\"movie_id\"]\n",
587561
" movie_embedding_dims = int(math.sqrt(len(movie_vocabulary)))\n",
588-
" # Create a lookup to convert string values to integer indices.\n",
589-
" movie_index_lookup = StringLookup(\n",
590-
" vocabulary=movie_vocabulary,\n",
591-
" mask_token=None,\n",
592-
" num_oov_indices=0,\n",
593-
" name=\"movie_index_lookup\",\n",
594-
" )\n",
595562
" # Create an embedding layer with the specified dimensions.\n",
596563
" movie_embedding_encoder = layers.Embedding(\n",
597564
" input_dim=len(movie_vocabulary),\n",
@@ -617,11 +584,10 @@
617584
" ## Define a function to encode a given movie id.\n",
618585
" def encode_movie(movie_id):\n",
619586
" # Convert the string input values into integer indices.\n",
620-
" movie_idx = movie_index_lookup(movie_id)\n",
621-
" movie_embedding = movie_embedding_encoder(movie_idx)\n",
587+
" movie_embedding = movie_embedding_encoder(movie_id)\n",
622588
" encoded_movie = movie_embedding\n",
623589
" if include_movie_features:\n",
624-
" movie_genres_vector = movie_genres_lookup(movie_idx)\n",
590+
" movie_genres_vector = movie_genres_lookup(movie_id)\n",
625591
" encoded_movie = movie_embedding_processor(\n",
626592
" layers.concatenate([movie_embedding, movie_genres_vector])\n",
627593
" )\n",
@@ -640,11 +606,11 @@
640606
" output_dim=movie_embedding_dims,\n",
641607
" name=\"position_embedding\",\n",
642608
" )\n",
643-
" positions = tf.range(start=0, limit=sequence_length - 1, delta=1)\n",
609+
" positions = ops.arange(start=0, stop=sequence_length - 1, step=1)\n",
644610
" encodded_positions = position_embedding_encoder(positions)\n",
645611
" # Retrieve sequence ratings to incorporate them into the encoding of the movie.\n",
646612
" sequence_ratings = inputs[\"sequence_ratings\"]\n",
647-
" sequence_ratings = keras.ops.expand_dims(sequence_ratings, -1)\n",
613+
" sequence_ratings = ops.expand_dims(sequence_ratings, -1)\n",
648614
" # Add the positional encoding to the movie encodings and multiply them by rating.\n",
649615
" encoded_sequence_movies_with_poistion_and_rating = layers.Multiply()(\n",
650616
" [(encoded_sequence_movies + encodded_positions), sequence_ratings]\n",
@@ -653,18 +619,53 @@
653619
" # Construct the transformer inputs.\n",
654620
" for i in range(sequence_length - 1):\n",
655621
" feature = encoded_sequence_movies_with_poistion_and_rating[:, i, ...]\n",
656-
" feature = keras.ops.expand_dims(feature, 1)\n",
622+
" feature = ops.expand_dims(feature, 1)\n",
657623
" encoded_transformer_features.append(feature)\n",
658624
" encoded_transformer_features.append(encoded_target_movie)\n",
659-
"\n",
660625
" encoded_transformer_features = layers.concatenate(\n",
661626
" encoded_transformer_features, axis=1\n",
662627
" )\n",
663-
"\n",
664628
" return encoded_transformer_features, encoded_other_features\n",
665629
""
666630
]
667631
},
632+
{
633+
"cell_type": "markdown",
634+
"metadata": {
635+
"colab_type": "text"
636+
},
637+
"source": [
638+
"## Create model inputs"
639+
]
640+
},
641+
{
642+
"cell_type": "code",
643+
"execution_count": 0,
644+
"metadata": {
645+
"colab_type": "code"
646+
},
647+
"outputs": [],
648+
"source": [
649+
"\n",
650+
"def create_model_inputs():\n",
651+
" return {\n",
652+
" \"user_id\": keras.Input(name=\"user_id\", shape=(1,), dtype=\"int32\"),\n",
653+
" \"sequence_movie_ids\": keras.Input(\n",
654+
" name=\"sequence_movie_ids\", shape=(sequence_length - 1,), dtype=\"int32\"\n",
655+
" ),\n",
656+
" \"target_movie_id\": keras.Input(\n",
657+
" name=\"target_movie_id\", shape=(1,), dtype=\"int32\"\n",
658+
" ),\n",
659+
" \"sequence_ratings\": keras.Input(\n",
660+
" name=\"sequence_ratings\", shape=(sequence_length - 1,), dtype=\"float32\"\n",
661+
" ),\n",
662+
" \"sex\": keras.Input(name=\"sex\", shape=(1,), dtype=\"int32\"),\n",
663+
" \"age_group\": keras.Input(name=\"age_group\", shape=(1,), dtype=\"int32\"),\n",
664+
" \"occupation\": keras.Input(name=\"occupation\", shape=(1,), dtype=\"int32\"),\n",
665+
" }\n",
666+
""
667+
]
668+
},
668669
{
669670
"cell_type": "markdown",
670671
"metadata": {
@@ -692,11 +693,11 @@
692693
"\n",
693694
"\n",
694695
"def create_model():\n",
696+
"\n",
695697
" inputs = create_model_inputs()\n",
696698
" transformer_features, other_features = encode_input_features(\n",
697699
" inputs, include_user_id, include_user_features, include_movie_features\n",
698700
" )\n",
699-
"\n",
700701
" # Create a multi-headed attention layer.\n",
701702
" attention_output = layers.MultiHeadAttention(\n",
702703
" num_heads=num_heads, key_dim=transformer_features.shape[2], dropout=dropout_rate\n",
@@ -713,7 +714,7 @@
713714
" transformer_features = layers.LayerNormalization()(transformer_features)\n",
714715
" features = layers.Flatten()(transformer_features)\n",
715716
"\n",
716-
" # Included the other features.\n",
717+
" # Included the other_features.\n",
717718
" if other_features is not None:\n",
718719
" features = layers.concatenate(\n",
719720
" [features, layers.Reshape([other_features.shape[-1]])(other_features)]\n",
@@ -725,7 +726,6 @@
725726
" features = layers.BatchNormalization()(features)\n",
726727
" features = layers.LeakyReLU()(features)\n",
727728
" features = layers.Dropout(dropout_rate)(features)\n",
728-
"\n",
729729
" outputs = layers.Dense(units=1)(features)\n",
730730
" model = keras.Model(inputs=inputs, outputs=outputs)\n",
731731
" return model\n",
@@ -759,6 +759,7 @@
759759
")\n",
760760
"\n",
761761
"# Read the training data.\n",
762+
"\n",
762763
"train_dataset = get_dataset_from_csv(\"train_data.csv\", batch_size=265, shuffle=True)\n",
763764
"\n",
764765
"# Fit the model with the training data.\n",

0 commit comments

Comments
 (0)