Skip to content

Commit 468d098

Browse files
authored
feat(sdk): allow Conversation.switch_profile to accept an inline LLM (#3018)
1 parent 6ffda69 commit 468d098

4 files changed

Lines changed: 209 additions & 11 deletions

File tree

openhands-agent-server/openhands/agent_server/conversation_router.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,28 @@ async def switch_conversation_profile(
319319
return Success()
320320

321321

322+
@conversation_router.post(
323+
"/{conversation_id}/switch_llm",
324+
responses={404: {"description": "Conversation not found"}},
325+
)
326+
async def switch_conversation_llm(
327+
conversation_id: UUID,
328+
llm: LLM = Body(..., embed=True), # noqa: B008
329+
conversation_service: ConversationService = Depends(get_conversation_service),
330+
) -> Success:
331+
"""Swap the conversation's LLM to a caller-supplied object.
332+
333+
Used by app-servers that own the LLM directly and don't push profiles
334+
to the agent-server's filesystem (see #3017).
335+
"""
336+
event_service = await conversation_service.get_event_service(conversation_id)
337+
if event_service is None:
338+
raise HTTPException(status.HTTP_404_NOT_FOUND)
339+
conversation = event_service.get_conversation()
340+
conversation.switch_llm(llm)
341+
return Success()
342+
343+
322344
@conversation_router.patch(
323345
"/{conversation_id}", responses={404: {"description": "Item not found"}}
324346
)

openhands-sdk/openhands/sdk/conversation/impl/local_conversation.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -632,11 +632,33 @@ def _pin_prompt_cache_key(self) -> None:
632632
if self.agent.llm._prompt_cache_key is None:
633633
self.agent.llm._prompt_cache_key = str(self._state.id)
634634

635+
def switch_llm(self, llm: LLM) -> None:
636+
"""Swap the agent's LLM to the given object.
637+
638+
The caller owns ``llm.usage_id``; it is the registry key. If an
639+
entry with that key already exists, the cached LLM is reused and
640+
the passed ``llm`` is dropped — matching the rest of the
641+
registry's "first-write-wins" contract.
642+
643+
Args:
644+
llm: LLM to install on the agent.
645+
"""
646+
try:
647+
new_llm = self.llm_registry.get(llm.usage_id)
648+
except KeyError:
649+
new_llm = llm
650+
self.llm_registry.add(new_llm)
651+
with self._state:
652+
self.agent = self.agent.model_copy(update={"llm": new_llm})
653+
self._state.agent = self.agent
654+
self._pin_prompt_cache_key()
655+
635656
def switch_profile(self, profile_name: str) -> None:
636-
"""Switch the agent's LLM to a named profile.
657+
"""Switch the agent's LLM to a profile loaded from disk.
637658
638-
Loads the profile from the LLMProfileStore (cached in the registry
639-
after the first load) and updates the agent and conversation state.
659+
Loads the profile from :class:`LLMProfileStore` (cached in the
660+
registry under ``profile:{profile_name}`` after first load) and
661+
delegates the swap to :meth:`switch_llm`.
640662
641663
Args:
642664
profile_name: Name of a profile previously saved via LLMProfileStore.
@@ -647,15 +669,11 @@ def switch_profile(self, profile_name: str) -> None:
647669
"""
648670
usage_id = f"profile:{profile_name}"
649671
try:
650-
new_llm = self.llm_registry.get(usage_id)
672+
cached = self.llm_registry.get(usage_id)
651673
except KeyError:
652-
new_llm = self._profile_store.load(profile_name)
653-
new_llm = new_llm.model_copy(update={"usage_id": usage_id})
654-
self.llm_registry.add(new_llm)
655-
with self._state:
656-
self.agent = self.agent.model_copy(update={"llm": new_llm})
657-
self._state.agent = self.agent
658-
self._pin_prompt_cache_key()
674+
loaded = self._profile_store.load(profile_name)
675+
cached = loaded.model_copy(update={"usage_id": usage_id})
676+
self.switch_llm(cached)
659677

660678
@observe(name="conversation.send_message")
661679
def send_message(self, message: str | Message, sender: str | None = None) -> None:

tests/agent_server/test_conversation_router.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1602,6 +1602,69 @@ def test_switch_conversation_profile_corrupted_profile(
16021602
client.app.dependency_overrides.clear()
16031603

16041604

1605+
def test_switch_conversation_llm_success(
1606+
client, mock_conversation_service, mock_event_service, sample_conversation_id
1607+
):
1608+
"""The /switch_llm endpoint forwards the inline LLM to switch_llm,
1609+
bypassing the profile store (#3017).
1610+
"""
1611+
mock_conversation = MagicMock()
1612+
mock_conversation_service.get_event_service.return_value = mock_event_service
1613+
mock_event_service.get_conversation.return_value = mock_conversation
1614+
1615+
client.app.dependency_overrides[get_conversation_service] = (
1616+
lambda: mock_conversation_service
1617+
)
1618+
1619+
llm_payload = {
1620+
"model": "openai/gpt-4o",
1621+
"api_key": "sk-test",
1622+
"usage_id": "caller-supplied-id",
1623+
}
1624+
1625+
try:
1626+
response = client.post(
1627+
f"/api/conversations/{sample_conversation_id}/switch_llm",
1628+
json={"llm": llm_payload},
1629+
)
1630+
1631+
assert response.status_code == 200
1632+
mock_conversation.switch_llm.assert_called_once()
1633+
forwarded_llm = mock_conversation.switch_llm.call_args.args[0]
1634+
assert isinstance(forwarded_llm, LLM)
1635+
assert forwarded_llm.model == "openai/gpt-4o"
1636+
assert forwarded_llm.usage_id == "caller-supplied-id"
1637+
finally:
1638+
client.app.dependency_overrides.clear()
1639+
1640+
1641+
def test_switch_conversation_llm_not_found(
1642+
client, mock_conversation_service, sample_conversation_id
1643+
):
1644+
"""The /switch_llm endpoint returns 404 when the conversation is missing."""
1645+
mock_conversation_service.get_event_service.return_value = None
1646+
1647+
client.app.dependency_overrides[get_conversation_service] = (
1648+
lambda: mock_conversation_service
1649+
)
1650+
1651+
try:
1652+
response = client.post(
1653+
f"/api/conversations/{sample_conversation_id}/switch_llm",
1654+
json={
1655+
"llm": {
1656+
"model": "openai/gpt-4o",
1657+
"api_key": "sk-test",
1658+
"usage_id": "x",
1659+
}
1660+
},
1661+
)
1662+
1663+
assert response.status_code == 404
1664+
finally:
1665+
client.app.dependency_overrides.clear()
1666+
1667+
16051668
def test_fork_conversation_success(
16061669
client, mock_conversation_service, sample_conversation_info, sample_conversation_id
16071670
):

tests/sdk/conversation/test_switch_model.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,98 @@ def test_switch_then_send_message(profile_store):
117117
# send_message triggers _ensure_agent_ready which re-registers agent LLMs;
118118
# the switched LLM must not cause a duplicate registration error.
119119
conv.send_message("hello")
120+
121+
122+
@pytest.fixture()
123+
def empty_profile_store(tmp_path, monkeypatch):
124+
"""Empty profile dir — simulates the agent-server sandbox where the
125+
app-server has never uploaded profile JSON. This is the real failure
126+
mode #3017 is fixing.
127+
"""
128+
profile_dir = tmp_path / "profiles"
129+
profile_dir.mkdir()
130+
monkeypatch.setattr(llm_profile_store, "_DEFAULT_PROFILE_DIR", profile_dir)
131+
return profile_dir
132+
133+
134+
def test_switch_llm_swaps_when_store_empty(empty_profile_store):
135+
"""Real app-server case (#3017): profile is unknown to the sandbox FS,
136+
the app-server supplies the LLM directly, and the swap succeeds.
137+
"""
138+
conv = _make_conversation()
139+
inline = _make_llm("inline-model", "caller-supplied-id")
140+
141+
conv.switch_llm(inline)
142+
143+
assert conv.agent.llm.model == "inline-model"
144+
# State must agree — agent_server reads agent.llm via _state.
145+
assert conv.state.agent.llm.model == "inline-model"
146+
# Caller's usage_id is preserved as the registry key.
147+
assert conv.agent.llm.usage_id == "caller-supplied-id"
148+
assert conv.llm_registry.get("caller-supplied-id").model == "inline-model"
149+
# Cache-key must be repinned (regression guard for #2918 on the new path).
150+
assert conv.agent.llm._prompt_cache_key == str(conv.id)
151+
152+
153+
def test_switch_llm_then_send_message(empty_profile_store):
154+
"""send_message triggers _ensure_agent_ready, which re-registers agent
155+
LLMs in the registry. switch_llm adds an entry under the caller's
156+
usage_id; this must not collide with the agent's own LLM
157+
re-registration on the next send_message().
158+
"""
159+
conv = _make_conversation()
160+
conv.switch_llm(_make_llm("inline-model", "x"))
161+
conv.send_message("hello")
162+
163+
164+
def test_switch_between_two_llms(empty_profile_store):
165+
"""Consecutive switch_llm calls under distinct usage_ids each register
166+
their own slot and end up as the agent's LLM.
167+
"""
168+
conv = _make_conversation()
169+
170+
conv.switch_llm(_make_llm("model-a", "x"))
171+
assert conv.agent.llm.model == "model-a"
172+
173+
conv.switch_llm(_make_llm("model-b", "y"))
174+
assert conv.agent.llm.model == "model-b"
175+
176+
177+
def test_switch_llm_does_not_consult_store(empty_profile_store, monkeypatch):
178+
"""switch_llm must not hit LLMProfileStore.load — the caller is
179+
authoritative. Guards against a regression where the inline path
180+
silently falls through to disk IO.
181+
"""
182+
calls: list[str] = []
183+
184+
def _spy_load(self, name):
185+
calls.append(name)
186+
raise FileNotFoundError(name)
187+
188+
monkeypatch.setattr(LLMProfileStore, "load", _spy_load)
189+
190+
conv = _make_conversation()
191+
conv.switch_llm(_make_llm("inline-model", "x"))
192+
193+
assert calls == [], f"profile store was consulted: {calls}"
194+
195+
196+
def test_switch_profile_delegates_to_switch_llm(profile_store, monkeypatch):
197+
"""switch_profile loads from disk and delegates to switch_llm; the LLM
198+
handed off carries the canonical ``profile:{name}`` usage_id.
199+
"""
200+
conv = _make_conversation()
201+
seen: list[LLM] = []
202+
real_switch_llm = conv.switch_llm
203+
204+
def _spy(llm):
205+
seen.append(llm)
206+
real_switch_llm(llm)
207+
208+
monkeypatch.setattr(conv, "switch_llm", _spy)
209+
210+
conv.switch_profile("fast")
211+
212+
assert len(seen) == 1
213+
assert seen[0].usage_id == "profile:fast"
214+
assert seen[0].model == "fast-model"

0 commit comments

Comments
 (0)