Skip to content

Commit dac7a69

Browse files
authored
Port #1081 to skyrl (#1131)
See #1081 <!-- devin-review-badge-begin --> --- <a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1131" target="_blank"> <picture> <source media="(prefers-color-scheme: dark)" srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1"> <img src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1" alt="Open with Devin"> </picture> </a> <!-- devin-review-badge-end -->
1 parent d38b99c commit dac7a69

1 file changed

Lines changed: 21 additions & 20 deletions

File tree

skyrl/tx/models/llama3.py

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from skyrl.tx.layers.rotary_embedding import apply_rope
99
from skyrl.tx.layers.layernorm import RMSNorm
1010
from skyrl.tx.layers.attention import dot_product_attention
11+
from skyrl.tx.layers.stacked import StackedDecoderLayers
1112
from skyrl.tx.utils.logits_processor import LogitsProcessorMixin, LMHead
1213
from skyrl.tx.models.types import CausalLMOutput, ModelForCausalLM, ModelOutput
1314
from skyrl.tx.utils.generator import GeneratorMixin, KVCache
@@ -211,9 +212,11 @@ def __init__(self, config: LlamaConfig, *, dtype: jnp.dtype, rngs: nnx.Rngs) ->
211212
embedding_init=nnx.initializers.normal(),
212213
rngs=rngs,
213214
)
214-
self.layers = nnx.List(
215-
[Llama3DecoderLayer(config, dtype=dtype, rngs=rngs) for _ in range(config.num_hidden_layers)]
216-
)
215+
216+
def create_layer(rngs: nnx.Rngs) -> Llama3DecoderLayer:
217+
return Llama3DecoderLayer(config, dtype=dtype, rngs=rngs)
218+
219+
self.layers = StackedDecoderLayers(create_layer, config.num_hidden_layers, rngs)
217220
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps, dtype=dtype, rngs=rngs)
218221

219222
def __call__(
@@ -225,36 +228,32 @@ def __call__(
225228
output_hidden_states: bool | None = None,
226229
adapter_indices: jax.Array | None = None,
227230
kv_cache: KVCache | None = None,
231+
is_training: bool = False,
228232
) -> ModelOutput:
229233
output_hidden_states = (
230234
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
231235
)
232236

233237
hidden_states = self.embed_tokens(input_ids, adapter_indices=adapter_indices)
234-
all_hidden_states: list[jax.Array] = []
235-
updated_keys, updated_values = [], []
236-
237-
for layer_idx, layer in enumerate(self.layers):
238-
if output_hidden_states:
239-
all_hidden_states.append(hidden_states)
240-
241-
hidden_states, (k, v) = layer(
242-
hidden_states,
243-
attention_mask=attention_mask,
244-
positions=positions,
245-
adapter_indices=adapter_indices,
246-
kv_cache=kv_cache and (kv_cache.keys[layer_idx], kv_cache.values[layer_idx]),
247-
)
248-
updated_keys.append(k)
249-
updated_values.append(v)
238+
239+
hidden_states, all_hidden_states, new_kv_cache = self.layers(
240+
hidden_states,
241+
attention_mask=attention_mask,
242+
positions=positions,
243+
adapter_indices=adapter_indices,
244+
kv_cache=kv_cache,
245+
output_hidden_states=output_hidden_states,
246+
gradient_checkpointing=self.config.gradient_checkpointing,
247+
is_training=is_training,
248+
)
250249

251250
hidden_states = self.norm(hidden_states)
252251
if output_hidden_states:
253252
all_hidden_states.append(hidden_states)
254253

255254
return ModelOutput(
256255
last_hidden_state=hidden_states,
257-
kv_cache=KVCache.update(kv_cache, updated_keys, updated_values, positions, attention_mask),
256+
kv_cache=new_kv_cache,
258257
hidden_states=all_hidden_states if output_hidden_states else None,
259258
)
260259

@@ -299,6 +298,7 @@ def __call__(
299298
output_hidden_states: bool | None = None,
300299
adapter_indices: jax.Array | None = None,
301300
kv_cache: KVCache | None = None,
301+
is_training: bool = False,
302302
) -> CausalLMOutput:
303303
if positions is None:
304304
positions = jnp.arange(attention_mask.shape[1])[None, :]
@@ -310,6 +310,7 @@ def __call__(
310310
output_hidden_states=output_hidden_states,
311311
adapter_indices=adapter_indices,
312312
kv_cache=kv_cache,
313+
is_training=is_training,
313314
)
314315

315316
return CausalLMOutput(

0 commit comments

Comments
 (0)