Skip to content

Commit 6bed15f

Browse files
authored
fix(swa): graceful abort instead of crash when decode OOM after retract (#944)
When decode KV cache is exhausted in SWA hybrid mode, retract_decode previously hit an assert crash when only one request remained but memory was still insufficient. This changes the behavior to gracefully abort the last request and return an error response to the client, following upstream sglang's approach.
1 parent 7907875 commit 6bed15f

2 files changed

Lines changed: 39 additions & 22 deletions

File tree

python/sgl_jax/srt/managers/schedule_batch.py

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -976,7 +976,7 @@ def check_decode_mem(self, buf_multiplier=1, selected_indices: list[int] | None
976976
self._evict_tree_cache_if_needed(num_tokens)
977977
return self._is_available_size_sufficient(num_tokens)
978978

979-
def retract_decode(self, server_args: ServerArgs):
979+
def retract_decode(self, server_args: ServerArgs) -> tuple[list[Req], float, list[Req]]:
980980
"""Retract the decoding requests when there is not enough memory."""
981981
sorted_indices = list(range(len(self.reqs)))
982982

@@ -990,19 +990,9 @@ def retract_decode(self, server_args: ServerArgs):
990990

991991
retracted_reqs = []
992992
first_iter = True
993-
while (not self.check_decode_mem(selected_indices=sorted_indices)) or first_iter:
993+
while first_iter or (not self.check_decode_mem(selected_indices=sorted_indices)):
994994
if len(sorted_indices) == 1:
995-
# Corner case: only one request left
996-
if self.is_hybrid:
997-
full_available_size = self.token_to_kv_pool_allocator.full_available_size()
998-
swa_available_size = self.token_to_kv_pool_allocator.swa_available_size()
999-
assert (
1000-
full_available_size > 0 and swa_available_size > 0
1001-
), f"No space left for only one request in SWA mode {full_available_size=}, {swa_available_size=}"
1002-
else:
1003-
assert (
1004-
self.token_to_kv_pool_allocator.available_size() > 0
1005-
), f"No space left for only one request, {self.token_to_kv_pool_allocator.available_size()=}"
995+
# Keep at least one request in the loop; handle OOM below.
1006996
break
1007997

1008998
first_iter = False
@@ -1011,11 +1001,24 @@ def retract_decode(self, server_args: ServerArgs):
10111001
retracted_reqs.append(req)
10121002
self.release_req(idx, len(sorted_indices), server_args)
10131003

1014-
if len(retracted_reqs) == 0:
1015-
# Corner case: only one request left
1016-
raise ValueError(
1017-
"Failed to retract any request. No space left for only one request."
1018-
)
1004+
# If the last remaining request still can't fit, abort it gracefully
1005+
# instead of crashing the scheduler (follows upstream sglang).
1006+
reqs_to_abort: list[Req] = []
1007+
if len(sorted_indices) <= 1 and not self.check_decode_mem(selected_indices=sorted_indices):
1008+
last_idx = sorted_indices.pop()
1009+
last_req = self.reqs[last_idx]
1010+
last_req.to_finish = FINISH_ABORT(
1011+
"Out of memory even after retracting all other requests "
1012+
"in the decode batch. Aborting the last request.",
1013+
HTTPStatus.INTERNAL_SERVER_ERROR,
1014+
"InternalServerError",
1015+
)
1016+
reqs_to_abort.append(last_req)
1017+
self.release_req(last_idx, 0, server_args)
1018+
logger.warning(
1019+
"retract_decode: aborted last request %s due to OOM",
1020+
last_req.rid,
1021+
)
10191022

10201023
self.filter_batch(keep_indices=sorted_indices)
10211024

@@ -1025,10 +1028,12 @@ def retract_decode(self, server_args: ServerArgs):
10251028

10261029
new_estimate_ratio = (
10271030
total_decoded_tokens + global_config.retract_decode_steps * len(self.reqs)
1028-
) / total_max_new_tokens
1031+
) / (
1032+
total_max_new_tokens + 1
1033+
) # +1 to avoid zero division when all reqs aborted
10291034
new_estimate_ratio = min(1.0, new_estimate_ratio)
10301035

1031-
return retracted_reqs, new_estimate_ratio
1036+
return retracted_reqs, new_estimate_ratio, reqs_to_abort
10321037

10331038
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
10341039
req = self.reqs[idx]

python/sgl_jax/srt/managers/scheduler.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,13 +1298,22 @@ def update_running_batch(self, batch: ScheduleBatch) -> ScheduleBatch | None:
12981298
):
12991299
old_ratio = self.new_token_ratio
13001300

1301-
retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
1301+
retracted_reqs, new_token_ratio, reqs_to_abort = batch.retract_decode(self.server_args)
13021302
num_retracted_reqs = len(retracted_reqs)
13031303
self.new_token_ratio = new_token_ratio
13041304

1305+
# Send abort responses so clients get an error instead of a hung connection
1306+
for req in reqs_to_abort:
1307+
abort_out = AbortReq(rid=req.rid)
1308+
if self._comm_backend is not None:
1309+
self._comm_backend.send_pyobj(abort_out)
1310+
else:
1311+
self.send_to_tokenizer.send_pyobj(abort_out)
1312+
13051313
logger.info(
1306-
"KV cache pool is full. Retract requests. #retracted_reqs: %d, #new_token_ratio: %.4f -> %.4f",
1314+
"KV cache pool is full. Retract requests. #retracted_reqs: %d, #aborted_reqs: %d, #new_token_ratio: %.4f -> %.4f",
13071315
num_retracted_reqs,
1316+
len(reqs_to_abort),
13081317
old_ratio,
13091318
self.new_token_ratio,
13101319
)
@@ -1319,6 +1328,9 @@ def update_running_batch(self, batch: ScheduleBatch) -> ScheduleBatch | None:
13191328
if batch.batch_size() < initial_bs:
13201329
batch.batch_is_full = False
13211330

1331+
if batch.is_empty():
1332+
return batch
1333+
13221334
# Update batch arrays
13231335
batch.prepare_for_decode()
13241336
return batch

0 commit comments

Comments
 (0)