Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 35 additions & 3 deletions llama_deploy/apiserver/routers/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ async def create_deployment_task(
)

run_kwargs = json.loads(task_definition.input) if task_definition.input else {}
session_id = session_id or task_definition.session_id
result = await deployment.run_workflow(
service_id=service_id, session_id=session_id, **run_kwargs
)
Expand Down Expand Up @@ -123,6 +124,7 @@ async def create_deployment_task_nowait(
)

run_kwargs = json.loads(task_definition.input) if task_definition.input else {}
session_id = session_id or task_definition.session_id
handler_id, session_id = deployment.run_workflow_no_wait(
service_id=service_id, session_id=session_id, **run_kwargs
)
Expand Down Expand Up @@ -176,8 +178,13 @@ async def event_stream(handler: WorkflowHandler) -> AsyncGenerator[str, None]:
await asyncio.sleep(0.01)
await handler

try:
deployment_handler = deployment._handlers[task_id]
except KeyError:
raise HTTPException(status_code=404, detail="Task not found")

return StreamingResponse(
event_stream(deployment._handlers[task_id]),
event_stream(deployment_handler),
media_type="application/x-ndjson",
)

Expand All @@ -190,8 +197,27 @@ async def get_task_result(
) -> TaskResult | None:
"""Get the task result associated with a task and session."""

handler = deployment._handlers[task_id]
return TaskResult(task_id=task_id, history=[], result=await handler)
try:
handler = deployment._handlers[task_id]
except KeyError:
raise HTTPException(status_code=404, detail="Task not found")
result = await handler
if not isinstance(result, str):
result = str(result)
return TaskResult(task_id=task_id, history=[], result=result)


@deployments_router.post("/{deployment_name}/tasks/delete")
async def delete_task(
deployment: Annotated[Deployment, Depends(deployment)], task_id: str
) -> None:
"""Get the active sessions in a deployment and service."""

if task_id not in deployment._handlers:
raise HTTPException(status_code=404, detail="Task not found")

deployment._handlers.pop(task_id) # noqa: ignore
deployment._handler_inputs.pop(task_id, None) # noqa: ignore


@deployments_router.get("/{deployment_name}/tasks")
Expand Down Expand Up @@ -224,6 +250,9 @@ async def get_session(
) -> SessionDefinition:
"""Get the definition of a session by ID."""

if session_id not in deployment._contexts:
raise HTTPException(status_code=404, detail="Session not found")

return SessionDefinition(session_id=session_id)


Expand All @@ -246,6 +275,9 @@ async def delete_session(
) -> None:
"""Get the active sessions in a deployment and service."""

if session_id not in deployment._contexts:
raise HTTPException(status_code=404, detail="Session not found")

deployment._contexts.pop(session_id)


Expand Down
23 changes: 21 additions & 2 deletions llama_deploy/client/models/apiserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,25 @@ class TaskCollection(Collection):
description="The ID of the deployment these tasks belong to."
)

async def delete(self, task_id: str) -> None:
"""Deletes the session with the provided `session_id`.

Args:
task_id: The id of the task that will be removed

Raises:
HTTPException: If the session couldn't be found with the id provided.
"""
delete_url = f"{self.client.api_server_url}/deployments/{self.deployment_id}/tasks/delete"

await self.client.request(
"POST",
delete_url,
params={"task_id": task_id},
verify=not self.client.disable_ssl,
timeout=self.client.timeout,
)

