diff --git a/examples/timeseries/timeseries_traffic_forecasting.ipynb b/examples/timeseries/timeseries_traffic_forecasting.ipynb new file mode 100644 index 0000000000..b63405df9a --- /dev/null +++ b/examples/timeseries/timeseries_traffic_forecasting.ipynb @@ -0,0 +1,920 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "# Traffic forecasting using graph neural networks and LSTM\n", + "\n", + "**Author:** [Arash Khodadadi](https://www.linkedin.com/in/arash-khodadadi-08a02490/)
\n", + "**Date created:** 2021/12/28
\n", + "**Last modified:** 2023/11/22
\n", + "**Description:** This example demonstrates how to do timeseries forecasting over graphs." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Introduction\n", + "\n", + "This example shows how to forecast traffic condition using graph neural networks and LSTM.\n", + "Specifically, we are interested in predicting the future values of the traffic speed given\n", + "a history of the traffic speed for a collection of road segments.\n", + "\n", + "One popular method to\n", + "solve this problem is to consider each road segment's traffic speed as a separate\n", + "timeseries and predict the future values of each timeseries\n", + "using the past values of the same timeseries.\n", + "\n", + "This method, however, ignores the dependency of the traffic speed of one road segment on\n", + "the neighboring segments. To be able to take into account the complex interactions between\n", + "the traffic speed on a collection of neighboring roads, we can define the traffic network\n", + "as a graph and consider the traffic speed as a signal on this graph. In this example,\n", + "we implement a neural network architecture which can process timeseries data over a graph.\n", + "We first show how to process the data and create a\n", + "[tf.data.Dataset](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) for\n", + "forecasting over graphs. Then, we implement a model which uses graph convolution and\n", + "LSTM layers to perform forecasting over a graph.\n", + "\n", + "The data processing and the model architecture are inspired by this paper:\n", + "\n", + "Yu, Bing, Haoteng Yin, and Zhanxing Zhu. \"Spatio-temporal graph convolutional networks:\n", + "a deep learning framework for traffic forecasting.\" Proceedings of the 27th International\n", + "Joint Conference on Artificial Intelligence, 2018.\n", + "([github](https://github.com/VeritasYin/STGCN_IJCAI-18))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n", + "\n", + "import pandas as pd\n", + "import numpy as np\n", + "import typing\n", + "import matplotlib.pyplot as plt\n", + "\n", + "import tensorflow as tf\n", + "import keras\n", + "from keras import layers\n", + "from keras import ops" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Data preparation" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "### Data description\n", + "\n", + "We use a real-world traffic speed dataset named `PeMSD7`. We use the version\n", + "collected and prepared by [Yu et al., 2018](https://arxiv.org/abs/1709.04875)\n", + "and available\n", + "[here](https://github.com/VeritasYin/STGCN_IJCAI-18/tree/master/dataset).\n", + "\n", + "The data consists of two files:\n", + "\n", + "- `PeMSD7_W_228.csv` contains the distances between 228\n", + "stations across the District 7 of California.\n", + "- `PeMSD7_V_228.csv` contains traffic\n", + "speed collected for those stations in the weekdays of May and June of 2012.\n", + "\n", + "The full description of the dataset can be found in\n", + "[Yu et al., 2018](https://arxiv.org/abs/1709.04875)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "### Loading data" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "url = \"https://github.com/VeritasYin/STGCN_IJCAI-18/raw/master/dataset/PeMSD7_Full.zip\"\n", + "# 1. Download and extract normally\n", + "zip_path = keras.utils.get_file(origin=url, extract=True, archive_format=\"zip\")\n", + "\n", + "# 2. FIX: Use os.path.dirname to safely get the folder where it was extracted\n", + "data_dir = os.path.dirname(zip_path)\n", + "\n", + "# 3. Construct the paths to the inner files safely\n", + "route_distances = pd.read_csv(\n", + " os.path.join(data_dir, \"PeMSD7_W_228.csv\"), header=None\n", + ").to_numpy()\n", + "speeds_array = pd.read_csv(\n", + " os.path.join(data_dir, \"PeMSD7_V_228.csv\"), header=None\n", + ").to_numpy()\n", + "\n", + "print(f\"route_distances shape={route_distances.shape}\")\n", + "print(f\"speeds_array shape={speeds_array.shape}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "### sub-sampling roads\n", + "\n", + "To reduce the problem size and make the training faster, we will only\n", + "work with a sample of 26 roads out of the 228 roads in the dataset.\n", + "We have chosen the roads by starting from road 0, choosing the 5 closest\n", + "roads to it, and continuing this process until we get 25 roads. You can choose\n", + "any other subset of the roads. We chose the roads in this way to increase the likelihood\n", + "of having roads with correlated speed timeseries.\n", + "`sample_routes` contains the IDs of the selected roads." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "sample_routes = [\n", + " 0,\n", + " 1,\n", + " 4,\n", + " 7,\n", + " 8,\n", + " 11,\n", + " 15,\n", + " 108,\n", + " 109,\n", + " 114,\n", + " 115,\n", + " 118,\n", + " 120,\n", + " 123,\n", + " 124,\n", + " 126,\n", + " 127,\n", + " 129,\n", + " 130,\n", + " 132,\n", + " 133,\n", + " 136,\n", + " 139,\n", + " 144,\n", + " 147,\n", + " 216,\n", + "]\n", + "route_distances = route_distances[np.ix_(sample_routes, sample_routes)]\n", + "speeds_array = speeds_array[:, sample_routes]\n", + "\n", + "print(f\"route_distances shape={route_distances.shape}\")\n", + "print(f\"speeds_array shape={speeds_array.shape}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "### Data visualization\n", + "\n", + "Here are the timeseries of the traffic speed for two of the routes:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "plt.figure(figsize=(18, 6))\n", + "plt.plot(speeds_array[:, [0, -1]])\n", + "plt.legend([\"route_0\", \"route_25\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "We can also visualize the correlation between the timeseries in different routes." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "plt.figure(figsize=(8, 8))\n", + "plt.matshow(np.corrcoef(speeds_array.T), 0)\n", + "plt.xlabel(\"road number\")\n", + "plt.ylabel(\"road number\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "Using this correlation heatmap, we can see that for example the speed in\n", + "routes 4, 5, 6 are highly correlated." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "### Splitting and normalizing data\n", + "\n", + "Next, we split the speed values array into train/validation/test sets,\n", + "and normalize the resulting arrays:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "train_size, val_size = 0.5, 0.2\n", + "\n", + "\n", + "def preprocess(data_array: np.ndarray, train_size: float, val_size: float):\n", + " \"\"\"Splits data into train/val/test sets and normalizes the data.\n", + "\n", + " Args:\n", + " data_array: ndarray of shape `(num_time_steps, num_routes)`\n", + " train_size: A float value between 0.0 and 1.0 that represent the proportion of the dataset\n", + " to include in the train split.\n", + " val_size: A float value between 0.0 and 1.0 that represent the proportion of the dataset\n", + " to include in the validation split.\n", + "\n", + " Returns:\n", + " `train_array`, `val_array`, `test_array`\n", + " \"\"\"\n", + "\n", + " num_time_steps = data_array.shape[0]\n", + " num_train, num_val = (\n", + " int(num_time_steps * train_size),\n", + " int(num_time_steps * val_size),\n", + " )\n", + " train_array = data_array[:num_train]\n", + " mean, std = train_array.mean(axis=0), train_array.std(axis=0)\n", + "\n", + " train_array = (train_array - mean) / std\n", + " val_array = (data_array[num_train : (num_train + num_val)] - mean) / std\n", + " test_array = (data_array[(num_train + num_val) :] - mean) / std\n", + "\n", + " return train_array, val_array, test_array\n", + "\n", + "\n", + "train_array, val_array, test_array = preprocess(speeds_array, train_size, val_size)\n", + "\n", + "print(f\"train set size: {train_array.shape}\")\n", + "print(f\"validation set size: {val_array.shape}\")\n", + "print(f\"test set size: {test_array.shape}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "### Creating TensorFlow Datasets\n", + "\n", + "Next, we create the datasets for our forecasting problem. The forecasting problem\n", + "can be stated as follows: given a sequence of the\n", + "road speed values at times `t+1, t+2, ..., t+T`, we want to predict the future values of\n", + "the roads speed for times `t+T+1, ..., t+T+h`. So for each time `t` the inputs to our\n", + "model are `T` vectors each of size `N` and the targets are `h` vectors each of size `N`,\n", + "where `N` is the number of roads." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "We use the Keras built-in function\n", + "`keras.utils.timeseries_dataset_from_array`.\n", + "The function `create_tf_dataset()` below takes as input a `numpy.ndarray` and returns a\n", + "`tf.data.Dataset`. In this function `input_sequence_length=T` and `forecast_horizon=h`.\n", + "\n", + "The argument `multi_horizon` needs more explanation. Assume `forecast_horizon=3`.\n", + "If `multi_horizon=True` then the model will make a forecast for time steps\n", + "`t+T+1, t+T+2, t+T+3`. So the target will have shape `(T,3)`. But if\n", + "`multi_horizon=False`, the model will make a forecast only for time step `t+T+3` and\n", + "so the target will have shape `(T, 1)`.\n", + "\n", + "You may notice that the input tensor in each batch has shape\n", + "`(batch_size, input_sequence_length, num_routes, 1)`. The last dimension is added to\n", + "make the model more general: at each time step, the input features for each raod may\n", + "contain multiple timeseries. For instance, one might want to use temperature timeseries\n", + "in addition to historical values of the speed as input features. In this example,\n", + "however, the last dimension of the input is always 1.\n", + "\n", + "We use the last 12 values of the speed in each road to forecast the speed for 3 time\n", + "steps ahead:" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "batch_size = 64\n", + "input_sequence_length = 12\n", + "forecast_horizon = 3\n", + "multi_horizon = False\n", + "\n", + "\n", + "def create_tf_dataset(\n", + " data_array: np.ndarray,\n", + " input_sequence_length: int,\n", + " forecast_horizon: int,\n", + " batch_size: int = 128,\n", + " shuffle=True,\n", + " multi_horizon=True,\n", + "):\n", + " \"\"\"Creates tensorflow dataset from numpy array.\n", + "\n", + " This function creates a dataset where each element is a tuple `(inputs, targets)`.\n", + " `inputs` is a Tensor\n", + " of shape `(batch_size, input_sequence_length, num_routes, 1)` containing\n", + " the `input_sequence_length` past values of the timeseries for each node.\n", + " `targets` is a Tensor of shape `(batch_size, forecast_horizon, num_routes)`\n", + " containing the `forecast_horizon`\n", + " future values of the timeseries for each node.\n", + "\n", + " Args:\n", + " data_array: np.ndarray with shape `(num_time_steps, num_routes)`\n", + " input_sequence_length: Length of the input sequence (in number of timesteps).\n", + " forecast_horizon: If `multi_horizon=True`, the target will be the values of the timeseries for 1 to\n", + " `forecast_horizon` timesteps ahead. If `multi_horizon=False`, the target will be the value of the\n", + " timeseries `forecast_horizon` steps ahead (only one value).\n", + " batch_size: Number of timeseries samples in each batch.\n", + " shuffle: Whether to shuffle output samples, or instead draw them in chronological order.\n", + " multi_horizon: See `forecast_horizon`.\n", + "\n", + " Returns:\n", + " A tf.data.Dataset instance.\n", + " \"\"\"\n", + "\n", + " inputs = keras.utils.timeseries_dataset_from_array(\n", + " np.expand_dims(data_array[:-forecast_horizon], axis=-1),\n", + " None,\n", + " sequence_length=input_sequence_length,\n", + " shuffle=False,\n", + " batch_size=batch_size,\n", + " )\n", + "\n", + " target_offset = (\n", + " input_sequence_length\n", + " if multi_horizon\n", + " else input_sequence_length + forecast_horizon - 1\n", + " )\n", + " target_seq_length = forecast_horizon if multi_horizon else 1\n", + " targets = keras.utils.timeseries_dataset_from_array(\n", + " data_array[target_offset:],\n", + " None,\n", + " sequence_length=target_seq_length,\n", + " shuffle=False,\n", + " batch_size=batch_size,\n", + " )\n", + "\n", + " dataset = tf.data.Dataset.zip((inputs, targets))\n", + " if shuffle:\n", + " dataset = dataset.shuffle(100)\n", + "\n", + " return dataset.prefetch(16).cache()\n", + "\n", + "\n", + "train_dataset, val_dataset = (\n", + " create_tf_dataset(data_array, input_sequence_length, forecast_horizon, batch_size)\n", + " for data_array in [train_array, val_array]\n", + ")\n", + "\n", + "test_dataset = create_tf_dataset(\n", + " test_array,\n", + " input_sequence_length,\n", + " forecast_horizon,\n", + " batch_size=test_array.shape[0],\n", + " shuffle=False,\n", + " multi_horizon=multi_horizon,\n", + ")\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "### Roads Graph\n", + "\n", + "As mentioned before, we assume that the road segments form a graph.\n", + "The `PeMSD7` dataset has the road segments distance. The next step\n", + "is to create the graph adjacency matrix from these distances. Following\n", + "[Yu et al., 2018](https://arxiv.org/abs/1709.04875) (equation 10) we assume there\n", + "is an edge between two nodes in the graph if the distance between the corresponding roads\n", + "is less than a threshold." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "def compute_adjacency_matrix(\n", + " route_distances: np.ndarray, sigma2: float, epsilon: float\n", + "):\n", + " \"\"\"Computes the adjacency matrix from distances matrix.\n", + "\n", + " It uses the formula in https://github.com/VeritasYin/STGCN_IJCAI-18#data-preprocessing to\n", + " compute an adjacency matrix from the distance matrix.\n", + " The implementation follows that paper.\n", + "\n", + " Args:\n", + " route_distances: np.ndarray of shape `(num_routes, num_routes)`. Entry `i,j` of this array is the\n", + " distance between roads `i,j`.\n", + " sigma2: Determines the width of the Gaussian kernel applied to the square distances matrix.\n", + " epsilon: A threshold specifying if there is an edge between two nodes. Specifically, `A[i,j]=1`\n", + " if `np.exp(-w2[i,j] / sigma2) >= epsilon` and `A[i,j]=0` otherwise, where `A` is the adjacency\n", + " matrix and `w2=route_distances * route_distances`\n", + "\n", + " Returns:\n", + " A boolean graph adjacency matrix.\n", + " \"\"\"\n", + " num_routes = route_distances.shape[0]\n", + " route_distances = route_distances / 10000.0\n", + " w2, w_mask = (\n", + " route_distances * route_distances,\n", + " np.ones([num_routes, num_routes]) - np.identity(num_routes),\n", + " )\n", + " return (np.exp(-w2 / sigma2) >= epsilon) * w_mask\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "The function `compute_adjacency_matrix()` returns a boolean adjacency matrix\n", + "where 1 means there is an edge between two nodes. We use the following class\n", + "to store the information about the graph." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "class GraphInfo:\n", + " def __init__(self, edges: typing.Tuple[list, list], num_nodes: int):\n", + " self.edges = edges\n", + " self.num_nodes = num_nodes\n", + "\n", + "\n", + "sigma2 = 0.1\n", + "epsilon = 0.5\n", + "adjacency_matrix = compute_adjacency_matrix(route_distances, sigma2, epsilon)\n", + "node_indices, neighbor_indices = np.where(adjacency_matrix == 1)\n", + "graph = GraphInfo(\n", + " edges=(node_indices.tolist(), neighbor_indices.tolist()),\n", + " num_nodes=adjacency_matrix.shape[0],\n", + ")\n", + "print(f\"number of nodes: {graph.num_nodes}, number of edges: {len(graph.edges[0])}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Network architecture\n", + "\n", + "Our model for forecasting over the graph consists of a graph convolution\n", + "layer and a LSTM layer." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "### Graph convolution layer\n", + "\n", + "Our implementation of the graph convolution layer resembles the implementation\n", + "in [this Keras example](https://keras.io/examples/graph/gnn_citations/). Note that\n", + "in that example input to the layer is a 2D tensor of shape `(num_nodes,in_feat)`\n", + "but in our example the input to the layer is a 4D tensor of shape\n", + "`(num_nodes, batch_size, input_seq_length, in_feat)`. The graph convolution layer\n", + "performs the following steps:\n", + "\n", + "- The nodes' representations are computed in `self.compute_nodes_representation()`\n", + "by multiplying the input features by `self.weight`\n", + "- The aggregated neighbors' messages are computed in `self.compute_aggregated_messages()`\n", + "by first aggregating the neighbors' representations and then multiplying the results by\n", + "`self.weight`\n", + "- The final output of the layer is computed in `self.update()` by combining the nodes\n", + "representations and the neighbors' aggregated messages" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "class GraphConv(layers.Layer):\n", + " def __init__(\n", + " self,\n", + " in_feat,\n", + " out_feat,\n", + " graph_info: GraphInfo,\n", + " aggregation_type=\"mean\",\n", + " combination_type=\"concat\",\n", + " activation: typing.Optional[str] = None,\n", + " **kwargs,\n", + " ):\n", + " super().__init__(**kwargs)\n", + " self.in_feat = in_feat\n", + " self.out_feat = out_feat\n", + " self.graph_info = graph_info\n", + " self.aggregation_type = aggregation_type\n", + " self.combination_type = combination_type\n", + " self.weight = self.add_weight(\n", + " initializer=keras.initializers.GlorotUniform(),\n", + " shape=(in_feat, out_feat),\n", + " dtype=\"float32\",\n", + " trainable=True,\n", + " )\n", + " self.activation = layers.Activation(activation)\n", + "\n", + " def aggregate(self, neighbour_representations):\n", + " aggregation_func = {\n", + " \"sum\": tf.math.unsorted_segment_sum,\n", + " \"mean\": tf.math.unsorted_segment_mean,\n", + " \"max\": tf.math.unsorted_segment_max,\n", + " }.get(self.aggregation_type)\n", + "\n", + " if aggregation_func:\n", + " return aggregation_func(\n", + " neighbour_representations,\n", + " self.graph_info.edges[0],\n", + " num_segments=self.graph_info.num_nodes,\n", + " )\n", + "\n", + " raise ValueError(f\"Invalid aggregation type: {self.aggregation_type}\")\n", + "\n", + " def compute_nodes_representation(self, features):\n", + " \"\"\"Computes each node's representation.\n", + "\n", + " The nodes' representations are obtained by multiplying the features tensor with\n", + " `self.weight`. Note that\n", + " `self.weight` has shape `(in_feat, out_feat)`.\n", + "\n", + " Args:\n", + " features: Tensor of shape `(num_nodes, batch_size, input_seq_len, in_feat)`\n", + "\n", + " Returns:\n", + " A tensor of shape `(num_nodes, batch_size, input_seq_len, out_feat)`\n", + " \"\"\"\n", + " return ops.matmul(features, self.weight)\n", + "\n", + " def compute_aggregated_messages(self, features):\n", + " neighbour_representations = tf.gather(features, self.graph_info.edges[1])\n", + " aggregated_messages = self.aggregate(neighbour_representations)\n", + " return ops.matmul(aggregated_messages, self.weight)\n", + "\n", + " def update(self, nodes_representation, aggregated_messages):\n", + " if self.combination_type == \"concat\":\n", + " h = ops.concatenate([nodes_representation, aggregated_messages], axis=-1)\n", + " elif self.combination_type == \"add\":\n", + " h = nodes_representation + aggregated_messages\n", + " else:\n", + " raise ValueError(f\"Invalid combination type: {self.combination_type}.\")\n", + " return self.activation(h)\n", + "\n", + " def call(self, features):\n", + " \"\"\"Forward pass.\n", + "\n", + " Args:\n", + " features: tensor of shape `(num_nodes, batch_size, input_seq_len, in_feat)`\n", + "\n", + " Returns:\n", + " A tensor of shape `(num_nodes, batch_size, input_seq_len, out_feat)`\n", + " \"\"\"\n", + " nodes_representation = self.compute_nodes_representation(features)\n", + " aggregated_messages = self.compute_aggregated_messages(features)\n", + " return self.update(nodes_representation, aggregated_messages)\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "### LSTM plus graph convolution\n", + "\n", + "By applying the graph convolution layer to the input tensor, we get another tensor\n", + "containing the nodes' representations over time (another 4D tensor). For each time\n", + "step, a node's representation is informed by the information from its neighbors.\n", + "\n", + "To make good forecasts, however, we need not only information from the neighbors\n", + "but also we need to process the information over time. To this end, we can pass each\n", + "node's tensor through a recurrent layer. The `LSTMGC` layer below, first applies\n", + "a graph convolution layer to the inputs and then passes the results through a\n", + "`LSTM` layer." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "class LSTMGC(layers.Layer):\n", + " \"\"\"Layer comprising a convolution layer followed by LSTM and dense layers.\"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " in_feat,\n", + " out_feat,\n", + " lstm_units: int,\n", + " input_seq_len: int,\n", + " output_seq_len: int,\n", + " graph_info: GraphInfo,\n", + " graph_conv_params: typing.Optional[dict] = None,\n", + " **kwargs,\n", + " ):\n", + " super().__init__(**kwargs)\n", + "\n", + " # graph conv layer\n", + " if graph_conv_params is None:\n", + " graph_conv_params = {\n", + " \"aggregation_type\": \"mean\",\n", + " \"combination_type\": \"concat\",\n", + " \"activation\": None,\n", + " }\n", + " self.graph_conv = GraphConv(in_feat, out_feat, graph_info, **graph_conv_params)\n", + "\n", + " self.lstm = layers.LSTM(lstm_units, activation=\"relu\")\n", + " self.dense = layers.Dense(output_seq_len)\n", + "\n", + " self.input_seq_len, self.output_seq_len = input_seq_len, output_seq_len\n", + "\n", + " def call(self, inputs):\n", + " \"\"\"Forward pass.\n", + "\n", + " Args:\n", + " inputs: tensor of shape `(batch_size, input_seq_len, num_nodes, in_feat)`\n", + "\n", + " Returns:\n", + " A tensor of shape `(batch_size, output_seq_len, num_nodes)`.\n", + " \"\"\"\n", + "\n", + " # convert shape to (num_nodes, batch_size, input_seq_len, in_feat)\n", + " inputs = ops.transpose(inputs, [2, 0, 1, 3])\n", + "\n", + " gcn_out = self.graph_conv(\n", + " inputs\n", + " ) # gcn_out has shape: (num_nodes, batch_size, input_seq_len, out_feat)\n", + " shape = ops.shape(gcn_out)\n", + " num_nodes, batch_size, input_seq_len, out_feat = (\n", + " shape[0],\n", + " shape[1],\n", + " shape[2],\n", + " shape[3],\n", + " )\n", + "\n", + " # LSTM takes only 3D tensors as input\n", + " gcn_out = ops.reshape(\n", + " gcn_out, (batch_size * num_nodes, input_seq_len, out_feat)\n", + " )\n", + " lstm_out = self.lstm(\n", + " gcn_out\n", + " ) # lstm_out has shape: (batch_size * num_nodes, lstm_units)\n", + "\n", + " dense_output = self.dense(\n", + " lstm_out\n", + " ) # dense_output has shape: (batch_size * num_nodes, output_seq_len)\n", + " output = ops.reshape(dense_output, (num_nodes, batch_size, self.output_seq_len))\n", + " return ops.transpose(\n", + " output, [1, 2, 0]\n", + " ) # returns Tensor of shape (batch_size, output_seq_len, num_nodes)\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Model training" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "in_feat = 1\n", + "batch_size = 64\n", + "epochs = 20\n", + "input_sequence_length = 12\n", + "forecast_horizon = 3\n", + "multi_horizon = False\n", + "out_feat = 10\n", + "lstm_units = 64\n", + "graph_conv_params = {\n", + " \"aggregation_type\": \"mean\",\n", + " \"combination_type\": \"concat\",\n", + " \"activation\": None,\n", + "}\n", + "\n", + "st_gcn = LSTMGC(\n", + " in_feat,\n", + " out_feat,\n", + " lstm_units,\n", + " input_sequence_length,\n", + " forecast_horizon,\n", + " graph,\n", + " graph_conv_params,\n", + ")\n", + "inputs = layers.Input((input_sequence_length, graph.num_nodes, in_feat))\n", + "outputs = st_gcn(inputs)\n", + "\n", + "model = keras.models.Model(inputs, outputs)\n", + "model.compile(\n", + " optimizer=keras.optimizers.RMSprop(learning_rate=0.0002),\n", + " loss=keras.losses.MeanSquaredError(),\n", + ")\n", + "model.fit(\n", + " train_dataset,\n", + " validation_data=val_dataset,\n", + " epochs=epochs,\n", + " callbacks=[keras.callbacks.EarlyStopping(patience=10)],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Making forecasts on test set\n", + "\n", + "Now we can use the trained model to make forecasts for the test set. Below, we\n", + "compute the MAE of the model and compare it to the MAE of naive forecasts.\n", + "The naive forecasts are the last value of the speed for each node." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "x_test, y = next(test_dataset.as_numpy_iterator())\n", + "y_pred = model.predict(x_test)\n", + "plt.figure(figsize=(18, 6))\n", + "plt.plot(y[:, 0, 0])\n", + "plt.plot(y_pred[:, 0, 0])\n", + "plt.legend([\"actual\", \"forecast\"])\n", + "\n", + "naive_mse, model_mse = (\n", + " np.square(x_test[:, -1, :, 0] - y[:, 0, :]).mean(),\n", + " np.square(y_pred[:, 0, :] - y[:, 0, :]).mean(),\n", + ")\n", + "print(f\"naive MAE: {naive_mse}, model MAE: {model_mse}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "Of course, the goal here is to demonstrate the method,\n", + "not to achieve the best performance. To improve the\n", + "model's accuracy, all model hyperparameters should be tuned carefully. In addition,\n", + "several of the `LSTMGC` blocks can be stacked to increase the representation power\n", + "of the model." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "timeseries_traffic_forecasting", + "private_outputs": false, + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.0" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/examples/timeseries/timeseries_traffic_forecasting.py b/examples/timeseries/timeseries_traffic_forecasting.py index ea66ee2131..f9fa5b9900 100644 --- a/examples/timeseries/timeseries_traffic_forecasting.py +++ b/examples/timeseries/timeseries_traffic_forecasting.py @@ -83,9 +83,13 @@ """ url = "https://github.com/VeritasYin/STGCN_IJCAI-18/raw/master/dataset/PeMSD7_Full.zip" -data_dir = keras.utils.get_file(origin=url, extract=True, archive_format="zip") -data_dir = data_dir.rstrip("PeMSD7_Full.zip") +# 1. Download and extract normally +zip_path = keras.utils.get_file(origin=url, extract=True, archive_format="zip") +# 2. FIX: Use os.path.dirname to safely get the folder where it was extracted +data_dir = os.path.dirname(zip_path) + +# 3. Construct the paths to the inner files safely route_distances = pd.read_csv( os.path.join(data_dir, "PeMSD7_W_228.csv"), header=None ).to_numpy() diff --git a/examples/timeseries/timeseries_weather_forecasting.ipynb b/examples/timeseries/timeseries_weather_forecasting.ipynb new file mode 100644 index 0000000000..dafe27bd7e --- /dev/null +++ b/examples/timeseries/timeseries_weather_forecasting.ipynb @@ -0,0 +1,574 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "# Timeseries forecasting for weather prediction\n", + "\n", + "**Authors:** [Prabhanshu Attri](https://prabhanshu.com/github), [Yashika Sharma](https://github.com/yashika51), [Kristi Takach](https://github.com/ktakattack), [Falak Shah](https://github.com/falaktheoptimist)
\n", + "**Date created:** 2020/06/23
\n", + "**Last modified:** 2023/11/22
\n", + "**Description:** This notebook demonstrates how to do timeseries forecasting using a LSTM model." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "import os\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "import keras" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Climate Data Time-Series\n", + "\n", + "We will be using Jena Climate dataset recorded by the\n", + "[Max Planck Institute for Biogeochemistry](https://www.bgc-jena.mpg.de/wetter/).\n", + "The dataset consists of 14 features such as temperature, pressure, humidity etc, recorded once per\n", + "10 minutes.\n", + "\n", + "**Location**: Weather Station, Max Planck Institute for Biogeochemistry\n", + "in Jena, Germany\n", + "\n", + "**Time-frame Considered**: Jan 10, 2009 - December 31, 2016\n", + "\n", + "\n", + "The table below shows the column names, their value formats, and their description.\n", + "\n", + "Index| Features |Format |Description\n", + "-----|---------------|-------------------|-----------------------\n", + "1 |Date Time |01.01.2009 00:10:00|Date-time reference\n", + "2 |p (mbar) |996.52 |The pascal SI derived unit of pressure used to quantify internal pressure. Meteorological reports typically state atmospheric pressure in millibars.\n", + "3 |T (degC) |-8.02 |Temperature in Celsius\n", + "4 |Tpot (K) |265.4 |Temperature in Kelvin\n", + "5 |Tdew (degC) |-8.9 |Temperature in Celsius relative to humidity. Dew Point is a measure of the absolute amount of water in the air, the DP is the temperature at which the air cannot hold all the moisture in it and water condenses.\n", + "6 |rh (%) |93.3 |Relative Humidity is a measure of how saturated the air is with water vapor, the %RH determines the amount of water contained within collection objects.\n", + "7 |VPmax (mbar) |3.33 |Saturation vapor pressure\n", + "8 |VPact (mbar) |3.11 |Vapor pressure\n", + "9 |VPdef (mbar) |0.22 |Vapor pressure deficit\n", + "10 |sh (g/kg) |1.94 |Specific humidity\n", + "11 |H2OC (mmol/mol)|3.12 |Water vapor concentration\n", + "12 |rho (g/m ** 3) |1307.75 |Airtight\n", + "13 |wv (m/s) |1.03 |Wind speed\n", + "14 |max. wv (m/s) |1.75 |Maximum wind speed\n", + "15 |wd (deg) |152.3 |Wind direction in degrees" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "from zipfile import ZipFile\n", + "\n", + "uri = \"https://storage.googleapis.com/tensorflow/tf-keras-datasets/jena_climate_2009_2016.csv.zip\"\n", + "zip_path = keras.utils.get_file(origin=uri, fname=\"jena_climate_2009_2016.csv.zip\")\n", + "zip_file = ZipFile(zip_path)\n", + "\n", + "# FIX: Extract to the cache directory, not the current working directory\n", + "zip_file.extractall(path=os.path.dirname(zip_path))\n", + "\n", + "# FIX: Construct the absolute path safely (works on Windows/Linux/Mac)\n", + "csv_path = os.path.join(os.path.dirname(zip_path), \"jena_climate_2009_2016.csv\")\n", + "\n", + "df = pd.read_csv(csv_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Raw Data Visualization\n", + "\n", + "To give us a sense of the data we are working with, each feature has been plotted below.\n", + "This shows the distinct pattern of each feature over the time period from 2009 to 2016.\n", + "It also shows where anomalies are present, which will be addressed during normalization." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "titles = [\n", + " \"Pressure\",\n", + " \"Temperature\",\n", + " \"Temperature in Kelvin\",\n", + " \"Temperature (dew point)\",\n", + " \"Relative Humidity\",\n", + " \"Saturation vapor pressure\",\n", + " \"Vapor pressure\",\n", + " \"Vapor pressure deficit\",\n", + " \"Specific humidity\",\n", + " \"Water vapor concentration\",\n", + " \"Airtight\",\n", + " \"Wind speed\",\n", + " \"Maximum wind speed\",\n", + " \"Wind direction in degrees\",\n", + "]\n", + "\n", + "feature_keys = [\n", + " \"p (mbar)\",\n", + " \"T (degC)\",\n", + " \"Tpot (K)\",\n", + " \"Tdew (degC)\",\n", + " \"rh (%)\",\n", + " \"VPmax (mbar)\",\n", + " \"VPact (mbar)\",\n", + " \"VPdef (mbar)\",\n", + " \"sh (g/kg)\",\n", + " \"H2OC (mmol/mol)\",\n", + " \"rho (g/m**3)\",\n", + " \"wv (m/s)\",\n", + " \"max. wv (m/s)\",\n", + " \"wd (deg)\",\n", + "]\n", + "\n", + "colors = [\n", + " \"blue\",\n", + " \"orange\",\n", + " \"green\",\n", + " \"red\",\n", + " \"purple\",\n", + " \"brown\",\n", + " \"pink\",\n", + " \"gray\",\n", + " \"olive\",\n", + " \"cyan\",\n", + "]\n", + "\n", + "date_time_key = \"Date Time\"\n", + "\n", + "\n", + "def show_raw_visualization(data):\n", + " time_data = data[date_time_key]\n", + " fig, axes = plt.subplots(\n", + " nrows=7, ncols=2, figsize=(15, 20), dpi=80, facecolor=\"w\", edgecolor=\"k\"\n", + " )\n", + " for i in range(len(feature_keys)):\n", + " key = feature_keys[i]\n", + " c = colors[i % (len(colors))]\n", + " t_data = data[key]\n", + " t_data.index = time_data\n", + " t_data.head()\n", + " ax = t_data.plot(\n", + " ax=axes[i // 2, i % 2],\n", + " color=c,\n", + " title=\"{} - {}\".format(titles[i], key),\n", + " rot=25,\n", + " )\n", + " ax.legend([titles[i]])\n", + " plt.tight_layout()\n", + "\n", + "\n", + "show_raw_visualization(df)\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Data Preprocessing\n", + "\n", + "Here we are picking ~300,000 data points for training. Observation is recorded every\n", + "10 mins, that means 6 times per hour. We will resample one point per hour since no\n", + "drastic change is expected within 60 minutes. We do this via the `sampling_rate`\n", + "argument in `timeseries_dataset_from_array` utility.\n", + "\n", + "We are tracking data from past 720 timestamps (720/6=120 hours). This data will be\n", + "used to predict the temperature after 72 timestamps (72/6=12 hours).\n", + "\n", + "Since every feature has values with\n", + "varying ranges, we do normalization to confine feature values to a range of `[0, 1]` before\n", + "training a neural network.\n", + "We do this by subtracting the mean and dividing by the standard deviation of each feature.\n", + "\n", + "71.5 % of the data will be used to train the model, i.e. 300,693 rows. `split_fraction` can\n", + "be changed to alter this percentage.\n", + "\n", + "The model is shown data for first 5 days i.e. 720 observations, that are sampled every\n", + "hour. The temperature after 72 (12 hours * 6 observation per hour) observation will be\n", + "used as a label." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "split_fraction = 0.715\n", + "train_split = int(split_fraction * int(df.shape[0]))\n", + "step = 6\n", + "\n", + "past = 720\n", + "future = 72\n", + "learning_rate = 0.001\n", + "batch_size = 256\n", + "epochs = 10\n", + "\n", + "\n", + "def normalize(data, train_split):\n", + " data_mean = data[:train_split].mean(axis=0)\n", + " data_std = data[:train_split].std(axis=0)\n", + " return (data - data_mean) / data_std\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "We can see from the correlation heatmap, few parameters like Relative Humidity and\n", + "Specific Humidity are redundant. Hence we will be using select features, not all." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "print(\n", + " \"The selected parameters are:\",\n", + " \", \".join([titles[i] for i in [0, 1, 5, 7, 8, 10, 11]]),\n", + ")\n", + "selected_features = [feature_keys[i] for i in [0, 1, 5, 7, 8, 10, 11]]\n", + "features = df[selected_features]\n", + "features.index = df[date_time_key]\n", + "features.head()\n", + "\n", + "features = normalize(features.values, train_split)\n", + "features = pd.DataFrame(features)\n", + "features.head()\n", + "\n", + "train_data = features.loc[0 : train_split - 1]\n", + "val_data = features.loc[train_split:]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "# Training dataset\n", + "\n", + "The training dataset labels starts from the 792nd observation (720 + 72)." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "start = past + future\n", + "end = start + train_split\n", + "\n", + "x_train = train_data[[i for i in range(7)]].values\n", + "y_train = features.iloc[start:end][[1]]\n", + "\n", + "sequence_length = int(past / step)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "The `timeseries_dataset_from_array` function takes in a sequence of data-points gathered at\n", + "equal intervals, along with time series parameters such as length of the\n", + "sequences/windows, spacing between two sequence/windows, etc., to produce batches of\n", + "sub-timeseries inputs and targets sampled from the main timeseries." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "dataset_train = keras.preprocessing.timeseries_dataset_from_array(\n", + " x_train,\n", + " y_train,\n", + " sequence_length=sequence_length,\n", + " sampling_rate=step,\n", + " batch_size=batch_size,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Validation dataset\n", + "\n", + "The validation dataset must not contain the last 792 rows as we won't have label data for\n", + "those records, hence 792 must be subtracted from the end of the data.\n", + "\n", + "The validation label dataset must start from 792 after train_split, hence we must add\n", + "past + future (792) to label_start." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "x_end = len(val_data) - past - future\n", + "\n", + "label_start = train_split + past + future\n", + "\n", + "x_val = val_data.iloc[:x_end][[i for i in range(7)]].values\n", + "y_val = features.iloc[label_start:][[1]]\n", + "\n", + "dataset_val = keras.preprocessing.timeseries_dataset_from_array(\n", + " x_val,\n", + " y_val,\n", + " sequence_length=sequence_length,\n", + " sampling_rate=step,\n", + " batch_size=batch_size,\n", + ")\n", + "\n", + "\n", + "for batch in dataset_train.take(1):\n", + " inputs, targets = batch\n", + "\n", + "print(\"Input shape:\", inputs.numpy().shape)\n", + "print(\"Target shape:\", targets.numpy().shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Training" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "inputs = keras.layers.Input(shape=(inputs.shape[1], inputs.shape[2]))\n", + "lstm_out = keras.layers.LSTM(32)(inputs)\n", + "outputs = keras.layers.Dense(1)(lstm_out)\n", + "\n", + "model = keras.Model(inputs=inputs, outputs=outputs)\n", + "model.compile(optimizer=keras.optimizers.Adam(learning_rate=learning_rate), loss=\"mse\")\n", + "model.summary()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "We'll use the `ModelCheckpoint` callback to regularly save checkpoints, and\n", + "the `EarlyStopping` callback to interrupt training when the validation loss\n", + "is not longer improving." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "path_checkpoint = \"model_checkpoint.weights.h5\"\n", + "es_callback = keras.callbacks.EarlyStopping(monitor=\"val_loss\", min_delta=0, patience=5)\n", + "\n", + "modelckpt_callback = keras.callbacks.ModelCheckpoint(\n", + " monitor=\"val_loss\",\n", + " filepath=path_checkpoint,\n", + " verbose=1,\n", + " save_weights_only=True,\n", + " save_best_only=True,\n", + ")\n", + "\n", + "history = model.fit(\n", + " dataset_train,\n", + " epochs=epochs,\n", + " validation_data=dataset_val,\n", + " callbacks=[es_callback, modelckpt_callback],\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "We can visualize the loss with the function below. After one point, the loss stops\n", + "decreasing." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "def visualize_loss(history, title):\n", + " loss = history.history[\"loss\"]\n", + " val_loss = history.history[\"val_loss\"]\n", + " epochs = range(len(loss))\n", + " plt.figure()\n", + " plt.plot(epochs, loss, \"b\", label=\"Training loss\")\n", + " plt.plot(epochs, val_loss, \"r\", label=\"Validation loss\")\n", + " plt.title(title)\n", + " plt.xlabel(\"Epochs\")\n", + " plt.ylabel(\"Loss\")\n", + " plt.legend()\n", + " plt.show()\n", + "\n", + "\n", + "visualize_loss(history, \"Training and Validation Loss\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Prediction\n", + "\n", + "The trained model above is now able to make predictions for 5 sets of values from\n", + "validation set." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "def show_plot(plot_data, delta, title):\n", + " labels = [\"History\", \"True Future\", \"Model Prediction\"]\n", + " marker = [\".-\", \"rx\", \"go\"]\n", + " time_steps = list(range(-(plot_data[0].shape[0]), 0))\n", + " if delta:\n", + " future = delta\n", + " else:\n", + " future = 0\n", + "\n", + " plt.title(title)\n", + " for i, val in enumerate(plot_data):\n", + " if i:\n", + " plt.plot(future, plot_data[i], marker[i], markersize=10, label=labels[i])\n", + " else:\n", + " plt.plot(time_steps, plot_data[i].flatten(), marker[i], label=labels[i])\n", + " plt.legend()\n", + " plt.xlim([time_steps[0], (future + 5) * 2])\n", + " plt.xlabel(\"Time-Step\")\n", + " plt.show()\n", + " return\n", + "\n", + "\n", + "for x, y in dataset_val.take(5):\n", + " show_plot(\n", + " [x[0][:, 1].numpy(), y[0].numpy(), model.predict(x)[0]],\n", + " 12,\n", + " \"Single Step Prediction\",\n", + " )" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "timeseries_weather_forecasting", + "private_outputs": false, + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.0" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/examples/timeseries/timeseries_weather_forecasting.py b/examples/timeseries/timeseries_weather_forecasting.py index 0b30771948..824395cc90 100644 --- a/examples/timeseries/timeseries_weather_forecasting.py +++ b/examples/timeseries/timeseries_weather_forecasting.py @@ -10,7 +10,7 @@ """ ## Setup """ - +import os import pandas as pd import matplotlib.pyplot as plt import keras @@ -50,13 +50,18 @@ 15 |wd (deg) |152.3 |Wind direction in degrees """ + from zipfile import ZipFile uri = "https://storage.googleapis.com/tensorflow/tf-keras-datasets/jena_climate_2009_2016.csv.zip" zip_path = keras.utils.get_file(origin=uri, fname="jena_climate_2009_2016.csv.zip") zip_file = ZipFile(zip_path) -zip_file.extractall() -csv_path = "jena_climate_2009_2016.csv" + +# FIX: Extract to the cache directory, not the current working directory +zip_file.extractall(path=os.path.dirname(zip_path)) + +# FIX: Construct the absolute path safely (works on Windows/Linux/Mac) +csv_path = os.path.join(os.path.dirname(zip_path), "jena_climate_2009_2016.csv") df = pd.read_csv(csv_path) diff --git a/guides/define_custom_kernel.py b/guides/define_custom_kernel.py deleted file mode 100644 index 4aba1c5d3a..0000000000 --- a/guides/define_custom_kernel.py +++ /dev/null @@ -1,400 +0,0 @@ -""" -Title: Define a Custom TPU/GPU Kernel -Author: [jeffcarp](https://www.jeffcarp.com/) -Date created: 2025/12/18 -Last modified: 2025/12/18 -Description: Write high-performance custom Keras layers for TPUs and GPUs. -Accelerator: TPU -""" - -""" -# How to Write a Custom TPU or GPU Kernel in Keras - -Keras has [many pre-made layers to choose from](/api/layers/), and the -ability to easily [create your -own](/guides/making_new_layers_and_models_via_subclassing/) if you can't -find the exact one you need. However, if you have a need for speed, or otherwise -need to customize the exact behavior of your model at the hardware level, you -may want to look into writing a custom kernel. A good way to know if you need a -custom kernel is to look at the profile of your model and see if there are any -idle gaps caused by computation or memory transfer bottlenecks (see the -[TensorBoard callback](/api/callbacks/tensorboard/) for how to get a profile). - -This guide will explore how to write a custom kernel and add it to your -Keras model. We will utilize **Pallas**, a library that lets you write -kernels in Python that can run on both TPU or GPU, where they're lowered -to Mosaic or Triton, respectively. You can learn more in the [Pallas -docs](https://docs.jax.dev/en/latest/pallas/index.html). - -**Compatibility note:** Pallas is only available when using the JAX backend on -certain hardware: - -- TPU v4 and above -- NVIDIA Ampere GPUs (compute capability 8.0) and above - -If you're running in Colab, the v5e-1 in the free tier supports running this -guide. - -First, make sure you're running the latest version of `libtpu`: -""" - -"""shell -pip install --upgrade -q "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -""" - -from functools import partial -import os -import time - - -os.environ["KERAS_BACKEND"] = "jax" - -import jax -from jax.experimental import pallas as pl -import jax.numpy as jnp -import keras - - -""" -# Simple Example - -Let's start with the example from the [Pallas -quickstart](https://docs.jax.dev/en/latest/pallas/quickstart.html): a simple -kernel to add two vectors together. -""" - - -def add_vectors_kernel(x_ref, y_ref, o_ref): - """Pallas kernel for adding two vectors together.""" - x, y = x_ref[...], y_ref[...] - o_ref[...] = x + y - - -""" -Now jit-compile the Pallas function into a function that can be used by JAX. -""" - - -@jax.jit -def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: - return pl.pallas_call( - add_vectors_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype) - )(x, y) - - -add_vectors(jnp.arange(8), jnp.arange(8)) - -""" -Now we can embed the jitted `add_vectors` function containing the Pallas kernel into a -Keras layer, just by calling it there. -""" - - -class PallasAddLayer(keras.Layer): - def call(self, x, y): - # Reuse the JIT-compiled Pallas function - return add_vectors(x, y) - - -layer = PallasAddLayer() - -x_data = jnp.arange(8, dtype=jnp.int32) -y_data = jnp.arange(8, dtype=jnp.int32) - -layer(x_data, y_data) - -""" -That's how to integrate a Pallas kernel into a Keras layer! Now for a more -in-depth example. -""" - -""" -# Writing a Fused Linear Activation Layer - -Some common reasons you might want to write a custom kernel is to take advantage of -**fusion** and **tiling**. - -**Operator fusion** is the process of combining two or more ops into one "fused" op, for -example instead of calling `keras.ops.matmul` then `keras.ops.relu` sequentially, we -could write a custom op that combines both into one more efficient operator. -XLA already [does operator fusion when possible](https://arxiv.org/abs/2301.13062) for -certain use cases, but to squeeze even more performance out of the TPU or GPU, we need to -write a custom op to specify the fusion exactly. - -**Tiling** is the ability to control how blocks of memory are loaded from the TPU or -GPU's larger High Bandwidth Memory (HBM) to the smaller, extremely fast on-chip -memory (called VMEM on TPU or SMEM on GPU) that the accelerator's computation -units (e.g., TPU's Matrix Units or a GPU's Tensor Cores) use directly. This is -critical for improving the performance of large matrix multiplications, for -example those in the MLP layer at the end of Transformer blocks. - -In Pallas, tiling is controlled by the `BlockSpec`. Learn more in the -[Pallas BlockSpec guide -here](https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#blockspec-a-k-a-how-to-chunk-up-inputs). - -In this section, we'll take two operations that commonly appear together: a -matrix multiplication (like in a `Dense` layer) and a ReLU activation. We will -write a new op that fuses them together for better performance. - -## Original Unoptimized Implementation -""" - - -class StandardDenseReLU(keras.layers.Layer): - """Standard Matmul and ReLU implementation using keras.ops.""" - - def __init__(self, units, **kwargs): - super().__init__(**kwargs) - self.units = units - - def build(self, input_shape): - self.w = self.add_weight( - shape=(input_shape[-1], self.units), - initializer="glorot_uniform", - trainable=True, - ) - - def call(self, inputs): - # The standard implementation performs two separate operations. - # Each one involves expensive data transfer with the main device memory (HBM). - # 1. Matmul: inputs (HBM) -> compute -> intermediate (HBM) - y = keras.ops.matmul(inputs, self.w) - # 2. ReLU: intermediate (HBM) -> compute -> output (HBM) - return keras.ops.relu(y) - - -""" -## 1. Define the Fused Kernel - -First we create an inner kernel function that defines the fused computation that -combines both matmul (`pl.dot`) and activation (`jnp.maximum`). -""" - -import jax.numpy as jnp -from jax.experimental import pallas as pl - - -def matmul_relu_kernel(a_ref, b_ref, c_ref): - """Pallas kernel for fused matmul + ReLU.""" - # Perform the matrix multiplication on the local tile - # pl.dot leverages the hardware's Matrix Unit (MXU) - acc = pl.dot(a_ref[...], b_ref[...]) - - # Fusion happens here: apply activation while data is in VMEM - result = jnp.maximum(acc, 0) - - # Write the final result to the output reference - c_ref[...] = result - - -""" -## 2. Specify the Tiling (BlockSpec) - -Since the input matrices are usually too large to fit into VMEM, Pallas needs ot -know how to "slice" them for loading from HBM to VMEM. - -We define this using `BlockSpec` - this tells the hardware: "Take a 128-row -chunk of Matrix A and a 128-column chunk of Matrix B to produce a 128x128 tile -of Matrix C." -""" - - -@jax.jit -def fused_matmul(a, b): - m, k = a.shape - _, n = b.shape - - # Define tile sizes - tile_m, tile_n = 128, 128 - assert ( - m % tile_m == 0 and n % tile_n == 0 - ), "Inputs must be multiples of 128 for this demo" - - return pl.pallas_call( - matmul_relu_kernel, - # Map output indices to input blocks - out_shape=jax.ShapeDtypeStruct((m, n), a.dtype), - in_specs=[ - # For each output tile, we take a slice of A of shape (tile_m, k) - pl.BlockSpec( - index_map=lambda i, j: (i, 0), block_shape=(tile_m, k) - ), # Matrix A - # For each output tile, we take a slice of B of shape (k, tile_n) - pl.BlockSpec( - index_map=lambda i, j: (0, j), block_shape=(k, tile_n) - ), # Matrix B - ], - out_specs=pl.BlockSpec( - index_map=lambda i, j: (i, j), block_shape=(tile_m, tile_n) - ), # Matrix C - grid=(m // tile_m, n // tile_n), - )(a, b) - - -fused_matmul(jnp.ones((256, 256)), jnp.ones((256, 256))) - -""" -## 3. Integrating into a Keras Layer - -Now for the final step, call the jit-compiled `fused_matmul` kernel from a -`keras.Layer`. -""" - - -class FusedDense(keras.layers.Layer): - """Custom Keras layer that applies the fused Dense and ReLU op.""" - - def __init__(self, units, **kwargs): - super().__init__(**kwargs) - self.units = units - - def build(self, input_shape): - self.w = self.add_weight( - shape=(input_shape[-1], self.units), initializer="glorot_uniform" - ) - - def call(self, inputs): - # Dispatch to our Pallas kernel - return fused_matmul(inputs, self.w.value) - - -FusedDense(256)(jnp.ones((256, 256))) - -""" -## 4. Benchmarking the Speedup -""" - -# 1. Setup Data -N = 8192 # Large enough to be memory bound -input_data = jnp.ones((N, N), dtype="float32") - -# Initialize layers -standard_layer = StandardDenseReLU(units=N) -pallas_layer = FusedDense(units=N) - -# Build layers by calling them once -standard_layer(input_data) -pallas_layer(input_data) - - -def benchmark(layer, x, name, iterations=100): - # Warm up to ensure JIT compilation is finished - for _ in range(10): - layer(x).block_until_ready() - - start_time = time.perf_counter() - for _ in range(iterations): - layer(x).block_until_ready() - end_time = time.perf_counter() - - avg_time = (end_time - start_time) / iterations * 1000 # convert to ms - print(f"{name} Average Latency: {avg_time:.3f} ms") - - -# 2. Run Comparison -print(f"Benchmarking Matrix Size: {N}x{N}\n" + "-" * 30) -benchmark(standard_layer, input_data, "Standard Keras (Matmul + ReLU)") -benchmark(pallas_layer, input_data, "Pallas Fused (Matmul + ReLU)") - - -""" -### Why this Works - -**Memory Bandwidth Efficiency:** By fusing the matrix multiplication and -activation, we perform the ReLU computation while data is still in the chip's -fast VMEM. This drastically reduces expensive read/write roundtrips to HBM. - -**Automatic Parallelization:** Pallas handles the "grid" execution, meaning -it automatically parallelizes your defined tiles across the available hardware -cores (whether TPU MXUs or GPU Tensor Cores). - -**Drop-in Inference Speed:** This `FusedDense` kernel can be integrated into any -Keras model, giving an example of improving serving/inference performance with -minimal code changes. -""" - -""" -## 5. Enabling Training - -In order for a Pallas kernel to be trainable, you must also supply -a second kernel to define the custom backward pass, since JAX can't -[AutoGrad](https://docs.jax.dev/en/latest/automatic-differentiation.html) -through Pallas kernels. Without it, you might see an error like this: - -``` -model = keras.Sequential([FusedDense(256)]) -model.compile(optimizer="adam", loss="mse") -model.fit(jnp.ones((256, 256)), jnp.ones((256, 256))) ->>> Linearization failed to produce known values for all output primals. This is -typically caused by attempting to differentiate a function uses an operation -that does not support reverse-mode autodiff. -``` - -To extend our fused matmul example above: -""" - - -# 1. Define the wrapper with `custom_vjp` using our original `fused_matmul`. -@jax.custom_vjp -def fused_matmul_trainable(x, w): - return fused_matmul(x, w) - - -# 2. Define the Forward Pass -# It must return the output AND "residuals" (data needed for the backward pass) -def fused_matmul_fwd(x, w): - y = fused_matmul_trainable(x, w) - # We save inputs x, w and output y for the backward calculation - return y, (x, w, y) - - -# 3. Define the Backward Pass -# JAX gives us the residuals and the incoming gradient (g) -def fused_matmul_bwd(residuals, g): - x, w, y = residuals - - # Calculate the gradient of ReLU: 1 if y > 0, else 0 - # g is the gradient flowing back from the next layer - grad_relu = g * (y > 0) - - # Standard backprop math for matmul: - # grad_x = grad_relu @ w.T - grad_x = jnp.dot(grad_relu, w.T) - - # grad_w = x.T @ grad_relu - grad_w = jnp.dot(x.T, grad_relu) - - return grad_x, grad_w - - -# 4. Register the forward and backward functions -fused_matmul_trainable.defvjp(fused_matmul_fwd, fused_matmul_bwd) - - -class FusedDenseTrainable(FusedDense): - """Updated layer that contains Pallas forward and backward pass.""" - - def call(self, inputs): - # Dispatch to our trainable Pallas kernel - return fused_matmul_trainable(inputs, self.w.value) - - -# Demonstrate trainability on dummy data -model = keras.Sequential([FusedDenseTrainable(256)]) -model.compile(optimizer="adam", loss="mse") -model.fit(jnp.ones((256, 256)), jnp.ones((256, 256)), batch_size=128) - -""" -# Followups - -In this guide we covered how to define a simple custom Pallas kernel performing vector -addition to include in a Keras model. Then we followed up with a more in-depth -example of a fused matmul + activation kernel that you might use in a real-world -model to improve performance. - -Please refer to the [Pallas -docs](https://docs.jax.dev/en/latest/pallas/index.html#) for further -documentation on writing custom kernels. Additionally to explore more examples -of Pallas kernels, including FlashAttention and MoE layers, check out the -[Tokamax](https://github.com/openxla/tokamax) library. -""" diff --git a/guides/ipynb/define_custom_kernel.ipynb b/guides/ipynb/define_custom_kernel.ipynb deleted file mode 100644 index 68d9bc9f84..0000000000 --- a/guides/ipynb/define_custom_kernel.ipynb +++ /dev/null @@ -1,601 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text" - }, - "source": [ - "# Define a Custom TPU/GPU Kernel\n", - "\n", - "**Author:** [jeffcarp](https://www.jeffcarp.com/)
\n", - "**Date created:** 2025/12/18
\n", - "**Last modified:** 2025/12/18
\n", - "**Description:** Write high-performance custom Keras layers for TPUs and GPUs." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text" - }, - "source": [ - "# How to Write a Custom TPU or GPU Kernel in Keras\n", - "\n", - "Keras has [many pre-made layers to choose from](/api/layers/), and the\n", - "ability to easily [create your\n", - "own](/guides/making_new_layers_and_models_via_subclassing/) if you can't\n", - "find the exact one you need. However, if you have a need for speed, or otherwise\n", - "need to customize the exact behavior of your model at the hardware level, you\n", - "may want to look into writing a custom kernel. A good way to know if you need a\n", - "custom kernel is to look at the profile of your model and see if there are any\n", - "idle gaps caused by computation or memory transfer bottlenecks (see the\n", - "[TensorBoard callback](/api/callbacks/tensorboard/) for how to get a profile).\n", - "\n", - "This guide will explore how to write a custom kernel and add it to your\n", - "Keras model. We will utilize **Pallas**, a library that lets you write\n", - "kernels in Python that can run on both TPU or GPU, where they're lowered\n", - "to Mosaic or Triton, respectively. You can learn more in the [Pallas\n", - "docs](https://docs.jax.dev/en/latest/pallas/index.html).\n", - "\n", - "**Compatibility note:** Pallas is only available when using the JAX backend on\n", - "certain hardware:\n", - "\n", - "- TPU v4 and above\n", - "- NVIDIA Ampere GPUs (compute capability 8.0) and above\n", - "\n", - "If you're running in Colab, the v5e-1 in the free tier supports running this\n", - "guide.\n", - "\n", - "First, make sure you're running the latest version of `libtpu`:" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab_type": "code" - }, - "outputs": [], - "source": [ - "!pip install --upgrade -q \"jax[tpu]\" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab_type": "code" - }, - "outputs": [], - "source": [ - "from functools import partial\n", - "import os\n", - "import time\n", - "\n", - "\n", - "os.environ[\"KERAS_BACKEND\"] = \"jax\"\n", - "\n", - "import jax\n", - "from jax.experimental import pallas as pl\n", - "import jax.numpy as jnp\n", - "import keras\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text" - }, - "source": [ - "# Simple Example\n", - "\n", - "Let's start with the example from the [Pallas\n", - "quickstart](https://docs.jax.dev/en/latest/pallas/quickstart.html): a simple\n", - "kernel to add two vectors together." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab_type": "code" - }, - "outputs": [], - "source": [ - "\n", - "def add_vectors_kernel(x_ref, y_ref, o_ref):\n", - " \"\"\"Pallas kernel for adding two vectors together.\"\"\"\n", - " x, y = x_ref[...], y_ref[...]\n", - " o_ref[...] = x + y\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text" - }, - "source": [ - "Now jit-compile the Pallas function into a function that can be used by JAX." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab_type": "code" - }, - "outputs": [], - "source": [ - "\n", - "@jax.jit\n", - "def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:\n", - " return pl.pallas_call(\n", - " add_vectors_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype)\n", - " )(x, y)\n", - "\n", - "\n", - "add_vectors(jnp.arange(8), jnp.arange(8))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text" - }, - "source": [ - "Now we can embed the jitted `add_vectors` function containing the Pallas kernel into a\n", - "Keras layer, just by calling it there." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab_type": "code" - }, - "outputs": [], - "source": [ - "\n", - "class PallasAddLayer(keras.Layer):\n", - " def call(self, x, y):\n", - " # Reuse the JIT-compiled Pallas function\n", - " return add_vectors(x, y)\n", - "\n", - "\n", - "layer = PallasAddLayer()\n", - "\n", - "x_data = jnp.arange(8, dtype=jnp.int32)\n", - "y_data = jnp.arange(8, dtype=jnp.int32)\n", - "\n", - "layer(x_data, y_data)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text" - }, - "source": [ - "That's how to integrate a Pallas kernel into a Keras layer! Now for a more\n", - "in-depth example." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text" - }, - "source": [ - "# Writing a Fused Linear Activation Layer\n", - "\n", - "Some common reasons you might want to write a custom kernel is to take advantage of\n", - "**fusion** and **tiling**.\n", - "\n", - "**Operator fusion** is the process of combining two or more ops into one \"fused\" op, for\n", - "example instead of calling `keras.ops.matmul` then `keras.ops.relu` sequentially, we\n", - "could write a custom op that combines both into one more efficient operator.\n", - "XLA already [does operator fusion when possible](https://arxiv.org/abs/2301.13062) for\n", - "certain use cases, but to squeeze even more performance out of the TPU or GPU, we need to\n", - "write a custom op to specify the fusion exactly.\n", - "\n", - "**Tiling** is the ability to control how blocks of memory are loaded from the TPU or\n", - "GPU's larger High Bandwidth Memory (HBM) to the smaller, extremely fast on-chip\n", - "memory (called VMEM on TPU or SMEM on GPU) that the accelerator's computation\n", - "units (e.g., TPU's Matrix Units or a GPU's Tensor Cores) use directly. This is\n", - "critical for improving the performance of large matrix multiplications, for\n", - "example those in the MLP layer at the end of Transformer blocks.\n", - "\n", - "In Pallas, tiling is controlled by the `BlockSpec`. Learn more in the\n", - "[Pallas BlockSpec guide\n", - "here](https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#blockspec-a-k-a-how-to-chunk-up-inputs).\n", - "\n", - "In this section, we'll take two operations that commonly appear together: a\n", - "matrix multiplication (like in a `Dense` layer) and a ReLU activation. We will\n", - "write a new op that fuses them together for better performance.\n", - "\n", - "## Original Unoptimized Implementation" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab_type": "code" - }, - "outputs": [], - "source": [ - "\n", - "class StandardDenseReLU(keras.layers.Layer):\n", - " \"\"\"Standard Matmul and ReLU implementation using keras.ops.\"\"\"\n", - "\n", - " def __init__(self, units, **kwargs):\n", - " super().__init__(**kwargs)\n", - " self.units = units\n", - "\n", - " def build(self, input_shape):\n", - " self.w = self.add_weight(\n", - " shape=(input_shape[-1], self.units),\n", - " initializer=\"glorot_uniform\",\n", - " trainable=True,\n", - " )\n", - "\n", - " def call(self, inputs):\n", - " # The standard implementation performs two separate operations.\n", - " # Each one involves expensive data transfer with the main device memory (HBM).\n", - " # 1. Matmul: inputs (HBM) -> compute -> intermediate (HBM)\n", - " y = keras.ops.matmul(inputs, self.w)\n", - " # 2. ReLU: intermediate (HBM) -> compute -> output (HBM)\n", - " return keras.ops.relu(y)\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text" - }, - "source": [ - "## 1. Define the Fused Kernel\n", - "\n", - "First we create an inner kernel function that defines the fused computation that\n", - "combines both matmul (`pl.dot`) and activation (`jnp.maximum`)." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab_type": "code" - }, - "outputs": [], - "source": [ - "import jax.numpy as jnp\n", - "from jax.experimental import pallas as pl\n", - "\n", - "\n", - "def matmul_relu_kernel(a_ref, b_ref, c_ref):\n", - " \"\"\"Pallas kernel for fused matmul + ReLU.\"\"\"\n", - " # Perform the matrix multiplication on the local tile\n", - " # pl.dot leverages the hardware's Matrix Unit (MXU)\n", - " acc = pl.dot(a_ref[...], b_ref[...])\n", - "\n", - " # Fusion happens here: apply activation while data is in VMEM\n", - " result = jnp.maximum(acc, 0)\n", - "\n", - " # Write the final result to the output reference\n", - " c_ref[...] = result\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text" - }, - "source": [ - "## 2. Specify the Tiling (BlockSpec)\n", - "\n", - "Since the input matrices are usually too large to fit into VMEM, Pallas needs ot\n", - "know how to \"slice\" them for loading from HBM to VMEM.\n", - "\n", - "We define this using `BlockSpec` - this tells the hardware: \"Take a 128-row\n", - "chunk of Matrix A and a 128-column chunk of Matrix B to produce a 128x128 tile\n", - "of Matrix C.\"" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab_type": "code" - }, - "outputs": [], - "source": [ - "\n", - "@jax.jit\n", - "def fused_matmul(a, b):\n", - " m, k = a.shape\n", - " _, n = b.shape\n", - "\n", - " # Define tile sizes\n", - " tile_m, tile_n = 128, 128\n", - " assert (\n", - " m % tile_m == 0 and n % tile_n == 0\n", - " ), \"Inputs must be multiples of 128 for this demo\"\n", - "\n", - " return pl.pallas_call(\n", - " matmul_relu_kernel,\n", - " # Map output indices to input blocks\n", - " out_shape=jax.ShapeDtypeStruct((m, n), a.dtype),\n", - " in_specs=[\n", - " # For each output tile, we take a slice of A of shape (tile_m, k)\n", - " pl.BlockSpec(\n", - " index_map=lambda i, j: (i, 0), block_shape=(tile_m, k)\n", - " ), # Matrix A\n", - " # For each output tile, we take a slice of B of shape (k, tile_n)\n", - " pl.BlockSpec(\n", - " index_map=lambda i, j: (0, j), block_shape=(k, tile_n)\n", - " ), # Matrix B\n", - " ],\n", - " out_specs=pl.BlockSpec(\n", - " index_map=lambda i, j: (i, j), block_shape=(tile_m, tile_n)\n", - " ), # Matrix C\n", - " grid=(m // tile_m, n // tile_n),\n", - " )(a, b)\n", - "\n", - "\n", - "fused_matmul(jnp.ones((256, 256)), jnp.ones((256, 256)))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text" - }, - "source": [ - "## 3. Integrating into a Keras Layer\n", - "\n", - "Now for the final step, call the jit-compiled `fused_matmul` kernel from a\n", - "`keras.Layer`." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab_type": "code" - }, - "outputs": [], - "source": [ - "\n", - "class FusedDense(keras.layers.Layer):\n", - " \"\"\"Custom Keras layer that applies the fused Dense and ReLU op.\"\"\"\n", - "\n", - " def __init__(self, units, **kwargs):\n", - " super().__init__(**kwargs)\n", - " self.units = units\n", - "\n", - " def build(self, input_shape):\n", - " self.w = self.add_weight(\n", - " shape=(input_shape[-1], self.units), initializer=\"glorot_uniform\"\n", - " )\n", - "\n", - " def call(self, inputs):\n", - " # Dispatch to our Pallas kernel\n", - " return fused_matmul(inputs, self.w.value)\n", - "\n", - "\n", - "FusedDense(256)(jnp.ones((256, 256)))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text" - }, - "source": [ - "## 4. Benchmarking the Speedup" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab_type": "code" - }, - "outputs": [], - "source": [ - "# 1. Setup Data\n", - "N = 8192 # Large enough to be memory bound\n", - "input_data = jnp.ones((N, N), dtype=\"float32\")\n", - "\n", - "# Initialize layers\n", - "standard_layer = StandardDenseReLU(units=N)\n", - "pallas_layer = FusedDense(units=N)\n", - "\n", - "# Build layers by calling them once\n", - "standard_layer(input_data)\n", - "pallas_layer(input_data)\n", - "\n", - "\n", - "def benchmark(layer, x, name, iterations=100):\n", - " # Warm up to ensure JIT compilation is finished\n", - " for _ in range(10):\n", - " layer(x).block_until_ready()\n", - "\n", - " start_time = time.perf_counter()\n", - " for _ in range(iterations):\n", - " layer(x).block_until_ready()\n", - " end_time = time.perf_counter()\n", - "\n", - " avg_time = (end_time - start_time) / iterations * 1000 # convert to ms\n", - " print(f\"{name} Average Latency: {avg_time:.3f} ms\")\n", - "\n", - "\n", - "# 2. Run Comparison\n", - "print(f\"Benchmarking Matrix Size: {N}x{N}\\n\" + \"-\" * 30)\n", - "benchmark(standard_layer, input_data, \"Standard Keras (Matmul + ReLU)\")\n", - "benchmark(pallas_layer, input_data, \"Pallas Fused (Matmul + ReLU)\")\n", - "" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text" - }, - "source": [ - "### Why this Works\n", - "\n", - "**Memory Bandwidth Efficiency:** By fusing the matrix multiplication and\n", - "activation, we perform the ReLU computation while data is still in the chip's\n", - "fast VMEM. This drastically reduces expensive read/write roundtrips to HBM.\n", - "\n", - "**Automatic Parallelization:** Pallas handles the \"grid\" execution, meaning\n", - "it automatically parallelizes your defined tiles across the available hardware\n", - "cores (whether TPU MXUs or GPU Tensor Cores).\n", - "\n", - "**Drop-in Inference Speed:** This `FusedDense` kernel can be integrated into any\n", - "Keras model, giving an example of improving serving/inference performance with\n", - "minimal code changes." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text" - }, - "source": [ - "## 5. Enabling Training\n", - "\n", - "In order for a Pallas kernel to be trainable, you must also supply\n", - "a second kernel to define the custom backward pass, since JAX can't\n", - "[AutoGrad](https://docs.jax.dev/en/latest/automatic-differentiation.html)\n", - "through Pallas kernels. Without it, you might see an error like this:\n", - "\n", - "```\n", - "model = keras.Sequential([FusedDense(256)])\n", - "model.compile(optimizer=\"adam\", loss=\"mse\")\n", - "model.fit(jnp.ones((256, 256)), jnp.ones((256, 256)))\n", - ">>> Linearization failed to produce known values for all output primals. This is\n", - "typically caused by attempting to differentiate a function uses an operation\n", - "that does not support reverse-mode autodiff.\n", - "```\n", - "\n", - "To extend our fused matmul example above:" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab_type": "code" - }, - "outputs": [], - "source": [ - "\n", - "# 1. Define the wrapper with `custom_vjp` using our original `fused_matmul`.\n", - "@jax.custom_vjp\n", - "def fused_matmul_trainable(x, w):\n", - " return fused_matmul(x, w)\n", - "\n", - "\n", - "# 2. Define the Forward Pass\n", - "# It must return the output AND \"residuals\" (data needed for the backward pass)\n", - "def fused_matmul_fwd(x, w):\n", - " y = fused_matmul_trainable(x, w)\n", - " # We save inputs x, w and output y for the backward calculation\n", - " return y, (x, w, y)\n", - "\n", - "\n", - "# 3. Define the Backward Pass\n", - "# JAX gives us the residuals and the incoming gradient (g)\n", - "def fused_matmul_bwd(residuals, g):\n", - " x, w, y = residuals\n", - "\n", - " # Calculate the gradient of ReLU: 1 if y > 0, else 0\n", - " # g is the gradient flowing back from the next layer\n", - " grad_relu = g * (y > 0)\n", - "\n", - " # Standard backprop math for matmul:\n", - " # grad_x = grad_relu @ w.T\n", - " grad_x = jnp.dot(grad_relu, w.T)\n", - "\n", - " # grad_w = x.T @ grad_relu\n", - " grad_w = jnp.dot(x.T, grad_relu)\n", - "\n", - " return grad_x, grad_w\n", - "\n", - "\n", - "# 4. Register the forward and backward functions\n", - "fused_matmul_trainable.defvjp(fused_matmul_fwd, fused_matmul_bwd)\n", - "\n", - "\n", - "class FusedDenseTrainable(FusedDense):\n", - " \"\"\"Updated layer that contains Pallas forward and backward pass.\"\"\"\n", - "\n", - " def call(self, inputs):\n", - " # Dispatch to our trainable Pallas kernel\n", - " return fused_matmul_trainable(inputs, self.w.value)\n", - "\n", - "\n", - "# Demonstrate trainability on dummy data\n", - "model = keras.Sequential([FusedDenseTrainable(256)])\n", - "model.compile(optimizer=\"adam\", loss=\"mse\")\n", - "model.fit(jnp.ones((256, 256)), jnp.ones((256, 256)), batch_size=128)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "colab_type": "text" - }, - "source": [ - "# Followups\n", - "\n", - "In this guide we covered how to define a simple custom Pallas kernel performing vector\n", - "addition to include in a Keras model. Then we followed up with a more in-depth\n", - "example of a fused matmul + activation kernel that you might use in a real-world\n", - "model to improve performance.\n", - "\n", - "Please refer to the [Pallas\n", - "docs](https://docs.jax.dev/en/latest/pallas/index.html#) for further\n", - "documentation on writing custom kernels. Additionally to explore more examples\n", - "of Pallas kernels, including FlashAttention and MoE layers, check out the\n", - "[Tokamax](https://github.com/openxla/tokamax) library." - ] - } - ], - "metadata": { - "accelerator": "TPU", - "colab": { - "collapsed_sections": [], - "name": "define_custom_kernel", - "private_outputs": false, - "provenance": [], - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.0" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} \ No newline at end of file diff --git a/guides/md/define_custom_kernel.md b/guides/md/define_custom_kernel.md deleted file mode 100644 index 4fdf4af23c..0000000000 --- a/guides/md/define_custom_kernel.md +++ /dev/null @@ -1,496 +0,0 @@ -# Define a Custom TPU/GPU Kernel - -**Author:** [jeffcarp](https://www.jeffcarp.com/)
-**Date created:** 2025/12/18
-**Last modified:** 2025/12/18
-**Description:** Write high-performance custom Keras layers for TPUs and GPUs. - - - [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/guides/ipynb/define_custom_kernel.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/guides/define_custom_kernel.py) - - - -# How to Write a Custom TPU or GPU Kernel in Keras - -Keras has [many pre-made layers to choose from](/api/layers/), and the -ability to easily [create your -own](/guides/making_new_layers_and_models_via_subclassing/) if you can't -find the exact one you need. However, if you have a need for speed, or otherwise -need to customize the exact behavior of your model at the hardware level, you -may want to look into writing a custom kernel. A good way to know if you need a -custom kernel is to look at the profile of your model and see if there are any -idle gaps caused by computation or memory transfer bottlenecks (see the -[TensorBoard callback](/api/callbacks/tensorboard/) for how to get a profile). - -This guide will explore how to write a custom kernel and add it to your -Keras model. We will utilize **Pallas**, a library that lets you write -kernels in Python that can run on both TPU or GPU, where they're lowered -to Mosaic or Triton, respectively. You can learn more in the [Pallas -docs](https://docs.jax.dev/en/latest/pallas/index.html). - -**Compatibility note:** Pallas is only available when using the JAX backend on -certain hardware: - -- TPU v4 and above -- NVIDIA Ampere GPUs (compute capability 8.0) and above - -If you're running in Colab, the v5e-1 in the free tier supports running this -guide. - -First, make sure you're running the latest version of `libtpu`: - - -```python -!pip install --upgrade -q "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html -``` - - -```python -from functools import partial -import os -import time - - -os.environ["KERAS_BACKEND"] = "jax" - -import jax -from jax.experimental import pallas as pl -import jax.numpy as jnp -import keras - -``` -
-``` -[notice] To update, run: pip install --upgrade pip - -/home/jeffcarp/venv/lib/python3.10/site-packages/jax/_src/cloud_tpu_init.py:84: UserWarning: Transparent hugepages are not enabled. TPU runtime startup and shutdown time should be significantly improved on TPU v5e and newer. If not already set, you may need to enable transparent hugepages in your VM image (sudo sh -c "echo always > /sys/kernel/mm/transparent_hugepage/enabled") - warnings.warn( -``` -
- -# Simple Example - -Let's start with the example from the [Pallas -quickstart](https://docs.jax.dev/en/latest/pallas/quickstart.html): a simple -kernel to add two vectors together. - - -```python - -def add_vectors_kernel(x_ref, y_ref, o_ref): - """Pallas kernel for adding two vectors together.""" - x, y = x_ref[...], y_ref[...] - o_ref[...] = x + y - -``` - -Now jit-compile the Pallas function into a function that can be used by JAX. - - -```python - -@jax.jit -def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array: - return pl.pallas_call( - add_vectors_kernel, out_shape=jax.ShapeDtypeStruct(x.shape, x.dtype) - )(x, y) - - -add_vectors(jnp.arange(8), jnp.arange(8)) -``` - - - - -
-``` -Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32) -``` -
- -Now we can embed the jitted `add_vectors` function containing the Pallas kernel into a -Keras layer, just by calling it there. - - -```python - -class PallasAddLayer(keras.Layer): - def call(self, x, y): - # Reuse the JIT-compiled Pallas function - return add_vectors(x, y) - - -layer = PallasAddLayer() - -x_data = jnp.arange(8, dtype=jnp.int32) -y_data = jnp.arange(8, dtype=jnp.int32) - -layer(x_data, y_data) -``` - - - - -
-``` -Array([ 0, 2, 4, 6, 8, 10, 12, 14], dtype=int32) -``` -
- -That's how to integrate a Pallas kernel into a Keras layer! Now for a more -in-depth example. - -# Writing a Fused Linear Activation Layer - -Some common reasons you might want to write a custom kernel is to take advantage of -**fusion** and **tiling**. - -**Operator fusion** is the process of combining two or more ops into one "fused" op, for -example instead of calling `keras.ops.matmul` then `keras.ops.relu` sequentially, we -could write a custom op that combines both into one more efficient operator. -XLA already [does operator fusion when possible](https://arxiv.org/abs/2301.13062) for -certain use cases, but to squeeze even more performance out of the TPU or GPU, we need to -write a custom op to specify the fusion exactly. - -**Tiling** is the ability to control how blocks of memory are loaded from the TPU or -GPU's larger High Bandwidth Memory (HBM) to the smaller, extremely fast on-chip -memory (called VMEM on TPU or SMEM on GPU) that the accelerator's computation -units (e.g., TPU's Matrix Units or a GPU's Tensor Cores) use directly. This is -critical for improving the performance of large matrix multiplications, for -example those in the MLP layer at the end of Transformer blocks. - -In Pallas, tiling is controlled by the `BlockSpec`. Learn more in the -[Pallas BlockSpec guide -here](https://docs.jax.dev/en/latest/pallas/grid_blockspec.html#blockspec-a-k-a-how-to-chunk-up-inputs). - -In this section, we'll take two operations that commonly appear together: a -matrix multiplication (like in a `Dense` layer) and a ReLU activation. We will -write a new op that fuses them together for better performance. - ---- -## Original Unoptimized Implementation - - -```python - -class StandardDenseReLU(keras.layers.Layer): - """Standard Matmul and ReLU implementation using keras.ops.""" - - def __init__(self, units, **kwargs): - super().__init__(**kwargs) - self.units = units - - def build(self, input_shape): - self.w = self.add_weight( - shape=(input_shape[-1], self.units), - initializer="glorot_uniform", - trainable=True, - ) - - def call(self, inputs): - # The standard implementation performs two separate operations. - # Each one involves expensive data transfer with the main device memory (HBM). - # 1. Matmul: inputs (HBM) -> compute -> intermediate (HBM) - y = keras.ops.matmul(inputs, self.w) - # 2. ReLU: intermediate (HBM) -> compute -> output (HBM) - return keras.ops.relu(y) - -``` - ---- -## 1. Define the Fused Kernel - -First we create an inner kernel function that defines the fused computation that -combines both matmul (`pl.dot`) and activation (`jnp.maximum`). - - -```python -import jax.numpy as jnp -from jax.experimental import pallas as pl - - -def matmul_relu_kernel(a_ref, b_ref, c_ref): - """Pallas kernel for fused matmul + ReLU.""" - # Perform the matrix multiplication on the local tile - # pl.dot leverages the hardware's Matrix Unit (MXU) - acc = pl.dot(a_ref[...], b_ref[...]) - - # Fusion happens here: apply activation while data is in VMEM - result = jnp.maximum(acc, 0) - - # Write the final result to the output reference - c_ref[...] = result - -``` - ---- -## 2. Specify the Tiling (BlockSpec) - -Since the input matrices are usually too large to fit into VMEM, Pallas needs ot -know how to "slice" them for loading from HBM to VMEM. - -We define this using `BlockSpec` - this tells the hardware: "Take a 128-row -chunk of Matrix A and a 128-column chunk of Matrix B to produce a 128x128 tile -of Matrix C." - - -```python - -@jax.jit -def fused_matmul(a, b): - m, k = a.shape - _, n = b.shape - - # Define tile sizes - tile_m, tile_n = 128, 128 - assert ( - m % tile_m == 0 and n % tile_n == 0 - ), "Inputs must be multiples of 128 for this demo" - - return pl.pallas_call( - matmul_relu_kernel, - # Map output indices to input blocks - out_shape=jax.ShapeDtypeStruct((m, n), a.dtype), - in_specs=[ - # For each output tile, we take a slice of A of shape (tile_m, k) - pl.BlockSpec( - index_map=lambda i, j: (i, 0), block_shape=(tile_m, k) - ), # Matrix A - # For each output tile, we take a slice of B of shape (k, tile_n) - pl.BlockSpec( - index_map=lambda i, j: (0, j), block_shape=(k, tile_n) - ), # Matrix B - ], - out_specs=pl.BlockSpec( - index_map=lambda i, j: (i, j), block_shape=(tile_m, tile_n) - ), # Matrix C - grid=(m // tile_m, n // tile_n), - )(a, b) - - -fused_matmul(jnp.ones((256, 256)), jnp.ones((256, 256))) -``` - - - - -
-``` -Array([[256., 256., 256., ..., 256., 256., 256.], - [256., 256., 256., ..., 256., 256., 256.], - [256., 256., 256., ..., 256., 256., 256.], - ..., - [256., 256., 256., ..., 256., 256., 256.], - [256., 256., 256., ..., 256., 256., 256.], - [256., 256., 256., ..., 256., 256., 256.]], dtype=float32) -``` -
- ---- -## 3. Integrating into a Keras Layer - -Now for the final step, call the jit-compiled `fused_matmul` kernel from a -`keras.Layer`. - - -```python - -class FusedDense(keras.layers.Layer): - """Custom Keras layer that applies the fused Dense and ReLU op.""" - - def __init__(self, units, **kwargs): - super().__init__(**kwargs) - self.units = units - - def build(self, input_shape): - self.w = self.add_weight( - shape=(input_shape[-1], self.units), initializer="glorot_uniform" - ) - - def call(self, inputs): - # Dispatch to our Pallas kernel - return fused_matmul(inputs, self.w.value) - - -FusedDense(256)(jnp.ones((256, 256))) -``` - - - - -
-``` -Array([[0. , 0.511034 , 0.19506836, ..., 0.29304314, 0. , - 0.04899597], - [0. , 0.511034 , 0.19506836, ..., 0.29304314, 0. , - 0.04899597], - [0. , 0.511034 , 0.19506836, ..., 0.29304314, 0. , - 0.04899597], - ..., - [0. , 0.511034 , 0.19506836, ..., 0.29304314, 0. , - 0.04899597], - [0. , 0.511034 , 0.19506836, ..., 0.29304314, 0. , - 0.04899597], - [0. , 0.511034 , 0.19506836, ..., 0.29304314, 0. , - 0.04899597]], dtype=float32) -``` -
- ---- -## 4. Benchmarking the Speedup - - -```python -# 1. Setup Data -N = 8192 # Large enough to be memory bound -input_data = jnp.ones((N, N), dtype="float32") - -# Initialize layers -standard_layer = StandardDenseReLU(units=N) -pallas_layer = FusedDense(units=N) - -# Build layers by calling them once -standard_layer(input_data) -pallas_layer(input_data) - - -def benchmark(layer, x, name, iterations=100): - # Warm up to ensure JIT compilation is finished - for _ in range(10): - layer(x).block_until_ready() - - start_time = time.perf_counter() - for _ in range(iterations): - layer(x).block_until_ready() - end_time = time.perf_counter() - - avg_time = (end_time - start_time) / iterations * 1000 # convert to ms - print(f"{name} Average Latency: {avg_time:.3f} ms") - - -# 2. Run Comparison -print(f"Benchmarking Matrix Size: {N}x{N}\n" + "-" * 30) -benchmark(standard_layer, input_data, "Standard Keras (Matmul + ReLU)") -benchmark(pallas_layer, input_data, "Pallas Fused (Matmul + ReLU)") - -``` - -
-``` -Benchmarking Matrix Size: 8192x8192 ------------------------------- - -Standard Keras (Matmul + ReLU) Average Latency: 7.811 ms - -Pallas Fused (Matmul + ReLU) Average Latency: 35.039 ms -``` -
- -### Why this Works - -**Memory Bandwidth Efficiency:** By fusing the matrix multiplication and -activation, we perform the ReLU computation while data is still in the chip's -fast VMEM. This drastically reduces expensive read/write roundtrips to HBM. - -**Automatic Parallelization:** Pallas handles the "grid" execution, meaning -it automatically parallelizes your defined tiles across the available hardware -cores (whether TPU MXUs or GPU Tensor Cores). - -**Drop-in Inference Speed:** This `FusedDense` kernel can be integrated into any -Keras model, giving an example of improving serving/inference performance with -minimal code changes. - ---- -## 5. Enabling Training - -In order for a Pallas kernel to be trainable, you must also supply -a second kernel to define the custom backward pass, since JAX can't -[AutoGrad](https://docs.jax.dev/en/latest/automatic-differentiation.html) -through Pallas kernels. Without it, you might see an error like this: - -``` -model = keras.Sequential([FusedDense(256)]) -model.compile(optimizer="adam", loss="mse") -model.fit(jnp.ones((256, 256)), jnp.ones((256, 256))) ->>> Linearization failed to produce known values for all output primals. This is -typically caused by attempting to differentiate a function uses an operation -that does not support reverse-mode autodiff. -``` - -To extend our fused matmul example above: - - -```python - -# 1. Define the wrapper with `custom_vjp` using our original `fused_matmul`. -@jax.custom_vjp -def fused_matmul_trainable(x, w): - return fused_matmul(x, w) - - -# 2. Define the Forward Pass -# It must return the output AND "residuals" (data needed for the backward pass) -def fused_matmul_fwd(x, w): - y = fused_matmul_trainable(x, w) - # We save inputs x, w and output y for the backward calculation - return y, (x, w, y) - - -# 3. Define the Backward Pass -# JAX gives us the residuals and the incoming gradient (g) -def fused_matmul_bwd(residuals, g): - x, w, y = residuals - - # Calculate the gradient of ReLU: 1 if y > 0, else 0 - # g is the gradient flowing back from the next layer - grad_relu = g * (y > 0) - - # Standard backprop math for matmul: - # grad_x = grad_relu @ w.T - grad_x = jnp.dot(grad_relu, w.T) - - # grad_w = x.T @ grad_relu - grad_w = jnp.dot(x.T, grad_relu) - - return grad_x, grad_w - - -# 4. Register the forward and backward functions -fused_matmul_trainable.defvjp(fused_matmul_fwd, fused_matmul_bwd) - - -class FusedDenseTrainable(FusedDense): - """Updated layer that contains Pallas forward and backward pass.""" - - def call(self, inputs): - # Dispatch to our trainable Pallas kernel - return fused_matmul_trainable(inputs, self.w.value) - - -# Demonstrate trainability on dummy data -model = keras.Sequential([FusedDenseTrainable(256)]) -model.compile(optimizer="adam", loss="mse") -model.fit(jnp.ones((256, 256)), jnp.ones((256, 256)), batch_size=128) -``` - - -
-``` -2/2 ━━━━━━━━━━━━━━━━━━━━ 0s 65ms/step - loss: 0.6481 - - -``` -
- -# Followups - -In this guide we covered how to define a simple custom Pallas kernel performing vector -addition to include in a Keras model. Then we followed up with a more in-depth -example of a fused matmul + activation kernel that you might use in a real-world -model to improve performance. - -Please refer to the [Pallas -docs](https://docs.jax.dev/en/latest/pallas/index.html#) for further -documentation on writing custom kernels. Additionally to explore more examples -of Pallas kernels, including FlashAttention and MoE layers, check out the -[Tokamax](https://github.com/openxla/tokamax) library. diff --git a/scripts/guides_master.py b/scripts/guides_master.py index 767f6a6eaa..be0ca935e9 100644 --- a/scripts/guides_master.py +++ b/scripts/guides_master.py @@ -143,11 +143,7 @@ { "path": "writing_quantization_compatible_layers", "title": "Writing quantization-compatible layers in Keras", - }, - { - "path": "define_custom_kernel", - "title": "Define a Custom TPU/GPU Kernel", - }, + } # { # "path": "preprocessing_layers", # "title": "Working with preprocessing layers", diff --git a/scripts/tutobooks.py b/scripts/tutobooks.py index 3a63ef48a4..a689eee122 100644 --- a/scripts/tutobooks.py +++ b/scripts/tutobooks.py @@ -224,7 +224,7 @@ def nb_to_md(nb_path, md_path, img_dir, working_dir=None): del_working_dir = False if working_dir is None: del_working_dir = True - working_dir = "tmp_" + str(random.randint(int(1e6), int(1e7))) + working_dir = "tmp_" + str(random.randint(1e6, 1e7)) if not os.path.exists(working_dir): os.makedirs(working_dir) print("Using working_dir:", working_dir) @@ -335,7 +335,7 @@ def validate(py): # Validate style with black tmp = tempfile.gettempdir() - fpath = os.path.join(tmp, str(random.randint(int(1e6), int(1e7))) + ".py") + fpath = os.path.join(tmp, str(random.randint(1e6, 1e7)) + ".py") f = open(fpath, "w") pre_formatting = "\n".join(lines) f.write(pre_formatting)