Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,12 +467,15 @@ def generation_logits(self) -> torch.Tensor | None:

@property
def log_probs(self) -> list[TokenLogprobs] | None:
return self._log_probs and hasattr(
self._log_probs, 'log_probs') and self._log_probs.log_probs
if not self._log_probs or not hasattr(self._log_probs, 'log_probs'):
return None
return self._log_probs.log_probs

@property
def cum_log_probs(self) -> list[float] | None:
return self._log_probs and self._log_probs.cum_log_probs
if not self._log_probs or not hasattr(self._log_probs, 'cum_log_probs'):
return None
return self._log_probs.cum_log_probs

@property
def mm_embedding_handles(self) -> List[Dict[str, Any]] | None:
Expand Down
13 changes: 13 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2821,6 +2821,19 @@ def _prepare_disagg_gen_transmission_complete(self, scheduled_batch):
for beam in range(0, beam_width):
req.add_new_token(first_gen_tokens[beam], beam)

# Prepend logprobs for first_gen_tokens if transferred from prefill.
disagg_params = getattr(req, 'py_disaggregated_params', None)
if (disagg_params is not None
and getattr(disagg_params, 'first_gen_log_probs',
None) is not None):
if beam_width != 1:
raise ValueError(
"first_gen_log_probs transfer currently assumes "
"beam_width == 1; beam search is not supported "
"with disaggregated logprobs propagation.")
req.py_result.append_log_probs(
[disagg_params.first_gen_log_probs])

@nvtx_range("_recv_disagg_gen_cache")
def _recv_disagg_gen_cache(self, new_gen_reqs):

Expand Down
3 changes: 3 additions & 0 deletions tensorrt_llm/disaggregated_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class DisaggregatedParams:
draft_tokens (List[int]): The draft tokens of the generation request
disagg_request_id (int): The disaggregated request id, if set, both context and generation requests will use it
as underlying request id.
first_gen_log_probs (List): The logprobs for first_gen_tokens, produced during prefill.
Each entry is a list (one per beam) of TokenLogprobs (list of dict[int, Logprob]).