async def run(self, task: TaskDefinition) -> Any:
"""Runs a task and returns the results once it's done.

Expand Down Expand Up @@ -222,8 +241,8 @@ async def list(self) -> list[Task]:
items = {
"id": task_model_class(
client=self.client,
id=task_def.task_id,
session_id=task_def.session_id,
id=task_def["task_id"],
session_id=task_def["session_id"],
deployment_id=self.deployment_id,
)
for task_def in r.json()
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ dev = [

[project]
name = "llama-deploy"
version = "0.9.1"
version = "0.9.2"
description = ""
authors = [
{name = "Logan Markewich", email = "[email protected]"},
Expand Down
94 changes: 90 additions & 4 deletions tests/apiserver/routers/test_deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import pytest
import respx
from fastapi.testclient import TestClient
from llama_index.core.agent.workflow import AgentOutput
from llama_index.core.base.llms.types import ChatMessage
from workflows.context import JsonSerializer
from workflows.events import Event

Expand Down Expand Up @@ -148,6 +150,13 @@ def test_run_deployment_task(
)
assert response.status_code == 200

deployment.reset_mock()
response = http_client.post(
"/deployments/test-deployment/tasks/run/",
json={"input": "{}", "session_id": "84"},
)
assert response.status_code == 200


def test_create_deployment_task(
http_client: TestClient, data_path: Path, mock_manager: MagicMock
Expand Down Expand Up @@ -175,6 +184,13 @@ def test_create_deployment_task(
)
assert response.status_code == 200

deployment.reset_mock()
response = http_client.post(
"/deployments/test-deployment/tasks/create/",
json={"input": "{}", "session_id": "84"},
)
assert response.status_code == 200


def test_send_event_not_found(
http_client: TestClient, data_path: Path, mock_manager: MagicMock
Expand Down Expand Up @@ -228,6 +244,15 @@ def test_get_event_not_found(
)
assert response.status_code == 404

deployment = mock.AsyncMock()
deployment._handlers = {}
mock_manager.get_deployment.return_value = deployment
response = http_client.get(
"/deployments/test-deployment/tasks/test_task_id/events",
params={"session_id": "42", "task_id": "84"},
)
assert response.status_code == 404


@pytest.mark.asyncio
async def test_get_event_stream(
Expand Down Expand Up @@ -319,7 +344,10 @@ async def await_impl(): # type:ignore
def test_get_task_result_not_found(
http_client: TestClient, data_path: Path, mock_manager: MagicMock
) -> None:
mock_manager.get_deployment.return_value = None
deployment = mock.AsyncMock()
deployment.default_service = "TestService"
deployment._handlers = {}
mock_manager.get_deployment.return_value = deployment
response = http_client.get(
"/deployments/test-deployment/tasks/test_task_id/results/?session_id=42",
)
Expand Down Expand Up @@ -377,6 +405,60 @@ async def await_impl(): # type:ignore
assert response.status_code == 200
assert TaskResult(**response.json()).result == "test_result"

# Mock the handler to return an AgentOutput
class MockAgentOutputHandler:
def __await__(self): # type:ignore
async def await_impl(): # type:ignore
return AgentOutput(
response=ChatMessage(content="test_result"),
current_agent_name="test_agent",
tool_calls=[],
raw=None,
)

return await_impl().__await__()

mock_agent_output_handler = MockAgentOutputHandler()
deployment._handlers = {"test_task_id": mock_agent_output_handler}

mock_manager.get_deployment.return_value = deployment

response = http_client.get(
"/deployments/test-deployment/tasks/test_task_id/results/?session_id=42",
)
assert response.status_code == 200
assert TaskResult(**response.json()).result == "test_result"


def test_delete_task_not_found(
http_client: TestClient, data_path: Path, mock_manager: MagicMock
) -> None:
deployment = mock.AsyncMock()
deployment.default_service = "TestService"
mock_manager.get_deployment.return_value = deployment
response = http_client.post(
"/deployments/test-deployment/tasks/delete/?task_id=42",
)
assert response.status_code == 404
assert response.json() == {"detail": "Task not found"}


def test_delete_task(
http_client: TestClient, data_path: Path, mock_manager: MagicMock
) -> None:
deployment = mock.AsyncMock()
deployment.default_service = "TestService"
deployment._handlers = {"42": mock.MagicMock()} # Mock handlers to be deleted
deployment._handler_inputs = {"42": "foo"}
mock_manager.get_deployment.return_value = deployment

response = http_client.post(
"/deployments/test-deployment/tasks/delete/?task_id=42",
)
assert response.status_code == 200
assert "42" not in deployment._handlers
assert "42" not in deployment._handler_inputs


def test_get_sessions_not_found(
http_client: TestClient, data_path: Path, mock_manager: MagicMock
Expand Down Expand Up @@ -406,12 +488,14 @@ def test_get_sessions(
def test_delete_session_not_found(
http_client: TestClient, data_path: Path, mock_manager: MagicMock
) -> None:
mock_manager.get_deployment.return_value = None
deployment = mock.AsyncMock()
deployment.default_service = "TestService"
mock_manager.get_deployment.return_value = deployment
response = http_client.post(
"/deployments/test-deployment/sessions/delete/?session_id=42",
)
assert response.status_code == 404
assert response.json() == {"detail": "Deployment not found"}
assert response.json() == {"detail": "Session not found"}


def test_delete_session(
Expand All @@ -431,7 +515,8 @@ def test_delete_session(
def test_get_session_not_found(
http_client: TestClient, data_path: Path, mock_manager: MagicMock
) -> None:
mock_manager.get_deployment.return_value = None
deployment = mock.AsyncMock()
mock_manager.get_deployment.return_value = deployment
response = http_client.get(
"/deployments/test-deployment/sessions/foo",
)
Expand All @@ -444,6 +529,7 @@ def test_get_session(
deployment = mock.AsyncMock()
mock_manager.get_deployment.return_value = deployment
session = mock.AsyncMock(id="foo")
deployment._contexts = {"foo": session}
deployment.client.core.sessions.get.return_value = session
response = http_client.get("/deployments/test-deployment/sessions/foo")
assert response.status_code == 200
Expand Down
25 changes: 20 additions & 5 deletions tests/client/models/test_apiserver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import io
from typing import Any
from typing import Any, Dict
from unittest import mock

import httpx
Expand Down Expand Up @@ -109,6 +109,23 @@ async def test_task_results(client: Any) -> None:
)


@pytest.mark.asyncio
async def test_task_collection_delete(client: Any) -> None:
coll = TaskCollection(
client=client,
items={},
deployment_id="a_deployment",
)
await coll.delete("a_task")
client.request.assert_awaited_with(
"POST",
"http://localhost:4501/deployments/a_deployment/tasks/delete",
params={"task_id": "a_task"},
timeout=120.0,
verify=True,
)


@pytest.mark.asyncio
async def test_task_collection_run(client: Any) -> None:
client.request.return_value = mock.MagicMock(json=lambda: "some result")
Expand Down Expand Up @@ -174,10 +191,8 @@ async def test_task_collection_create(client: Any) -> None:
@pytest.mark.asyncio
async def test_task_deployment_tasks(client: Any) -> None:
d = Deployment(client=client, id="a_deployment")
res: list[TaskDefinition] = [
TaskDefinition(
input='{"arg": "input"}', task_id="a_task", session_id="a_session"
)
res: list[Dict[str, str]] = [
{"input": '{"arg": "input"}', "task_id": "a_task", "session_id": "a_session"}
]
client.request.return_value = mock.MagicMock(json=lambda: res)

Expand Down
Loading
Loading