Skip to content

GPU VRAM Leak with Python Backend BLS Requests to ORT Backend #301

Open
@WoodieDudy

Description

@WoodieDudy

Description
There is a vram leak occurring when sending requests to a Triton server with two models: one is an ONNX model and the other is a Python backend model that makes a BLS request to the ONNX model. The memory leak happens when sending tensors with specific sequence of batch sizes. This issue does not occur when using the TensorRT backend with the same model. Also leak happens only on slice gpu.

Triton Information

  • Triton container: nvcr.io/nvidia/tritonserver:25.02-py3
  • Backend: Python and ONNX Runtime

To Reproduce

  1. Get gpu slice, e.g. A100 20GB

  2. Setup Triton server with two models:

    • A Python backend model that makes bls requests to the ONNX model.
    • An ONNX model.
  3. Send tensor requests to the Python model.

    • Batch sizes such as [16, 7, 7, 7...] cause memory leak.
    • Batch sizes like [16, 6, 6, 6...] do not cause any memory leak.
  4. Monitor the increase in VRAM usage in nvitop or a similar tool

  5. Here is the model structure:

    .
    └── models
        ├── bls
        │   ├── 1
        │   │   └── model.py
        │   └── config.pbtxt
        └── onnx
            ├── 1
            └── config.pbtxt
    
  6. Code for testing (request.py):

    import tritonclient.grpc as grpcclient
    from tritonclient.utils import np_to_triton_dtype
    import numpy as np
    
    triton_client = grpcclient.InferenceServerClient(url="localhost:8001")
    
    batch_sizes = np.array([16, 7, 7, 7, 7, 7, 7], dtype=np.int32)
    
    for i in range(5):
        input_tensors = [
            grpcclient.InferInput("input", batch_sizes.shape, np_to_triton_dtype(batch_sizes.dtype)),
        ]
        input_tensors[0].set_data_from_numpy(batch_sizes)
    
        response = triton_client.infer("bls", inputs=input_tensors)
    
        output = response.as_numpy("output")
        print(output.shape)

bls/config.pbtxt:

backend: "python"
max_batch_size: 0

input [
  {
    name: "input"
    data_type: TYPE_INT32
    dims: [ -1 ]
  }
]
output [
  {
    name: "output"
    data_type: TYPE_FP32
    dims: [ 1 ]
  }
]

onnx/config.pbtxt:

platform: "onnxruntime_onnx"
max_batch_size: 16

instance_group [
  {
    kind: KIND_GPU
    gpus: [ 0 ]
  }
]

bls/1/model.py:

import numpy as np
import triton_python_backend_utils as pb_utils

class TritonPythonModel:            
    async def execute(self, requests):
        logger = pb_utils.Logger
        responses = []
        for request in requests:

            batch_sizes = pb_utils.get_input_tensor_by_name(request, "input").as_numpy().tolist()
            logger.log_info(str(batch_sizes))

            for n in batch_sizes:
                input_x_np = np.random.rand(n, 240000).astype(np.float32)
                input_xlen_np = np.array([1] * n, dtype=np.float32).reshape(n, 1)

                input_x_pb = pb_utils.Tensor("x", input_x_np)
                input_xlen_pb = pb_utils.Tensor("xlen", input_xlen_np)

                infer_request = pb_utils.InferenceRequest(
                    model_name='onnx',
                    requested_output_names=["logits"],
                    inputs=[input_x_pb, input_xlen_pb],
                    preferred_memory=pb_utils.PreferredMemory(pb_utils.TRITONSERVER_MEMORY_CPU)
                )
                resp = infer_request.exec()
                if resp.has_error():
                    raise pb_utils.TritonModelException(resp.error().message())
                output1 = pb_utils.get_output_tensor_by_name(resp, "logits").as_numpy()
                logger.log_info(str(output1.shape))

            output = pb_utils.Tensor("output", np.array([1]).astype(np.float32))
            responses.append(pb_utils.InferenceResponse(output_tensors = [output]))          

        return responses

Link to download model.onnx.
I was not able to reproduce the issue with other models.

Expected behavior
The expected behavior is that there should be no memory leak.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions