File size: 5,977 Bytes
dc6fe88
79168b4
a4607e0
dc6fe88
a4607e0
dc6fe88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79168b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
edef444
 
 
 
 
 
 
f160d1e
edef444
 
 
dc6fe88
 
 
 
 
79168b4
 
 
 
 
edef444
 
 
 
 
dc6fe88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e06da36
dc6fe88
79168b4
 
dc6fe88
 
 
 
 
 
e06da36
dc6fe88
 
 
79168b4
edef444
79168b4
 
 
 
 
 
 
e06da36
79168b4
 
 
 
 
 
 
 
 
 
 
 
 
 
e06da36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79168b4
e06da36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dc6fe88
79168b4
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# Path: src/nl2sql/sql_agent.py
# SQL Agent for handling NL2SQL conversion with Auto-Correct functionality
from src.database.db_manager import get_db_connection, get_schema_context
from langchain_core.prompts import PromptTemplate
from src.nl2sql.hf_engine import get_llm

# Craft the Prompt Template to instruct LLM on its persona
SQL_PROMPT_TEMPLATE = """You are an expert SQLite developer.
Your task is to write a syntactically correct SQLite query to answer the user's question based strictly on the provided database schema.
Return ONLY the raw SQL query.
Do not include any explanations, markdown formatting, or code blocks.

Schema Context:
{schema}

User Question:
{question}

SQL Query:"""

REFINEMENT_PROMPT_TEMPLATE = """You are an expert SQLite developer.
You previously generated a SQL query to answer the user's question, but it resulted inan error when executed on the database.

Schema Context:
{schema}

User Question:
{question}

Previous Generated SQL:
{failed_sql}

Database Error Message:
{error_message}

Your task is to fix the SQL query based on the exact error message and the schema.
Pay close attention to column names, table relationships, and SQLite syntax.
Return ONLY the raw, corrected SQL query. 
Do not include any explanations, markdown formatting, or code blocks.

Corrected SQL Query:"""

# Generate text response
NL_RESPONSE_TEMPLATE = """You are a helpful data assisstant.
The user asked the following question: "{question}"
The database returned the following results: {results}

Provide a direct, natural language answer to the user's question using ONLY the provided data.
Keep it brief. Do not explain the SQL query or mention the database schema.
If the database returns more than 5 rows, DO NOT list the items individually. Instead, provide a brief summary sentence.

Answer:"""

prompt_template = PromptTemplate(
    input_variables = ["schema", "question"],
    template = SQL_PROMPT_TEMPLATE
)

refinement_prompt = PromptTemplate(
    input_variables = ["schema", "question", "failed_sql", "error_message"],
    template = REFINEMENT_PROMPT_TEMPLATE
)

nl_response_template = PromptTemplate(
    input_variables = ["question", "results"],
    template = NL_RESPONSE_TEMPLATE
)

# Clean the output
def clean_sql(raw_sql: str) -> str:
    """
    Utility to strip markdown formatting if the LLM hallucinated code blocks.
    Ensure the raw string can be directly executed by the SQLite cursor.
    """
    cleaned = raw_sql.strip()
    if cleaned.startswith("```sql"):
        cleaned = cleaned[6:]
    elif cleaned.startswith("```"):
        cleaned = cleaned[3:]
    
    if cleaned.endswith("```"):
        cleaned = cleaned[:-3]
    
    return cleaned.strip()

# Function to handle NL2SQL conversion
def nl2sql_agent(user_question: str, max_retries: int = 3, model_id: str = None) -> dict:
    """
    Complete flow execution with Auto-correction:
    Get Schema context -> Generate SQL query -> Execute SQL query -> If Error, Refine & Retry ->Return results
    """
    # Fetch database schema context using RAG
    print(f"Fetching RAG schema context for: '{user_question}'...")
    schema = get_schema_context(question = user_question)

    # Generate SQL query using the schema context and user question
    llm = get_llm(model_id=model_id)

    # LangChain Pipeline: Pipe prompt into LLM
    chain = prompt_template | llm
    refinement_chain = refinement_prompt | llm
    nl_chain = nl_response_template | llm

    current_sql = ""
    error_message = ""

    # Auto-correction Loop
    for attempt in range(1, max_retries + 1):
        if attempt == 1:
            print(f"Generating initial SQL query using {model_id or 'default model'}...")
            raw_response = chain.invoke({
                "schema": schema,
                "question": user_question
            })
        else:
            print(f"\n--- Attempt {attempt}/{max_retries}: Refining SQL query based on error ---")
            print(f"Feeding error back to LLM: {error_message}")
            raw_response = refinement_chain.invoke({
                "schema": schema,
                "question": user_question,
                "failed_sql": current_sql,
                "error_message": error_message
            })

        # Parse & clean the generated SQL query
        generated_sql = clean_sql(raw_response)
        current_sql = generated_sql
        print(f"Generated SQL: \n{generated_sql}")

        # Execute the generated SQL query and fetch results
        connection = get_db_connection()
        if not connection:
            return {
                "query": generated_sql,
                "error": "Could not establish database connection",
                "status": "failed"
            }
        
        try:
            cursor = connection.cursor()
            cursor.execute(generated_sql)
            results = cursor.fetchall()
            
            if attempt > 1:
                print(f"SQL query executed successfully after {attempt} attempts.")
            
            # Generate natural language response based on the results
            print("Generating natural language response based on query results...")
            nl_response = nl_chain.invoke({
                "question": user_question,
                "results": str(results)
            })

            return {
                "query": generated_sql,
                "results": results,
                "nl_response": nl_response,
                "status": "success",
                "attempts": attempt
            }
        except Exception as e:
            error_message = str(e)
            print(f"Error executing SQL: {error_message}")

            if attempt == max_retries:
                print("Max retries reached. Returning error.")
        finally:
            connection.close()
    
    return {
        "query": current_sql,
        "error": error_message,
        "status": f"Error executing SQL after {max_retries} attempts"
    }