Skip to content

Commit 0c93a38

Browse files
Carl Hvarfnermeta-codesync[bot]
authored andcommitted
Add setting_closure to PositiveIndexKernel for prior support (meta-pytorch#3267)
Summary: Pull Request resolved: meta-pytorch#3267 `_set_lower_triangle_corr` method wired as the `setting_closure` in `register_prior`, enabling `sample_from_prior` / `sample_all_priors` to work with task correlation priors (e.g. BetaPrior). Values are validated to be in [0, 1] since PositiveIndexKernel enforces positive correlations. Reviewed By: sdaulton Differential Revision: D99841562 fbshipit-source-id: bf15dfdadfa917aa34c7c86b0ff29436d2ab7717
1 parent 8418343 commit 0c93a38

2 files changed

Lines changed: 197 additions & 35 deletions

File tree

botorch/models/kernels/positive_index.py

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ def __init__(
3030
num_tasks: int,
3131
rank: int = 1,
3232
task_prior: Prior | None = None,
33-
diag_prior: Prior | None = None,
3433
normalize_covar_matrix: bool = False,
3534
var_constraint: Interval | None = None,
3635
target_task_index: int = 0,
@@ -43,7 +42,6 @@ def __init__(
4342
num_tasks (int): Total number of indices.
4443
rank (int): Rank of the covariance matrix parameterization.
4544
task_prior (Prior, optional): Prior for the covariance matrix.
46-
diag_prior (Prior, optional): Prior for the diagonal elements.
4745
normalize_covar_matrix (bool): Whether to normalize the covariance matrix.
4846
target_task_index (int): Index of the task whose diagonal element should be
4947
normalized to 1. Defaults to 0 (first task).
@@ -88,11 +86,11 @@ def __init__(
8886
f"{type(task_prior).__name__}"
8987
)
9088
self.register_prior(
91-
"IndexKernelPrior", task_prior, lambda m: m._lower_triangle_corr
89+
"IndexKernelPrior",
90+
task_prior,
91+
lambda m: m._lower_triangle_corr,
92+
lambda m, v: m._set_lower_triangle_corr(v),
9293
)
93-
if diag_prior is not None:
94-
self.register_prior("ScalePrior", diag_prior, lambda m: m._diagonal)
95-
9694
self.register_constraint("raw_covar_factor", GreaterThan(0.0))
9795

9896
def _covar_factor_params(self, m):
@@ -127,15 +125,49 @@ def _lower_triangle_corr(self):
127125

128126
return low_tri
129127

130-
@property
131-
def _diagonal(self):
132-
return torch.diagonal(self.covar_matrix, dim1=-2, dim2=-1)
128+
def _set_lower_triangle_corr(self, value):
129+
"""Set covar_factor to produce the given lower-triangle correlations.
130+
131+
Assembles a symmetric correlation matrix from the lower-triangle values,
132+
then projects it to the nearest positive-definite correlation matrix via
133+
eigenvalue clamping before Cholesky decomposition. This guarantees the
134+
setter never fails even when independently sampled correlation values
135+
do not form a PD matrix.
136+
137+
Args:
138+
value: Tensor of lower-triangle correlation values.
139+
"""
140+
n = self.num_tasks
141+
eps = 1e-6
142+
n_lower = n * (n - 1) // 2
143+
lower_row, lower_col = torch.tril_indices(n, n, offset=-1)
144+
# Expand under-batched input (e.g. scalar from sample_from_prior)
145+
if value.dim() == 0 or (value.dim() == 1 and value.shape[0] != n_lower):
146+
value = value.unsqueeze(-1).expand(*self.batch_shape, n_lower)
147+
elif value.shape[:-1] != self.batch_shape:
148+
value = value.expand(*self.batch_shape, n_lower)
149+
batch_shape = value.shape[:-1]
150+
corr = (
151+
torch.eye(n, dtype=value.dtype, device=value.device)
152+
.expand(*batch_shape, n, n)
153+
.clone()
154+
)
155+
corr[..., lower_row, lower_col] = value.clamp(0.0, 1.0)
156+
corr[..., lower_col, lower_row] = value.clamp(0.0, 1.0)
157+
# Project to nearest PD correlation matrix via eigenvalue clamping
158+
eigvals, eigvecs = torch.linalg.eigh(corr)
159+
eigvals = eigvals.clamp(min=eps)
160+
corr = eigvecs @ torch.diag_embed(eigvals) @ eigvecs.transpose(-1, -2)
161+
# Re-normalize diagonals to 1
162+
d = corr.diagonal(dim1=-1, dim2=-2).sqrt()
163+
corr = corr / (d.unsqueeze(-1) * d.unsqueeze(-2))
164+
chol = torch.linalg.cholesky(corr)
165+
rank = self.raw_covar_factor.shape[-1]
166+
self._set_covar_factor(chol[..., :, :rank].clamp(min=eps))
133167

134168
def _eval_covar_matrix(self):
135169
cf = self.covar_factor
136-
covar = cf @ cf.transpose(-1, -2) + self.var * torch.eye(
137-
self.num_tasks, dtype=cf.dtype, device=cf.device
138-
)
170+
covar = cf @ cf.transpose(-1, -2) + torch.diag_embed(self.var)
139171
# Normalize by the target task's diagonal element
140172
if self.unit_scale_for_target:
141173
norm_factor = covar[..., self.target_task_index, self.target_task_index]

test/models/kernels/test_positive_index.py

Lines changed: 153 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@
66

77
import torch
88
from botorch.models.kernels.positive_index import PositiveIndexKernel
9+
from botorch.models.utils.priors import BetaPrior
10+
from botorch.optim.utils import sample_all_priors
911
from botorch.utils.testing import BotorchTestCase
10-
from gpytorch.priors import NormalPrior
12+
from gpytorch.priors import NormalPrior, UniformPrior
1113

1214

1315
class TestPositiveIndexKernel(BotorchTestCase):
@@ -125,18 +127,15 @@ def test_positive_index_kernel(self):
125127
with self.subTest("with_priors", dtype=dtype):
126128
num_tasks = 4
127129
task_prior = NormalPrior(0, 1)
128-
diag_prior = NormalPrior(1, 0.1)
129130

130131
kernel = PositiveIndexKernel(
131132
num_tasks=num_tasks,
132133
rank=2,
133134
task_prior=task_prior,
134-
diag_prior=diag_prior,
135135
initialize_to_mode=False,
136136
).to(dtype=dtype)
137137
prior_names = [p[0] for p in kernel.named_priors()]
138138
self.assertIn("IndexKernelPrior", prior_names)
139-
self.assertIn("ScalePrior", prior_names)
140139

141140
# Test batch forward
142141
with self.subTest("batch_forward", dtype=dtype):
@@ -154,25 +153,6 @@ def test_positive_index_kernel(self):
154153
# Check that batch dimensions are preserved
155154
self.assertEqual(result.shape[0], 2)
156155

157-
# Test diagonal property (default target_task_index=0)
158-
with self.subTest("diagonal", dtype=dtype):
159-
kernel = PositiveIndexKernel(num_tasks=4, rank=2).to(dtype=dtype)
160-
diag = kernel._diagonal
161-
162-
self.assertEqual(diag.shape, torch.Size([4]))
163-
# First diagonal element should be 1.0 (default target_task_index=0)
164-
self.assertAllClose(diag[0], torch.tensor(1.0, dtype=dtype), atol=1e-4)
165-
166-
# Test diagonal property with custom target_task_index
167-
kernel = PositiveIndexKernel(
168-
num_tasks=4, rank=2, target_task_index=1
169-
).to(dtype=dtype)
170-
diag = kernel._diagonal
171-
172-
self.assertEqual(diag.shape, torch.Size([4]))
173-
# Second diagonal element should be 1.0 (target_task_index=1)
174-
self.assertAllClose(diag[1], torch.tensor(1.0, dtype=dtype), atol=1e-4)
175-
176156
# Test lower triangle property
177157
with self.subTest("lower_triangle", dtype=dtype):
178158
num_tasks = 5
@@ -222,3 +202,153 @@ def test_positive_index_kernel(self):
222202
new_value = torch.ones(3, 2, dtype=dtype) * 3.0
223203
kernel._covar_factor_closure(kernel, new_value)
224204
self.assertAllClose(kernel.covar_factor, new_value, atol=1e-5)
205+
206+
# Test _set_lower_triangle_corr produces valid covariance
207+
with self.subTest("set_lower_triangle_corr", dtype=dtype):
208+
kernel = PositiveIndexKernel(num_tasks=3, rank=3).to(dtype=dtype)
209+
target_corr = torch.tensor([0.8, 0.5, 0.6], dtype=dtype)
210+
kernel._set_lower_triangle_corr(target_corr)
211+
212+
# Covariance matrix should be PD and symmetric
213+
covar = kernel.covar_matrix
214+
eigvals = torch.linalg.eigvalsh(covar)
215+
self.assertTrue((eigvals > 0).all())
216+
self.assertAllClose(covar, covar.T, atol=1e-5)
217+
218+
# Recovered correlations should be positive
219+
recovered = kernel._lower_triangle_corr
220+
self.assertTrue((recovered >= 0).all())
221+
self.assertTrue((recovered <= 1).all())
222+
223+
# Test _set_lower_triangle_corr with batch shape
224+
with self.subTest("set_lower_triangle_corr_batch", dtype=dtype):
225+
batch_shape = torch.Size([2])
226+
kernel = PositiveIndexKernel(
227+
num_tasks=3, rank=3, batch_shape=batch_shape
228+
).to(dtype=dtype)
229+
target_corr = torch.rand(*batch_shape, 3, dtype=dtype)
230+
kernel._set_lower_triangle_corr(target_corr)
231+
covar = kernel.covar_matrix
232+
eigvals = torch.linalg.eigvalsh(covar)
233+
self.assertTrue((eigvals > 0).all())
234+
self.assertEqual(covar.shape, torch.Size([2, 3, 3]))
235+
236+
# Test sample_all_priors with batch shape
237+
with self.subTest("sample_all_priors_batch", dtype=dtype):
238+
batch_shape = torch.Size([2])
239+
task_prior = UniformPrior(0.0, 1.0)
240+
kernel = PositiveIndexKernel(
241+
num_tasks=3,
242+
rank=3,
243+
task_prior=task_prior,
244+
batch_shape=batch_shape,
245+
).to(dtype=dtype)
246+
sample_all_priors(kernel)
247+
covar = kernel.covar_matrix
248+
eigvals = torch.linalg.eigvalsh(covar)
249+
self.assertTrue((eigvals > 0).all())
250+
self.assertEqual(covar.shape, torch.Size([2, 3, 3]))
251+
252+
# Test _set_lower_triangle_corr with scalar input (under-batched)
253+
with self.subTest("set_lower_triangle_corr_scalar", dtype=dtype):
254+
batch_shape = torch.Size([2])
255+
kernel = PositiveIndexKernel(
256+
num_tasks=3, rank=3, batch_shape=batch_shape
257+
).to(dtype=dtype)
258+
# Scalar value — exercises dim()==0 branch
259+
kernel._set_lower_triangle_corr(torch.tensor(0.5, dtype=dtype))
260+
covar = kernel.covar_matrix
261+
eigvals = torch.linalg.eigvalsh(covar)
262+
self.assertTrue((eigvals > 0).all())
263+
self.assertEqual(covar.shape, torch.Size([2, 3, 3]))
264+
265+
# Test _set_lower_triangle_corr with unbatched input on batched kernel
266+
with self.subTest(
267+
"set_lower_triangle_corr_unbatched_on_batch", dtype=dtype
268+
):
269+
batch_shape = torch.Size([2])
270+
kernel = PositiveIndexKernel(
271+
num_tasks=3, rank=3, batch_shape=batch_shape
272+
).to(dtype=dtype)
273+
# 1D input with correct n_lower but no batch — exercises expand branch
274+
target_corr = torch.rand(3, dtype=dtype)
275+
kernel._set_lower_triangle_corr(target_corr)
276+
covar = kernel.covar_matrix
277+
eigvals = torch.linalg.eigvalsh(covar)
278+
self.assertTrue((eigvals > 0).all())
279+
self.assertEqual(covar.shape, torch.Size([2, 3, 3]))
280+
281+
# Test _set_lower_triangle_corr with boundary values
282+
with self.subTest("set_lower_triangle_corr_boundary", dtype=dtype):
283+
kernel = PositiveIndexKernel(num_tasks=2, rank=2).to(dtype=dtype)
284+
kernel._set_lower_triangle_corr(torch.tensor([0.0], dtype=dtype))
285+
self.assertTrue(kernel._lower_triangle_corr.isfinite().all())
286+
kernel._set_lower_triangle_corr(torch.tensor([0.999], dtype=dtype))
287+
self.assertTrue(kernel._lower_triangle_corr.isfinite().all())
288+
289+
# Test _set_lower_triangle_corr with non-PD input
290+
with self.subTest("set_lower_triangle_corr_non_pd", dtype=dtype):
291+
kernel = PositiveIndexKernel(num_tasks=3, rank=3).to(dtype=dtype)
292+
# [0.99, 0.01, 0.99] does not form a PD correlation matrix
293+
non_pd_corr = torch.tensor([0.99, 0.01, 0.99], dtype=dtype)
294+
kernel._set_lower_triangle_corr(non_pd_corr)
295+
covar = kernel.covar_matrix
296+
eigvals = torch.linalg.eigvalsh(covar)
297+
self.assertTrue((eigvals > 0).all())
298+
299+
# Test roundtrip accuracy for full-rank
300+
with self.subTest("set_lower_triangle_corr_roundtrip", dtype=dtype):
301+
kernel = PositiveIndexKernel(
302+
num_tasks=3, rank=3, unit_scale_for_target=False
303+
).to(dtype=dtype)
304+
# Set var to small known value to isolate correlation effect
305+
kernel.initialize(raw_var=torch.full((3,), -5.0, dtype=dtype))
306+
target_corr = torch.tensor([0.8, 0.5, 0.6], dtype=dtype)
307+
kernel._set_lower_triangle_corr(target_corr)
308+
recovered = kernel._lower_triangle_corr
309+
self.assertAllClose(recovered, target_corr, atol=0.05)
310+
311+
# Test sample_all_priors with task_prior
312+
with self.subTest("sample_all_priors_unbatched", dtype=dtype):
313+
task_prior = UniformPrior(0.0, 1.0)
314+
kernel = PositiveIndexKernel(
315+
num_tasks=3,
316+
rank=3,
317+
task_prior=task_prior,
318+
).to(dtype=dtype)
319+
320+
corr_before = kernel._lower_triangle_corr.clone()
321+
sample_all_priors(kernel)
322+
323+
corr_after = kernel._lower_triangle_corr
324+
self.assertFalse(torch.allclose(corr_before, corr_after))
325+
326+
covar = kernel.covar_matrix
327+
eigvals = torch.linalg.eigvalsh(covar)
328+
self.assertTrue((eigvals > 0).all())
329+
330+
# Test with BetaPrior
331+
with self.subTest("beta_prior", dtype=dtype):
332+
task_prior = BetaPrior(1.2, 0.9)
333+
kernel = PositiveIndexKernel(
334+
num_tasks=4,
335+
rank=4,
336+
task_prior=task_prior,
337+
).to(dtype=dtype)
338+
sample_all_priors(kernel)
339+
covar = kernel.covar_matrix
340+
eigvals = torch.linalg.eigvalsh(covar)
341+
self.assertTrue((eigvals > 0).all())
342+
343+
# Test sample_all_priors
344+
with self.subTest("sample_all_priors", dtype=dtype):
345+
task_prior = UniformPrior(0.0, 1.0)
346+
kernel = PositiveIndexKernel(
347+
num_tasks=3,
348+
rank=3,
349+
task_prior=task_prior,
350+
).to(dtype=dtype)
351+
sample_all_priors(kernel)
352+
covar = kernel.covar_matrix
353+
eigvals = torch.linalg.eigvalsh(covar)
354+
self.assertTrue((eigvals > 0).all())

0 commit comments

Comments
 (0)