Skip to content

Commit 24a2991

Browse files
gufengcgufengc
andauthored
chore(sglang): upgrade sglang to 0.5.12 (#450)
Co-authored-by: gufengc <gufeng@graident.network>
1 parent 3222958 commit 24a2991

7 files changed

Lines changed: 126 additions & 259 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ mac = [
5050
]
5151

5252
gpu = [
53-
"sglang[all]==0.5.7",
53+
"sglang[all]==0.5.12",
54+
"accelerate",
5455
"mlx-lm==0.28.4",
5556
"mlx[cpu]==0.30.0",
5657
]

src/parallax/server/executor/base_executor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,10 +618,12 @@ def _handle_raw_request(self, raw_request: Dict):
618618
if self.tokenizer.chat_template:
619619
messages = raw_request["messages"]
620620
process_message_content(messages)
621-
chat_template_kwargs = raw_request.get("chat_template_kwargs", {})
621+
chat_template_kwargs = dict(raw_request.get("chat_template_kwargs", {}))
622622
# check extra_body for backward compatibility
623623
if "extra_body" in raw_request and "chat_template_kwargs" in raw_request["extra_body"]:
624624
chat_template_kwargs.update(raw_request["extra_body"]["chat_template_kwargs"])
625+
# Transformers 5.x defaults return_dict=True, but Parallax expects list[int].
626+
chat_template_kwargs["return_dict"] = False
625627

626628
prompt = self.tokenizer.apply_chat_template(
627629
messages,

src/parallax/server/executor/sglang_executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def __init__(
126126
"dp_rank": dp_rank,
127127
"dp_size": dp_size,
128128
"nccl_port": nccl_port,
129-
"using_hfcache": use_hfcache,
129+
"use_hfcache": use_hfcache,
130130
"enable_lora": self.enable_lora,
131131
"max_lora_rank": self.max_lora_rank,
132132
"lora_target_modules": self.lora_target_modules,

src/parallax/sglang/batch_info.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
ScheduleBatch -> ModelWorkerBatch -> ForwardBatch
66
"""
77

8-
from types import SimpleNamespace
98
from typing import List, Optional
109

1110
import torch
1211
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
12+
from sglang.srt.mem_cache.cache_init_params import CacheInitParams
13+
from sglang.srt.mem_cache.chunk_cache import ChunkCache
14+
from sglang.srt.mem_cache.radix_cache import RadixCache as PageRadixCache
1315
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
1416
from sglang.srt.model_executor.model_runner import ModelRunner
1517
from sglang.srt.sampling.sampling_batch_info import (
@@ -18,7 +20,6 @@
1820
from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams
1921
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
2022

21-
from parallax.server.executor.sglang_executor import PageRadixCache
2223
from parallax.server.request import Request
2324
from parallax.server.sampling.sampling_params import (
2425
SamplingParams as ParallaxSamplingParams,
@@ -95,23 +96,23 @@ def form_sgl_batch_prefill(
9596
) -> ForwardBatch:
9697
"""Initialize a prefill ScheduleBatch -> ModelWorkerBatch -> ForwardBatch workflow"""
9798

98-
sgl_reqs = transform_requests_to_sglang(requests, page_tree_cache)
99+
tree_cache = page_tree_cache
100+
if tree_cache is None:
101+
cache_params = CacheInitParams(
102+
disable=True,
103+
req_to_token_pool=model_runner.req_to_token_pool,
104+
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
105+
page_size=model_runner.server_args.page_size,
106+
)
107+
tree_cache = ChunkCache(cache_params)
99108

100-
def dummy_evict(*args):
101-
pass
109+
sgl_reqs = transform_requests_to_sglang(requests, tree_cache)
102110

103-
dummy_tree_cache = SimpleNamespace(
104-
page_size=model_runner.server_args.page_size,
105-
device=model_runner.device,
106-
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
107-
evictable_size=0,
108-
)
109-
dummy_tree_cache.evict = dummy_evict
110111
schedule_batch = ScheduleBatch.init_new(
111112
reqs=sgl_reqs,
112113
req_to_token_pool=model_runner.req_to_token_pool,
113114
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
114-
tree_cache=page_tree_cache if page_tree_cache is not None else dummy_tree_cache,
115+
tree_cache=tree_cache,
115116
model_config=model_runner.model_config,
116117
enable_overlap=False,
117118
spec_algorithm=SpeculativeAlgorithm.NONE,

src/parallax/sglang/model_runner.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
init_distributed_environment,
2222
set_custom_all_reduce,
2323
set_mscclpp_all_reduce,
24+
set_torch_symm_mem_all_reduce,
2425
)
2526
from sglang.srt.layers.dp_attention import (
2627
get_attention_tp_group,
@@ -118,6 +119,8 @@ def init_torch_distributed(self):
118119
backend = "gloo"
119120
elif self.device == "npu":
120121
backend = "hccl"
122+
else:
123+
backend = "gloo"
121124

122125
before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
123126
if not self.server_args.enable_p2p_check:
@@ -129,6 +132,7 @@ def init_torch_distributed(self):
129132
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
130133
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
131134
set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
135+
set_torch_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem)
132136

133137
if not self.is_draft_worker:
134138
if self.device == "cpu":
@@ -153,14 +157,21 @@ def init_torch_distributed(self):
153157
local_rank=self.gpu_id,
154158
distributed_init_method=dist_init_method,
155159
timeout=self.server_args.dist_timeout,
160+
moe_a2a_backend=self.server_args.moe_a2a_backend,
161+
recovered_rank=self.server_args.elastic_ep_rejoin,
156162
)
157163

158164
# Use monkey patch modified function
159165
sglang.srt.distributed.parallel_state.initialize_model_parallel(
160166
tensor_model_parallel_size=self.tp_size,
161167
pipeline_model_parallel_size=self.pp_size,
162168
expert_model_parallel_size=self.moe_ep_size,
169+
attention_data_parallel_size=self.dp_size,
170+
attention_context_model_parallel_size=self.attn_cp_size,
171+
moe_data_model_parallel_size=self.moe_dp_size,
163172
duplicate_tp_group=self.server_args.enable_pdmux,
173+
enable_symm_mem=self.server_args.enable_symm_mem,
174+
recovered_rank=self.server_args.elastic_ep_rejoin,
164175
pp_start_layer=self.pp_start_layer,
165176
pp_end_layer=self.pp_end_layer,
166177
hidden_layers=self.model_config.num_hidden_layers,
@@ -225,6 +236,7 @@ def form_sgl_server_args(
225236
lora_eviction_policy: Optional[str] = "lru",
226237
lora_backend: Optional[str] = "triton",
227238
max_lora_chunk_size: Optional[int] = 128,
239+
max_num_tokens_per_batch: int = 16384,
228240
):
229241
"""Creates a SGL ServerArgs object"""
230242
sgl_server_args = ServerArgs(
@@ -247,6 +259,7 @@ def form_sgl_server_args(
247259
lora_backend=lora_backend,
248260
max_lora_chunk_size=max_lora_chunk_size,
249261
dp_size=dp_size,
262+
max_total_tokens=max_num_tokens_per_batch,
250263
)
251264
return sgl_server_args
252265

@@ -338,6 +351,7 @@ def initialize_sgl_model_runner(
338351
lora_eviction_policy,
339352
lora_backend,
340353
max_lora_chunk_size,
354+
max_num_tokens_per_batch=max_num_tokens_per_batch,
341355
)
342356
initialize_moe_config(server_args)
343357
quant_method = None

0 commit comments

Comments
 (0)