Skip to content

Commit afe6157

Browse files
authored
[None][fix] Harmony Parser Delta Grouping + Reuse Report + Better Test Coverage (NVIDIA#12467)
Signed-off-by: dongfengy <[email protected]>
1 parent fd7cc85 commit afe6157

File tree

3 files changed

+1296
-16
lines changed

3 files changed

+1296
-16
lines changed

tensorrt_llm/serve/harmony_adapter.py

Lines changed: 95 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
ChatCompletionStreamResponse,
2626
ChatCompletionToolsParam, ChatMessage,
2727
DeltaFunctionCall, DeltaMessage, DeltaToolCall,
28-
UsageInfo, to_disaggregated_params)
28+
PromptTokensDetails, UsageInfo,
29+
to_disaggregated_params)
2930

3031
# yapf: enable
3132

@@ -102,8 +103,13 @@ def process_token_batch(self, tokens: list[int]) -> list[dict[str, Any]]:
102103
"""
103104
Process a batch of tokens while maintaining parsing state.
104105
Returns OpenAI-compatible deltas for this batch.
106+
107+
Consecutive deltas of the same type (e.g., tool call arguments for the
108+
same function, reasoning tokens, content tokens) are merged into a
109+
single delta to reduce SSE overhead and avoid inflating client-side
110+
token counts with repeated JSON wrappers.
105111
"""
106-
deltas = []
112+
raw_deltas = []
107113
self.tokens_processed += len(tokens)
108114

109115
for token in tokens:
@@ -131,7 +137,7 @@ def process_token_batch(self, tokens: list[int]) -> list[dict[str, Any]]:
131137
# Send closing token for previous channel
132138
closing_delta = self._create_closing_token_delta()
133139
if closing_delta:
134-
deltas.append(closing_delta)
140+
raw_deltas.append(closing_delta)
135141

136142
# Reset channel state for new channel
137143
self.channel_started = False
@@ -141,9 +147,62 @@ def process_token_batch(self, tokens: list[int]) -> list[dict[str, Any]]:
141147
if self.parser.last_content_delta:
142148
delta = self._create_delta_from_parser_state()
143149
if delta:
144-
deltas.append(delta)
150+
raw_deltas.append(delta)
145151

146-
return deltas
152+
return self._merge_consecutive_deltas(raw_deltas)
153+
154+
@staticmethod
155+
def _merge_consecutive_deltas(
156+
deltas: list[dict[str, Any]]) -> list[dict[str, Any]]:
157+
"""Merge consecutive deltas of the same type to reduce SSE overhead.
158+
159+
For example, 20 consecutive tool_calls deltas with the same tool id
160+
are merged into 1 delta with concatenated arguments.
161+
"""
162+
if len(deltas) <= 1:
163+
return deltas
164+
165+
merged: list[dict[str, Any]] = []
166+
for delta in deltas:
167+
if not merged:
168+
merged.append(delta)
169+
continue
170+
171+
prev = merged[-1]
172+
173+
# Merge consecutive reasoning deltas
174+
if "reasoning" in delta and "reasoning" in prev and len(
175+
delta) == 1 and len(prev) == 1:
176+
prev["reasoning"] += delta["reasoning"]
177+
continue
178+
179+
# Merge consecutive content deltas (both must have same keys)
180+
if ("content" in delta and "content" in prev
181+
and delta.keys() == prev.keys()):
182+
prev["content"] += delta["content"]
183+
continue
184+
185+
# Merge consecutive tool_calls deltas for the same tool call
186+
if ("tool_calls" in delta and "tool_calls" in prev
187+
and "content" not in delta and "content" not in prev
188+
and "reasoning" not in delta and "reasoning" not in prev):
189+
prev_tc = prev["tool_calls"]
190+
curr_tc = delta["tool_calls"]
191+
# Both have exactly 1 tool call with the same id
192+
if (len(prev_tc) == 1 and len(curr_tc) == 1
193+
and prev_tc[0].get("id") == curr_tc[0].get("id")):
194+
# Concatenate arguments
195+
prev_args = prev_tc[0].get("function",
196+
{}).get("arguments", "")
197+
curr_args = curr_tc[0].get("function",
198+
{}).get("arguments", "")
199+
prev_tc[0].setdefault(
200+
"function", {})["arguments"] = prev_args + curr_args
201+
continue
202+
203+
merged.append(delta)
204+
205+
return merged
147206

148207
def process_token_batch_to_messages(self,
149208
tokens: list[int]) -> list[Message]:
@@ -1573,10 +1632,15 @@ def get_harmony_adapter() -> HarmonyAdapter:
15731632

15741633

15751634
def handle_streaming_response(tools: List[ChatCompletionToolsParam],
1576-
tool_choice: str, result: GenerationResult,
1577-
model: str, request_id: str, done: bool,
1635+
tool_choice: str,
1636+
result: GenerationResult,
1637+
model: str,
1638+
request_id: str,
1639+
done: bool,
15781640
num_prompt_tokens: int,
1579-
first_iteration: bool) -> List[str]:
1641+
first_iteration: bool,
1642+
stream_options=None,
1643+
cached_tokens: int = 0) -> List[str]:
15801644
output = result.outputs[0]
15811645

