Skip to content

Commit 8438493

Browse files
erictang000claude
andcommitted
Switch tx model tests to AutoConfig + truncate layer_types under shrink
transformers 5.4 turned PreTrainedConfig into a @strict @DataClass with class validators. Two patterns broke under transformers 5.8: 1. `PretrainedConfig.from_pretrained(model_name)` no longer round-trips model-specific config fields. With rope_parameters + a missing max_position_embeddings, validate fails. Switch every test caller to `AutoConfig.from_pretrained` (mirroring the production-side fix already adopted from PR #1561 in skyrl/backends/jax.py). 2. validate_layer_type asserts `len(layer_types) == num_hidden_layers`. tests/tx/utils/test_models.py:create_test_model shrinks num_hidden_layers to 1 to keep the test cheap, but layer_types is inherited from the real Qwen3-0.6B config (28 entries) and the wrapping Qwen3Config validator then raises. Truncate layer_types alongside the num_hidden_layers override. Verified locally on the cpu jax suite (CI=true, CUDA hidden to match the GitHub Actions cpu environment): all previously-failing tests in test_deepseekv3.py, test_deepseekv3_lora_training.py, test_llama3_lora_training.py, test_qwen3.py, test_qwen3_config.py, and test_models.py now pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 837439f commit 8438493

8 files changed

Lines changed: 28 additions & 21 deletions

tests/tx/models/test_deepseekv3.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import pytest
88
import torch
99
from flax import nnx
10-
from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig
10+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
1111
from transformers.models.deepseek_v3.modeling_deepseek_v3 import (
1212
DeepseekV3MoE as HFDeepseekV3MoE,
1313
)
@@ -40,7 +40,7 @@ def test_deepseekv3(tp: int):
4040
with tempfile.TemporaryDirectory() as tmp:
4141
hf_model.save_pretrained(tmp, safe_serialization=True)
4242

43-
base_config = PretrainedConfig.from_pretrained(model_name)
43+
base_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
4444
config = DeepseekV3Config(base_config, max_lora_adapters=32, max_lora_rank=32, shard_attention_heads=True)
4545
# EP axis required for MoE expert sharding
4646
mesh = jax.make_mesh((1, 1, tp), ("fsdp", "ep", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 3)
@@ -87,7 +87,7 @@ def test_deepseekv3_moe_layer(ep: int, tp: int):
8787
hf_model = AutoModelForCausalLM.from_pretrained(
8888
model_name, attn_implementation="eager", use_safetensors=True, torch_dtype=torch.float32
8989
)
90-
base_config = PretrainedConfig.from_pretrained(model_name)
90+
base_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
9191
config = DeepseekV3Config(base_config, max_lora_adapters=0, max_lora_rank=0, shard_attention_heads=True)
9292

9393
# Initial deepseek layers don't have MoE
@@ -136,7 +136,7 @@ def test_deepseekv3_moe_layer_lora(ep: int, tp: int):
136136
hf_model = AutoModelForCausalLM.from_pretrained(
137137
model_name, attn_implementation="eager", use_safetensors=True, torch_dtype=torch.float32
138138
)
139-
base_config = PretrainedConfig.from_pretrained(model_name)
139+
base_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
140140
config = DeepseekV3Config(base_config, max_lora_adapters=3, max_lora_rank=4, shard_attention_heads=True)
141141

142142
hf_moe_layer = hf_model.model.layers[1].mlp
@@ -211,7 +211,7 @@ def test_deepseekv3_gradient_checkpointing():
211211
that gradient checkpointing works correctly with heterogeneous layer types.
212212
"""
213213
model_name = "yujiepan/deepseek-v3-tiny-random"
214-
base_config = PretrainedConfig.from_pretrained(model_name)
214+
base_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
215215

216216
batch_size, seq_len = 2, 8
217217
mesh = jax.make_mesh((1, 1, 1), ("fsdp", "ep", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 3)

tests/tx/models/test_deepseekv3_lora_training.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import optax
44
from flax import nnx
55
from huggingface_hub import snapshot_download
6-
from transformers import PretrainedConfig
6+
from transformers import AutoConfig
77

88
from skyrl.tinker.types import LoraConfig
99
from skyrl.tx.layers.lora import init_lora_adapter
@@ -19,7 +19,7 @@
1919

2020
def test_lora_training_moe_rank_normalized():
2121
base_model = "yujiepan/deepseek-v3-tiny-random"
22-
base_config = PretrainedConfig.from_pretrained(base_model, trust_remote_code=True)
22+
base_config = AutoConfig.from_pretrained(base_model, trust_remote_code=True)
2323
config = DeepseekV3Config(base_config, max_lora_adapters=5, max_lora_rank=32, shard_attention_heads=True)
2424

2525
checkpoint_path = snapshot_download(base_model, allow_patterns=["*.safetensors"])
@@ -100,7 +100,7 @@ def loss_for_lora(lora_params):
100100

101101
def test_lora_training_high_rank():
102102
base_model = "yujiepan/deepseek-v3-tiny-random"
103-
base_config = PretrainedConfig.from_pretrained(base_model, trust_remote_code=True)
103+
base_config = AutoConfig.from_pretrained(base_model, trust_remote_code=True)
104104
config = DeepseekV3Config(base_config, max_lora_adapters=5, max_lora_rank=32, shard_attention_heads=True)
105105

106106
checkpoint_path = snapshot_download(base_model, allow_patterns=["*.safetensors"])

tests/tx/models/test_llama3_lora_training.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import optax
44
from flax import nnx
55
from huggingface_hub import snapshot_download
6-
from transformers import PretrainedConfig
6+
from transformers import AutoConfig
77

88
from skyrl.tinker.types import LoraConfig
99
from skyrl.tx.layers.lora import init_lora_adapter
@@ -19,7 +19,7 @@
1919

2020
def test_lora_training():
2121
base_model = "unsloth/Llama-3.2-1B"
22-
base_config = PretrainedConfig.from_pretrained(base_model)
22+
base_config = AutoConfig.from_pretrained(base_model)
2323
config = Llama3Config(base_config, max_lora_adapters=5, max_lora_rank=32, shard_attention_heads=True)
2424

2525
checkpoint_path = snapshot_download(base_model, allow_patterns=["*.safetensors"])

tests/tx/models/test_qwen3.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
from flax import nnx
99
from peft import LoraConfig, get_peft_model
10-
from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig
10+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
1111
from transformers.models.qwen3_moe.modeling_qwen3_moe import (
1212
Qwen3MoeSparseMoeBlock as HFQwen3MoeSparseMoeBlock,
1313
)
@@ -71,7 +71,7 @@ def test_qwen3_moe_layer(ep: int, tp: int):
7171
hf_model = AutoModelForCausalLM.from_pretrained(
7272
model_name, attn_implementation="eager", use_safetensors=True, torch_dtype=torch.float32
7373
)
74-
base_config = PretrainedConfig.from_pretrained(model_name)
74+
base_config = AutoConfig.from_pretrained(model_name)
7575
config = Qwen3Config(base_config, max_lora_adapters=0, max_lora_rank=0, shard_attention_heads=True)
7676

7777
hf_moe_layer = hf_model.model.layers[0].mlp
@@ -126,7 +126,7 @@ def test_qwen3_moe_layer_lora(ep: int, tp: int):
126126
hf_model = AutoModelForCausalLM.from_pretrained(
127127
model_name, attn_implementation="eager", use_safetensors=True, torch_dtype=torch.float32
128128
)
129-
base_config = PretrainedConfig.from_pretrained(model_name)
129+
base_config = AutoConfig.from_pretrained(model_name)
130130
config = Qwen3Config(base_config, max_lora_adapters=3, max_lora_rank=4, shard_attention_heads=True)
131131

132132
hf_moe_layer = hf_model.model.layers[0].mlp

tests/tx/models/test_qwen3_5_lora_training.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import optax
44
from flax import nnx
55
from huggingface_hub import snapshot_download
6-
from transformers import PretrainedConfig
6+
from transformers import AutoConfig
77

88
from skyrl.tinker.types import LoraConfig
99
from skyrl.tx.layers.lora import init_lora_adapter
@@ -19,7 +19,7 @@
1919

2020
def test_lora_training():
2121
base_model = "Qwen/Qwen3.5-0.8B"
22-
base_config = PretrainedConfig.from_pretrained(base_model)
22+
base_config = AutoConfig.from_pretrained(base_model)
2323
config = Qwen3_5Config(base_config, max_lora_adapters=5, max_lora_rank=32, shard_attention_heads=True)
2424

2525
checkpoint_path = snapshot_download(base_model, allow_patterns=["*.safetensors"])

tests/tx/models/test_qwen3_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
"""Tests for Qwen3Config."""
22

3-
from transformers import PretrainedConfig
3+
from transformers import AutoConfig
44

55
from skyrl.tx.models.configs import Qwen3Config
66

77

88
def test_config_wraps_pretrained_config():
99
"""Test that Qwen3Config wraps a PretrainedConfig and adds LoRA params."""
10-
hf_config = PretrainedConfig.from_pretrained("Qwen/Qwen3-0.6B")
10+
hf_config = AutoConfig.from_pretrained("Qwen/Qwen3-0.6B")
1111
config = Qwen3Config(hf_config, max_lora_adapters=8, max_lora_rank=16, shard_attention_heads=False)
1212

1313
# Check LoRA params were set
@@ -23,7 +23,7 @@ def test_config_wraps_pretrained_config():
2323

2424
def test_config_preserves_moe_config():
2525
"""Test that MoE-specific configs are preserved."""
26-
hf_config = PretrainedConfig.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForCausalLM")
26+
hf_config = AutoConfig.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForCausalLM")
2727
config = Qwen3Config(hf_config, max_lora_adapters=3, max_lora_rank=4, shard_attention_heads=True)
2828

2929
# Check that MoE-specific attributes are preserved

tests/tx/models/test_qwen3_generate.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import pytest
99
import torch
1010
from flax import nnx
11-
from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig
11+
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
1212

1313
from skyrl.tinker import types
1414
from skyrl.tx.models.configs import Qwen3Config
@@ -45,7 +45,7 @@ def test_qwen3_generate():
4545
# Generate with our implementation (batched with right-padding)
4646
with tempfile.TemporaryDirectory() as tmp:
4747
hf_model.save_pretrained(tmp, safe_serialization=True)
48-
base_config = PretrainedConfig.from_pretrained(model_name)
48+
base_config = AutoConfig.from_pretrained(model_name)
4949
config = Qwen3Config(base_config, max_lora_adapters=2, max_lora_rank=32, shard_attention_heads=True)
5050

5151
mesh = jax.make_mesh((1, 1), ("fsdp", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 2)
@@ -131,7 +131,7 @@ def test_qwen3_generate_speed():
131131
hf_model = AutoModelForCausalLM.from_pretrained(
132132
model_name, attn_implementation="eager", use_safetensors=True, torch_dtype=torch.float32
133133
)
134-
base_config = PretrainedConfig.from_pretrained(model_name)
134+
base_config = AutoConfig.from_pretrained(model_name)
135135
config = Qwen3Config(base_config, max_lora_adapters=32, max_lora_rank=32, shard_attention_heads=True)
136136

137137
inputs = [

tests/tx/utils/test_models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@ def create_test_model(base_model_name: str, rank: int, alpha: int, adapter_index
3434
base_config.intermediate_size = 128
3535
base_config.num_attention_heads = 2
3636
base_config.num_key_value_heads = 2
37+
# transformers >=5.4 has a strict validator (validate_layer_type) that
38+
# asserts len(layer_types) == num_hidden_layers when layer_types is set.
39+
# When we shrink num_hidden_layers above, also truncate layer_types so
40+
# validation still passes.
41+
layer_types = getattr(base_config, "layer_types", None)
42+
if layer_types is not None:
43+
base_config.layer_types = list(layer_types[: base_config.num_hidden_layers])
3744

3845
config = Qwen3Config(base_config, max_lora_adapters=5, max_lora_rank=32, shard_attention_heads=True)
3946

0 commit comments

Comments
 (0)