|
17 | 17 | #include "openai_completions.hpp" |
18 | 18 |
|
19 | 19 | #include <cmath> |
| 20 | +#include <limits> |
20 | 21 | #include <memory> |
21 | 22 | #include "src/port/rapidjson_stringbuffer.hpp" |
22 | 23 | #include "src/port/rapidjson_writer.hpp" |
@@ -44,6 +45,51 @@ namespace ovms { |
44 | 45 |
|
45 | 46 | constexpr size_t DEFAULT_MAX_STOP_WORDS = 16; // same as deep-seek |
46 | 47 |
|
| 48 | +namespace { |
| 49 | + |
| 50 | +ov::genai::JsonContainer rapidJsonValueToJsonContainer(const rapidjson::Value& value) { |
| 51 | + if (value.IsNull()) { |
| 52 | + return ov::genai::JsonContainer(nullptr); |
| 53 | + } |
| 54 | + if (value.IsBool()) { |
| 55 | + return ov::genai::JsonContainer(value.GetBool()); |
| 56 | + } |
| 57 | + if (value.IsInt64()) { |
| 58 | + return ov::genai::JsonContainer(value.GetInt64()); |
| 59 | + } |
| 60 | + if (value.IsUint64()) { |
| 61 | + auto uintValue = value.GetUint64(); |
| 62 | + if (uintValue <= static_cast<uint64_t>(std::numeric_limits<int64_t>::max())) { |
| 63 | + return ov::genai::JsonContainer(static_cast<int64_t>(uintValue)); |
| 64 | + } |
| 65 | + return ov::genai::JsonContainer(static_cast<double>(uintValue)); |
| 66 | + } |
| 67 | + if (value.IsDouble()) { |
| 68 | + return ov::genai::JsonContainer(value.GetDouble()); |
| 69 | + } |
| 70 | + if (value.IsString()) { |
| 71 | + return ov::genai::JsonContainer(std::string(value.GetString(), value.GetStringLength())); |
| 72 | + } |
| 73 | + if (value.IsArray()) { |
| 74 | + ov::genai::JsonContainer arrayContainer = ov::genai::JsonContainer::array(); |
| 75 | + for (const auto& item : value.GetArray()) { |
| 76 | + arrayContainer.push_back(rapidJsonValueToJsonContainer(item)); |
| 77 | + } |
| 78 | + return arrayContainer; |
| 79 | + } |
| 80 | + if (value.IsObject()) { |
| 81 | + ov::genai::JsonContainer objectContainer = ov::genai::JsonContainer::object(); |
| 82 | + for (auto member = value.MemberBegin(); member != value.MemberEnd(); ++member) { |
| 83 | + const std::string key(member->name.GetString(), member->name.GetStringLength()); |
| 84 | + objectContainer[key] = rapidJsonValueToJsonContainer(member->value); |
| 85 | + } |
| 86 | + return objectContainer; |
| 87 | + } |
| 88 | + throw std::invalid_argument("Unsupported JSON value type"); |
| 89 | +} |
| 90 | + |
| 91 | +} // namespace |
| 92 | + |
47 | 93 | absl::Status OpenAIChatCompletionsHandler::parseCompletionsPart() { |
48 | 94 | // prompt: string |
49 | 95 | auto it = doc.FindMember("prompt"); |
@@ -430,6 +476,23 @@ absl::Status OpenAIChatCompletionsHandler::parseTools() { |
430 | 476 | } |
431 | 477 |
|
432 | 478 | request.toolChoice = tool_choice; |
| 479 | + request.tools = std::nullopt; |
| 480 | + if (it != doc.MemberEnd() && !it->value.IsNull()) { |
| 481 | + try { |
| 482 | + request.tools = rapidJsonValueToJsonContainer(it->value); |
| 483 | + } catch (const std::exception& e) { |
| 484 | + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Direct tools conversion to JsonContainer failed: {}. Falling back to JSON string conversion.", e.what()); |
| 485 | + try { |
| 486 | + rapidjson::StringBuffer toolsBuffer; |
| 487 | + rapidjson::Writer<rapidjson::StringBuffer> toolsWriter(toolsBuffer); |
| 488 | + it->value.Accept(toolsWriter); |
| 489 | + request.tools = ov::genai::JsonContainer::from_json_string(toolsBuffer.GetString()); |
| 490 | + } catch (const std::exception& fallbackEx) { |
| 491 | + SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Fallback tools conversion failed: {}", fallbackEx.what()); |
| 492 | + return absl::InvalidArgumentError(absl::StrCat("Invalid tools payload: ", fallbackEx.what())); |
| 493 | + } |
| 494 | + } |
| 495 | + } |
433 | 496 | if (jsonChanged) { |
434 | 497 | StringBuffer buffer; |
435 | 498 | Writer<StringBuffer> writer(buffer); |
@@ -466,6 +529,10 @@ std::optional<std::string> OpenAIChatCompletionsHandler::getResponseFormat() con |
466 | 529 | return request.responseFormat; |
467 | 530 | } |
468 | 531 |
|
| 532 | +const std::optional<ov::genai::JsonContainer>& OpenAIChatCompletionsHandler::getTools() const { |
| 533 | + return request.tools; |
| 534 | +} |
| 535 | + |
469 | 536 | std::string convertOpenAIResponseFormatToStructuralTagStringFormat(const rapidjson::Value& openAIFormat) { |
470 | 537 | // Build the new object: {"type": "structural_tag", "format": <openAIFormat>} |
471 | 538 | // If response_format has {"json_schema": {"schema": {...}}}, flatten it to {"json_schema": {...}} |
|
0 commit comments