Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 28 additions & 16 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ dependencies = [
"SQLAlchemy~=2.0.40",
"pygls~=2.0.0a2",
"duckdb~=1.2.2",
"tree-sitter==0.20.1",
"tree-sitter-languages==1.5.0",
]

[project.urls]
Expand All @@ -48,22 +50,32 @@ python="3.10"
path = ".venv"

dependencies = [
"pylint~=3.2.2",
"pylint-pytest==2.0.0a0",
"coverage[toml]~=7.8.0",
"pytest~=8.3.5",
"pytest-cov>=5.0.0,<6.0.0",
"pytest-asyncio~=0.26.0",
"pytest-xdist~=3.5.0",
"black~=25.1.0",
"ruff~=0.11.6",
"databricks-connect==15.1",
"types-pyYAML~=6.0.12",
"types-pytz~=2025.2",
"databricks-labs-pylint~=0.4.0",
"mypy~=1.10.0",
"numpy==1.26.4",
"pandas==1.4.1",
"pylint~=3.2.2",
"pylint-pytest==2.0.0a0",
"coverage[toml]~=7.8.0",
"pytest~=8.3.5",
"pytest-cov>=5.0.0,<6.0.0",
"pytest-asyncio~=0.26.0",
"pytest-xdist~=3.5.0",
"black~=25.1.0",
"ruff~=0.11.6",
"databricks-connect==15.1",
"types-pyYAML~=6.0.12",
"types-pytz~=2025.2",
"databricks-labs-pylint~=0.4.0",
"mypy~=1.10.0",
"numpy==1.26.4",
"pandas==1.4.1",
"langchain",
"langchain-experimental",
"databricks-langchain",
"langchain-community",
"langchain-text-splitters",
"tree-sitter==0.20.1",
"tree-sitter-languages==1.5.0",
"langgraph",
"databricks-agents==0.16.0",
"databricks-vectorsearch==0.49"
]

[project.entry-points.databricks]
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from databricks.labs.remorph.agents.code_explainer.explainer.sql import SQLExplainer

__all__ = [
"SQLExplainer",
]
62 changes: 62 additions & 0 deletions src/databricks/labs/remorph/agents/code_explainer/execute.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import pprint
import logging

from databricks.labs.remorph.agents.code_explainer.parser import SqlParser
from databricks.labs.remorph.agents.code_explainer.explainer import SQLExplainer
from databricks.labs.remorph.agents.code_explainer.intent import CompareIntent

logger = logging.getLogger(__name__)


def _run(source_doc: str, target_doc: str = '', compare_intent: bool = False, format_flag: bool = False) -> None:
"""Run the SQL Explainer"""
# Set the experiment

source_documents = SqlParser(source_doc).parse()

if not source_documents:
logger.warning("No code found in the source document.")
print("[WARN]::No code found in the source document.")
return

target_documents = SqlParser(target_doc).parse() if compare_intent else None

# print("Number of documents: ", len(docs))

explainer = SQLExplainer(endpoint_name="databricks-llama-4-maverick", format_flag=format_flag)
source_explanations = explainer.explain_documents(source_documents)
target_explanations = explainer.explain_documents(target_documents) if target_documents else []

if source_explanations:
print("****" * 50)
print("Source SQL Code Explanation:")
pprint.pprint(source_explanations[0].get('explanation', "__NOT_FOUND__"))

if not target_documents:
return

if target_explanations:
print("****" * 50)
print("Target SQL Code Explanation:")
pprint.pprint(target_explanations[0].get('explanation', "__NOT_FOUND__"))

if source_explanations and target_explanations:
print("****" * 50)
print("Comparing Code intent of Source SQL and converted ")

intent_compare = CompareIntent(
source_intent=source_explanations[0].get('explanation', {}),
target_intent=target_explanations[0].get('explanation', {}),
endpoint_name="databricks-claude-3-7-sonnet",
)

print("****" * 50)
print(intent_compare.compare())


def intent(source_doc: str):
_run(source_doc=source_doc, compare_intent=False, format_flag=False)


