Skip to content

Commit b017b9c

Browse files
authored
Ensure LinaerKernel stores ard_num_dims property. (#2635)
[Fixes #2633]
1 parent da70269 commit b017b9c

File tree

2 files changed

+2
-1
lines changed

2 files changed

+2
-1
lines changed

gpytorch/kernels/linear_kernel.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def __init__(
5252
variance_constraint: Optional[Interval] = None,
5353
**kwargs,
5454
):
55-
super().__init__(**kwargs)
55+
super().__init__(ard_num_dims=ard_num_dims, **kwargs)
5656
if variance_constraint is None:
5757
variance_constraint = Positive()
5858
self.register_parameter(

test/kernels/test_linear_kernel.py

+1
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ class TestLinearKernelARD(TestLinearKernel):
9898
def test_kernel_ard(self) -> None:
9999
self.kernel_kwargs = {"ard_num_dims": 2}
100100
kernel = self.create_kernel_no_ard()
101+
self.assertEqual(kernel.ard_num_dims, 2)
101102
self.assertEqual(kernel.variance.shape, torch.Size([1, 2]))
102103

103104
def test_computes_linear_function_rectangular(self):

0 commit comments

Comments
 (0)