Skip to content

How do I pass control input to python-stateful backend? #824

Open
@lionsheep0724

Description

@lionsheep0724

Hi all, I'm planning to serve vad model (silero-vad) which has stateful nature. As my service scenario has multiple concurrent user streams, I need to keep state at every session. (i.e. every session state should be isolated, request from same stream goes to same model instance) AFAIK, stateful models need sequence batching which handles request to bind appropriate model instance, so my config.pbtxt will be as below, but I got tritonclient.utils.InferenceServerException: [StatusCode.INVALID_ARGUMENT] inference request to model 'vad_model' must specify a non-zero or non-empty correlation ID error from client, regardless of my headers.
How do I pass control input to server? My docker image ver is nvcr.io/nvidia/tritonserver:23.10-py3
I tried to pass them as grpcclient.InferInput, but it did not work.

  1. config.pbtxt
name: "vad_model"
backend: "python"

parameters: [
  {
  key: "model_dir"
  value: {
    string_value: "/workspace/models/vad_model/1/vad_model"
  }
}
]

input [
  {
    name: "AUDIO_CHUNK" 
    data_type: TYPE_STRING
    dims: [ 1 ]
  }
]

output [
  {
    name: "EVENT"
    data_type: TYPE_STRING
    dims: [ 1 ]
  },
  {
    name: "TIMESTAMP"
    data_type: TYPE_FP32
    dims: [ 1 ]
  }
]

sequence_batching {
  max_sequence_idle_microseconds: 1000000
  control_input [
    {
      name: "START"
      control [
        {
          kind: CONTROL_SEQUENCE_START
          fp32_false_true: [ 0, 1 ]
        }
      ]
    },
    {
      name: "END"
      control [
        {
          kind: CONTROL_SEQUENCE_END
          fp32_false_true: [ 0, 1 ]
        }
      ]
    },
    {
      name: "CORRID"
      control [
        {
          kind: CONTROL_SEQUENCE_CORRID
          data_type: TYPE_STRING
        }
      ]
    }
  ]
}

instance_group: [
  {
    kind: KIND_CPU,
    count: 8
  }
]
  1. client code
import tritonclient.grpc as grpcclient
triton_client = grpcclient.InferenceServerClient(url="tritonserver:8001")
base64_process_chunk = base64.b64encode(process_chunk).decode("utf-8")
input_audio = grpcclient.InferInput("AUDIO_CHUNK", [1], "BYTES")
input_audio.set_data_from_numpy(np.array([base64_process_chunk], dtype=np.object_))

output_event = grpcclient.InferRequestedOutput("EVENT")
output_timestamp = grpcclient.InferRequestedOutput("TIMESTAMP")
headers = {"correlation_id": "123"}
vad_result = self.triton_client.infer(model_name="vad_model",
                                                 inputs=[input_audio],
                                                 outputs=[output_event, output_timestamp],
                                                 model_version="1",
                                                 headers=headers
                                                 )

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions