Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions sbi/neural_nets/estimators/flowmatching_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

import warnings
from typing import Optional
from typing import Optional, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -56,6 +56,8 @@ def __init__(
condition_shape: torch.Size,
embedding_net: Optional[nn.Module] = None,
noise_scale: float = 1e-3,
mean_1: Union[Tensor, float] = 0.0,
std_1: Union[Tensor, float] = 1.0,
**kwargs,
) -> None:
r"""Creates a vector field estimator for Flow Matching.
Expand All @@ -67,6 +69,8 @@ def __init__(
embedding_net: Embedding network for the condition.
noise_scale: Scale of the noise added to the vector field
(:math:`\sigma_{min}` in [2]_).
mean_1: Mean of the data at t=1 (used for z-scoring).
std_1: Standard deviation of the data at t=1 (used for z-scoring).
zscore_transform_input: Whether to z-score the input.
This is ignored and will be removed.
"""
Expand All @@ -88,6 +92,9 @@ def __init__(
)
self.noise_scale = noise_scale

self.register_buffer("mean_1", torch.as_tensor(mean_1))
self.register_buffer("std_1", torch.as_tensor(std_1))

def forward(self, input: Tensor, condition: Tensor, time: Tensor) -> Tensor:
"""Forward pass of the FlowMatchingEstimator.

Expand Down Expand Up @@ -127,8 +134,15 @@ def forward(self, input: Tensor, condition: Tensor, time: Tensor) -> Tensor:
)
time = time.reshape(-1)

# call the network to get the estimated vector field
v = self.net(input, condition_emb, time)
t_view = time.view(-1, *([1] * (input.ndim - 1)))
Copy link
Contributor

@manuelgloeckler manuelgloeckler Feb 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Assuming a Gaussian target at t=1 with the here given mu1 and std1 the exact marginal velocity would have follow form:

# ---- marginal Gaussian stats (alpha=t, sigma=1-t, diag C = s1^2) ----
mu_t  = t_view * m                                          # \bar{mu}_t
var_t = (t_view.square() * s1_sq) + one_minus_t.square()     # diag(S_t)
std_t = var_t.sqrt().clamp_min(self.eps)

# ---- z-scoreing-scaling for net (as currently) ----
x_centered = x - mu_t
x_norm = x_centered / std_t                                 # c_in * (x - mu_t)

resid_norm = self.net(x_norm, condition_emb, t)             # f_theta(...)
resid = resid_norm * std_t                                  # c_out * f_theta

# ---- Gaussian posterior mean E[x1 | xt=x] under diag prior ----
# k_t = alpha * C / S_t  with alpha=t and C=s1^2 (diagonal)
k_t = (t_view * s1_sq) / var_t
x1_hat = m + k_t * x_centered                               # m + k_t (x - t m)

# ---- Gaussian affine baseline: a(t)=t, b(t)=1-t ----
u_gauss = (t_view * x) + (one_minus_t * x1_hat)

Although this is only with respect to the "prior" (i.e. not the posterior). But might still be reasonable.

mu_t = t_view * self.mean_1
var_t = (t_view * self.std_1) ** 2 + (1 - t_view) ** 2

std_t = torch.sqrt(var_t)
input_norm = (input - mu_t) / std_t

v_out = self.net(input_norm, condition_emb, time)
v = v_out * std_t
v = v.reshape(*batch_shape + self.input_shape)

return v
Expand Down
4 changes: 3 additions & 1 deletion sbi/neural_nets/net_builders/vector_field_nets.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def build_vector_field_estimator(
if z_score_x_bool:
mean_0, std_0 = z_standardization(batch_x, structured_x)
else:
mean_0, std_0 = 0, 1
mean_0, std_0 = 0.0, 1.0

z_score_y_bool, structured_y = z_score_parser(z_score_y)
embedding_net_y = (
Expand All @@ -149,6 +149,8 @@ def build_vector_field_estimator(
input_shape=batch_x[0].shape,
condition_shape=batch_y[0].shape,
embedding_net=embedding_net_y,
mean_1=mean_0,
std_1=std_0,
)
elif estimator_type == "score":
# Choose the appropriate score estimator based on SDE type
Expand Down
27 changes: 27 additions & 0 deletions tests/linearGaussian_vector_field_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,33 @@ def test_vfinference_with_different_models(vector_field_type, model):
check_c2st(samples, target_samples, alg=f"fmpe_{model}")


@pytest.mark.parametrize("vector_field_type", ["fmpe"])
def test_fmpe_time_dependent_z_scoring_integration(vector_field_type):
num_dim = 2
prior = BoxUniform(9.0 * ones(num_dim), 11.0 * ones(num_dim))

def simulator(theta):
return theta + torch.randn_like(theta) * 0.1

inference = FMPE(prior, z_score_x='structured', show_progress_bars=False)
theta = prior.sample((200,))
x = simulator(theta)
density_estimator = inference.append_simulations(theta, x).train(max_num_epochs=1)

assert hasattr(density_estimator, "mean_1")
assert hasattr(density_estimator, "std_1")
assert torch.all(density_estimator.mean_1 > 8.0)

batch_size = 10
t = torch.rand(batch_size)
theta_test = torch.randn(batch_size, num_dim)
cond_test = zeros(batch_size, num_dim)
v_pred = density_estimator.ode_fn(theta_test, cond_test, t)

assert v_pred.shape == (batch_size, num_dim)
assert not torch.isnan(v_pred).any()


# ------------------------------------------------------------------------------
# -------------------------------- SLOW TESTS ----------------------------------
# ------------------------------------------------------------------------------
Expand Down
9 changes: 7 additions & 2 deletions tests/torchutils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,13 @@ def test_searchsorted(self):
right_boundaries = bin_locations[:-1] + 0.1
mid_points = bin_locations[:-1] + 0.05

for inputs in [left_boundaries, right_boundaries, mid_points]:
with self.subTest(inputs=inputs):
test_cases = [
("left_boundaries", left_boundaries),
("right_boundaries", right_boundaries),
("mid_points", mid_points),
]
for name, inputs in test_cases:
with self.subTest(name=name):
idx = torchutils.searchsorted(bin_locations[None, :], inputs)
self.assertEqual(idx, torch.arange(0, 9))

Expand Down
Loading