-
Notifications
You must be signed in to change notification settings - Fork 541
feat(batch): add multi-endpoint body validation and testing #1982
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -45,6 +45,61 @@ | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| REQUIRED_FIELDS = ["custom_id", "method", "url", "body"] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| VALID_HTTP_METHODS = {"GET", "POST", "PUT", "DELETE", "PATCH"} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Endpoint-specific required body fields | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # These define the minimum required fields in the request body for each endpoint type | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ENDPOINT_REQUIRED_BODY_FIELDS: dict[str, list[str]] = { | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| BatchJobEndpoint.CHAT_COMPLETIONS.value: ["model", "messages"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| BatchJobEndpoint.COMPLETIONS.value: ["model", "prompt"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| BatchJobEndpoint.EMBEDDINGS.value: ["model", "input"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| BatchJobEndpoint.RERANK.value: ["model", "query", "documents"], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _validate_request_body_for_endpoint( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| body: dict, endpoint: str, line_num: int | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) -> Optional[str]: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """Validate request body fields are appropriate for the given endpoint. | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Args: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| body: The request body dictionary | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| endpoint: The API endpoint string (e.g., "/v1/chat/completions") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| line_num: Line number for error reporting | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Returns: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Error message string if validation fails, None if valid | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| required_fields = ENDPOINT_REQUIRED_BODY_FIELDS.get(endpoint) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if required_fields is None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Unknown endpoint, skip body validation | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return None | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| for field in required_fields: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if field not in body: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return ( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"Line {line_num}: Request body for endpoint '{endpoint}' " | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"is missing required field '{field}'" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| # Endpoint-specific type validation | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if endpoint == BatchJobEndpoint.CHAT_COMPLETIONS.value: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not isinstance(body.get("messages"), list): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return f"Line {line_num}: 'messages' must be an array for {endpoint}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif endpoint == BatchJobEndpoint.COMPLETIONS.value: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| prompt = body.get("prompt") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not isinstance(prompt, (str, list)): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return f"Line {line_num}: 'prompt' must be a string or array for {endpoint}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif endpoint == BatchJobEndpoint.EMBEDDINGS.value: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| input_val = body.get("input") | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not isinstance(input_val, (str, list)): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return f"Line {line_num}: 'input' must be a string or array for {endpoint}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| elif endpoint == BatchJobEndpoint.RERANK.value: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not isinstance(body.get("query"), str): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return f"Line {line_num}: 'query' must be a string for {endpoint}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if not isinstance(body.get("documents"), list): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return f"Line {line_num}: 'documents' must be an array for {endpoint}" | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+86
to
+99
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return f"Line {line_num}: 'messages' must be an array for {endpoint}" | |
| elif endpoint == BatchJobEndpoint.COMPLETIONS.value: | |
| prompt = body.get("prompt") | |
| if not isinstance(prompt, (str, list)): | |
| return f"Line {line_num}: 'prompt' must be a string or array for {endpoint}" | |
| elif endpoint == BatchJobEndpoint.EMBEDDINGS.value: | |
| input_val = body.get("input") | |
| if not isinstance(input_val, (str, list)): | |
| return f"Line {line_num}: 'input' must be a string or array for {endpoint}" | |
| elif endpoint == BatchJobEndpoint.RERANK.value: | |
| if not isinstance(body.get("query"), str): | |
| return f"Line {line_num}: 'query' must be a string for {endpoint}" | |
| if not isinstance(body.get("documents"), list): | |
| return f"Line {line_num}: 'documents' must be an array for {endpoint}" | |
| return f"Line {line_num}: 'messages' must be a list for {endpoint}" | |
| elif endpoint == BatchJobEndpoint.COMPLETIONS.value: | |
| prompt = body.get("prompt") | |
| if not isinstance(prompt, (str, list)): | |
| return f"Line {line_num}: 'prompt' must be a string or list for {endpoint}" | |
| elif endpoint == BatchJobEndpoint.EMBEDDINGS.value: | |
| input_val = body.get("input") | |
| if not isinstance(input_val, (str, list)): | |
| return f"Line {line_num}: 'input' must be a string or list for {endpoint}" | |
| elif endpoint == BatchJobEndpoint.RERANK.value: | |
| if not isinstance(body.get("query"), str): | |
| return f"Line {line_num}: 'query' must be a string for {endpoint}" | |
| if not isinstance(body.get("documents"), list): | |
| return f"Line {line_num}: 'documents' must be a list for {endpoint}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The _validate_batch_input_file function passes the untrusted request_url from the input file to _validate_request_body_for_endpoint instead of using the validated endpoint from the batch request. Since _validate_batch_input_file only checks if request_url ends with the expected endpoint (line 187), an attacker can provide a full URL (e.g., http://attacker.com/v1/chat/completions) that passes the check but fails to match any key in ENDPOINT_REQUIRED_BODY_FIELDS (which contains only paths like /v1/chat/completions). This causes _validate_request_body_for_endpoint to skip validation (line 74), allowing malformed request bodies to bypass the early validation check.
body_error = _validate_request_body_for_endpoint(
request["body"], endpoint, line_num
)| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -12,11 +12,12 @@ | |||||||
| # See the License for the specific language governing permissions and | ||||||||
| # limitations under the License. | ||||||||
|
|
||||||||
| """Unit tests for batch API endpoint support.""" | ||||||||
| """Unit tests for batch API endpoint support and body validation.""" | ||||||||
|
|
||||||||
| import pytest | ||||||||
|
|
||||||||
| from aibrix.batch.job_entity import BatchJobEndpoint | ||||||||
| from aibrix.metadata.api.v1.batch import _validate_request_body_for_endpoint | ||||||||
|
|
||||||||
|
|
||||||||
|
Comment on lines
+20
to
22
|
||||||||
| from aibrix.metadata.api.v1.batch import _validate_request_body_for_endpoint | |
| import aibrix.metadata.api.v1.batch as batch_module | |
| _validate_request_body_for_endpoint = batch_module._validate_request_body_for_endpoint |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Body validation is skipped entirely when
endpointdoesn’t exactly match a key inENDPOINT_REQUIRED_BODY_FIELDS. Since the value passed in comes from the batch input line (request_url), any benign variation (e.g., trailing slash, query string) would bypass validation. Consider normalizing the endpoint before lookup (e.g., strip query/fragment, strip trailing/) and/or validating against the canonical batch job endpoint (theendpointparameter passed to_validate_batch_input_file) rather than the per-line URL string.