Skip to content

Commit 36b2835

Browse files
committed
feat: support auto-selection and pre-check for function call and reasoning paser.
1 parent 3c20ea1 commit 36b2835

File tree

11 files changed

+152
-63
lines changed

11 files changed

+152
-63
lines changed

xllm/core/common/global_flags.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -302,13 +302,6 @@ DEFINE_int32(npu_phy_id, -1, "npu phy id");
302302

303303
DEFINE_int32(transfer_listen_port, 26000, "The KVCacheTranfer listen port.");
304304

305-
// --- function call config ---
306-
307-
DEFINE_string(tool_call_parser,
308-
"",
309-
"Specify the parser for handling tool-call interactions(e.g. "
310-
"qwen25, qwen3, kimi_k2, deepseekv3, glm45, glm47).");
311-
312305
// --- speculative config ---
313306

314307
DEFINE_int32(num_speculative_tokens, 0, "Number of speculative tokens.");
@@ -428,7 +421,14 @@ DEFINE_bool(enable_beam_search_kernel,
428421
DEFINE_string(reasoning_parser,
429422
"",
430423
"Specify the reasoning parser for handling reasoning "
431-
"interactions(e.g. glm45, glm47, qwen3, deepseek-r1).");
424+
"interactions(e.g. auto, glm45, glm47, qwen3, deepseek-r1).");
425+
426+
// --- function call config ---
427+
428+
DEFINE_string(tool_call_parser,
429+
"",
430+
"Specify the parser for handling tool-call interactions(e.g. "
431+
"auto, qwen25, qwen3, kimi_k2, deepseekv3, glm45, glm47).");
432432

433433
// --- qwen3 reranker config ---
434434

xllm/core/common/global_flags.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,6 @@ DECLARE_bool(enable_customize_mla_kernel);
133133

134134
DECLARE_bool(enable_atb_comm_multiprocess);
135135

136-
DECLARE_string(tool_call_parser);
137-
138136
DECLARE_bool(enable_atb_spec_kernel);
139137

140138
DECLARE_bool(enable_block_copy_kernel);
@@ -217,6 +215,8 @@ DECLARE_bool(enable_qwen3_reranker);
217215

218216
DECLARE_string(reasoning_parser);
219217

218+
DECLARE_string(tool_call_parser);
219+
220220
DECLARE_bool(enable_shm);
221221

222222
DECLARE_bool(use_contiguous_input_buffer);

xllm/core/common/help_formatter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ const OptionCategory kCommonOptions = {"COMMON OPTIONS",
5050

5151
const OptionCategory kMoeModelOptions = {
5252
"MOE MODEL OPTIONS",
53-
{"dp_size", "ep_size", "enable_mla", "expert_parallel_degree"}};
53+
{"dp_size", "ep_size", "expert_parallel_degree"}};
5454

