Skip to content

Commit 1c38020

Browse files
committed
fix: preserve hosted-agent response input text
1 parent cd48f46 commit 1c38020

2 files changed

Lines changed: 158 additions & 9 deletions

File tree

lib/src/holiday_peak_lib/agents/hosted.py

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -171,15 +171,91 @@ def get_session(self, service_session_id: str, *, session_id: str | None = None)
171171

172172

173173
def _extract_user_text(messages: Any) -> str:
174-
"""Pull the most recent user-message text from a MAF message sequence."""
174+
"""Pull the most recent user-message text from MAF/OpenAI shapes."""
175175
if not messages:
176176
return ""
177-
seq = messages if isinstance(messages, (list, tuple)) else [messages]
178-
for msg in reversed(seq):
179-
for content in getattr(msg, "contents", None) or []:
180-
text = getattr(content, "text", None)
177+
sequence = messages if isinstance(messages, (list, tuple)) else [messages]
178+
for message in reversed(sequence):
179+
role = _message_role(message)
180+
if role is not None and role != "user":
181+
continue
182+
text = _extract_message_text(message)
183+
if text:
184+
return text
185+
return ""
186+
187+
188+
def _message_role(message: Any) -> str | None:
189+
"""Return a normalized role from dict-like or object-like messages."""
190+
role = _message_field(message, "role")
191+
if role is None:
192+
return None
193+
return str(role).lower()
194+
195+
196+
def _message_field(message: Any, field_name: str) -> Any:
197+
"""Read one field from a dict-like or object-like message."""
198+
if isinstance(message, dict):
199+
return message.get(field_name)
200+
return getattr(message, field_name, None)
201+
202+
203+
def _extract_message_text(message: Any) -> str:
204+
"""Return the first non-empty text value from a single message."""
205+
if isinstance(message, str):
206+
return message
207+
208+
for field_name in ("contents", "content"):
209+
text = _extract_content_text(_message_field(message, field_name))
210+
if text:
211+
return text
212+
213+
text = _extract_input_text(_message_field(message, "input"))
214+
if text:
215+
return text
216+
217+
return _extract_content_text(_message_field(message, "text"))
218+
219+
220+
def _extract_input_text(value: Any) -> str:
221+
"""Extract text from Responses ``input`` values."""
222+
if _looks_like_message_sequence(value):
223+
return _extract_user_text(value)
224+
return _extract_content_text(value)
225+
226+
227+
def _looks_like_message_sequence(value: Any) -> bool:
228+
"""Return whether a value looks like a list of chat messages."""
229+
if not isinstance(value, (list, tuple)):
230+
return False
231+
return any(
232+
_message_field(item, "role") is not None
233+
or _message_field(item, "contents") is not None
234+
or _message_field(item, "content") is not None
235+
for item in value
236+
)
237+
238+
239+
def _extract_content_text(value: Any) -> str:
240+
"""Return the first non-empty text value from content parts."""
241+
if value is None:
242+
return ""
243+
if isinstance(value, str):
244+
return value
245+
if isinstance(value, (list, tuple)):
246+
for item in value:
247+
text = _extract_content_text(item)
181248
if text:
182-
return str(text)
249+
return text
250+
return ""
251+
for field_name in ("contents", "content", "input", "text"):
252+
field_value = _message_field(value, field_name)
253+
if field_name == "input":
254+
text = _extract_input_text(field_value)
255+
else:
256+
text = _extract_content_text(field_value)
257+
if text:
258+
return text
183259
return ""
184260

185261

lib/tests/test_agents_hosted.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from typing import Any
1414

1515
import pytest
16-
from agent_framework import Message
16+
from agent_framework import Content, Message
1717
from holiday_peak_lib.agents.base_agent import AgentDependencies, BaseRetailAgent
1818
from holiday_peak_lib.agents.hosted import (
1919
_extract_text_from_handle_result,
@@ -35,14 +35,87 @@ async def handle(self, request: dict[str, Any]) -> dict[str, Any]:
3535
return self.next_response
3636

3737

38+
class _MessageLike:
39+
def __init__(
40+
self,
41+
*,
42+
role: str | None = "user",
43+
contents: list[Any] | None = None,
44+
content: Any = None,
45+
text: str | None = None,
46+
input_value: Any = None,
47+
) -> None:
48+
self.role = role
49+
self.contents = contents
50+
self.content = content
51+
self.text = text
52+
self.input = input_value
53+
54+
3855
def test_extract_user_text_pulls_last_text_message() -> None:
3956
msgs = [
40-
Message(role="user", contents=["earlier"]),
41-
Message(role="user", contents=["latest input text"]),
57+
Message(role="user", contents=[Content(type="text", text="earlier")]),
58+
Message(
59+
role="user",
60+
contents=[Content(type="text", text="latest input text")],
61+
),
4262
]
4363
assert _extract_user_text(msgs) == "latest input text"
4464

4565

66+
@pytest.mark.parametrize(
67+
("message", "expected"),
68+
[
69+
(_MessageLike(content="plain object content"), "plain object content"),
70+
(
71+
_MessageLike(content=[{"type": "input_text", "text": "object part"}]),
72+
"object part",
73+
),
74+
({"role": "user", "content": "plain dict content"}, "plain dict content"),
75+
(
76+
{"role": "user", "content": [{"type": "input_text", "text": "dict part"}]},
77+
"dict part",
78+
),
79+
(_MessageLike(role=None, text="object direct text"), "object direct text"),
80+
({"text": "dict direct text"}, "dict direct text"),
81+
({"input": "direct input text"}, "direct input text"),
82+
(
83+
{
84+
"input": [
85+
{"role": "user", "content": "earlier nested input"},
86+
{
87+
"role": "user",
88+
"content": [{"type": "input_text", "text": "latest nested input"}],
89+
},
90+
]
91+
},
92+
"latest nested input",
93+
),
94+
],
95+
)
96+
def test_extract_user_text_handles_common_maf_and_openai_shapes(
97+
message: Any, expected: str
98+
) -> None:
99+
assert _extract_user_text(message) == expected
100+
101+
102+
def test_extract_user_text_prefers_most_recent_user_message() -> None:
103+
messages = [
104+
{"role": "user", "content": "older user text"},
105+
{"role": "assistant", "content": "assistant text should be ignored"},
106+
{"role": "user", "content": "latest user text"},
107+
]
108+
assert _extract_user_text(messages) == "latest user text"
109+
110+
111+
def test_extract_user_text_skips_later_non_user_messages() -> None:
112+
messages = [
113+
{"role": "user", "content": "latest user text"},
114+
{"role": "assistant", "content": "assistant text should be ignored"},
115+
]
116+
assert _extract_user_text(messages) == "latest user text"
117+
118+
46119
def test_extract_user_text_handles_empty_inputs() -> None:
47120
assert _extract_user_text(None) == ""
48121
assert _extract_user_text([]) == ""

0 commit comments

Comments
 (0)