Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions python/aibrix/aibrix/metadata/api/v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,61 @@
REQUIRED_FIELDS = ["custom_id", "method", "url", "body"]
VALID_HTTP_METHODS = {"GET", "POST", "PUT", "DELETE", "PATCH"}

# Endpoint-specific required body fields
# These define the minimum required fields in the request body for each endpoint type
ENDPOINT_REQUIRED_BODY_FIELDS: dict[str, list[str]] = {
BatchJobEndpoint.CHAT_COMPLETIONS.value: ["model", "messages"],
BatchJobEndpoint.COMPLETIONS.value: ["model", "prompt"],
BatchJobEndpoint.EMBEDDINGS.value: ["model", "input"],
BatchJobEndpoint.RERANK.value: ["model", "query", "documents"],
}


def _validate_request_body_for_endpoint(
body: dict, endpoint: str, line_num: int
) -> Optional[str]:
"""Validate request body fields are appropriate for the given endpoint.

Args:
body: The request body dictionary
endpoint: The API endpoint string (e.g., "/v1/chat/completions")
line_num: Line number for error reporting

Returns:
Error message string if validation fails, None if valid
"""
required_fields = ENDPOINT_REQUIRED_BODY_FIELDS.get(endpoint)
if required_fields is None:
# Unknown endpoint, skip body validation
return None
Comment on lines +71 to +74
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Body validation is skipped entirely when endpoint doesn’t exactly match a key in ENDPOINT_REQUIRED_BODY_FIELDS. Since the value passed in comes from the batch input line (request_url), any benign variation (e.g., trailing slash, query string) would bypass validation. Consider normalizing the endpoint before lookup (e.g., strip query/fragment, strip trailing /) and/or validating against the canonical batch job endpoint (the endpoint parameter passed to _validate_batch_input_file) rather than the per-line URL string.

Copilot uses AI. Check for mistakes.

for field in required_fields:
if field not in body:
return (
f"Line {line_num}: Request body for endpoint '{endpoint}' "
f"is missing required field '{field}'"
)

# Endpoint-specific type validation
if endpoint == BatchJobEndpoint.CHAT_COMPLETIONS.value:
if not isinstance(body.get("messages"), list):
return f"Line {line_num}: 'messages' must be an array for {endpoint}"
elif endpoint == BatchJobEndpoint.COMPLETIONS.value:
prompt = body.get("prompt")
if not isinstance(prompt, (str, list)):
return f"Line {line_num}: 'prompt' must be a string or array for {endpoint}"
elif endpoint == BatchJobEndpoint.EMBEDDINGS.value:
input_val = body.get("input")
if not isinstance(input_val, (str, list)):
return f"Line {line_num}: 'input' must be a string or array for {endpoint}"
elif endpoint == BatchJobEndpoint.RERANK.value:
if not isinstance(body.get("query"), str):
return f"Line {line_num}: 'query' must be a string for {endpoint}"
if not isinstance(body.get("documents"), list):
return f"Line {line_num}: 'documents' must be an array for {endpoint}"
Comment on lines +86 to +99
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error messages use “array”, which is more JavaScript/JSON terminology; in Python-facing validation errors, “list” is usually clearer (especially since the check is isinstance(..., list)). Consider rewording to “list” (and similarly for other messages) to make failures more actionable/debuggable.

Suggested change
return f"Line {line_num}: 'messages' must be an array for {endpoint}"
elif endpoint == BatchJobEndpoint.COMPLETIONS.value:
prompt = body.get("prompt")
if not isinstance(prompt, (str, list)):
return f"Line {line_num}: 'prompt' must be a string or array for {endpoint}"
elif endpoint == BatchJobEndpoint.EMBEDDINGS.value:
input_val = body.get("input")
if not isinstance(input_val, (str, list)):
return f"Line {line_num}: 'input' must be a string or array for {endpoint}"
elif endpoint == BatchJobEndpoint.RERANK.value:
if not isinstance(body.get("query"), str):
return f"Line {line_num}: 'query' must be a string for {endpoint}"
if not isinstance(body.get("documents"), list):
return f"Line {line_num}: 'documents' must be an array for {endpoint}"
return f"Line {line_num}: 'messages' must be a list for {endpoint}"
elif endpoint == BatchJobEndpoint.COMPLETIONS.value:
prompt = body.get("prompt")
if not isinstance(prompt, (str, list)):
return f"Line {line_num}: 'prompt' must be a string or list for {endpoint}"
elif endpoint == BatchJobEndpoint.EMBEDDINGS.value:
input_val = body.get("input")
if not isinstance(input_val, (str, list)):
return f"Line {line_num}: 'input' must be a string or list for {endpoint}"
elif endpoint == BatchJobEndpoint.RERANK.value:
if not isinstance(body.get("query"), str):
return f"Line {line_num}: 'query' must be a string for {endpoint}"
if not isinstance(body.get("documents"), list):
return f"Line {line_num}: 'documents' must be a list for {endpoint}"

