diff --git a/botorch/models/__init__.py b/botorch/models/__init__.py index c614f7a60b..8a00352eda 100644 --- a/botorch/models/__init__.py +++ b/botorch/models/__init__.py @@ -8,6 +8,7 @@ ApproximateGPyTorchModel, SingleTaskVariationalGP, ) +from botorch.models.classifier import SoftKNNClassifierModel from botorch.models.cost import AffineFidelityCostModel from botorch.models.deterministic import ( AffineDeterministicModel, @@ -52,4 +53,5 @@ "SingleTaskGP", "SingleTaskMultiFidelityGP", "SingleTaskVariationalGP", + "SoftKNNClassifierModel", ] diff --git a/botorch/models/classifier.py b/botorch/models/classifier.py new file mode 100644 index 0000000000..868d7d14fd --- /dev/null +++ b/botorch/models/classifier.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Classifier-based models for constraint boundaries and deterministic feasibility. + +These models wrap classifiers as BoTorch deterministic models, +enabling them to be used for modeling binary constraints, feasibility, and other +discontinuous outputs where traditional GP models fail due to smoothness assumptions. +""" + +from __future__ import annotations + +from typing import Any + +import torch +from botorch.models.deterministic import GenericDeterministicModel +from botorch.models.transforms.input import InputTransform +from botorch.utils.datasets import SupervisedDataset +from torch import Tensor + + +class SoftKNNClassifierModel(GenericDeterministicModel): + """ + Soft K-Nearest Neighbors classifier wrapped as a BoTorch deterministic model. + + This model uses Gaussian kernel weighting to compute soft class probabilities. + Supports both fixed scalar sigma and learnable per-dimension sigma trained via + leave-one-out (LOO) cross-validation. + + Example: + >>> from botorch.models.classifier import SoftKNNClassifierModel + >>> from botorch.utils.datasets import SupervisedDataset + >>> import torch + >>> + >>> X = torch.randn(100, 5) + >>> y = torch.randint(0, 2, (100, 1), dtype=torch.float64) + >>> dataset = SupervisedDataset(X=X, Y=y) + >>> + >>> # Fixed sigma + >>> model_inputs = SoftKNNClassifierModel.construct_inputs( + ... training_data=dataset, + ... sigma=0.3 + ... ) + >>> model = SoftKNNClassifierModel(**model_inputs) + >>> + >>> # Learnable per-dimension sigma + >>> model_inputs = SoftKNNClassifierModel.construct_inputs( + ... training_data=dataset, + ... learnable_sigma=True, + ... sigma_epochs=100 + ... ) + >>> model = SoftKNNClassifierModel(**model_inputs) + """ + + def __init__( + self, + train_X: Tensor, + train_Y: Tensor, + sigma: float = 0.1, + learnable_sigma: bool = False, + sigma_lr: float = 0.1, + sigma_epochs: int = 100, + input_transform: InputTransform | None = None, + **kwargs: Any, + ) -> None: + """Initialize SoftKNNClassifierModel. + + Args: + train_X: Training features tensor of shape (n, d). + train_Y: Training labels tensor of shape (n,) or (n, 1), binary (0 or 1). + sigma: Initial Gaussian kernel bandwidth (default: 0.1). + learnable_sigma: If True, learn per-dimension sigma via LOO + cross-validation (default: False). + sigma_lr: Learning rate for sigma optimization (default: 0.1). + sigma_epochs: Training epochs for sigma (default: 100). + input_transform: Optional InputTransform applied to both training + and test inputs before distance computation. + **kwargs: Additional arguments (ignored). + """ + # Ensure train_Y is 1D + train_Y = train_Y.view(-1) + + # Apply input transform to training data if provided + # This ensures train_X_t is in the same space as test inputs + # (which are transformed via Model.transform_inputs in posterior()) + if input_transform is not None: + train_X_t = input_transform(train_X) + else: + train_X_t = train_X + + # Learn or use fixed sigma + learned_sigma_tensor: Tensor | None = None + if learnable_sigma: + # Learn per-dimension sigma via LOO cross-validation + d = train_X_t.shape[-1] + log_sigma = torch.nn.Parameter( + torch.full( + (d,), + torch.log(torch.tensor(sigma, dtype=train_X_t.dtype)), + device=train_X_t.device, + dtype=train_X_t.dtype, + ) + ) + + optimizer = torch.optim.Adam([log_sigma], lr=sigma_lr, foreach=True) + N = train_X_t.shape[0] + train_Y_float = train_Y.to(dtype=train_X_t.dtype) + + for _ in range(sigma_epochs): + optimizer.zero_grad() + sigma_vec = log_sigma.exp() # [d] + + # Pairwise distances with per-dim sigma: sum((x_i - x_j)^2 / sigma_j^2) + diffs = train_X_t.unsqueeze(1) - train_X_t.unsqueeze(0) # [N, N, d] + dists = torch.sum((diffs**2) / (sigma_vec**2), dim=2) # [N, N] + + # LOO: exclude self (diagonal) + mask = ~torch.eye(N, dtype=torch.bool, device=train_X_t.device) + weights = torch.exp(-dists / 2) * mask + + weighted_class1 = torch.sum( + weights * (train_Y_float == 1.0).to(dtype=train_X_t.dtype), dim=1 + ) + total_weights = torch.sum(weights, dim=1) + prob_class1 = weighted_class1 / (total_weights + 1e-12) + + # Binary cross-entropy loss + eps = 1e-7 + prob_class1_clamped = prob_class1.clamp(eps, 1 - eps) + loss = -torch.mean( + train_Y_float * torch.log(prob_class1_clamped) + + (1 - train_Y_float) * torch.log(1 - prob_class1_clamped) + ) + loss.backward() + optimizer.step() + + # Detach learned sigma for inference + sigma_final: Tensor | float = log_sigma.exp().detach() # [d] + learned_sigma_tensor = sigma_final + else: + sigma_final = sigma # scalar + + # Create prediction closure with transformed training data + def predict_proba_fn(X: Tensor) -> Tensor: + original_shape = X.shape[:-1] + # Already transformed via Model.transform_inputs if set + X_flat = X.reshape(-1, X.shape[-1]) + + diffs = X_flat.unsqueeze(1) - train_X_t.to(X_flat).unsqueeze(0) + + if isinstance(sigma_final, Tensor): + # Per-dimension sigma + dists = torch.sum((diffs**2) / (sigma_final.to(X_flat) ** 2), dim=2) + weights = torch.exp(-dists / 2) + else: + # Scalar sigma + dists = torch.sum(diffs**2, dim=2) + weights = torch.exp(-dists / (2 * sigma_final**2)) + + mask_class1 = train_Y.to(X_flat) == 1.0 + mask_class1 = mask_class1.to(dtype=X_flat.dtype) + + weighted_class1 = torch.matmul(weights, mask_class1) + total_weights = torch.sum(weights, dim=1) + probs_flat = weighted_class1 / (total_weights + 1e-12) + + return probs_flat.reshape(*original_shape, 1) + + # Initialize parent with the prediction function + super().__init__(f=predict_proba_fn, num_outputs=1) + + # Register input_transform as a submodule so posterior() applies it + if input_transform is not None: + self.input_transform = input_transform + + # Expose learned sigma (if any) for inspection + self.learned_sigma = learned_sigma_tensor + + @classmethod + def construct_inputs( + cls, + training_data: SupervisedDataset, + **kwargs: Any, + ) -> dict[str, Any]: + """ + Construct inputs for SoftKNNClassifierModel from training data. + + This method extracts training data and parameters that will be passed + to __init__, where the input_transform is applied and the prediction + closure is created. This ensures compatibility with Ax's model bridge, + which adds input_transform after calling construct_inputs. + + Args: + training_data: SupervisedDataset with X (features) and Y (labels). + sigma: Initial Gaussian kernel bandwidth (default: 0.1). + learnable_sigma: If True, learn per-dimension sigma via LOO + cross-validation (default: False). + sigma_lr: Learning rate for sigma optimization (default: 0.1). + sigma_epochs: Training epochs for sigma (default: 100). + input_transform: Optional InputTransform applied to both training + and test inputs before distance computation. + + Returns: + Dictionary with training data and model parameters. + """ + return { + "train_X": training_data.X.detach().clone(), + "train_Y": training_data.Y.detach().clone(), + "sigma": kwargs.get("sigma", 0.1), + "learnable_sigma": kwargs.get("learnable_sigma", False), + "sigma_lr": kwargs.get("sigma_lr", 0.1), + "sigma_epochs": kwargs.get("sigma_epochs", 100), + "input_transform": kwargs.get("input_transform", None), + } diff --git a/sphinx/source/models.rst b/sphinx/source/models.rst index 1693e6a53c..dcbd37cac7 100644 --- a/sphinx/source/models.rst +++ b/sphinx/source/models.rst @@ -39,6 +39,11 @@ Additive GP Models .. automodule:: botorch.models.additive_gp :members: +Classifier Models +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.models.classifier + :members: + Cost Models (for cost-aware optimization) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.models.cost diff --git a/test/models/test_classifier.py b/test/models/test_classifier.py new file mode 100644 index 0000000000..da81ee6020 --- /dev/null +++ b/test/models/test_classifier.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from botorch.models.classifier import SoftKNNClassifierModel +from botorch.models.transforms.input import Normalize +from botorch.posteriors.ensemble import EnsemblePosterior +from botorch.utils.datasets import SupervisedDataset +from botorch.utils.testing import BotorchTestCase + + +class TestSoftKNNClassifierModel(BotorchTestCase): + def _make_data(self, n: int = 20, d: int = 3) -> tuple[torch.Tensor, torch.Tensor]: + torch.manual_seed(0) + X = torch.rand(n, d, dtype=torch.float64) + Y = torch.randint(0, 2, (n, 1), dtype=torch.float64) + return X, Y + + def test_basic_construction(self) -> None: + X, Y = self._make_data() + model = SoftKNNClassifierModel(train_X=X, train_Y=Y, sigma=0.3) + self.assertEqual(model.num_outputs, 1) + self.assertIsNone(model.learned_sigma) + + def test_forward_shape_and_range(self) -> None: + X, Y = self._make_data() + model = SoftKNNClassifierModel(train_X=X, train_Y=Y, sigma=0.3) + test_X = torch.rand(5, 3, dtype=torch.float64) + out = model(test_X) + self.assertEqual(out.shape, torch.Size([5, 1])) + self.assertTrue((out >= 0).all()) + self.assertTrue((out <= 1).all()) + + def test_batched_input(self) -> None: + X, Y = self._make_data() + model = SoftKNNClassifierModel(train_X=X, train_Y=Y, sigma=0.3) + test_X = torch.rand(2, 5, 3, dtype=torch.float64) + out = model(test_X) + self.assertEqual(out.shape, torch.Size([2, 5, 1])) + + def test_posterior(self) -> None: + X, Y = self._make_data() + model = SoftKNNClassifierModel(train_X=X, train_Y=Y, sigma=0.3) + test_X = torch.rand(5, 3, dtype=torch.float64) + posterior = model.posterior(test_X) + self.assertIsInstance(posterior, EnsemblePosterior) + self.assertEqual(posterior.mean.shape, torch.Size([5, 1])) + + def test_learnable_sigma(self) -> None: + X, Y = self._make_data() + d = X.shape[-1] + model = SoftKNNClassifierModel( + train_X=X, + train_Y=Y, + learnable_sigma=True, + sigma_epochs=10, + ) + self.assertIsNotNone(model.learned_sigma) + self.assertEqual(model.learned_sigma.shape, torch.Size([d])) + # Forward should still work + test_X = torch.rand(5, d, dtype=torch.float64) + out = model(test_X) + self.assertEqual(out.shape, torch.Size([5, 1])) + self.assertTrue((out >= 0).all()) + self.assertTrue((out <= 1).all()) + + def test_construct_inputs(self) -> None: + X, Y = self._make_data() + dataset = SupervisedDataset( + X=X, + Y=Y, + feature_names=[f"x{i}" for i in range(X.shape[-1])], + outcome_names=["y"], + ) + inputs = SoftKNNClassifierModel.construct_inputs( + training_data=dataset, + sigma=0.5, + learnable_sigma=True, + sigma_epochs=50, + ) + self.assertIn("train_X", inputs) + self.assertIn("train_Y", inputs) + self.assertEqual(inputs["sigma"], 0.5) + self.assertTrue(inputs["learnable_sigma"]) + self.assertEqual(inputs["sigma_epochs"], 50) + self.assertTrue(torch.equal(inputs["train_X"], X)) + self.assertTrue(torch.equal(inputs["train_Y"], Y)) + # Round-trip: construct model from inputs + model = SoftKNNClassifierModel(**inputs) + self.assertEqual(model.num_outputs, 1) + + def test_construct_inputs_does_not_mutate_kwargs(self) -> None: + X, Y = self._make_data() + dataset = SupervisedDataset( + X=X, + Y=Y, + feature_names=[f"x{i}" for i in range(X.shape[-1])], + outcome_names=["y"], + ) + kwargs = {"sigma": 0.5, "extra_key": "value"} + SoftKNNClassifierModel.construct_inputs(training_data=dataset, **kwargs) + # kwargs should not be mutated + self.assertIn("sigma", kwargs) + self.assertIn("extra_key", kwargs) + + def test_input_transform(self) -> None: + X, Y = self._make_data() + d = X.shape[-1] + bounds = torch.stack([torch.zeros(d), torch.ones(d)]) + intf = Normalize(d=d, bounds=bounds) + model = SoftKNNClassifierModel( + train_X=X, train_Y=Y, sigma=0.3, input_transform=intf + ) + test_X = torch.rand(5, d, dtype=torch.float64) + posterior = model.posterior(test_X) + self.assertEqual(posterior.mean.shape, torch.Size([5, 1])) + + def test_all_same_class(self) -> None: + X = torch.rand(10, 3, dtype=torch.float64) + Y = torch.ones(10, 1, dtype=torch.float64) + model = SoftKNNClassifierModel(train_X=X, train_Y=Y, sigma=0.3) + test_X = torch.rand(5, 3, dtype=torch.float64) + out = model(test_X) + # All class-1 training data → predictions should be ~1.0 + self.assertTrue((out > 0.99).all()) + + def test_learnable_sigma_with_input_transform(self) -> None: + X, Y = self._make_data() + d = X.shape[-1] + bounds = torch.stack([torch.zeros(d), torch.ones(d)]) + intf = Normalize(d=d, bounds=bounds) + model = SoftKNNClassifierModel( + train_X=X, + train_Y=Y, + learnable_sigma=True, + sigma_epochs=10, + input_transform=intf, + ) + self.assertIsNotNone(model.learned_sigma) + self.assertEqual(model.learned_sigma.shape, torch.Size([d])) + test_X = torch.rand(5, d, dtype=torch.float64) + posterior = model.posterior(test_X) + self.assertEqual(posterior.mean.shape, torch.Size([5, 1])) + self.assertTrue((posterior.mean >= 0).all()) + self.assertTrue((posterior.mean <= 1).all()) + + def test_single_training_point(self) -> None: + X = torch.rand(1, 3, dtype=torch.float64) + Y = torch.tensor([[1.0]], dtype=torch.float64) + model = SoftKNNClassifierModel(train_X=X, train_Y=Y, sigma=0.3) + test_X = torch.rand(5, 3, dtype=torch.float64) + out = model(test_X) + self.assertEqual(out.shape, torch.Size([5, 1])) + # Single class-1 point → predictions should be ~1.0 + self.assertTrue((out > 0.99).all())