Skip to content

Commit 697b009

Browse files
committed
fix: tighten intent matching and reject invalid query actions
1 parent bab72ba commit 697b009

6 files changed

Lines changed: 122 additions & 23 deletions

File tree

custom_components/ai_hub/conversation.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,21 @@
2929
MATCH_ALL: Literal["*"] = "*"
3030

3131

32+
def _extract_expansion_rule_tokens(rule_value: str) -> set[str]:
33+
"""Extract plain tokens from a simple expansion rule string."""
34+
cleaned = (
35+
rule_value.replace("(", "|")
36+
.replace(")", "|")
37+
.replace("[", "|")
38+
.replace("]", "|")
39+
.replace("<", "|")
40+
.replace(">", "|")
41+
.replace("{", "|")
42+
.replace("}", "|")
43+
)
44+
return {token.strip() for token in cleaned.split("|") if token.strip()}
45+
46+
3247
async def async_setup_entry(
3348
hass: HomeAssistant,
3449
config_entry: ConfigEntry,
@@ -227,6 +242,11 @@ async def _async_handle_local_and_builtin_intents(
227242
and has_explicit_outcome
228243
)
229244
)
245+
if self._is_query_like_utterance(user_input.text):
246+
is_acceptable_type = (
247+
response_type == intent.IntentResponseType.QUERY_ANSWER
248+
and not is_follow_up_prompt
249+
)
230250

