|
| 1 | +"""HITL Manager for Breeze Buddy - handles voice-native confirmation requests.""" |
| 2 | + |
| 3 | +import asyncio |
| 4 | +import uuid |
| 5 | +from dataclasses import dataclass, field |
| 6 | +from difflib import SequenceMatcher |
| 7 | +from typing import Any, Dict, List, Optional |
| 8 | + |
| 9 | +from pipecat.frames.frames import ( |
| 10 | + BotStartedSpeakingFrame, |
| 11 | + BotStoppedSpeakingFrame, |
| 12 | + TranscriptionFrame, |
| 13 | + TTSSpeakFrame, |
| 14 | + UserStoppedSpeakingFrame, |
| 15 | +) |
| 16 | + |
| 17 | +from app.ai.voice.agents.breeze_buddy.template.context import TemplateContext |
| 18 | +from app.ai.voice.agents.breeze_buddy.template.interruption import ( |
| 19 | + _apply_interruption_config, |
| 20 | + _get_user_aggregator, |
| 21 | +) |
| 22 | +from app.ai.voice.agents.breeze_buddy.template.types import ( |
| 23 | + HITLConfig, |
| 24 | + InterruptionConfig, |
| 25 | + InterruptionMode, |
| 26 | +) |
| 27 | +from app.core.config.static import BREEZE_BUDDY_HITL_DEFAULT_TIMEOUT |
| 28 | +from app.core.logger import logger |
| 29 | + |
| 30 | + |
| 31 | +def _similarity(a: str, b: str) -> float: |
| 32 | + return SequenceMatcher(None, a, b).ratio() |
| 33 | + |
| 34 | + |
| 35 | +def _matches_response(text: str, keywords: List[str]) -> bool: |
| 36 | + """Match a single text chunk against keywords.""" |
| 37 | + text = text.lower().strip() |
| 38 | + |
| 39 | + # Fast substring match |
| 40 | + for keyword in keywords: |
| 41 | + if keyword.lower() in text: |
| 42 | + return True |
| 43 | + |
| 44 | + # Fuzzy match (safer thresholds) |
| 45 | + for keyword in keywords: |
| 46 | + kw = keyword.lower() |
| 47 | + if abs(len(text) - len(kw)) <= 2: |
| 48 | + threshold = 0.9 if len(kw) <= 3 else 0.8 |
| 49 | + if _similarity(text, kw) >= threshold: |
| 50 | + return True |
| 51 | + |
| 52 | + return False |
| 53 | + |
| 54 | + |
| 55 | +@dataclass |
| 56 | +class PendingConfirmation: |
| 57 | + confirmation_id: str |
| 58 | + config: HITLConfig |
| 59 | + function_name: str |
| 60 | + arguments: Dict[str, Any] |
| 61 | + event: asyncio.Event = field(default_factory=asyncio.Event) |
| 62 | + response: Optional[Dict[str, Any]] = None |
| 63 | + retry_count: int = 0 |
| 64 | + max_retries: int = 3 |
| 65 | + ask_again: bool = False |
| 66 | + |
| 67 | + |
| 68 | +class BreezeBuddyHITLManager: |
| 69 | + def __init__(self): |
| 70 | + self._pending: Dict[str, PendingConfirmation] = {} |
| 71 | + self._active_confirmation_id: Optional[str] = None |
| 72 | + |
| 73 | + def is_confirmation_active(self) -> bool: |
| 74 | + return self._active_confirmation_id is not None |
| 75 | + |
| 76 | + async def consume_transcription_if_hitl(self, text: str) -> bool: |
| 77 | + """ |
| 78 | + Returns True if text is consumed by HITL (and should NOT go to LLM). |
| 79 | + """ |
| 80 | + confirmation_id = self._active_confirmation_id |
| 81 | + if not confirmation_id: |
| 82 | + return False |
| 83 | + |
| 84 | + pending = self._pending.get(confirmation_id) |
| 85 | + if not pending: |
| 86 | + return False |
| 87 | + |
| 88 | + text = text.strip() |
| 89 | + if not text: |
| 90 | + return True |
| 91 | + |
| 92 | + logger.info(f"HITL: Heard '{text}'") |
| 93 | + |
| 94 | + if _matches_response(text, pending.config.accepted_responses): |
| 95 | + pending.response = {"approved": True, "reason": "user_approved"} |
| 96 | + pending.event.set() |
| 97 | + return True |
| 98 | + |
| 99 | + if _matches_response(text, pending.config.rejected_responses): |
| 100 | + pending.response = {"approved": False, "reason": "user_rejected"} |
| 101 | + pending.event.set() |
| 102 | + return True |
| 103 | + |
| 104 | + # Unmatched response - increment retry |
| 105 | + pending.retry_count += 1 |
| 106 | + if pending.retry_count >= pending.max_retries: |
| 107 | + pending.response = {"approved": False, "reason": "max_retries_exceeded"} |
| 108 | + pending.event.set() |
| 109 | + logger.warning(f"HITL: Max retries exceeded for {pending.function_name}") |
| 110 | + return True |
| 111 | + |
| 112 | + # Signal to re-prompt and continue listening |
| 113 | + pending.ask_again = True |
| 114 | + pending.event.set() |
| 115 | + return True |
| 116 | + |
| 117 | + async def request_confirmation( |
| 118 | + self, |
| 119 | + context: TemplateContext, |
| 120 | + config: HITLConfig, |
| 121 | + function_name: str, |
| 122 | + arguments: Dict[str, Any], |
| 123 | + ) -> Dict[str, Any]: |
| 124 | + confirmation_id = str(uuid.uuid4()) |
| 125 | + |
| 126 | + pending = PendingConfirmation( |
| 127 | + confirmation_id=confirmation_id, |
| 128 | + config=config, |
| 129 | + function_name=function_name, |
| 130 | + arguments=arguments, |
| 131 | + ) |
| 132 | + self._pending[confirmation_id] = pending |
| 133 | + self._active_confirmation_id = confirmation_id |
| 134 | + |
| 135 | + timeout = float(config.timeout_seconds or BREEZE_BUDDY_HITL_DEFAULT_TIMEOUT) |
| 136 | + logger.info( |
| 137 | + f"HITL: Requesting confirmation for {function_name}, timeout={timeout}s" |
| 138 | + ) |
| 139 | + |
| 140 | + try: |
| 141 | + return await self._voice_confirmation(context, pending) |
| 142 | + except asyncio.TimeoutError: |
| 143 | + logger.warning(f"HITL timeout for {function_name}") |
| 144 | + return {"approved": False, "reason": "timeout"} |
| 145 | + finally: |
| 146 | + if self._active_confirmation_id == confirmation_id: |
| 147 | + self._active_confirmation_id = None |
| 148 | + self._cleanup(confirmation_id) |
| 149 | + |
| 150 | + async def _voice_confirmation( |
| 151 | + self, |
| 152 | + context: TemplateContext, |
| 153 | + pending: PendingConfirmation, |
| 154 | + ) -> Dict[str, Any]: |
| 155 | + task = context.task |
| 156 | + config = pending.config |
| 157 | + |
| 158 | + if not task: |
| 159 | + logger.error("No task available for HITL confirmation") |
| 160 | + return {"approved": False, "reason": "no_task"} |
| 161 | + |
| 162 | + message = config.confirmation_message or self._default_message( |
| 163 | + pending.function_name, pending.arguments |
| 164 | + ) |
| 165 | + |
| 166 | + timeout = float(config.timeout_seconds or BREEZE_BUDDY_HITL_DEFAULT_TIMEOUT) |
| 167 | + |
| 168 | + logger.info( |
| 169 | + f"HITL: Starting confirmation for {pending.function_name}, timeout={timeout}s" |
| 170 | + ) |
| 171 | + |
| 172 | + task.add_reached_upstream_filter((TranscriptionFrame, UserStoppedSpeakingFrame)) |
| 173 | + |
| 174 | + user_aggregator = _get_user_aggregator(context.bot) |
| 175 | + original_config = None |
| 176 | + |
| 177 | + if user_aggregator: |
| 178 | + original_config = getattr(user_aggregator, "interruption_config", None) |
| 179 | + disable_config = InterruptionConfig(mode=InterruptionMode.DISABLED_DISCARD) |
| 180 | + |
| 181 | + await _apply_interruption_config( |
| 182 | + user_aggregator, |
| 183 | + disable_config, |
| 184 | + has_vad=context.bot.vad_analyzer is not None, |
| 185 | + call_sid=context.call_sid or "unknown", |
| 186 | + label="hitl_confirmation", |
| 187 | + bot=context.bot, |
| 188 | + user_speech_timeout=0.0, |
| 189 | + ) |
| 190 | + logger.info( |
| 191 | + f"HITL: Interruptions disabled (user_speech_timeout=0.0) for {timeout}s confirmation window" |
| 192 | + ) |
| 193 | + |
| 194 | + # Two-phase TTS wait: BotStarted → BotStopped for THIS confirmation prompt |
| 195 | + tts_complete = asyncio.Event() |
| 196 | + bot_started = False |
| 197 | + |
| 198 | + async def on_bot_tts_window(task_obj, frame): |
| 199 | + nonlocal bot_started |
| 200 | + if tts_complete.is_set(): |
| 201 | + return # This handler's job is done; avoid stale reactions |
| 202 | + if isinstance(frame, BotStartedSpeakingFrame): |
| 203 | + bot_started = True |
| 204 | + elif isinstance(frame, BotStoppedSpeakingFrame) and bot_started: |
| 205 | + tts_complete.set() |
| 206 | + |
| 207 | + # Ensure we receive BotStartedSpeakingFrame too |
| 208 | + task.add_reached_downstream_filter((BotStartedSpeakingFrame,)) |
| 209 | + task.add_event_handler("on_frame_reached_downstream", on_bot_tts_window) |
| 210 | + |
| 211 | + try: |
| 212 | + await task.queue_frame(TTSSpeakFrame(text=message)) |
| 213 | + |
| 214 | + try: |
| 215 | + await asyncio.wait_for(tts_complete.wait(), timeout=10) |
| 216 | + logger.info("HITL: Confirmation prompt finished speaking") |
| 217 | + except asyncio.TimeoutError: |
| 218 | + logger.warning("HITL: TTS timeout") |
| 219 | + |
| 220 | + logger.info("HITL: Listening for response...") |
| 221 | + |
| 222 | + while pending.retry_count < pending.max_retries: |
| 223 | + try: |
| 224 | + pending.event.clear() |
| 225 | + await asyncio.wait_for(pending.event.wait(), timeout=timeout) |
| 226 | + except asyncio.TimeoutError: |
| 227 | + logger.warning("HITL: Response timeout") |
| 228 | + pending.response = {"approved": False, "reason": "timeout"} |
| 229 | + break |
| 230 | + |
| 231 | + # Check if we need to ask again |
| 232 | + if pending.ask_again: |
| 233 | + pending.ask_again = False |
| 234 | + retry_message = ( |
| 235 | + f"I didn't understand. {message} Please say yes or no." |
| 236 | + ) |
| 237 | + tts_complete.clear() |
| 238 | + bot_started = False |
| 239 | + await task.queue_frame(TTSSpeakFrame(text=retry_message)) |
| 240 | + try: |
| 241 | + await asyncio.wait_for(tts_complete.wait(), timeout=10) |
| 242 | + except asyncio.TimeoutError: |
| 243 | + logger.warning("HITL: Retry TTS timeout") |
| 244 | + logger.info("HITL: Listening for retry response...") |
| 245 | + continue |
| 246 | + |
| 247 | + # Response matched (approved or rejected) |
| 248 | + break |
| 249 | + |
| 250 | + result = pending.response or {"approved": False, "reason": "no_response"} |
| 251 | + logger.info(f"HITL: Confirmation result={result}") |
| 252 | + return result |
| 253 | + |
| 254 | + finally: |
| 255 | + # Critical: Restore interruptions AFTER HITL response is complete |
| 256 | + # and has been delivered to LLM state machine |
| 257 | + if user_aggregator and original_config: |
| 258 | + logger.info( |
| 259 | + f"HITL: Restoring interruptions after confirmation " |
| 260 | + f"(result={pending.response})" |
| 261 | + ) |
| 262 | + await _apply_interruption_config( |
| 263 | + user_aggregator, |
| 264 | + original_config, |
| 265 | + has_vad=context.bot.vad_analyzer is not None, |
| 266 | + call_sid=context.call_sid or "unknown", |
| 267 | + label="hitl_restore", |
| 268 | + bot=context.bot, |
| 269 | + ) |
| 270 | + logger.info("HITL: Interruptions restored") |
| 271 | + |
| 272 | + def _cleanup(self, confirmation_id: str): |
| 273 | + self._pending.pop(confirmation_id, None) |
| 274 | + |
| 275 | + def _default_message(self, function_name: str, arguments: Dict) -> str: |
| 276 | + action = function_name.replace("_", " ") |
| 277 | + return f"Should I {action}? Please say yes or no." |
| 278 | + |
| 279 | + |
| 280 | +_hitl_manager: Optional[BreezeBuddyHITLManager] = None |
| 281 | + |
| 282 | + |
| 283 | +def get_hitl_manager() -> BreezeBuddyHITLManager: |
| 284 | + global _hitl_manager |
| 285 | + if _hitl_manager is None: |
| 286 | + _hitl_manager = BreezeBuddyHITLManager() |
| 287 | + return _hitl_manager |
0 commit comments