2121"""
2222
2323import asyncio
24+ import json
2425import logging
2526import time
27+ from collections .abc import AsyncIterator
28+ from contextlib import suppress
29+ from typing import Optional , Union
2630
31+ from nemoguardrails .exceptions import StreamingNotSupportedError
2732from nemoguardrails .guardrails .guardrails_types import (
2833 LLMMessage ,
2934 LLMMessages ,
3439)
3540from nemoguardrails .guardrails .model_manager import ModelManager
3641from nemoguardrails .guardrails .rails_manager import RailsManager
42+ from nemoguardrails .rails .llm .buffer import get_buffer_strategy
3743from nemoguardrails .rails .llm .config import RailsConfig
3844from nemoguardrails .rails .llm .options import GenerationOptions
45+ from nemoguardrails .streaming import END_OF_STREAM , StreamingHandler
3946
4047log = logging .getLogger (__name__ )
4148
4249REFUSAL_MESSAGE = "I'm sorry, I can't respond to that."
4350
51+ # Default concurrency budget for streaming requests (separate from the AsyncWorkQueue for generate_async)
52+ STREAM_MAX_CONCURRENCY = 256
53+
54+ # Error type used by _generation_task when pushing error JSON into the stream
55+ _GENERATION_ERROR_TYPE = "generation_error"
56+
4457
4558class IORails :
4659 """Workflow engine for accelerated Input/Output rails inference."""
@@ -55,6 +68,15 @@ def __init__(self, config: RailsConfig) -> None:
5568 # Rails Manager is responsible for running rails by making calls to Model Manager
5669 self .rails_manager = RailsManager (config , self .model_manager )
5770
71+ # Semaphore for streaming concurrency control / load shedding
72+ self ._stream_semaphore = asyncio .Semaphore (STREAM_MAX_CONCURRENCY )
73+
74+ @property
75+ def _has_streaming_output_rails (self ) -> bool :
76+ """True when output rails are configured and streaming is enabled for them."""
77+ streaming = self .config .rails .output .streaming
78+ return streaming is not None and streaming .enabled and len (self .config .rails .output .flows ) > 0
79+
5880 async def start (self ) -> None :
5981 """Start the IORails engine. Call this during service startup."""
6082 if self ._running :
@@ -98,6 +120,8 @@ async def _run_sync_iorails():
98120
99121 async def generate_async (self , messages : LLMMessages , ** kwargs ) -> LLMMessage :
100122 """Run input rails, generation, and output rails. Return response if safe."""
123+ await self .start ()
124+
101125 token = set_new_request_id ()
102126 req_id = get_request_id ()
103127 t0 = time .monotonic ()
@@ -113,13 +137,13 @@ async def generate_async(self, messages: LLMMessages, **kwargs) -> LLMMessage:
113137 return {"role" : "assistant" , "content" : REFUSAL_MESSAGE }
114138
115139 # Step 2: Generate response from main LLM
116- # If we got an `options=GenerationOptions`, then unpack GenerationOptions.llm_params and add
117- # that to the main LLM call
118140 log .info ("[%s] Calling main LLM" , req_id )
119141 llm_kwargs = {}
120- if kwargs .get ("options" ) and isinstance (kwargs ["options" ], GenerationOptions ):
121- generation_options = kwargs ["options" ]
122- llm_kwargs = generation_options .llm_params if generation_options .llm_params else {}
142+ options = kwargs .get ("options" )
143+ if options and isinstance (options , dict ):
144+ options = GenerationOptions (** options )
145+ if isinstance (options , GenerationOptions ) and options .llm_params :
146+ llm_kwargs = options .llm_params
123147
124148 response_text = await self .model_manager .generate_async ("main" , messages , ** llm_kwargs )
125149 log .debug ("[%s] Main LLM response: %s" , req_id , truncate (response_text ))
@@ -140,3 +164,220 @@ async def generate_async(self, messages: LLMMessages, **kwargs) -> LLMMessage:
140164 elapsed_ms = (time .monotonic () - t0 ) * 1000
141165 log .info ("[%s] generate_async completed time=%.1fms" , req_id , elapsed_ms )
142166 reset_request_id (token )
167+
168+ def _validate_streaming_with_output_rails (self ) -> None :
169+ """Raise if output rails exist but streaming is not enabled for them."""
170+ if len (self .config .rails .output .flows ) > 0 and not self ._has_streaming_output_rails :
171+ raise StreamingNotSupportedError (
172+ "stream_async() cannot be used when output rails are configured but "
173+ "rails.output.streaming.enabled is False. Either set "
174+ "rails.output.streaming.enabled to True in your configuration, or use "
175+ "generate_async() instead of stream_async()."
176+ )
177+
178+ def stream_async (
179+ self ,
180+ messages : LLMMessages ,
181+ options : Optional [Union [dict , GenerationOptions ]] = None ,
182+ include_metadata : Optional [bool ] = False ,
183+ ) -> AsyncIterator [Union [str , dict ]]:
184+ """Stream LLM response tokens with input/output rails applied.
185+
186+ Returns an async iterator that yields string chunks (or dicts when
187+ ``include_metadata=True``). Input rails run before any tokens are
188+ streamed. If output rails are configured and streaming is enabled,
189+ tokens are buffered and checked using the same ``RollingBuffer`` /
190+ ``stream_first`` semantics as LLMRails.
191+
192+ Args:
193+ messages: Conversation messages in OpenAI format.
194+ options: Optional GenerationOptions (llm_params are forwarded to
195+ the main LLM call).
196+ include_metadata: When True, chunks are dicts with ``text`` and
197+ ``metadata`` keys instead of plain strings.
198+
199+ Returns:
200+ An async iterator of string chunks (or dicts).
201+
202+ Raises:
203+ StreamingNotSupportedError: If output rails are present but
204+ ``rails.output.streaming.enabled`` is False.
205+ ValueError: If ``include_metadata=True`` with output rails
206+ streaming enabled (BufferStrategy requires plain string chunks).
207+ asyncio.QueueFull: If the streaming concurrency limit is
208+ reached (load shedding).
209+ """
210+ self ._validate_streaming_with_output_rails ()
211+
212+ if include_metadata and self ._has_streaming_output_rails :
213+ raise ValueError (
214+ "include_metadata=True is not supported when output rails streaming is enabled. "
215+ "BufferStrategy requires plain string chunks. Use include_metadata=False or "
216+ "disable output rails streaming."
217+ )
218+
219+ # Extract llm_params from GenerationOptions if provided
220+ llm_kwargs : dict = {}
221+ if options and isinstance (options , dict ):
222+ options = GenerationOptions (** options )
223+ if isinstance (options , GenerationOptions ) and options .llm_params :
224+ llm_kwargs = options .llm_params
225+
226+ streaming_handler = StreamingHandler (include_metadata = include_metadata )
227+
228+ async def _generation_task ():
229+ """Background task: input rails → stream LLM chunks → push to handler.
230+
231+ Inherits the request ID from the caller context via create_task().
232+ """
233+ req_id = get_request_id ()
234+ t0 = time .monotonic ()
235+ try :
236+ # Step 1: Input rails (non-streaming)
237+ log .info ("[%s] Running input rails" , req_id )
238+ input_result = await self .rails_manager .is_input_safe (messages )
239+ if not input_result .is_safe :
240+ log .info ("[%s] Input blocked: %s" , req_id , input_result .reason )
241+ await streaming_handler .push_chunk (REFUSAL_MESSAGE )
242+ await streaming_handler .push_chunk (END_OF_STREAM ) # type: ignore[arg-type]
243+ return
244+
245+ # Step 2: Stream main LLM
246+ log .info ("[%s] Streaming main LLM" , req_id )
247+ async for chunk in self .model_manager .stream_async ("main" , messages , ** llm_kwargs ):
248+ await streaming_handler .push_chunk (chunk )
249+
250+ await streaming_handler .push_chunk (END_OF_STREAM ) # type: ignore[arg-type]
251+ except Exception as e :
252+ elapsed_ms = (time .monotonic () - t0 ) * 1000
253+ log .error (
254+ "[%s] generation task failed time=%.1fms" ,
255+ req_id ,
256+ elapsed_ms ,
257+ exc_info = True ,
258+ )
259+ error_payload = json .dumps (
260+ {"error" : {"message" : str (e ), "type" : _GENERATION_ERROR_TYPE , "code" : "generation_failed" }}
261+ )
262+ await streaming_handler .push_chunk (error_payload )
263+ await streaming_handler .push_chunk (END_OF_STREAM ) # type: ignore[arg-type]
264+ finally :
265+ elapsed_ms = (time .monotonic () - t0 ) * 1000
266+ log .info ("[%s] generation task completed time=%.1fms" , req_id , elapsed_ms )
267+
268+ async def _wrapped_iterator ():
269+ """Wrap the base iterator with semaphore-based concurrency control."""
270+ # Ensure engines are running (idempotent if already started).
271+ await self .start ()
272+
273+ # Non-blocking acquire; raises immediately if all slots are taken.
274+ # locked() returns True when the semaphore value is 0. Because there
275+ # is no await between the check and acquire(), no other coroutine can
276+ # interleave in asyncio's cooperative model, so this is race-free.
277+ if self ._stream_semaphore .locked ():
278+ raise asyncio .QueueFull ("Streaming concurrency limit reached" )
279+ await self ._stream_semaphore .acquire ()
280+
281+ token = set_new_request_id ()
282+ req_id = get_request_id ()
283+ t0 = time .monotonic ()
284+ try :
285+ log .info ("[%s] stream_async called" , req_id )
286+ log .debug ("[%s] stream_async messages=%s" , req_id , truncate (messages ))
287+
288+ task = asyncio .create_task (_generation_task ())
289+ try :
290+ # Determine base iterator: with or without output rails
291+ if self ._has_streaming_output_rails :
292+ base_iterator = self ._run_output_rails_in_streaming (
293+ streaming_handler = streaming_handler ,
294+ messages = messages ,
295+ )
296+ else :
297+ base_iterator = streaming_handler
298+
299+ async for chunk in base_iterator :
300+ if chunk is not None :
301+ yield chunk
302+ finally :
303+ try :
304+ if not task .done ():
305+ task .cancel ()
306+ with suppress (asyncio .CancelledError ):
307+ await task
308+ finally :
309+ try :
310+ reset_request_id (token )
311+ except ValueError :
312+ # GeneratorExit triggers cleanup in a different context
313+ # where the token is no longer valid — safe to ignore.
314+ pass
315+ except Exception :
316+ elapsed_ms = (time .monotonic () - t0 ) * 1000
317+ log .error ("[%s] stream_async failed time=%.1fms" , req_id , elapsed_ms , exc_info = True )
318+ raise
319+ finally :
320+ elapsed_ms = (time .monotonic () - t0 ) * 1000
321+ log .info ("[%s] stream_async completed time=%.1fms" , req_id , elapsed_ms )
322+ self ._stream_semaphore .release ()
323+
324+ return _wrapped_iterator ()
325+
326+ async def _run_output_rails_in_streaming (
327+ self ,
328+ streaming_handler : AsyncIterator [Union [str , dict ]],
329+ messages : LLMMessages ,
330+ ) -> AsyncIterator [Union [str , dict ]]:
331+ """Buffer streamed chunks and run output rails on each batch.
332+
333+ Uses the same ``RollingBuffer`` and ``stream_first`` semantics as
334+ LLMRails:
335+ - ``stream_first=True``: yield chunks immediately, then run output
336+ rails. If unsafe, inject an error and stop.
337+ - ``stream_first=False``: run output rails first, only yield chunks
338+ if safe.
339+ """
340+
341+ # Unpack streaming config and get the buffer strategy
342+ output_streaming_config = self .config .rails .output .streaming
343+ stream_first = output_streaming_config .stream_first
344+ buffer_strategy = get_buffer_strategy (output_streaming_config )
345+
346+ async for chunk_batch in buffer_strategy (streaming_handler ):
347+ user_output_chunks = chunk_batch .user_output_chunks
348+ bot_response_chunk = buffer_strategy .format_chunks (chunk_batch .processing_context )
349+
350+ # If the batch contains a generation error from _generation_task,
351+ # yield it directly and stop — don't feed error JSON through output rails.
352+ for chunk in user_output_chunks :
353+ try :
354+ parsed = json .loads (chunk )
355+ if isinstance (parsed , dict ) and parsed .get ("error" , {}).get ("type" ) == _GENERATION_ERROR_TYPE :
356+ yield chunk
357+ return
358+ except (json .JSONDecodeError , TypeError ):
359+ pass
360+
361+ if stream_first :
362+ for chunk in user_output_chunks :
363+ yield chunk
364+
365+ # Run output rails on the accumulated context
366+ req_id = get_request_id ()
367+ log .info ("[%s] Running output rails" , req_id )
368+ output_result = await self .rails_manager .is_output_safe (messages , bot_response_chunk )
369+ if not output_result .is_safe :
370+ log .info ("[%s] Output blocked: %s" , req_id , output_result .reason )
371+ error_data = {
372+ "error" : {
373+ "message" : f"Blocked by output rails: { output_result .reason } " ,
374+ "type" : "guardrails_violation" ,
375+ "code" : "content_blocked" ,
376+ }
377+ }
378+ yield json .dumps (error_data )
379+ return
380+
381+ if not stream_first :
382+ for chunk in user_output_chunks :
383+ yield chunk
0 commit comments