66import pytest
77import torch
88from torch import eye , ones , zeros
9- from torch .distributions import MultivariateNormal
9+ from torch .distributions import Independent , MultivariateNormal , Uniform
1010
1111from sbi .inference import (
1212 NLE_A ,
@@ -98,13 +98,20 @@ def test_importance_posterior_sample_log_prob(snplre_method: type):
9898
9999@pytest .mark .parametrize ("snpe_method" , [NPE_A , NPE_C ])
100100@pytest .mark .parametrize ("x_o_batch_dim" , (0 , 1 , 2 ))
101+ @pytest .mark .parametrize ("prior" , ("mvn" , "uniform" ))
101102def test_batched_sample_log_prob_with_different_x (
102- snpe_method : type , x_o_batch_dim : bool
103+ snpe_method : type ,
104+ x_o_batch_dim : bool ,
105+ prior : str ,
103106):
104107 num_dim = 2
105108 num_simulations = 1000
106109
107- prior = MultivariateNormal (loc = zeros (num_dim ), covariance_matrix = eye (num_dim ))
110+ # We also want to test on bounded support! Which will invoke leakage correction.
111+ if prior == "mvn" :
112+ prior = MultivariateNormal (loc = zeros (num_dim ), covariance_matrix = eye (num_dim ))
113+ elif prior == "uniform" :
114+ prior = Independent (Uniform (- 1.0 * ones (num_dim ), 1.0 * ones (num_dim )), 1 )
108115 simulator = diagonal_linear_gaussian
109116
110117 inference = snpe_method (prior = prior )
@@ -116,6 +123,7 @@ def test_batched_sample_log_prob_with_different_x(
116123
117124 posterior = DirectPosterior (posterior_estimator = posterior_estimator , prior = prior )
118125
126+ torch .manual_seed (0 )
119127 samples = posterior .sample_batched ((10 ,), x_o )
120128 batched_log_probs = posterior .log_prob_batched (samples , x_o )
121129
@@ -126,6 +134,20 @@ def test_batched_sample_log_prob_with_different_x(
126134 ), "Sample shape wrong"
127135 assert batched_log_probs .shape == (10 , max (x_o_batch_dim , 1 )), "logprob shape wrong"
128136
137+ # Test consistency with non-batched log_prob
138+ # NOTE: Leakage factor is a MC estimate, so we need to relax the tolerance here.
139+ if x_o_batch_dim == 0 :
140+ log_probs = posterior .log_prob (samples , x = x_o )
141+ assert torch .allclose (
142+ log_probs , batched_log_probs [:, 0 ], atol = 1e-1 , rtol = 1e-1
143+ ), "Batched log probs different from non-batched log probs"
144+ else :
145+ for idx in range (x_o_batch_dim ):
146+ log_probs = posterior .log_prob (samples [:, idx ], x = x_o [idx ])
147+ assert torch .allclose (
148+ log_probs , batched_log_probs [:, idx ], atol = 1e-1 , rtol = 1e-1
149+ ), "Batched log probs different from non-batched log probs"
150+
129151
130152@pytest .mark .mcmc
131153@pytest .mark .parametrize ("snlre_method" , [NLE_A , NRE_A , NRE_B , NRE_C , NPE_C ])
0 commit comments