Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 95 additions & 16 deletions tensorrt_llm/serve/harmony_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
ChatCompletionStreamResponse,
ChatCompletionToolsParam, ChatMessage,
DeltaFunctionCall, DeltaMessage, DeltaToolCall,
UsageInfo, to_disaggregated_params)
PromptTokensDetails, UsageInfo,
to_disaggregated_params)

# yapf: enable

Expand Down Expand Up @@ -102,8 +103,13 @@ def process_token_batch(self, tokens: list[int]) -> list[dict[str, Any]]:
"""
Process a batch of tokens while maintaining parsing state.
Returns OpenAI-compatible deltas for this batch.

Consecutive deltas of the same type (e.g., tool call arguments for the
same function, reasoning tokens, content tokens) are merged into a
single delta to reduce SSE overhead and avoid inflating client-side
token counts with repeated JSON wrappers.
"""
deltas = []
raw_deltas = []
self.tokens_processed += len(tokens)

for token in tokens:
Expand Down Expand Up @@ -131,7 +137,7 @@ def process_token_batch(self, tokens: list[int]) -> list[dict[str, Any]]:
# Send closing token for previous channel
closing_delta = self._create_closing_token_delta()
if closing_delta:
deltas.append(closing_delta)
raw_deltas.append(closing_delta)

# Reset channel state for new channel
self.channel_started = False
Expand All @@ -141,9 +147,62 @@ def process_token_batch(self, tokens: list[int]) -> list[dict[str, Any]]:
if self.parser.last_content_delta:
delta = self._create_delta_from_parser_state()
if delta:
deltas.append(delta)
raw_deltas.append(delta)

return deltas
return self._merge_consecutive_deltas(raw_deltas)

