Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 246 additions & 5 deletions nemoguardrails/guardrails/iorails.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,14 @@
"""

import asyncio
import json
import logging
import time
from collections.abc import AsyncIterator
from contextlib import suppress
from typing import Optional, Union

from nemoguardrails.exceptions import StreamingNotSupportedError
from nemoguardrails.guardrails.guardrails_types import (
LLMMessage,
LLMMessages,
Expand All @@ -34,13 +39,21 @@
)
from nemoguardrails.guardrails.model_manager import ModelManager
from nemoguardrails.guardrails.rails_manager import RailsManager
from nemoguardrails.rails.llm.buffer import get_buffer_strategy
from nemoguardrails.rails.llm.config import RailsConfig
from nemoguardrails.rails.llm.options import GenerationOptions
from nemoguardrails.streaming import END_OF_STREAM, StreamingHandler

log = logging.getLogger(__name__)

REFUSAL_MESSAGE = "I'm sorry, I can't respond to that."

# Default concurrency budget for streaming requests (separate from the AsyncWorkQueue for generate_async)
STREAM_MAX_CONCURRENCY = 256

# Error type used by _generation_task when pushing error JSON into the stream
_GENERATION_ERROR_TYPE = "generation_error"


class IORails:
"""Workflow engine for accelerated Input/Output rails inference."""
Expand All @@ -55,6 +68,15 @@ def __init__(self, config: RailsConfig) -> None:
# Rails Manager is responsible for running rails by making calls to Model Manager
self.rails_manager = RailsManager(config, self.model_manager)

# Semaphore for streaming concurrency control / load shedding
self._stream_semaphore = asyncio.Semaphore(STREAM_MAX_CONCURRENCY)

@property
def _has_streaming_output_rails(self) -> bool:
"""True when output rails are configured and streaming is enabled for them."""
streaming = self.config.rails.output.streaming
return streaming is not None and streaming.enabled and len(self.config.rails.output.flows) > 0

async def start(self) -> None:
"""Start the IORails engine. Call this during service startup."""
if self._running:
Expand Down Expand Up @@ -98,6 +120,8 @@ async def _run_sync_iorails():

async def generate_async(self, messages: LLMMessages, **kwargs) -> LLMMessage:
"""Run input rails, generation, and output rails. Return response if safe."""
await self.start()

token = set_new_request_id()
req_id = get_request_id()
t0 = time.monotonic()
Expand All @@ -113,13 +137,13 @@ async def generate_async(self, messages: LLMMessages, **kwargs) -> LLMMessage:
return {"role": "assistant", "content": REFUSAL_MESSAGE}

# Step 2: Generate response from main LLM
# If we got an `options=GenerationOptions`, then unpack GenerationOptions.llm_params and add
# that to the main LLM call
log.info("[%s] Calling main LLM", req_id)
llm_kwargs = {}
if kwargs.get("options") and isinstance(kwargs["options"], GenerationOptions):
generation_options = kwargs["options"]
llm_kwargs = generation_options.llm_params if generation_options.llm_params else {}
options = kwargs.get("options")
if options and isinstance(options, dict):
options = GenerationOptions(**options)
if isinstance(options, GenerationOptions) and options.llm_params:
llm_kwargs = options.llm_params

response_text = await self.model_manager.generate_async("main", messages, **llm_kwargs)
log.debug("[%s] Main LLM response: %s", req_id, truncate(response_text))
Expand All @@ -140,3 +164,220 @@ async def generate_async(self, messages: LLMMessages, **kwargs) -> LLMMessage:
elapsed_ms = (time.monotonic() - t0) * 1000
log.info("[%s] generate_async completed time=%.1fms", req_id, elapsed_ms)
reset_request_id(token)

def _validate_streaming_with_output_rails(self) -> None:
"""Raise if output rails exist but streaming is not enabled for them."""
if len(self.config.rails.output.flows) > 0 and not self._has_streaming_output_rails:
raise StreamingNotSupportedError(
"stream_async() cannot be used when output rails are configured but "
"rails.output.streaming.enabled is False. Either set "
"rails.output.streaming.enabled to True in your configuration, or use "
"generate_async() instead of stream_async()."
)

def stream_async(
self,
messages: LLMMessages,
options: Optional[Union[dict, GenerationOptions]] = None,
include_metadata: Optional[bool] = False,
) -> AsyncIterator[Union[str, dict]]:
"""Stream LLM response tokens with input/output rails applied.

Returns an async iterator that yields string chunks (or dicts when
``include_metadata=True``). Input rails run before any tokens are
streamed. If output rails are configured and streaming is enabled,
tokens are buffered and checked using the same ``RollingBuffer`` /
``stream_first`` semantics as LLMRails.

Args:
messages: Conversation messages in OpenAI format.
options: Optional GenerationOptions (llm_params are forwarded to
the main LLM call).
include_metadata: When True, chunks are dicts with ``text`` and
``metadata`` keys instead of plain strings.

Returns:
An async iterator of string chunks (or dicts).

Raises:
StreamingNotSupportedError: If output rails are present but
``rails.output.streaming.enabled`` is False.
ValueError: If ``include_metadata=True`` with output rails
streaming enabled (BufferStrategy requires plain string chunks).
asyncio.QueueFull: If the streaming concurrency limit is
reached (load shedding).
"""
self._validate_streaming_with_output_rails()

Comment thread
tgasser-nv marked this conversation as resolved.
if include_metadata and self._has_streaming_output_rails:
raise ValueError(
"include_metadata=True is not supported when output rails streaming is enabled. "
"BufferStrategy requires plain string chunks. Use include_metadata=False or "
"disable output rails streaming."
)

# Extract llm_params from GenerationOptions if provided
llm_kwargs: dict = {}
if options and isinstance(options, dict):
options = GenerationOptions(**options)
if isinstance(options, GenerationOptions) and options.llm_params:
llm_kwargs = options.llm_params

streaming_handler = StreamingHandler(include_metadata=include_metadata)

async def _generation_task():
"""Background task: input rails → stream LLM chunks → push to handler.

