Skip to content
Merged
155 changes: 143 additions & 12 deletions src/lightspeed_evaluation/core/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,23 @@
logger = logging.getLogger(__name__)


def _is_too_many_requests_error(exception: BaseException) -> bool:
"""Check if exception is a 429 error."""
return (
isinstance(exception, httpx.HTTPStatusError)
and exception.response.status_code == 429
)
def _is_retryable_server_error(exception: BaseException) -> bool:
"""Check if exception is a retryable HTTP error (429 or transient 5xx).

Only 502 Bad Gateway, 503 Service Unavailable, and 504 Gateway Timeout
are retried. 500 Internal Server Error is excluded as it may indicate
permanent server bugs.

Args:
exception: The exception to check.

Returns:
True if the exception is a retryable HTTP status error.
"""
if not isinstance(exception, httpx.HTTPStatusError):
return False
status = exception.response.status_code
return status in (429, 502, 503, 504)


class APIClient:
Expand All @@ -59,10 +70,11 @@ def __init__(
retry_decorator = self._create_retry_decorator()
self._standard_query_with_retry = retry_decorator(self._standard_query)
self._streaming_query_with_retry = retry_decorator(self._streaming_query)
self._rlsapi_infer_query_with_retry = retry_decorator(self._rlsapi_infer_query)

def _create_retry_decorator(self) -> Any:
return retry(
retry=retry_if_exception(_is_too_many_requests_error),
retry=retry_if_exception(_is_retryable_server_error),
stop=stop_after_attempt(
self.config.num_retries + 1
), # +1 to account for the initial attempt
Expand Down Expand Up @@ -186,6 +198,8 @@ def query(

if self.config.endpoint_type == "streaming":
response = self._streaming_query_with_retry(api_request)
elif self.config.endpoint_type == "infer":
response = self._rlsapi_infer_query_with_retry(api_request)
else:
response = self._standard_query_with_retry(api_request)

Expand All @@ -196,7 +210,7 @@ def query(
except RetryError as e:
raise APIError(
f"Maximum retry attempts ({self.config.num_retries}) reached "
"due to persistent rate limiting (HTTP 429)."
"due to retryable server errors (HTTP 429/5xx)."
) from e

def _prepare_request(
Expand Down Expand Up @@ -285,8 +299,7 @@ def _standard_query(self, api_request: APIRequest) -> APIResponse:
except httpx.TimeoutException as e:
raise self._handle_timeout_error("standard", self.config.timeout) from e
except httpx.HTTPStatusError as e:
# Re-raise 429 errors without conversion to allow retry decorator to handle them
if e.response.status_code == 429:
if _is_retryable_server_error(e):
raise
raise self._handle_http_error(e) from e
except ValueError as e:
Expand All @@ -313,8 +326,7 @@ def _streaming_query(self, api_request: APIRequest) -> APIResponse:
except httpx.TimeoutException as e:
raise self._handle_timeout_error("streaming", self.config.timeout) from e
except httpx.HTTPStatusError as e:
# Re-raise 429 errors without conversion to allow retry decorator to handle them
if e.response.status_code == 429:
if _is_retryable_server_error(e):
raise
raise self._handle_http_error(e) from e
except ValueError as e:
Expand All @@ -324,6 +336,125 @@ def _streaming_query(self, api_request: APIRequest) -> APIResponse:
except Exception as e:
raise self._handle_unexpected_error(e, "streaming query") from e

def _rlsapi_infer_query(self, api_request: APIRequest) -> APIResponse:
"""Query the RLSAPI /infer endpoint for tool call and RAG metadata.

The infer endpoint uses a different request/response format than
the standard query/streaming endpoints, converting "query" to
"question" and parsing tool_calls and rag_chunks from tool_results.

Args:
api_request: The prepared API request.

Returns:
APIResponse with response text, tool calls, and RAG contexts.

Raises:
APIError: If the request fails or response is invalid.
"""
if not self.client:
raise APIError("HTTP client not initialized")
try:
request_data = api_request.model_dump(exclude_none=True)
# `extra_request_params` are not forwarded to `/infer` — the
# endpoint only accepts `question` and `include_metadata`.
# Other params (model, provider, etc.) are not part of the
# RLSAPI `/infer` API contract.
infer_request: dict[str, object] = {
"question": request_data.pop("query"),
"include_metadata": True,
}

logger.debug(
"RLSAPI infer request URL: /api/lightspeed/%s/infer",
self.config.version,
)
logger.debug(
"RLSAPI infer request: version=%s, include_metadata=%s, "
"question_length=%d",
self.config.version,
True,
len(str(infer_request.get("question", ""))),
)

response = self.client.post(
f"/api/lightspeed/{self.config.version}/infer",
json=infer_request,
)
response.raise_for_status()

response_data = response.json()

if "data" in response_data:
data = response_data["data"]
if "text" in data:
response_data["response"] = data["text"]
if "request_id" in data:
response_data["conversation_id"] = data["request_id"]
if "input_tokens" in data:
response_data["input_tokens"] = data["input_tokens"]
if "output_tokens" in data:
response_data["output_tokens"] = data["output_tokens"]
if "tool_calls" in data:
response_data["tool_calls"] = data["tool_calls"]
if "tool_results" in data:
tool_results = data["tool_results"]
rag_chunks: list[dict[str, str]] = []
for result in tool_results:
if result.get("type") == "mcp_call":
content = result["content"].split("---")
rag_chunks.extend([{"content": chunk} for chunk in content])
response_data["rag_chunks"] = rag_chunks

if "response" not in response_data:
raise APIError("API response missing 'response' field")

if "tool_calls" in response_data and response_data["tool_calls"]:
raw_tool_calls = response_data["tool_calls"]
formatted_tool_calls = []

for tool_call in raw_tool_calls:
if isinstance(tool_call, dict):
formatted_tool: dict[str, object] = {
"tool_name": tool_call.get("name", ""),
"arguments": tool_call.get("args", {}),
}
if "tool_results" in response_data.get("data", {}):
tool_call_id = tool_call.get("id")
matching_result = next(
(
r
for r in response_data["data"]["tool_results"]
if r.get("id") == tool_call_id
),
None,
)
if matching_result:
formatted_tool["result"] = matching_result.get(
"content", matching_result.get("status", "")
)
formatted_tool["status"] = matching_result.get(
"status", ""
)
formatted_tool_calls.append([formatted_tool])

response_data["tool_calls"] = formatted_tool_calls

return APIResponse.from_raw_response(response_data)

except httpx.TimeoutException as e:
raise self._handle_timeout_error("infer", self.config.timeout) from e
except httpx.HTTPStatusError as e:
if _is_retryable_server_error(e):
raise
raise self._handle_http_error(e) from e
except ValueError as e:
raise self._handle_validation_error(e) from e
except APIError:
raise
except Exception as e:
raise self._handle_unexpected_error(e, "infer query") from e

def _handle_response_errors(self, response: httpx.Response) -> None:
"""Handle HTTP response errors for streaming endpoint."""
if response.status_code != 200:
Expand Down
2 changes: 1 addition & 1 deletion src/lightspeed_evaluation/core/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
DEFAULT_API_VERSION = "v1"
DEFAULT_API_TIMEOUT = 300
DEFAULT_ENDPOINT_TYPE = "streaming"
SUPPORTED_ENDPOINT_TYPES = ["streaming", "query"]
SUPPORTED_ENDPOINT_TYPES = ["streaming", "query", "infer"]
DEFAULT_API_CACHE_DIR = ".caches/api_cache"

DEFAULT_API_NUM_RETRIES = 3
Comment thread
Lifto marked this conversation as resolved.
Expand Down
8 changes: 8 additions & 0 deletions src/lightspeed_evaluation/core/models/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,14 @@ class EvaluationData(BaseModel):
min_length=1,
description="Tag for grouping and filtering conversations",
)
skip: bool = Field(
default=False,
description="Skip this conversation during evaluation",
)
skip_reason: Optional[str] = Field(
default=None,
description="Why this conversation is skipped (documentation only)",
)

# Conversation-level metrics
conversation_metrics: Optional[list[str]] = Field(
Expand Down
4 changes: 2 additions & 2 deletions src/lightspeed_evaluation/core/models/system.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ class APIConfig(BaseModel):
)
endpoint_type: str = Field(
default=DEFAULT_ENDPOINT_TYPE,
description="API endpoint type (streaming or query)",
description="API endpoint type (streaming, query, or infer)",
)
timeout: int = Field(
default=DEFAULT_API_TIMEOUT, ge=1, description="Request timeout in seconds"
Expand Down Expand Up @@ -301,7 +301,7 @@ class APIConfig(BaseModel):
ge=0,
description=(
"Maximum number of retry attempts for API calls on "
"429 Too Many Requests errors"
"retryable server errors (HTTP 429/5xx)"
),
)

Expand Down
39 changes: 39 additions & 0 deletions src/lightspeed_evaluation/core/system/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def __init__(
self.api_enabled = api_enabled
self.original_data_path: Optional[str] = None
self.fail_on_invalid_data = fail_on_invalid_data
self._system_config = system_config
self._turn_level_metrics: set[str] = (
system_config.turn_level_metric_names if system_config else set()
)
Expand All @@ -171,6 +172,7 @@ def load_evaluation_data(
data_path: str,
tags: Optional[list[str]] = None,
conv_ids: Optional[list[str]] = None,
metrics: Optional[list[str]] = None,
) -> list[EvaluationData]:
"""Load, filter, and validate evaluation data from YAML file.

Expand All @@ -184,6 +186,7 @@ def load_evaluation_data(
data_path: Path to the evaluation data YAML file
tags: Optional list of tags to filter by
conv_ids: Optional list of conversation group IDs to filter by
metrics: Optional list of metrics to run (filters each turn's turn_metrics)

Returns:
Filtered and validated list of Evaluation Data
Expand Down Expand Up @@ -230,6 +233,42 @@ def load_evaluation_data(
# Filter by scope before validation
evaluation_data = self._filter_by_scope(evaluation_data, tags, conv_ids)

# Remove skipped conversations
evaluation_data = [e for e in evaluation_data if not e.skip]

# Filter turn_metrics and conversation_metrics if --metrics was specified
if metrics:
metrics_set = set(metrics)
for eval_data in evaluation_data:
for turn in eval_data.turns:
if turn.turn_metrics is not None:
turn.turn_metrics = [
m for m in turn.turn_metrics if m in metrics_set
]
elif self._system_config is not None:
turn_defaults = (
self._system_config.default_turn_metrics_metadata
)
turn.turn_metrics = [
m
for m, meta in turn_defaults.items()
if meta.get("default", False) and m in metrics_set
]

if eval_data.conversation_metrics is not None:
eval_data.conversation_metrics = [
m for m in eval_data.conversation_metrics if m in metrics_set
]
elif self._system_config is not None:
conv_defaults = (
self._system_config.default_conversation_metrics_metadata
)
eval_data.conversation_metrics = [
m
for m, meta in conv_defaults.items()
if meta.get("default", False) and m in metrics_set
]

# Semantic validation (metrics availability and requirements)
if not self._validate_evaluation_data(evaluation_data):
raise DataValidationError("Evaluation data validation failed")
Expand Down
7 changes: 7 additions & 0 deletions src/lightspeed_evaluation/runner/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def run_evaluation( # pylint: disable=too-many-locals
eval_args.eval_data,
tags=eval_args.tags,
conv_ids=eval_args.conv_ids,
metrics=eval_args.metrics,
)

print(
Expand Down Expand Up @@ -236,6 +237,12 @@ def main() -> int:
default=None,
help="Filter by conversation group IDs (run only specified conversations)",
)
parser.add_argument(
"--metrics",
nargs="+",
default=None,
help="Filter to only run specified metrics (e.g. custom:answer_correctness)",
)
parser.add_argument(
"--cache-warmup",
action="store_true",
Expand Down
13 changes: 13 additions & 0 deletions tests/unit/core/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,19 @@ def basic_api_config_streaming_endpoint() -> APIConfig:
)


@pytest.fixture
def basic_api_config_infer_endpoint() -> APIConfig:
"""Create test API config for infer endpoint."""
return APIConfig(
enabled=True,
api_base="http://localhost:8080",
version="v1",
endpoint_type="infer",
timeout=30,
cache_enabled=False,
)


@pytest.fixture
def mock_response(mocker: MockerFixture) -> Any:
"""Create a mock streaming response."""
Expand Down
Loading
Loading