Skip to content

Commit be8019d

Browse files
committed
fix: address PR review feedback (asamal4 + CodeRabbit)
- Revert DEFAULT_LLM_RETRIES from 5 to 3 - Narrow retry codes to (429, 502, 503, 504), exclude 500 - Use RLSAPI native fields (name/args) in _rlsapi_infer_query - Fix RAG chunk accumulation across multiple mcp_call results - Redact prompt from debug log, log only metadata - Add comment about extra_request_params not forwarded to /infer - Fix tool result capture: use content with status fallback - Update endpoint_type description to include infer - Move skip tests from TestFilterByScope to TestDataValidator - Fix MockerFixture import in test_validator.py - Fix --metrics filter: handle turn_metrics=None by materializing system defaults before filtering; add conversation_metrics filter - Add metrics=None to runner test fixture for --metrics support - Add tests for metrics filter materialization Signed-off-by: Ellis Low <elow@redhat.com>
1 parent ce00854 commit be8019d

7 files changed

Lines changed: 208 additions & 58 deletions

File tree

src/lightspeed_evaluation/core/api/client.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,11 @@
2828

2929

3030
def _is_retryable_server_error(exception: BaseException) -> bool:
31-
"""Check if exception is a retryable HTTP error (429 or 5xx).
31+
"""Check if exception is a retryable HTTP error (429 or transient 5xx).
32+
33+
Only 502 Bad Gateway, 503 Service Unavailable, and 504 Gateway Timeout
34+
are retried. 500 Internal Server Error is excluded as it may indicate
35+
permanent server bugs.
3236
3337
Args:
3438
exception: The exception to check.
@@ -39,7 +43,7 @@ def _is_retryable_server_error(exception: BaseException) -> bool:
3943
if not isinstance(exception, httpx.HTTPStatusError):
4044
return False
4145
status = exception.response.status_code
42-
return status == 429 or 500 <= status < 600
46+
return status in (429, 502, 503, 504)
4347

4448

4549
class APIClient:
@@ -352,6 +356,10 @@ def _rlsapi_infer_query(self, api_request: APIRequest) -> APIResponse:
352356
raise APIError("HTTP client not initialized")
353357
try:
354358
request_data = api_request.model_dump(exclude_none=True)
359+
# `extra_request_params` are not forwarded to `/infer` — the
360+
# endpoint only accepts `question` and `include_metadata`.
361+
# Other params (model, provider, etc.) are not part of the
362+
# RLSAPI `/infer` API contract.
355363
infer_request: dict[str, object] = {
356364
"question": request_data.pop("query"),
357365
"include_metadata": True,
@@ -361,7 +369,13 @@ def _rlsapi_infer_query(self, api_request: APIRequest) -> APIResponse:
361369
"RLSAPI infer request URL: /api/lightspeed/%s/infer",
362370
self.config.version,
363371
)
364-
logger.debug("RLSAPI infer request body: %s", infer_request)
372+
logger.debug(
373+
"RLSAPI infer request: version=%s, include_metadata=%s, "
374+
"question_length=%d",
375+
self.config.version,
376+
True,
377+
len(str(infer_request.get("question", ""))),
378+
)
365379

