Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 213 additions & 0 deletions nemoguardrails/guardrails/iorails.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,14 @@
"""

import asyncio
import json
import logging
import time
from collections.abc import AsyncIterator
from contextlib import suppress
from typing import Optional, Union

from nemoguardrails.exceptions import StreamingNotSupportedError
from nemoguardrails.guardrails.guardrails_types import (
LLMMessage,
LLMMessages,
Expand All @@ -34,13 +39,18 @@
)
from nemoguardrails.guardrails.model_manager import ModelManager
from nemoguardrails.guardrails.rails_manager import RailsManager
from nemoguardrails.rails.llm.buffer import get_buffer_strategy
from nemoguardrails.rails.llm.config import RailsConfig
from nemoguardrails.rails.llm.options import GenerationOptions
from nemoguardrails.streaming import END_OF_STREAM, StreamingHandler

log = logging.getLogger(__name__)

REFUSAL_MESSAGE = "I'm sorry, I can't respond to that."

# Default concurrency budget for streaming requests (separate from the AsyncWorkQueue for generate_async)
STREAM_MAX_CONCURRENCY = 256


class IORails:
"""Workflow engine for accelerated Input/Output rails inference."""
Expand All @@ -55,6 +65,9 @@ def __init__(self, config: RailsConfig) -> None:
# Rails Manager is responsible for running rails by making calls to Model Manager
self.rails_manager = RailsManager(config, self.model_manager)

# Semaphore for streaming concurrency control / load shedding
self._stream_semaphore = asyncio.Semaphore(STREAM_MAX_CONCURRENCY)

async def start(self) -> None:
"""Start the IORails engine. Call this during service startup."""
if self._running:
Expand Down Expand Up @@ -140,3 +153,203 @@ async def generate_async(self, messages: LLMMessages, **kwargs) -> LLMMessage:
elapsed_ms = (time.monotonic() - t0) * 1000
log.info("[%s] generate_async completed time=%.1fms", req_id, elapsed_ms)
reset_request_id(token)

def _validate_streaming_with_output_rails(self) -> None:
"""Raise if output rails exist but streaming is not enabled for them."""
if len(self.config.rails.output.flows) > 0 and (
not self.config.rails.output.streaming or not self.config.rails.output.streaming.enabled
):
raise StreamingNotSupportedError(
"stream_async() cannot be used when output rails are configured but "
"rails.output.streaming.enabled is False. Either set "
"rails.output.streaming.enabled to True in your configuration, or use "
"generate_async() instead of stream_async()."
)

def stream_async(
self,
messages: LLMMessages,
options: Optional[Union[dict, GenerationOptions]] = None,
include_metadata: Optional[bool] = False,
) -> AsyncIterator[Union[str, dict]]:
"""Stream LLM response tokens with input/output rails applied.

Returns an async iterator that yields string chunks (or dicts when
``include_metadata=True``). Input rails run before any tokens are
streamed. If output rails are configured and streaming is enabled,
tokens are buffered and checked using the same ``RollingBuffer`` /
``stream_first`` semantics as LLMRails.

Args:
messages: Conversation messages in OpenAI format.
options: Optional GenerationOptions (llm_params are forwarded to
the main LLM call).
include_metadata: When True, chunks are dicts with ``text`` and
``metadata`` keys instead of plain strings.

Returns:
An async iterator of string chunks (or dicts).

Raises:
StreamingNotSupportedError: If output rails are present but
``rails.output.streaming.enabled`` is False.
ValueError: If ``include_metadata=True`` with output rails
streaming enabled (BufferStrategy requires plain string chunks).
asyncio.QueueFull: If the streaming concurrency limit is
reached (load shedding).
"""
self._validate_streaming_with_output_rails()

Comment thread
tgasser-nv marked this conversation as resolved.
output_streaming = self.config.rails.output.streaming
has_output_rails = output_streaming and output_streaming.enabled and len(self.config.rails.output.flows) > 0
if include_metadata and has_output_rails:
raise ValueError(
"include_metadata=True is not supported when output rails streaming is enabled. "
"BufferStrategy requires plain string chunks. Use include_metadata=False or "
"disable output rails streaming."
)

# Extract llm_params from GenerationOptions if provided
llm_kwargs: dict = {}
if options and isinstance(options, GenerationOptions):
llm_kwargs = options.llm_params if options.llm_params else {}
Comment thread
tgasser-nv marked this conversation as resolved.
Outdated

streaming_handler = StreamingHandler(include_metadata=include_metadata)

async def _generation_task():
"""Background task: input rails → stream LLM chunks → push to handler.

