Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion aepsych/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import typing
import warnings
from types import ModuleType, NoneType, UnionType
from typing import Any, Callable, ClassVar, Mapping, Sequence, TypeVar
from typing import Any, Callable, ClassVar, Literal, Mapping, Sequence, TypeVar

import botorch
import gpytorch
Expand Down Expand Up @@ -603,6 +603,31 @@ def _sort_types(annotations):
else:
value = object_cls

# Literal (supporting strings, ints, floats)
elif typing.get_origin(annotation) is Literal:
literal_args = typing.get_args(annotation)
for arg in literal_args:
if isinstance(arg, str):
attempt = config.get(name, key, fallback=None)
elif isinstance(arg, int):
attempt = config.getint(name, key, fallback=None)
elif isinstance(arg, float):
attempt = config.getfloat(name, key, fallback=None)
else:
raise NotImplementedError(
f"Literal types in {annotation} not supported yet!"
)

if attempt is None:
continue

if attempt in literal_args:
value = attempt
else:
raise RuntimeError(
f"Value {attempt} is not in the Literal type {annotation} for the option {key}!"
)

# Callable
elif annotation is Callable:
value = config.getobj(name, key)
Expand Down
13 changes: 7 additions & 6 deletions aepsych/factory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@
import sys

from ..config import Config
from .default import default_mean_covar_factory
from .ordinal import ordinal_mean_covar_factory
from .pairwise import pairwise_mean_covar_factory
from .song import song_mean_covar_factory
from .default import default_mean_covar_factory, DefaultMeanCovarFactory
from .pairwise import pairwise_mean_covar_factory, PairwiseMeanCovarFactory
from .song import song_mean_covar_factory, SongMeanCovarFactory

"""AEPsych factory functions.
These functions generate a gpytorch Mean and Kernel objects from
Expand All @@ -23,10 +22,12 @@
"""

__all__ = [
"DefaultMeanCovarFactory",
"default_mean_covar_factory",
"ordinal_mean_covar_factory",
"song_mean_covar_factory",
"pairwise_mean_covar_factory",
"PairwiseMeanCovarFactory",
"SongMeanCovarFactory",
"song_mean_covar_factory",
]

Config.register_module(sys.modules[__name__])
244 changes: 242 additions & 2 deletions aepsych/factory/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,25 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import abc
import math
import warnings
from configparser import NoOptionError
from copy import deepcopy
from typing import Any, Literal

import gpytorch
import torch
from aepsych.config import Config
from aepsych.config import Config, ConfigurableMixin
from aepsych.utils import get_dims
from scipy.stats import norm

from .utils import __default_invgamma_concentration, __default_invgamma_rate
from .utils import (
__default_invgamma_concentration,
__default_invgamma_rate,
DEFAULT_INVGAMMA_CONC,
DEFAULT_INVGAMMA_RATE,
)

# The gamma lengthscale prior is taken from
# https://betanalpha.github.io/assets/case_studies/gaussian_processes.html#323_Informative_Prior_Model
Expand All @@ -23,6 +32,232 @@
# https://arxiv.org/html/2402.02229v3


class MeanCovarFactory(ConfigurableMixin, abc.ABC):
def __init__(self, dim: int, stimuli_per_trial: int = 1, *args, **kwargs) -> None:
"""Abstract base class for mean and covariance function factories.

Args:
dim (int): Dimensionality of the parameter space.
stimuli_per_trial (int, optional): Number of stimuli per trial. Defaults to 1.
"""
self.dim = dim
self.stimuli_per_trial = stimuli_per_trial

self.mean_module = self._make_mean_module()
self.covar_module = self._make_covar_module()

def get_mean(self) -> gpytorch.means.Mean:
return deepcopy(self.mean_module)

def get_covar(self) -> gpytorch.kernels.Kernel:
return deepcopy(self.covar_module)

@abc.abstractmethod
def _make_mean_module(self) -> gpytorch.means.Mean:
pass

