diff --git a/tensorrt_llm/llmapi/reasoning_parser.py b/tensorrt_llm/llmapi/reasoning_parser.py index eb10011d16e..1c8d77ff09b 100644 --- a/tensorrt_llm/llmapi/reasoning_parser.py +++ b/tensorrt_llm/llmapi/reasoning_parser.py @@ -69,6 +69,12 @@ def parse(self, text: str) -> ReasoningParserResult: def parse_delta(self, delta_text: str) -> ReasoningParserResult: raise NotImplementedError + def finish(self) -> ReasoningParserResult: + """Called when the stream ends. Subclasses may override to flush + buffered state or reclassify accumulated content. The default + implementation returns an empty result.""" + return ReasoningParserResult() + @register_reasoning_parser("deepseek-r1", reasoning_at_start=True) @register_reasoning_parser("qwen3") @@ -185,6 +191,12 @@ def __init__(self, "force_nonempty_content", False) is True super().__init__(reasoning_at_start=reasoning_at_start, chat_template_kwargs=chat_template_kwargs) + # Workaround: the model sometimes does not send closing think tags + # which affects downstream applications. This is addressed by + # optionally accumulating reasoning tokens and returning them as + # content at the end of streaming. + self._accumulated_reasoning = "" + self._found_closing_tag = False def _maybe_swap_content( self, result: ReasoningParserResult) -> ReasoningParserResult: @@ -195,5 +207,50 @@ def _maybe_swap_content( reasoning_content="") return result + def parse_delta(self, delta_text: str) -> ReasoningParserResult: + """Wraps the parent parse_delta to track accumulated reasoning when + force_nonempty_content is set. When the closing tag is found + (in_reasoning transitions from True to False), the accumulation + is cleared to free memory.""" + was_in_reasoning = self.in_reasoning + result = super().parse_delta(delta_text) + if self._force_nonempty_content: + if result.reasoning_content: + self._accumulated_reasoning += result.reasoning_content + if was_in_reasoning and not self.in_reasoning: + self._found_closing_tag = True + self._accumulated_reasoning = "" + return result + + def finish(self) -> ReasoningParserResult: + """Called when the stream ends. + + If no closing think tag was found and force_nonempty_content is + set, returns the full accumulated reasoning as content so the + response is never empty. If no closing tag was found and + force_nonempty_content is not set, returns any remaining buffer + as reasoning_content since we are still in reasoning mode. + + If the closing tag was already found (or reasoning was never + entered), flushes any remaining buffer as content.""" + if self.in_reasoning and not self._found_closing_tag: + remaining = self._buffer + self._buffer = "" + if self._force_nonempty_content: + all_content = self._accumulated_reasoning + remaining + self._accumulated_reasoning = "" + self.in_reasoning = False + return ReasoningParserResult(content=all_content) + self._accumulated_reasoning = "" + self.in_reasoning = False + if remaining: + return ReasoningParserResult(reasoning_content=remaining) + return ReasoningParserResult() + remaining = self._buffer + self._buffer = "" + if remaining: + return ReasoningParserResult(content=remaining) + return ReasoningParserResult() + def parse(self, text: str) -> ReasoningParserResult: return self._maybe_swap_content(super().parse(text)) diff --git a/tensorrt_llm/serve/postprocess_handlers.py b/tensorrt_llm/serve/postprocess_handlers.py index bacce813b6b..2b1c83e91d5 100644 --- a/tensorrt_llm/serve/postprocess_handlers.py +++ b/tensorrt_llm/serve/postprocess_handlers.py @@ -12,7 +12,8 @@ from ..executor.result import Logprob, TokenLogprobs from ..llmapi import SamplingParams from ..llmapi.reasoning_parser import (BaseReasoningParser, - ReasoningParserFactory) + ReasoningParserFactory, + ReasoningParserResult) from ..llmapi.tokenizer import TransformersTokenizer # yapf: disable from .chat_utils import make_tool_call_id @@ -111,8 +112,11 @@ def create_logprobs(token_ids: List[int], tokenizer: TransformersTokenizer, return chat_logprobs -def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str, - streaming: bool) -> Tuple[str, str]: +def apply_reasoning_parser(args: ChatPostprocArgs, + output_index: int, + text: str, + streaming: bool, + finished: bool = False) -> Tuple[str, str]: reasoning_parser = None if args.reasoning_parser is not None: if output_index not in args.reasoning_parser_dict: @@ -127,6 +131,13 @@ def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str, result = reasoning_parser.parse(text) else: result = reasoning_parser.parse_delta(text) + if finished: + finish_result = reasoning_parser.finish() + result = ReasoningParserResult( + content=result.content + finish_result.content, + reasoning_content=result.reasoning_content + + finish_result.reasoning_content, + ) content, reasoning_content = result.content, result.reasoning_content else: content, reasoning_content = text, "" @@ -214,7 +225,11 @@ def yield_first_chat(num_tokens: int, delta_text = output.text_diff delta_text, reasoning_delta_text = apply_reasoning_parser( - args, i, delta_text, True) + args, + i, + delta_text, + True, + finished=(output.finish_reason is not None)) if args.tool_choice and type( args.tool_choice) is ChatCompletionNamedToolChoiceParam: diff --git a/tensorrt_llm/serve/responses_utils.py b/tensorrt_llm/serve/responses_utils.py index 35a999ac9a5..104228d4df3 100644 --- a/tensorrt_llm/serve/responses_utils.py +++ b/tensorrt_llm/serve/responses_utils.py @@ -46,7 +46,8 @@ from tensorrt_llm.llmapi import SamplingParams from tensorrt_llm.llmapi.llm import RequestOutput from tensorrt_llm.llmapi.reasoning_parser import (BaseReasoningParser, - ReasoningParserFactory) + ReasoningParserFactory, + ReasoningParserResult) from tensorrt_llm.llmapi.tokenizer import TokenizerBase, TransformersTokenizer from tensorrt_llm.logger import logger from tensorrt_llm.serve.chat_utils import parse_chat_messages_coroutines @@ -927,6 +928,7 @@ def _apply_reasoning_parser( text: str, streaming: bool, reasoning_parser_dict: Optional[dict[int, BaseReasoningParser]] = None, + finished: bool = False, ) -> Tuple[str, str]: reasoning_parser: Optional[BaseReasoningParser] = None if reasoning_parser_id is not None: @@ -946,6 +948,13 @@ def _apply_reasoning_parser( result = reasoning_parser.parse(text) else: result = reasoning_parser.parse_delta(text) + if finished: + finish_result = reasoning_parser.finish() + result = ReasoningParserResult( + content=result.content + finish_result.content, + reasoning_content=result.reasoning_content + + finish_result.reasoning_content, + ) content, reasoning_content = result.content, result.reasoning_content else: content, reasoning_content = text, "" @@ -1490,6 +1499,14 @@ def _should_send_done_events( should_send_reasoning_done = True reasoning_content = full_reasoning + # No closing tag: reasoning was streamed but re-parse shows everything as + # content (no found). Close the reasoning section so the text + # section can be properly opened and closed. + if not full_reasoning and full_text and finished_generation: + if streaming_events_helper and streaming_events_helper.is_reasoning_sent: + should_send_reasoning_done = True + reasoning_content = full_text + return should_send_reasoning_done, should_send_text_done, reasoning_content, text_content @@ -1525,6 +1542,7 @@ def check_parser(parser_id: Optional[str], text=delta_text, streaming=True, reasoning_parser_dict=reasoning_parser_dict, + finished=finished_generation, ) if delta_text: @@ -1595,6 +1613,37 @@ def check_parser(parser_id: Optional[str], streaming_events_helper.is_output_item_added_sent = False streaming_events_helper.is_text_sent = False + # Handle no-closing-tag case: reasoning was streamed but finish() moved + # all accumulated reasoning to content. Emit the full text section + # lifecycle (added → delta → done) since the reasoning section was just + # closed and generation is finished. + if (finished_generation and delta_text and should_send_reasoning_done + and not should_send_text_done): + streaming_events_helper.is_text_sent = True + yield from streaming_events_helper.get_message_output_added_events() + yield streaming_events_helper.get_text_delta_event(delta_text, []) + text_content_obj = ResponseOutputText( + text=delta_text, + annotations=[], + type="output_text", + logprobs=None, + ) + text_item = ResponseOutputMessage( + id=streaming_events_helper.item_id, + content=[text_content_obj], + role="assistant", + status="completed", + type="message", + ) + yield streaming_events_helper.get_text_done_event(delta_text, []) + yield streaming_events_helper.get_content_part_done_event( + text_content_obj) + yield streaming_events_helper.get_output_item_done_event(text_item) + streaming_events_helper.output_index_increment() + streaming_events_helper.is_output_item_added_sent = False + streaming_events_helper.is_text_sent = False + delta_text = "" + # Send delta events for ongoing content if delta_text: if delta_text.strip(): diff --git a/tests/unittest/llmapi/test_reasoning_parser.py b/tests/unittest/llmapi/test_reasoning_parser.py index 23517097eea..0c7639b74dd 100644 --- a/tests/unittest/llmapi/test_reasoning_parser.py +++ b/tests/unittest/llmapi/test_reasoning_parser.py @@ -160,3 +160,36 @@ def test_nano_v3_reasoning_parser_stream(delta_texts: list, content: list, print(f"delta_text: {delta_text}, result: {result}") assert result.content == content[i] assert result.reasoning_content == reasoning_context[i] + + +@pytest.mark.parametrize(("delta_texts", "finish_content", "finish_reasoning", + "chat_template_kwargs"), [ + (["a", "b"], "", "", None), + ([R1_END, "a", "b"], "", "", None), + (["a", R1_END, "b"], "", "", None), + (["a", "b"], "", "", { + "enable_thinking": False + }), + ([f"{R1_START}a", "b"], "", "", { + "enable_thinking": False + }), + (["a", "b"], "", "", { + "force_nonempty_content": False + }), + (["a", "b"], "ab", "", { + "force_nonempty_content": True + }), + ([R1_END, "a", "b"], "", "", { + "force_nonempty_content": True + }), + ]) +def test_nano_v3_reasoning_parser_finish(delta_texts: list, finish_content: str, + finish_reasoning: str, + chat_template_kwargs: dict): + reasoning_parser = ReasoningParserFactory.create_reasoning_parser( + "nano-v3", chat_template_kwargs) + for delta_text in delta_texts: + reasoning_parser.parse_delta(delta_text) + result = reasoning_parser.finish() + assert result.content == finish_content + assert result.reasoning_content == finish_reasoning