Inherits the request ID from the caller context via create_task().
"""
req_id = get_request_id()
t0 = time.monotonic()
try:
log.info("[%s] stream_async generation task started", req_id)
log.debug("[%s] stream_async messages=%s", req_id, truncate(messages))

# Step 1: Input rails (non-streaming)
log.info("[%s] Running input rails", req_id)
input_result = await self.rails_manager.is_input_safe(messages)
if not input_result.is_safe:
log.info("[%s] Input blocked: %s", req_id, input_result.reason)
await streaming_handler.push_chunk(REFUSAL_MESSAGE)
await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore[arg-type]
return

# Step 2: Stream main LLM
log.info("[%s] Streaming main LLM", req_id)
async for chunk in self.model_manager.stream_async("main", messages, **llm_kwargs):
await streaming_handler.push_chunk(chunk)

await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore[arg-type]
except Exception as e:
elapsed_ms = (time.monotonic() - t0) * 1000
log.error(
"[%s] stream_async generation task failed time=%.1fms",
req_id,
elapsed_ms,
exc_info=True,
)
error_payload = json.dumps(
{"error": {"message": str(e), "type": "generation_error", "code": "generation_failed"}}
)
await streaming_handler.push_chunk(error_payload)
await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore[arg-type]
finally:
elapsed_ms = (time.monotonic() - t0) * 1000
log.info("[%s] stream_async time=%.1fms", req_id, elapsed_ms)

async def _wrapped_iterator():
"""Wrap the base iterator with semaphore-based concurrency control."""
# Non-blocking acquire; raises immediately if all slots are taken.
# locked() returns True when the semaphore value is 0. Because there
# is no await between the check and acquire(), no other coroutine can
# interleave in asyncio's cooperative model, so this is race-free.
if self._stream_semaphore.locked():
raise asyncio.QueueFull("Streaming concurrency limit reached")
await self._stream_semaphore.acquire()
Comment thread
tgasser-nv marked this conversation as resolved.

try:
# Set request ID here so both the generation task (via create_task
# context copy) and output rails (running in this coroutine) share it.
token = set_new_request_id()
task = asyncio.create_task(_generation_task())
try:
# Determine base iterator: with or without output rails
output_streaming = self.config.rails.output.streaming
if output_streaming and output_streaming.enabled and len(self.config.rails.output.flows) > 0:
Comment thread
tgasser-nv marked this conversation as resolved.
Outdated
base_iterator = self._run_output_rails_in_streaming(
streaming_handler=streaming_handler,
messages=messages,
)
else:
base_iterator = streaming_handler

async for chunk in base_iterator:
if chunk is not None:
yield chunk
finally:
try:
if not task.done():
task.cancel()
with suppress(asyncio.CancelledError):
await task
finally:
try:
reset_request_id(token)
except ValueError:
# GeneratorExit triggers cleanup in a different context
# where the token is no longer valid — safe to ignore.
pass
finally:
self._stream_semaphore.release()

return _wrapped_iterator()

async def _run_output_rails_in_streaming(
self,
streaming_handler: AsyncIterator[Union[str, dict]],
messages: LLMMessages,
) -> AsyncIterator[Union[str, dict]]:
"""Buffer streamed chunks and run output rails on each batch.

Uses the same ``RollingBuffer`` and ``stream_first`` semantics as
LLMRails:
- ``stream_first=True``: yield chunks immediately, then run output
rails. If unsafe, inject an error and stop.
- ``stream_first=False``: run output rails first, only yield chunks
if safe.
"""

# Unpack streaming config and get the buffer strategy
output_streaming_config = self.config.rails.output.streaming
stream_first = output_streaming_config.stream_first
buffer_strategy = get_buffer_strategy(output_streaming_config)

async for chunk_batch in buffer_strategy(streaming_handler):
Comment thread
coderabbitai[bot] marked this conversation as resolved.
user_output_chunks = chunk_batch.user_output_chunks
bot_response_chunk = buffer_strategy.format_chunks(chunk_batch.processing_context)

if stream_first:
for chunk in user_output_chunks:
yield chunk

# Run output rails on the accumulated context
req_id = get_request_id()
log.info("[%s] Running output rails", req_id)
output_result = await self.rails_manager.is_output_safe(messages, bot_response_chunk)
if not output_result.is_safe:
log.info("[%s] Output blocked: %s", req_id, output_result.reason)
error_data = {
"error": {
"message": f"Blocked by output rails: {output_result.reason}",
"type": "guardrails_violation",
"code": "content_blocked",
}
}
yield json.dumps(error_data)
return

if not stream_first:
for chunk in user_output_chunks:
yield chunk
Loading
Loading