Inherits the request ID from the caller context via create_task().
"""
req_id = get_request_id()
t0 = time.monotonic()
try:
# Step 1: Input rails (non-streaming)
log.info("[%s] Running input rails", req_id)
input_result = await self.rails_manager.is_input_safe(messages)
if not input_result.is_safe:
log.info("[%s] Input blocked: %s", req_id, input_result.reason)
await streaming_handler.push_chunk(REFUSAL_MESSAGE)
await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore[arg-type]
return

# Step 2: Stream main LLM
log.info("[%s] Streaming main LLM", req_id)
async for chunk in self.model_manager.stream_async("main", messages, **llm_kwargs):
await streaming_handler.push_chunk(chunk)

await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore[arg-type]
except Exception as e:
elapsed_ms = (time.monotonic() - t0) * 1000
log.error(
"[%s] generation task failed time=%.1fms",
req_id,
elapsed_ms,
exc_info=True,
)
error_payload = json.dumps(
{"error": {"message": str(e), "type": _GENERATION_ERROR_TYPE, "code": "generation_failed"}}
)
await streaming_handler.push_chunk(error_payload)
await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore[arg-type]
finally:
elapsed_ms = (time.monotonic() - t0) * 1000
log.info("[%s] generation task completed time=%.1fms", req_id, elapsed_ms)

async def _wrapped_iterator():
"""Wrap the base iterator with semaphore-based concurrency control."""
# Ensure engines are running (idempotent if already started).
await self.start()

# Non-blocking acquire; raises immediately if all slots are taken.
# locked() returns True when the semaphore value is 0. Because there
# is no await between the check and acquire(), no other coroutine can
# interleave in asyncio's cooperative model, so this is race-free.
if self._stream_semaphore.locked():
raise asyncio.QueueFull("Streaming concurrency limit reached")
await self._stream_semaphore.acquire()
Comment thread
tgasser-nv marked this conversation as resolved.

token = set_new_request_id()
req_id = get_request_id()
t0 = time.monotonic()
try:
log.info("[%s] stream_async called", req_id)
log.debug("[%s] stream_async messages=%s", req_id, truncate(messages))

task = asyncio.create_task(_generation_task())
try:
# Determine base iterator: with or without output rails
if self._has_streaming_output_rails:
base_iterator = self._run_output_rails_in_streaming(
streaming_handler=streaming_handler,
messages=messages,
)
else:
base_iterator = streaming_handler

async for chunk in base_iterator:
if chunk is not None:
yield chunk
finally:
try:
if not task.done():
task.cancel()
with suppress(asyncio.CancelledError):
await task
finally:
try:
reset_request_id(token)
except ValueError:
# GeneratorExit triggers cleanup in a different context
# where the token is no longer valid — safe to ignore.
pass
except Exception:
elapsed_ms = (time.monotonic() - t0) * 1000
log.error("[%s] stream_async failed time=%.1fms", req_id, elapsed_ms, exc_info=True)
raise
finally:
elapsed_ms = (time.monotonic() - t0) * 1000
log.info("[%s] stream_async completed time=%.1fms", req_id, elapsed_ms)
self._stream_semaphore.release()

return _wrapped_iterator()

async def _run_output_rails_in_streaming(
self,
streaming_handler: AsyncIterator[Union[str, dict]],
messages: LLMMessages,
) -> AsyncIterator[Union[str, dict]]:
"""Buffer streamed chunks and run output rails on each batch.

Uses the same ``RollingBuffer`` and ``stream_first`` semantics as
LLMRails:
- ``stream_first=True``: yield chunks immediately, then run output
rails. If unsafe, inject an error and stop.
- ``stream_first=False``: run output rails first, only yield chunks
if safe.
"""

# Unpack streaming config and get the buffer strategy
output_streaming_config = self.config.rails.output.streaming
stream_first = output_streaming_config.stream_first
buffer_strategy = get_buffer_strategy(output_streaming_config)

async for chunk_batch in buffer_strategy(streaming_handler):
Comment thread
coderabbitai[bot] marked this conversation as resolved.
user_output_chunks = chunk_batch.user_output_chunks
bot_response_chunk = buffer_strategy.format_chunks(chunk_batch.processing_context)

# If the batch contains a generation error from _generation_task,
# yield it directly and stop — don't feed error JSON through output rails.
for chunk in user_output_chunks:
try:
parsed = json.loads(chunk)
if isinstance(parsed, dict) and parsed.get("error", {}).get("type") == _GENERATION_ERROR_TYPE:
yield chunk
return
except (json.JSONDecodeError, TypeError):
pass

if stream_first:
for chunk in user_output_chunks:
yield chunk

# Run output rails on the accumulated context
req_id = get_request_id()
log.info("[%s] Running output rails", req_id)
output_result = await self.rails_manager.is_output_safe(messages, bot_response_chunk)
if not output_result.is_safe:
log.info("[%s] Output blocked: %s", req_id, output_result.reason)
error_data = {
"error": {
"message": f"Blocked by output rails: {output_result.reason}",
"type": "guardrails_violation",
"code": "content_blocked",
}
}
yield json.dumps(error_data)
return

if not stream_first:
for chunk in user_output_chunks:
yield chunk
Loading
Loading