Skip to content

Commit 75c0ed4

Browse files
andycylmetameta-codesync[bot]
authored andcommitted
Prefer predictive adapter over fallback in GenerationNode._fitted_adapter (#5199)
Summary: Pull Request resolved: #5199 When BoTorch candidate generation fails after MAX_GEN_ATTEMPTS (e.g. due to search space exhaustion), the generation node falls back to Sobol for that particular gen call. This fallback overwrites `_generator_spec_to_gen_from` with a RandomAdapter, which cannot make predictions. The problem is that downstream analysis code reads `GenerationStrategy.adapter` to generate model-dependent plots (cross-validation, sensitivity, surface, modeled arm effects, etc.). Since the adapter now points to the Sobol fallback's RandomAdapter, all these analyses fail with "does not support predictions" or "TorchAdapter is required" errors -- even though the original fitted TorchAdapter is still preserved on the generator spec. This diff fixes `GenerationNode._fitted_adapter` to check: if the current adapter cannot predict, look for a fitted predictive adapter among the original `generator_specs` and prefer that instead. This is safe because the original TorchAdapter is never destroyed by the fallback -- it's just shadowed by the `_generator_spec_to_gen_from` override. Reviewed By: ItsMrLin Differential Revision: D99358260 fbshipit-source-id: 30821af699806be50c5af88589b9209080c58801
1 parent bbb0d71 commit 75c0ed4

2 files changed

Lines changed: 53 additions & 1 deletion

File tree

ax/generation_strategy/generation_node.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,15 +349,32 @@ def _fitted_adapter(self) -> Adapter | None:
349349
"""Private property to return optional fitted_adapter from
350350
self.generator_spec_to_gen_from for convenience. If no model is fit,
351351
this will return None.
352+
353+
If the current adapter (e.g. from a Sobol fallback after
354+
``_try_gen_with_fallback``) cannot predict, prefer a predictive adapter
355+
from the original ``generator_specs`` when available. This ensures that
356+
analysis code which relies on model predictions (e.g. cross-validation,
357+
sensitivity, surface plots) can still use the fitted surrogate model
358+
even after a transient fallback during candidate generation.
352359
"""
353360
try:
354361
# Using the private attribute since using the non-private `fitted_adapter`
355362
# property will raise a UserInputError if there is no fitted model.
356-
return self.generator_spec_to_gen_from._fitted_adapter
363+
adapter = self.generator_spec_to_gen_from._fitted_adapter
357364
except ModelError:
358365
# ModelError is raised if there are no fitted adapters to select from.
359366
return None
360367

368+
if adapter is not None and not adapter.can_predict:
369+
for spec in self.generator_specs:
370+
if (
371+
spec._fitted_adapter is not None
372+
and spec._fitted_adapter.can_predict
373+
):
374+
return spec._fitted_adapter
375+
376+
return adapter
377+
361378
def __repr__(self) -> str:
362379
"""String representation of this ``GenerationNode`` (note that it
363380
will abridge some aspects of ``TransitionCriterion`` and

ax/generation_strategy/tests/test_generation_node.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,41 @@ def test_gen_with_no_trial_type(self) -> None:
285285
self.assertIsNotNone(gr)
286286
self.assertNotIn("trial_type", none_throws(gr.gen_metadata))
287287

288+
@mock_botorch_optimize
289+
def test_fitted_adapter_prefers_predictive_over_fallback(self) -> None:
290+
"""After a Sobol fallback, _fitted_adapter should still return the
291+
original predictive TorchAdapter rather than the fallback's
292+
RandomAdapter. This ensures analysis code can generate model-dependent
293+
plots even after a transient fallback during candidate generation."""
294+
node = GenerationNode(
295+
name="test",
296+
generator_specs=[
297+
GeneratorSpec(
298+
generator_enum=Generators.BOTORCH_MODULAR,
299+
generator_kwargs={},
300+
generator_gen_kwargs={},
301+
),
302+
],
303+
)
304+
node._fit(experiment=self.branin_experiment)
305+
original_adapter = none_throws(node._fitted_adapter)
306+
self.assertTrue(original_adapter.can_predict)
307+
308+
# Simulate fallback: fit a Sobol fallback spec and override
309+
# _generator_spec_to_gen_from, mimicking _try_gen_with_fallback.
310+
fallback_spec = GeneratorSpec(
311+
generator_enum=Generators.SOBOL,
312+
generator_key_override="Fallback_Sobol",
313+
)
314+
fallback_spec.fit(experiment=self.branin_experiment)
315+
self.assertFalse(none_throws(fallback_spec._fitted_adapter).can_predict)
316+
node._generator_spec_to_gen_from = fallback_spec
317+
318+
# _fitted_adapter should still return the original predictive adapter.
319+
adapter_after_fallback = none_throws(node._fitted_adapter)
320+
self.assertTrue(adapter_after_fallback.can_predict)
321+
self.assertIs(adapter_after_fallback, original_adapter)
322+
288323
@mock_botorch_optimize
289324
def test_generator_gen_kwargs_deepcopy(self) -> None:
290325
sampler = SobolQMCNormalSampler(torch.Size([1]))

0 commit comments

Comments
 (0)