diff --git a/ax/analysis/analysis.py b/ax/analysis/analysis.py index ee76d61eb24..cacd2b077f2 100644 --- a/ax/analysis/analysis.py +++ b/ax/analysis/analysis.py @@ -206,10 +206,18 @@ def error_card_from_analysis_e( analysis_name = analysis_e.analysis.__class__.__name__ exception_name = analysis_e.exception.__class__.__name__ + # Include the exception message in the subtitle if available, so users can + # see the reasoning in the error card. + subtitle = ( + f"{exception_name}: {exception_message}" + if (exception_message := str(analysis_e.exception)) + else f"{exception_name} encountered while computing {analysis_name}." + ) + return ErrorAnalysisCard( name=analysis_name, title=f"{analysis_name} Error", - subtitle=f"{exception_name} encountered while computing {analysis_name}.", + subtitle=subtitle, df=pd.DataFrame(), blob=analysis_e.tb_str() or "", ) diff --git a/ax/analysis/tests/test_analysis.py b/ax/analysis/tests/test_analysis.py new file mode 100644 index 00000000000..caf60d5851f --- /dev/null +++ b/ax/analysis/tests/test_analysis.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from ax.analysis.analysis import AnalysisE, error_card_from_analysis_e +from ax.analysis.plotly.parallel_coordinates import ParallelCoordinatesPlot +from ax.utils.common.testutils import TestCase + + +class AnalysisTest(TestCase): + def test_error_card_from_analysis_e(self) -> None: + for exception, expected_subtitle in ( + ( + ValueError("something went wrong"), + "ValueError: something went wrong", + ), + ( + ValueError(), + "ValueError encountered while computing ParallelCoordinatesPlot.", + ), + ): + with self.subTest(exception=exception): + analysis_e = AnalysisE( + message="test", + exception=exception, + analysis=ParallelCoordinatesPlot(), + ) + + card = error_card_from_analysis_e(analysis_e) + + self.assertEqual(card.name, "ParallelCoordinatesPlot") + self.assertEqual(card.title, "ParallelCoordinatesPlot Error") + self.assertEqual(card.subtitle, expected_subtitle) + self.assertIn("ValueError", card.blob) diff --git a/ax/api/tests/test_client.py b/ax/api/tests/test_client.py index 49f764e0bbb..afa38da272e 100644 --- a/ax/api/tests/test_client.py +++ b/ax/api/tests/test_client.py @@ -23,7 +23,6 @@ from ax.api.protocols.metric import IMetric from ax.api.protocols.runner import IRunner from ax.api.types import TParameterization -from ax.core.analysis_card import AnalysisCard from ax.core.data import Data from ax.core.experiment import Experiment from ax.core.map_metric import MapMetric @@ -1280,12 +1279,7 @@ def test_compute_analyses(self) -> None: self.assertEqual(cards[0].title, "ParallelCoordinatesPlot Error") self.assertEqual( cards[0].subtitle, - "AnalysisNotApplicableStateError encountered while computing " - "ParallelCoordinatesPlot.", - ) - self.assertIn( - "Experiment has no trials", - assert_is_instance(cards[0], AnalysisCard).blob, + "AnalysisNotApplicableStateError: Experiment has no trials.", ) for trial_index, _ in client.get_next_trials(max_trials=1).items():