Skip to content

Commit ad8c9a7

Browse files
committed
fix: reliable priority extraction and rewritten intent prompt
Priority extraction fix: - LLM now returns *_mentioned booleans alongside *_priority values. Post-processing trusts priority only when _mentioned=true; resets to "medium" otherwise. This prevents the LLM (qwen2.5:7b) from inferring priorities from use-case type rather than explicit user statements. The SLO profiles already handle use-case-appropriate targets. Prompt rewrite (for smaller LLMs): - Replace verbose prose prompt with a short, directive format using ordered pattern-matching rules. The prompt is now self-contained (schema embedded inline) so INTENT_EXTRACTION_SCHEMA constant and the schema_description parameter on extract_structured_data() are removed. - Remove experience_class, complexity_priority, and additional_context from the LLM prompt. experience_class is inferred deterministically from use_case in post-processing; complexity_priority and additional_context were never consumed downstream. Post-processing hardening: - Case-insensitive normalization for domain_specialization, experience_class, and *_mentioned booleans (handles string "True"). - Lowercase use_case before alias/fuzzy lookup so mixed-case LLM responses like "Text_Summarization" are handled correctly. - Add logger.warning when an unrecognized use_case cannot be resolved by alias map or fuzzy match. - Priority value aliases (e.g. "very_high" -> "high") applied before validation. - Remove stale complexity_priority from test helper _base_intent(). - Add 4 unit tests for case-insensitive normalization. Assisted-by: Claude <noreply@anthropic.com> Signed-off-by: Andre Fredette <afredette@redhat.com>
1 parent cb1a6fa commit ad8c9a7

6 files changed

Lines changed: 167 additions & 142 deletions

File tree

src/planner/intent_extraction/extractor.py

Lines changed: 64 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
import difflib
44
import logging
5+
import re
56
from datetime import datetime
67
from pathlib import Path
78
from typing import get_args
89

910
from planner.llm.ollama_client import OllamaClient
10-
from planner.llm.prompts import INTENT_EXTRACTION_SCHEMA, build_intent_extraction_prompt
11+
from planner.llm.prompts import build_intent_extraction_prompt
1112
from planner.shared.schemas import ConversationMessage, DeploymentIntent
1213

