Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion wren-ai-service/src/globals.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,13 +213,17 @@ def create_service_container(
),
"db_schema_retrieval": _db_schema_retrieval_pipeline,
"sql_generation": generation.SQLGeneration(
**pipe_components["question_recommendation_sql_generation"],
**pipe_components[
"question_recommendation_sql_generation"
],
),
"sql_pairs_retrieval": _sql_pair_retrieval_pipeline,
"instructions_retrieval": _instructions_retrieval_pipeline,
"sql_functions_retrieval": _sql_functions_retrieval_pipeline,
"sql_correction": _sql_correction_pipeline,
},
allow_sql_functions_retrieval=settings.allow_sql_functions_retrieval,
max_sql_correction_retries=settings.max_sql_correction_retries,
**query_cache,
),
sql_pairs_service=services.SqlPairsService(
Expand Down
41 changes: 39 additions & 2 deletions wren-ai-service/src/web/v1/services/question_recommendation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
self,
pipelines: Dict[str, BasicPipeline],
allow_sql_functions_retrieval: bool = True,
max_sql_correction_retries: int = 3,
maxsize: int = 1_000_000,
ttl: int = 120,
):
Expand All @@ -39,6 +40,7 @@ def __init__(
maxsize=maxsize, ttl=ttl
)
self._allow_sql_functions_retrieval = allow_sql_functions_retrieval
self._max_sql_correction_retries = max_sql_correction_retries

def _handle_exception(
self,
Expand Down Expand Up @@ -127,10 +129,45 @@ async def _instructions_retrieval() -> list[dict]:

post_process = generated_sql["post_process"]

# If initial generation fails, try correction loop similar to ask flow
if len(post_process["valid_generation_result"]) == 0:
return post_process
failed_dry_run_result = post_process.get(
"invalid_generation_result"
)
current_sql_correction_retries = 0

while (
failed_dry_run_result
and current_sql_correction_retries
< self._max_sql_correction_retries
):
current_sql_correction_retries += 1

sql_correction_results = await self._pipelines[
"sql_correction"
].run(
contexts=table_ddls,
invalid_generation_result=failed_dry_run_result,
instructions=instructions,
project_id=project_id,
)

post_process = sql_correction_results["post_process"]
if valid_generation_result := post_process.get(
"valid_generation_result"
):
valid_sql = valid_generation_result["sql"]
break

failed_dry_run_result = post_process.get(
"invalid_generation_result"
)
else:
# Still no valid SQL after corrections
return post_process

valid_sql = post_process["valid_generation_result"]["sql"]
else:
valid_sql = post_process["valid_generation_result"]["sql"]

# Partial update the resource
current = self._cache[request_id]
Expand Down