@@ -17,46 +17,101 @@ limitations under the License.
1717
1818#include < iostream>
1919#include < stdexcept>
20+ #include < unordered_map>
2021
22+ #include " absl/strings/str_join.h"
2123#include " core/util/uuid.h"
2224#include " deepseekv3_detector.h"
2325#include " glm45_detector.h"
2426#include " glm47_detector.h"
2527#include " kimik2_detector.h"
2628#include " qwen25_detector.h"
29+
2730namespace xllm {
2831namespace function_call {
2932
30- const std::unordered_map<std::string, std::string>
31- FunctionCallParser::kToolCallParserMap = {
32- {" qwen25" , " qwen25" },
33- {" qwen3" , " qwen25" },
34- {" kimi_k2" , " kimi_k2" },
35- {" deepseekv3" , " deepseekv3" },
36- {" glm45" , " glm45" },
37- {" glm47" , " glm47" },
38- // TODO
39- // {"llama3", "llama3"},
40- // {"mistral", "mistral"},
41- // {"pythonic", "pythonic"},
42- // {"qwen3_coder", "qwen3_coder"},
43- // {"step3", "step3"},
33+ namespace {
34+
35+ const std::unordered_map<std::string, std::vector<std::string>> auto_paser_map =
36+ {
37+ {" qwen25" , {" qwen2" , " qwen3" }},
38+ {" kimi_k2" , {" kimi_k2" }},
39+ {" deepseekv3" , {" deepseek_v3" }},
40+ // GLM-4.5 and GLM-4.7 are not supported for tool call parser
41+ // auto-selection
42+ // {"glm45", {"glm4_moe"}},
43+ // {"glm47", {"glm4_moe"}},
4444};
4545
46+ std::string get_auto_paser_map_supported () {
47+ std::vector<std::string> keys;
48+ for (const auto & [key, value] : auto_paser_map) {
49+ for (const auto & v : value) {
50+ keys.push_back (v);
51+ }
52+ }
53+ return absl::StrJoin (keys, " , " );
54+ }
55+
56+ const std::unordered_map<std::string,
57+ std::function<std::unique_ptr<BaseFormatDetector>()>>
58+ detector_factories = {
59+ {" qwen25" , [] { return std::make_unique<Qwen25Detector>(); }},
60+ {" kimi_k2" , [] { return std::make_unique<KimiK2Detector>(); }},
61+ {" deepseekv3" , [] { return std::make_unique<DeepSeekV3Detector>(); }},
62+ {" glm45" , [] { return std::make_unique<Glm45Detector>(); }},
63+ {" glm47" , [] { return std::make_unique<Glm47Detector>(); }},
64+ };
65+
66+ std::string get_supported_detector_factories () {
67+ std::vector<std::string> keys;
68+ for (const auto & [key, value] : detector_factories) {
69+ keys.push_back (key);
70+ }
71+ return absl::StrJoin (keys, " , " );
72+ }
73+
74+ } // namespace
75+
76+ std::string FunctionCallParser::get_parser_auto (const std::string& parser,
77+ const std::string& model_type) {
78+ if (parser.empty ()) {
79+ return " " ;
80+ }
81+ if (parser == " auto" ) {
82+ // find the tool call parser that supports the model type
83+ for (const auto & [key, value] : auto_paser_map) {
84+ if (std::find (value.begin (), value.end (), model_type) != value.end ()) {
85+ LOG (INFO) << " Using tool call parser: " << key
86+ << " for model type: " << model_type;
87+ return key;
88+ }
89+ }
90+ LOG (FATAL) << " Unsupported model type for auto tool call parser: "
91+ << model_type << " . Supported model types are: "
92+ << get_auto_paser_map_supported ();
93+ return " " ;
94+ } else {
95+ // check if the tool call parser is supported
96+ if (parser == " qwen2" || parser == " qwen3" ) {
97+ return " qwen25" ;
98+ }
99+ if (detector_factories.find (parser) != detector_factories.end ()) {
100+ return parser;
101+ }
102+ LOG (FATAL) << " Unsupported tool call parser: " << parser
103+ << " . Supported parsers are: "
104+ << get_supported_detector_factories ();
105+ return " " ;
106+ }
107+ }
108+
46109FunctionCallParser::FunctionCallParser (const std::vector<JsonTool>& tools,
47110 const std::string& tool_call_parser)
48111 : tools_(tools) {
49112 detector_ = create_detector (tool_call_parser);
50113 CHECK (detector_ != nullptr )
51- << " Unsupported tool_call_parser: " << tool_call_parser
52- << " . Supported parsers are: " << [this ]() {
53- std::string supported;
54- for (const auto & [key, value] : kToolCallParserMap ) {
55- if (!supported.empty ()) supported += " , " ;
56- supported += key;
57- }
58- return supported;
59- }();
114+ << " Unsupported tool_call_parser: " << tool_call_parser;
60115}
61116
62117bool FunctionCallParser::has_tool_call (const std::string& text) const {
@@ -82,38 +137,15 @@ StreamingParseResult FunctionCallParser::parse_streaming_increment(
82137
83138std::unique_ptr<BaseFormatDetector> FunctionCallParser::create_detector (
84139 const std::string& tool_call_parser) {
85- auto it = kToolCallParserMap .find (tool_call_parser);
86- if (it == kToolCallParserMap .end ()) {
140+ if (tool_call_parser.empty ()) {
87141 return nullptr ;
88142 }
89143
90- if (it->second == " qwen25" ) {
91- return std::make_unique<Qwen25Detector>();
92- }
93-
94- if (it->second == " kimi_k2" ) {
95- return std::make_unique<KimiK2Detector>();
96- }
97-
98- if (it->second == " deepseekv3" ) {
99- return std::make_unique<DeepSeekV3Detector>();
144+ auto it = detector_factories.find (tool_call_parser);
145+ if (it != detector_factories.end ()) {
146+ return it->second ();
100147 }
101-
102- if (it->second == " glm45" ) {
103- return std::make_unique<Glm45Detector>();
104- }
105-
106- if (it->second == " glm47" ) {
107- return std::make_unique<Glm47Detector>();
108- }
109-
110- // if (tool_call_parser == "llama3") {
111- // return std::make_unique<Llama32Detector>();
112- // }
113- // if (tool_call_parser == "mistral") {
114- // return std::make_unique<MistralDetector>();
115- // }
116-
148+ LOG (ERROR) << " Unsupported tool call parser: " << tool_call_parser;
117149 return nullptr ;
118150}
119151
0 commit comments