-
Notifications
You must be signed in to change notification settings - Fork 372
Expand file tree
/
Copy pathsurrogate.py
More file actions
1309 lines (1198 loc) · 55 KB
/
surrogate.py
File metadata and controls
1309 lines (1198 loc) · 55 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
from __future__ import annotations
import inspect
from collections import OrderedDict
from collections.abc import Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass, field, replace
from logging import Logger
from typing import Any, cast
import numpy as np
import torch
from ax.core.search_space import SearchSpaceDigest
from ax.core.types import TCandidateMetadata
from ax.exceptions.core import AxError, UnsupportedError, UserInputError
from ax.exceptions.model import ModelError
from ax.generators.torch.botorch_modular.input_constructors.covar_modules import (
covar_module_argparse,
)
from ax.generators.torch.botorch_modular.input_constructors.input_transforms import (
input_transform_argparse,
)
from ax.generators.torch.botorch_modular.input_constructors.outcome_transform import (
outcome_transform_argparse,
)
from ax.generators.torch.botorch_modular.utils import (
convert_to_block_design,
copy_model_config_with_default_values,
fit_botorch_model,
get_all_task_values_from_ssd,
get_cv_fold,
ModelConfig,
subset_state_dict,
use_model_list,
)
from ax.generators.torch.utils import (
_to_equality_constraints,
_to_inequality_constraints,
pick_best_out_of_sample_point_acqf_class,
predict_from_model,
)
from ax.generators.torch_base import TorchOptConfig
from ax.generators.types import TConfig
from ax.generators.utils import best_in_sample_point
from ax.utils.common.base import Base
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import _argparse_type_encoder
from ax.utils.stats.model_fit_stats import (
DIAGNOSTIC_FN_DIRECTIONS,
DIAGNOSTIC_FNS,
ModelFitMetricDirection,
RANK_CORRELATION,
)
from botorch.exceptions.errors import ModelFittingError
from botorch.models.gpytorch import GPyTorchModel
from botorch.models.model import Model, ModelList
from botorch.models.model_list_gp_regression import ModelListGP
from botorch.models.multitask import MultiTaskGP
from botorch.models.transforms.input import (
ChainedInputTransform,
InputTransform,
LearnedFeatureImputation,
Normalize,
)
from botorch.models.transforms.outcome import ChainedOutcomeTransform, OutcomeTransform
from botorch.posteriors.fully_bayesian import GaussianMixturePosterior
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.settings import validate_input_scaling
from botorch.utils.containers import SliceContainer
from botorch.utils.datasets import MultiTaskDataset, RankingDataset, SupervisedDataset
from botorch.utils.dispatcher import Dispatcher
from botorch.utils.evaluation import AIC, BIC, compute_in_sample_model_fit_metric, MLL
from botorch.utils.transforms import normalize_indices
from botorch.utils.types import _DefaultType
from gpytorch.models.exact_gp import ExactGP
from pyre_extensions import assert_is_instance, none_throws
from torch import Tensor
from torch.nn import Module
NOT_YET_FIT_MSG = (
"Underlying BoTorch `Model` has not yet received its training_data. "
"Please fit the model first."
)
logger: Logger = get_logger(__name__)
MODEL_SELECTION_METRIC_DIRECTIONS: dict[str, ModelFitMetricDirection] = {
**DIAGNOSTIC_FN_DIRECTIONS,
MLL: ModelFitMetricDirection.MAXIMIZE,
AIC: ModelFitMetricDirection.MINIMIZE,
BIC: ModelFitMetricDirection.MINIMIZE,
}
def _extract_model_kwargs(
search_space_digest: SearchSpaceDigest, botorch_model_class: type[Model]
) -> dict[str, list[int] | dict[int, dict[int | float, list[int]]] | int]:
"""
Extracts keyword arguments that are passed to the `construct_inputs`
method of a BoTorch `Model` class.
Args:
search_space_digest: A `SearchSpaceDigest`.
botorch_model_class: The BoTorch model class to extract kwargs for.
Returns:
A dict of fidelity features, categorical features, and, if present, task
features.
"""
signature = inspect.signature(botorch_model_class)
fidelity_features = search_space_digest.fidelity_features
task_features = search_space_digest.task_features
if len(fidelity_features) > 0 and len(task_features) > 0:
raise NotImplementedError(
"Multi-Fidelity GP models with task_features are currently not supported."
)
if len(task_features) > 1:
raise NotImplementedError("Multiple task features are not supported.")
elif (
len(task_features) == 0
and issubclass(botorch_model_class, MultiTaskGP)
and "task_feature" in signature.parameters.keys()
):
# This is handled in Surrogate.model_selection and the MTGP will be
# skipped if there is no task feature.
# Some MTGP subclasses do not use task_feature, so we check for the class
# signature before erroring out.
raise ModelFittingError("Cannot fit MultiTaskGP without task feature.")
kwargs: dict[str, list[int] | dict[int, dict[int | float, list[int]]] | int] = {}
if len(search_space_digest.categorical_features) > 0:
kwargs["categorical_features"] = search_space_digest.categorical_features
if len(fidelity_features) > 0:
kwargs["fidelity_features"] = fidelity_features
if len(task_features) == 1:
task_feature = task_features[0]
if task_feature == len(search_space_digest.bounds) - 1:
# to support heterogeneous search spaces
task_feature = -1
kwargs["task_feature"] = task_feature
# Regular BoTorch models do not expect the argument `hierarchical_dependencies`.
# For now, it is the user's responsibility to make sure a hierarchical model is used
# when the HSS is not flattened.
if search_space_digest.hierarchical_dependencies:
kwargs["hierarchical_dependencies"] = (
search_space_digest.hierarchical_dependencies
)
return kwargs
def _make_botorch_input_transform(
input_transform_classes: list[type[InputTransform]] | _DefaultType,
input_transform_options: dict[str, dict[str, Any]],
dataset: SupervisedDataset,
search_space_digest: SearchSpaceDigest,
) -> InputTransform | None:
"""
Makes a BoTorch input transform from the provided input classes and options.
"""
if isinstance(input_transform_classes, _DefaultType):
transforms = _construct_default_input_transforms(
search_space_digest=search_space_digest, dataset=dataset
)
else:
transforms = _construct_specified_input_transforms(
input_transform_classes=input_transform_classes,
dataset=dataset,
search_space_digest=search_space_digest,
input_transform_options=input_transform_options,
)
if len(transforms) == 0:
return None
elif len(transforms) > 1:
return ChainedInputTransform(
**{f"tf{i}": t_i for i, t_i in enumerate(transforms)}
)
else:
return transforms[0]
def _construct_specified_input_transforms(
input_transform_classes: list[type[InputTransform]],
input_transform_options: dict[str, dict[str, Any]],
dataset: SupervisedDataset,
search_space_digest: SearchSpaceDigest,
) -> list[InputTransform]:
"""Constructs a list of input transforms from input transform classes and
options provided in ``ModelConfig``.
"""
if not (
isinstance(input_transform_classes, list)
and all(issubclass(c, InputTransform) for c in input_transform_classes)
):
raise UserInputError(
"Expected a list of input transform classes. "
f"Got {input_transform_classes=}."
)
input_transform_kwargs = [
input_transform_argparse(
transform_class,
dataset=dataset,
search_space_digest=search_space_digest,
input_transform_options=deepcopy( # In case of in-place modifications.
input_transform_options.get(transform_class.__name__, {})
),
)
for transform_class in input_transform_classes
]
return [
# pyre-ignore[45]: Concrete subclasses are passed at runtime.
transform_class(**single_input_transform_kwargs)
for transform_class, single_input_transform_kwargs in zip(
input_transform_classes, input_transform_kwargs
)
]
def _construct_default_input_transforms(
search_space_digest: SearchSpaceDigest,
dataset: SupervisedDataset,
) -> list[InputTransform]:
"""Construct the default input transforms for the given search space digest.
The default transforms are added in this order:
- If the bounds for the non-task features are not [0, 1], a ``Normalize`` transform
is used. The transfrom only applies to the non-task features.
"""
transforms = []
# Processing for Normalize.
input_transform_options = input_transform_argparse(
Normalize,
dataset=dataset,
search_space_digest=search_space_digest,
)
bounds = input_transform_options.get("bounds")
indices = input_transform_options.get("indices")
# Skip the Normalize transform if the bounds are [0, 1].
if bounds is not None:
if indices is not None:
bounds = bounds[:, indices]
lower_bounds, upper_bounds = bounds
if torch.allclose(
lower_bounds, torch.zeros_like(lower_bounds)
) and torch.allclose(upper_bounds, torch.ones_like(upper_bounds)):
return transforms
transforms.append(Normalize(**input_transform_options))
return transforms
def _make_botorch_outcome_transform(
outcome_transform_classes: list[type[OutcomeTransform]],
outcome_transform_options: dict[str, dict[str, Any]],
dataset: SupervisedDataset,
search_space_digest: SearchSpaceDigest,
) -> OutcomeTransform | None:
"""
Makes a BoTorch outcome transform from the provided classes and options.
"""
if not (
isinstance(outcome_transform_classes, list)
and all(issubclass(c, OutcomeTransform) for c in outcome_transform_classes)
):
raise UserInputError("Expected a list of outcome transforms.")
if len(outcome_transform_classes) == 0:
return None
outcome_transform_kwargs = [
outcome_transform_argparse(
transform_class,
outcome_transform_options=deepcopy( # In case of in-place modifications.
outcome_transform_options.get(transform_class.__name__, {})
),
dataset=dataset,
search_space_digest=search_space_digest,
)
for transform_class in outcome_transform_classes
]
outcome_transforms = [
# pyre-ignore[45]: Concrete subclasses are passed at runtime.
transform_class(**single_outcome_transform_kwargs)
for transform_class, single_outcome_transform_kwargs in zip(
outcome_transform_classes, outcome_transform_kwargs
)
]
outcome_transform_instance = (
ChainedOutcomeTransform(
**{f"otf{i}": otf for i, otf in enumerate(outcome_transforms)}
)
if len(outcome_transforms) > 1
else outcome_transforms[0]
)
return outcome_transform_instance
def _construct_submodules(
model_config: ModelConfig,
dataset: SupervisedDataset,
search_space_digest: SearchSpaceDigest,
botorch_model_class: type[Model],
) -> dict[str, Module | None]:
"""Constructs the submodules for the BoTorch model from the inputs
extracted from the ``ModelConfig``. If the corresponding inputs are
specified, the `covar_module`, `likelihood`, `input_transform`, and
`outcome_transform` submodules are constructed.
"""
botorch_model_class_args: list[str] = inspect.getfullargspec(
botorch_model_class
).args
def _error_if_arg_not_supported(arg_name: str) -> None:
if arg_name not in botorch_model_class_args:
raise UserInputError(
f"The BoTorch model class {botorch_model_class.__name__} does not "
f"support the input {arg_name}."
)
submodules: dict[str, Module | None] = {}
# NOTE: Using the walrus operator here and below helps pyre.
if (covar_class := model_config.covar_module_class) is not None:
_error_if_arg_not_supported("covar_module")
covar_module_kwargs = covar_module_argparse(
covar_class,
dataset=dataset,
botorch_model_class=botorch_model_class,
**deepcopy(model_config.covar_module_options),
)
# pyre-ignore[45]: Concrete subclasses are passed at runtime.
submodules["covar_module"] = covar_class(**covar_module_kwargs)
if (likelihood_class := model_config.likelihood_class) is not None:
_error_if_arg_not_supported("likelihood")
# pyre-ignore[45]: Concrete subclasses are passed at runtime.
submodules["likelihood"] = likelihood_class(
**deepcopy(model_config.likelihood_options)
)
if (input_transform_classes := model_config.input_transform_classes) is not None:
_error_if_arg_not_supported("input_transform")
submodules["input_transform"] = _make_botorch_input_transform(
input_transform_classes=input_transform_classes or [],
input_transform_options=model_config.input_transform_options or {},
dataset=dataset,
search_space_digest=search_space_digest,
)
if (
outcome_transform_classes := model_config.outcome_transform_classes
) is not None:
_error_if_arg_not_supported("outcome_transform")
submodules["outcome_transform"] = _make_botorch_outcome_transform(
outcome_transform_classes=outcome_transform_classes,
outcome_transform_options=model_config.outcome_transform_options or {},
dataset=dataset,
search_space_digest=search_space_digest,
)
elif "outcome_transform" in botorch_model_class_args:
# This is a temporary solution until all BoTorch models use
# `Standardize` by default, see TODO [T197435440].
# After this, we should update `Surrogate` to use `DEFAULT`
# (https://fburl.com/code/22f4397e) for both of these args. This will
# allow users to explicitly disable the default transforms by passing
# in `None`.
submodules["outcome_transform"] = None
return submodules
@dataclass(frozen=True)
class SurrogateSpec:
"""
Fields in the SurrogateSpec dataclass correspond to arguments in
``Surrogate.__init__``, except for ``outcomes`` which is used to specify which
outcomes the Surrogate is responsible for modeling.
When ``BoTorchGenerator.fit`` is called, these fields will be used to construct the
requisite Surrogate objects.
If ``outcomes`` is left empty then no outcomes will be fit to the Surrogate.
Args:
model_configs: List of model configs. Each model config is a specification of
a surrogate model. Defaults to a single ``ModelConfig`` with all defaults.
metric_to_model_configs: Dictionary mapping metric signatures to a list of model
configs for that metric.
eval_criterion: The name of the evaluation criteria to use. These are defined in
``ax.utils.stats.model_fit_stats``. Defaults to rank correlation.
outcomes: List of outcomes names.
use_posterior_predictive: Whether to use posterior predictive in
cross-validation.
num_folds: The number of folds to use in cross-validation. If None, then
leave-one-out.
"""
model_configs: list[ModelConfig] = field(default_factory=lambda: [ModelConfig()])
metric_to_model_configs: dict[str, list[ModelConfig]] = field(default_factory=dict)
eval_criterion: str = RANK_CORRELATION
outcomes: list[str] = field(default_factory=list)
allow_batched_models: bool = True
use_posterior_predictive: bool = False
num_folds: int | None = 10
class Surrogate(Base):
"""
**All classes in 'botorch_modular' directory are under
construction, incomplete, and should be treated as alpha
versions only.**
Ax wrapper for BoTorch ``Model``, subcomponent of ``BoTorchGenerator``
and is not meant to be used outside of it.
Args:
surrogate_spec: A ``SurrogateSpec`` that specifies the option to use when
constructing the surrogate models. See the docstring of ``SurrogateSpec``
for supported options and ``ModelConfig`` for additional details.
allow_batched_models: Set to true to fit the models in a batch if supported.
Set to false to fit individual models to each metric in a loop.
refit_on_cv: Whether to refit the model on the cross-validation folds.
warm_start_refit: Whether to warm-start refitting from the current state_dict
during cross-validation. If refit_on_cv is True, generally one
would set this to be False, so that no information is leaked between or
across folds.
metric_to_best_model_config: Dictionary mapping a metric signature to the best
model config. This is only used by `BoTorchGenerator.cross_validate` and
for logging what model was used.
"""
def __init__(
self,
surrogate_spec: SurrogateSpec | None = None,
allow_batched_models: bool = True,
refit_on_cv: bool = False,
warm_start_refit: bool = True,
metric_to_best_model_config: dict[str, ModelConfig] | None = None,
) -> None:
if surrogate_spec is None:
surrogate_spec = SurrogateSpec(allow_batched_models=allow_batched_models)
self.surrogate_spec: SurrogateSpec = surrogate_spec
# Store the last dataset used to fit the model for a given metric(s).
# If the new dataset is identical, we will skip model fitting for that metric.
# The keys are `tuple(dataset.outcome_names)`.
self._last_datasets: dict[tuple[str], SupervisedDataset] = {}
# Store a reference from a tuple of metric signatures to the BoTorch Model
# corresponding to those metrics. In most cases this will be a one-tuple,
# though we need n-tuples for LCE-M models. This will be used to skip model
# construction & fitting if the datasets are identical.
self._submodels: dict[tuple[str], Model] = {}
self.metric_to_best_model_config: dict[str, ModelConfig] = (
metric_to_best_model_config or {}
)
# Store a reference to search space digest used while fitting the cached models.
# We will re-fit the models if the search space digest changes.
self._last_search_space_digest: SearchSpaceDigest | None = None
# These are later updated during model fitting.
self._training_data: list[SupervisedDataset] | None = None
self._outcomes: list[str] | None = None
self._model: Model | None = None
self.refit_on_cv = refit_on_cv
self.warm_start_refit = warm_start_refit
# Updated during model selection
self._model_name_to_eval: dict[str, dict[str, float]] = {}
self._model_name_to_model: dict[str, dict[str, Model]] = {}
def __repr__(self) -> str:
return f"<{self.__class__.__name__} surrogate_spec={self.surrogate_spec}>"
@property
def model(self) -> Model:
if self._model is None:
raise ModelError(
"BoTorch `Model` has not yet been constructed, please fit the "
"surrogate first (done via `BoTorchGenerator.fit`)."
)
return self._model
@property
def training_data(self) -> list[SupervisedDataset]:
if self._training_data is None:
raise ModelError(NOT_YET_FIT_MSG)
return self._training_data
@property
def Xs(self) -> list[Tensor]:
# Handles multi-output models. TODO: Improve this!
training_data = self.training_data
Xs = []
for dataset in training_data:
if isinstance(dataset, RankingDataset):
# directly accessing the d-dim X tensor values
# instead of the augmented 2*d-dim dataset.X from RankingDataset
Xi = assert_is_instance(
dataset._X,
SliceContainer,
).values
else:
Xi = dataset.X
for _ in range(dataset.Y.shape[-1]):
Xs.append(Xi)
return Xs
@property
def dtype(self) -> torch.dtype:
return self.training_data[0].X.dtype
@property
def device(self) -> torch.device:
return self.training_data[0].X.device
def clone_reset(self) -> Surrogate:
return self.__class__(**self._serialize_attributes_as_kwargs())
def _construct_model(
self,
dataset: SupervisedDataset,
search_space_digest: SearchSpaceDigest,
model_config: ModelConfig,
state_dict: Mapping[str, Tensor] | None,
refit: bool,
) -> Model:
"""Constructs the underlying BoTorch ``Model`` using the training data.
If the dataset and model class are identical to those used while training
the cached sub-model, we skip model fitting and return the cached model.
Args:
dataset: Training data for the model (for one outcome for
the default `Surrogate`, with the exception of batched
multi-output case, where training data is formatted with just
one X and concatenated Ys).
search_space_digest: Search space digest used to set up model arguments.
model_config: The model_config.
default_botorch_model_class: The default ``Model`` class to be used as the
underlying BoTorch model, if the model_config does not specify one.
state_dict: Optional state dict to load. This should be subsetted for
the current submodel being constructed.
refit: Whether to re-optimize model parameters.
"""
outcome_names = tuple(dataset.outcome_names)
# Fill in default values for model_configs given dataset
model_config = copy_model_config_with_default_values(
model_config=model_config,
dataset=dataset,
search_space_digest=search_space_digest,
)
botorch_model_class = none_throws(model_config.botorch_model_class)
if self._dataset_matches_cache(dataset=dataset):
return self._submodels[outcome_names]
formatted_model_inputs = submodel_input_constructor(
botorch_model_class, # Do not pass as kwarg since this is used to dispatch.
model_config=model_config,
dataset=dataset,
search_space_digest=search_space_digest,
surrogate=self,
)
# pyre-ignore[45]: Concrete subclasses are passed at runtime.
model = botorch_model_class(**formatted_model_inputs)
if state_dict is not None and (not refit or self.warm_start_refit):
model.load_state_dict(state_dict)
if state_dict is None or refit:
fit_botorch_model(
model, # Intentionally not using named args for dispatcher.
mll_class=model_config.mll_class,
mll_options=model_config.mll_options,
)
return model
def _dataset_matches_cache(
self,
dataset: SupervisedDataset,
) -> bool:
"""Returns `True` if the given dataset matches the last dataset used to fit
the model for the corresponding outcomes.
"""
outcome_names = tuple(dataset.outcome_names)
return (
outcome_names in self._submodels
and dataset == self._last_datasets[outcome_names]
)
def fit(
self,
datasets: Sequence[SupervisedDataset],
search_space_digest: SearchSpaceDigest,
candidate_metadata: list[list[TCandidateMetadata]] | None = None,
state_dict: Mapping[str, Tensor] | None = None,
refit: bool = True,
repeat_model_selection_if_dataset_changed: bool = True,
) -> None:
"""Fits the underlying BoTorch ``Model`` to ``m`` outcomes.
NOTE: ``state_dict`` and ``refit`` keyword arguments control how the
underlying BoTorch ``Model`` will be fit: whether its parameters will
be reoptimized and whether it will be warm-started from a given state.
There are three possibilities:
* ``fit(state_dict=None)``: fit model from scratch (optimize model
parameters and set its training data used for inference),
* ``fit(state_dict=some_state_dict, refit=True)``: warm-start refit
with a state dict of parameters (still re-optimize model parameters
and set the training data),
* ``fit(state_dict=some_state_dict, refit=False)``: load model parameters
without refitting, but set new training data (used in cross-validation,
for example).
Args:
datasets: A list of ``SupervisedDataset`` containers, each
corresponding to the data of one metric (outcome), to be passed
to ``Model.construct_inputs`` in BoTorch.
search_space_digest: A ``SearchSpaceDigest`` object containing
metadata on the features in the datasets.
candidate_metadata: Model-produced metadata for candidates, in
the order corresponding to the Xs.
state_dict: Optional state dict to load.
refit: Whether to re-optimize model parameters.
repeat_model_selection_if_dataset_changed: Whether to repeat model
selection, ignoring previously found best config, if the dataset
for the corresponding outcomes has changed. This is typically
set to `True` when called from ``BoTorchGenerator.fit`` but set
to `False` when called from ``BoTorchGenerator.cross_validate``.
During cross_validation, we want to evaluate the quality of the
previously selected best model, rather than repeating model selection
for each fold.
"""
self._discard_cached_model_and_data_if_search_space_digest_changed(
search_space_digest=search_space_digest
)
# Deepcopy so that we are not making in-place changes when fill default values
metric_to_model_configs = deepcopy(self.surrogate_spec.metric_to_model_configs)
# To determine whether to use ModelList under the hood, we need to check for
# the batched multi-output case, so we first see which model would be chosen
# given the Yvars and the properties of data.
should_use_model_list = use_model_list(
datasets=datasets,
search_space_digest=search_space_digest,
model_configs=self.surrogate_spec.model_configs,
allow_batched_models=self.surrogate_spec.allow_batched_models,
metric_to_model_configs=metric_to_model_configs,
)
if not should_use_model_list and len(datasets) > 1:
try:
datasets = convert_to_block_design(datasets=datasets, force=False)
except UnsupportedError as e:
# If the block design conversion fails, use model-list.
logger.warning(
"Conversion to block design failed. Using model-list instead. "
f"Original error: {e}"
)
should_use_model_list = True
self._training_data = list(datasets) # So that it can be modified if needed.
feature_names_set = set(search_space_digest.feature_names)
models = []
outcome_names = []
for i, dataset in enumerate(datasets):
submodel_state_dict = None
if state_dict is not None:
if should_use_model_list:
submodel_state_dict = subset_state_dict(
state_dict=state_dict, submodel_index=i
)
else:
submodel_state_dict = state_dict
outcome_name_tuple = tuple(dataset.outcome_names)
first_outcome_name = outcome_name_tuple[0]
# If no model config is specified for the (likely aux) preference dataset,
# fallback to the default model config to avoid model fitting failure
if (
isinstance(dataset, RankingDataset)
and first_outcome_name not in metric_to_model_configs
):
metric_to_model_configs[first_outcome_name] = [ModelConfig()]
model_configs = (
metric_to_model_configs[first_outcome_name]
if first_outcome_name in metric_to_model_configs
else self.surrogate_spec.model_configs
)
# Case 1: There is either 1 model config, or we don't want to re-do
# model selection and we know what the previous best model was.
if (
not repeat_model_selection_if_dataset_changed
or self._dataset_matches_cache(dataset=dataset)
):
# Re-use the best model config, if the dataset hasn't changed or
# `repeat_model_selection_if_dataset_changed` is set to `False`.
model_config = self.metric_to_best_model_config.get(first_outcome_name)
else:
model_config = None
# Model selection is not performed if the best `ModelConfig` has already
# been identified (as specified in `metric_to_best_model_config`).
# The reason for doing this is to support the following flow:
# - Fit model to data and perform model selection, refitting on each fold
# if `refit_on_cv=True`. This will set the best ModelConfig in
# metric_to_best_model_config.
# - Evaluate the choice of model/visualize its performance via
# `Adapter.cross_validate``. This also will refit on each fold if
# `refit_on_cv=True`, but we wouldn't want to perform model selection
# on each fold, but rather show the performance of the selecting
# `ModelConfig`` since that is what will be used.
if len(model_configs) == 1 or (model_config is not None):
best_model_config = model_config or model_configs[0]
model = self._construct_model(
dataset=dataset,
search_space_digest=search_space_digest,
model_config=best_model_config,
state_dict=submodel_state_dict,
refit=refit,
)
# Case 2: There is more than 1 model config and we want to refit
# or don't know what the previous best model was
else:
if len(dataset.outcome_names) > 1:
raise UnsupportedError(
"Multiple model configs are not supported with datasets that"
" contain multiple outcomes. Each dataset must contain only "
"one outcome."
)
model, best_model_config = self.model_selection(
dataset=dataset,
model_configs=model_configs,
search_space_digest=search_space_digest,
candidate_metadata=candidate_metadata,
)
# Only update the outcome names and models if the dataset input
# matches the feature names from the SSD. In heterogeneous TL,
# _expand_ssd_to_joint_space adds source-only features to the SSD,
# so the target MultiTaskDataset's feature_names will be a strict
# subset -- the missing names are source-only params.
if set(dataset.feature_names) == feature_names_set or (
isinstance(dataset, MultiTaskDataset)
and set(dataset.feature_names).issubset(feature_names_set)
):
models.append(model)
outcome_names.extend(dataset.outcome_names)
# store best model config, model, and dataset
for metric_signature in dataset.outcome_names:
self.metric_to_best_model_config[metric_signature] = none_throws(
best_model_config
)
self._submodels[outcome_name_tuple] = model
self._last_datasets[outcome_name_tuple] = dataset
if should_use_model_list:
if all(isinstance(model, GPyTorchModel) for model in models):
self._model = ModelListGP(*models)
else:
self._model = ModelList(*models)
else:
self._model = models[0]
self._outcomes = outcome_names # In the order of input datasets
def model_selection(
self,
dataset: SupervisedDataset,
model_configs: list[ModelConfig],
search_space_digest: SearchSpaceDigest,
candidate_metadata: list[list[TCandidateMetadata]] | None = None,
) -> tuple[Model, ModelConfig]:
"""Perform model selection over a list of model configs.
This selects the best botorch Model across the provided model configs
based on the SurrogateSpec's eval_criteria. The eval_criteria is
computed using LOOCV on the provided dataset. The best model config is saved
in self.metric_to_best_model_config for future use (e.g. for using cross-
validation at the Adapter level).
Args:
dataset: Training data for the model
model_configs: The model_configs.
default_botorch_model_class: The default ``Model`` class to be used as
the default, if no botorch_model_class is specified in the
model_config.
search_space_digest: Search space digest.
candidate_metadata: Model-produced metadata for candidates.
Returns:
A two element tuple containing:
- The best model according to the eval_criterion.
- The ModelConfig for the best model.
"""
if (
isinstance(dataset, MultiTaskDataset)
and assert_is_instance(dataset, MultiTaskDataset).has_heterogeneous_features
):
raise UnsupportedError(
"Model selection is not supported for datasets with heterogeneous "
"features."
)
# loop over model configs, fit model for each config, perform LOOCV, select
# best model according to specified criterion
maximize = (
MODEL_SELECTION_METRIC_DIRECTIONS[self.surrogate_spec.eval_criterion]
== ModelFitMetricDirection.MAXIMIZE
)
prefix = "-" if maximize else ""
best_eval_metric = float(f"{prefix}inf")
best_model = None
best_model_config = None
outcome_name = dataset.outcome_names[0]
self._model_name_to_eval[outcome_name] = {}
self._model_name_to_model[outcome_name] = {}
for model_config in model_configs:
# fit model to all data
try:
model = self._construct_model(
dataset=dataset,
search_space_digest=search_space_digest,
model_config=model_config,
state_dict=None,
refit=True,
)
state_dict = cast(OrderedDict[str, Tensor], model.state_dict())
# perform LOOCV
if self.surrogate_spec.eval_criterion in (AIC, BIC, MLL):
eval_metric = compute_in_sample_model_fit_metric(
model=assert_is_instance(model, ExactGP),
criterion=self.surrogate_spec.eval_criterion,
)
else:
eval_metric = self.cross_validate(
dataset=dataset,
search_space_digest=search_space_digest,
model_config=model_config,
state_dict=state_dict,
)
except ModelFittingError as e:
logger.warning(
f"Model {model_config} failed to fit with error {e}. Skipping."
)
continue
self._model_name_to_eval[outcome_name][model_config.identifier] = (
eval_metric
)
self._model_name_to_model[outcome_name][model_config.identifier] = model
if maximize ^ (eval_metric < best_eval_metric):
best_eval_metric = eval_metric
best_model = model
best_model_config = model_config
if best_model is None:
raise AxError(
"No model configs were able to fit the data. Please check your "
"model configs and/or data."
)
return none_throws(best_model), none_throws(best_model_config)
def cross_validate(
self,
dataset: SupervisedDataset,
model_config: ModelConfig,
search_space_digest: SearchSpaceDigest,
state_dict: OrderedDict[str, Tensor] | None = None,
) -> float:
"""Cross-validation for a single outcome.
Args:
dataset: Training data for the model (for one outcome for
the default `Surrogate`, with the exception of batched
multi-output case, where training data is formatted with just
one X and concatenated Ys).
model_config: The model_config.
search_space_digest: Search space digest used to set up model arguments.
state_dict: Optional state dict to load.
Returns:
The eval criterion value for the given model config.
"""
if isinstance(dataset, MultiTaskDataset):
# only evaluate model on target task
target_dataset = dataset.datasets[dataset.target_outcome_name]
else:
target_dataset = dataset
X, Y = target_dataset.X, target_dataset.Y
num_folds = self.surrogate_spec.num_folds
if num_folds is None or num_folds > X.shape[0]:
num_folds = X.shape[0]
test_folds = np.array_split(
ary=torch.arange(X.shape[0], device=X.device),
indices_or_sections=num_folds,
)
cv_folds = (
get_cv_fold(
dataset=dataset,
X=X,
Y=Y,
idcs=torch.as_tensor(idcs, device=X.device),
)
for idcs in test_folds
)
pred_Y = []
pred_Yvar = []
obs_Y = []
for fold in cv_folds:
# Since each CV fold removes points from the training data, the
# remaining observations will not pass the input scaling checks.
# To avoid confusing users with warnings, we disable these checks.
with validate_input_scaling(False):
loo_model = self._construct_model(
dataset=fold.train_dataset,
search_space_digest=search_space_digest,
model_config=model_config,
state_dict=state_dict,
refit=self.refit_on_cv,
)
# evaluate model
with torch.no_grad():
posterior = loo_model.posterior(
fold.test_X,
observation_noise=self.surrogate_spec.use_posterior_predictive,
)
# TODO: support non-GPyTorch posteriors
posterior = assert_is_instance(posterior, GPyTorchPosterior)
if isinstance(posterior, GaussianMixturePosterior):
pred_mean = posterior.mixture_mean
pred_var = posterior.mixture_variance
else:
pred_mean = posterior.mean
pred_var = posterior.variance
pred_Y.append(pred_mean)
pred_Yvar.append(pred_var)
obs_Y.append(fold.test_Y)
# Stack results
pred_Y = torch.cat(pred_Y)
pred_Yvar = torch.cat(pred_Yvar)
obs_Y = torch.cat(obs_Y)
# evaluate model fit metric
diag_fn = DIAGNOSTIC_FNS[none_throws(self.surrogate_spec.eval_criterion)]
return diag_fn(
y_obs=obs_Y.view(-1).cpu().numpy(),
y_pred=pred_Y.view(-1).cpu().numpy(),
se_pred=np.sqrt(pred_Yvar.view(-1).cpu().numpy()),
)
def _discard_cached_model_and_data_if_search_space_digest_changed(
self, search_space_digest: SearchSpaceDigest
) -> None:
"""Checks whether the search space digest has changed since the last call
to `fit`. If it has, discards cached model and datasets. Also updates
`self._last_search_space_digest` for future checks.
"""
if (
self._last_search_space_digest is not None
and search_space_digest != self._last_search_space_digest
):
logger.debug(
"Discarding all previously trained models due to a change "
"in the search space digest."
)
self._submodels = {}
self._last_datasets = {}
self.metric_to_best_model_config = {}
self._last_search_space_digest = search_space_digest
def predict(
self, X: Tensor, use_posterior_predictive: bool = False
) -> tuple[Tensor, Tensor]:
"""Predicts outcomes given an input tensor.
Args:
X: A ``n x d`` tensor of input parameters.
use_posterior_predictive: A boolean indicating if the predictions
should be from the posterior predictive (i.e. including
observation noise).
Returns:
Tensor: The predicted posterior mean as an ``n x o``-dim tensor.
Tensor: The predicted posterior covariance as a ``n x o x o``-dim tensor.
"""
return predict_from_model(
model=self.model, X=X, use_posterior_predictive=use_posterior_predictive
)
def best_in_sample_point(
self,
search_space_digest: SearchSpaceDigest,
torch_opt_config: TorchOptConfig,
options: TConfig | None = None,
) -> tuple[Tensor, float]:
"""Finds the best observed point and the corresponding observed outcome
values.