Skip to content

Commit b78f48b

Browse files
authored
feat(iorails): IORails support for streaming output rails (#1765)
1 parent e2095b8 commit b78f48b

7 files changed

Lines changed: 1185 additions & 48 deletions

File tree

nemoguardrails/guardrails/iorails.py

Lines changed: 246 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,14 @@
2121
"""
2222

2323
import asyncio
24+
import json
2425
import logging
2526
import time
27+
from collections.abc import AsyncIterator
28+
from contextlib import suppress
29+
from typing import Optional, Union
2630

31+
from nemoguardrails.exceptions import StreamingNotSupportedError
2732
from nemoguardrails.guardrails.guardrails_types import (
2833
LLMMessage,
2934
LLMMessages,
@@ -34,13 +39,21 @@
3439
)
3540
from nemoguardrails.guardrails.model_manager import ModelManager
3641
from nemoguardrails.guardrails.rails_manager import RailsManager
42+
from nemoguardrails.rails.llm.buffer import get_buffer_strategy
3743
from nemoguardrails.rails.llm.config import RailsConfig
3844
from nemoguardrails.rails.llm.options import GenerationOptions
45+
from nemoguardrails.streaming import END_OF_STREAM, StreamingHandler
3946

4047
log = logging.getLogger(__name__)
4148

4249
REFUSAL_MESSAGE = "I'm sorry, I can't respond to that."
4350

51+
# Default concurrency budget for streaming requests (separate from the AsyncWorkQueue for generate_async)
52+
STREAM_MAX_CONCURRENCY = 256
53+
54+
# Error type used by _generation_task when pushing error JSON into the stream
55+
_GENERATION_ERROR_TYPE = "generation_error"
56+
4457

4558
class IORails:
4659
"""Workflow engine for accelerated Input/Output rails inference."""
@@ -55,6 +68,15 @@ def __init__(self, config: RailsConfig) -> None:
5568
# Rails Manager is responsible for running rails by making calls to Model Manager
5669
self.rails_manager = RailsManager(config, self.model_manager)
5770

71+
# Semaphore for streaming concurrency control / load shedding
72+
self._stream_semaphore = asyncio.Semaphore(STREAM_MAX_CONCURRENCY)
73+
74+
@property
75+
def _has_streaming_output_rails(self) -> bool:
76+
"""True when output rails are configured and streaming is enabled for them."""
77+
streaming = self.config.rails.output.streaming
78+
return streaming is not None and streaming.enabled and len(self.config.rails.output.flows) > 0
79+
5880
async def start(self) -> None:
5981
"""Start the IORails engine. Call this during service startup."""
6082
if self._running:
@@ -98,6 +120,8 @@ async def _run_sync_iorails():
98120

99121
async def generate_async(self, messages: LLMMessages, **kwargs) -> LLMMessage:
100122
"""Run input rails, generation, and output rails. Return response if safe."""
123+
await self.start()
124+
101125
token = set_new_request_id()
102126
req_id = get_request_id()
103127
t0 = time.monotonic()
@@ -113,13 +137,13 @@ async def generate_async(self, messages: LLMMessages, **kwargs) -> LLMMessage:
113137
return {"role": "assistant", "content": REFUSAL_MESSAGE}
114138

115139
# Step 2: Generate response from main LLM
116-
# If we got an `options=GenerationOptions`, then unpack GenerationOptions.llm_params and add
117-
# that to the main LLM call
118140
log.info("[%s] Calling main LLM", req_id)
119141
llm_kwargs = {}
120-
if kwargs.get("options") and isinstance(kwargs["options"], GenerationOptions):
121-
generation_options = kwargs["options"]
122-
llm_kwargs = generation_options.llm_params if generation_options.llm_params else {}
142+
options = kwargs.get("options")
143+
if options and isinstance(options, dict):
144+
options = GenerationOptions(**options)
145+
if isinstance(options, GenerationOptions) and options.llm_params:
146+
llm_kwargs = options.llm_params
123147

124148
response_text = await self.model_manager.generate_async("main", messages, **llm_kwargs)
125149
log.debug("[%s] Main LLM response: %s", req_id, truncate(response_text))
@@ -140,3 +164,220 @@ async def generate_async(self, messages: LLMMessages, **kwargs) -> LLMMessage:
140164
elapsed_ms = (time.monotonic() - t0) * 1000
141165
log.info("[%s] generate_async completed time=%.1fms", req_id, elapsed_ms)
142166
reset_request_id(token)
167+
168+
def _validate_streaming_with_output_rails(self) -> None:
169+
"""Raise if output rails exist but streaming is not enabled for them."""
170+
if len(self.config.rails.output.flows) > 0 and not self._has_streaming_output_rails:
171+
raise StreamingNotSupportedError(
172+
"stream_async() cannot be used when output rails are configured but "
173+
"rails.output.streaming.enabled is False. Either set "
174+
"rails.output.streaming.enabled to True in your configuration, or use "
175+
"generate_async() instead of stream_async()."
176+
)
177+
178+
def stream_async(
179+
self,
180+
messages: LLMMessages,
181+
options: Optional[Union[dict, GenerationOptions]] = None,
182+
include_metadata: Optional[bool] = False,
183+
) -> AsyncIterator[Union[str, dict]]:
184+
"""Stream LLM response tokens with input/output rails applied.
185+
186+
Returns an async iterator that yields string chunks (or dicts when
187+
``include_metadata=True``). Input rails run before any tokens are
188+
streamed. If output rails are configured and streaming is enabled,
189+
tokens are buffered and checked using the same ``RollingBuffer`` /
190+
``stream_first`` semantics as LLMRails.
191+
192+
Args:
193+
messages: Conversation messages in OpenAI format.
194+
options: Optional GenerationOptions (llm_params are forwarded to
195+
the main LLM call).
196+
include_metadata: When True, chunks are dicts with ``text`` and
197+
``metadata`` keys instead of plain strings.
198+
199+
Returns:
200+
An async iterator of string chunks (or dicts).
201+
202+
Raises:
203+
StreamingNotSupportedError: If output rails are present but
204+
``rails.output.streaming.enabled`` is False.
205+
ValueError: If ``include_metadata=True`` with output rails
206+
streaming enabled (BufferStrategy requires plain string chunks).
207+
asyncio.QueueFull: If the streaming concurrency limit is
208+
reached (load shedding).
209+
"""
210+
self._validate_streaming_with_output_rails()
211+
212+
if include_metadata and self._has_streaming_output_rails:
213+
raise ValueError(
214+
"include_metadata=True is not supported when output rails streaming is enabled. "
215+
"BufferStrategy requires plain string chunks. Use include_metadata=False or "
216+
"disable output rails streaming."
217+
)
218+
219+
# Extract llm_params from GenerationOptions if provided
220+
llm_kwargs: dict = {}
221+
if options and isinstance(options, dict):
222+
options = GenerationOptions(**options)
223+
if isinstance(options, GenerationOptions) and options.llm_params:
224+
llm_kwargs = options.llm_params
225+
226+
streaming_handler = StreamingHandler(include_metadata=include_metadata)
227+
228+
async def _generation_task():
229+
"""Background task: input rails → stream LLM chunks → push to handler.
230+
231+
Inherits the request ID from the caller context via create_task().
232+
"""
233+
req_id = get_request_id()
234+
t0 = time.monotonic()
235+
try:
236+
# Step 1: Input rails (non-streaming)
237+
log.info("[%s] Running input rails", req_id)
238+
input_result = await self.rails_manager.is_input_safe(messages)
239+
if not input_result.is_safe:
240+
log.info("[%s] Input blocked: %s", req_id, input_result.reason)
241+
await streaming_handler.push_chunk(REFUSAL_MESSAGE)
242+
await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore[arg-type]
243+
return
244+
245+
# Step 2: Stream main LLM
246+
log.info("[%s] Streaming main LLM", req_id)
247+
async for chunk in self.model_manager.stream_async("main", messages, **llm_kwargs):
248+
await streaming_handler.push_chunk(chunk)
249+
250+
await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore[arg-type]
251+
except Exception as e:
252+
elapsed_ms = (time.monotonic() - t0) * 1000
253+
log.error(
254+
"[%s] generation task failed time=%.1fms",
255+
req_id,
256+
elapsed_ms,
257+
exc_info=True,
258+
)
259+
error_payload = json.dumps(
260+
{"error": {"message": str(e), "type": _GENERATION_ERROR_TYPE, "code": "generation_failed"}}
261+
)
262+
await streaming_handler.push_chunk(error_payload)
263+
await streaming_handler.push_chunk(END_OF_STREAM) # type: ignore[arg-type]
264+
finally:
265+
elapsed_ms = (time.monotonic() - t0) * 1000
266+
log.info("[%s] generation task completed time=%.1fms", req_id, elapsed_ms)
267+
268+
async def _wrapped_iterator():
269+
"""Wrap the base iterator with semaphore-based concurrency control."""
270+
# Ensure engines are running (idempotent if already started).
271+
await self.start()
272+
273+
# Non-blocking acquire; raises immediately if all slots are taken.
274+
# locked() returns True when the semaphore value is 0. Because there
275+
# is no await between the check and acquire(), no other coroutine can
276+
# interleave in asyncio's cooperative model, so this is race-free.
277+
if self._stream_semaphore.locked():
278+
raise asyncio.QueueFull("Streaming concurrency limit reached")
279+
await self._stream_semaphore.acquire()
280+
281+
token = set_new_request_id()
282+
req_id = get_request_id()
283+
t0 = time.monotonic()
284+
try:
285+
log.info("[%s] stream_async called", req_id)
286+
log.debug("[%s] stream_async messages=%s", req_id, truncate(messages))
287+
288+
task = asyncio.create_task(_generation_task())
289+
try:
290+
# Determine base iterator: with or without output rails
291+
if self._has_streaming_output_rails:
292+
base_iterator = self._run_output_rails_in_streaming(
293+
streaming_handler=streaming_handler,
294+
messages=messages,
295+
)
296+
else:
297+
base_iterator = streaming_handler
298+
299+
async for chunk in base_iterator:
300+
if chunk is not None:
301+
yield chunk
302+
finally:
303+
try:
304+
if not task.done():
305+
task.cancel()
306+
with suppress(asyncio.CancelledError):
307+
await task
308+
finally:
309+
try:
310+
reset_request_id(token)
311+
except ValueError:
312+
# GeneratorExit triggers cleanup in a different context
313+
# where the token is no longer valid — safe to ignore.
314+
pass
315+
except Exception:
316+
elapsed_ms = (time.monotonic() - t0) * 1000
317+
log.error("[%s] stream_async failed time=%.1fms", req_id, elapsed_ms, exc_info=True)
318+
raise
319+
finally:
320+
elapsed_ms = (time.monotonic() - t0) * 1000
321+
log.info("[%s] stream_async completed time=%.1fms", req_id, elapsed_ms)
322+
self._stream_semaphore.release()
323+
324+
return _wrapped_iterator()
325+
326+
async def _run_output_rails_in_streaming(
327+
self,
328+
streaming_handler: AsyncIterator[Union[str, dict]],
329+
messages: LLMMessages,
330+
) -> AsyncIterator[Union[str, dict]]:
331+
"""Buffer streamed chunks and run output rails on each batch.
332+
333+
Uses the same ``RollingBuffer`` and ``stream_first`` semantics as
334+
LLMRails:
335+
- ``stream_first=True``: yield chunks immediately, then run output
336+
rails. If unsafe, inject an error and stop.
337+
- ``stream_first=False``: run output rails first, only yield chunks
338+
if safe.
339+
"""
340+
341+
# Unpack streaming config and get the buffer strategy
342+
output_streaming_config = self.config.rails.output.streaming
343+
stream_first = output_streaming_config.stream_first
344+
buffer_strategy = get_buffer_strategy(output_streaming_config)
345+
346+
async for chunk_batch in buffer_strategy(streaming_handler):
347+
user_output_chunks = chunk_batch.user_output_chunks
348+
bot_response_chunk = buffer_strategy.format_chunks(chunk_batch.processing_context)
349+
350+
# If the batch contains a generation error from _generation_task,
351+
# yield it directly and stop — don't feed error JSON through output rails.
352+
for chunk in user_output_chunks:
353+
try:
354+
parsed = json.loads(chunk)
355+
if isinstance(parsed, dict) and parsed.get("error", {}).get("type") == _GENERATION_ERROR_TYPE:
356+
yield chunk
357+
return
358+
except (json.JSONDecodeError, TypeError):
359+
pass
360+
361+
if stream_first:
362+
for chunk in user_output_chunks:
363+
yield chunk
364+
365+
# Run output rails on the accumulated context
366+
req_id = get_request_id()
367+
log.info("[%s] Running output rails", req_id)
368+
output_result = await self.rails_manager.is_output_safe(messages, bot_response_chunk)
369+
if not output_result.is_safe:
370+
log.info("[%s] Output blocked: %s", req_id, output_result.reason)
371+
error_data = {
372+
"error": {
373+
"message": f"Blocked by output rails: {output_result.reason}",
374+
"type": "guardrails_violation",
375+
"code": "content_blocked",
376+
}
377+
}
378+
yield json.dumps(error_data)
379+
return
380+
381+
if not stream_first:
382+
for chunk in user_output_chunks:
383+
yield chunk

0 commit comments

Comments
 (0)