Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions LLM/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,8 @@ def process(self, prompt: str | tuple) -> Iterator[tuple]:
yield ("__END_OF_RESPONSE__", None, None)

def on_session_end(self) -> None:
# reset() also clears init_chat_message, so a previous session's
# instructions cannot persist into the next one.
self.chat.reset()
self._last_instructions = None
self.tools = None
Expand Down
132 changes: 99 additions & 33 deletions LLM/openai_api_language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
from LLM.chat import Chat
from LLM.utils import remove_unspeechable
from api.openai_realtime.runtime_config import RuntimeConfig
from LLM.tool_call.qwen3coder_tool_parser import (
Qwen3CoderToolParser,
process_printable_text_qwen_xml,
strip_qwen_tool_markup_for_chat,
)
from LLM.voice_prompt import build_voice_system_prompt

logger = logging.getLogger(__name__)
Expand All @@ -27,6 +32,26 @@
"ko": "korean",
}

def _vllm_normalize_list_part(part: dict) -> dict:
"""
Normalize input_image part to detail="auto"
"""
t = part.get("type")
if t == "input_image":
part["detail"] = "auto"
return part


def _vllm_normalize_content(content: list | str) -> list[dict]:
"""Normalize chat rows for strict vLLM Responses API responses validators."""
if isinstance(content, list):
return [
_vllm_normalize_list_part(p) if isinstance(p, dict) else p
for p in content
]
else:
return content


class OpenApiModelHandler(BaseHandler):
"""
Expand Down Expand Up @@ -114,7 +139,7 @@ def process(self, prompt):
# Generation is deferred until __GENERATE_RESPONSE__ (from response.create).
if isinstance(prompt, tuple) and len(prompt) == 3 and prompt[0] == "__ADD_TO_CONTEXT__":
_, role, content = prompt
self.chat.append({"type": "message", "role": role, "content": content})
self.chat.append({"type": "message", "role": role, "content": _vllm_normalize_content(content)})
logger.debug("Added to LLM context (role=%s)", role)
return

Expand Down Expand Up @@ -156,16 +181,14 @@ def process(self, prompt):
})

optional_kwargs = {}
parser = None
if self.tools is not None:
optional_kwargs["tools"] = self.tools
parser = Qwen3CoderToolParser(tools=self.tools)
if self.tools_choice is not None:
optional_kwargs["tool_choice"] = self.tools_choice

# CancelScope.is_stale(gen) is checked when the stream iterator advances; a
# blocked read inside httpx cannot be aborted by cancel_scope.cancel() from
# the websocket router. Mitigations: request_timeout_s / ReadTimeout. A future
# option is to run this API call in a child process and terminate() on session
# end (IPC and lifecycle cost).
request_stream = self.stream and self.tools_choice != "required"
gen = self.cancel_scope.generation if self.cancel_scope else None
response: Response | Stream[ResponseStreamEvent] | None = None
tools: list[dict[str, str]] = []
Expand All @@ -176,12 +199,12 @@ def process(self, prompt):
response = self.client.responses.create(
model=self.model_name,
input=self.chat.to_list(),
stream=self.stream,
stream=request_stream,
extra_body=self._extra_body,
timeout=self.request_timeout,
**optional_kwargs,
)
if self.stream:
if request_stream:
cancelled = False
printable_text = ""
for event in response:
Expand All @@ -190,35 +213,48 @@ def process(self, prompt):
cancelled = True
break
if event.type == "response.output_text.delta":
new_text = remove_unspeechable(event.delta)
new_text = event.delta
clean_text += new_text
printable_text += new_text
sentences = sent_tokenize(printable_text)
if len(sentences) > 1:
for s in sentences[:-1]:
yield s, language_code, []
printable_text = sentences[-1]
if parser is not None:
chunks, tools, printable_text = process_printable_text_qwen_xml(
printable_text, tools, parser,
)
for s in chunks:
yield remove_unspeechable(s), language_code, []
else:
sentences = sent_tokenize(printable_text)
if len(sentences) > 1:
for s in sentences[:-1]:
yield remove_unspeechable(s), language_code, []
printable_text = sentences[-1]
elif event.type == "response.output_item.done":

if event.item.type == "function_call":
tools.append(event.item.model_dump())
elif event.item.type == "message":
self.chat.append({
"type": "message",
"role": event.item.role,
"content": event.item.content,
})
elif event.type == "response.completed":
usage = getattr(event.response, "usage", None)
if usage:
input_tokens = usage.input_tokens or 0
output_tokens = usage.output_tokens or 0
if not cancelled:
if printable_text.strip() or tools:
assistant_speech = remove_unspeechable(
strip_qwen_tool_markup_for_chat(clean_text),
)
if assistant_speech:
self.chat.append({
"type": "message",
"role": "assistant",
"content": assistant_speech,
})
printable_text = remove_unspeechable(
strip_qwen_tool_markup_for_chat(printable_text).strip(),
)
if printable_text or tools:
logger.debug(f"Clean text: {clean_text}")
logger.info(f"Tools: {tools}")
yield printable_text.strip(), language_code, tools
logger.debug(f"Tools: {tools}")
yield printable_text, language_code, tools
else:
# Non-streaming Response (stream=False or tool_choice forces sync API).
if gen is not None and self.cancel_scope.is_stale(gen):
logger.info("LLM generation cancelled (interruption)")
else:
Expand All @@ -230,20 +266,48 @@ def process(self, prompt):
if message.type == "function_call":
tools.append(message.model_dump())
elif message.type == "message":
self.chat.append({
"type": "message",
"role": message.role,
"content": message.content,
})
for chunk in message.content:
if chunk.type == "output_text":
clean_text += remove_unspeechable(chunk.text)
else:
logger.warning(f"Not supported message type: {message.type}")
logger.debug(f"Clean text: {clean_text}")
logger.info(f"Tools: {tools}")
if clean_text.strip() or tools:
yield clean_text.strip(), language_code, tools
if parser is not None:
chunks, tools, printable_text = process_printable_text_qwen_xml(
clean_text, tools, parser,
)
chunk_parts = [remove_unspeechable(s).strip() for s in chunks]
chunk_joined = " ".join(p for p in chunk_parts if p)
printable_text = remove_unspeechable(
strip_qwen_tool_markup_for_chat(printable_text).strip(),
)
combined = " ".join(
p for p in (chunk_joined, printable_text) if p
).strip()
assistant_speech = remove_unspeechable(
strip_qwen_tool_markup_for_chat(clean_text),
)
if assistant_speech:
self.chat.append({
"type": "message",
"role": "assistant",
"content": assistant_speech,
})
logger.debug(f"Clean text: {clean_text}")
logger.info(f"Tools: {tools}")
if combined or tools:
yield combined, language_code, tools
else:
logger.debug(f"Clean text: {clean_text}")
logger.info(f"Tools: {tools}")
clean_text = remove_unspeechable(clean_text)
if clean_text.strip():
self.chat.append({
"type": "message",
"role": "assistant",
"content": clean_text.strip(),
})
if clean_text.strip() or tools:
yield clean_text.strip(), language_code, tools
except httpx.ReadTimeout:
logger.warning(
"OpenAI API read timed out after %.1fs; ending the current response",
Expand All @@ -263,6 +327,8 @@ def process(self, prompt):
yield ("__END_OF_RESPONSE__", None, None)

def on_session_end(self):
# reset() also clears init_chat_message, so a previous session's
# instructions cannot persist into the next one.
self.chat.reset()
self._last_instructions = None
self.tools = None
Expand Down
Loading