Port https://github.com/NovaSky-AI/SkyRL/pull/1081 to skyrl folder#1131
Conversation
There was a problem hiding this comment.
Code Review
This pull request ports changes from another repository, refactoring the Llama3 model to use StackedDecoderLayers. This is a positive change that encapsulates the layer-stacking logic, improving code clarity and enabling performance optimizations like using jax.lax.scan. The addition of an is_training flag to bypass KV cache generation during training is also a valuable memory optimization. The implementation is solid, and I have one minor suggestion to enhance code conciseness.
| def create_layer(rngs: nnx.Rngs) -> Llama3DecoderLayer: | ||
| return Llama3DecoderLayer(config, dtype=dtype, rngs=rngs) | ||
|
|
||
| self.layers = StackedDecoderLayers(create_layer, config.num_hidden_layers, rngs) |
There was a problem hiding this comment.
For conciseness, you can use a lambda function directly as an argument instead of defining a separate create_layer helper function. This is a common pattern when a small function is needed for a single use, making the code more compact.
| def create_layer(rngs: nnx.Rngs) -> Llama3DecoderLayer: | |
| return Llama3DecoderLayer(config, dtype=dtype, rngs=rngs) | |
| self.layers = StackedDecoderLayers(create_layer, config.num_hidden_layers, rngs) | |
| self.layers = StackedDecoderLayers( | |
| lambda rngs: Llama3DecoderLayer(config, dtype=dtype, rngs=rngs), | |
| config.num_hidden_layers, | |
| rngs, | |
| ) |
|
We decided to do a little surgery first to make sure the |
See #1081