Skip to content

Commit 04b3472

Browse files
committed
Refactor LlmRequest and PyExecutor for improved state management and deep copy logic.
- Added a comment to clarify the condition for deep copying in LlmRequest. - Adjusted logic for state management in PyExecutor for generation requests to skip unnecessary checks. - Added assertions in unit tests to validate expected behavior in generation scenarios. Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
1 parent dbe7722 commit 04b3472

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -639,8 +639,10 @@ def create_response(self,
639639
result, is_final = super().create_serialized_result(
640640
use_fast_logits, mpi_world_rank)
641641

642+
# When using beam search we cannot incrementically update the logprobs in the result.
643+
# Instead we need to update all logprobs. In that case no deep copy is needed.
642644
need_deep_copy_logprobs = self.py_result.log_probs and self.sampling_config.beam_width <= 1
643-
need_deep_copy_generation_logits = self.py_result._generation_logits
645+
need_deep_copy_generation_logits = self.py_result._generation_logits is not None
644646
need_any_deep_copy = need_deep_copy_logprobs or need_deep_copy_generation_logits
645647
# Performs a deep copy of py_result._log_probs or py_result._generation_logits to eliminate race conditions
646648
# that may occur between IPC communication and the overriding of newly generated log_probs

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1757,11 +1757,10 @@ def _executor_loop_overlap(self):
17571757
# If set before, the response of a request may be incorrect, as it will
17581758
# use the wrong indices for generation logits when streaming is enabled.
17591759
for request in scheduled_batch.generation_requests:
1760-
if request.state != LlmRequestState.GENERATION_COMPLETE:
1761-
if not self.disable_overlap_scheduler and request.will_complete_next_iteration(
1762-
):
1763-
request.set_exclude_last_generation_logits(False)
1764-
request.state = LlmRequestState.GENERATION_TO_COMPLETE
1760+
if request.state != LlmRequestState.GENERATION_COMPLETE and request.will_complete_next_iteration(
1761+
):
1762+
request.set_exclude_last_generation_logits(False)
1763+
request.state = LlmRequestState.GENERATION_TO_COMPLETE
17651764

17661765
if can_queue:
17671766
if self.enable_iter_perf_stats:

tests/unittest/_torch/sampler/test_return_logits.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ def test_generation_with_return_logits(
179179
idx=idx,
180180
output=output,
181181
streaming=True)
182+
assert idx == sampling_params.max_tokens - 1
182183
else:
183184
for idx, output in enumerate(
184185
llm.generate(
@@ -197,3 +198,4 @@ def test_generation_with_return_logits(
197198
idx=idx,
198199
output=output,
199200
streaming=False)
201+
assert idx == len(prompts) - 1

0 commit comments

Comments
 (0)