Skip to content

Commit 52d25f1

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Weighted & batched EnsemblePosterior Ax support (facebook#4201)
Summary: Adds more predict support for `EnsemblePosterior`, and makes it more coherent with the predict support that is already in place for `GaussianMixturePosterior` by using the newly added (previous diff) `mixture` attributes for `EnsemblePosterior`. In the case where an `EnsemblePosterior` has batch dimensions in addition to the ensemble dimension, we compute the mean and variance over both batch and ensemble dimensions, instead of just the latter. Moreover, this allows us to properly account for non-uniform weights across batch and ensemble dimensions, which is needed for sampling-based benchmarking in ensembled Fully Bayesian models. Removed tests that involved previous mixture moment calculations, as these are now in BoTorch. Pull Request resolved: facebook#4201 Reviewed By: saitcakmak Differential Revision: D80972578 fbshipit-source-id: 179799158979078958b9e9ba958d6ef58d996873
1 parent 8d3c1c9 commit 52d25f1

2 files changed

Lines changed: 11 additions & 81 deletions

File tree

ax/generators/torch/tests/test_utils.py

Lines changed: 6 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -736,85 +736,21 @@ def test_model_config(self) -> None:
736736

737737
def test_predict_from_model_ensemble_posterior(self) -> None:
738738
"""Test predict_from_model with EnsemblePosterior support."""
739-
# Create test data
740739
X = torch.rand(2, 3)
741740

742-
# Create a mock EnsemblePosterior with ndim > 2
741+
# Create a mock EnsemblePosterior
743742
mock_posterior = Mock(spec=EnsemblePosterior)
744-
745-
# Set up posterior values with shape (num_models, batch_shape, output_shape)
746-
# This simulates an ensemble of 5 models with 2 test points and 2 outputs
747-
posterior_values = torch.rand(5, 2, 2) # (5 models, 2 points, 2 outputs)
748-
mock_posterior.values = posterior_values
743+
mock_posterior.mixture_mean = torch.rand(2, 2)
744+
mock_posterior.mixture_variance = torch.rand(2, 2)
749745

750746
# Create a mock model
751747
mock_model = Mock()
752748
mock_model.posterior.return_value = mock_posterior
753749

754-
# Test with use_posterior_predictive=False
750+
# Test prediction
755751
mean, cov = predict_from_model(mock_model, X, use_posterior_predictive=False)
756752

757-
# Verify the model.posterior was called correctly
758-
mock_model.posterior.assert_called_once()
759-
760-
# Verify output shapes
753+
# Verify shapes
761754
self.assertEqual(mean.shape, (2, 2)) # (n_points, n_outputs)
762755
self.assertEqual(cov.shape, (2, 2, 2)) # (n_points, n_outputs, n_outputs)
763-
764-
# Verify mean calculation (should be mean over ensemble dimension)
765-
expected_mean = posterior_values.mean(dim=0) # Average over first dimension
766-
self.assertTrue(torch.allclose(mean, expected_mean))
767-
768-
# Verify variance calculation (should be variance over ensemble dimension)
769-
expected_var = posterior_values.var(dim=0)
770-
# Check that the diagonal of the covariance matches expected variance
771-
self.assertTrue(
772-
torch.allclose(torch.diagonal(cov, dim1=-2, dim2=-1), expected_var)
773-
)
774-
775-
# Test with use_posterior_predictive=True
776-
mock_model.reset_mock()
777-
predict_from_model(mock_model, X, use_posterior_predictive=True)
778-
mock_model.posterior.assert_called_once()
779-
780-
mock_posterior2 = Mock(spec=EnsemblePosterior)
781-
# Shape: (num_models, batch1, batch2, output_shape) - ndim = 4
782-
posterior_values_5d = torch.rand(2, 3, 4, 2, 2)
783-
mock_posterior2.values = posterior_values_5d
784-
mock_model2 = Mock()
785-
mock_model2.posterior.return_value = mock_posterior2
786-
787-
X2 = torch.rand(4, 3)
788-
mean2, cov2 = predict_from_model(
789-
mock_model2, X2, use_posterior_predictive=False
790-
)
791-
792-
# Should average over first two dimensions (all except last 2)
793-
expected_mean_5d = posterior_values_5d.mean(dim=(0, 1, 2))
794-
expected_var_5d = posterior_values_5d.var(dim=(0, 1, 2))
795-
796-
self.assertTrue(torch.allclose(mean2, expected_mean_5d))
797-
self.assertTrue(
798-
torch.allclose(torch.diagonal(cov2, dim1=-2, dim2=-1), expected_var_5d)
799-
)
800-
801-
# Test case where ensemble size is 1 or non-existant
802-
# (variance should be zero, not NaN)
803-
posterior_values_singles = [
804-
torch.rand(1, 2, 2),
805-
torch.rand(1, 1, 2, 2),
806-
] # Single ensemble model
807-
mock_model.reset_mock()
808-
mock_posterior_single = Mock(spec=EnsemblePosterior)
809-
mock_model.posterior.return_value = mock_posterior_single
810-
for i, posterior_values_single in enumerate(posterior_values_singles):
811-
with self.subTest(i=i, shape=posterior_values_single.shape):
812-
mock_posterior_single.values = posterior_values_single
813-
mean_single, cov_single = predict_from_model(
814-
mock_model, X, use_posterior_predictive=False
815-
)
816-
# Variance should be zero (not NaN) when ensemble size is 1
817-
var_single = torch.diagonal(cov_single, dim1=-2, dim2=-1)
818-
self.assertTrue(
819-
torch.allclose(var_single, torch.zeros_like(var_single))
820-
)
756+
self.assertTrue(torch.all(cov >= 0)) # Ensure covariance is positive

ax/generators/torch/utils.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -535,17 +535,11 @@ def predict_from_model(
535535
if isinstance(posterior, GaussianMixturePosterior):
536536
mean = posterior.mixture_mean.cpu().detach()
537537
var = posterior.mixture_variance.cpu().detach().clamp_min(0)
538-
elif isinstance(posterior, EnsemblePosterior) and posterior.values.ndim > 2:
539-
# Compute dimensions to average over (all except last 2)
540-
# Not using the build-in EnsemblePosterior.variance() since that
541-
# does not allow us to compute variance over _all_ batch dimensions
542-
avg_dims = tuple(range(posterior.values.ndim - 2))
543-
mean = posterior.values.mean(dim=avg_dims).cpu().detach()
544-
var = posterior.values.var(dim=avg_dims).cpu().detach()
545-
546-
# Replace NaN values with zero (occurs when ensemble size is 1)
547-
if posterior.values[..., 0, 0].numel() == 1:
548-
var = torch.zeros_like(var)
538+
elif isinstance(posterior, EnsemblePosterior):
539+
# Always use mixture_mean and mixture_variance for ensemble
540+
# predictions - provides prediction from mixture, not just average
541+
mean = posterior.mixture_mean.cpu().detach()
542+
var = posterior.mixture_variance.cpu().detach().clamp_min(0)
549543
else:
550544
try:
551545
mean = posterior.mean.cpu().detach() # type: ignore

0 commit comments

Comments
 (0)