Skip to content

Commit 2c216c2

Browse files
authored
fix: do not save simulations in empirical prior (#1700)
1 parent c765e11 commit 2c216c2

File tree

4 files changed

+89
-57
lines changed

4 files changed

+89
-57
lines changed

sbi/inference/potentials/posterior_based_potential.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@
1717
)
1818
from sbi.sbi_types import TorchTransform
1919
from sbi.utils.sbiutils import (
20-
ImproperEmpirical,
2120
mcmc_transform,
22-
warn_empirical_prior_memory_risk,
2321
within_support,
2422
)
2523
from sbi.utils.torchutils import ensure_theta_batched
@@ -102,9 +100,6 @@ def to(self, device: Union[str, torch.device]) -> None:
102100
self.device = device
103101
self.posterior_estimator.to(device)
104102
if self.prior is not None:
105-
is_empirical = isinstance(self.prior, ImproperEmpirical)
106-
if is_empirical and torch.device(device).type == "cuda":
107-
warn_empirical_prior_memory_risk("moving empirical prior to CUDA")
108103
self.prior.to(device) # type: ignore
109104
if self._x_o is not None:
110105
self._x_o = self._x_o.to(device)

sbi/inference/trainers/npe/npe_base.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
from sbi.utils.sbiutils import (
5454
ImproperEmpirical,
5555
mask_sims_from_prior,
56-
warn_empirical_prior_memory_risk,
5756
)
5857
from sbi.utils.torchutils import assert_all_finite
5958

