Skip to content

Commit adcbe11

Browse files
committed
feat: customize intent in workflows and update tests
- Update prompts to explicitly handle user-provided intent and prevent inference - Refactor facet generator tests for multiple items and remove duplicates - Fix function naming in question generator tests - Update main logic and models to support intent customization
1 parent 7ead117 commit adcbe11

File tree

11 files changed

+91
-150
lines changed

11 files changed

+91
-150
lines changed

mcp/facet/facet_generator.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from model import context
44

55

6-
async def generate_facets_from_items(
6+
async def generate_facets(
77
facet_inputs_json: str, sql_dialect: str = "postgresql"
88
) -> str:
99
"""
@@ -26,25 +26,20 @@ async def generate_facets_from_items(
2626
final_facets = []
2727

2828
for item in item_list:
29-
question = item["question"]
30-
# Support both 'sql_snippet' (preferred) and 'facet' (legacy) keys
31-
facet_text = item.get("sql_snippet", item.get("facet"))
32-
if not facet_text:
33-
# Skip malformed items or raise error? For now, we might want to be robust
34-
# But if both are missing, we might have an issue.
35-
# Let's assume validation happened or we just let it fail later if None.
36-
# Actually, to be safe and avoid KeyError if strict logic elsewhere:
37-
raise KeyError("Each item must have a 'sql_snippet' or 'facet' key.")
3829

39-
intent = item.get(
40-
"intent", question
41-
) # Use provided intent or fallback to question
30+
sql_snippet = item.get("sql_snippet")
31+
if not sql_snippet:
32+
return '{"error": "Each item must have a \'sql_snippet\' key."}'
4233

43-
# 1. Extract value phrases from the question
44-
phrases = await parameterizer.extract_value_phrases(nl_query=question)
34+
intent = item.get("intent")
35+
if not intent:
36+
return '{"error": "Each item must have an \'intent\' key."}'
37+
38+
# 1. Extract value phrases from the intent (used as nl_query)
39+
phrases = await parameterizer.extract_value_phrases(nl_query=intent)
4540

4641
# 2. Generate the manifest
47-
manifest = question
42+
manifest = intent
4843
# Sort keys by length descending to replace longer phrases first
4944
sorted_phrases = sorted(phrases.keys(), key=len, reverse=True)
5045
for phrase in sorted_phrases:
@@ -54,12 +49,12 @@ async def generate_facets_from_items(
5449

5550
# 3. Parameterize the SQL and Intent
5651
parameterized_result = parameterizer.parameterize_sql_and_intent(
57-
phrases, facet_text, intent, db_dialect=db_dialect
52+
phrases, sql_snippet, intent, db_dialect=db_dialect
5853
)
5954

6055
# 4. Assemble the final facet object
6156
facet = context.Facet(
62-
sql_snippet=facet_text,
57+
sql_snippet=sql_snippet,
6358
intent=intent,
6459
manifest=manifest,
6560
parameterized=context.ParameterizedFacet(

mcp/main.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ async def generate_sql_pairs(
3535
has a "question" and a "sql" key.
3636
Example: '[{"question": "...", "sql": "..."}]'
3737
"""
38-
return await question_generator.generate_sql_pairs_from_schema(
38+
return await question_generator.generate_sql_pairs(
3939
db_schema, context, table_names, sql_dialect
4040
)
4141

@@ -45,20 +45,20 @@ async def generate_templates(
4545
template_inputs_json: str, sql_dialect: str = "postgresql"
4646
) -> str:
4747
"""
48-
Generates final templates from a list of user-approved question, SQL statement, and optional intent.
48+
Generates final templates from a list of user-approved template question, template SQL statement, and optional template intent.
4949
5050
Args:
5151
template_inputs_json: A JSON string representing a list of dictionaries (template inputs),
5252
where each dictionary has "question", "sql", and optional "intent" keys.
5353
Example (with intent): '[{"question": "How many users?", "sql": "SELECT count(*) FROM users", "intent": "Count total users"}]'
5454
Example (default intent): '[{"question": "List all items", "sql": "SELECT * FROM items"}]'
5555
sql_dialect: The SQL dialect to use for parameterization. Accepted
56-
values are 'postgresql', 'mysql', or 'googlesql'.
56+
values are 'postgresql' (default), 'mysql', or 'googlesql'.
5757
5858
Returns:
5959
A JSON string representing a ContextSet object.
6060
"""
61-
return await template_generator.generate_templates_from_items(
61+
return await template_generator.generate_templates(
6262
template_inputs_json, sql_dialect
6363
)
6464

@@ -68,20 +68,19 @@ async def generate_facets(
6868
facet_inputs_json: str, sql_dialect: str = "postgresql"
6969
) -> str:
7070
"""
71-
Generates final facets from a list of user-approved question, SQL snippet, and optional intent.
71+
Generates final facets from a list of user-approved facet intent and facet SQL snippet.
7272
7373
Args:
7474
facet_inputs_json: A JSON string representing a list of dictionaries (facet inputs),
75-
where each dictionary has "question", "sql_snippet", and optional "intent".
76-
Example (with intent): '[{"question": "expensive items", "sql_snippet": "price > 1000", "intent": "Filter by high price"}]'
77-
Example (default intent): '[{"question": "active users", "sql_snippet": "status = 'active'"}]'
75+
where each dictionary has "intent" and "sql_snippet".
76+
Example: '[{"intent": "high price", "sql_snippet": "price > 1000"}]'
7877
sql_dialect: The SQL dialect to use for parameterization. Accepted
79-
values are 'postgresql', 'mysql', or 'googlesql'.
78+
values are 'postgresql' (default), 'mysql', or 'googlesql'.
8079
8180
Returns:
8281
A JSON string representing a ContextSet object.
8382
"""
84-
return await facet_generator.generate_facets_from_items(
83+
return await facet_generator.generate_facets(
8584
facet_inputs_json, sql_dialect
8685
)
8786

@@ -211,19 +210,19 @@ def generate_upload_url(
211210

212211
@mcp.prompt
213212
def generate_bulk_templates() -> str:
214-
"""Initiates a guided workflow to generate Question/SQL pair templates."""
213+
"""Initiates a guided workflow to automatically generate templates based on the database schema."""
215214
return prompts.GENERATE_BULK_TEMPLATES_PROMPT
216215

217216

218217
@mcp.prompt
219218
def generate_targeted_templates() -> str:
220-
"""Initiates a guided workflow to generate specific Question/SQL pair templates."""
219+
"""Initiates a guided workflow to generate specific templates based on the user's input."""
221220
return prompts.GENERATE_TARGETED_TEMPLATES_PROMPT
222221

223222

224223
@mcp.prompt
225224
def generate_targeted_facets() -> str:
226-
"""Initiates a guided workflow to generate specific Phrase/SQL facet pair templates."""
225+
"""Initiates a guided workflow to generate specific facets based on the user's input."""
227226
return prompts.GENERATE_TARGETED_FACETS_PROMPT
228227

229228

mcp/model/context.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class ParameterizedFacet(BaseModel):
3333
parameterized_sql_snippet: str = Field(
3434
...,
3535
description="The SQL facet with placeholders (eg., ).",
36+
# "fragment" is deprecated, keep alias for backward compatibility
3637
validation_alias=AliasChoices(
3738
"parameterized_sql_snippet", "parameterized_fragment"
3839
),
@@ -48,7 +49,10 @@ class Facet(BaseModel):
4849
sql_snippet: str = Field(
4950
...,
5051
description="The corresponding, complete SQL facet.",
51-
validation_alias=AliasChoices("sql_snippet", "fragment"),
52+
# "fragment" is deprecated, keep alias for backward compatibility
53+
validation_alias=AliasChoices(
54+
"sql_snippet", "fragment"
55+
),
5256
)
5357
intent: str = Field(..., description="The user's specific intent.")
5458
manifest: str = Field(
@@ -66,5 +70,8 @@ class ContextSet(BaseModel):
6670
facets: List[Facet] | None = Field(
6771
None,
6872
description="A list of SQL facets.",
69-
validation_alias=AliasChoices("facets", "fragments"),
73+
# "fragments" is deprecated, keep alias for backward compatibility
74+
validation_alias=AliasChoices(
75+
"facets", "fragments"
76+
),
7077
)

mcp/prompts/bulk_templates.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
GENERATE_BULK_TEMPLATES_PROMPT = textwrap.dedent(
44
"""
5-
**Workflow for Generating Question/SQL Pair Templates**
5+
**Workflow for Automatically Generating Templates**
66
77
1. **Discover and Select Database:**
88
- Find all connected databases from the MCP Toolbox and `tools.yaml`.

mcp/prompts/targeted_facets.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,23 @@
22

33
GENERATE_TARGETED_FACETS_PROMPT = textwrap.dedent(
44
"""
5-
**Workflow for Generating Targeted Phrase/SQL Facet Pair Templates**
5+
**Workflow for Generating Targeted Facets**
66
77
1. **User Input Loop:**
8-
- Ask the user to provide a natural language phrase and its corresponding SQL facet.
9-
- **Optionally**, ask if they want to provide a specific "intent" for this pair. If not provided, the phrase will be used as the intent.
10-
- After capturing the pair, ask the user if they would like to add another one.
8+
- Ask the user to provide an intent and its corresponding SQL snippet.
9+
- **Important:** Do not infer the intent or SQL snippet. Wait for the user to provide them.
10+
- After capturing the intent and SQL snippet pair, ask the user if they would like to add another one.
1111
- Continue this loop until the user indicates they have no more pairs to add.
1212
1313
2. **Review and Confirmation:**
14-
- Present the complete list of user-provided Phrase/SQL facet pairs for confirmation.
15-
- **Use the following format for each pair:**
16-
**Pair [Number]**
17-
**Phrase:** [The natural language phrase]
18-
**Facet:**
14+
- Present the complete list of user-provided Intent/SQL snippet pairs for confirmation.
15+
- **Use the following format for each facet:**
16+
**Facet [Number]**
17+
**Intent:** [The intent]
18+
**SQL snippet:**
1919
```sql
20-
[The SQL facet, properly formatted]
20+
[The SQL snippet, properly formatted]
2121
```
22-
**Intent:** [The intent, if provided. Otherwise "Same as Phrase"]
2322
- Ask if any modifications are needed. If so, work with the user to refine the pairs.
2423
2524
3. **Final Facet Generation:**

mcp/prompts/targeted_templates.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,19 @@
22

33
GENERATE_TARGETED_TEMPLATES_PROMPT = textwrap.dedent(
44
"""
5-
**Workflow for Generating Targeted Question/SQL Pair Templates**
5+
**Workflow for Generating Targeted Templates**
66
77
1. **User Input Loop:**
88
- Ask the user to provide a natural language question and its corresponding SQL query.
99
- **Optionally**, ask if they want to provide a specific "intent" for this pair. If not provided, the question will be used as the intent.
10-
- After capturing the pair, ask the user if they would like to add another one.
11-
- Continue this loop until the user indicates they have no more pairs to add.
10+
- **Important:** Do not infer the question or SQL query. Wait for the user to provide them.
11+
- After capturing the inputs for a template, ask the user if they would like to add another one.
12+
- Continue this loop until the user indicates they have no more to add.
1213
1314
2. **Review and Confirmation:**
1415
- Present the complete list of user-provided Question/SQL pairs for confirmation.
1516
- **Use the following format for each pair:**
16-
**Pair [Number]**
17+
**Template [Number]**
1718
**Question:** [The natural language question]
1819
**SQL:**
1920
```sql

mcp/template/question_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class QuestionSQLPairs(BaseModel):
2222
pairs: List[QuestionSQLPair]
2323

2424

25-
async def generate_sql_pairs_from_schema(
25+
async def generate_sql_pairs(
2626
db_schema: str,
2727
context: str | None = None,
2828
table_names: List[str] | None = None,

mcp/template/template_generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from model import context
44

55

6-
async def generate_templates_from_items(
6+
async def generate_templates(
77
template_inputs_json: str, sql_dialect: str = "postgresql"
88
) -> str:
99
"""

0 commit comments

Comments
 (0)