Skip to content

Commit 333e620

Browse files
committed
update code style
1 parent c3aceb6 commit 333e620

File tree

4 files changed

+450
-5
lines changed

4 files changed

+450
-5
lines changed

fastdeploy/model_executor/layers/backends/xpu/moe/ep.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -493,10 +493,7 @@ def dispatch(
493493
# - if valid_token_num is NOT None, it means that we CAN accurately know
494494
# the size of the tensor, but the disadvantage is that it will interrupt
495495
# the process of kernel launch.
496-
if recv_expert_count is None:
497-
valid_token_num = -1
498-
else:
499-
valid_token_num = paddle.sum(recv_expert_count).item()
496+
valid_token_num = -1 # The EP operator will automatically calculate valid_token_num
500497

501498
if isinstance(recv_hidden_states, tuple):
502499
recv_x = recv_hidden_states[0]
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""
16+
EP4TP4 all2all测试 - Expert Parallel + Tensor Parallel (all2all通信)
17+
18+
测试配置:
19+
- 模型: ERNIE-4.5-300B-A47B-Paddle
20+
- 量化: wint4
21+
- Tensor Parallel: 4
22+
- Expert Parallel: 启用
23+
- Data Parallel: 1
24+
- 注意: 不使用 --disable-sequence-parallel-moe,启用all2all通信
25+
"""
26+
27+
28+
import openai
29+
import pytest
30+
from conftest import (
31+
get_model_path,
32+
get_port_num,
33+
print_logs_on_failure,
34+
restore_env,
35+
setup_ep_env,
36+
start_server,
37+
)
38+
39+
40+
def test_ep4tp4_all2all(xpu_env):
41+
"""EP4TP4 all2all通信测试"""
42+
43+
print("\n============================开始 EP4TP4 all2all 测试!============================")
44+
45+
# 设置EP环境变量
46+
original_env = setup_ep_env()
47+
48+
try:
49+
# 获取配置
50+
port_num = get_port_num()
51+
model_path = get_model_path()
52+
53+
# 构建服务器启动参数
54+
# 注意: 与EP4TP4 online相比,这里不使用 --disable-sequence-parallel-moe
55+
server_args = [
56+
"--model",
57+
f"{model_path}/ERNIE-4.5-300B-A47B-Paddle",
58+
"--port",
59+
str(port_num),
60+
"--tensor-parallel-size",
61+
"4",
62+
"--enable-expert-parallel",
63+
"--enable-prefix-caching",
64+
"--data-parallel-size",
65+
"1",
66+
"--max-model-len",
67+
"32768",
68+
"--max-num-seqs",
69+
"64",
70+
"--quantization",
71+
"wint4",
72+
"--engine-worker-queue-port",
73+
str(port_num + 10),
74+
"--metrics-port",
75+
str(port_num + 2),
76+
"--gpu-memory-utilization",
77+
"0.9",
78+
"--graph-optimization-config",
79+
'{"use_cudagraph":true}',
80+
]
81+
82+
# 启动服务器
83+
if not start_server(server_args):
84+
pytest.fail("EP4TP4 all2all服务启动失败")
85+
86+
# 执行测试
87+
ip = "0.0.0.0"
88+
client = openai.Client(base_url=f"http://{ip}:{port_num}/v1", api_key="EMPTY_API_KEY")
89+
90+
# 非流式对话
91+
response = client.chat.completions.create(
92+
model="default",
93+
messages=[
94+
{"role": "user", "content": "你好,你是谁?"},
95+
],
96+
temperature=1,
97+
top_p=0,
98+
max_tokens=64,
99+
stream=False,
100+
)
101+
102+
print(f"\n模型回复: {response.choices[0].message.content}")
103+
104+
# 验证响应
105+
assert any(
106+
keyword in response.choices[0].message.content for keyword in ["人工智能", "文心一言", "百度", "智能助手"]
107+
), f"响应内容不符合预期: {response.choices[0].message.content}"
108+
109+
print("\nEP4TP4 all2all测试通过!")
110+
111+
except Exception as e:
112+
print(f"\nEP4TP4 all2all测试失败: {str(e)}")
113+
print_logs_on_failure()
114+
pytest.fail(f"EP4TP4 all2all测试失败: {str(e)}")
115+
116+
finally:
117+
# 恢复环境变量
118+
restore_env(original_env)
119+
120+
121+
if __name__ == "__main__":
122+
pytest.main([__file__, "-v", "-s"])

0 commit comments

Comments
 (0)