@@ -454,16 +454,8 @@ def as_graph(
454454 if thread is None and thread_id is not None :
455455 thread = Thread .objects .get (id = thread_id )
456456
457- def custom_add_messages (left : list [BaseMessage ], right : list [BaseMessage ]):
458- result = add_messages (left , right ) # type: ignore
459- if thread :
460- # Save all messages, except the initial system message:
461- thread_messages = [m for m in result if not isinstance (m , SystemMessage )]
462- save_django_messages (cast (list [BaseMessage ], thread_messages ), thread = thread )
463- return result
464-
465457 class AgentState (TypedDict ):
466- messages : Annotated [list [AnyMessage ], custom_add_messages ]
458+ messages : Annotated [list [AnyMessage ], add_messages ]
467459 input : str | None # noqa: A003
468460 output : Any
469461
@@ -537,6 +529,10 @@ def record_response(state: AgentState):
537529 else :
538530 response = state ["messages" ][- 1 ].content
539531
532+ if thread :
533+ # Save all messages, except the initial system message:
534+ thread_messages = [m for m in state ["messages" ] if not isinstance (m , SystemMessage )]
535+ save_django_messages (cast (list [BaseMessage ], thread_messages ), thread = thread )
540536 return {"output" : response }
541537
542538 workflow = StateGraph (AgentState )
0 commit comments