-
Notifications
You must be signed in to change notification settings - Fork 321
Port https://github.com/NovaSky-AI/SkyRL/pull/1095 to skyrl folder #1129
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,30 @@ | ||
| from flax import nnx | ||
| import jax | ||
| import jax.numpy as jnp | ||
| from transformers import AutoConfig | ||
|
|
||
| from skyrl.tx.models.configs import ModelConfig | ||
| from skyrl.tx.models.types import ModelForCausalLM | ||
| from skyrl.tx.utils.models import load_safetensors, resolve_model_path | ||
|
|
||
|
|
||
| def load_model( | ||
| model_name: str, | ||
| config_cls: type[ModelConfig], | ||
| model_cls: type[ModelForCausalLM], | ||
| mesh_axes: tuple[str, ...], | ||
| *, | ||
| mesh_shape: tuple[int, ...] | None = None, | ||
| **config_kwargs, | ||
| ) -> tuple[ModelConfig, ModelForCausalLM]: | ||
| """Create a JAX model and load weights from the HuggingFace cache.""" | ||
| weights_dir = resolve_model_path(model_name) | ||
| base_config = AutoConfig.from_pretrained(model_name) | ||
| config = config_cls(base_config, shard_attention_heads=True, **config_kwargs) | ||
| if mesh_shape is None: | ||
| mesh_shape = (1,) * len(mesh_axes) | ||
| mesh = jax.make_mesh(mesh_shape, mesh_axes, axis_types=(jax.sharding.AxisType.Auto,) * len(mesh_axes)) | ||
| with jax.set_mesh(mesh): | ||
| model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) | ||
| load_safetensors(weights_dir, config, model) | ||
| return config, model |
| Original file line number | Diff line number | Diff line change | ||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,17 +1,14 @@ | ||||||||||||
| import tempfile | ||||||||||||
|
|
||||||||||||
| from flax import nnx | ||||||||||||
| import jax | ||||||||||||
| import jax.numpy as jnp | ||||||||||||
| import numpy as np | ||||||||||||
| import pytest | ||||||||||||
| from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer | ||||||||||||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||||||||||||
|
|
||||||||||||
| from skyrl.tx.models.configs import Llama3Config, ModelConfig, Qwen3Config | ||||||||||||
| from skyrl.tx.models.llama3 import Llama3ForCausalLM | ||||||||||||
| from skyrl.tx.models.qwen3 import Qwen3ForCausalLM | ||||||||||||
| from skyrl.tx.models.types import ModelForCausalLM | ||||||||||||
| from skyrl.tx.utils.models import load_safetensors | ||||||||||||
|
|
||||||||||||
| from tests.tx.models.conftest import load_model | ||||||||||||
|
|
||||||||||||
| MODEL_PARAMS = [ | ||||||||||||
| ("unsloth/Llama-3.2-1B", Llama3Config, Llama3ForCausalLM, ("fsdp", "tp")), | ||||||||||||
|
|
@@ -20,32 +17,6 @@ | |||||||||||
| MODEL_IDS = ["llama3", "qwen3"] | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| def load_model( | ||||||||||||
| tmp_dir: str, | ||||||||||||
| model_name: str, | ||||||||||||
| config_cls: type[ModelConfig], | ||||||||||||
| model_cls: type[ModelForCausalLM], | ||||||||||||
| mesh_axes: tuple[str, str], | ||||||||||||
| *, | ||||||||||||
| loss_chunk_size: int = 0, | ||||||||||||
| ) -> ModelForCausalLM: | ||||||||||||
| """Load model from pre-saved weights directory.""" | ||||||||||||
| base_config = AutoConfig.from_pretrained(model_name) | ||||||||||||
| config = config_cls( | ||||||||||||
| base_config, | ||||||||||||
| max_lora_adapters=1, | ||||||||||||
| max_lora_rank=1, | ||||||||||||
| shard_attention_heads=True, | ||||||||||||
| loss_chunk_size=loss_chunk_size, | ||||||||||||
| gradient_checkpointing=False, | ||||||||||||
| ) | ||||||||||||
| mesh = jax.make_mesh((1, 1), mesh_axes, axis_types=(jax.sharding.AxisType.Auto,) * 2) | ||||||||||||
| with jax.set_mesh(mesh): | ||||||||||||
| model = model_cls(config, dtype=jnp.float32, rngs=nnx.Rngs(0)) | ||||||||||||
| load_safetensors(tmp_dir, config, model) | ||||||||||||
| return model | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS) | ||||||||||||
| def test_compute_logits( | ||||||||||||
| model_name: str, | ||||||||||||
|
|
@@ -59,22 +30,27 @@ def test_compute_logits( | |||||||||||
| inputs = ["The capital of France is", "Hello world"] | ||||||||||||
| batch = tokenizer(inputs, return_tensors="pt", padding=True) | ||||||||||||
|
|
||||||||||||
| with tempfile.TemporaryDirectory() as tmp: | ||||||||||||
| # Load HF model, get logits, save weights, then delete to free memory | ||||||||||||
| hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True) | ||||||||||||
| hf_outputs = hf_model(batch.input_ids, attention_mask=batch.attention_mask) | ||||||||||||
| hf_logits = hf_outputs.logits.detach().numpy() | ||||||||||||
| hf_model.save_pretrained(tmp, safe_serialization=True) | ||||||||||||
| del hf_model, hf_outputs | ||||||||||||
|
|
||||||||||||
| # Load our model from saved weights | ||||||||||||
| model = load_model(tmp, model_name, config_cls, model_cls, mesh_axes) | ||||||||||||
| # Load HF model, get logits, then delete to free memory | ||||||||||||
| hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True) | ||||||||||||
| hf_outputs = hf_model(batch.input_ids, attention_mask=batch.attention_mask) | ||||||||||||
| hf_logits = hf_outputs.logits.detach().numpy() | ||||||||||||
| del hf_model, hf_outputs | ||||||||||||
|
|
||||||||||||
| _, model = load_model( | ||||||||||||
| model_name, | ||||||||||||
| config_cls, | ||||||||||||
| model_cls, | ||||||||||||
| mesh_axes, | ||||||||||||
| max_lora_adapters=1, | ||||||||||||
| max_lora_rank=1, | ||||||||||||
| gradient_checkpointing=False, | ||||||||||||
| ) | ||||||||||||
|
Comment on lines
+46
to
+47
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||
|
|
||||||||||||
| # Get our logits via compute_logits | ||||||||||||
| outputs = model(batch.input_ids.numpy(), attention_mask=batch.attention_mask.numpy()) | ||||||||||||
| our_logits = np.asarray(model.compute_logits(outputs.last_hidden_state)) | ||||||||||||
| # Get our logits via compute_logits | ||||||||||||
| outputs = model(batch.input_ids.numpy(), attention_mask=batch.attention_mask.numpy()) | ||||||||||||
| our_logits = np.asarray(model.compute_logits(outputs.last_hidden_state)) | ||||||||||||
|
|
||||||||||||
| np.testing.assert_allclose(our_logits, hf_logits, rtol=3e-2, atol=3e-2) | ||||||||||||
| np.testing.assert_allclose(our_logits, hf_logits, rtol=3e-2, atol=3e-2) | ||||||||||||
|
|
||||||||||||
|
|
||||||||||||
| @pytest.mark.parametrize("model_name,config_cls,model_cls,mesh_axes", MODEL_PARAMS, ids=MODEL_IDS) | ||||||||||||
|
|
@@ -94,22 +70,18 @@ def test_chunked_logprobs( | |||||||||||
| attention_mask = jnp.array(batch.attention_mask.numpy()) | ||||||||||||
| target_ids = jnp.roll(input_ids, -1, axis=1) | ||||||||||||
|
|
||||||||||||
| with tempfile.TemporaryDirectory() as tmp: | ||||||||||||
| # Save HF weights once | ||||||||||||
| hf_model = AutoModelForCausalLM.from_pretrained(model_name, attn_implementation="eager", use_safetensors=True) | ||||||||||||
| hf_model.save_pretrained(tmp, safe_serialization=True) | ||||||||||||
| del hf_model | ||||||||||||
|
|
||||||||||||
| # Load non-chunked model, compute logprobs, then delete | ||||||||||||
| model = load_model(tmp, model_name, config_cls, model_cls, mesh_axes, loss_chunk_size=0) | ||||||||||||
| outputs = model(input_ids, attention_mask=attention_mask) | ||||||||||||
| logprobs_nonchunked = np.asarray(model.compute_logprobs(outputs.last_hidden_state, target_ids)) | ||||||||||||
| del model, outputs | ||||||||||||
|
|
||||||||||||
| # Load chunked model, compute logprobs | ||||||||||||
| model = load_model(tmp, model_name, config_cls, model_cls, mesh_axes, loss_chunk_size=chunk_size) | ||||||||||||
| outputs = model(input_ids, attention_mask=attention_mask) | ||||||||||||
| logprobs_chunked = np.asarray(model.compute_logprobs(outputs.last_hidden_state, target_ids)) | ||||||||||||
| common_kwargs = dict(max_lora_adapters=1, max_lora_rank=1, gradient_checkpointing=False) | ||||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Suggested change
|
||||||||||||
|
|
||||||||||||
| # Load non-chunked model, compute logprobs, then delete | ||||||||||||
| _, model = load_model(model_name, config_cls, model_cls, mesh_axes, loss_chunk_size=0, **common_kwargs) | ||||||||||||
| outputs = model(input_ids, attention_mask=attention_mask) | ||||||||||||
| logprobs_nonchunked = np.asarray(model.compute_logprobs(outputs.last_hidden_state, target_ids)) | ||||||||||||
| del model, outputs | ||||||||||||
|
|
||||||||||||
| # Load chunked model, compute logprobs | ||||||||||||
| _, model = load_model(model_name, config_cls, model_cls, mesh_axes, loss_chunk_size=chunk_size, **common_kwargs) | ||||||||||||
| outputs = model(input_ids, attention_mask=attention_mask) | ||||||||||||
| logprobs_chunked = np.asarray(model.compute_logprobs(outputs.last_hidden_state, target_ids)) | ||||||||||||
|
|
||||||||||||
| np.testing.assert_allclose( | ||||||||||||
| logprobs_chunked, | ||||||||||||
|
|
||||||||||||
Uh oh!
There was an error while loading. Please reload this page.