66
77import torch
88from botorch .models .kernels .positive_index import PositiveIndexKernel
9+ from botorch .models .utils .priors import BetaPrior
10+ from botorch .optim .utils import sample_all_priors
911from botorch .utils .testing import BotorchTestCase
10- from gpytorch .priors import NormalPrior
12+ from gpytorch .priors import NormalPrior , UniformPrior
1113
1214
1315class TestPositiveIndexKernel (BotorchTestCase ):
@@ -125,18 +127,15 @@ def test_positive_index_kernel(self):
125127 with self .subTest ("with_priors" , dtype = dtype ):
126128 num_tasks = 4
127129 task_prior = NormalPrior (0 , 1 )
128- diag_prior = NormalPrior (1 , 0.1 )
129130
130131 kernel = PositiveIndexKernel (
131132 num_tasks = num_tasks ,
132133 rank = 2 ,
133134 task_prior = task_prior ,
134- diag_prior = diag_prior ,
135135 initialize_to_mode = False ,
136136 ).to (dtype = dtype )
137137 prior_names = [p [0 ] for p in kernel .named_priors ()]
138138 self .assertIn ("IndexKernelPrior" , prior_names )
139- self .assertIn ("ScalePrior" , prior_names )
140139
141140 # Test batch forward
142141 with self .subTest ("batch_forward" , dtype = dtype ):
@@ -154,25 +153,6 @@ def test_positive_index_kernel(self):
154153 # Check that batch dimensions are preserved
155154 self .assertEqual (result .shape [0 ], 2 )
156155
157- # Test diagonal property (default target_task_index=0)
158- with self .subTest ("diagonal" , dtype = dtype ):
159- kernel = PositiveIndexKernel (num_tasks = 4 , rank = 2 ).to (dtype = dtype )
160- diag = kernel ._diagonal
161-
162- self .assertEqual (diag .shape , torch .Size ([4 ]))
163- # First diagonal element should be 1.0 (default target_task_index=0)
164- self .assertAllClose (diag [0 ], torch .tensor (1.0 , dtype = dtype ), atol = 1e-4 )
165-
166- # Test diagonal property with custom target_task_index
167- kernel = PositiveIndexKernel (
168- num_tasks = 4 , rank = 2 , target_task_index = 1
169- ).to (dtype = dtype )
170- diag = kernel ._diagonal
171-
172- self .assertEqual (diag .shape , torch .Size ([4 ]))
173- # Second diagonal element should be 1.0 (target_task_index=1)
174- self .assertAllClose (diag [1 ], torch .tensor (1.0 , dtype = dtype ), atol = 1e-4 )
175-
176156 # Test lower triangle property
177157 with self .subTest ("lower_triangle" , dtype = dtype ):
178158 num_tasks = 5
@@ -222,3 +202,153 @@ def test_positive_index_kernel(self):
222202 new_value = torch .ones (3 , 2 , dtype = dtype ) * 3.0
223203 kernel ._covar_factor_closure (kernel , new_value )
224204 self .assertAllClose (kernel .covar_factor , new_value , atol = 1e-5 )
205+
206+ # Test _set_lower_triangle_corr produces valid covariance
207+ with self .subTest ("set_lower_triangle_corr" , dtype = dtype ):
208+ kernel = PositiveIndexKernel (num_tasks = 3 , rank = 3 ).to (dtype = dtype )
209+ target_corr = torch .tensor ([0.8 , 0.5 , 0.6 ], dtype = dtype )
210+ kernel ._set_lower_triangle_corr (target_corr )
211+
212+ # Covariance matrix should be PD and symmetric
213+ covar = kernel .covar_matrix
214+ eigvals = torch .linalg .eigvalsh (covar )
215+ self .assertTrue ((eigvals > 0 ).all ())
216+ self .assertAllClose (covar , covar .T , atol = 1e-5 )
217+
218+ # Recovered correlations should be positive
219+ recovered = kernel ._lower_triangle_corr
220+ self .assertTrue ((recovered >= 0 ).all ())
221+ self .assertTrue ((recovered <= 1 ).all ())
222+
223+ # Test _set_lower_triangle_corr with batch shape
224+ with self .subTest ("set_lower_triangle_corr_batch" , dtype = dtype ):
225+ batch_shape = torch .Size ([2 ])
226+ kernel = PositiveIndexKernel (
227+ num_tasks = 3 , rank = 3 , batch_shape = batch_shape
228+ ).to (dtype = dtype )
229+ target_corr = torch .rand (* batch_shape , 3 , dtype = dtype )
230+ kernel ._set_lower_triangle_corr (target_corr )
231+ covar = kernel .covar_matrix
232+ eigvals = torch .linalg .eigvalsh (covar )
233+ self .assertTrue ((eigvals > 0 ).all ())
234+ self .assertEqual (covar .shape , torch .Size ([2 , 3 , 3 ]))
235+
236+ # Test sample_all_priors with batch shape
237+ with self .subTest ("sample_all_priors_batch" , dtype = dtype ):
238+ batch_shape = torch .Size ([2 ])
239+ task_prior = UniformPrior (0.0 , 1.0 )
240+ kernel = PositiveIndexKernel (
241+ num_tasks = 3 ,
242+ rank = 3 ,
243+ task_prior = task_prior ,
244+ batch_shape = batch_shape ,
245+ ).to (dtype = dtype )
246+ sample_all_priors (kernel )
247+ covar = kernel .covar_matrix
248+ eigvals = torch .linalg .eigvalsh (covar )
249+ self .assertTrue ((eigvals > 0 ).all ())
250+ self .assertEqual (covar .shape , torch .Size ([2 , 3 , 3 ]))
251+
252+ # Test _set_lower_triangle_corr with scalar input (under-batched)
253+ with self .subTest ("set_lower_triangle_corr_scalar" , dtype = dtype ):
254+ batch_shape = torch .Size ([2 ])
255+ kernel = PositiveIndexKernel (
256+ num_tasks = 3 , rank = 3 , batch_shape = batch_shape
257+ ).to (dtype = dtype )
258+ # Scalar value — exercises dim()==0 branch
259+ kernel ._set_lower_triangle_corr (torch .tensor (0.5 , dtype = dtype ))
260+ covar = kernel .covar_matrix
261+ eigvals = torch .linalg .eigvalsh (covar )
262+ self .assertTrue ((eigvals > 0 ).all ())
263+ self .assertEqual (covar .shape , torch .Size ([2 , 3 , 3 ]))
264+
265+ # Test _set_lower_triangle_corr with unbatched input on batched kernel
266+ with self .subTest (
267+ "set_lower_triangle_corr_unbatched_on_batch" , dtype = dtype
268+ ):
269+ batch_shape = torch .Size ([2 ])
270+ kernel = PositiveIndexKernel (
271+ num_tasks = 3 , rank = 3 , batch_shape = batch_shape
272+ ).to (dtype = dtype )
273+ # 1D input with correct n_lower but no batch — exercises expand branch
274+ target_corr = torch .rand (3 , dtype = dtype )
275+ kernel ._set_lower_triangle_corr (target_corr )
276+ covar = kernel .covar_matrix
277+ eigvals = torch .linalg .eigvalsh (covar )
278+ self .assertTrue ((eigvals > 0 ).all ())
279+ self .assertEqual (covar .shape , torch .Size ([2 , 3 , 3 ]))
280+
281+ # Test _set_lower_triangle_corr with boundary values
282+ with self .subTest ("set_lower_triangle_corr_boundary" , dtype = dtype ):
283+ kernel = PositiveIndexKernel (num_tasks = 2 , rank = 2 ).to (dtype = dtype )
284+ kernel ._set_lower_triangle_corr (torch .tensor ([0.0 ], dtype = dtype ))
285+ self .assertTrue (kernel ._lower_triangle_corr .isfinite ().all ())
286+ kernel ._set_lower_triangle_corr (torch .tensor ([0.999 ], dtype = dtype ))
287+ self .assertTrue (kernel ._lower_triangle_corr .isfinite ().all ())
288+
289+ # Test _set_lower_triangle_corr with non-PD input
290+ with self .subTest ("set_lower_triangle_corr_non_pd" , dtype = dtype ):
291+ kernel = PositiveIndexKernel (num_tasks = 3 , rank = 3 ).to (dtype = dtype )
292+ # [0.99, 0.01, 0.99] does not form a PD correlation matrix
293+ non_pd_corr = torch .tensor ([0.99 , 0.01 , 0.99 ], dtype = dtype )
294+ kernel ._set_lower_triangle_corr (non_pd_corr )
295+ covar = kernel .covar_matrix
296+ eigvals = torch .linalg .eigvalsh (covar )
297+ self .assertTrue ((eigvals > 0 ).all ())
298+
299+ # Test roundtrip accuracy for full-rank
300+ with self .subTest ("set_lower_triangle_corr_roundtrip" , dtype = dtype ):
301+ kernel = PositiveIndexKernel (
302+ num_tasks = 3 , rank = 3 , unit_scale_for_target = False
303+ ).to (dtype = dtype )
304+ # Set var to small known value to isolate correlation effect
305+ kernel .initialize (raw_var = torch .full ((3 ,), - 5.0 , dtype = dtype ))
306+ target_corr = torch .tensor ([0.8 , 0.5 , 0.6 ], dtype = dtype )
307+ kernel ._set_lower_triangle_corr (target_corr )
308+ recovered = kernel ._lower_triangle_corr
309+ self .assertAllClose (recovered , target_corr , atol = 0.05 )
310+
311+ # Test sample_all_priors with task_prior
312+ with self .subTest ("sample_all_priors_unbatched" , dtype = dtype ):
313+ task_prior = UniformPrior (0.0 , 1.0 )
314+ kernel = PositiveIndexKernel (
315+ num_tasks = 3 ,
316+ rank = 3 ,
317+ task_prior = task_prior ,
318+ ).to (dtype = dtype )
319+
320+ corr_before = kernel ._lower_triangle_corr .clone ()
321+ sample_all_priors (kernel )
322+
323+ corr_after = kernel ._lower_triangle_corr
324+ self .assertFalse (torch .allclose (corr_before , corr_after ))
325+
326+ covar = kernel .covar_matrix
327+ eigvals = torch .linalg .eigvalsh (covar )
328+ self .assertTrue ((eigvals > 0 ).all ())
329+
330+ # Test with BetaPrior
331+ with self .subTest ("beta_prior" , dtype = dtype ):
332+ task_prior = BetaPrior (1.2 , 0.9 )
333+ kernel = PositiveIndexKernel (
334+ num_tasks = 4 ,
335+ rank = 4 ,
336+ task_prior = task_prior ,
337+ ).to (dtype = dtype )
338+ sample_all_priors (kernel )
339+ covar = kernel .covar_matrix
340+ eigvals = torch .linalg .eigvalsh (covar )
341+ self .assertTrue ((eigvals > 0 ).all ())
342+
343+ # Test sample_all_priors
344+ with self .subTest ("sample_all_priors" , dtype = dtype ):
345+ task_prior = UniformPrior (0.0 , 1.0 )
346+ kernel = PositiveIndexKernel (
347+ num_tasks = 3 ,
348+ rank = 3 ,
349+ task_prior = task_prior ,
350+ ).to (dtype = dtype )
351+ sample_all_priors (kernel )
352+ covar = kernel .covar_matrix
353+ eigvals = torch .linalg .eigvalsh (covar )
354+ self .assertTrue ((eigvals > 0 ).all ())
0 commit comments