88from skyrl .tx .layers .rotary_embedding import apply_rope
99from skyrl .tx .layers .layernorm import RMSNorm
1010from skyrl .tx .layers .attention import dot_product_attention
11+ from skyrl .tx .layers .stacked import StackedDecoderLayers
1112from skyrl .tx .utils .logits_processor import LogitsProcessorMixin , LMHead
1213from skyrl .tx .models .types import CausalLMOutput , ModelForCausalLM , ModelOutput
1314from skyrl .tx .utils .generator import GeneratorMixin , KVCache
@@ -211,9 +212,11 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) ->
211212 embedding_init = nnx .initializers .normal (),
212213 rngs = rngs ,
213214 )
214- self .layers = nnx .List (
215- [Llama3DecoderLayer (config , dtype = dtype , rngs = rngs ) for _ in range (config .num_hidden_layers )]
216- )
215+
216+ def create_layer (rngs : nnx .Rngs ) -> Llama3DecoderLayer :
217+ return Llama3DecoderLayer (config , dtype = dtype , rngs = rngs )
218+
219+ self .layers = StackedDecoderLayers (create_layer , config .num_hidden_layers , rngs )
217220 self .norm = RMSNorm (config .hidden_size , eps = config .rms_norm_eps , dtype = dtype , rngs = rngs )
218221
219222 def __call__ (
@@ -225,36 +228,32 @@ def __call__(
225228 output_hidden_states : bool | None = None ,
226229 adapter_indices : jax .Array | None = None ,
227230 kv_cache : KVCache | None = None ,
231+ is_training : bool = False ,
228232 ) -> ModelOutput :
229233 output_hidden_states = (
230234 output_hidden_states if output_hidden_states is not None else self .config .output_hidden_states
231235 )
232236
233237 hidden_states = self .embed_tokens (input_ids , adapter_indices = adapter_indices )
234- all_hidden_states : list [jax .Array ] = []
235- updated_keys , updated_values = [], []
236-
237- for layer_idx , layer in enumerate (self .layers ):
238- if output_hidden_states :
239- all_hidden_states .append (hidden_states )
240-
241- hidden_states , (k , v ) = layer (
242- hidden_states ,
243- attention_mask = attention_mask ,
244- positions = positions ,
245- adapter_indices = adapter_indices ,
246- kv_cache = kv_cache and (kv_cache .keys [layer_idx ], kv_cache .values [layer_idx ]),
247- )
248- updated_keys .append (k )
249- updated_values .append (v )
238+
239+ hidden_states , all_hidden_states , new_kv_cache = self .layers (
240+ hidden_states ,
241+ attention_mask = attention_mask ,
242+ positions = positions ,
243+ adapter_indices = adapter_indices ,
244+ kv_cache = kv_cache ,
245+ output_hidden_states = output_hidden_states ,
246+ gradient_checkpointing = self .config .gradient_checkpointing ,
247+ is_training = is_training ,
248+ )
250249
251250 hidden_states = self .norm (hidden_states )
252251 if output_hidden_states :
253252 all_hidden_states .append (hidden_states )
254253
255254 return ModelOutput (
256255 last_hidden_state = hidden_states ,
257- kv_cache = KVCache . update ( kv_cache , updated_keys , updated_values , positions , attention_mask ) ,
256+ kv_cache = new_kv_cache ,
258257 hidden_states = all_hidden_states if output_hidden_states else None ,
259258 )
260259
@@ -299,6 +298,7 @@ def __call__(
299298 output_hidden_states : bool | None = None ,
300299 adapter_indices : jax .Array | None = None ,
301300 kv_cache : KVCache | None = None ,
301+ is_training : bool = False ,
302302 ) -> CausalLMOutput :
303303 if positions is None :
304304 positions = jnp .arange (attention_mask .shape [1 ])[None , :]
@@ -310,6 +310,7 @@ def __call__(
310310 output_hidden_states = output_hidden_states ,
311311 adapter_indices = adapter_indices ,
312312 kv_cache = kv_cache ,
313+ is_training = is_training ,
313314 )
314315
315316 return CausalLMOutput (
0 commit comments