Description
Description
There is a vram leak occurring when sending requests to a Triton server with two models: one is an ONNX model and the other is a Python backend model that makes a BLS request to the ONNX model. The memory leak happens when sending tensors with specific sequence of batch sizes. This issue does not occur when using the TensorRT backend with the same model. Also leak happens only on slice gpu.
Triton Information
- Triton container: nvcr.io/nvidia/tritonserver:25.02-py3
- Backend: Python and ONNX Runtime
To Reproduce
-
Get gpu slice, e.g. A100 20GB
-
Setup Triton server with two models:
- A Python backend model that makes bls requests to the ONNX model.
- An ONNX model.
-
Send tensor requests to the Python model.
- Batch sizes such as [16, 7, 7, 7...] cause memory leak.
- Batch sizes like [16, 6, 6, 6...] do not cause any memory leak.
-
Monitor the increase in VRAM usage in
nvitop
or a similar tool -
Here is the model structure:
. └── models ├── bls │ ├── 1 │ │ └── model.py │ └── config.pbtxt └── onnx ├── 1 └── config.pbtxt
-
Code for testing (request.py):
import tritonclient.grpc as grpcclient from tritonclient.utils import np_to_triton_dtype import numpy as np triton_client = grpcclient.InferenceServerClient(url="localhost:8001") batch_sizes = np.array([16, 7, 7, 7, 7, 7, 7], dtype=np.int32) for i in range(5): input_tensors = [ grpcclient.InferInput("input", batch_sizes.shape, np_to_triton_dtype(batch_sizes.dtype)), ] input_tensors[0].set_data_from_numpy(batch_sizes) response = triton_client.infer("bls", inputs=input_tensors) output = response.as_numpy("output") print(output.shape)
bls/config.pbtxt:
backend: "python"
max_batch_size: 0
input [
{
name: "input"
data_type: TYPE_INT32
dims: [ -1 ]
}
]
output [
{
name: "output"
data_type: TYPE_FP32
dims: [ 1 ]
}
]
onnx/config.pbtxt:
platform: "onnxruntime_onnx"
max_batch_size: 16
instance_group [
{
kind: KIND_GPU
gpus: [ 0 ]
}
]
bls/1/model.py:
import numpy as np
import triton_python_backend_utils as pb_utils
class TritonPythonModel:
async def execute(self, requests):
logger = pb_utils.Logger
responses = []
for request in requests:
batch_sizes = pb_utils.get_input_tensor_by_name(request, "input").as_numpy().tolist()
logger.log_info(str(batch_sizes))
for n in batch_sizes:
input_x_np = np.random.rand(n, 240000).astype(np.float32)
input_xlen_np = np.array([1] * n, dtype=np.float32).reshape(n, 1)
input_x_pb = pb_utils.Tensor("x", input_x_np)
input_xlen_pb = pb_utils.Tensor("xlen", input_xlen_np)
infer_request = pb_utils.InferenceRequest(
model_name='onnx',
requested_output_names=["logits"],
inputs=[input_x_pb, input_xlen_pb],
preferred_memory=pb_utils.PreferredMemory(pb_utils.TRITONSERVER_MEMORY_CPU)
)
resp = infer_request.exec()
if resp.has_error():
raise pb_utils.TritonModelException(resp.error().message())
output1 = pb_utils.get_output_tensor_by_name(resp, "logits").as_numpy()
logger.log_info(str(output1.shape))
output = pb_utils.Tensor("output", np.array([1]).astype(np.float32))
responses.append(pb_utils.InferenceResponse(output_tensors = [output]))
return responses
Link to download model.onnx.
I was not able to reproduce the issue with other models.
Expected behavior
The expected behavior is that there should be no memory leak.