366380
response = self.client.post(
367381
f"/api/lightspeed/{self.config.version}/infer",
@@ -385,12 +399,12 @@ def _rlsapi_infer_query(self, api_request: APIRequest) -> APIResponse:
385399
response_data["tool_calls"] = data["tool_calls"]
386400
if "tool_results" in data:
387401
tool_results = data["tool_results"]
402+
rag_chunks: list[dict[str, str]] = []
388403
for result in tool_results:
389404
if result.get("type") == "mcp_call":
390405
content = result["content"].split("---")
391-
response_data["rag_chunks"] = [
392-
{"content": chunk} for chunk in content
393-
]
406+
rag_chunks.extend([{"content": chunk} for chunk in content])
407+
response_data["rag_chunks"] = rag_chunks
394408

395409
if "response" not in response_data:
396410
raise APIError("API response missing 'response' field")
@@ -402,16 +416,8 @@ def _rlsapi_infer_query(self, api_request: APIRequest) -> APIResponse:
402416
for tool_call in raw_tool_calls:
403417
if isinstance(tool_call, dict):
404418
formatted_tool: dict[str, object] = {
405-
"tool_name": (
406-
tool_call.get("tool_name")
407-
or tool_call.get("name")
408-
or ""
409-
),
410-
"arguments": (
411-
tool_call.get("arguments")
412-
or tool_call.get("args")
413-
or {}
414-
),
419+
"tool_name": tool_call.get("name", ""),
420+
"arguments": tool_call.get("args", {}),
415421
}
416422
if "tool_results" in response_data.get("data", {}):
417423
tool_call_id = tool_call.get("id")
@@ -424,7 +430,12 @@ def _rlsapi_infer_query(self, api_request: APIRequest) -> APIResponse:
424430
None,
425431
)
426432
if matching_result:
427-
formatted_tool["result"] = matching_result["status"]
433+
formatted_tool["result"] = matching_result.get(
434+
"content", matching_result.get("status", "")
435+
)
436+
formatted_tool["status"] = matching_result.get(
437+
"status", ""
438+
)
428439
formatted_tool_calls.append([formatted_tool])
429440

430441
response_data["tool_calls"] = formatted_tool_calls

src/lightspeed_evaluation/core/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
DEFAULT_SSL_CERT_FILE = None
7171
DEFAULT_LLM_TEMPERATURE = 0.0
7272
DEFAULT_LLM_MAX_TOKENS = 512
73-
DEFAULT_LLM_RETRIES = 5
73+
DEFAULT_LLM_RETRIES = 3
7474
DEFAULT_LLM_CACHE_DIR = ".caches/llm_cache"
7575

7676
DEFAULT_EMBEDDING_PROVIDER = "openai"

