Skip to content

Commit 448ea3f

Browse files
kylejcaronkylejcaron
andauthored
Fixed parameter updates for sequence layers in nn modules (#2024)
* fixed naming for sequence layers in nn modules * removed print statement * skipped nnx/eqx tests on py39 --------- Co-authored-by: kylejcaron <kylejcaron@bitbucket.org>
1 parent dc98129 commit 448ea3f

2 files changed

Lines changed: 104 additions & 1 deletion

File tree

numpyro/contrib/module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def _update_params(params, new_params, prior, prefix=""):
231231
A helper to recursively set prior to new_params.
232232
"""
233233
for name, item in params.items():
234-
flatten_name = ".".join([prefix, name]) if prefix else name
234+
flatten_name = ".".join([str(prefix), str(name)]) if prefix else str(name)
235235
if isinstance(item, dict):
236236
assert not isinstance(prior, dict) or flatten_name not in prior
237237
new_item = new_params[name]

test/contrib/test_module.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,53 @@ def model(data, labels=None):
480480
assert "nn/w" in samples
481481

482482

483+
@pytest.mark.skipif(sys.version_info[:2] == (3, 9), reason="Skipping on Python 3.9")
484+
def test_random_nnx_module_mcmc_sequence_params():
485+
from flax import nnx
486+
487+
class MLP(nnx.Module):
488+
def __init__(self, din, dout, hidden_layers, *, rngs, activation=jax.nn.relu):
489+
self.activation = activation
490+
self.layers = []
491+
layer_dims = [din] + hidden_layers + [dout]
492+
for in_dim, out_dim in zip(layer_dims[:-1], layer_dims[1:]):
493+
self.layers.append(nnx.Linear(in_dim, out_dim, rngs=rngs))
494+
495+
def __call__(self, x):
496+
for layer in self.layers[:-1]:
497+
x = self.activation(layer(x))
498+
return self.layers[-1](x)
499+
500+
N, dim = 3000, 3
501+
data = random.normal(random.PRNGKey(0), (N, dim))
502+
true_coefs = np.arange(1.0, dim + 1.0)
503+
logits = np.sum(true_coefs * data, axis=-1)
504+
labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))
505+
506+
rng_key = random.PRNGKey(0)
507+
nn_module = MLP(
508+
din=dim, dout=1, hidden_layers=[8, 8], rngs=nnx.Rngs(params=rng_key)
509+
)
510+
511+
def prior(name, shape):
512+
return dist.Cauchy() if name == "bias" else dist.Normal()
513+
514+
def model(data, labels=None):
515+
# Use the pre-initialized module with eager initialization
516+
nn = random_nnx_module("nn", nn_module, prior=prior)
517+
logits = nn(data).squeeze(-1)
518+
return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)
519+
520+
nuts_kernel = NUTS(model)
521+
mcmc = MCMC(nuts_kernel, num_warmup=1, num_samples=1, progress_bar=False)
522+
mcmc.run(random.PRNGKey(0), data, labels)
523+
samples = mcmc.get_samples()
524+
525+
# check both layers have parameters in the samples
526+
assert "nn/layers.0.bias" in samples
527+
assert "nn/layers.1.bias" in samples
528+
529+
483530
@pytest.mark.skipif(sys.version_info[:2] == (3, 9), reason="Skipping on Python 3.9")
484531
def test_eqx_module():
485532
import equinox as eqx
@@ -606,3 +653,59 @@ def model(data, labels=None):
606653
samples = mcmc.get_samples()
607654
assert "nn/bias" in samples
608655
assert "nn/weight" in samples
656+
657+
658+
@pytest.mark.skipif(sys.version_info[:2] == (3, 9), reason="Skipping on Python 3.9")
659+
def test_random_eqx_module_mcmc_sequence_params():
660+
import equinox as eqx
661+
662+
class MLP(eqx.Module):
663+
layers: list
664+
665+
def __init__(
666+
self,
667+
in_size: int,
668+
out_size: int,
669+
hidden_layers: list[int],
670+
key: jax.random.PRNGKey,
671+
):
672+
keys = jax.random.split(key, len(hidden_layers))
673+
self.layers = []
674+
675+
# Create all linear layers
676+
self.layers = []
677+
layer_dims = [in_size] + list(hidden_layers) + [out_size]
678+
for i, (in_dim, out_dim) in enumerate(zip(layer_dims[:-1], layer_dims[1:])):
679+
self.layers.append(eqx.nn.Linear(in_dim, out_dim, key=keys[i]))
680+
681+
def __call__(self, x):
682+
for layer in self.layers[:-1]:
683+
x = jax.nn.relu(layer(x))
684+
return self.layers[-1](x) # Final layer, no activation
685+
686+
N, dim = 3000, 3
687+
data = random.normal(random.PRNGKey(0), (N, dim))
688+
true_coefs = np.arange(1.0, dim + 1.0)
689+
logits = np.sum(true_coefs * data, axis=-1)
690+
labels = dist.Bernoulli(logits=logits).sample(random.PRNGKey(1))
691+
692+
rng_key = random.PRNGKey(0)
693+
nn_module = MLP(in_size=dim, out_size=1, hidden_layers=[8, 8], key=rng_key)
694+
695+
def prior(name, shape):
696+
return dist.Cauchy() if name == "bias" else dist.Normal()
697+
698+
def model(data, labels=None):
699+
# Use the pre-initialized module with eager initialization
700+
nn = random_eqx_module("nn", nn_module, prior=prior)
701+
logits = jax.vmap(nn)(data).squeeze(-1)
702+
return numpyro.sample("obs", dist.Bernoulli(logits=logits), obs=labels)
703+
704+
nuts_kernel = NUTS(model)
705+
mcmc = MCMC(nuts_kernel, num_warmup=1, num_samples=1, progress_bar=False)
706+
mcmc.run(random.PRNGKey(0), data, labels)
707+
samples = mcmc.get_samples()
708+
709+
# check both layers have parameters in the samples
710+
assert "nn/layers[0].bias" in samples
711+
assert "nn/layers[1].bias" in samples

0 commit comments

Comments
 (0)