Skip to content

Conversation

@rainyfly
Copy link
Collaborator

@rainyfly rainyfly commented Dec 26, 2025

Motivation

背景

当抢占发生时,服务层需要在发生抢占的槽位和引擎做一次同步,确保该槽位引擎生成的 token 都已经接收完毕,才可以被调度回 waiting 队列进行重调度。
如果服务层发动了对请求的抢占,却没有和引擎做同步,就可能导致服务层将被抢占的请求调度回 prefill时候看到的 need prefill token 数是 N,而之后发给引擎时候又由于异步接收到了额外一个 token,最终引擎层看到的need prefill token 确是 N+1,这种不同步会导致该槽位的请求 hang 住。
为了解决这一同步问题,目前服务层的 token processor 通过 token_id -1来保证同步,逻辑是引擎在吐完有效 token 后,处理 preempted task 时会返回-1,代表之前该槽位生成的有效 token 都已经接收完毕,之后再把被抢占请求放回 waiting 队列以便被重调度。

但是实际上,生成-1 会代表两个含义:

  1. 该槽位没有请求,即本身无效槽位的地方会生成-1
  2. 该槽位上已分配的 block资源已经填充完毕,或者 chunked prefill 时候,也会生成-1

用-1来做抢占时候的同步,当发生如下情景时可能会因为鲁棒性不足导致出同步异常,将情况 1 所生成的-1 当做情况 2 所生成的-1:
例如,token_processor 中添加额外逻辑导致变慢,token_processor 无法及时接收到引擎每一个 step 所生成的 token,即引擎如果每吐两个 token 以上,token_processor却只处理一个 token,token_processor看到的 token 相比引擎生成严重落后。

解决方案

为了对抢占时所需要做的同步更加鲁棒,不采用token_id -1 作为同步 token,额外添加一个 token_id用来做同步。消除某些情况下因为无法判断是情况 1 还是情况 2 所生成的-1 带来的同步问题。

Modifications

Usage or Command

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings December 26, 2025 09:58
@paddle-bot
Copy link

paddle-bot bot commented Dec 26, 2025

Thanks for your contribution!

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR fixes a synchronization issue that occurs during request preemption by introducing a dedicated PREEMPTED_TOKEN_ID (-9) to distinguish preemption synchronization signals from other uses of token_id -1. Previously, -1 was ambiguously used for both invalid slots and completed block allocations, which could cause synchronization problems when the token processor processes tokens slower than the engine generates them.

Key changes:

  • Introduces PREEMPTED_TOKEN_ID = -9 constant to explicitly signal preemption completion
  • Adds preempted_idx list tracking in all model runners (GPU, XPU, Metax) to record which slots have been preempted
  • Updates token processor logic to specifically check for PREEMPTED_TOKEN_ID before rescheduling preempted requests
  • Adds guard in cache_output_tokens to only cache during decoding phase, preventing issues when requests are preempted during prefill

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
fastdeploy/config.py Defines new PREEMPTED_TOKEN_ID = -9 constant for unambiguous preemption signaling
fastdeploy/worker/xpu_model_runner.py Initializes and tracks preempted indices in share_inputs for XPU model runner
fastdeploy/worker/metax_model_runner.py Initializes and tracks preempted indices in share_inputs for Metax model runner
fastdeploy/worker/gpu_model_runner.py Initializes and tracks preempted indices in share_inputs for GPU model runner
fastdeploy/output/token_processor.py Updates reschedule logic to check for specific PREEMPTED_TOKEN_ID; removes old batch-based reschedule method; adjusts token_id comparison from <= 0 to < 0
fastdeploy/model_executor/pre_and_post_process.py Sets PREEMPTED_TOKEN_ID for preempted slots after token generation and input updates
fastdeploy/engine/sched/resource_manager_v1.py Adds condition to only cache output tokens during decoding phase, preventing issues during preemption in prefill

@codecov-commenter
Copy link

codecov-commenter commented Dec 26, 2025

Codecov Report

❌ Patch coverage is 46.66667% with 8 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@91a2b13). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/output/token_processor.py 40.00% 0 Missing and 3 partials ⚠️
fastdeploy/engine/sched/resource_manager_v1.py 0.00% 2 Missing ⚠️
fastdeploy/model_executor/pre_and_post_process.py 50.00% 1 Missing and 1 partial ⚠️
fastdeploy/worker/gpu_model_runner.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #5796   +/-   ##
==========================================
  Coverage           ?   65.50%           
==========================================
  Files              ?      337           
  Lines              ?    43084           
  Branches           ?     6638           
==========================================
  Hits               ?    28221           
  Misses             ?    12760           
  Partials           ?     2103           
Flag Coverage Δ
GPU 65.50% <46.66%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@rainyfly rainyfly changed the title [Optim][Bugfix] Robust sync status when preempted happens [Optim] Robust sync status when preempted happens Dec 29, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants