Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 29 additions & 3 deletions tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,6 +519,23 @@ def request_finished(self, req: LlmRequest,

return saving_async

def abort_request(self, req: LlmRequest) -> None:
"""Remove a request from all async tracking structures.
Must be called when a request is cancelled or fails to prevent
worker.get_finished() from blocking on transfers that will never complete.
Args:
req: The request to abort.
"""
req_id = req.request_id
for store in (self.new_async_requests, self.pending_async_requests,
self.local_finished_async_requests):
store.loading.pop(req_id, None)
store.saving.pop(req_id, None)
self.scheduler_output_manager.requests.pop(req_id, None)
self.scheduler_output_manager.external_loads.pop(req_id, None)

def get_finished(self) -> List[LlmRequest]:
"""
Process requests that have finished loading and saving.
Expand All @@ -533,9 +550,18 @@ def get_finished(self) -> List[LlmRequest]:
self.pending_async_requests.add_from(self.new_async_requests)

# Pass these newly finished requests into get_finished, and get the list of requests that have finished saving and loading.
(finished_saving,
finished_loading) = self.worker.get_finished(finished_gen_req_ids,
started_loading_req_ids)
# Wrap in try/except so that a worker failure still reaches the mpi_allgather barrier,
# preventing a collective stall where other ranks wait indefinitely for this rank.
try:
(finished_saving, finished_loading) = self.worker.get_finished(
finished_gen_req_ids, started_loading_req_ids)
except Exception as e:
logger.error(
f"KV connector worker get_finished() raised an exception: {e}. "
"Reporting no completions this iteration to unblock mpi_allgather."
)
Comment on lines +558 to +562
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

logger is not imported — this will raise NameError at runtime.

The exception handler references logger.error(...), but logger is not imported in this file. This will cause a NameError when an exception occurs, defeating the purpose of the exception handling.

🐛 Proposed fix: Add the logger import

Add at the top of the file with other imports:

from tensorrt_llm.logger import logger
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/pyexecutor/kv_cache_connector.py` around lines 558 - 562,
The except block in get_finished() (the handler that calls logger.error on
exception 'e') references logger which is not imported; add the missing import
at the top of the module with the other imports by importing the shared logger
symbol (e.g., from tensorrt_llm.logger import logger) so the exception handler
and any other usages of logger resolve at runtime.

finished_saving = []
finished_loading = []

# Remove the requests from our pending list that have finished locally.
new_local_finished_async_requests = self.pending_async_requests.extract_by_id(
Expand Down
6 changes: 6 additions & 0 deletions tensorrt_llm/_torch/pyexecutor/py_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3175,6 +3175,8 @@ def _handle_errors(self,
]
self._enqueue_responses(list(error_responses.items()))
for request in failed_requests:
if self.kv_connector_manager is not None:
self.kv_connector_manager.abort_request(request)
self._terminate_request(request)

def _terminate_request(self, request: LlmRequest):
Expand Down Expand Up @@ -3202,6 +3204,10 @@ def _try_cancel_request(self, request) -> bool:
Returns:
bool: True if the request can be canceled (either successfully cancelled or doesn't need cancellation).
"""
if self.kv_connector_manager is not None:
self.kv_connector_manager.abort_request(request)
return True

if self.kv_cache_transceiver is None:
return True

Expand Down
60 changes: 60 additions & 0 deletions tests/unittest/_torch/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,66 @@ def test():
run_across_mpi(mpi_pool_executor, test, 2)


def test_abort_request_cleans_async_state():
"""Verify abort_request removes a request from all async tracking structures."""
worker = MagicMock()
scheduler = MagicMock()

manager = KvCacheConnectorManager(worker, scheduler=scheduler)

req = MagicMock()
req.request_id = 42

# Manually populate every tracking structure as if a transfer was in flight.
for store in (manager.new_async_requests, manager.pending_async_requests,
manager.local_finished_async_requests):
store.loading[42] = req
store.saving[42] = req
manager.scheduler_output_manager.requests[42] = req
manager.scheduler_output_manager.external_loads[42] = req

manager.abort_request(req)

for store in (manager.new_async_requests, manager.pending_async_requests,
manager.local_finished_async_requests):
assert 42 not in store.loading
assert 42 not in store.saving
assert 42 not in manager.scheduler_output_manager.requests
assert 42 not in manager.scheduler_output_manager.external_loads


@pytest.mark.parametrize("mpi_pool_executor", [2], indirect=True)
@pytest.mark.threadleak(enabled=False)
def test_connector_manager_get_finished_exception_unblocks_allgather(
mpi_pool_executor):
"""Verify that a worker.get_finished() exception on rank 0 does not cause a collective stall."""

def test():
worker = MagicMock()

if mpi_rank() == 0:
scheduler = MagicMock()
scheduler.request_finished.return_value = True
# Simulate a Rust/KVBM panic on rank 0.
worker.get_finished.side_effect = RuntimeError(
"simulated KVBM worker failure")
else:
scheduler = None
worker.get_finished.return_value = ([], [])

manager = KvCacheConnectorManager(worker, scheduler=scheduler)

req = MagicMock()
req.request_id = 42
manager.request_finished(req, [])

# Should not raise, return empty, and both ranks must reach this point.
result = manager.get_finished()
assert result == []

run_across_mpi(mpi_pool_executor, test, 2)


def test_scheduler_output_num_scheduled_tokens_with_mtp():
"""Test that num_scheduled_tokens is correctly set for MTP (multi-token prediction)."""
NUM_DRAFT_TOKENS = 3
Expand Down
Loading