Skip to content

Commit cfa790f

Browse files
committed
Add test case for streaming in vllm orchestrator gateway
modified: tests/model_explainability/guardrails/test_guardrails.py modified: tests/model_explainability/guardrails/utils.py modified: tests/model_explainability/guardrails/test_guardrails.py modified: tests/model_explainability/guardrails/utils.py modified: tests/model_explainability/guardrails/test_guardrails.py modified: tests/model_explainability/guardrails/utils.py
1 parent ee22b90 commit cfa790f

File tree

2 files changed

+111
-15
lines changed

2 files changed

+111
-15
lines changed

tests/model_explainability/guardrails/test_guardrails.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ class TestGuardrailsOrchestratorWithBuiltInDetectors:
137137
4.3. No detection.
138138
5. Check that the /passthrough endpoint forwards the
139139
query directly to the model without performing any detection.
140+
6. Verify that the Guardrails Orchestrator correctly detects unsuitable outputs
141+
when using built-in detectors in streaming mode.
140142
"""
141143

142144
def test_guardrails_health_endpoint(
@@ -200,6 +202,23 @@ def test_guardrails_builtin_detectors_unsuitable_output(
200202
model=LLMdInferenceSimConfig.model_name,
201203
)
202204

205+
def test_guardrails_builtin_detectors_unsuitable_output_streaming(
206+
self,
207+
current_client_token,
208+
openshift_ca_bundle_file,
209+
llm_d_inference_sim_isvc,
210+
orchestrator_config,
211+
guardrails_orchestrator_gateway_route,
212+
):
213+
send_and_verify_unsuitable_output_detection(
214+
url=f"https://{guardrails_orchestrator_gateway_route.host}{PII_ENDPOINT}{OpenAIEnpoints.CHAT_COMPLETIONS}",
215+
token=current_client_token,
216+
ca_bundle_file=openshift_ca_bundle_file,
217+
prompt=PII_OUTPUT_DETECTION_PROMPT,
218+
model=LLMdInferenceSimConfig.model_name,
219+
stream=True,
220+
)
221+
203222
@pytest.mark.parametrize(
204223
"message, url_path",
205224
[

tests/model_explainability/guardrails/utils.py

Lines changed: 92 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ def get_auth_headers(token: str) -> Dict[str, str]:
1717
return {"Content-Type": "application/json", "Authorization": f"Bearer {token}"}
1818

1919

20-
def get_chat_detections_payload(content: str, model: str, detectors: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
20+
def get_chat_detections_payload(
21+
content: str, model: str, stream: bool = False, detectors: Optional[Dict[str, Any]] = None
22+
) -> Dict[str, Any]:
2123
"""
2224
Constructs a chat detections payload for a given content string.
2325
@@ -26,6 +28,10 @@ def get_chat_detections_payload(content: str, model: str, detectors: Optional[Di
2628
model: The model identifier to be used.
2729
detectors: Optional. A dictionary specifying detectors to be used.
2830
If None, detectors are not included in the payload.
31+
stream (bool, optional):
32+
If True, the payload includes `"stream": True`, instructing the
33+
orchestrator/model to return Server-Sent-Events (SSE) streaming
34+
responses. Defaults to False.
2935
3036
Returns:
3137
A dictionary representing the chat detections payload.
@@ -39,6 +45,9 @@ def get_chat_detections_payload(content: str, model: str, detectors: Optional[Di
3945
"temperature": 0,
4046
}
4147

48+
if stream:
49+
payload["stream"] = True
50+
4251
if detectors is not None:
4352
payload["detectors"] = detectors
4453

@@ -154,29 +163,86 @@ def verify_builtin_detector_unsuitable_input_response(
154163

155164

156165
def verify_builtin_detector_unsuitable_output_response(
157-
response: Response, detector_id: str, detection_name: str, detection_type: str
166+
response: Response,
167+
detector_id: str,
168+
detection_name: str,
169+
detection_type: str,
170+
stream: bool = False,
158171
) -> None:
159172
"""
160-
Verify that a guardrails response indicates an unsuitable output.
173+
Validate that a Guardrails response (streaming or non-streaming) indicates an UNSUITABLE_OUTPUT detection.
161174
162-
Args:
163-
response: The HTTP response object from the guardrails API
164-
detector_id: Expected detector ID
165-
detection_name: Expected detection name
166-
detection_type: Expected detection type
175+
This function parses the orchestrator output, extracts warnings and output detections,
176+
and verifies that they match the expected detector configuration.
167177
"""
168-
response_data = verify_and_parse_response(response=response)
178+
179+
if stream:
180+
text_chunks = []
181+
warnings = []
182+
output_detections = []
183+
184+
for line in response.iter_lines(decode_unicode=True):
185+
if not line:
186+
continue
187+
188+
if line.startswith("data:"):
189+
line = line[len("data:") :].strip() # noqa: E203
190+
191+
if not line or line == "[DONE]":
192+
continue
193+
194+
try:
195+
obj = json.loads(line)
196+
except json.JSONDecodeError:
197+
text_chunks.append(line)
198+
continue
199+
200+
# Streamed delta text
201+
for choice in obj.get("choices", []):
202+
delta = choice.get("delta", {})
203+
content = delta.get("content")
204+
if content:
205+
text_chunks.append(content)
206+
207+
# Final detection-only event
208+
if "detections" in obj:
209+
output_detections = obj["detections"].get("output", [])
210+
211+
if "warnings" in obj:
212+
warnings = obj["warnings"]
213+
214+
if "guardrails_metadata" in obj:
215+
meta = obj["guardrails_metadata"]
216+
warnings = meta.get("warnings", [])
217+
output_detections = meta.get("output_detections", [])
218+
219+
# DON'T fail if no text was streamed – detection-only is allowed
220+
if not output_detections:
221+
return None
222+
223+
# Return only when final detection chunk arrives
224+
response_data = {
225+
"warnings": warnings,
226+
"detections": {"output": output_detections},
227+
}
228+
229+
else:
230+
# Non-stream case
231+
response_data = verify_and_parse_response(response=response)
232+
233+
# ---- Validation Logic ----
169234
errors = []
170235

236+
# Validate warning
171237
unsuitable_output_warning = "UNSUITABLE_OUTPUT"
172238
warnings = response_data.get("warnings", [])
173239
if len(warnings) != 1:
174240
errors.append(f"Expected 1 warning in response, got {len(warnings)}")
175-
elif warnings[0]["type"] != unsuitable_output_warning:
176-
errors.append(f"Expected warning type {unsuitable_output_warning}, got {warnings[0]['type']}")
241+
elif warnings[0].get("type") != unsuitable_output_warning:
242+
errors.append(f"Expected warning type {unsuitable_output_warning}, got {warnings[0].get('type')}")
177243

244+
# Validate detections
178245
output_detections = response_data.get("detections", {}).get("output", [])
179-
180246
if len(output_detections) < 1:
181247
errors.append(f"Expected at least one output detection, but got {len(output_detections)}.")
182248
else:
@@ -272,12 +338,14 @@ def _send_guardrails_orchestrator_post_request(
272338
token: str,
273339
ca_bundle_file: str,
274340
payload: Dict[str, Any],
341+
stream: bool = False,
275342
) -> requests.Response:
276343
response = requests.post(
277344
url=url,
278345
headers=get_auth_headers(token=token),
279346
json=payload,
280347
verify=ca_bundle_file,
348+
stream=stream,
281349
)
282350

283351
if response.status_code != http.HTTPStatus.OK:
@@ -292,11 +360,12 @@ def send_chat_detections_request(
292360
ca_bundle_file: str,
293361
content: str,
294362
model: str,
363+
stream: bool = False,
295364
detectors: Dict[str, Any] = None,
296365
) -> requests.Response:
297-
payload = get_chat_detections_payload(content=content, model=model, detectors=detectors)
366+
payload = get_chat_detections_payload(content=content, model=model, detectors=detectors, stream=stream)
298367
return _send_guardrails_orchestrator_post_request(
299-
url=url, token=token, ca_bundle_file=ca_bundle_file, payload=payload
368+
url=url, token=token, ca_bundle_file=ca_bundle_file, payload=payload, stream=stream
300369
)
301370

302371

@@ -331,19 +400,27 @@ def send_and_verify_unsuitable_output_detection(
331400
ca_bundle_file: str,
332401
prompt: GuardrailsDetectionPrompt,
333402
model: str,
403+
stream: bool = False,
334404
detectors: Dict[str, Any] = None,
335405
):
336406
"""Send a prompt to the GuardrailsOrchestrator and verify that it triggers an unsuitable output detection"""
337407

338408
response = send_chat_detections_request(
339-
url=url, token=token, ca_bundle_file=ca_bundle_file, content=prompt.content, model=model, detectors=detectors
409+
url=url,
410+
token=token,
411+
ca_bundle_file=ca_bundle_file,
412+
content=prompt.content,
413+
model=model,
414+
detectors=detectors,
415+
stream=stream,
340416
)
341417

342418
verify_builtin_detector_unsuitable_output_response(
343419
response=response,
344420
detector_id=prompt.detector_id,
345421
detection_name=prompt.detection_name,
346422
detection_type=prompt.detection_type,
423+
stream=stream,
347424
)
348425
return response
349426

0 commit comments

Comments
 (0)