Skip to content

Commit 6d9e3f2

Browse files
Balandatmeta-codesync[bot]
authored andcommitted
Phase 4: Exception Groups Migration (PEP 654)
Summary: Convert error collection patterns to use Python 3.11 ExceptionGroup instead of ad-hoc exception aggregation. This provides better structured exception handling and enables the use of except* syntax in callers. Files changed: - ax/core/metric.py: Update _unwrap_experiment_data(), _unwrap_trial_data_multi(), and _unwrap_experiment_data_multi() to raise ExceptionGroup with all collected errors instead of chaining to a single exception. - ax/generation_strategy/generation_node.py: Update new_trial_limit() to collect all generation-blocking errors and raise as ExceptionGroup. Reviewed By: saitcakmak Differential Revision: D91648883
1 parent ce4dc42 commit 6d9e3f2

5 files changed

Lines changed: 59 additions & 47 deletions

File tree

ax/core/metric.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ax.core.data import Data
2121
from ax.utils.common.base import SortableBase
2222
from ax.utils.common.logger import get_logger
23-
from ax.utils.common.result import Err, Ok, Result, UnwrapError
23+
from ax.utils.common.result import Err, Ok, Result
2424
from ax.utils.common.serialization import SerializationMixin
2525

2626
if TYPE_CHECKING:
@@ -505,7 +505,6 @@ def _unwrap_experiment_data(cls, results: Mapping[int, MetricFetchResult]) -> Da
505505
result for result in results.values() if isinstance(result, Err)
506506
]
507507

