Skip to content

Commit ee3f9c0

Browse files
committed
feat: support auto-selection and pre-check for function call and reasoning paser.
1 parent 99253f9 commit ee3f9c0

File tree

11 files changed

+218
-102
lines changed

11 files changed

+218
-102
lines changed

xllm/core/common/global_flags.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -302,13 +302,6 @@ DEFINE_int32(npu_phy_id, -1, "npu phy id");
302302

303303
DEFINE_int32(transfer_listen_port, 26000, "The KVCacheTranfer listen port.");
304304

305-
// --- function call config ---
306-
307-
DEFINE_string(tool_call_parser,
308-
"",
309-
"Specify the parser for handling tool-call interactions(e.g. "
310-
"qwen25, qwen3, kimi_k2, deepseekv3, glm45, glm47).");
311-
312305
// --- speculative config ---
313306

314307
DEFINE_int32(num_speculative_tokens, 0, "Number of speculative tokens.");
@@ -428,7 +421,14 @@ DEFINE_bool(enable_beam_search_kernel,
428421
DEFINE_string(reasoning_parser,
429422
"",
430423
"Specify the reasoning parser for handling reasoning "
431-
"interactions(e.g. glm45, glm47, qwen3, deepseek-r1).");
424+
"interactions(e.g. auto, glm45, glm47, qwen3, deepseek-r1).");
425+
426+
// --- function call config ---
427+
428+
DEFINE_string(tool_call_parser,
429+
"",
430+
"Specify the parser for handling tool-call interactions(e.g. "
431+
"auto, qwen25, qwen3, kimi_k2, deepseekv3, glm45, glm47).");
432432

433433
// --- qwen3 reranker config ---
434434

xllm/core/common/global_flags.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,6 @@ DECLARE_bool(enable_customize_mla_kernel);
133133

134134
DECLARE_bool(enable_atb_comm_multiprocess);
135135

136-
DECLARE_string(tool_call_parser);
137-
138136
DECLARE_bool(enable_atb_spec_kernel);
139137

140138
DECLARE_bool(enable_block_copy_kernel);
@@ -217,6 +215,8 @@ DECLARE_bool(enable_qwen3_reranker);
217215

218216
DECLARE_string(reasoning_parser);
219217

218+
DECLARE_string(tool_call_parser);
219+
220220
DECLARE_bool(enable_shm);
221221

222222
DECLARE_bool(use_contiguous_input_buffer);

xllm/core/common/help_formatter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ const OptionCategory kCommonOptions = {"COMMON OPTIONS",
5050

5151
const OptionCategory kMoeModelOptions = {
5252
"MOE MODEL OPTIONS",
53-
{"dp_size", "ep_size", "enable_mla", "expert_parallel_degree"}};
53+
{"dp_size", "ep_size", "expert_parallel_degree"}};
5454

