From 7f2f3b765aa1ec13bd88a63c0f5791de64fc25cc Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Thu, 20 Feb 2025 13:08:10 -0800 Subject: [PATCH] Early stopping tutorial (#3389) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3389 Tutorial demonstrating early stopping via the Client using modified hartmann6. Differential Revision: D69484561 --- tutorials/early_stopping/early_stopping.ipynb | 461 ++++++++++++++++++ website/tutorials.json | 6 + 2 files changed, 467 insertions(+) create mode 100644 tutorials/early_stopping/early_stopping.ipynb diff --git a/tutorials/early_stopping/early_stopping.ipynb b/tutorials/early_stopping/early_stopping.ipynb new file mode 100644 index 00000000000..1e8d6e54036 --- /dev/null +++ b/tutorials/early_stopping/early_stopping.ipynb @@ -0,0 +1,461 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "customInput": null, + "isAgentGenerated": false, + "language": "markdown", + "originalKey": "e0bc4e17-2f06-4c45-b5c2-da9ae6121f64", + "outputsInitialized": false, + "showInput": true + }, + "source": [ + "# Ask-tell experimentation with trial-level early stopping\n", + "\n", + "Sometimes, there is stepwise information available on the way to a final measurement.\n", + "The goal of trial-level early stopping is to monitor the results of expensive evaluations with timeseries-like data and terminate those that are unlikely to produce promising results, freeing up resources to explore more configurations.\n", + "\n", + "Like the [ask-tell tutorial](#) we'll be minimizing the Hartmann6 function, but this time we've modified it to incorporate a new parameter $t$ which allows the function to produce timeseries-like data where the value returned is closer and closer to Hartmann6's true value as $t$ increases.\n", + "At $t = 100$ the function will simply return Hartmann6's unaltered value.\n", + "$$\n", + "f(x, t) = hartmann6(x) - log_2(t/100)\n", + "$$\n", + "While the function is synthetic, the workflow captures the intended principles for this tutorial and is similar to the process of training typical machine learning models.\n", + "\n", + "## Learning Objectives\n", + "- Understand when time-series-like data can be used in an optimization experiment\n", + "- Run a simple optimization experiment with early stopping\n", + "- Configure details of an early stopping strategy\n", + "- Analyze the results of the optimization\n", + "\n", + "## Prerequisites\n", + "- Familiarity with Python and basic programming concepts\n", + "- Understanding of [adaptive experimentation](#) and [Bayesian optimization](#)\n", + "- [Ask-tell Optimization of Python Functions](#)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "customInput": null, + "language": "markdown", + "originalKey": "6c37f67a-3e56-4338-947d-915c6e62bd79", + "showInput": false + }, + "source": [ + "## Step 1: Import Necessary Modules\n", + "\n", + "First, ensure you have all the necessary imports:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "customInput": null, + "customOutput": null, + "executionStartTime": 1739309769186, + "executionStopTime": 1739309769452, + "isAgentGenerated": false, + "language": "python", + "originalKey": "288b1d67-ac58-4cbc-b625-26445141ce64", + "outputsInitialized": true, + "requestMsgId": "288b1d67-ac58-4cbc-b625-26445141ce64", + "serverExecutionDuration": 1.428663963452, + "showInput": true + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import plotly.express as px\n", + "import plotly.graph_objects as go\n", + "from ax.early_stopping.strategies import PercentileEarlyStoppingStrategy\n", + "from ax.preview.api.client import Client\n", + "from ax.preview.api.configs import ExperimentConfig, ParameterType, RangeParameterConfig" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "customInput": null, + "isAgentGenerated": false, + "language": "markdown", + "originalKey": "2d90d50e-8258-4fd3-a99e-dc26077a90a7", + "outputsInitialized": false, + "showInput": false + }, + "source": [ + "## Step 2: Initialize the Client\n", + "Create an instance of the `Client` to manage the state of your experiment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "customInput": null, + "customOutput": null, + "executionStartTime": 1739309769456, + "executionStopTime": 1739309770044, + "isAgentGenerated": false, + "language": "python", + "originalKey": "14d7212e-426e-407a-8ad7-3e5c9b9881cb", + "outputsInitialized": true, + "requestMsgId": "14d7212e-426e-407a-8ad7-3e5c9b9881cb", + "serverExecutionDuration": 1.5575950965285, + "showInput": true + }, + "outputs": [], + "source": [ + "client = Client(random_seed=42)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "customInput": null, + "isAgentGenerated": false, + "language": "markdown", + "originalKey": "02887e7c-6e0e-4940-8cd5-4e3aa78ae16c", + "outputsInitialized": false, + "showInput": false + }, + "source": [ + "## Step 3: Configure the Experiment\n", + "\n", + "The `Client` expects a series of `Config`s which define how the experiment will be run.\n", + "We'll set this up the same way as we did in our previous tutorial.\n", + "\n", + "The Hartmann6 is usually evaluated on the hypercube $x_i \\in (0, 1)$, so we will define six identical `RangeParameterConfig`s with the appropriate bounds, and add these to an `ExperimentConfig` along with other metadata about the experiment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "customInput": null, + "customOutput": null, + "executionStartTime": 1739309770048, + "executionStopTime": 1739309770238, + "isAgentGenerated": false, + "language": "python", + "originalKey": "3027314b-2076-4199-9641-6e8b8ec401da", + "outputsInitialized": true, + "requestMsgId": "3027314b-2076-4199-9641-6e8b8ec401da", + "serverExecutionDuration": 1.8919380381703, + "showInput": true + }, + "outputs": [], + "source": [ + "# Define six float parameters for the Hartmann6 function\n", + "parameters = [\n", + " RangeParameterConfig(\n", + " name=f\"x{i + 1}\", parameter_type=ParameterType.FLOAT, bounds=(0, 1)\n", + " )\n", + " for i in range(6)\n", + "]\n", + "\n", + "# Create an experiment configuration\n", + "experiment_config = ExperimentConfig(\n", + " name=\"hartmann6_experiment\",\n", + " parameters=parameters,\n", + " # The following arguments are optional\n", + " description=\"Optimization of the Hartmann6 function\",\n", + " owner=\"developer\",\n", + ")\n", + "\n", + "# Apply the experiment configuration to the client\n", + "client.configure_experiment(experiment_config=experiment_config)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "customInput": null, + "isAgentGenerated": false, + "language": "markdown", + "originalKey": "d73ca67f-eb86-4d74-b4cc-ef211dba66ba", + "outputsInitialized": false, + "showInput": false + }, + "source": [ + "## Step 4: Configure Optimization\n", + "Now, we must set up the optimization objective in `Client`, where `objective` is a string that specifies which metric we would like to optimize and which direction we consider optimal.\n", + "\n", + "We set the objective to be `-hartmann6` to signify that we want to minimize the function." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "customInput": null, + "customOutput": null, + "executionStartTime": 1739309770243, + "executionStopTime": 1739309770451, + "isAgentGenerated": false, + "language": "python", + "originalKey": "237f31f6-cad2-4cfd-8cc8-ade5d5f7cc30", + "outputsInitialized": true, + "requestMsgId": "237f31f6-cad2-4cfd-8cc8-ade5d5f7cc30", + "serverExecutionDuration": 2.1110590314493, + "showInput": true + }, + "outputs": [], + "source": [ + "client.configure_optimization(objective=\"-hartmann6\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "customInput": null, + "isAgentGenerated": false, + "language": "markdown", + "originalKey": "ba47a991-d87b-48b5-88ba-40a72620332a", + "outputsInitialized": false, + "showInput": false + }, + "source": [ + "## Step 5: Run Trials with early stopping\n", + "Here, we will configure the ask-tell loop.\n", + "\n", + "We begin by defining our Hartmann6 function as written above.\n", + "Remember, this is just an example problem and any Python function can be substituted here.\n", + "\n", + "Then we will iteratively do the following:\n", + "* Call `client.get_next_trials` to \"ask\" Ax for a parameterization to evaluate\n", + "* Evaluate `hartmann6_curve` using those parameters in an inner loop to simulate the generation of timeseries data\n", + "* \"Tell\" Ax the partial result using `client.attach_data`\n", + "* Query whether the trial should be stopped via `client.should_stop_trial_early`\n", + "* Stop the underperforming trial and report back to Ax that is has been stopped\n", + "\n", + "This loop will run multiple trials to optimize the function.\n", + "\n", + "Ax will configure an EarlyStoppingStrategy when `should_stop_trial_early` is called for the first time.\n", + "By default Ax uses Percentile early stopping, a strategy where trials are ended early if its performance falls below that of other trials at the same step.\n", + "This is parameterized via a minimum number of \"progressions\" to prevent the trial from stopping prematurely, i.e., before enough data is gathered to make a decision as well as the minimum number of trials with curve data that should complete before trials are allowed to be stopped early." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "customInput": null, + "customOutput": null, + "executionStartTime": 1739309770456, + "executionStopTime": 1739309770668, + "isAgentGenerated": false, + "language": "python", + "originalKey": "b1208371-8d12-43d1-9602-c85b30a38492", + "outputsInitialized": true, + "requestMsgId": "b1208371-8d12-43d1-9602-c85b30a38492", + "serverExecutionDuration": 1.9606580026448, + "showInput": true + }, + "outputs": [], + "source": [ + "# Hartmann6 function\n", + "def hartmann6(x1, x2, x3, x4, x5, x6):\n", + " alpha = np.array([1.0, 1.2, 3.0, 3.2])\n", + " A = np.array([\n", + " [10, 3, 17, 3.5, 1.7, 8],\n", + " [0.05, 10, 17, 0.1, 8, 14],\n", + " [3, 3.5, 1.7, 10, 17, 8],\n", + " [17, 8, 0.05, 10, 0.1, 14]\n", + " ])\n", + " P = 10**-4 * np.array([\n", + " [1312, 1696, 5569, 124, 8283, 5886],\n", + " [2329, 4135, 8307, 3736, 1004, 9991],\n", + " [2348, 1451, 3522, 2883, 3047, 6650],\n", + " [4047, 8828, 8732, 5743, 1091, 381]\n", + " ])\n", + "\n", + " outer = 0.0\n", + " for i in range(4):\n", + " inner = 0.0\n", + " for j, x in enumerate([x1, x2, x3, x4, x5, x6]):\n", + " inner += A[i, j] * (x - P[i, j])**2\n", + " outer += alpha[i] * np.exp(-inner)\n", + " return -outer\n", + "\n", + "# Hartmann6 function with additional t term\n", + "def hartmann6_curve(x1, x2, x3, x4, x5, x6, t):\n", + " return hartmann6(x1, x2, x3, x4, x5, x6) - np.log2(t)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true, + "customInput": null, + "customOutput": null, + "executionStartTime": 1739309770670, + "executionStopTime": 1739309792689, + "isAgentGenerated": false, + "language": "python", + "originalKey": "850df094-3b9b-4527-bd28-ce7d991d22ce", + "outputsInitialized": true, + "requestMsgId": "850df094-3b9b-4527-bd28-ce7d991d22ce", + "serverExecutionDuration": 21860.721127014, + "showInput": true + }, + "outputs": [], + "source": [ + "maximum_progressions = 100 # Observe hartmann6_curve over 100 progressions\n", + "\n", + "for _ in range(30): # Run 30 trials\n", + " trials = client.get_next_trials(maximum_trials=1)\n", + " for trial_index, parameters in trials.items():\n", + " for t in range(1, maximum_progressions + 1):\n", + " raw_data = {\"hartmann6\": hartmann6_curve(t=t, **parameters)}\n", + "\n", + " # On the final reading call complete_trial, else call attach_data\n", + " if t == maximum_progressions:\n", + " client.complete_trial(\n", + " trial_index=trial_index, raw_data=raw_data, progression=t\n", + " )\n", + " else:\n", + " client.attach_data(\n", + " trial_index=trial_index, raw_data=raw_data, progression=t\n", + " )\n", + "\n", + " # If the trial is underperforming, stop it\n", + " if client.should_stop_trial_early(trial_index=trial_index):\n", + " client.mark_trial_early_stopped(\n", + " trial_index=trial_index, raw_data=raw_data, progression=t\n", + " )\n", + " break" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "customInput": null, + "isAgentGenerated": false, + "language": "markdown", + "originalKey": "c2e49a05-899e-4274-9c9a-bc86b4e7be5e", + "outputsInitialized": false, + "showInput": true + }, + "source": [ + "## Step 6: Analyze Results\n", + "\n", + "After running trials, you can analyze the results.\n", + "Most commonly this means extracting the parameterization from the best performing trial you conducted." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "customInput": null, + "customOutput": null, + "executionStartTime": 1739309857743, + "executionStopTime": 1739309861050, + "isAgentGenerated": false, + "language": "python", + "originalKey": "c196b391-6966-456e-9400-2aef149595ff", + "outputsInitialized": true, + "requestMsgId": "c196b391-6966-456e-9400-2aef149595ff", + "serverExecutionDuration": 3080.3436069982, + "showInput": true + }, + "outputs": [], + "source": [ + "best_parameters, prediction, index, name = client.get_best_parameterization()\n", + "print(\"Best Parameters:\", best_parameters)\n", + "print(\"Prediction (mean, variance):\", prediction)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "customInput": null, + "isAgentGenerated": false, + "language": "markdown", + "originalKey": "767ee110-8ab8-4856-805b-a1b8b7912d4e", + "outputsInitialized": false, + "showInput": true + }, + "source": [ + "## Step 7: Compute Analyses\n", + "\n", + "Ax can also produce a number of analyses to help interpret the results of the experiment via `client.compute_analyses`.\n", + "Users can manually select which analyses to run, or can allow Ax to select which would be most relevant.\n", + "In this case Ax selects the following:\n", + "* **Parrellel Coordinates Plot** shows which parameterizations were evaluated and what metric values were observed -- this is useful for getting a high level overview of how thoroughly the search space was explored and which regions tend to produce which outcomes\n", + "* **Interaction Analysis Plot** shows which parameters have the largest affect on the function and plots the most important parameters as 1 or 2 dimensional surfaces\n", + "* **Summary** lists all trials generated along with their parameterizations, observations, and miscellaneous metadata" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "customOutput": null, + "executionStartTime": 1739309914120, + "executionStopTime": 1739309919999, + "isAgentGenerated": false, + "language": "python", + "originalKey": "9f0fcad5-74c8-409d-bd49-47e8112c664c", + "outputsInitialized": true, + "requestMsgId": "9f0fcad5-74c8-409d-bd49-47e8112c664c", + "serverExecutionDuration": 5307.3820680147 + }, + "outputs": [], + "source": [ + "client.compute_analyses(display=True) # By default Ax will display the AnalysisCards produced by compute_analyses" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "customInput": null, + "isAgentGenerated": false, + "language": "markdown", + "originalKey": "b21c1ff4-3472-48ef-9c32-103cf9a17d01", + "outputsInitialized": false, + "showInput": true + }, + "source": [ + "## Conclusion\n", + "\n", + "This tutorial demonstates Ax's early stopping capabilities, which utilize timeseries-like data to monitor the results of expensive evaluations and terminate those that are unlikely to produce promising results, freeing up resources to explore more configurations.\n", + "This can be used in a number of applications, and is especially useful in machine learning contexts." + ] + } + ], + "metadata": { + "fileHeader": "", + "fileUid": "d58f434d-b29c-4c6d-a882-70fdf8fb4978", + "isAdHoc": false, + "kernelspec": { + "display_name": "python3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/website/tutorials.json b/website/tutorials.json index cedd47c994c..21f59d2a340 100644 --- a/website/tutorials.json +++ b/website/tutorials.json @@ -8,5 +8,11 @@ "id": "human_in_the_loop", "title": "Ask-tell Optimization in a Human-in-the-loop Setting" } + ], + "Advanced features": [ + { + "id": "early_stopping", + "title": "Ask-tell experimentation with early stopping" + } ] }