508-
# TODO[mpolson64] Raise all errors in a group via PEP 654
509508
exceptions = [
510509
(
511510
err.err.exception
@@ -515,8 +514,8 @@ def _unwrap_experiment_data(cls, results: Mapping[int, MetricFetchResult]) -> Da
515514
for err in errs
516515
]
517516

518-
raise UnwrapError(errs) from (
519-
exceptions[0] if len(exceptions) == 1 else Exception(exceptions)
517+
raise ExceptionGroup(
518+
f"Failed to fetch data for {len(errs)} trial(s)", exceptions
520519
)
521520

522521
return Data.from_multiple_data(data=[ok.ok for ok in oks])
@@ -561,7 +560,6 @@ def _unwrap_trial_data_multi(
561560
]
562561

563562
if len(critical_errs) > 0:
564-
# TODO[mpolson64] Raise all errors in a group via PEP 654
565563
exceptions = [
566564
(
567565
err.err.exception
@@ -570,8 +568,9 @@ def _unwrap_trial_data_multi(
570568
)
571569
for err in critical_errs
572570
]
573-
raise UnwrapError(critical_errs) from (
574-
exceptions[0] if len(exceptions) == 1 else Exception(exceptions)
571+
raise ExceptionGroup(
572+
f"Failed to fetch data for {len(critical_errs)} critical metric(s)",
573+
exceptions,
575574
)
576575

577576
return Data.from_multiple_data(data=[ok.ok for ok in oks])
@@ -592,7 +591,6 @@ def _unwrap_experiment_data_multi(
592591
result for result in flattened if isinstance(result, Err)
593592
]
594593

595-
# TODO[mpolson64] Raise all errors in a group via PEP 654
596594
exceptions = [
597595
(
598596
err.err.exception
@@ -601,8 +599,8 @@ def _unwrap_experiment_data_multi(
601599
)
602600
for err in errs
603601
]
604-
raise UnwrapError(errs) from (
605-
exceptions[0] if len(exceptions) == 1 else Exception(exceptions)
602+
raise ExceptionGroup(
603+
f"Failed to fetch data for {len(errs)} metric(s)", exceptions
606604
)
607605

608606
return Data.from_multiple_data(data=[ok.ok for ok in oks])

ax/core/tests/test_metric.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,35 @@ def test_wrap_unwrap(self) -> None:
105105
def test_wrap_err(self) -> None:
106106
err = Err(MetricFetchE(message="failed!", exception=Exception("panic!")))
107107

108-
with self.assertRaisesRegex(Exception, "panic"):
109-
Metric._unwrap_experiment_data_multi(results={0: {"foo": err}})
110-
111-
with self.assertRaisesRegex(Exception, "panic"):
112-
Metric._unwrap_experiment_data(results={0: err})
113-
114-
with self.assertRaisesRegex(Exception, "panic"):
115-
Metric._unwrap_trial_data_multi(results={"foo": err})
108+
# With PEP 654 ExceptionGroups, errors are now wrapped in ExceptionGroup.
109+
# We verify both the top-level ExceptionGroup message and that the
110+
# original "panic!" exception is preserved as a nested exception.
111+
with self.subTest("experiment_data_multi"):
112+
with self.assertRaisesRegex(
113+
ExceptionGroup, "Failed to fetch data for 1 metric"
114+
) as cm:
115+
Metric._unwrap_experiment_data_multi(results={0: {"foo": err}})
116+
eg = cm.exception
117+
self.assertEqual(len(eg.exceptions), 1)
118+
self.assertIn("panic", str(eg.exceptions[0]))
119+
120+
with self.subTest("experiment_data"):
121+
with self.assertRaisesRegex(
122+
ExceptionGroup, "Failed to fetch data for 1 trial"
123+
) as cm:
124+
Metric._unwrap_experiment_data(results={0: err})
125+
eg = cm.exception
126+
self.assertEqual(len(eg.exceptions), 1)
127+
self.assertIn("panic", str(eg.exceptions[0]))
128+
129+
with self.subTest("trial_data_multi"):
130+
with self.assertRaisesRegex(
131+
ExceptionGroup, "Failed to fetch data for 1 critical metric"
132+
) as cm:
133+
Metric._unwrap_trial_data_multi(results={"foo": err})
134+
eg = cm.exception
135+
self.assertEqual(len(eg.exceptions), 1)
136+
self.assertIn("panic", str(eg.exceptions[0]))
116137

117138
def test_MetricFetchE(self) -> None:
118139
def foo() -> bool:

ax/generation_strategy/generation_node.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -880,23 +880,27 @@ def new_trial_limit(self, raise_generation_errors: bool = False) -> int:
880880
]
881881

882882
# Raise any necessary generation errors: for any met criterion,
883-
# call its `block_continued_generation_error` method The method might not
884-
# raise an error, depending on its implementation on given criterion, so the
885-
# error from the first met one that does block continued generation, will be
886-
# raised.
883+
# collect all errors from blocking criteria and raise as ExceptionGroup.
887884
if raise_generation_errors:
885+
generation_errors: list[Exception] = []
888886
for criterion in trial_based_gen_blocking_criteria:
889-
# TODO[mgarrard]: Raise a group of all the errors, from each gen-
890-
# blocking transition criterion.
891887
if criterion.is_met(
892888
self.experiment,
893889
curr_node=self,
894890
):
895-
criterion.block_continued_generation_error(
896-
node_name=self.name,
897-
experiment=self.experiment,
898-
trials_from_node=self.trials_from_node,
899-
)
891+
try:
892+
criterion.block_continued_generation_error(
893+
node_name=self.name,
894+
experiment=self.experiment,
895+
trials_from_node=self.trials_from_node,
896+
)
897+
except Exception as e:
898+
generation_errors.append(e)
899+
if generation_errors:
900+
raise ExceptionGroup(
901+
f"Generation blocked by {len(generation_errors)} criteria",
902+
generation_errors,
903+
)
900904
if len(gen_blocking_criterion_delta_from_threshold) == 0:
901905
return -1
902906
return max(min(gen_blocking_criterion_delta_from_threshold), -1)

ax/utils/common/result.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def unwrap(self) -> T:
9191
"""
9292
Returns the contained Ok value.
9393
94-
Because this function may raise an UnwrapError, its use is generally
94+
Because this function may raise a RuntimeError, its use is generally
9595
discouraged. Instead, prefer to handle the Err case explicitly, or call
9696
unwrap_or, unwrap_or_else, or unwrap_or_default.
9797
"""
@@ -103,7 +103,7 @@ def unwrap_err(self) -> E:
103103
"""
104104
Returns the contained Err value.
105105
106-
Because this function may raise an UnwrapError, its use is generally
106+
Because this function may raise a RuntimeError, its use is generally
107107
discouraged. Instead, prefer to handle the Err case explicitly, or call
108108
unwrap_or, unwrap_or_else, or unwrap_or_default.
109109
"""
@@ -178,7 +178,7 @@ def unwrap(self) -> T:
178178
return self._value
179179

180180
def unwrap_err(self) -> NoReturn:
181-
raise UnwrapError(f"Tried to unwrap_err {self}.")
181+
raise RuntimeError(f"Tried to unwrap_err {self}.")
182182

183183
def unwrap_or(self, default: U) -> T:
184184
return self._value
@@ -235,7 +235,7 @@ def map_or_else(self, default_op: Callable[[], U], op: Callable[[T], U]) -> U:
235235
return default_op()
236236

237237
def unwrap(self) -> NoReturn:
238-
raise UnwrapError(f"Tried to unwrap {self}.")
238+
raise RuntimeError(f"Tried to unwrap {self}.")
239239

240240
def unwrap_err(self) -> E:
241241
return self._value
@@ -249,17 +249,6 @@ def unwrap_or_else(self, op: Callable[[E], T]) -> T:
249249
return op(self._value)
250250

251251

252-
class UnwrapError(Exception):
253-
"""
254-
Exception that indicates something has gone wrong in an unwrap call.
255-
256-
This should not happen in real world use and indicates a user has improperly
257-
or unsafely used the Result abstraction.
258-
"""
259-
260-
pass
261-
262-
263252
class ExceptionE:
264253
"""
265254
A class that holds an Exception and can be used as the E type in Result[T, E].

ax/utils/common/tests/test_result.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
# pyre-strict
77

8-
from ax.utils.common.result import Err, Ok, Result, UnwrapError
8+
from ax.utils.common.result import Err, Ok, Result
99
from ax.utils.common.testutils import TestCase
1010

1111

@@ -59,12 +59,12 @@ def h() -> int:
5959

6060
def test_unwrap(self) -> None:
6161
self.assertEqual(self.ok.unwrap(), 0)
62-
with self.assertRaises(UnwrapError):
62+
with self.assertRaises(RuntimeError):
6363
self.ok.unwrap_err()
6464
self.assertEqual(self.ok.unwrap_or(1), 0)
6565
self.assertEqual(self.ok.unwrap_or_else(1), 0)
6666

67-
with self.assertRaises(UnwrapError):
67+
with self.assertRaises(RuntimeError):
6868
self.err.unwrap()
6969
self.assertEqual(self.err.unwrap_err(), "yikes")
7070
self.assertEqual(self.err.unwrap_or(1), 1)

0 commit comments

Comments
 (0)