Skip to content

Commit 8812057

Browse files
Balandatmeta-codesync[bot]
authored andcommitted
Phase 2: Migrate to StrEnum (#4868)
Summary: Pull Request resolved: #4868 Convert (str, Enum) and Enum patterns to Python 3.11's built-in StrEnum class. This removes the need for the dual inheritance pattern and removes pyre-fixme annotations that were needed to annotate enum values with str type. Files changed: - ax/utils/common/constants.py - Keys(str, Enum) -> Keys(StrEnum) - ax/storage/utils.py - MetricIntent(enum.Enum) -> MetricIntent(StrEnum) - ax/core/auxiliary.py - AuxiliaryExperimentPurpose(Enum) -> AuxiliaryExperimentPurpose(StrEnum) - ax/utils/stats/model_fit_stats.py - ModelFitMetricDirection(Enum) -> ModelFitMetricDirection(StrEnum) - ax/metrics/chemistry.py - ChemistryProblemType(Enum) -> ChemistryProblemType(StrEnum) - ax/metrics/sklearn.py - SklearnModelType(Enum), SklearnDataset(Enum) -> StrEnum Reviewed By: saitcakmak Differential Revision: D91648882 fbshipit-source-id: 56eaaced06aafe8dc1743f6633562e6afcf3224f
1 parent a412ff0 commit 8812057

6 files changed

Lines changed: 20 additions & 26 deletions

File tree

ax/core/auxiliary.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from abc import ABC
1111
from dataclasses import dataclass
12-
from enum import Enum, unique
12+
from enum import StrEnum, unique
1313
from typing import TYPE_CHECKING
1414

1515
from ax.core.data import Data
@@ -51,7 +51,7 @@ def _unique_id(self) -> str:
5151

5252

5353
@unique
54-
class AuxiliaryExperimentPurpose(Enum):
54+
class AuxiliaryExperimentPurpose(StrEnum):
5555
# BOPE Aux Experiment Usage pattern:
5656
# 1. Run the exploratory batch for the main / BO experiment.
5757
# 2. Use the BO experiment as the auxiliary experiment for the PE experiment

ax/metrics/chemistry.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from __future__ import annotations
3434

3535
from dataclasses import dataclass
36-
from enum import Enum
36+
from enum import StrEnum
3737
from functools import lru_cache
3838
from pathlib import Path
3939
from typing import Any
@@ -48,11 +48,9 @@
4848
from pyre_extensions import none_throws
4949

5050

51-
class ChemistryProblemType(Enum):
52-
# pyre-fixme[35]: Target cannot be annotated.
53-
SUZUKI: str = "suzuki"
54-
# pyre-fixme[35]: Target cannot be annotated.
55-
DIRECT_ARYLATION: str = "direct_arylation"
51+
class ChemistryProblemType(StrEnum):
52+
SUZUKI = "suzuki"
53+
DIRECT_ARYLATION = "direct_arylation"
5654

5755

5856
@dataclass(frozen=True)

ax/metrics/sklearn.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from __future__ import annotations
1010

1111
from copy import deepcopy
12-
from enum import Enum
12+
from enum import StrEnum
1313
from functools import lru_cache
1414
from math import sqrt
1515
from typing import Any
@@ -28,20 +28,15 @@
2828
from sklearn.neural_network import MLPClassifier, MLPRegressor
2929

3030

31-
class SklearnModelType(Enum):
32-
# pyre-fixme[35]: Target cannot be annotated.
33-
RF: str = "rf"
34-
# pyre-fixme[35]: Target cannot be annotated.
35-
NN: str = "nn"
31+
class SklearnModelType(StrEnum):
32+
RF = "rf"
33+
NN = "nn"
3634

3735

38-
class SklearnDataset(Enum):
39-
# pyre-fixme[35]: Target cannot be annotated.
40-
DIGITS: str = "digits"
41-
# pyre-fixme[35]: Target cannot be annotated.
42-
BOSTON: str = "boston"
43-
# pyre-fixme[35]: Target cannot be annotated.
44-
CANCER: str = "cancer"
36+
class SklearnDataset(StrEnum):
37+
DIGITS = "digits"
38+
BOSTON = "boston"
39+
CANCER = "cancer"
4540

4641

4742
@lru_cache(maxsize=8)

ax/storage/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import enum
1010
from collections import OrderedDict
1111
from collections.abc import Mapping
12+
from enum import StrEnum
1213
from hashlib import md5
1314

1415
import pandas as pd
@@ -25,7 +26,7 @@ class DomainType(enum.Enum):
2526
DERIVED = 4
2627

2728

28-
class MetricIntent(enum.Enum):
29+
class MetricIntent(StrEnum):
2930
"""Class for enumerating metric use types."""
3031

3132
OBJECTIVE = "objective"

ax/utils/common/constants.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# pyre-strict
88

9-
from enum import Enum, unique
9+
from enum import StrEnum, unique
1010

1111
# ------------------------- Miscellaneous -------------------------
1212

@@ -36,7 +36,7 @@
3636

3737

3838
@unique
39-
class Keys(str, Enum):
39+
class Keys(StrEnum):
4040
"""Enum of reserved keys in options dicts etc, alphabetized.
4141
4242

ax/utils/stats/model_fit_stats.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
# pyre-strict
77

88
from collections.abc import Mapping
9-
from enum import Enum
9+
from enum import StrEnum
1010
from logging import Logger
1111
from typing import Protocol
1212

@@ -33,7 +33,7 @@
3333
KENDALL_TAU_RANK_CORRELATION = "Kendall tau rank correlation"
3434

3535

36-
class ModelFitMetricDirection(Enum):
36+
class ModelFitMetricDirection(StrEnum):
3737
"""Model fit metric directions."""
3838

3939
MINIMIZE = "minimize"

0 commit comments

Comments
 (0)