5555
const OptionCategory kDisaggregatedPrefillDecodeOptions = {
5656
"DISAGGREGATED PREFILL-DECODE OPTIONS",

xllm/function_call/function_call_parser.cpp

Lines changed: 58 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -17,46 +17,74 @@ limitations under the License.
1717

1818
#include <iostream>
1919
#include <stdexcept>
20+
#include <unordered_map>
2021

2122
#include "core/util/uuid.h"
2223
#include "deepseekv3_detector.h"
2324
#include "glm45_detector.h"
2425
#include "glm47_detector.h"
2526
#include "kimik2_detector.h"
2627
#include "qwen25_detector.h"
28+
2729
namespace xllm {
2830
namespace function_call {
2931

30-
const std::unordered_map<std::string, std::string>
31-
FunctionCallParser::kToolCallParserMap = {
32-
{"qwen25", "qwen25"},
33-
{"qwen3", "qwen25"},
34-
{"kimi_k2", "kimi_k2"},
35-
{"deepseekv3", "deepseekv3"},
36-
{"glm45", "glm45"},
37-
{"glm47", "glm47"},
38-
// TODO
39-
// {"llama3", "llama3"},
40-
// {"mistral", "mistral"},
41-
// {"pythonic", "pythonic"},
42-
// {"qwen3_coder", "qwen3_coder"},
43-
// {"step3", "step3"},
32+
namespace {
33+
34+
const std::unordered_map<std::string, std::vector<std::string>>
35+
AutoToolCallParserMap = {
36+
{"qwen25", {"qwen2", "qwen3"}},
37+
{"kimi_k2", {"kimi_k2"}},
38+
{"deepseekv3", {"deepseek_v3"}},
39+
{"glm45", {"glm4_moe"}},
40+
{"glm47", {"glm4_moe"}},
4441
};
4542

43+
} // namespace
44+
45+
std::string FunctionCallParser::get_parser_auto(const std::string& parser,
46+
const std::string& model_type) {
47+
if (parser.empty()) {
48+
return "";
49+
}
50+
if (parser == "auto") {
51+
// find the tool call parser that supports the model type
52+
for (const auto& [key, value] : AutoToolCallParserMap) {
53+
if (std::find(value.begin(), value.end(), model_type) != value.end()) {
54+
LOG(INFO) << "Using tool call parser: " << key
55+
<< " for model type: " << model_type;
56+
return key;
57+
}
58+
}
59+
LOG(FATAL) << "Unsupported model type for auto tool call parser: "
60+
<< model_type;
61+
return "";
62+
} else {
63+
// check if the tool call parser is supported
64+
if (parser == "qwen2" || parser == "qwen3") {
65+
return "qwen25";
66+
}
67+
if (AutoToolCallParserMap.find(parser) != AutoToolCallParserMap.end()) {
68+
return parser;
69+
}
70+
LOG(FATAL) << "Unsupported tool call parser: " << parser
71+
<< ". Supported parsers are: " << []() {
72+
std::string supported = "qwen2, qwen3";
73+
for (const auto& [key, value] : AutoToolCallParserMap) {
74+
supported += ", " + key;
75+
}
76+
return supported;
77+
}();
78+
return "";
79+
}
80+
}
81+
4682
FunctionCallParser::FunctionCallParser(const std::vector<JsonTool>& tools,
4783
const std::string& tool_call_parser)
4884
: tools_(tools) {
4985
detector_ = create_detector(tool_call_parser);
5086
CHECK(detector_ != nullptr)
51-
<< "Unsupported tool_call_parser: " << tool_call_parser
52-
<< ". Supported parsers are: " << [this]() {
53-
std::string supported;
54-
for (const auto& [key, value] : kToolCallParserMap) {
55-
if (!supported.empty()) supported += ", ";
56-
supported += key;
57-
}
58-
return supported;
59-
}();
87+
<< "Unsupported tool_call_parser: " << tool_call_parser;
6088
}
6189

6290
bool FunctionCallParser::has_tool_call(const std::string& text) const {
@@ -82,38 +110,26 @@ StreamingParseResult FunctionCallParser::parse_streaming_increment(
82110

83111
std::unique_ptr<BaseFormatDetector> FunctionCallParser::create_detector(
84112
const std::string& tool_call_parser) {
85-
auto it = kToolCallParserMap.find(tool_call_parser);
86-
if (it == kToolCallParserMap.end()) {
113+
if (tool_call_parser.empty()) {
87114
return nullptr;
88115
}
89116

90-
if (it->second == "qwen25") {
117+
if (tool_call_parser == "qwen25") {
91118
return std::make_unique<Qwen25Detector>();
92119
}
93-
94-
if (it->second == "kimi_k2") {
120+
if (tool_call_parser == "kimi_k2") {
95121
return std::make_unique<KimiK2Detector>();
96122
}
97-
98-
if (it->second == "deepseekv3") {
123+
if (tool_call_parser == "deepseekv3") {
99124
return std::make_unique<DeepSeekV3Detector>();
100125
}
101-
102-
if (it->second == "glm45") {
126+
if (tool_call_parser == "glm45") {
103127
return std::make_unique<Glm45Detector>();
104128
}
105-
106-
if (it->second == "glm47") {
129+
if (tool_call_parser == "glm47") {
107130
return std::make_unique<Glm47Detector>();
108131
}
109-
110-
// if (tool_call_parser == "llama3") {
111-
// return std::make_unique<Llama32Detector>();
112-
// }
113-
// if (tool_call_parser == "mistral") {
114-
// return std::make_unique<MistralDetector>();
115-
// }
116-
132+
LOG(ERROR) << "Unsupported tool call parser: " << tool_call_parser;
117133
return nullptr;
118134
}
119135

xllm/function_call/function_call_parser.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ limitations under the License.
1818
#include <memory>
1919
#include <string>
2020
#include <tuple>
21-
#include <unordered_map>
2221
#include <vector>
2322

2423
#include "base_format_detector.h"
@@ -29,8 +28,6 @@ namespace function_call {
2928

3029
class FunctionCallParser {
3130
public:
32-
static const std::unordered_map<std::string, std::string> kToolCallParserMap;
33-
3431
FunctionCallParser(const std::vector<JsonTool>& tools,
3532
const std::string& tool_call_parser);
3633

@@ -54,6 +51,9 @@ class FunctionCallParser {
5451

5552
BaseFormatDetector* get_detector() const { return detector_.get(); }
5653

54+
static std::string get_parser_auto(const std::string& parser,
55+
const std::string& model_type);
56+
5757
private:
5858
std::unique_ptr<BaseFormatDetector> create_detector(
5959
const std::string& tool_call_parser);

xllm/function_call/qwen25_detector_test.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ TEST_F(Qwen25DetectorTest, PerformanceWithManyToolCalls) {
355355

356356
// Test basic streaming functionality
357357
TEST_F(Qwen25StreamingTest, BasicStreamingParsing) {
358-
FunctionCallParser parser(tools_, "qwen3");
358+
FunctionCallParser parser(tools_, "qwen25");
359359

360360
// Simulate streaming chunks
361361
std::vector<std::string> chunks = {
@@ -398,7 +398,7 @@ TEST_F(Qwen25StreamingTest, BasicStreamingParsing) {
398398

399399
// Test multiple tool calls streaming
400400
TEST_F(Qwen25StreamingTest, MultipleToolCallsStreaming) {
401-
FunctionCallParser parser(tools_, "qwen3");
401+
FunctionCallParser parser(tools_, "qwen25");
402402

403403
// Simulate realistic token-level streaming chunks with multiple tool calls
404404
std::vector<std::string> chunks = {"Let",
@@ -493,7 +493,7 @@ TEST_F(Qwen25StreamingTest, MultipleToolCallsStreaming) {
493493

494494
// Test partial token handling
495495
TEST_F(Qwen25StreamingTest, PartialTokenHandling) {
496-
FunctionCallParser parser(tools_, "qwen3");
496+
FunctionCallParser parser(tools_, "qwen25");
497497

498498
// Simulate realistic partial tokens being streamed - testing edge cases where
499499
// tokens are split

xllm/parser/detector_registry.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ limitations under the License.
1818

1919
#include <glog/logging.h>
2020

21+
#include <algorithm>
22+
2123
#include "absl/strings/str_join.h"
2224

2325
namespace xllm {
@@ -33,6 +35,16 @@ namespace {
3335
{name, [](bool stream, bool force) { \
3436
return std::make_unique<ReasoningDetector>(start, end, force, stream); \
3537
}}
38+
39+
// Maps reasoning_parser name to supported model_types
40+
const std::unordered_map<std::string, std::string> AutoReasoningParserMap = {
41+
{"deepseek_v3", "deepseek-v3"},
42+
{"qwen3", "qwen3"},
43+
{"glm4_moe", "glm45"},
44+
{"kimi_k2", "kimi"},
45+
{"step3", "step3"},
46+
};
47+
3648
} // namespace
3749

3850
DetectorRegistry::DetectorRegistry() {
@@ -67,4 +79,21 @@ std::unique_ptr<ReasoningDetector> DetectorRegistry::getDetector(
6779
return it->second(stream_reasoning, force_reasoning);
6880
};
6981

82+
std::string DetectorRegistry::getSupportedParsers() const {
83+
std::string supported;
84+
for (const auto& pair : factories_) {
85+
supported += pair.first + ", ";
86+
}
87+
return supported.substr(0, supported.size() - 2);
88+
}
89+
90+
std::string DetectorRegistry::getParserNameByModelType(
91+
const std::string& model_type) const {
92+
auto it = AutoReasoningParserMap.find(model_type);
93+
if (it != AutoReasoningParserMap.end()) {
94+
return it->second;
95+
}
96+
LOG(ERROR) << "Unsupported model type for reasoning parser: " << model_type;
97+
return "";
98+
}
7099
} // namespace xllm

xllm/parser/detector_registry.h

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,16 @@ class DetectorRegistry {
3737
bool stream_reasoning,
3838
bool force_reasoning);
3939

40-
bool hasDetector(const std::string& model_type) const {
41-
return factories_.find(model_type) != factories_.end();
40+
bool hasDetector(const std::string& parser_name) const {
41+
return factories_.find(parser_name) != factories_.end();
4242
}
4343

44+
std::string getSupportedParsers() const;
45+
46+
// Get the reasoning parser name for auto mode based on model_type
47+
// Returns empty string if not found
48+
std::string getParserNameByModelType(const std::string& model_type) const;
49+
4450
private:
4551
DetectorRegistry();
4652

xllm/parser/reasoning_parser.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ limitations under the License.
1616

1717
#include "xllm/parser/reasoning_parser.h"
1818

19+
#include <glog/logging.h>
20+
1921
namespace xllm {
22+
2023
ReasoningParser::ReasoningParser(const std::string& model_type,
2124
bool stream_reasoning,
2225
bool force_reasoning) {
@@ -34,4 +37,31 @@ ReasoningResult ReasoningParser::parse_stream_chunk(
3437
const_cast<std::string&>(chunk_text));
3538
}
3639

40+
std::string ReasoningParser::get_parser_auto(const std::string& parser,
41+
const std::string& model_type) {
42+
if (parser.empty()) {
43+
return "";
44+
}
45+
auto& registry = DetectorRegistry::getInstance();
46+
if (parser == "auto") {
47+
// find the reasoning parser that supports the model type
48+
std::string parser_name = registry.getParserNameByModelType(model_type);
49+
if (parser_name.empty()) {
50+
LOG(FATAL) << "Unsupported model type for auto reasoning parser: "
51+
<< model_type;
52+
}
53+
LOG(INFO) << "Using reasoning parser: " << parser_name
54+
<< " for model type: " << model_type;
55+
return parser_name;
56+
} else {
57+
// check if the reasoning parser is supported
58+
if (registry.hasDetector(parser)) {
59+
return parser;
60+
}
61+
LOG(FATAL) << "Unsupported reasoning parser: " << parser
62+
<< ". Supported parsers are: " << registry.getSupportedParsers();
63+
return "";
64+
}
65+
}
66+
3767
} // namespace xllm

xllm/parser/reasoning_parser.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ class ReasoningParser {
3434
// Streaming call: incremental parsing
3535
ReasoningResult parse_stream_chunk(const std::string& chunk_text);
3636

37+
static std::string get_parser_auto(const std::string& parser,
38+
const std::string& model_type);
39+
3740
private:
3841
std::unique_ptr<ReasoningDetector> detector_;
3942
};

0 commit comments

Comments
 (0)