Replies: 1 comment 1 reply
-
Hey @therooler, this should work: import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.core import pretty_repr
x = jnp.ones((10, 50, 32)) # (batch, time, features)
lstm = nn.RNN(
nn.LSTMCell(64),
variable_axes={'params': 0},
split_rngs={'params': True},
variable_broadcast=False,
)
variables = lstm.init(jax.random.key(0), x)
y = lstm.apply(variables, x)
print(pretty_repr(jax.tree_map(lambda x: x.shape, variables)))
print(y.shape) |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hey, I'm trying to figure out how to make an RNN where each cell has its own set of parameters. The documentation for the weightshared version is clear to me.
but I'm having trouble figuring out how to have each cell see its own set of parameters. I have looked into the documentation on Lifting, but I am a little lost on where to start.
Beta Was this translation helpful? Give feedback.
All reactions