Skip to content

Commit 1fc7148

Browse files
JasonKChowfacebook-github-bot
authored andcommitted
Change index kernel in mixed factory to just be multitask (#799)
Summary: Pull Request resolved: #799 The index kernel for the mixed factory now acts more like a multitask kernel as opposed to the cateogrical kernel which acts more like a true categorical kernel. Reviewed By: tymmsc Differential Revision: D74489044 fbshipit-source-id: 1c03c8bb5fd2065da3ed7f3d02f73654666bbb39
1 parent ff6db0b commit 1fc7148

2 files changed

Lines changed: 4 additions & 23 deletions

File tree

aepsych/factory/mixed.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,13 +109,10 @@ def _make_covar_module(self) -> gpytorch.kernels.Kernel:
109109
),
110110
)
111111
)
112-
add_kernel = gpytorch.kernels.AdditiveKernel(
113-
deepcopy(cont_kernel), *deepcopy(discrete_kernels)
114-
)
115-
prod_kernel = gpytorch.kernels.ProductKernel(
112+
113+
return gpytorch.kernels.ProductKernel(
116114
deepcopy(cont_kernel), *deepcopy(discrete_kernels)
117115
)
118-
return add_kernel * prod_kernel
119116
elif self.discrete_kernel == "categorical":
120117
constraint = gpytorch.constraints.GreaterThan(lower_bound=1e-4)
121118
discrete_kernel = botorch.models.kernels.CategoricalKernel(

tests/test_mean_covar_factories.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -703,32 +703,16 @@ def test_mixed_from_config(self):
703703
self.assertEqual(model.dim, 4)
704704
self.assertIsInstance(covar, gpytorch.kernels.ProductKernel)
705705

706-
# Check the additive part
707-
add_kernel = covar.kernels[0]
708-
self.assertIsInstance(add_kernel.kernels[0], gpytorch.kernels.RBFKernel)
709-
self.assertSequenceEqual(add_kernel.kernels[0].active_dims, (0, 3))
710-
self.assertEqual(len(add_kernel.kernels[1:]), 2)
711-
for kernel, index, rank in zip(add_kernel.kernels[1:], (1, 2), (2, 3)):
712-
self.assertIsInstance(kernel, gpytorch.kernels.IndexKernel)
713-
self.assertEqual(kernel.active_dims.item(), index)
714-
self.assertEqual(kernel.covar_factor.shape[1], rank)
715-
716-
# Check the product part
717-
cont_kernel = covar.kernels[1]
706+
cont_kernel = covar.kernels[0]
718707
self.assertIsInstance(cont_kernel, gpytorch.kernels.RBFKernel)
719708
self.assertSequenceEqual(cont_kernel.active_dims, (0, 3))
720709

721-
index_kernels = covar.kernels[2:]
710+
index_kernels = covar.kernels[1:]
722711
for kernel, index, rank in zip(index_kernels, (1, 2), (2, 3)):
723712
self.assertIsInstance(kernel, gpytorch.kernels.IndexKernel)
724713
self.assertEqual(kernel.active_dims.item(), index)
725714
self.assertEqual(kernel.covar_factor.shape[1], rank)
726715

727-
# Check there's copies and not duplicates
728-
self.assertNotEqual(add_kernel.kernels[0], cont_kernel)
729-
self.assertNotEqual(add_kernel.kernels[1], index_kernels[0])
730-
self.assertNotEqual(add_kernel.kernels[2], index_kernels[1])
731-
732716
def test_mixed_acquisition(self):
733717
def f_1d(x):
734718
"""

0 commit comments

Comments
 (0)