Skip to content

Commit f8d4cd3

Browse files
committed
add ad-hoc wait for large moe
1 parent e584f67 commit f8d4cd3

File tree

5 files changed

+92
-2
lines changed

5 files changed

+92
-2
lines changed

docker/patch/latest/sglang.patch

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,77 @@ index 1c541914c..6ed0e522d 100644
173173

174174
async def init_weights_send_group_for_remote_instance(
175175
self,
176+
diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py
177+
index 3d901ceb5..af9554b9a 100644
178+
--- a/python/sglang/srt/managers/tokenizer_manager.py
179+
+++ b/python/sglang/srt/managers/tokenizer_manager.py
180+
@@ -1060,6 +1060,9 @@ class TokenizerManager(TokenizerCommunicatorMixin):
181+
async with self.is_pause_cond:
182+
self.is_pause = True
183+
self.abort_request(abort_all=True)
184+
+ # do double abort to ensure all in-flight requests are aborted
185+
+ await asyncio.sleep(1)
186+
+ self.abort_request(abort_all=True)
187+
188+
async def continue_generation(self):
189+
async with self.is_pause_cond:
190+
@@ -1514,12 +1517,13 @@ class TokenizerManager(TokenizerCommunicatorMixin):
191+
return
192+
193+
if len(recv_obj.input_token_logprobs_val) > 0:
194+
- state.input_token_logprobs_val.extend(
195+
- recv_obj.input_token_logprobs_val[recv_obj_index]
196+
- )
197+
- state.input_token_logprobs_idx.extend(
198+
- recv_obj.input_token_logprobs_idx[recv_obj_index]
199+
- )
200+
+ if recv_obj.input_token_logprobs_val[recv_obj_index]:
201+
+ state.input_token_logprobs_val.extend(
202+
+ recv_obj.input_token_logprobs_val[recv_obj_index]
203+
+ )
204+
+ state.input_token_logprobs_idx.extend(
205+
+ recv_obj.input_token_logprobs_idx[recv_obj_index]
206+
+ )
207+
state.output_token_logprobs_val.extend(
208+
recv_obj.output_token_logprobs_val[recv_obj_index]
209+
)
210+
@@ -1731,14 +1735,24 @@ class TokenizerManager(TokenizerCommunicatorMixin):
211+
state.finished = True
212+
if recv_obj.finished_reason:
213+
out = {
214+
+ "text": "",
215+
+ "output_ids": [],
216+
"meta_info": {
217+
"id": recv_obj.rid,
218+
"finish_reason": recv_obj.finished_reason,
219+
+ "prompt_tokens": 0,
220+
+ "completion_tokens": 0,
221+
+ "model_version": self.server_args.weight_version,
222+
+ "cached_tokens": 0,
223+
+ "e2e_latency": 0,
224+
+ "output_token_logprobs": [[]],
225+
+ "input_token_logprobs": [[]],
226+
},
227+
}
228+
else:
229+
out = {
230+
"text": "",
231+
+ "output_ids": [],
232+
"meta_info": {
233+
"id": origin_rid,
234+
"finish_reason": {
235+
@@ -1747,6 +1761,11 @@ class TokenizerManager(TokenizerCommunicatorMixin):
236+
},
237+
"prompt_tokens": 0,
238+
"completion_tokens": 0,
239+
+ "model_version": self.server_args.weight_version,
240+
+ "cached_tokens": 0,
241+
+ "e2e_latency": 0,
242+
+ "output_token_logprobs": [[]],
243+
+ "input_token_logprobs": [[]],
244+
},
245+
}
246+
state.out_list.append(out)
176247
diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py
177248
index 0a1cededd..0093fe2a8 100644
178249
--- a/python/sglang/srt/model_executor/cuda_graph_runner.py

scripts/run-glm4.5-355B-A32B.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ ROLLOUT_ARGS=(
5050
--num-steps-per-rollout 4
5151
--balance-data
5252
--rollout-stop-token-ids 151329 151336 151338
53+
54+
# fault tolerance settings
55+
--rollout-health-check-first-wait 300
5356
)
5457

5558
EVAL_ARGS=(

scripts/run-qwen3-235B-A22B.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ ROLLOUT_ARGS=(
6262

6363
--global-batch-size 64
6464
--balance-data
65+
66+
# fault tolerance settings
67+
--rollout-health-check-first-wait 300
6568
)
6669

6770
EVAL_ARGS=(

slime/ray/rollout.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,10 @@ def __init__(self, args, pg, wandb_run_id):
5858
# fault tolerance
5959
self._health_monitor_thread = None
6060
self._health_monitor_stop_event = None
61-
self._health_check_interval = getattr(args, "rollout_health_check_interval", 10.0)
62-
self._health_check_timeout = getattr(args, "rollout_health_check_timeout", 5.0)
61+
self._health_check_interval = args.rollout_health_check_interval
62+
self._health_check_timeout = args.rollout_health_check_timeout
63+
self._health_check_is_first = True
64+
self._health_check_first_wait = args.rollout_health_check_first_wait
6365

6466
def get_rollout_engines_and_lock(self):
6567
return self.rollout_engines, self.rollout_engine_lock, self.num_new_engines
@@ -135,6 +137,11 @@ def _stop_health_monitor(self) -> None:
135137

136138
def _health_monitor_loop(self) -> None:
137139
assert self._health_monitor_stop_event is not None
140+
# TODO: need to be waiting for the large moe to be ready. this is hacky.
141+
if self._health_check_is_first:
142+
if self._health_monitor_stop_event.wait(self._health_check_first_wait):
143+
return
144+
self._health_check_is_first = False
138145
while not self._health_monitor_stop_event.is_set():
139146
self._run_health_checks()
140147
if self._health_monitor_stop_event.wait(self._health_check_interval):

slime/utils/arguments.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,12 @@ def add_rollout_arguments(parser):
228228
default=5.0,
229229
help="Timeout in seconds to wait for a rollout engine /health_generate response before killing it.",
230230
)
231+
parser.add_argument(
232+
"--rollout-health-check-first-wait",
233+
type=float,
234+
default=300.0,
235+
help="Time to wait for the compilation before the actual health check.",
236+
)
231237

232238
# sampling
233239
parser.add_argument(

0 commit comments

Comments
 (0)