Skip to content

Commit 6db7a75

Browse files
authored
Merge pull request trustyai-explainability#54 from trustyai-explainability/AddHeaderPassthrough
Feat: Add header extraction from shield params
2 parents b217ba1 + 6a1c292 commit 6db7a75

File tree

4 files changed

+34
-3
lines changed

4 files changed

+34
-3
lines changed

llama_stack_provider_trustyai_fms/detectors/base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def _extract_detector_params(self) -> dict[str, Any]:
325325

326326
return detector_params
327327

328-
def _prepare_headers(self) -> Headers:
328+
def _prepare_headers(self, params: dict[str, Any] = None) -> Headers:
329329
"""Prepare request headers based on configuration"""
330330
headers: Headers = {
331331
"accept": "application/json",
@@ -342,6 +342,11 @@ def _prepare_headers(self) -> Headers:
342342
elif self.config.auth_token:
343343
headers["Authorization"] = f"Bearer {self.config.auth_token}"
344344

345+
if params is not None:
346+
for k, v in params.items():
347+
if k.lower() == "headers" and isinstance(v, dict):
348+
headers.update(v)
349+
345350
return headers
346351

347352
def _prepare_request_payload(

llama_stack_provider_trustyai_fms/detectors/chat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ async def _call_detector_api(
199199
"""Call chat detector API with proper endpoint selection"""
200200
try:
201201
request = self._prepare_chat_request(messages, params)
202-
headers = self._prepare_headers()
202+
headers = self._prepare_headers(params)
203203

204204
logger.info("Making detector API request")
205205
logger.debug(f"Request headers: {headers}")

llama_stack_provider_trustyai_fms/detectors/content.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ async def _call_detector_api(
121121
"""Call detector API with proper endpoint selection"""
122122
try:
123123
request = self._prepare_content_request(content, params)
124-
headers = self._prepare_headers()
124+
headers = self._prepare_headers(params)
125125

126126
logger.info("Making detector API request")
127127
logger.debug(f"Request headers: {headers}")

tests/unit/test_detectors.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,32 @@ async def test_http_request_success(self, initialized_detector, httpx_mock):
285285
result.violation is not None
286286
) # Should detect violation because score > threshold
287287

288+
@pytest.mark.asyncio
289+
async def test_headers_extracted_from_params(self, initialized_detector):
290+
detector = initialized_detector
291+
292+
# Patch _make_request to capture headers
293+
called = {}
294+
295+
async def fake_make_request(request, headers, timeout=None):
296+
called["headers"] = headers
297+
# Return a valid detection so the rest of the code works
298+
return [{"score": 0.9, "label": "label", "detection_type": "content"}]
299+
300+
detector._make_request = fake_make_request
301+
302+
messages = [UserMessage(content="Test content", role="user")]
303+
params = {"headers": {"X-Test-Header": "header-value", "another": "foo"}}
304+
await detector.run_shield(
305+
shield_id=detector.config.detector_id, messages=messages, params=params
306+
)
307+
308+
# Check that headers were extracted and passed
309+
assert "X-Test-Header" in called["headers"]
310+
assert called["headers"]["X-Test-Header"] == "header-value"
311+
assert "another" in called["headers"]
312+
assert called["headers"]["another"] == "foo"
313+
288314
@pytest.mark.asyncio
289315
async def test_http_request_no_violation(self, initialized_detector, httpx_mock):
290316
detector = initialized_detector

0 commit comments

Comments
 (0)