Skip to content

[Fix] Defer req_to_token pool-index free in overlap scheduling to prevent cross-stream data race#18803

Closed
nvcastet wants to merge 1 commit intosgl-project:mainfrom
nvcastet:fix_mtpv2_better
Closed

[Fix] Defer req_to_token pool-index free in overlap scheduling to prevent cross-stream data race#18803
nvcastet wants to merge 1 commit intosgl-project:mainfrom
nvcastet:fix_mtpv2_better

Conversation

@nvcastet
Copy link
Collaborator

@nvcastet nvcastet commented Feb 13, 2026

Motivation

Fixes #18744

With MTPv2 overlap scheduling enabled (SGLANG_ENABLE_SPEC_V2=1), we see intermittent device-side assertion failures:

/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:111: operator(): block: [0,0,0], thread: [0,0,0]
Assertion `-sizes[i] <= index && index < sizes[i] && "index out of bounds"` failed.

Root cause: In overlap scheduling, process_batch_result(N-1) runs on the default stream concurrently with forward(N) draft extend on the forward stream. When a request finishes, release_kv_cache immediately returns its req_pool_idx to the free list. A new request can then recycle that pool index and prepare_for_decode overwrites the req_to_token row on the default stream while the draft extend attentionforward(N) still reads it.

Iteration N:
  forward stream: forward(N) reads req_to_token[pool_idx=1] for req A     ← still running
  default stream: process_batch_result(N-1) → A finished → free pool_idx=1

Iteration N+1:
  default stream: new req B gets pool_idx=1 (recycled)
  default stream: prepare_for_decode writes req_to_token[pool_idx=1] for B ← RACE!

Modifications

Instead of freeing immediately, append finished requests to a deferred list (_deferred_kv_release_reqs). The list is flushed at the start of the next process_batch_result_decode, right after copy_done.synchronize() -- which guarantees the forward stream has finished reading those rows. An additional flush during idle ensures slots aren't held unnecessarily when no more batches are running.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @nvcastet, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request addresses a critical data race condition occurring in MTPv2 overlap scheduling. The race happens when a req_to_token pool index is prematurely freed and reused by a new request while a concurrent forward pass is still reading from the old request's data at that same index. The fix introduces a deferred freeing mechanism for req_to_token pool indices, ensuring they are not reallocated until the previous forward pass that might be using them has fully completed, thereby preventing assertion failures and ensuring data integrity.

Highlights

  • Deferred Freeing Mechanism: Introduced ReqToTokenPool.deferred_free(req) to temporarily hold pool indices, preventing immediate reallocation.
  • Flush Deferred Frees: Added ReqToTokenPool.flush_deferred_frees() to release the held indices back to the free list after the relevant forward pass completes.
  • Conditional KV Cache Release: Modified release_kv_cache to accept a defer_pool_free parameter, allowing conditional use of the new deferred freeing logic.
  • Overlap Scheduling Integration: Updated process_batch_result_decode to utilize the deferred freeing mechanism when MTPv2 overlap scheduling is enabled, ensuring proper synchronization.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/managers/scheduler_output_processor_mixin.py
    • Added a call to self.req_to_token_pool.flush_deferred_frees() after result.copy_done.synchronize() to release deferred pool indices.
    • Modified calls to release_kv_cache to pass defer_pool_free=self.enable_overlap.
  • python/sglang/srt/mem_cache/common.py
    • Added a defer_pool_free boolean parameter to the release_kv_cache function.
    • Implemented conditional logic within release_kv_cache to either call tree_cache.req_to_token_pool.deferred_free(req) or tree_cache.req_to_token_pool.free(req) based on the defer_pool_free flag.
  • python/sglang/srt/mem_cache/memory_pool.py
    • Initialized a new _deferred_free_slots list to store indices awaiting deferred release.
    • Implemented deferred_free(self, req: Req) method to add a request's pool index to _deferred_free_slots.
    • Implemented flush_deferred_frees(self) method to move all indices from _deferred_free_slots to free_slots.
    • Updated the clear() method to also clear the _deferred_free_slots list.
Activity
  • No specific activity (comments, reviews, or progress updates) has been recorded for this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request effectively addresses a data race in overlap scheduling by deferring the freeing of req_to_token pool indices. The implementation is sound, introducing deferred_free and flush_deferred_frees methods to ReqToTokenPool and integrating them into the existing logic with the enable_overlap flag. The synchronization point for flushing deferred frees is well-chosen, ensuring correctness. I have one suggestion to refactor a small piece of code to improve maintainability.

@nvcastet
Copy link
Collaborator Author

/tag-and-rerun-ci

Copy link
Collaborator

@ShangmingCai ShangmingCai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nvcastet nvcastet marked this pull request as draft February 13, 2026 18:48
…vent cross-stream data race

In overlap scheduling (MTPv2), `process_batch_result(N-1)` runs on the
default stream concurrently with `forward(N)` on the forward stream.
When a request finishes, `release_kv_cache` immediately returns its
`req_pool_idx` to the free list.  A new request can then recycle that
pool index and `prepare_for_decode` overwrites the `req_to_token` row on
the default stream while `forward(N)` still reads it — causing an
"index out of bounds" assertion in IndexKernel.cu.

Fix: defer the pool-index free by one overlap iteration.

Co-authored-by: Trevor Morris <tmorris@nvidia.com>
@nvcastet nvcastet marked this pull request as ready for review February 13, 2026 22:38
@nvcastet
Copy link
Collaborator Author

/tag-and-rerun-ci

@nvcastet
Copy link
Collaborator Author

This fix just hides the bug by changing timing of the race condition.
In fact, result.copy_done.synchronize() is called just before release_kv_cache for the same batch so the forward is fully completed at that time.
Real fix from @hnyls2002 and @trevor-m investigations is at #18958

@nvcastet nvcastet closed this Feb 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] index out of bounds error with Spec V2

4 participants

Comments