Skip to content

Commit 02b912d

Browse files
committed
fix: extract hosted response enum roles
1 parent 1c38020 commit 02b912d

2 files changed

Lines changed: 115 additions & 6 deletions

File tree

lib/src/holiday_peak_lib/agents/hosted.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ def run(
118118
managed by the host server, and we do not yet pass extra client
119119
kwargs through to ``handle()``.
120120
"""
121-
_ = session, kwargs
121+
messages = _resolve_run_messages(messages, kwargs)
122+
_ = session
122123
if stream:
123124
return self._run_streaming(messages)
124125
return self._run_once(messages)
@@ -185,12 +186,27 @@ def _extract_user_text(messages: Any) -> str:
185186
return ""
186187

187188

189+
def _resolve_run_messages(messages: Any, kwargs: dict[str, Any]) -> Any:
190+
"""Return the message payload from known ``SupportsAgentRun`` call shapes."""
191+
if messages is not None:
192+
return messages
193+
for field_name in ("messages", "input", "inputs"):
194+
value = kwargs.get(field_name)
195+
if value is not None:
196+
return value
197+
return messages
198+
199+
188200
def _message_role(message: Any) -> str | None:
189201
"""Return a normalized role from dict-like or object-like messages."""
190202
role = _message_field(message, "role")
191203
if role is None:
192204
return None
193-
return str(role).lower()
205+
role_value = getattr(role, "value", role)
206+
normalized = str(role_value).lower()
207+
if normalized.startswith("messagerole."):
208+
return normalized.rsplit(".", maxsplit=1)[-1]
209+
return normalized
194210

195211

196212
def _message_field(message: Any, field_name: str) -> Any:

lib/tests/test_agents_hosted.py

Lines changed: 97 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@
22
33
These tests cover the *translation logic* (free-form text -> handle dict ->
44
AgentResponse) without requiring the optional
5-
``agent-framework-foundry-hosting`` package to be installed. End-to-end
6-
mount tests live under ``.tmp/probe_mount.py`` and the planned
7-
``inventory-health-check`` integration test once the SDK is available in
8-
the lib venv.
5+
``agent-framework-foundry-hosting`` package to be installed. When that SDK is
6+
present, the mounted ``/responses`` path is also exercised through FastAPI's
7+
``TestClient``.
98
"""
109

1110
from __future__ import annotations
@@ -19,6 +18,7 @@
1918
_extract_text_from_handle_result,
2019
_extract_user_text,
2120
_HostedAgentRunAdapter,
21+
_resolve_run_messages,
2222
)
2323

2424

@@ -52,6 +52,12 @@ def __init__(
5252
self.input = input_value
5353

5454

55+
class _EnumRoleLike:
56+
"""Minimal enum-like role object matching SDK role.value behavior."""
57+
58+
value = "user"
59+
60+
5561
def test_extract_user_text_pulls_last_text_message() -> None:
5662
msgs = [
5763
Message(role="user", contents=[Content(type="text", text="earlier")]),
@@ -63,6 +69,21 @@ def test_extract_user_text_pulls_last_text_message() -> None:
6369
assert _extract_user_text(msgs) == "latest input text"
6470

6571

72+
def test_extract_user_text_reads_content_from_text_contract() -> None:
73+
messages = [Message(role="user", contents=[Content.from_text("check health for SKU-1234")])]
74+
75+
assert _extract_user_text(messages) == "check health for SKU-1234"
76+
77+
78+
def test_extract_user_text_handles_enum_like_user_role() -> None:
79+
message = _MessageLike(
80+
role=_EnumRoleLike(),
81+
contents=[Content.from_text("check health for SKU-1234")],
82+
)
83+
84+
assert _extract_user_text([message]) == "check health for SKU-1234"
85+
86+
6687
@pytest.mark.parametrize(
6788
("message", "expected"),
6889
[
@@ -121,6 +142,13 @@ def test_extract_user_text_handles_empty_inputs() -> None:
121142
assert _extract_user_text([]) == ""
122143

123144

145+
def test_resolve_run_messages_accepts_sdk_keyword_messages() -> None:
146+
messages = [Message(role="user", contents=["kwarg"])]
147+
148+
assert _resolve_run_messages(None, {"messages": messages}) is messages
149+
assert _resolve_run_messages(["positional"], {"messages": messages}) == ["positional"]
150+
151+
124152
def test_extract_text_from_handle_result_prefers_known_keys() -> None:
125153
assert _extract_text_from_handle_result({"text": "t-value"}) == "t-value"
126154
assert _extract_text_from_handle_result({"response": "r-value"}) == "r-value"
@@ -167,6 +195,23 @@ async def translator(text: str) -> dict[str, Any]:
167195
assert text == "hello-from-handle"
168196

169197

198+
@pytest.mark.asyncio
199+
async def test_hosted_run_adapter_uses_messages_from_kwargs() -> None:
200+
agent = _RecordingAgent()
201+
202+
async def translator(text: str) -> dict[str, Any]:
203+
return {"prompt": text}
204+
205+
adapter = _HostedAgentRunAdapter(agent, translator)
206+
response = await adapter.run(messages=[Message(role="user", contents=["positional"])])
207+
assert agent.last_request == {"prompt": "positional"}
208+
assert response.messages
209+
210+
await adapter.run(input=[Message(role="user", contents=["kwarg-input"])])
211+
212+
assert agent.last_request == {"prompt": "kwarg-input"}
213+
214+
170215
@pytest.mark.asyncio
171216
async def test_hosted_run_adapter_streams_single_update() -> None:
172217
"""``run(stream=True)`` must return an async iterator (NOT a coroutine).
@@ -301,3 +346,51 @@ def test_serve_hosted_honors_explicit_prefix_when_sdk_present() -> None:
301346
}
302347
paths.discard(None)
303348
assert "/v1/responses" in paths
349+
350+
351+
@pytest.mark.parametrize(
352+
"body",
353+
[
354+
{"model": "inventory-health-check", "input": "check health for SKU-1234"},
355+
{
356+
"model": "inventory-health-check",
357+
"input": [
358+
{
359+
"type": "message",
360+
"role": "user",
361+
"content": "check health for SKU-1234",
362+
}
363+
],
364+
},
365+
{
366+
"model": "inventory-health-check",
367+
"input": [
368+
{
369+
"type": "message",
370+
"role": "user",
371+
"content": [{"type": "input_text", "text": "check health for SKU-1234"}],
372+
}
373+
],
374+
},
375+
],
376+
)
377+
def test_serve_hosted_responses_post_preserves_prompt_when_sdk_present(
378+
body: dict[str, Any],
379+
) -> None:
380+
pytest.importorskip("agent_framework_foundry_hosting")
381+
from fastapi import FastAPI
382+
from fastapi.testclient import TestClient
383+
384+
agent = _RecordingAgent()
385+
386+
async def translator(text: str) -> dict[str, Any]:
387+
return {"prompt": text}
388+
389+
app = FastAPI()
390+
agent.serve_hosted(app, request_translator=translator)
391+
client = TestClient(app)
392+
393+
response = client.post("/responses", json=body)
394+
395+
assert response.status_code == 200
396+
assert agent.last_request == {"prompt": "check health for SKU-1234"}

0 commit comments

Comments
 (0)