Skip to content

Commit c61a3db

Browse files
JasonKChowfacebook-github-bot
authored andcommitted
Mixed factory for mixed models (#795)
Summary: Pull Request resolved: #795 Added factories to create mixed models supporting continuous and discrete (Categorical) parameters together. Categorical parameters can now be enabled in configs. A mixed variable acquisition function generator has been added to support active learning. Differential Revision: D74196201
1 parent 84fa231 commit c61a3db

8 files changed

Lines changed: 680 additions & 14 deletions

File tree

aepsych/config.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,9 +240,6 @@ def update(
240240
self._check_param_settings(par_name)
241241

242242
if self[par_name]["par_type"] == "categorical":
243-
raise NotImplementedError(
244-
"Categorical parameters not supported yet"
245-
)
246243
choices = self.getlist(par_name, "choices", element_type=str)
247244
lb[i] = "0"
248245
ub[i] = str(len(choices) - 1)

aepsych/factory/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from ..config import Config
1111
from .default import default_mean_covar_factory, DefaultMeanCovarFactory
12+
from .mixed import MixedMeanCovarFactory
1213
from .pairwise import pairwise_mean_covar_factory, PairwiseMeanCovarFactory
1314
from .song import song_mean_covar_factory, SongMeanCovarFactory
1415

@@ -24,6 +25,7 @@
2425
__all__ = [
2526
"DefaultMeanCovarFactory",
2627
"default_mean_covar_factory",
28+
"MixedMeanCovarFactory",
2729
"pairwise_mean_covar_factory",
2830
"PairwiseMeanCovarFactory",
2931
"SongMeanCovarFactory",

aepsych/factory/mixed.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
5+
# This source code is licensed under the license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
from copy import deepcopy
9+
from typing import Any, Literal
10+
11+
import botorch
12+
import gpytorch
13+
import torch
14+
from aepsych.config import Config
15+
from aepsych.factory.default import (
16+
DefaultMeanCovarFactory,
17+
)
18+
from aepsych.factory.utils import temporary_attributes
19+
20+
21+
class MixedMeanCovarFactory(DefaultMeanCovarFactory):
22+
def __init__(
23+
self,
24+
dim: int,
25+
discrete_params: dict[int, int],
26+
stimuli_per_trial: int = 1,
27+
discrete_param_ranks: dict[int, int] | None = None,
28+
discrete_kernel: Literal["index", "categorical"] = "categorical",
29+
zero_mean: bool = False,
30+
target: float | None = None,
31+
cov_kernel: gpytorch.kernels.Kernel = gpytorch.kernels.RBFKernel,
32+
active_dims: list[int] | None = None,
33+
lengthscale_prior: Literal["invgamma", "gamma", "lognormal"] | None = None,
34+
ls_loc: torch.Tensor | float | None = None,
35+
ls_scale: torch.Tensor | float | None = None,
36+
fixed_kernel_amplitude: bool | None = None,
37+
outputscale_prior: Literal["box", "gamma"] = "box",
38+
) -> None:
39+
"""Factory that makes mean and covariance functions for generic GPs.
40+
After initialization, copies of the mean and covariance functions can be made with
41+
`get_mean` and `get_covar`.
42+
43+
Args:
44+
dim (int, optional): Dimensionality of the parameter space. Must be provided.
45+
stimuli_per_trial (int): Number of stimuli per trial. Defaults to 1.
46+
zero_mean (bool, optional): Whether to use zero for the mean module. Defaults to False.
47+
target (float, optional): Target for the mean module. Defaults to None.
48+
cov_kernel (gpytorch.kernels.Kernel, optional): Covariance kernel to use. Defaults to RBF
49+
kernel.
50+
active_dims (list[int], optional): List of dimensions to use in the covariance function. Defaults to None,
51+
which uses all dimensions.
52+
lengthscale_prior (Literal["invgamma", "gamma", "lognormal"], optional): Prior to use for
53+
lengthscale. Defaults to "lognormal" if stimuli_per_trial == 1, else "gamma".
54+
ls_loc (torch.Tensor | float, optional): Location parameter for lengthscale prior.
55+
Defaults to sqrt(2.0).
56+
ls_scale (torch.Tensor | float, optional): Scale parameter for lengthscale prior.
57+
Defaults to sqrt(3.0).
58+
fixed_kernel_amplitude (bool, optional): Whether to allow the covariance kernel to scale.
59+
Defaults to True if stimuli_per_trial == 1, else False.
60+
outputscale_prior (Literal["box", "gamma"], optional): Prior to use to scale the covariance kernel.
61+
Defaults to "box".
62+
"""
63+
discrete_param_ranks = discrete_param_ranks or discrete_params.copy()
64+
65+
# Check if the keys in both dictionaries match
66+
if set(discrete_params.keys()) != set(discrete_param_ranks.keys()):
67+
raise ValueError("discrete parameter indices and ranks should match")
68+
69+
if discrete_kernel not in ("index", "categorical"):
70+
raise ValueError(
71+
"only index or categorical kernels supported for discrete kernel"
72+
)
73+
74+
self.discrete_params = discrete_params
75+
self.discrete_param_ranks = discrete_param_ranks or discrete_params.copy()
76+
self.discrete_kernel = discrete_kernel
77+
self.zero_mean = zero_mean
78+
self.target = target
79+
self.cov_kernel = cov_kernel
80+
self.active_dims = active_dims
81+
self.lengthscale_prior = lengthscale_prior
82+
self.ls_loc = ls_loc
83+
self.ls_scale = ls_scale
84+
self.fixed_kernel_amplitude = fixed_kernel_amplitude
85+
self.outputscale_prior = outputscale_prior
86+
87+
super().__init__(dim, stimuli_per_trial)
88+
89+
def _make_covar_module(self) -> gpytorch.kernels.Kernel:
90+
# Make covariance module
91+
cont_dims = self.active_dims or list(range(self.dim))
92+
cont_dims = [idx for idx in cont_dims if idx not in self.discrete_params.keys()]
93+
with temporary_attributes(
94+
self, dim=len(cont_dims), fixed_kernel_amplitude=True, active_dims=cont_dims
95+
):
96+
cont_kernel = super()._make_covar_module()
97+
98+
if self.discrete_kernel == "index":
99+
discrete_kernels = []
100+
for idx in self.discrete_params.keys():
101+
discrete_kernels.append(
102+
gpytorch.kernels.IndexKernel(
103+
num_tasks=self.discrete_params[idx],
104+
rank=self.discrete_param_ranks[idx],
105+
active_dims=(idx,),
106+
ard_num_dims=1,
107+
prior=gpytorch.priors.LKJCovariancePrior(
108+
n=self.discrete_param_ranks[idx],
109+
eta=1.5,
110+
sd_prior=gpytorch.priors.GammaPrior(1.0, 0.15),
111+
),
112+
)
113+
)
114+
add_kernel = gpytorch.kernels.AdditiveKernel(
115+
deepcopy(cont_kernel), *deepcopy(discrete_kernels)
116+
)
117+
prod_kernel = gpytorch.kernels.ProductKernel(
118+
deepcopy(cont_kernel), *deepcopy(discrete_kernels)
119+
)
120+
return add_kernel * prod_kernel
121+
elif self.discrete_kernel == "categorical":
122+
constraint = gpytorch.constraints.GreaterThan(lower_bound=1e-4)
123+
discrete_kernel = botorch.models.kernels.CategoricalKernel(
124+
active_dims=tuple(self.discrete_params.keys()),
125+
ard_num_dims=len(self.discrete_params),
126+
lengthscale_constraint=constraint,
127+
)
128+
129+
if not self.fixed_kernel_amplitude:
130+
discrete_kernel = gpytorch.kernels.ScaleKernel(discrete_kernel)
131+
cont_kernel = gpytorch.kernels.ScaleKernel(cont_kernel)
132+
133+
add_kernel = deepcopy(cont_kernel) + deepcopy(discrete_kernel)
134+
prod_kernel = deepcopy(cont_kernel) * deepcopy(discrete_kernel)
135+
136+
return add_kernel * prod_kernel
137+
else:
138+
raise ValueError("discrete kernel must be index or categorical")
139+
140+
@classmethod
141+
def get_config_options(
142+
cls,
143+
config: Config,
144+
name: str | None = None,
145+
options: dict[str, Any] | None = None,
146+
) -> dict[str, Any]:
147+
"""Get configuration options for the MeanCovarFactory.
148+
149+
Args:
150+
config (Config): Config object to find options in.
151+
name (str, optional): Name of the factory. Defaults to the class name.
152+
options (dict, optional): Options to start with. Defaults to None.
153+
154+
Returns:
155+
dict[str, Any]: Options to use to initialize the factory.
156+
"""
157+
name = name or cls.__name__
158+
options = super().get_config_options(config, name, options)
159+
160+
# Figure out discrete parameters
161+
par_names = config.getlist("common", "parnames", element_type=str)
162+
discrete_params = {}
163+
discrete_ranks = {}
164+
for i, par_name in enumerate(par_names):
165+
if config.get(par_name, "par_type") == "categorical":
166+
discrete_params[i] = len(
167+
config.getlist(par_name, "choices", element_type=str)
168+
)
169+
discrete_ranks[i] = config.getint(
170+
par_name, "rank", fallback=discrete_params[i]
171+
)
172+
173+
if len(discrete_params) == 0:
174+
raise ValueError("No categorical parameters found")
175+
176+
options["discrete_params"] = discrete_params
177+
options["discrete_param_ranks"] = discrete_ranks
178+
179+
return options

aepsych/generators/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
from .epsilon_greedy_generator import EpsilonGreedyGenerator
1414
from .independent_acqf_generator import IndependentOptimizeAcqfGenerator
1515
from .manual_generator import ManualGenerator, SampleAroundPointsGenerator
16-
from .optimize_acqf_generator import OptimizeAcqfGenerator
16+
from .optimize_acqf_generator import MixedOptimizeAcqfGenerator, OptimizeAcqfGenerator
1717
from .random_generator import RandomGenerator
1818
from .semi_p import IntensityAwareSemiPGenerator
1919
from .sobol_generator import SobolGenerator
2020

2121
__all__ = [
2222
"OptimizeAcqfGenerator",
23+
"MixedOptimizeAcqfGenerator",
2324
"RandomGenerator",
2425
"SobolGenerator",
2526
"EpsilonGreedyGenerator",

aepsych/generators/optimize_acqf_generator.py

Lines changed: 125 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,17 @@
77

88
import time
99
import warnings
10+
from itertools import product
1011
from typing import Any
1112

1213
import torch
1314
from aepsych.acquisition.lookahead import LookaheadAcquisitionFunction
15+
from aepsych.config import Config
1416
from aepsych.generators.base import AcqfGenerator
1517
from aepsych.models.base import AEPsychModelMixin
1618
from aepsych.utils_logging import getLogger
1719
from botorch.acquisition import AcquisitionFunction
18-
from botorch.optim import optimize_acqf
20+
from botorch.optim import optimize_acqf, optimize_acqf_mixed
1921

2022
logger = getLogger()
2123

@@ -157,3 +159,125 @@ def _gen(
157159

158160
logger.info(f"Gen done, time={time.time() - starttime}")
159161
return new_candidate
162+
163+
164+
class MixedOptimizeAcqfGenerator(OptimizeAcqfGenerator):
165+
"""A variant of OptimizeAcqfGenerator that supports mixed parameter types
166+
(namely continuous and categorical parameters)."""
167+
168+
def __init__(
169+
self,
170+
lb: torch.Tensor,
171+
ub: torch.Tensor,
172+
categorical_parameters: dict[int, int],
173+
acqf: AcquisitionFunction,
174+
acqf_kwargs: dict[str, Any] | None = None,
175+
restarts: int = 10,
176+
samps: int = 1024,
177+
max_gen_time: float | None = None,
178+
stimuli_per_trial: int = 1,
179+
) -> None:
180+
"""Initialize OptimizeAcqfGenerator.
181+
Args:
182+
lb (torch.Tensor): Lower bounds for the optimization.
183+
ub (torch.Tensor): Upper bounds for the optimization.
184+
categorical_parameters (dict[int, int]): A dictionary mapping the indices of the categorical
185+
parameters to the number of categories.
186+
acqf (AcquisitionFunction): Acquisition function to use.
187+
acqf_kwargs (dict[str, object], optional): Extra arguments to
188+
pass to acquisition function. Defaults to no arguments.
189+
restarts (int): Number of restarts for acquisition function optimization. Defaults to 10.
190+
samps (int): Number of samples for quasi-random initialization of the acquisition function optimizer. Defaults to 1000.
191+
max_gen_time (float, optional): Maximum time (in seconds) to optimize the acquisition function. Defaults to None.
192+
stimuli_per_trial (int): Number of stimuli per trial. Defaults to 1.
193+
"""
194+
super().__init__(
195+
lb=lb,
196+
ub=ub,
197+
acqf=acqf,
198+
acqf_kwargs=acqf_kwargs,
199+
restarts=restarts,
200+
samps=samps,
201+
max_gen_time=max_gen_time,
202+
stimuli_per_trial=stimuli_per_trial,
203+
)
204+
205+
# Make every possible combination of categorical values in a list
206+
cat_indices = list(categorical_parameters.keys())
207+
cat_values = [range(n) for n in categorical_parameters.values()]
208+
categorical_combos = []
209+
for combo in product(*cat_values):
210+
# Unpack combo into a dictionary
211+
categorical_combos.append(dict(zip(cat_indices, [float(x) for x in combo])))
212+
213+
self.categorical_combos = categorical_combos
214+
215+
def _gen(
216+
self,
217+
num_points: int,
218+
model: AEPsychModelMixin,
219+
acqf: AcquisitionFunction,
220+
fixed_features: dict[int, float] | None = None,
221+
**gen_options: dict[str, Any],
222+
) -> torch.Tensor:
223+
"""
224+
Generates the next query points by optimizing the acquisition function.
225+
226+
Args:
227+
num_points (int): Number of points to query.
228+
model (AEPsychModelMixin): Fitted model of the data.
229+
acqf (AcquisitionFunction): Acquisition function.
230+
fixed_features (dict[int, float], optional): The values where the specified
231+
parameters should be at when generating. Should be a dictionary where
232+
the keys are the indices of the parameters to fix and the values are the
233+
values to fix them at.
234+
gen_options (dict[str, Any]): Additional options for generating points, such as custom configurations.
235+
236+
Returns:
237+
torch.Tensor: Next set of points to evaluate, with shape [num_points x dim].
238+
"""
239+
if fixed_features is not None and len(fixed_features) > 0:
240+
raise NotImplementedError(
241+
"Fixed features are not supported for mixed parameter types."
242+
)
243+
logger.info("Starting gen...")
244+
starttime = time.time()
245+
246+
new_candidate, _ = optimize_acqf_mixed(
247+
acq_function=acqf,
248+
bounds=torch.stack([self.lb, self.ub]),
249+
q=num_points,
250+
fixed_features_list=self.categorical_combos,
251+
num_restarts=self.restarts,
252+
raw_samples=self.samps,
253+
timeout_sec=self.max_gen_time,
254+
**gen_options,
255+
)
256+
257+
logger.info(f"Gen done, time={time.time() - starttime}")
258+
return new_candidate
259+
260+
@classmethod
261+
def get_config_options(
262+
cls,
263+
config: Config,
264+
name: str | None = None,
265+
options: dict[str, Any] | None = None,
266+
) -> dict[str, Any]:
267+
options = super().get_config_options(config, name, options)
268+
269+
# Figure out discrete parameters
270+
par_names = config.getlist("common", "parnames", element_type=str)
271+
discrete_params = {}
272+
for i, par_name in enumerate(par_names):
273+
if config.get(par_name, "par_type") == "categorical":
274+
discrete_params[i] = len(
275+
config.getlist(par_name, "choices", element_type=str)
276+
)
277+
278+
if len(discrete_params) == 0:
279+
raise ValueError("No categorical parameters found")
280+
281+
options["categorical_parameters"] = discrete_params
282+
283+
return options

0 commit comments

Comments
 (0)