Skip to content

Commit 2f930ed

Browse files
committed
save
1 parent 3721f50 commit 2f930ed

File tree

6 files changed

+80
-18
lines changed

6 files changed

+80
-18
lines changed

src/llm/apis/openai_completions.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "openai_completions.hpp"
1818

1919
#include <cmath>
20+
#include <limits>
2021
#include <memory>
2122
#include "src/port/rapidjson_stringbuffer.hpp"
2223
#include "src/port/rapidjson_writer.hpp"
@@ -44,6 +45,51 @@ namespace ovms {
4445

4546
constexpr size_t DEFAULT_MAX_STOP_WORDS = 16; // same as deep-seek
4647

48+
namespace {
49+
50+
ov::genai::JsonContainer rapidJsonValueToJsonContainer(const rapidjson::Value& value) {
51+
if (value.IsNull()) {
52+
return ov::genai::JsonContainer(nullptr);
53+
}
54+
if (value.IsBool()) {
55+
return ov::genai::JsonContainer(value.GetBool());
56+
}
57+
if (value.IsInt64()) {
58+
return ov::genai::JsonContainer(value.GetInt64());
59+
}
60+
if (value.IsUint64()) {
61+
auto uintValue = value.GetUint64();
62+
if (uintValue <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) {
63+
return ov::genai::JsonContainer(static_cast<int64_t>(uintValue));
64+
}
65+
return ov::genai::JsonContainer(static_cast<double>(uintValue));
66+
}
67+
if (value.IsDouble()) {
68+
return ov::genai::JsonContainer(value.GetDouble());
69+
}
70+
if (value.IsString()) {
71+
return ov::genai::JsonContainer(std::string(value.GetString(), value.GetStringLength()));
72+
}
73+
if (value.IsArray()) {
74+
ov::genai::JsonContainer arrayContainer = ov::genai::JsonContainer::array();
75+
for (const auto& item : value.GetArray()) {
76+
arrayContainer.push_back(rapidJsonValueToJsonContainer(item));
77+
}
78+
return arrayContainer;
79+
}
80+
if (value.IsObject()) {
81+
ov::genai::JsonContainer objectContainer = ov::genai::JsonContainer::object();
82+
for (auto member = value.MemberBegin(); member != value.MemberEnd(); ++member) {
83+
const std::string key(member->name.GetString(), member->name.GetStringLength());
84+
objectContainer[key] = rapidJsonValueToJsonContainer(member->value);
85+
}
86+
return objectContainer;
87+
}
88+
throw std::invalid_argument("Unsupported JSON value type");
89+
}
90+
91+
} // namespace
92+
4793
absl::Status OpenAIChatCompletionsHandler::parseCompletionsPart() {
4894
// prompt: string
4995
auto it = doc.FindMember("prompt");
@@ -430,6 +476,23 @@ absl::Status OpenAIChatCompletionsHandler::parseTools() {
430476
}
431477

432478
request.toolChoice = tool_choice;
479+
request.tools = std::nullopt;
480+
if (it != doc.MemberEnd() && !it->value.IsNull()) {
481+
try {
482+
request.tools = rapidJsonValueToJsonContainer(it->value);
483+
} catch (const std::exception& e) {
484+
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Direct tools conversion to JsonContainer failed: {}. Falling back to JSON string conversion.", e.what());
485+
try {
486+
rapidjson::StringBuffer toolsBuffer;
487+
rapidjson::Writer<rapidjson::StringBuffer> toolsWriter(toolsBuffer);
488+
it->value.Accept(toolsWriter);
489+
request.tools = ov::genai::JsonContainer::from_json_string(toolsBuffer.GetString());
490+
} catch (const std::exception& fallbackEx) {
491+
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Fallback tools conversion failed: {}", fallbackEx.what());
492+
return absl::InvalidArgumentError(absl::StrCat("Invalid tools payload: ", fallbackEx.what()));
493+
}
494+
}
495+
}
433496
if (jsonChanged) {
434497
StringBuffer buffer;
435498
Writer<StringBuffer> writer(buffer);
@@ -466,6 +529,10 @@ std::optional<std::string> OpenAIChatCompletionsHandler::getResponseFormat() con
466529
return request.responseFormat;
467530
}
468531

532+
const std::optional<ov::genai::JsonContainer>& OpenAIChatCompletionsHandler::getTools() const {
533+
return request.tools;
534+
}
535+
469536
std::string convertOpenAIResponseFormatToStructuralTagStringFormat(const rapidjson::Value& openAIFormat) {
470537
// Build the new object: {"type": "structural_tag", "format": <openAIFormat>}
471538
// If response_format has {"json_schema": {"schema": {...}}}, flatten it to {"json_schema": {...}}

src/llm/apis/openai_completions.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ class OpenAIChatCompletionsHandler {
102102
ov::genai::ChatHistory& getChatHistory();
103103
std::optional<int> getMaxTokens() const;
104104
std::optional<std::string> getResponseFormat() const;
105+
const std::optional<ov::genai::JsonContainer>& getTools() const;
105106

106107
bool isStream() const;
107108
std::string getModel() const;

src/llm/apis/openai_request.hpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <vector>
2626

2727
#include <openvino/runtime/tensor.hpp>
28+
#include <openvino/genai/json_container.hpp>
2829
#include <openvino/genai/tokenizer.hpp>
2930

3031
#include "src/port/rapidjson_document.hpp"
@@ -78,6 +79,8 @@ struct OpenAIChatCompletionsRequest {
7879
std::optional<std::string> responseFormat{std::nullopt};
7980
// Map that holds tool names and schemas for their arguments
8081
ToolsSchemas_t toolNameSchemaMap;
82+
// Full tools payload in JSON form for passing directly to tokenizer chat template.
83+
std::optional<ov::genai::JsonContainer> tools{std::nullopt};
8184
// Holds value for tool_choice field as described in https://platform.openai.com/docs/api-reference/chat/create#chat_create-tool_choice
8285
std::string toolChoice;
8386

src/llm/servable.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ absl::Status GenAiServable::parseRequest(std::shared_ptr<GenAiServableExecutionC
156156
return absl::OkStatus();
157157
}
158158

159+
// Continuous batching LLM
159160
absl::Status GenAiServable::prepareInputs(std::shared_ptr<GenAiServableExecutionContext>& executionContext) {
160161
if (executionContext->apiHandler == nullptr) {
161162
return absl::Status(absl::StatusCode::kInvalidArgument, "API handler is not initialized");

src/llm/visual_language_model/continuous_batching/servable.cpp

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ std::shared_ptr<GenAiServableProperties> VisualLanguageModelServable::getPropert
6262
return properties;
6363
}
6464

65+
// Continuous Batching VLM
6566
absl::Status VisualLanguageModelServable::prepareInputs(std::shared_ptr<GenAiServableExecutionContext>& executionContext) {
6667
auto vlmExecutionContext = std::static_pointer_cast<VisualLanguageModelServableExecutionContext>(executionContext);
6768
if (vlmExecutionContext->apiHandler == nullptr) {
@@ -93,24 +94,12 @@ absl::Status VisualLanguageModelServable::prepareInputs(std::shared_ptr<GenAiSer
9394
}
9495

9596
constexpr bool add_generation_prompt = true; // confirm it should be hardcoded
96-
ov::genai::JsonContainer tools = ov::genai::JsonContainer::from_json_string(R"([
97-
{
98-
"type": "function",
99-
"function": {
100-
"name": "get_weather",
101-
"description": "Get current weather by city",
102-
"parameters": {
103-
"type": "object",
104-
"properties": {
105-
"city": {"type": "string"}
106-
},
107-
"required": ["city"]
108-
}
109-
}
110-
}
111-
])");
112-
vlmExecutionContext->inputText = properties->tokenizer.apply_chat_template(chatHistory, add_generation_prompt, {}, tools);
113-
//vlmExecutionContext->inputText = properties->tokenizer.apply_chat_template(chatHistory, add_generation_prompt, {});
97+
const auto& tools = vlmExecutionContext->apiHandler->getTools();
98+
if (tools.has_value()) {
99+
vlmExecutionContext->inputText = properties->tokenizer.apply_chat_template(chatHistory, add_generation_prompt, {}, tools);
100+
} else {
101+
vlmExecutionContext->inputText = properties->tokenizer.apply_chat_template(chatHistory, add_generation_prompt, {});
102+
}
114103
} else {
115104
return absl::InvalidArgumentError("Unsupported endpoint");
116105
}

src/llm/visual_language_model/legacy/servable.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ absl::Status VisualLanguageModelLegacyServable::preparePartialResponse(std::shar
222222
return absl::OkStatus();
223223
}
224224

225+
// Legacy VLM
225226
absl::Status VisualLanguageModelLegacyServable::prepareInputs(std::shared_ptr<GenAiServableExecutionContext>& executionContext) {
226227
auto vlmExecutionContext = std::static_pointer_cast<VisualLanguageModelLegacyServableExecutionContext>(executionContext);
227228
if (vlmExecutionContext->apiHandler == nullptr) {

0 commit comments

Comments
 (0)