nl2sql-api / backend /src /scripts /interactive_mode.py
dvwn's picture
Update evaluation_mode.py & interactive_mode.py version 1.1.0
c96208b
Raw
History Blame Contribute Delete
4.38 kB
# Path: src/scripts/interactive_mode.py
# Interactive mode: Allows user to manually type questions and see the agent's response
import csv
import json
import pandas as pd
from pathlib import Path
from tabulate import tabulate
from src.database.db_manager import get_db_connection
from src.nl2sql.sql_agent import nl2sql_agent
from src.nl2sql.hf_engine import get_models
TEST_CASES_PATH = Path("src/scripts/test_cases.json")
def get_query_data(sql_query: str) -> pd.DataFrame:
"""
Executes a SQL query and returns the results as a DataFrame.
"""
if not sql_query or sql_query == "N/A":
return pd.DataFrame()
connection = get_db_connection()
if not connection:
return pd.DataFrame()
try:
df = pd.read_sql_query(sql_query, connection)
return df
except Exception as e:
print(f"Error executing SQL query: {e}")
return pd.DataFrame()
finally:
connection.close()
def verify_data(df_gold: pd.DataFrame, df_generated: pd.DataFrame) -> bool:
"""
Guardrail check:
Verifies if the generated DataFrame matches the expected gold DataFrame.
"""
if df_gold.empty and df_generated.empty:
return False # Both empty: To catch this as a potential issue
if len(df_gold) != len(df_generated):
return False
try:
gold_val = df_gold.fillna("").astype(str).values.tolist()
gen_val = df_generated.fillna("").astype(str).values.tolist()
# Strip whitespace and convert to tuples (sortable)
gold_tuples = [tuple(val.strip() for val in row) for row in gold_val]
gen_tuples = [tuple(val.strip() for val in row) for row in gen_val]
return sorted(gold_tuples) == sorted(gen_tuples)
except Exception as e:
print(f"Error during data verification: {e}")
return False
def run_interactiveMode(model_id: str):
"""
Automates the interactive Question Answering mode.
Runs predefined questions through the agent and logs the textual NL response.
"""
print("\n========= Interactive NL2SQL Mode =========")
print(f"Running Interactive Question Answering Evaluation on Model: {model_id}")
#print("Type 'exit' or 'q' to return to the main menu.\n")
if not TEST_CASES_PATH.exists():
print(f"Error: Could not find test cases at {TEST_CASES_PATH}")
return
with TEST_CASES_PATH.open("r", encoding="utf-8") as handle:
test_cases = json.load(handle)
questAns_results = []
for case in test_cases:
case_id = case.get("id")
question = case.get("question")
gold_sql = case.get("gold_sql")
print(f"\n\nTesting ID {case_id}: {question[:40]}...")
response = nl2sql_agent(user_question=question, model_id=model_id)
# Extract metadata
status = response.get('status')
nl_answer = response.get('nl_response', 'N/A')
sql_query = response.get('query', 'N/A')
error_msg = response.get('error', '')
attempts = response.get('attempts', 0)
# Data cross-check
df_gold = get_query_data(gold_sql)
df_generated = get_query_data(sql_query)
# Verify accuracy
is_data_accurate = verify_data(df_gold, df_generated)
questAns_results.append({
"id": case_id,
"model_id": model_id,
"question": question,
"status": status,
"data_returned_correct": is_data_accurate,
"attempts": attempts,
"nl_response": nl_answer,
"sql_generated": sql_query,
"error": error_msg
})
# Save to CSV for human-readable and easy comparison
safe_model_name = model_id.replace("/", "_").replace(":", "_").replace(" ", "_")
output_csv = Path(f"Q&A_report_{safe_model_name}.csv")
keys = questAns_results[0].keys()
with output_csv.open("w", newline='', encoding="utf-8") as f:
dict_writer = csv.DictWriter(f, fieldnames=keys)
dict_writer.writeheader()
dict_writer.writerows(questAns_results)
print(f"\nInteractive evaluation completed. Results saved to: {output_csv}")
if __name__ == "__main__":
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())
models_to_test = get_models()
for model in models_to_test:
run_interactiveMode(model)