@@ -976,7 +976,7 @@ def check_decode_mem(self, buf_multiplier=1, selected_indices: list[int] | None
976976 self ._evict_tree_cache_if_needed (num_tokens )
977977 return self ._is_available_size_sufficient (num_tokens )
978978
979- def retract_decode (self , server_args : ServerArgs ):
979+ def retract_decode (self , server_args : ServerArgs ) -> tuple [ list [ Req ], float , list [ Req ]] :
980980 """Retract the decoding requests when there is not enough memory."""
981981 sorted_indices = list (range (len (self .reqs )))
982982
@@ -990,19 +990,9 @@ def retract_decode(self, server_args: ServerArgs):
990990
991991 retracted_reqs = []
992992 first_iter = True
993- while (not self .check_decode_mem (selected_indices = sorted_indices )) or first_iter :
993+ while first_iter or (not self .check_decode_mem (selected_indices = sorted_indices )):
994994 if len (sorted_indices ) == 1 :
995- # Corner case: only one request left
996- if self .is_hybrid :
997- full_available_size = self .token_to_kv_pool_allocator .full_available_size ()
998- swa_available_size = self .token_to_kv_pool_allocator .swa_available_size ()
999- assert (
1000- full_available_size > 0 and swa_available_size > 0
1001- ), f"No space left for only one request in SWA mode { full_available_size = } , { swa_available_size = } "
1002- else :
1003- assert (
1004- self .token_to_kv_pool_allocator .available_size () > 0
1005- ), f"No space left for only one request, { self .token_to_kv_pool_allocator .available_size ()= } "
995+ # Keep at least one request in the loop; handle OOM below.
1006996 break
1007997
1008998 first_iter = False
@@ -1011,11 +1001,24 @@ def retract_decode(self, server_args: ServerArgs):
10111001 retracted_reqs .append (req )
10121002 self .release_req (idx , len (sorted_indices ), server_args )
10131003
1014- if len (retracted_reqs ) == 0 :
1015- # Corner case: only one request left
1016- raise ValueError (
1017- "Failed to retract any request. No space left for only one request."
1018- )
1004+ # If the last remaining request still can't fit, abort it gracefully
1005+ # instead of crashing the scheduler (follows upstream sglang).
1006+ reqs_to_abort : list [Req ] = []
1007+ if len (sorted_indices ) <= 1 and not self .check_decode_mem (selected_indices = sorted_indices ):
1008+ last_idx = sorted_indices .pop ()
1009+ last_req = self .reqs [last_idx ]
1010+ last_req .to_finish = FINISH_ABORT (
1011+ "Out of memory even after retracting all other requests "
1012+ "in the decode batch. Aborting the last request." ,
1013+ HTTPStatus .INTERNAL_SERVER_ERROR ,
1014+ "InternalServerError" ,
1015+ )
1016+ reqs_to_abort .append (last_req )
1017+ self .release_req (last_idx , 0 , server_args )
1018+ logger .warning (
1019+ "retract_decode: aborted last request %s due to OOM" ,
1020+ last_req .rid ,
1021+ )
10191022
10201023 self .filter_batch (keep_indices = sorted_indices )
10211024
@@ -1025,10 +1028,12 @@ def retract_decode(self, server_args: ServerArgs):
10251028
10261029 new_estimate_ratio = (
10271030 total_decoded_tokens + global_config .retract_decode_steps * len (self .reqs )
1028- ) / total_max_new_tokens
1031+ ) / (
1032+ total_max_new_tokens + 1
1033+ ) # +1 to avoid zero division when all reqs aborted
10291034 new_estimate_ratio = min (1.0 , new_estimate_ratio )
10301035
1031- return retracted_reqs , new_estimate_ratio
1036+ return retracted_reqs , new_estimate_ratio , reqs_to_abort
10321037
10331038 def release_req (self , idx : int , remaing_req_count : int , server_args : ServerArgs ):
10341039 req = self .reqs [idx ]
0 commit comments