From 2fbcf18c7a37d690d1e60f618ea2e5ac588b84ca Mon Sep 17 00:00:00 2001 From: Satwik Sai Prakash Sahoo Date: Tue, 3 Feb 2026 13:35:58 +0530 Subject: [PATCH 1/7] update vector field builder to pass z-scoring stats to estimator --- sbi/neural_nets/net_builders/vector_field_nets.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sbi/neural_nets/net_builders/vector_field_nets.py b/sbi/neural_nets/net_builders/vector_field_nets.py index d1f135386..c0668c23a 100644 --- a/sbi/neural_nets/net_builders/vector_field_nets.py +++ b/sbi/neural_nets/net_builders/vector_field_nets.py @@ -134,7 +134,7 @@ def build_vector_field_estimator( if z_score_x_bool: mean_0, std_0 = z_standardization(batch_x, structured_x) else: - mean_0, std_0 = 0, 1 + mean_0, std_0 = 0.0, 1.0 z_score_y_bool, structured_y = z_score_parser(z_score_y) embedding_net_y = ( @@ -149,6 +149,8 @@ def build_vector_field_estimator( input_shape=batch_x[0].shape, condition_shape=batch_y[0].shape, embedding_net=embedding_net_y, + mean_1=mean_0, + std_1=std_0, ) elif estimator_type == "score": # Choose the appropriate score estimator based on SDE type From 3883de29afb352766b07e54a9035821dea836765 Mon Sep 17 00:00:00 2001 From: Satwik Sai Prakash Sahoo Date: Tue, 3 Feb 2026 13:39:30 +0530 Subject: [PATCH 2/7] implement time-dependent z-scoring logic in FlowMatchingEstimator --- .../estimators/flowmatching_estimator.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/sbi/neural_nets/estimators/flowmatching_estimator.py b/sbi/neural_nets/estimators/flowmatching_estimator.py index dfbb59da9..6da2a00c2 100644 --- a/sbi/neural_nets/estimators/flowmatching_estimator.py +++ b/sbi/neural_nets/estimators/flowmatching_estimator.py @@ -2,7 +2,7 @@ # under the Apache License Version 2.0, see import warnings -from typing import Optional +from typing import Optional, Union import torch import torch.nn as nn @@ -56,6 +56,8 @@ def __init__( condition_shape: torch.Size, embedding_net: Optional[nn.Module] = None, noise_scale: float = 1e-3, + mean_1: Union[Tensor, float] = 0.0, + std_1: Union[Tensor, float] = 1.0, **kwargs, ) -> None: r"""Creates a vector field estimator for Flow Matching. @@ -67,6 +69,8 @@ def __init__( embedding_net: Embedding network for the condition. noise_scale: Scale of the noise added to the vector field (:math:`\sigma_{min}` in [2]_). + mean_1: Mean of the data at t=1 (used for z-scoring). + std_1: Standard deviation of the data at t=1 (used for z-scoring). zscore_transform_input: Whether to z-score the input. This is ignored and will be removed. """ @@ -88,6 +92,9 @@ def __init__( ) self.noise_scale = noise_scale + self.register_buffer("mean_1", torch.as_tensor(mean_1)) + self.register_buffer("std_1", torch.as_tensor(std_1)) + def forward(self, input: Tensor, condition: Tensor, time: Tensor) -> Tensor: """Forward pass of the FlowMatchingEstimator. @@ -127,8 +134,15 @@ def forward(self, input: Tensor, condition: Tensor, time: Tensor) -> Tensor: ) time = time.reshape(-1) - # call the network to get the estimated vector field - v = self.net(input, condition_emb, time) + t_view = time.view(-1, *([1] * (input.ndim - 1))) + mu_t = t_view * self.mean_1 + var_t = (t_view * self.std_1) ** 2 + (1 - t_view) ** 2 + + std_t = torch.sqrt(var_t) + input_norm = (input - mu_t) / std_t + + v_out = self.net(input_norm, condition_emb, time) + v = v_out * std_t v = v.reshape(*batch_shape + self.input_shape) return v From dc8dc8638c7d82948b6890ee412504b4940072ec Mon Sep 17 00:00:00 2001 From: Satwik Sai Prakash Sahoo Date: Tue, 3 Feb 2026 13:42:49 +0530 Subject: [PATCH 3/7] add integration test for FMPE time-dependent z-scoring --- tests/linearGaussian_vector_field_test.py | 27 +++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/linearGaussian_vector_field_test.py b/tests/linearGaussian_vector_field_test.py index ccfe93276..6efdcd0b3 100644 --- a/tests/linearGaussian_vector_field_test.py +++ b/tests/linearGaussian_vector_field_test.py @@ -248,6 +248,33 @@ def test_vfinference_with_different_models(vector_field_type, model): check_c2st(samples, target_samples, alg=f"fmpe_{model}") +@pytest.mark.parametrize("vector_field_type", ["fmpe"]) +def test_fmpe_time_dependent_z_scoring_integration(vector_field_type): + num_dim = 2 + prior = BoxUniform(9.0 * ones(num_dim), 11.0 * ones(num_dim)) + + def simulator(theta): + return theta + torch.randn_like(theta) * 0.1 + + inference = FMPE(prior, z_score_x='structured', show_progress_bars=False) + theta = prior.sample((200,)) + x = simulator(theta) + density_estimator = inference.append_simulations(theta, x).train(max_num_epochs=1) + + assert hasattr(density_estimator, "mean_1") + assert hasattr(density_estimator, "std_1") + assert torch.all(density_estimator.mean_1 > 8.0) + + batch_size = 10 + t = torch.rand(batch_size) + theta_test = torch.randn(batch_size, num_dim) + cond_test = zeros(batch_size, num_dim) + v_pred = density_estimator.ode_fn(theta_test, cond_test, t) + + assert v_pred.shape == (batch_size, num_dim) + assert not torch.isnan(v_pred).any() + + # ------------------------------------------------------------------------------ # -------------------------------- SLOW TESTS ---------------------------------- # ------------------------------------------------------------------------------ From 994e08df54afc4b8a2f2760c0f4295d0620b8509 Mon Sep 17 00:00:00 2001 From: Satwik Sai Prakash Sahoo Date: Wed, 4 Feb 2026 01:13:06 +0530 Subject: [PATCH 4/7] trigger ci From b1d7a3bc978f1f5bad363bb6eb9e456b7378a4e8 Mon Sep 17 00:00:00 2001 From: Satwik Sai Prakash Sahoo Date: Wed, 4 Feb 2026 01:22:32 +0530 Subject: [PATCH 5/7] re-trigger ci after ready for review From 6498c3ff0026a4d271b91d5a8f22fd15eb0d4e83 Mon Sep 17 00:00:00 2001 From: Satwik Sai Prakash Sahoo Date: Wed, 4 Feb 2026 01:39:26 +0530 Subject: [PATCH 6/7] again trigger the ci to pass the flaky test From da03aadc3e2f9b352c8673a0e092df7815782b65 Mon Sep 17 00:00:00 2001 From: Jan Date: Thu, 5 Feb 2026 20:37:23 +0530 Subject: [PATCH 7/7] fix serialization issue in test cases. this was an old bug that surfaced now likely because codecov was trying to serialize things. --- tests/torchutils_test.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/torchutils_test.py b/tests/torchutils_test.py index 10e0dff64..b60c36237 100644 --- a/tests/torchutils_test.py +++ b/tests/torchutils_test.py @@ -93,8 +93,13 @@ def test_searchsorted(self): right_boundaries = bin_locations[:-1] + 0.1 mid_points = bin_locations[:-1] + 0.05 - for inputs in [left_boundaries, right_boundaries, mid_points]: - with self.subTest(inputs=inputs): + test_cases = [ + ("left_boundaries", left_boundaries), + ("right_boundaries", right_boundaries), + ("mid_points", mid_points), + ] + for name, inputs in test_cases: + with self.subTest(name=name): idx = torchutils.searchsorted(bin_locations[None, :], inputs) self.assertEqual(idx, torch.arange(0, 9))