@@ -480,6 +480,53 @@ def model(data, labels=None):
480480 assert "nn/w" in samples
481481
482482
483+ @pytest .mark .skipif (sys .version_info [:2 ] == (3 , 9 ), reason = "Skipping on Python 3.9" )
484+ def test_random_nnx_module_mcmc_sequence_params ():
485+ from flax import nnx
486+
487+ class MLP (nnx .Module ):
488+ def __init__ (self , din , dout , hidden_layers , * , rngs , activation = jax .nn .relu ):
489+ self .activation = activation
490+ self .layers = []
491+ layer_dims = [din ] + hidden_layers + [dout ]
492+ for in_dim , out_dim in zip (layer_dims [:- 1 ], layer_dims [1 :]):
493+ self .layers .append (nnx .Linear (in_dim , out_dim , rngs = rngs ))
494+
495+ def __call__ (self , x ):
496+ for layer in self .layers [:- 1 ]:
497+ x = self .activation (layer (x ))
498+ return self .layers [- 1 ](x )
499+
500+ N , dim = 3000 , 3
501+ data = random .normal (random .PRNGKey (0 ), (N , dim ))
502+ true_coefs = np .arange (1.0 , dim + 1.0 )
503+ logits = np .sum (true_coefs * data , axis = - 1 )
504+ labels = dist .Bernoulli (logits = logits ).sample (random .PRNGKey (1 ))
505+
506+ rng_key = random .PRNGKey (0 )
507+ nn_module = MLP (
508+ din = dim , dout = 1 , hidden_layers = [8 , 8 ], rngs = nnx .Rngs (params = rng_key )
509+ )
510+
511+ def prior (name , shape ):
512+ return dist .Cauchy () if name == "bias" else dist .Normal ()
513+
514+ def model (data , labels = None ):
515+ # Use the pre-initialized module with eager initialization
516+ nn = random_nnx_module ("nn" , nn_module , prior = prior )
517+ logits = nn (data ).squeeze (- 1 )
518+ return numpyro .sample ("obs" , dist .Bernoulli (logits = logits ), obs = labels )
519+
520+ nuts_kernel = NUTS (model )
521+ mcmc = MCMC (nuts_kernel , num_warmup = 1 , num_samples = 1 , progress_bar = False )
522+ mcmc .run (random .PRNGKey (0 ), data , labels )
523+ samples = mcmc .get_samples ()
524+
525+ # check both layers have parameters in the samples
526+ assert "nn/layers.0.bias" in samples
527+ assert "nn/layers.1.bias" in samples
528+
529+
483530@pytest .mark .skipif (sys .version_info [:2 ] == (3 , 9 ), reason = "Skipping on Python 3.9" )
484531def test_eqx_module ():
485532 import equinox as eqx
@@ -606,3 +653,59 @@ def model(data, labels=None):
606653 samples = mcmc .get_samples ()
607654 assert "nn/bias" in samples
608655 assert "nn/weight" in samples
656+
657+
658+ @pytest .mark .skipif (sys .version_info [:2 ] == (3 , 9 ), reason = "Skipping on Python 3.9" )
659+ def test_random_eqx_module_mcmc_sequence_params ():
660+ import equinox as eqx
661+
662+ class MLP (eqx .Module ):
663+ layers : list
664+
665+ def __init__ (
666+ self ,
667+ in_size : int ,
668+ out_size : int ,
669+ hidden_layers : list [int ],
670+ key : jax .random .PRNGKey ,
671+ ):
672+ keys = jax .random .split (key , len (hidden_layers ))
673+ self .layers = []
674+
675+ # Create all linear layers
676+ self .layers = []
677+ layer_dims = [in_size ] + list (hidden_layers ) + [out_size ]
678+ for i , (in_dim , out_dim ) in enumerate (zip (layer_dims [:- 1 ], layer_dims [1 :])):
679+ self .layers .append (eqx .nn .Linear (in_dim , out_dim , key = keys [i ]))
680+
681+ def __call__ (self , x ):
682+ for layer in self .layers [:- 1 ]:
683+ x = jax .nn .relu (layer (x ))
684+ return self .layers [- 1 ](x ) # Final layer, no activation
685+
686+ N , dim = 3000 , 3
687+ data = random .normal (random .PRNGKey (0 ), (N , dim ))
688+ true_coefs = np .arange (1.0 , dim + 1.0 )
689+ logits = np .sum (true_coefs * data , axis = - 1 )
690+ labels = dist .Bernoulli (logits = logits ).sample (random .PRNGKey (1 ))
691+
692+ rng_key = random .PRNGKey (0 )
693+ nn_module = MLP (in_size = dim , out_size = 1 , hidden_layers = [8 , 8 ], key = rng_key )
694+
695+ def prior (name , shape ):
696+ return dist .Cauchy () if name == "bias" else dist .Normal ()
697+
698+ def model (data , labels = None ):
699+ # Use the pre-initialized module with eager initialization
700+ nn = random_eqx_module ("nn" , nn_module , prior = prior )
701+ logits = jax .vmap (nn )(data ).squeeze (- 1 )
702+ return numpyro .sample ("obs" , dist .Bernoulli (logits = logits ), obs = labels )
703+
704+ nuts_kernel = NUTS (model )
705+ mcmc = MCMC (nuts_kernel , num_warmup = 1 , num_samples = 1 , progress_bar = False )
706+ mcmc .run (random .PRNGKey (0 ), data , labels )
707+ samples = mcmc .get_samples ()
708+
709+ # check both layers have parameters in the samples
710+ assert "nn/layers[0].bias" in samples
711+ assert "nn/layers[1].bias" in samples
0 commit comments