Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 9 additions & 16 deletions skyrl/skyrl/tx/layers/stacked.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,12 @@ def __init__(
# Create first layer to get structure and shapes
first_layer = create_layer_fn(nnx.Rngs(layer_keys[0]))
graphdef, first_state = nnx.split(first_layer)
flat_first, treedef = jax.tree_util.tree_flatten(first_state)
flat_first, state_treedef = jax.tree_util.tree_flatten(first_state)

# Build a treedef with stacked partition metadata so tree_unflatten
# reconstructs Variables with the correct leading-layer sharding axis.
stacked_first_state = nnx.spmd.add_axis(first_state, 0, {nnx.PARTITION_NAME: None})
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

For consistency with nnx.spmd.add_axis and to prevent a potential AttributeError if PARTITION_NAME is not exposed at the top-level nnx module, it's safer to qualify it with the spmd submodule.

Suggested change
stacked_first_state = nnx.spmd.add_axis(first_state, 0, {nnx.PARTITION_NAME: None})
stacked_first_state = nnx.spmd.add_axis(first_state, 0, {nnx.spmd.PARTITION_NAME: None})

_, stacked_treedef = jax.tree_util.tree_flatten(stacked_first_state)

# Pre-allocate stacked arrays with correct sharding
stacked_flat = []
Expand All @@ -109,25 +114,13 @@ def copy_to_slice(stacked, arr, idx):
else:
layer = create_layer_fn(nnx.Rngs(layer_keys[layer_idx]))
_, state = nnx.split(layer)
flat, layer_treedef = jax.tree_util.tree_flatten(state)
assert layer_treedef == treedef, "Layer state structure mismatch while stacking decoder layers."
flat, current_treedef = jax.tree_util.tree_flatten(state)
assert current_treedef == state_treedef, "Layer state structure mismatch while stacking decoder layers."
for i, arr in enumerate(flat):
stacked_flat[i] = copy_to_slice(stacked_flat[i], arr, layer_idx)

# Reconstruct state from stacked arrays
stacked_state = jax.tree_util.tree_unflatten(treedef, stacked_flat)

# Sync NNX sharding metadata with actual array sharding.
# The arrays have correct stacked sharding from device_put, but NNX APIs
# (nnx.get_partition_spec, nnx.Optimizer) read from 'sharding_names' metadata.
for _, var in nnx.to_flat_state(stacked_state):
if not isinstance(var, nnx.Variable):
continue
array = var[...]
if hasattr(array, "sharding"):
array_sharding = array.sharding
if hasattr(array_sharding, "spec"):
var.set_metadata("sharding_names", tuple(array_sharding.spec))
stacked_state = jax.tree_util.tree_unflatten(stacked_treedef, stacked_flat)

self._stacked = nnx.merge(graphdef, stacked_state)

Expand Down
Loading