11import abc
22import inspect
33import re
4- from typing import Annotated , Any , ClassVar , Dict , Sequence , Type , TypedDict , cast
4+ from typing import (
5+ Annotated ,
6+ Any ,
7+ AsyncIterable ,
8+ AsyncIterator ,
9+ ClassVar ,
10+ Dict ,
11+ Literal ,
12+ Sequence ,
13+ Type ,
14+ TypedDict ,
15+ cast ,
16+ overload ,
17+ )
518
619from langchain_core .language_models import BaseChatModel
720from langchain_core .messages import (
@@ -415,16 +428,20 @@ def get_history_aware_retriever(self) -> Runnable[dict, RetrieverOutput]:
415428 )
416429
417430 @with_cast_id
418- def as_graph (self , thread_id : Any | None = None ) -> Runnable [dict , dict ]:
431+ def as_graph (
432+ self , thread_id : Any | None = None , thread : Any | None = None
433+ ) -> Runnable [dict , dict ]:
419434 """Create the LangGraph graph for the assistant.\n
420435 This graph is an agent that supports chat history, tool calling, and RAG (if `has_rag=True`).\n
421436 `as_graph` uses many other methods to create the graph for the assistant.
422437 Prefer to override the other methods to customize the graph for the assistant.
423438 Only override this method if you need to customize the graph at a lower level.
424439
440+ If both arguments are `None`, an in-memory chat message history is used.
441+
425442 Args:
426443 thread_id (Any | None): The thread ID for the chat message history.
427- If `None`, an in-memory chat message history is used .
444+ thread (Any | None): The thread object for the chat message history.
428445
429446 Returns:
430447 the compiled graph
@@ -434,10 +451,8 @@ def as_graph(self, thread_id: Any | None = None) -> Runnable[dict, dict]:
434451 llm = self .get_llm ()
435452 tools = self .get_tools ()
436453 llm_with_tools = llm .bind_tools (tools ) if tools else llm
437- if thread_id :
454+ if thread is None and thread_id is not None :
438455 thread = Thread .objects .get (id = thread_id )
439- else :
440- thread = None
441456
442457 def custom_add_messages (left : list [BaseMessage ], right : list [BaseMessage ]):
443458 result = add_messages (left , right ) # type: ignore
@@ -550,28 +565,62 @@ def record_response(state: AgentState):
550565
551566 return workflow .compile ()
552567
568+ @overload
569+ def invoke (
570+ self ,
571+ * args : Any ,
572+ thread_id : Any | None ,
573+ thread : Any | None = None ,
574+ mode : Literal ["invoke" ] = "invoke" ,
575+ ** kwargs : Any ,
576+ ) -> dict :
577+ ...
578+
579+ @overload
580+ def invoke (
581+ self ,
582+ * args : Any ,
583+ thread_id : Any | None ,
584+ thread : Any | None = None ,
585+ mode : Literal ["astream" ],
586+ ** kwargs : Any ,
587+ ) -> AsyncIterator [dict ]:
588+ ...
589+
553590 @with_cast_id
554- def invoke (self , * args : Any , thread_id : Any | None , ** kwargs : Any ) -> dict :
591+ def invoke (
592+ self ,
593+ * args : Any ,
594+ thread_id : Any | None = None ,
595+ thread : Any | None = None ,
596+ mode : Literal ["invoke" , "astream" ] = "invoke" ,
597+ ** kwargs : Any ,
598+ ) -> dict | AsyncIterator [dict ]:
555599 """Invoke the assistant LangChain graph with the given arguments and keyword arguments.\n
556600 This is the lower-level method to run the assistant.\n
557601 The graph is created by the `as_graph` method.\n
558602
603+ If thread_id and thread are `None`, an in-memory chat message history is used.
604+
559605 Args:
560606 *args: Positional arguments to pass to the graph.
561607 To add a new message, use a dict like `{"input": "user message"}`.
562608 If thread already has a `HumanMessage` in the end, you can invoke without args.
563609 thread_id (Any | None): The thread ID for the chat message history.
564- If `None`, an in-memory chat message history is used.
610+ thread (Any | None): The thread object for the chat message history.
611+ mode (invoke | astream): call named graph method
565612 **kwargs: Keyword arguments to pass to the graph.
566613
567614 Returns:
568615 dict: The output of the assistant graph,
569616 structured like `{"output": "assistant response", "history": ...}`.
570617 """
571- graph = self .as_graph (thread_id )
618+ graph = self .as_graph (thread_id = thread_id , thread = thread )
572619 config = kwargs .pop ("config" , {})
573620 config ["max_concurrency" ] = config .pop ("max_concurrency" , self .tool_max_concurrency )
574- return graph .invoke (* args , config = config , ** kwargs )
621+ if mode not in ("invoke" , "astream" ):
622+ raise NotImplementedError (f"mode={ mode !r} " )
623+ return getattr (graph , mode )(* args , config = config , ** kwargs )
575624
576625 @with_cast_id
577626 def run (self , message : str , thread_id : Any | None = None , ** kwargs : Any ) -> Any :
@@ -595,6 +644,34 @@ def run(self, message: str, thread_id: Any | None = None, **kwargs: Any) -> Any:
595644 ** kwargs ,
596645 )["output" ]
597646
647+ @with_cast_id
648+ async def astream (
649+ self , message : str , thread : Any | None = None , ** kwargs : Any
650+ ) -> AsyncIterable [Any ]:
651+ """Async-stream the assistant with the given message and thread.\n
652+ This is the higher-level method to run the assistant.\n
653+
654+ Args:
655+ message (str): The user message to pass to the assistant.
656+ thread (Any | None): The thread object for the chat message history.
657+ If `None`, an in-memory chat message history is used.
658+ **kwargs: Additional keyword arguments to pass to the graph.
659+
660+ Yields:
661+ Any: The assistant response to the user message.
662+ """
663+ async for output , metadata in self .invoke (
664+ {
665+ "input" : message ,
666+ },
667+ thread = thread ,
668+ mode = "astream" ,
669+ stream_mode = "messages" ,
670+ ** kwargs ,
671+ ):
672+ if metadata .get ("langgraph_node" ) == "agent" and (content := output .content ):
673+ yield content
674+
598675 def _run_as_tool (self , message : str , ** kwargs : Any ) -> Any :
599676 return self .run (message , thread_id = None , ** kwargs )
600677
0 commit comments