diff --git a/demos/common/export_models/export_model.py b/demos/common/export_models/export_model.py index 5aa81b0c81..1c0e2e87b3 100644 --- a/demos/common/export_models/export_model.py +++ b/demos/common/export_models/export_model.py @@ -134,8 +134,14 @@ def add_common_arguments(parser): name: "S2tExecutor" input_side_packet: "STT_NODE_RESOURCES:s2t_servable" calculator: "S2tCalculator" + input_stream: "LOOPBACK:loopback" input_stream: "HTTP_REQUEST_PAYLOAD:input" + output_stream: "LOOPBACK:loopback" output_stream: "HTTP_RESPONSE_PAYLOAD:output" + input_stream_info: { + tag_index: 'LOOPBACK:0', + back_edge: true + } node_options: { [type.googleapis.com / mediapipe.S2tCalculatorOptions]: { models_path: "{{model_path}}", @@ -144,6 +150,16 @@ def add_common_arguments(parser): enable_word_timestamps: {% if not enable_word_timestamps %}false{% else %}true{% endif%}, } } + input_stream_handler { + input_stream_handler: "SyncSetInputStreamHandler", + options { + [mediapipe.SyncSetInputStreamHandlerOptions.ext] { + sync_set { + tag_index: "LOOPBACK:0" + } + } + } + } } """ diff --git a/src/BUILD b/src/BUILD index 5cbf3b01e6..817bdf4388 100644 --- a/src/BUILD +++ b/src/BUILD @@ -52,6 +52,16 @@ cc_library( copts = COMMON_STATIC_LIBS_COPTS, ) +ovms_cc_library( + name = "executor_base", + hdrs = ["executor_base.hpp"], + deps = [ + "//src:libovmslogging", + ], + visibility = ["//visibility:public"], + alwayslink = 1, +) + cc_shared_library( name = "ovms_shared", dynamic_deps = [], @@ -1231,6 +1241,12 @@ ovms_cc_library( ) +ovms_cc_library( + name = "sse_utils", + hdrs = ["sse_utils.hpp"], + visibility = ["//visibility:public"], +) + ovms_cc_library( name = "libovmsstatus", hdrs = ["status.hpp",], @@ -2243,6 +2259,7 @@ cc_test( "test/llm/visual_language_model/initialization_test.cpp", "test/audio/text2speech_test.cpp", "test/audio/speech2text_test.cpp", + "test/audio/s2t_streaming_test.cpp", ], "//:disable_mediapipe" : [ "test/disabled_mediapipe_test.cpp", diff --git a/src/audio/speech_to_text/BUILD b/src/audio/speech_to_text/BUILD index ca8410e31a..770d9c5a39 100644 --- a/src/audio/speech_to_text/BUILD +++ b/src/audio/speech_to_text/BUILD @@ -19,7 +19,50 @@ load("//:common_settings.bzl", "ovms_cc_library") ovms_cc_library( name = "s2t_servable", - hdrs = ["s2t_servable.hpp"], + srcs = ["s2t_servable.cpp"], + hdrs = [ + "s2t_executor.hpp", + "s2t_servable.hpp", + ], + deps = [ + "//src:executor_base", + "//src:httppayload", + "//src:libovmslogging", + "//src:libmodelconfigjsonparser", + "//src:libovmsstring_utils", + "s2t_calculator_cc_proto", + "@com_google_absl//absl/status", + "//third_party:genai", + ], + visibility = ["//visibility:public"], + alwayslink = 1, +) + +ovms_cc_library( + name = "s2t_streaming_handler", + srcs = [ + "s2t_streaming_handler.cpp", + "streaming_text_queue.cpp", + ], + hdrs = [ + "s2t_streaming_handler.hpp", + "streaming_text_queue.hpp", + ], + deps = [ + "@mediapipe//mediapipe/framework:calculator_framework", + "//src:httppayload", + "//src:libovmsclient_connection", + "//src:libovmslogging", + "//src:libovmsstring_utils", + "//src:libovmsstatus", + "//src:libmodelconfigjsonparser", + "//src:sse_utils", + "//src/port:rapidjson_stringbuffer", + "//src/port:rapidjson_writer", + ":s2t_servable", + "s2t_calculator_cc_proto", + "//third_party:genai", + ], visibility = ["//visibility:public"], alwayslink = 1, ) @@ -37,11 +80,11 @@ ovms_cc_library( "//src/port:rapidjson_stringbuffer", "//src/port:rapidjson_writer", ":s2t_servable", + ":s2t_streaming_handler", "//third_party:genai", "//src/audio:audio_utils", "//src:libmodelconfigjsonparser", "//src/mediapipe_internal:node_initializer", - "//src:libovmsstring_utils", ], visibility = ["//visibility:public"], alwayslink = 1, diff --git a/src/audio/speech_to_text/s2t_calculator.cc b/src/audio/speech_to_text/s2t_calculator.cc index be1debd83f..fe8a0b2ff3 100644 --- a/src/audio/speech_to_text/s2t_calculator.cc +++ b/src/audio/speech_to_text/s2t_calculator.cc @@ -13,7 +13,8 @@ // See the License for the specific language governing permissions and // limitations under the License. //***************************************************************************** -#include +#include +#include #pragma warning(push) #pragma warning(disable : 4005 4309 6001 6385 6386 6326 6011 6246 4456 6246) @@ -25,26 +26,19 @@ #pragma warning(pop) #include "src/audio/audio_utils.hpp" +#include "src/client_connection.hpp" #include "src/http_payload.hpp" #include "src/logging.hpp" -#include "src/stringutils.hpp" -#include -#include #pragma warning(push) #pragma warning(disable : 6001 4324 6385 6386) -#include "absl/strings/escaping.h" #include "absl/strings/str_cat.h" #pragma warning(pop) #include "src/port/rapidjson_writer.hpp" #include "src/port/rapidjson_stringbuffer.hpp" #include "s2t_servable.hpp" - -#ifdef _WIN32 -#include -#include -#endif +#include "s2t_streaming_handler.hpp" using namespace ovms; @@ -58,7 +52,7 @@ enum Endpoint { UNSUPPORTED }; -Endpoint getEndpoint(const std::string& url) { +static Endpoint getEndpoint(const std::string& url) { if (absl::StartsWith(url, "/v3/audio/transcriptions")) { return Endpoint::TRANSCRIPTIONS; } @@ -68,19 +62,35 @@ Endpoint getEndpoint(const std::string& url) { return Endpoint::UNSUPPORTED; } -size_t ISO_LANG_CODE_MAX = 3; +static absl::Status checkClientDisconnected(const ovms::HttpPayload& payload, const std::string& nodeName, const char* context) { + if (payload.client && payload.client->isDisconnected()) { + SPDLOG_LOGGER_DEBUG(s2t_calculator_logger, "Client disconnected {} [Node: {}]", context, nodeName); + return absl::CancelledError("Client disconnected"); + } + return absl::OkStatus(); +} class S2tCalculator : public CalculatorBase { static const std::string INPUT_TAG_NAME; static const std::string OUTPUT_TAG_NAME; + static const std::string LOOPBACK_TAG_NAME; + + bool hasLoopback_ = false; + S2tStreamingHandler streamingHandler_; public: static absl::Status GetContract(CalculatorContract* cc) { RET_CHECK(!cc->Inputs().GetTags().empty()); RET_CHECK(!cc->Outputs().GetTags().empty()); cc->Inputs().Tag(INPUT_TAG_NAME).Set(); - cc->InputSidePackets().Tag(STT_SESSION_SIDE_PACKET_TAG).Set(); // TODO: template? + if (cc->Inputs().HasTag(LOOPBACK_TAG_NAME)) { + cc->Inputs().Tag(LOOPBACK_TAG_NAME).Set(); + } + cc->InputSidePackets().Tag(STT_SESSION_SIDE_PACKET_TAG).Set(); cc->Outputs().Tag(OUTPUT_TAG_NAME).Set(); + if (cc->Outputs().HasTag(LOOPBACK_TAG_NAME)) { + cc->Outputs().Tag(LOOPBACK_TAG_NAME).Set(); + } return absl::OkStatus(); } @@ -91,12 +101,44 @@ class S2tCalculator : public CalculatorBase { absl::Status Open(CalculatorContext* cc) final { SPDLOG_LOGGER_DEBUG(s2t_calculator_logger, "SpeechToTextCalculator [Node: {}] Open start", cc->NodeName()); + hasLoopback_ = cc->Inputs().HasTag(LOOPBACK_TAG_NAME); return absl::OkStatus(); } absl::Status Process(CalculatorContext* cc) final { SPDLOG_LOGGER_DEBUG(s2t_calculator_logger, "SpeechToTextCalculator [Node: {}] Process start", cc->NodeName()); + bool loopbackEmpty = !hasLoopback_ || cc->Inputs().Tag(LOOPBACK_TAG_NAME).IsEmpty(); + if (cc->Inputs().Tag(INPUT_TAG_NAME).IsEmpty() && loopbackEmpty) { + return absl::OkStatus(); + } + + // --- LOOPBACK iteration: drain streaming queue --- + if (!loopbackEmpty) { + std::string ssePayload; + bool shouldContinueLoopback = false; + bool hasOutput = false; + auto status = streamingHandler_.processIteration(ssePayload, shouldContinueLoopback, hasOutput); + if (status != absl::OkStatus()) { + return status; + } + + if (hasOutput) { + cc->Outputs().Tag(OUTPUT_TAG_NAME).Add(new std::string{std::move(ssePayload)}, cc->InputTimestamp()); + } + if (shouldContinueLoopback) { + auto now = std::chrono::system_clock::now(); + const auto nextTimestamp = ::mediapipe::Timestamp(std::chrono::duration_cast(now.time_since_epoch()).count()); + cc->Outputs().Tag(LOOPBACK_TAG_NAME).Add(new bool{true}, nextTimestamp); + } + return absl::OkStatus(); + } + + // --- First iteration: new request --- + if (cc->Inputs().Tag(INPUT_TAG_NAME).IsEmpty()) { + return absl::OkStatus(); + } + SttServableMap pipelinesMap = cc->InputSidePackets().Tag(STT_SESSION_SIDE_PACKET_TAG).Get(); auto it = pipelinesMap.find(cc->NodeName()); RET_CHECK(it != pipelinesMap.end()) << "Could not find initialized STT node named: " << cc->NodeName(); @@ -111,10 +153,15 @@ class S2tCalculator : public CalculatorBase { if (payload.multipartParser->hasParseError()) return absl::InvalidArgumentError("Failed to parse multipart data"); - std::string stream = payload.multipartParser->getFieldByName("stream"); - if (!stream.empty()) { - return absl::InvalidArgumentError("streaming is not supported"); + std::string streamField = payload.multipartParser->getFieldByName("stream"); + bool requestStreaming = (streamField == "true") && hasLoopback_; + if (streamField == "true" && endpoint == Endpoint::TRANSLATIONS) { + return absl::InvalidArgumentError("streaming is not supported for translations endpoint"); + } + if (streamField == "true" && !hasLoopback_) { + return absl::InvalidArgumentError("streaming is not supported for this graph configuration (LOOPBACK not configured)"); } + std::string_view file = payload.multipartParser->getFileContentByFieldName("file"); if (file.empty()) { return absl::InvalidArgumentError(absl::StrCat("File parsing fails")); @@ -132,100 +179,22 @@ class S2tCalculator : public CalculatorBase { } catch (std::exception&) { return absl::InvalidArgumentError("Received input file is not valid wav nor mp3 audio file"); } - rapidjson::StringBuffer buffer; - rapidjson::Writer writer(buffer); - writer.StartObject(); - writer.String("text"); - if (endpoint == Endpoint::TRANSCRIPTIONS) { + + if (requestStreaming) { ov::genai::WhisperGenerationConfig config = pipe->sttPipeline->get_generation_config(); - std::string language = payload.multipartParser->getFieldByName("language"); - if (language.size() > 0) { - if (language.size() > ISO_LANG_CODE_MAX) { - return absl::InvalidArgumentError("Invalid language code."); - } - SPDLOG_LOGGER_TRACE(s2t_calculator_logger, "Received language: {}"); - config.language = "<|" + language + "|>"; + auto status = ovms::SttServable::updateTranscriptionConfig(config, pipe, payload); + if (status != absl::OkStatus()) { + return status; } - std::vector timestampsTypes = payload.multipartParser->getArrayFieldByName("timestamp_granularities[]"); - config.word_timestamps = false; - for (auto const& timestampsType : timestampsTypes) { - SPDLOG_LOGGER_TRACE(s2t_calculator_logger, "Received timestamp type: {}", timestampsType); - if (timestampsType == "segment") { - config.return_timestamps = true; - } else if (timestampsType == "word") { - if (!pipe->enableWordTimestamps) - return absl::InvalidArgumentError("Word timestamps not supported for this model"); - config.word_timestamps = true; - } else { - return absl::InvalidArgumentError("Invalid timestamp_granularities type. Allowed types: \"segment\", \"word\""); - } - } - std::string temperature = payload.multipartParser->getFieldByName("temperature"); - if (temperature.size() > 0) { - SPDLOG_LOGGER_TRACE(s2t_calculator_logger, "Received temperature: {}", temperature); - auto temp = ovms::stof(temperature); - if (!temp.has_value()) { - temp = stou32(temperature); - if (!temp.has_value()) - return absl::InvalidArgumentError("Invalid temperature type."); - } - if (temp.value() < 0.0f || temp.value() > 2.0f) - return absl::InvalidArgumentError("Temperature out of range(0.0, 2.0)"); - config.temperature = temp.value(); - } else { - config.temperature = 1.0; // default value - } - std::unique_lock lock(pipe->sttPipelineMutex); - const ov::genai::WhisperDecodedResults result = pipe->sttPipeline->generate(rawSpeech, config); - lock.unlock(); - const std::string generatedText = result; // word chunks concatenation to single string - writer.String(generatedText.c_str()); - if (config.word_timestamps) { - writer.String("words"); - writer.StartArray(); - if (result.words.has_value()) { - for (const auto& word : *result.words) { - writer.StartObject(); - writer.String("word"); - writer.String(word.word.c_str()); - writer.String("start"); - writer.Double(word.start_ts); - writer.String("end"); - writer.Double(word.end_ts); - writer.EndObject(); - } - } - writer.EndArray(); - } - if (config.return_timestamps) { - writer.String("segments"); - writer.StartArray(); - if (result.chunks.has_value()) { - for (const auto& chunk : *result.chunks) { - writer.StartObject(); - writer.String("text"); - writer.String(chunk.text.c_str()); - writer.String("start"); - writer.Double(chunk.start_ts); - writer.String("end"); - writer.Double(chunk.end_ts); - writer.EndObject(); - } - } - writer.EndArray(); + status = streamingHandler_.start(pipe, payload, std::move(rawSpeech), config); + if (status != absl::OkStatus()) { + return status; } + cc->Outputs().Tag(LOOPBACK_TAG_NAME).Add(new bool{true}, cc->InputTimestamp()); + return absl::OkStatus(); } - if (endpoint == Endpoint::TRANSLATIONS) { - std::unique_lock lock(pipe->sttPipelineMutex); - std::string generatedText = pipe->sttPipeline->generate(rawSpeech, ov::genai::task("translate")); - lock.unlock(); - writer.String(generatedText.c_str()); - } - writer.EndObject(); - std::unique_ptr output = std::make_unique(buffer.GetString()); - cc->Outputs().Tag(OUTPUT_TAG_NAME).Add(output.release(), cc->InputTimestamp()); - SPDLOG_LOGGER_DEBUG(s2t_calculator_logger, "SpeechToTextCalculator [Node: {}] Process end", cc->NodeName()); + return processUnaryRequest(cc, pipe, endpoint, payload, rawSpeech); } catch (ov::AssertFailure& e) { return absl::InvalidArgumentError(e.what()); } catch (...) { @@ -233,10 +202,107 @@ class S2tCalculator : public CalculatorBase { } return absl::OkStatus(); } + +private: + absl::Status processUnaryRequest(CalculatorContext* cc, std::shared_ptr pipe, + Endpoint endpoint, const ovms::HttpPayload& payload, const std::vector& rawSpeech) { + auto client = payload.client; + auto disconnectCallback = [client](std::string) -> ov::genai::StreamingStatus { + if (client && client->isDisconnected()) { + return ov::genai::StreamingStatus::CANCEL; + } + return ov::genai::StreamingStatus::RUNNING; + }; + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + writer.StartObject(); + writer.String("text"); + if (endpoint == Endpoint::TRANSCRIPTIONS) { + ov::genai::WhisperGenerationConfig config = pipe->sttPipeline->get_generation_config(); + auto status = ovms::SttServable::updateTranscriptionConfig(config, pipe, payload); + if (status != absl::OkStatus()) + return status; + + std::unique_lock lock(pipe->sttPipelineMutex); + auto disconnectStatus = checkClientDisconnected(payload, cc->NodeName(), "before transcription"); + if (!disconnectStatus.ok()) + return disconnectStatus; + const ov::genai::WhisperDecodedResults result = pipe->sttPipeline->generate(rawSpeech, config, disconnectCallback); + lock.unlock(); + disconnectStatus = checkClientDisconnected(payload, cc->NodeName(), "after transcription"); + if (!disconnectStatus.ok()) + return disconnectStatus; + const std::string generatedText = result; + writer.String(generatedText.c_str()); + serializeTimestamps(writer, result, config); + } + if (endpoint == Endpoint::TRANSLATIONS) { + ov::genai::WhisperGenerationConfig config = pipe->sttPipeline->get_generation_config(); + config.task = "translate"; + auto status = ovms::SttServable::parseTemperature(payload, config); + if (status != absl::OkStatus()) + return status; + std::unique_lock lock(pipe->sttPipelineMutex); + auto disconnectStatus = checkClientDisconnected(payload, cc->NodeName(), "before translation"); + if (!disconnectStatus.ok()) + return disconnectStatus; + const ov::genai::WhisperDecodedResults result = pipe->sttPipeline->generate(rawSpeech, config, disconnectCallback); + lock.unlock(); + disconnectStatus = checkClientDisconnected(payload, cc->NodeName(), "after translation"); + if (!disconnectStatus.ok()) + return disconnectStatus; + const std::string generatedText = result; + writer.String(generatedText.c_str()); + } + writer.EndObject(); + auto output = std::make_unique(buffer.GetString()); + cc->Outputs().Tag(OUTPUT_TAG_NAME).Add(output.release(), cc->InputTimestamp()); + SPDLOG_LOGGER_DEBUG(s2t_calculator_logger, "SpeechToTextCalculator [Node: {}] Process end", cc->NodeName()); + return absl::OkStatus(); + } + + static void serializeTimestamps(rapidjson::Writer& writer, + const ov::genai::WhisperDecodedResults& result, const ov::genai::WhisperGenerationConfig& config) { + if (config.word_timestamps) { + writer.String("words"); + writer.StartArray(); + if (result.words.has_value()) { + for (const auto& word : *result.words) { + writer.StartObject(); + writer.String("word"); + writer.String(word.word.c_str()); + writer.String("start"); + writer.Double(word.start_ts); + writer.String("end"); + writer.Double(word.end_ts); + writer.EndObject(); + } + } + writer.EndArray(); + } + if (config.return_timestamps) { + writer.String("segments"); + writer.StartArray(); + if (result.chunks.has_value()) { + for (const auto& chunk : *result.chunks) { + writer.StartObject(); + writer.String("text"); + writer.String(chunk.text.c_str()); + writer.String("start"); + writer.Double(chunk.start_ts); + writer.String("end"); + writer.Double(chunk.end_ts); + writer.EndObject(); + } + } + writer.EndArray(); + } + } }; const std::string S2tCalculator::INPUT_TAG_NAME{"HTTP_REQUEST_PAYLOAD"}; const std::string S2tCalculator::OUTPUT_TAG_NAME{"HTTP_RESPONSE_PAYLOAD"}; +const std::string S2tCalculator::LOOPBACK_TAG_NAME{"LOOPBACK"}; REGISTER_CALCULATOR(S2tCalculator); } // namespace mediapipe diff --git a/src/audio/speech_to_text/s2t_executor.hpp b/src/audio/speech_to_text/s2t_executor.hpp new file mode 100644 index 0000000000..9b25ad9c76 --- /dev/null +++ b/src/audio/speech_to_text/s2t_executor.hpp @@ -0,0 +1,92 @@ +//***************************************************************************** +// Copyright 2026 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "openvino/genai/whisper_pipeline.hpp" +#include "src/executor_base.hpp" +#include "src/logging.hpp" + +namespace ovms { + +struct SttServableExecutionContext { + std::vector rawSpeech; + ov::genai::WhisperGenerationConfig config; + std::function streamerCallback; + std::function onFinished; + std::promise finishedPromise; + std::future finished; + + SttServableExecutionContext( + std::vector rawSpeech, + ov::genai::WhisperGenerationConfig config, + std::function streamerCallback, + std::function onFinished) : + rawSpeech(std::move(rawSpeech)), + config(std::move(config)), + streamerCallback(std::move(streamerCallback)), + onFinished(std::move(onFinished)), + finished(finishedPromise.get_future()) {} +}; + +struct SttExecutor : public Executor> { + std::shared_ptr sttPipeline; + std::mutex& sttPipelineMutex; + + SttExecutor(std::shared_ptr sttPipeline, std::mutex& sttPipelineMutex) : + sttPipeline(std::move(sttPipeline)), + sttPipelineMutex(sttPipelineMutex) {} + + void processRequest() { + std::shared_ptr requestExecutionContext; + { + std::lock_guard lock(queueMutex); + if (requests.empty()) { + return; + } + requestExecutionContext = std::move(requests.front()); + requests.pop(); + } + try { + std::unique_lock lock(sttPipelineMutex); + auto result = sttPipeline->generate( + requestExecutionContext->rawSpeech, + requestExecutionContext->config, + requestExecutionContext->streamerCallback); + lock.unlock(); + requestExecutionContext->finishedPromise.set_value(std::move(result)); + } catch (...) { + requestExecutionContext->finishedPromise.set_exception(std::current_exception()); + } + requestExecutionContext->onFinished(); + } +}; + +class SttExecutorWrapper : public ExecutorWrapper { +public: + SttExecutorWrapper(std::shared_ptr sttPipeline, std::mutex& sttPipelineMutex) : + ExecutorWrapper(s2t_calculator_logger, std::make_shared(std::move(sttPipeline), sttPipelineMutex)) {} +}; + +} // namespace ovms diff --git a/src/audio/speech_to_text/s2t_servable.cpp b/src/audio/speech_to_text/s2t_servable.cpp new file mode 100644 index 0000000000..451a8bb316 --- /dev/null +++ b/src/audio/speech_to_text/s2t_servable.cpp @@ -0,0 +1,117 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include "s2t_servable.hpp" + +#include + +#pragma warning(push) +#pragma warning(disable : 6386) +#include "absl/status/status.h" +#pragma warning(pop) +#include "openvino/genai/whisper_pipeline.hpp" + +#include "src/audio/speech_to_text/s2t_calculator.pb.h" +#include "src/http_payload.hpp" +#include "src/json_parser.hpp" +#include "src/logging.hpp" +#include "src/stringutils.hpp" + +namespace ovms { + +namespace { +constexpr size_t ISO_LANG_CODE_MAX = 3; +} + +SttServable::SttServable(const ::mediapipe::S2tCalculatorOptions& nodeOptions, const std::string& graphPath) { + auto fsModelsPath = std::filesystem::path(nodeOptions.models_path()); + if (fsModelsPath.is_relative()) { + parsedModelsPath = (std::filesystem::path(graphPath) / fsModelsPath); + } else { + parsedModelsPath = fsModelsPath; + } + ov::AnyMap config; + auto status = JsonParser::parsePluginConfig(nodeOptions.plugin_config(), config); + if (!status.ok()) { + SPDLOG_ERROR("Error during llm node plugin_config option parsing to JSON: {}", nodeOptions.plugin_config()); + throw std::runtime_error("Error during plugin_config option parsing"); + } + enableWordTimestamps = nodeOptions.enable_word_timestamps(); + if (enableWordTimestamps && nodeOptions.target_device() == "NPU") { + config["STATIC_PIPELINE"] = true; + } + config["word_timestamps"] = enableWordTimestamps; + sttPipeline = std::make_shared(parsedModelsPath.string(), nodeOptions.target_device(), config); + + streamingExecutor = std::make_unique(sttPipeline, sttPipelineMutex); +} + +void SttServable::addRequest(std::shared_ptr executionContext) { + if (!streamingExecutor) { + throw std::runtime_error("Cannot schedule STT streaming job - executor not initialized"); + } + streamingExecutor->addRequest(std::move(executionContext)); +} + +absl::Status SttServable::parseTemperature(const HttpPayload& payload, ov::genai::WhisperGenerationConfig& config) { + std::string temperatureStr = payload.multipartParser->getFieldByName("temperature"); + if (temperatureStr.size() > 0) { + SPDLOG_LOGGER_TRACE(s2t_calculator_logger, "Received temperature: {}", temperatureStr); + auto temp = ovms::stof(temperatureStr); + if (!temp.has_value()) { + temp = ovms::stou32(temperatureStr); + if (!temp.has_value()) + return absl::InvalidArgumentError("Invalid temperature type."); + } + config.temperature = temp.value(); + if (config.temperature > 0) { + config.do_sample = true; + } + } + return absl::OkStatus(); +} + +absl::Status SttServable::updateTranscriptionConfig(ov::genai::WhisperGenerationConfig& config, + const std::shared_ptr& servable, const HttpPayload& payload) { + std::string language = payload.multipartParser->getFieldByName("language"); + if (language.size() > 0) { + if (language.size() > ISO_LANG_CODE_MAX) { + return absl::InvalidArgumentError("Invalid language code."); + } + SPDLOG_LOGGER_TRACE(s2t_calculator_logger, "Received language: {}", language); + config.language = "<|" + language + "|>"; + } + std::vector timestampsTypes = payload.multipartParser->getArrayFieldByName("timestamp_granularities[]"); + config.word_timestamps = false; + for (const auto& timestampsType : timestampsTypes) { + SPDLOG_LOGGER_TRACE(s2t_calculator_logger, "Received timestamp type: {}", timestampsType); + if (timestampsType == "segment") { + config.return_timestamps = true; + } else if (timestampsType == "word") { + if (!servable->enableWordTimestamps) + return absl::InvalidArgumentError("Word timestamps not supported for this model"); + config.word_timestamps = true; + } else { + return absl::InvalidArgumentError("Invalid timestamp_granularities type. Allowed types: \"segment\", \"word\""); + } + } + auto status = parseTemperature(payload, config); + if (status != absl::OkStatus()) + return status; + return absl::OkStatus(); +} + +} // namespace ovms diff --git a/src/audio/speech_to_text/s2t_servable.hpp b/src/audio/speech_to_text/s2t_servable.hpp index 9ada8d0fae..c409089780 100644 --- a/src/audio/speech_to_text/s2t_servable.hpp +++ b/src/audio/speech_to_text/s2t_servable.hpp @@ -15,53 +15,53 @@ //***************************************************************************** #pragma once +#include +#include #include +#include #include #include +#include #include -#pragma warning(push) -#pragma warning(disable : 4005 4309 6001 6385 6386 6326 6011 4005 4456 6246) -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wdeprecated-declarations" -#include "mediapipe/framework/calculator_graph.h" -#pragma GCC diagnostic pop -#pragma warning(pop) - -#include "openvino/genai/whisper_pipeline.hpp" -#include "openvino/genai/speech_generation/text2speech_pipeline.hpp" -#include "src/audio/speech_to_text/s2t_calculator.pb.h" -#include "src/json_parser.hpp" +#include "src/audio/speech_to_text/s2t_executor.hpp" #include "src/status.hpp" -#include "src/logging.hpp" + +namespace absl { +class Status; +} // namespace absl + +namespace mediapipe { +class S2tCalculatorOptions; +} // namespace mediapipe + +namespace ov::genai { +class WhisperPipeline; +class WhisperGenerationConfig; +} // namespace ov::genai namespace ovms { +struct HttpPayload; + struct SttServable { std::filesystem::path parsedModelsPath; std::shared_ptr sttPipeline; std::mutex sttPipelineMutex; bool enableWordTimestamps; - SttServable(const ::mediapipe::S2tCalculatorOptions& nodeOptions, const std::string& graphPath) { - auto fsModelsPath = std::filesystem::path(nodeOptions.models_path()); - if (fsModelsPath.is_relative()) { - parsedModelsPath = (std::filesystem::path(graphPath) / fsModelsPath); - } else { - parsedModelsPath = fsModelsPath; - } - ov::AnyMap config; - auto status = JsonParser::parsePluginConfig(nodeOptions.plugin_config(), config); - if (!status.ok()) { - SPDLOG_ERROR("Error during llm node plugin_config option parsing to JSON: {}", nodeOptions.plugin_config()); - throw std::runtime_error("Error during plugin_config option parsing"); - } - enableWordTimestamps = nodeOptions.enable_word_timestamps(); - if (enableWordTimestamps && nodeOptions.target_device() == "NPU") - config["STATIC_PIPELINE"] = true; - config["word_timestamps"] = enableWordTimestamps; - sttPipeline = std::make_shared(parsedModelsPath.string(), nodeOptions.target_device(), config); - } + std::unique_ptr streamingExecutor; + + SttServable(const ::mediapipe::S2tCalculatorOptions& nodeOptions, const std::string& graphPath); + + ~SttServable() = default; + + void addRequest(std::shared_ptr executionContext); + + static absl::Status parseTemperature(const HttpPayload& payload, ov::genai::WhisperGenerationConfig& config); + + static absl::Status updateTranscriptionConfig(ov::genai::WhisperGenerationConfig& config, + const std::shared_ptr& servable, const HttpPayload& payload); }; using SttServableMap = std::unordered_map>; diff --git a/src/audio/speech_to_text/s2t_streaming_handler.cpp b/src/audio/speech_to_text/s2t_streaming_handler.cpp new file mode 100644 index 0000000000..1581fcdefd --- /dev/null +++ b/src/audio/speech_to_text/s2t_streaming_handler.cpp @@ -0,0 +1,150 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#include "s2t_streaming_handler.hpp" + +#include +#include +#include + +#include "src/port/rapidjson_stringbuffer.hpp" +#include "src/port/rapidjson_writer.hpp" +#include "src/http_payload.hpp" +#include "src/logging.hpp" +#include "src/sse_utils.hpp" +#include "src/stringutils.hpp" +#include "streaming_text_queue.hpp" +#include "s2t_servable.hpp" + +namespace mediapipe { + +std::string S2tStreamingHandler::serializeDeltaEvent(const std::string& delta) { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + writer.StartObject(); + writer.String("type"); + writer.String("transcript.text.delta"); + writer.String("delta"); + writer.String(delta.c_str()); + writer.String("logprobs"); + writer.StartArray(); + writer.EndArray(); + writer.EndObject(); + return buffer.GetString(); +} + +std::string S2tStreamingHandler::serializeDoneEvent(const std::string& text) { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + writer.StartObject(); + writer.String("type"); + writer.String("transcript.text.done"); + writer.String("text"); + writer.String(text.c_str()); + writer.String("logprobs"); + writer.StartArray(); + writer.EndArray(); + writer.EndObject(); + return buffer.GetString(); +} + +absl::Status S2tStreamingHandler::start(std::shared_ptr pipe, + const ovms::HttpPayload& payload, + std::vector rawSpeech, + const ov::genai::WhisperGenerationConfig& config) { + if (isStreaming_) { + return absl::FailedPreconditionError("Streaming request is already active"); + } + isStreaming_ = true; + streamingQueue_ = std::make_shared(); + executionContext_.reset(); + + auto client = payload.client; + auto streamerCallback = [queue = streamingQueue_, client](std::string text) -> ov::genai::StreamingStatus { + if (client && client->isDisconnected()) { + queue->endStreaming(); + return ov::genai::StreamingStatus::CANCEL; + } + if (!text.empty()) { + queue->push(std::move(text)); + } + return ov::genai::StreamingStatus::RUNNING; + }; + + auto guardedStreamerCallback = [streamerCallback = std::move(streamerCallback), queue = streamingQueue_](std::string text) mutable -> ov::genai::StreamingStatus { + try { + return streamerCallback(std::move(text)); + } catch (...) { + queue->endStreaming(); + throw; + } + }; + auto executionContext = std::make_shared( + std::move(rawSpeech), + config, + std::move(guardedStreamerCallback), + [queue = streamingQueue_]() { queue->endStreaming(); }); + try { + pipe->addRequest(executionContext); + executionContext_ = std::move(executionContext); + } catch (const std::exception& e) { + isStreaming_ = false; + return absl::InternalError(e.what()); + } + + return absl::OkStatus(); +} + +absl::Status S2tStreamingHandler::processIteration(std::string& ssePayload, + bool& shouldContinueLoopback, + bool& hasOutput) { + hasOutput = false; + shouldContinueLoopback = false; + + std::string chunk; + bool hasData = streamingQueue_->waitAndPop(chunk); + + if (hasData) { + ssePayload = ovms::wrapTextInServerSideEventMessage(serializeDeltaEvent(chunk)); + hasOutput = true; + shouldContinueLoopback = true; + } else { + // Generation complete — send final event and stop + std::string finalText; + try { + if (executionContext_ && executionContext_->finished.valid()) { + const ov::genai::WhisperDecodedResults result = executionContext_->finished.get(); + finalText = result; + } + } catch (ov::AssertFailure& e) { + isStreaming_ = false; + executionContext_.reset(); + return absl::InvalidArgumentError(e.what()); + } catch (...) { + isStreaming_ = false; + executionContext_.reset(); + return absl::InvalidArgumentError("Response generation failed"); + } + + ssePayload = ovms::wrapTextInServerSideEventMessage(serializeDoneEvent(finalText)); + hasOutput = true; + isStreaming_ = false; + executionContext_.reset(); + } + + return absl::OkStatus(); +} + +} // namespace mediapipe diff --git a/src/audio/speech_to_text/s2t_streaming_handler.hpp b/src/audio/speech_to_text/s2t_streaming_handler.hpp new file mode 100644 index 0000000000..fff5b0d7d0 --- /dev/null +++ b/src/audio/speech_to_text/s2t_streaming_handler.hpp @@ -0,0 +1,66 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#pragma once + +#include +#include +#include +#include + +#pragma warning(push) +#pragma warning(disable : 4005 4309 6001 6385 6386 6326 6011 6246 4456 6246) +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +#include "mediapipe/framework/calculator_framework.h" +#include "mediapipe/framework/port/canonical_errors.h" +#pragma GCC diagnostic pop +#pragma warning(pop) + +#include "openvino/genai/whisper_pipeline.hpp" + +#include "src/http_payload.hpp" +namespace ovms { +class StreamingTextQueue; +struct SttServable; +struct SttServableExecutionContext; +} // namespace ovms + +namespace mediapipe { + +// Encapsulates all streaming state and logic for the S2tCalculator. +// Manages the background generation thread, the text queue, SSE +// serialization and LOOPBACK signaling. +class S2tStreamingHandler { +public: + static std::string serializeDeltaEvent(const std::string& delta); + static std::string serializeDoneEvent(const std::string& text); + + absl::Status start(std::shared_ptr pipe, + const ovms::HttpPayload& payload, + std::vector rawSpeech, + const ov::genai::WhisperGenerationConfig& config); + + absl::Status processIteration(std::string& ssePayload, + bool& shouldContinueLoopback, + bool& hasOutput); + +private: + bool isStreaming_ = false; + std::shared_ptr streamingQueue_; + std::shared_ptr executionContext_; +}; + +} // namespace mediapipe diff --git a/src/audio/speech_to_text/streaming_text_queue.cpp b/src/audio/speech_to_text/streaming_text_queue.cpp new file mode 100644 index 0000000000..652cc276c0 --- /dev/null +++ b/src/audio/speech_to_text/streaming_text_queue.cpp @@ -0,0 +1,46 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** + +#include "streaming_text_queue.hpp" + +#include + +namespace ovms { + +void StreamingTextQueue::push(std::string text) { + std::lock_guard lock(mutex_); + queue_.push(std::move(text)); + cv_.notify_one(); +} + +void StreamingTextQueue::endStreaming() { + std::lock_guard lock(mutex_); + done_ = true; + cv_.notify_one(); +} + +bool StreamingTextQueue::waitAndPop(std::string& out) { + std::unique_lock lock(mutex_); + cv_.wait(lock, [this] { return !queue_.empty() || done_; }); + if (!queue_.empty()) { + out = std::move(queue_.front()); + queue_.pop(); + return true; + } + return false; // done and empty +} + +} // namespace ovms diff --git a/src/audio/speech_to_text/streaming_text_queue.hpp b/src/audio/speech_to_text/streaming_text_queue.hpp new file mode 100644 index 0000000000..177eb331d2 --- /dev/null +++ b/src/audio/speech_to_text/streaming_text_queue.hpp @@ -0,0 +1,47 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#pragma once + +#include +#include +#include +#include +#include + +namespace ovms { + +// Thread-safe queue for streaming partial results from a background +// generation thread to the MediaPipe LOOPBACK loop. Used by S2tCalculator +// to bridge ov::genai streamer callbacks and the calculator's Process() cycle. +class StreamingTextQueue { +public: + void push(std::string text); + + // Signals that generation has finished (successfully or with error). + void endStreaming(); + + // Blocks until a text chunk is available or generation is done. + // Returns true if a chunk was retrieved, false if done and queue is empty. + bool waitAndPop(std::string& out); + +private: + mutable std::mutex mutex_; + std::condition_variable cv_; + std::queue queue_; + bool done_ = false; +}; + +} // namespace ovms diff --git a/src/audio/text_to_speech/t2s_calculator.cc b/src/audio/text_to_speech/t2s_calculator.cc index 917a77b0d8..f8f4912f0d 100644 --- a/src/audio/text_to_speech/t2s_calculator.cc +++ b/src/audio/text_to_speech/t2s_calculator.cc @@ -25,6 +25,7 @@ #pragma warning(pop) #include "src/audio/audio_utils.hpp" +#include "src/client_connection.hpp" #include "src/http_payload.hpp" #include "src/logging.hpp" #include @@ -51,6 +52,14 @@ namespace mediapipe { const std::string TTS_SESSION_SIDE_PACKET_TAG = "TTS_NODE_RESOURCES"; +static absl::Status checkClientDisconnected(const ovms::HttpPayload& payload, const std::string& nodeName, const char* context) { + if (payload.client && payload.client->isDisconnected()) { + SPDLOG_LOGGER_DEBUG(t2s_calculator_logger, "Client disconnected {} [Node: {}]", context, nodeName); + return absl::CancelledError("Client disconnected"); + } + return absl::OkStatus(); +} + class T2sCalculator : public CalculatorBase { static const std::string INPUT_TAG_NAME; static const std::string OUTPUT_TAG_NAME; @@ -111,8 +120,12 @@ class T2sCalculator : public CalculatorBase { if (pipe->voices.find(voiceName.value()) == pipe->voices.end()) return absl::InvalidArgumentError(absl::StrCat("Requested voice not available: ", voiceName.value())); } + ov::genai::Text2SpeechDecodedResults generatedSpeech; std::unique_lock lock(pipe->ttsPipelineMutex); + auto disconnectStatus = checkClientDisconnected(payload, cc->NodeName(), "before generation"); + if (!disconnectStatus.ok()) + return disconnectStatus; if (voiceName.has_value()) { generatedSpeech = pipe->ttsPipeline->generate(inputIt->value.GetString(), pipe->voices[voiceName.value()]); @@ -125,6 +138,9 @@ class T2sCalculator : public CalculatorBase { // copy results to release inference request generatedSpeech.speeches[0].copy_to(cpuTensor); lock.unlock(); + disconnectStatus = checkClientDisconnected(payload, cc->NodeName(), "after generation"); + if (!disconnectStatus.ok()) + return disconnectStatus; void* ppData; size_t pDataSize; prepareAudioOutput(&ppData, pDataSize, bitsPerSample, speechSize, cpuTensor.data()); diff --git a/src/drogon_http_server.cpp b/src/drogon_http_server.cpp index 210776ac8b..c4c976b64a 100644 --- a/src/drogon_http_server.cpp +++ b/src/drogon_http_server.cpp @@ -88,12 +88,13 @@ Status DrogonHttpServer::startAcceptingRequests() { drogon::app().disableSigtermHandling(); drogon::app().setDefaultHandler([this](const drogon::HttpRequestPtr& req, std::function&& drogonResponseInitializeCallback) { - bool isTextGeneration = req->path().find("/completions") != std::string::npos || - req->path().find("/responses") != std::string::npos; + bool isStreamingEndpoint = req->path().find("/completions") != std::string::npos || + req->path().find("/responses") != std::string::npos || + req->path().find("/audio/transcriptions") != std::string::npos; // Here we need to schedule the request to the separate thread pool // in order to use disconnection callback of drogon. - if (isTextGeneration) { + if (isStreamingEndpoint) { this->pool->Schedule([this, req, drogonResponseInitializeCallback = std::move(drogonResponseInitializeCallback)]() mutable { SPDLOG_DEBUG("Request URI {} dispatched to streaming thread pool", req->path()); this->dispatch(req, std::move(drogonResponseInitializeCallback)); diff --git a/src/executor_base.hpp b/src/executor_base.hpp new file mode 100644 index 0000000000..50c7972c37 --- /dev/null +++ b/src/executor_base.hpp @@ -0,0 +1,121 @@ +//***************************************************************************** +// Copyright 2026 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace ovms { + +template +struct Executor { + using Request = RequestT; + + std::condition_variable cv; + std::queue requests; + std::mutex queueMutex; + + virtual void processRequest() = 0; + + bool hasRequests() { + std::lock_guard lock(queueMutex); + return !requests.empty(); + } + + size_t requestsQueueSize() { + std::lock_guard lock(queueMutex); + return requests.size(); + } + + void waitForRequests(std::atomic* receivedEndSignal) { + std::unique_lock lock(queueMutex); + cv.wait(lock, [this, receivedEndSignal] { return !requests.empty() || *receivedEndSignal; }); + } + + void notify() { + std::lock_guard lock(queueMutex); + cv.notify_one(); + } + + void scheduleRequest(RequestT&& request) { + std::lock_guard lock(queueMutex); + requests.push(std::move(request)); + cv.notify_one(); + } +}; + +template +void runExecutorLoop(Executor* executor, std::atomic* receivedEndSignal, const std::shared_ptr& logger) { + while (!(*receivedEndSignal)) { + SPDLOG_LOGGER_INFO(logger, "All requests: {};", executor->requestsQueueSize()); + if (executor->hasRequests()) { + executor->processRequest(); + } else { + executor->waitForRequests(receivedEndSignal); + } + } +} + +template +class ExecutorWrapper { + std::thread executorThread; + + using Request = typename ExecutorT::Request; + + static void run(Executor* exec, std::atomic* stop, std::shared_ptr logger) { + try { + runExecutorLoop(exec, stop, logger); + } catch (const std::exception& e) { + SPDLOG_LOGGER_ERROR(logger, "Error occurred in executor: {}.", e.what()); + exit(1); + } + } + +protected: + std::shared_ptr executor; + std::atomic finishExecutorThread = false; + +public: + ExecutorWrapper(std::shared_ptr logger, std::shared_ptr executor) : + executor(std::move(executor)) { + executorThread = std::thread(run, this->executor.get(), &finishExecutorThread, logger); + } + + void addRequest(Request request) { + if (finishExecutorThread) { + throw std::runtime_error("Cannot schedule request - executor is stopping"); + } + executor->scheduleRequest(std::move(request)); + } + + ~ExecutorWrapper() { + finishExecutorThread = true; + executor->notify(); + if (executorThread.joinable()) { + executorThread.join(); + } + } +}; + +} // namespace ovms diff --git a/src/graph_export/graph_export.cpp b/src/graph_export/graph_export.cpp index b98a18a966..ee273dbb0c 100644 --- a/src/graph_export/graph_export.cpp +++ b/src/graph_export/graph_export.cpp @@ -365,8 +365,14 @@ node { << exportSettings.modelName << R"(" calculator: "S2tCalculator" input_side_packet: "STT_NODE_RESOURCES:s2t_servable" + input_stream: "LOOPBACK:loopback" input_stream: "HTTP_REQUEST_PAYLOAD:input" + output_stream: "LOOPBACK:loopback" output_stream: "HTTP_RESPONSE_PAYLOAD:output" + input_stream_info: { + tag_index: 'LOOPBACK:0', + back_edge: true + } node_options: { [type.googleapis.com / mediapipe.S2tCalculatorOptions]: { models_path: ")" @@ -379,6 +385,16 @@ node { } oss << R"(} } + input_stream_handler { + input_stream_handler: "SyncSetInputStreamHandler", + options { + [mediapipe.SyncSetInputStreamHandlerOptions.ext] { + sync_set { + tag_index: "LOOPBACK:0" + } + } + } + } })"; #if (MEDIAPIPE_DISABLE == 0) ::mediapipe::CalculatorGraphConfig config; diff --git a/src/http_rest_api_handler.cpp b/src/http_rest_api_handler.cpp index c00bb6386b..5a24028d57 100644 --- a/src/http_rest_api_handler.cpp +++ b/src/http_rest_api_handler.cpp @@ -509,6 +509,11 @@ static Status createV3HttpPayload( } else { SPDLOG_DEBUG("Model name from deduced from MultiPart field: {}", modelName); } + // Detect stream field in multipart form data (used by audio endpoints) + std::string streamField = multiPartParser->getFieldByName("stream"); + if (streamField == "true") { + streamFieldVal = true; + } ensureJsonParserInErrorState(parsedJson); } else if (isApplicationJson) { { diff --git a/src/llm/BUILD b/src/llm/BUILD index 0ecab7e2fa..8fe6059d71 100644 --- a/src/llm/BUILD +++ b/src/llm/BUILD @@ -267,6 +267,7 @@ ovms_cc_library( "//third_party:openvino", "@mediapipe//mediapipe/framework:calculator_framework", "@com_github_tencent_rapidjson//:rapidjson", + "//src:executor_base", "//src:libmodelconfigjsonparser", "//src:libovmslogging", "//src:libovmsstatus", @@ -281,6 +282,7 @@ ovms_cc_library( ":generation_config_builders", "//src:httppayload", "//src:libhttpclientconnection", + "//src:sse_utils", "//third_party:genai",] + select({ "//:disable_python": [], "//:not_disable_python" : [":py_jinja_template_processor"], diff --git a/src/llm/language_model/legacy/legacy_executor.cpp b/src/llm/language_model/legacy/legacy_executor.cpp index 1422f64b8c..80699d418c 100644 --- a/src/llm/language_model/legacy/legacy_executor.cpp +++ b/src/llm/language_model/legacy/legacy_executor.cpp @@ -15,21 +15,17 @@ //***************************************************************************** #include "legacy_executor.hpp" + +#include "../../../logging.hpp" #include "servable.hpp" +#include + namespace ovms { LegacyExecutor::LegacyExecutor(std::shared_ptr pipe) { this->pipe = std::move(pipe); } -bool LegacyExecutor::hasRequests() { - return requests.size() > 0; -} - -size_t LegacyExecutor::requestsQueueSize() { - return requests.size(); -} - void LegacyExecutor::processRequest() { OVMS_PROFILE_FUNCTION(); auto& requestExecutionContext = requests.front(); @@ -52,50 +48,6 @@ void LegacyExecutor::processRequest() { requests.pop(); } -void LegacyExecutor::waitForRequests(std::atomic* receivedEndSignal) { - std::unique_lock lock(queueMutex); - cv.wait(lock, [this, receivedEndSignal] { return (requests.size() > 0 || *receivedEndSignal); }); -} - -void LegacyExecutor::addRequest(std::shared_ptr request) { - std::unique_lock lock(queueMutex); - requests.push(request); - cv.notify_one(); -} - -void LegacyExecutor::notify() { - std::unique_lock lock(queueMutex); - cv.notify_one(); -} - -void LegacyExecutorWrapper::run(LegacyExecutor* legacyExecutor, std::atomic* receivedEndSignal) { - // TODO add metrics - while (!(*receivedEndSignal)) { - try { - SPDLOG_LOGGER_INFO(llm_executor_logger, "All requests: {};", legacyExecutor->requestsQueueSize()); - if (legacyExecutor->hasRequests()) { - legacyExecutor->processRequest(); - } else { - legacyExecutor->waitForRequests(receivedEndSignal); - } - } catch (std::exception& e) { - SPDLOG_LOGGER_ERROR(llm_executor_logger, "Error occurred in LLM executor: {}.", e.what()); - exit(1); - } - } -} - LegacyExecutorWrapper::LegacyExecutorWrapper(std::shared_ptr pipe) : - legacyExecutor(std::move(pipe)) { - legacyExecutorThread = std::thread(LegacyExecutorWrapper::run, &legacyExecutor, &finishExecutorThread); -} - -LegacyExecutorWrapper::~LegacyExecutorWrapper() { - finishExecutorThread = true; - legacyExecutor.notify(); - legacyExecutorThread.join(); -} -void LegacyExecutorWrapper::addRequest(std::shared_ptr request) { - legacyExecutor.addRequest(request); -} + ExecutorWrapper(llm_executor_logger, std::make_shared(std::move(pipe))) {} } // namespace ovms diff --git a/src/llm/language_model/legacy/legacy_executor.hpp b/src/llm/language_model/legacy/legacy_executor.hpp index 5bb3365bd2..100f3cbe0d 100644 --- a/src/llm/language_model/legacy/legacy_executor.hpp +++ b/src/llm/language_model/legacy/legacy_executor.hpp @@ -16,57 +16,27 @@ #pragma once -#include -#include -#include #include -#include -#include -#include -#include #include "openvino/genai/llm_pipeline.hpp" #include "../../../logging.hpp" #include "../../../profiler.hpp" - -#include -#include +#include "../../../executor_base.hpp" namespace ovms { struct LegacyServableExecutionContext; -struct LegacyExecutor; -struct LegacyExecutor { - std::condition_variable cv; - std::queue> requests; - std::mutex queueMutex; +struct LegacyExecutor : public Executor> { std::shared_ptr pipe; LegacyExecutor(std::shared_ptr pipe); - bool hasRequests(); - - size_t requestsQueueSize(); void processRequest(); - - void waitForRequests(std::atomic* receivedEndSignal); - - void addRequest(std::shared_ptr request); - - void notify(); }; -class LegacyExecutorWrapper { - LegacyExecutor legacyExecutor; - std::thread legacyExecutorThread; - std::atomic finishExecutorThread = false; - - static void run(LegacyExecutor* legacyExecutor, std::atomic* receivedEndSignal); - +class LegacyExecutorWrapper : public ExecutorWrapper { public: LegacyExecutorWrapper(std::shared_ptr pipe); - ~LegacyExecutorWrapper(); - void addRequest(std::shared_ptr request); }; } // namespace ovms diff --git a/src/llm/language_model/legacy/servable.cpp b/src/llm/language_model/legacy/servable.cpp index 6a3f422c8a..aadbf9e683 100644 --- a/src/llm/language_model/legacy/servable.cpp +++ b/src/llm/language_model/legacy/servable.cpp @@ -20,6 +20,7 @@ #include #include "../../../logging.hpp" +#include "../../../profiler.hpp" #include "../../../status.hpp" #include "../../apis/openai_completions.hpp" #include "../../apis/openai_responses.hpp" diff --git a/src/llm/language_model/legacy/servable.hpp b/src/llm/language_model/legacy/servable.hpp index 19af42df85..eae3c9580e 100644 --- a/src/llm/language_model/legacy/servable.hpp +++ b/src/llm/language_model/legacy/servable.hpp @@ -14,6 +14,7 @@ // limitations under the License. //***************************************************************************** #pragma once +#include #include #include diff --git a/src/llm/servable.cpp b/src/llm/servable.cpp index 3079a07603..e2ccd06e78 100644 --- a/src/llm/servable.cpp +++ b/src/llm/servable.cpp @@ -355,15 +355,6 @@ absl::Status GenAiServable::preparePartialResponse(std::shared_ptr apiHandler; std::shared_ptr generationConfigBuilder; @@ -141,7 +142,7 @@ class GenAiServable { loadRequest method implementation MUST fill executionContext payload and endpoint fields. Base implementation does that and makes sure URI matches either chat/completions or completions endpoint. */ - virtual absl::Status loadRequest(std::shared_ptr& executionContext, const ovms::HttpPayload& payload); + virtual absl::Status loadRequest(std::shared_ptr& executionContext, const HttpPayload& payload); // Creates execution context for the request virtual std::shared_ptr createExecutionContext() = 0; @@ -207,7 +208,6 @@ class GenAiServable { */ virtual absl::Status preparePartialResponse(std::shared_ptr& executionContext); }; -std::string wrapTextInServerSideEventMessage(const std::string& text); using GenAiServableMap = std::unordered_map>; -void logRequestDetails(const ovms::HttpPayload& payload); +void logRequestDetails(const HttpPayload& payload); } // namespace ovms diff --git a/src/llm/visual_language_model/continuous_batching/servable.hpp b/src/llm/visual_language_model/continuous_batching/servable.hpp index 88c78e3dcb..a50b5c340d 100644 --- a/src/llm/visual_language_model/continuous_batching/servable.hpp +++ b/src/llm/visual_language_model/continuous_batching/servable.hpp @@ -48,7 +48,7 @@ class VisualLanguageModelServable : public ContinuousBatchingServable { absl::Status addRequestToPipeline(std::shared_ptr& executionContext) override; // Interface methods - absl::Status loadRequest(std::shared_ptr& executionContext, const ovms::HttpPayload& payload) override; + absl::Status loadRequest(std::shared_ptr& executionContext, const HttpPayload& payload) override; std::shared_ptr createExecutionContext() override; std::shared_ptr getProperties() override; absl::Status prepareInputs(std::shared_ptr& executionContext) override; diff --git a/src/llm/visual_language_model/legacy/servable.hpp b/src/llm/visual_language_model/legacy/servable.hpp index 8828153e7a..8c07818bce 100644 --- a/src/llm/visual_language_model/legacy/servable.hpp +++ b/src/llm/visual_language_model/legacy/servable.hpp @@ -65,7 +65,7 @@ class VisualLanguageModelLegacyServable : public GenAiServable { } // Interface methods - absl::Status loadRequest(std::shared_ptr& executionContext, const ovms::HttpPayload& payload); + absl::Status loadRequest(std::shared_ptr& executionContext, const HttpPayload& payload); std::shared_ptr createExecutionContext() override; std::shared_ptr getProperties() override; absl::Status parseRequest(std::shared_ptr& executionContext) override; diff --git a/src/mediapipe_internal/mediapipegraphexecutor.hpp b/src/mediapipe_internal/mediapipegraphexecutor.hpp index 4d0b069f43..0d87c86088 100644 --- a/src/mediapipe_internal/mediapipegraphexecutor.hpp +++ b/src/mediapipe_internal/mediapipegraphexecutor.hpp @@ -303,7 +303,7 @@ class MediapipeGraphExecutor { #endif inputSidePackets[LLM_SESSION_SIDE_PACKET_TAG] = mediapipe::MakePacket(this->sidePacketMaps.genAiServableMap).At(STARTING_TIMESTAMP); inputSidePackets[EMBEDDINGS_SESSION_SIDE_PACKET_TAG] = mediapipe::MakePacket(this->sidePacketMaps.embeddingsServableMap).At(STARTING_TIMESTAMP); - // Add image generation side packet in case image generation allow for streaming + inputSidePackets[STT_SESSION_SIDE_PACKET_TAG] = mediapipe::MakePacket(this->sidePacketMaps.sttServableMap).At(STARTING_TIMESTAMP); } { diff --git a/src/sse_utils.hpp b/src/sse_utils.hpp new file mode 100644 index 0000000000..0aba0348b4 --- /dev/null +++ b/src/sse_utils.hpp @@ -0,0 +1,29 @@ +//***************************************************************************** +// Copyright 2025 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#pragma once + +#include +#include + +namespace ovms { + +inline std::string wrapTextInServerSideEventMessage(const std::string& text) { + std::stringstream ss; + ss << "data: " << text << "\n\n"; + return ss.str(); +} + +} // namespace ovms diff --git a/src/test/audio/graph_stt.pbtxt b/src/test/audio/graph_stt.pbtxt index b38886db6b..df91efcca4 100644 --- a/src/test/audio/graph_stt.pbtxt +++ b/src/test/audio/graph_stt.pbtxt @@ -19,8 +19,14 @@ node { name: "S2tExecutor" input_side_packet: "STT_NODE_RESOURCES:s2t_servable" calculator: "S2tCalculator" + input_stream: "LOOPBACK:loopback" input_stream: "HTTP_REQUEST_PAYLOAD:input" + output_stream: "LOOPBACK:loopback" output_stream: "HTTP_RESPONSE_PAYLOAD:output" + input_stream_info: { + tag_index: 'LOOPBACK:0', + back_edge: true + } node_options: { [type.googleapis.com / mediapipe.S2tCalculatorOptions]: { models_path: "/ovms/src/test/llm_testing/openai/whisper-tiny", @@ -28,4 +34,14 @@ node { target_device: "CPU" } } + input_stream_handler { + input_stream_handler: "SyncSetInputStreamHandler", + options { + [mediapipe.SyncSetInputStreamHandlerOptions.ext] { + sync_set { + tag_index: "LOOPBACK:0" + } + } + } + } } \ No newline at end of file diff --git a/src/test/audio/graph_stt_word_timestamps.pbtxt b/src/test/audio/graph_stt_word_timestamps.pbtxt index 732d62b352..3c830b06ba 100644 --- a/src/test/audio/graph_stt_word_timestamps.pbtxt +++ b/src/test/audio/graph_stt_word_timestamps.pbtxt @@ -19,8 +19,14 @@ node { name: "S2tExecutor" input_side_packet: "STT_NODE_RESOURCES:s2t_servable" calculator: "S2tCalculator" + input_stream: "LOOPBACK:loopback" input_stream: "HTTP_REQUEST_PAYLOAD:input" + output_stream: "LOOPBACK:loopback" output_stream: "HTTP_RESPONSE_PAYLOAD:output" + input_stream_info: { + tag_index: 'LOOPBACK:0', + back_edge: true + } node_options: { [type.googleapis.com / mediapipe.S2tCalculatorOptions]: { models_path: "/ovms/src/test/llm_testing/openai/whisper-tiny", @@ -29,4 +35,14 @@ node { enable_word_timestamps: true } } + input_stream_handler { + input_stream_handler: "SyncSetInputStreamHandler", + options { + [mediapipe.SyncSetInputStreamHandlerOptions.ext] { + sync_set { + tag_index: "LOOPBACK:0" + } + } + } + } } \ No newline at end of file diff --git a/src/test/audio/s2t_streaming_test.cpp b/src/test/audio/s2t_streaming_test.cpp new file mode 100644 index 0000000000..2366d96cb6 --- /dev/null +++ b/src/test/audio/s2t_streaming_test.cpp @@ -0,0 +1,200 @@ +//***************************************************************************** +// Copyright 2026 Intel Corporation +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +//***************************************************************************** +#include +#include +#include +#include + +#include + +#include "../../audio/speech_to_text/streaming_text_queue.hpp" +#include "../../audio/speech_to_text/s2t_streaming_handler.hpp" +#include "../../sse_utils.hpp" + +using ovms::StreamingTextQueue; + +// ====================== StreamingTextQueue Tests ====================== + +TEST(StreamingTextQueueTest, PushAndPop) { + StreamingTextQueue queue; + queue.push("hello"); + std::string out; + EXPECT_TRUE(queue.waitAndPop(out)); + EXPECT_EQ(out, "hello"); +} + +TEST(StreamingTextQueueTest, FIFOOrder) { + StreamingTextQueue queue; + queue.push("first"); + queue.push("second"); + queue.push("third"); + std::string out; + EXPECT_TRUE(queue.waitAndPop(out)); + EXPECT_EQ(out, "first"); + EXPECT_TRUE(queue.waitAndPop(out)); + EXPECT_EQ(out, "second"); + EXPECT_TRUE(queue.waitAndPop(out)); + EXPECT_EQ(out, "third"); +} + +TEST(StreamingTextQueueTest, DoneWithEmptyQueue) { + StreamingTextQueue queue; + queue.endStreaming(); + std::string out; + EXPECT_FALSE(queue.waitAndPop(out)); +} + +TEST(StreamingTextQueueTest, DoneAfterAllPopped) { + StreamingTextQueue queue; + queue.push("data"); + queue.endStreaming(); + std::string out; + EXPECT_TRUE(queue.waitAndPop(out)); + EXPECT_EQ(out, "data"); + EXPECT_FALSE(queue.waitAndPop(out)); +} + +TEST(StreamingTextQueueTest, WaitAndPopBlocksUntilPush) { + StreamingTextQueue queue; + std::string result; + auto future = std::async(std::launch::async, [&queue, &result]() { + return queue.waitAndPop(result); + }); + // Give the consumer thread time to block + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + queue.push("delayed"); + EXPECT_TRUE(future.get()); + EXPECT_EQ(result, "delayed"); +} + +TEST(StreamingTextQueueTest, WaitAndPopUnblocksOnDone) { + StreamingTextQueue queue; + std::string result; + auto future = std::async(std::launch::async, [&queue, &result]() { + return queue.waitAndPop(result); + }); + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + queue.endStreaming(); + EXPECT_FALSE(future.get()); +} + +TEST(StreamingTextQueueTest, ConcurrentProducerConsumer) { + StreamingTextQueue queue; + const int numItems = 100; + std::vector received; + + auto producer = std::async(std::launch::async, [&queue, numItems]() { + for (int i = 0; i < numItems; ++i) { + queue.push(std::to_string(i)); + } + queue.endStreaming(); + }); + + auto consumer = std::async(std::launch::async, [&queue, &received]() { + std::string out; + while (queue.waitAndPop(out)) { + received.push_back(out); + } + }); + + producer.get(); + consumer.get(); + + ASSERT_EQ(static_cast(received.size()), numItems); + for (int i = 0; i < numItems; ++i) { + EXPECT_EQ(received[i], std::to_string(i)); + } +} + +TEST(StreamingTextQueueTest, EmptyStringPush) { + StreamingTextQueue queue; + queue.push(""); + queue.endStreaming(); + std::string out; + EXPECT_TRUE(queue.waitAndPop(out)); + EXPECT_EQ(out, ""); + EXPECT_FALSE(queue.waitAndPop(out)); +} + +// ====================== SSE Utils Tests ====================== + +TEST(SseUtilsTest, WrapSimpleMessage) { + std::string result = ovms::wrapTextInServerSideEventMessage("hello"); + EXPECT_EQ(result, "data: hello\n\n"); +} + +TEST(SseUtilsTest, WrapJsonMessage) { + std::string result = ovms::wrapTextInServerSideEventMessage("{\"text\":\"hi\"}"); + EXPECT_EQ(result, "data: {\"text\":\"hi\"}\n\n"); +} + +TEST(SseUtilsTest, WrapDoneMarker) { + std::string result = ovms::wrapTextInServerSideEventMessage("[DONE]"); + EXPECT_EQ(result, "data: [DONE]\n\n"); +} + +TEST(SseUtilsTest, WrapEmptyMessage) { + std::string result = ovms::wrapTextInServerSideEventMessage(""); + EXPECT_EQ(result, "data: \n\n"); +} + +// ====================== S2tStreamingHandler event serialization Tests ====================== + +TEST(S2tStreamingHandlerTest, SerializeDeltaEventSimple) { + std::string result = mediapipe::S2tStreamingHandler::serializeDeltaEvent("hello world"); + EXPECT_EQ(result, "{\"type\":\"transcript.text.delta\",\"delta\":\"hello world\",\"logprobs\":[]}"); +} + +TEST(S2tStreamingHandlerTest, SerializeDeltaEventEmpty) { + std::string result = mediapipe::S2tStreamingHandler::serializeDeltaEvent(""); + EXPECT_EQ(result, "{\"type\":\"transcript.text.delta\",\"delta\":\"\",\"logprobs\":[]}"); +} + +TEST(S2tStreamingHandlerTest, SerializeDeltaEventSpecialCharacters) { + std::string result = mediapipe::S2tStreamingHandler::serializeDeltaEvent("say \"hello\" & "); + // rapidjson escapes quotes + EXPECT_NE(result.find("\"delta\""), std::string::npos); + EXPECT_NE(result.find("say \\\"hello\\\""), std::string::npos); +} + +TEST(S2tStreamingHandlerTest, SerializeDeltaEventUnicode) { + std::string result = mediapipe::S2tStreamingHandler::serializeDeltaEvent("日本語テスト"); + EXPECT_NE(result.find("\"delta\""), std::string::npos); +} + +TEST(S2tStreamingHandlerTest, SerializeDoneEventSimple) { + std::string result = mediapipe::S2tStreamingHandler::serializeDoneEvent("hello world"); + EXPECT_EQ(result, "{\"type\":\"transcript.text.done\",\"text\":\"hello world\",\"logprobs\":[]}"); +} + +TEST(S2tStreamingHandlerTest, SerializeDoneEventUnicode) { + std::string result = mediapipe::S2tStreamingHandler::serializeDoneEvent("日本語テスト"); + EXPECT_NE(result.find("\"text\""), std::string::npos); +} + +// ====================== Full SSE streaming chunk formatting ====================== + +TEST(S2tStreamingHandlerTest, FullStreamingChunkFormat) { + std::string json = mediapipe::S2tStreamingHandler::serializeDeltaEvent("token"); + std::string sse = ovms::wrapTextInServerSideEventMessage(json); + EXPECT_EQ(sse, "data: {\"type\":\"transcript.text.delta\",\"delta\":\"token\",\"logprobs\":[]}\n\n"); +} + +TEST(S2tStreamingHandlerTest, FullDoneEventFormat) { + std::string json = mediapipe::S2tStreamingHandler::serializeDoneEvent("all text"); + std::string sse = ovms::wrapTextInServerSideEventMessage(json); + EXPECT_EQ(sse, "data: {\"type\":\"transcript.text.done\",\"text\":\"all text\",\"logprobs\":[]}\n\n"); +} diff --git a/src/test/audio/speech2text_test.cpp b/src/test/audio/speech2text_test.cpp index 5a957e1abc..0fc7985705 100644 --- a/src/test/audio/speech2text_test.cpp +++ b/src/test/audio/speech2text_test.cpp @@ -55,15 +55,17 @@ class Speech2TextHttpTest : public V3HttpTest { "Content-Disposition: form-data;name=\"file\";\"filename=file\"" "\r\nContent-Type: application/octet-stream" "\r\ncontent-transfer-encoding: quoted-printable\r\n\r\n"; - std::unique_ptr imageBytes; + std::unique_ptr audioBytes; size_t fileSize; - readFile(getGenericFullPathForSrcTest("/ovms/src/test/audio/test.wav"), fileSize, imageBytes); - Speech2TextHttpTest::body.append(imageBytes.get(), fileSize); + readFile(getGenericFullPathForSrcTest("/ovms/src/test/audio/test.wav"), fileSize, audioBytes); + Speech2TextHttpTest::body.append(audioBytes.get(), fileSize); Speech2TextHttpTest::body.append("12345"); } void SetUp() { V3HttpTest::SetUp(); + ON_CALL(*writer, IsDisconnected()) + .WillByDefault(::testing::Return(false)); ASSERT_EQ(handler->parseRequestComponents(comp, "POST", endpoint, multipartHeader), ovms::StatusCode::OK); } @@ -75,6 +77,239 @@ std::unique_ptr Speech2TextHttpTest::t; std::string Speech2TextHttpTest::body; std::string Speech2TextHttpTest::modelNameForm; +// ====================== Speech2Text Streaming Tests ====================== + +class Speech2TextStreamingTest : public Speech2TextHttpTest { +protected: + // Builds a multipart body identical to the base fixture body but with an + // extra `stream=true` field appended. + static std::string streamingBody() { + const std::string streamField = "\r\n" + "Content-Disposition: form-data;name=\"stream\"\r\n" + "\r\n" + "true\r\n" + "--12345"; + return Speech2TextHttpTest::body + streamField; + } + + void SetUp() override { + Speech2TextHttpTest::SetUp(); + ON_CALL(*writer, PartialReplyBegin(::testing::_)) + .WillByDefault(testing::Invoke([](std::function fn) { fn(); })); + ON_CALL(*writer, IsDisconnected()) + .WillByDefault(::testing::Return(false)); + } +}; + +TEST_F(Speech2TextStreamingTest, streamingTranscriptionReceivesDeltaAndDoneEvents) { + auto req = drogon::HttpRequest::newHttpRequest(); + req->setMethod(drogon::Post); + req->addHeader("content-type", "multipart/form-data; boundary=\"12345\""); + req->setBody(streamingBody()); + std::shared_ptr multiPartParserWithRequest = std::make_shared(req); + + std::vector receivedChunks; + EXPECT_CALL(*writer, PartialReply(::testing::_)) + .WillRepeatedly([&receivedChunks](std::string chunk) { + receivedChunks.push_back(std::move(chunk)); + }); + EXPECT_CALL(*writer, PartialReplyEnd()).Times(1); + + std::string requestBody; + ASSERT_EQ( + handler->dispatchToProcessor(endpoint, requestBody, &response, comp, responseComponents, writer, multiPartParserWithRequest), + ovms::StatusCode::PARTIAL_END); + + ASSERT_FALSE(receivedChunks.empty()); + + const std::string dataPrefix = "data: "; + // Validate each delta chunk (all but the last) + for (size_t i = 0; i + 1 < receivedChunks.size(); ++i) { + const std::string& chunk = receivedChunks[i]; + ASSERT_GE(chunk.size(), dataPrefix.size()); + ASSERT_EQ(chunk.substr(0, dataPrefix.size()), dataPrefix) << "Chunk " << i << " missing SSE prefix"; + std::string json = chunk.substr(dataPrefix.size()); + // Trim trailing newlines + while (!json.empty() && (json.back() == '\n' || json.back() == '\r')) + json.pop_back(); + rapidjson::Document d; + ASSERT_EQ(d.Parse(json.c_str()).HasParseError(), false) << "Chunk " << i << " is not valid JSON"; + ASSERT_TRUE(d.HasMember("type")) << "Chunk " << i << " missing 'type'"; + EXPECT_STREQ(d["type"].GetString(), "transcript.text.delta") << "Chunk " << i; + ASSERT_TRUE(d.HasMember("delta")) << "Chunk " << i << " missing 'delta'"; + EXPECT_TRUE(d["delta"].IsString()) << "Chunk " << i << " 'delta' is not a string"; + ASSERT_TRUE(d.HasMember("logprobs")) << "Chunk " << i << " missing 'logprobs'"; + EXPECT_TRUE(d["logprobs"].IsArray()) << "Chunk " << i << " 'logprobs' is not an array"; + } + + // Validate the final done event + const std::string& lastChunk = receivedChunks.back(); + ASSERT_GE(lastChunk.size(), dataPrefix.size()); + EXPECT_EQ(lastChunk.substr(0, dataPrefix.size()), dataPrefix); + std::string lastJson = lastChunk.substr(dataPrefix.size()); + while (!lastJson.empty() && (lastJson.back() == '\n' || lastJson.back() == '\r')) + lastJson.pop_back(); + rapidjson::Document doneDoc; + ASSERT_EQ(doneDoc.Parse(lastJson.c_str()).HasParseError(), false) << "Done event is not valid JSON"; + ASSERT_TRUE(doneDoc.HasMember("type")); + EXPECT_STREQ(doneDoc["type"].GetString(), "transcript.text.done"); + ASSERT_TRUE(doneDoc.HasMember("text")); + EXPECT_TRUE(doneDoc["text"].IsString()); + EXPECT_FALSE(std::string(doneDoc["text"].GetString()).empty()); + ASSERT_TRUE(doneDoc.HasMember("logprobs")); + EXPECT_TRUE(doneDoc["logprobs"].IsArray()); +} + +TEST_F(Speech2TextStreamingTest, streamingTranscriptionDoneTextMatchesConcatenatedDeltas) { + auto req = drogon::HttpRequest::newHttpRequest(); + req->setMethod(drogon::Post); + req->addHeader("content-type", "multipart/form-data; boundary=\"12345\""); + req->setBody(streamingBody()); + std::shared_ptr multiPartParserWithRequest = std::make_shared(req); + + std::vector receivedChunks; + ON_CALL(*writer, PartialReply(::testing::_)) + .WillByDefault([&receivedChunks](std::string chunk) { + receivedChunks.push_back(std::move(chunk)); + }); + EXPECT_CALL(*writer, PartialReplyEnd()).Times(1); + + std::string requestBody; + ASSERT_EQ( + handler->dispatchToProcessor(endpoint, requestBody, &response, comp, responseComponents, writer, multiPartParserWithRequest), + ovms::StatusCode::PARTIAL_END); + + ASSERT_GE(receivedChunks.size(), 1u); + + const std::string dataPrefix = "data: "; + auto parseChunkJson = [&dataPrefix](const std::string& chunk) { + std::string json = chunk.substr(dataPrefix.size()); + while (!json.empty() && (json.back() == '\n' || json.back() == '\r')) + json.pop_back(); + return json; + }; + + // Collect all but the last delta text + std::string concatenatedDeltas; + for (size_t i = 0; i + 1 < receivedChunks.size(); ++i) { + rapidjson::Document d; + d.Parse(parseChunkJson(receivedChunks[i]).c_str()); + if (d.HasMember("delta") && d["delta"].IsString()) { + concatenatedDeltas += d["delta"].GetString(); + } + } + + // Get done event text + rapidjson::Document doneDoc; + doneDoc.Parse(parseChunkJson(receivedChunks.back()).c_str()); + ASSERT_EQ(doneDoc.HasParseError(), false); + ASSERT_TRUE(doneDoc.HasMember("text")); + const std::string doneText = doneDoc["text"].GetString(); + + EXPECT_EQ(concatenatedDeltas, doneText) + << "Concatenated deltas should equal the final 'done' text"; +} + +TEST_F(Speech2TextStreamingTest, streamingTranscriptionWithLanguage) { + auto req = drogon::HttpRequest::newHttpRequest(); + req->setMethod(drogon::Post); + req->addHeader("content-type", "multipart/form-data; boundary=\"12345\""); + const std::string streamAndLanguage = "\r\n" + "Content-Disposition: form-data;name=\"stream\"\r\n" + "\r\n" + "true\r\n" + "--12345\r\n" + "Content-Disposition: form-data;name=\"language\"\r\n" + "\r\n" + "en\r\n" + "--12345"; + req->setBody(Speech2TextHttpTest::body + streamAndLanguage); + std::shared_ptr multiPartParserWithRequest = std::make_shared(req); + + std::vector receivedChunks; + ON_CALL(*writer, PartialReply(::testing::_)) + .WillByDefault([&receivedChunks](std::string chunk) { + receivedChunks.push_back(std::move(chunk)); + }); + EXPECT_CALL(*writer, PartialReplyEnd()).Times(1); + + std::string requestBody; + ASSERT_EQ( + handler->dispatchToProcessor(endpoint, requestBody, &response, comp, responseComponents, writer, multiPartParserWithRequest), + ovms::StatusCode::PARTIAL_END); + + ASSERT_FALSE(receivedChunks.empty()); + // Verify last chunk is a done event + const std::string dataPrefix = "data: "; + std::string lastJson = receivedChunks.back().substr(dataPrefix.size()); + while (!lastJson.empty() && (lastJson.back() == '\n' || lastJson.back() == '\r')) + lastJson.pop_back(); + rapidjson::Document doneDoc; + ASSERT_EQ(doneDoc.Parse(lastJson.c_str()).HasParseError(), false); + ASSERT_TRUE(doneDoc.HasMember("type")); + EXPECT_STREQ(doneDoc["type"].GetString(), "transcript.text.done"); +} + +TEST_F(Speech2TextStreamingTest, streamingTranscriptionInvalidFileReturnsError) { + const std::string invalidBody = Speech2TextHttpTest::modelNameForm + + "--12345\r\n" + "Content-Disposition: form-data;name=\"stream\"\r\n" + "\r\n" + "true\r\n" + "--12345\r\n" + "Content-Disposition: form-data;name=\"file\";\"filename=file\"" + "\r\nContent-Type: application/octet-stream" + "\r\ncontent-transfer-encoding: quoted-printable\r\n\r\n" + "INVALID_AUDIO12345"; + + auto req = drogon::HttpRequest::newHttpRequest(); + req->setMethod(drogon::Post); + req->addHeader("content-type", "multipart/form-data; boundary=\"12345\""); + req->setBody(invalidBody); + std::shared_ptr multiPartParserWithRequest = std::make_shared(req); + + EXPECT_CALL(*writer, PartialReplyWithStatus(::testing::_, ::testing::_)) + .WillOnce([](std::string responseBody, ovms::HTTPStatusCode code) { + EXPECT_EQ(code, ovms::HTTPStatusCode::BAD_REQUEST); + rapidjson::Document d; + ASSERT_EQ(d.Parse(responseBody.c_str()).HasParseError(), false); + ASSERT_TRUE(d.HasMember("error")); + }); + EXPECT_CALL(*writer, PartialReplyEnd()).Times(1); + + std::string requestBody; + ASSERT_EQ( + handler->dispatchToProcessor(endpoint, requestBody, &response, comp, responseComponents, writer, multiPartParserWithRequest), + ovms::StatusCode::PARTIAL_END); +} + +TEST_F(Speech2TextStreamingTest, streamingTranslationIsNotSupported) { + const std::string translationEndpoint = "/v3/audio/translations"; + ASSERT_EQ(handler->parseRequestComponents(comp, "POST", translationEndpoint, multipartHeader), ovms::StatusCode::OK); + + auto req = drogon::HttpRequest::newHttpRequest(); + req->setMethod(drogon::Post); + req->addHeader("content-type", "multipart/form-data; boundary=\"12345\""); + req->setBody(streamingBody()); + std::shared_ptr multiPartParserWithRequest = std::make_shared(req); + + EXPECT_CALL(*writer, PartialReplyWithStatus(::testing::_, ::testing::_)) + .WillOnce([](std::string responseBody, ovms::HTTPStatusCode code) { + EXPECT_EQ(code, ovms::HTTPStatusCode::BAD_REQUEST); + rapidjson::Document d; + ASSERT_EQ(d.Parse(responseBody.c_str()).HasParseError(), false); + ASSERT_TRUE(d.HasMember("error")); + ASSERT_TRUE(d["error"].IsString()); + EXPECT_NE(std::string(d["error"].GetString()).find("streaming is not supported for translations endpoint"), std::string::npos); + }); + EXPECT_CALL(*writer, PartialReplyEnd()).Times(1); + + std::string requestBody; + ASSERT_EQ( + handler->dispatchToProcessor(translationEndpoint, requestBody, &response, comp, responseComponents, writer, multiPartParserWithRequest), + ovms::StatusCode::PARTIAL_END); +} + TEST_F(Speech2TextHttpTest, simplePositive) { auto req = drogon::HttpRequest::newHttpRequest(); req->setMethod(drogon::Post); @@ -179,10 +414,10 @@ TEST_F(Speech2TextHttpTest, positiveWordTimestamps) { "Content-Disposition: form-data;name=\"file\";\"filename=file\"" "\r\nContent-Type: application/octet-stream" "\r\ncontent-transfer-encoding: quoted-printable\r\n\r\n"; - std::unique_ptr imageBytes; + std::unique_ptr audioBytes; size_t fileSize; - readFile(getGenericFullPathForSrcTest("/ovms/src/test/audio/test.wav"), fileSize, imageBytes); - multipartBody.append(imageBytes.get(), fileSize); + readFile(getGenericFullPathForSrcTest("/ovms/src/test/audio/test.wav"), fileSize, audioBytes); + multipartBody.append(audioBytes.get(), fileSize); multipartBody.append("12345"); req->setBody(multipartBody); std::shared_ptr multiPartParserWithRequest = std::make_shared(req); @@ -219,10 +454,10 @@ TEST_F(Speech2TextHttpTest, positiveBothTimestampsTypes) { "Content-Disposition: form-data;name=\"file\";\"filename=file\"" "\r\nContent-Type: application/octet-stream" "\r\ncontent-transfer-encoding: quoted-printable\r\n\r\n"; - std::unique_ptr imageBytes; + std::unique_ptr audioBytes; size_t fileSize; - readFile(getGenericFullPathForSrcTest("/ovms/src/test/audio/test.wav"), fileSize, imageBytes); - multipartBody.append(imageBytes.get(), fileSize); + readFile(getGenericFullPathForSrcTest("/ovms/src/test/audio/test.wav"), fileSize, audioBytes); + multipartBody.append(audioBytes.get(), fileSize); multipartBody.append("12345"); req->setBody(multipartBody); std::shared_ptr multiPartParserWithRequest = std::make_shared(req); @@ -300,16 +535,16 @@ TEST_F(Speech2TextHttpTest, invalidLanguageTooLong) { EXPECT_EQ(status.string(), expectedMsg); } -TEST_F(Speech2TextHttpTest, invalidTemperatureOutOfRange) { +TEST_F(Speech2TextHttpTest, invalidTemperatureType) { auto req = drogon::HttpRequest::newHttpRequest(); req->setMethod(drogon::Post); req->addHeader("content-type", "multipart/form-data; boundary=\"12345\""); - std::string language = "\r\n" - "Content-Disposition: form-data;name=\"temperature\"\r\n" - "\r\n" - "10.0\r\n" - "--12345"; - req->setBody(Speech2TextHttpTest::body + language); + std::string temperature = "\r\n" + "Content-Disposition: form-data;name=\"temperature\"\r\n" + "\r\n" + "INVALID\r\n" + "--12345"; + req->setBody(Speech2TextHttpTest::body + temperature); std::shared_ptr multiPartParserWithRequest = std::make_shared(req); std::string requestBody = ""; auto status = handler->dispatchToProcessor(endpoint, requestBody, &response, comp, responseComponents, writer, multiPartParserWithRequest); @@ -317,7 +552,7 @@ TEST_F(Speech2TextHttpTest, invalidTemperatureOutOfRange) { status.getCode(), ovms::StatusCode::MEDIAPIPE_EXECUTION_ERROR); std::string expectedMsg = "Mediapipe execution failed. MP status - INVALID_ARGUMENT: CalculatorGraph::Run() failed: \n" - "Calculator::Process() for node \"S2tExecutor\" failed: Temperature out of range(0.0, 2.0)"; + "Calculator::Process() for node \"S2tExecutor\" failed: Invalid temperature type."; EXPECT_EQ(status.string(), expectedMsg); } diff --git a/src/test/graph_export_test.cpp b/src/test/graph_export_test.cpp index 468d2e7917..70961e3493 100644 --- a/src/test/graph_export_test.cpp +++ b/src/test/graph_export_test.cpp @@ -422,8 +422,14 @@ node { name: "myModel" calculator: "S2tCalculator" input_side_packet: "STT_NODE_RESOURCES:s2t_servable" + input_stream: "LOOPBACK:loopback" input_stream: "HTTP_REQUEST_PAYLOAD:input" + output_stream: "LOOPBACK:loopback" output_stream: "HTTP_RESPONSE_PAYLOAD:output" + input_stream_info: { + tag_index: 'LOOPBACK:0', + back_edge: true + } node_options: { [type.googleapis.com / mediapipe.S2tCalculatorOptions]: { models_path: "/model1/path" @@ -431,6 +437,16 @@ node { plugin_config: '{"NUM_STREAMS":"2"}' } } + input_stream_handler { + input_stream_handler: "SyncSetInputStreamHandler", + options { + [mediapipe.SyncSetInputStreamHandlerOptions.ext] { + sync_set { + tag_index: "LOOPBACK:0" + } + } + } + } } )"; @@ -441,14 +457,30 @@ node { name: "" calculator: "S2tCalculator" input_side_packet: "STT_NODE_RESOURCES:s2t_servable" + input_stream: "LOOPBACK:loopback" input_stream: "HTTP_REQUEST_PAYLOAD:input" + output_stream: "LOOPBACK:loopback" output_stream: "HTTP_RESPONSE_PAYLOAD:output" + input_stream_info: { + tag_index: 'LOOPBACK:0', + back_edge: true + } node_options: { [type.googleapis.com / mediapipe.S2tCalculatorOptions]: { models_path: "./" target_device: "CPU" } } + input_stream_handler { + input_stream_handler: "SyncSetInputStreamHandler", + options { + [mediapipe.SyncSetInputStreamHandlerOptions.ext] { + sync_set { + tag_index: "LOOPBACK:0" + } + } + } + } } )"; diff --git a/src/test/multipart_calculator_test.cpp b/src/test/multipart_calculator_test.cpp index e18eae2cbc..6c2135c4a0 100644 --- a/src/test/multipart_calculator_test.cpp +++ b/src/test/multipart_calculator_test.cpp @@ -99,6 +99,7 @@ It has two lines. EXPECT_CALL(*multiPartParser, parse()).WillOnce(::testing::Return(true)); EXPECT_CALL(*multiPartParser, getFieldByName(::testing::Eq("model"))).WillOnce(::testing::Return("multipart")); + EXPECT_CALL(*multiPartParser, getFieldByName(::testing::Eq("stream"))).WillOnce(::testing::Return("")); EXPECT_CALL(*multiPartParser, getFieldByName(::testing::Eq("email"))).WillOnce(::testing::Return("john@example.com")); EXPECT_CALL(*multiPartParser, getFieldByName(::testing::Eq("username"))).WillOnce(::testing::Return("john_doe")); EXPECT_CALL(*multiPartParser, getArrayFieldByName(::testing::Eq("some_param[]"))).WillOnce(::testing::Return(std::vector{"val1", "val2"})); @@ -151,6 +152,7 @@ It has two lines. EXPECT_CALL(*multiPartParser, parse()).WillOnce(::testing::Return(true)); EXPECT_CALL(*multiPartParser, getFieldByName(::testing::Eq("model"))).WillOnce(::testing::Return("")); + EXPECT_CALL(*multiPartParser, getFieldByName(::testing::Eq("stream"))).WillOnce(::testing::Return("")); EXPECT_CALL(*multiPartParser, getFieldByName(::testing::Eq("email"))).WillOnce(::testing::Return("john@example.com")); EXPECT_CALL(*multiPartParser, getFieldByName(::testing::Eq("username"))).WillOnce(::testing::Return("john_doe")); EXPECT_CALL(*multiPartParser, getArrayFieldByName(::testing::Eq("some_param[]"))).WillOnce(::testing::Return(std::vector{"val1", "val2"})); @@ -196,8 +198,8 @@ It has two lines. ------WebKitFormBoundary7MA4YWxkTrZu0gW--)"; EXPECT_CALL(*multiPartParser, parse()).WillOnce(::testing::Return(true)); - EXPECT_CALL(*multiPartParser, getFieldByName(::testing::_)).Times(0); EXPECT_CALL(*multiPartParser, getFieldByName(::testing::Eq("model"))).WillOnce(::testing::Return("")); + EXPECT_CALL(*multiPartParser, getFieldByName(::testing::Eq("stream"))).WillOnce(::testing::Return("")); EXPECT_CALL(*multiPartParser, getFileContentByFieldName(::testing::_)).Times(0); // Default routing uses everything that comes after /v3/ as graph name @@ -232,8 +234,8 @@ It has two lines. ------WebKitFormBoundary7MA4YWxkTrZu0gW--)"; EXPECT_CALL(*multiPartParser, parse()).WillOnce(::testing::Return(true)); - EXPECT_CALL(*multiPartParser, getFieldByName(::testing::_)).Times(0); EXPECT_CALL(*multiPartParser, getFieldByName(::testing::Eq("model"))).WillOnce(::testing::Return("")); + EXPECT_CALL(*multiPartParser, getFieldByName(::testing::Eq("stream"))).WillOnce(::testing::Return("")); EXPECT_CALL(*multiPartParser, getFileContentByFieldName(::testing::_)).Times(0); // Default routing uses everything that comes after /v3/ as graph name