1- import tempfile
2-
3- from flax import nnx
4- import jax
51import jax .numpy as jnp
62import numpy as np
73import pytest
8- from transformers import AutoConfig , AutoModelForCausalLM , AutoTokenizer
4+ from transformers import AutoModelForCausalLM , AutoTokenizer
95
106from skyrl .tx .models .configs import Llama3Config , ModelConfig , Qwen3Config
117from skyrl .tx .models .llama3 import Llama3ForCausalLM
128from skyrl .tx .models .qwen3 import Qwen3ForCausalLM
139from skyrl .tx .models .types import ModelForCausalLM
14- from skyrl .tx .utils .models import load_safetensors
10+
11+ from tests .tx .models .conftest import load_model
1512
1613MODEL_PARAMS = [
1714 ("unsloth/Llama-3.2-1B" , Llama3Config , Llama3ForCausalLM , ("fsdp" , "tp" )),
2017MODEL_IDS = ["llama3" , "qwen3" ]
2118
2219
23- def load_model (
24- tmp_dir : str ,
25- model_name : str ,
26- config_cls : type [ModelConfig ],
27- model_cls : type [ModelForCausalLM ],
28- mesh_axes : tuple [str , str ],
29- * ,
30- loss_chunk_size : int = 0 ,
31- ) -> ModelForCausalLM :
32- """Load model from pre-saved weights directory."""
33- base_config = AutoConfig .from_pretrained (model_name )
34- config = config_cls (
35- base_config ,
36- max_lora_adapters = 1 ,
37- max_lora_rank = 1 ,
38- shard_attention_heads = True ,
39- loss_chunk_size = loss_chunk_size ,
40- gradient_checkpointing = False ,
41- )
42- mesh = jax .make_mesh ((1 , 1 ), mesh_axes , axis_types = (jax .sharding .AxisType .Auto ,) * 2 )
43- with jax .set_mesh (mesh ):
44- model = model_cls (config , dtype = jnp .float32 , rngs = nnx .Rngs (0 ))
45- load_safetensors (tmp_dir , config , model )
46- return model
47-
48-
4920@pytest .mark .parametrize ("model_name,config_cls,model_cls,mesh_axes" , MODEL_PARAMS , ids = MODEL_IDS )
5021def test_compute_logits (
5122 model_name : str ,
@@ -59,22 +30,27 @@ def test_compute_logits(
5930 inputs = ["The capital of France is" , "Hello world" ]
6031 batch = tokenizer (inputs , return_tensors = "pt" , padding = True )
6132
62- with tempfile .TemporaryDirectory () as tmp :
63- # Load HF model, get logits, save weights, then delete to free memory
64- hf_model = AutoModelForCausalLM .from_pretrained (model_name , attn_implementation = "eager" , use_safetensors = True )
65- hf_outputs = hf_model (batch .input_ids , attention_mask = batch .attention_mask )
66- hf_logits = hf_outputs .logits .detach ().numpy ()
67- hf_model .save_pretrained (tmp , safe_serialization = True )
68- del hf_model , hf_outputs
69-
70- # Load our model from saved weights
71- model = load_model (tmp , model_name , config_cls , model_cls , mesh_axes )
33+ # Load HF model, get logits, then delete to free memory
34+ hf_model = AutoModelForCausalLM .from_pretrained (model_name , attn_implementation = "eager" , use_safetensors = True )
35+ hf_outputs = hf_model (batch .input_ids , attention_mask = batch .attention_mask )
36+ hf_logits = hf_outputs .logits .detach ().numpy ()
37+ del hf_model , hf_outputs
38+
39+ _ , model = load_model (
40+ model_name ,
41+ config_cls ,
42+ model_cls ,
43+ mesh_axes ,
44+ max_lora_adapters = 1 ,
45+ max_lora_rank = 1 ,
46+ gradient_checkpointing = False ,
47+ )
7248
73- # Get our logits via compute_logits
74- outputs = model (batch .input_ids .numpy (), attention_mask = batch .attention_mask .numpy ())
75- our_logits = np .asarray (model .compute_logits (outputs .last_hidden_state ))
49+ # Get our logits via compute_logits
50+ outputs = model (batch .input_ids .numpy (), attention_mask = batch .attention_mask .numpy ())
51+ our_logits = np .asarray (model .compute_logits (outputs .last_hidden_state ))
7652
77- np .testing .assert_allclose (our_logits , hf_logits , rtol = 3e-2 , atol = 3e-2 )
53+ np .testing .assert_allclose (our_logits , hf_logits , rtol = 3e-2 , atol = 3e-2 )
7854
7955
8056@pytest .mark .parametrize ("model_name,config_cls,model_cls,mesh_axes" , MODEL_PARAMS , ids = MODEL_IDS )
@@ -94,22 +70,18 @@ def test_chunked_logprobs(
9470 attention_mask = jnp .array (batch .attention_mask .numpy ())
9571 target_ids = jnp .roll (input_ids , - 1 , axis = 1 )
9672
97- with tempfile .TemporaryDirectory () as tmp :
98- # Save HF weights once
99- hf_model = AutoModelForCausalLM .from_pretrained (model_name , attn_implementation = "eager" , use_safetensors = True )
100- hf_model .save_pretrained (tmp , safe_serialization = True )
101- del hf_model
102-
103- # Load non-chunked model, compute logprobs, then delete
104- model = load_model (tmp , model_name , config_cls , model_cls , mesh_axes , loss_chunk_size = 0 )
105- outputs = model (input_ids , attention_mask = attention_mask )
106- logprobs_nonchunked = np .asarray (model .compute_logprobs (outputs .last_hidden_state , target_ids ))
107- del model , outputs
108-
109- # Load chunked model, compute logprobs
110- model = load_model (tmp , model_name , config_cls , model_cls , mesh_axes , loss_chunk_size = chunk_size )
111- outputs = model (input_ids , attention_mask = attention_mask )
112- logprobs_chunked = np .asarray (model .compute_logprobs (outputs .last_hidden_state , target_ids ))
73+ common_kwargs = dict (max_lora_adapters = 1 , max_lora_rank = 1 , gradient_checkpointing = False )
74+
75+ # Load non-chunked model, compute logprobs, then delete
76+ _ , model = load_model (model_name , config_cls , model_cls , mesh_axes , loss_chunk_size = 0 , ** common_kwargs )
77+ outputs = model (input_ids , attention_mask = attention_mask )
78+ logprobs_nonchunked = np .asarray (model .compute_logprobs (outputs .last_hidden_state , target_ids ))
79+ del model , outputs
80+
81+ # Load chunked model, compute logprobs
82+ _ , model = load_model (model_name , config_cls , model_cls , mesh_axes , loss_chunk_size = chunk_size , ** common_kwargs )
83+ outputs = model (input_ids , attention_mask = attention_mask )
84+ logprobs_chunked = np .asarray (model .compute_logprobs (outputs .last_hidden_state , target_ids ))
11385
11486 np .testing .assert_allclose (
11587 logprobs_chunked ,
0 commit comments