Skip to content

Commit 5d01e78

Browse files
Resolve merge conflict
2 parents 0b71c1f + a978a0c commit 5d01e78

30 files changed

Lines changed: 1878 additions & 99 deletions

pyproject.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,14 @@ dependencies = [
3737

3838
mac = [
3939
"torch==2.8.0",
40-
"mlx-lm==0.26.4",
41-
"mlx==0.28.0",
40+
"mlx-lm==0.28.0",
41+
"mlx==0.29.1",
4242
]
4343

4444
gpu = [
45-
"mlx-lm==0.26.4",
46-
"mlx[cpu]==0.28.0",
47-
"sglang[all]==0.5.1.post3",
45+
"mlx-lm==0.28.0",
46+
"mlx[cpu]==0.29.1",
47+
"sglang[all]==0.5.2",
4848
]
4949

5050
benchmark = [

src/backend/main.py

Lines changed: 66 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
1+
import asyncio
2+
import json
13
import time
24
import uuid
35

46
import uvicorn
57
from fastapi import FastAPI, Request
8+
from fastapi.responses import JSONResponse, StreamingResponse
69

710
from backend.server.request_handler import RequestHandler
811
from backend.server.scheduler_manage import SchedulerManage
912
from backend.server.server_args import parse_args
13+
from backend.server.static_config import get_model_list, get_node_join_command
1014
from parallax_utils.logging_config import get_logger
1115

1216
app = FastAPI()
@@ -27,6 +31,68 @@ async def hello():
2731
return {"message": "Hello, World!"}
2832

2933

34+
@app.get("/model/list")
35+
async def model_list():
36+
return JSONResponse(
37+
content={
38+
"type": "model_list",
39+
"data": get_model_list(),
40+
},
41+
status_code=200,
42+
)
43+
44+
45+
@app.post("/scheduler/init")
46+
async def scheduler_init(raw_request: Request):
47+
request_data = await raw_request.json()
48+
model_name = request_data.get("model_name")
49+
init_nodes_num = request_data.get("init_nodes_num")
50+
is_local_network = request_data.get("is_local_network")
51+
if scheduler_manage.is_running():
52+
# todo reinit
53+
pass
54+
else:
55+
scheduler_manage.run(model_name, init_nodes_num, is_local_network)
56+
return JSONResponse(
57+
content={
58+
"type": "scheduler_init",
59+
"data": None,
60+
},
61+
status_code=200,
62+
)
63+
64+
65+
@app.get("/node/join/command")
66+
async def node_join_command():
67+
model_name = scheduler_manage.get_model_name()
68+
is_local_network = scheduler_manage.get_is_local_network()
69+
70+
return JSONResponse(
71+
content={
72+
"type": "node_join_command",
73+
"data": get_node_join_command(model_name, "${scheduler_addr}", is_local_network),
74+
},
75+
status_code=200,
76+
)
77+
78+
79+
@app.get("/cluster/status")
80+
async def cluster_status():
81+
async def stream_cluster_status():
82+
while True:
83+
yield json.dumps(scheduler_manage.get_cluster_status(), ensure_ascii=False) + "\n"
84+
await asyncio.sleep(1)
85+
86+
return StreamingResponse(
87+
stream_cluster_status(),
88+
media_type="application/x-ndjson",
89+
headers={
90+
"Cache-Control": "no-cache",
91+
"Connection": "keep-alive",
92+
},
93+
)
94+
95+
3096
@app.post("/v1/completions")
3197
async def openai_v1_completions(raw_request: Request):
3298
request_data = await raw_request.json()
@@ -70,9 +136,6 @@ async def openai_v1_chat_completions(raw_request: Request):
70136
init_nodes_num = args.init_nodes_num
71137
if model_name is not None and init_nodes_num is not None:
72138
scheduler_manage.run(model_name, init_nodes_num)
73-
else:
74-
logger.error("model_name and init_nodes_num are not set")
75-
exit(1)
76139

77140
port = args.port
78141

src/backend/server/constants.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Cluster status constants
2+
CLUSTER_STATUS_WAITING = "waiting"
3+
CLUSTER_STATUS_AVAILABLE = "available"
4+
CLUSTER_STATUS_REBALANCING = "rebalancing"
5+
6+
# Node status constants
7+
NODE_STATUS_WAITING = "waiting"
8+
NODE_STATUS_AVAILABLE = "available"
9+
NODE_STATUS_FAILED = "failed"

src/backend/server/request_handler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from fastapi import HTTPException
55
from fastapi.responses import JSONResponse, StreamingResponse
66

7+
from backend.server.constants import NODE_STATUS_AVAILABLE
78
from parallax_utils.logging_config import get_logger
89

910
logger = get_logger(__name__)
@@ -37,7 +38,7 @@ async def _forward_request(
3738
)
3839
if (
3940
self.scheduler_manage is None
40-
or not self.scheduler_manage.get_schedule_status() == "success"
41+
or not self.scheduler_manage.get_schedule_status() == NODE_STATUS_AVAILABLE
4142
):
4243
return JSONResponse(
4344
content={"error": "Server is not ready"},

src/backend/server/rpc_connection_handler.py

Lines changed: 4 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,6 @@ def __init__(
2828

2929
@rpc_stream
3030
def node_join(self, message):
31-
# node = {
32-
# "call_url": "http://127.0.0.1:8000",
33-
# "node_id": "lattica peer id",
34-
# "hardware": {
35-
# "node_id": "lattica peer id",
36-
# "tflops_fp16": 100,
37-
# "memory_gb": 100,
38-
# "memory_bandwidth_gbps": 100,
39-
# },
40-
# "model_name": "",
41-
# "kv_cache_ratio": 0.3,
42-
# "param_hosting_ratio": 0.5,
43-
# "max_concurrent_requests": 16,
44-
# "max_sequence_length": 1024,
45-
# }
4631
logger.info(f"receive node_join request: {message}")
4732
try:
4833
node = self.build_node(message)
@@ -78,6 +63,7 @@ def node_update(self, message):
7863
current_requests=node.current_requests,
7964
layer_latency_ms=node.layer_latency_ms,
8065
new_rtt_to_nodes=node.rtt_to_nodes,
66+
is_active=node.is_active,
8167
)
8268
return {}
8369
except Exception as e:
@@ -110,6 +96,7 @@ def build_node(self, node_json: dict):
11096
param_hosting_ratio=node_json.get("param_hosting_ratio"),
11197
max_concurrent_requests=node_json.get("max_concurrent_requests"),
11298
max_sequence_length=node_json.get("max_sequence_length"),
99+
is_active=node_json.get("is_active", True),
113100
)
114101
if node_json.get("start_layer", None) is not None:
115102
node.start_layer = node_json.get("start_layer")
@@ -126,11 +113,13 @@ def build_node(self, node_json: dict):
126113
def build_hardware(self, hardware_json):
127114
node_id = hardware_json.get("node_id")
128115
tflops_fp16 = hardware_json.get("tflops_fp16")
116+
gpu_name = hardware_json.get("gpu_name")
129117
memory_gb = hardware_json.get("memory_gb")
130118
memory_bandwidth_gbps = hardware_json.get("memory_bandwidth_gbps")
131119
return NodeHardwareInfo(
132120
node_id=node_id,
133121
tflops_fp16=tflops_fp16,
122+
gpu_name=gpu_name,
134123
memory_gb=memory_gb,
135124
memory_bandwidth_gbps=memory_bandwidth_gbps,
136125
)

src/backend/server/scheduler_manage.py

Lines changed: 81 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55
from lattica import Lattica
66

7+
from backend.server.constants import NODE_STATUS_AVAILABLE, NODE_STATUS_WAITING
78
from backend.server.rpc_connection_handler import RPCConnectionHandler
8-
from backend.server.static_config import get_model_info
9+
from backend.server.static_config import get_model_info, get_node_join_command
910
from parallax_utils.logging_config import get_logger
1011
from scheduling.node import RequestSignal
1112
from scheduling.scheduler import Scheduler
@@ -14,7 +15,8 @@
1415

1516

1617
class SchedulerManage:
17-
"""Coordinates the in-process scheduler and the P2P RPC layer.
18+
"""
19+
Coordinates the in-process scheduler and the P2P RPC layer.
1820
1921
This manager owns the `Scheduler` instance and the Lattica P2P node,
2022
wiring RPC calls from workers to scheduler events.
@@ -35,28 +37,81 @@ def __init__(
3537
self.host_maddrs = host_maddrs
3638
self.announce_maddrs = announce_maddrs
3739

40+
self.model_name = None
41+
self.init_nodes_num = None
3842
self.scheduler = None
3943
self.node_id = f"{dht_prefix}_announce"
4044
self.lattica = None
4145
self.stubs = {}
46+
self.is_local_network = False
4247

43-
def run(self, model_name, init_nodes_num):
44-
"""Start the scheduler and the P2P service for RPC handling."""
48+
def run(self, model_name, init_nodes_num, is_local_network=False):
49+
"""
50+
Start the scheduler and the P2P service for RPC handling.
51+
"""
4552
logger.info(
4653
f"SchedulerManage starting: model_name={model_name}, init_nodes_num={init_nodes_num}"
4754
)
55+
self.is_local_network = is_local_network
4856
self._start_scheduler(model_name, init_nodes_num)
4957
self._start_lattica()
5058

59+
def is_running(self):
60+
"""
61+
Returns True if the scheduler is running, False otherwise.
62+
"""
63+
return self.scheduler is not None
64+
65+
def get_model_name(self):
66+
return self.model_name
67+
68+
def get_init_nodes_num(self):
69+
return self.init_nodes_num
70+
71+
def get_is_local_network(self):
72+
return self.is_local_network
73+
74+
def get_cluster_status(self):
75+
return {
76+
"type": "cluster_status",
77+
"data": {
78+
"status": self.get_schedule_status(),
79+
"model_name": self.model_name,
80+
"init_nodes_num": self.init_nodes_num,
81+
"node_join_command": get_node_join_command(
82+
self.model_name, "${scheduler_addr}", self.is_local_network
83+
),
84+
"node_list": self.get_node_list(),
85+
},
86+
}
87+
88+
def get_node_list(self):
89+
if self.scheduler is None:
90+
return []
91+
92+
return [self.build_node_info(node) for node in self.scheduler.nodes]
93+
94+
def build_node_info(self, node):
95+
return {
96+
"node_id": node.node_id,
97+
"status": NODE_STATUS_AVAILABLE if node.is_active else NODE_STATUS_WAITING,
98+
"gpu_name": node.hardware.gpu_name,
99+
"gpu_memory": node.hardware.memory_gb,
100+
}
101+
51102
def _start_scheduler(self, model_name, init_nodes_num):
52-
"""Create the scheduler and start its background run loop if needed."""
103+
"""
104+
Create the scheduler and start its background run loop if needed.
105+
"""
53106
if self.scheduler is not None:
54107
logger.info("Scheduler already started; skipping re-initialization")
55108
return
56109

57-
mode_info = get_model_info(model_name)
58-
# 初始化 scheduler
59-
self.scheduler = Scheduler(mode_info, [], min_nodes_bootstrapping=init_nodes_num)
110+
self.model_name = model_name
111+
self.init_nodes_num = init_nodes_num
112+
113+
model_info = get_model_info(model_name)
114+
self.scheduler = Scheduler(model_info, [], min_nodes_bootstrapping=init_nodes_num)
60115

61116
# Run the scheduler's event/dispatch loops in background so the process
62117
# can continue to serve RPCs and HTTP traffic.
@@ -69,7 +124,9 @@ def _start_scheduler(self, model_name, init_nodes_num):
69124
logger.info("Scheduler background thread started (poll_interval=0.05)")
70125

71126
def _start_lattica(self):
72-
"""Initialize and start the Lattica P2P node used for RPCs."""
127+
"""
128+
Initialize and start the Lattica P2P node used for RPCs.
129+
"""
73130
logger.info(
74131
f"Starting Lattica with host_maddrs={self.host_maddrs}, mdns=False, dht_prefix={self.dht_prefix}"
75132
)
@@ -113,12 +170,12 @@ def get_routing_table(self, request_id, received_ts):
113170
request = RequestSignal(request_id, received_ts)
114171
self.scheduler.receive_request(request)
115172

116-
# 等待最长 5s, 但如果路由表已被设置(包括空列表),则立即返回
173+
# Wait up to 5 seconds, but return immediately if the routing table is set (including an empty list)
117174
start_time = time.time()
118175
while request.routing_table is None and (time.time() - start_time) < 5.0:
119176
time.sleep(0.05)
120177

121-
# 返回routing_table
178+
# Return the routing_table
122179
if request.routing_table is None:
123180
logger.info(
124181
f"Routing table not ready after {(time.time() - start_time):.2f}s for request_id={request_id}"
@@ -130,17 +187,26 @@ def get_routing_table(self, request_id, received_ts):
130187
return request.routing_table
131188

132189
def get_schedule_status(self):
133-
"""Return whether a full pipeline has been allocated across joined nodes."""
190+
"""
191+
Return whether a full pipeline has been allocated across joined nodes.
192+
"""
134193
if self.scheduler is None:
135194
logger.info("SchedulerManage status queried: waiting (scheduler not initialized)")
136-
return "waiting"
195+
return NODE_STATUS_WAITING
137196

138-
status = "success" if self.scheduler.layer_allocator.has_full_pipeline() else "waiting"
197+
# todo rebalance status
198+
status = (
199+
NODE_STATUS_AVAILABLE
200+
if self.scheduler.layer_allocator.has_full_pipeline()
201+
else NODE_STATUS_WAITING
202+
)
139203
logger.info(f"SchedulerManage status queried: {status}")
140204
return status
141205

142206
def get_call_url_by_node_id(self, node_id):
143-
"""Lookup the HTTP endpoint for a given node id managed by the RPC layer."""
207+
"""
208+
Lookup the HTTP endpoint for a given node id managed by the RPC layer.
209+
"""
144210
url = self.connection_handler.get_call_url_by_node_id(node_id)
145211
logger.info(f"Lookup call_url for node_id={node_id} -> {url}")
146212
return url

src/backend/server/server_args.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,9 @@ def parse_args() -> argparse.Namespace:
2929

3030
parser.add_argument("--port", type=int, default=5000, help="Port to listen on")
3131

32-
parser.add_argument(
33-
"--model-name", type=str, default="Qwen/Qwen3-0.6B-MLX-bf16", help="Model name"
34-
)
32+
parser.add_argument("--model-name", type=str, default=None, help="Model name")
3533

36-
parser.add_argument("--init-nodes-num", type=int, default=1, help="Number of initial nodes")
34+
parser.add_argument("--init-nodes-num", type=int, default=None, help="Number of initial nodes")
3735

3836
args = parser.parse_args()
3937

0 commit comments

Comments
 (0)