Skip to content
36 changes: 29 additions & 7 deletions tensorrt_llm/_torch/pyexecutor/llm_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(
self.use_device_memory = use_device_memory
self.use_chunked_generation_logits = use_chunked_generation_logits
self.chunk_size = chunk_size
self._logits_indices = []
self._logits_indices: list[tuple[int, int]] = []

# Lazily initialized by _init() upon first append()
self._storage: torch.Tensor | None = None
Expand Down Expand Up @@ -765,13 +765,35 @@ def create_response(self,
result, is_final = super().create_serialized_result(
use_fast_logits, mpi_world_rank)

# Performs a deep copy of py_result._log_probs to eliminate race conditions that may occur between IPC communication and the overriding of newly generated log_probs in streaming mode.
if self.streaming and self.py_result.log_probs and self.sampling_config.beam_width <= 1:
# When using beam search we cannot incrementically update the logprobs in the result.
# Instead we need to update all logprobs. In that case no deep copy is needed.
need_deep_copy_logprobs = self.py_result.log_probs and self.sampling_config.beam_width <= 1
need_deep_copy_generation_logits = self.py_result._generation_logits is not None
need_any_deep_copy = need_deep_copy_logprobs or need_deep_copy_generation_logits
# Performs a deep copy of py_result._log_probs or py_result._generation_logits to eliminate race conditions
# that may occur between IPC communication and the overriding of newly generated log_probs
# or the updating of py_result._generation_logits._logits_indices in streaming mode.
if self.streaming and need_any_deep_copy:
py_result = copy(self.py_result)
py_result._log_probs = deepcopy(self.py_result._log_probs)

for log_prob in self.py_result.log_probs:
log_prob.clear()
# Move _log_probs to py_result and create a new empty LogProbStorage in self.py_result
# This avoids performing a deepcopy
if need_deep_copy_logprobs:
py_result._log_probs = self.py_result._log_probs
self.py_result._log_probs = LogProbStorage()
# Initialize the storage and adjust the cum_log_probs to the previous value
self.py_result._log_probs._init(py_result.log_probs)
self.py_result._log_probs.cum_log_probs = py_result.cum_log_probs

# Perform copies of py_result._generation_logits
if need_deep_copy_generation_logits:
# shallow copy of generation_logits to avoid copying the logits tensor
py_result._generation_logits = copy(
self.py_result._generation_logits)
# deep copy the indices to avoid the race condition
# In streaming mode LogitsStorage only accesses either the last
# or second to last pair of indices. Therefore, copying only these two pairs is sufficient.
py_result._generation_logits._logits_indices = py_result._generation_logits._logits_indices[
-2:]
else:
py_result = self.py_result

Expand Down
30 changes: 23 additions & 7 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,6 +1369,9 @@ def _executor_loop_pp(self):
scheduled_batch, guided_decoder_failed_requests)

self._update_request_states(scheduled_batch)
if not self.disable_overlap_scheduler:
self._update_generation_requests_that_will_complete_next_iteration(
scheduled_batch.generation_requests)

if self.enable_iter_perf_stats:
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
Expand Down Expand Up @@ -2179,6 +2182,13 @@ def _executor_loop_overlap(self):
else:
self._enqueue_responses([])

# Call set_exclude_last_generation_logits after _process_previous_batch.
# If set before, the response of a request may be incorrect, as it will
# use the wrong indices for generation logits when streaming is enabled.
if can_queue:
self._update_generation_requests_that_will_complete_next_iteration(
scheduled_batch.generation_requests)

if can_queue:
if self.enable_iter_perf_stats:
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
Expand Down Expand Up @@ -3000,6 +3010,19 @@ def forward(scheduled_requests, resource_manager, new_tensors_device,
self._handle_errors(error_msg)
return None

def _update_generation_requests_that_will_complete_next_iteration(
self, generation_requests: list[LlmRequest]):
""" Update the generation requests that will complete next iteration.

If overlap scheduling is enabled, we need update the state of generation requests that will complete next iteration
and adjust the exclude_last_generation_logits flag accordingly.
"""
for request in generation_requests:
if request.state != LlmRequestState.GENERATION_COMPLETE and request.will_complete_next_iteration(
):
request.set_exclude_last_generation_logits(False)
request.state = LlmRequestState.GENERATION_TO_COMPLETE

def _update_request_states_tp(self, scheduled_requests: ScheduledRequests):
# handle potential attention dp dummy request
if self.active_requests and self.active_requests[
Expand All @@ -3025,13 +3048,6 @@ def _update_request_states_tp(self, scheduled_requests: ScheduledRequests):
else:
request.state = LlmRequestState.GENERATION_IN_PROGRESS

for request in scheduled_requests.generation_requests:
if request.state != LlmRequestState.GENERATION_COMPLETE:
if not self.disable_overlap_scheduler and request.will_complete_next_iteration(
):
request.set_exclude_last_generation_logits(False)
request.state = LlmRequestState.GENERATION_TO_COMPLETE

def _update_request_states_star_attention(
self, scheduled_requests: ScheduledRequests):
for request in scheduled_requests.context_requests:
Expand Down
3 changes: 0 additions & 3 deletions tests/integration/test_lists/waives.txt
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,6 @@ unittest/_torch/speculative/test_dynamic_spec_decode.py::test_dynamic_spec_decod
triton_server/test_triton.py::test_gpt_disaggregated_serving_bls[gpt-disaggregated-serving-bls] SKIP (https://nvbugs/5582118)
triton_server/test_triton.py::test_gpt_speculative_decoding[gpt-speculative-decoding] SKIP (https://nvbugs/5762854)
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B_Instruct_RocketKV::test_auto_dtype SKIP (https://nvbugs/5762822)
unittest/_torch/sampler/test_return_logits.py SKIP (https://nvbugs/5764627)
unittest/_torch/sampler/test_logits_logprobs.py::test_generate_with_return_logits SKIP (https://nvbugs/5764627)
unittest/_torch/sampler/test_logits_logprobs.py::test_generate_async_with_return_logits SKIP (https://nvbugs/5764627)
examples/test_ray.py::test_ray_disaggregated_serving[tp2] SKIP (https://nvbugs/5612502)
unittest/executor/test_rpc_proxy.py SKIP (https://nvbugs/5605741)
unittest/executor/test_rpc_worker.py SKIP (https://nvbugs/5605741)
Expand Down
Loading
Loading