15821646
# Convert tools to dictionary format for harmony adapter (standard pattern)
@@ -1591,12 +1655,20 @@ def handle_streaming_response(tools: List[ChatCompletionToolsParam],
15911655
else:
15921656
tools_for_parser = tools_dict
15931657

1658+
include_usage = True
1659+
if stream_options is not None:
1660+
include_usage = stream_options.include_usage
1661+
15941662
def end_streaming(res):
15951663
# Clean up state
15961664
harmony_adapter.cleanup_stream_state(request_id)
15971665

1666+
if not include_usage:
1667+
return
1668+
15981669
# Append usage info
1599-
usage_info = _create_usage_info(num_prompt_tokens, result.outputs)
1670+
usage_info = _create_usage_info(num_prompt_tokens, result.outputs,
1671+
cached_tokens)
16001672

16011673
final_usage_chunk = ChatCompletionStreamResponse(choices=[],
16021674
model=model,
@@ -1688,8 +1760,11 @@ def end_streaming(res):
16881760

16891761

16901762
def handle_non_streaming_response(tools: List[ChatCompletionToolsParam],
1691-
tool_choice: str, outputs: List, model: str,
1692-
num_prompt_tokens: int):
1763+
tool_choice: str,
1764+
outputs: List,
1765+
model: str,
1766+
num_prompt_tokens: int,
1767+
cached_tokens: int = 0):
16931768
"""Handle non-streaming response with harmony format."""
16941769
# Parse harmony output to OpenAI format
16951770
# Convert tools to dictionary format for harmony adapter (standard pattern)
@@ -1736,7 +1811,7 @@ def handle_non_streaming_response(tools: List[ChatCompletionToolsParam],
17361811
response_message = {"role": "assistant", "content": ""}
17371812

17381813
# Create usage info from metrics (RequestOutput doesn't have usage in v1)
1739-
usage_info = _create_usage_info(num_prompt_tokens, outputs)
1814+
usage_info = _create_usage_info(num_prompt_tokens, outputs, cached_tokens)
17401815

17411816
# Create response
17421817
response = ChatCompletionResponse(
@@ -1783,15 +1858,19 @@ def _determine_finish_reason(parsed_output: dict[str, Any],
17831858
return reason
17841859

17851860

1786-
def _create_usage_info(num_prompt_tokens, outputs) -> UsageInfo:
1861+
def _create_usage_info(num_prompt_tokens,
1862+
outputs,
1863+
cached_tokens: int = 0) -> UsageInfo:
17871864
"""Create usage info from RequestOutput following serving_chat.py pattern."""
17881865
# Calculate completion tokens from all outputs
17891866
num_generated_tokens = sum(len(output.token_ids) for output in outputs)
17901867

17911868
# Create usage info
1792-
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
1793-
completion_tokens=num_generated_tokens,
1794-
total_tokens=num_prompt_tokens + num_generated_tokens)
1869+
usage = UsageInfo(
1870+
prompt_tokens=num_prompt_tokens,
1871+
completion_tokens=num_generated_tokens,
1872+
total_tokens=num_prompt_tokens + num_generated_tokens,
1873+
prompt_tokens_details=PromptTokensDetails(cached_tokens=cached_tokens))
17951874
return usage
17961875

17971876

tensorrt_llm/serve/postprocess_handlers.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,7 @@ class ChatCompletionPostprocArgs(PostprocArgs):
571571
tool_choice: Optional[Union[Literal["none", "auto"],
572572
ChatCompletionNamedToolChoiceParam]]
573573
request_id: Optional[int] = None
574+
stream_options: Optional[StreamOptions] = None
574575
chat_template_kwargs: Optional[dict[str, Any]] = None
575576

576577
@classmethod
@@ -579,6 +580,7 @@ def from_request(cls, request: ChatCompletionRequest):
579580
model=request.model,
580581
tools=request.tools,
581582
tool_choice=request.tool_choice,
583+
stream_options=request.stream_options if request.stream else None,
582584
chat_template_kwargs=request.chat_template_kwargs,
583585
)
584586

@@ -593,6 +595,7 @@ def chat_harmony_post_processor(
593595
outputs=rsp.outputs,
594596
model=args.model,
595597
num_prompt_tokens=args.num_prompt_tokens,
598+
cached_tokens=rsp.cached_tokens,
596599
)
597600
return response
598601

@@ -614,6 +617,8 @@ def chat_harmony_streaming_post_processor(
614617
done=rsp._done,
615618
num_prompt_tokens=args.num_prompt_tokens,
616619
first_iteration=args.first_iteration,
620+
stream_options=args.stream_options,
621+
cached_tokens=rsp.cached_tokens,
617622
)
618623
args.first_iteration = False
619624
return response

0 commit comments

Comments
 (0)