Skip to content

Port https://github.com/NovaSky-AI/SkyRL/pull/1081 to skyrl folder#1131

Merged
pcmoritz merged 2 commits into
NovaSky-AI:mainfrom
pcmoritz:port-1081
Feb 16, 2026
Merged

Port https://github.com/NovaSky-AI/SkyRL/pull/1081 to skyrl folder#1131
pcmoritz merged 2 commits into
NovaSky-AI:mainfrom
pcmoritz:port-1081

Conversation

@pcmoritz
Copy link
Copy Markdown
Collaborator

@pcmoritz pcmoritz commented Feb 15, 2026

See #1081


Open with Devin

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request ports changes from another repository, refactoring the Llama3 model to use StackedDecoderLayers. This is a positive change that encapsulates the layer-stacking logic, improving code clarity and enabling performance optimizations like using jax.lax.scan. The addition of an is_training flag to bypass KV cache generation during training is also a valuable memory optimization. The implementation is solid, and I have one minor suggestion to enhance code conciseness.

Comment thread skyrl/tx/models/llama3.py
Comment on lines +216 to +219
def create_layer(rngs: nnx.Rngs) -> Llama3DecoderLayer:
return Llama3DecoderLayer(config, dtype=dtype, rngs=rngs)

self.layers = StackedDecoderLayers(create_layer, config.num_hidden_layers, rngs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For conciseness, you can use a lambda function directly as an argument instead of defining a separate create_layer helper function. This is a common pattern when a small function is needed for a single use, making the code more compact.

Suggested change
def create_layer(rngs: nnx.Rngs) -> Llama3DecoderLayer:
return Llama3DecoderLayer(config, dtype=dtype, rngs=rngs)
self.layers = StackedDecoderLayers(create_layer, config.num_hidden_layers, rngs)
self.layers = StackedDecoderLayers(
lambda rngs: Llama3DecoderLayer(config, dtype=dtype, rngs=rngs),
config.num_hidden_layers,
rngs,
)

Copy link
Copy Markdown
Contributor

@devin-ai-integration devin-ai-integration Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

✅ Devin Review: No Issues Found

Devin Review analyzed this PR and found no potential bugs to report.

View in Devin Review to see 4 additional findings.

Open in Devin Review

@pcmoritz pcmoritz closed this Feb 15, 2026
@pcmoritz
Copy link
Copy Markdown
Collaborator Author

We decided to do a little surgery first to make sure the skyrl folder preserves the git history.

@pcmoritz pcmoritz reopened this Feb 16, 2026
@pcmoritz pcmoritz merged commit dac7a69 into NovaSky-AI:main Feb 16, 2026
2 of 5 checks passed
@pcmoritz pcmoritz changed the title Port https://github.com/NovaSky-AI/SkyRL/pull/1081 to skyrl Port https://github.com/NovaSky-AI/SkyRL/pull/1081 to skyrl folder Feb 16, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant