Skip to content

Commit 02dcf1d

Browse files
committed
resolve comments
1 parent a6a1a00 commit 02dcf1d

11 files changed

+93
-69
lines changed

README.md

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1418,13 +1418,10 @@ will be automatically deallocated. This can increase the number of BLS requests
14181418
that you can execute in your model without running into the out of GPU or
14191419
shared memory error.
14201420

1421-
Starting from the 25.04 release, you can use the `infer_responses.cancel()` function
1422-
on a BLS decoupled response iterator to stop the response stream, which cancels
1423-
the request to the decoupled model. This is useful for stopping long inference
1424-
requests, such as those from auto-generative large language models, which may
1425-
run for an indeterminate amount of time and consume significant server resources.
1426-
The response iterator can be generated from `infer_request.exec(decoupled=True)`
1427-
and `infer_request.async_exec(decoupled=True)` functions:
1421+
### Cancelling decoupled BLS requests
1422+
A decoupled BLS inference request may be cancelled by calling the `cancel()`
1423+
method on the response iterator returned from the method executing the BLS
1424+
inference request. For example,
14281425

14291426
```python
14301427
import triton_python_backend_utils as pb_utils
@@ -1433,12 +1430,12 @@ class TritonPythonModel:
14331430
...
14341431
def execute(self, requests):
14351432
...
1436-
inference_request = pb_utils.InferenceRequest(
1433+
infer_request = pb_utils.InferenceRequest(
14371434
model_name='model_name',
14381435
requested_output_names=['REQUESTED_OUTPUT'],
14391436
inputs=[<pb_utils.Tensor object>])
14401437

1441-
# Execute the inference_request and wait for the response. Here we are
1438+
# Execute the infer_request and wait for the response. Here we are
14421439
# running a BLS request on a decoupled model, hence setting the parameter
14431440
# 'decoupled' to 'True'.
14441441
infer_responses = infer_request.exec(decoupled=True)
@@ -1449,14 +1446,14 @@ class TritonPythonModel:
14491446
# vLLM backend uses the CANCELLED error code when a request is cancelled.
14501447
# TensorRT-LLM backend does not use error codes; instead, it sends the
14511448
# TRITONSERVER_RESPONSE_COMPLETE_FINAL flag to the iterator.
1452-
if inference_response.has_error():
1449+
if infer_response.has_error():
14531450
if infer_response.error().code() == pb_utils.TritonError.CANCELLED:
14541451
print("request has been cancelled.")
14551452
break
14561453

14571454
# Collect the output tensor from the model's response
14581455
output = pb_utils.get_output_tensor_by_name(
1459-
inference_response, 'REQUESTED_OUTPUT')
1456+
infer_response, 'REQUESTED_OUTPUT')
14601457
response_tensors_received.append(output)
14611458

14621459
# Check if we have received enough inference output tensors

src/infer_payload.cc

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ InferPayload::InferPayload(
3232
const bool is_decoupled,
3333
std::function<void(std::unique_ptr<InferResponse>)> callback)
3434
: is_decoupled_(is_decoupled), is_promise_set_(false), callback_(callback),
35+
is_request_deleted_(false),
3536
request_address_(reinterpret_cast<intptr_t>(nullptr))
3637
{
3738
promise_.reset(new std::promise<std::unique_ptr<InferResponse>>());
@@ -104,4 +105,35 @@ InferPayload::GetRequestAddress()
104105
return request_address_;
105106
}
106107

108+
void
109+
InferPayload::SetRequestDeleted()
110+
{
111+
std::unique_lock<std::mutex> lock(request_deletion_mutex_);
112+
is_request_deleted_ = true;
113+
}
114+
115+
void
116+
InferPayload::SetRequestCancellationFunc(
117+
const std::function<void(intptr_t)>& request_cancel_func)
118+
{
119+
request_cancel_func_ = request_cancel_func;
120+
}
121+
122+
void
123+
InferPayload::SafeCancelRequest()
124+
{
125+
std::unique_lock<std::mutex> lock(request_deletion_mutex_);
126+
if (is_request_deleted_) {
127+
return;
128+
}
129+
130+
if (request_address_ == 0L) {
131+
return;
132+
}
133+
134+
if (request_cancel_func_) {
135+
request_cancel_func_(request_address_);
136+
}
137+
}
138+
107139
}}} // namespace triton::backend::python

src/infer_payload.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,12 @@ class InferPayload : public std::enable_shared_from_this<InferPayload> {
6262
void SetResponseAllocUserp(
6363
const ResponseAllocatorUserp& response_alloc_userp);
6464
std::shared_ptr<ResponseAllocatorUserp> ResponseAllocUserp();
65+
void SetRequestDeleted();
6566
void SetRequestAddress(intptr_t request_address);
6667
intptr_t GetRequestAddress();
68+
void SetRequestCancellationFunc(
69+
const std::function<void(intptr_t)>& request_cancel_func);
70+
void SafeCancelRequest();
6771

6872
private:
6973
std::unique_ptr<std::promise<std::unique_ptr<InferResponse>>> promise_;
@@ -72,7 +76,10 @@ class InferPayload : public std::enable_shared_from_this<InferPayload> {
7276
bool is_promise_set_;
7377
std::function<void(std::unique_ptr<InferResponse>)> callback_;
7478
std::shared_ptr<ResponseAllocatorUserp> response_alloc_userp_;
79+
std::mutex request_deletion_mutex_;
80+
bool is_request_deleted_;
7581
intptr_t request_address_;
82+
std::function<void(intptr_t)> request_cancel_func_;
7683
};
7784

7885
}}} // namespace triton::backend::python

src/ipc_message.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ typedef enum PYTHONSTUB_commandtype_enum {
6868
PYTHONSTUB_UnloadModelRequest,
6969
PYTHONSTUB_ModelReadinessRequest,
7070
PYTHONSTUB_IsRequestCancelled,
71-
PYTHONSTUB_CancelBLSDecoupledInferRequest
71+
PYTHONSTUB_CancelBLSInferRequest
7272
} PYTHONSTUB_CommandType;
7373

7474
///

src/pb_bls_cancel.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ PbBLSCancel::Cancel()
7272
return;
7373
}
7474

75-
stub->EnqueueCancelBLSDecoupledRequest(this);
75+
stub->EnqueueCancelBLSRequest(this);
7676
updating_ = true;
7777
}
7878
cv_.wait(lk, [this] { return !updating_; });

src/pb_stub.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1137,9 +1137,8 @@ Stub::ServiceStubToParentRequests()
11371137
utils_msg_payload->command_type == PYTHONSTUB_IsRequestCancelled) {
11381138
SendIsCancelled(utils_msg_payload);
11391139
} else if (
1140-
utils_msg_payload->command_type ==
1141-
PYTHONSTUB_CancelBLSDecoupledInferRequest) {
1142-
SendCancelBLSDecoupledRequest(utils_msg_payload);
1140+
utils_msg_payload->command_type == PYTHONSTUB_CancelBLSInferRequest) {
1141+
SendCancelBLSRequest(utils_msg_payload);
11431142
} else {
11441143
std::cerr << "Error when sending message via stub_to_parent message "
11451144
"buffer - unknown command\n";
@@ -1226,7 +1225,7 @@ Stub::EnqueueCleanupId(void* id, const PYTHONSTUB_CommandType& command_type)
12261225
}
12271226

12281227
void
1229-
Stub::SendCancelBLSDecoupledRequest(
1228+
Stub::SendCancelBLSRequest(
12301229
std::unique_ptr<UtilsMessagePayload>& utils_msg_payload)
12311230
{
12321231
PbBLSCancel* pb_bls_cancel =
@@ -1256,11 +1255,11 @@ Stub::SendCancelBLSDecoupledRequest(
12561255
}
12571256

12581257
void
1259-
Stub::EnqueueCancelBLSDecoupledRequest(PbBLSCancel* pb_bls_cancel)
1258+
Stub::EnqueueCancelBLSRequest(PbBLSCancel* pb_bls_cancel)
12601259
{
12611260
std::unique_ptr<UtilsMessagePayload> utils_msg_payload =
12621261
std::make_unique<UtilsMessagePayload>(
1263-
PYTHONSTUB_CancelBLSDecoupledInferRequest,
1262+
PYTHONSTUB_CancelBLSInferRequest,
12641263
reinterpret_cast<void*>(pb_bls_cancel));
12651264
EnqueueUtilsMessage(std::move(utils_msg_payload));
12661265
}

src/pb_stub.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -322,12 +322,12 @@ class Stub {
322322
void EnqueueCleanupId(void* id, const PYTHONSTUB_CommandType& command_type);
323323

324324
/// Send the id to the python backend for object cleanup
325-
void SendCancelBLSDecoupledRequest(
325+
void SendCancelBLSRequest(
326326
std::unique_ptr<UtilsMessagePayload>& utils_msg_payload);
327327

328328
/// Add infer payload id to queue. This is used for retrieving the request
329329
/// address from the infer_payload
330-
void EnqueueCancelBLSDecoupledRequest(PbBLSCancel* pb_bls_cancel);
330+
void EnqueueCancelBLSRequest(PbBLSCancel* pb_bls_cancel);
331331

332332
/// Add request cancellation query to queue
333333
void EnqueueIsCancelled(PbCancel* pb_cancel);

src/python_be.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -765,8 +765,8 @@ ModelInstanceState::StubToParentMQMonitor()
765765
boost::asio::post(*thread_pool_, std::move(task));
766766
break;
767767
}
768-
case PYTHONSTUB_CancelBLSDecoupledInferRequest: {
769-
ProcessCancelBLSDecoupledRequest(message);
768+
case PYTHONSTUB_CancelBLSInferRequest: {
769+
ProcessCancelBLSRequest(message);
770770
break;
771771
}
772772
default: {
@@ -860,7 +860,7 @@ ModelInstanceState::ProcessCleanupRequest(
860860
}
861861

862862
void
863-
ModelInstanceState::ProcessCancelBLSDecoupledRequest(
863+
ModelInstanceState::ProcessCancelBLSRequest(
864864
const std::unique_ptr<IPCMessage>& message)
865865
{
866866
AllocatedSharedMemory<CancelBLSRequestMessage> message_shm =
@@ -876,7 +876,7 @@ ModelInstanceState::ProcessCancelBLSDecoupledRequest(
876876
{
877877
std::lock_guard<std::mutex> lock(infer_payload_mu_);
878878
if (infer_payload_.find(id) != infer_payload_.end()) {
879-
request_executor_->Cancel(infer_payload_[id]);
879+
infer_payload_[id]->SafeCancelRequest();
880880
}
881881
}
882882
message_payload->is_cancelled = true;

src/python_be.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -403,9 +403,8 @@ class ModelInstanceState : public BackendModelInstance {
403403
// Process the decoupled cleanup request for InferPayload and ResponseFactory
404404
void ProcessCleanupRequest(const std::unique_ptr<IPCMessage>& message);
405405

406-
// Process cancelling a BLS decoupled request
407-
void ProcessCancelBLSDecoupledRequest(
408-
const std::unique_ptr<IPCMessage>& message);
406+
// Process cancelling a BLS request
407+
void ProcessCancelBLSRequest(const std::unique_ptr<IPCMessage>& message);
409408

410409
// Process request cancellation query
411410
void ProcessIsRequestCancelled(const std::unique_ptr<IPCMessage>& message);

src/request_executor.cc

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -69,12 +69,15 @@ InferRequestComplete(
6969
TRITONSERVER_InferenceRequest* request, const uint32_t flags, void* userp)
7070
{
7171
if (request != nullptr) {
72-
auto request_executor = reinterpret_cast<RequestExecutor*>(userp);
73-
request_executor->EraseRequestAddress(reinterpret_cast<intptr_t>(request));
72+
RequestCompletionUserp* completion_userp =
73+
reinterpret_cast<RequestCompletionUserp*>(userp);
74+
completion_userp->infer_payload->SetRequestDeleted();
7475

7576
LOG_IF_ERROR(
7677
TRITONSERVER_InferenceRequestDelete(request),
7778
"Failed to delete inference request.");
79+
80+
delete completion_userp;
7881
}
7982
}
8083

@@ -322,6 +325,18 @@ ResponseAlloc(
322325
return nullptr; // Success
323326
}
324327

328+
void
329+
InferRequestCancel(intptr_t request_address)
330+
{
331+
if (request_address == 0L) {
332+
return;
333+
}
334+
335+
TRITONSERVER_InferenceRequest* irequest =
336+
reinterpret_cast<TRITONSERVER_InferenceRequest*>(request_address);
337+
THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestCancel(irequest));
338+
}
339+
325340
TRITONSERVER_Error*
326341
OutputBufferQuery(
327342
TRITONSERVER_ResponseAllocator* allocator, void* userp,
@@ -364,6 +379,7 @@ RequestExecutor::Infer(
364379
bool is_ready = false;
365380
const char* model_name = infer_request->ModelName().c_str();
366381
TRITONSERVER_InferenceRequest* irequest = nullptr;
382+
RequestCompletionUserp* completion_userp = nullptr;
367383

368384
try {
369385
int64_t model_version = infer_request->ModelVersion();
@@ -415,8 +431,10 @@ RequestExecutor::Infer(
415431
THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestSetTimeoutMicroseconds(
416432
irequest, infer_request->Timeout()));
417433

434+
completion_userp = new RequestCompletionUserp(infer_payload);
418435
THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestSetReleaseCallback(
419-
irequest, InferRequestComplete, reinterpret_cast<void*>(this)));
436+
irequest, InferRequestComplete,
437+
reinterpret_cast<void*>(completion_userp)));
420438

421439
TRITONSERVER_InferenceTrace* trace = nullptr;
422440
if (infer_request->GetTrace().TritonTrace() != nullptr) {
@@ -485,22 +503,20 @@ RequestExecutor::Infer(
485503
reinterpret_cast<void*>(infer_payload->ResponseAllocUserp().get()),
486504
InferResponseComplete, reinterpret_cast<void*>(infer_payload.get())));
487505

488-
{
489-
std::lock_guard<std::mutex> lk(on_going_request_addresses_mu_);
490-
on_going_request_addresses_.insert(
491-
reinterpret_cast<intptr_t>(irequest));
492-
}
493506
// Store the inference request address submitted to the Triton server for
494507
// retrieval
495508
infer_payload->SetRequestAddress(reinterpret_cast<intptr_t>(irequest));
509+
infer_payload->SetRequestCancellationFunc(InferRequestCancel);
496510

497511
THROW_IF_TRITON_ERROR(
498512
TRITONSERVER_ServerInferAsync(server_, irequest, trace));
499513
}
500514
}
501515
catch (const PythonBackendException& pb_exception) {
502-
EraseRequestAddress(reinterpret_cast<intptr_t>(irequest));
503516
infer_payload->SetRequestAddress(0L);
517+
if (completion_userp != nullptr) {
518+
delete completion_userp;
519+
}
504520

505521
LOG_IF_ERROR(
506522
TRITONSERVER_InferenceRequestDelete(irequest),
@@ -514,34 +530,6 @@ RequestExecutor::Infer(
514530
return response_future;
515531
}
516532

517-
void
518-
RequestExecutor::Cancel(std::shared_ptr<InferPayload>& infer_payload)
519-
{
520-
intptr_t request_address = infer_payload->GetRequestAddress();
521-
if (request_address == 0L) {
522-
return;
523-
}
524-
525-
{
526-
std::lock_guard<std::mutex> lk(on_going_request_addresses_mu_);
527-
if (on_going_request_addresses_.find(request_address) !=
528-
on_going_request_addresses_.end()) {
529-
TRITONSERVER_InferenceRequest* irequest =
530-
reinterpret_cast<TRITONSERVER_InferenceRequest*>(request_address);
531-
THROW_IF_TRITON_ERROR(TRITONSERVER_InferenceRequestCancel(irequest));
532-
}
533-
}
534-
}
535-
536-
void
537-
RequestExecutor::EraseRequestAddress(intptr_t request_address)
538-
{
539-
if (request_address != 0L) {
540-
std::unique_lock<std::mutex> lk(on_going_request_addresses_mu_);
541-
on_going_request_addresses_.erase(request_address);
542-
}
543-
}
544-
545533
RequestExecutor::~RequestExecutor()
546534
{
547535
if (response_allocator_ != nullptr) {

src/request_executor.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,21 @@ namespace triton { namespace backend { namespace python {
3838
TRITONSERVER_Error* CreateTritonErrorFromException(
3939
const PythonBackendException& pb_exception);
4040

41+
struct RequestCompletionUserp {
42+
std::shared_ptr<InferPayload> infer_payload;
43+
RequestCompletionUserp(std::shared_ptr<InferPayload>& infer_payload)
44+
: infer_payload(infer_payload){};
45+
};
46+
4147
class RequestExecutor {
4248
TRITONSERVER_ResponseAllocator* response_allocator_ = nullptr;
4349
TRITONSERVER_Server* server_;
4450
std::unique_ptr<SharedMemoryManager>& shm_pool_;
45-
std::mutex on_going_request_addresses_mu_;
46-
std::unordered_set<intptr_t> on_going_request_addresses_;
4751

4852
public:
4953
std::future<std::unique_ptr<InferResponse>> Infer(
5054
std::shared_ptr<InferRequest>& infer_request,
5155
std::shared_ptr<InferPayload>& infer_payload);
52-
void EraseRequestAddress(intptr_t request_address);
53-
void Cancel(std::shared_ptr<InferPayload>& infer_payload);
5456

5557
RequestExecutor(
5658
std::unique_ptr<SharedMemoryManager>& shm_pool,

0 commit comments

Comments
 (0)