pmap
state
#2121
Answered
by
marcvanzee
yuanqing-wang
asked this question in
Q&A
`pmap` state
#2121
-
How should I broadcast a training state to multiple devices and
But got
|
Beta Was this translation helpful? Give feedback.
Answered by
marcvanzee
May 13, 2022
Replies: 1 comment 3 replies
-
You should use import jax
import jax.numpy as jnp
from flax import linen as nn
from flax import jax_utils
import optax
from flax.training.train_state import TrainState
model = nn.Dense(1)
x = jnp.ones((jax.device_count(), 3))
params = model.init(jax.random.PRNGKey(0), x)
tx = optax.adam(learning_rate=1e-3)
state = TrainState.create(
apply_fn=model.apply, params=params, tx=tx,
)
state = jax_utils.replicate(state)
def loss_fn(state, x):
return (model.apply(state.params, x) ** 2.0).mean()
jax.pmap(loss_fn)(state, x) |
Beta Was this translation helpful? Give feedback.
3 replies
Answer selected by
marcvanzee
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You should use
flax.jax_utils.replicate
. Also, it is safer to usejax.device_count()
rather than hardcoding8
in your array. Finally, theDense
layer expects at least two dimensions. This code should work: