Skip to content

Commit 363e73a

Browse files
ItsMrLinfacebook-github-bot
authored andcommitted
Move generation_strategy to a dedicated folder (#3287)
Summary: Moving the following from `ax/modelbridge/` to the new `ax/generation_strategy` directory ``` best_model_selector.py dispatch_utils.py external_generation_node.py generation_node_input_constructors.py generation_node.py generation_strategy.py model_spec.py transition_criterion.py ``` Reviewed By: saitcakmak Differential Revision: D68720587
1 parent 4b8ab75 commit 363e73a

21 files changed

+4769
-4524
lines changed

ax/generation_strategy/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# This source code is licensed under the MIT license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pyre-strict
8+
9+
from __future__ import annotations
10+
11+
from abc import ABC, abstractmethod
12+
from collections.abc import Callable
13+
from enum import Enum
14+
from functools import partial
15+
from typing import Any, Union
16+
17+
import numpy as np
18+
import numpy.typing as npt
19+
from ax.exceptions.core import UserInputError
20+
from ax.generation_strategy.model_spec import GeneratorSpec
21+
from ax.utils.common.base import Base
22+
from pyre_extensions import none_throws
23+
24+
# pyre-fixme[24]: Generic type `np.ndarray` expects 2 type parameters.
25+
ARRAYLIKE = Union[np.ndarray, list[float], list[np.ndarray]]
26+
27+
28+
class BestModelSelector(ABC, Base):
29+
@abstractmethod
30+
def best_model(self, model_specs: list[GeneratorSpec]) -> GeneratorSpec:
31+
"""Return the best ``GeneratorSpec`` based on some criteria.
32+
33+
NOTE: The returned ``GeneratorSpec`` may be a different object than
34+
what was provided in the original list. It may be possible to
35+
clone and modify the original ``GeneratorSpec`` to produce one that
36+
performs better.
37+
"""
38+
39+
40+
class ReductionCriterion(Enum):
41+
"""An enum for callables that are used for aggregating diagnostics over metrics
42+
and selecting the best diagnostic in ``SingleDiagnosticBestModelSelector``.
43+
44+
NOTE: This is used to ensure serializability of the callables.
45+
"""
46+
47+
# NOTE: Callables need to be wrapped in `partial` to be registered as members.
48+
# pyre-fixme[35]: Target cannot be annotated.
49+
MEAN: Callable[[ARRAYLIKE], npt.NDArray] = partial(np.mean)
50+
# pyre-fixme[35]: Target cannot be annotated.
51+
MIN: Callable[[ARRAYLIKE], npt.NDArray] = partial(np.min)
52+
# pyre-fixme[35]: Target cannot be annotated.
53+
MAX: Callable[[ARRAYLIKE], npt.NDArray] = partial(np.max)
54+
55+
def __call__(self, array_like: ARRAYLIKE) -> npt.NDArray:
56+
return self.value(array_like)
57+
58+
59+
class SingleDiagnosticBestModelSelector(BestModelSelector):
60+
"""Choose the best model using a single cross-validation diagnostic.
61+
62+
The input is a list of ``GeneratorSpec``, each corresponding to one model.
63+
The specified diagnostic is extracted from each of the models,
64+
its values (each of which corresponds to a separate metric) are
65+
aggregated with the aggregation function, the best one is determined
66+
with the criterion, and the index of the best diagnostic result is returned.
67+
68+
Example:
69+
::
70+
s = SingleDiagnosticBestModelSelector(
71+
diagnostic='Fisher exact test p',
72+
metric_aggregation=ReductionCriterion.MEAN,
73+
criterion=ReductionCriterion.MIN,
74+
model_cv_kwargs={"untransform": False},
75+
)
76+
best_model = s.best_model(model_specs=model_specs)
77+
78+
Args:
79+
diagnostic: The name of the diagnostic to use, which should be
80+
a key in ``CVDiagnostic``.
81+
metric_aggregation: ``ReductionCriterion`` applied to the values of the
82+
diagnostic for a single model to produce a single number.
83+
criterion: ``ReductionCriterion`` used to determine which of the
84+
(aggregated) diagnostics is the best.
85+
model_cv_kwargs: Optional dictionary of kwargs to pass in while computing
86+
the cross validation diagnostics.
87+
"""
88+
89+
def __init__(
90+
self,
91+
diagnostic: str,
92+
metric_aggregation: ReductionCriterion,
93+
criterion: ReductionCriterion,
94+
model_cv_kwargs: dict[str, Any] | None = None,
95+
) -> None:
96+
self.diagnostic = diagnostic
97+
if not isinstance(metric_aggregation, ReductionCriterion) or not isinstance(
98+
criterion, ReductionCriterion
99+
):
100+
raise UserInputError(
101+
"Both `metric_aggregation` and `criterion` must be "
102+
f"`ReductionCriterion`. Got {metric_aggregation=}, {criterion=}."
103+
)
104+
if criterion == ReductionCriterion.MEAN:
105+
raise UserInputError(
106+
f"{criterion=} is not supported. Please use MIN or MAX."
107+
)
108+
self.metric_aggregation = metric_aggregation
109+
self.criterion = criterion
110+
self.model_cv_kwargs = model_cv_kwargs
111+
112+
def best_model(self, model_specs: list[GeneratorSpec]) -> GeneratorSpec:
113+
"""Return the best ``GeneratorSpec`` based on the specified diagnostic.
114+
115+
Args:
116+
model_specs: List of ``GeneratorSpec`` to choose from.
117+
118+
Returns:
119+
The best ``GeneratorSpec`` based on the specified diagnostic.
120+
"""
121+
for model_spec in model_specs:
122+
model_spec.cross_validate(model_cv_kwargs=self.model_cv_kwargs)
123+
aggregated_diagnostic_values = [
124+
self.metric_aggregation(
125+
list(none_throws(model_spec.diagnostics)[self.diagnostic].values())
126+
)
127+
for model_spec in model_specs
128+
]
129+
best_diagnostic = self.criterion(aggregated_diagnostic_values).item()
130+
best_index = aggregated_diagnostic_values.index(best_diagnostic)
131+
return model_specs[best_index]

0 commit comments

Comments
 (0)