diff --git a/LICENSES/Apache-2.0.txt b/LICENSES/Apache-2.0.txt new file mode 100644 index 00000000..d6456956 --- /dev/null +++ b/LICENSES/Apache-2.0.txt @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/THIRD_PARTY_LICENSES.md b/THIRD_PARTY_LICENSES.md new file mode 100644 index 00000000..f732bb73 --- /dev/null +++ b/THIRD_PARTY_LICENSES.md @@ -0,0 +1,37 @@ +# Third-Party License Notices + +This repository is licensed under AGPL-3.0 (see `LICENSE`). + +The following files include code adapted from the vLLM project and are +licensed under Apache License 2.0: + +- `endpoints/OAI/reasoning/abs_reasoning_parsers.py` +- `endpoints/OAI/reasoning/basic_parsers.py` +- `endpoints/OAI/reasoning/deepseek_r1_reasoning_parser.py` +- `endpoints/OAI/reasoning/deepseek_v3_reasoning_parser.py` +- `endpoints/OAI/reasoning/ernie45_reasoning_parser.py` +- `endpoints/OAI/reasoning/exaone4_reasoning_parser.py` +- `endpoints/OAI/reasoning/glm4_moe_reasoning_parser.py` +- `endpoints/OAI/reasoning/gptoss_reasoning_parser.py` +- `endpoints/OAI/reasoning/granite_reasoning_parser.py` +- `endpoints/OAI/reasoning/holo2_reasoning_parser.py` +- `endpoints/OAI/reasoning/hunyuan_a13b_reasoning_parser.py` +- `endpoints/OAI/reasoning/identity_reasoning_parser.py` +- `endpoints/OAI/reasoning/kimi_k2_reasoning_parser.py` +- `endpoints/OAI/reasoning/minimax_m2_reasoning_parser.py` +- `endpoints/OAI/reasoning/mistral_reasoning_parser.py` +- `endpoints/OAI/reasoning/olmo3_reasoning_parser.py` +- `endpoints/OAI/reasoning/qwen3_reasoning_parser.py` +- `endpoints/OAI/reasoning/seedoss_reasoning_parser.py` +- `endpoints/OAI/reasoning/step3_reasoning_parser.py` +- `endpoints/OAI/reasoning/step3p5_reasoning_parser.py` +- `endpoints/OAI/reasoning/__init__.py` +- `endpoints/OAI/utils/parser_options.py` +- `endpoints/OAI/utils/tools.py` +- `templates/tool_calls/qwen3_coder.jinja` + +Source project: +- vLLM: https://github.com/vllm-project/vllm + +The Apache-2.0 license text is provided at: +- `LICENSES/Apache-2.0.txt` diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 6e59dbe3..d9341f10 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -1426,12 +1426,15 @@ async def generate_gen( full_response += chunk chunk_tokens = result.get("token_ids") + token_ids = [] if chunk_tokens is not None: + token_ids = chunk_tokens.flatten().tolist() generated_tokens += chunk_tokens.size(dim=0) generation = { "request_id": request_id, "text": chunk, + "token_ids": token_ids, "prompt_tokens": context_len, "generated_tokens": generated_tokens, "offset": len(full_response), diff --git a/backends/exllamav3/model.py b/backends/exllamav3/model.py index 50c30450..9780f940 100644 --- a/backends/exllamav3/model.py +++ b/backends/exllamav3/model.py @@ -996,6 +996,7 @@ async def generate_gen( max_rq_tokens=self.max_rq_tokens, filters=grammar_handler.filters, ) + self.active_job_ids[request_id] = job generated_tokens = 0 full_response = "" @@ -1013,8 +1014,21 @@ async def generate_gen( if chunk: chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk)) full_response += chunk + + # Extract token IDs as a plain list for downstream consumers if isinstance(chunk_tokens, torch.Tensor): + token_id_list = chunk_tokens.flatten().tolist() generated_tokens += chunk_tokens.size(dim=0) + elif isinstance(chunk_tokens, tuple): + first = chunk_tokens[0] + if isinstance(first, torch.Tensor): + token_id_list = first.flatten().tolist() + else: + token_id_list = list(first) + generated_tokens += len(token_id_list) + else: + token_id_list = list(chunk_tokens) + generated_tokens += len(token_id_list) # Increase penalty range to generated token amount # TODO: @@ -1024,6 +1038,7 @@ async def generate_gen( generation = { "request_id": request_id, "text": chunk, + "token_ids": token_id_list, "prompt_tokens": context_len, "generated_tokens": generated_tokens, "offset": len(full_response), @@ -1044,8 +1059,6 @@ async def generate_gen( yield finish_chunk break - # Assign the active job to the request ID - self.active_job_ids[request_id] = job except asyncio.CancelledError: await job.cancel() diff --git a/common/config_models.py b/common/config_models.py index 0e71734c..c13434d5 100644 --- a/common/config_models.py +++ b/common/config_models.py @@ -284,21 +284,54 @@ class ModelConfig(BaseConfigModel): ), ge=1, ) - prompt_template: Optional[str] = Field( - None, - description=( - "Set the prompt template for this model. (default: None)\n" - "If empty, attempts to look for the model's chat template.\n" + prompt_template: Optional[str] = Field( + None, + description=( + "Set the prompt template for this model. (default: None)\n" + "If empty, attempts to look for the model's chat template.\n" "If a model contains multiple templates in its tokenizer_config.json,\n" "set prompt_template to the name of the template you want to use.\n" - "NOTE: Only works with chat completion message lists!" - ), - ) - vision: Optional[bool] = Field( - False, - description=( - "Enables vision support if the model supports it. (default: False)" - ), + "NOTE: Only works with chat completion message lists!" + ), + ) + reasoning_parser: Optional[str] = Field( + None, + description=( + "Reasoning parser key used to split output into reasoning/content.\n" + "Compatible with vLLM parser naming (e.g. exaone4, deepseek_r1).\n" + "If omitted, defaults to 'basic'." + ), + ) + enable_auto_tool_choice: Optional[bool] = Field( + False, + description=( + "Enable auto tool choice for chat completions (default: False).\n" + "Equivalent to vLLM's --enable-auto-tool-choice.\n" + "Requires tool_call_parser to be set." + ), + ) + tool_call_parser: Optional[str] = Field( + None, + description=( + "Tool parser key for model-generated tool call output.\n" + "Equivalent to vLLM's --tool-call-parser.\n" + "Built-in parser keys include: hermes, llama/llama3_json/llama4_json,\n" + "openai, pythonic, qwen3_coder, qwen3_xml,\n" + "deepseek_v3, deepseek_v31, deepseek_v32." + ), + ) + exclude_tools_when_tool_choice_none: Optional[bool] = Field( + False, + description=( + "Exclude tool definitions from prompt when tool_choice='none'.\n" + "Equivalent to vLLM's --exclude-tools-when-tool-choice-none." + ), + ) + vision: Optional[bool] = Field( + False, + description=( + "Enables vision support if the model supports it. (default: False)" + ), ) _metadata: Metadata = PrivateAttr(Metadata()) diff --git a/common/templating.py b/common/templating.py index cc0cceb1..dda06d85 100644 --- a/common/templating.py +++ b/common/templating.py @@ -12,6 +12,7 @@ from jinja2.ext import loopcontrols from jinja2.sandbox import ImmutableSandboxedEnvironment from loguru import logger +from markupsafe import Markup from packaging import version @@ -24,12 +25,17 @@ class TemplateLoadError(Exception): pass +VALID_TOOL_CALL_FORMATS = {"json", "xml", "auto"} + + @dataclass class TemplateMetadata: """Represents the parsed metadata from a template.""" stop_strings: List[str] = field(default_factory=list) tool_start: Optional[str] = None + tool_end: Optional[str] = None + tool_call_format: str = "json" class PromptTemplate: @@ -46,6 +52,22 @@ class PromptTemplate: ) metadata: Optional[TemplateMetadata] = None + @staticmethod + def _tojson_compat(value, indent=None, ensure_ascii=True): + """Compatibility JSON filter for chat templates. + + Some model templates call ``tojson(ensure_ascii=False)`` while the + bundled Jinja filter may not accept that keyword in sandboxed mode. + """ + return Markup( + json.dumps( + value, + indent=indent, + ensure_ascii=ensure_ascii, + separators=(",", ": "), + ) + ) + async def extract_metadata(self, template_vars: dict): """ Returns deserialized template metadata from a chat template. @@ -76,6 +98,22 @@ async def extract_metadata(self, template_vars: dict): if isinstance(template_module.tool_start, str): template_metadata.tool_start = template_module.tool_start + if hasattr(template_module, "tool_end"): + if isinstance(template_module.tool_end, str): + template_metadata.tool_end = template_module.tool_end + + if hasattr(template_module, "tool_call_format"): + fmt = template_module.tool_call_format + if isinstance(fmt, str) and fmt in VALID_TOOL_CALL_FORMATS: + template_metadata.tool_call_format = fmt + logger.debug(f"Template tool_call_format: {fmt}") + else: + logger.warning( + f"Invalid tool_call_format '{fmt}' in template, " + f"defaulting to 'json'. " + f"Valid values: {VALID_TOOL_CALL_FORMATS}" + ) + self.metadata = template_metadata return template_metadata @@ -107,6 +145,7 @@ def raise_exception(message): self.environment.globals["strftime_now"] = strftime_now self.environment.globals["raise_exception"] = raise_exception + self.environment.filters["tojson"] = self._tojson_compat return self.environment.from_string(template_str) diff --git a/config_sample.yml b/config_sample.yml index 0b65f9e8..bd03afb2 100644 --- a/config_sample.yml +++ b/config_sample.yml @@ -153,6 +153,28 @@ model: # NOTE: Only works with chat completion message lists! prompt_template: + # Reasoning parser key for splitting hidden reasoning and final content. + # Compatible keys include: basic, exaone4, deepseek_r1, deepseek_v3. + # If omitted, TabbyAPI defaults to `basic`. + reasoning_parser: + + # Enable automatic tool selection (default: False). + # Equivalent to vLLM --enable-auto-tool-choice. + # Requires tool_call_parser to be set. + enable_auto_tool_choice: false + + # Tool parser key for model-generated tool call text. + # Equivalent to vLLM --tool-call-parser. + # Built-in values include: + # hermes, llama (alias of llama3_json), llama3_json, llama4_json, + # openai, pythonic, qwen3_coder, qwen3_xml, + # deepseek_v3, deepseek_v31, deepseek_v32. + tool_call_parser: + + # Exclude tool definitions from prompt when tool_choice='none'. + # Equivalent to vLLM --exclude-tools-when-tool-choice-none. + exclude_tools_when_tool_choice_none: false + # Enables vision support if the model supports it. (default: False) vision: false diff --git a/docs/02.-Server-options.md b/docs/02.-Server-options.md index 98cee556..4da7766b 100644 --- a/docs/02.-Server-options.md +++ b/docs/02.-Server-options.md @@ -53,27 +53,31 @@ Note: These are experimental flags that may be removed at any point. Note: Most of the options here will only apply on initial model load/startup (ephemeral). They will not persist unless you add the option name to `use_as_default`. -| Config Option | Type (Default) | Description | -| --------------------- | -------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | -| model_dir | String ("models") | Directory to look for models.

