1919from ax .core .generator_run import GeneratorRun
2020from ax .core .observation import ObservationFeatures
2121from ax .core .utils import extend_pending_observations , extract_pending_observations
22- from ax .exceptions .core import (
23- AxError ,
24- DataRequiredError ,
25- UnsupportedError ,
26- UserInputError ,
27- )
22+ from ax .exceptions .core import AxError , DataRequiredError , UnsupportedError
2823from ax .exceptions .generation_strategy import (
2924 GenerationStrategyCompleted ,
3025 GenerationStrategyMisconfiguredException ,
@@ -247,7 +242,6 @@ def gen(
247242 n : int | None = None ,
248243 fixed_features : ObservationFeatures | None = None ,
249244 num_trials : int = 1 ,
250- arms_per_node : dict [str , int ] | None = None ,
251245 ) -> list [list [GeneratorRun ]]:
252246 """Produce GeneratorRuns for multiple trials at once with the possibility of
253247 using multiple models per trial, getting multiple GeneratorRuns per trial.
@@ -275,12 +269,6 @@ def gen(
275269 important to specify all necessary fixed features.
276270 num_trials: Number of trials to generate generator runs for in this call.
277271 If not provided, defaults to 1.
278- arms_per_node: An optional map from node name to the number of arms to
279- generate from that node. If not provided, will default to the number
280- of arms specified in the node's ``InputConstructors`` or n if no
281- ``InputConstructors`` are defined on the node. We expect either n or
282- arms_per_node to be provided, but not both, and this is an advanced
283- argument that should only be used by advanced users.
284272
285273 Returns:
286274 A list of lists of lists generator runs. Each outer list represents
@@ -306,7 +294,6 @@ def gen(
306294 data = data ,
307295 n = n ,
308296 pending_observations = pending_observations ,
309- arms_per_node = arms_per_node ,
310297 fixed_features = fixed_features ,
311298 first_generation_in_multi = len (grs_for_multiple_trials ) < 1 ,
312299 )
@@ -467,24 +454,6 @@ def _validate_and_set_node_graph(self, nodes: list[GenerationNode]) -> None:
467454
468455 self ._curr = nodes [0 ]
469456
470- def _validate_arms_per_node (self , arms_per_node : dict [str , int ] | None ) -> None :
471- """Validate that the arms_per_node argument is valid if it is provided.
472-
473- Args:
474- arms_per_node: A map from node name to the number of arms to
475- generate from that node.
476- """
477- if arms_per_node is not None and not set (self .nodes_by_name ).issubset (
478- arms_per_node
479- ):
480- raise UserInputError (
481- "Each node defined in the `GenerationStrategy` must have an "
482- "associated number of arms to generate from that node defined "
483- f"in `arms_per_node`. { arms_per_node } does not include all of "
484- f"{ self .nodes_by_name .keys ()} . "
485- "It may help to double-check the spelling."
486- )
487-
488457 def _make_default_name (self ) -> str :
489458 """Make a default name for this generation strategy; used when no name is passed
490459 to the constructor. For node-based generation strategies, the name is
@@ -515,10 +484,6 @@ def _gen_with_multiple_nodes(
515484 pending_observations : dict [str , list [ObservationFeatures ]] | None = None ,
516485 data : Data | None = None ,
517486 fixed_features : ObservationFeatures | None = None ,
518- # TODO: Consider naming `arms_per_node` smtg like `arms_per_node_override`,
519- # to convey its manually-specified nature (if it's not specified, GS selects
520- # what to do on its own).
521- arms_per_node : dict [str , int ] | None = None ,
522487 first_generation_in_multi : bool = True ,
523488 ) -> list [GeneratorRun ]:
524489 """Produces a List of GeneratorRuns for a single trial, either ``Trial`` or
@@ -548,12 +513,6 @@ def _gen_with_multiple_nodes(
548513 passed down to the underlying nodes. Note: if provided this will
549514 override any algorithmically determined fixed features so it is
550515 important to specify all necessary fixed features.
551- arms_per_node: An optional map from node name to the number of arms to
552- generate from that node. If not provided, will default to the number
553- of arms specified in the node's ``InputConstructors`` or n if no
554- ``InputConstructors`` are defined on the node. We expect either n or
555- arms_per_node to be provided, but not both, and this is an advanced
556- argument that should only be used by advanced users.
557516
558517 Returns:
559518 A list of ``GeneratorRuns`` for a single trial.
@@ -570,7 +529,6 @@ def _gen_with_multiple_nodes(
570529 pending_observations if pending_observations is not None else {}
571530 )
572531 self .experiment = experiment
573- self ._validate_arms_per_node (arms_per_node = arms_per_node )
574532 pack_gs_gen_kwargs = {
575533 "grs_this_gen" : grs_this_gen ,
576534 "fixed_features" : fixed_features ,
@@ -596,7 +554,6 @@ def _gen_with_multiple_nodes(
596554 pending_observations = pending_observations ,
597555 skip_fit = not (first_generation_in_multi or transitioned ),
598556 n = n ,
599- arms_per_node = arms_per_node ,
600557 ** pack_gs_gen_kwargs ,
601558 )
602559 except DataRequiredError as err :
0 commit comments