Skip to content

Commit 8173aa7

Browse files
authored
Port #1095 to skyrl folder (#1129)
See #1095 <!-- devin-review-badge-begin --> --- <a href="https://app.devin.ai/review/novasky-ai/skyrl/pull/1129" target="_blank"> <picture> <source media="(prefers-color-scheme: dark)" srcset="https://static.devin.ai/assets/gh-open-in-devin-review-dark.svg?v=1"> <img src="https://static.devin.ai/assets/gh-open-in-devin-review-light.svg?v=1" alt="Open with Devin"> </picture> </a> <!-- devin-review-badge-end -->
1 parent 924d244 commit 8173aa7

3 files changed

Lines changed: 169 additions & 179 deletions

File tree

skyrl/tests/tx/models/conftest.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
from flax import nnx
2+
import jax
3+
import jax.numpy as jnp
4+
from transformers import AutoConfig
5+
6+
from skyrl.tx.models.configs import ModelConfig
7+
from skyrl.tx.models.types import ModelForCausalLM
8+
from skyrl.tx.utils.models import load_safetensors, resolve_model_path
9+
10+
11+
def load_model(
12+
model_name: str,
13+
config_cls: type[ModelConfig],
14+
model_cls: type[ModelForCausalLM],
15+
mesh_axes: tuple[str, ...],
16+
*,
17+
mesh_shape: tuple[int, ...] | None = None,
18+
**config_kwargs,
19+
) -> tuple[ModelConfig, ModelForCausalLM]:
20+
"""Create a JAX model and load weights from the HuggingFace cache."""
21+
weights_dir = resolve_model_path(model_name)
22+
base_config = AutoConfig.from_pretrained(model_name)
23+
config = config_cls(base_config, shard_attention_heads=True, **config_kwargs)
24+
if mesh_shape is None:
25+
mesh_shape = (1,) * len(mesh_axes)
26+
mesh = jax.make_mesh(mesh_shape, mesh_axes, axis_types=(jax.sharding.AxisType.Auto,) * len(mesh_axes))
27+
with jax.set_mesh(mesh):
28+
model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(0))
29+
load_safetensors(weights_dir, config, model)
30+
return config, model

skyrl/tests/tx/models/test_models_common.py

Lines changed: 34 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
1-
import tempfile
2-
3-
from flax import nnx
4-
import jax
51
import jax.numpy as jnp
62
import numpy as np
73
import pytest
8-
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
4+
from transformers import AutoModelForCausalLM, AutoTokenizer
95

106
from skyrl.tx.models.configs import Llama3Config, ModelConfig, Qwen3Config
117
from skyrl.tx.models.llama3 import Llama3ForCausalLM
128
from skyrl.tx.models.qwen3 import Qwen3ForCausalLM
139
from skyrl.tx.models.types import ModelForCausalLM
14-
from skyrl.tx.utils.models import load_safetensors
10+
11+
from tests.tx.models.conftest import load_model
1512

1613
MODEL_PARAMS = [
1714
("unsloth/Llama-3.2-1B", Llama3Config, Llama3ForCausalLM, ("fsdp", "tp")),
@@ -20,32 +17,6 @@
2017
MODEL_IDS = ["llama3", "qwen3"]
2118

2219

23-
def load_model(
24-
tmp_dir: str,
25-
model_name: str,
26-
config_cls: type[ModelConfig],
27-
model_cls: type[ModelForCausalLM],
28-
mesh_axes: tuple[str, str],
29-
*,
30-
loss_chunk_size: int = 0,
31-
) -> ModelForCausalLM:
32-
"""Load model from pre-saved weights directory."""
33-
base_config = AutoConfig.from_pretrained(model_name)
34-
config = config_cls(
35-
base_config,
36-
max_lora_adapters=1,
37-
max_lora_rank=1,
38-
shard_attention_heads=True,
39-
loss_chunk_size=loss_chunk_size,
40-
gradient_checkpointing=False,
41-
)
42-
mesh = jax.make_mesh((1, 1), mesh_axes, axis_types=(jax.sharding.AxisType.Auto,) * 2)
43-
with jax.set_mesh(mesh):
44-
model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(0))
45-
load_safetensors(tmp_dir, config, model)
46-
return model
47-
48-
4920
@pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS)
5021
def test_compute_logits(
5122
model_name: str,
@@ -59,22 +30,27 @@ def test_compute_logits(
5930
inputs = ["The capital of France is", "Hello world"]
6031
batch = tokenizer(inputs, return_tensors="pt", padding=True)
6132

62-
with tempfile.TemporaryDirectory() as tmp:
63-
# Load HF model, get logits, save weights, then delete to free memory
64-
hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True)
65-
hf_outputs = hf_model(batch.input_ids, attention_mask=batch.attention_mask)
66-
hf_logits = hf_outputs.logits.detach().numpy()
67-
hf_model.save_pretrained(tmp, safe_serialization=True)
68-
del hf_model, hf_outputs
69-
70-
# Load our model from saved weights
71-
model = load_model(tmp, model_name, config_cls, model_cls, mesh_axes)
33+
# Load HF model, get logits, then delete to free memory
34+
hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True)
35+
hf_outputs = hf_model(batch.input_ids, attention_mask=batch.attention_mask)
36+
hf_logits = hf_outputs.logits.detach().numpy()
37+
del hf_model, hf_outputs
38+
39+
_, model = load_model(
40+
model_name,
41+
config_cls,
42+
model_cls,
43+
mesh_axes,
44+
max_lora_adapters=1,
45+
max_lora_rank=1,
46+
gradient_checkpointing=False,
47+
)
7248

