@@ -76,8 +76,12 @@ class ZukoFlowType(Enum):
7676
7777def classifier_nn (
7878 model : str ,
79- z_score_theta : Optional [str ] = "independent" ,
80- z_score_x : Optional [str ] = "independent" ,
79+ z_score_theta : Optional [
80+ Literal ["independent" , "structured" , "transform_to_unconstrained" , "none" ]
81+ ] = "independent" ,
82+ z_score_x : Optional [
83+ Literal ["independent" , "structured" , "transform_to_unconstrained" , "none" ]
84+ ] = "independent" ,
8185 hidden_features : int = 50 ,
8286 embedding_net_theta : nn .Module = nn .Identity (),
8387 embedding_net_x : nn .Module = nn .Identity (),
@@ -151,8 +155,12 @@ def build_fn(batch_theta, batch_x):
151155
152156def likelihood_nn (
153157 model : str ,
154- z_score_theta : Optional [str ] = "independent" ,
155- z_score_x : Optional [str ] = "independent" ,
158+ z_score_theta : Optional [
159+ Literal ["independent" , "structured" , "transform_to_unconstrained" , "none" ]
160+ ] = "independent" ,
161+ z_score_x : Optional [
162+ Literal ["independent" , "structured" , "transform_to_unconstrained" , "none" ]
163+ ] = "independent" ,
156164 hidden_features : int = 50 ,
157165 num_transforms : int = 5 ,
158166 num_bins : int = 10 ,
@@ -226,8 +234,12 @@ def build_fn(batch_theta, batch_x):
226234
227235def posterior_nn (
228236 model : str ,
229- z_score_theta : Optional [str ] = "independent" ,
230- z_score_x : Optional [str ] = "independent" ,
237+ z_score_theta : Optional [
238+ Literal ["independent" , "structured" , "transform_to_unconstrained" , "none" ]
239+ ] = "independent" ,
240+ z_score_x : Optional [
241+ Literal ["independent" , "structured" , "transform_to_unconstrained" , "none" ]
242+ ] = "independent" ,
231243 hidden_features : int = 50 ,
232244 num_transforms : int = 5 ,
233245 num_bins : int = 10 ,
@@ -334,8 +346,12 @@ def posterior_score_nn(
334346 VectorFieldNet ,
335347 ] = "mlp" ,
336348 sde_type : str = "ve" ,
337- z_score_theta : Optional [str ] = "independent" ,
338- z_score_x : Optional [str ] = "independent" ,
349+ z_score_theta : Optional [
350+ Literal ["independent" , "structured" , "transform_to_unconstrained" , "none" ]
351+ ] = "independent" ,
352+ z_score_x : Optional [
353+ Literal ["independent" , "structured" , "transform_to_unconstrained" , "none" ]
354+ ] = "independent" ,
339355 hidden_features : int = 100 ,
340356 num_layers : int = 5 ,
341357 embedding_net : nn .Module = nn .Identity (),
@@ -436,8 +452,12 @@ def build_fn(batch_theta, batch_x):
436452# TODO: remove this function on next release
437453def flowmatching_nn (
438454 model : str ,
439- z_score_theta : Optional [str ] = "independent" ,
440- z_score_x : Optional [str ] = "independent" ,
455+ z_score_theta : Optional [
456+ Literal ["independent" , "structured" , "transform_to_unconstrained" , "none" ]
457+ ] = "independent" ,
458+ z_score_x : Optional [
459+ Literal ["independent" , "structured" , "transform_to_unconstrained" , "none" ]
460+ ] = "independent" ,
441461 hidden_features : int = 64 ,
442462 num_layers : int = 5 ,
443463 num_blocks : int = 5 ,
@@ -510,8 +530,12 @@ def posterior_flow_nn(
510530 Literal ["mlp" , "ada_mlp" , "transformer" , "transformer_cross_attn" ],
511531 VectorFieldNet ,
512532 ] = "mlp" ,
513- z_score_theta : Optional [str ] = None ,
514- z_score_x : Optional [str ] = "independent" ,
533+ z_score_theta : Optional [
534+ Literal ["independent" , "structured" , "transform_to_unconstrained" , "none" ]
535+ ] = None ,
536+ z_score_x : Optional [
537+ Literal ["independent" , "structured" , "transform_to_unconstrained" , "none" ]
538+ ] = "independent" ,
515539 hidden_features : int = 100 ,
516540 num_layers : int = 5 ,
517541 embedding_net : nn .Module = nn .Identity (),
@@ -592,7 +616,9 @@ def build_fn(batch_theta, batch_x):
592616
593617def marginal_nn (
594618 model : ZukoFlowType ,
595- z_score_x : Optional [str ] = "independent" ,
619+ z_score_x : Optional [
620+ Literal ["independent" , "structured" , "transform_to_unconstrained" , "none" ]
621+ ] = "independent" ,
596622 hidden_features : int = 50 ,
597623 num_transforms : int = 5 ,
598624 num_bins : int = 10 ,
0 commit comments