@staticmethod
def _merge_consecutive_deltas(
deltas: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Merge consecutive deltas of the same type to reduce SSE overhead.

For example, 20 consecutive tool_calls deltas with the same tool id
are merged into 1 delta with concatenated arguments.
"""
if len(deltas) <= 1:
return deltas

merged: list[dict[str, Any]] = []
for delta in deltas:
if not merged:
merged.append(delta)
continue

prev = merged[-1]

# Merge consecutive reasoning deltas
if "reasoning" in delta and "reasoning" in prev and len(
delta) == 1 and len(prev) == 1:
prev["reasoning"] += delta["reasoning"]
continue

# Merge consecutive content deltas (both must have same keys)
if ("content" in delta and "content" in prev
and delta.keys() == prev.keys()):
prev["content"] += delta["content"]
continue

# Merge consecutive tool_calls deltas for the same tool call
if ("tool_calls" in delta and "tool_calls" in prev
and "content" not in delta and "content" not in prev
and "reasoning" not in delta and "reasoning" not in prev):
prev_tc = prev["tool_calls"]
curr_tc = delta["tool_calls"]
# Both have exactly 1 tool call with the same id
if (len(prev_tc) == 1 and len(curr_tc) == 1
and prev_tc[0].get("id") == curr_tc[0].get("id")):
# Concatenate arguments
prev_args = prev_tc[0].get("function",
{}).get("arguments", "")
curr_args = curr_tc[0].get("function",
{}).get("arguments", "")
prev_tc[0].setdefault(
"function", {})["arguments"] = prev_args + curr_args
continue

merged.append(delta)

return merged

def process_token_batch_to_messages(self,
tokens: list[int]) -> list[Message]:
Expand Down Expand Up @@ -1573,10 +1632,15 @@ def get_harmony_adapter() -> HarmonyAdapter:


def handle_streaming_response(tools: List[ChatCompletionToolsParam],
tool_choice: str, result: GenerationResult,
model: str, request_id: str, done: bool,
tool_choice: str,
result: GenerationResult,
model: str,
request_id: str,
done: bool,
num_prompt_tokens: int,
first_iteration: bool) -> List[str]:
first_iteration: bool,
stream_options=None,
cached_tokens: int = 0) -> List[str]:
output = result.outputs[0]

# Convert tools to dictionary format for harmony adapter (standard pattern)
Expand All @@ -1591,12 +1655,20 @@ def handle_streaming_response(tools: List[ChatCompletionToolsParam],
else:
tools_for_parser = tools_dict

include_usage = True
if stream_options is not None:
include_usage = stream_options.include_usage

def end_streaming(res):
# Clean up state
harmony_adapter.cleanup_stream_state(request_id)

if not include_usage:
return

# Append usage info
usage_info = _create_usage_info(num_prompt_tokens, result.outputs)
usage_info = _create_usage_info(num_prompt_tokens, result.outputs,
cached_tokens)

final_usage_chunk = ChatCompletionStreamResponse(choices=[],
model=model,
Expand Down Expand Up @@ -1688,8 +1760,11 @@ def end_streaming(res):


def handle_non_streaming_response(tools: List[ChatCompletionToolsParam],
tool_choice: str, outputs: List, model: str,
num_prompt_tokens: int):
tool_choice: str,
outputs: List,
model: str,
num_prompt_tokens: int,
cached_tokens: int = 0):
"""Handle non-streaming response with harmony format."""
# Parse harmony output to OpenAI format
# Convert tools to dictionary format for harmony adapter (standard pattern)
Expand Down Expand Up @@ -1736,7 +1811,7 @@ def handle_non_streaming_response(tools: List[ChatCompletionToolsParam],
response_message = {"role": "assistant", "content": ""}

# Create usage info from metrics (RequestOutput doesn't have usage in v1)
usage_info = _create_usage_info(num_prompt_tokens, outputs)
usage_info = _create_usage_info(num_prompt_tokens, outputs, cached_tokens)

# Create response
response = ChatCompletionResponse(
Expand Down Expand Up @@ -1783,15 +1858,19 @@ def _determine_finish_reason(parsed_output: dict[str, Any],
return reason


def _create_usage_info(num_prompt_tokens, outputs) -> UsageInfo:
def _create_usage_info(num_prompt_tokens,
outputs,
cached_tokens: int = 0) -> UsageInfo:
"""Create usage info from RequestOutput following serving_chat.py pattern."""
# Calculate completion tokens from all outputs
num_generated_tokens = sum(len(output.token_ids) for output in outputs)

# Create usage info
usage = UsageInfo(prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
prompt_tokens_details=PromptTokensDetails(cached_tokens=cached_tokens))
return usage


Expand Down
5 changes: 5 additions & 0 deletions tensorrt_llm/serve/postprocess_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,7 @@ class ChatCompletionPostprocArgs(PostprocArgs):
tool_choice: Optional[Union[Literal["none", "auto"],
ChatCompletionNamedToolChoiceParam]]
request_id: Optional[int] = None
stream_options: Optional[StreamOptions] = None
chat_template_kwargs: Optional[dict[str, Any]] = None

@classmethod
Expand All @@ -579,6 +580,7 @@ def from_request(cls, request: ChatCompletionRequest):
model=request.model,
tools=request.tools,
tool_choice=request.tool_choice,
stream_options=request.stream_options if request.stream else None,
chat_template_kwargs=request.chat_template_kwargs,
)

Expand All @@ -593,6 +595,7 @@ def chat_harmony_post_processor(
outputs=rsp.outputs,
model=args.model,
num_prompt_tokens=args.num_prompt_tokens,
cached_tokens=rsp.cached_tokens,
)
return response

Expand All @@ -614,6 +617,8 @@ def chat_harmony_streaming_post_processor(
done=rsp._done,
num_prompt_tokens=args.num_prompt_tokens,
first_iteration=args.first_iteration,
stream_options=args.stream_options,
cached_tokens=rsp.cached_tokens,
)
args.first_iteration = False
return response
Expand Down
Loading
Loading