[https://nvbugs/5926823][fix] Propagate logprobs from prefill to decode in disagg#11727
Conversation
c3ed32a to
d2dd2f5
Compare
📝 WalkthroughWalkthroughA new disaggregated logprobs workflow is introduced with a reproduction module, guard conditions to safely access logprob attributes, and serialization support across the TensorRT-LLM stack. Changes include a prefill HTTP server, decode engine, logprob data threading through disaggregated parameters, and OpenAI protocol integration. Changes
Sequence Diagram(s)sequenceDiagram
participant Client
participant PrefillServer as Prefill HTTP Server
participant DecodeEngine as Decode Engine
participant LLM as TRT-LLM Executor
Client->>PrefillServer: POST /prefill (prompt, max_tokens, logprobs)
PrefillServer->>LLM: context_only generation
LLM-->>PrefillServer: first_gen_tokens, first_gen_log_probs, opaque_state
PrefillServer-->>Client: PrefillResponse (tokens + logprobs + state)
Client->>DecodeEngine: generate_async with remote_prefill=true
DecodeEngine->>DecodeEngine: _remote_prefill (deserialize opaque_state)
DecodeEngine->>LLM: streaming generation with DisaggregatedParams
LLM-->>DecodeEngine: stream chunks with logprobs
DecodeEngine-->>Client: streamed output with first_gen_log_probs appended
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 7
🧹 Nitpick comments (2)
tensorrt_llm/disaggregated_params.py (1)
32-33: Define a concrete contract type forfirst_gen_log_probs.
Optional[List]is too loose for a field that crosses prefill/decode and protocol boundaries. Please use an explicit nested type alias (beam/token shape) to avoid downstream wrapping/shape ambiguity.Also applies to: 42-42
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/disaggregated_params.py` around lines 32 - 33, Define explicit nested type aliases and use them to annotate first_gen_log_probs instead of Optional[List]; for example create aliases like TokenLogprobs = Dict[int, float] (token id -> logprob), BeamLogprobs = List[TokenLogprobs] (one list per beam), and FirstGenLogProbs = Optional[List[BeamLogprobs]] (or Optional[BeamLogprobs] if outer list is redundant), then replace the loose Optional[List] annotations for the first_gen_log_probs field(s) with FirstGenLogProbs; update both occurrences of first_gen_log_probs in this module so downstream code has an unambiguous beam/token shape contract.cum_log_probs_repro.py (1)
64-67: Make the model path configurable instead of hardcoding a local absolute path.Hardcoding
/home/...makes the repro brittle for reviewers and CI environments.💡 Suggested fix
import argparse import base64 +import os import threading @@ CACHE_TRANSCEIVER = {"backend": "UCX", "max_tokens_in_buffer": 2048} +DEFAULT_MODEL_PATH = "/home/scratch.bbuddharaju_gpu/random/hf_models/TinyLlama-1.1B-Chat-v1.0" +MODEL_PATH = os.getenv("TRTLLM_MODEL_PATH", DEFAULT_MODEL_PATH) @@ self.llm = LLM( - model="/home/scratch.bbuddharaju_gpu/random/hf_models/TinyLlama-1.1B-Chat-v1.0", + model=MODEL_PATH, disable_overlap_scheduler=True, cache_transceiver_config=CACHE_TRANSCEIVER, ) @@ - self.llm = LLM(model="/home/scratch.bbuddharaju_gpu/random/hf_models/TinyLlama-1.1B-Chat-v1.0", cache_transceiver_config=CACHE_TRANSCEIVER) + self.llm = LLM(model=MODEL_PATH, cache_transceiver_config=CACHE_TRANSCEIVER)Also applies to: 135-135
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cum_log_probs_repro.py` around lines 64 - 67, Replace the hardcoded absolute model path used as the model= argument with a configurable value: read it from an environment variable (e.g. os.environ.get('MODEL_PATH')) or add a CLI/config parameter and fall back to a sensible default; update both places where model="/home/..." is passed (the model= keyword in the object/constructor calls around the current occurrences) so tests/CI/reviewers can override the path without editing the file.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@cum_log_probs_repro.py`:
- Around line 51-52: The code currently flattens the nested token-position/top-k
structure when handling first_gen_log_probs, which loses the original grouping
when logprobs > 1; update the serialization and deserialization logic that
reads/writes first_gen_log_probs (and the similar blocks around the other
occurrences) to preserve the nested shape — keep a list of per-token objects
each containing the top-k list of SerializedLogprob entries instead of merging
them into a flat map; specifically, change the routines that iterate over
first_gen_log_probs to serialize each token's top-k array as-is and to
reconstruct the exact token-position → top-k list structure on load (ensure
types remain list[SerializedLogprob] | None and adjust any flattening helper
functions to operate on the nested lists rather than concatenating entries).
- Around line 1-8: Add the required NVIDIA Apache-2.0 copyright header to the
top of the new source file cum_log_probs_repro.py: insert the standard NVIDIA
copyright block (including the correct year and "NVIDIA CORPORATION &
AFFILIATES"), the Apache License, Version 2.0 notice and URL, and the short
license disclaimer sentence before any module docstring or code so the file
begins with the full header followed by the existing module docstring.
- Around line 79-82: The /prefill handler can run before the engine/LLM is
initialized; update the route in prefill (the async def prefill(req:
PrefillRequest) -> PrefillResponse) to guard against early requests by checking
a readiness condition (e.g., self.llm is not None or a boolean like
self.initialized or a threading.Event) and if not ready raise an HTTP 503
(HTTPException(status_code=503, detail="Engine not ready")) or return an
appropriate error response; only call self._generate_local_prefill(req) when the
readiness check passes. Ensure the readiness flag is set when start() finishes
assigning self.llm.
In `@tensorrt_llm/_torch/pyexecutor/py_executor.py`:
- Around line 2744-2750: The code assumes transferred first_gen_log_probs is
single-beam by always wrapping it in a list before calling
req.py_result.append_log_probs; instead, inspect
disagg_params.first_gen_log_probs shape/ndim (or a beam-size marker) and only
wrap when it's a 1D/single-beam vector, pass it through unchanged when it's
already a 2D/beam-aware array, and otherwise raise a clear error to fail fast;
update the logic around disagg_params and first_gen_log_probs used before
calling req.py_result.append_log_probs to perform this shape check and
conditional wrapping (or explicit rejection) so dimensions stay consistent with
beam-aware payloads.
In `@tensorrt_llm/executor/result.py`:
- Around line 476-483: The current code copies first_gen_log_probs from only
self._outputs[0], which loses per-sequence data in beam/best_of flows; update
the block in the class/method that uses context_phase_params and
self._disaggregated_params so that you collect the first-generation token
logprobs for each produced sequence by iterating over self._outputs and
extracting the first token's logprobs from each output that has logprobs (e.g.,
build a list like [out.logprobs[0] for out in self._outputs if out.logprobs])
and assign that list to self._disaggregated_params.first_gen_log_probs, ensuring
the resulting list length matches the number of produced sequences.
- Around line 307-324: The length assertions on output.logprobs should be
skipped when a sequence was cancelled: change the checks around output.logprobs/
output.length to first test finish_reasons[src_idx] != CANCELLED (or equivalent
enum/constant) and only run the strict assert/warning logic if not cancelled;
for cancelled sequences allow partial logprobs without raising and keep the
existing warning/handling for disaggregated cases intact (update the branch that
currently contains the assert lines referencing finish_reasons, src_idx,
output.logprobs, and output.length).
In `@tensorrt_llm/serve/openai_protocol.py`:
- Around line 1090-1115: In _serialize_first_gen_log_probs and
_deserialize_first_gen_log_probs validate the nested structure at the protocol
boundary: ensure input is a list of lists/dicts with required keys ("token_id",
"logprob") and optional "rank", check types (token_id int, logprob number, rank
None or int), and raise ValueError with descriptive messages on mismatch rather
than letting AttributeError/KeyError propagate; in
_serialize_first_gen_log_probs verify each pos is a mapping and each lp has
attributes logprob and rank before serializing, and in
_deserialize_first_gen_log_probs verify each item is a dict containing
"token_id" and "logprob" (and that values are the right types) before
constructing Logprob, raising ValueError mentioning the offending position/item
when validation fails.
---
Nitpick comments:
In `@cum_log_probs_repro.py`:
- Around line 64-67: Replace the hardcoded absolute model path used as the
model= argument with a configurable value: read it from an environment variable
(e.g. os.environ.get('MODEL_PATH')) or add a CLI/config parameter and fall back
to a sensible default; update both places where model="/home/..." is passed (the
model= keyword in the object/constructor calls around the current occurrences)
so tests/CI/reviewers can override the path without editing the file.
In `@tensorrt_llm/disaggregated_params.py`:
- Around line 32-33: Define explicit nested type aliases and use them to
annotate first_gen_log_probs instead of Optional[List]; for example create
aliases like TokenLogprobs = Dict[int, float] (token id -> logprob),
BeamLogprobs = List[TokenLogprobs] (one list per beam), and FirstGenLogProbs =
Optional[List[BeamLogprobs]] (or Optional[BeamLogprobs] if outer list is
redundant), then replace the loose Optional[List] annotations for the
first_gen_log_probs field(s) with FirstGenLogProbs; update both occurrences of
first_gen_log_probs in this module so downstream code has an unambiguous
beam/token shape contract.
ℹ️ Review info
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
cum_log_probs_repro.pytensorrt_llm/_torch/pyexecutor/llm_request.pytensorrt_llm/_torch/pyexecutor/py_executor.pytensorrt_llm/disaggregated_params.pytensorrt_llm/executor/result.pytensorrt_llm/serve/openai_protocol.py
|
@brb-nv do we have a test for this? If not, do we need to add one? |
tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py
Outdated
Show resolved
Hide resolved
c263977 to
96b0815
Compare
|
/bot run --disable-fail-fast |
96b0815 to
637cc1d
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #36868 [ run ] triggered by Bot. Commit: |
|
PR_Github #36868 [ run ] completed with state
|
ba3cacf to
3ee8d7d
Compare
|
/bot run --disable-fail-fast |
cd18351 to
82613cd
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #36972 [ run ] triggered by Bot. Commit: |
…de in disagg Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
82613cd to
2b5d0b3
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #36993 [ run ] triggered by Bot. Commit: |
|
/bot run --disable-fail-fast |
|
PR_Github #37034 [ run ] triggered by Bot. Commit: |
|
PR_Github #37034 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
1 similar comment
|
/bot run --disable-fail-fast |
|
PR_Github #37091 [ run ] triggered by Bot. Commit: |
|
PR_Github #37092 [ run ] triggered by Bot. Commit: |
|
PR_Github #37091 [ run ] completed with state |
|
PR_Github #37092 [ run ] completed with state |
…de in disagg (NVIDIA#11727) Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
…de in disagg (NVIDIA#11727) Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
Description
This is a proposed fix for these bugs:
https://nvbugspro.nvidia.com/bug/5926823
https://nvbugspro.nvidia.com/bug/5926799
While both manifest differently, the core problem is that we don't transfer logprobs from prefill->decode in disagg.
Question for reviewers:
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]to print this help message.See details below for each supported subcommand.
Details
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-listparameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/developer-guide/ci-overview.mdand the
scripts/test_to_stage_mapping.pyhelper.kill
killKill all running builds associated with pull request.
skip
skip --comment COMMENTSkip testing for latest commit on pull request.
--comment "Reason for skipping build/test"is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipelineReuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.
Summary by CodeRabbit
Release Notes
New Features
Bug Fixes
Tests