Skip to content

Commit ebe8fd9

Browse files
mgarrardfacebook-github-bot
authored andcommitted
reap gen_multiple and replace with gen_for_multiple_with_multiple
Summary: This diff removes _gen_multiple and replaces it with calls to gen_for_multiple_with_multiple. We plan to replace gen() with gen_for_multiple_with_multiple() with the Ax1.0 release so will keep both around for now. Internal - Tldr: {F1973912988} See diff 1/n in the stack for context Reviewed By: saitcakmak Differential Revision: D67319697
1 parent 51a90db commit ebe8fd9

File tree

6 files changed

+352
-276
lines changed

6 files changed

+352
-276
lines changed
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-strict
7+
8+
from __future__ import annotations
9+
10+
from abc import ABC, abstractmethod
11+
12+
from typing import Any
13+
14+
from ax.core.data import Data
15+
from ax.core.experiment import Experiment
16+
from ax.core.generator_run import GeneratorRun
17+
from ax.core.observation import ObservationFeatures
18+
from ax.exceptions.core import AxError, UnsupportedError
19+
from ax.utils.common.base import Base
20+
from pyre_extensions import none_throws
21+
22+
23+
class GenerationStrategyInterface(ABC, Base):
24+
"""Interface for all generation strategies: standard Ax
25+
``GenerationStrategy``, as well as non-standard (e.g. remote, external)
26+
generation strategies.
27+
28+
NOTE: Currently in Beta; please do not use without discussion with the Ax
29+
developers.
30+
"""
31+
32+
_name: str
33+
# Experiment, for which this generation strategy has generated trials, if
34+
# it exists.
35+
_experiment: Experiment | None = None
36+
37+
# Constant for default number of arms to generate if `n` is not specified in
38+
# `gen` call and "total_concurrent_arms" is not set in experiment properties.
39+
DEFAULT_N: int = 1
40+
41+
def __init__(self, name: str) -> None:
42+
self._name = name
43+
44+
@abstractmethod
45+
def gen_for_multiple_trials_with_multiple_models(
46+
self,
47+
experiment: Experiment,
48+
data: Data | None = None,
49+
pending_observations: dict[str, list[ObservationFeatures]] | None = None,
50+
n: int | None = None,
51+
fixed_features: ObservationFeatures | None = None,
52+
num_trials: int = 1,
53+
arms_per_node: dict[str, int] | None = None,
54+
) -> list[list[GeneratorRun]]:
55+
"""Produce ``GeneratorRun``-s for multiple trials at once with the possibility
56+
of joining ``GeneratorRun``-s from multiple models into one ``BatchTrial``.
57+
58+
Args:
59+
experiment: ``Experiment``, for which the generation strategy is producing
60+
a new generator run in the course of ``gen``, and to which that
61+
generator run will be added as trial(s). Information stored on the
62+
experiment (e.g., trial statuses) is used to determine which model
63+
will be used to produce the generator run returned from this method.
64+
data: Optional data to be passed to the underlying model's ``gen``, which
65+
is called within this method and actually produces the resulting
66+
generator run. By default, data is all data on the ``experiment``.
67+
pending_observations: A map from metric name to pending
68+
observations for that metric, used by some models to avoid
69+
resuggesting points that are currently being evaluated.
70+
n: Integer representing how many total arms should be in the generator
71+
runs produced by this method. NOTE: Some underlying models may ignore
72+
the `n` and produce a model-determined number of arms. In that
73+
case this method will also output generator runs with number of
74+
arms that can differ from `n`.
75+
fixed_features: An optional set of ``ObservationFeatures`` that will be
76+
passed down to the underlying models. Note: if provided this will
77+
override any algorithmically determined fixed features so it is
78+
important to specify all necessary fixed features.
79+
num_trials: Number of trials to generate generator runs for in this call.
80+
If not provided, defaults to 1.
81+
arms_per_node: An optional map from node name to the number of arms to
82+
generate from that node. If not provided, will default to the number
83+
of arms specified in the node's ``InputConstructors`` or n if no
84+
``InputConstructors`` are defined on the node. We expect either n or
85+
arms_per_node to be provided, but not both, and this is an advanced
86+
argument that should only be used by advanced users.
87+
88+
Returns:
89+
A list of lists of ``GeneratorRun``-s. Each outer list item represents
90+
a ``(Batch)Trial`` being suggested, with a list of ``GeneratorRun``-s for
91+
that trial.
92+
"""
93+
# When implementing your subclass' override for this method, don't forget
94+
# to consider using "pending points", corresponding to arms in trials that
95+
# are currently running / being evaluated/
96+
...
97+
98+
def _gen_multiple(
99+
self,
100+
experiment: Experiment,
101+
num_generator_runs: int,
102+
data: Data | None = None,
103+
n: int = 1,
104+
pending_observations: dict[str, list[ObservationFeatures]] | None = None,
105+
**model_gen_kwargs: Any,
106+
) -> list[GeneratorRun]:
107+
"""Produce multiple generator runs at once, to be made into multiple
108+
trials on the experiment.
109+
110+
NOTE: This is used to ensure that maximum parallelism and number
111+
of trials per node are not violated when producing many generator
112+
runs from this generation strategy in a row. Without this function,
113+
if one generates multiple generator runs without first making any
114+
of them into running trials, generation strategy cannot enforce that it only
115+
produces as many generator runs as are allowed by the parallelism
116+
limit and the limit on number of trials in current node.
117+
118+
Args:
119+
experiment: Experiment, for which the generation strategy is producing
120+
a new generator run in the course of `gen`, and to which that
121+
generator run will be added as trial(s). Information stored on the
122+
experiment (e.g., trial statuses) is used to determine which model
123+
will be used to produce the generator run returned from this method.
124+
data: Optional data to be passed to the underlying model's `gen`, which
125+
is called within this method and actually produces the resulting
126+
generator run. By default, data is all data on the `experiment`.
127+
n: Integer representing how many arms should be in the generator run
128+
produced by this method. NOTE: Some underlying models may ignore
129+
the ``n`` and produce a model-determined number of arms. In that
130+
case this method will also output a generator run with number of
131+
arms that can differ from ``n``.
132+
pending_observations: A map from metric name to pending
133+
observations for that metric, used by some models to avoid
134+
resuggesting points that are currently being evaluated.
135+
model_gen_kwargs: Keyword arguments that are passed through to
136+
``GenerationNode.gen``, which will pass them through to
137+
``GeneratorSpec.gen``, which will pass them to ``Adapter.gen``.
138+
"""
139+
...
140+
141+
@abstractmethod
142+
def clone_reset(self) -> GenerationStrategyInterface:
143+
"""Returns a clone of this generation strategy with all state reset."""
144+
...
145+
146+
@property
147+
def name(self) -> str:
148+
"""Name of this generation strategy."""
149+
return self._name
150+
151+
@property
152+
def experiment(self) -> Experiment:
153+
"""Experiment, currently set on this generation strategy."""
154+
if self._experiment is None:
155+
raise AxError("No experiment set on generation strategy.")
156+
return none_throws(self._experiment)
157+
158+
@experiment.setter
159+
def experiment(self, experiment: Experiment) -> None:
160+
"""If there is an experiment set on this generation strategy as the
161+
experiment it has been generating generator runs for, check if the
162+
experiment passed in is the same as the one saved and log an information
163+
statement if its not. Set the new experiment on this generation strategy.
164+
"""
165+
if self._experiment is not None and experiment._name != self.experiment._name:
166+
raise UnsupportedError(
167+
"This generation strategy has been used for experiment "
168+
f"{self.experiment._name} so far; cannot reset experiment"
169+
f" to {experiment._name}. If this is a new experiment, "
170+
"a new generation strategy should be created instead."
171+
)
172+
self._experiment = experiment
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
6+
# pyre-strict
7+
8+
9+
from ax.core.data import Data
10+
from ax.core.experiment import Experiment
11+
from ax.core.generation_strategy_interface import GenerationStrategyInterface
12+
from ax.core.generator_run import GeneratorRun
13+
from ax.core.observation import ObservationFeatures
14+
from ax.exceptions.core import AxError, UnsupportedError
15+
from ax.utils.common.testutils import TestCase
16+
from ax.utils.testing.core_stubs import get_experiment, SpecialGenerationStrategy
17+
18+
19+
class MyGSI(GenerationStrategyInterface):
20+
def gen_for_multiple_trials_with_multiple_models(
21+
self,
22+
experiment: Experiment,
23+
data: Data | None = None,
24+
pending_observations: dict[str, list[ObservationFeatures]] | None = None,
25+
n: int | None = None,
26+
fixed_features: ObservationFeatures | None = None,
27+
num_trials: int = 1,
28+
arms_per_node: dict[str, int] | None = None,
29+
) -> list[list[GeneratorRun]]:
30+
raise NotImplementedError
31+
32+
def clone_reset(self) -> "MyGSI":
33+
raise NotImplementedError
34+
35+
36+
class TestGenerationStrategyInterface(TestCase):
37+
def setUp(self) -> None:
38+
super().setUp()
39+
self.exp = get_experiment()
40+
self.gsi = MyGSI(name="my_GSI")
41+
self.special_gsi = SpecialGenerationStrategy()
42+
43+
def test_constructor(self) -> None:
44+
with self.assertRaisesRegex(TypeError, ".* abstract"):
45+
GenerationStrategyInterface(name="my_GSI") # pyre-ignore[45]
46+
self.assertEqual(self.gsi.name, "my_GSI")
47+
48+
def test_abstract(self) -> None:
49+
with self.assertRaises(NotImplementedError):
50+
self.gsi.gen_for_multiple_trials_with_multiple_models(experiment=self.exp)
51+
52+
with self.assertRaises(NotImplementedError):
53+
self.gsi.clone_reset()
54+
55+
def test_experiment(self) -> None:
56+
with self.assertRaisesRegex(AxError, "No experiment"):
57+
self.gsi.experiment
58+
self.gsi.experiment = self.exp
59+
exp_2 = get_experiment()
60+
exp_2.name = "exp_2"
61+
with self.assertRaisesRegex(UnsupportedError, "has been used for"):
62+
self.gsi.experiment = exp_2

0 commit comments

Comments
 (0)