1+ import argparse
2+
13import distrax
24import haiku as hk
35import jax
@@ -49,9 +51,7 @@ def _flow(method, **kwargs):
4951 layer = MaskedAutoregressive (
5052 bijector_fn = _bijector_fn ,
5153 conditioner = MADE (
52- 2 ,
53- [32 , 32 , 2 * 2 ],
54- 2 ,
54+ 2 , [32 , 32 ], 2 ,
5555 w_init = hk .initializers .TruncatedNormal (0.01 ),
5656 b_init = jnp .zeros ,
5757 ),
@@ -104,7 +104,7 @@ def loss_fn(params):
104104 return params , losses
105105
106106
107- def run ():
107+ def run (n_iter , model ):
108108 n = 10000
109109 thetas = distrax .Normal (jnp .zeros (2 ), jnp .full (2 , 10 )).sample (
110110 seed = random .PRNGKey (0 ), sample_shape = (n ,)
@@ -114,8 +114,8 @@ def run():
114114 )
115115 data = named_dataset (y , thetas )
116116
117- model = make_model (2 )
118- params , losses = train (hk .PRNGSequence (2 ), data , model )
117+ model = make_model (2 , model )
118+ params , losses = train (hk .PRNGSequence (2 ), data , model , n_iter )
119119 samples = model .apply (
120120 params ,
121121 random .PRNGKey (2 ),
@@ -129,4 +129,10 @@ def run():
129129
130130
131131if __name__ == "__main__" :
132- run ()
132+ parser = argparse .ArgumentParser ()
133+ parser .add_argument ("--n-iter" , type = int , default = 1_000 )
134+ parser .add_argument ("--model" , type = str , default = "coupling" )
135+ args = parser .parse_args ()
136+ run (args .n_iter , args .model )
137+
138+
0 commit comments