Skip to content

Commit eec2862

Browse files
JasonKChowfacebook-github-bot
authored andcommitted
Add DataAllocator (facebookresearch#789)
Summary: The DataAllocator inducing point allocator just returns the input as inducing points. Differential Revision: D73885655
1 parent fcfd175 commit eec2862

3 files changed

Lines changed: 134 additions & 1 deletion

File tree

aepsych/models/inducing_points/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@
88
import sys
99

1010
from ...config import Config
11+
from .data import DataAllocator
1112
from .fixed import FixedAllocator, FixedPlusAllocator
1213
from .greedy_variance_reduction import GreedyVarianceReduction
1314
from .kmeans import KMeansAllocator
1415
from .sobol import SobolAllocator
1516

1617
__all__ = [
18+
"DataAllocator",
1719
"FixedAllocator",
1820
"FixedPlusAllocator",
1921
"GreedyVarianceReduction",
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
5+
# This source code is licensed under the license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import warnings
9+
10+
import torch
11+
from aepsych.models.inducing_points.base import BaseAllocator, EMPTY_SIZE
12+
13+
14+
class DataAllocator(BaseAllocator):
15+
def __init__(
16+
self,
17+
dim: int,
18+
) -> None:
19+
"""Initialize the DataAllocator. This allocator simply returns the input
20+
data to use as the inducing points.
21+
22+
Args:
23+
dim (int): Dimensionality of the search space.
24+
"""
25+
super().__init__(dim=dim)
26+
27+
def allocate_inducing_points(
28+
self,
29+
inputs: torch.Tensor | None = None,
30+
covar_module: torch.nn.Module | None = None,
31+
num_inducing: int = 100,
32+
input_batch_shape: torch.Size = EMPTY_SIZE,
33+
) -> torch.Tensor:
34+
"""Allocate inducing points by returning the inputs as the inducing points.
35+
36+
Args:
37+
inputs (torch.Tensor): Input tensor, cloned and returned as inducing points.
38+
covar_module (torch.nn.Module, optional): Kernel covariance module; included for API compatibility, but not used here.
39+
num_inducing (int, optional): The number of inducing points to generate. This parameter is ignored by DataAllocator,
40+
which always returns all input points.
41+
input_batch_shape (torch.Size, optional): Batch shape; included for API compatibility, but not used here.
42+
43+
Returns:
44+
torch.Tensor: The input data as inducing points.
45+
"""
46+
if inputs is None: # Dummy points
47+
return self._allocate_dummy_points(num_inducing=num_inducing)
48+
49+
if num_inducing < inputs.shape[0]:
50+
warnings.warn(
51+
f"DataAllocator ignores num_inducing={num_inducing} and returns all input points.",
52+
UserWarning,
53+
stacklevel=2,
54+
)
55+
56+
self.last_allocator_used = self.__class__
57+
return inputs.clone().detach()

tests/test_points_allocators.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,14 @@
1111
from aepsych.config import Config
1212
from aepsych.models.gp_classification import GPClassificationModel
1313
from aepsych.models.inducing_points import (
14+
DataAllocator,
1415
FixedAllocator,
1516
FixedPlusAllocator,
1617
GreedyVarianceReduction,
1718
KMeansAllocator,
1819
SobolAllocator,
1920
)
20-
from aepsych.strategy import Strategy
21+
from aepsych.strategy import SequentialStrategy, Strategy
2122
from aepsych.transforms.parameters import ParameterTransforms, transform_options
2223
from sklearn.datasets import make_classification
2324

@@ -482,6 +483,79 @@ def test_fixed_plus_allocator_dimension_mismatch(self):
482483
main_allocator=KMeansAllocator,
483484
)
484485

486+
def test_data_allocator(self):
487+
"""Test basic functionality of DataAllocator."""
488+
allocator = DataAllocator(dim=2)
489+
inputs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])
490+
491+
# Test that it returns the input data and sets last_allocator_used
492+
inducing_points = allocator.allocate_inducing_points(
493+
inputs=inputs, num_inducing=10
494+
)
495+
self.assertTrue(torch.equal(inducing_points, inputs))
496+
self.assertIs(allocator.last_allocator_used, DataAllocator)
497+
self.assertIsNot(inducing_points, inputs) # Should be a clone
498+
499+
# Test when no inputs are provided we get dummy points
500+
inducing_points = allocator.allocate_inducing_points(num_inducing=10)
501+
self.assertEqual(inducing_points.shape, (10, 2))
502+
self.assertTrue(torch.all(inducing_points == 0))
503+
504+
# Test warning when num_inducing is less than inputs
505+
with self.assertWarns(UserWarning) as w:
506+
inducing_points = allocator.allocate_inducing_points(
507+
inputs=inputs, num_inducing=2
508+
)
509+
510+
self.assertEqual(len(w.warnings), 1)
511+
self.assertIn("DataAllocator ignores num_inducing=2", w.warning.args[0])
512+
self.assertTrue(torch.all(inducing_points == inputs))
513+
514+
def test_data_allocator_config_smoketest(self):
515+
"""Test DataAllocator integration with model and config."""
516+
# Test with config
517+
config_str = """
518+
[common]
519+
parnames = [par1]
520+
stimuli_per_trial = 1
521+
outcome_types = [binary]
522+
strategy_names = [init_strat, opt_strat]
523+
524+
[par1]
525+
par_type = continuous
526+
lower_bound = 0
527+
upper_bound = 1
528+
529+
[init_strat]
530+
generator = SobolGenerator
531+
min_asks = 2
532+
533+
[opt_strat]
534+
generator = OptimizeAcqfGenerator
535+
min_asks = 1
536+
model = GPClassificationModel
537+
538+
[GPClassificationModel]
539+
inducing_point_method = DataAllocator
540+
inducing_size = 2
541+
542+
[OptimizeAcqfGenerator]
543+
acqf = MCLevelSetEstimation
544+
"""
545+
546+
config = Config()
547+
config.update(config_str=config_str)
548+
strat = SequentialStrategy.from_config(config)
549+
550+
for response in [0, 1]:
551+
point = strat.gen()
552+
strat.add_data(point, torch.tensor([response]))
553+
554+
point = strat.gen()
555+
self.assertTrue(
556+
torch.all(strat.model.variational_strategy.inducing_points == strat.x)
557+
)
558+
485559

486560
if __name__ == "__main__":
487561
unittest.main()

0 commit comments

Comments
 (0)