Skip to content

Commit 3ace963

Browse files
committed
fix(test): integrate configurable endpoints with test
test(test-apis): fixed basic e2e test for test-apis Signed-off-by: NathanD <relvox@gmail.com>
1 parent b868a1d commit 3ace963

9 files changed

Lines changed: 520 additions & 437 deletions

File tree

src/parlant/api/app.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@
5959
from parlant.core.engines.alpha.tool_calling.tool_caller import ToolCaller
6060
from parlant.core.evaluations import EvaluationStore, EvaluationListener
6161
from parlant.core.journeys import JourneyStore
62-
from parlant.core.tools import LocalToolService
6362
from parlant.core.utterances import UtteranceStore
6463
from parlant.core.relationships import RelationshipStore
6564
from parlant.core.guidelines import GuidelineStore
@@ -449,24 +448,29 @@ async def configure_test_router(
449448
) -> AsyncIterator[FastAPI]:
450449
test_router_guideline_matching = (
451450
guideline_matcher_test_api.create_test_guideline_matching_router(
452-
guideline_matcher=container[GuidelineMatcher],
453451
agent_store=container[AgentStore],
454452
customer_store=container[CustomerStore],
455453
context_variable_store=container[ContextVariableStore],
456-
guideline_store=container[GuidelineStore],
457-
glossary_store=container[GlossaryStore],
458454
session_store=container[SessionStore],
455+
glossary_store=container[GlossaryStore],
456+
guideline_store=container[GuidelineStore],
457+
guideline_matcher=container[GuidelineMatcher],
459458
)
460459
)
461460
app.include_router(test_router_guideline_matching, prefix="/test/alpha/guideline-matching")
462461

463462
test_router_tool_call_inference = (
464463
tool_call_inference_test_api.create_test_tool_call_inference_router(
465-
tool_caller=container[ToolCaller],
466464
agent_store=container[AgentStore],
467465
customer_store=container[CustomerStore],
466+
context_variable_store=container[ContextVariableStore],
468467
session_store=container[SessionStore],
469-
tool_service=container[LocalToolService],
468+
glossary_store=container[GlossaryStore],
469+
guideline_store=container[GuidelineStore],
470+
service_registry=container[ServiceRegistry],
471+
journey_store=container[JourneyStore],
472+
tool_caller=container[ToolCaller],
473+
logger=container[Logger],
470474
)
471475
)
472476
app.include_router(test_router_tool_call_inference, prefix="/test/alpha/tool-call-inference")

src/parlant/api/guideline_matcher_test_api.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -81,20 +81,20 @@ class GuidelineMatchingParamsDTO(DefaultBaseModel):
8181
agent_id: AgentIdPath
8282
customer_id: CustomerIdPath
8383
context_variables: Sequence[ContextVariableIdPath]
84-
events: Sequence[EventIdPath]
84+
interaction_history: Sequence[EventIdPath]
8585
terms: Sequence[TermIdPath]
8686
staged_events: Sequence[EmittedEventDTO]
8787
guidelines: Sequence[GuidelineIdPath]
8888

8989

9090
def create_test_guideline_matching_router(
91-
guideline_matcher: GuidelineMatcher,
9291
agent_store: AgentStore,
9392
customer_store: CustomerStore,
9493
context_variable_store: ContextVariableStore,
95-
guideline_store: GuidelineStore,
96-
glossary_store: GlossaryStore,
9794
session_store: SessionStore,
95+
glossary_store: GlossaryStore,
96+
guideline_store: GuidelineStore,
97+
guideline_matcher: GuidelineMatcher,
9898
) -> APIRouter:
9999
test_router = APIRouter()
100100

