1+ from langchain .agents .middleware import PIIDetectionError
12import ujson
23from langchain_core .language_models import BaseChatModel
34from langchain_core .runnables import RunnableConfig
78from langgraph .types import StreamMode
89from deepagents import SubAgent
910
11+ from src .schemas .contexts import ContextSchema
1012from src .contexts .service import ServiceContext
11- from src .schemas .entities import LLMInput , LLMRequest
13+ from src .schemas .entities import LLMInput
1214from src .constants import APP_LOG_LEVEL
1315from src .flows import construct_agent
1416from src .services .db import get_checkpoint_db
1517from src .utils .messages import from_message_to_dict
1618from langchain_core .messages import (
17- AIMessage ,
1819 AIMessageChunk ,
19- BaseMessage ,
20+ HumanMessage ,
2021 ToolMessage ,
2122)
2223from src .utils .logger import log_to_file , logger
@@ -189,9 +190,9 @@ async def stream_generator(
189190 {"messages" : input .messages },
190191 stream_mode = ["messages" , "values" ],
191192 config = config ,
192- context = None
193+ context = ContextSchema ( model = agent . model )
193194 ):
194- # Serialize and yield each chunk as SSEq
195+ # Serialize and yield each chunk as SSE
195196 stream_chunk = handle_multi_mode (chunk )
196197 if stream_chunk :
197198 stream_type = stream_chunk [0 ]
@@ -202,10 +203,16 @@ async def stream_generator(
202203 log_to_file (str (data ), agent .model ) and APP_LOG_LEVEL == "DEBUG"
203204 logger .debug (f"data: { str (data )} " )
204205 yield f"data: { data } \n \n "
206+ except PIIDetectionError as e :
207+ # Yield error as SSE if streaming fails
208+ logger .warning (f"Sensitive data detected in the query: { e } " )
209+ # raise HTTPException(status_code=500, detail=str(e))
210+ error_msg = ujson .dumps (("error" , str (e )))
211+ yield f"data: { error_msg } \n \n "
205212
206213 except Exception as e :
207214 # Yield error as SSE if streaming fails
208- logger .exception ("Error in event_generator : %s" , e )
215+ logger .exception ("Error in stream_generator : %s" , e )
209216 # raise HTTPException(status_code=500, detail=str(e))
210217 error_msg = ujson .dumps (("error" , str (e )))
211218 yield f"data: { error_msg } \n \n "
@@ -214,12 +221,20 @@ async def stream_generator(
214221 final_state = await agent .graph .aget_state (config )
215222 configurable = final_state .config .get ("configurable" , {})
216223 messages = final_state .values .get ('messages' , [])
224+
225+ # Get the last HumanMessage
226+ last_human_message = None
227+ for message in reversed (messages ):
228+ if isinstance (message , HumanMessage ):
229+ last_human_message = message
230+ break
231+
217232 await service_context .thread_service .update (
218233 thread_id = configurable .get ("thread_id" ),
219234 data = {
220235 "thread_id" : configurable .get ("thread_id" ),
221236 "checkpoint_id" : configurable .get ("checkpoint_id" ),
222- "messages" : [messages [ - 1 ] .model_dump ()],
237+ "messages" : [last_human_message .model_dump ()] if last_human_message else [ ],
223238 "files" : files_map ,
224239 "updated_at" : get_time (),
225240 }
0 commit comments