def match_intent(source_doc: str, target_doc: str):
_run(source_doc=source_doc, target_doc=target_doc, compare_intent=True, format_flag=True)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from databricks.labs.remorph.agents.code_explainer.explainer.sql import SQLExplainer

__all__ = [
"SQLExplainer",
]
149 changes: 149 additions & 0 deletions src/databricks/labs/remorph/agents/code_explainer/explainer/sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
from __future__ import annotations

from typing import Any


from langchain.prompts import ChatPromptTemplate
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
from langchain_core.documents import Document
from databricks_langchain import ChatDatabricks # type: ignore

from databricks.labs.remorph.agents.exceptions import RemorphAgentException
from databricks.labs.remorph.agents.code_explainer.parser import SqlParser


class SQLExplainer:
"""SQL Explainer class to explain SQL code using Databricks LLM."""

def __init__(self, endpoint_name: str = "databricks-llama-4-maverick", format_flag: bool = False):
self.endpoint_name = endpoint_name
self.format_flag = format_flag

# Initialize the SQL Explainer Chain
self.llm = ChatDatabricks(endpoint=self.endpoint_name, extra_params={"temperature": 0.0})

# Create the prompt template
self.prompt_template = ChatPromptTemplate.from_template(
"""
Analyze the following SQL code segment and explain its purpose and functionality:

```sql
{sql_segment}
```

Provide a comprehensive explanation including:
1. What kind of SQL statement this is (CREATE, SELECT, INSERT, etc.)
2. The logical flow and what this SQL code is doing
3. Key tables involved
4. SQL functions used in the code and their purpose
5. Any potential performance considerations
6. Identify if there's any potential migration challenges to convert this SQL to Databricks DBSQL.

{format_instructions}
"""
)

# Create the LLM chain
# self.llm_chain = ConversationChain(
# llm=self.llm,
# prompt=self.prompt_template,
# memory = ConversationBufferMemory()
# )

self.llm_chain = self.prompt_template | self.llm

def format_output(self) -> StructuredOutputParser:
"""Format the output from the LLM into a structured format"""
# Define the response schema
sql_type_schema = ResponseSchema(
name="sql_type",
description="What kind of SQL statement this is (CREATE, SELECT, INSERT, etc.).\
If it's a create statement, include the database/schema (if available), table/view name and columns.",
)

sql_flow_schema = ResponseSchema(
name="sql_flow",
description="The logical flow of the SQL and what this SQL code is doing.\
Include the key tables and operations involved (Filter, Join, Group by etc).\
Write the SQL flow in a way we explain the ETL logic",
)

key_tables_schema = ResponseSchema(
name="key_tables",
description="extract all the tables/views along with schema details mentioned in the SQL and output them as comma .\
separated python list.",
)

sql_functions_schema = ResponseSchema(
name="sql_functions",
description="List of SQL functions used in the SQL and their purpose.\
Extract the Output as comma separated python list each having a tuple of function name and its purpose.",
)

perf_considerations_schema = ResponseSchema(
name="performance_considerations",
description="Identify if there's any potential performance considerations and suggest optimization techniques.",
)

migration_challenges_schema = ResponseSchema(
name="migration_challenges",
description="Identify if there's any potential migration challenges to convert this SQL to Databricks DBSQL. Identify the \
SQL functions and features that are not supported in Databricks DBSQL and suggest alternatives.",
)

response_schema = [
sql_type_schema,
sql_flow_schema,
key_tables_schema,
sql_functions_schema,
perf_considerations_schema,
migration_challenges_schema,
]

# Create the output parser
try:
output_parser = StructuredOutputParser.from_response_schemas(response_schema)
except Exception as e:
raise RemorphAgentException("Response formatter failed to load.") from e

return output_parser

def explain_document(self, doc: Document) -> dict[str, Any]:
"""Explain a single SQL document"""
sql_segment = doc.page_content

output_parser = self.format_output()
format_instructions = output_parser.get_format_instructions()

