4444from gpytorch .kernels .kernel import ProductKernel
4545from gpytorch .kernels .scale_kernel import ScaleKernel
4646from gpytorch .likelihoods .likelihood import Likelihood
47+ from gpytorch .module import Module
4748from gpytorch .priors .torch_priors import GammaPrior
4849from torch import Tensor
4950
@@ -72,6 +73,7 @@ def __init__(
7273 data_fidelities : Sequence [int ] | None = None ,
7374 linear_truncated : bool = True ,
7475 nu : float = 2.5 ,
76+ covar_module : Module | None = None ,
7577 likelihood : Likelihood | None = None ,
7678 outcome_transform : OutcomeTransform | _DefaultType | None = DEFAULT ,
7779 input_transform : InputTransform | None = None ,
@@ -93,6 +95,8 @@ def __init__(
9395 of the default kernel.
9496 nu: The smoothness parameter for the Matern kernel: either 1/2, 3/2, or
9597 5/2. Only used when ``linear_truncated=True``.
98+ covar_module: The module for computing the covariance matrix between
99+ the non-fidelity features. Defaults to ``RBFKernel``.
96100 likelihood: A likelihood. If omitted, use a standard GaussianLikelihood
97101 with inferred noise level.
98102 outcome_transform: An outcome transform that is applied to the
@@ -130,6 +134,7 @@ def __init__(
130134 data_fidelities = data_fidelities ,
131135 linear_truncated = linear_truncated ,
132136 nu = nu ,
137+ data_covar_module = covar_module ,
133138 )
134139 super ().__init__ (
135140 train_X = train_X ,
@@ -143,9 +148,10 @@ def __init__(
143148 # Used for subsetting along the output dimension. See Model.subset_output.
144149 self ._subset_batch_dict = {
145150 "mean_module.raw_constant" : - 1 ,
146- "covar_module.raw_outputscale" : - 1 ,
147151 ** subset_batch_dict ,
148152 }
153+ if linear_truncated :
154+ self ._subset_batch_dict ["covar_module.raw_outputscale" ] = - 1
149155 if train_Yvar is None :
150156 self ._subset_batch_dict ["likelihood.noise_covar.raw_noise" ] = - 2
151157 self .to (train_X )
@@ -174,7 +180,8 @@ def _setup_multifidelity_covar_module(
174180 data_fidelities : Sequence [int ] | None ,
175181 linear_truncated : bool ,
176182 nu : float ,
177- ) -> tuple [ScaleKernel , dict ]:
183+ data_covar_module : Module | None = None ,
184+ ) -> tuple [ScaleKernel | ProductKernel , dict [str , int ]]:
178185 """Helper function to get the covariance module and associated subset_batch_dict
179186 for the multifidelity setting.
180187
@@ -190,7 +197,8 @@ def _setup_multifidelity_covar_module(
190197 of the default kernel.
191198 nu: The smoothness parameter for the Matern kernel: either 1/2, 3/2, or
192199 5/2. Only used when ``linear_truncated=True``.
193-
200+ data_covar_module: The module for computing the covariance matrix between
201+ the non-fidelity features. Defaults to ``RBFKernel``.
194202 Returns:
195203 The covariance module and subset_batch_dict.
196204 """
@@ -205,6 +213,12 @@ def _setup_multifidelity_covar_module(
205213
206214 kernels = []
207215
216+ if linear_truncated and data_covar_module is not None :
217+ raise ValueError (
218+ "Non-fidelity covariance module cannot be specified when using a linear "
219+ "truncated kernel."
220+ )
221+
208222 if linear_truncated :
209223 leading_dims = [iteration_fidelity ] if iteration_fidelity is not None else []
210224 trailing_dims = (
@@ -225,13 +239,18 @@ def _setup_multifidelity_covar_module(
225239 if iteration_fidelity is not None :
226240 non_active_dims .add (iteration_fidelity )
227241 active_dimsX = sorted (set (range (dim )) - non_active_dims )
228- kernels .append (
229- get_covar_module_with_dim_scaled_prior (
230- ard_num_dims = len (active_dimsX ),
231- batch_shape = aug_batch_shape ,
232- active_dims = active_dimsX ,
242+
243+ if data_covar_module is None :
244+ kernels .append (
245+ get_covar_module_with_dim_scaled_prior (
246+ ard_num_dims = len (active_dimsX ),
247+ batch_shape = aug_batch_shape ,
248+ active_dims = active_dimsX ,
249+ )
233250 )
234- )
251+ else :
252+ kernels .append (data_covar_module )
253+
235254 if iteration_fidelity is not None :
236255 kernels .append (
237256 ExponentialDecayKernel (
@@ -255,11 +274,15 @@ def _setup_multifidelity_covar_module(
255274
256275 kernel = ProductKernel (* kernels )
257276
258- covar_module = ScaleKernel (
259- kernel , batch_shape = aug_batch_shape , outputscale_prior = GammaPrior (2.0 , 0.15 )
260- )
277+ if linear_truncated :
278+ covar_module = ScaleKernel (
279+ kernel , batch_shape = aug_batch_shape , outputscale_prior = GammaPrior (2.0 , 0.15 )
280+ )
281+ key_prefix = "covar_module.base_kernel.kernels"
282+ else :
283+ covar_module = kernel
284+ key_prefix = "covar_module.kernels"
261285
262- key_prefix = "covar_module.base_kernel.kernels"
263286 if linear_truncated :
264287 subset_batch_dict = {}
265288 for i in range (len (kernels )):
@@ -271,9 +294,15 @@ def _setup_multifidelity_covar_module(
271294 }
272295 )
273296 else :
274- subset_batch_dict = {
275- f"{ key_prefix } .0.raw_lengthscale" : - 3 ,
276- }
297+ subset_batch_dict = {}
298+
299+ # Only set the subset_batch_dict if using the default kernel. See SingleTaskGP.
300+ if data_covar_module is None :
301+ subset_batch_dict .update (
302+ {
303+ f"{ key_prefix } .0.raw_lengthscale" : - 3 ,
304+ }
305+ )
277306
278307 if iteration_fidelity is not None :
279308 subset_batch_dict .update (
0 commit comments