Skip to content

Commit 0ab7a78

Browse files
fix: all options for z-scoring in vectorfield nets (#1681)
1 parent 1b1edfd commit 0ab7a78

File tree

1 file changed

+10
-5
lines changed

1 file changed

+10
-5
lines changed

sbi/neural_nets/net_builders/vector_field_nets.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from sbi.utils.nn_utils import get_numel
1919
from sbi.utils.sbiutils import (
2020
standardizing_net,
21+
z_score_parser,
2122
z_standardization,
2223
)
2324
from sbi.utils.user_input_checks import check_data_device
@@ -129,12 +130,16 @@ def build_vector_field_estimator(
129130
raise ValueError(f"Unknown architecture: {net}")
130131

131132
# Z-score setup
132-
mean_0, std_0 = z_standardization(batch_x, z_score_x == "structured")
133+
z_score_x_bool, structured_x = z_score_parser(z_score_x)
134+
if z_score_x_bool:
135+
mean_0, std_0 = z_standardization(batch_x, structured_x)
136+
else:
137+
mean_0, std_0 = 0, 1
138+
139+
z_score_y_bool, structured_y = z_score_parser(z_score_y)
133140
embedding_net_y = (
134-
nn.Sequential(
135-
standardizing_net(batch_y, z_score_y == "structured"), embedding_net
136-
)
137-
if z_score_y
141+
nn.Sequential(standardizing_net(batch_y, structured_y), embedding_net)
142+
if z_score_y_bool
138143
else embedding_net
139144
)
140145

0 commit comments

Comments
 (0)