Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions skyrl/tests/tx/models/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from flax import nnx
import jax
import jax.numpy as jnp
from transformers import AutoConfig

from skyrl.tx.models.configs import ModelConfig
from skyrl.tx.models.types import ModelForCausalLM
from skyrl.tx.utils.models import load_safetensors, resolve_model_path


def load_model(
model_name: str,
config_cls: type[ModelConfig],
model_cls: type[ModelForCausalLM],
mesh_axes: tuple[str, ...],
*,
mesh_shape: tuple[int, ...] | None = None,
**config_kwargs,
) -> tuple[ModelConfig, ModelForCausalLM]:
"""Create a JAX model and load weights from the HuggingFace cache."""
weights_dir = resolve_model_path(model_name)
base_config = AutoConfig.from_pretrained(model_name)
config = config_cls(base_config, shard_attention_heads=True, **config_kwargs)
if mesh_shape is None:
mesh_shape = (1,) * len(mesh_axes)
mesh = jax.make_mesh(mesh_shape, mesh_axes, axis_types=(jax.sharding.AxisType.Auto,) * len(mesh_axes))
with jax.set_mesh(mesh):
model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(0))
load_safetensors(weights_dir, config, model)
return config, model
96 changes: 34 additions & 62 deletions skyrl/tests/tx/models/test_models_common.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import tempfile

from flax import nnx
import jax
import jax.numpy as jnp
import numpy as np
import pytest
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer

from skyrl.tx.models.configs import Llama3Config, ModelConfig, Qwen3Config
from skyrl.tx.models.llama3 import Llama3ForCausalLM
from skyrl.tx.models.qwen3 import Qwen3ForCausalLM
from skyrl.tx.models.types import ModelForCausalLM
from skyrl.tx.utils.models import load_safetensors

from tests.tx.models.conftest import load_model
Comment thread
devin-ai-integration[bot] marked this conversation as resolved.

MODEL_PARAMS = [
("unsloth/Llama-3.2-1B", Llama3Config, Llama3ForCausalLM, ("fsdp", "tp")),
Expand All @@ -20,32 +17,6 @@
MODEL_IDS = ["llama3", "qwen3"]


def load_model(
tmp_dir: str,
model_name: str,
config_cls: type[ModelConfig],
model_cls: type[ModelForCausalLM],
mesh_axes: tuple[str, str],
*,
loss_chunk_size: int = 0,
) -> ModelForCausalLM:
"""Load model from pre-saved weights directory."""
base_config = AutoConfig.from_pretrained(model_name)
config = config_cls(
base_config,
max_lora_adapters=1,
max_lora_rank=1,
shard_attention_heads=True,
loss_chunk_size=loss_chunk_size,
gradient_checkpointing=False,
)
mesh = jax.make_mesh((1, 1), mesh_axes, axis_types=(jax.sharding.AxisType.Auto,) * 2)
with jax.set_mesh(mesh):
model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(0))
load_safetensors(tmp_dir, config, model)
return model


@pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS)
def test_compute_logits(
model_name: str,
Expand All @@ -59,22 +30,27 @@ def test_compute_logits(
inputs = ["The capital of France is", "Hello world"]
batch = tokenizer(inputs, return_tensors="pt", padding=True)

with tempfile.TemporaryDirectory() as tmp:
# Load HF model, get logits, save weights, then delete to free memory
hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True)
hf_outputs = hf_model(batch.input_ids, attention_mask=batch.attention_mask)
hf_logits = hf_outputs.logits.detach().numpy()
hf_model.save_pretrained(tmp, safe_serialization=True)
del hf_model, hf_outputs

# Load our model from saved weights
model = load_model(tmp, model_name, config_cls, model_cls, mesh_axes)
# Load HF model, get logits, then delete to free memory
hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True)
hf_outputs = hf_model(batch.input_ids, attention_mask=batch.attention_mask)
hf_logits = hf_outputs.logits.detach().numpy()
del hf_model, hf_outputs

_, model = load_model(
model_name,
config_cls,
model_cls,
mesh_axes,
max_lora_adapters=1,
max_lora_rank=1,
gradient_checkpointing=False,
)
Comment on lines +46 to +47
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The shard_attention_heads=True parameter seems to be missing in the call to load_model. The original load_model function in this file hardcoded this parameter. Its omission in the refactored code might change the model's configuration and affect the test's correctness. It should be added to maintain consistency with the previous behavior.

Suggested change
gradient_checkpointing=False,
)
gradient_checkpointing=False,
shard_attention_heads=True,
)


# Get our logits via compute_logits
outputs = model(batch.input_ids.numpy(), attention_mask=batch.attention_mask.numpy())
our_logits = np.asarray(model.compute_logits(outputs.last_hidden_state))
# Get our logits via compute_logits
outputs = model(batch.input_ids.numpy(), attention_mask=batch.attention_mask.numpy())
our_logits = np.asarray(model.compute_logits(outputs.last_hidden_state))

np.testing.assert_allclose(our_logits, hf_logits, rtol=3e-2, atol=3e-2)
np.testing.assert_allclose(our_logits, hf_logits, rtol=3e-2, atol=3e-2)


@pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS)
Expand All @@ -94,22 +70,18 @@ def test_chunked_logprobs(
attention_mask = jnp.array(batch.attention_mask.numpy())
target_ids = jnp.roll(input_ids, -1, axis=1)

with tempfile.TemporaryDirectory() as tmp:
# Save HF weights once
hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True)
hf_model.save_pretrained(tmp, safe_serialization=True)
del hf_model

# Load non-chunked model, compute logprobs, then delete
model = load_model(tmp, model_name, config_cls, model_cls, mesh_axes, loss_chunk_size=0)
outputs = model(input_ids, attention_mask=attention_mask)
logprobs_nonchunked = np.asarray(model.compute_logprobs(outputs.last_hidden_state, target_ids))
del model, outputs

# Load chunked model, compute logprobs
model = load_model(tmp, model_name, config_cls, model_cls, mesh_axes, loss_chunk_size=chunk_size)
outputs = model(input_ids, attention_mask=attention_mask)
logprobs_chunked = np.asarray(model.compute_logprobs(outputs.last_hidden_state, target_ids))
common_kwargs = dict(max_lora_adapters=1, max_lora_rank=1, gradient_checkpointing=False)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The shard_attention_heads=True parameter seems to be missing from common_kwargs. The original model loading logic included this parameter. To ensure the test behaves as it did before the refactoring, this parameter should be added to the common keyword arguments passed to load_model.

Suggested change
common_kwargs = dict(max_lora_adapters=1, max_lora_rank=1, gradient_checkpointing=False)
common_kwargs = dict(max_lora_adapters=1, max_lora_rank=1, gradient_checkpointing=False, shard_attention_heads=True)


# Load non-chunked model, compute logprobs, then delete
_, model = load_model(model_name, config_cls, model_cls, mesh_axes, loss_chunk_size=0, **common_kwargs)
outputs = model(input_ids, attention_mask=attention_mask)
logprobs_nonchunked = np.asarray(model.compute_logprobs(outputs.last_hidden_state, target_ids))
del model, outputs

# Load chunked model, compute logprobs
_, model = load_model(model_name, config_cls, model_cls, mesh_axes, loss_chunk_size=chunk_size, **common_kwargs)
outputs = model(input_ids, attention_mask=attention_mask)
logprobs_chunked = np.asarray(model.compute_logprobs(outputs.last_hidden_state, target_ids))

np.testing.assert_allclose(
logprobs_chunked,
Expand Down
Loading
Loading