Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/configs/gs_content_safety/config/config.yml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
models:
- type: main
engine: nvidia_ai_endpoints
engine: nim
model: meta/llama-3.3-70b-instruct

- type: content_safety
engine: nvidia_ai_endpoints
engine: nim
model: nvidia/llama-3.1-nemoguard-8b-content-safety

rails:
Expand All @@ -16,5 +16,5 @@ rails:
- content safety check output $model=content_safety
streaming:
enabled: True
chunk_size: 200
context_size: 50
chunk_size: 5
context_size: 1
186 changes: 186 additions & 0 deletions nemoguardrails/guardrails/iorails.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,13 @@
"""

import asyncio
import json
import logging
import time
from collections.abc import AsyncIterator
from typing import Optional, Union

from nemoguardrails.exceptions import StreamingNotSupportedError
from nemoguardrails.guardrails.guardrails_types import (
LLMMessage,
LLMMessages,
Expand All @@ -34,13 +38,18 @@
)
from nemoguardrails.guardrails.model_manager import ModelManager
from nemoguardrails.guardrails.rails_manager import RailsManager
from nemoguardrails.rails.llm.buffer import get_buffer_strategy
from nemoguardrails.rails.llm.config import RailsConfig
from nemoguardrails.rails.llm.options import GenerationOptions
from nemoguardrails.streaming import END_OF_STREAM, StreamingHandler

log = logging.getLogger(__name__)

REFUSAL_MESSAGE = "I'm sorry, I can't respond to that."

# Default concurrency budget for streaming requests (separate from the AsyncWorkQueue for generate_async)
STREAM_MAX_CONCURRENCY = 256


class IORails:
"""Workflow engine for accelerated Input/Output rails inference."""
Expand All @@ -55,6 +64,9 @@ def __init__(self, config: RailsConfig) -> None:
# Rails Manager is responsible for running rails by making calls to Model Manager
self.rails_manager = RailsManager(config, self.model_manager)

# Semaphore for streaming concurrency control / load shedding
self._stream_semaphore = asyncio.Semaphore(STREAM_MAX_CONCURRENCY)

async def start(self) -> None:
"""Start the IORails engine. Call this during service startup."""
if self._running:
Expand Down Expand Up @@ -140,3 +152,177 @@ async def generate_async(self, messages: LLMMessages, **kwargs) -> LLMMessage:
elapsed_ms = (time.monotonic() - t0) * 1000
log.info("[%s] generate_async completed time=%.1fms", req_id, elapsed_ms)
reset_request_id(token)

def _validate_streaming_with_output_rails(self) -> None:
"""Raise if output rails exist but streaming is not enabled for them."""
if len(self.config.rails.output.flows) > 0 and (
not self.config.rails.output.streaming or not self.config.rails.output.streaming.enabled
):
raise StreamingNotSupportedError(
"stream_async() cannot be used when output rails are configured but "
"rails.output.streaming.enabled is False. Either set "
"rails.output.streaming.enabled to True in your configuration, or use "
"generate_async() instead of stream_async()."
)

def stream_async(
self,
messages: LLMMessages,
options: Optional[Union[dict, GenerationOptions]] = None,
include_metadata: Optional[bool] = False,
) -> AsyncIterator[Union[str, dict]]:
"""Stream LLM response tokens with input/output rails applied.

Returns an async iterator that yields string chunks (or dicts when
``include_metadata=True``). Input rails run before any tokens are
streamed. If output rails are configured and streaming is enabled,
tokens are buffered and checked using the same ``RollingBuffer`` /
``stream_first`` semantics as LLMRails.

Args:
messages: Conversation messages in OpenAI format.
options: Optional GenerationOptions (llm_params are forwarded to
the main LLM call).
include_metadata: When True, chunks are dicts with ``text`` and
``metadata`` keys instead of plain strings.

Returns:
An async iterator of string chunks (or dicts).

Raises:
StreamingNotSupportedError: If output rails are present but
``rails.output.streaming.enabled`` is False.
asyncio.QueueFull: If the streaming concurrency limit is
reached (load shedding).
"""
self._validate_streaming_with_output_rails()

Comment thread
tgasser-nv marked this conversation as resolved.
# Extract llm_params from GenerationOptions if provided
llm_kwargs: dict = {}
if options and isinstance(options, GenerationOptions):
llm_kwargs = options.llm_params if options.llm_params else {}
Comment thread
tgasser-nv marked this conversation as resolved.
Outdated

streaming_handler = StreamingHandler(include_metadata=include_metadata)

async def _generation_task():
"""Background task: input rails → stream LLM chunks → push to handler.

