Description
Hi all, I'm planning to serve vad model (silero-vad) which has stateful nature. As my service scenario has multiple concurrent user streams, I need to keep state at every session. (i.e. every session state should be isolated, request from same stream goes to same model instance) AFAIK, stateful models need sequence batching which handles request to bind appropriate model instance, so my config.pbtxt will be as below, but I got tritonclient.utils.InferenceServerException: [StatusCode.INVALID_ARGUMENT] inference request to model 'vad_model' must specify a non-zero or non-empty correlation ID
error from client, regardless of my headers.
How do I pass control input to server? My docker image ver is nvcr.io/nvidia/tritonserver:23.10-py3
I tried to pass them as grpcclient.InferInput
, but it did not work.
- config.pbtxt
name: "vad_model"
backend: "python"
parameters: [
{
key: "model_dir"
value: {
string_value: "/workspace/models/vad_model/1/vad_model"
}
}
]
input [
{
name: "AUDIO_CHUNK"
data_type: TYPE_STRING
dims: [ 1 ]
}
]
output [
{
name: "EVENT"
data_type: TYPE_STRING
dims: [ 1 ]
},
{
name: "TIMESTAMP"
data_type: TYPE_FP32
dims: [ 1 ]
}
]
sequence_batching {
max_sequence_idle_microseconds: 1000000
control_input [
{
name: "START"
control [
{
kind: CONTROL_SEQUENCE_START
fp32_false_true: [ 0, 1 ]
}
]
},
{
name: "END"
control [
{
kind: CONTROL_SEQUENCE_END
fp32_false_true: [ 0, 1 ]
}
]
},
{
name: "CORRID"
control [
{
kind: CONTROL_SEQUENCE_CORRID
data_type: TYPE_STRING
}
]
}
]
}
instance_group: [
{
kind: KIND_CPU,
count: 8
}
]
- client code
import tritonclient.grpc as grpcclient
triton_client = grpcclient.InferenceServerClient(url="tritonserver:8001")
base64_process_chunk = base64.b64encode(process_chunk).decode("utf-8")
input_audio = grpcclient.InferInput("AUDIO_CHUNK", [1], "BYTES")
input_audio.set_data_from_numpy(np.array([base64_process_chunk], dtype=np.object_))
output_event = grpcclient.InferRequestedOutput("EVENT")
output_timestamp = grpcclient.InferRequestedOutput("TIMESTAMP")
headers = {"correlation_id": "123"}
vad_result = self.triton_client.infer(model_name="vad_model",
inputs=[input_audio],
outputs=[output_event, output_timestamp],
model_version="1",
headers=headers
)