5555
const OptionCategory kDisaggregatedPrefillDecodeOptions = {
5656
"DISAGGREGATED PREFILL-DECODE OPTIONS",

xllm/function_call/function_call_parser.cpp

Lines changed: 83 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -17,46 +17,101 @@ limitations under the License.
1717

1818
#include <iostream>
1919
#include <stdexcept>
20+
#include <unordered_map>
2021

22+
#include "absl/strings/str_join.h"
2123
#include "core/util/uuid.h"
2224
#include "deepseekv3_detector.h"
2325
#include "glm45_detector.h"
2426
#include "glm47_detector.h"
2527
#include "kimik2_detector.h"
2628
#include "qwen25_detector.h"
29+
2730
namespace xllm {
2831
namespace function_call {
2932

30-
const std::unordered_map<std::string, std::string>
31-
FunctionCallParser::kToolCallParserMap = {
32-
{"qwen25", "qwen25"},
33-
{"qwen3", "qwen25"},
34-
{"kimi_k2", "kimi_k2"},
35-
{"deepseekv3", "deepseekv3"},
36-
{"glm45", "glm45"},
37-
{"glm47", "glm47"},
38-
// TODO
39-
// {"llama3", "llama3"},
40-
// {"mistral", "mistral"},
41-
// {"pythonic", "pythonic"},
42-
// {"qwen3_coder", "qwen3_coder"},
43-
// {"step3", "step3"},
33+
namespace {
34+
35+
const std::unordered_map<std::string, std::vector<std::string>> auto_paser_map =
36+
{
37+
{"qwen25", {"qwen2", "qwen3"}},
38+
{"kimi_k2", {"kimi_k2"}},
39+
{"deepseekv3", {"deepseek_v3"}},
40+
// GLM-4.5 and GLM-4.7 are not supported for tool call parser
41+
// auto-selection
42+
// {"glm45", {"glm4_moe"}},
43+
// {"glm47", {"glm4_moe"}},
4444
};
4545

46+
std::string get_auto_paser_map_supported() {
47+
std::vector<std::string> keys;
48+
for (const auto& [key, value] : auto_paser_map) {
49+
for (const auto& v : value) {
50+
keys.push_back(v);
51+
}
52+
}
53+
return absl::StrJoin(keys, ", ");
54+
}
55+
56+
const std::unordered_map<std::string,
57+
std::function<std::unique_ptr<BaseFormatDetector>()>>
58+
detector_factories = {
59+
{"qwen25", [] { return std::make_unique<Qwen25Detector>(); }},
60+
{"kimi_k2", [] { return std::make_unique<KimiK2Detector>(); }},
61+
{"deepseekv3", [] { return std::make_unique<DeepSeekV3Detector>(); }},
62+
{"glm45", [] { return std::make_unique<Glm45Detector>(); }},
63+
{"glm47", [] { return std::make_unique<Glm47Detector>(); }},
64+
};
65+
66+
std::string get_supported_detector_factories() {
67+
std::vector<std::string> keys;
68+
for (const auto& [key, value] : detector_factories) {
69+
keys.push_back(key);
70+
}
71+
return absl::StrJoin(keys, ", ");
72+
}
73+
74+
} // namespace
75+
76+
std::string FunctionCallParser::get_parser_auto(const std::string& parser,
77+
const std::string& model_type) {
78+
if (parser.empty()) {
79+
return "";
80+
}
81+
if (parser == "auto") {
82+
// find the tool call parser that supports the model type
83+
for (const auto& [key, value] : auto_paser_map) {
84+
if (std::find(value.begin(), value.end(), model_type) != value.end()) {
85+
LOG(INFO) << "Using tool call parser: " << key
86+
<< " for model type: " << model_type;
87+
return key;
88+
}
89+
}
90+
LOG(FATAL) << "Unsupported model type for auto tool call parser: "
91+
<< model_type << ". Supported model types are: "
92+
<< get_auto_paser_map_supported();
93+
return "";
94+
} else {
95+
// check if the tool call parser is supported
96+
if (parser == "qwen2" || parser == "qwen3") {
97+
return "qwen25";
98+
}
99+
if (detector_factories.find(parser) != detector_factories.end()) {
100+
return parser;
101+
}
102+
LOG(FATAL) << "Unsupported tool call parser: " << parser
103+
<< ". Supported parsers are: "
104+
<< get_supported_detector_factories();
105+
return "";
106+
}
107+
}
108+
46109
FunctionCallParser::FunctionCallParser(const std::vector<JsonTool>& tools,
47110
const std::string& tool_call_parser)
48111
: tools_(tools) {
49112
detector_ = create_detector(tool_call_parser);
50113
CHECK(detector_ != nullptr)
51-
<< "Unsupported tool_call_parser: " << tool_call_parser
52-
<< ". Supported parsers are: " << [this]() {
53-
std::string supported;
54-
for (const auto& [key, value] : kToolCallParserMap) {
55-
if (!supported.empty()) supported += ", ";
56-
supported += key;
57-
}
58-
return supported;
59-
}();
114+
<< "Unsupported tool_call_parser: " << tool_call_parser;
60115
}
61116

62117
bool FunctionCallParser::has_tool_call(const std::string& text) const {
@@ -82,38 +137,15 @@ StreamingParseResult FunctionCallParser::parse_streaming_increment(
82137

83138
std::unique_ptr<BaseFormatDetector> FunctionCallParser::create_detector(
84139
const std::string& tool_call_parser) {
85-
auto it = kToolCallParserMap.find(tool_call_parser);
86-
if (it == kToolCallParserMap.end()) {
140+
if (tool_call_parser.empty()) {
87141
return nullptr;
88142
}
89143

90-
if (it->second == "qwen25") {
91-
return std::make_unique<Qwen25Detector>();
92-
}
93-
94-
if (it->second == "kimi_k2") {
95-
return std::make_unique<KimiK2Detector>();
96-
}
97-
98-
if (it->second == "deepseekv3") {
99-
return std::make_unique<DeepSeekV3Detector>();
144+
auto it = detector_factories.find(tool_call_parser);
145+
if (it != detector_factories.end()) {
146+
return it->second();
100147
}
101-
102-
if (it->second == "glm45") {
103-
return std::make_unique<Glm45Detector>();
104-
}
105-
106-
if (it->second == "glm47") {
107-
return std::make_unique<Glm47Detector>();
108-
}
109-
110-
// if (tool_call_parser == "llama3") {
111-
// return std::make_unique<Llama32Detector>();
112-
// }
113-
// if (tool_call_parser == "mistral") {
114-
// return std::make_unique<MistralDetector>();
115-
// }
116-
148+
LOG(ERROR) << "Unsupported tool call parser: " << tool_call_parser;
117149
return nullptr;
118150
}
119151

xllm/function_call/function_call_parser.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ limitations under the License.
1818
#include <memory>
1919
#include <string>
2020
#include <tuple>
21-
#include <unordered_map>
2221
#include <vector>
2322

2423
#include "base_format_detector.h"
@@ -29,8 +28,6 @@ namespace function_call {
2928

3029
class FunctionCallParser {
3130
public:
32-
static const std::unordered_map<std::string, std::string> kToolCallParserMap;
33-
3431
FunctionCallParser(const std::vector<JsonTool>& tools,
3532
const std::string& tool_call_parser);
3633

@@ -54,6 +51,9 @@ class FunctionCallParser {
5451

5552
BaseFormatDetector* get_detector() const { return detector_.get(); }
5653

54+
static std::string get_parser_auto(const std::string& parser,
55+
const std::string& model_type);
56+
5757
private:
5858
std::unique_ptr<BaseFormatDetector> create_detector(
5959
const std::string& tool_call_parser);

xllm/function_call/qwen25_detector_test.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ TEST_F(Qwen25DetectorTest, PerformanceWithManyToolCalls) {
355355

356356
// Test basic streaming functionality
357357
TEST_F(Qwen25StreamingTest, BasicStreamingParsing) {
358-
FunctionCallParser parser(tools_, "qwen3");
358+
FunctionCallParser parser(tools_, "qwen25");
359359

360360
// Simulate streaming chunks
361361
std::vector<std::string> chunks = {
@@ -398,7 +398,7 @@ TEST_F(Qwen25StreamingTest, BasicStreamingParsing) {
398398

399399
// Test multiple tool calls streaming
400400
TEST_F(Qwen25StreamingTest, MultipleToolCallsStreaming) {
401-
FunctionCallParser parser(tools_, "qwen3");
401+
FunctionCallParser parser(tools_, "qwen25");
402402

403403
// Simulate realistic token-level streaming chunks with multiple tool calls
404404
std::vector<std::string> chunks = {"Let",
@@ -493,7 +493,7 @@ TEST_F(Qwen25StreamingTest, MultipleToolCallsStreaming) {
493493

494494
// Test partial token handling
495495
TEST_F(Qwen25StreamingTest, PartialTokenHandling) {
496-
FunctionCallParser parser(tools_, "qwen3");
496+
FunctionCallParser parser(tools_, "qwen25");
497497

498498
// Simulate realistic partial tokens being streamed - testing edge cases where
499499
// tokens are split

xllm/parser/detector_registry.cpp

Lines changed: 62 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ limitations under the License.
1818

1919
#include <glog/logging.h>
2020

21+
#include <algorithm>
22+
2123
#include "absl/strings/str_join.h"
2224

2325
namespace xllm {
@@ -33,31 +35,49 @@ namespace {
3335
{name, [](bool stream, bool force) { \
3436
return std::make_unique<ReasoningDetector>(start, end, force, stream); \
3537
}}
36-
} // namespace
3738

38-
DetectorRegistry::DetectorRegistry() {
39-
factories_ = {
40-
REGISTER_DETECTOR_DEFAULT_FORCE(
41-
"deepseek-r1", "<think>", "</think>", true),
42-
REGISTER_DETECTOR("deepseek-v3", "<think>", "</think>"),
43-
REGISTER_DETECTOR("glm45", "<think>", "</think>"),
44-
REGISTER_DETECTOR("glm47", "<think>", "</think>"),
45-
REGISTER_DETECTOR_DEFAULT_FORCE("kimi", "◁think▷", "◁/think▷", false),
46-
REGISTER_DETECTOR("qwen3", "<think>", "</think>"),
47-
REGISTER_DETECTOR_DEFAULT_FORCE(
48-
"qwen3-thinking", "<think>", "</think>", true),
49-
REGISTER_DETECTOR("step3", "<think>", "</think>"),
50-
};
39+
// Maps reasoning_parser name to supported model_types
40+
const std::unordered_map<std::string, std::string> auto_paser_map = {
41+
// {"deepseek_v3", "deepseek-v3"},
42+
// {"qwen3", "qwen3"},
43+
{"glm4_moe", "glm45"},
44+
{"kimi_k2", "kimi"},
45+
{"step3", "step3"},
46+
};
47+
48+
std::string get_auto_paser_map_supported() {
49+
std::vector<std::string> keys;
50+
keys.reserve(auto_paser_map.size());
51+
for (const auto& pair : auto_paser_map) {
52+
keys.push_back(pair.first);
53+
}
54+
return absl::StrJoin(keys, ", ");
5155
}
5256

53-
std::unique_ptr<ReasoningDetector> DetectorRegistry::getDetector(
57+
const std::unordered_map<std::string, DetectorFactory> paser_factories = {
58+
REGISTER_DETECTOR_DEFAULT_FORCE("deepseek-r1", "<think>", "</think>", true),
59+
REGISTER_DETECTOR("deepseek-v3", "<think>", "</think>"),
60+
REGISTER_DETECTOR("glm45", "<think>", "</think>"),
61+
REGISTER_DETECTOR("glm47", "<think>", "</think>"),
62+
REGISTER_DETECTOR_DEFAULT_FORCE("kimi", "◁think▷", "◁/think▷", false),
63+
REGISTER_DETECTOR("qwen3", "<think>", "</think>"),
64+
REGISTER_DETECTOR_DEFAULT_FORCE("qwen3-thinking",
65+
"<think>",
66+
"</think>",
67+
true),
68+
REGISTER_DETECTOR("step3", "<think>", "</think>"),
69+
};
70+
71+
} // namespace
72+
73+
std::unique_ptr<ReasoningDetector> DetectorRegistry::get_detector(
5474
const std::string& model_type,
5575
bool stream_reasoning,
5676
bool force_reasoning) {
57-
auto it = factories_.find(model_type);
58-
if (it == factories_.end()) {
77+
auto it = paser_factories.find(model_type);
78+
if (it == paser_factories.end()) {
5979
std::vector<std::string> keys;
60-
for (const auto& pair : factories_) {
80+
for (const auto& pair : paser_factories) {
6181
keys.push_back(pair.first);
6282
}
6383
LOG(FATAL) << "Unsupported model type for reasoning parser: " << model_type
@@ -67,4 +87,28 @@ std::unique_ptr<ReasoningDetector> DetectorRegistry::getDetector(
6787
return it->second(stream_reasoning, force_reasoning);
6888
};
6989

90+
bool DetectorRegistry::has_detector(const std::string& parser_name) const {
91+
return paser_factories.find(parser_name) != paser_factories.end();
92+
}
93+
94+
std::string DetectorRegistry::get_supported_parsers() const {
95+
std::vector<std::string> keys;
96+
keys.reserve(paser_factories.size());
97+
for (const auto& pair : paser_factories) {
98+
keys.push_back(pair.first);
99+
}
100+
return absl::StrJoin(keys, ", ");
101+
}
102+
103+
std::string DetectorRegistry::get_parser_name_by_model_type(
104+
const std::string& model_type) const {
105+
auto it = auto_paser_map.find(model_type);
106+
if (it != auto_paser_map.end()) {
107+
return it->second;
108+
}
109+
LOG(FATAL) << "Unsupported model type for reasoning parser: " << model_type
110+
<< ". Supported model types are: "
111+
<< get_auto_paser_map_supported();
112+
return "";
113+
}
70114
} // namespace xllm

0 commit comments

Comments
 (0)