Skip to content

Commit fea8421

Browse files
authored
feat: add option to return partial results in rejection samplers (#1720)
* feat: add return_partial_on_timeout option to rejection samplers * review comments
1 parent 299f3d0 commit fea8421

File tree

5 files changed

+99
-7
lines changed

5 files changed

+99
-7
lines changed

sbi/inference/posteriors/direct_posterior.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def sample(
140140
show_progress_bars: bool = True,
141141
reject_outside_prior: bool = True,
142142
max_sampling_time: Optional[float] = None,
143+
return_partial_on_timeout: bool = False,
143144
) -> Tensor:
144145
r"""Draw samples from the approximate posterior distribution $p(\theta|x)$.
145146
@@ -159,6 +160,10 @@ def sample(
159160
If exceeded, sampling is aborted and a RuntimeError is raised. Only
160161
applies when `reject_outside_prior=True` (no effect otherwise since
161162
direct sampling is fast).
163+
return_partial_on_timeout: If True and `max_sampling_time` is exceeded,
164+
return the samples collected so far instead of raising a RuntimeError.
165+
A warning will be issued. Only applies when `reject_outside_prior=True`
166+
(default).
162167
"""
163168
num_samples = torch.Size(sample_shape).numel()
164169
x = self._x_else_default_x(x)
@@ -191,6 +196,7 @@ def sample(
191196
proposal_sampling_kwargs={"condition": x},
192197
alternative_method="build_posterior(..., sample_with='mcmc')",
193198
max_sampling_time=max_sampling_time,
199+
return_partial_on_timeout=return_partial_on_timeout,
194200
)[0]
195201
else:
196202
# Bypass rejection sampling entirely.
@@ -210,6 +216,7 @@ def sample_batched(
210216
show_progress_bars: bool = True,
211217
reject_outside_prior: bool = True,
212218
max_sampling_time: Optional[float] = None,
219+
return_partial_on_timeout: bool = False,
213220
) -> Tensor:
214221
r"""Draw samples from the posteriors for a batch of different xs.
215222
@@ -230,6 +237,9 @@ def sample_batched(
230237
max_sampling_time: Optional maximum allowed sampling time in seconds.
231238
If exceeded, sampling is aborted and a RuntimeError is raised. Only
232239
applies when `reject_outside_prior=True`.
240+
return_partial_on_timeout: If True and `max_sampling_time` is exceeded,
241+
return the samples collected so far instead of raising a RuntimeError.
242+
A warning will be issued. Only applies when `reject_outside_prior=True`.
233243
234244
Returns:
235245
Samples from the posteriors of shape (*sample_shape, B, *input_shape)
@@ -276,6 +286,7 @@ def sample_batched(
276286
proposal_sampling_kwargs={"condition": x},
277287
alternative_method="build_posterior(..., sample_with='mcmc')",
278288
max_sampling_time=max_sampling_time,
289+
return_partial_on_timeout=return_partial_on_timeout,
279290
)[0]
280291
else:
281292
# Bypass rejection sampling entirely.

sbi/inference/posteriors/rejection_posterior.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def sample(
138138
show_progress_bars: bool = True,
139139
reject_outside_prior: bool = True,
140140
max_sampling_time: Optional[float] = None,
141+
return_partial_on_timeout: bool = False,
141142
):
142143
r"""Draw samples from the approximate posterior via rejection sampling.
143144
@@ -164,6 +165,10 @@ def sample(
164165
If exceeded, sampling is aborted and a RuntimeError is raised. Only
165166
applies when `reject_outside_prior=True` (no effect otherwise since
166167
direct sampling from the proposal is fast).
168+
return_partial_on_timeout: If True and `max_sampling_time` is exceeded,
169+
return the samples collected so far instead of raising a RuntimeError.
170+
A warning will be issued. Only applies when `reject_outside_prior=True`
171+
(default).
167172
168173
Returns:
169174
Samples from posterior.
@@ -203,6 +208,7 @@ def sample(
203208
num_iter_to_find_max=num_iter_to_find_max,
204209
m=m,
205210
max_sampling_time=max_sampling_time,
211+
return_partial_on_timeout=return_partial_on_timeout,
206212
device=self._device,
207213
)
208214
else:

sbi/inference/posteriors/vector_field_posterior.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ def sample(
164164
show_progress_bars: bool = True,
165165
reject_outside_prior: bool = True,
166166
max_sampling_time: Optional[float] = None,
167+
return_partial_on_timeout: bool = False,
167168
) -> Tensor:
168169
r"""Return samples from posterior distribution $p(\theta|x)$.
169170
@@ -209,6 +210,10 @@ def sample(
209210
If exceeded, sampling is aborted and a RuntimeError is raised. Only
210211
applies when `reject_outside_prior=True` (no effect otherwise since
211212
direct sampling does not use rejection).
213+
return_partial_on_timeout: If True and `max_sampling_time` is exceeded,
214+
return the samples collected so far instead of raising a RuntimeError.
215+
A warning will be issued. Only applies when `reject_outside_prior=True`
216+
(default).
212217
"""
213218

214219
if sample_with is None:
@@ -235,6 +240,7 @@ def sample(
235240
show_progress_bars=show_progress_bars,
236241
max_sampling_batch_size=max_sampling_batch_size,
237242
max_sampling_time=max_sampling_time,
243+
return_partial_on_timeout=return_partial_on_timeout,
238244
)
239245
else:
240246
# Bypass rejection sampling entirely.
@@ -259,6 +265,7 @@ def sample(
259265
max_sampling_batch_size=max_sampling_batch_size,
260266
proposal_sampling_kwargs=proposal_sampling_kwargs,
261267
max_sampling_time=max_sampling_time,
268+
return_partial_on_timeout=return_partial_on_timeout,
262269
)
263270
else:
264271
# Bypass rejection sampling entirely.
@@ -459,6 +466,7 @@ def sample_batched(
459466
show_progress_bars: bool = True,
460467
reject_outside_prior: bool = True,
461468
max_sampling_time: Optional[float] = None,
469+
return_partial_on_timeout: bool = False,
462470
) -> Tensor:
463471
r"""Given a batch of observations [x_1, ..., x_B] this function samples from
464472
posteriors $p(\theta|x_1)$, ... ,$p(\theta|x_B)$, in a batched (i.e. vectorized)
@@ -488,6 +496,9 @@ def sample_batched(
488496
max_sampling_time: Optional maximum allowed sampling time in seconds.
489497
If exceeded, sampling is aborted and a RuntimeError is raised. Only
490498
applies when `reject_outside_prior=True`.
499+
return_partial_on_timeout: If True and `max_sampling_time` is exceeded,
500+
return the samples collected so far instead of raising a RuntimeError.
501+
A warning will be issued. Only applies when `reject_outside_prior=True`.
491502
492503
Returns:
493504
Samples from the posteriors of shape (*sample_shape, B, *input_shape)
@@ -525,6 +536,7 @@ def sample_batched(
525536
show_progress_bars=show_progress_bars,
526537
max_sampling_batch_size=max_sampling_batch_size,
527538
max_sampling_time=max_sampling_time,
539+
return_partial_on_timeout=return_partial_on_timeout,
528540
)
529541
else:
530542
# Bypass rejection sampling.
@@ -553,6 +565,7 @@ def sample_batched(
553565
max_sampling_batch_size=max_sampling_batch_size,
554566
proposal_sampling_kwargs=proposal_sampling_kwargs,
555567
max_sampling_time=max_sampling_time,
568+
return_partial_on_timeout=return_partial_on_timeout,
556569
)
557570
else:
558571
# Bypass rejection sampling.

sbi/samplers/rejection/rejection.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def rejection_sample(
2626
num_iter_to_find_max: int = 100,
2727
m: float = 1.2,
2828
max_sampling_time: Optional[float] = None,
29+
return_partial_on_timeout: bool = False,
2930
device: str = "cpu",
3031
) -> Tuple[Tensor, Tensor]:
3132
r"""Return samples from a `potential_fn` obtained via rejection sampling.
@@ -57,11 +58,14 @@ def rejection_sample(
5758
value will ensure that the samples are indeed from the correct
5859
distribution, but will increase the fraction of rejected samples and thus
5960
computation time.
60-
device: Device on which to sample.
6161
max_sampling_time: Optional maximum allowed sampling time (in seconds).
6262
If this time is exceeded, rejection sampling is aborted and a RuntimeError
63-
is raised. This prevents jobs from stalling indefinitely when the
64-
acceptance rate is extremely low.
63+
is raised (unless `return_partial_on_timeout=True`). This prevents jobs
64+
from stalling indefinitely when the acceptance rate is extremely low.
65+
return_partial_on_timeout: If True and `max_sampling_time` is exceeded, return
66+
the samples collected so far instead of raising a RuntimeError. A warning
67+
will be issued indicating the partial return. Default is False.
68+
device: Device on which to sample.
6569
6670
Returns:
6771
Accepted samples and acceptance rate as scalar Tensor.
@@ -143,6 +147,16 @@ def log_prob(self, theta: Tensor, **kwargs) -> Tensor:
143147
max_sampling_time is not None
144148
and (time.time() - start_time) > max_sampling_time
145149
):
150+
num_collected = sum(s.shape[0] for s in accepted)
151+
if return_partial_on_timeout and num_collected > 0:
152+
pbar.close()
153+
warnings.warn(
154+
f"Timeout exceeded after collecting {num_collected}/"
155+
f"{num_samples} samples. Returning partial results.",
156+
stacklevel=2,
157+
)
158+
samples = torch.cat(accepted)
159+
return samples, as_tensor(acceptance_rate)
146160
raise RuntimeError(
147161
"Sampling aborted early because rejection sampling exceeded "
148162
"max_sampling_time. This is likely due to extremely low "
@@ -225,6 +239,7 @@ def accept_reject_sample(
225239
proposal_sampling_kwargs: Optional[Dict] = None,
226240
alternative_method: Optional[str] = None,
227241
max_sampling_time: Optional[float] = None,
242+
return_partial_on_timeout: bool = False,
228243
**kwargs,
229244
) -> Tuple[Tensor, Tensor]:
230245
r"""Returns samples from a proposal according to a acception criterion.
@@ -264,12 +279,16 @@ def accept_reject_sample(
264279
alternative_method: An alternative method for sampling from the restricted
265280
proposal. E.g., for SNPE, we suggest to sample with MCMC if the rejection
266281
rate is too high. Used only for printing during a potential warning.
282+
max_sampling_time: Optional maximum allowed sampling time (in seconds).
283+
If exceeded, the sampling loop is interrupted and a RuntimeError is raised
284+
unless `return_partial_on_timeout=True`. This prevents infinite or
285+
excessively slow rejection sampling runs, e.g. in cases of heavy leakage
286+
or extremely low acceptance rates.
287+
return_partial_on_timeout: If True and `max_sampling_time` is exceeded, return
288+
the samples collected so far instead of raising a RuntimeError. A warning
289+
will be issued indicating the partial return. Default is False.
267290
kwargs: Absorb additional unused arguments that can be passed to
268291
`rejection_sample()`. Warn if not empty.
269-
max_sampling_time: Optional maximum allowed sampling time (in seconds).
270-
If exceeded, the sampling loop is interrupted and a RuntimeError is raised.
271-
This prevents infinite or excessively slow rejection sampling runs, e.g.
272-
in cases of heavy leakage or extremely low acceptance rates.
273292
274293
Returns:
275294
Accepted samples of shape `(sample_dim, batch_dim, *event_shape)`, and
@@ -318,6 +337,24 @@ def accept_reject_sample(
318337
max_sampling_time is not None
319338
and (time.time() - start_time) > max_sampling_time
320339
):
340+
# Check if we have any samples collected
341+
num_collected = min(
342+
sum(s.shape[0] for s in accepted[i]) for i in range(num_xos)
343+
)
344+
if return_partial_on_timeout and num_collected > 0:
345+
pbar.close()
346+
warnings.warn(
347+
f"Timeout exceeded after collecting {num_collected}/{num_samples}"
348+
f" samples. Returning partial results.",
349+
stacklevel=2,
350+
)
351+
# Return partial samples with proper shape
352+
samples = [
353+
torch.cat(accepted[i], dim=0)[:num_collected]
354+
for i in range(num_xos)
355+
]
356+
samples = torch.stack(samples, dim=1)
357+
return samples, as_tensor(acceptance_rate, device=samples.device)
321358
raise RuntimeError(
322359
"Sampling aborted early because rejection sampling exceeded "
323360
"max_sampling_time. This is likely due to extremely low "

tests/rejection_sampling_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,31 @@ def to(self, device):
123123
assert torch.all(no_reject == 5.0)
124124

125125

126+
def test_accept_reject_sample_partial_return():
127+
"""Test that return_partial_on_timeout returns collected samples."""
128+
129+
def accept_rare_fn(x):
130+
# Accept only 1% of samples to ensure we don't finish
131+
return torch.rand(x.shape[0]) < 0.01
132+
133+
proposal = DummyProposal()
134+
with warnings.catch_warnings(record=True) as w:
135+
warnings.simplefilter("always")
136+
samples, acceptance = accept_reject_sample(
137+
proposal=proposal,
138+
accept_reject_fn=accept_rare_fn,
139+
num_samples=10000, # Request many samples
140+
max_sampling_time=0.001, # Very short timeout
141+
return_partial_on_timeout=True,
142+
)
143+
# Should have some samples (not all 10000)
144+
assert samples.shape[0] > 0
145+
assert samples.shape[0] < 10000
146+
# Should have issued a warning
147+
assert len(w) == 1
148+
assert "partial results" in str(w[0].message).lower()
149+
150+
126151
def test_warn_if_outside_prior_support():
127152
"""Test the warning utility for samples outside prior support."""
128153
prior = Uniform(torch.zeros(2), torch.ones(2))

0 commit comments

Comments
 (0)