Inherits the request ID from the caller context via create_task().
"""
req_id = get_request_id()
t0 = time.monotonic()
try:
log.info("[%s] stream_async generation task started", req_id)
log.debug("[%s] stream_async messages=%s", req_id, truncate(messages))

# Step 1: Input rails (non-streaming)
log.info("[%s] Running input rails", req_id)
input_result = await self.rails_manager.is_input_safe(messages)
if not input_result.is_safe:
log.info("[%s] Input blocked: %s", req_id, input_result.reason)
await streaming_handler.push_chunk(REFUSAL_MESSAGE)
await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore[arg-type]
return

# Step 2: Stream main LLM
log.info("[%s] Streaming main LLM", req_id)
async for chunk in self.model_manager.stream_async("main", messages, **llm_kwargs):
await streaming_handler.push_chunk(chunk)

await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore[arg-type]
except Exception as e:
elapsed_ms = (time.monotonic() - t0) * 1000
log.error(
"[%s] stream_async generation task failed time=%.1fms",
req_id,
elapsed_ms,
exc_info=True,
)
error_payload = json.dumps(
{"error": {"message": str(e), "type": "generation_error", "code": "generation_failed"}}
)
await streaming_handler.push_chunk(error_payload)
await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore[arg-type]
finally:
elapsed_ms = (time.monotonic() - t0) * 1000
log.info("[%s] stream_async time=%.1fms", req_id, elapsed_ms)

async def _wrapped_iterator():
"""Wrap the base iterator with semaphore-based concurrency control."""
# Try to acquire a streaming slot; raise immediately if saturated
if self._stream_semaphore._value <= 0: # noqa: SLF001
raise asyncio.QueueFull("Streaming concurrency limit reached")

await self._stream_semaphore.acquire()
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
# Set request ID here so both the generation task (via create_task
Comment thread
tgasser-nv marked this conversation as resolved.
Outdated
# context copy) and output rails (running in this coroutine) share it.
token = set_new_request_id()
task = asyncio.create_task(_generation_task())
Comment thread
tgasser-nv marked this conversation as resolved.
Outdated
try:
# Determine base iterator: with or without output rails
output_streaming = self.config.rails.output.streaming
if output_streaming and output_streaming.enabled and len(self.config.rails.output.flows) > 0:
base_iterator = self._run_output_rails_in_streaming(
streaming_handler=streaming_handler,
messages=messages,
)
else:
base_iterator = streaming_handler

async for chunk in base_iterator:
if chunk is not None:
yield chunk
finally:
self._stream_semaphore.release()
await task
reset_request_id(token)
Comment thread
tgasser-nv marked this conversation as resolved.
Outdated
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

return _wrapped_iterator()

async def _run_output_rails_in_streaming(
self,
streaming_handler: AsyncIterator[Union[str, dict]],
messages: LLMMessages,
) -> AsyncIterator[Union[str, dict]]:
"""Buffer streamed chunks and run output rails on each batch.

Uses the same ``RollingBuffer`` and ``stream_first`` semantics as
LLMRails:
- ``stream_first=True``: yield chunks immediately, then run output
rails. If unsafe, inject an error and stop.
- ``stream_first=False``: run output rails first, only yield chunks
if safe.
"""

# Unpack streaming config and get the buffer strategy
output_streaming_config = self.config.rails.output.streaming
stream_first = output_streaming_config.stream_first
buffer_strategy = get_buffer_strategy(output_streaming_config)

async for chunk_batch in buffer_strategy(streaming_handler):
Comment thread
coderabbitai[bot] marked this conversation as resolved.
user_output_chunks = chunk_batch.user_output_chunks
bot_response_chunk = buffer_strategy.format_chunks(chunk_batch.processing_context)

if stream_first:
for chunk in user_output_chunks:
yield chunk

# Run output rails on the accumulated context
req_id = get_request_id()
log.info("[%s] Running output rails", req_id)
output_result = await self.rails_manager.is_output_safe(messages, bot_response_chunk)
if not output_result.is_safe:
log.info("[%s] Output blocked: %s", req_id, output_result.reason)
error_data = {
"error": {
"message": f"Blocked by output rails: {output_result.reason}",
"type": "guardrails_violation",
"code": "content_blocked",
}
}
yield json.dumps(error_data)
return

if not stream_first:
for chunk in user_output_chunks:
yield chunk
Loading
Loading