Skip to content

Commit e2095b8

Browse files
authored
chore(iorails): Refactor RailsManager and Nemoguard Actions (#1762)
1 parent c7233ce commit e2095b8

10 files changed

Lines changed: 1442 additions & 1306 deletions
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Content safety rail actions for IORails."""
17+
18+
from typing import Any, Optional
19+
20+
from nemoguardrails.guardrails.guardrails_types import LLMMessages, RailResult
21+
from nemoguardrails.guardrails.rail_action import RailAction
22+
23+
_MAX_TOKENS = 3
24+
_TEMPERATURE = 1e-20
25+
26+
27+
class ContentSafetyInputAction(RailAction):
28+
"""Check user input for content safety violations."""
29+
30+
action_name = "content safety check input"
31+
requires_model = True
32+
33+
def _extract_messages(self, messages: LLMMessages, bot_response: Optional[str]) -> dict[str, Any]:
34+
return {"user_input": self._last_user_content(messages)}
35+
36+
def _create_prompt(self, model_type: Optional[str], extracted: dict[str, Any]) -> list[dict]:
37+
prompt_task_key = f"content_safety_check_input $model={model_type}"
38+
content_safety_config = self.task_manager.config.rails.config.content_safety
39+
if content_safety_config is None:
40+
raise RuntimeError("content_safety config is required for content safety rail")
41+
reasoning_enabled = content_safety_config.reasoning.enabled
42+
43+
prompt = self.task_manager.render_task_prompt(
44+
task=prompt_task_key,
45+
context={"user_input": extracted["user_input"], "reasoning_enabled": reasoning_enabled},
46+
)
47+
return self._prompt_to_messages(prompt)
48+
49+
async def _get_response(self, model_type: Optional[str], prompt: Any) -> str:
50+
prompt_task_key = f"content_safety_check_input $model={model_type}"
51+
52+
stop = self.task_manager.get_stop_tokens(task=prompt_task_key)
53+
max_tokens = self.task_manager.get_max_tokens(task=prompt_task_key) or _MAX_TOKENS
54+
kwargs: dict = {"temperature": _TEMPERATURE, "max_tokens": max_tokens}
55+
if stop:
56+
kwargs["stop"] = stop
57+
58+
response_text = await self._get_llm_response(model_type, prompt, **kwargs)
59+
60+
# Parse via LLMTaskManager's registered output parser
61+
return self.task_manager.parse_task_output(task=prompt_task_key, output=response_text) # type: ignore[arg-type]
62+
63+
def _parse_response(self, response: Any) -> RailResult:
64+
return _content_safety_to_rail_result(response)
65+
66+
67+
class ContentSafetyOutputAction(RailAction):
68+
"""Check bot response for content safety violations."""
69+
70+
action_name = "content safety check output"
71+
72+
def _extract_messages(self, messages: LLMMessages, bot_response: Optional[str]) -> dict[str, Any]:
73+
if not bot_response:
74+
raise RuntimeError("bot_response is required for content safety output check")
75+
return {
76+
"user_input": self._last_user_content(messages),
77+
"bot_response": bot_response,
78+
}
79+
80+
def _create_prompt(self, model_type: Optional[str], extracted: dict[str, Any]) -> list[dict]:
81+
prompt_task_key = f"content_safety_check_output $model={model_type}"
82+
content_safety_config = self.task_manager.config.rails.config.content_safety
83+
if content_safety_config is None:
84+
raise RuntimeError("content_safety config is required for content safety rail")
85+
reasoning_enabled = content_safety_config.reasoning.enabled
86+
87+
prompt = self.task_manager.render_task_prompt(
88+
task=prompt_task_key,
89+
context={
90+
"user_input": extracted["user_input"],
91+
"bot_response": extracted["bot_response"],
92+
"reasoning_enabled": reasoning_enabled,
93+
},
94+
)
95+
return self._prompt_to_messages(prompt)
96+
97+
async def _get_response(self, model_type: Optional[str], prompt: Any) -> str:
98+
prompt_task_key = f"content_safety_check_output $model={model_type}"
99+
100+
stop = self.task_manager.get_stop_tokens(task=prompt_task_key)
101+
max_tokens = self.task_manager.get_max_tokens(task=prompt_task_key) or _MAX_TOKENS
102+
kwargs: dict = {"temperature": _TEMPERATURE, "max_tokens": max_tokens}
103+
if stop:
104+
kwargs["stop"] = stop
105+
106+
response_text = await self._get_llm_response(model_type, prompt, **kwargs)
107+
return self.task_manager.parse_task_output(task=prompt_task_key, output=response_text) # type: ignore[arg-type]
108+
109+
def _parse_response(self, response: Any) -> RailResult:
110+
return _content_safety_to_rail_result(response)
111+
112+
113+
def _content_safety_to_rail_result(parsed: object) -> RailResult:
114+
"""Convert nemoguard parser output to RailResult.
115+
116+
nemoguard_parse_prompt_safety / nemoguard_parse_response_safety return:
117+
[True] -> safe
118+
[False, "S1: Violence", ...] -> unsafe with categories
119+
"""
120+
if isinstance(parsed, (list, tuple)):
121+
if parsed and parsed[0] is True:
122+
return RailResult(is_safe=True)
123+
if parsed and parsed[0] is False:
124+
if len(parsed) > 1:
125+
categories = ", ".join(str(c) for c in parsed[1:])
126+
return RailResult(is_safe=False, reason=f"Safety categories: {categories}")
127+
return RailResult(is_safe=False, reason="Unknown")
128+
raise RuntimeError(f"Unexpected content safety parse result: {parsed}")
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Jailbreak detection rail action for IORails."""
17+
18+
from typing import Any, Optional
19+
20+
from nemoguardrails.guardrails.guardrails_types import LLMMessages, RailResult
21+
from nemoguardrails.guardrails.rail_action import RailAction
22+
23+
24+
class JailbreakDetectionAction(RailAction):
25+
"""Detect jailbreak attempts via the NIM jailbreak detection API."""
26+
27+
action_name = "jailbreak detection model"
28+
requires_model = False
29+
30+
def _extract_messages(self, messages: LLMMessages, bot_response: Optional[str]) -> dict[str, Any]:
31+
return {"user_input": self._last_user_content(messages)}
32+
33+
def _create_prompt(self, model_type: Optional[str], extracted: dict[str, Any]) -> dict[str, str]:
34+
return {"input": extracted["user_input"]}
35+
36+
async def _get_response(self, model_type: Optional[str], prompt: Any) -> dict:
37+
return await self._get_api_response("jailbreak_detection", prompt)
38+
39+
def _parse_response(self, response: Any) -> RailResult:
40+
if "jailbreak" not in response:
41+
raise RuntimeError(f"Jailbreak response missing 'jailbreak' field: {response}")
42+
43+
score = response.get("score", "unknown")
44+
if response["jailbreak"]:
45+
return RailResult(is_safe=False, reason=f"Score: {score}")
46+
return RailResult(is_safe=True, reason=f"Score: {score}")
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# SPDX-FileCopyrightText: Copyright (c) 2023-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""Topic safety rail action for IORails."""
17+
18+
from typing import Any, Optional
19+
20+
from nemoguardrails.guardrails.guardrails_types import LLMMessages, RailResult
21+
from nemoguardrails.guardrails.rail_action import RailAction
22+
from nemoguardrails.library.topic_safety.actions import (
23+
TOPIC_SAFETY_MAX_TOKENS,
24+
TOPIC_SAFETY_OUTPUT_RESTRICTION,
25+
TOPIC_SAFETY_TEMPERATURE,
26+
)
27+
28+
29+
class TopicSafetyInputAction(RailAction):
30+
"""Check whether user input is on-topic per configured guidelines."""
31+
32+
action_name = "topic safety check input"
33+
requires_model = True
34+
35+
def _extract_messages(self, messages: LLMMessages, bot_response: Optional[str]) -> dict[str, Any]:
36+
return {"messages": messages}
37+
38+
def _create_prompt(self, model_type: Optional[str], extracted: dict[str, Any]) -> list[dict]:
39+
task_key = f"topic_safety_check_input $model={model_type}"
40+
41+
system_prompt = self.task_manager.render_task_prompt(task=task_key)
42+
if isinstance(system_prompt, list):
43+
raise RuntimeError(f"Topic safety prompt must be a string template, got messages: {task_key}")
44+
45+
system_prompt = system_prompt.strip()
46+
if not system_prompt.endswith(TOPIC_SAFETY_OUTPUT_RESTRICTION):
47+
system_prompt = f"{system_prompt}\n\n{TOPIC_SAFETY_OUTPUT_RESTRICTION}"
48+
49+
return [{"role": "system", "content": system_prompt}, *extracted["messages"]]
50+
51+
async def _get_response(self, model_type: Optional[str], prompt: Any) -> str:
52+
task_key = f"topic_safety_check_input $model={model_type}"
53+
54+
stop = self.task_manager.get_stop_tokens(task=task_key)
55+
max_tokens = self.task_manager.get_max_tokens(task=task_key) or TOPIC_SAFETY_MAX_TOKENS
56+
kwargs: dict = {"temperature": TOPIC_SAFETY_TEMPERATURE, "max_tokens": max_tokens}
57+
if stop:
58+
kwargs["stop"] = stop
59+
60+
return await self._get_llm_response(model_type, prompt, **kwargs)
61+
62+
def _parse_response(self, response: Any) -> RailResult:
63+
if response.lower().strip() == "off-topic":
64+
return RailResult(is_safe=False, reason="Topic safety: off-topic")
65+
return RailResult(is_safe=True)

0 commit comments

Comments
 (0)