|
10 | 10 | "\n", |
11 | 11 | "**Author:** [Khalid Salama](https://www.linkedin.com/in/khalid-salama-24403144/)<br>\n", |
12 | 12 | "**Date created:** 2020/12/30<br>\n", |
13 | | - "**Last modified:** 2025/01/03<br>\n", |
| 13 | + "**Last modified:** 2025/01/27<br>\n", |
14 | 14 | "**Description:** Rating rate prediction using the Behavior Sequence Transformer (BST) model on the Movielens." |
15 | 15 | ] |
16 | 16 | }, |
|
82 | 82 | "source": [ |
83 | 83 | "import os\n", |
84 | 84 | "\n", |
85 | | - "os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n", |
| 85 | + "os.environ[\"KERAS_BACKEND\"] = \"jax\" # or torch, or tensorflow\n", |
86 | 86 | "\n", |
87 | 87 | "import math\n", |
88 | 88 | "from zipfile import ZipFile\n", |
89 | 89 | "from urllib.request import urlretrieve\n", |
90 | | - "\n", |
91 | | - "import keras\n", |
92 | 90 | "import numpy as np\n", |
93 | 91 | "import pandas as pd\n", |
94 | | - "import tensorflow as tf\n", |
95 | | - "from keras import layers\n", |
| 92 | + "\n", |
| 93 | + "import keras\n", |
| 94 | + "from keras import layers, ops\n", |
96 | 95 | "from keras.layers import StringLookup" |
97 | 96 | ] |
98 | 97 | }, |
|
408 | 407 | "\n", |
409 | 408 | "USER_FEATURES = [\"sex\", \"age_group\", \"occupation\"]\n", |
410 | 409 | "\n", |
411 | | - "MOVIE_FEATURES = [\"genres\"]" |
| 410 | + "MOVIE_FEATURES = [\"genres\"]\n", |
| 411 | + "" |
412 | 412 | ] |
413 | 413 | }, |
414 | 414 | { |
|
417 | 417 | "colab_type": "text" |
418 | 418 | }, |
419 | 419 | "source": [ |
420 | | - "## Create `tf.data.Dataset` for training and evaluation" |
| 420 | + "## Encode input features\n", |
| 421 | + "\n", |
| 422 | + "The `encode_input_features` function works as follows:\n", |
| 423 | + "\n", |
| 424 | + "1. Each categorical user feature is encoded using `layers.Embedding`, with embedding\n", |
| 425 | + "dimension equals to the square root of the vocabulary size of the feature.\n", |
| 426 | + "The embeddings of these features are concatenated to form a single input tensor.\n", |
| 427 | + "\n", |
| 428 | + "2. Each movie in the movie sequence and the target movie is encoded `layers.Embedding`,\n", |
| 429 | + "where the dimension size is the square root of the number of movies.\n", |
| 430 | + "\n", |
| 431 | + "3. A multi-hot genres vector for each movie is concatenated with its embedding vector,\n", |
| 432 | + "and processed using a non-linear `layers.Dense` to output a vector of the same movie\n", |
| 433 | + "embedding dimensions.\n", |
| 434 | + "\n", |
| 435 | + "4. A positional embedding is added to each movie embedding in the sequence, and then\n", |
| 436 | + "multiplied by its rating from the ratings sequence.\n", |
| 437 | + "\n", |
| 438 | + "5. The target movie embedding is concatenated to the sequence movie embeddings, producing\n", |
| 439 | + "a tensor with the shape of `[batch size, sequence length, embedding size]`, as expected\n", |
| 440 | + "by the attention layer for the transformer architecture.\n", |
| 441 | + "\n", |
| 442 | + "6. The method returns a tuple of two elements: `encoded_transformer_features` and\n", |
| 443 | + "`encoded_other_features`." |
421 | 444 | ] |
422 | 445 | }, |
423 | 446 | { |
|
428 | 451 | }, |
429 | 452 | "outputs": [], |
430 | 453 | "source": [ |
| 454 | + "# Required for tf.data.Dataset\n", |
| 455 | + "import tensorflow as tf\n", |
| 456 | + "\n", |
431 | 457 | "\n", |
432 | 458 | "def get_dataset_from_csv(csv_file_path, batch_size, shuffle=True):\n", |
| 459 | + "\n", |
433 | 460 | " def process(features):\n", |
434 | 461 | " movie_ids_string = features[\"sequence_movie_ids\"]\n", |
435 | 462 | " sequence_movie_ids = tf.strings.split(movie_ids_string, \",\").to_tensor()\n", |
436 | | - "\n", |
437 | 463 | " # The last movie id in the sequence is the target movie.\n", |
438 | 464 | " features[\"target_movie_id\"] = sequence_movie_ids[:, -1]\n", |
439 | 465 | " features[\"sequence_movie_ids\"] = sequence_movie_ids[:, :-1]\n", |
440 | | - "\n", |
| 466 | + " # Sequence ratings\n", |
441 | 467 | " ratings_string = features[\"sequence_ratings\"]\n", |
442 | 468 | " sequence_ratings = tf.strings.to_number(\n", |
443 | 469 | " tf.strings.split(ratings_string, \",\"), tf.dtypes.float32\n", |
444 | 470 | " ).to_tensor()\n", |
445 | | - "\n", |
446 | 471 | " # The last rating in the sequence is the target for the model to predict.\n", |
447 | 472 | " target = sequence_ratings[:, -1]\n", |
448 | 473 | " features[\"sequence_ratings\"] = sequence_ratings[:, :-1]\n", |
449 | 474 | "\n", |
| 475 | + " def encoding_helper(feature_name):\n", |
| 476 | + "\n", |
| 477 | + " # This are target_movie_id and sequence_movie_ids and they have the same\n", |
| 478 | + " # vocabulary as movie_id.\n", |
| 479 | + " if feature_name not in CATEGORICAL_FEATURES_WITH_VOCABULARY:\n", |
| 480 | + " vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[\"movie_id\"]\n", |
| 481 | + " index_lookup = StringLookup(\n", |
| 482 | + " vocabulary=vocabulary, mask_token=None, num_oov_indices=0\n", |
| 483 | + " )\n", |
| 484 | + " # Convert the string input values into integer indices.\n", |
| 485 | + " value_index = index_lookup(features[feature_name])\n", |
| 486 | + " features[feature_name] = value_index\n", |
| 487 | + " else:\n", |
| 488 | + " # movie_id is not part of the features, hence not processed. It was mainly required\n", |
| 489 | + " # for its vocabulary above.\n", |
| 490 | + " if feature_name == \"movie_id\":\n", |
| 491 | + " pass\n", |
| 492 | + " else:\n", |
| 493 | + " vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]\n", |
| 494 | + " index_lookup = StringLookup(\n", |
| 495 | + " vocabulary=vocabulary, mask_token=None, num_oov_indices=0\n", |
| 496 | + " )\n", |
| 497 | + " # Convert the string input values into integer indices.\n", |
| 498 | + " value_index = index_lookup(features[feature_name])\n", |
| 499 | + " features[feature_name] = value_index\n", |
| 500 | + "\n", |
| 501 | + " # Encode the user features\n", |
| 502 | + " for feature_name in CATEGORICAL_FEATURES_WITH_VOCABULARY:\n", |
| 503 | + " encoding_helper(feature_name)\n", |
| 504 | + " # Encoding target_movie_id and returning it as the target variable\n", |
| 505 | + " encoding_helper(\"target_movie_id\")\n", |
| 506 | + " # Encoding sequence movie_ids.\n", |
| 507 | + " encoding_helper(\"sequence_movie_ids\")\n", |
450 | 508 | " return dict(features), target\n", |
451 | 509 | "\n", |
452 | 510 | " dataset = tf.data.experimental.make_csv_dataset(\n", |
|
458 | 516 | " field_delim=\"|\",\n", |
459 | 517 | " shuffle=shuffle,\n", |
460 | 518 | " ).map(process)\n", |
461 | | - "\n", |
462 | 519 | " return dataset\n", |
463 | | - "" |
464 | | - ] |
465 | | - }, |
466 | | - { |
467 | | - "cell_type": "markdown", |
468 | | - "metadata": { |
469 | | - "colab_type": "text" |
470 | | - }, |
471 | | - "source": [ |
472 | | - "## Create model inputs" |
473 | | - ] |
474 | | - }, |
475 | | - { |
476 | | - "cell_type": "code", |
477 | | - "execution_count": 0, |
478 | | - "metadata": { |
479 | | - "colab_type": "code" |
480 | | - }, |
481 | | - "outputs": [], |
482 | | - "source": [ |
483 | | - "\n", |
484 | | - "def create_model_inputs():\n", |
485 | | - " return {\n", |
486 | | - " \"user_id\": keras.Input(name=\"user_id\", shape=(1,), dtype=\"string\"),\n", |
487 | | - " \"sequence_movie_ids\": keras.Input(\n", |
488 | | - " name=\"sequence_movie_ids\", shape=(sequence_length - 1,), dtype=\"string\"\n", |
489 | | - " ),\n", |
490 | | - " \"target_movie_id\": keras.Input(\n", |
491 | | - " name=\"target_movie_id\", shape=(1,), dtype=\"string\"\n", |
492 | | - " ),\n", |
493 | | - " \"sequence_ratings\": keras.Input(\n", |
494 | | - " name=\"sequence_ratings\", shape=(sequence_length - 1,), dtype=tf.float32\n", |
495 | | - " ),\n", |
496 | | - " \"sex\": keras.Input(name=\"sex\", shape=(1,), dtype=\"string\"),\n", |
497 | | - " \"age_group\": keras.Input(name=\"age_group\", shape=(1,), dtype=\"string\"),\n", |
498 | | - " \"occupation\": keras.Input(name=\"occupation\", shape=(1,), dtype=\"string\"),\n", |
499 | | - " }\n", |
500 | | - "" |
501 | | - ] |
502 | | - }, |
503 | | - { |
504 | | - "cell_type": "markdown", |
505 | | - "metadata": { |
506 | | - "colab_type": "text" |
507 | | - }, |
508 | | - "source": [ |
509 | | - "## Encode input features\n", |
510 | | - "\n", |
511 | | - "The `encode_input_features` method works as follows:\n", |
512 | | - "\n", |
513 | | - "1. Each categorical user feature is encoded using `layers.Embedding`, with embedding\n", |
514 | | - "dimension equals to the square root of the vocabulary size of the feature.\n", |
515 | | - "The embeddings of these features are concatenated to form a single input tensor.\n", |
516 | 520 | "\n", |
517 | | - "2. Each movie in the movie sequence and the target movie is encoded `layers.Embedding`,\n", |
518 | | - "where the dimension size is the square root of the number of movies.\n", |
519 | | - "\n", |
520 | | - "3. A multi-hot genres vector for each movie is concatenated with its embedding vector,\n", |
521 | | - "and processed using a non-linear `layers.Dense` to output a vector of the same movie\n", |
522 | | - "embedding dimensions.\n", |
523 | | - "\n", |
524 | | - "4. A positional embedding is added to each movie embedding in the sequence, and then\n", |
525 | | - "multiplied by its rating from the ratings sequence.\n", |
526 | | - "\n", |
527 | | - "5. The target movie embedding is concatenated to the sequence movie embeddings, producing\n", |
528 | | - "a tensor with the shape of `[batch size, sequence length, embedding size]`, as expected\n", |
529 | | - "by the attention layer for the transformer architecture.\n", |
530 | | - "\n", |
531 | | - "6. The method returns a tuple of two elements: `encoded_transformer_features` and\n", |
532 | | - "`encoded_other_features`." |
533 | | - ] |
534 | | - }, |
535 | | - { |
536 | | - "cell_type": "code", |
537 | | - "execution_count": 0, |
538 | | - "metadata": { |
539 | | - "colab_type": "code" |
540 | | - }, |
541 | | - "outputs": [], |
542 | | - "source": [ |
543 | 521 | "\n", |
544 | 522 | "def encode_input_features(\n", |
545 | 523 | " inputs,\n", |
546 | | - " include_user_id=True,\n", |
547 | | - " include_user_features=True,\n", |
548 | | - " include_movie_features=True,\n", |
| 524 | + " include_user_id,\n", |
| 525 | + " include_user_features,\n", |
| 526 | + " include_movie_features,\n", |
549 | 527 | "):\n", |
550 | 528 | " encoded_transformer_features = []\n", |
551 | 529 | " encoded_other_features = []\n", |
|
558 | 536 | "\n", |
559 | 537 | " ## Encode user features\n", |
560 | 538 | " for feature_name in other_feature_names:\n", |
561 | | - " # Convert the string input values into integer indices.\n", |
562 | 539 | " vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[feature_name]\n", |
563 | | - " idx = StringLookup(vocabulary=vocabulary, mask_token=None, num_oov_indices=0)(\n", |
564 | | - " inputs[feature_name]\n", |
565 | | - " )\n", |
566 | 540 | " # Compute embedding dimensions\n", |
567 | 541 | " embedding_dims = int(math.sqrt(len(vocabulary)))\n", |
568 | 542 | " # Create an embedding layer with the specified dimensions.\n", |
|
572 | 546 | " name=f\"{feature_name}_embedding\",\n", |
573 | 547 | " )\n", |
574 | 548 | " # Convert the index values to embedding representations.\n", |
575 | | - " encoded_other_features.append(embedding_encoder(idx))\n", |
| 549 | + " encoded_other_features.append(embedding_encoder(inputs[feature_name]))\n", |
576 | 550 | "\n", |
577 | 551 | " ## Create a single embedding vector for the user features\n", |
578 | 552 | " if len(encoded_other_features) > 1:\n", |
|
585 | 559 | " ## Create a movie embedding encoder\n", |
586 | 560 | " movie_vocabulary = CATEGORICAL_FEATURES_WITH_VOCABULARY[\"movie_id\"]\n", |
587 | 561 | " movie_embedding_dims = int(math.sqrt(len(movie_vocabulary)))\n", |
588 | | - " # Create a lookup to convert string values to integer indices.\n", |
589 | | - " movie_index_lookup = StringLookup(\n", |
590 | | - " vocabulary=movie_vocabulary,\n", |
591 | | - " mask_token=None,\n", |
592 | | - " num_oov_indices=0,\n", |
593 | | - " name=\"movie_index_lookup\",\n", |
594 | | - " )\n", |
595 | 562 | " # Create an embedding layer with the specified dimensions.\n", |
596 | 563 | " movie_embedding_encoder = layers.Embedding(\n", |
597 | 564 | " input_dim=len(movie_vocabulary),\n", |
|
617 | 584 | " ## Define a function to encode a given movie id.\n", |
618 | 585 | " def encode_movie(movie_id):\n", |
619 | 586 | " # Convert the string input values into integer indices.\n", |
620 | | - " movie_idx = movie_index_lookup(movie_id)\n", |
621 | | - " movie_embedding = movie_embedding_encoder(movie_idx)\n", |
| 587 | + " movie_embedding = movie_embedding_encoder(movie_id)\n", |
622 | 588 | " encoded_movie = movie_embedding\n", |
623 | 589 | " if include_movie_features:\n", |
624 | | - " movie_genres_vector = movie_genres_lookup(movie_idx)\n", |
| 590 | + " movie_genres_vector = movie_genres_lookup(movie_id)\n", |
625 | 591 | " encoded_movie = movie_embedding_processor(\n", |
626 | 592 | " layers.concatenate([movie_embedding, movie_genres_vector])\n", |
627 | 593 | " )\n", |
|
640 | 606 | " output_dim=movie_embedding_dims,\n", |
641 | 607 | " name=\"position_embedding\",\n", |
642 | 608 | " )\n", |
643 | | - " positions = tf.range(start=0, limit=sequence_length - 1, delta=1)\n", |
| 609 | + " positions = ops.arange(start=0, stop=sequence_length - 1, step=1)\n", |
644 | 610 | " encodded_positions = position_embedding_encoder(positions)\n", |
645 | 611 | " # Retrieve sequence ratings to incorporate them into the encoding of the movie.\n", |
646 | 612 | " sequence_ratings = inputs[\"sequence_ratings\"]\n", |
647 | | - " sequence_ratings = keras.ops.expand_dims(sequence_ratings, -1)\n", |
| 613 | + " sequence_ratings = ops.expand_dims(sequence_ratings, -1)\n", |
648 | 614 | " # Add the positional encoding to the movie encodings and multiply them by rating.\n", |
649 | 615 | " encoded_sequence_movies_with_poistion_and_rating = layers.Multiply()(\n", |
650 | 616 | " [(encoded_sequence_movies + encodded_positions), sequence_ratings]\n", |
|
653 | 619 | " # Construct the transformer inputs.\n", |
654 | 620 | " for i in range(sequence_length - 1):\n", |
655 | 621 | " feature = encoded_sequence_movies_with_poistion_and_rating[:, i, ...]\n", |
656 | | - " feature = keras.ops.expand_dims(feature, 1)\n", |
| 622 | + " feature = ops.expand_dims(feature, 1)\n", |
657 | 623 | " encoded_transformer_features.append(feature)\n", |
658 | 624 | " encoded_transformer_features.append(encoded_target_movie)\n", |
659 | | - "\n", |
660 | 625 | " encoded_transformer_features = layers.concatenate(\n", |
661 | 626 | " encoded_transformer_features, axis=1\n", |
662 | 627 | " )\n", |
663 | | - "\n", |
664 | 628 | " return encoded_transformer_features, encoded_other_features\n", |
665 | 629 | "" |
666 | 630 | ] |
667 | 631 | }, |
| 632 | + { |
| 633 | + "cell_type": "markdown", |
| 634 | + "metadata": { |
| 635 | + "colab_type": "text" |
| 636 | + }, |
| 637 | + "source": [ |
| 638 | + "## Create model inputs" |
| 639 | + ] |
| 640 | + }, |
| 641 | + { |
| 642 | + "cell_type": "code", |
| 643 | + "execution_count": 0, |
| 644 | + "metadata": { |
| 645 | + "colab_type": "code" |
| 646 | + }, |
| 647 | + "outputs": [], |
| 648 | + "source": [ |
| 649 | + "\n", |
| 650 | + "def create_model_inputs():\n", |
| 651 | + " return {\n", |
| 652 | + " \"user_id\": keras.Input(name=\"user_id\", shape=(1,), dtype=\"int32\"),\n", |
| 653 | + " \"sequence_movie_ids\": keras.Input(\n", |
| 654 | + " name=\"sequence_movie_ids\", shape=(sequence_length - 1,), dtype=\"int32\"\n", |
| 655 | + " ),\n", |
| 656 | + " \"target_movie_id\": keras.Input(\n", |
| 657 | + " name=\"target_movie_id\", shape=(1,), dtype=\"int32\"\n", |
| 658 | + " ),\n", |
| 659 | + " \"sequence_ratings\": keras.Input(\n", |
| 660 | + " name=\"sequence_ratings\", shape=(sequence_length - 1,), dtype=\"float32\"\n", |
| 661 | + " ),\n", |
| 662 | + " \"sex\": keras.Input(name=\"sex\", shape=(1,), dtype=\"int32\"),\n", |
| 663 | + " \"age_group\": keras.Input(name=\"age_group\", shape=(1,), dtype=\"int32\"),\n", |
| 664 | + " \"occupation\": keras.Input(name=\"occupation\", shape=(1,), dtype=\"int32\"),\n", |
| 665 | + " }\n", |
| 666 | + "" |
| 667 | + ] |
| 668 | + }, |
668 | 669 | { |
669 | 670 | "cell_type": "markdown", |
670 | 671 | "metadata": { |
|
692 | 693 | "\n", |
693 | 694 | "\n", |
694 | 695 | "def create_model():\n", |
| 696 | + "\n", |
695 | 697 | " inputs = create_model_inputs()\n", |
696 | 698 | " transformer_features, other_features = encode_input_features(\n", |
697 | 699 | " inputs, include_user_id, include_user_features, include_movie_features\n", |
698 | 700 | " )\n", |
699 | | - "\n", |
700 | 701 | " # Create a multi-headed attention layer.\n", |
701 | 702 | " attention_output = layers.MultiHeadAttention(\n", |
702 | 703 | " num_heads=num_heads, key_dim=transformer_features.shape[2], dropout=dropout_rate\n", |
|
713 | 714 | " transformer_features = layers.LayerNormalization()(transformer_features)\n", |
714 | 715 | " features = layers.Flatten()(transformer_features)\n", |
715 | 716 | "\n", |
716 | | - " # Included the other features.\n", |
| 717 | + " # Included the other_features.\n", |
717 | 718 | " if other_features is not None:\n", |
718 | 719 | " features = layers.concatenate(\n", |
719 | 720 | " [features, layers.Reshape([other_features.shape[-1]])(other_features)]\n", |
|
725 | 726 | " features = layers.BatchNormalization()(features)\n", |
726 | 727 | " features = layers.LeakyReLU()(features)\n", |
727 | 728 | " features = layers.Dropout(dropout_rate)(features)\n", |
728 | | - "\n", |
729 | 729 | " outputs = layers.Dense(units=1)(features)\n", |
730 | 730 | " model = keras.Model(inputs=inputs, outputs=outputs)\n", |
731 | 731 | " return model\n", |
|
759 | 759 | ")\n", |
760 | 760 | "\n", |
761 | 761 | "# Read the training data.\n", |
| 762 | + "\n", |
762 | 763 | "train_dataset = get_dataset_from_csv(\"train_data.csv\", batch_size=265, shuffle=True)\n", |
763 | 764 | "\n", |
764 | 765 | "# Fit the model with the training data.\n", |
|
0 commit comments