Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions sbi/inference/posteriors/vector_field_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,8 @@ def _sample_via_diffusion(
corrector: The corrector for the diffusion-based sampler. Either of
[None].
steps: Number of steps to take for the Euler-Maruyama method.
ts: Time points at which to evaluate the diffusion process. If None, call
the solve schedule specific to the score estimator.
ts: Time points at which to evaluate the diffusion process. If None,
uses the solve_schedule() specific to the estimator.
max_sampling_batch_size: Maximum batch size for sampling.
sample_with: Deprecated - use `.build_posterior(sample_with=...)` prior to
`.sample()`.
Expand Down
1 change: 1 addition & 0 deletions sbi/inference/trainers/vfpe/base_vf_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ def default_calibration_kernel(x):
)

if isinstance(validation_times, int):
# Use nugget to offset from boundaries for numerical stability
validation_times = self._neural_net.solve_schedule(
validation_times,
t_min=self._neural_net.t_min + validation_times_nugget,
Expand Down
25 changes: 9 additions & 16 deletions sbi/neural_nets/estimators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,28 +485,21 @@ def solve_schedule(
t_min: Optional[float] = None,
t_max: Optional[float] = None,
) -> Tensor:
"""
Time grid used during sampling (solving steps) and loss evaluation steps.
"""Time schedule used during sampling. Can be overridden by subclasses.

This grid is deterministic and decreasing. Can be overriden by subclasses.
Return by default a uniform time stepping between t_max and t_min.
Returns a uniform time stepping between t_max and t_min by default.

Args:
steps: number of discretization steps
t_min: The minimum time value. Defaults to self.t_min.
t_max: The maximum time value. Defaults to self.t_max.
steps: Number of discretization steps.
t_min: Minimum time value. Defaults to self.t_min.
t_max: Maximum time value. Defaults to self.t_max.

Returns:
Tensor: A tensor of time steps within the range [t_max, t_min].
Tensor of time steps from t_max to t_min.
"""
if t_min is None:
t_min = self.t_min
if t_max is None:
t_max = self.t_max

times = torch.linspace(t_max, t_min, steps, device=self._mean_base.device)

return times
t_min = self.t_min if t_min is None else t_min
t_max = self.t_max if t_max is None else t_max
return torch.linspace(t_max, t_min, steps, device=self._mean_base.device)


class UnconditionalEstimator(nn.Module, ABC):
Expand Down
Loading
Loading