Skip to content

Commit d059120

Browse files
saitcakmakfacebook-github-bot
authored andcommitted
Utilize existing model from GS in BestPointMixin._get_hypervolume (#3285)
Summary: Pull Request resolved: #3285 This was previously re-constructing the model using `get_model_from_generator_run`, which is a helper that I want to deprecate. Since the GS is readily available, we can utilize the model from the GS rather than re-constructing & fitting it from scratch. Reviewed By: esantorella Differential Revision: D68836172 fbshipit-source-id: af8300f6a80f6802977d6bdb482eab2dc982421e
1 parent 9be04d6 commit d059120

File tree

1 file changed

+6
-26
lines changed

1 file changed

+6
-26
lines changed

ax/service/utils/best_point_mixin.py

+6-26
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
predicted_hypervolume,
3333
validate_and_apply_final_transform,
3434
)
35-
from ax.modelbridge.registry import get_model_from_generator_run, ModelRegistryBase
35+
from ax.modelbridge.registry import ModelRegistryBase
3636
from ax.modelbridge.torch import TorchModelBridge
3737
from ax.modelbridge.transforms.derelativize import Derelativize
3838
from ax.models.torch.botorch_moo_defaults import (
@@ -387,35 +387,15 @@ def _get_hypervolume(
387387
)
388388

389389
if use_model_predictions:
390-
current_model = generation_strategy._curr.model_spec_to_gen_from.model_enum
391-
# Cover for the case where source of `self._curr.model` was not a `Models`
392-
# enum but a factory function, in which case we cannot do
393-
# `get_model_from_generator_run` (since we don't have model type and inputs
394-
# recorded on the generator run.
395-
models_enum = (
396-
current_model.__class__
397-
if isinstance(current_model, ModelRegistryBase)
398-
else None
399-
)
400-
401-
if models_enum is None:
402-
raise ValueError(
403-
f"Model {current_model} is not in the ModelRegistry, cannot "
404-
"calculate predicted hypervolume."
405-
)
406-
407-
model = get_model_from_generator_run(
408-
generator_run=none_throws(generation_strategy.last_generator_run),
409-
experiment=experiment,
410-
data=experiment.fetch_data(trial_indices=trial_indices),
411-
models_enum=models_enum,
412-
)
390+
# Make sure that the model is fitted. If model is fitted already,
391+
# this should be a no-op.
392+
generation_strategy._fit_current_model(data=None)
393+
model = generation_strategy.model
413394
if not isinstance(model, TorchModelBridge):
414395
raise ValueError(
415-
f"Model {current_model} is not of type TorchModelBridge, cannot "
396+
f"Model {model} is not of type TorchModelBridge, cannot "
416397
"calculate predicted hypervolume."
417398
)
418-
419399
return predicted_hypervolume(
420400
modelbridge=model, optimization_config=optimization_config
421401
)

0 commit comments

Comments
 (0)