Skip to content

Commit 7f54e45

Browse files
committed
feat: add qwen3.5 moe
1 parent 1dfe18b commit 7f54e45

7 files changed

Lines changed: 235 additions & 88 deletions

File tree

src/backend/server/static_config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,11 @@
8484
"Qwen/Qwen3-Next-80B-A3B-Instruct-FP8": "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit",
8585
"Qwen/Qwen3-Next-80B-A3B-Thinking": "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit",
8686
"Qwen/Qwen3-Next-80B-A3B-Thinking-FP8": "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit",
87+
# Qwen 3.5 MoE Series
88+
"Qwen/Qwen3.5-0.8B": "Qwen/Qwen3.5-0.8B",
89+
"Qwen/Qwen3.5-35B-A3B": "mlx-community/Qwen3.5-35B-A3B-4bit",
8790
# Qwen 3.6 Series
91+
"Qwen/Qwen3.6-35B-A3B": "mlx-community/Qwen3.6-35B-A3B-4bit",
8892
"Qwen/Qwen3.6-27B": "mlx-community/Qwen3.6-27B-mxfp4",
8993
# Qwen 3 Large MoE Models
9094
"Qwen/Qwen3-235B-A22B-Instruct-2507-FP8": "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit",

src/parallax/server/shard_loader.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,12 @@
3333
MODEL_CLASS_MAP = {
3434
"kimi_k2": "mlx_lm.models.deepseek_v3",
3535
"minimax_m2": "mlx_lm.models.minimax",
36+
"qwen3_5_moe": "mlx_lm.models.qwen3_5",
3637
}
3738

3839
ARCHITECTURE_CLASS_ALIASES = {
3940
"GlmMoeDsaForCausalLM": "DeepseekV32ForCausalLM",
41+
"Qwen3_5MoeForConditionalGeneration": "Qwen3_5ForConditionalGeneration",
4042
}
4143

4244

@@ -310,6 +312,21 @@ def _cast_weight_array(weight_array: mx.array, dtype: mx.Dtype) -> mx.array:
310312
return weight_array.astype(dtype)
311313
return weight_array
312314

