diff --git a/.gitignore b/.gitignore
index 7f756fb5a0..d646eb5568 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,6 +1,7 @@
tmp/*
sources/*
site/*
+*.DS_Store
*.pyc
*.swp
templates/examples/audio/*
diff --git a/examples/keras_rs/dlrm.py b/examples/keras_rs/dlrm.py
new file mode 100644
index 0000000000..ee1f722122
--- /dev/null
+++ b/examples/keras_rs/dlrm.py
@@ -0,0 +1,495 @@
+"""
+Title: Ranking with Deep Learning Recommendation Model
+Author: [Harshith Kulkarni](https://github.com/kharshith-k)
+Date created: 2025/06/02
+Last modified: 2025/09/04
+Description: Rank movies with DLRM using KerasRS.
+Accelerator: GPU
+"""
+
+"""
+## Introduction
+
+This tutorial demonstrates how to use the Deep Learning Recommendation Model (DLRM) to
+effectively learn the relationships between items and user preferences using a
+dot-product interaction mechanism. For more details, please refer to the
+[DLRM](https://arxiv.org/abs/1906.00091) paper.
+
+DLRM is designed to excel at capturing explicit, bounded-degree feature interactions and
+is particularly effective at processing both categorical and continuous (sparse/dense)
+input features. The architecture consists of three main components: dedicated input
+layers to handle diverse features (typically embedding layers for categorical features),
+a dot-product interaction layer to explicitly model feature interactions, and a
+Multi-Layer Perceptron (MLP) to capture implicit feature relationships.
+
+The dot-product interaction layer lies at the heart of DLRM, efficiently computing
+pairwise interactions between different feature embeddings. This contrasts with models
+like Deep & Cross Network (DCN), which can treat elements within a feature vector as
+independent units, potentially leading to a higher-dimensional space and increased
+computational cost. The MLP is a standard feedforward network. The DLRM is formed by
+combining the interaction layer and MLP.
+
+The following image illustrates the DLRM architecture:
+
+
+
+
+Now that we have a foundational understanding of DLRM's architecture and key
+characteristics, let's dive into the code. We will train a DLRM on a real-world dataset
+to demonstrate its capability to learn meaningful feature interactions. Let's begin by
+setting the backend to JAX and organizing our imports.
+"""
+
+"""shell
+!pip install -q keras-rs
+"""
+
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow" # `"tensorflow"`/`"torch"`
+
+import keras
+import matplotlib.pyplot as plt
+import numpy as np
+import tensorflow as tf
+import tensorflow_datasets as tfds
+from mpl_toolkits.axes_grid1 import make_axes_locatable
+
+import keras_rs
+
+"""
+Let's also define variables which will be reused throughout the example.
+"""
+
+MOVIELENS_CONFIG = {
+ # features
+ "continuous_features": [
+ "raw_user_age",
+ "hour_of_day_sin",
+ "hour_of_day_cos",
+ "hour_of_week_sin",
+ "hour_of_week_cos",
+ ],
+ "categorical_int_features": [
+ "user_gender",
+ ],
+ "categorical_str_features": [
+ "user_zip_code",
+ "user_occupation_text",
+ "movie_id",
+ "user_id",
+ ],
+ # model
+ "embedding_dim": 8,
+ "mlp_dim": 8,
+ "deep_net_num_units": [192, 192, 192],
+ # training
+ "learning_rate": 1e-4,
+ "num_epochs": 30,
+ "batch_size": 8192,
+}
+
+"""
+Here, we define a helper function for visualising weights of the cross layer in
+order to better understand its functioning. Also, we define a function for
+compiling, training and evaluating a given model.
+"""
+
+
+def plot_training_metrics(history):
+ """Graphs all metrics tracked in the history object."""
+ plt.figure(figsize=(12, 6))
+
+ for metric_name, metric_values in history.history.items():
+ plt.plot(metric_values, label=metric_name.replace("_", " ").title())
+
+ plt.title("Metrics over Epochs")
+ plt.xlabel("Epoch")
+ plt.ylabel("Metric Value")
+ plt.legend()
+ plt.grid(True)
+
+
+def visualize_layer(matrix, features, cmap=plt.cm.Blues):
+
+ im = plt.matshow(
+ matrix, cmap=cmap, extent=[-0.5, len(features) - 0.5, len(features) - 0.5, -0.5]
+ )
+
+ ax = plt.gca()
+ divider = make_axes_locatable(plt.gca())
+ cax = divider.append_axes("right", size="5%", pad=0.05)
+ plt.colorbar(im, cax=cax)
+ cax.tick_params(labelsize=10)
+
+ # Set tick locations explicitly before setting labels
+ ax.set_xticks(np.arange(len(features)))
+ ax.set_yticks(np.arange(len(features)))
+
+ ax.set_xticklabels(features, rotation=45, fontsize=5)
+ ax.set_yticklabels(features, fontsize=5)
+
+ plt.show()
+
+
+def train_and_evaluate(
+ learning_rate,
+ epochs,
+ train_data,
+ test_data,
+ model,
+ plot_metrics=False,
+):
+ optimizer = keras.optimizers.AdamW(learning_rate=learning_rate, clipnorm=1.0)
+ loss = keras.losses.MeanSquaredError()
+ rmse = keras.metrics.RootMeanSquaredError()
+
+ model.compile(
+ optimizer=optimizer,
+ loss=loss,
+ metrics=[rmse],
+ )
+
+ history = model.fit(
+ train_data,
+ epochs=epochs,
+ verbose=1,
+ )
+ if plot_metrics:
+ plot_training_metrics(history)
+
+ results = model.evaluate(test_data, return_dict=True, verbose=1)
+ rmse_value = results["root_mean_squared_error"]
+
+ return rmse_value, model.count_params()
+
+
+def print_stats(rmse_list, num_params, model_name):
+ # Report metrics.
+ num_trials = len(rmse_list)
+ avg_rmse = np.mean(rmse_list)
+ std_rmse = np.std(rmse_list)
+
+ if num_trials == 1:
+ print(f"{model_name}: RMSE = {avg_rmse}; #params = {num_params}")
+ else:
+ print(f"{model_name}: RMSE = {avg_rmse} ± {std_rmse}; #params = {num_params}")
+
+
+"""
+## Real-world example
+
+Let's use the MovieLens 100K dataset. This dataset is used to train models to
+predict users' movie ratings, based on user-related features and movie-related
+features.
+
+### Preparing the dataset
+
+The dataset processing steps here are similar to what's given in the
+[basic ranking](/keras_rs/examples/basic_ranking/)
+tutorial. Let's load the dataset, and keep only the useful columns.
+"""
+
+ratings_ds = tfds.load("movielens/100k-ratings", split="train")
+
+
+def preprocess_features(x):
+ """Extracts and cyclically encodes timestamp features."""
+ features = {
+ "movie_id": x["movie_id"],
+ "user_id": x["user_id"],
+ "user_gender": tf.cast(x["user_gender"], dtype=tf.int32),
+ "user_zip_code": x["user_zip_code"],
+ "user_occupation_text": x["user_occupation_text"],
+ "raw_user_age": tf.cast(x["raw_user_age"], dtype=tf.float32),
+ }
+ label = tf.cast(x["user_rating"], dtype=tf.float32)
+
+ # The timestamp is in seconds since the epoch.
+ timestamp = tf.cast(x["timestamp"], dtype=tf.float32)
+
+ # Constants for time periods
+ SECONDS_IN_HOUR = 3600.0
+ HOURS_IN_DAY = 24.0
+ HOURS_IN_WEEK = 168.0
+
+ # Calculate hour of day and encode it
+ hour_of_day = (timestamp / SECONDS_IN_HOUR) % HOURS_IN_DAY
+ features["hour_of_day_sin"] = tf.sin(2 * np.pi * hour_of_day / HOURS_IN_DAY)
+ features["hour_of_day_cos"] = tf.cos(2 * np.pi * hour_of_day / HOURS_IN_DAY)
+
+ # Calculate hour of week and encode it
+ hour_of_week = (timestamp / SECONDS_IN_HOUR) % HOURS_IN_WEEK
+ features["hour_of_week_sin"] = tf.sin(2 * np.pi * hour_of_week / HOURS_IN_WEEK)
+ features["hour_of_week_cos"] = tf.cos(2 * np.pi * hour_of_week / HOURS_IN_WEEK)
+
+ return features, label
+
+
+# Apply the new preprocessing function
+ratings_ds = ratings_ds.map(preprocess_features)
+
+"""
+For every categorical feature, let's get the list of unique values, i.e., vocabulary, so
+that we can use that for the embedding layer.
+"""
+
+vocabularies = {}
+for feature_name in (
+ MOVIELENS_CONFIG["categorical_int_features"]
+ + MOVIELENS_CONFIG["categorical_str_features"]
+):
+ vocabulary = ratings_ds.batch(10_000).map(lambda x, y: x[feature_name])
+ vocabularies[feature_name] = np.unique(np.concatenate(list(vocabulary)))
+
+"""
+One thing we need to do is to use `keras.layers.StringLookup` and
+`keras.layers.IntegerLookup` to convert all the categorical features into indices, which
+can
+then be fed into embedding layers.
+"""
+
+lookup_layers = {}
+lookup_layers.update(
+ {
+ feature: keras.layers.IntegerLookup(vocabulary=vocabularies[feature])
+ for feature in MOVIELENS_CONFIG["categorical_int_features"]
+ }
+)
+lookup_layers.update(
+ {
+ feature: keras.layers.StringLookup(vocabulary=vocabularies[feature])
+ for feature in MOVIELENS_CONFIG["categorical_str_features"]
+ }
+)
+
+"""
+Let's normalize all the continuous features, so that we can use that for the MLP layers.
+"""
+
+normalization_layers = {}
+for feature_name in MOVIELENS_CONFIG["continuous_features"]:
+ normalization_layers[feature_name] = keras.layers.Normalization(axis=-1)
+
+training_data_for_adaptation = ratings_ds.take(80_000).map(lambda x, y: x)
+
+for feature_name in MOVIELENS_CONFIG["continuous_features"]:
+ feature_ds = training_data_for_adaptation.map(
+ lambda x: tf.expand_dims(x[feature_name], axis=-1)
+ )
+ normalization_layers[feature_name].adapt(feature_ds)
+
+ratings_ds = ratings_ds.map(
+ lambda x, y: (
+ {
+ **{
+ feature_name: lookup_layers[feature_name](x[feature_name])
+ for feature_name in vocabularies
+ },
+ # Apply the adapted normalization layers to the continuous features.
+ **{
+ feature_name: tf.squeeze(
+ normalization_layers[feature_name](
+ tf.expand_dims(x[feature_name], axis=-1)
+ ),
+ axis=-1,
+ )
+ for feature_name in MOVIELENS_CONFIG["continuous_features"]
+ },
+ },
+ y,
+ )
+)
+
+"""
+Let's split our data into train and test sets. We also use `cache()` and
+`prefetch()` for better performance.
+"""
+
+ratings_ds = ratings_ds.shuffle(100_000)
+
+train_ds = (
+ ratings_ds.take(80_000)
+ .batch(MOVIELENS_CONFIG["batch_size"])
+ .cache()
+ .prefetch(tf.data.AUTOTUNE)
+)
+test_ds = (
+ ratings_ds.skip(80_000)
+ .batch(MOVIELENS_CONFIG["batch_size"])
+ .take(20_000)
+ .cache()
+ .prefetch(tf.data.AUTOTUNE)
+)
+
+"""
+### Building the model
+
+The model will have embedding layers, followed by DotInteraction and feedforward
+layers.
+"""
+
+
+class DLRM(keras.Model):
+ def __init__(
+ self,
+ dense_num_units_lst,
+ embedding_dim=MOVIELENS_CONFIG["embedding_dim"],
+ mlp_dim=MOVIELENS_CONFIG["mlp_dim"],
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.embedding_layers = {}
+ for feature_name in (
+ MOVIELENS_CONFIG["categorical_int_features"]
+ + MOVIELENS_CONFIG["categorical_str_features"]
+ ):
+ vocab_size = len(vocabularies[feature_name]) + 1 # +1 for OOV token
+ self.embedding_layers[feature_name] = keras.layers.Embedding(
+ input_dim=vocab_size,
+ output_dim=embedding_dim,
+ )
+
+ self.bottom_mlp = keras.Sequential(
+ [
+ keras.layers.Dense(mlp_dim, activation="relu"),
+ keras.layers.Dense(embedding_dim), # Output must match embedding_dim
+ ]
+ )
+
+ self.dot_layer = keras_rs.layers.DotInteraction()
+
+ self.top_mlp = []
+ for num_units in dense_num_units_lst:
+ self.top_mlp.append(keras.layers.Dense(num_units, activation="relu"))
+
+ self.output_layer = keras.layers.Dense(1)
+
+ self.dense_num_units_lst = dense_num_units_lst
+ self.embedding_dim = embedding_dim
+
+ def call(self, inputs):
+ embeddings = []
+ for feature_name in (
+ MOVIELENS_CONFIG["categorical_int_features"]
+ + MOVIELENS_CONFIG["categorical_str_features"]
+ ):
+ embedding = self.embedding_layers[feature_name](inputs[feature_name])
+ embeddings.append(embedding)
+
+ # Process all continuous features together.
+ continuous_inputs = []
+ for feature_name in MOVIELENS_CONFIG["continuous_features"]:
+ # Reshape each feature to (batch_size, 1)
+ feature = keras.ops.reshape(
+ keras.ops.cast(inputs[feature_name], dtype="float32"), (-1, 1)
+ )
+ continuous_inputs.append(feature)
+
+ # Concatenate into a single tensor: (batch_size, num_continuous_features)
+ concatenated_continuous = keras.ops.concatenate(continuous_inputs, axis=1)
+
+ # Pass through the Bottom MLP to get one combined vector.
+ processed_continuous = self.bottom_mlp(concatenated_continuous)
+
+ # Combine with categorical embeddings. Note: we add a list containing the
+ # single tensor.
+ combined_features = embeddings + [processed_continuous]
+
+ # Pass the list of features to the DotInteraction layer.
+ x = self.dot_layer(combined_features)
+
+ for layer in self.top_mlp:
+ x = layer(x)
+
+ x = self.output_layer(x)
+
+ return x
+
+
+dot_network = DLRM(
+ dense_num_units_lst=MOVIELENS_CONFIG["deep_net_num_units"],
+ embedding_dim=MOVIELENS_CONFIG["embedding_dim"],
+ mlp_dim=MOVIELENS_CONFIG["mlp_dim"],
+)
+
+rmse, dot_network_num_params = train_and_evaluate(
+ learning_rate=MOVIELENS_CONFIG["learning_rate"],
+ epochs=MOVIELENS_CONFIG["num_epochs"],
+ train_data=train_ds,
+ test_data=test_ds,
+ model=dot_network,
+ plot_metrics=True,
+)
+print_stats(
+ rmse_list=[rmse],
+ num_params=dot_network_num_params,
+ model_name="Dot Network",
+)
+
+"""
+### Visualizing feature interactions
+
+The DotInteraction layer itself doesn't have a conventional "weight" matrix like a Dense
+layer. Instead, its function is to compute the dot product between the embedding vectors
+of your features.
+
+To visualize the strength of these interactions, we can calculate a matrix representing
+the pairwise interaction strength between all feature embeddings. A common way to do this
+is to take the dot product of the embedding matrices for each pair of features and then
+aggregate the result into a single value (like the mean of the absolute values) that
+represents the overall interaction strength.
+"""
+
+
+def get_dot_interaction_matrix(model, categorical_features, continuous_features):
+ # The new feature list for the plot labels
+ all_feature_names = categorical_features + ["all_continuous_features"]
+ num_features = len(all_feature_names)
+
+ # Store all feature outputs in the correct order.
+ all_feature_outputs = []
+
+ # Get outputs for categorical features from embedding layers (unchanged).
+ for feature_name in categorical_features:
+ embedding = model.embedding_layers[feature_name](keras.ops.array([0]))
+ all_feature_outputs.append(embedding)
+
+ # Get a single output for ALL continuous features from the shared MLP.
+ num_continuous_features = len(continuous_features)
+ # Create a dummy input of zeros for the MLP
+ dummy_continuous_input = keras.ops.zeros((1, num_continuous_features))
+ processed_continuous = model.bottom_mlp(dummy_continuous_input)
+ all_feature_outputs.append(processed_continuous)
+
+ interaction_matrix = np.zeros((num_features, num_features))
+
+ # Iterate through each pair to calculate interaction strength.
+ for i in range(num_features):
+ for j in range(num_features):
+ interaction = keras.ops.dot(
+ all_feature_outputs[i], keras.ops.transpose(all_feature_outputs[j])
+ )
+ interaction_strength = keras.ops.convert_to_numpy(np.abs(interaction))[0][0]
+ interaction_matrix[i, j] = interaction_strength
+
+ return interaction_matrix, all_feature_names
+
+
+# Get the list of categorical feature names.
+categorical_feature_names = (
+ MOVIELENS_CONFIG["categorical_int_features"]
+ + MOVIELENS_CONFIG["categorical_str_features"]
+)
+
+# Calculate the interaction matrix with the corrected function.
+interaction_matrix, feature_names = get_dot_interaction_matrix(
+ model=dot_network,
+ categorical_features=categorical_feature_names,
+ continuous_features=MOVIELENS_CONFIG["continuous_features"],
+)
+
+# Visualize the matrix as a heatmap.
+print("\nVisualizing the feature interaction strengths:")
+visualize_layer(interaction_matrix, feature_names)
diff --git a/examples/keras_rs/img/dlrm/dlrm_19_158.png b/examples/keras_rs/img/dlrm/dlrm_19_158.png
new file mode 100644
index 0000000000..48e90b6bf3
Binary files /dev/null and b/examples/keras_rs/img/dlrm/dlrm_19_158.png differ
diff --git a/examples/keras_rs/img/dlrm/dlrm_21_1.png b/examples/keras_rs/img/dlrm/dlrm_21_1.png
new file mode 100644
index 0000000000..81d11ed54b
Binary files /dev/null and b/examples/keras_rs/img/dlrm/dlrm_21_1.png differ
diff --git a/examples/keras_rs/img/dlrm/dlrm_architecture.gif b/examples/keras_rs/img/dlrm/dlrm_architecture.gif
new file mode 100644
index 0000000000..7186adc365
Binary files /dev/null and b/examples/keras_rs/img/dlrm/dlrm_architecture.gif differ
diff --git a/examples/keras_rs/ipynb/dlrm.ipynb b/examples/keras_rs/ipynb/dlrm.ipynb
new file mode 100644
index 0000000000..9de8dab3d7
--- /dev/null
+++ b/examples/keras_rs/ipynb/dlrm.ipynb
@@ -0,0 +1,680 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "# Ranking with Deep Learning Recommendation Model\n",
+ "\n",
+ "**Author:** [Harshith Kulkarni](https://github.com/kharshith-k)
\n",
+ "**Date created:** 2025/06/02
\n",
+ "**Last modified:** 2025/09/04
\n",
+ "**Description:** Rank movies with DLRM using KerasRS."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Introduction\n",
+ "\n",
+ "This tutorial demonstrates how to use the Deep Learning Recommendation Model (DLRM) to\n",
+ "effectively learn the relationships between items and user preferences using a\n",
+ "dot-product interaction mechanism. For more details, please refer to the\n",
+ "[DLRM](https://arxiv.org/abs/1906.00091) paper.\n",
+ "\n",
+ "DLRM is designed to excel at capturing explicit, bounded-degree feature interactions and\n",
+ "is particularly effective at processing both categorical and continuous (sparse/dense)\n",
+ "input features. The architecture consists of three main components: dedicated input\n",
+ "layers to handle diverse features (typically embedding layers for categorical features),\n",
+ "a dot-product interaction layer to explicitly model feature interactions, and a\n",
+ "Multi-Layer Perceptron (MLP) to capture implicit feature relationships.\n",
+ "\n",
+ "The dot-product interaction layer lies at the heart of DLRM, efficiently computing\n",
+ "pairwise interactions between different feature embeddings. This contrasts with models\n",
+ "like Deep & Cross Network (DCN), which can treat elements within a feature vector as\n",
+ "independent units, potentially leading to a higher-dimensional space and increased\n",
+ "computational cost. The MLP is a standard feedforward network. The DLRM is formed by\n",
+ "combining the interaction layer and MLP.\n",
+ "\n",
+ "The following image illustrates the DLRM architecture:\n",
+ "\n",
+ "\n",
+ "\n",
+ "\n",
+ "Now that we have a foundational understanding of DLRM's architecture and key\n",
+ "characteristics, let's dive into the code. We will train a DLRM on a real-world dataset\n",
+ "to demonstrate its capability to learn meaningful feature interactions. Let's begin by\n",
+ "setting the backend to JAX and organizing our imports."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "!pip install -q keras-rs"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "os.environ[\"KERAS_BACKEND\"] = \"tensorflow\" # `\"tensorflow\"`/`\"torch\"`\n",
+ "\n",
+ "import keras\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import tensorflow as tf\n",
+ "import tensorflow_datasets as tfds\n",
+ "from mpl_toolkits.axes_grid1 import make_axes_locatable\n",
+ "\n",
+ "import keras_rs"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "Let's also define variables which will be reused throughout the example."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "MOVIELENS_CONFIG = {\n",
+ " # features\n",
+ " \"continuous_features\": [\n",
+ " \"raw_user_age\",\n",
+ " \"hour_of_day_sin\",\n",
+ " \"hour_of_day_cos\",\n",
+ " \"hour_of_week_sin\",\n",
+ " \"hour_of_week_cos\",\n",
+ " ],\n",
+ " \"categorical_int_features\": [\n",
+ " \"user_gender\",\n",
+ " ],\n",
+ " \"categorical_str_features\": [\n",
+ " \"user_zip_code\",\n",
+ " \"user_occupation_text\",\n",
+ " \"movie_id\",\n",
+ " \"user_id\",\n",
+ " ],\n",
+ " # model\n",
+ " \"embedding_dim\": 8,\n",
+ " \"mlp_dim\": 8,\n",
+ " \"deep_net_num_units\": [192, 192, 192],\n",
+ " # training\n",
+ " \"learning_rate\": 1e-4,\n",
+ " \"num_epochs\": 30,\n",
+ " \"batch_size\": 8192,\n",
+ "}"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "Here, we define a helper function for visualising weights of the cross layer in\n",
+ "order to better understand its functioning. Also, we define a function for\n",
+ "compiling, training and evaluating a given model."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "def plot_training_metrics(history):\n",
+ " \"\"\"Graphs all metrics tracked in the history object.\"\"\"\n",
+ " plt.figure(figsize=(12, 6))\n",
+ "\n",
+ " for metric_name, metric_values in history.history.items():\n",
+ " plt.plot(metric_values, label=metric_name.replace(\"_\", \" \").title())\n",
+ "\n",
+ " plt.title(\"Metrics over Epochs\")\n",
+ " plt.xlabel(\"Epoch\")\n",
+ " plt.ylabel(\"Metric Value\")\n",
+ " plt.legend()\n",
+ " plt.grid(True)\n",
+ "\n",
+ "\n",
+ "def visualize_layer(matrix, features, cmap=plt.cm.Blues):\n",
+ "\n",
+ " im = plt.matshow(\n",
+ " matrix, cmap=cmap, extent=[-0.5, len(features) - 0.5, len(features) - 0.5, -0.5]\n",
+ " )\n",
+ "\n",
+ " ax = plt.gca()\n",
+ " divider = make_axes_locatable(plt.gca())\n",
+ " cax = divider.append_axes(\"right\", size=\"5%\", pad=0.05)\n",
+ " plt.colorbar(im, cax=cax)\n",
+ " cax.tick_params(labelsize=10)\n",
+ "\n",
+ " # Set tick locations explicitly before setting labels\n",
+ " ax.set_xticks(np.arange(len(features)))\n",
+ " ax.set_yticks(np.arange(len(features)))\n",
+ "\n",
+ " ax.set_xticklabels(features, rotation=45, fontsize=5)\n",
+ " ax.set_yticklabels(features, fontsize=5)\n",
+ "\n",
+ " plt.show()\n",
+ "\n",
+ "\n",
+ "def train_and_evaluate(\n",
+ " learning_rate,\n",
+ " epochs,\n",
+ " train_data,\n",
+ " test_data,\n",
+ " model,\n",
+ " plot_metrics=False,\n",
+ "):\n",
+ " optimizer = keras.optimizers.AdamW(learning_rate=learning_rate, clipnorm=1.0)\n",
+ " loss = keras.losses.MeanSquaredError()\n",
+ " rmse = keras.metrics.RootMeanSquaredError()\n",
+ "\n",
+ " model.compile(\n",
+ " optimizer=optimizer,\n",
+ " loss=loss,\n",
+ " metrics=[rmse],\n",
+ " )\n",
+ "\n",
+ " history = model.fit(\n",
+ " train_data,\n",
+ " epochs=epochs,\n",
+ " verbose=1,\n",
+ " )\n",
+ " if plot_metrics:\n",
+ " plot_training_metrics(history)\n",
+ "\n",
+ " results = model.evaluate(test_data, return_dict=True, verbose=1)\n",
+ " rmse_value = results[\"root_mean_squared_error\"]\n",
+ "\n",
+ " return rmse_value, model.count_params()\n",
+ "\n",
+ "\n",
+ "def print_stats(rmse_list, num_params, model_name):\n",
+ " # Report metrics.\n",
+ " num_trials = len(rmse_list)\n",
+ " avg_rmse = np.mean(rmse_list)\n",
+ " std_rmse = np.std(rmse_list)\n",
+ "\n",
+ " if num_trials == 1:\n",
+ " print(f\"{model_name}: RMSE = {avg_rmse}; #params = {num_params}\")\n",
+ " else:\n",
+ " print(f\"{model_name}: RMSE = {avg_rmse} ± {std_rmse}; #params = {num_params}\")\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Real-world example\n",
+ "\n",
+ "Let's use the MovieLens 100K dataset. This dataset is used to train models to\n",
+ "predict users' movie ratings, based on user-related features and movie-related\n",
+ "features.\n",
+ "\n",
+ "### Preparing the dataset\n",
+ "\n",
+ "The dataset processing steps here are similar to what's given in the\n",
+ "[basic ranking](/keras_rs/examples/basic_ranking/)\n",
+ "tutorial. Let's load the dataset, and keep only the useful columns."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "ratings_ds = tfds.load(\"movielens/100k-ratings\", split=\"train\")\n",
+ "\n",
+ "\n",
+ "def preprocess_features(x):\n",
+ " \"\"\"Extracts and cyclically encodes timestamp features.\"\"\"\n",
+ " features = {\n",
+ " \"movie_id\": x[\"movie_id\"],\n",
+ " \"user_id\": x[\"user_id\"],\n",
+ " \"user_gender\": tf.cast(x[\"user_gender\"], dtype=tf.int32),\n",
+ " \"user_zip_code\": x[\"user_zip_code\"],\n",
+ " \"user_occupation_text\": x[\"user_occupation_text\"],\n",
+ " \"raw_user_age\": tf.cast(x[\"raw_user_age\"], dtype=tf.float32),\n",
+ " }\n",
+ " label = tf.cast(x[\"user_rating\"], dtype=tf.float32)\n",
+ "\n",
+ " # The timestamp is in seconds since the epoch.\n",
+ " timestamp = tf.cast(x[\"timestamp\"], dtype=tf.float32)\n",
+ "\n",
+ " # Constants for time periods\n",
+ " SECONDS_IN_HOUR = 3600.0\n",
+ " HOURS_IN_DAY = 24.0\n",
+ " HOURS_IN_WEEK = 168.0\n",
+ "\n",
+ " # Calculate hour of day and encode it\n",
+ " hour_of_day = (timestamp / SECONDS_IN_HOUR) % HOURS_IN_DAY\n",
+ " features[\"hour_of_day_sin\"] = tf.sin(2 * np.pi * hour_of_day / HOURS_IN_DAY)\n",
+ " features[\"hour_of_day_cos\"] = tf.cos(2 * np.pi * hour_of_day / HOURS_IN_DAY)\n",
+ "\n",
+ " # Calculate hour of week and encode it\n",
+ " hour_of_week = (timestamp / SECONDS_IN_HOUR) % HOURS_IN_WEEK\n",
+ " features[\"hour_of_week_sin\"] = tf.sin(2 * np.pi * hour_of_week / HOURS_IN_WEEK)\n",
+ " features[\"hour_of_week_cos\"] = tf.cos(2 * np.pi * hour_of_week / HOURS_IN_WEEK)\n",
+ "\n",
+ " return features, label\n",
+ "\n",
+ "\n",
+ "# Apply the new preprocessing function\n",
+ "ratings_ds = ratings_ds.map(preprocess_features)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "For every categorical feature, let's get the list of unique values, i.e., vocabulary, so\n",
+ "that we can use that for the embedding layer."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "vocabularies = {}\n",
+ "for feature_name in (\n",
+ " MOVIELENS_CONFIG[\"categorical_int_features\"]\n",
+ " + MOVIELENS_CONFIG[\"categorical_str_features\"]\n",
+ "):\n",
+ " vocabulary = ratings_ds.batch(10_000).map(lambda x, y: x[feature_name])\n",
+ " vocabularies[feature_name] = np.unique(np.concatenate(list(vocabulary)))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "One thing we need to do is to use `keras.layers.StringLookup` and\n",
+ "`keras.layers.IntegerLookup` to convert all the categorical features into indices, which\n",
+ "can\n",
+ "then be fed into embedding layers."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "lookup_layers = {}\n",
+ "lookup_layers.update(\n",
+ " {\n",
+ " feature: keras.layers.IntegerLookup(vocabulary=vocabularies[feature])\n",
+ " for feature in MOVIELENS_CONFIG[\"categorical_int_features\"]\n",
+ " }\n",
+ ")\n",
+ "lookup_layers.update(\n",
+ " {\n",
+ " feature: keras.layers.StringLookup(vocabulary=vocabularies[feature])\n",
+ " for feature in MOVIELENS_CONFIG[\"categorical_str_features\"]\n",
+ " }\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "Let's normalize all the continuous features, so that we can use that for the MLP layers."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "normalization_layers = {}\n",
+ "for feature_name in MOVIELENS_CONFIG[\"continuous_features\"]:\n",
+ " normalization_layers[feature_name] = keras.layers.Normalization(axis=-1)\n",
+ "\n",
+ "training_data_for_adaptation = ratings_ds.take(80_000).map(lambda x, y: x)\n",
+ "\n",
+ "for feature_name in MOVIELENS_CONFIG[\"continuous_features\"]:\n",
+ " feature_ds = training_data_for_adaptation.map(\n",
+ " lambda x: tf.expand_dims(x[feature_name], axis=-1)\n",
+ " )\n",
+ " normalization_layers[feature_name].adapt(feature_ds)\n",
+ "\n",
+ "ratings_ds = ratings_ds.map(\n",
+ " lambda x, y: (\n",
+ " {\n",
+ " **{\n",
+ " feature_name: lookup_layers[feature_name](x[feature_name])\n",
+ " for feature_name in vocabularies\n",
+ " },\n",
+ " # Apply the adapted normalization layers to the continuous features.\n",
+ " **{\n",
+ " feature_name: tf.squeeze(\n",
+ " normalization_layers[feature_name](\n",
+ " tf.expand_dims(x[feature_name], axis=-1)\n",
+ " ),\n",
+ " axis=-1,\n",
+ " )\n",
+ " for feature_name in MOVIELENS_CONFIG[\"continuous_features\"]\n",
+ " },\n",
+ " },\n",
+ " y,\n",
+ " )\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "Let's split our data into train and test sets. We also use `cache()` and\n",
+ "`prefetch()` for better performance."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "ratings_ds = ratings_ds.shuffle(100_000)\n",
+ "\n",
+ "train_ds = (\n",
+ " ratings_ds.take(80_000)\n",
+ " .batch(MOVIELENS_CONFIG[\"batch_size\"])\n",
+ " .cache()\n",
+ " .prefetch(tf.data.AUTOTUNE)\n",
+ ")\n",
+ "test_ds = (\n",
+ " ratings_ds.skip(80_000)\n",
+ " .batch(MOVIELENS_CONFIG[\"batch_size\"])\n",
+ " .take(20_000)\n",
+ " .cache()\n",
+ " .prefetch(tf.data.AUTOTUNE)\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### Building the model\n",
+ "\n",
+ "The model will have embedding layers, followed by DotInteraction and feedforward\n",
+ "layers."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "class DLRM(keras.Model):\n",
+ " def __init__(\n",
+ " self,\n",
+ " dense_num_units_lst,\n",
+ " embedding_dim=MOVIELENS_CONFIG[\"embedding_dim\"],\n",
+ " mlp_dim=MOVIELENS_CONFIG[\"mlp_dim\"],\n",
+ " **kwargs,\n",
+ " ):\n",
+ " super().__init__(**kwargs)\n",
+ "\n",
+ " self.embedding_layers = {}\n",
+ " for feature_name in (\n",
+ " MOVIELENS_CONFIG[\"categorical_int_features\"]\n",
+ " + MOVIELENS_CONFIG[\"categorical_str_features\"]\n",
+ " ):\n",
+ " vocab_size = len(vocabularies[feature_name]) + 1 # +1 for OOV token\n",
+ " self.embedding_layers[feature_name] = keras.layers.Embedding(\n",
+ " input_dim=vocab_size,\n",
+ " output_dim=embedding_dim,\n",
+ " )\n",
+ "\n",
+ " self.bottom_mlp = keras.Sequential(\n",
+ " [\n",
+ " keras.layers.Dense(mlp_dim, activation=\"relu\"),\n",
+ " keras.layers.Dense(embedding_dim), # Output must match embedding_dim\n",
+ " ]\n",
+ " )\n",
+ "\n",
+ " self.dot_layer = keras_rs.layers.DotInteraction()\n",
+ "\n",
+ " self.top_mlp = []\n",
+ " for num_units in dense_num_units_lst:\n",
+ " self.top_mlp.append(keras.layers.Dense(num_units, activation=\"relu\"))\n",
+ "\n",
+ " self.output_layer = keras.layers.Dense(1)\n",
+ "\n",
+ " self.dense_num_units_lst = dense_num_units_lst\n",
+ " self.embedding_dim = embedding_dim\n",
+ "\n",
+ " def call(self, inputs):\n",
+ " embeddings = []\n",
+ " for feature_name in (\n",
+ " MOVIELENS_CONFIG[\"categorical_int_features\"]\n",
+ " + MOVIELENS_CONFIG[\"categorical_str_features\"]\n",
+ " ):\n",
+ " embedding = self.embedding_layers[feature_name](inputs[feature_name])\n",
+ " embeddings.append(embedding)\n",
+ "\n",
+ " # Process all continuous features together.\n",
+ " continuous_inputs = []\n",
+ " for feature_name in MOVIELENS_CONFIG[\"continuous_features\"]:\n",
+ " # Reshape each feature to (batch_size, 1)\n",
+ " feature = keras.ops.reshape(\n",
+ " keras.ops.cast(inputs[feature_name], dtype=\"float32\"), (-1, 1)\n",
+ " )\n",
+ " continuous_inputs.append(feature)\n",
+ "\n",
+ " # Concatenate into a single tensor: (batch_size, num_continuous_features)\n",
+ " concatenated_continuous = keras.ops.concatenate(continuous_inputs, axis=1)\n",
+ "\n",
+ " # Pass through the Bottom MLP to get one combined vector.\n",
+ " processed_continuous = self.bottom_mlp(concatenated_continuous)\n",
+ "\n",
+ " # Combine with categorical embeddings. Note: we add a list containing the\n",
+ " # single tensor.\n",
+ " combined_features = embeddings + [processed_continuous]\n",
+ "\n",
+ " # Pass the list of features to the DotInteraction layer.\n",
+ " x = self.dot_layer(combined_features)\n",
+ "\n",
+ " for layer in self.top_mlp:\n",
+ " x = layer(x)\n",
+ "\n",
+ " x = self.output_layer(x)\n",
+ "\n",
+ " return x\n",
+ "\n",
+ "\n",
+ "dot_network = DLRM(\n",
+ " dense_num_units_lst=MOVIELENS_CONFIG[\"deep_net_num_units\"],\n",
+ " embedding_dim=MOVIELENS_CONFIG[\"embedding_dim\"],\n",
+ " mlp_dim=MOVIELENS_CONFIG[\"mlp_dim\"],\n",
+ ")\n",
+ "\n",
+ "rmse, dot_network_num_params = train_and_evaluate(\n",
+ " learning_rate=MOVIELENS_CONFIG[\"learning_rate\"],\n",
+ " epochs=MOVIELENS_CONFIG[\"num_epochs\"],\n",
+ " train_data=train_ds,\n",
+ " test_data=test_ds,\n",
+ " model=dot_network,\n",
+ " plot_metrics=True,\n",
+ ")\n",
+ "print_stats(\n",
+ " rmse_list=[rmse],\n",
+ " num_params=dot_network_num_params,\n",
+ " model_name=\"Dot Network\",\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### Visualizing feature interactions\n",
+ "\n",
+ "The DotInteraction layer itself doesn't have a conventional \"weight\" matrix like a Dense\n",
+ "layer. Instead, its function is to compute the dot product between the embedding vectors\n",
+ "of your features.\n",
+ "\n",
+ "To visualize the strength of these interactions, we can calculate a matrix representing\n",
+ "the pairwise interaction strength between all feature embeddings. A common way to do this\n",
+ "is to take the dot product of the embedding matrices for each pair of features and then\n",
+ "aggregate the result into a single value (like the mean of the absolute values) that\n",
+ "represents the overall interaction strength."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "def get_dot_interaction_matrix(model, categorical_features, continuous_features):\n",
+ " # The new feature list for the plot labels\n",
+ " all_feature_names = categorical_features + [\"all_continuous_features\"]\n",
+ " num_features = len(all_feature_names)\n",
+ "\n",
+ " # Store all feature outputs in the correct order.\n",
+ " all_feature_outputs = []\n",
+ "\n",
+ " # Get outputs for categorical features from embedding layers (unchanged).\n",
+ " for feature_name in categorical_features:\n",
+ " embedding = model.embedding_layers[feature_name](keras.ops.array([0]))\n",
+ " all_feature_outputs.append(embedding)\n",
+ "\n",
+ " # Get a single output for ALL continuous features from the shared MLP.\n",
+ " num_continuous_features = len(continuous_features)\n",
+ " # Create a dummy input of zeros for the MLP\n",
+ " dummy_continuous_input = keras.ops.zeros((1, num_continuous_features))\n",
+ " processed_continuous = model.bottom_mlp(dummy_continuous_input)\n",
+ " all_feature_outputs.append(processed_continuous)\n",
+ "\n",
+ " interaction_matrix = np.zeros((num_features, num_features))\n",
+ "\n",
+ " # Iterate through each pair to calculate interaction strength.\n",
+ " for i in range(num_features):\n",
+ " for j in range(num_features):\n",
+ " interaction = keras.ops.dot(\n",
+ " all_feature_outputs[i], keras.ops.transpose(all_feature_outputs[j])\n",
+ " )\n",
+ " interaction_strength = keras.ops.convert_to_numpy(np.abs(interaction))[0][0]\n",
+ " interaction_matrix[i, j] = interaction_strength\n",
+ "\n",
+ " return interaction_matrix, all_feature_names\n",
+ "\n",
+ "\n",
+ "# Get the list of categorical feature names.\n",
+ "categorical_feature_names = (\n",
+ " MOVIELENS_CONFIG[\"categorical_int_features\"]\n",
+ " + MOVIELENS_CONFIG[\"categorical_str_features\"]\n",
+ ")\n",
+ "\n",
+ "# Calculate the interaction matrix with the corrected function.\n",
+ "interaction_matrix, feature_names = get_dot_interaction_matrix(\n",
+ " model=dot_network,\n",
+ " categorical_features=categorical_feature_names,\n",
+ " continuous_features=MOVIELENS_CONFIG[\"continuous_features\"],\n",
+ ")\n",
+ "\n",
+ "# Visualize the matrix as a heatmap.\n",
+ "print(\"\\nVisualizing the feature interaction strengths:\")\n",
+ "visualize_layer(interaction_matrix, feature_names)"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "collapsed_sections": [],
+ "name": "dlrm",
+ "private_outputs": false,
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
diff --git a/examples/keras_rs/md/dlrm.md b/examples/keras_rs/md/dlrm.md
new file mode 100644
index 0000000000..46131c0bad
--- /dev/null
+++ b/examples/keras_rs/md/dlrm.md
@@ -0,0 +1,520 @@
+# Ranking with Deep Learning Recommendation Model
+
+**Author:** [Harshith Kulkarni](https://github.com/kharshith-k)
+**Date created:** 2025/06/02
+**Last modified:** 2025/09/04
+**Description:** Rank movies with DLRM using KerasRS.
+
+
+ [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/keras_rs/ipynb/dlrm.ipynb) •
[**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/keras_rs/dlrm.py)
+
+
+
+---
+## Introduction
+
+This tutorial demonstrates how to use the Deep Learning Recommendation Model (DLRM) to
+effectively learn the relationships between items and user preferences using a
+dot-product interaction mechanism. For more details, please refer to the
+[DLRM](https://arxiv.org/abs/1906.00091) paper.
+
+DLRM is designed to excel at capturing explicit, bounded-degree feature interactions and
+is particularly effective at processing both categorical and continuous (sparse/dense)
+input features. The architecture consists of three main components: dedicated input
+layers to handle diverse features (typically embedding layers for categorical features),
+a dot-product interaction layer to explicitly model feature interactions, and a
+Multi-Layer Perceptron (MLP) to capture implicit feature relationships.
+
+The dot-product interaction layer lies at the heart of DLRM, efficiently computing
+pairwise interactions between different feature embeddings. This contrasts with models
+like Deep & Cross Network (DCN), which can treat elements within a feature vector as
+independent units, potentially leading to a higher-dimensional space and increased
+computational cost. The MLP is a standard feedforward network. The DLRM is formed by
+combining the interaction layer and MLP.
+
+The following image illustrates the DLRM architecture:
+
+
+
+
+Now that we have a foundational understanding of DLRM's architecture and key
+characteristics, let's dive into the code. We will train a DLRM on a real-world dataset
+to demonstrate its capability to learn meaningful feature interactions. Let's begin by
+setting the backend to JAX and organizing our imports.
+
+
+```python
+!pip install -q keras-rs
+```
+
+
+
+
+```python
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow" # `"tensorflow"`/`"torch"`
+
+import keras
+import matplotlib.pyplot as plt
+import numpy as np
+import tensorflow as tf
+import tensorflow_datasets as tfds
+from mpl_toolkits.axes_grid1 import make_axes_locatable
+
+import keras_rs
+```
+
+Let's also define variables which will be reused throughout the example.
+
+
+```python
+MOVIELENS_CONFIG = {
+ # features
+ "continuous_features": [
+ "raw_user_age",
+ "hour_of_day_sin",
+ "hour_of_day_cos",
+ "hour_of_week_sin",
+ "hour_of_week_cos",
+ ],
+ "categorical_int_features": [
+ "user_gender",
+ ],
+ "categorical_str_features": [
+ "user_zip_code",
+ "user_occupation_text",
+ "movie_id",
+ "user_id",
+ ],
+ # model
+ "embedding_dim": 8,
+ "mlp_dim": 8,
+ "deep_net_num_units": [192, 192, 192],
+ # training
+ "learning_rate": 1e-4,
+ "num_epochs": 30,
+ "batch_size": 8192,
+}
+```
+
+Here, we define a helper function for visualising weights of the cross layer in
+order to better understand its functioning. Also, we define a function for
+compiling, training and evaluating a given model.
+
+
+```python
+
+def plot_training_metrics(history):
+ """Graphs all metrics tracked in the history object."""
+ plt.figure(figsize=(12, 6))
+
+ for metric_name, metric_values in history.history.items():
+ plt.plot(metric_values, label=metric_name.replace("_", " ").title())
+
+ plt.title("Metrics over Epochs")
+ plt.xlabel("Epoch")
+ plt.ylabel("Metric Value")
+ plt.legend()
+ plt.grid(True)
+
+
+def visualize_layer(matrix, features, cmap=plt.cm.Blues):
+
+ im = plt.matshow(
+ matrix, cmap=cmap, extent=[-0.5, len(features) - 0.5, len(features) - 0.5, -0.5]
+ )
+
+ ax = plt.gca()
+ divider = make_axes_locatable(plt.gca())
+ cax = divider.append_axes("right", size="5%", pad=0.05)
+ plt.colorbar(im, cax=cax)
+ cax.tick_params(labelsize=10)
+
+ # Set tick locations explicitly before setting labels
+ ax.set_xticks(np.arange(len(features)))
+ ax.set_yticks(np.arange(len(features)))
+
+ ax.set_xticklabels(features, rotation=45, fontsize=5)
+ ax.set_yticklabels(features, fontsize=5)
+
+ plt.show()
+
+
+def train_and_evaluate(
+ learning_rate,
+ epochs,
+ train_data,
+ test_data,
+ model,
+ plot_metrics=False,
+):
+ optimizer = keras.optimizers.AdamW(learning_rate=learning_rate, clipnorm=1.0)
+ loss = keras.losses.MeanSquaredError()
+ rmse = keras.metrics.RootMeanSquaredError()
+
+ model.compile(
+ optimizer=optimizer,
+ loss=loss,
+ metrics=[rmse],
+ )
+
+ history = model.fit(
+ train_data,
+ epochs=epochs,
+ verbose=1,
+ )
+ if plot_metrics:
+ plot_training_metrics(history)
+
+ results = model.evaluate(test_data, return_dict=True, verbose=1)
+ rmse_value = results["root_mean_squared_error"]
+
+ return rmse_value, model.count_params()
+
+
+def print_stats(rmse_list, num_params, model_name):
+ # Report metrics.
+ num_trials = len(rmse_list)
+ avg_rmse = np.mean(rmse_list)
+ std_rmse = np.std(rmse_list)
+
+ if num_trials == 1:
+ print(f"{model_name}: RMSE = {avg_rmse}; #params = {num_params}")
+ else:
+ print(f"{model_name}: RMSE = {avg_rmse} ± {std_rmse}; #params = {num_params}")
+
+```
+
+---
+## Real-world example
+
+Let's use the MovieLens 100K dataset. This dataset is used to train models to
+predict users' movie ratings, based on user-related features and movie-related
+features.
+
+### Preparing the dataset
+
+The dataset processing steps here are similar to what's given in the
+[basic ranking](/keras_rs/examples/basic_ranking/)
+tutorial. Let's load the dataset, and keep only the useful columns.
+
+
+```python
+ratings_ds = tfds.load("movielens/100k-ratings", split="train")
+
+
+def preprocess_features(x):
+ """Extracts and cyclically encodes timestamp features."""
+ features = {
+ "movie_id": x["movie_id"],
+ "user_id": x["user_id"],
+ "user_gender": tf.cast(x["user_gender"], dtype=tf.int32),
+ "user_zip_code": x["user_zip_code"],
+ "user_occupation_text": x["user_occupation_text"],
+ "raw_user_age": tf.cast(x["raw_user_age"], dtype=tf.float32),
+ }
+ label = tf.cast(x["user_rating"], dtype=tf.float32)
+
+ # The timestamp is in seconds since the epoch.
+ timestamp = tf.cast(x["timestamp"], dtype=tf.float32)
+
+ # Constants for time periods
+ SECONDS_IN_HOUR = 3600.0
+ HOURS_IN_DAY = 24.0
+ HOURS_IN_WEEK = 168.0
+
+ # Calculate hour of day and encode it
+ hour_of_day = (timestamp / SECONDS_IN_HOUR) % HOURS_IN_DAY
+ features["hour_of_day_sin"] = tf.sin(2 * np.pi * hour_of_day / HOURS_IN_DAY)
+ features["hour_of_day_cos"] = tf.cos(2 * np.pi * hour_of_day / HOURS_IN_DAY)
+
+ # Calculate hour of week and encode it
+ hour_of_week = (timestamp / SECONDS_IN_HOUR) % HOURS_IN_WEEK
+ features["hour_of_week_sin"] = tf.sin(2 * np.pi * hour_of_week / HOURS_IN_WEEK)
+ features["hour_of_week_cos"] = tf.cos(2 * np.pi * hour_of_week / HOURS_IN_WEEK)
+
+ return features, label
+
+
+# Apply the new preprocessing function
+ratings_ds = ratings_ds.map(preprocess_features)
+```
+
+For every categorical feature, let's get the list of unique values, i.e., vocabulary, so
+that we can use that for the embedding layer.
+
+
+```python
+vocabularies = {}
+for feature_name in (
+ MOVIELENS_CONFIG["categorical_int_features"]
+ + MOVIELENS_CONFIG["categorical_str_features"]
+):
+ vocabulary = ratings_ds.batch(10_000).map(lambda x, y: x[feature_name])
+ vocabularies[feature_name] = np.unique(np.concatenate(list(vocabulary)))
+```
+
+One thing we need to do is to use `keras.layers.StringLookup` and
+`keras.layers.IntegerLookup` to convert all the categorical features into indices, which
+can
+then be fed into embedding layers.
+
+
+```python
+lookup_layers = {}
+lookup_layers.update(
+ {
+ feature: keras.layers.IntegerLookup(vocabulary=vocabularies[feature])
+ for feature in MOVIELENS_CONFIG["categorical_int_features"]
+ }
+)
+lookup_layers.update(
+ {
+ feature: keras.layers.StringLookup(vocabulary=vocabularies[feature])
+ for feature in MOVIELENS_CONFIG["categorical_str_features"]
+ }
+)
+```
+
+Let's normalize all the continuous features, so that we can use that for the MLP layers.
+
+
+```python
+normalization_layers = {}
+for feature_name in MOVIELENS_CONFIG["continuous_features"]:
+ normalization_layers[feature_name] = keras.layers.Normalization(axis=-1)
+
+training_data_for_adaptation = ratings_ds.take(80_000).map(lambda x, y: x)
+
+for feature_name in MOVIELENS_CONFIG["continuous_features"]:
+ feature_ds = training_data_for_adaptation.map(
+ lambda x: tf.expand_dims(x[feature_name], axis=-1)
+ )
+ normalization_layers[feature_name].adapt(feature_ds)
+
+ratings_ds = ratings_ds.map(
+ lambda x, y: (
+ {
+ **{
+ feature_name: lookup_layers[feature_name](x[feature_name])
+ for feature_name in vocabularies
+ },
+ # Apply the adapted normalization layers to the continuous features.
+ **{
+ feature_name: tf.squeeze(
+ normalization_layers[feature_name](
+ tf.expand_dims(x[feature_name], axis=-1)
+ ),
+ axis=-1,
+ )
+ for feature_name in MOVIELENS_CONFIG["continuous_features"]
+ },
+ },
+ y,
+ )
+)
+```
+
+Let's split our data into train and test sets. We also use `cache()` and
+`prefetch()` for better performance.
+
+
+```python
+ratings_ds = ratings_ds.shuffle(100_000)
+
+train_ds = (
+ ratings_ds.take(80_000)
+ .batch(MOVIELENS_CONFIG["batch_size"])
+ .cache()
+ .prefetch(tf.data.AUTOTUNE)
+)
+test_ds = (
+ ratings_ds.skip(80_000)
+ .batch(MOVIELENS_CONFIG["batch_size"])
+ .take(20_000)
+ .cache()
+ .prefetch(tf.data.AUTOTUNE)
+)
+```
+
+### Building the model
+
+The model will have embedding layers, followed by DotInteraction and feedforward
+layers.
+
+
+```python
+
+class DLRM(keras.Model):
+ def __init__(
+ self,
+ dense_num_units_lst,
+ embedding_dim=MOVIELENS_CONFIG["embedding_dim"],
+ mlp_dim=MOVIELENS_CONFIG["mlp_dim"],
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+
+ self.embedding_layers = {}
+ for feature_name in (
+ MOVIELENS_CONFIG["categorical_int_features"]
+ + MOVIELENS_CONFIG["categorical_str_features"]
+ ):
+ vocab_size = len(vocabularies[feature_name]) + 1 # +1 for OOV token
+ self.embedding_layers[feature_name] = keras.layers.Embedding(
+ input_dim=vocab_size,
+ output_dim=embedding_dim,
+ )
+
+ self.bottom_mlp = keras.Sequential(
+ [
+ keras.layers.Dense(mlp_dim, activation="relu"),
+ keras.layers.Dense(embedding_dim), # Output must match embedding_dim
+ ]
+ )
+
+ self.dot_layer = keras_rs.layers.DotInteraction()
+
+ self.top_mlp = []
+ for num_units in dense_num_units_lst:
+ self.top_mlp.append(keras.layers.Dense(num_units, activation="relu"))
+
+ self.output_layer = keras.layers.Dense(1)
+
+ self.dense_num_units_lst = dense_num_units_lst
+ self.embedding_dim = embedding_dim
+
+ def call(self, inputs):
+ embeddings = []
+ for feature_name in (
+ MOVIELENS_CONFIG["categorical_int_features"]
+ + MOVIELENS_CONFIG["categorical_str_features"]
+ ):
+ embedding = self.embedding_layers[feature_name](inputs[feature_name])
+ embeddings.append(embedding)
+
+ # Process all continuous features together.
+ continuous_inputs = []
+ for feature_name in MOVIELENS_CONFIG["continuous_features"]:
+ # Reshape each feature to (batch_size, 1)
+ feature = keras.ops.reshape(
+ keras.ops.cast(inputs[feature_name], dtype="float32"), (-1, 1)
+ )
+ continuous_inputs.append(feature)
+
+ # Concatenate into a single tensor: (batch_size, num_continuous_features)
+ concatenated_continuous = keras.ops.concatenate(continuous_inputs, axis=1)
+
+ # Pass through the Bottom MLP to get one combined vector.
+ processed_continuous = self.bottom_mlp(concatenated_continuous)
+
+ # Combine with categorical embeddings. Note: we add a list containing the
+ # single tensor.
+ combined_features = embeddings + [processed_continuous]
+
+ # Pass the list of features to the DotInteraction layer.
+ x = self.dot_layer(combined_features)
+
+ for layer in self.top_mlp:
+ x = layer(x)
+
+ x = self.output_layer(x)
+
+ return x
+
+
+dot_network = DLRM(
+ dense_num_units_lst=MOVIELENS_CONFIG["deep_net_num_units"],
+ embedding_dim=MOVIELENS_CONFIG["embedding_dim"],
+ mlp_dim=MOVIELENS_CONFIG["mlp_dim"],
+)
+
+rmse, dot_network_num_params = train_and_evaluate(
+ learning_rate=MOVIELENS_CONFIG["learning_rate"],
+ epochs=MOVIELENS_CONFIG["num_epochs"],
+ train_data=train_ds,
+ test_data=test_ds,
+ model=dot_network,
+ plot_metrics=True,
+)
+print_stats(
+ rmse_list=[rmse],
+ num_params=dot_network_num_params,
+ model_name="Dot Network",
+)
+```
+
+
+
+
+### Visualizing feature interactions
+
+The DotInteraction layer itself doesn't have a conventional "weight" matrix like a Dense
+layer. Instead, its function is to compute the dot product between the embedding vectors
+of your features.
+
+To visualize the strength of these interactions, we can calculate a matrix representing
+the pairwise interaction strength between all feature embeddings. A common way to do this
+is to take the dot product of the embedding matrices for each pair of features and then
+aggregate the result into a single value (like the mean of the absolute values) that
+represents the overall interaction strength.
+
+
+```python
+
+def get_dot_interaction_matrix(model, categorical_features, continuous_features):
+ # The new feature list for the plot labels
+ all_feature_names = categorical_features + ["all_continuous_features"]
+ num_features = len(all_feature_names)
+
+ # Store all feature outputs in the correct order.
+ all_feature_outputs = []
+
+ # Get outputs for categorical features from embedding layers (unchanged).
+ for feature_name in categorical_features:
+ embedding = model.embedding_layers[feature_name](keras.ops.array([0]))
+ all_feature_outputs.append(embedding)
+
+ # Get a single output for ALL continuous features from the shared MLP.
+ num_continuous_features = len(continuous_features)
+ # Create a dummy input of zeros for the MLP
+ dummy_continuous_input = keras.ops.zeros((1, num_continuous_features))
+ processed_continuous = model.bottom_mlp(dummy_continuous_input)
+ all_feature_outputs.append(processed_continuous)
+
+ interaction_matrix = np.zeros((num_features, num_features))
+
+ # Iterate through each pair to calculate interaction strength.
+ for i in range(num_features):
+ for j in range(num_features):
+ interaction = keras.ops.dot(
+ all_feature_outputs[i], keras.ops.transpose(all_feature_outputs[j])
+ )
+ interaction_strength = keras.ops.convert_to_numpy(np.abs(interaction))[0][0]
+ interaction_matrix[i, j] = interaction_strength
+
+ return interaction_matrix, all_feature_names
+
+
+# Get the list of categorical feature names.
+categorical_feature_names = (
+ MOVIELENS_CONFIG["categorical_int_features"]
+ + MOVIELENS_CONFIG["categorical_str_features"]
+)
+
+# Calculate the interaction matrix with the corrected function.
+interaction_matrix, feature_names = get_dot_interaction_matrix(
+ model=dot_network,
+ categorical_features=categorical_feature_names,
+ continuous_features=MOVIELENS_CONFIG["continuous_features"],
+)
+
+# Visualize the matrix as a heatmap.
+print("\nVisualizing the feature interaction strengths:")
+visualize_layer(interaction_matrix, feature_names)
+```
+
+
+
+
diff --git a/scripts/rs_master.py b/scripts/rs_master.py
index ae6b330c58..fcbaa870de 100644
--- a/scripts/rs_master.py
+++ b/scripts/rs_master.py
@@ -260,6 +260,10 @@
"path": "dcn",
"title": "Ranking with Deep and Cross Networks",
},
+ {
+ "path" : "dlrm",
+ "title" : "Rank movies with DLRM using KerasRS",
+ },
{
"path": "sas_rec",
"title": "Sequential retrieval using SASRec",
diff --git a/templates/examples/keras_rs/dlrm.md b/templates/examples/keras_rs/dlrm.md
new file mode 100644
index 0000000000..50ad0f7ef0
--- /dev/null
+++ b/templates/examples/keras_rs/dlrm.md
@@ -0,0 +1,522 @@
+# Ranking with Deep Learning Recommendation Model
+
+**Author:** [Harshith Kulkarni](https://github.com/kharshith-k)
+**Date created:** 2025/06/02
+**Last modified:** 2025/09/04
+**Description:** Rank movies with DLRM using KerasRS.
+
+
+