diff --git a/mcp/main.py b/mcp/main.py index 6f144b1..5a976a7 100644 --- a/mcp/main.py +++ b/mcp/main.py @@ -3,6 +3,7 @@ import textwrap from template import question_generator, template_generator from facet import facet_generator +from value_index import generator as vi_generator from model import context import datetime import os @@ -87,6 +88,39 @@ async def generate_facets( ) +@mcp.tool +async def generate_value_indices( + table_name: str, + column_name: str, + concept_type: str, + match_function: str, + db_engine: str, + db_version: Optional[str] = None, + description: Optional[str] = None, +) -> str: + """ + Generates a single Value Index configuration. + + Args: + table_name: The name of the table. + column_name: The name of the column. + concept_type: The semantic type (e.g., 'City'). + match_function: The match function to use (e.g., 'EXACT_MATCH_STRINGS'). + db_engine: The database engine (postgresql, mysql, etc.). + db_version: The database version (optional). + Returns: + A JSON string representing a ContextSet object with the new value index. + """ + if db_version and not db_version.strip(): + db_version = None + + # Ensure we pass a string, defaulting to 'postgresql' if None is provided. + dialect = db_engine if db_engine is not None else "postgresql" + return vi_generator.generate_value_index( + table_name, column_name, concept_type, match_function, dialect, db_version, description + ) + + @mcp.tool def save_context_set( context_set_json: str, @@ -129,18 +163,18 @@ def attach_context_set( Attaches a ContextSet to an existing JSON file. This tool reads an existing JSON file containing a ContextSet, - appends new templates/facets to it, and writes the updated ContextSet + appends new templates/facets/value_indices to it, and writes the updated ContextSet back to the file. Exceptions are propagated to the caller. Args: - context_set_json: The JSON string output from the `generate_templates` or `generate_facets` tool. + context_set_json: The JSON string output from the generation tools. file_path: The **absolute path** to the existing template file. Returns: A confirmation message with the path to the updated file. """ - existing_content_dict = {"templates": [], "facets": []} + existing_content_dict = {"templates": [], "facets": [], "value_indices": []} if os.path.getsize(file_path) > 0: with open(file_path, "r") as f: existing_content_dict = json.load(f) @@ -159,10 +193,15 @@ def attach_context_set( if new_context.facets: existing_context.facets.extend(new_context.facets) + if existing_context.value_indices is None: + existing_context.value_indices = [] + if new_context.value_indices: + existing_context.value_indices.extend(new_context.value_indices) + with open(file_path, "w") as f: json.dump(existing_context.model_dump(), f, indent=2) - return f"Successfully attached templates to {file_path}" + return f"Successfully attached context to {file_path}" @mcp.tool @@ -408,6 +447,72 @@ def generate_targeted_facets() -> str: """ ) +@mcp.prompt +def generate_targeted_value_indices() -> str: + """Initiates a guided workflow to generate specific Value Index configurations.""" + return textwrap.dedent( + """ + **Workflow for Generating Targeted Value Indices** + + 1. **Database Configuration:** + - Ask the user for the **Database Engine** (e.g., `postgresql`, `mysql`, `spanner`).. + - Ask the user for the **Database Version**. + - Tell them they can enter default to use the default version. + + 2. **User Input Loop:** + - Ask the user to provide the following details for a value index: + - **Table Name** + - **Column Name** + - **Concept Type** (e.g., "City", "Product ID") + - **Match Function** (e.g., `EXACT_MATCH_STRINGS`, `FUZZY_MATCH_STRINGS`) + - **Description** (optional): A description of the value index. + - After capturing the details, ask the user if they would like to add another one. + - Continue this loop until the user indicates they have no more indices to add. + + 3. **Review and Confirmation:** + - Present the complete list of user-provided index definitions for confirmation. + - **Use the following format for each index:** + **Index [Number]** + **Table:** [Table Name] + **Column:** [Column Name] + **Concept:** [Concept Type] + **Function:** [Match Function] + **Description:** [Description] + - Ask if any modifications are needed. If so, work with the user to refine the list. + + 4. **Final Generation:** + - Once approved, call the `generate_value_indices` tool for each index defined. + - **Important:** Pass the `db_engine` and `db_version` collected in Step 1 to the tool. + - Combine all generated Value Index configurations into a single JSON structure (ContextSet). + + 5. **Save Value Indices:** + - Ask the user to choose one of the following options: + 1. Create a new context set file. + 2. Append value indices to an existing context set file. + + - **If creating a new file:** + - You will need to ask the user for the database instance and database name to create the filename. + - Call the `save_context_set` tool. You will need to provide the database instance, database name, the JSON content from the previous step, and the root directory where the Gemini CLI is running. + + - **If appending to an existing file:** + - Ask the user to provide the path to the existing context set file. + - Call the `attach_context_set` tool with the JSON content and the absolute file path. + + 6. **Generate Upload URL (Optional):** + - After the file is saved, ask the user if they want to generate a URL to upload the context set file. + - If the user confirms, you must collect the necessary database context from them. This includes: + - **Database Type:** 'alloydb', 'cloudsql', or 'spanner'. + - **Project ID:** The Google Cloud project ID. + - **And depending on the database type:** + - For 'alloydb': Location and Cluster ID. + - For 'cloudsql': Instance ID. + - For 'spanner': Instance ID and Database ID. + - Once you have the required information, call the `generate_upload_url` tool to provide the upload URL to the user. + + Start the workflow. + """ + ) + if __name__ == "__main__": mcp.run() # Uses STDIO transport by default diff --git a/mcp/model/context.py b/mcp/model/context.py index 8c40272..e12051a 100644 --- a/mcp/model/context.py +++ b/mcp/model/context.py @@ -56,9 +56,17 @@ class Facet(BaseModel): ) parameterized: ParameterizedFacet +class ValueIndex(BaseModel): + """Represents a single, complete value index.""" + + query: str = Field(..., description="The parameterized SQL query (using $value).") + concept_type: str = Field( + ..., description="The semantic type (e.g., 'City', 'Product ID')." + ) + description: Optional[str] = Field(None, description="Optional description.") class ContextSet(BaseModel): - """A set of templates and facets.""" + """A set of templates, facets and value indexes.""" templates: Optional[List[Template]] = Field( None, description="A list of complete templates." @@ -68,4 +76,8 @@ class ContextSet(BaseModel): description="A list of SQL facets.", validation_alias=AliasChoices("facets", "fragments"), ) + value_indices: Optional[List[ValueIndex]] = Field( + None, + description="A list of value index.", + ) diff --git a/mcp/pyproject.toml b/mcp/pyproject.toml index fff2a84..2260c21 100644 --- a/mcp/pyproject.toml +++ b/mcp/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "db-context-enrichment" -version = "0.2.0" +version = "0.3.0" description = "A FastMCP server for generating natural language to SQL templates from database schemas." readme = "README.md" requires-python = ">=3.12" @@ -21,7 +21,7 @@ test = [ [tool.setuptools] py-modules = ["main"] -packages = ["template", "facet", "common", "model"] +packages = ["template", "facet", "value_index", "common", "model"] [[tool.uv.index]] url = "https://pypi.org/simple" diff --git a/mcp/tests/value_index/__init__.py b/mcp/tests/value_index/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/mcp/tests/value_index/generator_test.py b/mcp/tests/value_index/generator_test.py new file mode 100644 index 0000000..b96fe44 --- /dev/null +++ b/mcp/tests/value_index/generator_test.py @@ -0,0 +1,101 @@ +import pytest +import json +from unittest.mock import patch +from value_index.generator import generate_value_index +from value_index import match_templates +from model.context import ContextSet + + +def test_generate_value_index_postgres_default(): + # Test generating a standard Postgres exact match (uses default version) + result_json = generate_value_index( + table_name="users", + column_name="country_code", + concept_type="Country", + match_function="EXACT_MATCH_STRINGS", + db_engine="postgresql", + ) + + # Validate JSON structure and content + context_set = ContextSet.model_validate_json(result_json) + + # Check List Structure + assert context_set.value_indices is not None + assert len(context_set.value_indices) == 1 + + # Check ValueIndex Object + vi = context_set.value_indices[0] + assert vi.concept_type == "Country" + + # Check SQL Parameterization + # Should replace {table} and {column} but keep $value + assert "users.country_code" in vi.query + assert "FROM users T" in vi.query + assert "$value" in vi.query + + +def test_generate_value_index_invalid_dialect(): + # Test error handling for unknown database engine + with pytest.raises(ValueError, match="Dialect 'invalid_db' not supported"): + generate_value_index( + table_name="t", + column_name="c", + concept_type="C", + match_function="EXACT_MATCH_STRINGS", + db_engine="invalid_db" + ) + + +def test_generate_value_index_invalid_function(): + # Test error handling for unknown match function + with pytest.raises(ValueError, match="Match function 'BAD_FUNC' not found"): + generate_value_index( + table_name="t", + column_name="c", + concept_type="C", + match_function="BAD_FUNC", + db_engine="postgresql" + ) + + +def test_generate_value_index_specific_version_success(): + # Mock the registry to test specific version logic without relying on real data + fake_registry = { + "postgresql": { + "99.0": { + "TEST_FUNC": { + "sql_template": "SELECT {table}.{column} FROM {table} WHERE version=99", + "description": "Test Description" + } + } + } + } + + # Inject the fake registry into the module + with patch.dict(match_templates.MATCH_TEMPLATES, fake_registry, clear=True): + result_json = generate_value_index( + table_name="users", + column_name="age", + concept_type="Age", + match_function="TEST_FUNC", + db_engine="postgresql", + db_version="99.0" + ) + + context_set = ContextSet.model_validate_json(result_json) + vi = context_set.value_indices[0] + assert "WHERE version=99" in vi.query + assert "users.age" in vi.query + + +def test_generate_value_index_specific_version_not_found(): + # Verify strict version checking: Should raise Error, NOT fallback to default + with pytest.raises(ValueError, match="Version '999.0' not found"): + generate_value_index( + table_name="t", + column_name="c", + concept_type="C", + match_function="EXACT_MATCH_STRINGS", + db_engine="postgresql", + db_version="999.0" + ) \ No newline at end of file diff --git a/mcp/value_index/generator.py b/mcp/value_index/generator.py new file mode 100644 index 0000000..e828328 --- /dev/null +++ b/mcp/value_index/generator.py @@ -0,0 +1,54 @@ +from typing import Optional +from model import context +from value_index import match_templates + + +def generate_value_index( + table_name: str, + column_name: str, + concept_type: str, + match_function: str, + db_engine: str, + db_version: Optional[str] = None, + description: Optional[str] = None, +) -> str: + """ + Generates a single Value Index configuration based on specific inputs. + + Args: + table_name: The name of the table. + column_name: The name of the column. + concept_type: The semantic type (e.g., 'City'). + match_function: The match function to use (e.g., 'EXACT_MATCH_STRINGS'). + db_engine: The database engine (e.g., 'postgresql'). + db_version: The specific database version (optional). + + Returns: + A JSON string representation of a ContextSet containing the generated index. + """ + template_def = match_templates.get_match_template( + dialect=db_engine, + function_name=match_function, + version=db_version, + ) + raw_sql = template_def["sql_template"] + + # Replace {table}, {column}, {concept_type} with the user's inputs. + # $value remains as a placeholder. + value_index_query = raw_sql.format( + table=table_name, + column=column_name, + concept_type=concept_type, + ) + + # Wrap this single index in a list because ContextSet expects a list. + vi = context.ValueIndex( + concept_type=concept_type, + query=value_index_query, + description=description, + ) + + # Return as ContextSet JSON + return context.ContextSet(value_indices=[vi]).model_dump_json( + indent=2, exclude_none=True + ) \ No newline at end of file diff --git a/mcp/value_index/match_templates.py b/mcp/value_index/match_templates.py new file mode 100644 index 0000000..178cf82 --- /dev/null +++ b/mcp/value_index/match_templates.py @@ -0,0 +1,84 @@ +from typing import Dict, Any, Optional + +# If a user does not specify a version, we default to these versions. +DEFAULT_VERSIONS: Dict[str, str] = { + "postgresql": "16" +} + +# Structure: Dialect -> Version -> Function Name -> Template Data +MATCH_TEMPLATES: Dict[str, Dict[str, Dict[str, Any]]] = { + "postgresql": { + "16": { + "EXACT_MATCH_STRINGS": { + "description": "Exact match (Standard SQL)", + "sql_template": ( + "SELECT $value as value, '{table}.{column}' as columns, " + "'{concept_type}' as concept_type, 0 as distance, " + "'' as context FROM {table} T WHERE T.{column} = $value LIMIT 1" + ), + }, + "FUZZY_MATCH_STRINGS": { + "description": "Fuzzy match using standard levenshtein (requires fuzzystrmatch extension)", + "sql_template": ( + "SELECT T.{column} as value, '{table}.{column}' as columns, " + "'{concept_type}' as concept_type, levenshtein(T.{column}, $value) as distance, " + "'' as context FROM {table} T ORDER BY distance LIMIT 10" + ), + }, + } + }, +} + + +def get_match_template( + dialect: str, function_name: str, version: Optional[str] = None +) -> dict: + """ + Retrieves a specific match template using a strict version lookup strategy. + + Args: + dialect: The database dialect (e.g., 'postgresql'). + function_name: The name of the match function (e.g., 'EXACT_MATCH_STRINGS'). + version: The specific database version. If None/Empty, uses 'default'. + + Returns: + A dictionary containing the template definition. + + Raises: + ValueError: If dialect, version, or function is not found. + """ + dialect_key = dialect.lower() + engine_config = MATCH_TEMPLATES.get(dialect_key) + + if not engine_config: + raise ValueError(f"Dialect '{dialect}' not supported.") + + if not version: + version = "default" + + if version.lower() == "default": + target_version_key = DEFAULT_VERSIONS.get(dialect_key) + if not target_version_key: + raise ValueError( + f"Configuration Error: No default version defined for dialect '{dialect}'." + ) + else: + target_version_key = version + + version_config = engine_config.get(target_version_key) + + if not version_config: + available = list(engine_config.keys()) + raise ValueError( + f"Version '{target_version_key}' not found for dialect '{dialect}'. " + f"Available versions: {available}" + ) + + template = version_config.get(function_name) + if not template: + raise ValueError( + f"Match function '{function_name}' not found for {dialect} " + f"(version: {target_version_key})." + ) + + return template \ No newline at end of file