|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 3 | +# All rights reserved. |
| 4 | + |
| 5 | +# This source code is licensed under the license found in the |
| 6 | +# LICENSE file in the root directory of this source tree. |
| 7 | + |
| 8 | +from copy import deepcopy |
| 9 | +from typing import Any, Literal |
| 10 | + |
| 11 | +import botorch |
| 12 | +import gpytorch |
| 13 | +import torch |
| 14 | +from aepsych.config import Config |
| 15 | +from aepsych.factory.default import ( |
| 16 | + DefaultMeanCovarFactory, |
| 17 | +) |
| 18 | +from aepsych.factory.utils import temporary_attributes |
| 19 | + |
| 20 | + |
| 21 | +class MixedMeanCovarFactory(DefaultMeanCovarFactory): |
| 22 | + def __init__( |
| 23 | + self, |
| 24 | + dim: int, |
| 25 | + discrete_params: dict[int, int], |
| 26 | + stimuli_per_trial: int = 1, |
| 27 | + discrete_param_ranks: dict[int, int] | None = None, |
| 28 | + discrete_kernel: Literal["index", "categorical"] = "categorical", |
| 29 | + zero_mean: bool = False, |
| 30 | + target: float | None = None, |
| 31 | + cov_kernel: gpytorch.kernels.Kernel = gpytorch.kernels.RBFKernel, |
| 32 | + active_dims: list[int] | None = None, |
| 33 | + lengthscale_prior: Literal["invgamma", "gamma", "lognormal"] | None = None, |
| 34 | + ls_loc: torch.Tensor | float | None = None, |
| 35 | + ls_scale: torch.Tensor | float | None = None, |
| 36 | + fixed_kernel_amplitude: bool | None = None, |
| 37 | + outputscale_prior: Literal["box", "gamma"] = "box", |
| 38 | + ) -> None: |
| 39 | + """Factory that makes mean and covariance functions for generic GPs. |
| 40 | + After initialization, copies of the mean and covariance functions can be made with |
| 41 | + `get_mean` and `get_covar`. |
| 42 | +
|
| 43 | + Args: |
| 44 | + dim (int, optional): Dimensionality of the parameter space. Must be provided. |
| 45 | + stimuli_per_trial (int): Number of stimuli per trial. Defaults to 1. |
| 46 | + zero_mean (bool, optional): Whether to use zero for the mean module. Defaults to False. |
| 47 | + target (float, optional): Target for the mean module. Defaults to None. |
| 48 | + cov_kernel (gpytorch.kernels.Kernel, optional): Covariance kernel to use. Defaults to RBF |
| 49 | + kernel. |
| 50 | + active_dims (list[int], optional): List of dimensions to use in the covariance function. Defaults to None, |
| 51 | + which uses all dimensions. |
| 52 | + lengthscale_prior (Literal["invgamma", "gamma", "lognormal"], optional): Prior to use for |
| 53 | + lengthscale. Defaults to "lognormal" if stimuli_per_trial == 1, else "gamma". |
| 54 | + ls_loc (torch.Tensor | float, optional): Location parameter for lengthscale prior. |
| 55 | + Defaults to sqrt(2.0). |
| 56 | + ls_scale (torch.Tensor | float, optional): Scale parameter for lengthscale prior. |
| 57 | + Defaults to sqrt(3.0). |
| 58 | + fixed_kernel_amplitude (bool, optional): Whether to allow the covariance kernel to scale. |
| 59 | + Defaults to True if stimuli_per_trial == 1, else False. |
| 60 | + outputscale_prior (Literal["box", "gamma"], optional): Prior to use to scale the covariance kernel. |
| 61 | + Defaults to "box". |
| 62 | + """ |
| 63 | + discrete_param_ranks = discrete_param_ranks or discrete_params.copy() |
| 64 | + |
| 65 | + # Check if the keys in both dictionaries match |
| 66 | + if set(discrete_params.keys()) != set(discrete_param_ranks.keys()): |
| 67 | + raise ValueError("discrete parameter indices and ranks should match") |
| 68 | + |
| 69 | + if discrete_kernel not in ("index", "categorical"): |
| 70 | + raise ValueError( |
| 71 | + "only index or categorical kernels supported for discrete kernel" |
| 72 | + ) |
| 73 | + |
| 74 | + self.discrete_params = discrete_params |
| 75 | + self.discrete_param_ranks = discrete_param_ranks or discrete_params.copy() |
| 76 | + self.discrete_kernel = discrete_kernel |
| 77 | + self.zero_mean = zero_mean |
| 78 | + self.target = target |
| 79 | + self.cov_kernel = cov_kernel |
| 80 | + self.active_dims = active_dims |
| 81 | + self.lengthscale_prior = lengthscale_prior |
| 82 | + self.ls_loc = ls_loc |
| 83 | + self.ls_scale = ls_scale |
| 84 | + self.fixed_kernel_amplitude = fixed_kernel_amplitude |
| 85 | + self.outputscale_prior = outputscale_prior |
| 86 | + |
| 87 | + super().__init__(dim, stimuli_per_trial) |
| 88 | + |
| 89 | + def _make_covar_module(self) -> gpytorch.kernels.Kernel: |
| 90 | + # Make covariance module |
| 91 | + cont_dims = self.active_dims or list(range(self.dim)) |
| 92 | + cont_dims = [idx for idx in cont_dims if idx not in self.discrete_params.keys()] |
| 93 | + with temporary_attributes( |
| 94 | + self, dim=len(cont_dims), fixed_kernel_amplitude=True, active_dims=cont_dims |
| 95 | + ): |
| 96 | + cont_kernel = super()._make_covar_module() |
| 97 | + |
| 98 | + if self.discrete_kernel == "index": |
| 99 | + discrete_kernels = [] |
| 100 | + for idx in self.discrete_params.keys(): |
| 101 | + discrete_kernels.append( |
| 102 | + gpytorch.kernels.IndexKernel( |
| 103 | + num_tasks=self.discrete_params[idx], |
| 104 | + rank=self.discrete_param_ranks[idx], |
| 105 | + active_dims=(idx,), |
| 106 | + ard_num_dims=1, |
| 107 | + prior=gpytorch.priors.LKJCovariancePrior( |
| 108 | + n=self.discrete_param_ranks[idx], |
| 109 | + eta=1.5, |
| 110 | + sd_prior=gpytorch.priors.GammaPrior(1.0, 0.15), |
| 111 | + ), |
| 112 | + ) |
| 113 | + ) |
| 114 | + add_kernel = gpytorch.kernels.AdditiveKernel( |
| 115 | + deepcopy(cont_kernel), *deepcopy(discrete_kernels) |
| 116 | + ) |
| 117 | + prod_kernel = gpytorch.kernels.ProductKernel( |
| 118 | + deepcopy(cont_kernel), *deepcopy(discrete_kernels) |
| 119 | + ) |
| 120 | + return add_kernel * prod_kernel |
| 121 | + elif self.discrete_kernel == "categorical": |
| 122 | + constraint = gpytorch.constraints.GreaterThan(lower_bound=1e-4) |
| 123 | + discrete_kernel = botorch.models.kernels.CategoricalKernel( |
| 124 | + active_dims=tuple(self.discrete_params.keys()), |
| 125 | + ard_num_dims=len(self.discrete_params), |
| 126 | + lengthscale_constraint=constraint, |
| 127 | + ) |
| 128 | + |
| 129 | + if not self.fixed_kernel_amplitude: |
| 130 | + discrete_kernel = gpytorch.kernels.ScaleKernel(discrete_kernel) |
| 131 | + cont_kernel = gpytorch.kernels.ScaleKernel(cont_kernel) |
| 132 | + |
| 133 | + add_kernel = deepcopy(cont_kernel) + deepcopy(discrete_kernel) |
| 134 | + prod_kernel = deepcopy(cont_kernel) * deepcopy(discrete_kernel) |
| 135 | + |
| 136 | + return add_kernel * prod_kernel |
| 137 | + else: |
| 138 | + raise ValueError("discrete kernel must be index or categorical") |
| 139 | + |
| 140 | + @classmethod |
| 141 | + def get_config_options( |
| 142 | + cls, |
| 143 | + config: Config, |
| 144 | + name: str | None = None, |
| 145 | + options: dict[str, Any] | None = None, |
| 146 | + ) -> dict[str, Any]: |
| 147 | + """Get configuration options for the MeanCovarFactory. |
| 148 | +
|
| 149 | + Args: |
| 150 | + config (Config): Config object to find options in. |
| 151 | + name (str, optional): Name of the factory. Defaults to the class name. |
| 152 | + options (dict, optional): Options to start with. Defaults to None. |
| 153 | +
|
| 154 | + Returns: |
| 155 | + dict[str, Any]: Options to use to initialize the factory. |
| 156 | + """ |
| 157 | + name = name or cls.__name__ |
| 158 | + options = super().get_config_options(config, name, options) |
| 159 | + |
| 160 | + # Figure out discrete parameters |
| 161 | + par_names = config.getlist("common", "parnames", element_type=str) |
| 162 | + discrete_params = {} |
| 163 | + discrete_ranks = {} |
| 164 | + for i, par_name in enumerate(par_names): |
| 165 | + if config.get(par_name, "par_type") == "categorical": |
| 166 | + discrete_params[i] = len( |
| 167 | + config.getlist(par_name, "choices", element_type=str) |
| 168 | + ) |
| 169 | + discrete_ranks[i] = config.getint( |
| 170 | + par_name, "rank", fallback=discrete_params[i] |
| 171 | + ) |
| 172 | + |
| 173 | + if len(discrete_params) == 0: |
| 174 | + raise ValueError("No categorical parameters found") |
| 175 | + |
| 176 | + options["discrete_params"] = discrete_params |
| 177 | + options["discrete_param_ranks"] = discrete_ranks |
| 178 | + |
| 179 | + return options |
0 commit comments