Skip to content

Commit e94505c

Browse files
stniegreg-kwasniewski1
authored andcommitted
[https://nvbugs/5764627][fix] Fix generation logits with streaming and improve runtime of logits testcase. Also fixes https://nvbugs/5573238 (NVIDIA#10637)
Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
1 parent 966405b commit e94505c

File tree

4 files changed

+136
-158
lines changed

4 files changed

+136
-158
lines changed

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(
8484
self.use_device_memory = use_device_memory
8585
self.use_chunked_generation_logits = use_chunked_generation_logits
8686
self.chunk_size = chunk_size
87-
self._logits_indices = []
87+
self._logits_indices: list[tuple[int, int]] = []
8888

8989
# Lazily initialized by _init() upon first append()
9090
self._storage: torch.Tensor | None = None
@@ -768,13 +768,35 @@ def create_response(self,
768768
result, is_final = super().create_serialized_result(
769769
use_fast_logits, mpi_world_rank)
770770

771-
# Performs a deep copy of py_result._log_probs to eliminate race conditions that may occur between IPC communication and the overriding of newly generated log_probs in streaming mode.
772-
if self.streaming and self.py_result.log_probs and self.sampling_config.beam_width <= 1:
771+
# When using beam search we cannot incrementically update the logprobs in the result.
772+
# Instead we need to update all logprobs. In that case no deep copy is needed.
773+
need_deep_copy_logprobs = self.py_result.log_probs and self.sampling_config.beam_width <= 1
774+
need_deep_copy_generation_logits = self.py_result._generation_logits is not None
775+
need_any_deep_copy = need_deep_copy_logprobs or need_deep_copy_generation_logits
776+
# Performs a deep copy of py_result._log_probs or py_result._generation_logits to eliminate race conditions
777+
# that may occur between IPC communication and the overriding of newly generated log_probs
778+
# or the updating of py_result._generation_logits._logits_indices in streaming mode.
779+
if self.streaming and need_any_deep_copy:
773780
py_result = copy(self.py_result)
774-
py_result._log_probs = deepcopy(self.py_result._log_probs)
775-
776-
for log_prob in self.py_result.log_probs:
777-
log_prob.clear()
781+
# Move _log_probs to py_result and create a new empty LogProbStorage in self.py_result
782+
# This avoids performing a deepcopy
783+
if need_deep_copy_logprobs:
784+
py_result._log_probs = self.py_result._log_probs
785+
self.py_result._log_probs = LogProbStorage()
786+
# Initialize the storage and adjust the cum_log_probs to the previous value
787+
self.py_result._log_probs._init(py_result.log_probs)
788+
self.py_result._log_probs.cum_log_probs = py_result.cum_log_probs
789+
790+
# Perform copies of py_result._generation_logits
791+
if need_deep_copy_generation_logits:
792+
# shallow copy of generation_logits to avoid copying the logits tensor
793+
py_result._generation_logits = copy(
794+
self.py_result._generation_logits)
795+
# deep copy the indices to avoid the race condition
796+
# In streaming mode LogitsStorage only accesses either the last
797+
# or second to last pair of indices. Therefore, copying only these two pairs is sufficient.
798+
py_result._generation_logits._logits_indices = py_result._generation_logits._logits_indices[
799+
-2:]
778800
else:
779801
py_result = self.py_result
780802

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1369,6 +1369,9 @@ def _executor_loop_pp(self):
13691369
scheduled_batch, guided_decoder_failed_requests)
13701370

13711371
self._update_request_states(scheduled_batch)
1372+
if not self.disable_overlap_scheduler:
1373+
self._update_generation_requests_that_will_complete_next_iteration(
1374+
scheduled_batch.generation_requests)
13721375

13731376
if self.enable_iter_perf_stats:
13741377
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
@@ -2179,6 +2182,13 @@ def _executor_loop_overlap(self):
21792182
else:
21802183
self._enqueue_responses([])
21812184

2185+
# Call set_exclude_last_generation_logits after _process_previous_batch.
2186+
# If set before, the response of a request may be incorrect, as it will
2187+
# use the wrong indices for generation logits when streaming is enabled.
2188+
if can_queue:
2189+
self._update_generation_requests_that_will_complete_next_iteration(
2190+
scheduled_batch.generation_requests)
2191+
21822192
if can_queue:
21832193
if self.enable_iter_perf_stats:
21842194
iter_stats.inflight_batching_stats.num_ctx_tokens = self.model_engine.iter_states[
@@ -3013,6 +3023,19 @@ def forward(scheduled_requests, resource_manager, new_tensors_device,
30133023
self._handle_errors(error_msg)
30143024
return None
30153025

3026+
def _update_generation_requests_that_will_complete_next_iteration(
3027+
self, generation_requests: list[LlmRequest]):
3028+
""" Update the generation requests that will complete next iteration.
3029+
3030+
If overlap scheduling is enabled, we need update the state of generation requests that will complete next iteration
3031+
and adjust the exclude_last_generation_logits flag accordingly.
3032+
"""
3033+
for request in generation_requests:
3034+
if request.state != LlmRequestState.GENERATION_COMPLETE and request.will_complete_next_iteration(
3035+
):
3036+
request.set_exclude_last_generation_logits(False)
3037+
request.state = LlmRequestState.GENERATION_TO_COMPLETE
3038+
30163039
def _update_request_states_tp(self, scheduled_requests: ScheduledRequests):
30173040
# handle potential attention dp dummy request
30183041
if self.active_requests and self.active_requests[
@@ -3038,13 +3061,6 @@ def _update_request_states_tp(self, scheduled_requests: ScheduledRequests):
30383061
else:
30393062
request.state = LlmRequestState.GENERATION_IN_PROGRESS
30403063

3041-
for request in scheduled_requests.generation_requests:
3042-
if request.state != LlmRequestState.GENERATION_COMPLETE:
3043-
if not self.disable_overlap_scheduler and request.will_complete_next_iteration(
3044-
):
3045-
request.set_exclude_last_generation_logits(False)
3046-
request.state = LlmRequestState.GENERATION_TO_COMPLETE
3047-
30483064
def _update_request_states_star_attention(
30493065
self, scheduled_requests: ScheduledRequests):
30503066
for request in scheduled_requests.context_requests:

tests/integration/test_lists/waives.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -200,9 +200,6 @@ unittest/_torch/speculative/test_dynamic_spec_decode.py::test_dynamic_spec_decod
200200
triton_server/test_triton.py::test_gpt_disaggregated_serving_bls[gpt-disaggregated-serving-bls] SKIP (https://nvbugs/5582118)
201201
triton_server/test_triton.py::test_gpt_speculative_decoding[gpt-speculative-decoding] SKIP (https://nvbugs/5762854)
202202
accuracy/test_llm_api_pytorch.py::TestLlama3_1_8B_Instruct_RocketKV::test_auto_dtype SKIP (https://nvbugs/5762822)
203-
unittest/_torch/sampler/test_return_logits.py SKIP (https://nvbugs/5764627)
204-
unittest/_torch/sampler/test_logits_logprobs.py::test_generate_with_return_logits SKIP (https://nvbugs/5764627)
205-
unittest/_torch/sampler/test_logits_logprobs.py::test_generate_async_with_return_logits SKIP (https://nvbugs/5764627)
206203
examples/test_ray.py::test_ray_disaggregated_serving[tp2] SKIP (https://nvbugs/5612502)
207204
unittest/executor/test_rpc_proxy.py SKIP (https://nvbugs/5605741)
208205
unittest/executor/test_rpc_worker.py SKIP (https://nvbugs/5605741)

0 commit comments

Comments
 (0)