Skip to content

Commit 91b7f17

Browse files
authored
Merge pull request #917 from crestalnetwork/codex/migrate-langchain-to-version-1.0
Refactor agent middleware for LangChain v1 migration
2 parents acc2400 + 36ba6d8 commit 91b7f17

File tree

3 files changed

+285
-309
lines changed

3 files changed

+285
-309
lines changed

intentkit/core/engine.py

Lines changed: 115 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -20,28 +20,29 @@
2020

2121
import sqlalchemy
2222
from epyxid import XID
23-
from langchain_core.language_models import BaseChatModel
23+
from langchain.agents import create_agent as create_langchain_agent
2424
from langchain_core.messages import (
2525
BaseMessage,
2626
HumanMessage,
2727
)
2828
from langchain_core.tools import BaseTool
2929
from langgraph.errors import GraphRecursionError
3030
from langgraph.graph.state import CompiledStateGraph
31-
from langgraph.prebuilt import create_react_agent
32-
from langgraph.runtime import Runtime
3331
from sqlalchemy import func, update
3432
from sqlalchemy.exc import SQLAlchemyError
3533

3634
from intentkit.abstracts.graph import AgentContext, AgentError, AgentState
3735
from intentkit.config.config import config
3836
from intentkit.core.chat import clear_thread_memory
3937
from intentkit.core.credit import expense_message, expense_skill
40-
from intentkit.core.node import PreModelNode, post_model_node
41-
from intentkit.core.prompt import (
42-
create_formatted_prompt_function,
43-
explain_prompt,
38+
from intentkit.core.node import (
39+
CreditCheckMiddleware,
40+
DynamicPromptMiddleware,
41+
SummarizationMiddleware,
42+
ToolBindingMiddleware,
43+
TrimMessagesMiddleware,
4444
)
45+
from intentkit.core.prompt import explain_prompt
4546
from intentkit.models.agent import Agent, AgentTable
4647
from intentkit.models.agent_data import AgentData, AgentQuota
4748
from intentkit.models.app_setting import AppSetting, SystemMessageType
@@ -53,7 +54,7 @@
5354
)
5455
from intentkit.models.credit import CreditAccount, OwnerType
5556
from intentkit.models.db import get_langgraph_checkpointer, get_session
56-
from intentkit.models.llm import LLMModelInfo, LLMProvider, create_llm_model
57+
from intentkit.models.llm import LLMModelInfo, create_llm_model
5758
from intentkit.models.skill import AgentSkillData, ChatSkillData, Skill
5859
from intentkit.models.user import User
5960
from intentkit.utils.error import IntentKitAPIError
@@ -158,53 +159,37 @@ async def build_agent(
158159
tools = list({tool.name: tool for tool in tools}.values())
159160
private_tools = list({tool.name: tool for tool in private_tools}.values())
160161

161-
# Create the formatted_prompt function using the refactored prompt module
162-
formatted_prompt = create_formatted_prompt_function(agent, agent_data)
163-
164-
# bind tools to llm
165-
async def select_model(
166-
state: AgentState, runtime: Runtime[AgentContext]
167-
) -> BaseChatModel:
168-
llm_params = {}
169-
context = runtime.context
170-
if context.search or agent.has_search():
171-
if llm_model.info.supports_search:
172-
if llm_model.info.provider == LLMProvider.OPENAI:
173-
tools.append({"type": "web_search"})
174-
private_tools.append({"type": "web_search"})
175-
if llm_model.model_name == "gpt-5-mini":
176-
llm_params["reasoning_effort"] = "medium"
177-
if llm_model.info.provider == LLMProvider.XAI:
178-
llm_params["search_parameters"] = {"mode": "auto"}
179-
# TODO: else use a search skill
180-
# build llm now
181-
llm = await llm_model.create_instance(llm_params)
182-
if context.is_private:
183-
return llm.bind_tools(private_tools)
184-
return llm.bind_tools(tools)
185-
186162
for tool in private_tools:
187163
logger.info(
188164
f"[{agent.id}] loaded tool: {tool.name if isinstance(tool, BaseTool) else tool}"
189165
)
190166

191-
# Pre model hook
192-
summarize_llm = await create_llm_model(model_name="gpt-5-mini")
193-
summarize_model = await summarize_llm.create_instance()
194-
pre_model_hook = PreModelNode(
195-
model=summarize_model,
196-
short_term_memory_strategy=agent.short_term_memory_strategy,
197-
max_tokens=llm_model.info.context_length // 2,
198-
max_summary_tokens=2048,
199-
)
167+
base_model = await llm_model.create_instance()
168+
169+
middleware = [
170+
ToolBindingMiddleware(llm_model, tools, private_tools),
171+
DynamicPromptMiddleware(agent, agent_data),
172+
]
173+
174+
if agent.short_term_memory_strategy == "trim":
175+
middleware.append(TrimMessagesMiddleware(max_summary_tokens=2048))
176+
elif agent.short_term_memory_strategy == "summarize":
177+
summarize_llm = await create_llm_model(model_name="gpt-5-mini")
178+
summarize_model = await summarize_llm.create_instance()
179+
middleware.append(
180+
SummarizationMiddleware(
181+
model=summarize_model,
182+
max_tokens_before_summary=llm_model.info.context_length // 2,
183+
)
184+
)
185+
186+
if config.payment_enabled:
187+
middleware.append(CreditCheckMiddleware())
200188

201-
# Create ReAct Agent using the LLM and CDP Agentkit tools.
202-
executor = create_react_agent(
203-
model=select_model,
189+
executor = create_langchain_agent(
190+
model=base_model,
204191
tools=private_tools,
205-
prompt=formatted_prompt,
206-
pre_model_hook=pre_model_hook,
207-
post_model_hook=post_model_node if config.payment_enabled else None,
192+
middleware=middleware,
208193
state_schema=AgentState,
209194
context_schema=AgentContext,
210195
checkpointer=memory,
@@ -521,23 +506,65 @@ def get_agent() -> Agent:
521506
cached_tool_step = None
522507
try:
523508
async for chunk in executor.astream(
524-
{"messages": messages}, context=context, config=stream_config
509+
{"messages": messages},
510+
context=context,
511+
config=stream_config,
512+
stream_mode=["updates", "custom"],
525513
):
526514
this_time = time.perf_counter()
527515
logger.debug(f"stream chunk: {chunk}", extra={"thread_id": thread_id})
528-
if "agent" in chunk and "messages" in chunk["agent"]:
529-
if len(chunk["agent"]["messages"]) != 1:
516+
517+
if isinstance(chunk, dict) and "credit_check" in chunk:
518+
credit_payload = chunk.get("credit_check", {})
519+
content = credit_payload.get("message")
520+
if content:
521+
credit_message_create = ChatMessageCreate(
522+
id=str(XID()),
523+
agent_id=user_message.agent_id,
524+
chat_id=user_message.chat_id,
525+
user_id=user_message.user_id,
526+
author_id=user_message.agent_id,
527+
author_type=AuthorType.AGENT,
528+
model=agent.model,
529+
thread_type=user_message.author_type,
530+
reply_to=user_message.id,
531+
message=content,
532+
input_tokens=0,
533+
output_tokens=0,
534+
time_cost=this_time - last,
535+
)
536+
last = this_time
537+
credit_message = await credit_message_create.save()
538+
yield credit_message
539+
540+
error_message_create = await ChatMessageCreate.from_system_message(
541+
SystemMessageType.INSUFFICIENT_BALANCE,
542+
agent_id=user_message.agent_id,
543+
chat_id=user_message.chat_id,
544+
user_id=user_message.user_id,
545+
author_id=user_message.agent_id,
546+
thread_type=user_message.author_type,
547+
reply_to=user_message.id,
548+
time_cost=0,
549+
)
550+
error_message = await error_message_create.save()
551+
yield error_message
552+
return
553+
554+
if not isinstance(chunk, dict):
555+
continue
556+
557+
if "model" in chunk and "messages" in chunk["model"]:
558+
if len(chunk["model"]["messages"]) != 1:
530559
logger.error(
531-
"unexpected agent message: " + str(chunk["agent"]["messages"]),
560+
"unexpected model message: " + str(chunk["model"]["messages"]),
532561
extra={"thread_id": thread_id},
533562
)
534-
msg = chunk["agent"]["messages"][0]
563+
msg = chunk["model"]["messages"][0]
535564
if hasattr(msg, "tool_calls") and msg.tool_calls:
536-
# tool calls, save for later use, if it is deleted by post_model_hook, will not be used.
537565
cached_tool_step = msg
538566
if hasattr(msg, "content") and msg.content:
539567
content = _extract_text_content(msg.content)
540-
# agent message
541568
chat_message_create = ChatMessageCreate(
542569
id=str(XID()),
543570
agent_id=user_message.agent_id,
@@ -562,16 +589,13 @@ def get_agent() -> Agent:
562589
time_cost=this_time - last,
563590
)
564591
last = this_time
565-
# handle message and payment in one transaction
566592
async with get_session() as session:
567-
# payment
568593
if payment_enabled:
569594
amount = await model.calculate_cost(
570595
chat_message_create.input_tokens,
571596
chat_message_create.output_tokens,
572597
)
573598

574-
# Check for web_search_call in additional_kwargs
575599
if (
576600
hasattr(msg, "additional_kwargs")
577601
and msg.additional_kwargs
@@ -675,10 +699,8 @@ def get_agent() -> Agent:
675699
time_cost=this_time - last,
676700
)
677701
last = this_time
678-
# save message and credit in one transaction
679702
async with get_session() as session:
680703
if payment_enabled:
681-
# message payment, only first call in a group has message bill
682704
if have_first_call_in_cache:
683705
message_amount = await model.calculate_cost(
684706
skill_message_create.input_tokens,
@@ -698,7 +720,6 @@ def get_agent() -> Agent:
698720
skill_message_create.credit_cost = (
699721
message_payment_event.total_amount
700722
)
701-
# skill payment
702723
for skill_call in skill_calls:
703724
if not skill_call["success"]:
704725
continue
@@ -723,42 +744,40 @@ def get_agent() -> Agent:
723744
skill_message = await skill_message_create.save_in_session(session)
724745
await session.commit()
725746
yield skill_message
726-
elif "pre_model_hook" in chunk:
727-
pass
728-
elif "post_model_hook" in chunk:
729-
logger.debug(
730-
f"post_model_hook: {chunk}",
731-
extra={"thread_id": thread_id},
732-
)
733-
if chunk["post_model_hook"] and "error" in chunk["post_model_hook"]:
747+
else:
748+
for node_name, update in chunk.items():
734749
if (
735-
chunk["post_model_hook"]["error"]
736-
== AgentError.INSUFFICIENT_CREDITS
750+
node_name.endswith("CreditCheckMiddleware.after_model")
751+
and isinstance(update, dict)
752+
and update.get("error") == AgentError.INSUFFICIENT_CREDITS
737753
):
738-
if "messages" in chunk["post_model_hook"]:
739-
msg = chunk["post_model_hook"]["messages"][-1]
740-
content = msg.content
741-
if isinstance(msg.content, list):
742-
# in new version, content item maybe a list
743-
content = msg.content[0]
744-
post_model_message_create = ChatMessageCreate(
745-
id=str(XID()),
746-
agent_id=user_message.agent_id,
747-
chat_id=user_message.chat_id,
748-
user_id=user_message.user_id,
749-
author_id=user_message.agent_id,
750-
author_type=AuthorType.AGENT,
751-
model=agent.model,
752-
thread_type=user_message.author_type,
753-
reply_to=user_message.id,
754-
message=content,
755-
input_tokens=0,
756-
output_tokens=0,
757-
time_cost=this_time - last,
758-
)
759-
last = this_time
760-
post_model_message = await post_model_message_create.save()
761-
yield post_model_message
754+
ai_messages = [
755+
message
756+
for message in update.get("messages", [])
757+
if isinstance(message, BaseMessage)
758+
]
759+
content = ""
760+
if ai_messages:
761+
content = _extract_text_content(ai_messages[-1].content)
762+
post_model_message_create = ChatMessageCreate(
763+
id=str(XID()),
764+
agent_id=user_message.agent_id,
765+
chat_id=user_message.chat_id,
766+
user_id=user_message.user_id,
767+
author_id=user_message.agent_id,
768+
author_type=AuthorType.AGENT,
769+
model=agent.model,
770+
thread_type=user_message.author_type,
771+
reply_to=user_message.id,
772+
message=content,
773+
input_tokens=0,
774+
output_tokens=0,
775+
time_cost=this_time - last,
776+
)
777+
last = this_time
778+
post_model_message = await post_model_message_create.save()
779+
yield post_model_message
780+
762781
error_message_create = (
763782
await ChatMessageCreate.from_system_message(
764783
SystemMessageType.INSUFFICIENT_BALANCE,
@@ -773,12 +792,7 @@ def get_agent() -> Agent:
773792
)
774793
error_message = await error_message_create.save()
775794
yield error_message
776-
else:
777-
error_traceback = traceback.format_exc()
778-
logger.error(
779-
f"unexpected message type: {str(chunk)}\n{error_traceback}",
780-
extra={"thread_id": thread_id},
781-
)
795+
return
782796
except SQLAlchemyError as e:
783797
error_traceback = traceback.format_exc()
784798
logger.error(

0 commit comments

Comments
 (0)