Skip to content

Commit 6a3a80c

Browse files
committed
refactor score utils, small fixes.
1 parent a2809c5 commit 6a3a80c

File tree

5 files changed

+125
-73
lines changed

5 files changed

+125
-73
lines changed

sbi/inference/posteriors/vector_field_posterior.py

Lines changed: 54 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# This file is part of sbi, a toolkit for simulation-based inference. sbi is licensed
22
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
33

4+
import math
45
import warnings
56
from typing import Dict, Literal, Optional, Union
67

@@ -150,7 +151,9 @@ def sample(
150151
corrector_params: Optional[Dict] = None,
151152
steps: int = 500,
152153
ts: Optional[Tensor] = None,
153-
iid_method: Literal["fnpe", "gauss", "auto_gauss", "jac_gauss"] = "auto_gauss",
154+
iid_method: Optional[
155+
Literal["fnpe", "gauss", "auto_gauss", "jac_gauss"]
156+
] = None,
154157
iid_params: Optional[Dict] = None,
155158
max_sampling_batch_size: int = 10_000,
156159
sample_with: Optional[str] = None,
@@ -201,19 +204,22 @@ def sample(
201204
x = reshape_to_batch_event(x, self.vector_field_estimator.condition_shape)
202205
is_iid = x.shape[0] > 1
203206
self.potential_fn.set_x(
204-
x, x_is_iid=is_iid, iid_method=iid_method, iid_params=iid_params
207+
x,
208+
x_is_iid=is_iid,
209+
iid_method=iid_method or self.potential_fn.iid_method,
210+
iid_params=iid_params,
205211
)
206212

207213
num_samples = torch.Size(sample_shape).numel()
208214

209215
if sample_with == "ode":
210-
samples = rejection.accept_reject_sample(
216+
samples, _ = rejection.accept_reject_sample(
211217
proposal=self.sample_via_ode,
212218
accept_reject_fn=lambda theta: within_support(self.prior, theta),
213219
num_samples=num_samples,
214220
show_progress_bars=show_progress_bars,
215221
max_sampling_batch_size=max_sampling_batch_size,
216-
)[0]
222+
)
217223
elif sample_with == "sde":
218224
proposal_sampling_kwargs = {
219225
"predictor": predictor,
@@ -225,14 +231,14 @@ def sample(
225231
"max_sampling_batch_size": max_sampling_batch_size,
226232
"show_progress_bars": show_progress_bars,
227233
}
228-
samples = rejection.accept_reject_sample(
234+
samples, _ = rejection.accept_reject_sample(
229235
proposal=self._sample_via_diffusion,
230236
accept_reject_fn=lambda theta: within_support(self.prior, theta),
231237
num_samples=num_samples,
232238
show_progress_bars=show_progress_bars,
233239
max_sampling_batch_size=max_sampling_batch_size,
234240
proposal_sampling_kwargs=proposal_sampling_kwargs,
235-
)[0]
241+
)
236242
else:
237243
raise ValueError(
238244
f"Expected sample_with to be 'ode' or 'sde', but got {sample_with}."
@@ -282,13 +288,16 @@ def _sample_via_diffusion(
282288
"The vector field estimator does not support the 'sde' sampling method."
283289
)
284290

285-
num_samples = torch.Size(sample_shape).numel()
291+
total_samples_needed = torch.Size(sample_shape).numel()
286292

287-
max_sampling_batch_size = (
293+
# Determine effective batch size for sampling
294+
effective_batch_size = (
288295
self.max_sampling_batch_size
289296
if max_sampling_batch_size is None
290297
else max_sampling_batch_size
291298
)
299+
# Ensure we don't use larger batches than total samples needed
300+
effective_batch_size = min(effective_batch_size, total_samples_needed)
292301

293302
# TODO: the time schedule should be provided by the estimator, see issue #1437
294303
if ts is None:
@@ -297,28 +306,47 @@ def _sample_via_diffusion(
297306
ts = torch.linspace(t_max, t_min, steps)
298307
ts = ts.to(self.device)
299308

309+
# Initialize the diffusion sampler
300310
diffuser = Diffuser(
301311
self.potential_fn,
302312
predictor=predictor,
303313
corrector=corrector,
304314
predictor_params=predictor_params,
305315
corrector_params=corrector_params,
306316
)
307-
max_sampling_batch_size = min(max_sampling_batch_size, num_samples)
308-
samples = []
309-
num_iter = num_samples // max_sampling_batch_size
310-
num_iter = (
311-
num_iter + 1 if (num_samples % max_sampling_batch_size) != 0 else num_iter
312-
)
313-
for _ in range(num_iter):
314-
samples.append(
315-
diffuser.run(
316-
num_samples=max_sampling_batch_size,
317-
ts=ts,
318-
show_progress_bars=show_progress_bars,
319-
)
317+
318+
# Calculate how many batches we need
319+
num_batches = math.ceil(total_samples_needed / effective_batch_size)
320+
321+
# Generate samples in batches
322+
all_samples = []
323+
samples_generated = 0
324+
325+
for _ in range(num_batches):
326+
# Calculate how many samples to generate in this batch
327+
remaining_samples = total_samples_needed - samples_generated
328+
current_batch_size = min(effective_batch_size, remaining_samples)
329+
330+
# Generate samples for this batch
331+
batch_samples = diffuser.run(
332+
num_samples=current_batch_size,
333+
ts=ts,
334+
show_progress_bars=show_progress_bars,
335+
)
336+
337+
all_samples.append(batch_samples)
338+
samples_generated += current_batch_size
339+
340+
# Concatenate all batches and ensure we return exactly the requested number
341+
samples = torch.cat(all_samples, dim=0)[:total_samples_needed]
342+
343+
# Check for NaN values
344+
if torch.isnan(samples).any():
345+
raise RuntimeError(
346+
"NaN values detected during diffusion sampling. This may indicate"
347+
" numerical instability in the vector field or improper time "
348+
"scheduling."
320349
)
321-
samples = torch.cat(samples, dim=0)[:num_samples]
322350

323351
return samples
324352

@@ -443,14 +471,14 @@ def sample_batched(
443471
max_sampling_batch_size = capped
444472

445473
if self.sample_with == "ode":
446-
samples = rejection.accept_reject_sample(
474+
samples, _ = rejection.accept_reject_sample(
447475
proposal=self.sample_via_ode,
448476
accept_reject_fn=lambda theta: within_support(self.prior, theta),
449477
num_samples=num_samples,
450478
num_xos=batch_size,
451479
show_progress_bars=show_progress_bars,
452480
max_sampling_batch_size=max_sampling_batch_size,
453-
)[0]
481+
)
454482
samples = samples.reshape(
455483
sample_shape + batch_shape + self.vector_field_estimator.input_shape
456484
)
@@ -465,15 +493,15 @@ def sample_batched(
465493
"max_sampling_batch_size": max_sampling_batch_size,
466494
"show_progress_bars": show_progress_bars,
467495
}
468-
samples = rejection.accept_reject_sample(
496+
samples, _ = rejection.accept_reject_sample(
469497
proposal=self._sample_via_diffusion,
470498
accept_reject_fn=lambda theta: within_support(self.prior, theta),
471499
num_samples=num_samples,
472500
num_xos=batch_size,
473501
show_progress_bars=show_progress_bars,
474502
max_sampling_batch_size=max_sampling_batch_size,
475503
proposal_sampling_kwargs=proposal_sampling_kwargs,
476-
)[0]
504+
)
477505
samples = samples.reshape(
478506
sample_shape + batch_shape + self.vector_field_estimator.input_shape
479507
)

sbi/inference/potentials/score_fn_iid.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@ def estimate_posterior_precision(
658658
precision_est_budget = min(int(prior.event_shape[0] * 1000), 5000)
659659

660660
thetas = posterior.sample_batched(
661-
torch.Size([precision_est_budget]),
661+
sample_shape=torch.Size([precision_est_budget]),
662662
x=conditions,
663663
show_progress_bars=False,
664664
steps=precision_initial_sampler_steps,
@@ -737,7 +737,7 @@ def ensure_lam_positive_definite(
737737
denoising_posterior_precision: torch.Tensor,
738738
N: int,
739739
precision_nugget: float = 0.1,
740-
) -> (torch.Tensor, torch.Tensor):
740+
) -> tuple[torch.Tensor, torch.Tensor]:
741741
r"""
742742
Ensure that the matrix is positive definite.
743743

sbi/inference/potentials/vector_field_potential.py

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -21,42 +21,6 @@
2121
from sbi.utils.torchutils import ensure_theta_batched
2222

2323

24-
def vector_field_estimator_based_potential(
25-
vector_field_estimator: ConditionalVectorFieldEstimator,
26-
prior: Optional[Distribution],
27-
x_o: Optional[Tensor],
28-
enable_transform: bool = True,
29-
**kwargs,
30-
) -> Tuple["VectorFieldBasedPotential", TorchTransform]:
31-
r"""Returns the potential function gradient for vector field estimators.
32-
33-
Args:
34-
vector_field_estimator: The neural network modelling the vector field.
35-
prior: The prior distribution.
36-
x_o: The observed data at which to evaluate the vector field.
37-
enable_transform: Whether to enable transforms. Not supported yet.
38-
**kwargs: Additional keyword arguments passed to
39-
`VectorFieldBasedPotential`.
40-
Returns:
41-
The potential function and a transformation that maps
42-
to unconstrained space.
43-
"""
44-
device = str(next(vector_field_estimator.parameters()).device)
45-
46-
potential_fn = VectorFieldBasedPotential(
47-
vector_field_estimator, prior, x_o, device=device, **kwargs
48-
)
49-
50-
if prior is not None:
51-
theta_transform = mcmc_transform(
52-
prior, device=device, enable_transform=enable_transform
53-
)
54-
else:
55-
theta_transform = torch.distributions.transforms.identity_transform
56-
57-
return potential_fn, theta_transform
58-
59-
6024
class VectorFieldBasedPotential(BasePotential):
6125
def __init__(
6226
self,
@@ -130,7 +94,7 @@ def set_x(
13094
self,
13195
x_o: Optional[Tensor],
13296
x_is_iid: Optional[bool] = False,
133-
iid_method: Literal["fnpe", "gauss", "auto_gauss", "jac_gauss"] = "auto_gauss",
97+
iid_method: Optional[str] = None,
13498
iid_params: Optional[Dict[str, Any]] = None,
13599
**ode_kwargs,
136100
):
@@ -149,7 +113,7 @@ def set_x(
149113
ode_kwargs: Additional keyword arguments for the neural ODE.
150114
"""
151115
super().set_x(x_o, x_is_iid)
152-
self.iid_method = iid_method
116+
self.iid_method = iid_method or self.iid_method
153117
self.iid_params = iid_params
154118
# NOTE: Once IID potential evaluation is supported. This needs to be adapted.
155119
# See #1450.
@@ -281,6 +245,42 @@ def rebuild_flow(self, **kwargs) -> NormalizingFlow:
281245
return flow
282246

283247

248+
def vector_field_estimator_based_potential(
249+
vector_field_estimator: ConditionalVectorFieldEstimator,
250+
prior: Optional[Distribution],
251+
x_o: Optional[Tensor],
252+
enable_transform: bool = True,
253+
**kwargs,
254+
) -> Tuple[VectorFieldBasedPotential, TorchTransform]:
255+
r"""Returns the potential function gradient for vector field estimators.
256+
257+
Args:
258+
vector_field_estimator: The neural network modelling the vector field.
259+
prior: The prior distribution.
260+
x_o: The observed data at which to evaluate the vector field.
261+
enable_transform: Whether to enable transforms. Not supported yet.
262+
**kwargs: Additional keyword arguments passed to
263+
`VectorFieldBasedPotential`.
264+
Returns:
265+
The potential function and a transformation that maps
266+
to unconstrained space.
267+
"""
268+
device = str(next(vector_field_estimator.parameters()).device)
269+
270+
potential_fn = VectorFieldBasedPotential(
271+
vector_field_estimator, prior, x_o, device=device, **kwargs
272+
)
273+
274+
if prior is not None:
275+
theta_transform = mcmc_transform(
276+
prior, device=device, enable_transform=enable_transform
277+
)
278+
else:
279+
theta_transform = torch.distributions.transforms.identity_transform
280+
281+
return potential_fn, theta_transform
282+
283+
284284
class DifferentiablePotentialFunction(torch.autograd.Function):
285285
"""
286286
A wrapper of `VectorFieldBasedPotential` with a custom autograd function

sbi/samplers/rejection/rejection.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,9 @@ def accept_reject_sample(
269269
pbar = tqdm(
270270
disable=not show_progress_bars,
271271
total=num_samples,
272-
desc=f"Drawing {num_samples} posterior samples for {num_xos} observations",
272+
desc=f"Drawing {num_samples} samples for {num_xos} observation" + "s"
273+
if num_xos > 1
274+
else "",
273275
)
274276

275277
accepted = [[] for _ in range(num_xos)]
@@ -280,6 +282,7 @@ def accept_reject_sample(
280282
sampling_batch_size = min(num_samples, max_sampling_batch_size)
281283
num_sampled_total = torch.zeros(num_xos)
282284
num_samples_possible = 0
285+
283286
while num_remaining > 0:
284287
# Sample and reject.
285288
candidates = proposal(
@@ -288,6 +291,7 @@ def accept_reject_sample(
288291
)
289292
# SNPE-style rejection-sampling when the proposal is the neural net.
290293
are_accepted = accept_reject_fn(candidates)
294+
291295
# Reshape necessary in certain cases which do not follow the shape conventions
292296
# of the "DensityEstimator" class.
293297
are_accepted = are_accepted.reshape(sampling_batch_size, num_xos)

sbi/samplers/score/diffuser.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -143,22 +143,42 @@ def run(
143143
Returns:
144144
Tensor: Samples from the distribution(s).
145145
"""
146+
# Initialize samples from the base distribution
146147
samples = self.initialize(num_samples).to(ts.device)
148+
149+
# Set up progress bar for time-stepping through the diffusion process
150+
total_time_steps = ts.numel() - 1 # We skip the first time point
147151
pbar = tqdm(
148152
range(1, ts.numel()),
149153
disable=not show_progress_bars,
150-
desc=f"Drawing {num_samples} posterior samples",
154+
desc=f"Generating {num_samples} posterior samples in {total_time_steps} "
155+
"diffusion steps.",
151156
)
152157

153158
if save_intermediate:
154159
intermediate_samples = [samples]
155160

156-
for i in pbar:
157-
t1 = ts[i - 1]
158-
t0 = ts[i]
159-
samples = self.predictor(samples, t1, t0)
161+
# Step through the diffusion process from t_max to t_min
162+
for time_step_idx in pbar:
163+
# Get current and next time points (going backwards in time)
164+
t_current = ts[time_step_idx - 1] # Previous time point
165+
t_next = ts[time_step_idx] # Current time point
166+
167+
# Apply predictor step
168+
samples = self.predictor(samples, t_current, t_next)
169+
170+
# Check for NaN values after predictor
171+
if torch.isnan(samples).any():
172+
raise RuntimeError(
173+
f"NaN values detected after predictor step "
174+
f"{time_step_idx}/{total_time_steps}. "
175+
f"This may indicate numerical instability in the vector field."
176+
)
177+
178+
# Apply corrector step if available
160179
if self.corrector is not None:
161-
samples = self.corrector(samples, t0, t1)
180+
samples = self.corrector(samples, t_next, t_current)
181+
162182
if save_intermediate:
163183
intermediate_samples.append(samples)
164184

0 commit comments

Comments
 (0)