result = {}
# Run the chain
response = (
self.llm_chain.invoke({"sql_segment": sql_segment, "format_instructions": format_instructions})
if self.format_flag
else self.llm_chain.invoke({"sql_segment": sql_segment, "format_instructions": ""})
)

# format the output if flag is set to True
result["explanation"] = output_parser.parse(response.content) if self.format_flag else response.content

# Add metadata to the result
result["metadata"] = doc.metadata
result["original_sql"] = sql_segment

return result

def explain_documents(self, docs: list[Document]) -> list[dict[str, Any]]:
"""Explain multiple SQL documents"""
results = []
for doc in docs:
result = self.explain_document(doc)
results.append(result)
return results

def batch_explain(self, sql_file_path: str) -> list[dict[str, Any]]:
"""Parse and explain all SQL segments in a file"""
parser = SqlParser(file_path=sql_file_path)
docs = parser.parse()
if not docs:
raise ValueError(f"Parsing failed for the SQL Code Segment: {sql_file_path}")
return self.explain_documents(docs)
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from databricks.labs.remorph.agents.code_explainer.intent.compare import CompareIntent

__all__ = ["CompareIntent"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import json
from typing import Any


from databricks_langchain import ChatDatabricks # type: ignore
from langchain.prompts import ChatPromptTemplate


class CompareIntent:
def __init__(
self,
source_intent: dict[str, Any],
target_intent: dict[str, Any],
endpoint_name: str = "databricks-llama-4-maverick",
):
self.source_intent = source_intent
self.target_intent = target_intent
self.endpoint_name = endpoint_name
# self.explainer = SQLExplainer(endpoint_name="databricks-llama-4-maverick", format_flag=True)

"""Initialize the SQL Explainer Chain"""
self.llm = ChatDatabricks(endpoint=self.endpoint_name, extra_params={"temperature": 0.0})

# Create the prompt template
self.prompt_template = ChatPromptTemplate.from_template(
"""
Analyze the following two JSON segments each containing the code intents of source and migrated target SQL code:

Source SQL Intent:
```json
{source_intent}
```

Target SQL Intent:
```json
{target_intent}
```

Each JSON has following keys:
1. sql_type : The type of SQL statement (CREATE, SELECT, INSERT, etc.)
2. logical_flow : The logical flow and what this SQL code is doing
3. key_tables : Key tables involved
4. sql_functions : SQL functions used in the code and their purpose
5. performance_considerations : Any potential performance considerations
6. migration_challenges : Identify if there's any potential migration challenges to convert this SQL to Databricks DBSQL.

Consider the the first 5 keys to compare whether the source and target SQL code are similar or not.
If you think they are different, estimate the similarity score at the end of the response.
Also while analyzing the sql_functions, be mindful of the fact that depending on source and target SQL dialects,
the same function may have different names or parameters. Acknowledge that in your response while calculating the similarity.
"""
)

self.llm_chain = self.prompt_template | self.llm

def compare(self) -> str | list:
"""Compare the code intent of source and target SQL code."""
source_intent_str = json.dumps(self.source_intent)
target_intent_str = json.dumps(self.target_intent)
response = self.llm_chain.invoke({"source_intent": source_intent_str, "target_intent": target_intent_str})

return response.content
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from databricks.labs.remorph.agents.code_explainer.parser.sql import SqlParser
from databricks.labs.remorph.agents.code_explainer.parser.python import PythonParser


__all__ = ["SqlParser", "PythonParser"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from __future__ import annotations

from collections.abc import Iterator

from langchain_core.documents import Document
from langchain_community.document_loaders.generic import GenericLoader
from langchain_community.document_loaders.parsers.language import LanguageParser


class PythonParser:
def __init__(self, file_path: str):
self.file_path = file_path
self.loader = GenericLoader.from_filesystem(self.file_path, parser=LanguageParser("python"))

def parse(self) -> list[Document] | None:
"""Parse the Python code into list of Documents"""
return self.loader.load()

def lazy_parse(self) -> Iterator[Document] | None:
"""Parse the Python code into Documents. Yields one Document at a time."""
return self.loader.lazy_load()
Loading