Skip to content

Commit ccafca3

Browse files
Kartik-SamamanuelgloecklerCopilot
authored
Naive approach for IID potential evaluation for score estimators (#1508)
* Changes to compute log_prob of NPSE under iid observations * Improved variable naming, and fixed duplicate log_probs being returned * cleaned unneccesary comment * Added doc strings for score based log prob test function * init path for simulators modified to make use of true_posterior_linear_gaussian_mvn_prior in tests * Removed redundant print statement * Moved test of log_prob calculation for NPSE to an existing test file * Fixing test to recent variant * Test running but failing for some other reason * Test is basically passing, some torch autograd is just failing in some cases. * Fixing bug on merge * fix linting * Update tests/linearGaussian_vector_field_test.py Multiline string Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Multi line string * tolerance * Slight improvements --------- Co-authored-by: manuelgloeckler <manu.gloeckler@hotmail.de> Co-authored-by: manuelgloeckler <38903899+manuelgloeckler@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 54a7b74 commit ccafca3

File tree

3 files changed

+103
-29
lines changed

3 files changed

+103
-29
lines changed

sbi/inference/posteriors/vector_field_posterior.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,10 @@ def log_prob(
404404
`(len(θ),)`-shaped log posterior probability $\log p(\theta|x)$ for θ in the
405405
support of the prior, -∞ (corresponding to 0 probability) outside.
406406
"""
407-
self.potential_fn.set_x(self._x_else_default_x(x), **(ode_kwargs or {}))
407+
x = self._x_else_default_x(x)
408+
x = reshape_to_batch_event(x, self.vector_field_estimator.condition_shape)
409+
is_iid = x.shape[0] > 1
410+
self.potential_fn.set_x(x, x_is_iid=is_iid, **(ode_kwargs or {}))
408411

409412
theta = ensure_theta_batched(torch.as_tensor(theta))
410413
return self.potential_fn(

sbi/inference/potentials/vector_field_potential.py

Lines changed: 53 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4-
from typing import Any, Dict, Literal, Optional, Tuple, Union
4+
from typing import Any, Dict, List, Literal, Optional, Tuple, Union
55

66
import torch
77
from torch import Tensor
@@ -115,10 +115,10 @@ def set_x(
115115
super().set_x(x_o, x_is_iid)
116116
self.iid_method = iid_method or self.iid_method
117117
self.iid_params = iid_params
118-
# NOTE: Once IID potential evaluation is supported. This needs to be adapted.
119-
# See #1450.
120118
if not x_is_iid and (self._x_o is not None):
121119
self.flow = self.rebuild_flow(**ode_kwargs)
120+
elif self._x_o is not None:
121+
self.flows = self.rebuild_flows_for_batch(**ode_kwargs)
122122

123123
def __call__(
124124
self,
@@ -135,26 +135,6 @@ def __call__(
135135
Returns:
136136
The potential function, i.e., the log probability of the posterior.
137137
"""
138-
# TODO: incorporate iid setting. See issue #1450 and PR #1508
139-
if self.x_is_iid:
140-
if (
141-
self.vector_field_estimator.MARGINALS_DEFINED
142-
and self.vector_field_estimator.SCORE_DEFINED
143-
):
144-
raise NotImplementedError(
145-
"Potential function evaluation in the "
146-
"IID setting is not yet supported"
147-
" for vector field based methods. "
148-
"Sampling does however work via `.sample`. "
149-
"If you intended to evaluate the posterior "
150-
"given a batch of (non-iid) "
151-
"x use `log_prob_batched`."
152-
)
153-
else:
154-
raise NotImplementedError(
155-
"IID is not supported for this vector field estimator "
156-
"since the required methods (marginals or score) are not defined."
157-
)
158138

159139
theta = ensure_theta_batched(torch.as_tensor(theta))
160140
theta_density_estimator = reshape_to_sample_batch_event(
@@ -163,7 +143,31 @@ def __call__(
163143
self.vector_field_estimator.eval()
164144

165145
with torch.set_grad_enabled(track_gradients):
166-
log_probs = self.flow.log_prob(theta_density_estimator).squeeze(-1)
146+
if self.x_is_iid:
147+
assert self.prior is not None, (
148+
"Prior is required for evaluating log_prob with iid observations."
149+
)
150+
assert self.flows is not None, (
151+
"Flows for each iid x are required for evaluating log_prob."
152+
)
153+
num_iid = self.x_o.shape[0] # number of iid samples
154+
iid_posteriors_prob = torch.sum(
155+
torch.stack(
156+
[
157+
flow.log_prob(theta_density_estimator).squeeze(-1)
158+
for flow in self.flows
159+
],
160+
dim=0,
161+
),
162+
dim=0,
163+
)
164+
# Apply the adjustment for iid observations i.e. we have to subtract
165+
# (num_iid-1) times the log prior.
166+
log_probs = iid_posteriors_prob - (num_iid - 1) * self.prior.log_prob(
167+
theta_density_estimator
168+
).squeeze(-1)
169+
else:
170+
log_probs = self.flow.log_prob(theta_density_estimator).squeeze(-1)
167171
# Force probability to be zero outside prior support.
168172
in_prior_support = within_support(self.prior, theta)
169173

@@ -208,8 +212,8 @@ def gradient(
208212

209213
if self._x_o is None:
210214
raise ValueError(
211-
"No observed data x_o is available. Please reinitialize \
212-
the potential or manually set self._x_o."
215+
"No observed data x_o is available. Please reinitialize"
216+
"the potential or manually set self._x_o."
213217
)
214218

215219
with torch.set_grad_enabled(track_gradients):
@@ -239,8 +243,8 @@ def rebuild_flow(self, **kwargs) -> NormalizingFlow:
239243
"""
240244
if self._x_o is None:
241245
raise ValueError(
242-
"No observed data x_o is available. Please reinitialize \
243-
the potential or manually set self._x_o."
246+
"No observed data x_o is available. Please reinitialize"
247+
"the potential or manually set self._x_o."
244248
)
245249
x_density_estimator = reshape_to_batch_event(
246250
self.x_o, event_shape=self.vector_field_estimator.condition_shape
@@ -249,6 +253,27 @@ def rebuild_flow(self, **kwargs) -> NormalizingFlow:
249253
flow = self.neural_ode(x_density_estimator, **kwargs)
250254
return flow
251255

256+
def rebuild_flows_for_batch(self, **kwargs) -> List[NormalizingFlow]:
257+
"""
258+
Rebuilds the continuous normalizing flows for each iid in x_o. This is used when
259+
a new default x_o is set, or to evaluate the log probs at higher precision.
260+
"""
261+
if self._x_o is None:
262+
raise ValueError(
263+
"No observed data x_o is available. Please reinitialize "
264+
"the potential or manually set self._x_o."
265+
)
266+
flows = []
267+
for i in range(self._x_o.shape[0]):
268+
iid_x = self._x_o[i]
269+
x_density_estimator = reshape_to_batch_event(
270+
iid_x, event_shape=self.vector_field_estimator.condition_shape
271+
)
272+
273+
flow = self.neural_ode(x_density_estimator, **kwargs)
274+
flows.append(flow)
275+
return flows
276+
252277

253278
def vector_field_estimator_based_potential(
254279
vector_field_estimator: ConditionalVectorFieldEstimator,

tests/linearGaussian_vector_field_test.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,3 +588,49 @@ def simulator(theta):
588588

589589
max_err = np.max(error)
590590
assert max_err < 0.0027
591+
592+
593+
@pytest.mark.slow
594+
@pytest.mark.parametrize("vector_field_type", ["ve", "vp", "fmpe"])
595+
@pytest.mark.parametrize("prior_type", ["gaussian"])
596+
@pytest.mark.parametrize("iid_batch_size", [1, 2, 5])
597+
def test_iid_log_prob(vector_field_type, prior_type, iid_batch_size):
598+
'''
599+
Tests the log-probability computation of the score-based posterior.
600+
601+
'''
602+
603+
vector_field_trained_model = train_vector_field_model(vector_field_type, prior_type)
604+
605+
# Prior Gaussian
606+
prior = vector_field_trained_model["prior"]
607+
vf_estimator = vector_field_trained_model["estimator"]
608+
inference = vector_field_trained_model["inference"]
609+
likelihood_shift = vector_field_trained_model["likelihood_shift"]
610+
likelihood_cov = vector_field_trained_model["likelihood_cov"]
611+
prior_mean = vector_field_trained_model["prior_mean"]
612+
prior_cov = vector_field_trained_model["prior_cov"]
613+
num_dim = vector_field_trained_model["num_dim"]
614+
num_posterior_samples = 1000
615+
616+
# Ground truth theta
617+
theta_o = zeros(num_dim)
618+
x_o = linear_gaussian(
619+
theta_o.repeat(iid_batch_size, 1),
620+
likelihood_shift=likelihood_shift,
621+
likelihood_cov=likelihood_cov,
622+
)
623+
true_posterior = true_posterior_linear_gaussian_mvn_prior(
624+
x_o, likelihood_shift, likelihood_cov, prior_mean, prior_cov
625+
)
626+
627+
approx_posterior = inference.build_posterior(vf_estimator, prior=prior)
628+
posterior_samples = true_posterior.sample((num_posterior_samples,))
629+
true_prob = true_posterior.log_prob(posterior_samples)
630+
approx_prob = approx_posterior.log_prob(posterior_samples, x=x_o)
631+
632+
diff = torch.abs(true_prob - approx_prob)
633+
assert diff.mean() < 0.3 * iid_batch_size, (
634+
f"Probs diff: {diff.mean()} too big "
635+
f"for number of samples {num_posterior_samples}"
636+
)

0 commit comments

Comments
 (0)