Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions xllm/core/common/global_flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,13 +302,6 @@ DEFINE_int32(npu_phy_id, -1, "npu phy id");

DEFINE_int32(transfer_listen_port, 26000, "The KVCacheTranfer listen port.");

// --- function call config ---

DEFINE_string(tool_call_parser,
"",
"Specify the parser for handling tool-call interactions(e.g. "
"qwen25, qwen3, kimi_k2, deepseekv3, glm45, glm47).");

// --- speculative config ---

DEFINE_int32(num_speculative_tokens, 0, "Number of speculative tokens.");
Expand Down Expand Up @@ -428,7 +421,14 @@ DEFINE_bool(enable_beam_search_kernel,
DEFINE_string(reasoning_parser,
"",
"Specify the reasoning parser for handling reasoning "
"interactions(e.g. glm45, glm47, qwen3, deepseek-r1).");
"interactions(e.g. auto, glm45, glm47, qwen3, deepseek-r1).");

// --- function call config ---

DEFINE_string(tool_call_parser,
"",
"Specify the parser for handling tool-call interactions(e.g. "
"auto, qwen25, qwen3, kimi_k2, deepseekv3, glm45, glm47).");

// --- qwen3 reranker config ---

Expand Down
4 changes: 2 additions & 2 deletions xllm/core/common/global_flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,6 @@ DECLARE_bool(enable_customize_mla_kernel);

DECLARE_bool(enable_atb_comm_multiprocess);

DECLARE_string(tool_call_parser);

DECLARE_bool(enable_atb_spec_kernel);

DECLARE_bool(enable_block_copy_kernel);
Expand Down Expand Up @@ -217,6 +215,8 @@ DECLARE_bool(enable_qwen3_reranker);

DECLARE_string(reasoning_parser);

DECLARE_string(tool_call_parser);

DECLARE_bool(enable_shm);

DECLARE_bool(use_contiguous_input_buffer);
Expand Down
2 changes: 1 addition & 1 deletion xllm/core/common/help_formatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ const OptionCategory kCommonOptions = {"COMMON OPTIONS",

const OptionCategory kMoeModelOptions = {
"MOE MODEL OPTIONS",
{"dp_size", "ep_size", "enable_mla", "expert_parallel_degree"}};
{"dp_size", "ep_size", "expert_parallel_degree"}};