multimodal_embedding_handles (List[Dict[str, Any]]): The resulting multimodal embedding handles from ViT.
multimodal_hashes (List[List[int]]): The multimodal hashes of each multimodal item in the request.
Expand All @@ -37,6 +39,7 @@ class DisaggregatedParams:
request_type: Optional[str] = None
# P-D Disaggregated Params
first_gen_tokens: Optional[List[int]] = None
first_gen_log_probs: Optional[List] = None
ctx_request_id: Optional[int] = None
opaque_state: Optional[bytes] = None
draft_tokens: Optional[List[int]] = None
Expand Down
43 changes: 36 additions & 7 deletions tensorrt_llm/executor/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from ..disaggregated_params import DisaggregatedParams
from ..llmapi.tracer import global_tracer
from ..llmapi.utils import AsyncQueue, print_traceback_on_error
from ..logger import logger
from ..metrics import MetricNames, MetricsCollector, RequestEventTiming
from ..sampling_params import LogprobParams, SamplingParams
from .utils import ErrorResponse, has_event_loop, is_llm_response
Expand Down Expand Up @@ -289,16 +290,33 @@ def _handle_sequence(self,
output.logprobs += response_tensors.log_probs[src_idx]

# overcome some WAR in the cpp executor
if finish_reasons[
src_idx] != tllm.FinishReason.CANCELLED and self.use_trtllm_sampler:
# Check if logprobs is a list (not a dict or other structure)
if len(output.logprobs) > output.length:
if finish_reasons[src_idx] != tllm.FinishReason.CANCELLED:
if self.use_trtllm_sampler and len(
output.logprobs) > output.length:
# LlmResult holds a reference to LogProbStorage, which may be updated by the worker before the result is serialized.
# Therefore, we treat extra logprobs/logits as expected and only consume what's needed.
output.logprobs = output.logprobs[:output.length]
assert len(
output.logprobs
) == output.length, f"logprobs length: {len(output.logprobs)} != output.length: {output.length}"

is_generation_only = (self.disaggregated_params is not None
and self.disaggregated_params.request_type
== "generation_only")
if is_generation_only:
assert len(output.logprobs) >= output.length - 1, (
f"logprobs length: {len(output.logprobs)} < "
f"output.length - 1: {output.length - 1}")
if len(output.logprobs) < output.length:
logger.warning(
"Disaggregated serving: the response contains "
"%d logprob entries instead of %d because "
"logprobs for the first generated token were "
"not transferred from the context server. "
"Enable logprobs on both the prefill and "
"decode servers to receive complete results.",
len(output.logprobs), output.length)
else:
assert len(output.logprobs) == output.length, (
f"logprobs length: {len(output.logprobs)} != "
f"output.length: {output.length}")

if response_tensors.generation_logits is not None:
output.generation_logits = response_tensors.generation_logits[
Expand Down Expand Up @@ -450,6 +468,17 @@ def _handle_response(self,
response_result.sequence_index,
logprobs_result, req_perf_metrics_dict)

# For context_only responses, carry the first gen token's logprobs
# so the generation_only side can prepend them.
if (context_phase_params is not None
and self._disaggregated_params is not None):
first_gen_lp = [
out.logprobs[0] for out in self._outputs if out.logprobs
]
if first_gen_lp:
self._disaggregated_params.first_gen_log_probs = \
first_gen_lp

if response_result.context_logits is not None:
self._context_logits = response_result.context_logits

Expand Down
51 changes: 51 additions & 0 deletions tensorrt_llm/serve/openai_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ class ResponseFormat(OpenAIBaseModel):
class DisaggregatedParams(OpenAIBaseModel):
request_type: str
first_gen_tokens: Optional[List[int]] = None
first_gen_log_probs: Optional[List] = None
ctx_request_id: Optional[int] = None
encoded_opaque_state: Optional[str] = None
draft_tokens: Optional[List[int]] = None
Expand Down Expand Up @@ -1086,13 +1087,61 @@ def decode_opaque_state(encoded_opaque_state: Optional[str]) -> Optional[bytes]:
return base64.b64decode(encoded_opaque_state)


def _serialize_first_gen_log_probs(
first_gen_log_probs: Optional[list], ) -> Optional[List]:
"""Serialize list[dict[int, Logprob]] to JSON-safe list[list[dict]]."""
if first_gen_log_probs is None:
return None
if not isinstance(first_gen_log_probs, list):
raise ValueError("first_gen_log_probs must be a list")
result = []
for i, pos in enumerate(first_gen_log_probs):
if not isinstance(pos, dict):
raise ValueError(
f"first_gen_log_probs[{i}] must be a dict, got {type(pos)}")
result.append([{
"token_id": tid,
"logprob": lp.logprob,
"rank": lp.rank
} for tid, lp in pos.items()])
return result


def _deserialize_first_gen_log_probs(
serialized: Optional[List], ) -> Optional[list]:
"""Deserialize JSON list[list[dict]] back to list[dict[int, Logprob]]."""
if serialized is None:
return None
from tensorrt_llm.executor.result import Logprob
result = []
for i, pos in enumerate(serialized):
if not isinstance(pos, list):
raise ValueError(
f"first_gen_log_probs[{i}] must be a list, got {type(pos)}")
token_map = {}
for j, item in enumerate(pos):
if not isinstance(item, dict):
raise ValueError(
f"first_gen_log_probs[{i}][{j}] must be a dict")
if "token_id" not in item or "logprob" not in item:
raise ValueError(
f"first_gen_log_probs[{i}][{j}] missing required keys "
"'token_id' and/or 'logprob'")
token_map[item["token_id"]] = Logprob(logprob=item["logprob"],
rank=item.get("rank"))
result.append(token_map)
return result


def to_disaggregated_params(
tllm_disagg_params: LlmDisaggregatedParams) -> DisaggregatedParams:
if tllm_disagg_params is None:
return None
return DisaggregatedParams(
request_type=tllm_disagg_params.request_type,
first_gen_tokens=tllm_disagg_params.first_gen_tokens,
first_gen_log_probs=_serialize_first_gen_log_probs(
tllm_disagg_params.first_gen_log_probs),
ctx_request_id=tllm_disagg_params.ctx_request_id,
encoded_opaque_state=encode_opaque_state(
tllm_disagg_params.opaque_state),
Expand All @@ -1111,6 +1160,8 @@ def to_llm_disaggregated_params(
return LlmDisaggregatedParams(
request_type=disaggregated_params.request_type,
first_gen_tokens=disaggregated_params.first_gen_tokens,
first_gen_log_probs=_deserialize_first_gen_log_probs(
disaggregated_params.first_gen_log_probs),
ctx_request_id=disaggregated_params.ctx_request_id,
opaque_state=decode_opaque_state(
disaggregated_params.encoded_opaque_state),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -510,5 +510,111 @@ def test_disaggregated_spec_dec_batch_slot_limit(model, spec_dec_model_path,
print("All workers terminated.")


@pytest.mark.parametrize("model", ["TinyLlama-1.1B-Chat-v1.0"])
@pytest.mark.parametrize("generation_overlap", [False, True])
def test_disaggregated_logprobs(model, generation_overlap):
"""Verify that logprobs propagate correctly from prefill to decode.
Ensures first_gen_log_probs is carried in DisaggregatedParams
so the generation_only worker receives one logprob per token.
"""
worker_pytorch_configs = [
dict(disable_overlap_scheduler=True),
dict(disable_overlap_scheduler=not generation_overlap),
]

kv_cache_configs = [KvCacheConfig(max_tokens=2048 * 8) for _ in range(2)]
cache_transceiver_configs = [
CacheTransceiverConfig(backend="DEFAULT") for _ in range(2)
]
model_names = [model_path(model) for _ in range(2)]
ranks = [0, 1]
worker_args = list(
zip(kv_cache_configs, cache_transceiver_configs, worker_pytorch_configs,
model_names, ranks))

port_name = mpi_publish_name()
max_tokens = 10
prompt = "What is the capital of Germany?"

with MPIPoolExecutor(max_workers=2,
env={
"UCX_TLS": "^ib,gdr_copy",
"UCX_MM_ERROR_HANDLING": "y"
}) as executor:
futures = []
try:
for worker_arg in worker_args:
future = executor.submit(worker_entry_point, *worker_arg)
futures.append(future)
except Exception as e:
print(f"Error in worker {worker_arg}: {e}")
raise e

intercomm = None
try:
intercomm = mpi_initialize_intercomm(port_name)
for _ in range(2):
intercomm.recv(tag=MPI_READY)

# --- Context-only phase (prefill) with logprobs ---
ctx_requests = [(prompt,
SamplingParams(max_tokens=max_tokens,
ignore_eos=True,
logprobs=1),
DisaggregatedParams(request_type="context_only"))]

ctx_responses = send_requests_to_worker(ctx_requests, 0, intercomm)
ctx_output = ctx_responses[0][0]

assert ctx_output.disaggregated_params is not None
assert ctx_output.disaggregated_params.request_type == "context_only"
assert len(ctx_output.token_ids) == 1

# The context phase must populate first_gen_log_probs.
dp = ctx_output.disaggregated_params
assert dp.first_gen_log_probs is not None, (
"first_gen_log_probs should be populated by the context phase")
assert len(dp.first_gen_log_probs) >= 1
for lp_entry in dp.first_gen_log_probs:
assert isinstance(lp_entry, dict)
for token_id, logprob_obj in lp_entry.items():
assert isinstance(token_id, int)
assert logprob_obj.logprob <= 0.0, (
"Log probabilities must be non-positive")

# --- Generation-only phase (decode) with logprobs ---
dp.request_type = "generation_only"
gen_requests = [(prompt,
SamplingParams(max_tokens=max_tokens,
ignore_eos=True,
logprobs=1), dp)]

gen_responses = send_requests_to_worker(gen_requests, 1, intercomm)
gen_output = gen_responses[0][0]

# Without first_gen_log_probs propagation this either crashes
# (AttributeError) or returns fewer logprobs than tokens.
assert gen_output.logprobs is not None, (
"Generation phase should return logprobs")
assert len(gen_output.logprobs) == len(gen_output.token_ids), (
f"Expected one logprob per token: got {len(gen_output.logprobs)}"
f" logprobs for {len(gen_output.token_ids)} tokens")

for pos_idx, lp_entry in enumerate(gen_output.logprobs):
assert isinstance(
lp_entry, dict), (f"logprobs[{pos_idx}] should be a dict")
for token_id, logprob_obj in lp_entry.items():
assert isinstance(token_id, int)
assert logprob_obj.logprob <= 0.0

except Exception as e:
print(f"Exception encountered: {e}", flush=True)
raise e
finally:
mpi_send_termination_request(intercomm)
for future in futures:
future.result()


if __name__ == "__main__":
pytest.main()
2 changes: 2 additions & 0 deletions tests/integration/test_lists/qa/llm_function_core.txt
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,8 @@ disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_qwen3[
disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_qwen3[False-True-Qwen3-8B-FP8]
disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_qwen3[True-False-Qwen3-8B-FP8]
disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_simple_qwen3[True-True-Qwen3-8B-FP8]
disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_logprobs[False-TinyLlama-1.1B-Chat-v1.0]
disaggregated/test_disaggregated_single_gpu.py::test_disaggregated_logprobs[True-TinyLlama-1.1B-Chat-v1.0]
disaggregated/test_disaggregated.py::test_disaggregated_ctxtp2pp2_gentp2pp2[TinyLlama-1.1B-Chat-v1.0]
disaggregated/test_disaggregated.py::test_disaggregated_ctxpp4_genpp4[TinyLlama-1.1B-Chat-v1.0]
disaggregated/test_disaggregated.py::test_disaggregated_kv_cache_time_output[TinyLlama-1.1B-Chat-v1.0]
Expand Down
Loading