Skip to content

Commit 123a3bd

Browse files
committed
Refactor deep copy logic for generation logits in LlmRequest.
- Adjusted the condition for deep copying generation logits, ensuring only necessary copies are made. - Updated comments for clarity on the copying process of generation logits and their indices to enhance understanding of the logic. Signed-off-by: Stefan Niebler <82932102+stnie@users.noreply.github.com>
1 parent dbe70c1 commit 123a3bd

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -654,10 +654,16 @@ def create_response(self,
654654
for log_prob in self.py_result.log_probs:
655655
log_prob.clear()
656656

657-
# Perform a deep copy of py_result._generation_logits
657+
# Perform copies of py_result._generation_logits
658658
if need_deep_copy_generation_logits:
659-
py_result._generation_logits = deepcopy(
659+
# shallow copy of generation_logits to avoid copying the logits tensor
660+
py_result._generation_logits = copy(
660661
self.py_result._generation_logits)
662+
# deep copy the indices to avoid the race condition
663+
# In streaming mode LogitsStorage only accesses either the last
664+
# or second to last pair of indices. Therefore, copying only these two pairs is sufficient.
665+
py_result._generation_logits._logits_indices = py_result._generation_logits._logits_indices[
666+
-2:]
661667
else:
662668
py_result = self.py_result
663669

0 commit comments

Comments
 (0)