@@ -26,6 +26,7 @@ def rejection_sample(
2626 num_iter_to_find_max : int = 100 ,
2727 m : float = 1.2 ,
2828 max_sampling_time : Optional [float ] = None ,
29+ return_partial_on_timeout : bool = False ,
2930 device : str = "cpu" ,
3031) -> Tuple [Tensor , Tensor ]:
3132 r"""Return samples from a `potential_fn` obtained via rejection sampling.
@@ -57,11 +58,14 @@ def rejection_sample(
5758 value will ensure that the samples are indeed from the correct
5859 distribution, but will increase the fraction of rejected samples and thus
5960 computation time.
60- device: Device on which to sample.
6161 max_sampling_time: Optional maximum allowed sampling time (in seconds).
6262 If this time is exceeded, rejection sampling is aborted and a RuntimeError
63- is raised. This prevents jobs from stalling indefinitely when the
64- acceptance rate is extremely low.
63+ is raised (unless `return_partial_on_timeout=True`). This prevents jobs
64+ from stalling indefinitely when the acceptance rate is extremely low.
65+ return_partial_on_timeout: If True and `max_sampling_time` is exceeded, return
66+ the samples collected so far instead of raising a RuntimeError. A warning
67+ will be issued indicating the partial return. Default is False.
68+ device: Device on which to sample.
6569
6670 Returns:
6771 Accepted samples and acceptance rate as scalar Tensor.
@@ -143,6 +147,16 @@ def log_prob(self, theta: Tensor, **kwargs) -> Tensor:
143147 max_sampling_time is not None
144148 and (time .time () - start_time ) > max_sampling_time
145149 ):
150+ num_collected = sum (s .shape [0 ] for s in accepted )
151+ if return_partial_on_timeout and num_collected > 0 :
152+ pbar .close ()
153+ warnings .warn (
154+ f"Timeout exceeded after collecting { num_collected } /"
155+ f"{ num_samples } samples. Returning partial results." ,
156+ stacklevel = 2 ,
157+ )
158+ samples = torch .cat (accepted )
159+ return samples , as_tensor (acceptance_rate )
146160 raise RuntimeError (
147161 "Sampling aborted early because rejection sampling exceeded "
148162 "max_sampling_time. This is likely due to extremely low "
@@ -225,6 +239,7 @@ def accept_reject_sample(
225239 proposal_sampling_kwargs : Optional [Dict ] = None ,
226240 alternative_method : Optional [str ] = None ,
227241 max_sampling_time : Optional [float ] = None ,
242+ return_partial_on_timeout : bool = False ,
228243 ** kwargs ,
229244) -> Tuple [Tensor , Tensor ]:
230245 r"""Returns samples from a proposal according to a acception criterion.
@@ -264,12 +279,16 @@ def accept_reject_sample(
264279 alternative_method: An alternative method for sampling from the restricted
265280 proposal. E.g., for SNPE, we suggest to sample with MCMC if the rejection
266281 rate is too high. Used only for printing during a potential warning.
282+ max_sampling_time: Optional maximum allowed sampling time (in seconds).
283+ If exceeded, the sampling loop is interrupted and a RuntimeError is raised
284+ unless `return_partial_on_timeout=True`. This prevents infinite or
285+ excessively slow rejection sampling runs, e.g. in cases of heavy leakage
286+ or extremely low acceptance rates.
287+ return_partial_on_timeout: If True and `max_sampling_time` is exceeded, return
288+ the samples collected so far instead of raising a RuntimeError. A warning
289+ will be issued indicating the partial return. Default is False.
267290 kwargs: Absorb additional unused arguments that can be passed to
268291 `rejection_sample()`. Warn if not empty.
269- max_sampling_time: Optional maximum allowed sampling time (in seconds).
270- If exceeded, the sampling loop is interrupted and a RuntimeError is raised.
271- This prevents infinite or excessively slow rejection sampling runs, e.g.
272- in cases of heavy leakage or extremely low acceptance rates.
273292
274293 Returns:
275294 Accepted samples of shape `(sample_dim, batch_dim, *event_shape)`, and
@@ -318,6 +337,24 @@ def accept_reject_sample(
318337 max_sampling_time is not None
319338 and (time .time () - start_time ) > max_sampling_time
320339 ):
340+ # Check if we have any samples collected
341+ num_collected = min (
342+ sum (s .shape [0 ] for s in accepted [i ]) for i in range (num_xos )
343+ )
344+ if return_partial_on_timeout and num_collected > 0 :
345+ pbar .close ()
346+ warnings .warn (
347+ f"Timeout exceeded after collecting { num_collected } /{ num_samples } "
348+ f" samples. Returning partial results." ,
349+ stacklevel = 2 ,
350+ )
351+ # Return partial samples with proper shape
352+ samples = [
353+ torch .cat (accepted [i ], dim = 0 )[:num_collected ]
354+ for i in range (num_xos )
355+ ]
356+ samples = torch .stack (samples , dim = 1 )
357+ return samples , as_tensor (acceptance_rate , device = samples .device )
321358 raise RuntimeError (
322359 "Sampling aborted early because rejection sampling exceeded "
323360 "max_sampling_time. This is likely due to extremely low "
0 commit comments