Spaces:
Sleeping
Sleeping
File size: 5,977 Bytes
dc6fe88 79168b4 a4607e0 dc6fe88 a4607e0 dc6fe88 79168b4 edef444 f160d1e edef444 dc6fe88 79168b4 edef444 dc6fe88 e06da36 dc6fe88 79168b4 dc6fe88 e06da36 dc6fe88 79168b4 edef444 79168b4 e06da36 79168b4 e06da36 79168b4 e06da36 dc6fe88 79168b4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | # Path: src/nl2sql/sql_agent.py
# SQL Agent for handling NL2SQL conversion with Auto-Correct functionality
from src.database.db_manager import get_db_connection, get_schema_context
from langchain_core.prompts import PromptTemplate
from src.nl2sql.hf_engine import get_llm
# Craft the Prompt Template to instruct LLM on its persona
SQL_PROMPT_TEMPLATE = """You are an expert SQLite developer.
Your task is to write a syntactically correct SQLite query to answer the user's question based strictly on the provided database schema.
Return ONLY the raw SQL query.
Do not include any explanations, markdown formatting, or code blocks.
Schema Context:
{schema}
User Question:
{question}
SQL Query:"""
REFINEMENT_PROMPT_TEMPLATE = """You are an expert SQLite developer.
You previously generated a SQL query to answer the user's question, but it resulted inan error when executed on the database.
Schema Context:
{schema}
User Question:
{question}
Previous Generated SQL:
{failed_sql}
Database Error Message:
{error_message}
Your task is to fix the SQL query based on the exact error message and the schema.
Pay close attention to column names, table relationships, and SQLite syntax.
Return ONLY the raw, corrected SQL query.
Do not include any explanations, markdown formatting, or code blocks.
Corrected SQL Query:"""
# Generate text response
NL_RESPONSE_TEMPLATE = """You are a helpful data assisstant.
The user asked the following question: "{question}"
The database returned the following results: {results}
Provide a direct, natural language answer to the user's question using ONLY the provided data.
Keep it brief. Do not explain the SQL query or mention the database schema.
If the database returns more than 5 rows, DO NOT list the items individually. Instead, provide a brief summary sentence.
Answer:"""
prompt_template = PromptTemplate(
input_variables = ["schema", "question"],
template = SQL_PROMPT_TEMPLATE
)
refinement_prompt = PromptTemplate(
input_variables = ["schema", "question", "failed_sql", "error_message"],
template = REFINEMENT_PROMPT_TEMPLATE
)
nl_response_template = PromptTemplate(
input_variables = ["question", "results"],
template = NL_RESPONSE_TEMPLATE
)
# Clean the output
def clean_sql(raw_sql: str) -> str:
"""
Utility to strip markdown formatting if the LLM hallucinated code blocks.
Ensure the raw string can be directly executed by the SQLite cursor.
"""
cleaned = raw_sql.strip()
if cleaned.startswith("```sql"):
cleaned = cleaned[6:]
elif cleaned.startswith("```"):
cleaned = cleaned[3:]
if cleaned.endswith("```"):
cleaned = cleaned[:-3]
return cleaned.strip()
# Function to handle NL2SQL conversion
def nl2sql_agent(user_question: str, max_retries: int = 3, model_id: str = None) -> dict:
"""
Complete flow execution with Auto-correction:
Get Schema context -> Generate SQL query -> Execute SQL query -> If Error, Refine & Retry ->Return results
"""
# Fetch database schema context using RAG
print(f"Fetching RAG schema context for: '{user_question}'...")
schema = get_schema_context(question = user_question)
# Generate SQL query using the schema context and user question
llm = get_llm(model_id=model_id)
# LangChain Pipeline: Pipe prompt into LLM
chain = prompt_template | llm
refinement_chain = refinement_prompt | llm
nl_chain = nl_response_template | llm
current_sql = ""
error_message = ""
# Auto-correction Loop
for attempt in range(1, max_retries + 1):
if attempt == 1:
print(f"Generating initial SQL query using {model_id or 'default model'}...")
raw_response = chain.invoke({
"schema": schema,
"question": user_question
})
else:
print(f"\n--- Attempt {attempt}/{max_retries}: Refining SQL query based on error ---")
print(f"Feeding error back to LLM: {error_message}")
raw_response = refinement_chain.invoke({
"schema": schema,
"question": user_question,
"failed_sql": current_sql,
"error_message": error_message
})
# Parse & clean the generated SQL query
generated_sql = clean_sql(raw_response)
current_sql = generated_sql
print(f"Generated SQL: \n{generated_sql}")
# Execute the generated SQL query and fetch results
connection = get_db_connection()
if not connection:
return {
"query": generated_sql,
"error": "Could not establish database connection",
"status": "failed"
}
try:
cursor = connection.cursor()
cursor.execute(generated_sql)
results = cursor.fetchall()
if attempt > 1:
print(f"SQL query executed successfully after {attempt} attempts.")
# Generate natural language response based on the results
print("Generating natural language response based on query results...")
nl_response = nl_chain.invoke({
"question": user_question,
"results": str(results)
})
return {
"query": generated_sql,
"results": results,
"nl_response": nl_response,
"status": "success",
"attempts": attempt
}
except Exception as e:
error_message = str(e)
print(f"Error executing SQL: {error_message}")
if attempt == max_retries:
print("Max retries reached. Returning error.")
finally:
connection.close()
return {
"query": current_sql,
"error": error_message,
"status": f"Error executing SQL after {max_retries} attempts"
} |