Skip to content

pmap state #2121

Answered by marcvanzee
yuanqing-wang asked this question in Q&A
May 12, 2022 · 1 comments · 3 replies
Discussion options

You must be logged in to vote

You should use flax.jax_utils.replicate. Also, it is safer to use jax.device_count() rather than hardcoding 8 in your array. Finally, the Dense layer expects at least two dimensions. This code should work:

import jax
import jax.numpy as jnp

from flax import linen as nn
from flax import jax_utils
import optax
from flax.training.train_state import TrainState

model = nn.Dense(1)
x = jnp.ones((jax.device_count(), 3))
params = model.init(jax.random.PRNGKey(0), x)
tx = optax.adam(learning_rate=1e-3)
state = TrainState.create(
    apply_fn=model.apply, params=params, tx=tx,
)
state = jax_utils.replicate(state)

def loss_fn(state, x):
    return (model.apply(state.params, x) ** 2.0).mean()

jax.p…

Replies: 1 comment 3 replies

Comment options

You must be logged in to vote
3 replies
@zhenlan0426
Comment options

@cgarciae
Comment options

@zhenlan0426
Comment options

Answer selected by marcvanzee
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
4 participants
Converted from issue

This discussion was converted from issue #2120 on May 13, 2022 09:53.