1314
logger = logging.getLogger(__name__)
@@ -76,15 +77,13 @@ def extract_intent(
7677
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
7778
prompt_file = PROMPTS_DIR / f"intent_extraction_{timestamp}.txt"
7879

79-
full_prompt_with_schema = f"{prompt}\n\n{INTENT_EXTRACTION_SCHEMA}"
80-
8180
with open(prompt_file, "w") as f:
8281
f.write("=" * 80 + "\n")
8382
f.write("INTENT EXTRACTION PROMPT\n")
8483
f.write(f"Generated: {datetime.now().isoformat()}\n")
8584
f.write(f"User Message: {user_message}\n")
8685
f.write("=" * 80 + "\n\n")
87-
f.write(full_prompt_with_schema)
86+
f.write(prompt)
8887
f.write("\n\n" + "=" * 80 + "\n")
8988
f.write("Copy everything above this line to test in other LLMs\n")
9089
f.write("=" * 80 + "\n")
@@ -93,8 +92,6 @@ def extract_intent(
9392
logger.info("=" * 80)
9493
logger.info("[FULL INTENT EXTRACTION PROMPT - START]")
9594
logger.info(prompt)
96-
logger.info("[SCHEMA BEING USED]")
97-
logger.info(INTENT_EXTRACTION_SCHEMA)
9895
logger.info("[FULL INTENT EXTRACTION PROMPT - END]")
9996
logger.info(f"💾 Prompt saved to: {prompt_file}")
10097
logger.info("=" * 80)
@@ -103,15 +100,14 @@ def extract_intent(
103100
# Extract structured data from LLM
104101
extracted = self.llm_client.extract_structured_data(
105102
prompt,
106-
INTENT_EXTRACTION_SCHEMA,
107103
temperature=0.3, # Lower temperature for more consistent extraction
108104
)
109105

110106
# Log extracted intent
111107
logger.info(f"[EXTRACTED INTENT] {extracted}")
112108

113109
# Validate and parse into Pydantic model
114-
intent = self._parse_extracted_intent(extracted)
110+
intent = self._parse_extracted_intent(extracted, user_message)
115111
logger.info(f"Extracted intent: use_case={intent.use_case}, users={intent.user_count}")
116112

117113
return intent
@@ -120,12 +116,13 @@ def extract_intent(
120116
logger.error(f"Failed to extract intent: {e}")
121117
raise ValueError(f"Intent extraction failed: {e}") from e
122118

123-
def _parse_extracted_intent(self, raw_data: dict) -> DeploymentIntent:
119+
def _parse_extracted_intent(self, raw_data: dict, user_message: str = "") -> DeploymentIntent:
124120
"""
125121
Parse and validate raw LLM output into DeploymentIntent.
126122
127123
Args:
128124
raw_data: Raw dict from LLM
125+
user_message: Original user message for priority validation
129126
130127
Returns:
131128
Validated DeploymentIntent
@@ -134,20 +131,21 @@ def _parse_extracted_intent(self, raw_data: dict) -> DeploymentIntent:
134131
ValueError: If data is invalid
135132
"""
136133
# Handle common LLM mistakes
137-
cleaned_data = self._clean_llm_output(raw_data)
134+
cleaned_data = self._clean_llm_output(raw_data, user_message)
138135

139136
try:
140137
return DeploymentIntent(**cleaned_data)
141138
except Exception as e:
142139
logger.error(f"Failed to parse intent from: {cleaned_data}")
143140
raise ValueError(f"Invalid intent data: {e}") from e
144141

145-
def _clean_llm_output(self, data: dict) -> dict:
142+
def _clean_llm_output(self, data: dict, user_message: str = "") -> dict:
146143
"""
147144
Clean common LLM output mistakes.
148145
149146
Args:
150147
data: Raw LLM output
148+
user_message: Original user message for priority validation
151149
152150
Returns:
153151
Cleaned data dict
@@ -161,7 +159,8 @@ def _clean_llm_output(self, data: dict) -> dict:
161159
cleaned["use_case"] = cleaned["use_case"].split("|")[0].strip()
162160

163161
# Normalize hallucinated use_case values
164-
use_case = cleaned.get("use_case", "")
162+
use_case = cleaned.get("use_case", "").lower()
163+
cleaned["use_case"] = use_case
165164
valid_use_cases = list(get_args(DeploymentIntent.model_fields["use_case"].annotation))
166165
if use_case not in valid_use_cases:
167166
mapped = _USE_CASE_ALIASES.get(use_case)
@@ -173,6 +172,14 @@ def _clean_llm_output(self, data: dict) -> dict:
173172
if close:
174173
logger.info("Fuzzy-matched use_case '%s' -> '%s'", use_case, close[0])
175174
cleaned["use_case"] = close[0]
175+
else:
176+
logger.warning(
177+
"Unrecognized use_case '%s' — no alias or fuzzy match found", use_case
178+
)
179+
180+
# Normalize experience_class to lowercase if provided by LLM
181+
if "experience_class" in cleaned and isinstance(cleaned["experience_class"], str):
182+
cleaned["experience_class"] = cleaned["experience_class"].lower()
176183

177184
# Infer experience_class if not provided
178185
if "experience_class" not in cleaned or not cleaned.get("experience_class"):
@@ -203,8 +210,6 @@ def _clean_llm_output(self, data: dict) -> dict:
203210
# Fix user_count if it's a descriptive string instead of integer
204211
if "user_count" in cleaned and isinstance(cleaned["user_count"], str):
205212
# Extract integer from strings like "thousands of users (estimated: 5,000 - 10,000)"
206-
import re
207-
208213
user_count_str = cleaned["user_count"]
209214

210215
# Try to find numbers with commas or ranges
@@ -249,31 +254,46 @@ def _clean_llm_output(self, data: dict) -> dict:
249254
f"Could not parse user_count from '{user_count_str}', defaulting to 1000"
250255
)
251256

252-
# Ensure domain_specialization is a list
257+
# Ensure domain_specialization is a list with lowercase values
253258
if "domain_specialization" in cleaned:
254259
if isinstance(cleaned["domain_specialization"], str):
255-
# Convert single string to list
256-
cleaned["domain_specialization"] = [cleaned["domain_specialization"]]
257-
elif "|" in str(cleaned.get("domain_specialization", "")):
258-
# Handle "general|code" format
260+
if "|" in cleaned["domain_specialization"]:
261+
# Handle "general|code" format
262+
cleaned["domain_specialization"] = [
263+
d.strip().lower() for d in cleaned["domain_specialization"].split("|")
264+
]
265+
else:
266+
# Convert single string to list
267+
cleaned["domain_specialization"] = [cleaned["domain_specialization"].lower()]
268+
elif isinstance(cleaned["domain_specialization"], list):
259269
cleaned["domain_specialization"] = [
260-
d.strip() for d in cleaned["domain_specialization"].split("|")
270+
d.lower() if isinstance(d, str) else d for d in cleaned["domain_specialization"]
261271
]
262272

263273
# Ensure priority fields have valid values (default to "medium" if invalid/missing)
264274
valid_priorities = ["low", "medium", "high"]
275+
# Map common LLM variations to valid values before discarding
276+
_priority_aliases = {
277+
"very_high": "high",
278+
"very high": "high",
279+
"critical": "high",
280+
"very_low": "low",
281+
"very low": "low",
282+
"none": "low",
283+
}
265284
for priority_field in [
266285
"accuracy_priority",
267286
"cost_priority",
268287
"latency_priority",
269-
"complexity_priority",
270288
]:
271289
if priority_field in cleaned:
272290
# Normalize to lowercase and validate
273291
priority_value = str(cleaned[priority_field]).lower().strip()
292+
priority_value = _priority_aliases.get(priority_value, priority_value)
274293
if priority_value not in valid_priorities:
275294
logger.info(
276-
f"Invalid {priority_field}='{cleaned[priority_field]}', defaulting to 'medium'"
295+
f"Invalid {priority_field}='{cleaned[priority_field]}', "
296+
f"defaulting to 'medium'"
277297
)
278298
cleaned[priority_field] = "medium"
279299
else:
@@ -282,6 +302,27 @@ def _clean_llm_output(self, data: dict) -> dict:
282302
# Field not provided by LLM, default to medium
283303
cleaned[priority_field] = "medium"
284304

305+
# Enforce explicit-only priority extraction.
306+
# The LLM returns *_mentioned booleans alongside *_priority values.
307+
# Trust the LLM's priority only when it reports the user mentioned the
308+
# topic. Otherwise reset to medium — the LLM is likely inferring from
309+
# use-case type rather than from what the user said. The SLO profiles
310+
# already handle use-case-appropriate targets.
311+
for prefix in ("accuracy", "cost", "latency"):
312+
mentioned_key = f"{prefix}_mentioned"
313+
priority_key = f"{prefix}_priority"
314+
mentioned_raw = cleaned.pop(mentioned_key, False)
315+
mentioned = (
316+
str(mentioned_raw).lower() == "true"
317+
if isinstance(mentioned_raw, str)
318+
else bool(mentioned_raw)
319+
)
320+
if not mentioned and cleaned.get(priority_key, "medium") != "medium":
321+
logger.info(
322+
f"Resetting {priority_key} from '{cleaned[priority_key]}' to 'medium' "
323+
f"(LLM reported {mentioned_key}=false)"
324+
)
325+
cleaned[priority_key] = "medium"
285326
# Remove any unexpected fields that aren't in the schema
286327
valid_fields = DeploymentIntent.model_fields.keys()
287328
cleaned = {k: v for k, v in cleaned.items() if k in valid_fields}
@@ -302,11 +343,7 @@ def infer_missing_fields(self, intent: DeploymentIntent) -> DeploymentIntent:
302343
if intent.domain_specialization == ["general"]:
303344
if intent.use_case in ["code_generation_detailed", "code_completion"]:
304345
intent.domain_specialization = ["general", "code"]
305-
elif intent.use_case == "translation" or (
306-
"multilingual" in intent.additional_context.lower()
307-
if intent.additional_context
308-
else False
309-
):
346+
elif intent.use_case == "translation":
310347
intent.domain_specialization = ["general", "multilingual"]
311348

312349
return intent

src/planner/llm/ollama_client.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -127,29 +127,19 @@ def generate_completion(
127127
def extract_structured_data(
128128
self,
129129
prompt: str,
130-
schema_description: str,
131130
temperature: float = 0.3,
132131
) -> dict[str, Any]:
133132
"""
134133
Extract structured data from prompt using JSON format.
135134
136135
Args:
137-
prompt: Input prompt describing what to extract
138-
schema_description: Description of expected JSON schema
136+
prompt: Input prompt (should include schema and instructions)
139137
temperature: Lower temperature for more consistent extraction
140138
141139
Returns:
142140
Parsed JSON dict
143141
"""
144-
full_prompt = f"""{prompt}
145-
146-
{schema_description}
147-
148-
Return ONLY valid JSON matching the schema above. Do not include any explanation or additional text."""
149-
150-
response_text = self.generate_completion(
151-
full_prompt, format_json=True, temperature=temperature
152-
)
142+
response_text = self.generate_completion(prompt, format_json=True, temperature=temperature)
153143

154144
try:
155145
result: dict[str, Any] = json.loads(response_text)

0 commit comments

Comments
 (0)