13
13
from botorch .models import SingleTaskGP
14
14
from botorch .models .fully_bayesian import SaasFullyBayesianSingleTaskGP
15
15
from botorch .models .transforms .outcome import Standardize
16
+ from botorch .sampling .normal import IIDNormalSampler
16
17
from botorch .utils .testing import BotorchTestCase
17
18
18
19
20
+ def get_model (
21
+ train_X ,
22
+ train_Y ,
23
+ standardize_model ,
24
+ ** tkwargs ,
25
+ ):
26
+ num_objectives = train_Y .shape [- 1 ]
27
+
28
+ if standardize_model :
29
+ outcome_transform = Standardize (m = num_objectives )
30
+ else :
31
+ outcome_transform = None
32
+
33
+ model = SingleTaskGP (
34
+ train_X = train_X ,
35
+ train_Y = train_Y ,
36
+ outcome_transform = outcome_transform ,
37
+ )
38
+
39
+ return model
40
+
41
+
19
42
def _get_mcmc_samples (num_samples : int , dim : int , infer_noise : bool , ** tkwargs ):
20
43
21
44
mcmc_samples = {
@@ -28,7 +51,7 @@ def _get_mcmc_samples(num_samples: int, dim: int, infer_noise: bool, **tkwargs):
28
51
return mcmc_samples
29
52
30
53
31
- def get_model (
54
+ def get_fully_bayesian_model (
32
55
train_X ,
33
56
train_Y ,
34
57
num_models ,
@@ -72,21 +95,26 @@ def test_q_bayesian_active_learning_by_disagreement(self):
72
95
tkwargs = {"device" : self .device }
73
96
num_objectives = 1
74
97
num_models = 3
98
+ input_dim = 2
99
+
100
+ X_pending_list = [None , torch .rand (2 , input_dim )]
75
101
for (
76
102
dtype ,
77
103
standardize_model ,
78
104
infer_noise ,
105
+ X_pending ,
79
106
) in product (
80
107
(torch .float , torch .double ),
81
108
(False , True ), # standardize_model
82
109
(True ,), # infer_noise - only one option avail in PyroModels
110
+ X_pending_list ,
83
111
):
112
+ X_pending = X_pending .to (** tkwargs ) if X_pending is not None else None
84
113
tkwargs ["dtype" ] = dtype
85
- input_dim = 2
86
114
train_X = torch .rand (4 , input_dim , ** tkwargs )
87
115
train_Y = torch .rand (4 , num_objectives , ** tkwargs )
88
116
89
- model = get_model (
117
+ model = get_fully_bayesian_model (
90
118
train_X ,
91
119
train_Y ,
92
120
num_models ,
@@ -96,32 +124,40 @@ def test_q_bayesian_active_learning_by_disagreement(self):
96
124
)
97
125
98
126
# test acquisition
99
- X_pending_list = [None , torch .rand (2 , input_dim , ** tkwargs )]
100
- for i in range (len (X_pending_list )):
101
- X_pending = X_pending_list [i ]
102
-
103
- acq = qBayesianActiveLearningByDisagreement (
104
- model = model ,
105
- X_pending = X_pending ,
106
- )
107
-
108
- test_Xs = [
109
- torch .rand (4 , 1 , input_dim , ** tkwargs ),
110
- torch .rand (4 , 3 , input_dim , ** tkwargs ),
111
- torch .rand (4 , 5 , 1 , input_dim , ** tkwargs ),
112
- torch .rand (4 , 5 , 3 , input_dim , ** tkwargs ),
113
- ]
114
-
115
- for j in range (len (test_Xs )):
116
- acq_X = acq .forward (test_Xs [j ])
117
- acq_X = acq (test_Xs [j ])
118
- # assess shape
119
- self .assertTrue (acq_X .shape == test_Xs [j ].shape [:- 2 ])
127
+ acq = qBayesianActiveLearningByDisagreement (
128
+ model = model ,
129
+ X_pending = X_pending ,
130
+ )
131
+
132
+ acq2 = qBayesianActiveLearningByDisagreement (
133
+ model = model , sampler = IIDNormalSampler (torch .Size ([9 ]))
134
+ )
135
+ self .assertIsInstance (acq2 .sampler , IIDNormalSampler )
136
+
137
+ test_Xs = [
138
+ torch .rand (4 , 1 , input_dim , ** tkwargs ),
139
+ torch .rand (4 , 3 , input_dim , ** tkwargs ),
140
+ torch .rand (4 , 5 , 1 , input_dim , ** tkwargs ),
141
+ torch .rand (4 , 5 , 3 , input_dim , ** tkwargs ),
142
+ torch .rand (5 , 13 , input_dim , ** tkwargs ),
143
+ ]
144
+
145
+ for j in range (len (test_Xs )):
146
+ acq_X = acq .forward (test_Xs [j ])
147
+ acq_X = acq (test_Xs [j ])
148
+ # assess shape
149
+ self .assertTrue (acq_X .shape == test_Xs [j ].shape [:- 2 ])
150
+
151
+ self .assertTrue (torch .all (acq_X > 0 ))
120
152
121
153
# Support with non-fully bayesian models is not possible. Thus, we
122
154
# throw an error.
123
- non_fully_bayesian_model = SingleTaskGP (train_X , train_Y )
124
- with self .assertRaises (ValueError ):
155
+ non_fully_bayesian_model = get_model (train_X , train_Y , False )
156
+ with self .assertRaisesRegex (
157
+ ValueError ,
158
+ "Fully Bayesian acquisition functions require a "
159
+ "SaasFullyBayesianSingleTaskGP to run." ,
160
+ ):
125
161
acq = qBayesianActiveLearningByDisagreement (
126
162
model = non_fully_bayesian_model ,
127
163
)
0 commit comments