Note: Persisted across subsequent load requests | -| inline_model_loading | Bool (False) | Enables ability to switch models using the `model` argument in a generation request. More info in [Usage](https://github.com/theroyallab/tabbyAPI/wiki/03.-Usage#inline-loading) | -| use_dummy_models | Bool (False) | Send a dummy OAI model card when calling the `/v1/models` endpoint. Used for clients which enforce specific OAI models.

Note: Persisted across subsequent load requests | -| dummy_model_names | List[String] (["gpt-3.5-turbo"]) | List of dummy names to send on model endpoint requests | -| model_name | String (None) | Folder name of a model to load. The below parameters will not apply unless this is filled out. | -| use_as_default | List[String] ([]) | Keys to use by default when loading models. For example, putting `cache_mode` in this array will make every model load with that value unless specified by the API request.

Note: Also applies to the `draft` sub-block | -| max_seq_len | Float (None) | Maximum sequence length of the model. Uses the value from config.json if not specified here. Also called the max context length. | -| tensor_parallel | Bool (False) | Enables tensor parallelism. Automatically falls back to autosplit if GPU split isn't provided.

Note: `gpu_split_auto` is ignored when this is enabled. | -| gpu_split_auto | Bool (True) | Automatically split the model across multiple GPUs. Manual GPU split isn't used if this is enabled. | -| autosplit_reserve | List[Int] ([96]) | Amount of empty VRAM to reserve when loading with autosplit.

Represented as an array of MB per GPU used. | -| gpu_split | List[Float] ([]) | Float array of GBs to split a model between GPUs. | -| rope_scale | Float (1.0) | Adjustment for rope scale (or compress_pos_emb)

Note: If the model has YaRN support, this option will not apply. | -| rope_alpha | Float (None) | Adjustment for rope alpha. Leave blank to automatically calculate based on the max_seq_len.

Note: If the model has YaRN support, this option will not apply. | -| cache_mode | String ("FP16") | Cache mode for the model.

Options: FP16, Q8, Q6, Q4 | -| cache_size | Int (max_seq_len) | Size of the K/V cache

Note: If using CFG, the cache size should be 2 * max_seq_len. | -| chunk_size | Int (2048) | Amount of tokens per chunk with ingestion. A lower value reduces VRAM usage at the cost of ingestion speed. | -| max_batch_size | Int (None) | The absolute maximum amount of prompts to process at one time. This value is automatically adjusted based on cache size. | -| prompt_template | String (None) | Name of a jinja2 chat template to apply for this model. Must be located in the `templates` directory. | -| vision | Bool (False) | Enable vision support for the provided model (if it exists). | +| Config Option | Type (Default) | Description | +| ------------------------------------ | -------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| model_dir | String ("models") | Directory to look for models.

Note: Persisted across subsequent load requests | +| inline_model_loading | Bool (False) | Enables ability to switch models using the `model` argument in a generation request. More info in [Usage](https://github.com/theroyallab/tabbyAPI/wiki/03.-Usage#inline-loading) | +| use_dummy_models | Bool (False) | Send a dummy OAI model card when calling the `/v1/models` endpoint. Used for clients which enforce specific OAI models.

Note: Persisted across subsequent load requests | +| dummy_model_names | List[String] (["gpt-3.5-turbo"]) | List of dummy names to send on model endpoint requests | +| model_name | String (None) | Folder name of a model to load. The below parameters will not apply unless this is filled out. | +| use_as_default | List[String] ([]) | Keys to use by default when loading models. For example, putting `cache_mode` in this array will make every model load with that value unless specified by the API request.

Note: Also applies to the `draft` sub-block | +| max_seq_len | Float (None) | Maximum sequence length of the model. Uses the value from config.json if not specified here. Also called the max context length. | +| tensor_parallel | Bool (False) | Enables tensor parallelism. Automatically falls back to autosplit if GPU split isn't provided.

Note: `gpu_split_auto` is ignored when this is enabled. | +| gpu_split_auto | Bool (True) | Automatically split the model across multiple GPUs. Manual GPU split isn't used if this is enabled. | +| autosplit_reserve | List[Int] ([96]) | Amount of empty VRAM to reserve when loading with autosplit.

Represented as an array of MB per GPU used. | +| gpu_split | List[Float] ([]) | Float array of GBs to split a model between GPUs. | +| rope_scale | Float (1.0) | Adjustment for rope scale (or compress_pos_emb)

Note: If the model has YaRN support, this option will not apply. | +| rope_alpha | Float (None) | Adjustment for rope alpha. Leave blank to automatically calculate based on the max_seq_len.

Note: If the model has YaRN support, this option will not apply. | +| cache_mode | String ("FP16") | Cache mode for the model.

Options: FP16, Q8, Q6, Q4 | +| cache_size | Int (max_seq_len) | Size of the K/V cache

Note: If using CFG, the cache size should be 2 * max_seq_len. | +| chunk_size | Int (2048) | Amount of tokens per chunk with ingestion. A lower value reduces VRAM usage at the cost of ingestion speed. | +| max_batch_size | Int (None) | The absolute maximum amount of prompts to process at one time. This value is automatically adjusted based on cache size. | +| prompt_template | String (None) | Name of a jinja2 chat template to apply for this model. Must be located in the `templates` directory. | +| reasoning_parser | String (None) | Reasoning parser key used to split reasoning and final answer text (vLLM-compatible names, default parser behavior is `basic`). | +| enable_auto_tool_choice | Bool (False) | Enables vLLM-style automatic tool choice handling. Equivalent to `--enable-auto-tool-choice` and requires `tool_call_parser`. | +| tool_call_parser | String (None) | vLLM-compatible tool parser key used to parse model-emitted tool calls. Equivalent to `--tool-call-parser`. | +| exclude_tools_when_tool_choice_none | Bool (False) | Excludes tool definitions from the prompt when `tool_choice` is `"none"`. Equivalent to `--exclude-tools-when-tool-choice-none`. | +| vision | Bool (False) | Enable vision support for the provided model (if it exists). | ### Draft Model Options diff --git a/docs/04.-Chat-Completions.md b/docs/04.-Chat-Completions.md index 647ee92d..96044599 100644 --- a/docs/04.-Chat-Completions.md +++ b/docs/04.-Chat-Completions.md @@ -31,4 +31,11 @@ Now let's pass the custom var in the following template: I'm going to say {{ test_var }} ``` -Running render on this template will now result in: `I'm going to say hello!` \ No newline at end of file +Running render on this template will now result in: `I'm going to say hello!` + +### Reasoning controls + +TabbyAPI supports reasoning parser output separation with vLLM-compatible parser keys via `model.reasoning_parser` in `config.yml`. + +- `include_reasoning` request field: include or suppress reasoning output in responses +- `enable_thinking` / `thinking` request fields: accepted as top-level aliases and forwarded to template vars (`template_vars.enable_thinking`, `template_vars.thinking`) diff --git a/docs/10.-Tool-Calling.md b/docs/10.-Tool-Calling.md index 83e379a5..aaed88cf 100644 --- a/docs/10.-Tool-Calling.md +++ b/docs/10.-Tool-Calling.md @@ -12,11 +12,30 @@ TabbyAPI's tool calling implementation aligns with the [OpenAI Standard](https:/ TabbyAPI's tool implementation supports: - Tool calling when streaming - Calling multiple tools per turn +- `tool_choice` values: `none`, `auto`, `required`, and named function choice +- vLLM-style parser selection via `model.tool_call_parser` Current limitations: -- No support for `tool_choice` parameter (always assumed to be auto) - `strict` parameter not yet supported (OAI format ensured, but dtype and argument name choices not yet enforced) +### vLLM-compatible options + +The following model config options are available to align behavior with vLLM: + +- `enable_auto_tool_choice`: equivalent to `--enable-auto-tool-choice` +- `tool_call_parser`: equivalent to `--tool-call-parser` +- `exclude_tools_when_tool_choice_none`: equivalent to `--exclude-tools-when-tool-choice-none` + +`tool_choice="auto"` requires both `enable_auto_tool_choice: true` and `tool_call_parser` to be set. + +Supported parser keys include: +- `hermes` +- `llama` (alias of `llama3_json`), `llama3_json`, `llama4_json` +- `openai` +- `pythonic` +- `qwen3_coder`, `qwen3_xml` +- `deepseek_v3`, `deepseek_v31`, `deepseek_v32` + ## Model Support TabbyAPI exposes controls within the `prompt_template` to accommodate models specifically tuned for tool calling and those that aren't. By default, TabbyAPI includes `chatml_with_headers_tool_calling.jinja`, a generic template built to support the Llama 3.1 family and other models following the ChatML (with headers) format. diff --git a/endpoints/OAI/reasoning/__init__.py b/endpoints/OAI/reasoning/__init__.py new file mode 100644 index 00000000..d7df6ee1 --- /dev/null +++ b/endpoints/OAI/reasoning/__init__.py @@ -0,0 +1,95 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from endpoints.OAI.reasoning.abs_reasoning_parsers import ( + DeltaMessage, + ReasoningParser, + ReasoningParserManager, +) +from endpoints.OAI.reasoning.basic_parsers import BaseThinkingReasoningParser +from endpoints.OAI.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser +from endpoints.OAI.reasoning.deepseek_v3_reasoning_parser import DeepSeekV3ReasoningParser +from endpoints.OAI.reasoning.ernie45_reasoning_parser import Ernie45ReasoningParser +from endpoints.OAI.reasoning.exaone4_reasoning_parser import Exaone4ReasoningParser +from endpoints.OAI.reasoning.glm4_moe_reasoning_parser import ( + Glm4MoeModelReasoningParser, +) +from endpoints.OAI.reasoning.gptoss_reasoning_parser import GptOssReasoningParser +from endpoints.OAI.reasoning.granite_reasoning_parser import GraniteReasoningParser +from endpoints.OAI.reasoning.holo2_reasoning_parser import Holo2ReasoningParser +from endpoints.OAI.reasoning.hunyuan_a13b_reasoning_parser import ( + HunyuanA13BReasoningParser, +) +from endpoints.OAI.reasoning.identity_reasoning_parser import IdentityReasoningParser +from endpoints.OAI.reasoning.kimi_k2_reasoning_parser import KimiK2ReasoningParser +from endpoints.OAI.reasoning.minimax_m2_reasoning_parser import ( + MiniMaxM2AppendThinkReasoningParser, + MiniMaxM2ReasoningParser, +) +from endpoints.OAI.reasoning.mistral_reasoning_parser import MistralReasoningParser +from endpoints.OAI.reasoning.olmo3_reasoning_parser import Olmo3ReasoningParser +from endpoints.OAI.reasoning.qwen3_reasoning_parser import Qwen3ReasoningParser +from endpoints.OAI.reasoning.seedoss_reasoning_parser import SeedOSSReasoningParser +from endpoints.OAI.reasoning.step3_reasoning_parser import Step3ReasoningParser +from endpoints.OAI.reasoning.step3p5_reasoning_parser import Step3p5ReasoningParser + + +@ReasoningParserManager.register_module("identity") +class _IdentityParser(IdentityReasoningParser): + pass + + +@ReasoningParserManager.register_module("basic") +class _BasicParser(DeepSeekR1ReasoningParser): + pass + + +ReasoningParserManager.reasoning_parsers.update( + { + "deepseek_r1": DeepSeekR1ReasoningParser, + "deepseek_v3": DeepSeekV3ReasoningParser, + "ernie45": Ernie45ReasoningParser, + "exaone4": Exaone4ReasoningParser, + "glm45": Glm4MoeModelReasoningParser, + "openai_gptoss": GptOssReasoningParser, + "granite": GraniteReasoningParser, + "holo2": Holo2ReasoningParser, + "hunyuan_a13b": HunyuanA13BReasoningParser, + "kimi_k2": KimiK2ReasoningParser, + "minimax_m2": MiniMaxM2ReasoningParser, + "minimax_m2_append_think": MiniMaxM2AppendThinkReasoningParser, + "mistral": MistralReasoningParser, + "olmo3": Olmo3ReasoningParser, + "qwen3": Qwen3ReasoningParser, + "seed_oss": SeedOSSReasoningParser, + "step3": Step3ReasoningParser, + "step3p5": Step3p5ReasoningParser, + } +) + + +__all__ = [ + "BaseThinkingReasoningParser", + "DeltaMessage", + "DeepSeekR1ReasoningParser", + "DeepSeekV3ReasoningParser", + "Ernie45ReasoningParser", + "Exaone4ReasoningParser", + "Glm4MoeModelReasoningParser", + "GptOssReasoningParser", + "GraniteReasoningParser", + "Holo2ReasoningParser", + "HunyuanA13BReasoningParser", + "IdentityReasoningParser", + "KimiK2ReasoningParser", + "MiniMaxM2AppendThinkReasoningParser", + "MiniMaxM2ReasoningParser", + "MistralReasoningParser", + "Olmo3ReasoningParser", + "Qwen3ReasoningParser", + "ReasoningParser", + "ReasoningParserManager", + "SeedOSSReasoningParser", + "Step3ReasoningParser", + "Step3p5ReasoningParser", +] diff --git a/endpoints/OAI/reasoning/abs_reasoning_parsers.py b/endpoints/OAI/reasoning/abs_reasoning_parsers.py new file mode 100644 index 00000000..b81983d0 --- /dev/null +++ b/endpoints/OAI/reasoning/abs_reasoning_parsers.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod +from collections.abc import Callable, Sequence +from dataclasses import dataclass +from typing import Any + + +@dataclass +class DeltaMessage: + content: str | None = None + reasoning: str | None = None + + +class ReasoningParser(ABC): + def __init__(self, tokenizer: Any, *args, **kwargs): + self.model_tokenizer = tokenizer + + @property + def vocab(self) -> dict[str, int]: + return self.model_tokenizer.get_vocab() + + @abstractmethod + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + pass + + def is_reasoning_end_streaming( + self, input_ids: Sequence[int], delta_ids: Sequence[int] + ) -> bool: + return self.is_reasoning_end(input_ids) + + @abstractmethod + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + pass + + @abstractmethod + def extract_reasoning( + self, + model_output: str, + request: Any, + ) -> tuple[str | None, str | None]: + pass + + @abstractmethod + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + pass + + def prepare_structured_tag(self, original_tag: str | None, tool_server: Any | None): + return original_tag + + +class ReasoningParserManager: + reasoning_parsers: dict[str, type[ReasoningParser]] = {} + + @classmethod + def list_registered(cls) -> list[str]: + return sorted(cls.reasoning_parsers.keys()) + + @classmethod + def get_reasoning_parser(cls, name: str) -> type[ReasoningParser]: + parser = cls.reasoning_parsers.get(name) + if parser is None: + registered = ", ".join(cls.list_registered()) + raise KeyError( + f"Reasoning parser '{name}' not found. Available parsers: {registered}" + ) + return parser + + @classmethod + def register_module( + cls, + module_name: str, + ) -> Callable[[type[ReasoningParser]], type[ReasoningParser]]: + def _decorator(module: type[ReasoningParser]) -> type[ReasoningParser]: + cls.reasoning_parsers[module_name] = module + return module + + return _decorator diff --git a/endpoints/OAI/reasoning/basic_parsers.py b/endpoints/OAI/reasoning/basic_parsers.py new file mode 100644 index 00000000..f2dfc0c7 --- /dev/null +++ b/endpoints/OAI/reasoning/basic_parsers.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import abstractmethod +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser + + +class BaseThinkingReasoningParser(ReasoningParser): + @property + @abstractmethod + def start_token(self) -> str: + raise NotImplementedError + + @property + @abstractmethod + def end_token(self) -> str: + raise NotImplementedError + + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + self.start_token_id = self.vocab.get(self.start_token) + self.end_token_id = self.vocab.get(self.end_token) + if self.start_token_id is None or self.end_token_id is None: + raise RuntimeError( + f"{self.__class__.__name__} could not locate think tokens in tokenizer" + ) + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + for token_id in reversed(input_ids): + if token_id == self.start_token_id: + return False + if token_id == self.end_token_id: + return True + return False + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + if self.end_token_id not in input_ids[:-1]: + return [] + return input_ids[input_ids.index(self.end_token_id) + 1 :] + + def extract_reasoning( + self, + model_output: str, + request: Any, + ) -> tuple[str | None, str | None]: + model_output_parts = model_output.partition(self.start_token) + model_output = ( + model_output_parts[2] if model_output_parts[1] else model_output_parts[0] + ) + + if self.end_token not in model_output: + return model_output or None, None + + reasoning, _, content = model_output.partition(self.end_token) + return reasoning or None, content or None + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: list[int], + current_token_ids: list[int], + delta_token_ids: list[int], + ) -> DeltaMessage | None: + if len(delta_token_ids) == 1 and ( + delta_token_ids[0] in [self.start_token_id, self.end_token_id] + ): + return None + + if self.start_token_id in previous_token_ids: + if self.end_token_id in delta_token_ids: + end_index = delta_text.find(self.end_token) + reasoning = delta_text[:end_index] or None + content = delta_text[end_index + len(self.end_token) :] or None + return DeltaMessage(reasoning=reasoning, content=content) + if self.end_token_id in previous_token_ids: + return DeltaMessage(content=delta_text or None) + return DeltaMessage(reasoning=delta_text or None) + + if self.start_token_id in delta_token_ids: + if self.end_token_id in delta_token_ids: + start_index = delta_text.find(self.start_token) + end_index = delta_text.find(self.end_token) + reasoning = delta_text[start_index + len(self.start_token) : end_index] + content = delta_text[end_index + len(self.end_token) :] + return DeltaMessage(reasoning=reasoning or None, content=content or None) + return DeltaMessage(reasoning=delta_text or None) + + return DeltaMessage(content=delta_text or None) diff --git a/endpoints/OAI/reasoning/deepseek_r1_reasoning_parser.py b/endpoints/OAI/reasoning/deepseek_r1_reasoning_parser.py new file mode 100644 index 00000000..3b93bb17 --- /dev/null +++ b/endpoints/OAI/reasoning/deepseek_r1_reasoning_parser.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage +from endpoints.OAI.reasoning.basic_parsers import BaseThinkingReasoningParser + + +class DeepSeekR1ReasoningParser(BaseThinkingReasoningParser): + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: list[int], + current_token_ids: list[int], + delta_token_ids: list[int], + ) -> DeltaMessage | None: + ret = super().extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) + + if ( + ret is not None + and self.start_token_id not in previous_token_ids + and self.start_token_id not in delta_token_ids + ): + if self.end_token_id in delta_token_ids: + end_index = delta_text.find(self.end_token) + reasoning = delta_text[:end_index] or None + content = delta_text[end_index + len(self.end_token) :] or None + return DeltaMessage(reasoning=reasoning, content=content) + if self.end_token_id in previous_token_ids: + return DeltaMessage(content=delta_text or None) + return DeltaMessage(reasoning=delta_text or None) + + return ret diff --git a/endpoints/OAI/reasoning/deepseek_v3_reasoning_parser.py b/endpoints/OAI/reasoning/deepseek_v3_reasoning_parser.py new file mode 100644 index 00000000..5e8deb73 --- /dev/null +++ b/endpoints/OAI/reasoning/deepseek_v3_reasoning_parser.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser +from endpoints.OAI.reasoning.deepseek_r1_reasoning_parser import ( + DeepSeekR1ReasoningParser, +) +from endpoints.OAI.reasoning.identity_reasoning_parser import IdentityReasoningParser + + +class DeepSeekV3ReasoningParser(ReasoningParser): + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {} + thinking = bool(chat_kwargs.get("thinking", False)) + enable_thinking = bool(chat_kwargs.get("enable_thinking", False)) + thinking = thinking or enable_thinking + + if thinking: + self._parser = DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs) + else: + self._parser = IdentityReasoningParser(tokenizer, *args, **kwargs) + + def is_reasoning_end(self, input_ids: list[int]) -> bool: + return self._parser.is_reasoning_end(input_ids) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + return self._parser.extract_content_ids(input_ids) + + def extract_reasoning( + self, + model_output: str, + request: Any, + ) -> tuple[str | None, str | None]: + return self._parser.extract_reasoning(model_output, request) + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: list[int], + current_token_ids: list[int], + delta_token_ids: list[int], + ) -> DeltaMessage | None: + return self._parser.extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) diff --git a/endpoints/OAI/reasoning/ernie45_reasoning_parser.py b/endpoints/OAI/reasoning/ernie45_reasoning_parser.py new file mode 100644 index 00000000..bff91166 --- /dev/null +++ b/endpoints/OAI/reasoning/ernie45_reasoning_parser.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage +from endpoints.OAI.reasoning.basic_parsers import BaseThinkingReasoningParser + + +class Ernie45ReasoningParser(BaseThinkingReasoningParser): + response_start_token: str = "" + response_end_token: str = "" + newline_token: str = "<0x0A>" + + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + self.response_start_token_id = self.vocab.get(self.response_start_token) + self.response_end_token_id = self.vocab.get(self.response_end_token) + self.newline_token_id = self.vocab.get(self.newline_token) + self.parser_token_ids = [self.end_token_id, self.response_end_token_id] + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + if len(delta_token_ids) == 1 and ( + delta_token_ids[0] + in [ + self.start_token_id, + self.end_token_id, + self.response_start_token_id, + self.response_end_token_id, + ] + ): + return None + + if self.end_token_id in delta_token_ids: + think_end_index = delta_text.find(self.end_token) + reasoning = delta_text[:think_end_index] + content = delta_text[think_end_index + len(self.end_token) :].lstrip("\n") + response_start_idx = content.find(self.response_start_token) + response_end_idx = content.rfind(self.response_end_token) + if response_start_idx != -1: + content = content[response_start_idx + len(self.response_start_token) :] + if response_end_idx != -1: + content = content[:response_end_idx] + return DeltaMessage(reasoning=reasoning, content=content or None) + + if self.end_token_id in previous_token_ids: + content = delta_text + if self.response_start_token_id in delta_token_ids: + content = content.lstrip("\n") + response_start_idx = content.find(self.response_start_token) + content = content[response_start_idx + len(self.response_start_token) :] + response_end_idx = content.rfind(self.response_end_token) + if response_end_idx != -1: + content = content[:response_end_idx] + elif self.response_end_token_id in delta_token_ids: + response_end_idx = content.rfind(self.response_end_token) + content = content[:response_end_idx] + + if previous_token_ids and previous_token_ids[-1] in self.parser_token_ids: + if delta_token_ids and delta_token_ids[0] == self.newline_token_id: + content = content.lstrip("\n") + if len(previous_token_ids) > 1 and previous_token_ids[-2] == self.end_token_id: + if delta_token_ids and delta_token_ids[0] == self.newline_token_id: + content = content.lstrip("\n") + + return DeltaMessage(content=content or None) + + return DeltaMessage(reasoning=delta_text) + + def extract_reasoning( + self, model_output: str, request: Any + ) -> tuple[str | None, str | None]: + reasoning, content = super().extract_reasoning(model_output, request) + if content: + start_idx = content.find(self.response_start_token) + end_idx = content.rfind(self.response_end_token) + if start_idx != -1 and end_idx != -1 and start_idx < end_idx: + content = content[start_idx + len(self.response_start_token) : end_idx] + return reasoning, content or None diff --git a/endpoints/OAI/reasoning/exaone4_reasoning_parser.py b/endpoints/OAI/reasoning/exaone4_reasoning_parser.py new file mode 100644 index 00000000..5e2dd43f --- /dev/null +++ b/endpoints/OAI/reasoning/exaone4_reasoning_parser.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import ( + DeltaMessage, + ReasoningParser, + ReasoningParserManager, +) + + +@ReasoningParserManager.register_module("exaone4") +class Exaone4ReasoningParser(ReasoningParser): + """ + Reasoning parser for EXAONE 4.x models. + + Behavior notes: + - EXAONE uses `enable_thinking` (not `thinking`) to control reasoning mode. + - Templates may prefill ``, so streamed/output text can start directly + with reasoning text and close at ``. + """ + + start_token = "" + end_token = "" + # Tool-call starts supported by ToolCallProcessor parser families. + # We use these as fallback reasoning boundaries when a model emits + # tool syntax without closing . + tool_start_markers = ( + "", + "", + "<|tool▁call▁begin|>", + "<|DSML|function_calls>", + "<|DSML|invoke", + "<|python_tag|>", + ) + + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {} + self.thinking_enabled = bool(chat_kwargs.get("enable_thinking", False)) + self.start_token_id = self.vocab.get(self.start_token) + self.end_token_id = self.vocab.get(self.end_token) + + def _strip_reasoning_tokens(self, text: str) -> str: + if not text: + return "" + return text.replace(self.start_token, "").replace(self.end_token, "") + + def _trailing_overlap_len(self, text: str, token: str) -> int: + """Longest suffix overlap of text with token prefix.""" + max_len = min(len(text), len(token) - 1) + for size in range(max_len, 0, -1): + if text.endswith(token[:size]): + return size + return 0 + + def _find_first_marker(self, text: str, markers: Sequence[str]) -> tuple[int, str] | None: + first_idx = -1 + first_marker = "" + for marker in markers: + idx = text.find(marker) + if idx == -1: + continue + if first_idx == -1 or idx < first_idx: + first_idx = idx + first_marker = marker + if first_idx == -1: + return None + return first_idx, first_marker + + def _max_trailing_overlap_len(self, text: str, markers: Sequence[str]) -> int: + overlap = 0 + for marker in markers: + overlap = max(overlap, self._trailing_overlap_len(text, marker)) + return overlap + + def _split_reasoning_content_streaming( + self, text: str + ) -> tuple[str | None, str | None]: + """Split text into reasoning/content for streaming-safe diffing. + + Important: when end token is not yet complete, withhold a trailing + overlap with `` or tool-call prefixes to avoid leaking + partial control-tag bytes into reasoning output. This prevents + boundary-split regressions such as `answer` and + `{...}`. + """ + if not self.thinking_enabled: + content = self._strip_reasoning_tokens(text) + return None, content or None + + body = text + if self.start_token in body: + _, _, body = body.partition(self.start_token) + + if self.end_token in body: + reasoning, _, content = body.partition(self.end_token) + return reasoning or None, self._strip_reasoning_tokens(content) or None + + marker_match = self._find_first_marker(body, self.tool_start_markers) + if marker_match is not None: + marker_index, _ = marker_match + reasoning = body[:marker_index] + content = body[marker_index:] + return reasoning or None, self._strip_reasoning_tokens(content) or None + + reasoning = body.replace(self.start_token, "") + overlap = max( + self._trailing_overlap_len(reasoning, self.end_token), + self._max_trailing_overlap_len(reasoning, self.tool_start_markers), + ) + if overlap: + reasoning = reasoning[:-overlap] + return reasoning or None, None + + def _delta_from_previous(self, previous: str | None, current: str | None) -> str | None: + if current is None: + return None + previous_text = previous or "" + if current.startswith(previous_text): + delta = current[len(previous_text) :] + else: + # Fallback for recovery paths where prefix alignment breaks. + delta = current + return delta or None + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + if not self.thinking_enabled: + return True + if self.end_token_id is None: + return False + return any(token_id == self.end_token_id for token_id in reversed(input_ids)) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + if not self.thinking_enabled: + return input_ids + if self.end_token_id is None or self.end_token_id not in input_ids[:-1]: + return [] + return input_ids[input_ids.index(self.end_token_id) + 1 :] + + def extract_reasoning( + self, + model_output: str, + request: Any, + ) -> tuple[str | None, str | None]: + if not self.thinking_enabled: + content = self._strip_reasoning_tokens(model_output) + return None, content or None + + if self.start_token in model_output: + _, _, model_output = model_output.partition(self.start_token) + + if self.end_token in model_output: + reasoning, _, content = model_output.partition(self.end_token) + content = self._strip_reasoning_tokens(content) + return reasoning or None, content or None + + reasoning = model_output.replace(self.start_token, "") + return reasoning or None, None + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: list[int], + current_token_ids: list[int], + delta_token_ids: list[int], + ) -> DeltaMessage | None: + if not delta_text and not delta_token_ids: + return None + + if not self.thinking_enabled: + prev_reasoning, prev_content = self._split_reasoning_content_streaming( + previous_text + ) + cur_reasoning, cur_content = self._split_reasoning_content_streaming( + current_text + ) + content_delta = self._delta_from_previous(prev_content, cur_content) + if content_delta is None: + return None + return DeltaMessage(content=content_delta) + + if len(delta_token_ids) == 1 and ( + (self.start_token_id is not None and delta_token_ids[0] == self.start_token_id) + or (self.end_token_id is not None and delta_token_ids[0] == self.end_token_id) + ): + return None + + prev_reasoning, prev_content = self._split_reasoning_content_streaming(previous_text) + cur_reasoning, cur_content = self._split_reasoning_content_streaming(current_text) + + reasoning_delta = self._delta_from_previous(prev_reasoning, cur_reasoning) + content_delta = self._delta_from_previous(prev_content, cur_content) + + if reasoning_delta is None and content_delta is None: + return None + return DeltaMessage(reasoning=reasoning_delta, content=content_delta) diff --git a/endpoints/OAI/reasoning/glm4_moe_reasoning_parser.py b/endpoints/OAI/reasoning/glm4_moe_reasoning_parser.py new file mode 100644 index 00000000..9368f2c3 --- /dev/null +++ b/endpoints/OAI/reasoning/glm4_moe_reasoning_parser.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from endpoints.OAI.reasoning.holo2_reasoning_parser import Holo2ReasoningParser + + +class Glm4MoeModelReasoningParser(Holo2ReasoningParser): + pass diff --git a/endpoints/OAI/reasoning/gptoss_reasoning_parser.py b/endpoints/OAI/reasoning/gptoss_reasoning_parser.py new file mode 100644 index 00000000..3e454bed --- /dev/null +++ b/endpoints/OAI/reasoning/gptoss_reasoning_parser.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import json +from collections.abc import Sequence +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser + + +NO_FUNC_REASONING_TAG = { + "type": "structural_tag", + "format": { + "type": "triggered_tags", + "tags": [ + { + "begin": "<|channel|>analysis<|message|>", + "content": {"type": "any_text"}, + "end": "<|end|>", + } + ], + "triggers": ["<|channel|>analysis"], + "stop_after_first": False, + }, +} + + +class GptOssReasoningParser(ReasoningParser): + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + def _split_harmony(self, text: str) -> tuple[str | None, str | None]: + # Minimal harmony-compatible splitter without vLLM parser dependency. + analysis_tag = "<|channel|>analysis<|message|>" + final_tag = "<|channel|>final<|message|>" + end_tag = "<|end|>" + + a_idx = text.find(analysis_tag) + f_idx = text.find(final_tag) + if a_idx == -1 and f_idx == -1: + return None, text or None + + reasoning = None + content = None + + if a_idx != -1: + a_start = a_idx + len(analysis_tag) + a_end = text.find(end_tag, a_start) + if a_end == -1: + a_end = f_idx if f_idx != -1 else len(text) + reasoning = text[a_start:a_end] or None + + if f_idx != -1: + f_start = f_idx + len(final_tag) + f_end = text.find(end_tag, f_start) + if f_end == -1: + f_end = len(text) + content = text[f_start:f_end] or None + + return reasoning, content + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + text = self.model_tokenizer.decode(input_ids) + return "<|channel|>final<|message|>" in text + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + _, content = self._split_harmony(self.model_tokenizer.decode(input_ids)) + if content is None: + return [] + return self.model_tokenizer.encode(content) + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + prev_reasoning, prev_content = self._split_harmony(previous_text) + cur_reasoning, cur_content = self._split_harmony(current_text) + + reasoning_delta = None + content_delta = None + if cur_reasoning is not None: + prev_r = prev_reasoning or "" + reasoning_delta = ( + cur_reasoning[len(prev_r) :] + if cur_reasoning.startswith(prev_r) + else cur_reasoning + ) or None + if cur_content is not None: + prev_c = prev_content or "" + content_delta = ( + cur_content[len(prev_c) :] + if cur_content.startswith(prev_c) + else cur_content + ) or None + + if reasoning_delta is None and content_delta is None: + return None + return DeltaMessage(reasoning=reasoning_delta, content=content_delta) + + def extract_reasoning( + self, + model_output: str, + request: Any, + ) -> tuple[str | None, str | None]: + return self._split_harmony(model_output) + + def prepare_structured_tag( + self, original_tag: str | None, tool_server: Any | None + ) -> str | None: + if original_tag is not None: + return original_tag + return json.dumps(NO_FUNC_REASONING_TAG) diff --git a/endpoints/OAI/reasoning/granite_reasoning_parser.py b/endpoints/OAI/reasoning/granite_reasoning_parser.py new file mode 100644 index 00000000..c60c3fac --- /dev/null +++ b/endpoints/OAI/reasoning/granite_reasoning_parser.py @@ -0,0 +1,376 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +try: + import regex as re +except ImportError: + import re + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser + + + +class GraniteReasoningParser(ReasoningParser): + """ + Reasoning parser for IBM Granite. + + IBM granite models currently use "Here is my thought process:" + and "Here is my response:" to separate its thinking / response outputs. + """ + + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + # NOTE: There have been some observed occurrences of quantized + # instances of the current models using "Here's" instead of "Here is", + # so to be safe, we match on both. + self.think_start_expr = r"(?:Here's|Here is) my thought process:" + self.response_start_expr = r"(?:Here's|Here is) my response:" + + self.reasoning_regex = re.compile( + rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", re.DOTALL + ) + + self.valid_think_starts = [ + "Here's my thought process:", + "Here is my thought process:", + ] + self.valid_response_starts = ["Here's my response:", "Here is my response:"] + + # Substrings to match for sequence boundaries on raw text + self.seq_boundary_end = ":" + self.seq_boundary_start = "Here" + + # The longest any thinking / start of response message can be + self.longest_think_start = max( + len(think_start) for think_start in self.valid_think_starts + ) + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + text = self.model_tokenizer.decode(input_ids) + return any(resp in text for resp in self.valid_response_starts) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + text = self.model_tokenizer.decode(input_ids) + _, content = self.extract_reasoning(text, None) + if not content: + return [] + return self.model_tokenizer.encode(content) + + def extract_reasoning( + self, model_output: str, request: ChatCompletionRequest + ) -> tuple[str | None, str | None]: + """Extract the reasoning content & content sections, respectively. + If the sequence doesn't match what we expect, i.e., the model generates + something else, all content is considered non-reasoning content. + + Args: + model_output (str): Output of the model to be parsed. + request (ChatCompletionRequest): Request being processed. + + Returns: + tuple[Optional[str], Optional[str]]: Tuple pair containing the + reasoning content and non-reasoning content. + """ + re_match = self.reasoning_regex.findall(model_output) + if not re_match: + return None, model_output + reasoning, response_content = re_match[0] + if not response_content: + return reasoning, None + return reasoning, response_content + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + """Extract the reasoning content / content emitted by granite models; + If the sequence doesn't match what we expect, i.e., the model generates + something else, all content is considered non-reasoning content. + + NOTE: Granite models do not use a special token to start their reasoning + and response sections; instead they have token sequences, e.g., + + Here is my thought process: Foo Here is my response: Bar + + This increases the complexity of correctly handling streams, since we + need to watch for specific sequences and correctly parse them without + dropping content that is potentially overlapping & spanning multiple + delta messages. + + Args: + previous_text (str): Previous text outside of this delta message. + current_text (str): Previous text + delta text. + delta_text (str): Text to consider and parse content from. + previous_token_ids (Sequence[int]): Token IDs of previous_text. + current_token_ids (Sequence[int]): Token IDs of current_text. + delta_token_ids (Sequence[int]): Token IDs of delta_text. + + Returns: + Union[DeltaMessage, None] + DeltaMessage with either reasoning content or content, or None. + """ + reasoning, resp_seq_len, content = self._get_content_sections(current_text) + # Either we haven't finished the start of the reasoning sequence, + # or the model is generating something unexpected. + if not reasoning: + delta_message = self._get_delta_message_with_no_reasoning_bounds( + current_text, delta_text + ) + # We have a start of reasoning message, but have not yet finished + # the start of response sequence. + elif not content: + delta_message = self._get_delta_message_with_no_response_bounds( + current_text, reasoning, delta_text + ) + # We've finished both the start of reasoning and start of response seq. + else: + # This should never happen since we matched on the response + assert resp_seq_len is not None + delta_message = self._get_delta_message_with_both_bounds( + delta_text, reasoning, content, current_text, resp_seq_len + ) + if not delta_message.content and not delta_message.reasoning: + return None + return delta_message + + #### Implementation details of stream parsing for granite models + def _is_reasoning_start_substr(self, text: str) -> bool: + """Check if a text matches one of the possible start reasoning seqs. + + Args: + text (str): Text to check for leading substr. + + Returns: + bool: True if any of the possible reasoning start seqs match. + """ + return any( + think_start.startswith(text) for think_start in self.valid_think_starts + ) + + def _is_response_start_substr(self, text: str) -> bool: + """Check if a text matches one of the possible start response seqs. + + Args: + text (str): Text to check for leading substr. + + Returns: + bool: True if any of the possible response start seqs match. + """ + return any( + response_start.startswith(text) + for response_start in self.valid_response_starts + ) + + def _get_delta_message_with_no_reasoning_bounds( + self, + current_text: str, + delta_text: str, + ) -> DeltaMessage: + """Parse the delta message when the current text has not yet completed + its start of reasoning sequence. + + Args: + current_text (str): The full previous + delta text. + delta_text (str): Text to consider and parse content from. + + Returns: + DeltaMessage: Message containing the parsed content. + """ + prev_longest_length = len(current_text) - len(delta_text) + is_substr = self._is_reasoning_start_substr(current_text) + was_substr = self._is_reasoning_start_substr(current_text[:prev_longest_length]) + + # Check if we just generated something NOT in the special token seq; + # if so, add everything that we previously skipped with this delta + # message and append everything to content in the future. + if was_substr and not is_substr: + return DeltaMessage( + reasoning=None, + content=current_text, + ) + if is_substr: + # Might still be in the special token sequence; return nothing + return DeltaMessage(reasoning=None, content=None) + # Otherwise the sequence has already been broken and we already + # corrected; just return the delta text as normal content. + return DeltaMessage(reasoning=None, content=delta_text) + + def _get_delta_message_with_no_response_bounds( + self, + current_text: str, + reasoning: str, + delta_text: str, + ) -> DeltaMessage: + """Parse the delta message when the current text has both reasoning + content with no (response) content. NOTE that we may have overlapping + tokens with the start of reasoning / start of response sequences on + either side of the delta text. + + Args: + current_text (str): The full previous + delta text. + reasoning (str): reasoning content from current_text. + delta_text (str): Text to consider and parse content from. + + Returns: + DeltaMessage: Message containing the parsed content. + """ + # If we have no reasoning content or explicitly end with the start of + # response sequence, we are in transition to the response; need to be + # careful here, since the final token (:) will match the reasoning + # content and fully parse it out; we should not pass the : back. + ends_with_start_response_seq = any( + current_text.endswith(response_start) + for response_start in self.valid_response_starts + ) + if reasoning is None or ends_with_start_response_seq: + return DeltaMessage(reasoning=None, content=None) + + # Consider previous / current text only within context of the reasoning + previous_text = reasoning[: -len(delta_text)] + current_text = reasoning + + # We need to be careful about adding unfinished response sequences; + # Find the place at which we MIGHT be starting a response sequence + prev_idx = previous_text.rfind(self.seq_boundary_start) + delta_idx = delta_text.rfind(self.seq_boundary_start) + + # Check the state of potential start of response substring matches. + prev_was_substr = ( + self._is_response_start_substr(previous_text[prev_idx:]) + if prev_idx >= 0 + else False + ) + delta_continues_substr = ( + self._is_response_start_substr(current_text[prev_idx:]) + if prev_idx >= 0 + else False + ) + delta_new_substr = ( + self._is_response_start_substr(delta_text[delta_idx:]) + if delta_idx >= 0 + else False + ) + + # Delta only contains potential continued response sequence text. + if delta_continues_substr: + return DeltaMessage(reasoning=None, content=None) + + if not prev_was_substr: + # Delta may be starting a new response seq but has other text too. + if delta_new_substr: + return DeltaMessage(reasoning=delta_text[:delta_idx], content=None) + # Normal case for most reasoning text (no potential special seqs). + return DeltaMessage(reasoning=delta_text, content=None) + # The substring that previously seemed to be a potential response + # seq wasn't one; we need to add the content to the delta message, + # and also slice off the potential response sequence + elif delta_new_substr: + reasoning = previous_text[prev_idx:] + delta_text[:delta_idx] + return DeltaMessage(reasoning=reasoning, content=None) + # No new substring yet, and we broke our old one; take the whole delta + return DeltaMessage( + reasoning=previous_text[prev_idx:] + delta_text, + content=None, + ) + + def _get_delta_message_with_both_bounds( + self, + delta_text: str, + reasoning: str, + response_content: str, + current_text: str, + response_seq_len: int, + ) -> DeltaMessage: + """Parse the delta message when the current text has both reasoning + content and normal (response) content. + + Args: + delta_text: Text to consider and parse content from. + reasoning: reasoning content from current_text. + response_content: response content from current_text. + current_text: The full previous + delta text. + response_seq_len: Len of the complete response sequence used. + + Returns: + DeltaMessage: Message containing the parsed content. + """ + # Always have content; take length to the end + delta_content = delta_text[-len(response_content) :] + reasoning_end_idx = len(delta_text) - (len(response_content) + response_seq_len) + + if reasoning_end_idx < 0: + delta_reasoning = None + else: + # Get the starting offset + start_reasoning_idx = ( + len(reasoning) + response_seq_len + len(response_content) - 1 + ) + delta_offset = len(current_text) - len(delta_text) + start_offset = start_reasoning_idx - delta_offset + if start_offset < 0: + start_offset = 0 + delta_reasoning = delta_text[start_offset:reasoning_end_idx] + + return DeltaMessage( + reasoning=delta_reasoning, + content=delta_content, + ) + + def _get_content_sections( + self, current_text: str + ) -> tuple[str | None, int | None, str | None]: + """Parse the text to extract the reasoning content / content + if we have them. + + Args: + current_text (str): The full previous + delta text. + + Returns: + tuple[Optional[str], Optional[int], Optional[str]]: Tuple of len 3 + containing the reasoning content, the length of the response seq + (if there is one) and the non-reasoning content. + """ + current_chunk_start = 0 + start_reasoning = None + parsed_content = False + delimiter_idxs = [ + idx + for idx, char in enumerate(current_text) + if char == self.seq_boundary_end + ] + + for current_chunk_end in delimiter_idxs: + current_chunk = current_text[current_chunk_start:current_chunk_end] + # Check to see if the start of reasoning seq if complete + if start_reasoning is None: + for think_start in self.valid_think_starts: + if current_chunk == think_start[:-1]: + start_reasoning = current_chunk_end + 1 + current_chunk_start = current_chunk_end + 1 + break + + # Check to see if the start of response seq if complete + elif not parsed_content: + for response_start in self.valid_response_starts: + if current_chunk[-len(response_start) + 1 :] == response_start[:-1]: + # Mark end of reasoning and start response content + # after the start of response sequence. + end_reasoning = current_chunk_end - len(response_start) + reasoning = current_text[start_reasoning:end_reasoning] + response_content = current_text[current_chunk_end + 1 :] + return reasoning, len(response_start), response_content + + if start_reasoning and not parsed_content: + return current_text[start_reasoning:], None, None + return None, None, None diff --git a/endpoints/OAI/reasoning/holo2_reasoning_parser.py b/endpoints/OAI/reasoning/holo2_reasoning_parser.py new file mode 100644 index 00000000..cdd4d356 --- /dev/null +++ b/endpoints/OAI/reasoning/holo2_reasoning_parser.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser +from endpoints.OAI.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser +from endpoints.OAI.reasoning.identity_reasoning_parser import IdentityReasoningParser + + +class Holo2ReasoningParser(ReasoningParser): + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {} + thinking = bool(chat_kwargs.get("thinking", True)) + enable_thinking = bool(chat_kwargs.get("enable_thinking", True)) + thinking = thinking and enable_thinking + self._parser = ( + DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs) + if thinking + else IdentityReasoningParser(tokenizer, *args, **kwargs) + ) + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + return self._parser.is_reasoning_end(input_ids) + + def is_reasoning_end_streaming( + self, input_ids: Sequence[int], delta_ids: Sequence[int] + ) -> bool: + return self._parser.is_reasoning_end_streaming(input_ids, delta_ids) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + return self._parser.extract_content_ids(input_ids) + + def extract_reasoning( + self, model_output: str, request: Any + ) -> tuple[str | None, str | None]: + return self._parser.extract_reasoning(model_output, request) + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + return self._parser.extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) diff --git a/endpoints/OAI/reasoning/hunyuan_a13b_reasoning_parser.py b/endpoints/OAI/reasoning/hunyuan_a13b_reasoning_parser.py new file mode 100644 index 00000000..b3ae3c72 --- /dev/null +++ b/endpoints/OAI/reasoning/hunyuan_a13b_reasoning_parser.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +try: + import regex as re +except ImportError: + import re + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser + + + +class HunyuanA13BReasoningParser(ReasoningParser): + """ + Reasoning parser for Hunyuan A13B Model + + HunyuanReasoningParser + + This class implements a reasoning parser specifically designed + for the Hunyuan A13B Model. It is responsible for parsing and + extracting structured reasoning and answer segments from model + outputs that follow a specific pattern. + + Key Features: + - For non-stream output , Recognizes and extracts reasoning ("think") + and answer ("answer") sections from text using regular expressions. + - For stream process, it requires a token id sequences to change the + reasoning state and other state so it maintains internal state to + manage parsing across multiple token. + + + think start: "\n": [14023, 771, 397] + think ends: "\n\n\n": [198, 524, 27963, 397, 27, 9399, 397] + response ends: "\n": [524, 9399, 29] + """ + + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + self.think_start_expr = r"\n" + self.think_end_expr = r"\n\n" + + self.response_start_expr = r"\n\n\n" + self.response_end_expr = r"\n" + + self.full_match_reasoning_regex = re.compile( + rf"(?:{self.think_start_expr}(.*?){self.response_start_expr})?(.*?){self.response_end_expr}", + re.DOTALL, + ) + + self.half_match_reasoning_regex = re.compile( + rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", re.DOTALL + ) + + self.think_start_ids = [14023, 771, 397] + self.think_start_ids_fast = [14023, 771, 1363] + self.response_start_ids = [198, 524, 27963, 397, 27, 9399, 397] + self.response_start_ids_fast = [524, 27963, 397, 27, 9399, 397] + self.response_end_ids = [198, 524, 9399, 29] + self.fast_think_ids = [14023, 771, 1363, 524, 27963, 397, 27, 9399, 397] + + # when state change, send out all the buffered text in last state + self.buffered_text = [] + self.buffered_ids = [] + + self.current_state = "reasoning" + self.all_states = ["reasoning", "response"] + + self.current_state = "idle" + self.expected_sequence = self.think_start_ids + # this sequence only for the think start, it has two way to start. + self.expected_sequence_side = self.think_start_ids_fast + self.sequence_index = 0 + self.token_buffer = [] + self.text_buffer = "" + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + return self.current_state == "response" + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + # for hunyuan streaming reason parsing, the stream parse + # will call first, and the same token will be called in + # is_reasoning_end and extract_content_ids + # this id is not part of content, so just return [] here. + return [] + + def extract_reasoning( + self, model_output: str, request: ChatCompletionRequest + ) -> tuple[str | None, str | None]: + """Extract the reasoning content & content sections, respectively. + If the sequence doesn't match what we expect, i.e., the model generates + something else, all content is considered non-reasoning content. + + Args: + model_output (str): Output of the model to be parsed. + request (ChatCompletionRequest): Request being processed. + + Returns: + tuple[Optional[str], Optional[str]]: Tuple pair containing the + reasoning content and non-reasoning content. + """ + + re_match = self.full_match_reasoning_regex.findall(model_output) + if re_match: + reasoning, response_content = re_match[0] + if len(reasoning) == 0: + reasoning = None + if len(response_content) == 0: + response_content = None + return reasoning, response_content + + fallback_regex = self.half_match_reasoning_regex + fallback_match = fallback_regex.findall(model_output) + if fallback_match: + reasoning, response_content = fallback_match[0] + + if response_content.endswith(self.response_end_expr): + response_content = response_content[: -len(self.response_end_expr)] + + if len(reasoning) == 0: + reasoning = None + if len(response_content) == 0: + response_content = None + + return reasoning, response_content + + return None, model_output + + def _is_strict_increasing_subsequence( + self, subsequence: Sequence[int], sequence: Sequence[int] + ) -> bool: + if not subsequence: + return False + + sub_idx = 0 + for num in sequence: + if sub_idx < len(subsequence) and num == subsequence[sub_idx]: + sub_idx += 1 + return sub_idx == len(subsequence) + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + """Extract content using token ID sequence state machine""" + # Define sequences + think_start_sequence = self.think_start_ids + response_start_sequence = self.response_start_ids + response_end_sequence = self.response_end_ids + + if not delta_token_ids: + return None + + # Process each token in the delta + token = delta_token_ids[-1] + + def check_token_with_sequence(token): + if self.current_state == "idle" or self.current_state == "think": + return ( + token == self.expected_sequence[self.sequence_index] + or token == self.expected_sequence_side[self.sequence_index] + ) + else: + return token == self.expected_sequence[self.sequence_index] + + def check_last_token(token): + if self.current_state == "idle" or self.current_state == "think": + # only return true if it's judge using a side sequence. + if ( + self.sequence_index - 1 < len(self.expected_sequence_side) + and token == self.expected_sequence_side[self.sequence_index - 1] + ): + return self.sequence_index == len(self.expected_sequence_side) + else: + return self.sequence_index == len(self.expected_sequence) + else: + return self.sequence_index == len(self.expected_sequence) + + # Check if token matches expected sequence + token_in_state_seq = check_token_with_sequence(token) + + if token_in_state_seq: + # Store matching token + self.token_buffer.append(token) + self.text_buffer += delta_text + self.sequence_index += 1 + ## state change from idle->think->response->idle + + # Check if sequence fully matched + if check_last_token(token): + # State transition + if self.current_state == "idle": + self.current_state = "think" + self.expected_sequence = response_start_sequence + self.expected_sequence_side = self.response_start_ids_fast + elif self.current_state == "think": + self.current_state = "response" + self.expected_sequence = response_end_sequence + elif self.current_state == "response": + self.current_state = "idle" + self.expected_sequence = think_start_sequence + self.expected_sequence_side = self.think_start_ids_fast + + # Reset matching state + self.sequence_index = 0 + self.token_buffer = [] + self.text_buffer = "" + # Do not send content for state transition texts. + else: + # Sequence broken - handle buffered content + if self.token_buffer and len(self.token_buffer) > 0: + # Send buffered tokens + buffered_content = self.text_buffer + delta_text + # Reset matching state + self.sequence_index = 0 + self.token_buffer = [] + self.text_buffer = "" + + # Return content based on current state + if self.current_state == "think": + return DeltaMessage(reasoning=buffered_content, content=None) + else: + return DeltaMessage(reasoning=None, content=buffered_content) + else: + # No buffered content, send normally + if self.current_state == "think": + return DeltaMessage(reasoning=delta_text, content=None) + else: + return DeltaMessage(reasoning=None, content=delta_text) + + # If no content to send in this delta + return None diff --git a/endpoints/OAI/reasoning/identity_reasoning_parser.py b/endpoints/OAI/reasoning/identity_reasoning_parser.py new file mode 100644 index 00000000..52aa4052 --- /dev/null +++ b/endpoints/OAI/reasoning/identity_reasoning_parser.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser + + +class IdentityReasoningParser(ReasoningParser): + def is_reasoning_end(self, input_ids: list[int]) -> bool: + return True + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + return input_ids + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: list[int], + current_token_ids: list[int], + delta_token_ids: list[int], + ) -> DeltaMessage | None: + if not delta_text: + return None + return DeltaMessage(content=delta_text) + + def extract_reasoning( + self, + model_output: str, + request: Any, + ) -> tuple[str | None, str | None]: + return None, model_output diff --git a/endpoints/OAI/reasoning/kimi_k2_reasoning_parser.py b/endpoints/OAI/reasoning/kimi_k2_reasoning_parser.py new file mode 100644 index 00000000..664d5163 --- /dev/null +++ b/endpoints/OAI/reasoning/kimi_k2_reasoning_parser.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser +from endpoints.OAI.reasoning.deepseek_r1_reasoning_parser import DeepSeekR1ReasoningParser +from endpoints.OAI.reasoning.identity_reasoning_parser import IdentityReasoningParser + + +class KimiK2ReasoningParser(ReasoningParser): + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + chat_kwargs = kwargs.get("chat_template_kwargs", {}) or {} + thinking = bool(chat_kwargs.get("thinking", True)) + self._parser = ( + DeepSeekR1ReasoningParser(tokenizer, *args, **kwargs) + if thinking + else IdentityReasoningParser(tokenizer, *args, **kwargs) + ) + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + return self._parser.is_reasoning_end(input_ids) + + def is_reasoning_end_streaming( + self, input_ids: Sequence[int], delta_ids: Sequence[int] + ) -> bool: + return self._parser.is_reasoning_end_streaming(input_ids, delta_ids) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + return self._parser.extract_content_ids(input_ids) + + def extract_reasoning( + self, model_output: str, request: Any + ) -> tuple[str | None, str | None]: + return self._parser.extract_reasoning(model_output, request) + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + return self._parser.extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) diff --git a/endpoints/OAI/reasoning/minimax_m2_reasoning_parser.py b/endpoints/OAI/reasoning/minimax_m2_reasoning_parser.py new file mode 100644 index 00000000..ca036d05 --- /dev/null +++ b/endpoints/OAI/reasoning/minimax_m2_reasoning_parser.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser +from endpoints.OAI.reasoning.basic_parsers import BaseThinkingReasoningParser + + +class MiniMaxM2ReasoningParser(BaseThinkingReasoningParser): + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + if len(delta_token_ids) == 1 and delta_token_ids[0] == self.end_token_id: + return None + + if self.end_token_id in previous_token_ids: + return DeltaMessage(content=delta_text) + + if self.end_token_id in delta_token_ids: + end_index = delta_text.find(self.end_token) + reasoning = delta_text[:end_index] + content = delta_text[end_index + len(self.end_token) :] + return DeltaMessage(reasoning=reasoning or None, content=content or None) + + return DeltaMessage(reasoning=delta_text) + + +class MiniMaxM2AppendThinkReasoningParser(ReasoningParser): + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + self.end_token_id = self.vocab.get("") + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + return any(input_id == self.end_token_id for input_id in reversed(input_ids)) + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + return input_ids + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + if len(previous_token_ids) == 0: + delta_text = "" + delta_text + return DeltaMessage(content=delta_text) + + def extract_reasoning( + self, model_output: str, request: Any + ) -> tuple[str | None, str | None]: + return None, "" + model_output diff --git a/endpoints/OAI/reasoning/mistral_reasoning_parser.py b/endpoints/OAI/reasoning/mistral_reasoning_parser.py new file mode 100644 index 00000000..99436dad --- /dev/null +++ b/endpoints/OAI/reasoning/mistral_reasoning_parser.py @@ -0,0 +1,73 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +from endpoints.OAI.reasoning.basic_parsers import BaseThinkingReasoningParser + + +class MistralReasoningParser(BaseThinkingReasoningParser): + @property + def start_token(self) -> str: + return "[THINK]" + + @property + def end_token(self) -> str: + return "[/THINK]" + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + has_eot = False + for token_id in reversed(input_ids): + if token_id == self.start_token_id: + return has_eot + if token_id == self.end_token_id: + has_eot = True + return False + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + has_bot = False + has_eot = False + bot_idx = -1 + eot_idx = -1 + for i, token_id in enumerate(input_ids): + if token_id == self.start_token_id and not has_bot: + has_bot = True + bot_idx = i + elif token_id == self.end_token_id: + has_eot = True + eot_idx = i + break + + if has_bot and not has_eot: + return input_ids[:bot_idx] + if not has_bot and not has_eot: + return input_ids + if has_bot and has_eot: + return input_ids[:bot_idx] + input_ids[eot_idx + 1 :] + return input_ids[:eot_idx] + input_ids[eot_idx + 1 :] + + def extract_reasoning( + self, model_output: str, request: Any + ) -> tuple[str | None, str | None]: + if not model_output: + return None, "" + + prefix, bot, post_bot = model_output.partition(self.start_token) + has_bot = bool(bot) + has_valid_eot = has_bot and self.end_token in post_bot + + if has_bot and has_valid_eot: + reasoning, _, post_eot = post_bot.partition(self.end_token) + content = prefix + post_eot + return reasoning or None, content or None + if has_bot: + return post_bot or None, prefix or None + + if self.end_token in prefix: + pre_eot, _, post_eot = prefix.partition(self.end_token) + return None, (pre_eot + post_eot) or None + + return None, prefix diff --git a/endpoints/OAI/reasoning/olmo3_reasoning_parser.py b/endpoints/OAI/reasoning/olmo3_reasoning_parser.py new file mode 100644 index 00000000..6e5f0948 --- /dev/null +++ b/endpoints/OAI/reasoning/olmo3_reasoning_parser.py @@ -0,0 +1,166 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import dataclasses as dt +import enum +from collections.abc import Sequence +from typing import Any + +try: + import regex as re +except ImportError: # pragma: no cover + import re + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser + + +class Olmo3ReasoningState(enum.Enum): + REASONING = 1 + CONTENT = 2 + + +@dt.dataclass(frozen=True) +class Indices: + start: int + end: int + + def __len__(self): + return self.end - self.start + + +def string_overlap(a: str, b: str) -> tuple[Indices | None, Indices | None]: + a, b, swap = (a, b, False) if len(a) < len(b) else (b, a, True) + + if a in b: + ind_a = Indices(0, len(a)) + ind_b = Indices(b.index(a), b.index(a) + len(a)) + return (ind_b, ind_a) if swap else (ind_a, ind_b) + + for i in range(len(a) - 1, 0, -1): + if a[-i:] == b[:i]: + ind_a = Indices(len(a) - i, len(a)) + ind_b = Indices(0, i) + return (ind_b, ind_a) if swap else (ind_a, ind_b) + + for i in range(len(a) - 1, 0, -1): + if b[-i:] == a[:i]: + ind_a = Indices(0, i) + ind_b = Indices(len(b) - i, len(b)) + return (ind_b, ind_a) if swap else (ind_a, ind_b) + + return None, None + + +@dt.dataclass +class Olmo3ReasoningBuffer: + think_start: str = "" + think_end: str = "" + buffer: str = "" + state: Olmo3ReasoningState = Olmo3ReasoningState.REASONING + + def process_buffer(self) -> DeltaMessage | None: + start_think_idx = self.buffer.find(self.think_start) + if start_think_idx >= 0: + self.state = Olmo3ReasoningState.REASONING + pretext, self.buffer = ( + self.buffer[:start_think_idx], + self.buffer[start_think_idx + len(self.think_start) :], + ) + if start_think_idx > 0: + return DeltaMessage(content=pretext) + + end_think_idx = self.buffer.rfind(self.think_end) + if end_think_idx >= 0: + self.state = Olmo3ReasoningState.CONTENT + pretext, self.buffer = ( + self.buffer[:end_think_idx], + self.buffer[end_think_idx + len(self.think_end) :], + ) + if end_think_idx > 0: + return DeltaMessage(reasoning=pretext) + + if self.state == Olmo3ReasoningState.REASONING: + text_buffer, self.buffer = self.buffer, "" + return DeltaMessage(reasoning=text_buffer) + + if self.state == Olmo3ReasoningState.CONTENT: + text_buffer, self.buffer = self.buffer, "" + return DeltaMessage(content=text_buffer) + + return None + + def add_text(self, delta_text: str) -> DeltaMessage | None: + self.buffer += delta_text + delta_message: DeltaMessage | None = None + + _, overlap_think_start = string_overlap(delta_text, self.think_start) + _, overlap_think_end = string_overlap(delta_text, self.think_end) + + partial_overlap_start = overlap_think_start is not None and len( + overlap_think_start + ) < len(self.think_start) + partial_overlap_end = overlap_think_end is not None and len(overlap_think_end) < len( + self.think_end + ) + + if partial_overlap_start and self.think_start in self.buffer and not partial_overlap_end: + delta_message = self.process_buffer() + elif partial_overlap_end and self.think_end in self.buffer: + delta_message = self.process_buffer() + elif partial_overlap_start or partial_overlap_end: + return None + else: + delta_message = self.process_buffer() + + return delta_message + + +class Olmo3ReasoningParser(ReasoningParser): + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + + self.think_start = r"" + self.think_end = r"" + + reasoning_expr = ( + rf"^(?:{self.think_start})?(?P.*?)" + rf"{self.think_end}(?P.*)$" + ) + self.reasoning_regex = re.compile(reasoning_expr, re.DOTALL) + self.buffer = Olmo3ReasoningBuffer( + think_start=self.think_start, think_end=self.think_end + ) + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + text = self.model_tokenizer.decode(input_ids) + return self.think_end in text + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + return [] + + def extract_reasoning( + self, + model_output: str, + request: Any, + ) -> tuple[str | None, str | None]: + re_match = self.reasoning_regex.match(model_output) + if re_match: + reasoning = re_match.group("reasoning") or None + content = re_match.group("content") or None + return reasoning, content + return None, model_output + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + delta_message = self.buffer.add_text(delta_text) + if delta_message is None and self.buffer.think_end in self.buffer.buffer: + delta_message = self.buffer.process_buffer() + return delta_message diff --git a/endpoints/OAI/reasoning/qwen3_reasoning_parser.py b/endpoints/OAI/reasoning/qwen3_reasoning_parser.py new file mode 100644 index 00000000..7e6a941b --- /dev/null +++ b/endpoints/OAI/reasoning/qwen3_reasoning_parser.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Any + +from endpoints.OAI.reasoning.basic_parsers import BaseThinkingReasoningParser + + +class Qwen3ReasoningParser(BaseThinkingReasoningParser): + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + def extract_reasoning( + self, model_output: str, request: Any + ) -> tuple[str | None, str | None]: + if self.start_token not in model_output or self.end_token not in model_output: + return None, model_output + + _, _, tail = model_output.partition(self.start_token) + reasoning, _, content = tail.partition(self.end_token) + return reasoning or None, content or None diff --git a/endpoints/OAI/reasoning/seedoss_reasoning_parser.py b/endpoints/OAI/reasoning/seedoss_reasoning_parser.py new file mode 100644 index 00000000..6d7a964c --- /dev/null +++ b/endpoints/OAI/reasoning/seedoss_reasoning_parser.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from endpoints.OAI.reasoning.basic_parsers import BaseThinkingReasoningParser + + +class SeedOSSReasoningParser(BaseThinkingReasoningParser): + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" diff --git a/endpoints/OAI/reasoning/step3_reasoning_parser.py b/endpoints/OAI/reasoning/step3_reasoning_parser.py new file mode 100644 index 00000000..bedaf620 --- /dev/null +++ b/endpoints/OAI/reasoning/step3_reasoning_parser.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage, ReasoningParser + + +class Step3ReasoningParser(ReasoningParser): + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + self.think_end_token = "" + self.think_end_token_id = self.vocab.get(self.think_end_token) + if self.think_end_token_id is None: + raise RuntimeError( + "Step3 reasoning parser could not locate think end token in tokenizer" + ) + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + if len(delta_token_ids) == 1 and delta_token_ids[0] == self.think_end_token_id: + return None + + if self.think_end_token_id in delta_token_ids: + end_index = delta_text.find(self.think_end_token) + reasoning = delta_text[:end_index] + content = delta_text[end_index + len(self.think_end_token) :] + return DeltaMessage(reasoning=reasoning, content=content or None) + + if self.think_end_token_id in previous_token_ids: + return DeltaMessage(content=delta_text) + + return DeltaMessage(reasoning=delta_text) + + def extract_reasoning( + self, model_output: str, request: Any + ) -> tuple[str | None, str | None]: + if self.think_end_token not in model_output: + return model_output or None, None + + end_index = model_output.find(self.think_end_token) + reasoning = model_output[:end_index] + content = model_output[end_index + len(self.think_end_token) :] + return reasoning or None, content or None + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + return self.think_end_token_id in input_ids + + def is_reasoning_end_streaming( + self, input_ids: Sequence[int], delta_ids: Sequence[int] + ) -> bool: + return self.think_end_token_id in delta_ids + + def extract_content_ids(self, input_ids: list[int]) -> list[int]: + if self.think_end_token_id not in input_ids[:-1]: + return [] + return input_ids[input_ids.index(self.think_end_token_id) + 1 :] diff --git a/endpoints/OAI/reasoning/step3p5_reasoning_parser.py b/endpoints/OAI/reasoning/step3p5_reasoning_parser.py new file mode 100644 index 00000000..9d56282e --- /dev/null +++ b/endpoints/OAI/reasoning/step3p5_reasoning_parser.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from collections.abc import Sequence +from typing import Any + +from endpoints.OAI.reasoning.abs_reasoning_parsers import DeltaMessage +from endpoints.OAI.reasoning.basic_parsers import BaseThinkingReasoningParser + + +class Step3p5ReasoningParser(BaseThinkingReasoningParser): + @property + def start_token(self) -> str: + return "" + + @property + def end_token(self) -> str: + return "" + + def __init__(self, tokenizer: Any, *args, **kwargs): + super().__init__(tokenizer, *args, **kwargs) + self._pending_reasoning_newline = False + self.end_offset = 1 + + def is_reasoning_end(self, input_ids: Sequence[int]) -> bool: + if self.end_token_id in input_ids and self.end_offset > 0: + self.end_offset -= 1 + return False + return self.end_offset < 1 + + def is_reasoning_end_streaming( + self, input_ids: Sequence[int], delta_ids: Sequence[int] + ) -> bool: + if self.end_token_id in input_ids and self.end_offset > 0: + self.end_offset -= 1 + return False + return self.end_offset < 1 + + def extract_reasoning( + self, model_output: str, request: Any + ) -> tuple[str | None, str | None]: + reasoning, content = super().extract_reasoning(model_output, request) + if reasoning is not None: + reasoning = reasoning.removesuffix("\n") + if content is not None: + content = content.removeprefix("\n") + return reasoning or None, content or None + + def extract_reasoning_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + ) -> DeltaMessage | None: + if previous_text.endswith(self.end_token) and delta_text: + if delta_text == "\n": + return None + if delta_text.startswith("\n"): + remaining = delta_text.removeprefix("\n") + return DeltaMessage(content=remaining) if remaining else None + + ret = super().extract_reasoning_streaming( + previous_text, + current_text, + delta_text, + previous_token_ids, + current_token_ids, + delta_token_ids, + ) + if ret is None: + return None + + if ( + self.start_token_id not in previous_token_ids + and self.start_token_id not in delta_token_ids + ): + if self.end_token_id in delta_token_ids: + end_index = delta_text.find(self.end_token) + reasoning = delta_text[:end_index] + content = delta_text[end_index + len(self.end_token) :] + ret = DeltaMessage(reasoning=reasoning, content=content or None) + elif self.end_token_id in previous_token_ids: + ret = DeltaMessage(content=delta_text) + else: + ret = DeltaMessage(reasoning=delta_text) + + reasoning_to_output = ret.reasoning + content_to_output = ret.content + + if reasoning_to_output is not None: + if self._pending_reasoning_newline: + reasoning_to_output = "\n" + reasoning_to_output + self._pending_reasoning_newline = False + + if reasoning_to_output.endswith("\n"): + reasoning_to_output = reasoning_to_output.removesuffix("\n") + if self.end_token in delta_text: + self._pending_reasoning_newline = False + else: + self._pending_reasoning_newline = True + + if content_to_output is not None: + self.end_offset -= 1 + self._pending_reasoning_newline = False + if self.end_token in delta_text and content_to_output.startswith("\n"): + content_to_output = content_to_output.removeprefix("\n") + + reasoning_to_output = reasoning_to_output or None + content_to_output = content_to_output or None + if reasoning_to_output is None and content_to_output is None: + return None + + return DeltaMessage(reasoning=reasoning_to_output, content=content_to_output) diff --git a/endpoints/OAI/types/chat_completion.py b/endpoints/OAI/types/chat_completion.py index 52523149..61112d4b 100644 --- a/endpoints/OAI/types/chat_completion.py +++ b/endpoints/OAI/types/chat_completion.py @@ -1,10 +1,10 @@ -from pydantic import AliasChoices, BaseModel, Field, field_validator +from pydantic import AliasChoices, BaseModel, Field, field_validator, model_validator from time import time from typing import Literal, Union, List, Optional, Dict from uuid import uuid4 from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest -from endpoints.OAI.types.tools import ToolSpec, ToolCall +from endpoints.OAI.types.tools import NamedToolChoice, ToolSpec, ToolCall class ChatCompletionLogprob(BaseModel): @@ -30,6 +30,8 @@ class ChatCompletionMessagePart(BaseModel): class ChatCompletionMessage(BaseModel): role: str = "user" content: Optional[Union[str, List[ChatCompletionMessagePart]]] = None + reasoning: Optional[str] = None + reasoning_content: Optional[str] = None tool_calls: Optional[List[ToolCall]] = None tool_call_id: Optional[str] = None @@ -49,7 +51,7 @@ class ChatCompletionStreamChoice(BaseModel): # Index is 0 since we aren't using multiple choices index: int = 0 finish_reason: Optional[str] = None - delta: Union[ChatCompletionMessage, dict] = {} + delta: Union[ChatCompletionMessage, dict] = Field(default_factory=dict) logprobs: Optional[ChatCompletionLogprobs] = None @@ -59,18 +61,25 @@ class ChatCompletionRequest(CommonCompletionRequest): prompt_template: Optional[str] = None add_generation_prompt: Optional[bool] = True template_vars: Optional[dict] = Field( - default={}, + default_factory=dict, validation_alias=AliasChoices("template_vars", "chat_template_kwargs"), description="Aliases: chat_template_kwargs", ) + enable_thinking: Optional[bool] = None + thinking: Optional[bool] = None response_prefix: Optional[str] = None model: Optional[str] = None + include_reasoning: Optional[bool] = True # tools is follows the format OAI schema, functions is more flexible # both are available in the chat template. tools: Optional[List[ToolSpec]] = None functions: Optional[List[Dict]] = None + tool_choice: Optional[ + Union[Literal["none", "auto", "required"], NamedToolChoice] + ] = None + parallel_tool_calls: Optional[bool] = True # Chat completions requests do not have a BOS token preference. Backend # respects the tokenization config for the individual model. @@ -81,6 +90,20 @@ def force_bos_token(cls, v): """Always disable add_bos_token with chat completions.""" return None + @model_validator(mode="after") + def apply_thinking_aliases(self): + """Support clients that send thinking flags at the top-level.""" + template_vars = dict(self.template_vars or {}) + + if self.enable_thinking is not None and "enable_thinking" not in template_vars: + template_vars["enable_thinking"] = self.enable_thinking + + if self.thinking is not None and "thinking" not in template_vars: + template_vars["thinking"] = self.thinking + + self.template_vars = template_vars + return self + class ChatCompletionResponse(BaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{uuid4().hex}") diff --git a/endpoints/OAI/types/tools.py b/endpoints/OAI/types/tools.py index b5b9611f..1e572663 100644 --- a/endpoints/OAI/types/tools.py +++ b/endpoints/OAI/types/tools.py @@ -1,5 +1,5 @@ from pydantic import BaseModel, Field -from typing import Dict, Literal +from typing import Dict, Literal, Optional from uuid import uuid4 @@ -28,8 +28,28 @@ class Tool(BaseModel): class ToolCall(BaseModel): - """Represents an OAI tool description.""" + """Represents an OAI tool call. + + The ``index`` field is optional so it can be omitted in non-streaming + responses (where OpenAI does not include it) via ``exclude_none=True``, + while being set explicitly for streaming deltas where it is required + by strict validators like the Vercel AI SDK. + """ - id: str = Field(default_factory=lambda: str(uuid4()).replace("-", "")[:9]) + id: str = Field(default_factory=lambda: f"call_{uuid4().hex[:24]}") function: Tool type: Literal["function"] = "function" + index: Optional[int] = None + + +class NamedToolFunction(BaseModel): + """Represents a named function reference for tool_choice.""" + + name: str + + +class NamedToolChoice(BaseModel): + """Represents a named tool choice (forces a specific function call).""" + + function: NamedToolFunction + type: Literal["function"] = "function" diff --git a/endpoints/OAI/utils/chat_completion.py b/endpoints/OAI/utils/chat_completion.py index b559bb2b..a3f302ef 100644 --- a/endpoints/OAI/utils/chat_completion.py +++ b/endpoints/OAI/utils/chat_completion.py @@ -1,8 +1,10 @@ """Chat completion utilities for OAI server.""" import asyncio +import json import pathlib from asyncio import CancelledError +from dataclasses import dataclass, field from typing import List, Optional from fastapi import HTTPException, Request from jinja2 import TemplateError @@ -16,7 +18,9 @@ handle_request_error, request_disconnect_loop, ) +from common.tabby_config import config from common.utils import unwrap +from endpoints.OAI.reasoning import ReasoningParserManager from endpoints.OAI.types.chat_completion import ( ChatCompletionLogprobs, ChatCompletionLogprob, @@ -28,24 +32,198 @@ ChatCompletionStreamChoice, ) from endpoints.OAI.types.common import UsageStats +from endpoints.OAI.types.tools import NamedToolChoice, ToolCall from endpoints.OAI.utils.completion import _parse_gen_request_id, _stream_collector +from endpoints.OAI.utils.parser_options import ( + list_tool_call_parsers, + parser_uses_native_tool_generation, + resolve_tool_call_format, +) from endpoints.OAI.utils.tools import ToolCallProcessor, TOOL_CALL_SCHEMA +@dataclass +class _StreamReasoningState: + text: str = "" + token_ids: List[int] = field(default_factory=list) + + +class _TokenizerAdapter: + """Expose the minimal tokenizer interface required by reasoning parsers.""" + + def __init__(self): + self._vocab = None + + def get_vocab(self) -> dict[str, int]: + if self._vocab is not None: + return self._vocab + + tokenizer = model.container.tokenizer + if hasattr(tokenizer, "get_vocab"): + self._vocab = tokenizer.get_vocab() + return self._vocab + + pieces = tokenizer.get_id_to_piece_list(True) + vocab: dict[str, int] = {} + for token_id, piece in enumerate(pieces): + if piece not in vocab: + vocab[piece] = token_id + self._vocab = vocab + return vocab + + +def _token_ids_from_generation(generation: dict) -> List[int]: + token_ids = generation.get("token_ids") + if token_ids is None: + return [] + if isinstance(token_ids, list): + return token_ids + if hasattr(token_ids, "flatten"): + return token_ids.flatten().tolist() + return list(token_ids) + + +def _build_reasoning_parser(request_data: ChatCompletionRequest): + parser_key = unwrap(config.model.reasoning_parser, "basic") or "basic" + try: + parser_cls = ReasoningParserManager.get_reasoning_parser(parser_key) + except KeyError as exc: + raise HTTPException(400, str(exc)) from exc + + template_kwargs = unwrap(request_data.template_vars, {}) + try: + return parser_cls(_TokenizerAdapter(), chat_template_kwargs=template_kwargs) + except RuntimeError as exc: + # Keep compatibility for models that do not expose thinking tags. + if parser_key == "basic": + logger.warning( + "Reasoning parser 'basic' could not initialize ({}). " + "Falling back to identity parser.", + str(exc), + ) + identity_cls = ReasoningParserManager.get_reasoning_parser("identity") + return identity_cls(_TokenizerAdapter(), chat_template_kwargs=template_kwargs) + raise HTTPException(400, str(exc)) from exc + + +def _validate_and_get_tool_call_format( + request_data: ChatCompletionRequest, default_format: str +) -> str: + tool_choice = request_data.tool_choice + parser_key = config.model.tool_call_parser + enable_auto = bool(config.model.enable_auto_tool_choice) + parser_names = list_tool_call_parsers() + + if parser_key and parser_key not in parser_names: + parsers_str = ", ".join(sorted(parser_names)) + raise HTTPException( + 400, + f"invalid tool call parser: {parser_key} (choose from {{{parsers_str}}})", + ) + + if tool_choice == "auto" and (not enable_auto or not parser_key): + raise HTTPException( + 400, + '"auto" tool choice requires --enable-auto-tool-choice and ' + "--tool-call-parser to be set", + ) + + if tool_choice not in (None, "none", "auto") and parser_key is None: + raise HTTPException( + 400, + f'tool_choice="{tool_choice}" requires --tool-call-parser to be set', + ) + + if ( + tool_choice == "none" + and config.model.exclude_tools_when_tool_choice_none + and request_data.tools + ): + request_data.tools = None + + resolved_format = resolve_tool_call_format(parser_key, default_format) + if not resolved_format: + raise HTTPException( + 400, + f"Could not resolve format for tool_call_parser={parser_key}", + ) + return resolved_format + + +def _serialize_stream_chunk(chunk) -> str: + """Serialize a streaming chunk with OpenAI-compatible field handling. + + Uses exclude_none=True to strip irrelevant null fields (tool_calls, + tool_call_id, logprobs, usage) while ensuring finish_reason is always + present on each choice (as null when not set), matching OpenAI's + observed streaming behavior. + """ + d = chunk.model_dump(exclude_none=True) + for choice in d.get("choices", []): + if "finish_reason" not in choice: + choice["finish_reason"] = None + return json.dumps(d, ensure_ascii=False) + + def _create_response( - request_id: str, generations: List[dict], model_name: Optional[str] + request_id: str, + generations: List[dict], + model_name: Optional[str], + tool_call_format: str = "json", + tool_choice=None, ): """Create a chat completion response from the provided text.""" choices = [] + parser_key = config.model.tool_call_parser for index, generation in enumerate(generations): + reasoning = generation.get("reasoning") + reasoning_content = generation.get("reasoning_content") message = ChatCompletionMessage( - role="assistant", content=unwrap(generation.get("text"), "") + role="assistant", + content=generation.get("text"), + reasoning=reasoning, + reasoning_content=reasoning_content, ) - tool_calls = generation["tool_calls"] - if tool_calls: - message.tool_calls = ToolCallProcessor.from_json(tool_calls) + tool_calls_raw = generation.get("tool_calls") + if tool_calls_raw: + parsed = ToolCallProcessor.parse( + tool_calls_raw, + format=tool_call_format, + parser_key=parser_key, + ) + if parsed and isinstance(tool_choice, NamedToolChoice): + parsed = ToolCallProcessor.filter_by_name( + parsed, tool_choice.function.name + ) + if parsed: + message.tool_calls = parsed + message.content = None + else: + logger.warning( + "Tool call text present but parsing returned no results " + f"(format={tool_call_format})" + ) + + # Fallback: detect bare XML tool calls in content that were not + # caught by the two-pass system (model never emitted tool_start) + if ( + tool_call_format in ("xml", "auto") + and not message.tool_calls + and message.content + and " List[ChatCompletionStreamChunk]: + """Build the OpenAI-standard streaming sequence for tool calls. + + Emits two chunks: + 1. Tool-call chunk: role="assistant", complete tool_calls with + index/id/type/name/arguments (all data in one chunk). + 2. Finish chunk: empty delta, finish_reason="tool_calls". + + Complete arguments are sent in a single chunk rather than streamed + incrementally, which is valid per OpenAI's spec (clients concatenate + argument strings across deltas) and maximizes compatibility with + clients that may not implement multi-chunk tool-call assembly. + + The tool_calls are placed directly into a ChatCompletionMessage + (not a raw dict) so Pydantic validates them as ToolCall objects + with the index field preserved (ToolCall declares index as Optional[int]). + """ + chunk_id = f"chatcmpl-{request_id}" + + # Set index on each tool call for streaming + for idx, tc in enumerate(tool_calls): + tc.index = idx + + # Chunk 1: Complete tool call data + tool_call_message = ChatCompletionMessage( + role="assistant", + tool_calls=tool_calls, + ) + tool_chunk = ChatCompletionStreamChunk( + id=chunk_id, + choices=[ + ChatCompletionStreamChoice( + index=choice_index, + delta=tool_call_message, + finish_reason=None, + ) + ], + model=model_name, + ) + + # Chunk 2: Finish signal + # Use model_construct to prevent Pydantic's smart Union from + # coercing the empty dict {} into ChatCompletionMessage(role="user") + finish_choice = ChatCompletionStreamChoice.model_construct( + index=choice_index, + delta={}, + finish_reason="tool_calls", + logprobs=None, + ) + finish_chunk = ChatCompletionStreamChunk( + id=chunk_id, + choices=[finish_choice], + model=model_name, + ) + + return [tool_chunk, finish_chunk] + + async def _append_template_metadata(data: ChatCompletionRequest, template_vars: dict): """Adding metadata is a one-time process.""" @@ -237,6 +478,24 @@ async def format_messages_with_template( message_dicts.append(message.model_dump(exclude_none=True)) + # Pre-template: convert tool_call arguments from JSON strings to dicts. + # OpenAI-compatible clients (Kilo, Roo, etc.) send arguments as JSON + # strings per the OAI spec, but Qwen3-Coder's template calls + # .items() on arguments which requires a dict/mapping. + for msg in message_dicts: + if msg.get("tool_calls"): + for tc in msg["tool_calls"]: + func = tc.get("function", {}) + args = func.get("arguments") + if isinstance(args, str): + try: + func["arguments"] = json.loads(args) + except (json.JSONDecodeError, ValueError): + logger.warning( + "Failed to parse tool_call arguments JSON " + "string to dict, keeping as string" + ) + # Get all special tokens special_tokens_dict = model.container.get_special_tokens() @@ -319,10 +578,19 @@ async def stream_generate_chat_completion( gen_queue = asyncio.Queue() gen_tasks: List[asyncio.Task] = [] tool_start = model.container.prompt_template.metadata.tool_start + default_tool_call_format = model.container.prompt_template.metadata.tool_call_format disconnect_task = asyncio.create_task(request_disconnect_loop(request)) try: logger.info(f"Received chat completion streaming request {request.state.id}") + tool_call_format = _validate_and_get_tool_call_format( + data, default_tool_call_format + ) + reasoning_parser = _build_reasoning_parser(data) + reasoning_states = [_StreamReasoningState() for _ in range(0, data.n)] + force_tool_pass = data.tool_choice == "required" or isinstance( + data.tool_choice, NamedToolChoice + ) for idx in range(0, data.n): task_gen_params = data.model_copy(deep=True) @@ -342,18 +610,67 @@ async def stream_generate_chat_completion( gen_tasks.append(gen_task) - # Text accumulation for tool calls - current_generation_text = "" - # Consumer loop while True: + # Fast path: items already queued — no task overhead + if not gen_queue.empty(): + generation = gen_queue.get_nowait() + else: + # Slow path: queue empty — race get against disconnect + get_task = asyncio.create_task(gen_queue.get()) + done, _ = await asyncio.wait( + [get_task, disconnect_task], + return_when=asyncio.FIRST_COMPLETED, + ) + if disconnect_task in done: + get_task.cancel() + raise CancelledError() + generation = get_task.result() + if disconnect_task.done(): raise CancelledError() - generation = await gen_queue.get() + # Stream collector will push an exception to the queue if it fails + if isinstance(generation, Exception): + raise generation + + if "text" in generation and generation.get("finish_reason") is None: + idx = generation["index"] + state = reasoning_states[idx] + + delta_text = unwrap(generation.get("text"), "") + delta_token_ids = _token_ids_from_generation(generation) + + current_text = state.text + delta_text + current_token_ids = state.token_ids + delta_token_ids + + delta_message = reasoning_parser.extract_reasoning_streaming( + state.text, + current_text, + delta_text, + state.token_ids, + current_token_ids, + delta_token_ids, + ) + + state.text = current_text + state.token_ids = current_token_ids + + if delta_message is None: + continue + + generation["text"] = delta_message.content + if data.include_reasoning: + generation["reasoning"] = delta_message.reasoning + generation["reasoning_content"] = delta_message.reasoning + else: + generation["reasoning"] = None + generation["reasoning_content"] = None + if generation["text"] is None: + continue # Handle options if a tool model is present - if tool_start: + if (tool_start or force_tool_pass) and data.tool_choice != "none": if "stop_str" in generation: generations = await generate_tool_calls( prompt, @@ -361,21 +678,64 @@ async def stream_generate_chat_completion( data, [generation], request, + tool_call_format=tool_call_format, ) # Only one generation present in this case generation = generations[0] - elif "text" in generation: - current_generation_text += generation["text"] - # Stream collector will push an exception to the queue if it fails - if isinstance(generation, Exception): - raise generation + # Emit proper three-phase tool-call streaming sequence + if "tool_calls" in generation: + tool_calls_raw = generation["tool_calls"] + parsed = ToolCallProcessor.parse( + tool_calls_raw, + format=tool_call_format, + parser_key=config.model.tool_call_parser, + ) + if parsed and isinstance(data.tool_choice, NamedToolChoice): + parsed = ToolCallProcessor.filter_by_name( + parsed, data.tool_choice.function.name + ) + if parsed: + for tc_chunk in _build_tool_call_chunks( + parsed, + request.state.id, + model_path.name, + choice_index=generation.get("index", 0), + ): + yield _serialize_stream_chunk(tc_chunk) + + # Handle completion and usage after tool calls + if ( + all(task.done() for task in gen_tasks) + and gen_queue.empty() + ): + if ( + data.stream_options + and data.stream_options.include_usage + ): + usage_chunk = _create_stream_chunk( + request.state.id, + generation, + model_path.name, + is_usage_chunk=True, + ) + yield _serialize_stream_chunk(usage_chunk) + + logger.info( + "Finished chat completion streaming " + f"request {request.state.id}" + ) + yield "[DONE]" + break + continue response = _create_stream_chunk( - request.state.id, generation, model_path.name + request.state.id, + generation, + model_path.name, ) - yield response.model_dump_json() + yield _serialize_stream_chunk(response) # Check if all tasks are completed if all(task.done() for task in gen_tasks) and gen_queue.empty(): @@ -387,7 +747,7 @@ async def stream_generate_chat_completion( model_path.name, is_usage_chunk=True, ) - yield usage_chunk.model_dump_json() + yield _serialize_stream_chunk(usage_chunk) logger.info( f"Finished chat completion streaming request {request.state.id}" @@ -398,13 +758,16 @@ async def stream_generate_chat_completion( except CancelledError: # Get out if the request gets disconnected - if not abort_event.is_set(): - abort_event.set() - handle_request_disconnect("Chat completion generation cancelled by user.") + handle_request_disconnect("Chat completion generation cancelled by user.") + except HTTPException as exc: + yield get_generator_error(str(exc.detail)) except Exception: yield get_generator_error( "Chat completion aborted. Please check the server console." ) + finally: + abort_event.set() + disconnect_task.cancel() async def generate_chat_completion( @@ -416,6 +779,10 @@ async def generate_chat_completion( ): gen_tasks: List[asyncio.Task] = [] tool_start = model.container.prompt_template.metadata.tool_start + default_tool_call_format = model.container.prompt_template.metadata.tool_call_format + tool_call_format = _validate_and_get_tool_call_format( + data, default_tool_call_format + ) try: logger.info(f"Received chat completion request {request.state.id}") @@ -437,16 +804,46 @@ async def generate_chat_completion( generations = await asyncio.gather(*gen_tasks) # Check all the generations and see if a tool call is required - if tool_start: + force_tool_pass = data.tool_choice == "required" or isinstance( + data.tool_choice, NamedToolChoice + ) + if tool_start or force_tool_pass: generations = await generate_tool_calls( - prompt, embeddings, data, generations, request + prompt, + embeddings, + data, + generations, + request, + tool_call_format=tool_call_format, ) - response = _create_response(request.state.id, generations, model_path.name) + reasoning_parser = _build_reasoning_parser(data) + for generation in generations: + reasoning, content = reasoning_parser.extract_reasoning( + unwrap(generation.get("text"), ""), + data, + ) + + if not data.include_reasoning: + reasoning = None + + generation["reasoning"] = reasoning + generation["reasoning_content"] = reasoning + generation["text"] = content + + response = _create_response( + request.state.id, + generations, + model_path.name, + tool_call_format=tool_call_format, + tool_choice=data.tool_choice, + ) logger.info(f"Finished chat completion request {request.state.id}") return response + except HTTPException: + raise except Exception as exc: error_message = handle_request_error( f"Chat completion {request.state.id} aborted. " @@ -462,29 +859,88 @@ async def generate_tool_calls( prompt: str, embeddings: MultimodalEmbeddingWrapper, data: ChatCompletionRequest, - generations: List[str], + generations: List[dict], request: Request, + tool_call_format: Optional[str] = None, ): gen_tasks: List[asyncio.Task] = [] tool_start = model.container.prompt_template.metadata.tool_start + if tool_call_format is None: + default_tool_call_format = model.container.prompt_template.metadata.tool_call_format + tool_call_format = _validate_and_get_tool_call_format( + data, default_tool_call_format + ) + tool_choice = data.tool_choice + parser_key = config.model.tool_call_parser + use_native_generation = parser_uses_native_tool_generation( + parser_key, tool_call_format + ) + + if tool_choice == "none": + return generations # Tracks which generations asked for a tool call tool_idx: List[int] = [] # Copy to make sure the parent JSON schema doesn't get modified tool_data = data.model_copy(deep=True) - tool_data.json_schema = TOOL_CALL_SCHEMA + + if use_native_generation: + # Native syntax mode: let the model generate its natural tool-call + # representation without JSON schema constraint. + logger.debug( + "generate_tool_calls: Using parser '{}' in native mode " + "(format={}, no JSON schema constraint)", + parser_key or "template-default", + tool_call_format, + ) + + # Remove tool_start from stop strings so the model can emit + # multiple sequential blocks without stopping early + if ( + tool_start + and isinstance(tool_data.stop, list) + and tool_start in tool_data.stop + ): + tool_data.stop = [s for s in tool_data.stop if s != tool_start] + logger.debug( + f"generate_tool_calls: Removed '{tool_start}' from " + f"second-pass stop strings" + ) + else: + # JSON mode: constrained generation (existing behavior) + tool_data.json_schema = TOOL_CALL_SCHEMA for idx, gen in enumerate(generations): - if gen["stop_str"] != tool_start: + stop_str = gen.get("stop_str") + should_generate = stop_str == tool_start + + # Force tool generation if tool_choice requires it + if not should_generate and ( + tool_choice == "required" or isinstance(tool_choice, NamedToolChoice) + ): + should_generate = True + + if not should_generate: continue - logger.info(f"Detected tool call in chat completion request {request.state.id}") + logger.info( + f"Detected tool call in chat completion request " + f"{request.state.id} (format={tool_call_format})" + ) - # Append the existing generation text if present + # Build per-generation prompt (avoid mutating shared prompt) + tool_prompt = prompt precursor_text = gen.get("full_text") if precursor_text: - prompt = prompt + precursor_text + tool_prompt = tool_prompt + precursor_text + + # For native generation mode: append tool_start back to prompt. + # The stop string was consumed by the first pass and not included + # in full_text, but the model expects to continue after tool_start. + # Include a trailing newline to match the canonical template format. + if use_native_generation and tool_start: + tool_prompt = tool_prompt + tool_start + "\n" gen_request_id = gen.get("request_id") tool_request_id = f"{gen_request_id}-tool" @@ -493,7 +949,7 @@ async def generate_tool_calls( asyncio.create_task( model.container.generate( tool_request_id, - prompt, + tool_prompt, tool_data, mm_embeddings=embeddings, ) @@ -507,6 +963,12 @@ async def generate_tool_calls( # Map tool calls to their appropriate generation for gen_idx, tool_call in zip(tool_idx, tool_calls, strict=True): - generations[gen_idx]["tool_calls"] = tool_call["text"] + raw_text = tool_call["text"] + + if use_native_generation and tool_start: + # Prepend tool_start to reconstruct complete native payload. + raw_text = tool_start + "\n" + raw_text + + generations[gen_idx]["tool_calls"] = raw_text return generations diff --git a/endpoints/OAI/utils/completion.py b/endpoints/OAI/utils/completion.py index f66d381d..c11a25bf 100644 --- a/endpoints/OAI/utils/completion.py +++ b/endpoints/OAI/utils/completion.py @@ -225,11 +225,24 @@ async def stream_generate_completion( # Consumer loop while True: + # Fast path: items already queued — no task overhead + if not gen_queue.empty(): + generation = gen_queue.get_nowait() + else: + # Slow path: queue empty — race get against disconnect + get_task = asyncio.create_task(gen_queue.get()) + done, _ = await asyncio.wait( + [get_task, disconnect_task], + return_when=asyncio.FIRST_COMPLETED, + ) + if disconnect_task in done: + get_task.cancel() + raise CancelledError() + generation = get_task.result() + if disconnect_task.done(): raise CancelledError() - generation = await gen_queue.get() - # Stream collector will push an exception to the queue if it fails if isinstance(generation, Exception): raise generation @@ -245,15 +258,16 @@ async def stream_generate_completion( except CancelledError: # Get out if the request gets disconnected - if not abort_event.is_set(): - abort_event.set() - handle_request_disconnect( - f"Completion generation {request.state.id} cancelled by user." - ) + handle_request_disconnect( + f"Completion generation {request.state.id} cancelled by user." + ) except Exception: yield get_generator_error( f"Completion {request.state.id} aborted. Please check the server console." ) + finally: + abort_event.set() + disconnect_task.cancel() async def generate_completion( diff --git a/endpoints/OAI/utils/parser_options.py b/endpoints/OAI/utils/parser_options.py new file mode 100644 index 00000000..e1da446c --- /dev/null +++ b/endpoints/OAI/utils/parser_options.py @@ -0,0 +1,99 @@ +"""Parser option helpers for vLLM-compatible chat settings.""" + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from typing import Dict, Set + + +# Mirrors vLLM parser keys to keep CLI/config ergonomics familiar. +# Source of truth: vllm/tool_parsers/__init__.py::_TOOL_PARSERS_TO_REGISTER +# Format is the fallback parsing mode supported by ToolCallProcessor. +TOOL_CALL_PARSER_FORMATS: Dict[str, str] = { + "deepseek_v3": "json", + "deepseek_v31": "json", + "deepseek_v32": "json", + "ernie45": "json", + "glm45": "json", + "glm47": "json", + "granite-20b-fc": "json", + "granite": "json", + "hermes": "json", + "hunyuan_a13b": "json", + "internlm": "json", + "jamba": "json", + "kimi_k2": "json", + "llama3_json": "json", + "llama4_json": "json", + "llama4_pythonic": "json", + "longcat": "json", + "minimax_m2": "json", + "minimax": "json", + "mistral": "json", + "olmo3": "json", + "openai": "json", + "phi4_mini_json": "json", + "pythonic": "json", + "qwen3_coder": "xml", + "qwen3_xml": "xml", + "seed_oss": "json", + "step3": "json", + "step3p5": "json", + "xlam": "json", + "gigachat3": "json", + "functiongemma": "json", + # Convenience alias for mixed/inferred content + "auto": "auto", +} + +# Compatibility aliases accepted by this server. +# Keys are user-facing parser names, values are canonical parser keys. +TOOL_CALL_PARSER_ALIASES: Dict[str, str] = { + "llama": "llama3_json", +} + +# Parsers that should generate tool calls in their native syntax on tool pass +# (no JSON schema constraint). Most JSON-style parsers should stay constrained. +NATIVE_TOOL_GENERATION_PARSERS: Set[str] = { + "auto", + "deepseek_v3", + "deepseek_v31", + "deepseek_v32", + "llama4_pythonic", + "pythonic", + "qwen3_coder", + "qwen3_xml", +} + + +def resolve_tool_call_parser_key(tool_call_parser: str | None) -> str | None: + """Normalize a user parser key to its canonical key.""" + if not tool_call_parser: + return None + return TOOL_CALL_PARSER_ALIASES.get(tool_call_parser, tool_call_parser) + + +def list_tool_call_parsers() -> Set[str]: + return set(TOOL_CALL_PARSER_FORMATS.keys()).union(TOOL_CALL_PARSER_ALIASES.keys()) + + +def resolve_tool_call_format( + tool_call_parser: str | None, fallback_format: str +) -> str: + """Resolve effective parser format from configured parser key.""" + if not tool_call_parser: + return fallback_format + parser_key = resolve_tool_call_parser_key(tool_call_parser) + return TOOL_CALL_PARSER_FORMATS.get(parser_key, "") + + +def parser_uses_native_tool_generation( + tool_call_parser: str | None, fallback_format: str +) -> bool: + """Whether tool pass should use native model format (unconstrained).""" + if not tool_call_parser: + return fallback_format in ("xml", "auto") + parser_key = resolve_tool_call_parser_key(tool_call_parser) + if parser_key in NATIVE_TOOL_GENERATION_PARSERS: + return True + return resolve_tool_call_format(parser_key, fallback_format) in ("xml", "auto") diff --git a/endpoints/OAI/utils/tools.py b/endpoints/OAI/utils/tools.py index c1ebdedf..bf6e2a0d 100644 --- a/endpoints/OAI/utils/tools.py +++ b/endpoints/OAI/utils/tools.py @@ -1,8 +1,16 @@ +"""Tool call processing utilities for OAI server.""" + +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import ast import json +import re from loguru import logger -from typing import List +from typing import Any, Callable, Dict, List, Tuple -from endpoints.OAI.types.tools import ToolCall +from endpoints.OAI.types.tools import ToolCall, Tool +from endpoints.OAI.utils.parser_options import resolve_tool_call_parser_key TOOL_CALL_SCHEMA = { @@ -27,24 +35,807 @@ }, } +# --------------------------------------------------------------------------- +# XML parsing regex patterns +# Derived from vLLM's Qwen3CoderToolParser and the official Qwen parser. +# These handle both complete and partially-closed tags. +# --------------------------------------------------------------------------- + +# Matches complete ... blocks +TOOL_CALL_BLOCK_RE = re.compile( + r"(.*?)", + re.DOTALL, +) + +# Matches BODY blocks +FUNCTION_RE = re.compile( + r"(.*?)", + re.DOTALL, +) + +# Matches VALUE +# Terminates on: , next , or +PARAMETER_RE = re.compile( + r"(.*?)" + r"(?:|(?=)|(?=))", + re.DOTALL, +) + +# Think block patterns +THINK_BLOCK_RE = re.compile(r".*?\s*", re.DOTALL) +THINK_UNCLOSED_RE = re.compile(r"(?!.*).*$", re.DOTALL) + +# Markdown code fence patterns +CODE_FENCE_RE = re.compile(r"^```(?:json)?\s*", re.MULTILINE) +CODE_FENCE_END_RE = re.compile(r"\s*```\s*$", re.MULTILINE) + +# DeepSeek family patterns +DEEPSEEK_V31_CALL_RE = re.compile( + r"<|tool▁call▁begin|>(?P.*?)<|tool▁sep|>(?P.*?)<|tool▁call▁end|>", + re.DOTALL, +) +DEEPSEEK_V3_CALL_RE = re.compile( + r"<|tool▁call▁begin|>(?P.*?)<|tool▁sep|>(?P.*?)\n```json\n(?P.*?)\n```(?:\s*)<|tool▁call▁end|>", # noqa: E501 + re.DOTALL, +) +DEEPSEEK_V32_INVOKE_RE = re.compile( + r'<|DSML|invoke\s+name="(?P[^"]+)"\s*>(?P.*?)', + re.DOTALL, +) +DEEPSEEK_V32_PARAM_RE = re.compile( + r'<|DSML|parameter\s+name="(?P[^"]+)"(?:\s+string="(?Ptrue|false)")?\s*>(?P.*?)', # noqa: E501 + re.DOTALL, +) + + +def _strip_think_blocks(text: str) -> str: + """Strip ... blocks from text. + + Handles both complete and unclosed blocks (quantization can cause + the model to never close a think tag). + """ + original = text + + # Complete blocks first + text = THINK_BLOCK_RE.sub("", text) + + # Unclosed block (think started but never closed — strip to end) + text = THINK_UNCLOSED_RE.sub("", text) + + if text != original: + if THINK_UNCLOSED_RE.search(original): + logger.warning( + "XML Parser: Stripped unclosed block " + "(possible quantization degradation)" + ) + else: + logger.debug("XML Parser: Stripped block(s) from output") + + return text + + +def _coerce_param_value(raw: str) -> Any: + """Coerce a raw parameter value string to the appropriate Python type. + + Strategy (safe, no eval()): + 1. Strip leading/trailing newlines (official template emits \\n + after opening tag and before closing tag). + 2. Try json.loads — handles objects, arrays, numbers, bools, null. + 3. Fall back to plain string. + """ + # Strip template-inserted newlines around values + if raw.startswith("\n"): + raw = raw[1:] + if raw.endswith("\n"): + raw = raw[:-1] + + stripped = raw.strip() + + # Empty string + if not stripped: + return "" + + # Try JSON parse (handles objects, arrays, numbers, booleans, null) + try: + return json.loads(stripped) + except (json.JSONDecodeError, ValueError): + pass + + # Fall back to string — never eval() + return stripped + class ToolCallProcessor: + _PARSER_DISPATCHER: Dict[str, Callable[[str], List[ToolCall]]] = {} + + # ------------------------------------------------------------------ + # JSON normalization helpers + # ------------------------------------------------------------------ + + @staticmethod + def _normalize_tool_calls(raw) -> list: + """Normalize model-emitted tool call payloads into OAI-like objects. + + Accepted forms: + - [{"type":"function","function":{"name":...,"arguments":{...}}}] + - [{"name":...,"arguments":{...}}] + - {"name":...,"arguments":{...}} + """ + if isinstance(raw, dict): + raw = [raw] + if not isinstance(raw, list): + raise ValueError("tool_calls payload is not list/dict") + + normalized: list = [] + for item in raw: + if not isinstance(item, dict): + continue + + if "function" in item and isinstance(item["function"], dict): + fn = item["function"] + name = fn.get("name") + arguments = fn.get("arguments", {}) + else: + name = item.get("name") + arguments = item.get("arguments", {}) + + if name is None: + continue + + if isinstance(arguments, str): + try: + arguments = json.loads(arguments) + except json.JSONDecodeError: + arguments = {"input": arguments} + + normalized.append( + { + "type": "function", + "function": { + "name": name, + "arguments": arguments if isinstance(arguments, dict) else {}, + }, + } + ) + return normalized + + @staticmethod + def _safe_json_loads(payload: str) -> list: + """Best-effort JSON parse for model-emitted tool payloads. + + Handles: clean JSON, markdown-fenced JSON, JSON substrings in + surrounding text, flat {name, arguments} dicts, and single objects. + """ + # Direct parse + try: + return ToolCallProcessor._normalize_tool_calls(json.loads(payload)) + except (json.JSONDecodeError, ValueError): + pass + + # Clean up common model artifacts (markdown fences, whitespace) + cleaned = payload.strip() + cleaned = CODE_FENCE_RE.sub("", cleaned) + cleaned = CODE_FENCE_END_RE.sub("", cleaned) + cleaned = cleaned.strip() + + # Try cleaned + try: + return ToolCallProcessor._normalize_tool_calls(json.loads(cleaned)) + except (json.JSONDecodeError, ValueError): + pass + + # Find JSON array substring + start = cleaned.find("[") + end = cleaned.rfind("]") + if start != -1 and end != -1 and end > start: + try: + return ToolCallProcessor._normalize_tool_calls( + json.loads(cleaned[start : end + 1]) + ) + except (json.JSONDecodeError, ValueError): + pass + + # Find JSON object substring + obj_start = cleaned.find("{") + obj_end = cleaned.rfind("}") + if obj_start != -1 and obj_end != -1 and obj_end > obj_start: + try: + return ToolCallProcessor._normalize_tool_calls( + json.loads(cleaned[obj_start : obj_end + 1]) + ) + except (json.JSONDecodeError, ValueError): + pass + + raise json.JSONDecodeError( + "Could not extract valid JSON from payload", payload, 0 + ) + + @staticmethod + def _build_tool_calls_from_normalized(raw: Any) -> List[ToolCall]: + """Normalize dict/list payload and build ToolCall models.""" + normalized = ToolCallProcessor._normalize_tool_calls(raw) + for tool_call in normalized: + tool_call["function"]["arguments"] = json.dumps( + tool_call["function"]["arguments"], ensure_ascii=False + ) + return [ToolCall(**tool_call) for tool_call in normalized] + + @staticmethod + def _decode_json_sequence(text: str) -> List[Any]: + """Decode multiple JSON values from a single string.""" + decoder = json.JSONDecoder() + values: List[Any] = [] + idx = 0 + while idx < len(text): + while idx < len(text) and text[idx] in " \t\r\n,;": + idx += 1 + if idx >= len(text): + break + if text.startswith("<|python_tag|>", idx): + idx += len("<|python_tag|>") + continue + try: + value, end = decoder.raw_decode(text[idx:]) + except json.JSONDecodeError: + break + values.append(value) + idx += end + return values + + @staticmethod + def _coerce_argument_payload(arguments_raw: str) -> str: + """Normalize raw argument payload to a JSON string where possible.""" + payload = arguments_raw.strip() + if not payload: + return "{}" + try: + return json.dumps(json.loads(payload), ensure_ascii=False) + except (json.JSONDecodeError, ValueError, TypeError): + return payload + + @staticmethod + def _ast_to_literal(node: ast.AST) -> Any: + """Safely convert AST literal nodes to Python primitives.""" + if isinstance(node, ast.Constant): + return node.value + if isinstance(node, ast.List): + return [ToolCallProcessor._ast_to_literal(item) for item in node.elts] + if isinstance(node, ast.Tuple): + return [ToolCallProcessor._ast_to_literal(item) for item in node.elts] + if isinstance(node, ast.Dict): + result = {} + for key, value in zip(node.keys, node.values): + literal_key = ToolCallProcessor._ast_to_literal(key) # type: ignore[arg-type] + if not isinstance(literal_key, str): + raise ValueError("pythonic parser requires string dict keys") + result[literal_key] = ToolCallProcessor._ast_to_literal(value) + return result + if isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub): + return -ToolCallProcessor._ast_to_literal(node.operand) + raise ValueError(f"unsupported pythonic AST node: {type(node).__name__}") + + # ------------------------------------------------------------------ + # JSON parsing + # ------------------------------------------------------------------ + + @staticmethod + def from_hermes(raw_text: str) -> List[ToolCall]: + """Parse Hermes-style JSON tool calls (often wrapped in ).""" + text = _strip_think_blocks(raw_text) + wrapped_calls = [] + for match in TOOL_CALL_BLOCK_RE.finditer(text): + inner = match.group(1).strip() + if not inner: + continue + try: + parsed = json.loads(inner) + except (json.JSONDecodeError, ValueError): + continue + wrapped_calls.extend(ToolCallProcessor._build_tool_calls_from_normalized(parsed)) + + if wrapped_calls: + return wrapped_calls + + return ToolCallProcessor.from_json(text) + + @staticmethod + def from_llama(raw_text: str) -> List[ToolCall]: + """Parse Llama JSON tool calls (single/multiple JSON objects).""" + text = _strip_think_blocks(raw_text).strip() + if text.startswith("<|python_tag|>"): + text = text[len("<|python_tag|>") :].lstrip() + + try: + parsed = ToolCallProcessor.from_json(text) + if parsed: + return parsed + except (json.JSONDecodeError, ValueError, KeyError): + pass + + decoded = ToolCallProcessor._decode_json_sequence(text) + if not decoded: + return [] + + flattened = [] + for item in decoded: + if isinstance(item, list): + flattened.extend(item) + else: + flattened.append(item) + + return ToolCallProcessor._build_tool_calls_from_normalized(flattened) + + @staticmethod + def from_openai(raw_text: str) -> List[ToolCall]: + """Best-effort parser for OpenAI/Harmony-style text payloads.""" + text = _strip_think_blocks(raw_text).strip() + try: + parsed = ToolCallProcessor.from_json(text) + if parsed: + return parsed + except (json.JSONDecodeError, ValueError, KeyError): + pass + + decoded = ToolCallProcessor._decode_json_sequence(text) + tool_calls: List[ToolCall] = [] + normalized_items = [] + for value in decoded: + candidates = value if isinstance(value, list) else [value] + for item in candidates: + if not isinstance(item, dict): + continue + + nested = item.get("tool_calls") + if nested: + try: + tool_calls.extend( + ToolCallProcessor._build_tool_calls_from_normalized(nested) + ) + except (ValueError, KeyError, TypeError): + pass + + recipient = item.get("recipient") + content = item.get("content") + if isinstance(recipient, str) and recipient.startswith("functions."): + fn_name = recipient.split("functions.", 1)[1] + if isinstance(content, str): + payload = ToolCallProcessor._coerce_argument_payload(content) + elif content is None: + payload = "{}" + else: + payload = json.dumps(content, ensure_ascii=False) + tool_calls.append( + ToolCall(function=Tool(name=fn_name, arguments=payload)) + ) + continue + + if "name" in item: + normalized_items.append(item) + + if normalized_items: + tool_calls.extend( + ToolCallProcessor._build_tool_calls_from_normalized(normalized_items) + ) + + return tool_calls + + @staticmethod + def from_pythonic(raw_text: str) -> List[ToolCall]: + """Parse Pythonic list-of-calls tool syntax.""" + text = _strip_think_blocks(raw_text).strip() + if text.startswith("<|python_tag|>"): + text = text[len("<|python_tag|>") :].lstrip() + if not text: + return [] + + if not text.startswith("[") and re.match(r"^[A-Za-z_]\w*\s*\(", text): + text = f"[{text}]" + + expression = ast.parse(text, mode="eval").body + call_nodes = expression.elts if isinstance(expression, ast.List) else [expression] + + tool_calls = [] + for node in call_nodes: + if not isinstance(node, ast.Call) or not isinstance(node.func, ast.Name): + continue + args_dict: Dict[str, Any] = {} + if node.args: + args_dict["_args"] = [ + ToolCallProcessor._ast_to_literal(argument) + for argument in node.args + ] + for keyword in node.keywords: + if keyword.arg is None: + continue + args_dict[keyword.arg] = ToolCallProcessor._ast_to_literal(keyword.value) + + tool_calls.append( + ToolCall( + function=Tool( + name=node.func.id, + arguments=json.dumps(args_dict, ensure_ascii=False), + ) + ) + ) + + return tool_calls + + @staticmethod + def from_deepseek_v31(raw_text: str) -> List[ToolCall]: + """Parse DeepSeek v3.1 tool call syntax.""" + tool_calls = [] + for match in DEEPSEEK_V31_CALL_RE.finditer(raw_text): + name = match.group("name").strip() + if not name: + continue + arguments = ToolCallProcessor._coerce_argument_payload(match.group("args")) + tool_calls.append(ToolCall(function=Tool(name=name, arguments=arguments))) + return tool_calls + + @staticmethod + def from_deepseek_v3(raw_text: str) -> List[ToolCall]: + """Parse DeepSeek v3 tool call syntax.""" + tool_calls = [] + for match in DEEPSEEK_V3_CALL_RE.finditer(raw_text): + name = match.group("name").strip() + if not name: + continue + arguments = ToolCallProcessor._coerce_argument_payload(match.group("args")) + tool_calls.append(ToolCall(function=Tool(name=name, arguments=arguments))) + + if tool_calls: + return tool_calls + + return ToolCallProcessor.from_deepseek_v31(raw_text) + + @staticmethod + def from_deepseek_v32(raw_text: str) -> List[ToolCall]: + """Parse DeepSeek v3.2 DSML tool call syntax.""" + tool_calls = [] + for invoke in DEEPSEEK_V32_INVOKE_RE.finditer(raw_text): + function_name = invoke.group("name").strip() + if not function_name: + continue + + params: Dict[str, Any] = {} + body = invoke.group("body") + for param in DEEPSEEK_V32_PARAM_RE.finditer(body): + key = param.group("name").strip() + value_raw = param.group("value") + is_string = param.group("string") == "true" + if is_string: + value = value_raw.strip("\n") + else: + value = _coerce_param_value(value_raw) + params[key] = value + + tool_calls.append( + ToolCall( + function=Tool( + name=function_name, + arguments=json.dumps(params, ensure_ascii=False), + ) + ) + ) + + if tool_calls: + return tool_calls + + return ToolCallProcessor.from_deepseek_v31(raw_text) + @staticmethod def from_json(tool_calls_str: str) -> List[ToolCall]: - """Postprocess tool call JSON to a parseable class""" + """Postprocess tool call JSON to a parseable class. + + Handles clean JSON arrays, markdown-fenced output, flat dicts, + and other common model output variations via _safe_json_loads. + """ + logger.debug(f"JSON Parser: Parsing tool calls ({len(tool_calls_str)} chars)") - tool_calls = json.loads(tool_calls_str) + tool_calls = ToolCallProcessor._safe_json_loads(tool_calls_str) for tool_call in tool_calls: tool_call["function"]["arguments"] = json.dumps( tool_call["function"]["arguments"] ) - return [ToolCall(**tool_call) for tool_call in tool_calls] + result = [ToolCall(**tool_call) for tool_call in tool_calls] + logger.debug(f"JSON Parser: Successfully parsed {len(result)} tool call(s)") + return result + + # ------------------------------------------------------------------ + # XML parsing (Qwen3-Coder / GLM-4.5 style) + # ------------------------------------------------------------------ @staticmethod - def dump(tool_calls: List[ToolCall]) -> List[dict]: + def from_xml(raw_text: str) -> List[ToolCall]: + """Parse Qwen3-Coder XML-format tool calls into ToolCall objects. + + Handles: + - Wrapped: ... + - Bare: ... (missing wrapper) + - Multiple sequential tool call blocks + - blocks (stripped) + - Multi-line parameter values + - Missing closing tags """ - Convert ToolCall objects to a list of dictionaries. + logger.debug(f"XML Parser: Parsing tool calls ({len(raw_text)} chars)") + + # Stage 1: Strip think blocks + text = _strip_think_blocks(raw_text) + + # Stage 2: Check for incomplete XML at end (generation cutoff) + stripped_end = text.rstrip() + if stripped_end.endswith(("<", "]*$", "", text) + + # Stage 3: Extract function blocks + # First, find all wrapped ... blocks + wrapped_positions = [ + (m.start(), m.end()) for m in TOOL_CALL_BLOCK_RE.finditer(text) + ] + + # Collect function blocks from inside wrapped regions + function_blocks = [] + for match in TOOL_CALL_BLOCK_RE.finditer(text): + inner = match.group(1) + for func_match in FUNCTION_RE.finditer(inner): + function_blocks.append((func_match.group(1), func_match.group(2))) + + # Find bare blocks NOT inside any wrapped region + for func_match in FUNCTION_RE.finditer(text): + pos = func_match.start() + is_wrapped = any(start <= pos < end for start, end in wrapped_positions) + if not is_wrapped: + logger.debug( + "XML Parser: Found bare block without " + " wrapper" + ) + function_blocks.append((func_match.group(1), func_match.group(2))) + + if not function_blocks: + logger.warning("XML Parser: No blocks found") + return [] + + # Stage 4: Parse each function block into a ToolCall + tool_calls = [] + for func_name_raw, func_body in function_blocks: + func_name = func_name_raw.strip() + + # Extract parameters + params = {} + for param_match in PARAMETER_RE.finditer(func_body): + key = param_match.group(1).strip() + value_raw = param_match.group(2) + value = _coerce_param_value(value_raw) + params[key] = value + + arguments_json = json.dumps(params, ensure_ascii=False) + + tool_call = ToolCall( + function=Tool(name=func_name, arguments=arguments_json) + ) + tool_calls.append(tool_call) + + logger.debug(f"XML Parser: Successfully parsed {len(tool_calls)} tool call(s)") + return tool_calls + + # ------------------------------------------------------------------ + # Auto-detect parsing (JSON → JSON-in-tool_call → XML) + # ------------------------------------------------------------------ + + @staticmethod + def from_auto(raw_text: str) -> List[ToolCall]: + """Auto-detect format and parse. + + Tries in order: + 1. Pure JSON (standard TabbyAPI / Llama) + 2. JSON inside wrappers (Qwen3-Instruct style) + 3. XML with tags (Qwen3-Coder style) + """ + logger.debug("Auto Parser: Attempting format auto-detection") + + # Attempt 1: Pure JSON array + try: + result = ToolCallProcessor.from_json(raw_text) + logger.debug("Auto Parser: Detected JSON format") + return result + except (json.JSONDecodeError, ValueError, KeyError) as e: + logger.debug(f"Auto Parser: Not JSON ({e}), trying next format") + + # Attempt 2: JSON inside wrappers (Qwen3-Instruct) + try: + all_tool_calls = [] + for match in TOOL_CALL_BLOCK_RE.finditer(raw_text): + inner = match.group(1).strip() + if inner.startswith("{") or inner.startswith("["): + parsed = json.loads(inner) + if isinstance(parsed, dict): + parsed = [parsed] + if isinstance(parsed, list): + for tc in parsed: + name = tc.get("name", "") + arguments = tc.get("arguments", {}) + if isinstance(arguments, dict): + arguments = json.dumps(arguments) + elif not isinstance(arguments, str): + arguments = json.dumps(arguments) + all_tool_calls.append( + ToolCall(function=Tool(name=name, arguments=arguments)) + ) + if all_tool_calls: + logger.debug( + "Auto Parser: Detected JSON-inside-tool_call " + f"format ({len(all_tool_calls)} call(s))" + ) + return all_tool_calls + except (json.JSONDecodeError, ValueError, KeyError) as e: + logger.debug(f"Auto Parser: Not JSON-in-tool_call ({e}), trying XML") + + # Attempt 3: XML format (Qwen3-Coder style) + result = ToolCallProcessor.from_xml(raw_text) + if result: + logger.debug("Auto Parser: Detected XML format") + else: + logger.warning("Auto Parser: All format detection attempts failed") + return result + + # ------------------------------------------------------------------ + # Dispatcher + # ------------------------------------------------------------------ + + @staticmethod + def _parser_dispatcher() -> Dict[str, Callable[[str], List[ToolCall]]]: + """Registry for parser-key-specific handlers.""" + if not ToolCallProcessor._PARSER_DISPATCHER: + ToolCallProcessor._PARSER_DISPATCHER = { + "deepseek_v3": ToolCallProcessor.from_deepseek_v3, + "deepseek_v31": ToolCallProcessor.from_deepseek_v31, + "deepseek_v32": ToolCallProcessor.from_deepseek_v32, + "hermes": ToolCallProcessor.from_hermes, + "llama": ToolCallProcessor.from_llama, + "llama3_json": ToolCallProcessor.from_llama, + "llama4_json": ToolCallProcessor.from_llama, + "openai": ToolCallProcessor.from_openai, + "pythonic": ToolCallProcessor.from_pythonic, + "qwen3_coder": ToolCallProcessor.from_xml, + "qwen3_xml": ToolCallProcessor.from_xml, + } + return ToolCallProcessor._PARSER_DISPATCHER + + @staticmethod + def parse( + tool_calls_str: str, format: str = "json", parser_key: str | None = None + ) -> List[ToolCall]: + """Dispatch tool call parsing to the appropriate format handler. + + Args: + tool_calls_str: Raw tool call text from model generation. + format: One of ``"json"``, ``"xml"``, ``"auto"``. + parser_key: Optional vLLM-compatible parser key. + + Returns: + List of parsed ToolCall objects. Empty list on parse failure + (never raises). + """ + try: + if parser_key: + canonical_key = resolve_tool_call_parser_key(parser_key) or parser_key + parser = ToolCallProcessor._parser_dispatcher().get(canonical_key) + if parser: + try: + parsed = parser(tool_calls_str) + except Exception as exc: + logger.warning( + "Parser '{}' failed: {}. Falling back to format '{}'.", + canonical_key, + str(exc), + format, + ) + else: + if parsed: + return parsed + + if format == "xml": + return ToolCallProcessor.from_xml(tool_calls_str) + elif format == "auto": + return ToolCallProcessor.from_auto(tool_calls_str) + else: + return ToolCallProcessor.from_json(tool_calls_str) + except Exception as e: + logger.error( + f"ToolCallProcessor.parse: Failed to parse tool calls " + f"(format={format}): {e}" + ) + return [] + + # ------------------------------------------------------------------ + # Filtering + # ------------------------------------------------------------------ + + @staticmethod + def filter_by_name( + tool_calls: List[ToolCall], function_name: str + ) -> List[ToolCall]: + """Filter parsed tool calls to only those matching a function name.""" + filtered = [tc for tc in tool_calls if tc.function.name == function_name] + if not filtered: + logger.warning( + f"filter_by_name: No tool calls matched '{function_name}' " + f"(had {len(tool_calls)} call(s))" + ) + return filtered + + # ------------------------------------------------------------------ + # Content / tool-call separation + # ------------------------------------------------------------------ + + @staticmethod + def extract_content_and_tools( + raw_text: str, + ) -> Tuple[str, List[ToolCall]]: + """Separate plain text content from XML tool call blocks. + + Used when the model mixes reasoning text with tool calls, e.g.: + ``"I'll help with that: ...`` + + Returns: + Tuple of (remaining_content, tool_calls). + """ + text = _strip_think_blocks(raw_text) + + # Collect all XML regions to exclude from content + xml_regions = [] + + # Wrapped tool call blocks + for match in TOOL_CALL_BLOCK_RE.finditer(text): + xml_regions.append((match.start(), match.end())) + + # Bare function blocks not inside wrappers + for match in FUNCTION_RE.finditer(text): + pos = match.start() + is_wrapped = any(start <= pos < end for start, end in xml_regions) + if not is_wrapped: + xml_regions.append((match.start(), match.end())) + + # Sort and extract content (everything outside XML regions) + xml_regions.sort() + content_parts = [] + last_end = 0 + for start, end in xml_regions: + if start > last_end: + part = text[last_end:start].strip() + if part: + content_parts.append(part) + last_end = end + if last_end < len(text): + part = text[last_end:].strip() + if part: + content_parts.append(part) + + content = " ".join(content_parts).strip() + + # Parse tool calls from the full text + tool_calls = ToolCallProcessor.from_xml(text) + + logger.debug( + f"extract_content_and_tools: Found {len(tool_calls)} tool " + f"call(s), content={'yes' if content else 'no'} " + f"({len(content)} chars)" + ) + + return content, tool_calls + + # ------------------------------------------------------------------ + # Serialisation helpers (unchanged from original) + # ------------------------------------------------------------------ + + @staticmethod + def dump(tool_calls: List[ToolCall]) -> List[dict]: + """Convert ToolCall objects to a list of dictionaries. Args: tool_calls (List[ToolCall]): List of ToolCall objects to convert @@ -65,8 +856,7 @@ def dump(tool_calls: List[ToolCall]) -> List[dict]: @staticmethod def to_json(tool_calls: List[ToolCall]) -> str: - """ - Convert ToolCall objects to JSON string representation. + """Convert ToolCall objects to JSON string representation. Args: tool_calls (List[ToolCall]): List of ToolCall objects to convert diff --git a/templates/tool_calls/qwen3_coder.jinja b/templates/tool_calls/qwen3_coder.jinja new file mode 100644 index 00000000..0df78172 --- /dev/null +++ b/templates/tool_calls/qwen3_coder.jinja @@ -0,0 +1,125 @@ +{# SPDX-License-Identifier: Apache-2.0 #} +{# SPDX-FileCopyrightText: Copyright contributors to the vLLM project #} +{# TabbyAPI Metadata #} +{%- set tool_call_format = "xml" -%} +{%- set tool_start = "" -%} +{%- set tool_end = "" -%} +{%- set stop_strings = ["<|im_start|>", "<|im_end|>"] -%} + +{% macro render_extra_keys(json_dict, handled_keys) %} + {%- if json_dict is mapping %} + {%- for json_key in json_dict if json_key not in handled_keys %} + {%- if json_dict[json_key] is string %} + {{-'\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | string) ~ '' }} + {%- else %} + {{- '\n<' ~ json_key ~ '>' ~ (json_dict[json_key] | tojson | safe) ~ '' }} + {%- endif %} + {%- endfor %} + {%- endif %} +{%- endmacro %} + +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} + +{%- if not tools is defined %} + {%- set tools = [] %} +{%- endif %} + +{%- if system_message is defined %} + {{- "<|im_start|>system\n" + system_message }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- "<|im_start|>system\nYou are Qwen, a helpful AI assistant that can interact with a computer to solve tasks." }} + {%- endif %} +{%- endif %} +{%- if tools is iterable and tools | length > 0 %} + {{- "\n\n# Tools\n\nYou have access to the following functions:\n\n" }} + {{- "" }} + {%- for tool in tools %} + {%- if tool.function is defined %} + {%- set tool = tool.function %} + {%- endif %} + {{- "\n\n" ~ tool.name ~ "" }} + {%- if tool.description is defined %} + {{- '\n' ~ (tool.description | trim) ~ '' }} + {%- endif %} + {{- '\n' }} + {%- if tool.parameters is defined and tool.parameters is mapping and tool.parameters.properties is defined and tool.parameters.properties is mapping %} + {%- for param_name, param_fields in tool.parameters.properties|items %} + {{- '\n' }} + {{- '\n' ~ param_name ~ '' }} + {%- if param_fields.type is defined %} + {{- '\n' ~ (param_fields.type | string) ~ '' }} + {%- endif %} + {%- if param_fields.description is defined %} + {{- '\n' ~ (param_fields.description | trim) ~ '' }} + {%- endif %} + {%- set handled_keys = ['name', 'type', 'description'] %} + {{- render_extra_keys(param_fields, handled_keys) }} + {{- '\n' }} + {%- endfor %} + {%- endif %} + {%- set handled_keys = ['type', 'properties'] %} + {{- render_extra_keys(tool.parameters, handled_keys) }} + {{- '\n' }} + {%- set handled_keys = ['type', 'name', 'description', 'parameters'] %} + {{- render_extra_keys(tool, handled_keys) }} + {{- '\n' }} + {%- endfor %} + {{- "\n" }} + {{- '\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n\n\n\nvalue_1\n\n\nThis is the value for the second parameter\nthat can span\nmultiple lines\n\n\n\n\n\nReminder:\n- Function calls MUST follow the specified format: an inner block must be nested within XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n' }} +{%- endif %} +{%- if system_message is defined %} + {{- '<|im_end|>\n' }} +{%- else %} + {%- if tools is iterable and tools | length > 0 %} + {{- '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- for message in loop_messages %} + {%- if message.role == "assistant" and message.tool_calls is defined and message.tool_calls is iterable and message.tool_calls | length > 0 %} + {{- '<|im_start|>' + message.role }} + {%- if message.content is defined and message.content is string and message.content | trim | length > 0 %} + {{- '\n' + message.content | trim + '\n' }} + {%- endif %} + {%- for tool_call in message.tool_calls %} + {%- if tool_call.function is defined %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n\n\n' }} + {%- if tool_call.arguments is defined %} + {%- for args_name, args_value in tool_call.arguments|items %} + {{- '\n' }} + {%- set args_value = args_value if args_value is string else args_value | tojson | safe %} + {{- args_value }} + {{- '\n\n' }} + {%- endfor %} + {%- endif %} + {{- '\n' }} + {%- endfor %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "user" or message.role == "system" or message.role == "assistant" %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>' + '\n' }} + {%- elif message.role == "tool" %} + {%- if loop.previtem and loop.previtem.role != "tool" %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- message.content }} + {{- '\n' }} + {%- if not loop.last and loop.nextitem.role != "tool" %} + {{- '<|im_end|>\n' }} + {%- elif loop.last %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + message.content + '<|im_end|>\n' }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} +{%- endif %} diff --git a/tests/exaone4_reasoning_parser_test.py b/tests/exaone4_reasoning_parser_test.py new file mode 100644 index 00000000..135b37cb --- /dev/null +++ b/tests/exaone4_reasoning_parser_test.py @@ -0,0 +1,203 @@ +"""Tests for Exaone4 reasoning parser behavior.""" + +from endpoints.OAI.reasoning.exaone4_reasoning_parser import Exaone4ReasoningParser + + +class _FakeTokenizer: + def get_vocab(self): + return { + "": 101, + "": 102, + } + + +def _parser(enable_thinking: bool) -> Exaone4ReasoningParser: + return Exaone4ReasoningParser( + _FakeTokenizer(), + chat_template_kwargs={"enable_thinking": enable_thinking}, + ) + + +def test_non_thinking_mode_emits_content_only(): + parser = _parser(enable_thinking=False) + + reasoning, content = parser.extract_reasoning("hello", request=None) + assert reasoning is None + assert content == "hello" + + reasoning, content = parser.extract_reasoning("hello", request=None) + assert reasoning is None + assert content == "hello" + + delta = parser.extract_reasoning_streaming( + previous_text="", + current_text="hello", + delta_text="hello", + previous_token_ids=[], + current_token_ids=[1], + delta_token_ids=[1], + ) + assert delta is not None + assert delta.reasoning is None + assert delta.content == "hello" + + +def test_thinking_mode_extract_reasoning_and_content_non_streaming(): + parser = _parser(enable_thinking=True) + + reasoning, content = parser.extract_reasoning( + "reasonanswer", request=None + ) + assert reasoning == "reason" + assert content == "answer" + + reasoning, content = parser.extract_reasoning("reasonanswer", request=None) + assert reasoning == "reason" + assert content == "answer" + + +def test_thinking_mode_without_end_token_is_reasoning_only(): + parser = _parser(enable_thinking=True) + + reasoning, content = parser.extract_reasoning("reasoning only", request=None) + assert reasoning == "reasoning only" + assert content is None + + +def test_thinking_streaming_prefill_flow_without_start_token(): + parser = _parser(enable_thinking=True) + + first = parser.extract_reasoning_streaming( + previous_text="", + current_text="reason ", + delta_text="reason ", + previous_token_ids=[], + current_token_ids=[11], + delta_token_ids=[11], + ) + assert first is not None + assert first.reasoning == "reason " + assert first.content is None + + second = parser.extract_reasoning_streaming( + previous_text="reason ", + current_text="reason morefinal", + delta_text="morefinal", + previous_token_ids=[11], + current_token_ids=[11, 12, 102, 13], + delta_token_ids=[12, 102, 13], + ) + assert second is not None + assert second.reasoning == "more" + assert second.content == "final" + + third = parser.extract_reasoning_streaming( + previous_text="reason morefinal", + current_text="reason morefinal!", + delta_text="!", + previous_token_ids=[11, 12, 102, 13], + current_token_ids=[11, 12, 102, 13, 14], + delta_token_ids=[14], + ) + assert third is not None + assert third.reasoning is None + assert third.content == "!" + + +def test_thinking_streaming_handles_split_end_token_boundary(): + parser = _parser(enable_thinking=True) + + first = parser.extract_reasoning_streaming( + previous_text="", + current_text="analysis {"name":"lookup","arguments":{}}' + + +def test_thinking_streaming_handles_split_deepseek_tool_boundary_without_end_token(): + parser = _parser(enable_thinking=True) + + first = parser.extract_reasoning_streaming( + previous_text="", + current_text="analysis <|tool▁call▁b", + delta_text="analysis <|tool▁call▁b", + previous_token_ids=[], + current_token_ids=[11], + delta_token_ids=[11], + ) + assert first is not None + assert first.reasoning == "analysis " + assert first.content is None + + second = parser.extract_reasoning_streaming( + previous_text="analysis <|tool▁call▁b", + current_text=( + "analysis <|tool▁call▁begin|>lookup<|tool▁sep|>{\"q\":\"tabby\"}" + "<|tool▁call▁end|>" + ), + delta_text='egin|>lookup<|tool▁sep|>{"q":"tabby"}<|tool▁call▁end|>', + previous_token_ids=[11], + current_token_ids=[11, 12], + delta_token_ids=[12], + ) + assert second is not None + assert second.reasoning is None + assert second.content == ( + '<|tool▁call▁begin|>lookup<|tool▁sep|>{"q":"tabby"}<|tool▁call▁end|>' + ) + + +def test_thinking_mode_content_ids_and_end_detection(): + parser = _parser(enable_thinking=True) + + assert parser.is_reasoning_end([1, 2, 102]) is True + assert parser.is_reasoning_end([1, 2, 3]) is False + + assert parser.extract_content_ids([10, 101, 20, 102, 30, 31]) == [30, 31] + assert parser.extract_content_ids([10, 101, 20]) == [] diff --git a/tests/parser_options_test.py b/tests/parser_options_test.py new file mode 100644 index 00000000..0d328bc0 --- /dev/null +++ b/tests/parser_options_test.py @@ -0,0 +1,39 @@ +"""Tests for vLLM-compatible parser option mapping.""" + +from endpoints.OAI.utils.parser_options import ( + list_tool_call_parsers, + parser_uses_native_tool_generation, + resolve_tool_call_parser_key, + resolve_tool_call_format, +) + + +def test_parser_key_registry_contains_core_vllm_keys(): + parser_keys = list_tool_call_parsers() + + assert "openai" in parser_keys + assert "qwen3_coder" in parser_keys + assert "qwen3_xml" in parser_keys + assert "mistral" in parser_keys + assert "deepseek_v3" in parser_keys + assert "llama" in parser_keys + + +def test_resolve_tool_call_format_uses_vllm_mapping(): + assert resolve_tool_call_format("openai", "json") == "json" + assert resolve_tool_call_format("qwen3_coder", "json") == "xml" + assert resolve_tool_call_format("auto", "json") == "auto" + assert resolve_tool_call_format("llama", "json") == "json" + assert resolve_tool_call_parser_key("llama") == "llama3_json" + + +def test_resolve_tool_call_format_falls_back_and_rejects_unknown(): + assert resolve_tool_call_format(None, "json") == "json" + assert resolve_tool_call_format("unknown_parser", "json") == "" + + +def test_native_generation_flags_cover_native_syntax_parsers(): + assert parser_uses_native_tool_generation("qwen3_coder", "json") is True + assert parser_uses_native_tool_generation("deepseek_v31", "json") is True + assert parser_uses_native_tool_generation("pythonic", "json") is True + assert parser_uses_native_tool_generation("hermes", "json") is False diff --git a/tests/tool_parser_test.py b/tests/tool_parser_test.py new file mode 100644 index 00000000..6d8b9372 --- /dev/null +++ b/tests/tool_parser_test.py @@ -0,0 +1,211 @@ +"""Tests for tool call parsing helpers.""" + +import json + +from endpoints.OAI.utils.tools import ToolCallProcessor + + +def _arguments_dict(tool_call): + return json.loads(tool_call.function.arguments) + + +def test_from_json_handles_markdown_fences_and_flat_shape(): + payload = """```json +[{"name": "get_weather", "arguments": {"city": "Seoul"}}] +```""" + + parsed = ToolCallProcessor.from_json(payload) + + assert len(parsed) == 1 + assert parsed[0].function.name == "get_weather" + assert _arguments_dict(parsed[0]) == {"city": "Seoul"} + + +def test_from_xml_parses_qwen3_coder_style_blocks(): + payload = ( + "internal reasoning" + "" + "\nSeoul\n" + "\n3\n" + "" + ) + + parsed = ToolCallProcessor.from_xml(payload) + + assert len(parsed) == 1 + assert parsed[0].function.name == "get_weather" + assert _arguments_dict(parsed[0]) == {"city": "Seoul", "days": 3} + + +def test_from_auto_parses_json_inside_tool_call_wrapper(): + payload = ( + "" + '{"name": "search", "arguments": {"query": "tabbyapi"}}' + "" + ) + + parsed = ToolCallProcessor.from_auto(payload) + + assert len(parsed) == 1 + assert parsed[0].function.name == "search" + assert _arguments_dict(parsed[0]) == {"query": "tabbyapi"} + + +def test_extract_content_and_tools_splits_content_from_xml_calls(): + payload = ( + "I will call a tool now. " + "\ntabby\n" + "" + " Done." + ) + + content, parsed = ToolCallProcessor.extract_content_and_tools(payload) + + assert "I will call a tool now." in content + assert "Done." in content + assert len(parsed) == 1 + assert parsed[0].function.name == "search" + + +def test_filter_by_name_keeps_only_requested_function(): + payload = ( + "[" + '{"name": "a", "arguments": {}},' + '{"name": "b", "arguments": {}}' + "]" + ) + parsed = ToolCallProcessor.from_json(payload) + + filtered = ToolCallProcessor.filter_by_name(parsed, "b") + + assert len(filtered) == 1 + assert filtered[0].function.name == "b" + + +def test_parse_with_hermes_parser_handles_wrapped_json(): + payload = ( + "" + '{"name":"weather","arguments":{"city":"Seoul"}}' + "" + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="hermes") + + assert len(parsed) == 1 + assert parsed[0].function.name == "weather" + assert _arguments_dict(parsed[0]) == {"city": "Seoul"} + + +def test_parse_with_llama_parser_handles_sequential_json(): + payload = ( + "<|python_tag|>" + '{"name":"a","arguments":{"x":1}};' + '{"name":"b","arguments":{"y":2}}' + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="llama") + + assert len(parsed) == 2 + assert parsed[0].function.name == "a" + assert _arguments_dict(parsed[0]) == {"x": 1} + assert parsed[1].function.name == "b" + assert _arguments_dict(parsed[1]) == {"y": 2} + + +def test_parse_with_pythonic_parser_extracts_function_calls(): + payload = "[get_weather(city='San Francisco', days=3)]" + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="pythonic") + + assert len(parsed) == 1 + assert parsed[0].function.name == "get_weather" + assert _arguments_dict(parsed[0]) == {"city": "San Francisco", "days": 3} + + +def test_parse_with_deepseek_v31_parser(): + payload = ( + "<|tool▁calls▁begin|>" + '<|tool▁call▁begin|>foo<|tool▁sep|>{"x":1}<|tool▁call▁end|>' + "<|tool▁calls▁end|>" + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="deepseek_v31") + + assert len(parsed) == 1 + assert parsed[0].function.name == "foo" + assert _arguments_dict(parsed[0]) == {"x": 1} + + +def test_parse_with_deepseek_v3_parser(): + payload = ( + "<|tool▁calls▁begin|>" + "<|tool▁call▁begin|>function<|tool▁sep|>lookup\n" + "```json\n" + '{"q":"tabbyapi"}' + "\n```\n" + "<|tool▁call▁end|>" + "<|tool▁calls▁end|>" + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="deepseek_v3") + + assert len(parsed) == 1 + assert parsed[0].function.name == "lookup" + assert _arguments_dict(parsed[0]) == {"q": "tabbyapi"} + + +def test_parse_with_deepseek_v32_parser(): + payload = ( + "<|DSML|function_calls>" + '<|DSML|invoke name="get_weather">' + '<|DSML|parameter name="location" string="true">Seoul' + '<|DSML|parameter name="days" string="false">3' + "" + "" + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="deepseek_v32") + + assert len(parsed) == 1 + assert parsed[0].function.name == "get_weather" + assert _arguments_dict(parsed[0]) == {"location": "Seoul", "days": 3} + + +def test_parse_with_openai_parser_handles_functions_recipient(): + payload = ( + '[{"recipient":"functions.get_weather","content":"{\\"city\\":\\"Seoul\\"}"}]' + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="openai") + + assert len(parsed) == 1 + assert parsed[0].function.name == "get_weather" + assert _arguments_dict(parsed[0]) == {"city": "Seoul"} + + +def test_parser_key_dispatch_overrides_format_for_qwen3_xml(): + payload = ( + "" + "\ntabby\n" + "" + ) + + parsed = ToolCallProcessor.parse(payload, format="json", parser_key="qwen3_xml") + + assert len(parsed) == 1 + assert parsed[0].function.name == "search" + assert _arguments_dict(parsed[0]) == {"q": "tabby"} + + +def test_parser_failure_falls_back_to_format_parser(): + payload = ( + "" + "\n42\n" + "" + ) + + parsed = ToolCallProcessor.parse(payload, format="xml", parser_key="openai") + + assert len(parsed) == 1 + assert parsed[0].function.name == "lookup" + assert _arguments_dict(parsed[0]) == {"id": 42}