const OptionCategory kDisaggregatedPrefillDecodeOptions = {
"DISAGGREGATED PREFILL-DECODE OPTIONS",
Expand Down
134 changes: 83 additions & 51 deletions xllm/function_call/function_call_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,46 +17,101 @@ limitations under the License.

#include <iostream>
#include <stdexcept>
#include <unordered_map>

#include "absl/strings/str_join.h"
#include "core/util/uuid.h"
#include "deepseekv3_detector.h"
#include "glm45_detector.h"
#include "glm47_detector.h"
#include "kimik2_detector.h"
#include "qwen25_detector.h"

namespace xllm {
namespace function_call {

const std::unordered_map<std::string, std::string>
FunctionCallParser::kToolCallParserMap = {
{"qwen25", "qwen25"},
{"qwen3", "qwen25"},
{"kimi_k2", "kimi_k2"},
{"deepseekv3", "deepseekv3"},
{"glm45", "glm45"},
{"glm47", "glm47"},
// TODO
// {"llama3", "llama3"},
// {"mistral", "mistral"},
// {"pythonic", "pythonic"},
// {"qwen3_coder", "qwen3_coder"},
// {"step3", "step3"},
namespace {

const std::unordered_map<std::string, std::vector<std::string>> auto_paser_map =
{
{"qwen25", {"qwen2", "qwen3"}},
{"kimi_k2", {"kimi_k2"}},
{"deepseekv3", {"deepseek_v3"}},
// GLM-4.5 and GLM-4.7 are not supported for tool call parser
// auto-selection
// {"glm45", {"glm4_moe"}},
// {"glm47", {"glm4_moe"}},
};

std::string get_auto_paser_map_supported() {
std::vector<std::string> keys;
for (const auto& [key, value] : auto_paser_map) {
for (const auto& v : value) {
keys.push_back(v);
}
}
return absl::StrJoin(keys, ", ");
}

const std::unordered_map<std::string,
std::function<std::unique_ptr<BaseFormatDetector>()>>
detector_factories = {
{"qwen25", [] { return std::make_unique<Qwen25Detector>(); }},
{"kimi_k2", [] { return std::make_unique<KimiK2Detector>(); }},
{"deepseekv3", [] { return std::make_unique<DeepSeekV3Detector>(); }},
{"glm45", [] { return std::make_unique<Glm45Detector>(); }},
{"glm47", [] { return std::make_unique<Glm47Detector>(); }},
};

std::string get_supported_detector_factories() {
std::vector<std::string> keys;
for (const auto& [key, value] : detector_factories) {
keys.push_back(key);
}
return absl::StrJoin(keys, ", ");
}

} // namespace

std::string FunctionCallParser::get_parser_auto(const std::string& parser,
const std::string& model_type) {
if (parser.empty()) {
return "";
}
if (parser == "auto") {
// find the tool call parser that supports the model type
for (const auto& [key, value] : auto_paser_map) {
if (std::find(value.begin(), value.end(), model_type) != value.end()) {
LOG(INFO) << "Using tool call parser: " << key
<< " for model type: " << model_type;
return key;
}
}
LOG(FATAL) << "Unsupported model type for auto tool call parser: "
<< model_type << ". Supported model types are: "
<< get_auto_paser_map_supported();
return "";
} else {
// check if the tool call parser is supported
if (parser == "qwen2" || parser == "qwen3") {
return "qwen25";
}
if (detector_factories.find(parser) != detector_factories.end()) {
return parser;
}
LOG(FATAL) << "Unsupported tool call parser: " << parser
<< ". Supported parsers are: "
<< get_supported_detector_factories();
return "";
}
}

FunctionCallParser::FunctionCallParser(const std::vector<JsonTool>& tools,
const std::string& tool_call_parser)
: tools_(tools) {
detector_ = create_detector(tool_call_parser);
CHECK(detector_ != nullptr)
<< "Unsupported tool_call_parser: " << tool_call_parser
<< ". Supported parsers are: " << [this]() {
std::string supported;
for (const auto& [key, value] : kToolCallParserMap) {
if (!supported.empty()) supported += ", ";
supported += key;
}
return supported;
}();
<< "Unsupported tool_call_parser: " << tool_call_parser;
}

bool FunctionCallParser::has_tool_call(const std::string& text) const {
Expand All @@ -82,38 +137,15 @@ StreamingParseResult FunctionCallParser::parse_streaming_increment(

std::unique_ptr<BaseFormatDetector> FunctionCallParser::create_detector(
const std::string& tool_call_parser) {
auto it = kToolCallParserMap.find(tool_call_parser);
if (it == kToolCallParserMap.end()) {
if (tool_call_parser.empty()) {
return nullptr;
}

if (it->second == "qwen25") {
return std::make_unique<Qwen25Detector>();
}

if (it->second == "kimi_k2") {
return std::make_unique<KimiK2Detector>();
}

if (it->second == "deepseekv3") {
return std::make_unique<DeepSeekV3Detector>();
auto it = detector_factories.find(tool_call_parser);
if (it != detector_factories.end()) {
return it->second();
}

if (it->second == "glm45") {
return std::make_unique<Glm45Detector>();
}

if (it->second == "glm47") {
return std::make_unique<Glm47Detector>();
}

// if (tool_call_parser == "llama3") {
// return std::make_unique<Llama32Detector>();
// }
// if (tool_call_parser == "mistral") {
// return std::make_unique<MistralDetector>();
// }

LOG(ERROR) << "Unsupported tool call parser: " << tool_call_parser;
return nullptr;
}

Expand Down
6 changes: 3 additions & 3 deletions xllm/function_call/function_call_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ limitations under the License.
#include <memory>
#include <string>
#include <tuple>
#include <unordered_map>
#include <vector>

#include "base_format_detector.h"
Expand All @@ -29,8 +28,6 @@ namespace function_call {

class FunctionCallParser {
public:
static const std::unordered_map<std::string, std::string> kToolCallParserMap;

FunctionCallParser(const std::vector<JsonTool>& tools,
const std::string& tool_call_parser);

Expand All @@ -54,6 +51,9 @@ class FunctionCallParser {

BaseFormatDetector* get_detector() const { return detector_.get(); }

static std::string get_parser_auto(const std::string& parser,
const std::string& model_type);

private:
std::unique_ptr<BaseFormatDetector> create_detector(
const std::string& tool_call_parser);
Expand Down
6 changes: 3 additions & 3 deletions xllm/function_call/qwen25_detector_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ TEST_F(Qwen25DetectorTest, PerformanceWithManyToolCalls) {

// Test basic streaming functionality
TEST_F(Qwen25StreamingTest, BasicStreamingParsing) {
FunctionCallParser parser(tools_, "qwen3");
FunctionCallParser parser(tools_, "qwen25");

// Simulate streaming chunks
std::vector<std::string> chunks = {
Expand Down Expand Up @@ -398,7 +398,7 @@ TEST_F(Qwen25StreamingTest, BasicStreamingParsing) {

// Test multiple tool calls streaming
TEST_F(Qwen25StreamingTest, MultipleToolCallsStreaming) {
FunctionCallParser parser(tools_, "qwen3");
FunctionCallParser parser(tools_, "qwen25");

// Simulate realistic token-level streaming chunks with multiple tool calls
std::vector<std::string> chunks = {"Let",
Expand Down Expand Up @@ -493,7 +493,7 @@ TEST_F(Qwen25StreamingTest, MultipleToolCallsStreaming) {

// Test partial token handling
TEST_F(Qwen25StreamingTest, PartialTokenHandling) {
FunctionCallParser parser(tools_, "qwen3");
FunctionCallParser parser(tools_, "qwen25");

// Simulate realistic partial tokens being streamed - testing edge cases where
// tokens are split
Expand Down
80 changes: 62 additions & 18 deletions xllm/parser/detector_registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ limitations under the License.

#include <glog/logging.h>

#include <algorithm>

#include "absl/strings/str_join.h"

namespace xllm {
Expand All @@ -33,31 +35,49 @@ namespace {
{name, [](bool stream, bool force) { \
return std::make_unique<ReasoningDetector>(start, end, force, stream); \
}}
} // namespace

DetectorRegistry::DetectorRegistry() {
factories_ = {
REGISTER_DETECTOR_DEFAULT_FORCE(
"deepseek-r1", "<think>", "</think>", true),
REGISTER_DETECTOR("deepseek-v3", "<think>", "</think>"),
REGISTER_DETECTOR("glm45", "<think>", "</think>"),
REGISTER_DETECTOR("glm47", "<think>", "</think>"),
REGISTER_DETECTOR_DEFAULT_FORCE("kimi", "◁think▷", "◁/think▷", false),
REGISTER_DETECTOR("qwen3", "<think>", "</think>"),
REGISTER_DETECTOR_DEFAULT_FORCE(
"qwen3-thinking", "<think>", "</think>", true),
REGISTER_DETECTOR("step3", "<think>", "</think>"),
};
// Maps reasoning_parser name to supported model_types
const std::unordered_map<std::string, std::string> auto_paser_map = {
// {"deepseek_v3", "deepseek-v3"},
// {"qwen3", "qwen3"},
{"glm4_moe", "glm45"},
{"kimi_k2", "kimi"},
{"step3", "step3"},
};

std::string get_auto_paser_map_supported() {
std::vector<std::string> keys;
keys.reserve(auto_paser_map.size());
for (const auto& pair : auto_paser_map) {
keys.push_back(pair.first);
}
return absl::StrJoin(keys, ", ");
}

std::unique_ptr<ReasoningDetector> DetectorRegistry::getDetector(
const std::unordered_map<std::string, DetectorFactory> paser_factories = {
REGISTER_DETECTOR_DEFAULT_FORCE("deepseek-r1", "<think>", "</think>", true),
REGISTER_DETECTOR("deepseek-v3", "<think>", "</think>"),
REGISTER_DETECTOR("glm45", "<think>", "</think>"),
REGISTER_DETECTOR("glm47", "<think>", "</think>"),
REGISTER_DETECTOR_DEFAULT_FORCE("kimi", "◁think▷", "◁/think▷", false),
REGISTER_DETECTOR("qwen3", "<think>", "</think>"),
REGISTER_DETECTOR_DEFAULT_FORCE("qwen3-thinking",
"<think>",
"</think>",
true),
REGISTER_DETECTOR("step3", "<think>", "</think>"),
};

} // namespace

std::unique_ptr<ReasoningDetector> DetectorRegistry::get_detector(
const std::string& model_type,
bool stream_reasoning,
bool force_reasoning) {
auto it = factories_.find(model_type);
if (it == factories_.end()) {
auto it = paser_factories.find(model_type);
if (it == paser_factories.end()) {
std::vector<std::string> keys;
for (const auto& pair : factories_) {
for (const auto& pair : paser_factories) {
keys.push_back(pair.first);
}
LOG(FATAL) << "Unsupported model type for reasoning parser: " << model_type
Expand All @@ -67,4 +87,28 @@ std::unique_ptr<ReasoningDetector> DetectorRegistry::getDetector(
return it->second(stream_reasoning, force_reasoning);
};

bool DetectorRegistry::has_detector(const std::string& parser_name) const {
return paser_factories.find(parser_name) != paser_factories.end();
}

std::string DetectorRegistry::get_supported_parsers() const {
std::vector<std::string> keys;
keys.reserve(paser_factories.size());
for (const auto& pair : paser_factories) {
keys.push_back(pair.first);
}
return absl::StrJoin(keys, ", ");
}

std::string DetectorRegistry::get_parser_name_by_model_type(
const std::string& model_type) const {
auto it = auto_paser_map.find(model_type);
if (it != auto_paser_map.end()) {
return it->second;
}
LOG(FATAL) << "Unsupported model type for reasoning parser: " << model_type
<< ". Supported model types are: "
<< get_auto_paser_map_supported();
return "";
}
} // namespace xllm
Loading
Loading