|
7 | 7 | import pytest |
8 | 8 | import torch |
9 | 9 | from flax import nnx |
10 | | -from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig |
| 10 | +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer |
11 | 11 | from transformers.models.deepseek_v3.modeling_deepseek_v3 import ( |
12 | 12 | DeepseekV3MoE as HFDeepseekV3MoE, |
13 | 13 | ) |
@@ -40,7 +40,7 @@ def test_deepseekv3(tp: int): |
40 | 40 | with tempfile.TemporaryDirectory() as tmp: |
41 | 41 | hf_model.save_pretrained(tmp, safe_serialization=True) |
42 | 42 |
|
43 | | - base_config = PretrainedConfig.from_pretrained(model_name) |
| 43 | + base_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) |
44 | 44 | config = DeepseekV3Config(base_config, max_lora_adapters=32, max_lora_rank=32, shard_attention_heads=True) |
45 | 45 | # EP axis required for MoE expert sharding |
46 | 46 | mesh = jax.make_mesh((1, 1, tp), ("fsdp", "ep", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 3) |
@@ -87,7 +87,7 @@ def test_deepseekv3_moe_layer(ep: int, tp: int): |
87 | 87 | hf_model = AutoModelForCausalLM.from_pretrained( |
88 | 88 | model_name, attn_implementation="eager", use_safetensors=True, torch_dtype=torch.float32 |
89 | 89 | ) |
90 | | - base_config = PretrainedConfig.from_pretrained(model_name) |
| 90 | + base_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) |
91 | 91 | config = DeepseekV3Config(base_config, max_lora_adapters=0, max_lora_rank=0, shard_attention_heads=True) |
92 | 92 |
|
93 | 93 | # Initial deepseek layers don't have MoE |
@@ -136,7 +136,7 @@ def test_deepseekv3_moe_layer_lora(ep: int, tp: int): |
136 | 136 | hf_model = AutoModelForCausalLM.from_pretrained( |
137 | 137 | model_name, attn_implementation="eager", use_safetensors=True, torch_dtype=torch.float32 |
138 | 138 | ) |
139 | | - base_config = PretrainedConfig.from_pretrained(model_name) |
| 139 | + base_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) |
140 | 140 | config = DeepseekV3Config(base_config, max_lora_adapters=3, max_lora_rank=4, shard_attention_heads=True) |
141 | 141 |
|
142 | 142 | hf_moe_layer = hf_model.model.layers[1].mlp |
@@ -211,7 +211,7 @@ def test_deepseekv3_gradient_checkpointing(): |
211 | 211 | that gradient checkpointing works correctly with heterogeneous layer types. |
212 | 212 | """ |
213 | 213 | model_name = "yujiepan/deepseek-v3-tiny-random" |
214 | | - base_config = PretrainedConfig.from_pretrained(model_name) |
| 214 | + base_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) |
215 | 215 |
|
216 | 216 | batch_size, seq_len = 2, 8 |
217 | 217 | mesh = jax.make_mesh((1, 1, 1), ("fsdp", "ep", "tp"), axis_types=(jax.sharding.AxisType.Auto,) * 3) |
|
0 commit comments