-
Notifications
You must be signed in to change notification settings - Fork 372
Expand file tree
/
Copy pathutils.py
More file actions
996 lines (878 loc) · 39.9 KB
/
utils.py
File metadata and controls
996 lines (878 loc) · 39.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# pyre-strict
import warnings
from collections import OrderedDict
from collections.abc import Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass, field
from functools import cached_property
from logging import Logger
from typing import Any, cast
import torch
from ax.core.data import MAP_KEY
from ax.core.search_space import SearchSpaceDigest
from ax.exceptions.core import AxWarning, UnsupportedError, UserInputError
from ax.generators.torch.utils import extract_objectives
from ax.generators.torch_base import TorchOptConfig
from ax.generators.types import TConfig
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import _argparse_type_encoder
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.logei import (
qLogNoisyExpectedImprovement,
qLogProbabilityOfFeasibility,
)
from botorch.acquisition.multi_objective.logei import (
qLogNoisyExpectedHypervolumeImprovement,
)
from botorch.acquisition.multi_objective.parego import qLogNParEGO
from botorch.acquisition.preference import qExpectedUtilityOfBestOption
from botorch.exceptions.errors import BotorchError, CandidateGenerationError
from botorch.fit import fit_fully_bayesian_model_nuts, fit_gpytorch_mll
from botorch.models import PairwiseLaplaceMarginalLogLikelihood
from botorch.models.fully_bayesian import AbstractFullyBayesianSingleTaskGP
from botorch.models.fully_bayesian_multitask import SaasFullyBayesianMultiTaskGP
from botorch.models.gp_regression import SingleTaskGP
from botorch.models.gp_regression_fidelity import SingleTaskMultiFidelityGP
from botorch.models.gp_regression_mixed import MixedSingleTaskGP
from botorch.models.gpytorch import BatchedMultiOutputGPyTorchModel, GPyTorchModel
from botorch.models.heterogeneous_mtgp import HeterogeneousMTGP
from botorch.models.model import Model, ModelList
from botorch.models.multitask import MultiTaskGP
from botorch.models.pairwise_gp import PairwiseGP
from botorch.models.transforms.input import (
InputTransform,
LearnedFeatureImputation,
Normalize,
)
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.optim.parameter_constraints import (
evaluate_feasibility,
get_constraint_tolerance,
)
from botorch.optim.utils import columnwise_clamp
from botorch.utils.constraints import get_outcome_constraint_transforms
from botorch.utils.datasets import MultiTaskDataset, RankingDataset, SupervisedDataset
from botorch.utils.dispatcher import Dispatcher
from botorch.utils.types import _DefaultType, DEFAULT
from gpytorch.kernels.kernel import Kernel
from gpytorch.likelihoods import Likelihood
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.mlls.marginal_log_likelihood import MarginalLogLikelihood
from pyre_extensions import assert_is_instance, none_throws
from torch import Tensor
MIN_NUM_OBJECTIVES_PAREGO = 5
logger: Logger = get_logger(__name__)
@dataclass
class ModelConfig:
"""Configuration for the BoTorch Model used in Surrogate.
Args:
botorch_model_class: ``Model`` class to be used as the underlying
BoTorch model. If None is provided a model class will be selected (either
one for all outcomes or a ModelList with separate models for each outcome)
will be selected automatically based off the datasets at `construct` time.
model_options: Dictionary of options / kwargs for the BoTorch
``Model`` constructed during ``Surrogate.fit``.
Note that the corresponding attribute will later be updated to include any
additional kwargs passed into ``BoTorchGenerator.fit``.
mll_class: ``MarginalLogLikelihood`` class to use for model-fitting.
mll_options: Dictionary of options / kwargs for the MLL.
outcome_transform_classes: List of BoTorch outcome transforms classes. Passed
down to the BoTorch ``Model``. Multiple outcome transforms can be chained
together using ``ChainedOutcomeTransform``.
outcome_transform_options: Outcome transform classes kwargs. The keys are
class string names and the values are dictionaries of outcome transform
kwargs. For example,
`
outcome_transform_classes = [Standardize]
outcome_transform_options = {
"Standardize": {"m": 1},
`
For more options see `botorch/models/transforms/outcome.py`.
input_transform_classes: List of BoTorch input transforms classes.
Passed down to the BoTorch ``Model``. Multiple input transforms
will be chained together using ``ChainedInputTransform``.
If `DEFAULT`, a default set of input transforms may be constructed
based on the search space digest (in `_construct_default_input_transforms`).
To disable this behavior, pass in `input_transform_classes=None`.
input_transform_options: Input transform classes kwargs. The keys are
class string names and the values are dictionaries of input transform
kwargs. For example,
`
input_transform_classes = [Normalize, Round]
input_transform_options = {
"Normalize": {"d": 3},
"Round": {"integer_indices": [0], "categorical_features": {1: 2}},
}
`
For more input options see `botorch/models/transforms/input.py`.
covar_module_class: Covariance module class. This gets initialized after
parsing the ``covar_module_options`` in ``covar_module_argparse``,
and gets passed to the model constructor as ``covar_module``.
covar_module_options: Covariance module kwargs.
likelihood_class: ``Likelihood`` class. This gets initialized with
``likelihood_options`` and gets passed to the model constructor.
likelihood_options: Likelihood options.
name: Name of the model config. This is used to identify the model config.
"""
botorch_model_class: type[Model] | None = None
model_options: dict[str, Any] = field(default_factory=dict)
mll_class: type[MarginalLogLikelihood] | None = None
mll_options: dict[str, Any] = field(default_factory=dict)
input_transform_classes: list[type[InputTransform]] | _DefaultType | None = DEFAULT
input_transform_options: dict[str, dict[str, Any]] | None = field(
default_factory=dict
)
outcome_transform_classes: list[type[OutcomeTransform]] | None = None
outcome_transform_options: dict[str, dict[str, Any]] = field(default_factory=dict)
covar_module_class: type[Kernel] | None = None
covar_module_options: dict[str, Any] = field(default_factory=dict)
likelihood_class: type[Likelihood] | None = None
likelihood_options: dict[str, Any] = field(default_factory=dict)
name: str | None = None
@cached_property
def identifier(self) -> str:
"""Returns a unique identifier for the model config."""
return self.name if self.name is not None else str(self)
def use_model_list(
datasets: Sequence[SupervisedDataset],
search_space_digest: SearchSpaceDigest,
model_configs: list[ModelConfig] | None = None,
metric_to_model_configs: dict[str, list[ModelConfig]] | None = None,
allow_batched_models: bool = True,
) -> bool:
model_configs = model_configs or []
metric_to_model_configs = metric_to_model_configs or {}
if len(datasets) == 1 and datasets[0].Y.shape[-1] == 1:
# There is only one outcome, so we can use a single model.
return False
if (
len(model_configs) > 1
or len(metric_to_model_configs) > 0
or any(len(model_config) for model_config in metric_to_model_configs.values())
):
# There are multiple outcomes and outcomes might be modeled with different
# models
return True
if len({type(d) for d in datasets}) > 1:
# Use a `ModelList` if there are multiple dataset classes.
return True
if 0 < len([d.Yvar for d in datasets if d.Yvar is not None]) < len(datasets):
# Use a `ModelList` if some datasets have Yvar and some do not.
return True
botorch_model_class_set = {mc.botorch_model_class for mc in model_configs}
# if any of the botorch_model_class is unspecified, we'd need to infer its class
if (not botorch_model_class_set) or (None in botorch_model_class_set):
inferred_botorch_model_class_set = {
choose_model_class(dataset=dataset, search_space_digest=search_space_digest)
for dataset in datasets
}
botorch_model_class_set = botorch_model_class_set.union(
inferred_botorch_model_class_set
)
# Safe even if None is not present
botorch_model_class_set.discard(None)
if len(botorch_model_class_set) > 1:
# Use a `ModelList` if there are multiple possible botorch_model_class classes.
return True
# Otherwise, the same model class is used for all outcomes.
botorch_model_class = none_throws(next(iter(botorch_model_class_set)))
if getattr(botorch_model_class, "_supports_batched_models", None) is False:
# Models with _supports_batched_models = False do not support batching
# multiple metrics. Use model list if there are multiple outcomes.
return len(datasets) > 1 or datasets[0].Y.shape[-1] > 1
elif len(datasets) == 1:
# This method is called before multiple datasets are merged into
# one if using a batched model. If there is one dataset here,
# there should be a reason that a single model should be used:
# e.g. a contextual model, where we want to jointly model the metric
# each context (and context-level metrics are different outcomes).
return False
elif issubclass(botorch_model_class, BatchedMultiOutputGPyTorchModel) and all(
torch.equal(datasets[0].X, ds.X) for ds in datasets[1:]
):
# Use batch models if allowed
return not allow_batched_models
# If there are multiple Xs and they are not all equal, we use `ModelListGP`.
return True
def _ensure_input_transform(
model_config: ModelConfig,
transform_cls: type[InputTransform],
position: int | None = None,
) -> None:
"""Ensure ``transform_cls`` is in ``model_config.input_transform_classes``.
If the user hasn't specified any transforms (``DEFAULT``), initialise the
list with ``[transform_cls]``. Otherwise append (or insert at ``position``)
only when the class isn't already present. Mutates ``model_config``
in-place.
"""
itc = model_config.input_transform_classes
if isinstance(itc, list):
if transform_cls not in itc:
if position is not None:
itc.insert(position, transform_cls)
else:
itc.append(transform_cls)
else:
model_config.input_transform_classes = [transform_cls]
ito = model_config.input_transform_options or {}
ito.setdefault(transform_cls.__name__, {})
model_config.input_transform_options = ito
def copy_model_config_with_default_values(
model_config: ModelConfig,
dataset: SupervisedDataset,
search_space_digest: SearchSpaceDigest,
) -> ModelConfig:
model_config_copy = deepcopy(model_config)
# Always call choose_model_class to handle heterogeneous datasets.
model_config_copy.botorch_model_class = choose_model_class(
dataset=dataset,
search_space_digest=search_space_digest,
specified_model_class=model_config_copy.botorch_model_class,
)
# Handle heterogeneous multi-task datasets: ensure Normalize is present
# and add LearnedFeatureImputation for models that don't handle
# heterogeneity natively.
if isinstance(dataset, MultiTaskDataset) and dataset.has_heterogeneous_features:
_ensure_input_transform(model_config_copy, Normalize, position=0)
if model_config_copy.botorch_model_class is not None and not issubclass(
model_config_copy.botorch_model_class, HeterogeneousMTGP
):
_ensure_input_transform(model_config_copy, LearnedFeatureImputation)
if model_config_copy.mll_class is None:
model_config_copy.mll_class = (
PairwiseLaplaceMarginalLogLikelihood
if model_config_copy.botorch_model_class is PairwiseGP
else ExactMarginalLogLikelihood
)
# PairwiseGP does not use outcome transforms
if model_config_copy.outcome_transform_classes is not None:
if model_config_copy.botorch_model_class is PairwiseGP:
model_config_copy.outcome_transform_classes = None
return model_config_copy
def choose_model_class(
dataset: SupervisedDataset,
search_space_digest: SearchSpaceDigest,
specified_model_class: type[Model] | None = None,
) -> type[Model]:
"""Chooses a BoTorch `Model` class and `MarginalLogLikelihood` class
using the given the dataset and search_space_digest.
Args:
dataset: The dataset on which the model will be fitted.
search_space_digest: The digest of the search space the model will be
fitted within.
specified_model_class: If provided, this model class will be used.
Returns:
A BoTorch `Model` class.
"""
if len(search_space_digest.fidelity_features) > 1:
raise NotImplementedError(
"Only a single fidelity feature supported "
f"(got: {search_space_digest.fidelity_features})."
)
if len(search_space_digest.task_features) > 1:
raise NotImplementedError(
f"Only a single task feature supported "
f"(got: {search_space_digest.task_features})."
)
if search_space_digest.task_features and search_space_digest.fidelity_features:
raise NotImplementedError(
"Multi-task multi-fidelity optimization not yet supported."
)
# Check for heterogeneous multi-task datasets. If a model class was
# explicitly specified, respect it; otherwise default to MultiTaskGP
# (LearnedFeatureImputation handles missing features).
if (
search_space_digest.task_features
and isinstance(dataset, MultiTaskDataset)
and dataset.has_heterogeneous_features
):
model_class = (
specified_model_class if specified_model_class is not None else MultiTaskGP
)
logger.debug(f"Chose BoTorch model class: {model_class}.")
return model_class
# If a model class was specified and no override is needed, use it
if specified_model_class is not None:
logger.debug(f"Using specified BoTorch model class: {specified_model_class}.")
return specified_model_class
# Preference learning case
if isinstance(dataset, RankingDataset):
model_class = PairwiseGP
# Multi-task case (when `task_features` is specified).
elif search_space_digest.task_features:
model_class = MultiTaskGP
# Single-task multi-fidelity cases.
elif search_space_digest.fidelity_features:
model_class = SingleTaskMultiFidelityGP
# Mixed optimization case. Note that presence of categorical
# features in search space digest indicates that downstream in the
# stack we chose not to perform continuous relaxation on those
# features.
elif search_space_digest.categorical_features:
model_class = MixedSingleTaskGP
# Single-task single-fidelity cases.
else:
model_class = SingleTaskGP
logger.debug(f"Chose BoTorch model class: {model_class}.")
return model_class
def _objective_threshold_to_outcome_constraints(
objective_weights: Tensor,
objective_thresholds: Tensor,
) -> tuple[Tensor, Tensor]:
"""Convert objective thresholds to outcome constraint format ``(A, b)``.
For each objective ``i`` with non-NaN threshold ``t_i``, the constraint is
that the objective value must exceed the threshold in the maximization-
aligned space. Since thresholds are already maximization-aligned, the
constraint is: ``objective_weights[i] @ Y >= t_i``, which in standard
``A f(x) <= b`` format becomes ``-objective_weights[i] @ Y <= -t_i``.
Args:
objective_weights: A ``(n_objectives, n_outcomes)`` tensor of objective
weights.
objective_thresholds: A ``(n_objectives,)`` tensor of maximization-
aligned objective thresholds.
Returns:
A tuple ``(A, b)`` of outcome constraint tensors.
"""
# Filter to objectives with non-NaN thresholds. Objective thresholds
# can contain NaNs if the objective thresholds were inferred, but
# there are no feasible points. In that case,
# qLogProbabilityOfFeasibility is used.
non_nan_mask = ~objective_thresholds.isnan()
A = -objective_weights[non_nan_mask]
b = -objective_thresholds[non_nan_mask].unsqueeze(-1)
return A, b
def choose_botorch_acqf_class(
search_space_digest: SearchSpaceDigest,
torch_opt_config: TorchOptConfig,
datasets: Sequence[SupervisedDataset] | None,
use_p_feasible: bool = True,
) -> type[AcquisitionFunction]:
"""Chooses the most suitable BoTorch `AcquisitionFunction` class.
Args:
search_space_digest: The search space digest.
torch_opt_config: The torch optimization config.
datasets: The datasets that were used to fit the model.
use_p_feasible: Whether we dispatch to `qLogProbabilityOfFeasibility` when
there are no feasible points in the training data.
Returns:
A BoTorch `AcquisitionFunction` class. The current logic chooses between:
- `qLogProbabilityOfFeasibility` if no observed point simultaneously
satisfies all outcome constraints (if any) and dominates all
objective thresholds (if any, for MOO).
- `qExpectedUtilityOfBestOption` for single-objective
preference-only optimization (PBO / LILO) where the sole
objective is ``pairwise_pref_query`` with no constraints.
- `qLogNoisyExpectedImprovement` for single-objective optimization.
- `qLogNoisyExpectedHypervolumeImprovement` for multi-objective
optimization with <= 4 objectives.
- `qLogNParEGO` for multi-objective optimization with > 4 objectives
to prevent slow optimization.
"""
# Preferential BO (PBO): single-objective preference-only optimization
# (e.g. LILO). The sole objective is pairwise_pref_query with no
# constraints, so we optimize the preference model directly via qEUBO.
# This must come before the PFeasibility block to prevent PBO experiments
# from reaching convert_to_block_design (which is incompatible with
# RankingDatasets).
if (
not torch_opt_config.is_moo
and torch_opt_config.outcome_constraints is None
and Keys.PAIRWISE_PREFERENCE_QUERY.value in torch_opt_config.opt_config_metrics
):
acqf_class = qExpectedUtilityOfBestOption
logger.debug(f"Chose BoTorch acquisition function class: {acqf_class}.")
return acqf_class
if use_p_feasible and datasets is not None:
has_outcome_constraints = torch_opt_config.outcome_constraints is not None
has_objective_thresholds = (
torch_opt_config.is_moo
and torch_opt_config.objective_thresholds is not None
)
if has_outcome_constraints or has_objective_thresholds:
# NOTE: `convert_to_block_design` will drop points that are only
# observed by some of the metrics which is natural as we are using
# observed values to determine feasibility.
dataset = convert_to_block_design(
datasets=datasets,
force=True,
fixed_features=torch_opt_config.fixed_features,
fix_map_key_to_target=True,
)[0]
# Start with all points considered feasible.
is_feasible = torch.ones(
dataset.Y.shape[0], dtype=torch.bool, device=dataset.Y.device
)
# Check feasibility w.r.t. outcome constraints.
if has_outcome_constraints:
con_tfs = (
get_outcome_constraint_transforms(
torch_opt_config.outcome_constraints
)
or []
)
con_observed = torch.stack([con(dataset.Y) for con in con_tfs], dim=-1)
is_feasible = is_feasible & (con_observed <= 0).all(dim=-1)
# Check domination w.r.t. objective thresholds.
if has_objective_thresholds:
obj_weights = torch_opt_config.objective_weights
obj_thresholds = none_throws(torch_opt_config.objective_thresholds)
obj_idcs, weights = extract_objectives(obj_weights)
non_nan_mask = ~obj_thresholds.isnan()
if non_nan_mask.any():
# Convert observations to maximization-aligned objective
# values and compare against thresholds (already aligned).
weighted_Y = dataset.Y[:, obj_idcs] * weights
is_feasible = is_feasible & (
(
weighted_Y[:, non_nan_mask] >= obj_thresholds[non_nan_mask]
).all(dim=-1)
)
if not is_feasible.any().item():
# NOTE: Adding a new acqf class here requires a
# corresponding update in
# get_botorch_objective_and_transform.
acqf_class = qLogProbabilityOfFeasibility
logger.debug(f"Chose BoTorch acquisition function class: {acqf_class}.")
return acqf_class
if torch_opt_config.is_moo and not torch_opt_config.use_learned_objective:
# For MOO problems with > 4 objectives, use ParEGO to prevent slow optimization.
if torch_opt_config.objective_weights.shape[0] >= MIN_NUM_OBJECTIVES_PAREGO:
acqf_class = qLogNParEGO
else:
acqf_class = qLogNoisyExpectedHypervolumeImprovement
else:
acqf_class = qLogNoisyExpectedImprovement
logger.debug(f"Chose BoTorch acquisition function class: {acqf_class}.")
return acqf_class
def construct_acquisition_and_optimizer_options(
acqf_options: TConfig,
botorch_acqf_options: TConfig,
model_gen_options: TConfig | None = None,
botorch_acqf_classes_with_options: (
list[tuple[type[AcquisitionFunction], TConfig]] | None
) = None,
) -> tuple[
TConfig, TConfig, TConfig, list[tuple[type[AcquisitionFunction], TConfig]] | None
]:
"""Extract acquisition and optimizer options from `model_gen_options`."""
acq_options = acqf_options.copy()
botorch_acqf_options = botorch_acqf_options.copy()
opt_options = {}
if model_gen_options:
# Define the allowed paths
if (
len(
extra_keys_in_model_gen_options := set(model_gen_options.keys())
- {
Keys.OPTIMIZER_KWARGS.value,
Keys.ACQF_KWARGS.value,
Keys.AX_ACQUISITION_KWARGS.value,
# Keys for candidate generation
"in_sample",
"sampling_strategy_class",
"sampling_strategy_kwargs",
}
)
> 0
):
raise ValueError(
"Found forbidden keys in `model_gen_options`: "
f"{extra_keys_in_model_gen_options}."
)
new_botorch_acqf_options: dict[str, Any] = assert_is_instance(
model_gen_options.get(Keys.ACQF_KWARGS, {}),
dict,
)
if new_botorch_acqf_options and botorch_acqf_classes_with_options is not None:
if len(botorch_acqf_classes_with_options) > 1:
warnings.warn(
message="botorch_acqf_options are being ignored, due to using "
"MultiAcquisition. Specify options for each acquisition function "
"via botorch_acqf_classes_with_options.",
category=AxWarning,
stacklevel=4,
)
else:
botorch_acqf_classes_with_options = deepcopy(
botorch_acqf_classes_with_options
)
botorch_acqf_classes_with_options[0][1].update(new_botorch_acqf_options)
else:
botorch_acqf_options.update(new_botorch_acqf_options)
acq_options.update(
assert_is_instance(
model_gen_options.get(Keys.AX_ACQUISITION_KWARGS, {}),
dict,
)
)
# TODO: Add this if all acq. functions accept the `subset_model`
# kwarg or opt for kwarg filtering.
# acq_options[SUBSET_MODEL] = model_gen_options.get(SUBSET_MODEL)
opt_options = assert_is_instance(
model_gen_options.get(Keys.OPTIMIZER_KWARGS, {}),
dict,
).copy()
return (
acq_options,
botorch_acqf_options,
opt_options,
botorch_acqf_classes_with_options,
)
def _fix_map_key_to_target(
Xs: list[Tensor],
feature_names: list[str],
fixed_features: dict[int, float] | None,
) -> list[Tensor]:
"""Fixes MAP_KEY feature to the target value in a list of tensors.
This is used to avoid points getting discarded due to metrics being observed
at different progressions.
Args:
Xs: A list of tensors to fix.
feature_names: The feature names corresponding to the columns in the tensors.
fixed_features: A dictionary mapping feature indices to fixed values.
If the index of MAP_KEY is not in the dictionary, this is a no-op.
Returns:
The tensors with MAP_KEY fixed to the target value (if applicable).
"""
try:
map_index = feature_names.index(MAP_KEY)
except ValueError:
return Xs
if fixed_features is not None and map_index in fixed_features:
map_value = fixed_features[map_index]
Xs = [X.clone() for X in Xs]
for X in Xs:
X[..., map_index] = map_value
return Xs
def convert_to_block_design(
datasets: Sequence[SupervisedDataset],
force: bool = False,
fixed_features: dict[int, float] | None = None,
fix_map_key_to_target: bool = False,
) -> list[SupervisedDataset]:
"""Converts a list of datasets to a single block-design dataset that contains
all outcomes.
Args:
datasets: A list of datasets to merge.
force: If True, will force conversion of data not complying to a block
design to block design by dropping observations that are not shared
between outcomes.
If only a subset of the outcomes have noise observations, all noise
observations will be dropped.
fixed_features: A dictionary mapping feature indices to fixed values. Used
to fix MAP_KEY to the target value if `fix_map_key_to_target` is True.
fix_map_key_to_target: If True, will fix MAP_KEY to the target value in
the datasets before merging them.
NOTE: This should not be done for modeling. It is only implemented to
support acquisition related utilities.
Returns:
A single element list containing the merged dataset.
"""
is_fixed = [ds.Yvar is not None for ds in datasets]
if any(is_fixed) and not all(is_fixed):
if force:
logger.debug(
"Only a subset of datasets have noise observations. "
"Dropping all noise observations since `force=True`. "
)
else:
raise UnsupportedError(
"Cannot convert mixed data with and without variance "
"observations to `block design`."
)
is_fixed = all(is_fixed)
Xs = [dataset.X for dataset in datasets]
for dset in datasets[1:]:
if dset.feature_names != datasets[0].feature_names:
raise ValueError(
"Feature names must be the same across all datasets, "
f"got {dset.feature_names} and {datasets[0].feature_names}"
)
if fix_map_key_to_target:
Xs = _fix_map_key_to_target(
Xs=Xs,
feature_names=datasets[0].feature_names,
fixed_features=fixed_features,
)
# Join the outcome names of datasets.
outcome_names = sum([ds.outcome_names for ds in datasets], [])
if len({X.shape for X in Xs}) != 1 or not all(
torch.equal(X, Xs[0]) for X in Xs[1:]
):
if not force:
raise UnsupportedError(
"Cannot convert data to non-block design data. "
"To force this and drop data not shared between "
"outcomes use `force=True`."
)
logger.debug(
"Forcing conversion of data not complying to a block design "
"to block design by dropping observations that are not shared "
"between outcomes."
)
X_shared, idcs_shared = _get_shared_rows(Xs=Xs)
Y = torch.cat([ds.Y[i] for ds, i in zip(datasets, idcs_shared)], dim=-1)
if is_fixed:
Yvar = torch.cat(
[none_throws(ds.Yvar)[i] for ds, i in zip(datasets, idcs_shared)],
dim=-1,
)
else:
Yvar = None
datasets = [
SupervisedDataset(
X=X_shared,
Y=Y,
Yvar=Yvar,
feature_names=datasets[0].feature_names,
outcome_names=outcome_names,
)
]
return datasets
# data complies to block design, can concat with impunity
Y = torch.cat([ds.Y for ds in datasets], dim=-1)
if is_fixed:
Yvar = torch.cat([none_throws(ds.Yvar) for ds in datasets], dim=-1)
else:
Yvar = None
datasets = [
SupervisedDataset(
X=Xs[0],
Y=Y,
Yvar=Yvar,
feature_names=datasets[0].feature_names,
outcome_names=outcome_names,
)
]
return datasets
def _get_shared_rows(Xs: list[Tensor]) -> tuple[Tensor, list[Tensor]]:
"""Extract shared rows from a list of tensors
Args:
Xs: A list of m two-dimensional tensors with shapes
`(n_1 x d), ..., (n_m x d)`. It is not required that
the `n_i` are the same.
Returns:
A two-tuple containing (i) a Tensor with the rows that are
shared between all the Tensors in `Xs`, and (ii) a list of
index tensors that indicate the location of these rows
in the respective elements of `Xs`.
"""
if any(X.ndim != 2 for X in Xs):
raise UserInputError("All inputs must be two-dimensional.")
idcs_shared = []
Xs_sorted = sorted(Xs, key=len)
X_shared = Xs_sorted[0].clone().unique(dim=0)
for X in Xs_sorted[1:]:
X_shared = X_shared[(X_shared == X.unsqueeze(-2)).all(dim=-1).any(dim=-2)]
# get indices
for X in Xs:
_, inverse_indices = X.unique(dim=0, return_inverse=True)
# set of unique original indices
unique_indices = set(
{v: i for i, v in enumerate(inverse_indices.tolist())}.values()
)
same = (X_shared == X.unsqueeze(-2)).all(dim=-1).any(dim=-1)
shared_idcs_list = []
for i, is_shared in enumerate(same):
if is_shared and i in unique_indices:
shared_idcs_list.append(i)
shared_idcs = torch.tensor(
shared_idcs_list, dtype=torch.long, device=X_shared.device
)
idcs_shared.append(shared_idcs)
return X_shared, idcs_shared
def subset_state_dict(
state_dict: Mapping[str, Tensor],
submodel_index: int,
) -> OrderedDict[str, Tensor]:
"""Get the state dict for a submodel from the state dict of a model list.
Args:
state_dict: A state dict.
submodel_index: The index of the submodel to extract.
Returns:
The state dict for the submodel.
"""
expected_substring = f"models.{submodel_index}."
len_substring = len(expected_substring)
new_items = [
(k[len_substring:], v)
for k, v in state_dict.items()
if k.startswith(expected_substring)
]
return OrderedDict(new_items)
# ----------------------- Model fitting helpers ----------------------- #
fit_botorch_model = Dispatcher(name="fit_botorch_model", encoder=_argparse_type_encoder)
@fit_botorch_model.register(ModelList)
def _fit_botorch_model_list(
model: Model,
mll_class: type[MarginalLogLikelihood],
mll_options: dict[str, Any] | None = None,
) -> None:
for m in cast(list[Model], model.models):
fit_botorch_model(m, mll_class=mll_class, mll_options=mll_options)
@fit_botorch_model.register(GPyTorchModel)
@fit_botorch_model.register(PairwiseGP)
def _fit_botorch_model_gpytorch(
model: GPyTorchModel | PairwiseGP,
mll_class: type[MarginalLogLikelihood],
mll_options: dict[str, Any] | None = None,
) -> None:
"""Fit a GPyTorch based BoTorch model."""
mll_options = mll_options or {}
mll = mll_class(likelihood=model.likelihood, model=model, **mll_options)
fit_gpytorch_mll(mll)
@fit_botorch_model.register(
(AbstractFullyBayesianSingleTaskGP, SaasFullyBayesianMultiTaskGP)
)
def _fit_botorch_model_fully_bayesian_nuts(
model: AbstractFullyBayesianSingleTaskGP | SaasFullyBayesianMultiTaskGP,
mll_class: type[MarginalLogLikelihood],
mll_options: dict[str, Any] | None = None,
) -> None:
mll_options = mll_options or {}
mll_options.setdefault("disable_progbar", True)
fit_fully_bayesian_model_nuts(model, **mll_options)
@fit_botorch_model.register(object)
def _fit_botorch_model_not_implemented(
model: Model,
mll_class: type[MarginalLogLikelihood],
mll_options: dict[str, Any] | None = None,
) -> None:
raise NotImplementedError(
f"fit_botorch_model is not implemented for {model.__class__.__name__}. "
"You can register a model fitting routine for it by adding new case "
"to the `fit_botorch_model` dispatcher. To do so, decorate a function "
"that accepts `model`, `mll_class` and `mll_options` inputs with "
f"`@fit_botorch_model.register({model.__class__.__name__})`."
)
@dataclass(frozen=True)
class CVFold:
"""
Args:
train_dataset: The training dataset for the fold.
test_X: The test inputs for the fold.
test_Y: The test outputs for the fold.
"""
train_dataset: SupervisedDataset
test_X: Tensor
test_Y: Tensor
def get_cv_fold(
dataset: SupervisedDataset, X: Tensor, Y: Tensor, idcs: Tensor
) -> CVFold:
train_mask = torch.ones(X.shape[0], dtype=torch.bool, device=X.device)
train_mask[idcs] = 0
return CVFold(
train_dataset=dataset.clone(mask=train_mask),
test_X=X[idcs],
test_Y=Y[idcs],
)
def get_all_task_values_from_ssd(search_space_digest: SearchSpaceDigest) -> list[int]:
"""Get all task values from a search space digest.
Args:
search_space_digest: The search space digest.
Returns:
A list of all task values.
"""
task_feature = search_space_digest.task_features[0]
task_bounds = search_space_digest.bounds[task_feature]
return list(range(int(task_bounds[0]), int(task_bounds[1] + 1)))
def _format_discrete_value(val: float, allowed_values: Sequence[float]) -> str:
"""Format a discrete value for display alongside allowed values.
If all allowed values are integers, formats val as int (via rounding).
Otherwise formats as float with 4 decimal places.
"""
if all(float(v).is_integer() for v in allowed_values):
return str(int(round(val)))
return f"{val:.4f}"
def validate_candidates(
candidates: Tensor,
bounds: Tensor,
discrete_choices: Mapping[int, Sequence[float]] | None,
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None,
feature_names: list[str] | None = None,
task_features: list[int] | None = None,
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
) -> None:
"""Validate candidates satisfy bounds, discrete, and linear constraints.
Args:
candidates: A `n x d`-dim Tensor of candidates to validate.
bounds: A `2 x d`-dim Tensor of lower and upper bounds.
discrete_choices: A mapping from parameter indices to allowed discrete values.
inequality_constraints: A list of tuples (indices, coefficients, rhs),
representing inequality constraints of the form
`sum_i (X[indices[i]] * coefficients[i]) >= rhs`.
feature_names: Optional list of feature names for better error messages.
task_features: Optional list of task feature indices to skip discrete value
validation for. Task features can be fixed to new task values via
fixed_features that are not in the search space's discrete_choices.
equality_constraints: A list of tuples (indices, coefficients, rhs),
representing equality constraints of the form
`sum_i (X[indices[i]] * coefficients[i]) = rhs`.
Raises:
CandidateGenerationError: If any candidate violates constraints.
"""
# 1. Bounds validation
try:
columnwise_clamp(
candidates, lower=bounds[0], upper=bounds[1], raise_on_violation=True
)
except BotorchError as e:
raise CandidateGenerationError(f"Candidate violates bounds: {e}")
# 2. Discrete value validation
task_features_set = set(task_features) if task_features else set()
if discrete_choices:
tol = get_constraint_tolerance(candidates.dtype)
for dim, allowed_values in discrete_choices.items():
# Skip task features as they can be fixed to new task values via
# fixed_features that are not in the search space's discrete_choices
if dim in task_features_set:
continue
allowed = torch.tensor(
allowed_values, device=candidates.device, dtype=candidates.dtype
)
candidate_vals = candidates[..., dim].flatten()
# Vectorized check: (num_candidates, num_allowed) -> any match per candidate
is_valid = torch.isclose(
candidate_vals.unsqueeze(-1), allowed.unsqueeze(0), atol=tol
).any(dim=-1)
if not is_valid.all():
invalid_idx = int(torch.where(~is_valid)[0][0].item())
val_float = candidate_vals[invalid_idx].item()
dim_name = feature_names[dim] if feature_names else f"dim {dim}"
raise CandidateGenerationError(
f"Invalid discrete value "
f"{_format_discrete_value(val_float, allowed_values)} for "
f"{dim_name}. Allowed: {list(allowed_values)}"
)
# 3. Inequality constraint validation
if inequality_constraints:
is_feasible = evaluate_feasibility(
X=candidates.unsqueeze(-2), # Add q dimension
inequality_constraints=inequality_constraints,
)
if not is_feasible.all():
infeasible_indices = torch.where(~is_feasible)[0].tolist()
raise CandidateGenerationError(
f"Candidates violate inequality constraints. "
f"Infeasible candidate indices: {infeasible_indices}. "
f"Number of constraints: {len(inequality_constraints)}."
)
# 4. Equality constraint validation
if equality_constraints:
is_feasible = evaluate_feasibility(
X=candidates.unsqueeze(-2), # Add q dimension
equality_constraints=equality_constraints,
)
if not is_feasible.all():
infeasible_indices = torch.where(~is_feasible)[0].tolist()
raise CandidateGenerationError(
f"Candidates violate equality constraints. "
f"Infeasible candidate indices: {infeasible_indices}. "
f"Number of constraints: {len(equality_constraints)}."
)