src/lightspeed_evaluation/core/models/system.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ class APIConfig(BaseModel):
271271
)
272272
endpoint_type: str = Field(
273273
default=DEFAULT_ENDPOINT_TYPE,
274-
description="API endpoint type (streaming or query)",
274+
description="API endpoint type (streaming, query, or infer)",
275275
)
276276
timeout: int = Field(
277277
default=DEFAULT_API_TIMEOUT, ge=1, description="Request timeout in seconds"

src/lightspeed_evaluation/core/system/validator.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def __init__(
159159
self.api_enabled = api_enabled
160160
self.original_data_path: Optional[str] = None
161161
self.fail_on_invalid_data = fail_on_invalid_data
162+
self._system_config = system_config
162163
self._turn_level_metrics: set[str] = (
163164
system_config.turn_level_metric_names if system_config else set()
164165
)
@@ -235,15 +236,38 @@ def load_evaluation_data(
235236
# Remove skipped conversations
236237
evaluation_data = [e for e in evaluation_data if not e.skip]
237238

238-
# Filter turn_metrics if --metrics was specified
239+
# Filter turn_metrics and conversation_metrics if --metrics was specified
239240
if metrics:
240241
metrics_set = set(metrics)
241242
for eval_data in evaluation_data:
242243
for turn in eval_data.turns:
243-
if turn.turn_metrics:
244+
if turn.turn_metrics is not None:
244245
turn.turn_metrics = [
245246
m for m in turn.turn_metrics if m in metrics_set
246247
]
248+
elif self._system_config is not None:
249+
turn_defaults = (
250+
self._system_config.default_turn_metrics_metadata
251+
)
252+
turn.turn_metrics = [
253+
m
254+
for m, meta in turn_defaults.items()
255+
if meta.get("default", False) and m in metrics_set
256+
]
257+
258+
if eval_data.conversation_metrics is not None:
259+
eval_data.conversation_metrics = [
260+
m for m in eval_data.conversation_metrics if m in metrics_set
261+
]
262+
elif self._system_config is not None:
263+
conv_defaults = (
264+
self._system_config.default_conversation_metrics_metadata
265+
)
266+
eval_data.conversation_metrics = [
267+
m
268+
for m, meta in conv_defaults.items()
269+
if meta.get("default", False) and m in metrics_set
270+
]
247271

248272
# Semantic validation (metrics availability and requirements)
249273
if not self._validate_evaluation_data(evaluation_data):

tests/unit/core/api/test_client.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -697,7 +697,7 @@ def test_is_retryable_server_error(self, mocker: MockerFixture) -> None:
697697
)
698698

699699
resp_500 = mocker.Mock(status_code=500)
700-
assert _is_retryable_server_error(
700+
assert not _is_retryable_server_error(
701701
httpx.HTTPStatusError("", request=mocker.Mock(), response=resp_500)
702702
)
703703

@@ -814,23 +814,23 @@ def test_query_raises_api_error_after_max_retries(
814814

815815
assert mock_client.post.call_count == 4 # 3 retries + 1 initial attempt
816816

817-
def test_standard_query_retries_on_500_then_succeeds(
817+
def test_standard_query_retries_on_502_then_succeeds(
818818
self, basic_api_config_query_endpoint: APIConfig, mocker: MockerFixture
819819
) -> None:
820-
"""Test standard query retries on 500 error and succeeds on retry."""
821-
mock_response_500 = mocker.Mock(status_code=500, text="Internal server error")
822-
mock_response_500.raise_for_status.side_effect = httpx.HTTPStatusError(
823-
"500 error", request=mocker.Mock(), response=mock_response_500
820+
"""Test standard query retries on 502 error and succeeds on retry."""
821+
mock_response_502 = mocker.Mock(status_code=502, text="Bad gateway")
822+
mock_response_502.raise_for_status.side_effect = httpx.HTTPStatusError(
823+
"502 error", request=mocker.Mock(), response=mock_response_502
824824
)
825825

826826
mock_response_success = mocker.Mock(status_code=200)
827827
mock_response_success.json.return_value = {
828-
"response": "Success after 500 retry",
828+
"response": "Success after 502 retry",
829829
"conversation_id": "conv_123",
830830
}
831831

832832
mock_client = mocker.Mock()
833-
mock_client.post.side_effect = [mock_response_500, mock_response_success]
833+
mock_client.post.side_effect = [mock_response_502, mock_response_success]
834834
mock_client.headers = {}
835835

836836
mocker.patch(
@@ -841,7 +841,7 @@ def test_standard_query_retries_on_500_then_succeeds(
841841
client = APIClient(basic_api_config_query_endpoint)
842842
result = client.query("Test standard query")
843843

844-
assert result.response == "Success after 500 retry"
844+
assert result.response == "Success after 502 retry"
845845
assert mock_client.post.call_count == 2
846846

847847

@@ -903,7 +903,7 @@ def test_infer_query_formats_tool_calls(
903903
"name": "search_documentation",
904904
"args": {"q": "rhel"},
905905
},
906-
{"id": "tc2", "tool_name": "mcp_list_tools", "arguments": {}},
906+
{"id": "tc2", "name": "mcp_list_tools", "args": {}},
907907
],
908908
"tool_results": [
909909
{
@@ -938,9 +938,11 @@ def test_infer_query_formats_tool_calls(
938938
assert isinstance(result.tool_calls[0], list)
939939
assert result.tool_calls[0][0]["tool_name"] == "search_documentation"
940940
assert result.tool_calls[0][0]["arguments"] == {"q": "rhel"}
941-
assert result.tool_calls[0][0]["result"] == "success"
941+
assert result.tool_calls[0][0]["result"] == "result1"
942+
assert result.tool_calls[0][0]["status"] == "success"
942943
assert result.tool_calls[1][0]["tool_name"] == "mcp_list_tools"
943-
assert result.tool_calls[1][0]["result"] == "completed"
944+
assert result.tool_calls[1][0]["result"] == "tools"
945+
assert result.tool_calls[1][0]["status"] == "completed"
944946

945947
def test_infer_query_extracts_rag_chunks(
946948
self, basic_api_config_infer_endpoint: APIConfig, mocker: MockerFixture

0 commit comments

Comments
 (0)