55# This source code is licensed under the license found in the
66# LICENSE file in the root directory of this source tree.
77
8+ import abc
89import math
910import warnings
1011from configparser import NoOptionError
12+ from copy import deepcopy
13+ from typing import Any , Literal
1114
1215import gpytorch
1316import torch
14- from aepsych .config import Config
17+ from aepsych .config import Config , ConfigurableMixin
18+ from aepsych .utils import get_dims
1519from scipy .stats import norm
1620
17- from .utils import __default_invgamma_concentration , __default_invgamma_rate
21+ from .utils import (
22+ __default_invgamma_concentration ,
23+ __default_invgamma_rate ,
24+ DEFAULT_INVGAMMA_CONC ,
25+ DEFAULT_INVGAMMA_RATE ,
26+ )
1827
1928# The gamma lengthscale prior is taken from
2029# https://betanalpha.github.io/assets/case_studies/gaussian_processes.html#323_Informative_Prior_Model
2332# https://arxiv.org/html/2402.02229v3
2433
2534
35+ class MeanCovarFactory (ConfigurableMixin , abc .ABC ):
36+ def __init__ (self , dim : int , stimuli_per_trial : int = 1 , * args , ** kwargs ) -> None :
37+ """Abstract base class for mean and covariance function factories.
38+
39+ Args:
40+ dim (int): Dimensionality of the parameter space.
41+ stimuli_per_trial (int, optional): Number of stimuli per trial. Defaults to 1.
42+ """
43+ self .dim = dim
44+ self .stimuli_per_trial = stimuli_per_trial
45+
46+ self .mean_module = self ._make_mean_module ()
47+ self .covar_module = self ._make_covar_module ()
48+
49+ def get_mean (self ) -> gpytorch .means .Mean :
50+ return deepcopy (self .mean_module )
51+
52+ def get_covar (self ) -> gpytorch .kernels .Kernel :
53+ return deepcopy (self .covar_module )
54+
55+ @abc .abstractmethod
56+ def _make_mean_module (self ) -> gpytorch .means .Mean :
57+ pass
58+
59+ @abc .abstractmethod
60+ def _make_covar_module (self ) -> gpytorch .kernels .Kernel :
61+ pass
62+
63+ @classmethod
64+ def get_config_options (
65+ cls ,
66+ config : Config ,
67+ name : str | None = None ,
68+ options : dict [str , Any ] | None = None ,
69+ ) -> dict [str , Any ]:
70+ """Get configuration options for the MeanCovarFactory.
71+
72+ Args:
73+ config (Config): Config object to find options in.
74+ name (str, optional): Name of the factory. Defaults to the class name.
75+ options (dict, optional): Options to start with. Defaults to None.
76+
77+ Returns:
78+ dict[str, Any]: Options to use to initialize the factory.
79+ """
80+ name = name or cls .__name__
81+ options = super ().get_config_options (config , name , options )
82+
83+ if "dim" not in options :
84+ options ["dim" ] = get_dims (config )
85+
86+ return options
87+
88+
89+ class DefaultMeanCovarFactory (MeanCovarFactory ):
90+ def __init__ (
91+ self ,
92+ dim : int ,
93+ stimuli_per_trial : int = 1 ,
94+ zero_mean : bool = False ,
95+ target : float | None = None ,
96+ cov_kernel : gpytorch .kernels .Kernel = gpytorch .kernels .RBFKernel ,
97+ active_dims : list [int ] | None = None ,
98+ lengthscale_prior : Literal ["invgamma" , "gamma" , "lognormal" ] | None = None ,
99+ ls_loc : torch .Tensor | float | None = None ,
100+ ls_scale : torch .Tensor | float | None = None ,
101+ fixed_kernel_amplitude : bool | None = None ,
102+ outputscale_prior : Literal ["box" , "gamma" ] = "box" ,
103+ ) -> None :
104+ """Factory that makes mean and covariance functions for generic GPs.
105+ After initialization, copies of the mean and covariance functions can be made with
106+ `get_mean` and `get_covar`.
107+
108+ Args:
109+ dim (int, optional): Dimensionality of the parameter space. Must be provided.
110+ stimuli_per_trial (int): Number of stimuli per trial. Defaults to 1.
111+ zero_mean (bool, optional): Whether to use zero for the mean module. Defaults to False.
112+ target (float, optional): Target for the mean module. Defaults to None.
113+ cov_kernel (gpytorch.kernels.Kernel, optional): Covariance kernel to use. Defaults to RBF
114+ kernel.
115+ active_dims (list[int], optional): List of dimensions to use in the covariance function. Defaults to None,
116+ which uses all dimensions.
117+ lengthscale_prior (Literal["invgamma", "gamma", "lognormal"], optional): Prior to use for
118+ lengthscale. Defaults to "lognormal" if stimuli_per_trial == 1, else "gamma".
119+ ls_loc (torch.Tensor | float, optional): Location parameter for lengthscale prior.
120+ Defaults to sqrt(2.0).
121+ ls_scale (torch.Tensor | float, optional): Scale parameter for lengthscale prior.
122+ Defaults to sqrt(3.0).
123+ fixed_kernel_amplitude (bool, optional): Whether to allow the covariance kernel to scale.
124+ Defaults to True if stimuli_per_trial == 1, else False.
125+ outputscale_prior (Literal["box", "gamma"], optional): Prior to use to scale the covariance kernel.
126+ Defaults to "box".
127+ """
128+ self .zero_mean = zero_mean
129+ self .target = target
130+ self .cov_kernel = cov_kernel
131+ self .active_dims = active_dims
132+ self .lengthscale_prior = lengthscale_prior
133+ self .ls_loc = ls_loc
134+ self .ls_scale = ls_scale
135+ self .fixed_kernel_amplitude = fixed_kernel_amplitude
136+ self .outputscale_prior = outputscale_prior
137+
138+ super ().__init__ (dim , stimuli_per_trial )
139+
140+ def get_mean (self ) -> gpytorch .means .Mean :
141+ return deepcopy (self .mean_module )
142+
143+ def get_covar (self ) -> gpytorch .kernels .Kernel :
144+ return deepcopy (self .covar_module )
145+
146+ def _make_mean_module (self ) -> gpytorch .means .Mean :
147+ # Make mean module
148+ if self .zero_mean :
149+ mean = gpytorch .means .ZeroMean ()
150+ else :
151+ mean = gpytorch .means .ConstantMean ()
152+
153+ if self .target is not None :
154+ if self .zero_mean :
155+ warnings .warn (
156+ "Specified both `zero_mean = True` and `target`. Zero mean will be overwritten by target fixed mean!" ,
157+ UserWarning ,
158+ stacklevel = 2 ,
159+ )
160+
161+ mean .constant .requires_grad_ (False )
162+ mean .constant .copy_ (torch .tensor (norm .ppf (self .target )))
163+
164+ return mean
165+
166+ def _make_covar_module (self ) -> gpytorch .kernels .Kernel :
167+ # Make covariance module
168+ if self .ls_loc is None :
169+ self .ls_loc = torch .tensor (math .sqrt (2.0 ), dtype = torch .float64 )
170+ elif not isinstance (self .ls_loc , torch .Tensor ):
171+ self .ls_loc = torch .tensor (self .ls_loc , dtype = torch .float64 )
172+
173+ if self .ls_scale is None :
174+ self .ls_scale = torch .tensor (math .sqrt (3.0 ), dtype = torch .float64 )
175+ elif not isinstance (self .ls_scale , torch .Tensor ):
176+ self .ls_scale = torch .tensor (self .ls_scale , dtype = torch .float64 )
177+
178+ if self .fixed_kernel_amplitude is None :
179+ self .fixed_kernel_amplitude = True if self .stimuli_per_trial == 1 else False
180+
181+ if self .lengthscale_prior == "invgamma" :
182+ ls_prior = gpytorch .priors .GammaPrior (
183+ concentration = DEFAULT_INVGAMMA_CONC ,
184+ rate = DEFAULT_INVGAMMA_RATE ,
185+ transform = lambda x : 1 / x ,
186+ )
187+ ls_prior_mode = ls_prior .rate / (ls_prior .concentration + 1 )
188+
189+ elif self .lengthscale_prior == "gamma" or (
190+ self .lengthscale_prior is None and self .stimuli_per_trial != 1
191+ ):
192+ ls_prior = gpytorch .priors .GammaPrior (concentration = 3.0 , rate = 6.0 )
193+ ls_prior_mode = (ls_prior .concentration - 1 ) / ls_prior .rate
194+
195+ elif self .lengthscale_prior == "lognormal" or (
196+ self .lengthscale_prior is None and self .stimuli_per_trial == 1
197+ ):
198+ ls_prior = gpytorch .priors .LogNormalPrior (
199+ self .ls_loc + math .log (self .dim ) / 2 , self .ls_scale
200+ )
201+ ls_prior_mode = torch .exp (self .ls_loc - self .ls_scale ** 2 )
202+ else :
203+ raise RuntimeError (
204+ f"Lengthscale_prior should be invgamma, gamma, or lognormal, got { self .lengthscale_prior } "
205+ )
206+
207+ ls_constraint = gpytorch .constraints .GreaterThan (
208+ lower_bound = 1e-4 , transform = None , initial_value = ls_prior_mode
209+ )
210+
211+ covar = self .cov_kernel (
212+ lengthscale_prior = ls_prior ,
213+ lengthscale_constraint = ls_constraint ,
214+ ard_num_dims = self .dim ,
215+ active_dims = self .active_dims ,
216+ )
217+ if not self .fixed_kernel_amplitude :
218+ if self .outputscale_prior == "gamma" :
219+ os_prior = gpytorch .priors .GammaPrior (concentration = 2.0 , rate = 0.15 )
220+ elif self .outputscale_prior == "box" :
221+ os_prior = gpytorch .priors .SmoothedBoxPrior (a = 1 , b = 4 )
222+ else :
223+ raise RuntimeError (
224+ f"Outputscale_prior should be gamma or box, got { self .outputscale_prior } "
225+ )
226+
227+ covar = gpytorch .kernels .ScaleKernel (
228+ covar ,
229+ outputscale_prior = os_prior ,
230+ outputscale_constraint = gpytorch .constraints .GreaterThan (1e-4 ),
231+ )
232+
233+ return covar
234+
235+ @classmethod
236+ def get_config_options (
237+ cls ,
238+ config : Config ,
239+ name : str | None = None ,
240+ options : dict [str , Any ] | None = None ,
241+ ) -> dict [str , Any ]:
242+ """Get configuration options for the MeanCovarFactory.
243+
244+ Args:
245+ config (Config): Config object to find options in.
246+ name (str, optional): Name of the factory. Defaults to the class name.
247+ options (dict, optional): Options to start with. Defaults to None.
248+
249+ Returns:
250+ dict[str, Any]: Options to use to initialize the factory.
251+ """
252+ name = name or cls .__name__
253+ options = super ().get_config_options (config , name , options )
254+
255+ if "dim" not in options :
256+ options ["dim" ] = get_dims (config )
257+
258+ return options
259+
260+
26261def default_mean_covar_factory (
27262 config : Config | None = None ,
28263 dim : int | None = None ,
@@ -41,6 +276,11 @@ def default_mean_covar_factory(
41276 tuple[gpytorch.means.Mean, gpytorch.kernels.Kernel]: Instantiated
42277 ConstantMean and ScaleKernel with priors based on bounds.
43278 """
279+ warnings .warn (
280+ "default_mean_covar_factory is deprecated, use the DefaultMeanCovarFactory class instead!" ,
281+ DeprecationWarning ,
282+ stacklevel = 2 ,
283+ )
44284
45285 assert (config is not None ) or (
46286 dim is not None
0 commit comments