|
11 | 11 | from aepsych.config import Config |
12 | 12 | from aepsych.models.gp_classification import GPClassificationModel |
13 | 13 | from aepsych.models.inducing_points import ( |
| 14 | + DataAllocator, |
14 | 15 | FixedAllocator, |
15 | 16 | FixedPlusAllocator, |
16 | 17 | GreedyVarianceReduction, |
17 | 18 | KMeansAllocator, |
18 | 19 | SobolAllocator, |
19 | 20 | ) |
20 | | -from aepsych.strategy import Strategy |
| 21 | +from aepsych.strategy import SequentialStrategy, Strategy |
21 | 22 | from aepsych.transforms.parameters import ParameterTransforms, transform_options |
22 | 23 | from sklearn.datasets import make_classification |
23 | 24 |
|
@@ -482,6 +483,79 @@ def test_fixed_plus_allocator_dimension_mismatch(self): |
482 | 483 | main_allocator=KMeansAllocator, |
483 | 484 | ) |
484 | 485 |
|
| 486 | + def test_data_allocator(self): |
| 487 | + """Test basic functionality of DataAllocator.""" |
| 488 | + allocator = DataAllocator(dim=2) |
| 489 | + inputs = torch.tensor([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]) |
| 490 | + |
| 491 | + # Test that it returns the input data and sets last_allocator_used |
| 492 | + inducing_points = allocator.allocate_inducing_points( |
| 493 | + inputs=inputs, num_inducing=10 |
| 494 | + ) |
| 495 | + self.assertTrue(torch.equal(inducing_points, inputs)) |
| 496 | + self.assertIs(allocator.last_allocator_used, DataAllocator) |
| 497 | + self.assertIsNot(inducing_points, inputs) # Should be a clone |
| 498 | + |
| 499 | + # Test when no inputs are provided we get dummy points |
| 500 | + inducing_points = allocator.allocate_inducing_points(num_inducing=10) |
| 501 | + self.assertEqual(inducing_points.shape, (10, 2)) |
| 502 | + self.assertTrue(torch.all(inducing_points == 0)) |
| 503 | + |
| 504 | + # Test warning when num_inducing is less than inputs |
| 505 | + with self.assertWarns(UserWarning) as w: |
| 506 | + inducing_points = allocator.allocate_inducing_points( |
| 507 | + inputs=inputs, num_inducing=2 |
| 508 | + ) |
| 509 | + |
| 510 | + self.assertEqual(len(w.warnings), 1) |
| 511 | + self.assertIn("DataAllocator ignores num_inducing=2", w.warning.args[0]) |
| 512 | + self.assertTrue(torch.all(inducing_points == inputs)) |
| 513 | + |
| 514 | + def test_data_allocator_config_smoketest(self): |
| 515 | + """Test DataAllocator integration with model and config.""" |
| 516 | + # Test with config |
| 517 | + config_str = """ |
| 518 | + [common] |
| 519 | + parnames = [par1] |
| 520 | + stimuli_per_trial = 1 |
| 521 | + outcome_types = [binary] |
| 522 | + strategy_names = [init_strat, opt_strat] |
| 523 | +
|
| 524 | + [par1] |
| 525 | + par_type = continuous |
| 526 | + lower_bound = 0 |
| 527 | + upper_bound = 1 |
| 528 | +
|
| 529 | + [init_strat] |
| 530 | + generator = SobolGenerator |
| 531 | + min_asks = 2 |
| 532 | +
|
| 533 | + [opt_strat] |
| 534 | + generator = OptimizeAcqfGenerator |
| 535 | + min_asks = 1 |
| 536 | + model = GPClassificationModel |
| 537 | +
|
| 538 | + [GPClassificationModel] |
| 539 | + inducing_point_method = DataAllocator |
| 540 | + inducing_size = 2 |
| 541 | +
|
| 542 | + [OptimizeAcqfGenerator] |
| 543 | + acqf = MCLevelSetEstimation |
| 544 | + """ |
| 545 | + |
| 546 | + config = Config() |
| 547 | + config.update(config_str=config_str) |
| 548 | + strat = SequentialStrategy.from_config(config) |
| 549 | + |
| 550 | + for response in [0, 1]: |
| 551 | + point = strat.gen() |
| 552 | + strat.add_data(point, torch.tensor([response])) |
| 553 | + |
| 554 | + point = strat.gen() |
| 555 | + self.assertTrue( |
| 556 | + torch.all(strat.model.variational_strategy.inducing_points == strat.x) |
| 557 | + ) |
| 558 | + |
485 | 559 |
|
486 | 560 | if __name__ == "__main__": |
487 | 561 | unittest.main() |
0 commit comments