@abc.abstractmethod
def _make_covar_module(self) -> gpytorch.kernels.Kernel:
pass

@classmethod
def get_config_options(
cls,
config: Config,
name: str | None = None,
options: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Get configuration options for the MeanCovarFactory.

Args:
config (Config): Config object to find options in.
name (str, optional): Name of the factory. Defaults to the class name.
options (dict, optional): Options to start with. Defaults to None.

Returns:
dict[str, Any]: Options to use to initialize the factory.
"""
name = name or cls.__name__
options = super().get_config_options(config, name, options)

if "dim" not in options:
options["dim"] = get_dims(config)

return options


class DefaultMeanCovarFactory(MeanCovarFactory):
def __init__(
self,
dim: int,
stimuli_per_trial: int = 1,
zero_mean: bool = False,
target: float | None = None,
cov_kernel: gpytorch.kernels.Kernel = gpytorch.kernels.RBFKernel,
active_dims: list[int] | None = None,
lengthscale_prior: Literal["invgamma", "gamma", "lognormal"] | None = None,
ls_loc: torch.Tensor | float | None = None,
ls_scale: torch.Tensor | float | None = None,
fixed_kernel_amplitude: bool | None = None,
outputscale_prior: Literal["box", "gamma"] = "box",
) -> None:
"""Factory that makes mean and covariance functions for generic GPs.
After initialization, copies of the mean and covariance functions can be made with
`get_mean` and `get_covar`.

Args:
dim (int, optional): Dimensionality of the parameter space. Must be provided.
stimuli_per_trial (int): Number of stimuli per trial. Defaults to 1.
zero_mean (bool, optional): Whether to use zero for the mean module. Defaults to False.
target (float, optional): Target for the mean module. Defaults to None.
cov_kernel (gpytorch.kernels.Kernel, optional): Covariance kernel to use. Defaults to RBF
kernel.
active_dims (list[int], optional): List of dimensions to use in the covariance function. Defaults to None,
which uses all dimensions.
lengthscale_prior (Literal["invgamma", "gamma", "lognormal"], optional): Prior to use for
lengthscale. Defaults to "lognormal" if stimuli_per_trial == 1, else "gamma".
ls_loc (torch.Tensor | float, optional): Location parameter for lengthscale prior.
Defaults to sqrt(2.0).
ls_scale (torch.Tensor | float, optional): Scale parameter for lengthscale prior.
Defaults to sqrt(3.0).
fixed_kernel_amplitude (bool, optional): Whether to allow the covariance kernel to scale.
Defaults to True if stimuli_per_trial == 1, else False.
outputscale_prior (Literal["box", "gamma"], optional): Prior to use to scale the covariance kernel.
Defaults to "box".
"""
self.zero_mean = zero_mean
self.target = target
self.cov_kernel = cov_kernel
self.active_dims = active_dims
self.lengthscale_prior = lengthscale_prior
self.ls_loc = ls_loc
self.ls_scale = ls_scale
self.fixed_kernel_amplitude = fixed_kernel_amplitude
self.outputscale_prior = outputscale_prior

super().__init__(dim, stimuli_per_trial)

def get_mean(self) -> gpytorch.means.Mean:
return deepcopy(self.mean_module)

def get_covar(self) -> gpytorch.kernels.Kernel:
return deepcopy(self.covar_module)

def _make_mean_module(self) -> gpytorch.means.Mean:
# Make mean module
if self.zero_mean:
mean = gpytorch.means.ZeroMean()
else:
mean = gpytorch.means.ConstantMean()

if self.target is not None:
if self.zero_mean:
warnings.warn(
"Specified both `zero_mean = True` and `target`. Zero mean will be overwritten by target fixed mean!",
UserWarning,
stacklevel=2,
)

mean.constant.requires_grad_(False)
mean.constant.copy_(torch.tensor(norm.ppf(self.target)))

return mean

def _make_covar_module(self) -> gpytorch.kernels.Kernel:
# Make covariance module
if self.ls_loc is None:
self.ls_loc = torch.tensor(math.sqrt(2.0), dtype=torch.float64)
elif not isinstance(self.ls_loc, torch.Tensor):
self.ls_loc = torch.tensor(self.ls_loc, dtype=torch.float64)

if self.ls_scale is None:
self.ls_scale = torch.tensor(math.sqrt(3.0), dtype=torch.float64)
elif not isinstance(self.ls_scale, torch.Tensor):
self.ls_scale = torch.tensor(self.ls_scale, dtype=torch.float64)

if self.fixed_kernel_amplitude is None:
self.fixed_kernel_amplitude = True if self.stimuli_per_trial == 1 else False

if self.lengthscale_prior == "invgamma":
ls_prior = gpytorch.priors.GammaPrior(
concentration=DEFAULT_INVGAMMA_CONC,
rate=DEFAULT_INVGAMMA_RATE,
transform=lambda x: 1 / x,
)
ls_prior_mode = ls_prior.rate / (ls_prior.concentration + 1)

elif self.lengthscale_prior == "gamma" or (
self.lengthscale_prior is None and self.stimuli_per_trial != 1
):
ls_prior = gpytorch.priors.GammaPrior(concentration=3.0, rate=6.0)
ls_prior_mode = (ls_prior.concentration - 1) / ls_prior.rate

elif self.lengthscale_prior == "lognormal" or (
self.lengthscale_prior is None and self.stimuli_per_trial == 1
):
ls_prior = gpytorch.priors.LogNormalPrior(
self.ls_loc + math.log(self.dim) / 2, self.ls_scale
)
ls_prior_mode = torch.exp(self.ls_loc - self.ls_scale**2)
else:
raise RuntimeError(
f"Lengthscale_prior should be invgamma, gamma, or lognormal, got {self.lengthscale_prior}"
)

ls_constraint = gpytorch.constraints.GreaterThan(
lower_bound=1e-4, transform=None, initial_value=ls_prior_mode
)

covar = self.cov_kernel(
lengthscale_prior=ls_prior,
lengthscale_constraint=ls_constraint,
ard_num_dims=self.dim,
active_dims=self.active_dims,
)
if not self.fixed_kernel_amplitude:
if self.outputscale_prior == "gamma":
os_prior = gpytorch.priors.GammaPrior(concentration=2.0, rate=0.15)
elif self.outputscale_prior == "box":
os_prior = gpytorch.priors.SmoothedBoxPrior(a=1, b=4)
else:
raise RuntimeError(
f"Outputscale_prior should be gamma or box, got {self.outputscale_prior}"
)

covar = gpytorch.kernels.ScaleKernel(
covar,
outputscale_prior=os_prior,
outputscale_constraint=gpytorch.constraints.GreaterThan(1e-4),
)

return covar

@classmethod
def get_config_options(
cls,
config: Config,
name: str | None = None,
options: dict[str, Any] | None = None,
) -> dict[str, Any]:
"""Get configuration options for the MeanCovarFactory.

Args:
config (Config): Config object to find options in.
name (str, optional): Name of the factory. Defaults to the class name.
options (dict, optional): Options to start with. Defaults to None.

Returns:
dict[str, Any]: Options to use to initialize the factory.
"""
name = name or cls.__name__
options = super().get_config_options(config, name, options)

if "dim" not in options:
options["dim"] = get_dims(config)

return options


def default_mean_covar_factory(
config: Config | None = None,
dim: int | None = None,
Expand All @@ -41,6 +276,11 @@ def default_mean_covar_factory(
tuple[gpytorch.means.Mean, gpytorch.kernels.Kernel]: Instantiated
ConstantMean and ScaleKernel with priors based on bounds.
"""
warnings.warn(
"default_mean_covar_factory is deprecated, use the DefaultMeanCovarFactory class instead!",
DeprecationWarning,
stacklevel=2,
)

assert (config is not None) or (
dim is not None
Expand Down
43 changes: 0 additions & 43 deletions aepsych/factory/ordinal.py

This file was deleted.

Loading