Add run_moderation to the remote provider#21
Conversation
Reviewer's GuideThis PR adds moderation support to the remote safety provider by implementing a new run_moderation method with helper functions for shield resolution and input conversion, includes a warning for backward compatibility, bumps the package version, and adds comprehensive unit tests for the new moderation workflow. Sequence diagram for the new moderation workflow in run_moderationsequenceDiagram
participant Client
participant DetectorProvider
participant ShieldsService
participant Shield
participant ModerationObject
Client->>DetectorProvider: run_moderation(input, model)
DetectorProvider->>DetectorProvider: _get_shield_id_from_model(model)
DetectorProvider->>ShieldsService: list_shields()
ShieldsService-->>DetectorProvider: shields_response
DetectorProvider->>DetectorProvider: _convert_input_to_messages(input)
DetectorProvider->>Shield: run_shield(shield_id, messages)
Shield-->>DetectorProvider: shield_response
DetectorProvider->>ModerationObject: Build ModerationObject with results
DetectorProvider-->>Client: ModerationObject
Entity relationship diagram for ModerationObject and ModerationObjectResultserDiagram
MODERATION_OBJECT {
string id
string model
}
MODERATION_OBJECT_RESULTS {
boolean flagged
object categories
object category_applied_input_types
object category_scores
string user_message
object metadata
}
MODERATION_OBJECT ||--o{ MODERATION_OBJECT_RESULTS : contains
Class diagram for new and updated moderation-related typesclassDiagram
class DetectorProvider {
+run_moderation(input: str | list[str], model: str): ModerationObject
+_get_shield_id_from_model(model: str): str
+_convert_input_to_messages(texts: str | list[str]): List[Message]
}
class ModerationObject {
+id: str
+model: str
+results: List[ModerationObjectResults]
}
class ModerationObjectResults {
+flagged: bool
+categories: dict
+category_applied_input_types: dict
+category_scores: dict
+user_message: str
+metadata: dict
}
class UserMessage {
+content: str
}
DetectorProvider --> ModerationObject
ModerationObject --> ModerationObjectResults
DetectorProvider --> UserMessage
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
| provider.run_shield = AsyncMock(return_value=FakeShieldResponse()) | ||
|
|
||
| result = await provider.run_moderation(["bad message", "good message"], "test_model") | ||
| assert len(result.results) == 2 |
Check notice
Code scanning / Bandit
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test
|
|
||
| result = await provider.run_moderation(["bad message", "good message"], "test_model") | ||
| assert len(result.results) == 2 | ||
| assert result.results[0].flagged is True |
Check notice
Code scanning / Bandit
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test
| result = await provider.run_moderation(["bad message", "good message"], "test_model") | ||
| assert len(result.results) == 2 | ||
| assert result.results[0].flagged is True | ||
| assert result.results[1].flagged is False |
Check notice
Code scanning / Bandit
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test
| assert len(result.results) == 2 | ||
| assert result.results[0].flagged is True | ||
| assert result.results[1].flagged is False | ||
| assert result.results[0].user_message == "bad message" |
Check notice
Code scanning / Bandit
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test
| assert result.results[0].flagged is True | ||
| assert result.results[1].flagged is False | ||
| assert result.results[0].user_message == "bad message" | ||
| assert result.results[1].user_message == "good message" |
Check notice
Code scanning / Bandit
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test
| provider._convert_input_to_messages = MagicMock(return_value=[MagicMock(content="msg")]) | ||
|
|
||
| result = await provider.run_moderation(["msg"], "test_model") | ||
| assert len(result.results) == 1 |
Check notice
Code scanning / Bandit
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test
There was a problem hiding this comment.
Hey there - I've reviewed your changes - here's some feedback:
- Avoid using
inputas a parameter name in run_moderation to prevent shadowing Python’s built‐in; consider renaming it to something liketextsorinputs. - run_moderation currently calls list_shields on every invocation and does a linear scan of results_metadata for each message; consider caching shield_id per model and indexing results_metadata by message_index to improve performance.
- The run_moderation method is quite large in base.py—extracting it into a dedicated helper or service class would improve readability and maintainability.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- Avoid using `input` as a parameter name in run_moderation to prevent shadowing Python’s built‐in; consider renaming it to something like `texts` or `inputs`.
- run_moderation currently calls list_shields on every invocation and does a linear scan of results_metadata for each message; consider caching shield_id per model and indexing results_metadata by message_index to improve performance.
- The run_moderation method is quite large in base.py—extracting it into a dedicated helper or service class would improve readability and maintainability.
## Individual Comments
### Comment 1
<location> `tests/test_moderation.py:4` </location>
<code_context>
+import pytest
+from unittest.mock import AsyncMock, MagicMock
+
+@pytest.mark.asyncio
+async def test_run_moderation_flagged():
+ from llama_stack_provider_trustyai_fms.detectors.base import DetectorProvider
+
</code_context>
<issue_to_address>
Missing test for empty input and single string input edge cases.
Please add tests for empty list input and single string input to ensure run_moderation handles these cases correctly.
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| @pytest.mark.asyncio | ||
| async def test_run_moderation_flagged(): |
There was a problem hiding this comment.
suggestion (testing): Missing test for empty input and single string input edge cases.
Please add tests for empty list input and single string input to ensure run_moderation handles these cases correctly.
| if isinstance(input, str): | ||
| inputs = [input] | ||
| else: | ||
| inputs = input |
There was a problem hiding this comment.
suggestion (code-quality): Replace if statement with if expression (assign-if-exp)
| if isinstance(input, str): | |
| inputs = [input] | |
| else: | |
| inputs = input | |
| inputs = [input] if isinstance(input, str) else input |
There was a problem hiding this comment.
Sure! I'm generating a new review now.
There was a problem hiding this comment.
Hey @m-misiura, I've posted a new review for you!
|
|
||
| result = await provider.run_moderation(["msg"], "test_model") | ||
| assert len(result.results) == 1 | ||
| assert result.results[0].flagged is False |
Check notice
Code scanning / Bandit
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test
| result = await provider.run_moderation(["msg"], "test_model") | ||
| assert len(result.results) == 1 | ||
| assert result.results[0].flagged is False | ||
| assert "fail" in result.results[0].metadata["error"] |
Check notice
Code scanning / Bandit
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test
| provider._convert_input_to_messages = MagicMock(return_value=[]) | ||
| provider.run_shield = AsyncMock() | ||
| result = await provider.run_moderation([], "test_model") | ||
| assert len(result.results) == 0 |
Check notice
Code scanning / Bandit
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test
| ]}) | ||
| )) | ||
| result = await provider.run_moderation("one message", "test_model") | ||
| assert len(result.results) == 1 |
Check notice
Code scanning / Bandit
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test
| )) | ||
| result = await provider.run_moderation("one message", "test_model") | ||
| assert len(result.results) == 1 | ||
| assert result.results[0].user_message == "one message" No newline at end of file |
Check notice
Code scanning / Bandit
Use of assert detected. The enclosed code will be removed when compiling to optimised byte code. Note test
…on` and avoiding shadowing Python’s built-in function input inside the method body
There was a problem hiding this comment.
Hey there - I've reviewed your changes - here's some feedback:
- Instead of lazily creating
_model_to_shield_idviahasattrchecks, initialize that cache in the provider’s__init__to make the code clearer and avoid repeated attribute lookups. - Catching all exceptions in
run_moderationand hiding them in result metadata can make debugging harder—consider logging unexpected errors or narrowing the except clause to known failure modes. - The
run_moderationmethod combines shield lookup, input conversion, and result assembly in one block—extracting the result-building logic into a helper would improve readability and maintainability.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- Instead of lazily creating `_model_to_shield_id` via `hasattr` checks, initialize that cache in the provider’s `__init__` to make the code clearer and avoid repeated attribute lookups.
- Catching all exceptions in `run_moderation` and hiding them in result metadata can make debugging harder—consider logging unexpected errors or narrowing the except clause to known failure modes.
- The `run_moderation` method combines shield lookup, input conversion, and result assembly in one block—extracting the result-building logic into a helper would improve readability and maintainability.
## Individual Comments
### Comment 1
<location> `llama_stack_provider_trustyai_fms/detectors/base.py:1806` </location>
<code_context>
+ if _HAS_MODERATION:
</code_context>
<issue_to_address>
Conditional method definition may lead to missing attributes.
Since the presence of run_moderation depends on _HAS_MODERATION, code that expects this method may fail in environments where moderation is unavailable. To ensure a consistent class interface, define run_moderation unconditionally and raise NotImplementedError when moderation is not supported.
</issue_to_address>
### Comment 2
<location> `llama_stack_provider_trustyai_fms/detectors/base.py:1839` </location>
<code_context>
+ if result:
+ cat = result.get("detection_type")
+ score = result.get("score")
+ if isinstance(cat, str) and score is not None:
+ is_violation = result.get("status") == "violation"
+ categories[cat] = is_violation
</code_context>
<issue_to_address>
Only one category per message is supported.
Currently, only one detection_type and score are processed per message. If the API can return multiple categories, update the logic to handle all relevant categories.
Suggested implementation:
```python
if result:
# Support multiple categories per message
detected_categories = result.get("categories")
detected_scores = result.get("scores")
detected_statuses = result.get("statuses")
# Fallback for single category format
if detected_categories and isinstance(detected_categories, dict):
for cat, status in detected_statuses.items():
score = detected_scores.get(cat)
if isinstance(cat, str) and score is not None:
is_violation = status == "violation"
categories[cat] = is_violation
category_scores[cat] = float(score)
category_applied_input_types[cat] = ["text"]
if is_violation:
flagged = True
else:
cat = result.get("detection_type")
score = result.get("score")
if isinstance(cat, str) and score is not None:
is_violation = result.get("status") == "violation"
categories[cat] = is_violation
category_scores[cat] = float(score)
category_applied_input_types[cat] = ["text"]
flagged = is_violation
meta = result
```
- You may need to adjust the keys (`categories`, `scores`, `statuses`) to match the actual API response format if they differ.
- If the API returns a list of category objects instead of dicts, iterate accordingly.
- Ensure that the rest of the code (e.g., how `ModerationObjectResults` uses these dicts) supports multiple categories.
</issue_to_address>
### Comment 3
<location> `tests/test_moderation.py:9` </location>
<code_context>
+ from llama_stack_provider_trustyai_fms.detectors.base import DetectorProvider
+
+ provider = DetectorProvider(detectors={})
+ provider._get_shield_id_from_model = AsyncMock(return_value="test_shield")
+ provider._convert_input_to_messages = MagicMock(return_value=[
+ MagicMock(content="bad message"), MagicMock(content="good message")
</code_context>
<issue_to_address>
Consider adding a test for multiple shields found for a model.
Please add a test that triggers the multiple shields exception and verifies the error is correctly reflected in the moderation results metadata.
Suggested implementation:
```python
import pytest
from unittest.mock import AsyncMock, MagicMock
class MultipleShieldsFoundError(Exception):
pass
@pytest.mark.asyncio
async def test_run_moderation_multiple_shields_error():
from llama_stack_provider_trustyai_fms.detectors.base import DetectorProvider
provider = DetectorProvider(detectors={})
# Simulate multiple shields found by raising the error
provider._get_shield_id_from_model = AsyncMock(side_effect=MultipleShieldsFoundError("Multiple shields found for model"))
provider._convert_input_to_messages = MagicMock(return_value=[
MagicMock(content="test message")
])
# Run moderation and check error in metadata
result = await provider.run_moderation("test_model", "test input")
assert result["metadata"]["error"] == "Multiple shields found for model"
@pytest.mark.asyncio
async def test_run_moderation_flagged():
```
- If `MultipleShieldsFoundError` is defined elsewhere in your codebase, import it instead of defining it in the test file.
- Ensure that `provider.run_moderation` correctly catches the exception and sets the error in `result["metadata"]["error"]`. If not, you may need to update the implementation to handle this case.
</issue_to_address>Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
| if _HAS_MODERATION: | ||
| async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject: | ||
| """ | ||
| Runs moderation for each input message. | ||
| Returns a ModerationObject with one ModerationObjectResults per input. | ||
| """ | ||
| texts = input # Avoid shadowing the built-in 'input' | ||
| try: | ||
| # Shield ID caching for performance | ||
| if not hasattr(self, "_model_to_shield_id"): |
There was a problem hiding this comment.
issue: Conditional method definition may lead to missing attributes.
Since the presence of run_moderation depends on _HAS_MODERATION, code that expects this method may fail in environments where moderation is unavailable. To ensure a consistent class interface, define run_moderation unconditionally and raise NotImplementedError when moderation is not supported.
| if result: | ||
| cat = result.get("detection_type") | ||
| score = result.get("score") | ||
| if isinstance(cat, str) and score is not None: |
There was a problem hiding this comment.
suggestion: Only one category per message is supported.
Currently, only one detection_type and score are processed per message. If the API can return multiple categories, update the logic to handle all relevant categories.
Suggested implementation:
if result:
# Support multiple categories per message
detected_categories = result.get("categories")
detected_scores = result.get("scores")
detected_statuses = result.get("statuses")
# Fallback for single category format
if detected_categories and isinstance(detected_categories, dict):
for cat, status in detected_statuses.items():
score = detected_scores.get(cat)
if isinstance(cat, str) and score is not None:
is_violation = status == "violation"
categories[cat] = is_violation
category_scores[cat] = float(score)
category_applied_input_types[cat] = ["text"]
if is_violation:
flagged = True
else:
cat = result.get("detection_type")
score = result.get("score")
if isinstance(cat, str) and score is not None:
is_violation = result.get("status") == "violation"
categories[cat] = is_violation
category_scores[cat] = float(score)
category_applied_input_types[cat] = ["text"]
flagged = is_violation
meta = result- You may need to adjust the keys (
categories,scores,statuses) to match the actual API response format if they differ. - If the API returns a list of category objects instead of dicts, iterate accordingly.
- Ensure that the rest of the code (e.g., how
ModerationObjectResultsuses these dicts) supports multiple categories.
| from llama_stack_provider_trustyai_fms.detectors.base import DetectorProvider | ||
|
|
||
| provider = DetectorProvider(detectors={}) | ||
| provider._get_shield_id_from_model = AsyncMock(return_value="test_shield") |
There was a problem hiding this comment.
suggestion (testing): Consider adding a test for multiple shields found for a model.
Please add a test that triggers the multiple shields exception and verifies the error is correctly reflected in the moderation results metadata.
Suggested implementation:
import pytest
from unittest.mock import AsyncMock, MagicMock
class MultipleShieldsFoundError(Exception):
pass
@pytest.mark.asyncio
async def test_run_moderation_multiple_shields_error():
from llama_stack_provider_trustyai_fms.detectors.base import DetectorProvider
provider = DetectorProvider(detectors={})
# Simulate multiple shields found by raising the error
provider._get_shield_id_from_model = AsyncMock(side_effect=MultipleShieldsFoundError("Multiple shields found for model"))
provider._convert_input_to_messages = MagicMock(return_value=[
MagicMock(content="test message")
])
# Run moderation and check error in metadata
result = await provider.run_moderation("test_model", "test input")
assert result["metadata"]["error"] == "Multiple shields found for model"
@pytest.mark.asyncio
async def test_run_moderation_flagged():- If
MultipleShieldsFoundErroris defined elsewhere in your codebase, import it instead of defining it in the test file. - Ensure that
provider.run_moderationcorrectly catches the exception and sets the error inresult["metadata"]["error"]. If not, you may need to update the implementation to handle this case.
| ) | ||
| ) | ||
| if _HAS_MODERATION: | ||
| async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject: |
There was a problem hiding this comment.
issue (code-quality): We've found these issues:
- Move assignment closer to its usage within a block (
move-assign-in-block) - Use named expression to simplify assignment and conditional (
use-named-expression) - Low code quality found in DetectorProvider.run_moderation - 25% (
low-code-quality)
Explanation
The quality score for this function is below the quality threshold of 25%.
This score is a combination of the method length, cognitive complexity and working memory.
How can you solve this?
It might be worth refactoring this function to make it shorter and more readable.
- Reduce the function length by extracting pieces of functionality out into
their own functions. This is the most important thing you can do - ideally a
function should be less than 10 lines. - Reduce nesting, perhaps by introducing guard clauses to return early.
- Ensure that variables are tightly scoped, so that code using related concepts
sits together within the function rather than being scattered.
| if isinstance(texts, str): | ||
| inputs = [texts] | ||
| else: | ||
| inputs = texts |
There was a problem hiding this comment.
suggestion (code-quality): Replace if statement with if expression (assign-if-exp)
| if isinstance(texts, str): | |
| inputs = [texts] | |
| else: | |
| inputs = texts | |
| inputs = [texts] if isinstance(texts, str) else texts |
What does this PR do?
With the changes in upstream llama stack >= 0.2.18, there is a need to add the
run_moderationmethod, else the provider will break (see this PR and this discussionTo ensure backward compatibility, e.g. with llama stack == 0.2.14, imports
ModerationObjectandModerationObjectResultsare put inside the try-except statementTest plan
I added some tests for the
run_moderationusing Mocks. I also tested manually against a live serverHere is a run_moderation response from an inline provider (codeshield)
Here is a run_moderation response from the trustyai_fms provider
In addition:
quay.io/rh-ee-mmisiura/lls:run_moderationSummary by Sourcery
Enable moderation support in the trustyai_fms provider by implementing run_moderation with backward compatibility fallbacks and accompanying unit tests
New Features:
Enhancements:
Tests:
Chores: