From b5fd7900b3a9ab8954d0d422fb950cbde54987d4 Mon Sep 17 00:00:00 2001 From: Miles Olson Date: Wed, 19 Feb 2025 09:47:51 -0800 Subject: [PATCH] Closed loop (scheduler) tutorial (#3391) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/3391 Tutorial demonstrating custom metrics, runner, and run_n_trials (scheduler) Differential Revision: D69810152 --- tutorials/closed_loop/closed_loop.ipynb | 624 ++++++++++++++++++++++++ website/tutorials.json | 4 + 2 files changed, 628 insertions(+) create mode 100644 tutorials/closed_loop/closed_loop.ipynb diff --git a/tutorials/closed_loop/closed_loop.ipynb b/tutorials/closed_loop/closed_loop.ipynb new file mode 100644 index 00000000000..8717a1dc75f --- /dev/null +++ b/tutorials/closed_loop/closed_loop.ipynb @@ -0,0 +1,624 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "customInput": null, + "isAgentGenerated": false, + "language": "markdown", + "originalKey": "edea1464-e04b-4c70-b857-db86ac113945", + "outputsInitialized": false, + "showInput": false + }, + "source": [ + "# Closed-loop Optimization with Ax\n", + "Previously, we've discussed how to use Ax for ask-tell optimization, a paradigm in which we \"ask\" Ax for candidate configurations and \"tell\" Ax our observations.\n", + "This can be effective in many scenerios, and it can be automated through use of flow control statements like `for` and `while` loops.\n", + "However there are some situations where it would be beneficial to allow Ax to orchestrate the entire optimization: deploying trials to external systems, polling their status, and reading reading their results.\n", + "This can be common in a number of real world engineering tasks, including:\n", + "* Large scale machine learning experiments running workloads on high-performance computing clusters\n", + "* A/B tests conducted using an external experimentation platform\n", + "* Materials science optimizations utilizing a self-driving laboratory\n", + "\n", + "Ax's `Client` can orchestrate adaptive experiments like this using its method `run_trials`.\n", + "Users create custom classes which implement Ax's `IMetric` and `IRunner` protocols to handle data fetching and trial deployment respectively.\n", + "Then, users simply configure their `Client` as they would normally and call `run_trials`; Ax will deploy trials, fetch data, generate candidates, and repeat as necessary.\n", + "Ax can also handle complex aspects of orchestration, such as launching many trials in parallel up to a user-specified concurrency limit and gracefully handling trial failure, continuing the experiment even if individual trials fail to execute or essential data fails to fetch.\n", + "\n", + "In this tutorial we will optimize the Hartmann6 function as before, but we will configure custom Runners and Metrics to mimic an external execution system.\n", + "The Runner will calculate Hartmann6 with the appropriate parameters, write the result to a file, and tell Ax the trial is ready after 5 seconds.\n", + "The Metric will find the appropriate file and report the results back to Ax.\n", + "\n", + "### Learning Objectives\n", + "* Learn when it can be appropriate and/or advantageous to run Ax in a closed-loop\n", + "* Configure custom Runners and Metrics, allowing Ax to deploy trials and fetch data automatically\n", + "* Understand tradeoffs between parallelism and optimization performance\n", + "\n", + "### Prerequisites\n", + "* Understanding of adaptive experimentation and Bayesian optimization (see [Introduction to Adaptive Experimentation](#) and [Introduction to Bayesian Optimization](#))\n", + "* Familiarity with [configuring and conducting experiments in Ax](#)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "customInput": null, + "isAgentGenerated": false, + "language": "markdown", + "originalKey": "a19e6f85-aebc-40c0-b2e8-f0b0692c23c9", + "outputsInitialized": false, + "showInput": false + }, + "source": [ + "## Step 1: Import Necessary Modules\n", + "\n", + "First, ensure you have all the necessary imports:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "executionStartTime": 1739574715872, + "executionStopTime": 1739574716127, + "isAgentGenerated": false, + "language": "python", + "originalKey": "270c46ff-d6ac-43b9-a0fa-bc3b7024f5b8", + "outputsInitialized": false, + "requestMsgId": "270c46ff-d6ac-43b9-a0fa-bc3b7024f5b8", + "serverExecutionDuration": 1.8916884437203 + }, + "outputs": [], + "source": [ + "import os\n", + "import time\n", + "from typing import Any, Mapping\n", + "\n", + "import numpy as np\n", + "from ax.preview.api.client import Client\n", + "from ax.preview.api.configs import (\n", + " ExperimentConfig,\n", + " OrchestrationConfig,\n", + " ParameterType,\n", + " RangeParameterConfig,\n", + ")\n", + "from ax.preview.api.protocols.metric import IMetric\n", + "from ax.preview.api.protocols.runner import IRunner, TrialStatus\n", + "from ax.preview.api.types import TParameterization" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "customInput": null, + "isAgentGenerated": false, + "language": "markdown", + "originalKey": "3224c8a0-ee09-44c8-9fd1-42e1090a0819", + "outputsInitialized": false, + "showInput": false + }, + "source": [ + "# Step 2: Defining our custom Runner and Metric\n", + "\n", + "As stated before, we will be creating custom Runner and Metric classes to mimic an external system.\n", + "Let's start by defining our Hartmann6 function as before." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "customInput": null, + "executionStartTime": 1739574716130, + "executionStopTime": 1739574716324, + "isAgentGenerated": false, + "language": "python", + "originalKey": "d60a68d4-8050-47e1-8cf7-0341ac50239b", + "outputsInitialized": true, + "requestMsgId": "d60a68d4-8050-47e1-8cf7-0341ac50239b", + "serverExecutionDuration": 4.697322845459, + "showInput": true + }, + "outputs": [], + "source": [ + "# Hartmann6 function\n", + "def hartmann6(x1, x2, x3, x4, x5, x6):\n", + " alpha = np.array([1.0, 1.2, 3.0, 3.2])\n", + " A = np.array([\n", + " [10, 3, 17, 3.5, 1.7, 8],\n", + " [0.05, 10, 17, 0.1, 8, 14],\n", + " [3, 3.5, 1.7, 10, 17, 8],\n", + " [17, 8, 0.05, 10, 0.1, 14]\n", + " ])\n", + " P = 10**-4 * np.array([\n", + " [1312, 1696, 5569, 124, 8283, 5886],\n", + " [2329, 4135, 8307, 3736, 1004, 9991],\n", + " [2348, 1451, 3522, 2883, 3047, 6650],\n", + " [4047, 8828, 8732, 5743, 1091, 381]\n", + " ])\n", + "\n", + " outer = 0.0\n", + " for i in range(4):\n", + " inner = 0.0\n", + " for j, x in enumerate([x1, x2, x3, x4, x5, x6]):\n", + " inner += A[i, j] * (x - P[i, j])**2\n", + " outer += alpha[i] * np.exp(-inner)\n", + " return -outer\n", + "\n", + "hartmann6(0.1, 0.45, 0.8, 0.25, 0.552, 1.0)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "customInput": null, + "isAgentGenerated": false, + "language": "markdown", + "originalKey": "aff023b3-5f72-48ac-ad54-ef683ee32113", + "outputsInitialized": false, + "showInput": false + }, + "source": [ + "Next, we will define the `MockRunner`.\n", + "The `MockRunner` requires two methods: `run_trial` and `poll_trial`.\n", + "\n", + "`run_trial` deploys a trial to the external system with the given parameters.\n", + "In this case, we will simply save a file containing the result of a call to the Hartmann6 function.\n", + "\n", + "`poll_trial` queries the external system to see if the trial has completed, failed, or if it's still running.\n", + "In this mock example, we will check to see how many passed have elapsed since the `run_trial` was called and only report a trial as completed once 5 seconds have elapsed.\n", + "\n", + "Runner's may also optionally implement a `stop_trial` method to terminate a trial's execution before it has completed.\n", + "This is necessary for using [early stopping](#) in closed-loop experimentation, but we will skip this for now." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "customInput": null, + "executionStartTime": 1739574716327, + "executionStopTime": 1739574716508, + "isAgentGenerated": false, + "language": "python", + "originalKey": "38bd2eba-086b-4cbe-a24c-eb59853d7927", + "outputsInitialized": false, + "requestMsgId": "38bd2eba-086b-4cbe-a24c-eb59853d7927", + "serverExecutionDuration": 1.8612169660628, + "showInput": true + }, + "outputs": [], + "source": [ + "class MockRunner(IRunner):\n", + " def run_trial(\n", + " self, trial_index: int, parameterization: TParameterization\n", + " ) -> dict[str, Any]:\n", + " file_name = f\"{int(time.time())}.txt\"\n", + "\n", + " x1 = parameterization[\"x1\"]\n", + " x2 = parameterization[\"x2\"]\n", + " x3 = parameterization[\"x3\"]\n", + " x4 = parameterization[\"x4\"]\n", + " x5 = parameterization[\"x5\"]\n", + " x6 = parameterization[\"x6\"]\n", + "\n", + " result = hartmann6(x1, x2, x3, x4, x5, x6)\n", + "\n", + " with open(file_name, \"w\") as f:\n", + " f.write(f\"{result}\")\n", + "\n", + " return {\"file_name\": file_name}\n", + "\n", + " def poll_trial(\n", + " self, trial_index: int, trial_metadata: Mapping[str, Any]\n", + " ) -> TrialStatus:\n", + " file_name = trial_metadata[\"file_name\"]\n", + " time_elapsed = time.time() - int(file_name[:4])\n", + "\n", + " if time_elapsed < 5:\n", + " return TrialStatus.RUNNING\n", + "\n", + " return TrialStatus.COMPLETED" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "customInput": null, + "isAgentGenerated": false, + "language": "markdown", + "originalKey": "a5b82b27-7c97-4ae7-8495-1b1cd43fe3fb", + "outputsInitialized": false, + "showInput": false + }, + "source": [ + "It's worthwhile to instantiate your Runner and test it is behaving as expected.\n", + "Let's deploy a mock trial by manually calling `run_trial` and ensuring it creates a file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "customInput": null, + "executionStartTime": 1739574716511, + "executionStopTime": 1739574716704, + "isAgentGenerated": false, + "language": "python", + "originalKey": "42a853a2-ea90-48a8-b6c8-e852e76d7669", + "outputsInitialized": true, + "requestMsgId": "42a853a2-ea90-48a8-b6c8-e852e76d7669", + "serverExecutionDuration": 3.4155221655965, + "showInput": true + }, + "outputs": [], + "source": [ + "runner = MockRunner()\n", + "\n", + "trial_metadata = runner.run_trial(\n", + " trial_index=-1,\n", + " parameterization={\n", + " \"x1\": 0.1,\n", + " \"x2\": 0.45,\n", + " \"x3\": 0.8,\n", + " \"x4\": 0.25,\n", + " \"x5\": 0.552,\n", + " \"x6\": 1.0,\n", + " },\n", + ")\n", + "\n", + "os.path.exists(trial_metadata[\"file_name\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "customInput": null, + "isAgentGenerated": false, + "language": "markdown", + "originalKey": "c480152b-d5f0-4649-9243-cba145f19ab1", + "outputsInitialized": false, + "showInput": false + }, + "source": [ + "Now, we will implement the Metric.\n", + "Metrics only need to implement a `fetch` method, which returns a progression value (i.e. a step in a timeseries) and an observation value.\n", + "Note that the observation can either be a simple float or a (mean, SEM) pair if the external system can report observed noise.\n", + "\n", + "In this case we have neither a relevant progression value nor observed noise so we will simply read the file and report `(0, value)`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "customInput": null, + "executionStartTime": 1739574716706, + "executionStopTime": 1739574716923, + "isAgentGenerated": false, + "language": "python", + "originalKey": "17091378-c5c5-446c-bd1c-8f13622cd0dd", + "outputsInitialized": false, + "requestMsgId": "17091378-c5c5-446c-bd1c-8f13622cd0dd", + "serverExecutionDuration": 1.5956750139594, + "showInput": true + }, + "outputs": [], + "source": [ + "class MockMetric(IMetric):\n", + " def fetch(\n", + " self,\n", + " trial_index: int,\n", + " trial_metadata: Mapping[str, Any],\n", + " ) -> tuple[int, float | tuple[float, float]]:\n", + " file_name = trial_metadata[\"file_name\"]\n", + "\n", + " with open(file_name, 'r') as file:\n", + " value = float(file.readline())\n", + " return (0, value)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "customInput": null, + "isAgentGenerated": false, + "language": "markdown", + "originalKey": "200a6713-bbc3-49bd-b277-2f68ebe2c083", + "outputsInitialized": false, + "showInput": false + }, + "source": [ + "Again, lets make sure the metric wroks by instantiating it and checking using the file generated when we tested the Runner." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "customInput": null, + "executionStartTime": 1739574716926, + "executionStopTime": 1739574717552, + "isAgentGenerated": false, + "language": "python", + "originalKey": "2d695a30-82cb-4a79-b3d0-a7ea44570617", + "outputsInitialized": true, + "requestMsgId": "2d695a30-82cb-4a79-b3d0-a7ea44570617", + "serverExecutionDuration": 4.006918054074, + "showInput": true + }, + "outputs": [], + "source": [ + "# Note: all Metrics must have a name. This will become relevant when attaching metrics to the Client\n", + "hartmann6_metric = MockMetric(name=\"hartmann6\")\n", + "\n", + "hartmann6_metric.fetch(trial_index=-1, trial_metadata=trial_metadata)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "customInput": null, + "isAgentGenerated": false, + "language": "markdown", + "originalKey": "8842c8d2-e531-42c4-8269-7cf15801214b", + "outputsInitialized": false, + "showInput": false + }, + "source": [ + "## Step 3: Initialize the Client and Configure the Experiment\n", + "\n", + "Finally, we can initialize the `Client` and configure the experiment as before.\n", + "This will be familiar to readers of the [Ask-tell optimization with Ax tutorial](#) -- the only difference is we will attach the previously defined Runner and Metric by calling `configure_runner` and `configure_metrics` respectively.\n", + "\n", + "Note that when initializing `hartmann6_metric` we set `name=hartmann6`, matching the objective we now set in `configure_optimization`. The `configure_metrics` method uses this name to ensure that data fetched by this Metric is used correctly during the experiment.\n", + "Be careful to make sure if you intend to use a Metric as an objective or outcome constraint its name is set appropriately." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "customInput": null, + "executionStartTime": 1739574717554, + "executionStopTime": 1739574717808, + "isAgentGenerated": false, + "language": "python", + "originalKey": "0519a6b6-2197-48bb-b3d5-8156200082a8", + "outputsInitialized": false, + "requestMsgId": "0519a6b6-2197-48bb-b3d5-8156200082a8", + "serverExecutionDuration": 2.8515672311187, + "showInput": true + }, + "outputs": [], + "source": [ + "client = Client()\n", + "# Define six float parameters for the Hartmann6 function\n", + "parameters = [\n", + " RangeParameterConfig(\n", + " name=f\"x{i + 1}\", parameter_type=ParameterType.FLOAT, bounds=(0, 1)\n", + " )\n", + " for i in range(6)\n", + "]\n", + "\n", + "# Create an experiment configuration\n", + "experiment_config = ExperimentConfig(\n", + " name=\"hartmann6_experiment\",\n", + " parameters=parameters,\n", + " # The following arguments are optional\n", + " description=\"Optimization of the Hartmann6 function\",\n", + " owner=\"developer\",\n", + ")\n", + "\n", + "# Apply the experiment configuration to the client\n", + "client.configure_experiment(experiment_config=experiment_config)\n", + "client.configure_optimization(objective=\"-hartmann6\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "customInput": null, + "executionStartTime": 1739574717810, + "executionStopTime": 1739574718376, + "isAgentGenerated": false, + "language": "python", + "originalKey": "e0b2b4b1-53e2-4200-b351-5cf08165ddb9", + "outputsInitialized": false, + "requestMsgId": "e0b2b4b1-53e2-4200-b351-5cf08165ddb9", + "serverExecutionDuration": 1.3919230550528, + "showInput": true + }, + "outputs": [], + "source": [ + "client.configure_runner(runner=runner)\n", + "client.configure_metrics(metrics=[hartmann6_metric])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "customInput": null, + "isAgentGenerated": false, + "language": "markdown", + "originalKey": "5203befa-80ac-4bc6-856f-1739eebebead", + "outputsInitialized": false, + "showInput": false + }, + "source": [ + "## Step 5: Run trials\n", + "Now the the `Client` has been configured we can begin running trials.\n", + "\n", + "Internally, Ax uses a class called the `Scheduler` orchestrate the trial deployment, polling, data fetching, and candidate generation.\n", + "\n", + "TODO: IMAGE https://pytorch.org/tutorials/intermediate/ax_multiobjective_nas_tutorial.html\n", + "\n", + "The `OrchestrationConfig` provides users with control over various orchestration settings:\n", + "* `parallelism` controls how many trials may be run at once. Your external system may be able to support multiple evaluations in parallel; increasing this number can greately decrease the time it takes to conduct your experiment. However, it is important to note that **as parallelism increases, optimiztion performance often decreases.** This is because adaptive experimentation methods rely on previously observed data for candidate generation -- the more trials that have been observed by the time a new candidate needs to be generated the better Ax's model for candidate generation will be.\n", + "* `tolerated_trial_failure_rate` controls what proportion of trials are allowed to fail before Ax raises an Exception. Depending on how expensive a single trial is to evaluate or how unreliable trials are expected to be, the experimenter may want to be notified as soon as a single trial fails or they may not care until more than half the trials are failing. Set this value as is appropriate for your context.\n", + "* `initial_seconds_between_polls` controls how often we poll the status of a trial and how often we attempt to fetch results. Set this to be low for trials that are expected to complete quickly or high for trials the are expected to take a long time.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "customInput": null, + "executionStartTime": 1739574718380, + "executionStopTime": 1739574718664, + "isAgentGenerated": false, + "language": "python", + "originalKey": "4fa62073-6de1-4fe5-9b42-1e0f4de20a96", + "outputsInitialized": false, + "requestMsgId": "4fa62073-6de1-4fe5-9b42-1e0f4de20a96", + "serverExecutionDuration": 1.3189618475735, + "showInput": true + }, + "outputs": [], + "source": [ + "orchestration_config = OrchestrationConfig(\n", + " parallelism=3,\n", + " tolerated_trial_failure_rate=0.1,\n", + " initial_seconds_between_polls=1,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "customInput": null, + "customOutput": null, + "executionStartTime": 1739574738021, + "executionStopTime": 1739574820879, + "isAgentGenerated": false, + "language": "python", + "originalKey": "611981e8-82a1-402f-b097-cf536252e271", + "outputsInitialized": true, + "requestMsgId": "611981e8-82a1-402f-b097-cf536252e271", + "serverExecutionDuration": 82613.62616485, + "showInput": true + }, + "outputs": [], + "source": [ + "client.run_trials(maximum_trials=30, options=orchestration_config)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "customInput": null, + "isAgentGenerated": false, + "language": "markdown", + "originalKey": "9415ff80-5577-4e45-8eda-36a14de96475", + "outputsInitialized": false, + "showInput": true + }, + "source": [ + "## Step 6: Analyze Results\n", + "As before, Ax can compute the best parameterization observed and produce a number of analyses to help interpret the results of the experiment.\n", + "\n", + "It is also worth noting that the experiment can be resumed at any time using Ax's storage functionality.\n", + "When configured to use a SQL databse, the `Client` saves a snapshot of itself at various points throughout the call to `run_trials`, making it incredibly easy to continue optimization after an unexpected failure.\n", + "You can learn more about storage in Ax [here](#)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "customInput": null, + "customOutput": null, + "executionStartTime": 1739574859795, + "executionStopTime": 1739574867583, + "isAgentGenerated": false, + "language": "python", + "originalKey": "ec93c967-b07e-4b1e-b30e-c5597fa70962", + "outputsInitialized": true, + "requestMsgId": "ec93c967-b07e-4b1e-b30e-c5597fa70962", + "serverExecutionDuration": 7511.262876913, + "showInput": true + }, + "outputs": [], + "source": [ + "best_parameters, prediction, index, name = client.get_best_parameterization()\n", + "print(\"Best Parameters:\", best_parameters)\n", + "print(\"Prediction (mean, variance):\", prediction)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false, + "customInput": null, + "executionStartTime": 1739917888088, + "executionStopTime": 1739917888552, + "isAgentGenerated": false, + "language": "python", + "originalKey": "df68d3a3-1a05-4d8a-82df-1745bb597926", + "outputsInitialized": false, + "requestMsgId": "7473e106-47bb-456b-af75-32e50650a8af", + "serverExecutionDuration": 263.48441885784, + "showInput": true + }, + "outputs": [], + "source": [ + "client.compute_analyses()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "customInput": null, + "isAgentGenerated": false, + "language": "markdown", + "originalKey": "3af7500a-f513-4a7a-a7b7-b0bfdc577580", + "outputsInitialized": false, + "showInput": true + }, + "source": [ + "## Conclusion\n", + "\n", + "This tutorial demonstrates how to use Ax's `Client` for closed-loop optimization using the Hartmann6 function as an example.\n", + "This style of optimization is useful in scenarios where trials are evaluated on some external system or when experimenters wish to take advantage of parallel evaluation, trial failure handling, or simply to manage long-running optimization tasks without human intervention.\n", + "You can define your own Runner and Metric classes to communicate with whatever external systems you wish to interface with, and control optimization using the `OrchestrationConfig`.\n" + ] + } + ], + "metadata": { + "fileHeader": "", + "fileUid": "f1dfa50c-a38c-494b-82e1-2f6e393f1d13", + "isAdHoc": false, + "kernelspec": { + "display_name": "python3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/website/tutorials.json b/website/tutorials.json index 21f59d2a340..877ea6405e4 100644 --- a/website/tutorials.json +++ b/website/tutorials.json @@ -7,6 +7,10 @@ { "id": "human_in_the_loop", "title": "Ask-tell Optimization in a Human-in-the-loop Setting" + }, + { + "id": "closed_loop", + "title": "Closed-loop Optimization with Ax" } ], "Advanced features": [