Skip to content

Commit 0cea6c3

Browse files
mpolson64facebook-github-bot
authored andcommitted
Add summarize method (#3506)
Summary: Pull Request resolved: #3506 Add a convenience function for accessing a DF explaning the experiment state. Note that we do not call this "to_df", which IMO implies some sort of “lossless”-ness and makes it seem more analogous to “to_json” etc The DataFrame computed will contain one row per arm and the following columns: - trial_index: The trial index of the arm - arm_name: The name of the arm - trial_status: The status of the trial (e.g. RUNNING, SUCCEDED, FAILED) - failure_reason: The reason for the failure, if applicable - generation_method: The model_key of the model that generated the arm - generation_node: The name of the ``GenerationNode`` that generated the arm - **METADATA: Any metadata associated with the trial, as specified by the Experiment's runner.run_metadata_report_keys field - **METRIC_NAME: The observed mean of the metric specified, for each metric - **PARAMETER_NAME: The parameter value for the arm, for each parameter Reviewed By: lena-kashtelyan Differential Revision: D70923695
1 parent 203c317 commit 0cea6c3

File tree

2 files changed

+101
-0
lines changed

2 files changed

+101
-0
lines changed

ax/preview/api/client.py

+32
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import Any
1212

1313
import numpy as np
14+
import pandas as pd
1415

1516
from ax.analysis.analysis import ( # Used as a return type
1617
Analysis,
@@ -20,6 +21,7 @@
2021
from ax.analysis.markdown.markdown_analysis import (
2122
markdown_analysis_card_from_analysis_e,
2223
)
24+
from ax.analysis.summary import Summary
2325
from ax.analysis.utils import choose_analyses
2426
from ax.core.experiment import Experiment
2527
from ax.core.metric import Metric
@@ -662,6 +664,36 @@ def compute_analyses(
662664

663665
return cards
664666

667+
def summarize(self) -> pd.DataFrame:
668+
"""
669+
Special convenience method for producing the DataFrame produced by the Summary
670+
Analysis. This method is a convenient way to inspect the state of the
671+
experiment, but because the shape of the resultant DataFrame can change based
672+
on the experiment state both users and Ax developers should prefer to use other
673+
methods for extracting information from the experiment to consume downstream.
674+
675+
The DataFrame computed will contain one row per arm and the following columns
676+
(though empty columns are omitted):
677+
- trial_index: The trial index of the arm
678+
- arm_name: The name of the arm
679+
- trial_status: The status of the trial (e.g. RUNNING, SUCCEDED, FAILED)
680+
- failure_reason: The reason for the failure, if applicable
681+
- generation_node: The name of the ``GenerationNode`` that generated the arm
682+
- **METADATA: Any metadata associated with the trial, as specified by the
683+
Experiment's runner.run_metadata_report_keys field
684+
- **METRIC_NAME: The observed mean of the metric specified, for each metric
685+
- **PARAMETER_NAME: The parameter value for the arm, for each parameter
686+
"""
687+
688+
return (
689+
Summary(omit_empty_columns=True)
690+
.compute(
691+
experiment=self._experiment,
692+
generation_strategy=self._generation_strategy,
693+
)
694+
.df
695+
)
696+
665697
def get_best_parameterization(
666698
self, use_model_predictions: bool = True
667699
) -> tuple[TParameterization, TOutcome, int, str]:

ax/preview/api/tests/test_client.py

+69
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,75 @@ def test_get_next_trials_then_run_trials(self) -> None:
834834
5,
835835
)
836836

837+
def test_summarize(self) -> None:
838+
client = Client()
839+
840+
client.configure_experiment(
841+
experiment_config=ExperimentConfig(
842+
name="test_experiment",
843+
parameters=[
844+
RangeParameterConfig(
845+
name="x1",
846+
parameter_type=ParameterType.FLOAT,
847+
bounds=(0, 1),
848+
),
849+
RangeParameterConfig(
850+
name="x2",
851+
parameter_type=ParameterType.FLOAT,
852+
bounds=(0, 1),
853+
),
854+
],
855+
)
856+
)
857+
client.configure_optimization(objective="foo, bar")
858+
859+
# Get two trials and fail one, giving us a ragged structure
860+
client.get_next_trials(maximum_trials=2)
861+
client.complete_trial(trial_index=0, raw_data={"foo": 1.0, "bar": 2.0})
862+
client.mark_trial_failed(trial_index=1)
863+
864+
summary_df = client.summarize()
865+
866+
self.assertEqual(
867+
{*summary_df.columns},
868+
{
869+
"trial_index",
870+
"arm_name",
871+
"trial_status",
872+
"generation_node",
873+
"foo",
874+
"bar",
875+
"x1",
876+
"x2",
877+
},
878+
)
879+
880+
trial_0_parameters = none_throws(
881+
assert_is_instance(client._experiment.trials[0], Trial).arm
882+
).parameters
883+
trial_1_parameters = none_throws(
884+
assert_is_instance(client._experiment.trials[1], Trial).arm
885+
).parameters
886+
expected = pd.DataFrame(
887+
{
888+
"trial_index": {0: 0, 1: 1},
889+
"arm_name": {0: "0_0", 1: "1_0"},
890+
"trial_status": {0: "COMPLETED", 1: "FAILED"},
891+
"generation_node": {0: "Sobol", 1: "Sobol"},
892+
"foo": {0: 1.0, 1: np.nan}, # NaN because trial 1 failed
893+
"bar": {0: 2.0, 1: np.nan},
894+
"x1": {
895+
0: trial_0_parameters["x1"],
896+
1: trial_1_parameters["x1"],
897+
},
898+
"x2": {
899+
0: trial_0_parameters["x2"],
900+
1: trial_1_parameters["x2"],
901+
},
902+
}
903+
)
904+
pd.testing.assert_frame_equal(summary_df, expected)
905+
837906
def test_compute_analyses(self) -> None:
838907
client = Client()
839908

0 commit comments

Comments
 (0)