diff --git a/sbi/inference/posteriors/vector_field_posterior.py b/sbi/inference/posteriors/vector_field_posterior.py index 8d81fc2b8..9379a7b16 100644 --- a/sbi/inference/posteriors/vector_field_posterior.py +++ b/sbi/inference/posteriors/vector_field_posterior.py @@ -306,8 +306,8 @@ def _sample_via_diffusion( corrector: The corrector for the diffusion-based sampler. Either of [None]. steps: Number of steps to take for the Euler-Maruyama method. - ts: Time points at which to evaluate the diffusion process. If None, a - linear grid between t_max and t_min is used. + ts: Time points at which to evaluate the diffusion process. If None, call + the solve schedule specific to the score estimator. max_sampling_batch_size: Maximum batch size for sampling. sample_with: Deprecated - use `.build_posterior(sample_with=...)` prior to `.sample()`. @@ -333,11 +333,8 @@ def _sample_via_diffusion( # Ensure we don't use larger batches than total samples needed effective_batch_size = min(effective_batch_size, total_samples_needed) - # TODO: the time schedule should be provided by the estimator, see issue #1437 if ts is None: - t_max = self.vector_field_estimator.t_max - t_min = self.vector_field_estimator.t_min - ts = torch.linspace(t_max, t_min, steps) + ts = self.vector_field_estimator.solve_schedule(steps) ts = ts.to(self.device) # Initialize the diffusion sampler diff --git a/sbi/inference/trainers/vfpe/base_vf_inference.py b/sbi/inference/trainers/vfpe/base_vf_inference.py index b0db7f766..d7aad91cd 100644 --- a/sbi/inference/trainers/vfpe/base_vf_inference.py +++ b/sbi/inference/trainers/vfpe/base_vf_inference.py @@ -313,10 +313,10 @@ def default_calibration_kernel(x): ) if isinstance(validation_times, int): - validation_times = torch.linspace( - self._neural_net.t_min + validation_times_nugget, - self._neural_net.t_max - validation_times_nugget, + validation_times = self._neural_net.solve_schedule( validation_times, + t_min=self._neural_net.t_min + validation_times_nugget, + t_max=self._neural_net.t_max - validation_times_nugget, ) loss_args = LossArgsVF( diff --git a/sbi/neural_nets/estimators/base.py b/sbi/neural_nets/estimators/base.py index e76979f72..f8e116ccb 100644 --- a/sbi/neural_nets/estimators/base.py +++ b/sbi/neural_nets/estimators/base.py @@ -479,6 +479,35 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: """ raise NotImplementedError("Diffusion is not implemented for this estimator.") + def solve_schedule( + self, + steps: int, + t_min: Optional[float] = None, + t_max: Optional[float] = None, + ) -> Tensor: + """ + Time grid used during sampling (solving steps) and loss evaluation steps. + + This grid is deterministic and decreasing. Can be overriden by subclasses. + Return by default a uniform time stepping between t_max and t_min. + + Args: + steps: number of discretization steps + t_min: The minimum time value. Defaults to self.t_min. + t_max: The maximum time value. Defaults to self.t_max. + + Returns: + Tensor: A tensor of time steps within the range [t_max, t_min]. + """ + if t_min is None: + t_min = self.t_min + if t_max is None: + t_max = self.t_max + + times = torch.linspace(t_max, t_min, steps, device=self._mean_base.device) + + return times + class UnconditionalEstimator(nn.Module, ABC): r"""Base class for unconditional estimators that estimate properties of diff --git a/sbi/neural_nets/estimators/score_estimator.py b/sbi/neural_nets/estimators/score_estimator.py index bb2011ccc..f9c24ed44 100644 --- a/sbi/neural_nets/estimators/score_estimator.py +++ b/sbi/neural_nets/estimators/score_estimator.py @@ -66,6 +66,8 @@ def __init__( condition_shape: torch.Size, embedding_net: Optional[nn.Module] = None, weight_fn: Union[str, Callable] = "max_likelihood", + beta_min: float = 0.01, + beta_max: float = 10.0, mean_0: Union[Tensor, float] = 0.0, std_0: Union[Tensor, float] = 1.0, t_min: float = 1e-3, @@ -111,6 +113,10 @@ def __init__( t_max=t_max, ) + # Min/max values for noise variance beta + self.beta_min = beta_min + self.beta_max = beta_max + # Set lambdas (variance weights) function. self._set_weight_fn(weight_fn) self.register_buffer("mean_0", mean_0.clone().detach()) @@ -228,7 +234,8 @@ def loss( Args: input: Input variable i.e. theta. condition: Conditioning variable. - times: SDE time variable in [t_min, t_max]. Uniformly sampled if None. + times: SDE time variable in [t_min, t_max]. If None, call the train_schedule + method specific to the score estimator. control_variate: Whether to use a control variate to reduce the variance of the stochastic loss estimator. control_variate_threshold: Threshold for the control variate. If the std @@ -240,13 +247,11 @@ def loss( MSE between target score and network output, scaled by the weight function. """ - # Sample diffusion times. + # Sample times from the Markov chain, use batch dimension if times is None: - times = ( - torch.rand(input.shape[0], device=input.device) - * (self.t_max - self.t_min) - + self.t_min - ) + times = self.train_schedule(input.shape[0]) + times = times.to(input.device) + # Sample noise. eps = torch.randn_like(input) @@ -390,6 +395,86 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: """ raise NotImplementedError + def noise_schedule(self, times: Tensor) -> Tensor: + """ + Create a mapping from time to noise magnitude (beta/sigma) used in the SDE. + + This method acts as a fallback in case derivative classes do not + implement it on their own. We implement here a linear beta schedule defined + by the input `times`, which represents the normalized time steps t ∈ [0, 1]. + + Args: + times: SDE times in [0, 1]. + + Returns: + Tensor: Generated beta schedule at a given time. + + """ + return self.beta_min + (self.beta_max - self.beta_min) * times + + def train_schedule( + self, + num_samples: int, + t_min: Optional[float] = None, + t_max: Optional[float] = None, + ) -> Tensor: + """ + Sample diffusion times used during training. + + Can be overriden by subclasses. We implement here a uniform sampling of + time within the range [t_min, t_max]. The `times` tensor will be put on + the same device as the stored network. + + Args: + num_samples: Number of samples to generate. + t_min: The minimum time value. Defaults to self.t_min. + t_max: The maximum time value. Defaults to self.t_max. + + Returns: + Tensor: A tensor of time variables sampled in the range [t_min,t_max]. + + """ + if t_min is None: + t_min = self.t_min + if t_max is None: + t_max = self.t_max + + return ( + torch.rand(num_samples, device=self._mean_base.device) * (t_max - t_min) + + t_min + ) + + def solve_schedule( + self, + num_steps: int, + t_min: Optional[float] = None, + t_max: Optional[float] = None, + ) -> Tensor: + """ + Time grid used to solve the reverse SDE and evaluate the loss function. + + This grid is deterministic and decreasing. We implement here a uniform time + stepping within the range [t_max, t_min]. The `times` tensor will be put on + the same device as the stored network. + + Args: + num_steps: Number of time steps to generate. + t_min: The minimum time value. Defaults to self.t_min. + t_max: The maximum time value. Defaults to self.t_max. + + Returns: + Tensor: A tensor of time steps within the range [t_max, t_min]. + + """ + if t_min is None: + t_min = self.t_min + if t_max is None: + t_max = self.t_max + + times = torch.linspace(t_max, t_min, num_steps, device=self._mean_base.device) + + return times + def _set_weight_fn(self, weight_fn: Union[str, Callable]): """Set the weight function. @@ -480,8 +565,6 @@ def __init__( t_min: float = 1e-3, t_max: float = 1.0, ) -> None: - self.beta_min = beta_min - self.beta_max = beta_max super().__init__( net, input_shape, @@ -490,6 +573,8 @@ def __init__( weight_fn=weight_fn, mean_0=mean_0, std_0=std_0, + beta_min=beta_min, + beta_max=beta_max, t_min=t_min, t_max=t_max, ) @@ -525,17 +610,6 @@ def std_fn(self, times: Tensor) -> Tensor: std = std.unsqueeze(-1) return torch.sqrt(std) - def _beta_schedule(self, times: Tensor) -> Tensor: - """Linear beta schedule for mean scaling in variance preserving SDEs. - - Args: - times: SDE time variable in [0,1]. - - Returns: - Beta schedule at a given time. - """ - return self.beta_min + (self.beta_max - self.beta_min) * times - def drift_fn(self, input: Tensor, times: Tensor) -> Tensor: """Drift function for variance preserving SDEs. @@ -546,7 +620,7 @@ def drift_fn(self, input: Tensor, times: Tensor) -> Tensor: Returns: Drift function at a given time. """ - phi = -0.5 * self._beta_schedule(times) + phi = -0.5 * self.noise_schedule(times) while len(phi.shape) < len(input.shape): phi = phi.unsqueeze(-1) return phi * input @@ -561,7 +635,7 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: Returns: Drift function at a given time. """ - g = torch.sqrt(self._beta_schedule(times)) + g = torch.sqrt(self.noise_schedule(times)) while len(g.shape) < len(input.shape): g = g.unsqueeze(-1) return g @@ -604,14 +678,14 @@ def __init__( t_min: float = 1e-2, t_max: float = 1.0, ) -> None: - self.beta_min = beta_min - self.beta_max = beta_max super().__init__( net, input_shape, condition_shape, embedding_net=embedding_net, weight_fn=weight_fn, + beta_min=beta_min, + beta_max=beta_max, mean_0=mean_0, std_0=std_0, t_min=t_min, @@ -649,18 +723,6 @@ def std_fn(self, times: Tensor) -> Tensor: std = std.unsqueeze(-1) return std - def _beta_schedule(self, times: Tensor) -> Tensor: - """Linear beta schedule for mean scaling in sub-variance preserving SDEs. - (Same as for variance preserving SDEs.) - - Args: - times: SDE time variable in [0,1]. - - Returns: - Beta schedule at a given time. - """ - return self.beta_min + (self.beta_max - self.beta_min) * times - def drift_fn(self, input: Tensor, times: Tensor) -> Tensor: """Drift function for sub-variance preserving SDEs. @@ -671,7 +733,7 @@ def drift_fn(self, input: Tensor, times: Tensor) -> Tensor: Returns: Drift function at a given time. """ - phi = -0.5 * self._beta_schedule(times).to(input.device) + phi = -0.5 * self.noise_schedule(times) while len(phi.shape) < len(input.shape): phi = phi.unsqueeze(-1) @@ -690,7 +752,7 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: """ g = torch.sqrt( torch.abs( - self._beta_schedule(times) + self.noise_schedule(times) * ( 1 - torch.exp( @@ -788,16 +850,18 @@ def std_fn(self, times: Tensor) -> Tensor: std = std.unsqueeze(-1) return std - def _sigma_schedule(self, times: Tensor) -> Tensor: - """Geometric sigma schedule for variance exploding SDEs. + def noise_schedule(self, times: Tensor) -> Tensor: + """Noise schedule used in the SDE drift and diffusion coefficients. + Note that for VE, this method returns the same as std_fn(). Args: times: SDE time variable in [0,1]. Returns: - Sigma schedule at a given time. + Noise magnitude (sigma) at a given time. """ - return self.sigma_min * (self.sigma_max / self.sigma_min) ** times + std = self.sigma_min * (self.sigma_max / self.sigma_min) ** times + return std def drift_fn(self, input: Tensor, times: Tensor) -> Tensor: """Drift function for variance exploding SDEs. @@ -821,11 +885,10 @@ def diffusion_fn(self, input: Tensor, times: Tensor) -> Tensor: Returns: Diffusion function at a given time. """ - g = self._sigma_schedule(times) * math.sqrt( - (2 * math.log(self.sigma_max / self.sigma_min)) - ) + sigma_scale = self.sigma_max / self.sigma_min + sigmas = self.noise_schedule(times) + g = sigmas * math.sqrt((2 * math.log(sigma_scale))) while len(g.shape) < len(input.shape): g = g.unsqueeze(-1) - return g.to(input.device) diff --git a/tests/linearGaussian_vector_field_test.py b/tests/linearGaussian_vector_field_test.py index ccfe93276..40d685499 100644 --- a/tests/linearGaussian_vector_field_test.py +++ b/tests/linearGaussian_vector_field_test.py @@ -69,7 +69,7 @@ def test_c2st_vector_field_on_linearGaussian( x_o = zeros(1, num_dim) num_samples = 1000 - num_simulations = 2500 + num_simulations = 2600 if vector_field_type == "ve" else 2500 # likelihood_mean will be likelihood_shift+theta likelihood_shift = -1.0 * ones(num_dim) diff --git a/tests/vf_estimator_test.py b/tests/vf_estimator_test.py index 4c53594f1..b9de41dc7 100644 --- a/tests/vf_estimator_test.py +++ b/tests/vf_estimator_test.py @@ -185,3 +185,66 @@ def _build_vector_field_estimator_and_tensors( ) condition = condition return estimator, inputs, condition + + +@pytest.mark.parametrize( + "estimator_type,sde_type", + [ + ("score", "vp"), + ("score", "subvp"), + ("score", "ve"), + ("flow", None), + ], +) +def test_train_schedule(estimator_type, sde_type): + """Test on shapes and bounds for train and solve schedules + of vector field estimators (flow or score) + """ + embedding_net = torch.nn.Identity() + t_min = torch.tensor([0.0]) + t_max = torch.tensor([1.0]) + + if estimator_type == "flow": + estimator = build_flow_matching_estimator( + torch.randn(100, 1), + torch.randn(100, 1), + embedding_net=embedding_net, + ) + + else: + estimator = build_score_matching_estimator( + torch.randn(100, 1), + torch.randn(100, 1), + embedding_net=embedding_net, + sde_type=sde_type, + ) + # Train schedule only defined for score estimators + # Schedule with default bounds + train_schedule_default = estimator.train_schedule(300) + assert train_schedule_default.shape == torch.Size((300,)) + assert train_schedule_default.max() <= estimator.t_max + assert train_schedule_default.min() >= estimator.t_min + + # Schedule with given bounds + train_schedule = estimator.train_schedule(300, t_min, t_max) + assert train_schedule.shape == torch.Size((300,)) + assert train_schedule.max() <= t_max.item() + assert train_schedule.min() >= t_min.item() + + # Solve schedule with default bounds + solve_schedule_default = estimator.solve_schedule( + 300, t_max=estimator.t_max, t_min=estimator.t_min + ) + assert torch.allclose(solve_schedule_default[0], torch.tensor([estimator.t_max])) + assert torch.allclose(solve_schedule_default[-1], torch.tensor([estimator.t_min])) + assert solve_schedule_default.shape == torch.Size((300,)) + assert torch.all(solve_schedule_default[:-1] - solve_schedule_default[1:] >= 0) + + # Solve schedule with given bounds + solve_schedule = estimator.solve_schedule( + 300, t_max=t_max.item(), t_min=t_min.item() + ) + assert torch.allclose(solve_schedule[0], t_max) + assert torch.allclose(solve_schedule[-1], t_min) + assert solve_schedule_default.shape == torch.Size((300,)) + assert torch.all(solve_schedule[:-1] - solve_schedule[1:] >= 0)