|
32 | 32 | predicted_hypervolume,
|
33 | 33 | validate_and_apply_final_transform,
|
34 | 34 | )
|
35 |
| -from ax.modelbridge.registry import get_model_from_generator_run, ModelRegistryBase |
| 35 | +from ax.modelbridge.registry import ModelRegistryBase |
36 | 36 | from ax.modelbridge.torch import TorchModelBridge
|
37 | 37 | from ax.modelbridge.transforms.derelativize import Derelativize
|
38 | 38 | from ax.models.torch.botorch_moo_defaults import (
|
@@ -387,35 +387,15 @@ def _get_hypervolume(
|
387 | 387 | )
|
388 | 388 |
|
389 | 389 | if use_model_predictions:
|
390 |
| - current_model = generation_strategy._curr.model_spec_to_gen_from.model_enum |
391 |
| - # Cover for the case where source of `self._curr.model` was not a `Models` |
392 |
| - # enum but a factory function, in which case we cannot do |
393 |
| - # `get_model_from_generator_run` (since we don't have model type and inputs |
394 |
| - # recorded on the generator run. |
395 |
| - models_enum = ( |
396 |
| - current_model.__class__ |
397 |
| - if isinstance(current_model, ModelRegistryBase) |
398 |
| - else None |
399 |
| - ) |
400 |
| - |
401 |
| - if models_enum is None: |
402 |
| - raise ValueError( |
403 |
| - f"Model {current_model} is not in the ModelRegistry, cannot " |
404 |
| - "calculate predicted hypervolume." |
405 |
| - ) |
406 |
| - |
407 |
| - model = get_model_from_generator_run( |
408 |
| - generator_run=none_throws(generation_strategy.last_generator_run), |
409 |
| - experiment=experiment, |
410 |
| - data=experiment.fetch_data(trial_indices=trial_indices), |
411 |
| - models_enum=models_enum, |
412 |
| - ) |
| 390 | + # Make sure that the model is fitted. If model is fitted already, |
| 391 | + # this should be a no-op. |
| 392 | + generation_strategy._fit_current_model(data=None) |
| 393 | + model = generation_strategy.model |
413 | 394 | if not isinstance(model, TorchModelBridge):
|
414 | 395 | raise ValueError(
|
415 |
| - f"Model {current_model} is not of type TorchModelBridge, cannot " |
| 396 | + f"Model {model} is not of type TorchModelBridge, cannot " |
416 | 397 | "calculate predicted hypervolume."
|
417 | 398 | )
|
418 |
| - |
419 | 399 | return predicted_hypervolume(
|
420 | 400 | modelbridge=model, optimization_config=optimization_config
|
421 | 401 | )
|
|
0 commit comments