Skip to content

Commit bc71edd

Browse files
JasonKChowfacebook-github-bot
authored andcommitted
Refactor factories to use classes and use configurablemixin (facebookresearch#793)
Summary: Factory classes to create mean and covariance modules has been refactored to use ConfigurableMixin. This means they follow the same from_config API as all other classes that can be initialized from a config. The old functions are now deprecated and will warn. Differential Revision: D74040481
1 parent 5a5594c commit bc71edd

28 files changed

Lines changed: 894 additions & 254 deletions

aepsych/config.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import typing
1414
import warnings
1515
from types import ModuleType, NoneType, UnionType
16-
from typing import Any, Callable, ClassVar, Mapping, Sequence, TypeVar
16+
from typing import Any, Callable, ClassVar, Literal, Mapping, Sequence, TypeVar
1717

1818
import botorch
1919
import gpytorch
@@ -603,6 +603,31 @@ def _sort_types(annotations):
603603
else:
604604
value = object_cls
605605

606+
# Literal (supporting strings, ints, floats)
607+
elif typing.get_origin(annotation) is Literal:
608+
literal_args = typing.get_args(annotation)
609+
for arg in literal_args:
610+
if isinstance(arg, str):
611+
attempt = config.get(name, key, fallback=None)
612+
elif isinstance(arg, int):
613+
attempt = config.getint(name, key, fallback=None)
614+
elif isinstance(arg, float):
615+
attempt = config.getfloat(name, key, fallback=None)
616+
else:
617+
raise NotImplementedError(
618+
f"Literal types in {annotation} not supported yet!"
619+
)
620+
621+
if attempt is None:
622+
continue
623+
624+
if attempt in literal_args:
625+
value = attempt
626+
else:
627+
raise RuntimeError(
628+
f"Value {attempt} is not in the Literal type {annotation} for the option {key}!"
629+
)
630+
606631
# Callable
607632
elif annotation is Callable:
608633
value = config.getobj(name, key)

aepsych/factory/__init__.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,9 @@
88
import sys
99

1010
from ..config import Config
11-
from .default import default_mean_covar_factory
12-
from .ordinal import ordinal_mean_covar_factory
13-
from .pairwise import pairwise_mean_covar_factory
14-
from .song import song_mean_covar_factory
11+
from .default import default_mean_covar_factory, DefaultMeanCovarFactory
12+
from .pairwise import pairwise_mean_covar_factory, PairwiseMeanCovarFactory
13+
from .song import song_mean_covar_factory, SongMeanCovarFactory
1514

1615
"""AEPsych factory functions.
1716
These functions generate a gpytorch Mean and Kernel objects from
@@ -23,10 +22,12 @@
2322
"""
2423

2524
__all__ = [
25+
"DefaultMeanCovarFactory",
2626
"default_mean_covar_factory",
27-
"ordinal_mean_covar_factory",
28-
"song_mean_covar_factory",
2927
"pairwise_mean_covar_factory",
28+
"PairwiseMeanCovarFactory",
29+
"SongMeanCovarFactory",
30+
"song_mean_covar_factory",
3031
]
3132

3233
Config.register_module(sys.modules[__name__])

aepsych/factory/default.py

Lines changed: 242 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,25 @@
55
# This source code is licensed under the license found in the
66
# LICENSE file in the root directory of this source tree.
77

8+
import abc
89
import math
910
import warnings
1011
from configparser import NoOptionError
12+
from copy import deepcopy
13+
from typing import Any, Literal
1114

1215
import gpytorch
1316
import torch
14-
from aepsych.config import Config
17+
from aepsych.config import Config, ConfigurableMixin
18+
from aepsych.utils import get_dims
1519
from scipy.stats import norm
1620

17-
from .utils import __default_invgamma_concentration, __default_invgamma_rate
21+
from .utils import (
22+
__default_invgamma_concentration,
23+
__default_invgamma_rate,
24+
DEFAULT_INVGAMMA_CONC,
25+
DEFAULT_INVGAMMA_RATE,
26+
)
1827

1928
# The gamma lengthscale prior is taken from
2029
# https://betanalpha.github.io/assets/case_studies/gaussian_processes.html#323_Informative_Prior_Model
@@ -23,6 +32,232 @@
2332
# https://arxiv.org/html/2402.02229v3
2433

2534

35+
class MeanCovarFactory(ConfigurableMixin, abc.ABC):
36+
def __init__(self, dim: int, stimuli_per_trial: int = 1, *args, **kwargs) -> None:
37+
"""Abstract base class for mean and covariance function factories.
38+
39+
Args:
40+
dim (int): Dimensionality of the parameter space.
41+
stimuli_per_trial (int, optional): Number of stimuli per trial. Defaults to 1.
42+
"""
43+
self.dim = dim
44+
self.stimuli_per_trial = stimuli_per_trial
45+
46+
self.mean_module = self._make_mean_module()
47+
self.covar_module = self._make_covar_module()
48+
49+
def get_mean(self) -> gpytorch.means.Mean:
50+
return deepcopy(self.mean_module)
51+
52+
def get_covar(self) -> gpytorch.kernels.Kernel:
53+
return deepcopy(self.covar_module)
54+
55+
@abc.abstractmethod
56+
def _make_mean_module(self) -> gpytorch.means.Mean:
57+
pass
58+
59+
@abc.abstractmethod
60+
def _make_covar_module(self) -> gpytorch.kernels.Kernel:
61+
pass
62+
63+
@classmethod
64+
def get_config_options(
65+
cls,
66+
config: Config,
67+
name: str | None = None,
68+
options: dict[str, Any] | None = None,
69+
) -> dict[str, Any]:
70+
"""Get configuration options for the MeanCovarFactory.
71+
72+
Args:
73+
config (Config): Config object to find options in.
74+
name (str, optional): Name of the factory. Defaults to the class name.
75+
options (dict, optional): Options to start with. Defaults to None.
76+
77+
Returns:
78+
dict[str, Any]: Options to use to initialize the factory.
79+
"""
80+
name = name or cls.__name__
81+
options = super().get_config_options(config, name, options)
82+
83+
if "dim" not in options:
84+
options["dim"] = get_dims(config)
85+
86+
return options
87+
88+
89+
class DefaultMeanCovarFactory(MeanCovarFactory):
90+
def __init__(
91+
self,
92+
dim: int,
93+
stimuli_per_trial: int = 1,
94+
zero_mean: bool = False,
95+
target: float | None = None,
96+
cov_kernel: gpytorch.kernels.Kernel = gpytorch.kernels.RBFKernel,
97+
active_dims: list[int] | None = None,
98+
lengthscale_prior: Literal["invgamma", "gamma", "lognormal"] | None = None,
99+
ls_loc: torch.Tensor | float | None = None,
100+
ls_scale: torch.Tensor | float | None = None,
101+
fixed_kernel_amplitude: bool | None = None,
102+
outputscale_prior: Literal["box", "gamma"] = "box",
103+
) -> None:
104+
"""Factory that makes mean and covariance functions for generic GPs.
105+
After initialization, copies of the mean and covariance functions can be made with
106+
`get_mean` and `get_covar`.
107+
108+
Args:
109+
dim (int, optional): Dimensionality of the parameter space. Must be provided.
110+
stimuli_per_trial (int): Number of stimuli per trial. Defaults to 1.
111+
zero_mean (bool, optional): Whether to use zero for the mean module. Defaults to False.
112+
target (float, optional): Target for the mean module. Defaults to None.
113+
cov_kernel (gpytorch.kernels.Kernel, optional): Covariance kernel to use. Defaults to RBF
114+
kernel.
115+
active_dims (list[int], optional): List of dimensions to use in the covariance function. Defaults to None,
116+
which uses all dimensions.
117+
lengthscale_prior (Literal["invgamma", "gamma", "lognormal"], optional): Prior to use for
118+
lengthscale. Defaults to "lognormal" if stimuli_per_trial == 1, else "gamma".
119+
ls_loc (torch.Tensor | float, optional): Location parameter for lengthscale prior.
120+
Defaults to sqrt(2.0).
121+
ls_scale (torch.Tensor | float, optional): Scale parameter for lengthscale prior.
122+
Defaults to sqrt(3.0).
123+
fixed_kernel_amplitude (bool, optional): Whether to allow the covariance kernel to scale.
124+
Defaults to True if stimuli_per_trial == 1, else False.
125+
outputscale_prior (Literal["box", "gamma"], optional): Prior to use to scale the covariance kernel.
126+
Defaults to "box".
127+
"""
128+
self.zero_mean = zero_mean
129+
self.target = target
130+
self.cov_kernel = cov_kernel
131+
self.active_dims = active_dims
132+
self.lengthscale_prior = lengthscale_prior
133+
self.ls_loc = ls_loc
134+
self.ls_scale = ls_scale
135+
self.fixed_kernel_amplitude = fixed_kernel_amplitude
136+
self.outputscale_prior = outputscale_prior
137+
138+
super().__init__(dim, stimuli_per_trial)
139+
140+
def get_mean(self) -> gpytorch.means.Mean:
141+
return deepcopy(self.mean_module)
142+
143+
def get_covar(self) -> gpytorch.kernels.Kernel:
144+
return deepcopy(self.covar_module)
145+
146+
def _make_mean_module(self) -> gpytorch.means.Mean:
147+
# Make mean module
148+
if self.zero_mean:
149+
mean = gpytorch.means.ZeroMean()
150+
else:
151+
mean = gpytorch.means.ConstantMean()
152+
153+
if self.target is not None:
154+
if self.zero_mean:
155+
warnings.warn(
156+
"Specified both `zero_mean = True` and `target`. Zero mean will be overwritten by target fixed mean!",
157+
UserWarning,
158+
stacklevel=2,
159+
)
160+
161+
mean.constant.requires_grad_(False)
162+
mean.constant.copy_(torch.tensor(norm.ppf(self.target)))
163+
164+
return mean
165+
166+
def _make_covar_module(self) -> gpytorch.kernels.Kernel:
167+
# Make covariance module
168+
if self.ls_loc is None:
169+
self.ls_loc = torch.tensor(math.sqrt(2.0), dtype=torch.float64)
170+
elif not isinstance(self.ls_loc, torch.Tensor):
171+
self.ls_loc = torch.tensor(self.ls_loc, dtype=torch.float64)
172+
173+
if self.ls_scale is None:
174+
self.ls_scale = torch.tensor(math.sqrt(3.0), dtype=torch.float64)
175+
elif not isinstance(self.ls_scale, torch.Tensor):
176+
self.ls_scale = torch.tensor(self.ls_scale, dtype=torch.float64)
177+
178+
if self.fixed_kernel_amplitude is None:
179+
self.fixed_kernel_amplitude = True if self.stimuli_per_trial == 1 else False
180+
181+
if self.lengthscale_prior == "invgamma":
182+
ls_prior = gpytorch.priors.GammaPrior(
183+
concentration=DEFAULT_INVGAMMA_CONC,
184+
rate=DEFAULT_INVGAMMA_RATE,
185+
transform=lambda x: 1 / x,
186+
)
187+
ls_prior_mode = ls_prior.rate / (ls_prior.concentration + 1)
188+
189+
elif self.lengthscale_prior == "gamma" or (
190+
self.lengthscale_prior is None and self.stimuli_per_trial != 1
191+
):
192+
ls_prior = gpytorch.priors.GammaPrior(concentration=3.0, rate=6.0)
193+
ls_prior_mode = (ls_prior.concentration - 1) / ls_prior.rate
194+
195+
elif self.lengthscale_prior == "lognormal" or (
196+
self.lengthscale_prior is None and self.stimuli_per_trial == 1
197+
):
198+
ls_prior = gpytorch.priors.LogNormalPrior(
199+
self.ls_loc + math.log(self.dim) / 2, self.ls_scale
200+
)
201+
ls_prior_mode = torch.exp(self.ls_loc - self.ls_scale**2)
202+
else:
203+
raise RuntimeError(
204+
f"Lengthscale_prior should be invgamma, gamma, or lognormal, got {self.lengthscale_prior}"
205+
)
206+
207+
ls_constraint = gpytorch.constraints.GreaterThan(
208+
lower_bound=1e-4, transform=None, initial_value=ls_prior_mode
209+
)
210+
211+
covar = self.cov_kernel(
212+
lengthscale_prior=ls_prior,
213+
lengthscale_constraint=ls_constraint,
214+
ard_num_dims=self.dim,
215+
active_dims=self.active_dims,
216+
)
217+
if not self.fixed_kernel_amplitude:
218+
if self.outputscale_prior == "gamma":
219+
os_prior = gpytorch.priors.GammaPrior(concentration=2.0, rate=0.15)
220+
elif self.outputscale_prior == "box":
221+
os_prior = gpytorch.priors.SmoothedBoxPrior(a=1, b=4)
222+
else:
223+
raise RuntimeError(
224+
f"Outputscale_prior should be gamma or box, got {self.outputscale_prior}"
225+
)
226+
227+
covar = gpytorch.kernels.ScaleKernel(
228+
covar,
229+
outputscale_prior=os_prior,
230+
outputscale_constraint=gpytorch.constraints.GreaterThan(1e-4),
231+
)
232+
233+
return covar
234+
235+
@classmethod
236+
def get_config_options(
237+
cls,
238+
config: Config,
239+
name: str | None = None,
240+
options: dict[str, Any] | None = None,
241+
) -> dict[str, Any]:
242+
"""Get configuration options for the MeanCovarFactory.
243+
244+
Args:
245+
config (Config): Config object to find options in.
246+
name (str, optional): Name of the factory. Defaults to the class name.
247+
options (dict, optional): Options to start with. Defaults to None.
248+
249+
Returns:
250+
dict[str, Any]: Options to use to initialize the factory.
251+
"""
252+
name = name or cls.__name__
253+
options = super().get_config_options(config, name, options)
254+
255+
if "dim" not in options:
256+
options["dim"] = get_dims(config)
257+
258+
return options
259+
260+
26261
def default_mean_covar_factory(
27262
config: Config | None = None,
28263
dim: int | None = None,
@@ -41,6 +276,11 @@ def default_mean_covar_factory(
41276
tuple[gpytorch.means.Mean, gpytorch.kernels.Kernel]: Instantiated
42277
ConstantMean and ScaleKernel with priors based on bounds.
43278
"""
279+
warnings.warn(
280+
"default_mean_covar_factory is deprecated, use the DefaultMeanCovarFactory class instead!",
281+
DeprecationWarning,
282+
stacklevel=2,
283+
)
44284

45285
assert (config is not None) or (
46286
dim is not None

aepsych/factory/ordinal.py

Lines changed: 0 additions & 43 deletions
This file was deleted.

0 commit comments

Comments
 (0)