@@ -470,18 +469,10 @@ def _get_potential_function(
470469
to unconstrained space.
471470
"""
472471

473-
is_empirical = isinstance(prior, ImproperEmpirical)
474-
if is_empirical:
475-
warn_empirical_prior_memory_risk(
476-
"disabling parameter transforms for empirical prior"
477-
)
478-
479472
potential_fn, theta_transform = posterior_estimator_based_potential(
480473
posterior_estimator=estimator,
481474
prior=prior,
482475
x_o=None,
483-
# Disable transforms if prior is empirical to avoid sampling issues.
484-
enable_transform=not is_empirical,
485476
)
486477
return potential_fn, theta_transform
487478

sbi/utils/sbiutils.py

Lines changed: 88 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
from torch.optim.adam import Adam
3636

3737
from sbi.sbi_types import TorchTransform
38-
from sbi.utils.torchutils import atleast_2d
3938

4039

4140
def warn_if_zscoring_changes_data(x: Tensor, duplicate_tolerance: float = 0.1) -> None:
@@ -242,21 +241,6 @@ def biject_transform_zuko(
242241
)
243242

244243

245-
def warn_empirical_prior_memory_risk(context: Optional[str] = None) -> None:
246-
"""Emit a standardized warning about empirical-prior memory/VRAM risks.
247-
248-
Args:
249-
context: Optional context string to append to the warning.
250-
"""
251-
base = (
252-
"Empirical prior memory/VRAM risk: empirical priors retain all simulations "
253-
"as support and may trigger operations over large supports. This can "
254-
"significantly increase memory usage and cause out-of-memory (OOM) errors."
255-
)
256-
message = f"{base} Context: {context}" if context else base
257-
warnings.warn(message, stacklevel=2)
258-
259-
260244
def z_standardization(
261245
batch_t: Tensor,
262246
structured_dims: bool = False,
@@ -752,13 +736,6 @@ def mcmc_transform(
752736
(or z-scored) to constrained (or non-z-scored) space.
753737
"""
754738
if enable_transform:
755-
if isinstance(prior, (ImproperEmpirical, Empirical)):
756-
warn_empirical_prior_memory_risk(
757-
"disabled parameter transforms to avoid sampling-based moments"
758-
)
759-
return torch_tf.IndependentTransform(
760-
torch_tf.identity_transform, reinterpreted_batch_ndims=1
761-
)
762739

763740
def prior_mean_std_transform(prior, device):
764741
try:
@@ -850,7 +827,12 @@ def check_transform(
850827
) -> None:
851828
"""Check validity of transformed and re-transformed samples."""
852829

853-
theta = prior.sample(torch.Size((2,)))
830+
# check transform with prior samples
831+
try:
832+
theta = prior.sample(torch.Size((2,)))
833+
except NotImplementedError:
834+
# Prior has no sampling method, use the prior mean instead
835+
theta = prior.mean.repeat(2, *[1] * prior.mean.dim())
854836

855837
theta_unconstrained = transform.inv(theta)
856838
assert (
@@ -881,9 +863,15 @@ class ImproperEmpirical(Empirical):
881863
def __init__(self, values: Tensor, log_weights: Optional[Tensor] = None):
882864
super().__init__(values, log_weights=log_weights)
883865
# Warn if extremely large to inform about memory/serialization cost.
884-
support_size = values.shape[0]
885-
if support_size > 10_000_000: # 10M still works well on modern hardware.
886-
warn_empirical_prior_memory_risk(f">10M support size (size={support_size})")
866+
self._mean = self._compute_mean(values, log_weights)
867+
self._variance = self._compute_variance(values, log_weights)
868+
869+
def sample(self, sample_shape=torch.Size()):
870+
raise NotImplementedError(
871+
"Sampling from ImproperEmpirical is not supported. If you are using "
872+
"likelihood or ratio estimation, or multi-round inference, you need to "
873+
"define a prior distribution."
874+
)
887875

888876
def log_prob(self, value: Tensor) -> Tensor:
889877
"""
@@ -895,8 +883,79 @@ def log_prob(self, value: Tensor) -> Tensor:
895883
Returns:
896884
Tensor of as many ones as there were parameter sets.
897885
"""
898-
value = atleast_2d(value)
899-
return zeros(value.shape[0])
886+
raise NotImplementedError(
887+
"Evaluating log_prob from ImproperEmpirical is not supported. If you are "
888+
"using likelihood or ratio estimation, or multi-round inference, you need "
889+
"to define a prior distribution."
890+
)
891+
892+
def _compute_mean(self, values: Tensor, weights: Optional[Tensor] = None) -> Tensor:
893+
"""
894+
Return the mean of the empirical distribution.
895+
896+
Args:
897+
values: The empirical samples.
898+
weights: Optional weights for the samples.
899+
900+
Returns:
901+
The mean of the empirical distribution.
902+
"""
903+
if weights is None:
904+
return torch.mean(values, dim=0)
905+
else:
906+
normalized_weights = torch.nn.functional.softmax(weights, dim=0)
907+
return torch.sum(normalized_weights.unsqueeze(-1) * values, dim=0)
908+
909+
def _compute_variance(
910+
self, values: Tensor, weights: Optional[Tensor] = None
911+
) -> Tensor:
912+
"""
913+
Return the standard deviation of the empirical distribution.
914+
915+
Args:
916+
values: The empirical samples.
917+
weights: Optional weights for the samples.
918+
919+
Returns:
920+
The standard deviation of the empirical distribution.
921+
"""
922+
if weights is None:
923+
variance = torch.var(values, dim=0)
924+
else:
925+
normalized_weights = torch.nn.functional.softmax(weights, dim=0)
926+
variance = torch.sum(
927+
normalized_weights.unsqueeze(-1) * (values - self._mean) ** 2,
928+
dim=0,
929+
)
930+
# bias correction
931+
variance = variance / (1 - torch.sum(normalized_weights**2))
932+
return variance
933+
934+
@property
935+
def mean(self) -> Tensor:
936+
return self._mean
937+
938+
@property
939+
def variance(self) -> Tensor:
940+
return self._variance
941+
942+
@property
943+
def stddev(self) -> Tensor:
944+
return torch.sqrt(self._variance)
945+
946+
def to(self, device: Union[str, torch.device]) -> None:
947+
"""
948+
Move the distribution to a different device.
949+
950+
Args:
951+
device: The device to move the distribution to.
952+
953+
Returns:
954+
The distribution on the specified device.
955+
"""
956+
self._mean = self._mean.to(device)
957+
self._variance = self._variance.to(device)
958+
super().to(device)
900959

901960

902961
def mog_log_prob(

tests/sbiutils_test.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4-
import warnings
54
from typing import Tuple
65

76
import matplotlib.pyplot as plt
@@ -21,7 +20,7 @@
2120
from sbi.inference.trainers.npe.npe_a import NPE_A_MDN
2221
from sbi.neural_nets import classifier_nn, likelihood_nn, posterior_nn
2322
from sbi.utils import BoxUniform, get_kde
24-
from sbi.utils.sbiutils import ImproperEmpirical, mcmc_transform, z_score_parser
23+
from sbi.utils.sbiutils import z_score_parser
2524

2625

2726
def test_conditional_density_1d():
@@ -555,15 +554,3 @@ def test_z_scoring_structured(z_x, z_theta, builder):
555554
# plt.plot(x_zstructured.T)
556555
# plt.title('z-scored: structured dims');
557556
# plt.show()
558-
559-
560-
def test_mcmc_transform_emits_warning_for_improper_empirical():
561-
values = torch.randn(100, 3)
562-
logw = torch.zeros(values.shape[0])
563-
prior = ImproperEmpirical(values, log_weights=logw)
564-
with warnings.catch_warnings(record=True) as w:
565-
warnings.simplefilter("always")
566-
_ = mcmc_transform(prior, enable_transform=True)
567-
assert any("Empirical prior memory/VRAM risk" in str(ww.message) for ww in w), (
568-
"Expected generic empirical prior memory/VRAM risk warning."
569-
)

0 commit comments

Comments
 (0)