Skip to content

Commit 480e1f2

Browse files
feat: added HITL feature
1 parent 72116d9 commit 480e1f2

8 files changed

Lines changed: 413 additions & 5 deletions

File tree

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from app.ai.voice.agents.breeze_buddy.features.hitl.manager import (
2+
BreezeBuddyHITLManager,
3+
get_hitl_manager,
4+
)
5+
6+
__all__ = ["BreezeBuddyHITLManager", "get_hitl_manager"]
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
"""HITL Manager for Breeze Buddy - handles voice-native confirmation requests."""
2+
3+
import asyncio
4+
import uuid
5+
from dataclasses import dataclass, field
6+
from difflib import SequenceMatcher
7+
from typing import Any, Dict, List, Optional
8+
9+
from pipecat.frames.frames import (
10+
BotStartedSpeakingFrame,
11+
BotStoppedSpeakingFrame,
12+
TranscriptionFrame,
13+
TTSSpeakFrame,
14+
UserStoppedSpeakingFrame,
15+
)
16+
17+
from app.ai.voice.agents.breeze_buddy.template.context import TemplateContext
18+
from app.ai.voice.agents.breeze_buddy.template.interruption import (
19+
_apply_interruption_config,
20+
_get_user_aggregator,
21+
)
22+
from app.ai.voice.agents.breeze_buddy.template.types import (
23+
HITLConfig,
24+
InterruptionConfig,
25+
InterruptionMode,
26+
)
27+
from app.core.config.static import BREEZE_BUDDY_HITL_DEFAULT_TIMEOUT
28+
from app.core.logger import logger
29+
30+
31+
def _similarity(a: str, b: str) -> float:
32+
return SequenceMatcher(None, a, b).ratio()
33+
34+
35+
def _matches_response(text: str, keywords: List[str]) -> bool:
36+
"""Match a single text chunk against keywords."""
37+
text = text.lower().strip()
38+
39+
# Fast substring match
40+
for keyword in keywords:
41+
if keyword.lower() in text:
42+
return True
43+
44+
# Fuzzy match (safer thresholds)
45+
for keyword in keywords:
46+
kw = keyword.lower()
47+
if abs(len(text) - len(kw)) <= 2:
48+
threshold = 0.9 if len(kw) <= 3 else 0.8
49+
if _similarity(text, kw) >= threshold:
50+
return True
51+
52+
return False
53+
54+
55+
@dataclass
56+
class PendingConfirmation:
57+
confirmation_id: str
58+
config: HITLConfig
59+
function_name: str
60+
arguments: Dict[str, Any]
61+
event: asyncio.Event = field(default_factory=asyncio.Event)
62+
response: Optional[Dict[str, Any]] = None
63+
retry_count: int = 0
64+
max_retries: int = 3
65+
ask_again: bool = False
66+
67+
68+
class BreezeBuddyHITLManager:
69+
def __init__(self):
70+
self._pending: Dict[str, PendingConfirmation] = {}
71+
self._active_confirmation_id: Optional[str] = None
72+
73+
def is_confirmation_active(self) -> bool:
74+
return self._active_confirmation_id is not None
75+
76+
async def consume_transcription_if_hitl(self, text: str) -> bool:
77+
"""
78+
Returns True if text is consumed by HITL (and should NOT go to LLM).
79+
"""
80+
confirmation_id = self._active_confirmation_id
81+
if not confirmation_id:
82+
return False
83+
84+
pending = self._pending.get(confirmation_id)
85+
if not pending:
86+
return False
87+
88+
text = text.strip()
89+
if not text:
90+
return True
91+
92+
logger.info(f"HITL: Heard '{text}'")
93+
94+
if _matches_response(text, pending.config.accepted_responses):
95+
pending.response = {"approved": True, "reason": "user_approved"}
96+
pending.event.set()
97+
return True
98+
99+
if _matches_response(text, pending.config.rejected_responses):
100+
pending.response = {"approved": False, "reason": "user_rejected"}
101+
pending.event.set()
102+
return True
103+
104+
# Unmatched response - increment retry
105+
pending.retry_count += 1
106+
if pending.retry_count >= pending.max_retries:
107+
pending.response = {"approved": False, "reason": "max_retries_exceeded"}
108+
pending.event.set()
109+
logger.warning(f"HITL: Max retries exceeded for {pending.function_name}")
110+
return True
111+
112+
# Signal to re-prompt and continue listening
113+
pending.ask_again = True
114+
pending.event.set()
115+
return True
116+
117+
async def request_confirmation(
118+
self,
119+
context: TemplateContext,
120+
config: HITLConfig,
121+
function_name: str,
122+
arguments: Dict[str, Any],
123+
) -> Dict[str, Any]:
124+
confirmation_id = str(uuid.uuid4())
125+
126+
pending = PendingConfirmation(
127+
confirmation_id=confirmation_id,
128+
config=config,
129+
function_name=function_name,
130+
arguments=arguments,
131+
)
132+
self._pending[confirmation_id] = pending
133+
self._active_confirmation_id = confirmation_id
134+
135+
timeout = float(config.timeout_seconds or BREEZE_BUDDY_HITL_DEFAULT_TIMEOUT)
136+
logger.info(
137+
f"HITL: Requesting confirmation for {function_name}, timeout={timeout}s"
138+
)
139+
140+
try:
141+
return await self._voice_confirmation(context, pending)
142+
except asyncio.TimeoutError:
143+
logger.warning(f"HITL timeout for {function_name}")
144+
return {"approved": False, "reason": "timeout"}
145+
finally:
146+
if self._active_confirmation_id == confirmation_id:
147+
self._active_confirmation_id = None
148+
self._cleanup(confirmation_id)
149+
150+
async def _voice_confirmation(
151+
self,
152+
context: TemplateContext,
153+
pending: PendingConfirmation,
154+
) -> Dict[str, Any]:
155+
task = context.task
156+
config = pending.config
157+
158+
if not task:
159+
logger.error("No task available for HITL confirmation")
160+
return {"approved": False, "reason": "no_task"}
161+
162+
message = config.confirmation_message or self._default_message(
163+
pending.function_name, pending.arguments
164+
)
165+
166+
timeout = float(config.timeout_seconds or BREEZE_BUDDY_HITL_DEFAULT_TIMEOUT)
167+
168+
logger.info(
169+
f"HITL: Starting confirmation for {pending.function_name}, timeout={timeout}s"
170+
)
171+
172+
task.add_reached_upstream_filter((TranscriptionFrame, UserStoppedSpeakingFrame))
173+
174+
user_aggregator = _get_user_aggregator(context.bot)
175+
original_config = None
176+
177+
if user_aggregator:
178+
original_config = getattr(user_aggregator, "interruption_config", None)
179+
disable_config = InterruptionConfig(mode=InterruptionMode.DISABLED_DISCARD)
180+
181+
await _apply_interruption_config(
182+
user_aggregator,
183+
disable_config,
184+
has_vad=context.bot.vad_analyzer is not None,
185+
call_sid=context.call_sid or "unknown",
186+
label="hitl_confirmation",
187+
bot=context.bot,
188+
user_speech_timeout=0.0,
189+
)
190+
logger.info(
191+
f"HITL: Interruptions disabled (user_speech_timeout=0.0) for {timeout}s confirmation window"
192+
)
193+
194+
# Two-phase TTS wait: BotStarted → BotStopped for THIS confirmation prompt
195+
tts_complete = asyncio.Event()
196+
bot_started = False
197+
198+
async def on_bot_tts_window(task_obj, frame):
199+
nonlocal bot_started
200+
if tts_complete.is_set():
201+
return # This handler's job is done; avoid stale reactions
202+
if isinstance(frame, BotStartedSpeakingFrame):
203+
bot_started = True
204+
elif isinstance(frame, BotStoppedSpeakingFrame) and bot_started:
205+
tts_complete.set()
206+
207+
# Ensure we receive BotStartedSpeakingFrame too
208+
task.add_reached_downstream_filter((BotStartedSpeakingFrame,))
209+
task.add_event_handler("on_frame_reached_downstream", on_bot_tts_window)
210+
211+
try:
212+
await task.queue_frame(TTSSpeakFrame(text=message))
213+
214+
try:
215+
await asyncio.wait_for(tts_complete.wait(), timeout=10)
216+
logger.info("HITL: Confirmation prompt finished speaking")
217+
except asyncio.TimeoutError:
218+
logger.warning("HITL: TTS timeout")
219+
220+
logger.info("HITL: Listening for response...")
221+
222+
while pending.retry_count < pending.max_retries:
223+
try:
224+
pending.event.clear()
225+
await asyncio.wait_for(pending.event.wait(), timeout=timeout)
226+
except asyncio.TimeoutError:
227+
logger.warning("HITL: Response timeout")
228+
pending.response = {"approved": False, "reason": "timeout"}
229+
break
230+
231+
# Check if we need to ask again
232+
if pending.ask_again:
233+
pending.ask_again = False
234+
retry_message = (
235+
f"I didn't understand. {message} Please say yes or no."
236+
)
237+
tts_complete.clear()
238+
bot_started = False
239+
await task.queue_frame(TTSSpeakFrame(text=retry_message))
240+
try:
241+
await asyncio.wait_for(tts_complete.wait(), timeout=10)
242+
except asyncio.TimeoutError:
243+
logger.warning("HITL: Retry TTS timeout")
244+
logger.info("HITL: Listening for retry response...")
245+
continue
246+
247+
# Response matched (approved or rejected)
248+
break
249+
250+
result = pending.response or {"approved": False, "reason": "no_response"}
251+
logger.info(f"HITL: Confirmation result={result}")
252+
return result
253+
254+
finally:
255+
# Critical: Restore interruptions AFTER HITL response is complete
256+
# and has been delivered to LLM state machine
257+
if user_aggregator and original_config:
258+
logger.info(
259+
f"HITL: Restoring interruptions after confirmation "
260+
f"(result={pending.response})"
261+
)
262+
await _apply_interruption_config(
263+
user_aggregator,
264+
original_config,
265+
has_vad=context.bot.vad_analyzer is not None,
266+
call_sid=context.call_sid or "unknown",
267+
label="hitl_restore",
268+
bot=context.bot,
269+
)
270+
logger.info("HITL: Interruptions restored")
271+
272+
def _cleanup(self, confirmation_id: str):
273+
self._pending.pop(confirmation_id, None)
274+
275+
def _default_message(self, function_name: str, arguments: Dict) -> str:
276+
action = function_name.replace("_", " ")
277+
return f"Should I {action}? Please say yes or no."
278+
279+
280+
_hitl_manager: Optional[BreezeBuddyHITLManager] = None
281+
282+
283+
def get_hitl_manager() -> BreezeBuddyHITLManager:
284+
global _hitl_manager
285+
if _hitl_manager is None:
286+
_hitl_manager = BreezeBuddyHITLManager()
287+
return _hitl_manager