231251
if (
232252
not has_error
@@ -300,6 +320,33 @@ async def _async_handle_local_and_builtin_intents(
300320

301321
return None
302322

323+
def _is_query_like_utterance(self, text: str) -> bool:
324+
"""Return whether the utterance matches configured query markers."""
325+
config = self._config_cache.get_config() or {}
326+
expansion_rules = config.get("expansion_rules", {})
327+
if not isinstance(expansion_rules, dict):
328+
return False
329+
330+
query_rule_names = (
331+
"how_many_is",
332+
"what_is",
333+
"which",
334+
"how_is",
335+
"state",
336+
"currently",
337+
)
338+
text_lower = text.lower().strip()
339+
340+
for rule_name in query_rule_names:
341+
rule_value = expansion_rules.get(rule_name)
342+
if not isinstance(rule_value, str) or not rule_value.strip():
343+
continue
344+
tokens = _extract_expansion_rule_tokens(rule_value)
345+
if any(token in text_lower for token in tokens):
346+
return True
347+
348+
return False
349+
303350
async def _async_handle_llm_message(
304351
self,
305352
user_input: conversation.ConversationInput,

custom_components/ai_hub/intents/handlers.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -99,26 +99,28 @@ def should_handle(self, text: str) -> bool:
9999
text_clean = text.strip()
100100
text_lower = text.lower().strip()
101101

102+
area_names, device_types = self._parse_device_and_area(text_lower, global_config)
103+
target_entities = self._match_named_entities(text_lower, device_types or None, area_names)
104+
102105
# 规则1: 检查明确的全局关键词 - HA不支持的功能
103106
global_keywords = global_config.get('global_keywords', [])
104107
has_global_keyword = any(keyword in text_lower for keyword in global_keywords)
105108

106109
# 规则2: 检查显式动作/参数指令
107110
has_action_word = self._has_explicit_action_word(text_lower, global_config)
108111
has_parameter_command = self._has_parameter_command(text, text_lower, global_config)
109-
has_target_hint = self._has_target_hint(text_lower, global_config)
112+
has_resolved_target = bool(device_types or target_entities)
110113
is_short_text = len(text_clean) <= 4
111114

112-
# 关键判断:
113-
# 1. 全局控制直接本地处理
114-
# 2. 显式参数控制 + 明确目标,本地处理以补足中文能力
115-
# 3. 显式开关控制 + 明确目标,本地处理以补足中文泛化开关
116-
should_handle = has_global_keyword or (
117-
has_target_hint and (has_action_word or has_parameter_command)
118-
)
115+
# 本地增强意图必须完整命中可执行控制目标,不能靠模糊猜测。
116+
should_handle = has_resolved_target and (has_action_word or has_parameter_command)
119117

120-
# 对于有动作词的短文本,如果缺少全局关键词,则不处理
121-
if has_action_word and is_short_text and not has_global_keyword and not has_target_hint:
118+
# 全局控制也必须解析出具体设备类型,避免“打开所有”之类半句被错误接管。
119+
if has_global_keyword and not device_types and not target_entities:
120+
should_handle = False
121+
122+
# 对于短文本,如果没有完整命中目标,则不处理。
123+
if (has_action_word or has_parameter_command) and is_short_text and not has_resolved_target:
122124
should_handle = False
123125

124126
_LOGGER.debug("Local intent check: '%s' -> %s", text, should_handle)
@@ -203,14 +205,6 @@ def _has_explicit_action(
203205

204206
return any(action in text_without_state for action in action_words)
205207

206-
def _has_target_hint(self, text_lower: str, global_config: dict) -> bool:
207-
"""Check whether the text contains a concrete area/device hint."""
208-
area_names, device_types = self._parse_device_and_area(text_lower, global_config)
209-
if area_names or device_types:
210-
return True
211-
212-
return bool(self._match_named_entities(text_lower))
213-
214208
def _parse_device_and_area(self, text_lower: str, global_config: dict) -> tuple:
215209
"""解析设备类型和区域."""
216210
area_names = []
@@ -399,7 +393,17 @@ async def _execute_control(
399393
fail_msg=fail_msg,
400394
)
401395

402-
return self._create_response(language, message)
396+
success_results = [
397+
{
398+
"type": "entity",
399+
"name": self._get_device_friendly_name(device_id),
400+
"id": device_id,
401+
}
402+
for device_id in all_devices
403+
if self.hass.states.get(device_id) is not None
404+
]
405+
406+
return self._create_response(language, message, success_results=success_results)
403407

404408
except Exception as e:
405409
message = self._format_response_message('error', error=str(e))
@@ -506,9 +510,20 @@ def _format_failure_message(self, error_count: int, failed_devices: list) -> str
506510
template = failure_config.get('many_devices', '')
507511
return template.format(error_count=len(unique_failed), failed_list=failed_list)
508512

509-
def _create_response(self, language: str, message: str, is_error: bool = False):
513+
def _create_response(
514+
self,
515+
language: str,
516+
message: str,
517+
is_error: bool = False,
518+
success_results: list[dict[str, Any]] | None = None,
519+
):
510520
"""创建响应结果."""
511-
return create_intent_result(language, message, is_error=is_error)
521+
return create_intent_result(
522+
language,
523+
message,
524+
is_error=is_error,
525+
success_results=success_results,
526+
)
512527

513528
# ========== 参数控制方法 ==========
514529

custom_components/ai_hub/intents/response_utils.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,29 @@ def format_response_message(template: str, **kwargs: Any) -> str:
2828
return template.format(**values)
2929

3030

31-
def create_intent_result(language: str, message: str, *, is_error: bool = False) -> dict[str, Any]:
31+
def create_intent_result(
32+
language: str,
33+
message: str,
34+
*,
35+
is_error: bool = False,
36+
success_results: list[dict[str, Any]] | None = None,
37+
failed_results: list[dict[str, Any]] | None = None,
38+
) -> dict[str, Any]:
3239
"""Create the standard local intent service result payload."""
3340
response = intent.IntentResponse(language=language)
3441
if is_error:
3542
response.async_set_error(intent.IntentResponseErrorCode.UNKNOWN, message)
3643
else:
3744
response.async_set_speech(message)
3845

46+
if success_results or failed_results:
47+
response.response_type = intent.IntentResponseType.ACTION_DONE
48+
response.data = {
49+
"success": success_results or [],
50+
"failed": failed_results or [],
51+
"targets": success_results or [],
52+
}
53+
3954
return {
4055
"response": response,
4156
"success": not is_error,

custom_components/ai_hub/manifest.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,5 @@
99
"iot_class": "cloud_polling",
1010
"issue_tracker": "https://github.com/ha-china/ai_hub/issues",
1111
"requirements": ["edge-tts==7.2.7", "aiofiles", "aiohttp"],
12-
"version": "v2026.04.5"
12+
"version": "v2026.04.6"
1313
}

tests/test_api_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,8 @@ def test_local_intent_handler_matches_area_scoped_all_lights_commands():
237237

238238
assert handler.should_handle("打开客厅所有的灯") is True
239239
assert handler.should_handle("关闭客厅所有的灯") is True
240+
assert handler.should_handle("打开") is False
241+
assert handler.should_handle("关闭") is False
240242

241243

242244
def test_error_message_extraction():

tests/test_conversation.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from custom_components.ai_hub.providers.openai_compatible import OpenAICompatibleProvider
3030
from custom_components.ai_hub.providers.ollama_compatible import OllamaCompatibleProvider
3131
from custom_components.ai_hub.http import resolve_provider_name
32+
from custom_components.ai_hub.intents.response_utils import create_intent_result
3233

3334

3435
@pytest.fixture
@@ -330,6 +331,25 @@ def test_keeps_backward_compatibility_with_legacy_nested_structure(self):
330331
assert cache.get_error_message("llm_config_error") == "配置错误"
331332

332333

334+
class TestLocalIntentResponse:
335+
"""Tests for local intent structured responses."""
336+
337+
def test_create_intent_result_can_include_success_results(self):
338+
"""Local intent responses should preserve structured success targets."""
339+
result = create_intent_result(
340+
"zh-CN",
341+
"已打开书房灯",
342+
success_results=[
343+
{"type": "entity", "name": "书房灯", "id": "light.study"}
344+
],
345+
)
346+
347+
response = result["response"]
348+
assert response.data["success"] == [
349+
{"type": "entity", "name": "书房灯", "id": "light.study"}
350+
]
351+
352+
333353
class TestEntityStreamingSelection:
334354
"""Tests for provider streaming selection."""
335355

0 commit comments

Comments
 (0)