Skip to content

Commit 06f72c1

Browse files
committed
feat(model): add qwen3.5
1 parent 734be12 commit 06f72c1

9 files changed

Lines changed: 374 additions & 35 deletions

File tree

AGENTS.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
mac env:
2+
python: .venv/bin/python
3+
4+
gpu env:
5+
ssh H200 server, then use conda activate parallax
6+
7+
8+
start a local server:
9+
python src/parallax/launch.py --model-path <MODEL_NAME> --log-level DEBUG
10+
11+
test the server:
12+
curl --location 'http://localhost:3000/v1/chat/completions' --header 'Content-Type: application/json' --data '{
13+
"max_tokens": 1024,
14+
"messages": [
15+
{
16+
"role": "user",
17+
"content": "hello"
18+
}
19+
],
20+
"stream": true
21+
}'

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ dependencies = [
2525
"numpy>=1.26",
2626
"pyzmq>=25.0",
2727
"psutil>=5.9.5",
28+
"requests",
2829
"httpx[socks]>=0.26.0",
2930
"aiohttp",
3031
"uvicorn",

src/backend/server/static_config.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@
8484
"Qwen/Qwen3-Next-80B-A3B-Instruct-FP8": "mlx-community/Qwen3-Next-80B-A3B-Instruct-8bit",
8585
"Qwen/Qwen3-Next-80B-A3B-Thinking": "mlx-community/Qwen3-Next-80B-A3B-Thinking-4bit",
8686
"Qwen/Qwen3-Next-80B-A3B-Thinking-FP8": "mlx-community/Qwen3-Next-80B-A3B-Thinking-8bit",
87+
# Qwen 3.6 Series
88+
"Qwen/Qwen3.6-27B": "mlx-community/Qwen3.6-27B-mxfp4",
8789
# Qwen 3 Large MoE Models
8890
"Qwen/Qwen3-235B-A22B-Instruct-2507-FP8": "mlx-community/Qwen3-235B-A22B-Instruct-2507-8bit",
8991
"Qwen/Qwen3-235B-A22B-Thinking-2507-FP8": "mlx-community/Qwen3-235B-A22B-Thinking-2507-8bit",
@@ -100,34 +102,39 @@
100102
NODE_JOIN_COMMAND_PUBLIC_NETWORK = """parallax join -s {scheduler_addr} """
101103

102104

103-
def get_model_info(model_name, use_hfcache: bool = False):
104-
config = load_config_only(model_name, local_files_only=use_hfcache)
105-
105+
def get_param_bytes_per_element(config, model_name: str) -> float:
106106
quant_method = config.get("quant_method", None)
107-
quantization_config = config.get("quantization_config", None)
107+
quantization_config = config.get("quantization_config") or config.get("quantization")
108108
if quant_method is None and quantization_config is not None:
109-
quant_method = quantization_config.get("quant_method", None)
109+
quant_method = quantization_config.get("quant_method") or quantization_config.get("mode")
110+
111+
if quantization_config is not None and quantization_config.get("bits") is not None:
112+
return quantization_config["bits"] / 8
110113

111114
if quant_method is None:
112-
param_bytes_per_element = 2
115+
return 2
113116
elif quant_method == "fp8":
114-
param_bytes_per_element = 1
117+
return 1
115118
elif quant_method in ("mxfp4", "int4", "awq", "gptq", "compressed-tensors"):
116-
param_bytes_per_element = 0.5
119+
return 0.5
117120
else:
118-
param_bytes_per_element = 1
119121
logger.warning(
120122
f"model_name:{model_name} quant_method {quant_method} not supported in get_model_info method"
121123
)
124+
return 1
125+
126+
127+
def get_model_info(model_name, use_hfcache: bool = False):
128+
config = load_config_only(model_name, local_files_only=use_hfcache)
129+
130+
param_bytes_per_element = get_param_bytes_per_element(config, model_name)
122131

123132
mlx_param_bytes_per_element = param_bytes_per_element
124133
mlx_model_name = MODELS.get(model_name, model_name)
125134

126135
if mlx_model_name != model_name:
127136
mlx_config = load_config_only(mlx_model_name, local_files_only=use_hfcache)
128-
mlx_quant_dict = mlx_config.get("quantization_config", None)
129-
if mlx_quant_dict and "bits" in mlx_quant_dict:
130-
mlx_param_bytes_per_element = mlx_quant_dict["bits"] / 8
137+
mlx_param_bytes_per_element = get_param_bytes_per_element(mlx_config, mlx_model_name)
131138

132139
# get local experts
133140
num_local_experts = config.get("num_local_experts", None)

src/parallax/models/qwen3_5.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
"""
2+
Defines the Qwen3.5 text block for Parallax.
3+
"""
4+
5+
from typing import Any, List, Optional
6+
7+
import mlx.core as mx
8+
import mlx.nn as nn
9+
from mlx_lm.models.gated_delta import gated_delta_update
10+
from mlx_lm.models.qwen3_5 import DecoderLayer as MLXQwen35Block
11+
from mlx_lm.models.qwen3_5 import GatedDeltaNet as MLXQwen35GatedDeltaNet
12+
from mlx_lm.models.qwen3_5 import TextModelArgs
13+
14+
from parallax.models.qwen3_next import ParallaxQwen3NextAttention
15+
from parallax.server.cache.base import BaseCache
16+
17+
18+
class ParallaxQwen35GatedDeltaNet(MLXQwen35GatedDeltaNet):
19+
def __call__(
20+
self,
21+
x: mx.array,
22+
cache: Optional[BaseCache] = None,
23+
state_slot_mapping: Optional[mx.array] = None,
24+
**kwargs,
25+
):
26+
batch, target_len, _ = x.shape
27+
28+
qkv = self.in_proj_qkv(x)
29+
z = self.in_proj_z(x).reshape(batch, target_len, self.num_v_heads, self.head_v_dim)
30+
b = self.in_proj_b(x)
31+
a = self.in_proj_a(x)
32+
33+
if target_len == 1:
34+
conv_state, state = cache.read_states(state_slot_mapping)
35+
else:
36+
conv_state = mx.zeros(
37+
(batch, self.conv_kernel_size - 1, self.conv_dim),
38+
dtype=x.dtype,
39+
)
40+
state = None
41+
42+
conv_input = mx.concatenate([conv_state, qkv], axis=1)
43+
next_conv_state = conv_input[:, -(self.conv_kernel_size - 1) :]
44+
conv_out = nn.silu(self.conv1d(conv_input))
45+
46+
q, k, v = [
47+
t.reshape(batch, target_len, h, d)
48+
for t, h, d in zip(
49+
mx.split(conv_out, [self.key_dim, 2 * self.key_dim], -1),
50+
[self.num_k_heads, self.num_k_heads, self.num_v_heads],
51+
[self.head_k_dim, self.head_k_dim, self.head_v_dim],
52+
)
53+
]
54+
55+
inv_scale = k.shape[-1] ** -0.5
56+
q = (inv_scale**2) * mx.fast.rms_norm(q, None, 1e-6)
57+
k = inv_scale * mx.fast.rms_norm(k, None, 1e-6)
58+
59+
out, state = gated_delta_update(
60+
q,
61+
k,
62+
v,
63+
a,
64+
b,
65+
self.A_log,
66+
self.dt_bias,
67+
state,
68+
use_kernel=not self.training,
69+
)
70+
71+
cache.write_states(state_slot_mapping, next_conv_state, state)
72+
73+
out = self.norm(out, z)
74+
return self.out_proj(out.reshape(batch, target_len, -1))
75+
76+
77+
class ParallaxQwen35Block(MLXQwen35Block):
78+
def __init__(self, args: TextModelArgs, layer_idx: int, local_layer_idx: int):
79+
super().__init__(args, layer_idx)
80+
self.layer_idx = layer_idx
81+
self.local_layer_idx = local_layer_idx
82+
if self.is_linear:
83+
self.linear_attn = ParallaxQwen35GatedDeltaNet(args)
84+
else:
85+
self.self_attn = ParallaxQwen3NextAttention(args)
86+
87+
def __call__(
88+
self,
89+
x: mx.array,
90+
mask: Optional[mx.array] = None,
91+
cache: Optional[List[Any]] = None,
92+
block_tables: Optional[mx.array] = None,
93+
context_lengths: Optional[mx.array] = None,
94+
slot_mapping: Optional[mx.array] = None,
95+
**kwargs,
96+
):
97+
if self.is_linear:
98+
state_slot_mapping = kwargs.pop("state_slot_mapping", None)
99+
r = self.linear_attn(
100+
self.input_layernorm(x),
101+
cache[self.local_layer_idx],
102+
state_slot_mapping,
103+
**kwargs,
104+
)
105+
else:
106+
r = self.self_attn(
107+
self.input_layernorm(x),
108+
mask,
109+
cache[self.local_layer_idx],
110+
block_tables=block_tables,
111+
context_lengths=context_lengths,
112+
slot_mapping=slot_mapping,
113+
**kwargs,
114+
)
115+
h = x + r
116+
return h + self.mlp(self.post_attention_layernorm(h))
117+
118+
@classmethod
119+
def get_architecture(cls):
120+
return "Qwen3_5ForConditionalGeneration"
121+
122+
123+
EntryClass = ParallaxQwen35Block

src/parallax/server/http_server.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import uvicorn
2929
import zmq
3030
import zmq.asyncio
31-
from fastapi.responses import ORJSONResponse, StreamingResponse
31+
from fastapi.responses import JSONResponse, StreamingResponse
3232
from mlx_lm.tokenizer_utils import StreamingDetokenizer
3333
from mlx_lm.utils import load_config
3434
from pydantic import BaseModel
@@ -101,6 +101,7 @@ class HTTPRequestInfo:
101101
# tool calling support
102102
tool_state: Optional[ToolCallState] = None
103103
tool_calls: List[Dict[str, Any]] = field(default_factory=list)
104+
enable_thinking: bool = True
104105

105106

106107
class HTTPHandler:
@@ -137,12 +138,29 @@ def __init__(
137138
self.tokenizer = load_tokenizer(model_path, eos_token_ids=config.get("eos_token_id", None))
138139
self.detokenizer_class, self.tokenmap = load_detokenizer(model_path, self.tokenizer)
139140

141+
@staticmethod
142+
def _is_thinking_enabled(request: Dict) -> bool:
143+
chat_template_kwargs = dict(request.get("chat_template_kwargs", {}))
144+
extra_body = request.get("extra_body")
145+
if isinstance(extra_body, dict) and "chat_template_kwargs" in extra_body:
146+
chat_template_kwargs.update(extra_body["chat_template_kwargs"])
147+
return chat_template_kwargs.get("enable_thinking") is not False
148+
149+
def _get_initial_assistant_content(self, request_info: HTTPRequestInfo) -> str:
150+
model_path = self.model_path_str.lower()
151+
if "minimax-m2" in model_path:
152+
return "<think>"
153+
if "qwen3.6" in model_path and request_info.enable_thinking:
154+
return "<think>"
155+
return ""
156+
140157
def create_request(self, request: Dict):
141158
"""Creates a new request information"""
142159
rid = request["rid"]
143160
stream = request.get("stream", False)
144161
model = request.get("model", "default")
145162
return_probs = request.get("return_probs", False) # Check if probs requested
163+
enable_thinking = self._is_thinking_enabled(request)
146164
chat_object = "chat.completion.chunk" if stream else "chat.completion"
147165
detokenizer = self.detokenizer_class(self.tokenizer, self.tokenmap)
148166
create_time = time.time()
@@ -156,6 +174,7 @@ def create_request(self, request: Dict):
156174
update_time=update_time,
157175
detokenizer=detokenizer,
158176
return_probs=return_probs,
177+
enable_thinking=enable_thinking,
159178
)
160179
request_info.tool_state = ToolCallState.from_tokenizer(
161180
self.tokenizer, request.get("tools"), stream
@@ -206,9 +225,7 @@ def _generate_stream_chunk(self, rid, token, is_first=False, is_last=False):
206225

207226
if is_first:
208227
role = "assistant"
209-
content = ""
210-
if "minimax-m2" in self.model_path_str.lower():
211-
content = "<think>"
228+
content = self._get_initial_assistant_content(request_info)
212229
tool_calls = None
213230
elif is_last:
214231
role = None
@@ -318,7 +335,7 @@ def generate_non_stream_response(self, rid):
318335
choice = response["choices"][0]
319336
choice["message"] = {
320337
"role": "assistant",
321-
"content": request_info.text,
338+
"content": self._get_initial_assistant_content(request_info) + request_info.text,
322339
"reasoning_content": None,
323340
"tool_calls": request_info.tool_calls or None,
324341
}
@@ -464,7 +481,7 @@ def create_error_response(
464481
):
465482
"""Creates a json error response for the frontend."""
466483
error = ErrorResponse(message=message, type=err_type, code=status_code.value)
467-
return ORJSONResponse(content=error.model_dump(), status_code=error.code)
484+
return JSONResponse(content=error.model_dump(), status_code=error.code)
468485

469486

470487
# Fast API
@@ -548,7 +565,7 @@ async def v1_chat_completions(raw_request: fastapi.Request):
548565

549566
response = app.state.http_handler.generate_non_stream_response(request_id)
550567
app.state.http_handler.release_request(request_id)
551-
return ORJSONResponse(status_code=200, content=response)
568+
return JSONResponse(status_code=200, content=response)
552569
except Exception as e:
553570
# Handle any unexpected errors during processing
554571
logger.error(f"Error processing non-streaming request {request_id}: {e}")

0 commit comments

Comments
 (0)