app/ai/voice/agents/breeze_buddy/processors/transcription_gate.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
)
4646
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
4747

48+
from app.ai.voice.agents.breeze_buddy.features.hitl.manager import get_hitl_manager
4849
from app.ai.voice.agents.breeze_buddy.template.types import (
4950
KeywordFilterConfig,
5051
KeywordMatchType,
@@ -187,6 +188,23 @@ async def process_frame(self, frame: Frame, direction: FrameDirection):
187188

188189
# ---- Transcription suppression logic -------------------------
189190
elif isinstance(frame, (TranscriptionFrame, InterimTranscriptionFrame)):
191+
# HITL gate:
192+
# - drop interim frames while HITL is active (avoid turn-start interruptions)
193+
# - consume final transcription via HITL matcher
194+
hitl_manager = get_hitl_manager()
195+
196+
if isinstance(frame, InterimTranscriptionFrame):
197+
if hitl_manager.is_confirmation_active():
198+
logger.debug(
199+
"TranscriptionGate: dropping interim transcription (HITL active)"
200+
)
201+
return
202+
203+
if isinstance(frame, TranscriptionFrame):
204+
consumed = await hitl_manager.consume_transcription_if_hitl(frame.text)
205+
if consumed:
206+
logger.debug("TranscriptionGate: consumed transcription (HITL)")
207+
return
190208
# Mode 1: hard mute — drop unconditionally (both final and interim).
191209
# InterimTranscriptionFrame must also be dropped: TranscriptionUserTurnStartStrategy
192210
# with use_interim=True fires on interim frames, which triggers an interruption even

app/ai/voice/agents/breeze_buddy/template/builder.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -552,8 +552,8 @@ def _build_function_schema(self, func: FlowFunction) -> FlowsFunctionSchema:
552552
Pipecat FlowsFunctionSchema object with unified handler
553553
"""
554554
logger.debug(
555-
f"Building function schema for: {func.name}, "
556-
f"transition_to={func.transition_to}, hooks={func.hooks}"
555+
f"BUILDER: Building function schema for: {func.name}, "
556+
f"transition_to={func.transition_to}, hooks={func.hooks}, hitl={func.hitl_config}"
557557
)
558558

559559
# Get the wrapped unified handler from handler_map
@@ -566,19 +566,29 @@ def _build_function_schema(self, func: FlowFunction) -> FlowsFunctionSchema:
566566
hooks = [hook.model_dump() for hook in func.hooks] if func.hooks else []
567567
logger.debug(f"Using hooks for {func.name}: {hooks}")
568568

569-
# Create a wrapper handler matching FlowsFunctionSchema expected signature.
569+
# Serialize HITL config if present
570+
hitl_config = func.hitl_config.model_dump() if func.hitl_config else None
571+
logger.info(f"BUILDER: hitl_config for {func.name} = {hitl_config}")
572+
570573
# In flows 1.0, ConsolidatedFunctionResult is (FlowResult | None, NodeConfig | None);
571574
# the legacy str-node-name variant of next_node was removed.
575+
# Create a wrapper handler matching FlowsFunctionSchema expected signature
576+
# Signature: (llm_args: Dict[str, Any], flow_manager: FlowManager) -> Awaitable[FlowResult | tuple]
572577
async def wrapper_handler(
573578
llm_args: Dict[str, Any], _flow_manager: FlowManager
574579
) -> FlowResult | tuple[FlowResult | None, NodeConfig | None]:
575580
# Call the wrapped unified transition handler
576581
# The with_context wrapper expects llm_args as first positional arg
582+
logger.info(
583+
f"DEBUG: Calling transition handler with hitl_config={hitl_config}"
584+
)
585+
577586
result = await cast(Callable[..., Any], wrapped_unified_handler)(
578587
llm_args,
579588
transition_to=func.transition_to,
580589
hooks=hooks,
581590
function_name=func.name,
591+
hitl_config=hitl_config,
582592
)
583593
return result
584594

@@ -590,6 +600,9 @@ async def wrapper_handler(
590600
handler=wrapper_handler,
591601
properties=func.properties,
592602
required=func.required,
603+
timeout_secs=(
604+
30.0 if func.hitl_config and func.hitl_config.enabled else None
605+
),
593606
)
594607

595608
def _build_action(self, action: FlowAction) -> Dict[str, Any]:

app/ai/voice/agents/breeze_buddy/template/context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ async def wrapper(*args, **kwargs):
380380
hooks = kwargs.pop("hooks", None)
381381
function_name = kwargs.pop("function_name", None)
382382
function_config = kwargs.pop("function_config", None)
383+
hitl_config = kwargs.pop("hitl_config", None)
383384

384385
is_transition_handler = hooks is not None or function_name is not None
385386
is_global_function_handler = function_config is not None
@@ -415,6 +416,7 @@ async def wrapper(*args, **kwargs):
415416
transition_to=transition_to,
416417
hooks=hooks,
417418
function_name=function_name,
419+
hitl_config=hitl_config,
418420
)
419421
else:
420422
# Action handlers don't need hooks/function_name

0 commit comments

Comments
 (0)