Skip to content

Commit 65678a7

Browse files
committed
feat: add return_partial_on_timeout option to rejection samplers
1 parent fa36b0f commit 65678a7

File tree

5 files changed

+96
-7
lines changed

5 files changed

+96
-7
lines changed

sbi/inference/posteriors/direct_posterior.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@ def sample(
141141
show_progress_bars: bool = True,
142142
reject_outside_prior: bool = True,
143143
max_sampling_time: Optional[float] = None,
144+
return_partial_on_timeout: bool = False,
144145
) -> Tensor:
145146
r"""Return samples from posterior distribution $p(\theta|x)$.
146147
@@ -159,6 +160,9 @@ def sample(
159160
If exceeded, sampling is aborted and a RuntimeError is raised. Only
160161
applies when `reject_outside_prior=True` (no effect otherwise since
161162
direct sampling is fast).
163+
return_partial_on_timeout: If True and `max_sampling_time` is exceeded,
164+
return the samples collected so far instead of raising a RuntimeError.
165+
A warning will be issued. Only applies when `reject_outside_prior=True`.
162166
"""
163167
num_samples = torch.Size(sample_shape).numel()
164168
x = self._x_else_default_x(x)
@@ -198,6 +202,7 @@ def sample(
198202
proposal_sampling_kwargs={"condition": x},
199203
alternative_method="build_posterior(..., sample_with='mcmc')",
200204
max_sampling_time=max_sampling_time,
205+
return_partial_on_timeout=return_partial_on_timeout,
201206
)[0]
202207
else:
203208
# Bypass rejection sampling entirely.
@@ -217,6 +222,7 @@ def sample_batched(
217222
show_progress_bars: bool = True,
218223
reject_outside_prior: bool = True,
219224
max_sampling_time: Optional[float] = None,
225+
return_partial_on_timeout: bool = False,
220226
) -> Tensor:
221227
r"""Given a batch of observations [x_1, ..., x_B] this function samples from
222228
posteriors $p(\theta|x_1)$, ... ,$p(\theta|x_B)$, in a batched (i.e. vectorized)
@@ -236,6 +242,9 @@ def sample_batched(
236242
max_sampling_time: Optional maximum allowed sampling time in seconds.
237243
If exceeded, sampling is aborted and a RuntimeError is raised. Only
238244
applies when `reject_outside_prior=True`.
245+
return_partial_on_timeout: If True and `max_sampling_time` is exceeded,
246+
return the samples collected so far instead of raising a RuntimeError.
247+
A warning will be issued. Only applies when `reject_outside_prior=True`.
239248
240249
Returns:
241250
Samples from the posteriors of shape (*sample_shape, B, *input_shape)
@@ -282,6 +291,7 @@ def sample_batched(
282291
proposal_sampling_kwargs={"condition": x},
283292
alternative_method="build_posterior(..., sample_with='mcmc')",
284293
max_sampling_time=max_sampling_time,
294+
return_partial_on_timeout=return_partial_on_timeout,
285295
)[0]
286296
else:
287297
# Bypass rejection sampling entirely.

sbi/inference/posteriors/rejection_posterior.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ def sample(
139139
show_progress_bars: bool = True,
140140
reject_outside_prior: bool = True,
141141
max_sampling_time: Optional[float] = None,
142+
return_partial_on_timeout: bool = False,
142143
):
143144
r"""Return samples from posterior $p(\theta|x)$ via rejection sampling.
144145
@@ -157,6 +158,9 @@ def sample(
157158
If exceeded, sampling is aborted and a RuntimeError is raised. Only
158159
applies when `reject_outside_prior=True` (no effect otherwise since
159160
direct sampling from the proposal is fast).
161+
return_partial_on_timeout: If True and `max_sampling_time` is exceeded,
162+
return the samples collected so far instead of raising a RuntimeError.
163+
A warning will be issued. Only applies when `reject_outside_prior=True`.
160164
161165
Returns:
162166
Samples from posterior.
@@ -202,6 +206,7 @@ def sample(
202206
num_iter_to_find_max=num_iter_to_find_max,
203207
m=m,
204208
max_sampling_time=max_sampling_time,
209+
return_partial_on_timeout=return_partial_on_timeout,
205210
device=self._device,
206211
)
207212
else:

sbi/inference/posteriors/vector_field_posterior.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def sample(
164164
show_progress_bars: bool = True,
165165
reject_outside_prior: bool = True,
166166
max_sampling_time: Optional[float] = None,
167+
return_partial_on_timeout: bool = False,
167168
) -> Tensor:
168169
r"""Return samples from posterior distribution $p(\theta|x)$.
169170
@@ -209,6 +210,9 @@ def sample(
209210
If exceeded, sampling is aborted and a RuntimeError is raised. Only
210211
applies when `reject_outside_prior=True` (no effect otherwise since
211212
direct sampling does not use rejection).
213+
return_partial_on_timeout: If True and `max_sampling_time` is exceeded,
214+
return the samples collected so far instead of raising a RuntimeError.
215+
A warning will be issued. Only applies when `reject_outside_prior=True`.
212216
"""
213217

214218
if sample_with is None:
@@ -235,6 +239,7 @@ def sample(
235239
show_progress_bars=show_progress_bars,
236240
max_sampling_batch_size=max_sampling_batch_size,
237241
max_sampling_time=max_sampling_time,
242+
return_partial_on_timeout=return_partial_on_timeout,
238243
)
239244
else:
240245
# Bypass rejection sampling entirely.
@@ -259,6 +264,7 @@ def sample(
259264
max_sampling_batch_size=max_sampling_batch_size,
260265
proposal_sampling_kwargs=proposal_sampling_kwargs,
261266
max_sampling_time=max_sampling_time,
267+
return_partial_on_timeout=return_partial_on_timeout,
262268
)
263269
else:
264270
# Bypass rejection sampling entirely.
@@ -459,6 +465,7 @@ def sample_batched(
459465
show_progress_bars: bool = True,
460466
reject_outside_prior: bool = True,
461467
max_sampling_time: Optional[float] = None,
468+
return_partial_on_timeout: bool = False,
462469
) -> Tensor:
463470
r"""Given a batch of observations [x_1, ..., x_B] this function samples from
464471
posteriors $p(\theta|x_1)$, ... ,$p(\theta|x_B)$, in a batched (i.e. vectorized)
@@ -488,6 +495,9 @@ def sample_batched(
488495
max_sampling_time: Optional maximum allowed sampling time in seconds.
489496
If exceeded, sampling is aborted and a RuntimeError is raised. Only
490497
applies when `reject_outside_prior=True`.
498+
return_partial_on_timeout: If True and `max_sampling_time` is exceeded,
499+
return the samples collected so far instead of raising a RuntimeError.
500+
A warning will be issued. Only applies when `reject_outside_prior=True`.
491501
492502
Returns:
493503
Samples from the posteriors of shape (*sample_shape, B, *input_shape)
@@ -525,6 +535,7 @@ def sample_batched(
525535
show_progress_bars=show_progress_bars,
526536
max_sampling_batch_size=max_sampling_batch_size,
527537
max_sampling_time=max_sampling_time,
538+
return_partial_on_timeout=return_partial_on_timeout,
528539
)
529540
else:
530541
# Bypass rejection sampling.
@@ -553,6 +564,7 @@ def sample_batched(
553564
max_sampling_batch_size=max_sampling_batch_size,
554565
proposal_sampling_kwargs=proposal_sampling_kwargs,
555566
max_sampling_time=max_sampling_time,
567+
return_partial_on_timeout=return_partial_on_timeout,
556568
)
557569
else:
558570
# Bypass rejection sampling.

sbi/samplers/rejection/rejection.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def rejection_sample(
2626
num_iter_to_find_max: int = 100,
2727
m: float = 1.2,
2828
max_sampling_time: Optional[float] = None,
29+
return_partial_on_timeout: bool = False,
2930
device: str = "cpu",
3031
) -> Tuple[Tensor, Tensor]:
3132
r"""Return samples from a `potential_fn` obtained via rejection sampling.
@@ -57,11 +58,14 @@ def rejection_sample(
5758
value will ensure that the samples are indeed from the correct
5859
distribution, but will increase the fraction of rejected samples and thus
5960
computation time.
60-
device: Device on which to sample.
6161
max_sampling_time: Optional maximum allowed sampling time (in seconds).
6262
If this time is exceeded, rejection sampling is aborted and a RuntimeError
63-
is raised. This prevents jobs from stalling indefinitely when the
64-
acceptance rate is extremely low.
63+
is raised (unless `return_partial_on_timeout=True`). This prevents jobs
64+
from stalling indefinitely when the acceptance rate is extremely low.
65+
return_partial_on_timeout: If True and `max_sampling_time` is exceeded, return
66+
the samples collected so far instead of raising a RuntimeError. A warning
67+
will be issued indicating the partial return. Default is False.
68+
device: Device on which to sample.
6569
6670
Returns:
6771
Accepted samples and acceptance rate as scalar Tensor.
@@ -143,6 +147,16 @@ def log_prob(self, theta: Tensor, **kwargs) -> Tensor:
143147
max_sampling_time is not None
144148
and (time.time() - start_time) > max_sampling_time
145149
):
150+
num_collected = sum(s.shape[0] for s in accepted)
151+
if return_partial_on_timeout and num_collected > 0:
152+
pbar.close()
153+
warnings.warn(
154+
f"Timeout exceeded after collecting {num_collected}/"
155+
f"{num_samples} samples. Returning partial results.",
156+
stacklevel=2,
157+
)
158+
samples = torch.cat(accepted)
159+
return samples, as_tensor(acceptance_rate)
146160
raise RuntimeError(
147161
"Sampling aborted early because rejection sampling exceeded "
148162
"max_sampling_time. This is likely due to extremely low "
@@ -225,6 +239,7 @@ def accept_reject_sample(
225239
proposal_sampling_kwargs: Optional[Dict] = None,
226240
alternative_method: Optional[str] = None,
227241
max_sampling_time: Optional[float] = None,
242+
return_partial_on_timeout: bool = False,
228243
**kwargs,
229244
) -> Tuple[Tensor, Tensor]:
230245
r"""Returns samples from a proposal according to a acception criterion.
@@ -264,12 +279,16 @@ def accept_reject_sample(
264279
alternative_method: An alternative method for sampling from the restricted
265280
proposal. E.g., for SNPE, we suggest to sample with MCMC if the rejection
266281
rate is too high. Used only for printing during a potential warning.
282+
max_sampling_time: Optional maximum allowed sampling time (in seconds).
283+
If exceeded, the sampling loop is interrupted and a RuntimeError is raised
284+
(unless `return_partial_on_timeout=True`). This prevents infinite or
285+
excessively slow rejection sampling runs, e.g. in cases of heavy leakage
286+
or extremely low acceptance rates.
287+
return_partial_on_timeout: If True and `max_sampling_time` is exceeded, return
288+
the samples collected so far instead of raising a RuntimeError. A warning
289+
will be issued indicating the partial return. Default is False.
267290
kwargs: Absorb additional unused arguments that can be passed to
268291
`rejection_sample()`. Warn if not empty.
269-
max_sampling_time: Optional maximum allowed sampling time (in seconds).
270-
If exceeded, the sampling loop is interrupted and a RuntimeError is raised.
271-
This prevents infinite or excessively slow rejection sampling runs, e.g.
272-
in cases of heavy leakage or extremely low acceptance rates.
273292
274293
Returns:
275294
Accepted samples of shape `(sample_dim, batch_dim, *event_shape)`, and
@@ -318,6 +337,24 @@ def accept_reject_sample(
318337
max_sampling_time is not None
319338
and (time.time() - start_time) > max_sampling_time
320339
):
340+
# Check if we have any samples collected
341+
num_collected = min(
342+
sum(s.shape[0] for s in accepted[i]) for i in range(num_xos)
343+
)
344+
if return_partial_on_timeout and num_collected > 0:
345+
pbar.close()
346+
warnings.warn(
347+
f"Timeout exceeded after collecting {num_collected}/{num_samples}"
348+
f" samples. Returning partial results.",
349+
stacklevel=2,
350+
)
351+
# Return partial samples with proper shape
352+
samples = [
353+
torch.cat(accepted[i], dim=0)[:num_collected]
354+
for i in range(num_xos)
355+
]
356+
samples = torch.stack(samples, dim=1)
357+
return samples, as_tensor(acceptance_rate, device=samples.device)
321358
raise RuntimeError(
322359
"Sampling aborted early because rejection sampling exceeded "
323360
"max_sampling_time. This is likely due to extremely low "

tests/rejection_sampling_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,31 @@ def to(self, device):
123123
assert torch.all(no_reject == 5.0)
124124

125125

126+
def test_accept_reject_sample_partial_return():
127+
"""Test that return_partial_on_timeout returns collected samples."""
128+
129+
def accept_rare_fn(x):
130+
# Accept only 1% of samples to ensure we don't finish
131+
return torch.rand(x.shape[0]) < 0.01
132+
133+
proposal = DummyProposal()
134+
with warnings.catch_warnings(record=True) as w:
135+
warnings.simplefilter("always")
136+
samples, acceptance = accept_reject_sample(
137+
proposal=proposal,
138+
accept_reject_fn=accept_rare_fn,
139+
num_samples=10000, # Request many samples
140+
max_sampling_time=0.001, # Very short timeout
141+
return_partial_on_timeout=True,
142+
)
143+
# Should have some samples (not all 10000)
144+
assert samples.shape[0] > 0
145+
assert samples.shape[0] < 10000
146+
# Should have issued a warning
147+
assert len(w) == 1
148+
assert "partial results" in str(w[0].message).lower()
149+
150+
126151
def test_warn_if_outside_prior_support():
127152
"""Test the warning utility for samples outside prior support."""
128153
prior = Uniform(torch.zeros(2), torch.ones(2))

0 commit comments

Comments
 (0)