nl2sql-api / backend /src /scripts /evaluation_mode.py
dvwn's picture
Update evaluation_mode.py & interactive_mode.py version 1.1.0
c96208b
Raw
History Blame Contribute Delete
7.03 kB
# Path: src/scripts/evaluation_mode.py
# Evaluation script for Hugging Face SQL generation.
import json
import sqlglot
from pathlib import Path
import pandas as pd
from src.database.db_manager import get_db_connection
from src.nl2sql.hf_engine import get_models
from src.nl2sql.sql_agent import nl2sql_agent
from src.scripts.taxonomy_report import print_taxonomyReport
TEST_CASES_PATH = Path("src/scripts/test_cases.json")
def _normalize_dataframe(dataframe: pd.DataFrame) -> list:
# Normalize dataframe to ensure accurate comparison
"""
Standardize dataframes for Execution Accuracy (EX).
- Converts dataframe to a list of tuples to ignore column names.
- Rounds floating points to 4 decimal places to avoid precision mismatch.
- Sorts the final list to ensure order-agnostic comparison.
"""
if dataframe is None or dataframe.empty:
return []
normalized = dataframe.copy()
for column in normalized.columns:
normalized[column] = normalized[column].map(
lambda x: round(float(x), 4)
if pd.api.types.is_numeric_dtype(type(x)) and isinstance(x, float)
else x
#lambda value: round(float(value), 6)
#if isinstance(value, (float, int))
#else value
)
# Convert to list of tuples for order-agnostic comparison
data_tuples = [tuple(row) for row in normalized.to_numpy()]
# Sort to ensure order agnoticism
try:
data_tuples.sort(key=lambda x: str(x))
except Exception as e:
pass
return data_tuples
# Semantic safety net
def extract_tables(sql: str) -> set:
"""
Extract a set of all table names used in a SQL query.
Used to catch false positives where EX passes but the model queried the wrong tables.
"""
if not sql:
return set()
try:
parsed = sqlglot.parse_one(sql, read=None)
# Find all table expressions & extract names, ignore aliases
return set(table.name.lower() for table in parsed.find_all(sqlglot.exp.Table) if table.name)
except Exception as e:
return set()
# EX: Compare generated SQL results with expected results
def calculate_ex(df_generated: pd.DataFrame, df_gold: pd.DataFrame) -> bool:
"""
Execution Accuracy (EX): Compare generated SQL results with expected results.
"""
if df_generated is None or df_gold is None:
return False
#if normalized_generated.shape != normalized_gold.shape:
# return False
try:
normalized_generated = _normalize_dataframe(df_generated)
normalized_gold = _normalize_dataframe(df_gold)
return normalized_generated == normalized_gold
except Exception as error:
print(f"EX Evaluation Error: {error}")
return False
def calculate_esm(generated_sql: str, gold_sql: str) -> bool:
"""
Exact Set Match (ESM): Compare AST structure using sqlglot.
- Ignores formatting, capitalization, and minor syntactic sugar.
"""
if not generated_sql or not gold_sql:
return False
try:
# Parse both SQL queries into expressions
generated_exp = sqlglot.parse_one(generated_sql, read=None)
gold_exp = sqlglot.parse_one(gold_sql, read=None)
# Compare the expressions for structural equivalence
return generated_exp == gold_exp
except Exception as error:
print(f"ESM Evaluation Error: {error}")
return False
def run_evaluation(model_id: str):
print(f"\nRunning SQL evaluation for model: {model_id}")
print("\n" + "-" *50)
if not TEST_CASES_PATH.exists():
print(f"Error: Could not find test cases at {TEST_CASES_PATH}")
return
with TEST_CASES_PATH.open("r", encoding="utf-8") as handle:
test_cases = json.load(handle)
results = []
ex_count = 0
esm_count = 0
print(f"Running evaluation on {len(test_cases)} test cases...\n")
for case in test_cases:
id = case.get("id")
question = case.get("question")
gold_sql = case.get("gold_sql")
taxonomy = case.get("taxonomy", "Unknown")
print(f"Testing ID {id}: {question[:40]}...")
# Implement agent to handle RAG retrieval and SQL generation
agent_response = nl2sql_agent(user_question=question, model_id=model_id)
generated_sql = agent_response.get("query", "")
# ESM Evaluation
esm_result = calculate_esm(generated_sql, gold_sql)
if esm_result:
esm_count += 1
# EX Evaluation
ex_result = False
connection = get_db_connection()
if connection is None:
raise RuntimeError("Unable to connect to the SQLite database.")
try:
df_generated = pd.read_sql_query(generated_sql, connection)
df_gold = pd.read_sql_query(gold_sql, connection)
# Trap the False Positive (empty set): weak test case
if df_gold.empty:
print(f"[!]WARNING: Gold SQL for ID {id} returned an emoty response.")
ex_result = calculate_ex(df_generated, df_gold)
# Semantic safety net check
if ex_result:
gen_tables = extract_tables(generated_sql)
gold_tables = extract_tables(gold_sql)
if gen_tables != gold_tables:
print(f"[X] FALSE POSITIVE (ID{id}): Data matched, tables not")
print(f"\nGenerated SQL tables: {gen_tables} | Gold SQL tables: {gold_tables}")
ex_result = False
if ex_result:
ex_count += 1
except Exception as error:
print(f"Error executing SQL for ID {id}: {error}")
finally:
connection.close()
results.append({
"id": id,
"question": question,
"taxonomy": taxonomy,
"ex_pass": ex_result,
"esm_pass": esm_result,
"generated_sql": generated_sql,
"gold_sql": gold_sql
})
# Summary Statistics
total = len(test_cases)
ex_accuracy = (ex_count / total) * 100 if total > 0 else 0
esm_accuracy = (esm_count / total) * 100 if total > 0 else 0
print("\nEVALUATION SUMMARY")
print("-" * 40)
print(f"Model Evaluated: {model_id}")
print(f"Total Test Cases: {total}")
print(f"Execution Accuracy (EX): {ex_accuracy:.2f}% ({ex_count}/{total})")
print(f"Exact Set Match (ESM): {esm_accuracy:.2f}% ({esm_count}/{total})")
safe_model_name = model_id.replace("/", "_").replace(":", "_")
output_file = Path(f"sql_eval_{safe_model_name}.json")
with output_file.open("w", encoding="utf-8") as handle:
json.dump(results, handle, indent=4)
print_taxonomyReport(results)
if __name__ == "__main__":
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())
models_to_test = get_models()
for model in models_to_test:
run_evaluation(model)