diff --git a/examples/timeseries/timeseries_traffic_forecasting.ipynb b/examples/timeseries/timeseries_traffic_forecasting.ipynb
new file mode 100644
index 0000000000..b63405df9a
--- /dev/null
+++ b/examples/timeseries/timeseries_traffic_forecasting.ipynb
@@ -0,0 +1,920 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "# Traffic forecasting using graph neural networks and LSTM\n",
+ "\n",
+ "**Author:** [Arash Khodadadi](https://www.linkedin.com/in/arash-khodadadi-08a02490/)
\n",
+ "**Date created:** 2021/12/28
\n",
+ "**Last modified:** 2023/11/22
\n",
+ "**Description:** This example demonstrates how to do timeseries forecasting over graphs."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Introduction\n",
+ "\n",
+ "This example shows how to forecast traffic condition using graph neural networks and LSTM.\n",
+ "Specifically, we are interested in predicting the future values of the traffic speed given\n",
+ "a history of the traffic speed for a collection of road segments.\n",
+ "\n",
+ "One popular method to\n",
+ "solve this problem is to consider each road segment's traffic speed as a separate\n",
+ "timeseries and predict the future values of each timeseries\n",
+ "using the past values of the same timeseries.\n",
+ "\n",
+ "This method, however, ignores the dependency of the traffic speed of one road segment on\n",
+ "the neighboring segments. To be able to take into account the complex interactions between\n",
+ "the traffic speed on a collection of neighboring roads, we can define the traffic network\n",
+ "as a graph and consider the traffic speed as a signal on this graph. In this example,\n",
+ "we implement a neural network architecture which can process timeseries data over a graph.\n",
+ "We first show how to process the data and create a\n",
+ "[tf.data.Dataset](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) for\n",
+ "forecasting over graphs. Then, we implement a model which uses graph convolution and\n",
+ "LSTM layers to perform forecasting over a graph.\n",
+ "\n",
+ "The data processing and the model architecture are inspired by this paper:\n",
+ "\n",
+ "Yu, Bing, Haoteng Yin, and Zhanxing Zhu. \"Spatio-temporal graph convolutional networks:\n",
+ "a deep learning framework for traffic forecasting.\" Proceedings of the 27th International\n",
+ "Joint Conference on Artificial Intelligence, 2018.\n",
+ "([github](https://github.com/VeritasYin/STGCN_IJCAI-18))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Setup"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "\n",
+ "os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
+ "\n",
+ "import pandas as pd\n",
+ "import numpy as np\n",
+ "import typing\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "import tensorflow as tf\n",
+ "import keras\n",
+ "from keras import layers\n",
+ "from keras import ops"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Data preparation"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### Data description\n",
+ "\n",
+ "We use a real-world traffic speed dataset named `PeMSD7`. We use the version\n",
+ "collected and prepared by [Yu et al., 2018](https://arxiv.org/abs/1709.04875)\n",
+ "and available\n",
+ "[here](https://github.com/VeritasYin/STGCN_IJCAI-18/tree/master/dataset).\n",
+ "\n",
+ "The data consists of two files:\n",
+ "\n",
+ "- `PeMSD7_W_228.csv` contains the distances between 228\n",
+ "stations across the District 7 of California.\n",
+ "- `PeMSD7_V_228.csv` contains traffic\n",
+ "speed collected for those stations in the weekdays of May and June of 2012.\n",
+ "\n",
+ "The full description of the dataset can be found in\n",
+ "[Yu et al., 2018](https://arxiv.org/abs/1709.04875)."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### Loading data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "url = \"https://github.com/VeritasYin/STGCN_IJCAI-18/raw/master/dataset/PeMSD7_Full.zip\"\n",
+ "# 1. Download and extract normally\n",
+ "zip_path = keras.utils.get_file(origin=url, extract=True, archive_format=\"zip\")\n",
+ "\n",
+ "# 2. FIX: Use os.path.dirname to safely get the folder where it was extracted\n",
+ "data_dir = os.path.dirname(zip_path)\n",
+ "\n",
+ "# 3. Construct the paths to the inner files safely\n",
+ "route_distances = pd.read_csv(\n",
+ " os.path.join(data_dir, \"PeMSD7_W_228.csv\"), header=None\n",
+ ").to_numpy()\n",
+ "speeds_array = pd.read_csv(\n",
+ " os.path.join(data_dir, \"PeMSD7_V_228.csv\"), header=None\n",
+ ").to_numpy()\n",
+ "\n",
+ "print(f\"route_distances shape={route_distances.shape}\")\n",
+ "print(f\"speeds_array shape={speeds_array.shape}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### sub-sampling roads\n",
+ "\n",
+ "To reduce the problem size and make the training faster, we will only\n",
+ "work with a sample of 26 roads out of the 228 roads in the dataset.\n",
+ "We have chosen the roads by starting from road 0, choosing the 5 closest\n",
+ "roads to it, and continuing this process until we get 25 roads. You can choose\n",
+ "any other subset of the roads. We chose the roads in this way to increase the likelihood\n",
+ "of having roads with correlated speed timeseries.\n",
+ "`sample_routes` contains the IDs of the selected roads."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "sample_routes = [\n",
+ " 0,\n",
+ " 1,\n",
+ " 4,\n",
+ " 7,\n",
+ " 8,\n",
+ " 11,\n",
+ " 15,\n",
+ " 108,\n",
+ " 109,\n",
+ " 114,\n",
+ " 115,\n",
+ " 118,\n",
+ " 120,\n",
+ " 123,\n",
+ " 124,\n",
+ " 126,\n",
+ " 127,\n",
+ " 129,\n",
+ " 130,\n",
+ " 132,\n",
+ " 133,\n",
+ " 136,\n",
+ " 139,\n",
+ " 144,\n",
+ " 147,\n",
+ " 216,\n",
+ "]\n",
+ "route_distances = route_distances[np.ix_(sample_routes, sample_routes)]\n",
+ "speeds_array = speeds_array[:, sample_routes]\n",
+ "\n",
+ "print(f\"route_distances shape={route_distances.shape}\")\n",
+ "print(f\"speeds_array shape={speeds_array.shape}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### Data visualization\n",
+ "\n",
+ "Here are the timeseries of the traffic speed for two of the routes:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "plt.figure(figsize=(18, 6))\n",
+ "plt.plot(speeds_array[:, [0, -1]])\n",
+ "plt.legend([\"route_0\", \"route_25\"])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "We can also visualize the correlation between the timeseries in different routes."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "plt.figure(figsize=(8, 8))\n",
+ "plt.matshow(np.corrcoef(speeds_array.T), 0)\n",
+ "plt.xlabel(\"road number\")\n",
+ "plt.ylabel(\"road number\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "Using this correlation heatmap, we can see that for example the speed in\n",
+ "routes 4, 5, 6 are highly correlated."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### Splitting and normalizing data\n",
+ "\n",
+ "Next, we split the speed values array into train/validation/test sets,\n",
+ "and normalize the resulting arrays:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "train_size, val_size = 0.5, 0.2\n",
+ "\n",
+ "\n",
+ "def preprocess(data_array: np.ndarray, train_size: float, val_size: float):\n",
+ " \"\"\"Splits data into train/val/test sets and normalizes the data.\n",
+ "\n",
+ " Args:\n",
+ " data_array: ndarray of shape `(num_time_steps, num_routes)`\n",
+ " train_size: A float value between 0.0 and 1.0 that represent the proportion of the dataset\n",
+ " to include in the train split.\n",
+ " val_size: A float value between 0.0 and 1.0 that represent the proportion of the dataset\n",
+ " to include in the validation split.\n",
+ "\n",
+ " Returns:\n",
+ " `train_array`, `val_array`, `test_array`\n",
+ " \"\"\"\n",
+ "\n",
+ " num_time_steps = data_array.shape[0]\n",
+ " num_train, num_val = (\n",
+ " int(num_time_steps * train_size),\n",
+ " int(num_time_steps * val_size),\n",
+ " )\n",
+ " train_array = data_array[:num_train]\n",
+ " mean, std = train_array.mean(axis=0), train_array.std(axis=0)\n",
+ "\n",
+ " train_array = (train_array - mean) / std\n",
+ " val_array = (data_array[num_train : (num_train + num_val)] - mean) / std\n",
+ " test_array = (data_array[(num_train + num_val) :] - mean) / std\n",
+ "\n",
+ " return train_array, val_array, test_array\n",
+ "\n",
+ "\n",
+ "train_array, val_array, test_array = preprocess(speeds_array, train_size, val_size)\n",
+ "\n",
+ "print(f\"train set size: {train_array.shape}\")\n",
+ "print(f\"validation set size: {val_array.shape}\")\n",
+ "print(f\"test set size: {test_array.shape}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### Creating TensorFlow Datasets\n",
+ "\n",
+ "Next, we create the datasets for our forecasting problem. The forecasting problem\n",
+ "can be stated as follows: given a sequence of the\n",
+ "road speed values at times `t+1, t+2, ..., t+T`, we want to predict the future values of\n",
+ "the roads speed for times `t+T+1, ..., t+T+h`. So for each time `t` the inputs to our\n",
+ "model are `T` vectors each of size `N` and the targets are `h` vectors each of size `N`,\n",
+ "where `N` is the number of roads."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "We use the Keras built-in function\n",
+ "`keras.utils.timeseries_dataset_from_array`.\n",
+ "The function `create_tf_dataset()` below takes as input a `numpy.ndarray` and returns a\n",
+ "`tf.data.Dataset`. In this function `input_sequence_length=T` and `forecast_horizon=h`.\n",
+ "\n",
+ "The argument `multi_horizon` needs more explanation. Assume `forecast_horizon=3`.\n",
+ "If `multi_horizon=True` then the model will make a forecast for time steps\n",
+ "`t+T+1, t+T+2, t+T+3`. So the target will have shape `(T,3)`. But if\n",
+ "`multi_horizon=False`, the model will make a forecast only for time step `t+T+3` and\n",
+ "so the target will have shape `(T, 1)`.\n",
+ "\n",
+ "You may notice that the input tensor in each batch has shape\n",
+ "`(batch_size, input_sequence_length, num_routes, 1)`. The last dimension is added to\n",
+ "make the model more general: at each time step, the input features for each raod may\n",
+ "contain multiple timeseries. For instance, one might want to use temperature timeseries\n",
+ "in addition to historical values of the speed as input features. In this example,\n",
+ "however, the last dimension of the input is always 1.\n",
+ "\n",
+ "We use the last 12 values of the speed in each road to forecast the speed for 3 time\n",
+ "steps ahead:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "batch_size = 64\n",
+ "input_sequence_length = 12\n",
+ "forecast_horizon = 3\n",
+ "multi_horizon = False\n",
+ "\n",
+ "\n",
+ "def create_tf_dataset(\n",
+ " data_array: np.ndarray,\n",
+ " input_sequence_length: int,\n",
+ " forecast_horizon: int,\n",
+ " batch_size: int = 128,\n",
+ " shuffle=True,\n",
+ " multi_horizon=True,\n",
+ "):\n",
+ " \"\"\"Creates tensorflow dataset from numpy array.\n",
+ "\n",
+ " This function creates a dataset where each element is a tuple `(inputs, targets)`.\n",
+ " `inputs` is a Tensor\n",
+ " of shape `(batch_size, input_sequence_length, num_routes, 1)` containing\n",
+ " the `input_sequence_length` past values of the timeseries for each node.\n",
+ " `targets` is a Tensor of shape `(batch_size, forecast_horizon, num_routes)`\n",
+ " containing the `forecast_horizon`\n",
+ " future values of the timeseries for each node.\n",
+ "\n",
+ " Args:\n",
+ " data_array: np.ndarray with shape `(num_time_steps, num_routes)`\n",
+ " input_sequence_length: Length of the input sequence (in number of timesteps).\n",
+ " forecast_horizon: If `multi_horizon=True`, the target will be the values of the timeseries for 1 to\n",
+ " `forecast_horizon` timesteps ahead. If `multi_horizon=False`, the target will be the value of the\n",
+ " timeseries `forecast_horizon` steps ahead (only one value).\n",
+ " batch_size: Number of timeseries samples in each batch.\n",
+ " shuffle: Whether to shuffle output samples, or instead draw them in chronological order.\n",
+ " multi_horizon: See `forecast_horizon`.\n",
+ "\n",
+ " Returns:\n",
+ " A tf.data.Dataset instance.\n",
+ " \"\"\"\n",
+ "\n",
+ " inputs = keras.utils.timeseries_dataset_from_array(\n",
+ " np.expand_dims(data_array[:-forecast_horizon], axis=-1),\n",
+ " None,\n",
+ " sequence_length=input_sequence_length,\n",
+ " shuffle=False,\n",
+ " batch_size=batch_size,\n",
+ " )\n",
+ "\n",
+ " target_offset = (\n",
+ " input_sequence_length\n",
+ " if multi_horizon\n",
+ " else input_sequence_length + forecast_horizon - 1\n",
+ " )\n",
+ " target_seq_length = forecast_horizon if multi_horizon else 1\n",
+ " targets = keras.utils.timeseries_dataset_from_array(\n",
+ " data_array[target_offset:],\n",
+ " None,\n",
+ " sequence_length=target_seq_length,\n",
+ " shuffle=False,\n",
+ " batch_size=batch_size,\n",
+ " )\n",
+ "\n",
+ " dataset = tf.data.Dataset.zip((inputs, targets))\n",
+ " if shuffle:\n",
+ " dataset = dataset.shuffle(100)\n",
+ "\n",
+ " return dataset.prefetch(16).cache()\n",
+ "\n",
+ "\n",
+ "train_dataset, val_dataset = (\n",
+ " create_tf_dataset(data_array, input_sequence_length, forecast_horizon, batch_size)\n",
+ " for data_array in [train_array, val_array]\n",
+ ")\n",
+ "\n",
+ "test_dataset = create_tf_dataset(\n",
+ " test_array,\n",
+ " input_sequence_length,\n",
+ " forecast_horizon,\n",
+ " batch_size=test_array.shape[0],\n",
+ " shuffle=False,\n",
+ " multi_horizon=multi_horizon,\n",
+ ")\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### Roads Graph\n",
+ "\n",
+ "As mentioned before, we assume that the road segments form a graph.\n",
+ "The `PeMSD7` dataset has the road segments distance. The next step\n",
+ "is to create the graph adjacency matrix from these distances. Following\n",
+ "[Yu et al., 2018](https://arxiv.org/abs/1709.04875) (equation 10) we assume there\n",
+ "is an edge between two nodes in the graph if the distance between the corresponding roads\n",
+ "is less than a threshold."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "def compute_adjacency_matrix(\n",
+ " route_distances: np.ndarray, sigma2: float, epsilon: float\n",
+ "):\n",
+ " \"\"\"Computes the adjacency matrix from distances matrix.\n",
+ "\n",
+ " It uses the formula in https://github.com/VeritasYin/STGCN_IJCAI-18#data-preprocessing to\n",
+ " compute an adjacency matrix from the distance matrix.\n",
+ " The implementation follows that paper.\n",
+ "\n",
+ " Args:\n",
+ " route_distances: np.ndarray of shape `(num_routes, num_routes)`. Entry `i,j` of this array is the\n",
+ " distance between roads `i,j`.\n",
+ " sigma2: Determines the width of the Gaussian kernel applied to the square distances matrix.\n",
+ " epsilon: A threshold specifying if there is an edge between two nodes. Specifically, `A[i,j]=1`\n",
+ " if `np.exp(-w2[i,j] / sigma2) >= epsilon` and `A[i,j]=0` otherwise, where `A` is the adjacency\n",
+ " matrix and `w2=route_distances * route_distances`\n",
+ "\n",
+ " Returns:\n",
+ " A boolean graph adjacency matrix.\n",
+ " \"\"\"\n",
+ " num_routes = route_distances.shape[0]\n",
+ " route_distances = route_distances / 10000.0\n",
+ " w2, w_mask = (\n",
+ " route_distances * route_distances,\n",
+ " np.ones([num_routes, num_routes]) - np.identity(num_routes),\n",
+ " )\n",
+ " return (np.exp(-w2 / sigma2) >= epsilon) * w_mask\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "The function `compute_adjacency_matrix()` returns a boolean adjacency matrix\n",
+ "where 1 means there is an edge between two nodes. We use the following class\n",
+ "to store the information about the graph."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "class GraphInfo:\n",
+ " def __init__(self, edges: typing.Tuple[list, list], num_nodes: int):\n",
+ " self.edges = edges\n",
+ " self.num_nodes = num_nodes\n",
+ "\n",
+ "\n",
+ "sigma2 = 0.1\n",
+ "epsilon = 0.5\n",
+ "adjacency_matrix = compute_adjacency_matrix(route_distances, sigma2, epsilon)\n",
+ "node_indices, neighbor_indices = np.where(adjacency_matrix == 1)\n",
+ "graph = GraphInfo(\n",
+ " edges=(node_indices.tolist(), neighbor_indices.tolist()),\n",
+ " num_nodes=adjacency_matrix.shape[0],\n",
+ ")\n",
+ "print(f\"number of nodes: {graph.num_nodes}, number of edges: {len(graph.edges[0])}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Network architecture\n",
+ "\n",
+ "Our model for forecasting over the graph consists of a graph convolution\n",
+ "layer and a LSTM layer."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### Graph convolution layer\n",
+ "\n",
+ "Our implementation of the graph convolution layer resembles the implementation\n",
+ "in [this Keras example](https://keras.io/examples/graph/gnn_citations/). Note that\n",
+ "in that example input to the layer is a 2D tensor of shape `(num_nodes,in_feat)`\n",
+ "but in our example the input to the layer is a 4D tensor of shape\n",
+ "`(num_nodes, batch_size, input_seq_length, in_feat)`. The graph convolution layer\n",
+ "performs the following steps:\n",
+ "\n",
+ "- The nodes' representations are computed in `self.compute_nodes_representation()`\n",
+ "by multiplying the input features by `self.weight`\n",
+ "- The aggregated neighbors' messages are computed in `self.compute_aggregated_messages()`\n",
+ "by first aggregating the neighbors' representations and then multiplying the results by\n",
+ "`self.weight`\n",
+ "- The final output of the layer is computed in `self.update()` by combining the nodes\n",
+ "representations and the neighbors' aggregated messages"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "class GraphConv(layers.Layer):\n",
+ " def __init__(\n",
+ " self,\n",
+ " in_feat,\n",
+ " out_feat,\n",
+ " graph_info: GraphInfo,\n",
+ " aggregation_type=\"mean\",\n",
+ " combination_type=\"concat\",\n",
+ " activation: typing.Optional[str] = None,\n",
+ " **kwargs,\n",
+ " ):\n",
+ " super().__init__(**kwargs)\n",
+ " self.in_feat = in_feat\n",
+ " self.out_feat = out_feat\n",
+ " self.graph_info = graph_info\n",
+ " self.aggregation_type = aggregation_type\n",
+ " self.combination_type = combination_type\n",
+ " self.weight = self.add_weight(\n",
+ " initializer=keras.initializers.GlorotUniform(),\n",
+ " shape=(in_feat, out_feat),\n",
+ " dtype=\"float32\",\n",
+ " trainable=True,\n",
+ " )\n",
+ " self.activation = layers.Activation(activation)\n",
+ "\n",
+ " def aggregate(self, neighbour_representations):\n",
+ " aggregation_func = {\n",
+ " \"sum\": tf.math.unsorted_segment_sum,\n",
+ " \"mean\": tf.math.unsorted_segment_mean,\n",
+ " \"max\": tf.math.unsorted_segment_max,\n",
+ " }.get(self.aggregation_type)\n",
+ "\n",
+ " if aggregation_func:\n",
+ " return aggregation_func(\n",
+ " neighbour_representations,\n",
+ " self.graph_info.edges[0],\n",
+ " num_segments=self.graph_info.num_nodes,\n",
+ " )\n",
+ "\n",
+ " raise ValueError(f\"Invalid aggregation type: {self.aggregation_type}\")\n",
+ "\n",
+ " def compute_nodes_representation(self, features):\n",
+ " \"\"\"Computes each node's representation.\n",
+ "\n",
+ " The nodes' representations are obtained by multiplying the features tensor with\n",
+ " `self.weight`. Note that\n",
+ " `self.weight` has shape `(in_feat, out_feat)`.\n",
+ "\n",
+ " Args:\n",
+ " features: Tensor of shape `(num_nodes, batch_size, input_seq_len, in_feat)`\n",
+ "\n",
+ " Returns:\n",
+ " A tensor of shape `(num_nodes, batch_size, input_seq_len, out_feat)`\n",
+ " \"\"\"\n",
+ " return ops.matmul(features, self.weight)\n",
+ "\n",
+ " def compute_aggregated_messages(self, features):\n",
+ " neighbour_representations = tf.gather(features, self.graph_info.edges[1])\n",
+ " aggregated_messages = self.aggregate(neighbour_representations)\n",
+ " return ops.matmul(aggregated_messages, self.weight)\n",
+ "\n",
+ " def update(self, nodes_representation, aggregated_messages):\n",
+ " if self.combination_type == \"concat\":\n",
+ " h = ops.concatenate([nodes_representation, aggregated_messages], axis=-1)\n",
+ " elif self.combination_type == \"add\":\n",
+ " h = nodes_representation + aggregated_messages\n",
+ " else:\n",
+ " raise ValueError(f\"Invalid combination type: {self.combination_type}.\")\n",
+ " return self.activation(h)\n",
+ "\n",
+ " def call(self, features):\n",
+ " \"\"\"Forward pass.\n",
+ "\n",
+ " Args:\n",
+ " features: tensor of shape `(num_nodes, batch_size, input_seq_len, in_feat)`\n",
+ "\n",
+ " Returns:\n",
+ " A tensor of shape `(num_nodes, batch_size, input_seq_len, out_feat)`\n",
+ " \"\"\"\n",
+ " nodes_representation = self.compute_nodes_representation(features)\n",
+ " aggregated_messages = self.compute_aggregated_messages(features)\n",
+ " return self.update(nodes_representation, aggregated_messages)\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "### LSTM plus graph convolution\n",
+ "\n",
+ "By applying the graph convolution layer to the input tensor, we get another tensor\n",
+ "containing the nodes' representations over time (another 4D tensor). For each time\n",
+ "step, a node's representation is informed by the information from its neighbors.\n",
+ "\n",
+ "To make good forecasts, however, we need not only information from the neighbors\n",
+ "but also we need to process the information over time. To this end, we can pass each\n",
+ "node's tensor through a recurrent layer. The `LSTMGC` layer below, first applies\n",
+ "a graph convolution layer to the inputs and then passes the results through a\n",
+ "`LSTM` layer."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "class LSTMGC(layers.Layer):\n",
+ " \"\"\"Layer comprising a convolution layer followed by LSTM and dense layers.\"\"\"\n",
+ "\n",
+ " def __init__(\n",
+ " self,\n",
+ " in_feat,\n",
+ " out_feat,\n",
+ " lstm_units: int,\n",
+ " input_seq_len: int,\n",
+ " output_seq_len: int,\n",
+ " graph_info: GraphInfo,\n",
+ " graph_conv_params: typing.Optional[dict] = None,\n",
+ " **kwargs,\n",
+ " ):\n",
+ " super().__init__(**kwargs)\n",
+ "\n",
+ " # graph conv layer\n",
+ " if graph_conv_params is None:\n",
+ " graph_conv_params = {\n",
+ " \"aggregation_type\": \"mean\",\n",
+ " \"combination_type\": \"concat\",\n",
+ " \"activation\": None,\n",
+ " }\n",
+ " self.graph_conv = GraphConv(in_feat, out_feat, graph_info, **graph_conv_params)\n",
+ "\n",
+ " self.lstm = layers.LSTM(lstm_units, activation=\"relu\")\n",
+ " self.dense = layers.Dense(output_seq_len)\n",
+ "\n",
+ " self.input_seq_len, self.output_seq_len = input_seq_len, output_seq_len\n",
+ "\n",
+ " def call(self, inputs):\n",
+ " \"\"\"Forward pass.\n",
+ "\n",
+ " Args:\n",
+ " inputs: tensor of shape `(batch_size, input_seq_len, num_nodes, in_feat)`\n",
+ "\n",
+ " Returns:\n",
+ " A tensor of shape `(batch_size, output_seq_len, num_nodes)`.\n",
+ " \"\"\"\n",
+ "\n",
+ " # convert shape to (num_nodes, batch_size, input_seq_len, in_feat)\n",
+ " inputs = ops.transpose(inputs, [2, 0, 1, 3])\n",
+ "\n",
+ " gcn_out = self.graph_conv(\n",
+ " inputs\n",
+ " ) # gcn_out has shape: (num_nodes, batch_size, input_seq_len, out_feat)\n",
+ " shape = ops.shape(gcn_out)\n",
+ " num_nodes, batch_size, input_seq_len, out_feat = (\n",
+ " shape[0],\n",
+ " shape[1],\n",
+ " shape[2],\n",
+ " shape[3],\n",
+ " )\n",
+ "\n",
+ " # LSTM takes only 3D tensors as input\n",
+ " gcn_out = ops.reshape(\n",
+ " gcn_out, (batch_size * num_nodes, input_seq_len, out_feat)\n",
+ " )\n",
+ " lstm_out = self.lstm(\n",
+ " gcn_out\n",
+ " ) # lstm_out has shape: (batch_size * num_nodes, lstm_units)\n",
+ "\n",
+ " dense_output = self.dense(\n",
+ " lstm_out\n",
+ " ) # dense_output has shape: (batch_size * num_nodes, output_seq_len)\n",
+ " output = ops.reshape(dense_output, (num_nodes, batch_size, self.output_seq_len))\n",
+ " return ops.transpose(\n",
+ " output, [1, 2, 0]\n",
+ " ) # returns Tensor of shape (batch_size, output_seq_len, num_nodes)\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Model training"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "in_feat = 1\n",
+ "batch_size = 64\n",
+ "epochs = 20\n",
+ "input_sequence_length = 12\n",
+ "forecast_horizon = 3\n",
+ "multi_horizon = False\n",
+ "out_feat = 10\n",
+ "lstm_units = 64\n",
+ "graph_conv_params = {\n",
+ " \"aggregation_type\": \"mean\",\n",
+ " \"combination_type\": \"concat\",\n",
+ " \"activation\": None,\n",
+ "}\n",
+ "\n",
+ "st_gcn = LSTMGC(\n",
+ " in_feat,\n",
+ " out_feat,\n",
+ " lstm_units,\n",
+ " input_sequence_length,\n",
+ " forecast_horizon,\n",
+ " graph,\n",
+ " graph_conv_params,\n",
+ ")\n",
+ "inputs = layers.Input((input_sequence_length, graph.num_nodes, in_feat))\n",
+ "outputs = st_gcn(inputs)\n",
+ "\n",
+ "model = keras.models.Model(inputs, outputs)\n",
+ "model.compile(\n",
+ " optimizer=keras.optimizers.RMSprop(learning_rate=0.0002),\n",
+ " loss=keras.losses.MeanSquaredError(),\n",
+ ")\n",
+ "model.fit(\n",
+ " train_dataset,\n",
+ " validation_data=val_dataset,\n",
+ " epochs=epochs,\n",
+ " callbacks=[keras.callbacks.EarlyStopping(patience=10)],\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Making forecasts on test set\n",
+ "\n",
+ "Now we can use the trained model to make forecasts for the test set. Below, we\n",
+ "compute the MAE of the model and compare it to the MAE of naive forecasts.\n",
+ "The naive forecasts are the last value of the speed for each node."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "x_test, y = next(test_dataset.as_numpy_iterator())\n",
+ "y_pred = model.predict(x_test)\n",
+ "plt.figure(figsize=(18, 6))\n",
+ "plt.plot(y[:, 0, 0])\n",
+ "plt.plot(y_pred[:, 0, 0])\n",
+ "plt.legend([\"actual\", \"forecast\"])\n",
+ "\n",
+ "naive_mse, model_mse = (\n",
+ " np.square(x_test[:, -1, :, 0] - y[:, 0, :]).mean(),\n",
+ " np.square(y_pred[:, 0, :] - y[:, 0, :]).mean(),\n",
+ ")\n",
+ "print(f\"naive MAE: {naive_mse}, model MAE: {model_mse}\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "Of course, the goal here is to demonstrate the method,\n",
+ "not to achieve the best performance. To improve the\n",
+ "model's accuracy, all model hyperparameters should be tuned carefully. In addition,\n",
+ "several of the `LSTMGC` blocks can be stacked to increase the representation power\n",
+ "of the model."
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "collapsed_sections": [],
+ "name": "timeseries_traffic_forecasting",
+ "private_outputs": false,
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file
diff --git a/examples/timeseries/timeseries_traffic_forecasting.py b/examples/timeseries/timeseries_traffic_forecasting.py
index ea66ee2131..f9fa5b9900 100644
--- a/examples/timeseries/timeseries_traffic_forecasting.py
+++ b/examples/timeseries/timeseries_traffic_forecasting.py
@@ -83,9 +83,13 @@
"""
url = "https://github.com/VeritasYin/STGCN_IJCAI-18/raw/master/dataset/PeMSD7_Full.zip"
-data_dir = keras.utils.get_file(origin=url, extract=True, archive_format="zip")
-data_dir = data_dir.rstrip("PeMSD7_Full.zip")
+# 1. Download and extract normally
+zip_path = keras.utils.get_file(origin=url, extract=True, archive_format="zip")
+# 2. FIX: Use os.path.dirname to safely get the folder where it was extracted
+data_dir = os.path.dirname(zip_path)
+
+# 3. Construct the paths to the inner files safely
route_distances = pd.read_csv(
os.path.join(data_dir, "PeMSD7_W_228.csv"), header=None
).to_numpy()
diff --git a/examples/timeseries/timeseries_weather_forecasting.ipynb b/examples/timeseries/timeseries_weather_forecasting.ipynb
new file mode 100644
index 0000000000..dafe27bd7e
--- /dev/null
+++ b/examples/timeseries/timeseries_weather_forecasting.ipynb
@@ -0,0 +1,574 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "# Timeseries forecasting for weather prediction\n",
+ "\n",
+ "**Authors:** [Prabhanshu Attri](https://prabhanshu.com/github), [Yashika Sharma](https://github.com/yashika51), [Kristi Takach](https://github.com/ktakattack), [Falak Shah](https://github.com/falaktheoptimist)
\n",
+ "**Date created:** 2020/06/23
\n",
+ "**Last modified:** 2023/11/22
\n",
+ "**Description:** This notebook demonstrates how to do timeseries forecasting using a LSTM model."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Setup"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "import pandas as pd\n",
+ "import matplotlib.pyplot as plt\n",
+ "import keras"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Climate Data Time-Series\n",
+ "\n",
+ "We will be using Jena Climate dataset recorded by the\n",
+ "[Max Planck Institute for Biogeochemistry](https://www.bgc-jena.mpg.de/wetter/).\n",
+ "The dataset consists of 14 features such as temperature, pressure, humidity etc, recorded once per\n",
+ "10 minutes.\n",
+ "\n",
+ "**Location**: Weather Station, Max Planck Institute for Biogeochemistry\n",
+ "in Jena, Germany\n",
+ "\n",
+ "**Time-frame Considered**: Jan 10, 2009 - December 31, 2016\n",
+ "\n",
+ "\n",
+ "The table below shows the column names, their value formats, and their description.\n",
+ "\n",
+ "Index| Features |Format |Description\n",
+ "-----|---------------|-------------------|-----------------------\n",
+ "1 |Date Time |01.01.2009 00:10:00|Date-time reference\n",
+ "2 |p (mbar) |996.52 |The pascal SI derived unit of pressure used to quantify internal pressure. Meteorological reports typically state atmospheric pressure in millibars.\n",
+ "3 |T (degC) |-8.02 |Temperature in Celsius\n",
+ "4 |Tpot (K) |265.4 |Temperature in Kelvin\n",
+ "5 |Tdew (degC) |-8.9 |Temperature in Celsius relative to humidity. Dew Point is a measure of the absolute amount of water in the air, the DP is the temperature at which the air cannot hold all the moisture in it and water condenses.\n",
+ "6 |rh (%) |93.3 |Relative Humidity is a measure of how saturated the air is with water vapor, the %RH determines the amount of water contained within collection objects.\n",
+ "7 |VPmax (mbar) |3.33 |Saturation vapor pressure\n",
+ "8 |VPact (mbar) |3.11 |Vapor pressure\n",
+ "9 |VPdef (mbar) |0.22 |Vapor pressure deficit\n",
+ "10 |sh (g/kg) |1.94 |Specific humidity\n",
+ "11 |H2OC (mmol/mol)|3.12 |Water vapor concentration\n",
+ "12 |rho (g/m ** 3) |1307.75 |Airtight\n",
+ "13 |wv (m/s) |1.03 |Wind speed\n",
+ "14 |max. wv (m/s) |1.75 |Maximum wind speed\n",
+ "15 |wd (deg) |152.3 |Wind direction in degrees"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "from zipfile import ZipFile\n",
+ "\n",
+ "uri = \"https://storage.googleapis.com/tensorflow/tf-keras-datasets/jena_climate_2009_2016.csv.zip\"\n",
+ "zip_path = keras.utils.get_file(origin=uri, fname=\"jena_climate_2009_2016.csv.zip\")\n",
+ "zip_file = ZipFile(zip_path)\n",
+ "\n",
+ "# FIX: Extract to the cache directory, not the current working directory\n",
+ "zip_file.extractall(path=os.path.dirname(zip_path))\n",
+ "\n",
+ "# FIX: Construct the absolute path safely (works on Windows/Linux/Mac)\n",
+ "csv_path = os.path.join(os.path.dirname(zip_path), \"jena_climate_2009_2016.csv\")\n",
+ "\n",
+ "df = pd.read_csv(csv_path)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Raw Data Visualization\n",
+ "\n",
+ "To give us a sense of the data we are working with, each feature has been plotted below.\n",
+ "This shows the distinct pattern of each feature over the time period from 2009 to 2016.\n",
+ "It also shows where anomalies are present, which will be addressed during normalization."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "titles = [\n",
+ " \"Pressure\",\n",
+ " \"Temperature\",\n",
+ " \"Temperature in Kelvin\",\n",
+ " \"Temperature (dew point)\",\n",
+ " \"Relative Humidity\",\n",
+ " \"Saturation vapor pressure\",\n",
+ " \"Vapor pressure\",\n",
+ " \"Vapor pressure deficit\",\n",
+ " \"Specific humidity\",\n",
+ " \"Water vapor concentration\",\n",
+ " \"Airtight\",\n",
+ " \"Wind speed\",\n",
+ " \"Maximum wind speed\",\n",
+ " \"Wind direction in degrees\",\n",
+ "]\n",
+ "\n",
+ "feature_keys = [\n",
+ " \"p (mbar)\",\n",
+ " \"T (degC)\",\n",
+ " \"Tpot (K)\",\n",
+ " \"Tdew (degC)\",\n",
+ " \"rh (%)\",\n",
+ " \"VPmax (mbar)\",\n",
+ " \"VPact (mbar)\",\n",
+ " \"VPdef (mbar)\",\n",
+ " \"sh (g/kg)\",\n",
+ " \"H2OC (mmol/mol)\",\n",
+ " \"rho (g/m**3)\",\n",
+ " \"wv (m/s)\",\n",
+ " \"max. wv (m/s)\",\n",
+ " \"wd (deg)\",\n",
+ "]\n",
+ "\n",
+ "colors = [\n",
+ " \"blue\",\n",
+ " \"orange\",\n",
+ " \"green\",\n",
+ " \"red\",\n",
+ " \"purple\",\n",
+ " \"brown\",\n",
+ " \"pink\",\n",
+ " \"gray\",\n",
+ " \"olive\",\n",
+ " \"cyan\",\n",
+ "]\n",
+ "\n",
+ "date_time_key = \"Date Time\"\n",
+ "\n",
+ "\n",
+ "def show_raw_visualization(data):\n",
+ " time_data = data[date_time_key]\n",
+ " fig, axes = plt.subplots(\n",
+ " nrows=7, ncols=2, figsize=(15, 20), dpi=80, facecolor=\"w\", edgecolor=\"k\"\n",
+ " )\n",
+ " for i in range(len(feature_keys)):\n",
+ " key = feature_keys[i]\n",
+ " c = colors[i % (len(colors))]\n",
+ " t_data = data[key]\n",
+ " t_data.index = time_data\n",
+ " t_data.head()\n",
+ " ax = t_data.plot(\n",
+ " ax=axes[i // 2, i % 2],\n",
+ " color=c,\n",
+ " title=\"{} - {}\".format(titles[i], key),\n",
+ " rot=25,\n",
+ " )\n",
+ " ax.legend([titles[i]])\n",
+ " plt.tight_layout()\n",
+ "\n",
+ "\n",
+ "show_raw_visualization(df)\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Data Preprocessing\n",
+ "\n",
+ "Here we are picking ~300,000 data points for training. Observation is recorded every\n",
+ "10 mins, that means 6 times per hour. We will resample one point per hour since no\n",
+ "drastic change is expected within 60 minutes. We do this via the `sampling_rate`\n",
+ "argument in `timeseries_dataset_from_array` utility.\n",
+ "\n",
+ "We are tracking data from past 720 timestamps (720/6=120 hours). This data will be\n",
+ "used to predict the temperature after 72 timestamps (72/6=12 hours).\n",
+ "\n",
+ "Since every feature has values with\n",
+ "varying ranges, we do normalization to confine feature values to a range of `[0, 1]` before\n",
+ "training a neural network.\n",
+ "We do this by subtracting the mean and dividing by the standard deviation of each feature.\n",
+ "\n",
+ "71.5 % of the data will be used to train the model, i.e. 300,693 rows. `split_fraction` can\n",
+ "be changed to alter this percentage.\n",
+ "\n",
+ "The model is shown data for first 5 days i.e. 720 observations, that are sampled every\n",
+ "hour. The temperature after 72 (12 hours * 6 observation per hour) observation will be\n",
+ "used as a label."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "split_fraction = 0.715\n",
+ "train_split = int(split_fraction * int(df.shape[0]))\n",
+ "step = 6\n",
+ "\n",
+ "past = 720\n",
+ "future = 72\n",
+ "learning_rate = 0.001\n",
+ "batch_size = 256\n",
+ "epochs = 10\n",
+ "\n",
+ "\n",
+ "def normalize(data, train_split):\n",
+ " data_mean = data[:train_split].mean(axis=0)\n",
+ " data_std = data[:train_split].std(axis=0)\n",
+ " return (data - data_mean) / data_std\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "We can see from the correlation heatmap, few parameters like Relative Humidity and\n",
+ "Specific Humidity are redundant. Hence we will be using select features, not all."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "print(\n",
+ " \"The selected parameters are:\",\n",
+ " \", \".join([titles[i] for i in [0, 1, 5, 7, 8, 10, 11]]),\n",
+ ")\n",
+ "selected_features = [feature_keys[i] for i in [0, 1, 5, 7, 8, 10, 11]]\n",
+ "features = df[selected_features]\n",
+ "features.index = df[date_time_key]\n",
+ "features.head()\n",
+ "\n",
+ "features = normalize(features.values, train_split)\n",
+ "features = pd.DataFrame(features)\n",
+ "features.head()\n",
+ "\n",
+ "train_data = features.loc[0 : train_split - 1]\n",
+ "val_data = features.loc[train_split:]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "# Training dataset\n",
+ "\n",
+ "The training dataset labels starts from the 792nd observation (720 + 72)."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "start = past + future\n",
+ "end = start + train_split\n",
+ "\n",
+ "x_train = train_data[[i for i in range(7)]].values\n",
+ "y_train = features.iloc[start:end][[1]]\n",
+ "\n",
+ "sequence_length = int(past / step)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "The `timeseries_dataset_from_array` function takes in a sequence of data-points gathered at\n",
+ "equal intervals, along with time series parameters such as length of the\n",
+ "sequences/windows, spacing between two sequence/windows, etc., to produce batches of\n",
+ "sub-timeseries inputs and targets sampled from the main timeseries."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "dataset_train = keras.preprocessing.timeseries_dataset_from_array(\n",
+ " x_train,\n",
+ " y_train,\n",
+ " sequence_length=sequence_length,\n",
+ " sampling_rate=step,\n",
+ " batch_size=batch_size,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Validation dataset\n",
+ "\n",
+ "The validation dataset must not contain the last 792 rows as we won't have label data for\n",
+ "those records, hence 792 must be subtracted from the end of the data.\n",
+ "\n",
+ "The validation label dataset must start from 792 after train_split, hence we must add\n",
+ "past + future (792) to label_start."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "x_end = len(val_data) - past - future\n",
+ "\n",
+ "label_start = train_split + past + future\n",
+ "\n",
+ "x_val = val_data.iloc[:x_end][[i for i in range(7)]].values\n",
+ "y_val = features.iloc[label_start:][[1]]\n",
+ "\n",
+ "dataset_val = keras.preprocessing.timeseries_dataset_from_array(\n",
+ " x_val,\n",
+ " y_val,\n",
+ " sequence_length=sequence_length,\n",
+ " sampling_rate=step,\n",
+ " batch_size=batch_size,\n",
+ ")\n",
+ "\n",
+ "\n",
+ "for batch in dataset_train.take(1):\n",
+ " inputs, targets = batch\n",
+ "\n",
+ "print(\"Input shape:\", inputs.numpy().shape)\n",
+ "print(\"Target shape:\", targets.numpy().shape)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Training"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "inputs = keras.layers.Input(shape=(inputs.shape[1], inputs.shape[2]))\n",
+ "lstm_out = keras.layers.LSTM(32)(inputs)\n",
+ "outputs = keras.layers.Dense(1)(lstm_out)\n",
+ "\n",
+ "model = keras.Model(inputs=inputs, outputs=outputs)\n",
+ "model.compile(optimizer=keras.optimizers.Adam(learning_rate=learning_rate), loss=\"mse\")\n",
+ "model.summary()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "We'll use the `ModelCheckpoint` callback to regularly save checkpoints, and\n",
+ "the `EarlyStopping` callback to interrupt training when the validation loss\n",
+ "is not longer improving."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "path_checkpoint = \"model_checkpoint.weights.h5\"\n",
+ "es_callback = keras.callbacks.EarlyStopping(monitor=\"val_loss\", min_delta=0, patience=5)\n",
+ "\n",
+ "modelckpt_callback = keras.callbacks.ModelCheckpoint(\n",
+ " monitor=\"val_loss\",\n",
+ " filepath=path_checkpoint,\n",
+ " verbose=1,\n",
+ " save_weights_only=True,\n",
+ " save_best_only=True,\n",
+ ")\n",
+ "\n",
+ "history = model.fit(\n",
+ " dataset_train,\n",
+ " epochs=epochs,\n",
+ " validation_data=dataset_val,\n",
+ " callbacks=[es_callback, modelckpt_callback],\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "We can visualize the loss with the function below. After one point, the loss stops\n",
+ "decreasing."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "def visualize_loss(history, title):\n",
+ " loss = history.history[\"loss\"]\n",
+ " val_loss = history.history[\"val_loss\"]\n",
+ " epochs = range(len(loss))\n",
+ " plt.figure()\n",
+ " plt.plot(epochs, loss, \"b\", label=\"Training loss\")\n",
+ " plt.plot(epochs, val_loss, \"r\", label=\"Validation loss\")\n",
+ " plt.title(title)\n",
+ " plt.xlabel(\"Epochs\")\n",
+ " plt.ylabel(\"Loss\")\n",
+ " plt.legend()\n",
+ " plt.show()\n",
+ "\n",
+ "\n",
+ "visualize_loss(history, \"Training and Validation Loss\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "colab_type": "text"
+ },
+ "source": [
+ "## Prediction\n",
+ "\n",
+ "The trained model above is now able to make predictions for 5 sets of values from\n",
+ "validation set."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 0,
+ "metadata": {
+ "colab_type": "code"
+ },
+ "outputs": [],
+ "source": [
+ "\n",
+ "def show_plot(plot_data, delta, title):\n",
+ " labels = [\"History\", \"True Future\", \"Model Prediction\"]\n",
+ " marker = [\".-\", \"rx\", \"go\"]\n",
+ " time_steps = list(range(-(plot_data[0].shape[0]), 0))\n",
+ " if delta:\n",
+ " future = delta\n",
+ " else:\n",
+ " future = 0\n",
+ "\n",
+ " plt.title(title)\n",
+ " for i, val in enumerate(plot_data):\n",
+ " if i:\n",
+ " plt.plot(future, plot_data[i], marker[i], markersize=10, label=labels[i])\n",
+ " else:\n",
+ " plt.plot(time_steps, plot_data[i].flatten(), marker[i], label=labels[i])\n",
+ " plt.legend()\n",
+ " plt.xlim([time_steps[0], (future + 5) * 2])\n",
+ " plt.xlabel(\"Time-Step\")\n",
+ " plt.show()\n",
+ " return\n",
+ "\n",
+ "\n",
+ "for x, y in dataset_val.take(5):\n",
+ " show_plot(\n",
+ " [x[0][:, 1].numpy(), y[0].numpy(), model.predict(x)[0]],\n",
+ " 12,\n",
+ " \"Single Step Prediction\",\n",
+ " )"
+ ]
+ }
+ ],
+ "metadata": {
+ "accelerator": "GPU",
+ "colab": {
+ "collapsed_sections": [],
+ "name": "timeseries_weather_forecasting",
+ "private_outputs": false,
+ "provenance": [],
+ "toc_visible": true
+ },
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 0
+}
\ No newline at end of file
diff --git a/examples/timeseries/timeseries_weather_forecasting.py b/examples/timeseries/timeseries_weather_forecasting.py
index 0b30771948..824395cc90 100644
--- a/examples/timeseries/timeseries_weather_forecasting.py
+++ b/examples/timeseries/timeseries_weather_forecasting.py
@@ -10,7 +10,7 @@
"""
## Setup
"""
-
+import os
import pandas as pd
import matplotlib.pyplot as plt
import keras
@@ -50,13 +50,18 @@
15 |wd (deg) |152.3 |Wind direction in degrees
"""
+
from zipfile import ZipFile
uri = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/jena_climate_2009_2016.csv.zip"
zip_path = keras.utils.get_file(origin=uri, fname="jena_climate_2009_2016.csv.zip")
zip_file = ZipFile(zip_path)
-zip_file.extractall()
-csv_path = "jena_climate_2009_2016.csv"
+
+# FIX: Extract to the cache directory, not the current working directory
+zip_file.extractall(path=os.path.dirname(zip_path))
+
+# FIX: Construct the absolute path safely (works on Windows/Linux/Mac)
+csv_path = os.path.join(os.path.dirname(zip_path), "jena_climate_2009_2016.csv")
df = pd.read_csv(csv_path)
diff --git a/guides/define_custom_kernel.py b/guides/define_custom_kernel.py
deleted file mode 100644
index 4aba1c5d3a..0000000000
--- a/guides/define_custom_kernel.py
+++ /dev/null
@@ -1,400 +0,0 @@
-"""
-Title: Define a Custom TPU/GPU Kernel
-Author: [jeffcarp](https://www.jeffcarp.com/)
-Date created: 2025/12/18
-Last modified: 2025/12/18
-Description: Write high-performance custom Keras layers for TPUs and GPUs.
-Accelerator: TPU
-"""
-
-"""
-# How to Write a Custom TPU or GPU Kernel in Keras
-
-Keras has [many pre-made layers to choose from](/api/layers/), and the
-ability to easily [create your
-own](/guides/making_new_layers_and_models_via_subclassing/) if you can't
-find the exact one you need. However, if you have a need for speed, or otherwise
-need to customize the exact behavior of your model at the hardware level, you
-may want to look into writing a custom kernel. A good way to know if you need a
-custom kernel is to look at the profile of your model and see if there are any
-idle gaps caused by computation or memory transfer bottlenecks (see the
-[TensorBoard callback](/api/callbacks/tensorboard/) for how to get a profile).
-
-This guide will explore how to write a custom kernel and add it to your
-Keras model. We will utilize **Pallas**, a library that lets you write
-kernels in Python that can run on both TPU or GPU, where they're lowered
-to Mosaic or Triton, respectively. You can learn more in the [Pallas
-docs](https://docs.jax.dev/en/latest/pallas/index.html).
-
-**Compatibility note:** Pallas is only available when using the JAX backend on
-certain hardware:
-
-- TPU v4 and above
-- NVIDIA Ampere GPUs (compute capability 8.0) and above
-
-If you're running in Colab, the v5e-1 in the free tier supports running this
-guide.
-
-First, make sure you're running the latest version of `libtpu`:
-"""
-
-"""shell
-pip install --upgrade -q "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
-"""
-
-from functools import partial
-import os
-import time
-
-
-os.environ["KERAS_BACKEND"] = "jax"
-
-import jax
-from jax.experimental import pallas as pl
-import jax.numpy as jnp
-import keras
-
-
-"""
-# Simple Example
-
-Let's start with the example from the [Pallas
-quickstart](https://docs.jax.dev/en/latest/pallas/quickstart.html): a simple
-kernel to add two vectors together.
-"""
-
-
-def add_vectors_kernel(x_ref, y_ref, o_ref):
- """Pallas kernel for adding two vectors together."""
- x, y = x_ref[...], y_ref[...]
- o_ref[...] = x + y
-
-
-"""
-Now jit-compile the Pallas function into a function that can be used by JAX.
-"""
-
-
-@jax.jit
-def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
- return pl.pallas_call(
- add_vectors_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)
- )(x, y)
-
-
-add_vectors(jnp.arange(8), jnp.arange(8))
-
-"""
-Now we can embed the jitted `add_vectors` function containing the Pallas kernel into a
-Keras layer, just by calling it there.
-"""
-
-
-class PallasAddLayer(keras.Layer):
- def call(self, x, y):
- # Reuse the JIT-compiled Pallas function
- return add_vectors(x, y)
-
-
-layer = PallasAddLayer()
-
-x_data = jnp.arange(8, dtype=jnp.int32)
-y_data = jnp.arange(8, dtype=jnp.int32)
-
-layer(x_data, y_data)
-
-"""
-That's how to integrate a Pallas kernel into a Keras layer! Now for a more
-in-depth example.
-"""
-
-"""
-# Writing a Fused Linear Activation Layer
-
-Some common reasons you might want to write a custom kernel is to take advantage of
-**fusion** and **tiling**.
-
-**Operator fusion** is the process of combining two or more ops into one "fused" op, for
-example instead of calling `keras.ops.matmul` then `keras.ops.relu` sequentially, we
-could write a custom op that combines both into one more efficient operator.
-XLA already [does operator fusion when possible](https://arxiv.org/abs/2301.13062) for
-certain use cases, but to squeeze even more performance out of the TPU or GPU, we need to
-write a custom op to specify the fusion exactly.
-
-**Tiling** is the ability to control how blocks of memory are loaded from the TPU or
-GPU's larger High Bandwidth Memory (HBM) to the smaller, extremely fast on-chip
-memory (called VMEM on TPU or SMEM on GPU) that the accelerator's computation
-units (e.g., TPU's Matrix Units or a GPU's Tensor Cores) use directly. This is
-critical for improving the performance of large matrix multiplications, for
-example those in the MLP layer at the end of Transformer blocks.
-
-In Pallas, tiling is controlled by the `BlockSpec`. Learn more in the
-[Pallas BlockSpec guide
-here](https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#blockspec-a-k-a-how-to-chunk-up-inputs).
-
-In this section, we'll take two operations that commonly appear together: a
-matrix multiplication (like in a `Dense` layer) and a ReLU activation. We will
-write a new op that fuses them together for better performance.
-
-## Original Unoptimized Implementation
-"""
-
-
-class StandardDenseReLU(keras.layers.Layer):
- """Standard Matmul and ReLU implementation using keras.ops."""
-
- def __init__(self, units, **kwargs):
- super().__init__(**kwargs)
- self.units = units
-
- def build(self, input_shape):
- self.w = self.add_weight(
- shape=(input_shape[-1], self.units),
- initializer="glorot_uniform",
- trainable=True,
- )
-
- def call(self, inputs):
- # The standard implementation performs two separate operations.
- # Each one involves expensive data transfer with the main device memory (HBM).
- # 1. Matmul: inputs (HBM) -> compute -> intermediate (HBM)
- y = keras.ops.matmul(inputs, self.w)
- # 2. ReLU: intermediate (HBM) -> compute -> output (HBM)
- return keras.ops.relu(y)
-
-
-"""
-## 1. Define the Fused Kernel
-
-First we create an inner kernel function that defines the fused computation that
-combines both matmul (`pl.dot`) and activation (`jnp.maximum`).
-"""
-
-import jax.numpy as jnp
-from jax.experimental import pallas as pl
-
-
-def matmul_relu_kernel(a_ref, b_ref, c_ref):
- """Pallas kernel for fused matmul + ReLU."""
- # Perform the matrix multiplication on the local tile
- # pl.dot leverages the hardware's Matrix Unit (MXU)
- acc = pl.dot(a_ref[...], b_ref[...])
-
- # Fusion happens here: apply activation while data is in VMEM
- result = jnp.maximum(acc, 0)
-
- # Write the final result to the output reference
- c_ref[...] = result
-
-
-"""
-## 2. Specify the Tiling (BlockSpec)
-
-Since the input matrices are usually too large to fit into VMEM, Pallas needs ot
-know how to "slice" them for loading from HBM to VMEM.
-
-We define this using `BlockSpec` - this tells the hardware: "Take a 128-row
-chunk of Matrix A and a 128-column chunk of Matrix B to produce a 128x128 tile
-of Matrix C."
-"""
-
-
-@jax.jit
-def fused_matmul(a, b):
- m, k = a.shape
- _, n = b.shape
-
- # Define tile sizes
- tile_m, tile_n = 128, 128
- assert (
- m % tile_m == 0 and n % tile_n == 0
- ), "Inputs must be multiples of 128 for this demo"
-
- return pl.pallas_call(
- matmul_relu_kernel,
- # Map output indices to input blocks
- out_shape=jax.ShapeDtypeStruct((m, n), a.dtype),
- in_specs=[
- # For each output tile, we take a slice of A of shape (tile_m, k)
- pl.BlockSpec(
- index_map=lambda i, j: (i, 0), block_shape=(tile_m, k)
- ), # Matrix A
- # For each output tile, we take a slice of B of shape (k, tile_n)
- pl.BlockSpec(
- index_map=lambda i, j: (0, j), block_shape=(k, tile_n)
- ), # Matrix B
- ],
- out_specs=pl.BlockSpec(
- index_map=lambda i, j: (i, j), block_shape=(tile_m, tile_n)
- ), # Matrix C
- grid=(m // tile_m, n // tile_n),
- )(a, b)
-
-
-fused_matmul(jnp.ones((256, 256)), jnp.ones((256, 256)))
-
-"""
-## 3. Integrating into a Keras Layer
-
-Now for the final step, call the jit-compiled `fused_matmul` kernel from a
-`keras.Layer`.
-"""
-
-
-class FusedDense(keras.layers.Layer):
- """Custom Keras layer that applies the fused Dense and ReLU op."""
-
- def __init__(self, units, **kwargs):
- super().__init__(**kwargs)
- self.units = units
-
- def build(self, input_shape):
- self.w = self.add_weight(
- shape=(input_shape[-1], self.units), initializer="glorot_uniform"
- )
-
- def call(self, inputs):
- # Dispatch to our Pallas kernel
- return fused_matmul(inputs, self.w.value)
-
-
-FusedDense(256)(jnp.ones((256, 256)))
-
-"""
-## 4. Benchmarking the Speedup
-"""
-
-# 1. Setup Data
-N = 8192 # Large enough to be memory bound
-input_data = jnp.ones((N, N), dtype="float32")
-
-# Initialize layers
-standard_layer = StandardDenseReLU(units=N)
-pallas_layer = FusedDense(units=N)
-
-# Build layers by calling them once
-standard_layer(input_data)
-pallas_layer(input_data)
-
-
-def benchmark(layer, x, name, iterations=100):
- # Warm up to ensure JIT compilation is finished
- for _ in range(10):
- layer(x).block_until_ready()
-
- start_time = time.perf_counter()
- for _ in range(iterations):
- layer(x).block_until_ready()
- end_time = time.perf_counter()
-
- avg_time = (end_time - start_time) / iterations * 1000 # convert to ms
- print(f"{name} Average Latency: {avg_time:.3f} ms")
-
-
-# 2. Run Comparison
-print(f"Benchmarking Matrix Size: {N}x{N}\n" + "-" * 30)
-benchmark(standard_layer, input_data, "Standard Keras (Matmul + ReLU)")
-benchmark(pallas_layer, input_data, "Pallas Fused (Matmul + ReLU)")
-
-
-"""
-### Why this Works
-
-**Memory Bandwidth Efficiency:** By fusing the matrix multiplication and
-activation, we perform the ReLU computation while data is still in the chip's
-fast VMEM. This drastically reduces expensive read/write roundtrips to HBM.
-
-**Automatic Parallelization:** Pallas handles the "grid" execution, meaning
-it automatically parallelizes your defined tiles across the available hardware
-cores (whether TPU MXUs or GPU Tensor Cores).
-
-**Drop-in Inference Speed:** This `FusedDense` kernel can be integrated into any
-Keras model, giving an example of improving serving/inference performance with
-minimal code changes.
-"""
-
-"""
-## 5. Enabling Training
-
-In order for a Pallas kernel to be trainable, you must also supply
-a second kernel to define the custom backward pass, since JAX can't
-[AutoGrad](https://docs.jax.dev/en/latest/automatic-differentiation.html)
-through Pallas kernels. Without it, you might see an error like this:
-
-```
-model = keras.Sequential([FusedDense(256)])
-model.compile(optimizer="adam", loss="mse")
-model.fit(jnp.ones((256, 256)), jnp.ones((256, 256)))
->>> Linearization failed to produce known values for all output primals. This is
-typically caused by attempting to differentiate a function uses an operation
-that does not support reverse-mode autodiff.
-```
-
-To extend our fused matmul example above:
-"""
-
-
-# 1. Define the wrapper with `custom_vjp` using our original `fused_matmul`.
-@jax.custom_vjp
-def fused_matmul_trainable(x, w):
- return fused_matmul(x, w)
-
-
-# 2. Define the Forward Pass
-# It must return the output AND "residuals" (data needed for the backward pass)
-def fused_matmul_fwd(x, w):
- y = fused_matmul_trainable(x, w)
- # We save inputs x, w and output y for the backward calculation
- return y, (x, w, y)
-
-
-# 3. Define the Backward Pass
-# JAX gives us the residuals and the incoming gradient (g)
-def fused_matmul_bwd(residuals, g):
- x, w, y = residuals
-
- # Calculate the gradient of ReLU: 1 if y > 0, else 0
- # g is the gradient flowing back from the next layer
- grad_relu = g * (y > 0)
-
- # Standard backprop math for matmul:
- # grad_x = grad_relu @ w.T
- grad_x = jnp.dot(grad_relu, w.T)
-
- # grad_w = x.T @ grad_relu
- grad_w = jnp.dot(x.T, grad_relu)
-
- return grad_x, grad_w
-
-
-# 4. Register the forward and backward functions
-fused_matmul_trainable.defvjp(fused_matmul_fwd, fused_matmul_bwd)
-
-
-class FusedDenseTrainable(FusedDense):
- """Updated layer that contains Pallas forward and backward pass."""
-
- def call(self, inputs):
- # Dispatch to our trainable Pallas kernel
- return fused_matmul_trainable(inputs, self.w.value)
-
-
-# Demonstrate trainability on dummy data
-model = keras.Sequential([FusedDenseTrainable(256)])
-model.compile(optimizer="adam", loss="mse")
-model.fit(jnp.ones((256, 256)), jnp.ones((256, 256)), batch_size=128)
-
-"""
-# Followups
-
-In this guide we covered how to define a simple custom Pallas kernel performing vector
-addition to include in a Keras model. Then we followed up with a more in-depth
-example of a fused matmul + activation kernel that you might use in a real-world
-model to improve performance.
-
-Please refer to the [Pallas
-docs](https://docs.jax.dev/en/latest/pallas/index.html#) for further
-documentation on writing custom kernels. Additionally to explore more examples
-of Pallas kernels, including FlashAttention and MoE layers, check out the
-[Tokamax](https://github.com/openxla/tokamax) library.
-"""
diff --git a/guides/ipynb/define_custom_kernel.ipynb b/guides/ipynb/define_custom_kernel.ipynb
deleted file mode 100644
index 68d9bc9f84..0000000000
--- a/guides/ipynb/define_custom_kernel.ipynb
+++ /dev/null
@@ -1,601 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text"
- },
- "source": [
- "# Define a Custom TPU/GPU Kernel\n",
- "\n",
- "**Author:** [jeffcarp](https://www.jeffcarp.com/)
\n",
- "**Date created:** 2025/12/18
\n",
- "**Last modified:** 2025/12/18
\n",
- "**Description:** Write high-performance custom Keras layers for TPUs and GPUs."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text"
- },
- "source": [
- "# How to Write a Custom TPU or GPU Kernel in Keras\n",
- "\n",
- "Keras has [many pre-made layers to choose from](/api/layers/), and the\n",
- "ability to easily [create your\n",
- "own](/guides/making_new_layers_and_models_via_subclassing/) if you can't\n",
- "find the exact one you need. However, if you have a need for speed, or otherwise\n",
- "need to customize the exact behavior of your model at the hardware level, you\n",
- "may want to look into writing a custom kernel. A good way to know if you need a\n",
- "custom kernel is to look at the profile of your model and see if there are any\n",
- "idle gaps caused by computation or memory transfer bottlenecks (see the\n",
- "[TensorBoard callback](/api/callbacks/tensorboard/) for how to get a profile).\n",
- "\n",
- "This guide will explore how to write a custom kernel and add it to your\n",
- "Keras model. We will utilize **Pallas**, a library that lets you write\n",
- "kernels in Python that can run on both TPU or GPU, where they're lowered\n",
- "to Mosaic or Triton, respectively. You can learn more in the [Pallas\n",
- "docs](https://docs.jax.dev/en/latest/pallas/index.html).\n",
- "\n",
- "**Compatibility note:** Pallas is only available when using the JAX backend on\n",
- "certain hardware:\n",
- "\n",
- "- TPU v4 and above\n",
- "- NVIDIA Ampere GPUs (compute capability 8.0) and above\n",
- "\n",
- "If you're running in Colab, the v5e-1 in the free tier supports running this\n",
- "guide.\n",
- "\n",
- "First, make sure you're running the latest version of `libtpu`:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab_type": "code"
- },
- "outputs": [],
- "source": [
- "!pip install --upgrade -q \"jax[tpu]\" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab_type": "code"
- },
- "outputs": [],
- "source": [
- "from functools import partial\n",
- "import os\n",
- "import time\n",
- "\n",
- "\n",
- "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n",
- "\n",
- "import jax\n",
- "from jax.experimental import pallas as pl\n",
- "import jax.numpy as jnp\n",
- "import keras\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text"
- },
- "source": [
- "# Simple Example\n",
- "\n",
- "Let's start with the example from the [Pallas\n",
- "quickstart](https://docs.jax.dev/en/latest/pallas/quickstart.html): a simple\n",
- "kernel to add two vectors together."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab_type": "code"
- },
- "outputs": [],
- "source": [
- "\n",
- "def add_vectors_kernel(x_ref, y_ref, o_ref):\n",
- " \"\"\"Pallas kernel for adding two vectors together.\"\"\"\n",
- " x, y = x_ref[...], y_ref[...]\n",
- " o_ref[...] = x + y\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text"
- },
- "source": [
- "Now jit-compile the Pallas function into a function that can be used by JAX."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab_type": "code"
- },
- "outputs": [],
- "source": [
- "\n",
- "@jax.jit\n",
- "def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:\n",
- " return pl.pallas_call(\n",
- " add_vectors_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)\n",
- " )(x, y)\n",
- "\n",
- "\n",
- "add_vectors(jnp.arange(8), jnp.arange(8))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text"
- },
- "source": [
- "Now we can embed the jitted `add_vectors` function containing the Pallas kernel into a\n",
- "Keras layer, just by calling it there."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab_type": "code"
- },
- "outputs": [],
- "source": [
- "\n",
- "class PallasAddLayer(keras.Layer):\n",
- " def call(self, x, y):\n",
- " # Reuse the JIT-compiled Pallas function\n",
- " return add_vectors(x, y)\n",
- "\n",
- "\n",
- "layer = PallasAddLayer()\n",
- "\n",
- "x_data = jnp.arange(8, dtype=jnp.int32)\n",
- "y_data = jnp.arange(8, dtype=jnp.int32)\n",
- "\n",
- "layer(x_data, y_data)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text"
- },
- "source": [
- "That's how to integrate a Pallas kernel into a Keras layer! Now for a more\n",
- "in-depth example."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text"
- },
- "source": [
- "# Writing a Fused Linear Activation Layer\n",
- "\n",
- "Some common reasons you might want to write a custom kernel is to take advantage of\n",
- "**fusion** and **tiling**.\n",
- "\n",
- "**Operator fusion** is the process of combining two or more ops into one \"fused\" op, for\n",
- "example instead of calling `keras.ops.matmul` then `keras.ops.relu` sequentially, we\n",
- "could write a custom op that combines both into one more efficient operator.\n",
- "XLA already [does operator fusion when possible](https://arxiv.org/abs/2301.13062) for\n",
- "certain use cases, but to squeeze even more performance out of the TPU or GPU, we need to\n",
- "write a custom op to specify the fusion exactly.\n",
- "\n",
- "**Tiling** is the ability to control how blocks of memory are loaded from the TPU or\n",
- "GPU's larger High Bandwidth Memory (HBM) to the smaller, extremely fast on-chip\n",
- "memory (called VMEM on TPU or SMEM on GPU) that the accelerator's computation\n",
- "units (e.g., TPU's Matrix Units or a GPU's Tensor Cores) use directly. This is\n",
- "critical for improving the performance of large matrix multiplications, for\n",
- "example those in the MLP layer at the end of Transformer blocks.\n",
- "\n",
- "In Pallas, tiling is controlled by the `BlockSpec`. Learn more in the\n",
- "[Pallas BlockSpec guide\n",
- "here](https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#blockspec-a-k-a-how-to-chunk-up-inputs).\n",
- "\n",
- "In this section, we'll take two operations that commonly appear together: a\n",
- "matrix multiplication (like in a `Dense` layer) and a ReLU activation. We will\n",
- "write a new op that fuses them together for better performance.\n",
- "\n",
- "## Original Unoptimized Implementation"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab_type": "code"
- },
- "outputs": [],
- "source": [
- "\n",
- "class StandardDenseReLU(keras.layers.Layer):\n",
- " \"\"\"Standard Matmul and ReLU implementation using keras.ops.\"\"\"\n",
- "\n",
- " def __init__(self, units, **kwargs):\n",
- " super().__init__(**kwargs)\n",
- " self.units = units\n",
- "\n",
- " def build(self, input_shape):\n",
- " self.w = self.add_weight(\n",
- " shape=(input_shape[-1], self.units),\n",
- " initializer=\"glorot_uniform\",\n",
- " trainable=True,\n",
- " )\n",
- "\n",
- " def call(self, inputs):\n",
- " # The standard implementation performs two separate operations.\n",
- " # Each one involves expensive data transfer with the main device memory (HBM).\n",
- " # 1. Matmul: inputs (HBM) -> compute -> intermediate (HBM)\n",
- " y = keras.ops.matmul(inputs, self.w)\n",
- " # 2. ReLU: intermediate (HBM) -> compute -> output (HBM)\n",
- " return keras.ops.relu(y)\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text"
- },
- "source": [
- "## 1. Define the Fused Kernel\n",
- "\n",
- "First we create an inner kernel function that defines the fused computation that\n",
- "combines both matmul (`pl.dot`) and activation (`jnp.maximum`)."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab_type": "code"
- },
- "outputs": [],
- "source": [
- "import jax.numpy as jnp\n",
- "from jax.experimental import pallas as pl\n",
- "\n",
- "\n",
- "def matmul_relu_kernel(a_ref, b_ref, c_ref):\n",
- " \"\"\"Pallas kernel for fused matmul + ReLU.\"\"\"\n",
- " # Perform the matrix multiplication on the local tile\n",
- " # pl.dot leverages the hardware's Matrix Unit (MXU)\n",
- " acc = pl.dot(a_ref[...], b_ref[...])\n",
- "\n",
- " # Fusion happens here: apply activation while data is in VMEM\n",
- " result = jnp.maximum(acc, 0)\n",
- "\n",
- " # Write the final result to the output reference\n",
- " c_ref[...] = result\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text"
- },
- "source": [
- "## 2. Specify the Tiling (BlockSpec)\n",
- "\n",
- "Since the input matrices are usually too large to fit into VMEM, Pallas needs ot\n",
- "know how to \"slice\" them for loading from HBM to VMEM.\n",
- "\n",
- "We define this using `BlockSpec` - this tells the hardware: \"Take a 128-row\n",
- "chunk of Matrix A and a 128-column chunk of Matrix B to produce a 128x128 tile\n",
- "of Matrix C.\""
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab_type": "code"
- },
- "outputs": [],
- "source": [
- "\n",
- "@jax.jit\n",
- "def fused_matmul(a, b):\n",
- " m, k = a.shape\n",
- " _, n = b.shape\n",
- "\n",
- " # Define tile sizes\n",
- " tile_m, tile_n = 128, 128\n",
- " assert (\n",
- " m % tile_m == 0 and n % tile_n == 0\n",
- " ), \"Inputs must be multiples of 128 for this demo\"\n",
- "\n",
- " return pl.pallas_call(\n",
- " matmul_relu_kernel,\n",
- " # Map output indices to input blocks\n",
- " out_shape=jax.ShapeDtypeStruct((m, n), a.dtype),\n",
- " in_specs=[\n",
- " # For each output tile, we take a slice of A of shape (tile_m, k)\n",
- " pl.BlockSpec(\n",
- " index_map=lambda i, j: (i, 0), block_shape=(tile_m, k)\n",
- " ), # Matrix A\n",
- " # For each output tile, we take a slice of B of shape (k, tile_n)\n",
- " pl.BlockSpec(\n",
- " index_map=lambda i, j: (0, j), block_shape=(k, tile_n)\n",
- " ), # Matrix B\n",
- " ],\n",
- " out_specs=pl.BlockSpec(\n",
- " index_map=lambda i, j: (i, j), block_shape=(tile_m, tile_n)\n",
- " ), # Matrix C\n",
- " grid=(m // tile_m, n // tile_n),\n",
- " )(a, b)\n",
- "\n",
- "\n",
- "fused_matmul(jnp.ones((256, 256)), jnp.ones((256, 256)))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text"
- },
- "source": [
- "## 3. Integrating into a Keras Layer\n",
- "\n",
- "Now for the final step, call the jit-compiled `fused_matmul` kernel from a\n",
- "`keras.Layer`."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab_type": "code"
- },
- "outputs": [],
- "source": [
- "\n",
- "class FusedDense(keras.layers.Layer):\n",
- " \"\"\"Custom Keras layer that applies the fused Dense and ReLU op.\"\"\"\n",
- "\n",
- " def __init__(self, units, **kwargs):\n",
- " super().__init__(**kwargs)\n",
- " self.units = units\n",
- "\n",
- " def build(self, input_shape):\n",
- " self.w = self.add_weight(\n",
- " shape=(input_shape[-1], self.units), initializer=\"glorot_uniform\"\n",
- " )\n",
- "\n",
- " def call(self, inputs):\n",
- " # Dispatch to our Pallas kernel\n",
- " return fused_matmul(inputs, self.w.value)\n",
- "\n",
- "\n",
- "FusedDense(256)(jnp.ones((256, 256)))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text"
- },
- "source": [
- "## 4. Benchmarking the Speedup"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab_type": "code"
- },
- "outputs": [],
- "source": [
- "# 1. Setup Data\n",
- "N = 8192 # Large enough to be memory bound\n",
- "input_data = jnp.ones((N, N), dtype=\"float32\")\n",
- "\n",
- "# Initialize layers\n",
- "standard_layer = StandardDenseReLU(units=N)\n",
- "pallas_layer = FusedDense(units=N)\n",
- "\n",
- "# Build layers by calling them once\n",
- "standard_layer(input_data)\n",
- "pallas_layer(input_data)\n",
- "\n",
- "\n",
- "def benchmark(layer, x, name, iterations=100):\n",
- " # Warm up to ensure JIT compilation is finished\n",
- " for _ in range(10):\n",
- " layer(x).block_until_ready()\n",
- "\n",
- " start_time = time.perf_counter()\n",
- " for _ in range(iterations):\n",
- " layer(x).block_until_ready()\n",
- " end_time = time.perf_counter()\n",
- "\n",
- " avg_time = (end_time - start_time) / iterations * 1000 # convert to ms\n",
- " print(f\"{name} Average Latency: {avg_time:.3f} ms\")\n",
- "\n",
- "\n",
- "# 2. Run Comparison\n",
- "print(f\"Benchmarking Matrix Size: {N}x{N}\\n\" + \"-\" * 30)\n",
- "benchmark(standard_layer, input_data, \"Standard Keras (Matmul + ReLU)\")\n",
- "benchmark(pallas_layer, input_data, \"Pallas Fused (Matmul + ReLU)\")\n",
- ""
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text"
- },
- "source": [
- "### Why this Works\n",
- "\n",
- "**Memory Bandwidth Efficiency:** By fusing the matrix multiplication and\n",
- "activation, we perform the ReLU computation while data is still in the chip's\n",
- "fast VMEM. This drastically reduces expensive read/write roundtrips to HBM.\n",
- "\n",
- "**Automatic Parallelization:** Pallas handles the \"grid\" execution, meaning\n",
- "it automatically parallelizes your defined tiles across the available hardware\n",
- "cores (whether TPU MXUs or GPU Tensor Cores).\n",
- "\n",
- "**Drop-in Inference Speed:** This `FusedDense` kernel can be integrated into any\n",
- "Keras model, giving an example of improving serving/inference performance with\n",
- "minimal code changes."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text"
- },
- "source": [
- "## 5. Enabling Training\n",
- "\n",
- "In order for a Pallas kernel to be trainable, you must also supply\n",
- "a second kernel to define the custom backward pass, since JAX can't\n",
- "[AutoGrad](https://docs.jax.dev/en/latest/automatic-differentiation.html)\n",
- "through Pallas kernels. Without it, you might see an error like this:\n",
- "\n",
- "```\n",
- "model = keras.Sequential([FusedDense(256)])\n",
- "model.compile(optimizer=\"adam\", loss=\"mse\")\n",
- "model.fit(jnp.ones((256, 256)), jnp.ones((256, 256)))\n",
- ">>> Linearization failed to produce known values for all output primals. This is\n",
- "typically caused by attempting to differentiate a function uses an operation\n",
- "that does not support reverse-mode autodiff.\n",
- "```\n",
- "\n",
- "To extend our fused matmul example above:"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 0,
- "metadata": {
- "colab_type": "code"
- },
- "outputs": [],
- "source": [
- "\n",
- "# 1. Define the wrapper with `custom_vjp` using our original `fused_matmul`.\n",
- "@jax.custom_vjp\n",
- "def fused_matmul_trainable(x, w):\n",
- " return fused_matmul(x, w)\n",
- "\n",
- "\n",
- "# 2. Define the Forward Pass\n",
- "# It must return the output AND \"residuals\" (data needed for the backward pass)\n",
- "def fused_matmul_fwd(x, w):\n",
- " y = fused_matmul_trainable(x, w)\n",
- " # We save inputs x, w and output y for the backward calculation\n",
- " return y, (x, w, y)\n",
- "\n",
- "\n",
- "# 3. Define the Backward Pass\n",
- "# JAX gives us the residuals and the incoming gradient (g)\n",
- "def fused_matmul_bwd(residuals, g):\n",
- " x, w, y = residuals\n",
- "\n",
- " # Calculate the gradient of ReLU: 1 if y > 0, else 0\n",
- " # g is the gradient flowing back from the next layer\n",
- " grad_relu = g * (y > 0)\n",
- "\n",
- " # Standard backprop math for matmul:\n",
- " # grad_x = grad_relu @ w.T\n",
- " grad_x = jnp.dot(grad_relu, w.T)\n",
- "\n",
- " # grad_w = x.T @ grad_relu\n",
- " grad_w = jnp.dot(x.T, grad_relu)\n",
- "\n",
- " return grad_x, grad_w\n",
- "\n",
- "\n",
- "# 4. Register the forward and backward functions\n",
- "fused_matmul_trainable.defvjp(fused_matmul_fwd, fused_matmul_bwd)\n",
- "\n",
- "\n",
- "class FusedDenseTrainable(FusedDense):\n",
- " \"\"\"Updated layer that contains Pallas forward and backward pass.\"\"\"\n",
- "\n",
- " def call(self, inputs):\n",
- " # Dispatch to our trainable Pallas kernel\n",
- " return fused_matmul_trainable(inputs, self.w.value)\n",
- "\n",
- "\n",
- "# Demonstrate trainability on dummy data\n",
- "model = keras.Sequential([FusedDenseTrainable(256)])\n",
- "model.compile(optimizer=\"adam\", loss=\"mse\")\n",
- "model.fit(jnp.ones((256, 256)), jnp.ones((256, 256)), batch_size=128)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "colab_type": "text"
- },
- "source": [
- "# Followups\n",
- "\n",
- "In this guide we covered how to define a simple custom Pallas kernel performing vector\n",
- "addition to include in a Keras model. Then we followed up with a more in-depth\n",
- "example of a fused matmul + activation kernel that you might use in a real-world\n",
- "model to improve performance.\n",
- "\n",
- "Please refer to the [Pallas\n",
- "docs](https://docs.jax.dev/en/latest/pallas/index.html#) for further\n",
- "documentation on writing custom kernels. Additionally to explore more examples\n",
- "of Pallas kernels, including FlashAttention and MoE layers, check out the\n",
- "[Tokamax](https://github.com/openxla/tokamax) library."
- ]
- }
- ],
- "metadata": {
- "accelerator": "TPU",
- "colab": {
- "collapsed_sections": [],
- "name": "define_custom_kernel",
- "private_outputs": false,
- "provenance": [],
- "toc_visible": true
- },
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.7.0"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 0
-}
\ No newline at end of file
diff --git a/guides/md/define_custom_kernel.md b/guides/md/define_custom_kernel.md
deleted file mode 100644
index 4fdf4af23c..0000000000
--- a/guides/md/define_custom_kernel.md
+++ /dev/null
@@ -1,496 +0,0 @@
-# Define a Custom TPU/GPU Kernel
-
-**Author:** [jeffcarp](https://www.jeffcarp.com/)
-**Date created:** 2025/12/18
-**Last modified:** 2025/12/18
-**Description:** Write high-performance custom Keras layers for TPUs and GPUs.
-
-
- [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/define_custom_kernel.ipynb) •
[**GitHub source**](https://github.com/keras-team/keras-io/blob/master/guides/define_custom_kernel.py)
-
-
-
-# How to Write a Custom TPU or GPU Kernel in Keras
-
-Keras has [many pre-made layers to choose from](/api/layers/), and the
-ability to easily [create your
-own](/guides/making_new_layers_and_models_via_subclassing/) if you can't
-find the exact one you need. However, if you have a need for speed, or otherwise
-need to customize the exact behavior of your model at the hardware level, you
-may want to look into writing a custom kernel. A good way to know if you need a
-custom kernel is to look at the profile of your model and see if there are any
-idle gaps caused by computation or memory transfer bottlenecks (see the
-[TensorBoard callback](/api/callbacks/tensorboard/) for how to get a profile).
-
-This guide will explore how to write a custom kernel and add it to your
-Keras model. We will utilize **Pallas**, a library that lets you write
-kernels in Python that can run on both TPU or GPU, where they're lowered
-to Mosaic or Triton, respectively. You can learn more in the [Pallas
-docs](https://docs.jax.dev/en/latest/pallas/index.html).
-
-**Compatibility note:** Pallas is only available when using the JAX backend on
-certain hardware:
-
-- TPU v4 and above
-- NVIDIA Ampere GPUs (compute capability 8.0) and above
-
-If you're running in Colab, the v5e-1 in the free tier supports running this
-guide.
-
-First, make sure you're running the latest version of `libtpu`:
-
-
-```python
-!pip install --upgrade -q "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
-```
-
-
-```python
-from functools import partial
-import os
-import time
-
-
-os.environ["KERAS_BACKEND"] = "jax"
-
-import jax
-from jax.experimental import pallas as pl
-import jax.numpy as jnp
-import keras
-
-```
-