@@ -135,15 +135,11 @@ async def match_guidelines(
135135

136136
session_id = sessions[0].id
137137
events: list[Event] = []
138-
for event_id in params.events:
138+
for event_id in params.interaction_history:
139139
event = await session_store.read_event(session_id, event_id)
140140
events.append(event)
141141

142-
terms: list[Term] = []
143-
for term_id in params.terms:
144-
if glossary_store:
145-
term = await glossary_store.read_term(term_id)
146-
terms.append(term)
142+
terms: list[Term] = [await glossary_store.read_term(term_id) for term_id in params.terms]
147143

148144
staged_events: list[EmittedEvent] = []
149145
for staged_event in params.staged_events:
@@ -163,6 +159,7 @@ async def match_guidelines(
163159

164160
guideline_matching_result = await guideline_matcher.match_guidelines(
165161
agent=agent,
162+
session=sessions[0],
166163
customer=customer,
167164
context_variables=context_variables,
168165
interaction_history=events,
@@ -192,8 +189,6 @@ async def match_guidelines(
192189
score=match.score,
193190
rationale=match.rationale,
194191
guideline_previously_applied=match.guideline_previously_applied,
195-
guideline_is_continuous=match.guideline_is_continuous,
196-
should_reapply=match.should_reapply,
197192
)
198193
for match in batch
199194
]

src/parlant/api/tool_call_inference_test_api.py

Lines changed: 87 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,37 @@
1717

1818
from parlant.api.agents import AgentIdPath
1919
from parlant.api.common import JSONSerializableDTO, ToolNameField, apigen_config
20+
from parlant.api.context_variables import ContextVariableIdPath
2021
from parlant.api.customers import CustomerIdPath
21-
from parlant.api.sessions import EventCorrelationIdField, EventIdPath, EventKindDTO, EventSourceDTO
22+
from parlant.api.glossary import TermIdPath
23+
from parlant.api.journeys import JourneyIdPath
24+
from parlant.api.sessions import (
25+
EventCorrelationIdField,
26+
EventIdPath,
27+
EventKindDTO,
28+
EventSourceDTO,
29+
)
2230
from parlant.core.agents import AgentStore
2331
from parlant.core.common import DefaultBaseModel
24-
from parlant.core.context_variables import ContextVariable, ContextVariableValue
32+
from parlant.core.context_variables import (
33+
ContextVariable,
34+
ContextVariableStore,
35+
ContextVariableValue,
36+
)
2537
from parlant.core.customers import CustomerStore
2638
from parlant.core.emissions import EmittedEvent
27-
from parlant.core.engines.alpha.guideline_matching.guideline_match import GuidelineMatch
39+
from parlant.core.engines.alpha.guideline_matching.guideline_match import (
40+
GuidelineMatch,
41+
GuidelineMatchDTO,
42+
)
2843
from parlant.core.engines.alpha.tool_calling.tool_caller import ToolCall, ToolCaller
29-
from parlant.core.glossary import Term
44+
from parlant.core.glossary import GlossaryStore, Term
45+
from parlant.core.guidelines import GuidelineStore
46+
from parlant.core.journeys import Journey, JourneyStore
47+
from parlant.core.loggers import Logger
3048
from parlant.core.sessions import Event, EventKind, EventSource, SessionStore
31-
from parlant.core.tools import Tool, ToolContext, ToolId, ToolService
49+
from parlant.core.tools import Tool, ToolContext, ToolId
50+
from parlant.core.services.tools.service_registry import ServiceRegistry
3251

3352

3453
API_GROUP = "engine_test"
@@ -69,17 +88,39 @@ class EmittedEventDTO(DefaultBaseModel):
6988
class ToolCallInferenceParamsDTO(DefaultBaseModel):
7089
agent_id: AgentIdPath
7190
customer_id: CustomerIdPath
72-
events: Sequence[EventIdPath]
91+
context_variables: Sequence[ContextVariableIdPath]
92+
interaction_history: Sequence[EventIdPath]
93+
terms: Sequence[TermIdPath]
94+
ordinary_guideline_matches: Sequence[GuidelineMatchDTO]
95+
tool_enabled_guideline_matches: Sequence[GuidelineMatchDTO]
96+
journeys: Sequence[JourneyIdPath]
7397
staged_events: Sequence[EmittedEventDTO]
7498
available_tools: Sequence[ToolNameField]
7599

76100

101+
async def _convert_guideline_match(
102+
guideline_store: GuidelineStore,
103+
dto: GuidelineMatchDTO,
104+
) -> GuidelineMatch:
105+
return GuidelineMatch(
106+
await guideline_store.read_guideline(dto.guideline_id),
107+
dto.score,
108+
dto.rationale,
109+
dto.guideline_previously_applied,
110+
)
111+
112+
77113
def create_test_tool_call_inference_router(
78-
tool_caller: ToolCaller,
79114
agent_store: AgentStore,
80115
customer_store: CustomerStore,
116+
context_variable_store: ContextVariableStore,
81117
session_store: SessionStore,
82-
tool_service: ToolService,
118+
glossary_store: GlossaryStore,
119+
guideline_store: GuidelineStore,
120+
service_registry: ServiceRegistry,
121+
journey_store: JourneyStore,
122+
tool_caller: ToolCaller,
123+
logger: Logger,
83124
) -> APIRouter:
84125
test_router = APIRouter()
85126

@@ -104,6 +145,13 @@ async def infer_tool_calls(
104145
agent = await agent_store.read_agent(params.agent_id)
105146
customer = await customer_store.read_customer(params.customer_id)
106147

148+
context_variables: list[tuple[ContextVariable, ContextVariableValue]] = []
149+
for context_variable_id in params.context_variables:
150+
context_variable = await context_variable_store.read_variable(context_variable_id)
151+
context_variable_values = await context_variable_store.list_values(context_variable_id)
152+
for _, context_variable_value in context_variable_values:
153+
context_variables.append((context_variable, context_variable_value))
154+
107155
sessions = await session_store.list_sessions(params.agent_id, params.customer_id)
108156
if len(sessions) == 0:
109157
raise HTTPException(
@@ -113,16 +161,16 @@ async def infer_tool_calls(
113161

114162
session_id = sessions[0].id
115163
events: list[Event] = []
116-
for event_id in params.events:
164+
for event_id in params.interaction_history:
117165
event = await session_store.read_event(session_id, event_id)
118166
events.append(event)
119167

120168
staged_events: list[EmittedEvent] = []
121169
for staged_event in params.staged_events:
122170
staged_events.append(
123171
EmittedEvent(
124-
source=EventSource(staged_event.source),
125-
kind=EventKind(staged_event.kind),
172+
source=EventSource(staged_event.source.value),
173+
kind=EventKind(staged_event.kind.value),
126174
correlation_id=staged_event.correlation_id,
127175
data=staged_event.data,
128176
)
@@ -133,28 +181,49 @@ async def infer_tool_calls(
133181
for tool_id_str in params.available_tools:
134182
tool_id = ToolId.from_string(tool_id_str)
135183
tool_ids.append(tool_id)
136-
tool = await tool_service.read_tool(tool_id.tool_name)
137-
available_tools.append(tool)
138184

139-
context_variables: list[tuple[ContextVariable, ContextVariableValue]] = []
140-
terms: list[Term] = []
141-
ordinary_guideline_matches: list[GuidelineMatch] = []
142-
tool_enabled_guideline_matches: dict[GuidelineMatch, list[ToolId]] = {}
185+
try:
186+
tool_service = await service_registry.read_tool_service(tool_id.service_name)
187+
tool = await tool_service.read_tool(tool_id.tool_name)
188+
available_tools.append(tool)
189+
except Exception as e:
190+
raise HTTPException(
191+
status_code=status.HTTP_404_NOT_FOUND,
192+
detail=f"Tool '{tool_id.tool_name}' not found in service '{tool_id.service_name}': {str(e)}",
193+
)
194+
195+
terms: list[Term] = [await glossary_store.read_term(term_id) for term_id in params.terms]
196+
197+
ordinary_guideline_matches: list[GuidelineMatch] = [
198+
await _convert_guideline_match(guideline_store, match_dto)
199+
for match_dto in params.ordinary_guideline_matches
200+
]
201+
202+
tool_enabled_guideline_matches: dict[GuidelineMatch, list[ToolId]] = {
203+
await _convert_guideline_match(guideline_store, match_dto): [
204+
ToolId.from_string(id) for id in (match_dto.associated_tool_ids or [])
205+
]
206+
for match_dto in params.tool_enabled_guideline_matches
207+
}
143208

144209
tool_context = ToolContext(
145210
agent_id=agent.id,
146211
customer_id=customer.id,
147212
session_id=session_id,
148213
)
149214

215+
journeys: list[Journey] = [
216+
await journey_store.read_journey(journey_id) for journey_id in params.journeys
217+
]
218+
150219
tool_call_inference_result = await tool_caller.infer_tool_calls(
151220
agent=agent,
152221
context_variables=context_variables,
153222
interaction_history=events,
154223
terms=terms,
155224
ordinary_guideline_matches=ordinary_guideline_matches,
156225
tool_enabled_guideline_matches=tool_enabled_guideline_matches,
157-
journeys=[],
226+
journeys=journeys,
158227
staged_events=staged_events,
159228
tool_context=tool_context,
160229
)

src/parlant/core/engines/alpha/guideline_matching/guideline_match.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,14 @@
1616

1717
from dataclasses import dataclass
1818
from enum import Enum
19+
from typing import Annotated, Optional, TypeAlias
1920

21+
from pydantic import Field
22+
23+
24+
from parlant.api.common import GuidelineIdField
25+
from parlant.api.sessions import GuidelineMatchRationaleField, GuidelineMatchScoreField, ToolIdField
26+
from parlant.core.common import DefaultBaseModel
2027
from parlant.core.guidelines import Guideline
2128

2229

@@ -39,3 +46,56 @@ class GuidelineMatch:
3946
class AnalyzedGuideline:
4047
guideline: Guideline
4148
is_previously_applied: bool
49+
50+
51+
GuidelinePreviouslyAppliedField: TypeAlias = Annotated[
52+
PreviouslyAppliedType,
53+
Field(
54+
default=PreviouslyAppliedType.NO,
55+
description="Status of the guideline's previous application.",
56+
examples=[
57+
PreviouslyAppliedType.NO,
58+
PreviouslyAppliedType.PARTIALLY,
59+
PreviouslyAppliedType.FULLY,
60+
PreviouslyAppliedType.IRRELEVANT,
61+
],
62+
),
63+
]
64+
65+
GuidelineIsContinuousField: TypeAlias = Annotated[
66+
bool,
67+
Field(
68+
default=False,
69+
description="Whether the guideline is continuous.",
70+
examples=[True, False],
71+
),
72+
]
73+
74+
GuidelineShouldReapplyField: TypeAlias = Annotated[
75+
bool,
76+
Field(
77+
default=False,
78+
description="Whether the guideline should be reapplied.",
79+
examples=[True, False],
80+
),
81+
]
82+
83+
guideline_match_dto_example = {
84+
"guideline_id": "123xyz",
85+
"score": "9",
86+
"rationale": "customer asks for weather today",
87+
"guideline_previously_applied": "no",
88+
"guideline_is_continuous": "False",
89+
"should_reapply": "False",
90+
}
91+
92+
93+
class GuidelineMatchDTO(
94+
DefaultBaseModel,
95+
json_schema_extra={"example": guideline_match_dto_example},
96+
):
97+
guideline_id: GuidelineIdField
98+
score: GuidelineMatchScoreField
99+
rationale: GuidelineMatchRationaleField
100+
guideline_previously_applied: GuidelinePreviouslyAppliedField
101+
associated_tool_ids: Optional[list[ToolIdField]] = None

0 commit comments

Comments
 (0)