3535from torch .optim .adam import Adam
3636
3737from sbi .sbi_types import TorchTransform
38- from sbi .utils .torchutils import atleast_2d
3938
4039
4140def warn_if_zscoring_changes_data (x : Tensor , duplicate_tolerance : float = 0.1 ) -> None :
@@ -242,21 +241,6 @@ def biject_transform_zuko(
242241 )
243242
244243
245- def warn_empirical_prior_memory_risk (context : Optional [str ] = None ) -> None :
246- """Emit a standardized warning about empirical-prior memory/VRAM risks.
247-
248- Args:
249- context: Optional context string to append to the warning.
250- """
251- base = (
252- "Empirical prior memory/VRAM risk: empirical priors retain all simulations "
253- "as support and may trigger operations over large supports. This can "
254- "significantly increase memory usage and cause out-of-memory (OOM) errors."
255- )
256- message = f"{ base } Context: { context } " if context else base
257- warnings .warn (message , stacklevel = 2 )
258-
259-
260244def z_standardization (
261245 batch_t : Tensor ,
262246 structured_dims : bool = False ,
@@ -752,13 +736,6 @@ def mcmc_transform(
752736 (or z-scored) to constrained (or non-z-scored) space.
753737 """
754738 if enable_transform :
755- if isinstance (prior , (ImproperEmpirical , Empirical )):
756- warn_empirical_prior_memory_risk (
757- "disabled parameter transforms to avoid sampling-based moments"
758- )
759- return torch_tf .IndependentTransform (
760- torch_tf .identity_transform , reinterpreted_batch_ndims = 1
761- )
762739
763740 def prior_mean_std_transform (prior , device ):
764741 try :
@@ -850,7 +827,12 @@ def check_transform(
850827) -> None :
851828 """Check validity of transformed and re-transformed samples."""
852829
853- theta = prior .sample (torch .Size ((2 ,)))
830+ # check transform with prior samples
831+ try :
832+ theta = prior .sample (torch .Size ((2 ,)))
833+ except NotImplementedError :
834+ # Prior has no sampling method, use the prior mean instead
835+ theta = prior .mean .repeat (2 , * [1 ] * prior .mean .dim ())
854836
855837 theta_unconstrained = transform .inv (theta )
856838 assert (
@@ -881,9 +863,15 @@ class ImproperEmpirical(Empirical):
881863 def __init__ (self , values : Tensor , log_weights : Optional [Tensor ] = None ):
882864 super ().__init__ (values , log_weights = log_weights )
883865 # Warn if extremely large to inform about memory/serialization cost.
884- support_size = values .shape [0 ]
885- if support_size > 10_000_000 : # 10M still works well on modern hardware.
886- warn_empirical_prior_memory_risk (f">10M support size (size={ support_size } )" )
866+ self ._mean = self ._compute_mean (values , log_weights )
867+ self ._variance = self ._compute_variance (values , log_weights )
868+
869+ def sample (self , sample_shape = torch .Size ()):
870+ raise NotImplementedError (
871+ "Sampling from ImproperEmpirical is not supported. If you are using "
872+ "likelihood or ratio estimation, or multi-round inference, you need to "
873+ "define a prior distribution."
874+ )
887875
888876 def log_prob (self , value : Tensor ) -> Tensor :
889877 """
@@ -895,8 +883,79 @@ def log_prob(self, value: Tensor) -> Tensor:
895883 Returns:
896884 Tensor of as many ones as there were parameter sets.
897885 """
898- value = atleast_2d (value )
899- return zeros (value .shape [0 ])
886+ raise NotImplementedError (
887+ "Evaluating log_prob from ImproperEmpirical is not supported. If you are "
888+ "using likelihood or ratio estimation, or multi-round inference, you need "
889+ "to define a prior distribution."
890+ )
891+
892+ def _compute_mean (self , values : Tensor , weights : Optional [Tensor ] = None ) -> Tensor :
893+ """
894+ Return the mean of the empirical distribution.
895+
896+ Args:
897+ values: The empirical samples.
898+ weights: Optional weights for the samples.
899+
900+ Returns:
901+ The mean of the empirical distribution.
902+ """
903+ if weights is None :
904+ return torch .mean (values , dim = 0 )
905+ else :
906+ normalized_weights = torch .nn .functional .softmax (weights , dim = 0 )
907+ return torch .sum (normalized_weights .unsqueeze (- 1 ) * values , dim = 0 )
908+
909+ def _compute_variance (
910+ self , values : Tensor , weights : Optional [Tensor ] = None
911+ ) -> Tensor :
912+ """
913+ Return the standard deviation of the empirical distribution.
914+
915+ Args:
916+ values: The empirical samples.
917+ weights: Optional weights for the samples.
918+
919+ Returns:
920+ The standard deviation of the empirical distribution.
921+ """
922+ if weights is None :
923+ variance = torch .var (values , dim = 0 )
924+ else :
925+ normalized_weights = torch .nn .functional .softmax (weights , dim = 0 )
926+ variance = torch .sum (
927+ normalized_weights .unsqueeze (- 1 ) * (values - self ._mean ) ** 2 ,
928+ dim = 0 ,
929+ )
930+ # bias correction
931+ variance = variance / (1 - torch .sum (normalized_weights ** 2 ))
932+ return variance
933+
934+ @property
935+ def mean (self ) -> Tensor :
936+ return self ._mean
937+
938+ @property
939+ def variance (self ) -> Tensor :
940+ return self ._variance
941+
942+ @property
943+ def stddev (self ) -> Tensor :
944+ return torch .sqrt (self ._variance )
945+
946+ def to (self , device : Union [str , torch .device ]) -> None :
947+ """
948+ Move the distribution to a different device.
949+
950+ Args:
951+ device: The device to move the distribution to.
952+
953+ Returns:
954+ The distribution on the specified device.
955+ """
956+ self ._mean = self ._mean .to (device )
957+ self ._variance = self ._variance .to (device )
958+ super ().to (device )
900959
901960
902961def mog_log_prob (
0 commit comments