22import pytest
33import os
44
5- from agents .triage_agent import run_workflow , TriageState
5+ from agents .triage_agent import run_workflow , TriageState , create_triage_agent
6+ from agents .metrics_middleware import MetricsMiddleware
67from agents .observability import setup_observability
78from common .models import TriageOutputSchema , Resolution , BackportData
89
@@ -14,8 +15,14 @@ def __init__(self, input, expected_output):
1415 self .metrics : dict = None
1516
1617 async def run (self ) -> TriageState :
17- return await run_workflow (self .input , False )
18-
18+ metrics_middleware = MetricsMiddleware ()
19+ def testing_factory (gateway_tools ):
20+ triage_agent = create_triage_agent (gateway_tools )
21+ triage_agent .middlewares .append (metrics_middleware )
22+ return triage_agent
23+ finished_state = await run_workflow (self .input , False , testing_factory )
24+ self .metrics = metrics_middleware .get_metrics ()
25+ return finished_state
1926
2027test_cases = [
2128 TriageAgentTestCase (input = "RHEL-15216" ,
@@ -88,5 +95,4 @@ def verify_result(real_output: TriageOutputSchema, expected_output: TriageOutput
8895 assert real_output .data .fix_version == expected_output .data .fix_version
8996
9097 finished_state = await test_case .run ()
91- test_case .metrics = finished_state .metrics
9298 verify_result (finished_state .triage_result , test_case .expected_output )
0 commit comments