Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions tensorrt_llm/llmapi/reasoning_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ def parse(self, text: str) -> ReasoningParserResult:
def parse_delta(self, delta_text: str) -> ReasoningParserResult:
raise NotImplementedError

def finish(self) -> ReasoningParserResult:
"""Called when the stream ends. Subclasses may override to flush
buffered state or reclassify accumulated content. The default
implementation returns an empty result."""
return ReasoningParserResult()


@register_reasoning_parser("deepseek-r1", reasoning_at_start=True)
@register_reasoning_parser("qwen3")
Expand Down Expand Up @@ -185,6 +191,12 @@ def __init__(self,
"force_nonempty_content", False) is True
super().__init__(reasoning_at_start=reasoning_at_start,
chat_template_kwargs=chat_template_kwargs)
# Workaround: the model sometimes does not send closing think tags
# which affects downstream applications. This is addressed by
# optionally accumulating reasoning tokens and returning them as
# content at the end of streaming.
self._accumulated_reasoning = ""
self._found_closing_tag = False

def _maybe_swap_content(
self, result: ReasoningParserResult) -> ReasoningParserResult:
Expand All @@ -195,5 +207,50 @@ def _maybe_swap_content(
reasoning_content="")
return result

def parse_delta(self, delta_text: str) -> ReasoningParserResult:
"""Wraps the parent parse_delta to track accumulated reasoning when
force_nonempty_content is set. When the closing tag is found
(in_reasoning transitions from True to False), the accumulation
is cleared to free memory."""
was_in_reasoning = self.in_reasoning
result = super().parse_delta(delta_text)
if self._force_nonempty_content:
if result.reasoning_content:
self._accumulated_reasoning += result.reasoning_content
if was_in_reasoning and not self.in_reasoning:
self._found_closing_tag = True
self._accumulated_reasoning = ""
return result

def finish(self) -> ReasoningParserResult:
"""Called when the stream ends.

If no closing think tag was found and force_nonempty_content is
set, returns the full accumulated reasoning as content so the
response is never empty. If no closing tag was found and
force_nonempty_content is not set, returns any remaining buffer
as reasoning_content since we are still in reasoning mode.

If the closing tag was already found (or reasoning was never
entered), flushes any remaining buffer as content."""
if self.in_reasoning and not self._found_closing_tag:
remaining = self._buffer
self._buffer = ""
if self._force_nonempty_content:
all_content = self._accumulated_reasoning + remaining
self._accumulated_reasoning = ""
self.in_reasoning = False
return ReasoningParserResult(content=all_content)
self._accumulated_reasoning = ""
self.in_reasoning = False
if remaining:
return ReasoningParserResult(reasoning_content=remaining)
return ReasoningParserResult()
remaining = self._buffer
self._buffer = ""
if remaining:
return ReasoningParserResult(content=remaining)
return ReasoningParserResult()