315+
@staticmethod
316+
def _load_mlx_lm_module_and_args(model_type: str, config: Dict[str, Any]):
317+
if model_type in MODEL_CLASS_MAP:
318+
model_class = MODEL_CLASS_MAP[model_type]
319+
else:
320+
model_class = f"mlx_lm.models.{model_type}"
321+
322+
arch_module = importlib.import_module(model_class)
323+
if hasattr(arch_module, "TextModelArgs"):
324+
model_args_class = getattr(arch_module, "TextModelArgs")
325+
else:
326+
model_args_class = getattr(arch_module, "ModelArgs")
327+
328+
return arch_module, model_args_class.from_dict(config)
329+
313330
def load(
314331
self, lazy: bool = False, strict: bool = False, use_selective_download: bool = True
315332
) -> Tuple[nn.Module, Dict[str, Any], Any]:
@@ -366,18 +383,8 @@ def load(
366383
if not model_type:
367384
raise ValueError("model_type not found in config.json")
368385

369-
if model_type in MODEL_CLASS_MAP:
370-
model_class = MODEL_CLASS_MAP[model_type]
371-
else:
372-
model_class = f"mlx_lm.models.{model_type}"
373-
374386
try:
375-
arch_module = importlib.import_module(model_class)
376-
if model_type == "qwen3_5" and hasattr(arch_module, "TextModelArgs"):
377-
model_args_class = getattr(arch_module, "TextModelArgs")
378-
else:
379-
model_args_class = getattr(arch_module, "ModelArgs")
380-
model_args = model_args_class.from_dict(config)
387+
arch_module, model_args = self._load_mlx_lm_module_and_args(model_type, config)
381388
self.arch_module = arch_module
382389
self.model_args = model_args
383390

src/parallax/utils/model_download.py

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
1+
import json
12
import logging
23
import os
34
from pathlib import Path
4-
from typing import Optional
5+
from typing import Dict, List, Optional, Set
56

67
from huggingface_hub import hf_hub_download as _hf_hub_download
78
from huggingface_hub import snapshot_download as _snapshot_download
89
from modelscope import snapshot_download as _ms_snapshot_download
910
from modelscope.hub.file_download import model_file_download as _ms_model_file_download
1011

1112
from parallax.utils.weight_filter_utils import (
12-
determine_needed_weight_files_for_download,
13+
normalize_language_model_weight_key,
14+
should_include_weight_key,
1315
)
1416

1517
logger = logging.getLogger(__name__)
@@ -96,7 +98,7 @@ def selective_model_download(
9698
if start_layer is not None and end_layer is not None:
9799
logger.debug(f"Determining required weight files for layers [{start_layer}, {end_layer})")
98100

99-
needed_weight_files = determine_needed_weight_files_for_download(
101+
needed_weight_files = _determine_needed_weight_files_for_download(
100102
model_path=model_path,
101103
start_layer=start_layer,
102104
end_layer=end_layer,
@@ -149,5 +151,80 @@ def selective_model_download(
149151
]
150152

151153

154+
def _determine_needed_weight_files_for_download(
155+
model_path: Path,
156+
start_layer: int,
157+
end_layer: int,
158+
config: Optional[Dict] = None,
159+
) -> List[str]:
160+
is_first_shard = start_layer == 0
161+
162+
is_last_shard = False
163+
if config:
164+
num_hidden_layers = config.get("num_hidden_layers", 0)
165+
is_last_shard = end_layer >= num_hidden_layers
166+
else:
167+
config_file = model_path / "config.json"
168+
if config_file.exists():
169+
from parallax.utils.utils import normalize_model_config
170+
171+
with open(config_file, "r") as f:
172+
cfg = normalize_model_config(json.load(f))
173+
num_hidden_layers = cfg.get("num_hidden_layers", 0)
174+
is_last_shard = end_layer >= num_hidden_layers
175+
176+
index_file = model_path / "model.safetensors.index.json"
177+
178+
if not index_file.exists():
179+
logger.debug(f"Index file not found at {index_file}, checking for single weight file")
180+
# For non-sharded models, look for single weight file
181+
single_weight_files = [
182+
"model.safetensors",
183+
"pytorch_model.bin",
184+
"model.bin",
185+
]
186+
for weight_file in single_weight_files:
187+
if (model_path / weight_file).exists():
188+
logger.debug(f"Found single weight file: {weight_file}")
189+
return [weight_file]
190+
191+
logger.debug("No weight files found (neither index nor single file)")
192+
return []
193+
194+
with open(index_file, "r") as f:
195+
index_data = json.load(f)
196+
197+
weight_map = index_data.get("weight_map", {})
198+
if not weight_map:
199+
logger.debug("weight_map is empty in index file")
200+
return []
201+
202+
tie_word_embeddings = False
203+
if config:
204+
tie_word_embeddings = config.get("tie_word_embeddings", False)
205+
206+
needed_files: Set[str] = set()
207+
208+
for key, filename in weight_map.items():
209+
if filename in needed_files:
210+
continue
211+
key = normalize_language_model_weight_key(key)
212+
if should_include_weight_key(
213+
key=key,
214+
start_layer=start_layer,
215+
end_layer=end_layer,
216+
is_first_shard=is_first_shard,
217+
is_last_shard=is_last_shard,
218+
tie_word_embeddings=tie_word_embeddings,
219+
):
220+
needed_files.add(filename)
221+
222+
result = sorted(list(needed_files))
223+
logger.debug(
224+
f"Determined {len(result)} weight files needed for layers [{start_layer}, {end_layer})"
225+
)
226+
return result
227+
228+
152229
def _use_modelscope() -> bool:
153230
return _USE_MODELSCOPE_ENV in os.environ

src/parallax/utils/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,9 @@ def load_config_only(name: str, local_files_only: bool = False):
303303
def normalize_model_config(config: dict) -> dict:
304304
"""Expose nested text model fields at the top level for VLM-style configs."""
305305
text_config = config.get("text_config")
306-
if config.get("model_type") == "qwen3_5" and isinstance(text_config, dict):
306+
if config.get("model_type") in {"qwen3_5", "qwen3_5_moe"} and isinstance(
307+
text_config, dict
308+
):
307309
normalized = {**config, **text_config}
308310
normalized["model_type"] = config["model_type"]
309311
normalized["architectures"] = config.get("architectures", normalized.get("architectures"))

src/parallax/utils/weight_filter_utils.py

Lines changed: 0 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -116,76 +116,3 @@ def filter_weight_files_by_layer_range_for_load(
116116
)
117117

118118
return filtered_files
119-
120-
121-
def determine_needed_weight_files_for_download(
122-
model_path: Path,
123-
start_layer: int,
124-
end_layer: int,
125-
config: Optional[Dict] = None,
126-
) -> List[str]:
127-
is_first_shard = start_layer == 0
128-
129-
is_last_shard = False
130-
if config:
131-
num_hidden_layers = config.get("num_hidden_layers", 0)
132-
is_last_shard = end_layer >= num_hidden_layers
133-
else:
134-
config_file = model_path / "config.json"
135-
if config_file.exists():
136-
with open(config_file, "r") as f:
137-
cfg = json.load(f)
138-
num_hidden_layers = cfg.get("num_hidden_layers", 0)
139-
is_last_shard = end_layer >= num_hidden_layers
140-
141-
index_file = model_path / "model.safetensors.index.json"
142-
143-
if not index_file.exists():
144-
logger.debug(f"Index file not found at {index_file}, checking for single weight file")
145-
# For non-sharded models, look for single weight file
146-
single_weight_files = [
147-
"model.safetensors",
148-
"pytorch_model.bin",
149-
"model.bin",
150-
]
151-
for weight_file in single_weight_files:
152-
if (model_path / weight_file).exists():
153-
logger.debug(f"Found single weight file: {weight_file}")
154-
return [weight_file]
155-
156-
logger.debug("No weight files found (neither index nor single file)")
157-
return []
158-
159-
with open(index_file, "r") as f:
160-
index_data = json.load(f)
161-
162-
weight_map = index_data.get("weight_map", {})
163-
if not weight_map:
164-
logger.debug("weight_map is empty in index file")
165-
return []
166-
167-
tie_word_embeddings = False
168-
if config:
169-
tie_word_embeddings = config.get("tie_word_embeddings", False)
170-
171-
needed_files: Set[str] = set()
172-
173-
for key, filename in weight_map.items():
174-
if filename in needed_files:
175-
continue
176-
key = normalize_language_model_weight_key(key)
177-
if should_include_weight_key(
178-
key=key,
179-
start_layer=start_layer,
180-
end_layer=end_layer,
181-
is_first_shard=is_first_shard,
182-
is_last_shard=is_last_shard,
183-
tie_word_embeddings=tie_word_embeddings,
184-
):
185-
needed_files.add(filename)
186-
187-
result = sorted(list(needed_files))
188-
logger.debug(
189-
f"Determined {len(result)} weight files needed for layers [{start_layer}, {end_layer})"
190-
)
191-
return result

tests/test_shard_loader.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
Tests for the shard_loader module.
33
"""
44

5+
import json
56
import sys
67
from unittest.mock import Mock, patch
78

@@ -14,6 +15,8 @@
1415
MLXModelLoader,
1516
normalize_language_model_weight_key,
1617
)
18+
from parallax.utils.model_download import _determine_needed_weight_files_for_download
19+
from parallax.utils.utils import normalize_model_config
1720
from parallax.utils.weight_filter_utils import should_include_weight_key
1821

1922

@@ -27,10 +30,20 @@ def test_normalize_nested_language_model_weight_keys():
2730
normalize_language_model_weight_key("model.language_model.layers.12.mlp.up_proj.weight")
2831
== "model.layers.12.mlp.up_proj.weight"
2932
)
33+
assert (
34+
normalize_language_model_weight_key(
35+
"language_model.model.layers.12.mlp.switch_mlp.up_proj.weight"
36+
)
37+
== "model.layers.12.mlp.switch_mlp.up_proj.weight"
38+
)
3039
assert (
3140
normalize_language_model_weight_key("model.language_model.norm.weight")
3241
== "model.norm.weight"
3342
)
43+
assert (
44+
normalize_language_model_weight_key("language_model.model.norm.weight")
45+
== "model.norm.weight"
46+
)
3447
assert (
3548
normalize_language_model_weight_key("model.language_model.lm_head.weight")
3649
== "lm_head.weight"
@@ -111,6 +124,80 @@ def test_mlx_lm_sanitize_uses_local_layer_keys_for_shards():
111124
) == ["layers.0.linear_attn.conv1d.weight"]
112125

113126

127+
def test_qwen35_moe_uses_qwen35_text_args_and_sanitizer_module():
128+
loader = MLXModelLoader("test_model_path")
129+
config = normalize_model_config(
130+
{
131+
"model_type": "qwen3_5_moe",
132+
"architectures": ["Qwen3_5MoeForConditionalGeneration"],
133+
"text_config": {
134+
"model_type": "qwen3_5_moe_text",
135+
"hidden_size": 2048,
136+
"num_hidden_layers": 40,
137+
"num_attention_heads": 16,
138+
"num_key_value_heads": 2,
139+
"vocab_size": 248320,
140+
"num_experts": 256,
141+
"num_experts_per_tok": 8,
142+
"moe_intermediate_size": 512,
143+
},
144+
}
145+
)
146+
147+
sanitizer_module, model_args = loader._load_mlx_lm_module_and_args(
148+
"qwen3_5_moe", config
149+
)
150+
151+
assert MODEL_CLASS_MAP["qwen3_5_moe"] == "mlx_lm.models.qwen3_5"
152+
assert sanitizer_module.__name__ == "mlx_lm.models.qwen3_5"
153+
assert model_args.num_hidden_layers == 40
154+
assert model_args.hidden_size == 2048
155+
assert model_args.num_experts == 256
156+
assert model_args.num_experts_per_tok == 8
157+
assert model_args.moe_intermediate_size == 512
158+
159+
160+
def test_register_block_class_includes_qwen35_moe():
161+
loader = MLXModelLoader("test_model_path")
162+
163+
assert "Qwen3_5MoeForConditionalGeneration" in loader.block_class_map
164+
165+
166+
def test_selective_download_uses_nested_qwen35_moe_num_layers(tmp_path):
167+
(tmp_path / "config.json").write_text(
168+
json.dumps(
169+
{
170+
"model_type": "qwen3_5_moe",
171+
"text_config": {
172+
"num_hidden_layers": 40,
173+
"tie_word_embeddings": False,
174+
},
175+
}
176+
)
177+
)
178+
(tmp_path / "model.safetensors.index.json").write_text(
179+
json.dumps(
180+
{
181+
"weight_map": {
182+
"language_model.model.layers.39.linear_attn.in_proj_qkv.weight": (
183+
"layers-39.safetensors"
184+
),
185+
"language_model.model.norm.weight": "final.safetensors",
186+
"language_model.lm_head.weight": "final.safetensors",
187+
}
188+
}
189+
)
190+
)
191+
192+
needed_files = _determine_needed_weight_files_for_download(
193+
tmp_path,
194+
start_layer=39,
195+
end_layer=40,
196+
)
197+
198+
assert needed_files == ["final.safetensors", "layers-39.safetensors"]
199+
200+
114201
@pytest.mark.skipif(sys.platform != "darwin", reason="MLX tests require macOS")
115202
class TestMLXModelLoader:
116203
"""Test MLXModelLoader functionality."""

0 commit comments

Comments
 (0)