Spaces:
Sleeping
Sleeping
| # Path: src/scripts/interactive_mode.py | |
| # Interactive mode: Allows user to manually type questions and see the agent's response | |
| import csv | |
| import json | |
| import pandas as pd | |
| from pathlib import Path | |
| from tabulate import tabulate | |
| from src.database.db_manager import get_db_connection | |
| from src.nl2sql.sql_agent import nl2sql_agent | |
| from src.nl2sql.hf_engine import get_models | |
| TEST_CASES_PATH = Path("src/scripts/test_cases.json") | |
| def get_query_data(sql_query: str) -> pd.DataFrame: | |
| """ | |
| Executes a SQL query and returns the results as a DataFrame. | |
| """ | |
| if not sql_query or sql_query == "N/A": | |
| return pd.DataFrame() | |
| connection = get_db_connection() | |
| if not connection: | |
| return pd.DataFrame() | |
| try: | |
| df = pd.read_sql_query(sql_query, connection) | |
| return df | |
| except Exception as e: | |
| print(f"Error executing SQL query: {e}") | |
| return pd.DataFrame() | |
| finally: | |
| connection.close() | |
| def verify_data(df_gold: pd.DataFrame, df_generated: pd.DataFrame) -> bool: | |
| """ | |
| Guardrail check: | |
| Verifies if the generated DataFrame matches the expected gold DataFrame. | |
| """ | |
| if df_gold.empty and df_generated.empty: | |
| return False # Both empty: To catch this as a potential issue | |
| if len(df_gold) != len(df_generated): | |
| return False | |
| try: | |
| gold_val = df_gold.fillna("").astype(str).values.tolist() | |
| gen_val = df_generated.fillna("").astype(str).values.tolist() | |
| # Strip whitespace and convert to tuples (sortable) | |
| gold_tuples = [tuple(val.strip() for val in row) for row in gold_val] | |
| gen_tuples = [tuple(val.strip() for val in row) for row in gen_val] | |
| return sorted(gold_tuples) == sorted(gen_tuples) | |
| except Exception as e: | |
| print(f"Error during data verification: {e}") | |
| return False | |
| def run_interactiveMode(model_id: str): | |
| """ | |
| Automates the interactive Question Answering mode. | |
| Runs predefined questions through the agent and logs the textual NL response. | |
| """ | |
| print("\n========= Interactive NL2SQL Mode =========") | |
| print(f"Running Interactive Question Answering Evaluation on Model: {model_id}") | |
| #print("Type 'exit' or 'q' to return to the main menu.\n") | |
| if not TEST_CASES_PATH.exists(): | |
| print(f"Error: Could not find test cases at {TEST_CASES_PATH}") | |
| return | |
| with TEST_CASES_PATH.open("r", encoding="utf-8") as handle: | |
| test_cases = json.load(handle) | |
| questAns_results = [] | |
| for case in test_cases: | |
| case_id = case.get("id") | |
| question = case.get("question") | |
| gold_sql = case.get("gold_sql") | |
| print(f"\n\nTesting ID {case_id}: {question[:40]}...") | |
| response = nl2sql_agent(user_question=question, model_id=model_id) | |
| # Extract metadata | |
| status = response.get('status') | |
| nl_answer = response.get('nl_response', 'N/A') | |
| sql_query = response.get('query', 'N/A') | |
| error_msg = response.get('error', '') | |
| attempts = response.get('attempts', 0) | |
| # Data cross-check | |
| df_gold = get_query_data(gold_sql) | |
| df_generated = get_query_data(sql_query) | |
| # Verify accuracy | |
| is_data_accurate = verify_data(df_gold, df_generated) | |
| questAns_results.append({ | |
| "id": case_id, | |
| "model_id": model_id, | |
| "question": question, | |
| "status": status, | |
| "data_returned_correct": is_data_accurate, | |
| "attempts": attempts, | |
| "nl_response": nl_answer, | |
| "sql_generated": sql_query, | |
| "error": error_msg | |
| }) | |
| # Save to CSV for human-readable and easy comparison | |
| safe_model_name = model_id.replace("/", "_").replace(":", "_").replace(" ", "_") | |
| output_csv = Path(f"Q&A_report_{safe_model_name}.csv") | |
| keys = questAns_results[0].keys() | |
| with output_csv.open("w", newline='', encoding="utf-8") as f: | |
| dict_writer = csv.DictWriter(f, fieldnames=keys) | |
| dict_writer.writeheader() | |
| dict_writer.writerows(questAns_results) | |
| print(f"\nInteractive evaluation completed. Results saved to: {output_csv}") | |
| if __name__ == "__main__": | |
| from dotenv import load_dotenv, find_dotenv | |
| load_dotenv(find_dotenv()) | |
| models_to_test = get_models() | |
| for model in models_to_test: | |
| run_interactiveMode(model) |