73-
# Get our logits via compute_logits
74-
outputs = model(batch.input_ids.numpy(), attention_mask=batch.attention_mask.numpy())
75-
our_logits = np.asarray(model.compute_logits(outputs.last_hidden_state))
49+
# Get our logits via compute_logits
50+
outputs = model(batch.input_ids.numpy(), attention_mask=batch.attention_mask.numpy())
51+
our_logits = np.asarray(model.compute_logits(outputs.last_hidden_state))
7652

77-
np.testing.assert_allclose(our_logits, hf_logits, rtol=3e-2, atol=3e-2)
53+
np.testing.assert_allclose(our_logits, hf_logits, rtol=3e-2, atol=3e-2)
7854

7955

8056
@pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS)
@@ -94,22 +70,18 @@ def test_chunked_logprobs(
9470
attention_mask = jnp.array(batch.attention_mask.numpy())
9571
target_ids = jnp.roll(input_ids, -1, axis=1)
9672

97-
with tempfile.TemporaryDirectory() as tmp:
98-
# Save HF weights once
99-
hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True)
100-
hf_model.save_pretrained(tmp, safe_serialization=True)
101-
del hf_model
102-
103-
# Load non-chunked model, compute logprobs, then delete
104-
model = load_model(tmp, model_name, config_cls, model_cls, mesh_axes, loss_chunk_size=0)
105-
outputs = model(input_ids, attention_mask=attention_mask)
106-
logprobs_nonchunked = np.asarray(model.compute_logprobs(outputs.last_hidden_state, target_ids))
107-
del model, outputs
108-
109-
# Load chunked model, compute logprobs
110-
model = load_model(tmp, model_name, config_cls, model_cls, mesh_axes, loss_chunk_size=chunk_size)
111-
outputs = model(input_ids, attention_mask=attention_mask)
112-
logprobs_chunked = np.asarray(model.compute_logprobs(outputs.last_hidden_state, target_ids))
73+
common_kwargs = dict(max_lora_adapters=1, max_lora_rank=1, gradient_checkpointing=False)
74+
75+
# Load non-chunked model, compute logprobs, then delete
76+
_, model = load_model(model_name, config_cls, model_cls, mesh_axes, loss_chunk_size=0, **common_kwargs)
77+
outputs = model(input_ids, attention_mask=attention_mask)
78+
logprobs_nonchunked = np.asarray(model.compute_logprobs(outputs.last_hidden_state, target_ids))
79+
del model, outputs
80+
81+
# Load chunked model, compute logprobs
82+
_, model = load_model(model_name, config_cls, model_cls, mesh_axes, loss_chunk_size=chunk_size, **common_kwargs)
83+
outputs = model(input_ids, attention_mask=attention_mask)
84+
logprobs_chunked = np.asarray(model.compute_logprobs(outputs.last_hidden_state, target_ids))
11385

11486
np.testing.assert_allclose(
11587
logprobs_chunked,

0 commit comments

Comments
 (0)