def parse(self, text: str) -> ReasoningParserResult:
return self._maybe_swap_content(super().parse(text))
23 changes: 19 additions & 4 deletions tensorrt_llm/serve/postprocess_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from ..executor.result import Logprob, TokenLogprobs
from ..llmapi import SamplingParams
from ..llmapi.reasoning_parser import (BaseReasoningParser,
ReasoningParserFactory)
ReasoningParserFactory,
ReasoningParserResult)
from ..llmapi.tokenizer import TransformersTokenizer
# yapf: disable
from .chat_utils import make_tool_call_id
Expand Down Expand Up @@ -111,8 +112,11 @@ def create_logprobs(token_ids: List[int], tokenizer: TransformersTokenizer,
return chat_logprobs


def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str,
streaming: bool) -> Tuple[str, str]:
def apply_reasoning_parser(args: ChatPostprocArgs,
output_index: int,
text: str,
streaming: bool,
finished: bool = False) -> Tuple[str, str]:
reasoning_parser = None
if args.reasoning_parser is not None:
if output_index not in args.reasoning_parser_dict:
Expand All @@ -127,6 +131,13 @@ def apply_reasoning_parser(args: ChatPostprocArgs, output_index: int, text: str,
result = reasoning_parser.parse(text)
else:
result = reasoning_parser.parse_delta(text)
if finished:
finish_result = reasoning_parser.finish()
result = ReasoningParserResult(
content=result.content + finish_result.content,
reasoning_content=result.reasoning_content +
finish_result.reasoning_content,
)
content, reasoning_content = result.content, result.reasoning_content
else:
content, reasoning_content = text, ""
Expand Down Expand Up @@ -214,7 +225,11 @@ def yield_first_chat(num_tokens: int,
delta_text = output.text_diff

delta_text, reasoning_delta_text = apply_reasoning_parser(
args, i, delta_text, True)
args,
i,
delta_text,
True,
finished=(output.finish_reason is not None))

if args.tool_choice and type(
args.tool_choice) is ChatCompletionNamedToolChoiceParam:
Expand Down
51 changes: 50 additions & 1 deletion tensorrt_llm/serve/responses_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@
from tensorrt_llm.llmapi import SamplingParams
from tensorrt_llm.llmapi.llm import RequestOutput
from tensorrt_llm.llmapi.reasoning_parser import (BaseReasoningParser,
ReasoningParserFactory)
ReasoningParserFactory,
ReasoningParserResult)
from tensorrt_llm.llmapi.tokenizer import TokenizerBase, TransformersTokenizer
from tensorrt_llm.logger import logger
from tensorrt_llm.serve.chat_utils import parse_chat_messages_coroutines
Expand Down Expand Up @@ -927,6 +928,7 @@ def _apply_reasoning_parser(
text: str,
streaming: bool,
reasoning_parser_dict: Optional[dict[int, BaseReasoningParser]] = None,
finished: bool = False,
) -> Tuple[str, str]:
reasoning_parser: Optional[BaseReasoningParser] = None
if reasoning_parser_id is not None:
Expand All @@ -946,6 +948,13 @@ def _apply_reasoning_parser(
result = reasoning_parser.parse(text)
else:
result = reasoning_parser.parse_delta(text)
if finished:
finish_result = reasoning_parser.finish()
result = ReasoningParserResult(
content=result.content + finish_result.content,
reasoning_content=result.reasoning_content +
finish_result.reasoning_content,
)
content, reasoning_content = result.content, result.reasoning_content
else:
content, reasoning_content = text, ""
Expand Down Expand Up @@ -1490,6 +1499,14 @@ def _should_send_done_events(
should_send_reasoning_done = True
reasoning_content = full_reasoning

# No closing tag: reasoning was streamed but re-parse shows everything as
# content (no </think> found). Close the reasoning section so the text
# section can be properly opened and closed.
if not full_reasoning and full_text and finished_generation:
if streaming_events_helper and streaming_events_helper.is_reasoning_sent:
should_send_reasoning_done = True
reasoning_content = full_text

return should_send_reasoning_done, should_send_text_done, reasoning_content, text_content


Expand Down Expand Up @@ -1525,6 +1542,7 @@ def check_parser(parser_id: Optional[str],
text=delta_text,
streaming=True,
reasoning_parser_dict=reasoning_parser_dict,
finished=finished_generation,
)

if delta_text:
Expand Down Expand Up @@ -1595,6 +1613,37 @@ def check_parser(parser_id: Optional[str],
streaming_events_helper.is_output_item_added_sent = False
streaming_events_helper.is_text_sent = False

# Handle no-closing-tag case: reasoning was streamed but finish() moved
# all accumulated reasoning to content. Emit the full text section
# lifecycle (added → delta → done) since the reasoning section was just
# closed and generation is finished.
if (finished_generation and delta_text and should_send_reasoning_done
and not should_send_text_done):
streaming_events_helper.is_text_sent = True
yield from streaming_events_helper.get_message_output_added_events()
yield streaming_events_helper.get_text_delta_event(delta_text, [])
text_content_obj = ResponseOutputText(
text=delta_text,
annotations=[],
type="output_text",
logprobs=None,
)
text_item = ResponseOutputMessage(
id=streaming_events_helper.item_id,
content=[text_content_obj],
role="assistant",
status="completed",
type="message",
)
yield streaming_events_helper.get_text_done_event(delta_text, [])
yield streaming_events_helper.get_content_part_done_event(
text_content_obj)
yield streaming_events_helper.get_output_item_done_event(text_item)
streaming_events_helper.output_index_increment()
streaming_events_helper.is_output_item_added_sent = False
streaming_events_helper.is_text_sent = False
delta_text = ""

# Send delta events for ongoing content
if delta_text:
if delta_text.strip():
Expand Down
33 changes: 33 additions & 0 deletions tests/unittest/llmapi/test_reasoning_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,3 +160,36 @@ def test_nano_v3_reasoning_parser_stream(delta_texts: list, content: list,
print(f"delta_text: {delta_text}, result: {result}")
assert result.content == content[i]
assert result.reasoning_content == reasoning_context[i]


@pytest.mark.parametrize(("delta_texts", "finish_content", "finish_reasoning",
"chat_template_kwargs"), [
(["a", "b"], "", "", None),
([R1_END, "a", "b"], "", "", None),
(["a", R1_END, "b"], "", "", None),
(["a", "b"], "", "", {
"enable_thinking": False
}),
([f"{R1_START}a", "b"], "", "", {
"enable_thinking": False
}),
(["a", "b"], "", "", {
"force_nonempty_content": False
}),
(["a", "b"], "ab", "", {
"force_nonempty_content": True
}),
([R1_END, "a", "b"], "", "", {
"force_nonempty_content": True
}),
])
def test_nano_v3_reasoning_parser_finish(delta_texts: list, finish_content: str,
finish_reasoning: str,
chat_template_kwargs: dict):
reasoning_parser = ReasoningParserFactory.create_reasoning_parser(
"nano-v3", chat_template_kwargs)
for delta_text in delta_texts:
reasoning_parser.parse_delta(delta_text)
result = reasoning_parser.finish()
assert result.content == finish_content
assert result.reasoning_content == finish_reasoning
Loading