Copilot uses AI. Check for mistakes.

return None


async def _validate_batch_input_file(
storage: BaseStorage, file_id: str, endpoint: str
Expand Down Expand Up @@ -136,6 +191,13 @@ async def _validate_batch_input_file(
f"batch endpoint '{endpoint}'",
)

# Validate request body has required fields for the endpoint
body_error = _validate_request_body_for_endpoint(
request["body"], request_url, line_num
)
Comment on lines +196 to +197
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-medium medium

The _validate_batch_input_file function passes the untrusted request_url from the input file to _validate_request_body_for_endpoint instead of using the validated endpoint from the batch request. Since _validate_batch_input_file only checks if request_url ends with the expected endpoint (line 187), an attacker can provide a full URL (e.g., http://attacker.com/v1/chat/completions) that passes the check but fails to match any key in ENDPOINT_REQUIRED_BODY_FIELDS (which contains only paths like /v1/chat/completions). This causes _validate_request_body_for_endpoint to skip validation (line 74), allowing malformed request bodies to bypass the early validation check.

            body_error = _validate_request_body_for_endpoint(
                request["body"], endpoint, line_num
            )

if body_error:
return 0, body_error

# Check if file was empty
if request_count == 0:
return 0, "Batch input file is empty or contains only empty lines"
Expand Down
200 changes: 199 additions & 1 deletion python/aibrix/tests/batch/test_batch_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unit tests for batch API endpoint support."""
"""Unit tests for batch API endpoint support and body validation."""

import pytest

from aibrix.batch.job_entity import BatchJobEndpoint
from aibrix.metadata.api.v1.batch import _validate_request_body_for_endpoint


Comment on lines +20 to 22
Copy link

Copilot AI Mar 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests are importing and depending on a private function (_validate_request_body_for_endpoint). This increases coupling and makes refactors harder (renames/moves break tests even if behavior is unchanged). Consider promoting this validator to a non-underscored helper (or moving it into a small shared validation module) and importing that public surface from tests.

Suggested change
from aibrix.metadata.api.v1.batch import _validate_request_body_for_endpoint
import aibrix.metadata.api.v1.batch as batch_module
_validate_request_body_for_endpoint = batch_module._validate_request_body_for_endpoint

Copilot uses AI. Check for mistakes.
def test_chat_completions_endpoint_supported():
Expand Down Expand Up @@ -84,3 +85,200 @@ def test_endpoint_count():
"""Test that we have exactly 4 supported endpoints."""
endpoints = list(BatchJobEndpoint)
assert len(endpoints) == 4


# ---- Endpoint-specific body validation tests ----


class TestEndpointBodyValidation:
"""Tests for endpoint-specific request body validation."""

def test_chat_completions_valid_body(self):
"""Test valid chat completions body passes validation."""
body = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hello"}],
}
result = _validate_request_body_for_endpoint(
body, "/v1/chat/completions", 1
)
assert result is None

def test_chat_completions_missing_messages(self):
"""Test chat completions body without messages fails."""
body = {"model": "gpt-3.5-turbo"}
result = _validate_request_body_for_endpoint(
body, "/v1/chat/completions", 1
)
assert result is not None
assert "messages" in result

def test_chat_completions_messages_not_array(self):
"""Test chat completions body with non-array messages fails."""
body = {"model": "gpt-3.5-turbo", "messages": "not an array"}
result = _validate_request_body_for_endpoint(
body, "/v1/chat/completions", 1
)
assert result is not None
assert "messages" in result
assert "array" in result

def test_completions_valid_body_string_prompt(self):
"""Test valid completions body with string prompt passes."""
body = {"model": "gpt-3.5-turbo", "prompt": "Hello world"}
result = _validate_request_body_for_endpoint(
body, "/v1/completions", 1
)
assert result is None

def test_completions_valid_body_array_prompt(self):
"""Test valid completions body with array prompt passes."""
body = {"model": "gpt-3.5-turbo", "prompt": ["Hello", "World"]}
result = _validate_request_body_for_endpoint(
body, "/v1/completions", 1
)
assert result is None

def test_completions_missing_prompt(self):
"""Test completions body without prompt fails."""
body = {"model": "gpt-3.5-turbo"}
result = _validate_request_body_for_endpoint(
body, "/v1/completions", 1
)
assert result is not None
assert "prompt" in result

def test_completions_invalid_prompt_type(self):
"""Test completions body with invalid prompt type fails."""
body = {"model": "gpt-3.5-turbo", "prompt": 123}
result = _validate_request_body_for_endpoint(
body, "/v1/completions", 1
)
assert result is not None
assert "prompt" in result

def test_embeddings_valid_body_string_input(self):
"""Test valid embeddings body with string input passes."""
body = {"model": "text-embedding-ada-002", "input": "Hello world"}
result = _validate_request_body_for_endpoint(
body, "/v1/embeddings", 1
)
assert result is None

def test_embeddings_valid_body_array_input(self):
"""Test valid embeddings body with array input passes."""
body = {"model": "text-embedding-ada-002", "input": ["Hello", "World"]}
result = _validate_request_body_for_endpoint(
body, "/v1/embeddings", 1
)
assert result is None

def test_embeddings_missing_input(self):
"""Test embeddings body without input fails."""
body = {"model": "text-embedding-ada-002"}
result = _validate_request_body_for_endpoint(
body, "/v1/embeddings", 1
)
assert result is not None
assert "input" in result

def test_embeddings_missing_model(self):
"""Test embeddings body without model fails."""
body = {"input": "Hello world"}
result = _validate_request_body_for_endpoint(
body, "/v1/embeddings", 1
)
assert result is not None
assert "model" in result

def test_rerank_valid_body(self):
"""Test valid rerank body passes validation."""
body = {
"model": "reranker-v1",
"query": "What is AI?",
"documents": ["AI is...", "Machine learning is..."],
}
result = _validate_request_body_for_endpoint(
body, "/v1/rerank", 1
)
assert result is None

def test_rerank_missing_query(self):
"""Test rerank body without query fails."""
body = {
"model": "reranker-v1",
"documents": ["doc1", "doc2"],
}
result = _validate_request_body_for_endpoint(
body, "/v1/rerank", 1
)
assert result is not None
assert "query" in result

def test_rerank_missing_documents(self):
"""Test rerank body without documents fails."""
body = {
"model": "reranker-v1",
"query": "What is AI?",
}
result = _validate_request_body_for_endpoint(
body, "/v1/rerank", 1
)
assert result is not None
assert "documents" in result

def test_rerank_invalid_documents_type(self):
"""Test rerank body with non-array documents fails."""
body = {
"model": "reranker-v1",
"query": "What is AI?",
"documents": "not an array",
}
result = _validate_request_body_for_endpoint(
body, "/v1/rerank", 1
)
assert result is not None
assert "documents" in result
assert "array" in result

def test_unknown_endpoint_passes(self):
"""Test that unknown endpoints skip body validation."""
body = {"anything": "goes"}
result = _validate_request_body_for_endpoint(
body, "/v1/unknown", 1
)
assert result is None

@pytest.mark.parametrize(
"endpoint,body",
[
(
"/v1/chat/completions",
{
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Hi"}],
},
),
(
"/v1/completions",
{"model": "gpt-3.5-turbo", "prompt": "Hi"},
),
(
"/v1/embeddings",
{"model": "text-embedding-ada-002", "input": "Hi"},
),
(
"/v1/rerank",
{
"model": "reranker-v1",
"query": "Hi",
"documents": ["doc1"],
},
),
],
ids=["chat_completions", "completions", "embeddings", "rerank"],
)
def test_all_endpoints_accept_valid_bodies(self, endpoint, body):
"""Parametrized test: all endpoints accept their valid bodies."""
result = _validate_request_body_for_endpoint(body, endpoint, 1)
assert result is None, f"Unexpected error for {endpoint}: {result}"
Loading
Loading