1717
1818from parlant .api .agents import AgentIdPath
1919from parlant .api .common import JSONSerializableDTO , ToolNameField , apigen_config
20+ from parlant .api .context_variables import ContextVariableIdPath
2021from parlant .api .customers import CustomerIdPath
21- from parlant .api .sessions import EventCorrelationIdField , EventIdPath , EventKindDTO , EventSourceDTO
22+ from parlant .api .glossary import TermIdPath
23+ from parlant .api .journeys import JourneyIdPath
24+ from parlant .api .sessions import (
25+ EventCorrelationIdField ,
26+ EventIdPath ,
27+ EventKindDTO ,
28+ EventSourceDTO ,
29+ )
2230from parlant .core .agents import AgentStore
2331from parlant .core .common import DefaultBaseModel
24- from parlant .core .context_variables import ContextVariable , ContextVariableValue
32+ from parlant .core .context_variables import (
33+ ContextVariable ,
34+ ContextVariableStore ,
35+ ContextVariableValue ,
36+ )
2537from parlant .core .customers import CustomerStore
2638from parlant .core .emissions import EmittedEvent
27- from parlant .core .engines .alpha .guideline_matching .guideline_match import GuidelineMatch
39+ from parlant .core .engines .alpha .guideline_matching .guideline_match import (
40+ GuidelineMatch ,
41+ GuidelineMatchDTO ,
42+ )
2843from parlant .core .engines .alpha .tool_calling .tool_caller import ToolCall , ToolCaller
29- from parlant .core .glossary import Term
44+ from parlant .core .glossary import GlossaryStore , Term
45+ from parlant .core .guidelines import GuidelineStore
46+ from parlant .core .journeys import Journey , JourneyStore
47+ from parlant .core .loggers import Logger
3048from parlant .core .sessions import Event , EventKind , EventSource , SessionStore
31- from parlant .core .tools import Tool , ToolContext , ToolId , ToolService
49+ from parlant .core .tools import Tool , ToolContext , ToolId
50+ from parlant .core .services .tools .service_registry import ServiceRegistry
3251
3352
3453API_GROUP = "engine_test"
@@ -69,17 +88,39 @@ class EmittedEventDTO(DefaultBaseModel):
6988class ToolCallInferenceParamsDTO (DefaultBaseModel ):
7089 agent_id : AgentIdPath
7190 customer_id : CustomerIdPath
72- events : Sequence [EventIdPath ]
91+ context_variables : Sequence [ContextVariableIdPath ]
92+ interaction_history : Sequence [EventIdPath ]
93+ terms : Sequence [TermIdPath ]
94+ ordinary_guideline_matches : Sequence [GuidelineMatchDTO ]
95+ tool_enabled_guideline_matches : Sequence [GuidelineMatchDTO ]
96+ journeys : Sequence [JourneyIdPath ]
7397 staged_events : Sequence [EmittedEventDTO ]
7498 available_tools : Sequence [ToolNameField ]
7599
76100
101+ async def _convert_guideline_match (
102+ guideline_store : GuidelineStore ,
103+ dto : GuidelineMatchDTO ,
104+ ) -> GuidelineMatch :
105+ return GuidelineMatch (
106+ await guideline_store .read_guideline (dto .guideline_id ),
107+ dto .score ,
108+ dto .rationale ,
109+ dto .guideline_previously_applied ,
110+ )
111+
112+
77113def create_test_tool_call_inference_router (
78- tool_caller : ToolCaller ,
79114 agent_store : AgentStore ,
80115 customer_store : CustomerStore ,
116+ context_variable_store : ContextVariableStore ,
81117 session_store : SessionStore ,
82- tool_service : ToolService ,
118+ glossary_store : GlossaryStore ,
119+ guideline_store : GuidelineStore ,
120+ service_registry : ServiceRegistry ,
121+ journey_store : JourneyStore ,
122+ tool_caller : ToolCaller ,
123+ logger : Logger ,
83124) -> APIRouter :
84125 test_router = APIRouter ()
85126
@@ -104,6 +145,13 @@ async def infer_tool_calls(
104145 agent = await agent_store .read_agent (params .agent_id )
105146 customer = await customer_store .read_customer (params .customer_id )
106147
148+ context_variables : list [tuple [ContextVariable , ContextVariableValue ]] = []
149+ for context_variable_id in params .context_variables :
150+ context_variable = await context_variable_store .read_variable (context_variable_id )
151+ context_variable_values = await context_variable_store .list_values (context_variable_id )
152+ for _ , context_variable_value in context_variable_values :
153+ context_variables .append ((context_variable , context_variable_value ))
154+
107155 sessions = await session_store .list_sessions (params .agent_id , params .customer_id )
108156 if len (sessions ) == 0 :
109157 raise HTTPException (
@@ -113,16 +161,16 @@ async def infer_tool_calls(
113161
114162 session_id = sessions [0 ].id
115163 events : list [Event ] = []
116- for event_id in params .events :
164+ for event_id in params .interaction_history :
117165 event = await session_store .read_event (session_id , event_id )
118166 events .append (event )
119167
120168 staged_events : list [EmittedEvent ] = []
121169 for staged_event in params .staged_events :
122170 staged_events .append (
123171 EmittedEvent (
124- source = EventSource (staged_event .source ),
125- kind = EventKind (staged_event .kind ),
172+ source = EventSource (staged_event .source . value ),
173+ kind = EventKind (staged_event .kind . value ),
126174 correlation_id = staged_event .correlation_id ,
127175 data = staged_event .data ,
128176 )
@@ -133,28 +181,49 @@ async def infer_tool_calls(
133181 for tool_id_str in params .available_tools :
134182 tool_id = ToolId .from_string (tool_id_str )
135183 tool_ids .append (tool_id )
136- tool = await tool_service .read_tool (tool_id .tool_name )
137- available_tools .append (tool )
138184
139- context_variables : list [tuple [ContextVariable , ContextVariableValue ]] = []
140- terms : list [Term ] = []
141- ordinary_guideline_matches : list [GuidelineMatch ] = []
142- tool_enabled_guideline_matches : dict [GuidelineMatch , list [ToolId ]] = {}
185+ try :
186+ tool_service = await service_registry .read_tool_service (tool_id .service_name )
187+ tool = await tool_service .read_tool (tool_id .tool_name )
188+ available_tools .append (tool )
189+ except Exception as e :
190+ raise HTTPException (
191+ status_code = status .HTTP_404_NOT_FOUND ,
192+ detail = f"Tool '{ tool_id .tool_name } ' not found in service '{ tool_id .service_name } ': { str (e )} " ,
193+ )
194+
195+ terms : list [Term ] = [await glossary_store .read_term (term_id ) for term_id in params .terms ]
196+
197+ ordinary_guideline_matches : list [GuidelineMatch ] = [
198+ await _convert_guideline_match (guideline_store , match_dto )
199+ for match_dto in params .ordinary_guideline_matches
200+ ]
201+
202+ tool_enabled_guideline_matches : dict [GuidelineMatch , list [ToolId ]] = {
203+ await _convert_guideline_match (guideline_store , match_dto ): [
204+ ToolId .from_string (id ) for id in (match_dto .associated_tool_ids or [])
205+ ]
206+ for match_dto in params .tool_enabled_guideline_matches
207+ }
143208
144209 tool_context = ToolContext (
145210 agent_id = agent .id ,
146211 customer_id = customer .id ,
147212 session_id = session_id ,
148213 )
149214
215+ journeys : list [Journey ] = [
216+ await journey_store .read_journey (journey_id ) for journey_id in params .journeys
217+ ]
218+
150219 tool_call_inference_result = await tool_caller .infer_tool_calls (
151220 agent = agent ,
152221 context_variables = context_variables ,
153222 interaction_history = events ,
154223 terms = terms ,
155224 ordinary_guideline_matches = ordinary_guideline_matches ,
156225 tool_enabled_guideline_matches = tool_enabled_guideline_matches ,
157- journeys = [] ,
226+ journeys = journeys ,
158227 staged_events = staged_events ,
159228 tool_context = tool_context ,
160229 )
0 commit comments