File tree Expand file tree Collapse file tree 1 file changed +10
-5
lines changed
sbi/neural_nets/net_builders Expand file tree Collapse file tree 1 file changed +10
-5
lines changed Original file line number Diff line number Diff line change 1818from sbi .utils .nn_utils import get_numel
1919from sbi .utils .sbiutils import (
2020 standardizing_net ,
21+ z_score_parser ,
2122 z_standardization ,
2223)
2324from sbi .utils .user_input_checks import check_data_device
@@ -129,12 +130,16 @@ def build_vector_field_estimator(
129130 raise ValueError (f"Unknown architecture: { net } " )
130131
131132 # Z-score setup
132- mean_0 , std_0 = z_standardization (batch_x , z_score_x == "structured" )
133+ z_score_x_bool , structured_x = z_score_parser (z_score_x )
134+ if z_score_x_bool :
135+ mean_0 , std_0 = z_standardization (batch_x , structured_x )
136+ else :
137+ mean_0 , std_0 = 0 , 1
138+
139+ z_score_y_bool , structured_y = z_score_parser (z_score_y )
133140 embedding_net_y = (
134- nn .Sequential (
135- standardizing_net (batch_y , z_score_y == "structured" ), embedding_net
136- )
137- if z_score_y
141+ nn .Sequential (standardizing_net (batch_y , structured_y ), embedding_net )
142+ if z_score_y_bool
138143 else embedding_net
139144 )
140145
You can’t perform that action at this time.
0 commit comments