Skip to content

Commit 6c8e102

Browse files
author
Antti Hautaniemi
committed
Enable langgraph astream usage
1 parent d52ba1b commit 6c8e102

3 files changed

Lines changed: 119 additions & 11 deletions

File tree

django_ai_assistant/helpers/assistants.py

Lines changed: 87 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,20 @@
11
import abc
22
import inspect
33
import re
4-
from typing import Annotated, Any, ClassVar, Dict, Sequence, Type, TypedDict, cast
4+
from typing import (
5+
Annotated,
6+
Any,
7+
AsyncIterable,
8+
AsyncIterator,
9+
ClassVar,
10+
Dict,
11+
Literal,
12+
Sequence,
13+
Type,
14+
TypedDict,
15+
cast,
16+
overload,
17+
)
518

619
from langchain_core.language_models import BaseChatModel
720
from langchain_core.messages import (
@@ -415,16 +428,20 @@ def get_history_aware_retriever(self) -> Runnable[dict, RetrieverOutput]:
415428
)
416429

417430
@with_cast_id
418-
def as_graph(self, thread_id: Any | None = None) -> Runnable[dict, dict]:
431+
def as_graph(
432+
self, thread_id: Any | None = None, thread: Any | None = None
433+
) -> Runnable[dict, dict]:
419434
"""Create the LangGraph graph for the assistant.\n
420435
This graph is an agent that supports chat history, tool calling, and RAG (if `has_rag=True`).\n
421436
`as_graph` uses many other methods to create the graph for the assistant.
422437
Prefer to override the other methods to customize the graph for the assistant.
423438
Only override this method if you need to customize the graph at a lower level.
424439
440+
If both arguments are `None`, an in-memory chat message history is used.
441+
425442
Args:
426443
thread_id (Any | None): The thread ID for the chat message history.
427-
If `None`, an in-memory chat message history is used.
444+
thread (Any | None): The thread object for the chat message history.
428445
429446
Returns:
430447
the compiled graph
@@ -434,10 +451,8 @@ def as_graph(self, thread_id: Any | None = None) -> Runnable[dict, dict]:
434451
llm = self.get_llm()
435452
tools = self.get_tools()
436453
llm_with_tools = llm.bind_tools(tools) if tools else llm
437-
if thread_id:
454+
if thread is None and thread_id is not None:
438455
thread = Thread.objects.get(id=thread_id)
439-
else:
440-
thread = None
441456

442457
def custom_add_messages(left: list[BaseMessage], right: list[BaseMessage]):
443458
result = add_messages(left, right) # type: ignore
@@ -550,28 +565,62 @@ def record_response(state: AgentState):
550565

551566
return workflow.compile()
552567

568+
@overload
569+
def invoke(
570+
self,
571+
*args: Any,
572+
thread_id: Any | None,
573+
thread: Any | None = None,
574+
mode: Literal["invoke"] = "invoke",
575+
**kwargs: Any,
576+
) -> dict:
577+
...
578+
579+
@overload
580+
def invoke(
581+
self,
582+
*args: Any,
583+
thread_id: Any | None,
584+
thread: Any | None = None,
585+
mode: Literal["astream"],
586+
**kwargs: Any,
587+
) -> AsyncIterator[dict]:
588+
...
589+
553590
@with_cast_id
554-
def invoke(self, *args: Any, thread_id: Any | None, **kwargs: Any) -> dict:
591+
def invoke(
592+
self,
593+
*args: Any,
594+
thread_id: Any | None = None,
595+
thread: Any | None = None,
596+
mode: Literal["invoke", "astream"] = "invoke",
597+
**kwargs: Any,
598+
) -> dict | AsyncIterator[dict]:
555599
"""Invoke the assistant LangChain graph with the given arguments and keyword arguments.\n
556600
This is the lower-level method to run the assistant.\n
557601
The graph is created by the `as_graph` method.\n
558602
603+
If thread_id and thread are `None`, an in-memory chat message history is used.
604+
559605
Args:
560606
*args: Positional arguments to pass to the graph.
561607
To add a new message, use a dict like `{"input": "user message"}`.
562608
If thread already has a `HumanMessage` in the end, you can invoke without args.
563609
thread_id (Any | None): The thread ID for the chat message history.
564-
If `None`, an in-memory chat message history is used.
610+
thread (Any | None): The thread object for the chat message history.
611+
mode (invoke | astream): call named graph method
565612
**kwargs: Keyword arguments to pass to the graph.
566613
567614
Returns:
568615
dict: The output of the assistant graph,
569616
structured like `{"output": "assistant response", "history": ...}`.
570617
"""
571-
graph = self.as_graph(thread_id)
618+
graph = self.as_graph(thread_id=thread_id, thread=thread)
572619
config = kwargs.pop("config", {})
573620
config["max_concurrency"] = config.pop("max_concurrency", self.tool_max_concurrency)
574-
return graph.invoke(*args, config=config, **kwargs)
621+
if mode not in ("invoke", "astream"):
622+
raise NotImplementedError(f"mode={mode!r}")
623+
return getattr(graph, mode)(*args, config=config, **kwargs)
575624

576625
@with_cast_id
577626
def run(self, message: str, thread_id: Any | None = None, **kwargs: Any) -> Any:
@@ -595,6 +644,34 @@ def run(self, message: str, thread_id: Any | None = None, **kwargs: Any) -> Any:
595644
**kwargs,
596645
)["output"]
597646

647+
@with_cast_id
648+
async def astream(
649+
self, message: str, thread: Any | None = None, **kwargs: Any
650+
) -> AsyncIterable[Any]:
651+
"""Async-stream the assistant with the given message and thread.\n
652+
This is the higher-level method to run the assistant.\n
653+
654+
Args:
655+
message (str): The user message to pass to the assistant.
656+
thread (Any | None): The thread object for the chat message history.
657+
If `None`, an in-memory chat message history is used.
658+
**kwargs: Additional keyword arguments to pass to the graph.
659+
660+
Yields:
661+
Any: The assistant response to the user message.
662+
"""
663+
async for output, metadata in self.invoke(
664+
{
665+
"input": message,
666+
},
667+
thread=thread,
668+
mode="astream",
669+
stream_mode="messages",
670+
**kwargs,
671+
):
672+
if metadata.get("langgraph_node") == "agent" and (content := output.content):
673+
yield content
674+
598675
def _run_as_tool(self, message: str, **kwargs: Any) -> Any:
599676
return self.run(message, thread_id=None, **kwargs)
600677

django_ai_assistant/helpers/django_messages.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from typing import TYPE_CHECKING
23

34
from django.db import connections, transaction
@@ -14,7 +15,7 @@
1415

1516

1617
@transaction.atomic
17-
def save_django_messages(messages: list[BaseMessage], thread: "Thread") -> list["DjangoMessage"]:
18+
def _save_django_messages(messages: list[BaseMessage], thread: "Thread") -> list["DjangoMessage"]:
1819
"""
1920
Save a list of messages to the Django database.
2021
Note: Changes the message objects in place by changing each message.id to the Django ID.
@@ -61,3 +62,12 @@ def save_django_messages(messages: list[BaseMessage], thread: "Thread") -> list[
6162

6263
DjangoMessage.objects.bulk_update(created_messages, ["message"])
6364
return created_messages
65+
66+
67+
def save_django_messages(messages: list[BaseMessage], thread: "Thread") -> None:
68+
try:
69+
loop = asyncio.get_running_loop()
70+
except RuntimeError:
71+
_save_django_messages(messages, thread)
72+
else:
73+
loop.run_in_executor(None, _save_django_messages, messages, thread)

tests/test_helpers/test_assistants.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import sys
12
from typing import List, TypedDict
23
from unittest.mock import patch
34

45
import pytest
6+
import vcr
57
from langchain_core.documents import Document
68
from langchain_core.messages import (
79
AIMessage,
@@ -163,6 +165,25 @@ def test_AIAssistant_invoke():
163165
]
164166

165167

168+
@pytest.mark.skipif(sys.version_info < (3, 11), reason="Not supported on Python 3.10")
169+
@pytest.mark.django_db(transaction=True)
170+
@pytest.mark.asyncio
171+
@vcr.use_cassette("cassettes/test_AIAssistant_invoke.yaml")
172+
async def test_AIAssistant_astream():
173+
thread = await Thread.objects.acreate(name="Recife Temperature Chat")
174+
assistant = AIAssistant.get_cls("temperature_assistant")()
175+
response = "".join(
176+
[
177+
stream_response
178+
async for stream_response in assistant.astream(
179+
"What is the temperature today in Recife?",
180+
thread=thread,
181+
)
182+
]
183+
)
184+
assert response == "The temperature today in Recife is 32 degrees Celsius."
185+
186+
166187
def test_AIAssistant_run_handles_optional_thread_id_param():
167188
assistant = AIAssistant.get_cls("temperature_assistant")()
168189

0 commit comments

Comments
 (0)