diff --git a/runners/anthropic_runner.py b/runners/anthropic_runner.py index 2081afb..264f2e1 100644 --- a/runners/anthropic_runner.py +++ b/runners/anthropic_runner.py @@ -128,7 +128,7 @@ def process_row(row, model_name, args): def run_anthropic_eval(args): - # get params from args + """Run evaluation using Anthropic""" questions_file_list = args.questions_file prompt_file_list = args.prompt_file output_file_list = args.output_file @@ -145,97 +145,55 @@ def run_anthropic_eval(args): print( f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" ) - question_query_df = prepare_questions_df( + df = prepare_questions_df( questions_file, db_type, num_questions, k_shot, cot_table_alias ) - input_rows = question_query_df.to_dict("records") - output_rows = [] - with ThreadPoolExecutor(args.parallel_threads) as executor: - futures = [] - for row in input_rows: - generated_query_fut = executor.submit( - process_row, - row=row, - model_name=args.model, - args=args, - ) - futures.append(generated_query_fut) - - total_tried = 0 - total_correct = 0 - for f in (pbar := tqdm(as_completed(futures), total=len(futures))): - total_tried += 1 - i = futures.index(f) - row = input_rows[i] - result_dict = f.result() - query_gen = result_dict["query"] - reason = result_dict["reason"] - err = result_dict["err"] - # save custom metrics - if "latency_seconds" in result_dict: - row["latency_seconds"] = result_dict["latency_seconds"] - if "tokens_used" in result_dict: - row["tokens_used"] = result_dict["tokens_used"] - row["generated_query"] = query_gen - row["reason"] = reason - row["error_msg"] = err - # save failures into relevant columns in the dataframe - if "GENERATION ERROR" in err: - row["error_query_gen"] = 1 - elif "TIMEOUT" in err: - row["timeout"] = 1 - else: - expected_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - try: - is_correct = compare_query_results( - query_gold=expected_query, - query_gen=query_gen, - db_name=db_name, - db_type=db_type, - db_creds=db_creds_all[db_type], - question=row["question"], - query_category=row["query_category"], - decimal_points=args.decimal_points, - ) - if is_correct: - total_correct += 1 - row["is_correct"] = 1 - row["error_msg"] = "" - else: - row["is_correct"] = 0 - row["error_msg"] = "INCORRECT RESULTS" - except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"EXECUTION ERROR: {str(e)}" - output_rows.append(row) - pbar.set_description( - f"Accuracy: {round(total_correct/total_tried * 100, 2)}% ({total_correct}/{total_tried})" - ) - # save results to csv + output_rows, total_correct, total_tried = run_eval_in_threadpool( + df, args.model, process_row, args + ) + + # Convert to DataFrame and save results output_df = pd.DataFrame(output_rows) output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist + if "prompt" in output_df.columns: + del output_df["prompt"] + + # Get stats by query category + agg_stats = ( + output_df.groupby("query_category") + .agg( + num_rows=("db_name", "count"), + mean_correct=("is_correct", "mean"), + mean_error_db_exec=("error_db_exec", "mean"), + ) + .reset_index() + ) + print(agg_stats) + + # Create output directory if needed output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): + if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) + output_df.to_csv(output_file, index=False, float_format="%.2f") - # get average rate of correct results - avg_subset = output_df["is_correct"].sum() / len(output_df) - print(f"Average correct rate: {avg_subset:.2f}") - - results = output_df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="anthropic", - prompt=prompt, - args=args, - ) + # Print summary stats + print(f"Total questions: {total_tried}") + print(f"Total correct: {total_correct}") + print(f"Accuracy: {total_correct/total_tried:.3f}") + + # Upload results if URL provided + try: + if hasattr(args, "upload_url") and args.upload_url: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=output_df.to_dict("records"), + url=args.upload_url, + runner_type="anthropic", + prompt=prompt, + args=args, + ) + except Exception as e: + print(f"Error uploading results: {e}") diff --git a/runners/api_runner.py b/runners/api_runner.py index 6b2afdf..a8b57f1 100644 --- a/runners/api_runner.py +++ b/runners/api_runner.py @@ -206,34 +206,30 @@ def process_row( def run_api_eval(args): - # get params from args + """Run evaluation using API""" questions_file_list = args.questions_file prompt_file_list = args.prompt_file + output_file_list = args.output_file num_questions = args.num_questions public_data = not args.use_private_data - api_url = args.api_url - api_type = args.api_type - output_file_list = args.output_file k_shot = args.k_shot - num_beams = args.num_beams - max_workers = args.parallel_threads + cot_table_alias = args.cot_table_alias db_type = args.db_type - decimal_points = args.decimal_points logprobs = args.logprobs - cot_table_alias = args.cot_table_alias - sql_lora_path = args.adapter if args.adapter else None - sql_lora_name = args.adapter_name if args.adapter_name else None - run_name = args.run_name if args.run_name else None + run_name = getattr(args, "run_name", None) + sql_lora_path = getattr(args, "adapter", None) + if sql_lora_path: print("Using LoRA adapter at:", sql_lora_path) + + # Logprobs visualization directory handling if logprobs: - # check that the eval-visualizer/public directory exists if not os.path.exists("./eval-visualizer"): - # thorow error raise Exception( - "The eval-visualizer directory does not exist. Please clone it with `git clone https://github.com/defog-ai/eval-visualizer/` before running sql-eval with the --logprobs flag." + "The eval-visualizer directory does not exist. Please clone it with " + "`git clone https://github.com/defog-ai/eval-visualizer/` before running " + "sql-eval with the --logprobs flag." ) - if not os.path.exists("./eval-visualizer/public"): os.makedirs("./eval-visualizer/public") @@ -241,7 +237,6 @@ def run_api_eval(args): questions_file_list, prompt_file_list, output_file_list ): print(f"Using prompt file {prompt_file}") - # get questions print("Preparing questions...") print( f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" @@ -249,7 +244,8 @@ def run_api_eval(args): df = prepare_questions_df( questions_file, db_type, num_questions, k_shot, cot_table_alias ) - # create a prompt for each question + + # Create prompts with all parameters df["prompt"] = df.apply( lambda row: generate_prompt( prompt_file, @@ -262,65 +258,30 @@ def run_api_eval(args): row["table_metadata_string"], row["prev_invalid_sql"], row["prev_error_msg"], - row["question_0"], - row["query_0"], - row["question_1"], - row["query_1"], - row["cot_instructions"], - row["cot_pregen"], + row.get("question_0", ""), + row.get("query_0", ""), + row.get("question_1", ""), + row.get("query_1", ""), + row.get("cot_instructions", ""), + row.get("cot_pregen", False), public_data, - args.num_columns, + args.num_columns if hasattr(args, "num_columns") else 40, args.shuffle_metadata, - row["table_aliases"], + row.get("table_aliases", ""), ), axis=1, ) - total_tried = 0 - total_correct = 0 - output_rows = [] - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] - for row in df.to_dict("records"): - futures.append( - executor.submit( - process_row, - row, - api_url, - api_type, - num_beams, - decimal_points, - logprobs, - sql_lora_path, - sql_lora_name, - ) - ) - - with tqdm(as_completed(futures), total=len(futures)) as pbar: - for f in pbar: - row = f.result() - output_rows.append(row) - if row["correct"]: - total_correct += 1 - total_tried += 1 - pbar.update(1) - pbar.set_description( - f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" - ) + output_rows, total_correct, total_tried = run_eval_in_threadpool( + df, args.api_url, process_row, args + ) output_df = pd.DataFrame(output_rows) - - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist - output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): - os.makedirs(output_dir) - - results = output_df.to_dict("records") + # Handle logprobs visualization if logprobs: + results = output_df.to_dict("records") print( f"Writing logprobs to JSON file at eval-visualizer/public/{output_file.split('/')[-1].replace('.csv', '.json')}" ) @@ -330,27 +291,50 @@ def run_api_eval(args): ) as f: json.dump(results, f) - del output_df["prompt"] + # Get stats by query category + agg_stats = ( + output_df.groupby("query_category") + .agg( + num_rows=("db_name", "count"), + mean_correct=("correct", "mean"), + mean_error_db_exec=("error_db_exec", "mean"), + ) + .reset_index() + ) + print(agg_stats) + + # Clean up and save results + if "prompt" in output_df.columns: + del output_df["prompt"] + + output_dir = os.path.dirname(output_file) + if output_dir and not os.path.exists(output_dir): + os.makedirs(output_dir) + try: output_df.to_csv(output_file, index=False, float_format="%.2f") except: output_df.to_pickle(output_file) - # upload results - # with open(prompt_file, "r") as f: - # prompt = f.read() - - if args.run_name is None: + # Handle run naming and result upload + if run_name is None: run_name = output_file.split("/")[-1].replace(".csv", "") print( - "Run name not provided. Using a output filename for run name:", run_name + "Run name not provided. Using output filename for run name:", run_name ) - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="api_runner", - args=args, - run_name=run_name, - ) + print(f"Total questions: {total_tried}") + print(f"Total correct: {total_correct}") + print(f"Accuracy: {total_correct/total_tried:.3f}") + + try: + if hasattr(args, "upload_url") and args.upload_url: + upload_results( + results=output_df.to_dict("records"), + url=args.upload_url, + runner_type="api_runner", + args=args, + run_name=run_name, + ) + except Exception as e: + print(f"Error uploading results: {e}") diff --git a/runners/base_runner.py b/runners/base_runner.py new file mode 100644 index 0000000..4669ed0 --- /dev/null +++ b/runners/base_runner.py @@ -0,0 +1,117 @@ +import json +from time import time +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pandas as pd +import sqlparse +from tqdm import tqdm + +from eval.eval import compare_query_results +from utils.creds import db_creds_all +from utils.dialects import convert_postgres_ddl_to_dialect +from utils.gen_prompt import to_prompt_schema +from utils.questions import prepare_questions_df +from utils.reporting import upload_results + + +def generate_base_prompt( + prompt_file, + question, + db_name, + db_type, + instructions="", + k_shot_prompt="", + glossary="", + table_metadata_string="", + prev_invalid_sql="", + prev_error_msg="", + public_data=True, + shuffle=True, +): + """ + Base prompt generation logic used by all runners. + """ + if public_data: + from defog_data.metadata import dbs + import defog_data.supplementary as sup + else: + from defog_data_private.metadata import dbs + import defog_data_private.supplementary as sup + + if table_metadata_string == "": + md = dbs[db_name]["table_metadata"] + pruned_metadata_ddl = to_prompt_schema(md, shuffle) + pruned_metadata_ddl = convert_postgres_ddl_to_dialect( + postgres_ddl=pruned_metadata_ddl, + to_dialect=db_type, + db_name=db_name, + ) + column_join = sup.columns_join.get(db_name, {}) + join_list = [] + for values in column_join.values(): + if isinstance(values[0], tuple): + for col_pair in values: + col_1, col_2 = col_pair + join_str = f"{col_1} can be joined with {col_2}" + if join_str not in join_list: + join_list.append(join_str) + else: + col_1, col_2 = values[0] + join_str = f"{col_1} can be joined with {col_2}" + if join_str not in join_list: + join_list.append(join_str) + join_str = ( + "\nHere is a list of joinable columns:\n" + "\n".join(join_list) + if join_list + else "" + ) + pruned_metadata_str = pruned_metadata_ddl + join_str + else: + pruned_metadata_str = table_metadata_string + + return { + "prompt_file": prompt_file, + "question": question, + "db_type": db_type, + "instructions": instructions, + "table_metadata_string": pruned_metadata_str, + "k_shot_prompt": k_shot_prompt, + "glossary": glossary, + "prev_invalid_sql": prev_invalid_sql, + "prev_error_msg": prev_error_msg, + } + + +def extract_sql_from_response(content): + """Extract SQL from between ```sql blocks and format it.""" + try: + generated_query = content.split("```sql", 1)[-1].split("```", 1)[0].strip() + return sqlparse.format(generated_query, reindent=True, keyword_case="upper") + except: + return content + + +def run_eval_in_threadpool(df, model_name, process_row_func, args): + """Common threadpool execution pattern for all runners.""" + total_tried = 0 + total_correct = 0 + output_rows = [] + + print(f"Running evaluation using {model_name}...") + with ThreadPoolExecutor(max_workers=args.parallel_threads) as executor: + futures = [] + for row in df.to_dict("records"): + futures.append(executor.submit(process_row_func, row, model_name, args)) + + with tqdm(as_completed(futures), total=len(futures)) as pbar: + for f in pbar: + row = f.result() + output_rows.append(row) + if row.get("correct", 0): + total_correct += 1 + total_tried += 1 + pbar.set_description( + f"Acc: {total_correct}/{total_tried}={total_correct/total_tried:.3f}" + ) + + return output_rows, total_correct, total_tried diff --git a/runners/bedrock_runner.py b/runners/bedrock_runner.py index 806402a..4c46769 100644 --- a/runners/bedrock_runner.py +++ b/runners/bedrock_runner.py @@ -1,99 +1,109 @@ -import boto3 import json import os -from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import Optional - -from eval.eval import compare_query_results +import boto3 +from time import time import pandas as pd + +from runners.base_runner import ( + generate_base_prompt, + extract_sql_from_response, + run_eval_in_threadpool, +) from utils.gen_prompt import generate_prompt from utils.questions import prepare_questions_df from utils.creds import db_creds_all -from tqdm import tqdm -from time import time from utils.reporting import upload_results +from eval.eval import compare_query_results bedrock = boto3.client(service_name="bedrock-runtime") -def process_row(row, model_id, decimal_points): +def process_row(row, model_id, args): + """Process a single row using AWS Bedrock""" start_time = time() - - body = json.dumps( - { - "prompt": row["prompt"], - "max_gen_len": 600, - "temperature": 0, - "top_p": 1, - } - ) - - accept = "application/json" - contentType = "application/json" - response = bedrock.invoke_model( - body=body, modelId=model_id, accept=accept, contentType=contentType - ) - model_response = json.loads(response["body"].read()) - - generated_query = model_response["generation"] - end_time = time() - - generated_query = ( - generated_query.split("```sql")[-1].split("```")[0].split(";")[0].strip() + ";" - ) - - row["generated_query"] = generated_query - row["latency_seconds"] = end_time - start_time - row["tokens_used"] = None - golden_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - question = row["question"] - query_category = row["query_category"] - table_metadata_string = row["table_metadata_string"] - exact_match = correct = 0 - try: - exact_match, correct = compare_query_results( - query_gold=golden_query, - query_gen=generated_query, - db_name=db_name, - db_type=db_type, - db_creds=db_creds_all[row["db_type"]], - question=question, - query_category=query_category, - table_metadata_string=table_metadata_string, - decimal_points=decimal_points, + # Bedrock-specific request payload + body = json.dumps( + { + "prompt": row["prompt"], + "max_gen_len": 600, + "temperature": 0, + "top_p": 1, + } ) - row["exact_match"] = int(exact_match) - row["correct"] = int(correct) - row["error_msg"] = "" - except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - return row + accept = "application/json" + contentType = "application/json" + response = bedrock.invoke_model( + body=body, modelId=model_id, accept=accept, contentType=contentType + ) + model_response = json.loads(response["body"].read()) + generated_query = model_response["generation"] + end_time = time() + + # Bedrock-specific SQL extraction + generated_query = extract_sql_from_response(generated_query) + + # Store results + row["generated_query"] = generated_query + row["latency_seconds"] = end_time - start_time + row["tokens_used"] = None # Bedrock doesn't provide token count + + # Verify results + golden_query = row["query"] + db_name = row["db_name"] + db_type = row["db_type"] + question = row["question"] + query_category = row["query_category"] + table_metadata_string = row["table_metadata_string"] + + try: + exact_match, correct = compare_query_results( + query_gold=golden_query, + query_gen=generated_query, + db_name=db_name, + db_type=db_type, + db_creds=db_creds_all[db_type], + question=question, + query_category=query_category, + table_metadata_string=table_metadata_string, + decimal_points=( + args.decimal_points if hasattr(args, "decimal_points") else 2 + ), + ) + row["exact_match"] = int(exact_match) + row["correct"] = int(correct) + row["is_correct"] = int(correct) # For compatibility with base runner + row["error_msg"] = "" + except Exception as e: + row["error_db_exec"] = 1 + row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" + + return row + except Exception as e: + row["error_query_gen"] = 1 + row["generated_query"] = "" + row["error_msg"] = f"GENERATION ERROR: {e}" + row["latency_seconds"] = time() - start_time + row["tokens_used"] = None + return row def run_bedrock_eval(args): - # get params from args + """Run evaluation using AWS Bedrock""" questions_file_list = args.questions_file prompt_file_list = args.prompt_file + output_file_list = args.output_file num_questions = args.num_questions public_data = not args.use_private_data - output_file_list = args.output_file k_shot = args.k_shot - max_workers = args.parallel_threads db_type = args.db_type - decimal_points = args.decimal_points - model_id = args.model cot_table_alias = args.cot_table_alias for questions_file, prompt_file, output_file in zip( questions_file_list, prompt_file_list, output_file_list ): print(f"Using prompt file {prompt_file}") - # get questions print("Preparing questions...") print( f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" @@ -101,7 +111,8 @@ def run_bedrock_eval(args): df = prepare_questions_df( questions_file, db_type, num_questions, k_shot, cot_table_alias ) - # create a prompt for each question + + # Create prompts with all parameters df["prompt"] = df.apply( lambda row: generate_prompt( prompt_file, @@ -114,64 +125,68 @@ def run_bedrock_eval(args): row["table_metadata_string"], row["prev_invalid_sql"], row["prev_error_msg"], - row["question_0"], - row["query_0"], - row["question_1"], - row["query_1"], - row["cot_instructions"], - row["cot_pregen"], + row.get("question_0", ""), + row.get("query_0", ""), + row.get("question_1", ""), + row.get("query_1", ""), + row.get("cot_instructions", ""), + row.get("cot_pregen", False), public_data, - args.num_columns, + args.num_columns if hasattr(args, "num_columns") else 40, args.shuffle_metadata, + row.get("table_aliases", ""), ), axis=1, ) - total_tried = 0 - total_correct = 0 - output_rows = [] - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] - for row in df.to_dict("records"): - futures.append( - executor.submit(process_row, row, model_id, decimal_points) - ) - - with tqdm(as_completed(futures), total=len(futures)) as pbar: - for f in pbar: - row = f.result() - output_rows.append(row) - if row["correct"]: - total_correct += 1 - total_tried += 1 - pbar.update(1) - pbar.set_description( - f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" - ) + output_rows, total_correct, total_tried = run_eval_in_threadpool( + df, args.model, process_row, args + ) + # Convert to DataFrame and save results output_df = pd.DataFrame(output_rows) - del output_df["prompt"] - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist + if "prompt" in output_df.columns: + del output_df["prompt"] + + # Get stats by query category + agg_stats = ( + output_df.groupby("query_category") + .agg( + num_rows=("db_name", "count"), + mean_correct=("correct", "mean"), + mean_error_db_exec=("error_db_exec", "mean"), + ) + .reset_index() + ) + print(agg_stats) + + # Create output directory if needed output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): + if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) + try: output_df.to_csv(output_file, index=False, float_format="%.2f") except: output_df.to_pickle(output_file) - results = output_df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="api_runner", - prompt=prompt, - args=args, - ) + # Print summary stats + print(f"Total questions: {total_tried}") + print(f"Total correct: {total_correct}") + print(f"Accuracy: {total_correct/total_tried:.3f}") + + # Upload results if URL provided + try: + if hasattr(args, "upload_url") and args.upload_url: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=output_df.to_dict("records"), + url=args.upload_url, + runner_type="bedrock", + prompt=prompt, + args=args, + ) + except Exception as e: + print(f"Error uploading results: {e}") diff --git a/runners/deepseek_runner.py b/runners/deepseek_runner.py index 323c0c1..8fb0172 100644 --- a/runners/deepseek_runner.py +++ b/runners/deepseek_runner.py @@ -1,96 +1,110 @@ import os -from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Dict - -from eval.eval import compare_query_results +from time import time import pandas as pd + +from openai import OpenAI + +from runners.base_runner import generate_base_prompt, run_eval_in_threadpool from utils.gen_prompt import generate_prompt from utils.questions import prepare_questions_df from utils.creds import db_creds_all -from tqdm import tqdm -from time import time -from openai import OpenAI from utils.reporting import upload_results - +from eval.eval import compare_query_results client = OpenAI( base_url="https://api.deepseek.com", api_key=os.environ.get("DEEPSEEK_API_KEY") ) -def process_row(row: Dict, model: str): +def process_row(row: Dict, model: str, args): + """Process a single row using Deepseek""" start_time = time() - messages = row["prompt"] - if model != "deepseek-reasoner": - response = client.chat.completions.create( - model=model, - messages=messages, - max_tokens=800, - temperature=0.0, - ) - else: - response = client.chat.completions.create( - model=model, - messages=messages, - max_tokens=800, - ) - content = response.choices[0].message.content - generated_query = content.replace("```sql", "").replace("```", "").strip() - end_time = time() - - row["generated_query"] = generated_query - row["latency_seconds"] = end_time - start_time - row["tokens_used"] = None - golden_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - question = row["question"] - query_category = row["query_category"] - table_metadata_string = row["table_metadata_string"] - exact_match = correct = 0 - try: - exact_match, correct = compare_query_results( - query_gold=golden_query, - query_gen=generated_query, - db_name=db_name, - db_type=db_type, - db_creds=db_creds_all[row["db_type"]], - question=question, - query_category=query_category, - table_metadata_string=table_metadata_string, - ) - row["exact_match"] = int(exact_match) - row["correct"] = int(correct) - row["error_msg"] = "" - except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" + messages = row["prompt"] + # Deepseek-specific handling + if model != "deepseek-reasoner": + response = client.chat.completions.create( + model=model, + messages=messages, + max_tokens=800, + temperature=0.0, + ) + else: + response = client.chat.completions.create( + model=model, + messages=messages, + max_tokens=800, + ) + content = response.choices[0].message.content + # Deepseek-specific SQL extraction + generated_query = content.replace("```sql", "").replace("```", "").strip() + end_time = time() + + # Store results + row["generated_query"] = generated_query + row["latency_seconds"] = end_time - start_time + row["tokens_used"] = None # Deepseek doesn't provide token count + + # Verify results + golden_query = row["query"] + db_name = row["db_name"] + db_type = row["db_type"] + question = row["question"] + query_category = row["query_category"] + table_metadata_string = row["table_metadata_string"] - return row + try: + exact_match, correct = compare_query_results( + query_gold=golden_query, + query_gen=generated_query, + db_name=db_name, + db_type=db_type, + db_creds=db_creds_all[db_type], + question=question, + query_category=query_category, + table_metadata_string=table_metadata_string, + decimal_points=( + args.decimal_points if hasattr(args, "decimal_points") else 2 + ), + ) + row["exact_match"] = int(exact_match) + row["correct"] = int(correct) + row["is_correct"] = int(correct) # For compatibility with base runner + row["error_msg"] = "" + except Exception as e: + row["error_db_exec"] = 1 + row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" + + return row + except Exception as e: + row["error_query_gen"] = 1 + row["generated_query"] = "" + row["error_msg"] = f"GENERATION ERROR: {e}" + row["latency_seconds"] = time() - start_time + row["tokens_used"] = None + return row def run_deepseek_eval(args): - # get params from args + """Run evaluation using Deepseek""" questions_file_list = args.questions_file prompt_file_list = args.prompt_file + output_file_list = args.output_file num_questions = args.num_questions public_data = not args.use_private_data - output_file_list = args.output_file k_shot = args.k_shot - max_workers = args.parallel_threads db_type = args.db_type - decimal_points = args.decimal_points - model = args.model cot_table_alias = args.cot_table_alias for questions_file, prompt_file, output_file in zip( questions_file_list, prompt_file_list, output_file_list ): + # Deepseek-specific JSON validation if not prompt_file.endswith(".json"): raise ValueError(f"Prompt file must be a JSON file. Got {prompt_file}") + print(f"Using prompt file {prompt_file}") - # get questions print("Preparing questions...") print( f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" @@ -98,8 +112,8 @@ def run_deepseek_eval(args): df = prepare_questions_df( questions_file, db_type, num_questions, k_shot, cot_table_alias ) - # create a prompt for each question - # note that the prompt for together ai uses the openai chat API + + # Create prompts with all parameters df["prompt"] = df.apply( lambda row: generate_prompt( prompt_file, @@ -112,63 +126,68 @@ def run_deepseek_eval(args): row["table_metadata_string"], row["prev_invalid_sql"], row["prev_error_msg"], - row["question_0"], - row["query_0"], - row["question_1"], - row["query_1"], - row["cot_instructions"], - row["cot_pregen"], + row.get("question_0", ""), + row.get("query_0", ""), + row.get("question_1", ""), + row.get("query_1", ""), + row.get("cot_instructions", ""), + row.get("cot_pregen", False), public_data, - args.num_columns, + args.num_columns if hasattr(args, "num_columns") else 40, args.shuffle_metadata, - row["table_aliases"], + row.get("table_aliases", ""), ), axis=1, ) - total_tried = 0 - total_correct = 0 - output_rows = [] - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] - for row in df.to_dict("records"): - futures.append(executor.submit(process_row, row, model)) - - with tqdm(as_completed(futures), total=len(futures)) as pbar: - for f in pbar: - row = f.result() - output_rows.append(row) - if row["correct"]: - total_correct += 1 - total_tried += 1 - pbar.update(1) - pbar.set_description( - f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" - ) + output_rows, total_correct, total_tried = run_eval_in_threadpool( + df, args.model, process_row, args + ) + # Convert to DataFrame and save results output_df = pd.DataFrame(output_rows) - del output_df["prompt"] - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist + if "prompt" in output_df.columns: + del output_df["prompt"] + + # Get stats by query category + agg_stats = ( + output_df.groupby("query_category") + .agg( + num_rows=("db_name", "count"), + mean_correct=("correct", "mean"), + mean_error_db_exec=("error_db_exec", "mean"), + ) + .reset_index() + ) + print(agg_stats) + + # Create output directory if needed output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): + if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) + try: output_df.to_csv(output_file, index=False, float_format="%.2f") except: output_df.to_pickle(output_file) - results = output_df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="api_runner", - prompt=prompt, - args=args, - ) + # Print summary stats + print(f"Total questions: {total_tried}") + print(f"Total correct: {total_correct}") + print(f"Accuracy: {total_correct/total_tried:.3f}") + + # Upload results if URL provided + try: + if hasattr(args, "upload_url") and args.upload_url: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=output_df.to_dict("records"), + url=args.upload_url, + runner_type="deepseek", + prompt=prompt, + args=args, + ) + except Exception as e: + print(f"Error uploading results: {e}") diff --git a/runners/gemini_runner.py b/runners/gemini_runner.py index cd292c1..db9c923 100644 --- a/runners/gemini_runner.py +++ b/runners/gemini_runner.py @@ -1,18 +1,17 @@ -import os from time import time -from concurrent.futures import ThreadPoolExecutor, as_completed - +import os import pandas as pd -import sqlparse -from tqdm import tqdm -from eval.eval import compare_query_results -from utils.creds import db_creds_all -from utils.dialects import convert_postgres_ddl_to_dialect -from utils.gen_prompt import to_prompt_schema +from runners.base_runner import ( + generate_base_prompt, + extract_sql_from_response, + run_eval_in_threadpool, +) from utils.questions import prepare_questions_df -from utils.reporting import upload_results from utils.llm import chat_gemini +from utils.creds import db_creds_all +from utils.reporting import upload_results +from eval.eval import compare_query_results def generate_prompt( @@ -29,55 +28,33 @@ def generate_prompt( public_data=True, shuffle=True, ): - if public_data: - from defog_data.metadata import dbs - import defog_data.supplementary as sup - else: - # raise Exception("Replace this with your private data import") - from defog_data_private.metadata import dbs - import defog_data_private.supplementary as sup + """Gemini-specific prompt handling""" + # Get base prompt data + base_data = generate_base_prompt( + prompt_file, + question, + db_name, + db_type, + instructions, + k_shot_prompt, + glossary, + table_metadata_string, + prev_invalid_sql, + prev_error_msg, + public_data, + shuffle, + ) + # Load and format Gemini text prompt with open(prompt_file, "r") as f: prompt = f.read() - if table_metadata_string == "": - md = dbs[db_name]["table_metadata"] - pruned_metadata_ddl = to_prompt_schema(md, shuffle) - pruned_metadata_ddl = convert_postgres_ddl_to_dialect( - postgres_ddl=pruned_metadata_ddl, - to_dialect=db_type, - db_name=db_name, - ) - column_join = sup.columns_join.get(db_name, {}) - # get join_str from column_join - join_list = [] - for values in column_join.values(): - if isinstance(values[0], tuple): - for col_pair in values: - col_1, col_2 = col_pair - # add to join_list - join_str = f"{col_1} can be joined with {col_2}" - if join_str not in join_list: - join_list.append(join_str) - else: - col_1, col_2 = values[0] - # add to join_list - join_str = f"{col_1} can be joined with {col_2}" - if join_str not in join_list: - join_list.append(join_str) - if len(join_list) > 0: - join_str = "\nHere is a list of joinable columns:\n" + "\n".join(join_list) - else: - join_str = "" - pruned_metadata_str = pruned_metadata_ddl + join_str - else: - pruned_metadata_str = table_metadata_string - + # Format the prompt with all parameters prompt = prompt.format( user_question=question, db_type=db_type, instructions=instructions, - table_metadata_string=pruned_metadata_str, + table_metadata_string=base_data["table_metadata_string"], k_shot_prompt=k_shot_prompt, glossary=glossary, prev_invalid_sql=prev_invalid_sql, @@ -87,68 +64,65 @@ def generate_prompt( def process_row(row, model_name, args): + """Process a single row using Gemini""" start_time = time() + # Prompt already in row from DataFrame preprocessing messages = [{"role": "user", "content": row["prompt"]}] try: response = chat_gemini(messages=messages, model=model_name, temperature=0.0) - generated_query = ( - response.content.split("```sql", 1)[-1].split("```", 1)[0].strip() - ) - try: - generated_query = sqlparse.format( - generated_query, - strip_comments=True, - strip_whitespace=True, - keyword_case="upper", - ) - except: - pass + generated_query = extract_sql_from_response(response.content) + + # Gemini-specific result handling row["generated_query"] = generated_query row["latency_seconds"] = response.time row["tokens_used"] = response.input_tokens + response.output_tokens - except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"GENERATION ERROR: {e}" - return row - golden_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - question = row["question"] - query_category = row["query_category"] - exact_match = correct = 0 + # Verify results with exact_match + golden_query = row["query"] + db_name = row["db_name"] + db_type = row["db_type"] + question = row["question"] + query_category = row["query_category"] - try: - exact_match, correct = compare_query_results( - query_gold=golden_query, - query_gen=generated_query, - db_name=db_name, - db_type=db_type, - db_creds=db_creds_all[db_type], - question=question, - query_category=query_category, - decimal_points=args.decimal_points, - ) - row["exact_match"] = int(exact_match) - row["correct"] = int(correct) - row["error_msg"] = "" - except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" + try: + exact_match, correct = compare_query_results( + query_gold=golden_query, + query_gen=generated_query, + db_name=db_name, + db_type=db_type, + db_creds=db_creds_all[db_type], + question=question, + query_category=query_category, + decimal_points=( + args.decimal_points if hasattr(args, "decimal_points") else 2 + ), + ) + row["exact_match"] = int(exact_match) + row["correct"] = int(correct) + row["is_correct"] = int(correct) # For compatibility with base runner + row["error_msg"] = "" + except Exception as e: + row["error_db_exec"] = 1 + row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - return row + return row + except Exception as e: + row["error_query_gen"] = 1 + row["generated_query"] = "" + row["error_msg"] = f"GENERATION ERROR: {e}" + row["latency_seconds"] = time() - start_time + row["tokens_used"] = 0 + return row def run_gemini_eval(args): - # get params from args + """Run evaluation using Gemini""" questions_file_list = args.questions_file prompt_file_list = args.prompt_file + output_file_list = args.output_file num_questions = args.num_questions public_data = not args.use_private_data - model_name = args.model - output_file_list = args.output_file k_shot = args.k_shot - max_workers = args.parallel_threads db_type = args.db_type cot_table_alias = args.cot_table_alias @@ -164,6 +138,7 @@ def run_gemini_eval(args): questions_file, db_type, num_questions, k_shot, cot_table_alias ) + # Gemini-specific: preprocess prompts into DataFrame df["prompt"] = df.apply( lambda row: generate_prompt( prompt_file, @@ -182,49 +157,51 @@ def run_gemini_eval(args): axis=1, ) - total_tried = 0 - total_correct = 0 - output_rows = [] - - print(f"Running evaluation using {model_name}...") - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] - for row in df.to_dict("records"): - futures.append(executor.submit(process_row, row, model_name, args)) - - with tqdm(as_completed(futures), total=len(futures)) as pbar: - for f in pbar: - row = f.result() - output_rows.append(row) - if row.get("correct", 0): - total_correct += 1 - total_tried += 1 - pbar.set_description( - f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" - ) + output_rows, total_correct, total_tried = run_eval_in_threadpool( + df, args.model, process_row, args + ) + # Convert to DataFrame and save results output_df = pd.DataFrame(output_rows) - del output_df["prompt"] - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) + if "prompt" in output_df.columns: + del output_df["prompt"] + + # Get stats by query category + agg_stats = ( + output_df.groupby("query_category") + .agg( + num_rows=("db_name", "count"), + mean_correct=("correct", "mean"), + mean_error_db_exec=("error_db_exec", "mean"), + ) + .reset_index() + ) + print(agg_stats) + # Create output directory if needed output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): + if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) - try: - output_df.to_csv(output_file, index=False, float_format="%.2f") - except: - output_df.to_pickle(output_file) - results = output_df.to_dict("records") + output_df.to_csv(output_file, index=False, float_format="%.2f") - if args.upload_url is not None: - with open(prompt_file, "r") as f: - prompt = f.read() + # Print summary stats + print(f"Total questions: {total_tried}") + print(f"Total correct: {total_correct}") + print(f"Accuracy: {total_correct/total_tried:.3f}") + + # Upload results if URL provided + try: + if hasattr(args, "upload_url") and args.upload_url: + with open(prompt_file, "r") as f: + prompt = f.read() upload_results( - results=results, + results=output_df.to_dict("records"), url=args.upload_url, - runner_type="api_runner", + runner_type="gemini", prompt=prompt, args=args, ) + except Exception as e: + print(f"Error uploading results: {e}") diff --git a/runners/hf_runner.py b/runners/hf_runner.py index 9046a65..1e601e9 100644 --- a/runners/hf_runner.py +++ b/runners/hf_runner.py @@ -1,22 +1,21 @@ import os from typing import Optional - -from eval.eval import compare_query_results -import pandas as pd import torch +import gc +import pandas as pd from transformers import ( AutoTokenizer, AutoModelForCausalLM, pipeline, ) +from tqdm import tqdm +from psycopg2.extensions import QueryCanceledError + from utils.gen_prompt import generate_prompt from utils.questions import prepare_questions_df from utils.creds import db_creds_all -from tqdm import tqdm -from psycopg2.extensions import QueryCanceledError -from time import time -import gc from utils.reporting import upload_results +from eval.eval import compare_query_results device_map = "mps" if torch.backends.mps.is_available() else "auto" @@ -62,15 +61,23 @@ def get_tokenizer_model(model_name: Optional[str], adapter_path: Optional[str]): return tokenizer, model +def extract_hf_sql(text: str, has_sql_tag: bool) -> str: + """HuggingFace-specific SQL extraction""" + if not has_sql_tag: + return text.split("```")[0].split(";")[0].strip() + ";" + else: + return text.split("[/SQL]")[0].split(";")[0].strip() + ";" + + def run_hf_eval(args): - # get params from args + """Run evaluation using HuggingFace models""" questions_file_list = args.questions_file prompt_file_list = args.prompt_file + output_file_list = args.output_file num_questions = args.num_questions public_data = not args.use_private_data model_name = args.model adapter_path = args.adapter - output_file_list = args.output_file k_shot = args.k_shot db_type = args.db_type num_beams = args.num_beams @@ -94,8 +101,6 @@ def run_hf_eval(args): print("model loaded\nnow generating and evaluating predictions...") - # from here, we generate and evaluate predictions - # eos_token_id = tokenizer.convert_tokens_to_ids(["```"])[0] pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, batch_size=args.batch_size ) @@ -104,7 +109,6 @@ def run_hf_eval(args): questions_file_list, prompt_file_list, output_file_list ): print(f"Using prompt file {prompt_file}") - # get questions print("Preparing questions...") print( f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" @@ -112,7 +116,8 @@ def run_hf_eval(args): df = prepare_questions_df( questions_file, db_type, num_questions, k_shot, cot_table_alias ) - # create a prompt for each question + + # Create prompts with all parameters df["prompt"] = df.apply( lambda row: generate_prompt( prompt_file, @@ -125,15 +130,16 @@ def run_hf_eval(args): row["table_metadata_string"], row["prev_invalid_sql"], row["prev_error_msg"], - row["question_0"], - row["query_0"], - row["question_1"], - row["query_1"], - row["cot_instructions"], - row["cot_pregen"], + row.get("question_0", ""), + row.get("query_0", ""), + row.get("question_1", ""), + row.get("query_1", ""), + row.get("cot_instructions", ""), + row.get("cot_pregen", False), public_data, - args.num_columns, + args.num_columns if hasattr(args, "num_columns") else 40, args.shuffle_metadata, + row.get("table_aliases", ""), ), axis=1, ) @@ -165,30 +171,16 @@ def chunk_dataframe(df, chunk_size): top_p=None, ) gc.collect() - torch.cuda.empty_cache() - torch.cuda.synchronize() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() for row, result in zip(batch.to_dict("records"), generated_queries): total_tried += 1 - # we set return_full_text to False so that we don't get the prompt text in the generated text - # this simplifies our postprocessing to deal with just the truncation of the end of the query - - if "[SQL]" not in row["prompt"]: - generated_query = ( - result[0]["generated_text"] - .split("```")[0] - .split(";")[0] - .strip() - + ";" - ) - else: - generated_query = ( - result[0]["generated_text"] - .split("[/SQL]")[0] - .split(";")[0] - .strip() - + ";" - ) + has_sql_tag = "[SQL]" in row["prompt"] + generated_query = extract_hf_sql( + result[0]["generated_text"], has_sql_tag + ) gc.collect() if torch.cuda.is_available(): @@ -203,8 +195,6 @@ def chunk_dataframe(df, chunk_size): question = row["question"] query_category = row["query_category"] table_metadata_string = row["table_metadata_string"] - exact_match = correct = 0 - db_creds = db_creds_all[db_type] try: exact_match, correct = compare_query_results( @@ -212,14 +202,21 @@ def chunk_dataframe(df, chunk_size): query_gen=generated_query, db_name=db_name, db_type=db_type, - db_creds=db_creds, + db_creds=db_creds_all[db_type], question=question, query_category=query_category, table_metadata_string=table_metadata_string, - decimal_points=args.decimal_points, + decimal_points=( + args.decimal_points + if hasattr(args, "decimal_points") + else 2 + ), ) row["exact_match"] = int(exact_match) row["correct"] = int(correct) + row["is_correct"] = int( + correct + ) # For base runner compatibility row["error_msg"] = "" if correct: total_correct += 1 @@ -236,25 +233,47 @@ def chunk_dataframe(df, chunk_size): f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" ) + # Convert to DataFrame and save results output_df = pd.DataFrame(output_rows) - del output_df["prompt"] - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist + if "prompt" in output_df.columns: + del output_df["prompt"] + + # Get stats by query category + agg_stats = ( + output_df.groupby("query_category") + .agg( + num_rows=("db_name", "count"), + mean_correct=("correct", "mean"), + mean_error_db_exec=("error_db_exec", "mean"), + ) + .reset_index() + ) + print(agg_stats) + + # Create output directory if needed output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): + if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) + output_df.to_csv(output_file, index=False, float_format="%.2f") - results = output_df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="hf_runner", - prompt=prompt, - args=args, - ) + # Print summary stats + print(f"Total questions: {total_tried}") + print(f"Total correct: {total_correct}") + print(f"Accuracy: {total_correct/total_tried:.3f}") + + # Upload results if URL provided + try: + if hasattr(args, "upload_url") and args.upload_url: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=output_df.to_dict("records"), + url=args.upload_url, + runner_type="hf_runner", + prompt=prompt, + args=args, + ) + except Exception as e: + print(f"Error uploading results: {e}") diff --git a/runners/llama_cpp_runner.py b/runners/llama_cpp_runner.py index 0297ca0..16131a7 100644 --- a/runners/llama_cpp_runner.py +++ b/runners/llama_cpp_runner.py @@ -1,84 +1,97 @@ import os - -from eval.eval import compare_query_results +from time import time import pandas as pd +from llama_cpp import Llama +from tqdm import tqdm + from utils.gen_prompt import generate_prompt from utils.questions import prepare_questions_df from utils.creds import db_creds_all -from tqdm import tqdm -from time import time from utils.reporting import upload_results -from llama_cpp import Llama +from eval.eval import compare_query_results def process_row(llm, row, args): + """Process a single row using Llama.cpp""" start_time = time() - prompt = row["prompt"] - generated_query = ( - llm( + try: + prompt = row["prompt"] + response = llm( prompt, max_tokens=512, temperature=0, top_p=1, echo=False, repeat_penalty=1.0, - )["choices"][0]["text"] - .split(";")[0] - .split("```")[0] - .strip() - + ";" - ) - end_time = time() - row["generated_query"] = generated_query - row["latency_seconds"] = end_time - start_time - golden_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - question = row["question"] - query_category = row["query_category"] - table_metadata_string = row["table_metadata_string"] - exact_match = correct = 0 - - try: - exact_match, correct = compare_query_results( - query_gold=golden_query, - query_gen=generated_query, - db_name=db_name, - db_type=db_type, - db_creds=db_creds_all[row["db_type"]], - question=question, - query_category=query_category, - table_metadata_string=table_metadata_string, - decimal_points=args.decimal_points, ) - row["exact_match"] = int(exact_match) - row["correct"] = int(correct) - row["error_msg"] = "" + # Llama.cpp-specific SQL extraction + generated_query = ( + response["choices"][0]["text"].split(";")[0].split("```")[0].strip() + ";" + ) + end_time = time() + + # Store results + row["generated_query"] = generated_query + row["latency_seconds"] = end_time - start_time + + # Verify results + golden_query = row["query"] + db_name = row["db_name"] + db_type = row["db_type"] + question = row["question"] + query_category = row["query_category"] + table_metadata_string = row["table_metadata_string"] + + try: + exact_match, correct = compare_query_results( + query_gold=golden_query, + query_gen=generated_query, + db_name=db_name, + db_type=db_type, + db_creds=db_creds_all[db_type], + question=question, + query_category=query_category, + table_metadata_string=table_metadata_string, + decimal_points=( + args.decimal_points if hasattr(args, "decimal_points") else 2 + ), + ) + row["exact_match"] = int(exact_match) + row["correct"] = int(correct) + row["is_correct"] = int(correct) # For compatibility with base runner + row["error_msg"] = "" + except Exception as e: + row["error_db_exec"] = 1 + row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" + + return row except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - return row + row["error_query_gen"] = 1 + row["generated_query"] = "" + row["error_msg"] = f"GENERATION ERROR: {e}" + row["latency_seconds"] = time() - start_time + return row def run_llama_cpp_eval(args): - # get params from args + """Run evaluation using Llama.cpp""" questions_file_list = args.questions_file prompt_file_list = args.prompt_file + output_file_list = args.output_file num_questions = args.num_questions public_data = not args.use_private_data model_path = args.model - output_file_list = args.output_file k_shot = args.k_shot db_type = args.db_type cot_table_alias = args.cot_table_alias + # Load Llama.cpp model llm = Llama(model_path=model_path, n_gpu_layers=-1, n_ctx=4096) for questions_file, prompt_file, output_file in zip( questions_file_list, prompt_file_list, output_file_list ): print(f"Using prompt file {prompt_file}") - # get questions print("Preparing questions...") print( f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" @@ -86,7 +99,8 @@ def run_llama_cpp_eval(args): df = prepare_questions_df( questions_file, db_type, num_questions, k_shot, cot_table_alias ) - # create a prompt for each question + + # Create prompts with all parameters df["prompt"] = df.apply( lambda row: generate_prompt( prompt_file, @@ -99,19 +113,21 @@ def run_llama_cpp_eval(args): row["table_metadata_string"], row["prev_invalid_sql"], row["prev_error_msg"], - row["question_0"], - row["query_0"], - row["question_1"], - row["query_1"], - row["cot_instructions"], - row["cot_pregen"], + row.get("question_0", ""), + row.get("query_0", ""), + row.get("question_1", ""), + row.get("query_1", ""), + row.get("cot_instructions", ""), + row.get("cot_pregen", False), public_data, - args.num_columns, + args.num_columns if hasattr(args, "num_columns") else 40, args.shuffle_metadata, + row.get("table_aliases", ""), ), axis=1, ) + # Process rows with direct iteration (no threading) total_tried = 0 total_correct = 0 output_rows = [] @@ -120,7 +136,7 @@ def run_llama_cpp_eval(args): for row in df.to_dict("records"): row = process_row(llm, row, args) output_rows.append(row) - if row["correct"]: + if row.get("correct", 0): total_correct += 1 total_tried += 1 pbar.update(1) @@ -128,28 +144,50 @@ def run_llama_cpp_eval(args): f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" ) + # Convert to DataFrame and save results output_df = pd.DataFrame(output_rows) - del output_df["prompt"] - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist + if "prompt" in output_df.columns: + del output_df["prompt"] + + # Get stats by query category + agg_stats = ( + output_df.groupby("query_category") + .agg( + num_rows=("db_name", "count"), + mean_correct=("correct", "mean"), + mean_error_db_exec=("error_db_exec", "mean"), + ) + .reset_index() + ) + print(agg_stats) + + # Create output directory if needed output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): + if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) + try: output_df.to_csv(output_file, index=False, float_format="%.2f") except: output_df.to_pickle(output_file) - results = output_df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="llama_cpp_runner", - prompt=prompt, - args=args, - ) + # Print summary stats + print(f"Total questions: {total_tried}") + print(f"Total correct: {total_correct}") + print(f"Accuracy: {total_correct/total_tried:.3f}") + + # Upload results if URL provided + try: + if hasattr(args, "upload_url") and args.upload_url: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=output_df.to_dict("records"), + url=args.upload_url, + runner_type="llama_cpp_runner", + prompt=prompt, + args=args, + ) + except Exception as e: + print(f"Error uploading results: {e}") diff --git a/runners/mistral_runner.py b/runners/mistral_runner.py index 4abdf81..97b287e 100644 --- a/runners/mistral_runner.py +++ b/runners/mistral_runner.py @@ -1,18 +1,19 @@ import os from time import time -from concurrent.futures import ThreadPoolExecutor, as_completed +import pandas as pd from mistralai.client import MistralClient from mistralai.models.chat_completion import ChatMessage -import pandas as pd -from tqdm import tqdm -from eval.eval import compare_query_results -from utils.creds import db_creds_all -from utils.gen_prompt import to_prompt_schema -from utils.dialects import convert_postgres_ddl_to_dialect +from runners.base_runner import ( + generate_base_prompt, + extract_sql_from_response, + run_eval_in_threadpool, +) from utils.questions import prepare_questions_df +from utils.creds import db_creds_all from utils.reporting import upload_results +from eval.eval import compare_query_results api_key = os.environ.get("MISTRAL_API_KEY") client = MistralClient(api_key=api_key) @@ -32,141 +33,126 @@ def generate_prompt( public_data=True, shuffle=True, ): + """Mistral-specific prompt handling with System/User format""" + # Get base prompt data + base_data = generate_base_prompt( + prompt_file, + question, + db_name, + db_type, + instructions, + k_shot_prompt, + glossary, + table_metadata_string, + prev_invalid_sql, + prev_error_msg, + public_data, + shuffle, + ) + + # Load and parse Mistral-specific prompt format with open(prompt_file, "r") as f: prompt = f.read() # Check that System and User prompts are in the prompt file if "System:" not in prompt or "User:" not in prompt: raise ValueError("Invalid prompt file. Please use prompt_mistral.md") + sys_prompt = prompt.split("System:")[1].split("User:")[0].strip() user_prompt = prompt.split("User:")[1].strip() - if table_metadata_string == "": - if public_data: - from defog_data.metadata import dbs - import defog_data.supplementary as sup - else: - from defog_data_private.metadata import dbs - import defog_data_private.supplementary as sup - - md = dbs[db_name]["table_metadata"] - metadata_ddl = to_prompt_schema(md, shuffle) - metadata_ddl = convert_postgres_ddl_to_dialect( - postgres_ddl=metadata_ddl, - to_dialect=db_type, - db_name=db_name, - ) - column_join = sup.columns_join.get(db_name, {}) - # get join_str from column_join - join_list = [] - for values in column_join.values(): - if isinstance(values[0], tuple): - for col_pair in values: - col_1, col_2 = col_pair - # add to join_list - join_str = f"{col_1} can be joined with {col_2}" - if join_str not in join_list: - join_list.append(join_str) - else: - col_1, col_2 = values[0] - # add to join_list - join_str = f"{col_1} can be joined with {col_2}" - if join_str not in join_list: - join_list.append(join_str) - if len(join_list) > 0: - join_str = "\nHere is a list of joinable columns:\n" + "\n".join(join_list) - else: - join_str = "" - pruned_metadata_str = metadata_ddl + join_str - else: - pruned_metadata_str = table_metadata_string - + # Format the user prompt with parameters user_prompt = user_prompt.format( user_question=question, instructions=instructions, - table_metadata_string=pruned_metadata_str, + table_metadata_string=base_data["table_metadata_string"], k_shot_prompt=k_shot_prompt, glossary=glossary, prev_invalid_sql=prev_invalid_sql, prev_error_msg=prev_error_msg, ) - messages = [ - ChatMessage( - role="system", - content=sys_prompt, - ), - ChatMessage( - role="user", - content=user_prompt, - ), + + # Return Mistral-specific message format + return [ + ChatMessage(role="system", content=sys_prompt), + ChatMessage(role="user", content=user_prompt), ] - return messages def process_row(row, model, args): + """Process a single row using Mistral""" start_time = time() - chat_response = client.chat( - model=model, - messages=row["prompt"], - temperature=0, - max_tokens=600, - ) - end_time = time() - generated_query = chat_response.choices[0].message.content - try: - # replace all backslashes with empty string - generated_query = generated_query.replace("\\", "") - - generated_query = generated_query.split(";")[0].split("```sql")[-1].strip() - generated_query = [i for i in generated_query.split("```") if i.strip() != ""][ - 0 - ] + ";" - except Exception as e: - print(e) + chat_response = client.chat( + model=model, + messages=row["prompt"], + temperature=0, + max_tokens=600, + ) + end_time = time() generated_query = chat_response.choices[0].message.content - row["generated_query"] = generated_query - row["latency_seconds"] = end_time - start_time - golden_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - question = row["question"] - query_category = row["query_category"] - table_metadata_string = row["table_metadata_string"] - exact_match = correct = 0 - try: - exact_match, correct = compare_query_results( - query_gold=golden_query, - query_gen=generated_query, - db_name=db_name, - db_type=db_type, - db_creds=db_creds_all[row["db_type"]], - question=question, - query_category=query_category, - table_metadata_string=table_metadata_string, - decimal_points=args.decimal_points, - ) - row["exact_match"] = int(exact_match) - row["correct"] = int(correct) - row["error_msg"] = "" - except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" + # Mistral-specific SQL extraction with backslash handling + try: + generated_query = generated_query.replace("\\", "") + generated_query = generated_query.split(";")[0].split("```sql")[-1].strip() + generated_query = [ + i for i in generated_query.split("```") if i.strip() != "" + ][0] + ";" + except Exception as e: + print(e) + generated_query = chat_response.choices[0].message.content + + row["generated_query"] = generated_query + row["latency_seconds"] = end_time - start_time + + # Verify results + golden_query = row["query"] + db_name = row["db_name"] + db_type = row["db_type"] + question = row["question"] + query_category = row["query_category"] + table_metadata_string = row["table_metadata_string"] - return row + try: + exact_match, correct = compare_query_results( + query_gold=golden_query, + query_gen=generated_query, + db_name=db_name, + db_type=db_type, + db_creds=db_creds_all[db_type], + question=question, + query_category=query_category, + table_metadata_string=table_metadata_string, + decimal_points=( + args.decimal_points if hasattr(args, "decimal_points") else 2 + ), + ) + row["exact_match"] = int(exact_match) + row["correct"] = int(correct) + row["is_correct"] = int(correct) # For compatibility with base runner + row["error_msg"] = "" + except Exception as e: + row["error_db_exec"] = 1 + row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" + + return row + except Exception as e: + row["error_query_gen"] = 1 + row["generated_query"] = "" + row["error_msg"] = f"GENERATION ERROR: {e}" + row["latency_seconds"] = time() - start_time + return row def run_mistral_eval(args): - # get params from args + """Run evaluation using Mistral""" questions_file_list = args.questions_file prompt_file_list = args.prompt_file + output_file_list = args.output_file num_questions = args.num_questions public_data = not args.use_private_data - model = args.model - output_file_list = args.output_file k_shot = args.k_shot - max_workers = args.parallel_threads db_type = args.db_type cot_table_alias = args.cot_table_alias @@ -174,7 +160,6 @@ def run_mistral_eval(args): questions_file_list, prompt_file_list, output_file_list ): print(f"Using prompt file {prompt_file}") - # get questions print("Preparing questions...") print( f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" @@ -182,7 +167,8 @@ def run_mistral_eval(args): df = prepare_questions_df( questions_file, db_type, num_questions, k_shot, cot_table_alias ) - # create a prompt for each question + + # Mistral-specific: preprocess prompts into DataFrame df["prompt"] = df.apply( lambda row: generate_prompt( prompt_file, @@ -201,48 +187,51 @@ def run_mistral_eval(args): axis=1, ) - total_tried = 0 - total_correct = 0 - output_rows = [] - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] - for row in df.to_dict("records"): - futures.append(executor.submit(process_row, row, model, args)) - - with tqdm(as_completed(futures), total=len(futures)) as pbar: - for f in pbar: - row = f.result() - output_rows.append(row) - if row.get("correct", 0): - total_correct += 1 - total_tried += 1 - pbar.set_description( - f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" - ) + output_rows, total_correct, total_tried = run_eval_in_threadpool( + df, args.model, process_row, args + ) + # Convert to DataFrame and save results output_df = pd.DataFrame(output_rows) - del output_df["prompt"] - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist + if "prompt" in output_df.columns: + del output_df["prompt"] + + # Get stats by query category + agg_stats = ( + output_df.groupby("query_category") + .agg( + num_rows=("db_name", "count"), + mean_correct=("correct", "mean"), + mean_error_db_exec=("error_db_exec", "mean"), + ) + .reset_index() + ) + print(agg_stats) + + # Create output directory if needed output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): + if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) + + output_df.to_csv(output_file, index=False, float_format="%.2f") + + # Print summary stats + print(f"Total questions: {total_tried}") + print(f"Total correct: {total_correct}") + print(f"Accuracy: {total_correct/total_tried:.3f}") + + # Upload results if URL provided try: - output_df.to_csv(output_file, index=False, float_format="%.2f") - except: - output_df.to_pickle(output_file) - - results = output_df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="mistral_runner", - prompt=prompt, - args=args, - ) + if hasattr(args, "upload_url") and args.upload_url: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=output_df.to_dict("records"), + url=args.upload_url, + runner_type="mistral", + prompt=prompt, + args=args, + ) + except Exception as e: + print(f"Error uploading results: {e}") diff --git a/runners/mlx_runner.py b/runners/mlx_runner.py index e773008..0f1601a 100644 --- a/runners/mlx_runner.py +++ b/runners/mlx_runner.py @@ -1,78 +1,91 @@ import os - -from eval.eval import compare_query_results +from time import time import pandas as pd +from tqdm import tqdm +from mlx_lm import load, generate + from utils.gen_prompt import generate_prompt from utils.questions import prepare_questions_df from utils.creds import db_creds_all -from tqdm import tqdm -from time import time from utils.reporting import upload_results -from mlx_lm import load, generate +from eval.eval import compare_query_results def process_row(model, tokenizer, row, args): + """Process a single row using MLX""" start_time = time() - prompt = row["prompt"] - - generated_query = ( - generate(model, tokenizer, prompt=prompt, max_tokens=512, temp=0, verbose=True) - .split(";")[0] - .split("```")[0] - .strip() - + ";" - ) - end_time = time() - row["generated_query"] = generated_query - row["latency_seconds"] = end_time - start_time - golden_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - question = row["question"] - query_category = row["query_category"] - table_metadata_string = row["table_metadata_string"] - exact_match = correct = 0 - try: - exact_match, correct = compare_query_results( - query_gold=golden_query, - query_gen=generated_query, - db_name=db_name, - db_type=db_type, - db_creds=db_creds_all[row["db_type"]], - question=question, - query_category=query_category, - table_metadata_string=table_metadata_string, - decimal_points=args.decimal_points, + prompt = row["prompt"] + + # MLX-specific generation + generated_text = generate( + model, tokenizer, prompt=prompt, max_tokens=512, temp=0, verbose=True ) - row["exact_match"] = int(exact_match) - row["correct"] = int(correct) - row["error_msg"] = "" + generated_query = generated_text.split(";")[0].split("```")[0].strip() + ";" + end_time = time() + + # Store results + row["generated_query"] = generated_query + row["latency_seconds"] = end_time - start_time + + # Verify results + golden_query = row["query"] + db_name = row["db_name"] + db_type = row["db_type"] + question = row["question"] + query_category = row["query_category"] + table_metadata_string = row["table_metadata_string"] + + try: + exact_match, correct = compare_query_results( + query_gold=golden_query, + query_gen=generated_query, + db_name=db_name, + db_type=db_type, + db_creds=db_creds_all[db_type], + question=question, + query_category=query_category, + table_metadata_string=table_metadata_string, + decimal_points=( + args.decimal_points if hasattr(args, "decimal_points") else 2 + ), + ) + row["exact_match"] = int(exact_match) + row["correct"] = int(correct) + row["is_correct"] = int(correct) # For compatibility with base runner + row["error_msg"] = "" + except Exception as e: + row["error_db_exec"] = 1 + row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" + + return row except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" - return row + row["error_query_gen"] = 1 + row["generated_query"] = "" + row["error_msg"] = f"GENERATION ERROR: {e}" + row["latency_seconds"] = time() - start_time + return row def run_mlx_eval(args): - # get params from args + """Run evaluation using MLX""" questions_file_list = args.questions_file prompt_file_list = args.prompt_file + output_file_list = args.output_file num_questions = args.num_questions public_data = not args.use_private_data model_path = args.model - output_file_list = args.output_file k_shot = args.k_shot db_type = args.db_type cot_table_alias = args.cot_table_alias + # MLX-specific model loading model, tokenizer = load(model_path) for questions_file, prompt_file, output_file in zip( questions_file_list, prompt_file_list, output_file_list ): print(f"Using prompt file {prompt_file}") - # get questions print("Preparing questions...") print( f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" @@ -80,7 +93,8 @@ def run_mlx_eval(args): df = prepare_questions_df( questions_file, db_type, num_questions, k_shot, cot_table_alias ) - # create a prompt for each question + + # Create prompts with all parameters df["prompt"] = df.apply( lambda row: generate_prompt( prompt_file, @@ -93,19 +107,21 @@ def run_mlx_eval(args): row["table_metadata_string"], row["prev_invalid_sql"], row["prev_error_msg"], - row["question_0"], - row["query_0"], - row["question_1"], - row["query_1"], - row["cot_instructions"], - row["cot_pregen"], + row.get("question_0", ""), + row.get("query_0", ""), + row.get("question_1", ""), + row.get("query_1", ""), + row.get("cot_instructions", ""), + row.get("cot_pregen", False), public_data, - args.num_columns, + args.num_columns if hasattr(args, "num_columns") else 40, args.shuffle_metadata, + row.get("table_aliases", ""), ), axis=1, ) + # Process rows sequentially with tqdm total_tried = 0 total_correct = 0 output_rows = [] @@ -114,7 +130,7 @@ def run_mlx_eval(args): for row in df.to_dict("records"): row = process_row(model, tokenizer, row, args) output_rows.append(row) - if row["correct"]: + if row.get("correct", 0): total_correct += 1 total_tried += 1 pbar.update(1) @@ -122,28 +138,50 @@ def run_mlx_eval(args): f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" ) + # Convert to DataFrame and save results output_df = pd.DataFrame(output_rows) - del output_df["prompt"] - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist + if "prompt" in output_df.columns: + del output_df["prompt"] + + # Get stats by query category + agg_stats = ( + output_df.groupby("query_category") + .agg( + num_rows=("db_name", "count"), + mean_correct=("correct", "mean"), + mean_error_db_exec=("error_db_exec", "mean"), + ) + .reset_index() + ) + print(agg_stats) + + # Create output directory if needed output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): + if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) + try: output_df.to_csv(output_file, index=False, float_format="%.2f") except: output_df.to_pickle(output_file) - results = output_df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="mlx_runner", - prompt=prompt, - args=args, - ) + # Print summary stats + print(f"Total questions: {total_tried}") + print(f"Total correct: {total_correct}") + print(f"Accuracy: {total_correct/total_tried:.3f}") + + # Upload results if URL provided + try: + if hasattr(args, "upload_url") and args.upload_url: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=output_df.to_dict("records"), + url=args.upload_url, + runner_type="mlx_runner", + prompt=prompt, + args=args, + ) + except Exception as e: + print(f"Error uploading results: {e}") diff --git a/runners/openai_runner.py b/runners/openai_runner.py index 5d207ef..862613a 100644 --- a/runners/openai_runner.py +++ b/runners/openai_runner.py @@ -1,19 +1,18 @@ -import os from time import time -from concurrent.futures import ThreadPoolExecutor, as_completed import json - +import os import pandas as pd -import sqlparse -from tqdm import tqdm -from eval.eval import compare_query_results -from utils.creds import db_creds_all -from utils.dialects import convert_postgres_ddl_to_dialect -from utils.gen_prompt import to_prompt_schema +from runners.base_runner import ( + generate_base_prompt, + extract_sql_from_response, + run_eval_in_threadpool, +) from utils.questions import prepare_questions_df -from utils.reporting import upload_results from utils.llm import chat_openai +from utils.creds import db_creds_all +from utils.reporting import upload_results +from eval.eval import compare_query_results def generate_prompt( @@ -30,45 +29,28 @@ def generate_prompt( public_data=True, shuffle=True, ): - if public_data: - from defog_data.metadata import dbs - import defog_data.supplementary as sup - else: - from defog_data_private.metadata import dbs - import defog_data_private.supplementary as sup + """OpenAI-specific prompt handling""" + # Get base prompt data + base_data = generate_base_prompt( + prompt_file, + question, + db_name, + db_type, + instructions, + k_shot_prompt, + glossary, + table_metadata_string, + prev_invalid_sql, + prev_error_msg, + public_data, + shuffle, + ) + # Load and format OpenAI-specific JSON prompt with open(prompt_file, "r") as f: prompt = json.load(f) - if table_metadata_string == "": - md = dbs[db_name]["table_metadata"] - pruned_metadata_ddl = to_prompt_schema(md, shuffle) - pruned_metadata_ddl = convert_postgres_ddl_to_dialect( - postgres_ddl=pruned_metadata_ddl, - to_dialect=db_type, - db_name=db_name, - ) - column_join = sup.columns_join.get(db_name, {}) - join_list = [] - for values in column_join.values(): - if isinstance(values[0], tuple): - for col_pair in values: - col_1, col_2 = col_pair - join_str = f"{col_1} can be joined with {col_2}" - if join_str not in join_list: - join_list.append(join_str) - else: - col_1, col_2 = values[0] - join_str = f"{col_1} can be joined with {col_2}" - if join_str not in join_list: - join_list.append(join_str) - if len(join_list) > 0: - join_str = "\nHere is a list of joinable columns:\n" + "\n".join(join_list) - else: - join_str = "" - pruned_metadata_str = pruned_metadata_ddl + join_str - else: - pruned_metadata_str = table_metadata_string + pruned_metadata_str = base_data["table_metadata_string"] if prompt[0]["role"] == "system": prompt[0]["content"] = prompt[0]["content"].format( @@ -81,7 +63,7 @@ def generate_prompt( k_shot_prompt=k_shot_prompt, ) else: - prompt[0]["content"] = prompt[1]["content"].format( + prompt[0]["content"] = prompt[0]["content"].format( db_type=db_type, user_question=question, instructions=instructions, @@ -92,6 +74,7 @@ def generate_prompt( def process_row(row, model_name, args): + """Process a single row using OpenAI""" start_time = time() messages = generate_prompt( prompt_file=args.prompt_file[0], @@ -109,34 +92,57 @@ def process_row(row, model_name, args): ) try: response = chat_openai(messages=messages, model=model_name, temperature=0.0) - generated_query = ( - response.content.split("```sql", 1)[-1].split("```", 1)[0].strip() - ) - try: - generated_query = sqlparse.format( - generated_query, reindent=True, keyword_case="upper" - ) - except: - pass - return { - "query": generated_query, + generated_query = extract_sql_from_response(response.content) + + result = { + "generated_query": generated_query, "reason": "", - "err": "", + "error_msg": "", "latency_seconds": time() - start_time, "tokens_used": response.input_tokens + response.output_tokens, } + + # Verify results + expected_query = row["query"] + db_name = row["db_name"] + db_type = row["db_type"] + try: + is_correct = compare_query_results( + query_gold=expected_query, + query_gen=generated_query, + db_name=db_name, + db_type=db_type, + db_creds=db_creds_all[db_type], + question=row["question"], + query_category=row["query_category"], + decimal_points=( + args.decimal_points if hasattr(args, "decimal_points") else 2 + ), + ) + if is_correct: + row["is_correct"] = 1 + else: + row["is_correct"] = 0 + result["error_msg"] = "INCORRECT RESULTS" + except Exception as e: + row["error_db_exec"] = 1 + result["error_msg"] = f"EXECUTION ERROR: {str(e)}" + + # Update row with result data + row.update(result) + return row except Exception as e: - return { - "query": "", - "reason": "", - "err": f"GENERATION ERROR: {str(e)}", - "latency_seconds": time() - start_time, - "tokens_used": 0, - } + row["error_query_gen"] = 1 + row["generated_query"] = "" + row["reason"] = "" + row["error_msg"] = f"GENERATION ERROR: {str(e)}" + row["latency_seconds"] = time() - start_time + row["tokens_used"] = 0 + return row def run_openai_eval(args): - # get params from args + """Run evaluation using OpenAI""" questions_file_list = args.questions_file prompt_file_list = args.prompt_file output_file_list = args.output_file @@ -153,78 +159,21 @@ def run_openai_eval(args): print( f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" ) - question_query_df = prepare_questions_df( + df = prepare_questions_df( questions_file, db_type, num_questions, k_shot, cot_table_alias ) - input_rows = question_query_df.to_dict("records") - output_rows = [] - with ThreadPoolExecutor(args.parallel_threads) as executor: - futures = [] - for row in input_rows: - generated_query_fut = executor.submit( - process_row, - row=row, - model_name=args.model, - args=args, - ) - futures.append(generated_query_fut) - total_tried = 0 - total_correct = 0 - for f in (pbar := tqdm(as_completed(futures), total=len(futures))): - total_tried += 1 - i = futures.index(f) - row = input_rows[i] - result_dict = f.result() - query_gen = result_dict["query"] - reason = result_dict["reason"] - err = result_dict["err"] - # save custom metrics - if "latency_seconds" in result_dict: - row["latency_seconds"] = result_dict["latency_seconds"] - if "tokens_used" in result_dict: - row["tokens_used"] = result_dict["tokens_used"] - row["generated_query"] = query_gen - row["reason"] = reason - row["error_msg"] = err - # save failures into relevant columns in the dataframe - if "GENERATION ERROR" in err: - row["error_query_gen"] = 1 - else: - expected_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - try: - is_correct = compare_query_results( - query_gold=expected_query, - query_gen=query_gen, - db_name=db_name, - db_type=db_type, - question=row["question"], - query_category=row["query_category"], - db_creds=db_creds_all[db_type], - ) - if is_correct: - total_correct += 1 - row["is_correct"] = 1 - row["error_msg"] = "" - else: - row["is_correct"] = 0 - row["error_msg"] = "INCORRECT RESULTS" - except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"EXECUTION ERROR: {str(e)}" - output_rows.append(row) - pbar.set_description( - f"Accuracy: {round(total_correct/total_tried * 100, 2)}% ({total_correct}/{total_tried})" - ) + output_rows, total_correct, total_tried = run_eval_in_threadpool( + df, args.model, process_row, args + ) - # save results to csv + # Convert to DataFrame and save results output_df = pd.DataFrame(output_rows) output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) if "prompt" in output_df.columns: del output_df["prompt"] - # get num rows, mean correct, mean error_db_exec for each query_category + + # Get stats by query category agg_stats = ( output_df.groupby("query_category") .agg( @@ -235,26 +184,30 @@ def run_openai_eval(args): .reset_index() ) print(agg_stats) - # get directory of output_file and create if not exist + + # Create output directory if needed output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): + if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) - output_df.to_csv(output_file, index=False, float_format="%.2f") - # get average rate of correct results - avg_subset = output_df["correct"].sum() / len(output_df) - print(f"Average correct rate: {avg_subset:.2f}") + output_df.to_csv(output_file, index=False, float_format="%.2f") - results = output_df.to_dict("records") + # Print summary stats + print(f"Total questions: {total_tried}") + print(f"Total correct: {total_correct}") + print(f"Accuracy: {total_correct/total_tried:.3f}") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="openai", - prompt=prompt, - args=args, - ) + # Upload results if URL provided + try: + if hasattr(args, "upload_url") and args.upload_url: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=output_df.to_dict("records"), + url=args.upload_url, + runner_type="openai", + prompt=prompt, + args=args, + ) + except Exception as e: + print(f"Error uploading results: {e}") diff --git a/runners/together_runner.py b/runners/together_runner.py index 0414e57..35a22fb 100644 --- a/runners/together_runner.py +++ b/runners/together_runner.py @@ -1,96 +1,109 @@ import os -from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Dict - -from eval.eval import compare_query_results +from time import time import pandas as pd +from copy import deepcopy + +from together import Together + +from runners.base_runner import generate_base_prompt, run_eval_in_threadpool from utils.gen_prompt import generate_prompt from utils.questions import prepare_questions_df from utils.creds import db_creds_all -from tqdm import tqdm -from time import time -from together import Together from utils.reporting import upload_results - +from eval.eval import compare_query_results client = Together(api_key=os.environ.get("TOGETHER_API_KEY")) def process_row(row: Dict, model: str): + """Process a single row using Together""" start_time = time() - if model.startswith("meta-llama"): - stop = ["<|eot_id|>", "<|eom_id|>"] - else: - print( - "Undefined stop token(s). Please specify the stop token(s) for the model." - ) - stop = [] - messages = row["prompt"] - response = client.chat.completions.create( - model=model, - messages=messages, - max_tokens=800, - temperature=0.0, - stop=stop, - stream=False, - ) - content = response.choices[0].message.content - generated_query = content.split("```", 1)[0].strip() - end_time = time() - - row["generated_query"] = generated_query - row["latency_seconds"] = end_time - start_time - row["tokens_used"] = None - golden_query = row["query"] - db_name = row["db_name"] - db_type = row["db_type"] - question = row["question"] - query_category = row["query_category"] - table_metadata_string = row["table_metadata_string"] - exact_match = correct = 0 - try: - exact_match, correct = compare_query_results( - query_gold=golden_query, - query_gen=generated_query, - db_name=db_name, - db_type=db_type, - db_creds=db_creds_all[row["db_type"]], - question=question, - query_category=query_category, - table_metadata_string=table_metadata_string, + # Together-specific stop tokens + if model.startswith("meta-llama"): + stop = ["<|eot_id|>", "<|eom_id|>"] + else: + print( + "Undefined stop token(s). Please specify the stop token(s) for the model." + ) + stop = [] + + messages = row["prompt"] + response = client.chat.completions.create( + model=model, + messages=messages, + max_tokens=800, + temperature=0.0, + stop=stop, + stream=False, ) - row["exact_match"] = int(exact_match) - row["correct"] = int(correct) - row["error_msg"] = "" - except Exception as e: - row["error_db_exec"] = 1 - row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" + # Together-specific SQL extraction + content = response.choices[0].message.content + generated_query = content.split("```", 1)[0].strip() + end_time = time() + + # Store results + row["generated_query"] = generated_query + row["latency_seconds"] = end_time - start_time + row["tokens_used"] = None # Together doesn't provide token count + + # Verify results + golden_query = row["query"] + db_name = row["db_name"] + db_type = row["db_type"] + question = row["question"] + query_category = row["query_category"] + table_metadata_string = row["table_metadata_string"] - return row + try: + exact_match, correct = compare_query_results( + query_gold=golden_query, + query_gen=generated_query, + db_name=db_name, + db_type=db_type, + db_creds=db_creds_all[db_type], + question=question, + query_category=query_category, + table_metadata_string=table_metadata_string, + ) + row["exact_match"] = int(exact_match) + row["correct"] = int(correct) + row["is_correct"] = int(correct) # For compatibility with base runner + row["error_msg"] = "" + except Exception as e: + row["error_db_exec"] = 1 + row["error_msg"] = f"QUERY EXECUTION ERROR: {e}" + + return row + except Exception as e: + row["error_query_gen"] = 1 + row["generated_query"] = "" + row["error_msg"] = f"GENERATION ERROR: {e}" + row["latency_seconds"] = time() - start_time + row["tokens_used"] = None + return row def run_together_eval(args): - # get params from args + """Run evaluation using Together""" questions_file_list = args.questions_file prompt_file_list = args.prompt_file + output_file_list = args.output_file num_questions = args.num_questions public_data = not args.use_private_data - output_file_list = args.output_file k_shot = args.k_shot - max_workers = args.parallel_threads db_type = args.db_type - decimal_points = args.decimal_points - model = args.model cot_table_alias = args.cot_table_alias for questions_file, prompt_file, output_file in zip( questions_file_list, prompt_file_list, output_file_list ): + # Together-specific JSON validation if not prompt_file.endswith(".json"): raise ValueError(f"Prompt file must be a JSON file. Got {prompt_file}") + print(f"Using prompt file {prompt_file}") - # get questions print("Preparing questions...") print( f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" @@ -98,8 +111,8 @@ def run_together_eval(args): df = prepare_questions_df( questions_file, db_type, num_questions, k_shot, cot_table_alias ) - # create a prompt for each question - # note that the prompt for together ai uses the openai chat API + + # Together-specific: use full generate_prompt with all parameters df["prompt"] = df.apply( lambda row: generate_prompt( prompt_file, @@ -112,63 +125,68 @@ def run_together_eval(args): row["table_metadata_string"], row["prev_invalid_sql"], row["prev_error_msg"], - row["question_0"], - row["query_0"], - row["question_1"], - row["query_1"], - row["cot_instructions"], - row["cot_pregen"], + row.get("question_0", ""), + row.get("query_0", ""), + row.get("question_1", ""), + row.get("query_1", ""), + row.get("cot_instructions", ""), + row.get("cot_pregen", False), public_data, - args.num_columns, + args.num_columns if hasattr(args, "num_columns") else 40, args.shuffle_metadata, - row["table_aliases"], + row.get("table_aliases", ""), ), axis=1, ) - total_tried = 0 - total_correct = 0 - output_rows = [] - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] - for row in df.to_dict("records"): - futures.append(executor.submit(process_row, row, model)) - - with tqdm(as_completed(futures), total=len(futures)) as pbar: - for f in pbar: - row = f.result() - output_rows.append(row) - if row["correct"]: - total_correct += 1 - total_tried += 1 - pbar.update(1) - pbar.set_description( - f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" - ) + output_rows, total_correct, total_tried = run_eval_in_threadpool( + df, args.model, process_row, args + ) + # Convert to DataFrame and save results output_df = pd.DataFrame(output_rows) - del output_df["prompt"] - print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean()) output_df = output_df.sort_values(by=["db_name", "query_category", "question"]) - # get directory of output_file and create if not exist + if "prompt" in output_df.columns: + del output_df["prompt"] + + # Get stats by query category + agg_stats = ( + output_df.groupby("query_category") + .agg( + num_rows=("db_name", "count"), + mean_correct=("correct", "mean"), + mean_error_db_exec=("error_db_exec", "mean"), + ) + .reset_index() + ) + print(agg_stats) + + # Create output directory if needed output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): + if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) + try: output_df.to_csv(output_file, index=False, float_format="%.2f") except: output_df.to_pickle(output_file) - results = output_df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="api_runner", - prompt=prompt, - args=args, - ) + # Print summary stats + print(f"Total questions: {total_tried}") + print(f"Total correct: {total_correct}") + print(f"Accuracy: {total_correct/total_tried:.3f}") + + # Upload results if URL provided + try: + if hasattr(args, "upload_url") and args.upload_url: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=output_df.to_dict("records"), + url=args.upload_url, + runner_type="together", + prompt=prompt, + args=args, + ) + except Exception as e: + print(f"Error uploading results: {e}") diff --git a/runners/vllm_runner.py b/runners/vllm_runner.py index 59ed962..57bd960 100644 --- a/runners/vllm_runner.py +++ b/runners/vllm_runner.py @@ -1,40 +1,45 @@ -import json import os -from typing import List import sqlparse +import time +import torch +import pandas as pd +from typing import List +from tqdm import tqdm + from vllm import LLM, SamplingParams from vllm.lora.request import LoRARequest -from eval.eval import compare_query_results -import pandas as pd +from transformers import AutoTokenizer + from utils.gen_prompt import generate_prompt from utils.questions import prepare_questions_df from utils.creds import db_creds_all -import time -import torch -from transformers import AutoTokenizer -from tqdm import tqdm from utils.reporting import upload_results +from eval.eval import compare_query_results def run_vllm_eval(args): - # get params from args + """Run evaluation using VLLM with batching""" questions_file_list = args.questions_file prompt_file_list = args.prompt_file + output_file_list = args.output_file num_questions = args.num_questions public_data = not args.use_private_data model_name = args.model - output_file_list = args.output_file num_beams = args.num_beams k_shot = args.k_shot db_type = args.db_type cot_table_alias = args.cot_table_alias + + # VLLM-specific LoRA handling enable_lora = True if args.adapter else False lora_request = LoRARequest("sql_adapter", 1, args.adapter) if args.adapter else None - # initialize model only once as it takes a while + # Initialize VLLM model and tokenizer print(f"Preparing {model_name}") tokenizer = AutoTokenizer.from_pretrained(model_name) tokenizer.pad_token_id = tokenizer.eos_token_id + + # VLLM-specific model initialization if not args.quantized: llm = LLM( model=model_name, @@ -66,7 +71,6 @@ def run_vllm_eval(args): questions_file_list, prompt_file_list, output_file_list ): print(f"Using prompt file {prompt_file}") - # get questions print("Preparing questions...") print( f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}" @@ -74,7 +78,8 @@ def run_vllm_eval(args): df = prepare_questions_df( questions_file, db_type, num_questions, k_shot, cot_table_alias ) - # create a prompt for each question + + # Create prompts with all parameters df["prompt"] = df.apply( lambda row: generate_prompt( prompt_file, @@ -87,15 +92,16 @@ def run_vllm_eval(args): row["table_metadata_string"], row["prev_invalid_sql"], row["prev_error_msg"], - row["question_0"], - row["query_0"], - row["question_1"], - row["query_1"], - row["cot_instructions"], - row["cot_pregen"], + row.get("question_0", ""), + row.get("query_0", ""), + row.get("question_1", ""), + row.get("query_1", ""), + row.get("cot_instructions", ""), + row.get("cot_pregen", False), public_data, - args.num_columns, + args.num_columns if hasattr(args, "num_columns") else 40, args.shuffle_metadata, + row.get("table_aliases", ""), ), axis=1, ) @@ -112,22 +118,22 @@ def chunk_dataframe(df, chunk_size): df_chunks.append(df_i) return df_chunks + # VLLM-specific batch processing df_chunks = chunk_dataframe(df, args.batch_size) - total_tried = 0 total_correct = 0 output_rows = [] - print(f"Generating completions") - + print("Generating completions") for batch in (pbar := tqdm(df_chunks, total=len(df))): prompts = batch["prompt"].tolist() print(f"Generating completions for {len(prompts)} prompts") + + # VLLM-specific token handling prompt_tokens = [] prompt_token_sizes = [] for prompt in prompts: token_ids = tokenizer.encode(prompt, add_special_tokens=False) - # add bos token if not already present in prompt if token_ids[0] != tokenizer.bos_token_id: token_ids = [tokenizer.bos_token_id] + token_ids prompt_tokens.append(token_ids) @@ -135,8 +141,8 @@ def chunk_dataframe(df, chunk_size): print( f"Average prompt size: {sum(prompt_token_sizes)/len(prompt_token_sizes):.0f}" ) + start_time = time.time() - # outputs = llm.generate(prompts, sampling_params) # if you prefer to use prompts instead of token_ids outputs = llm.generate( sampling_params=sampling_params, prompt_token_ids=prompt_tokens, @@ -147,6 +153,7 @@ def chunk_dataframe(df, chunk_size): f"Generated {len(outputs)} completions in {time.time() - start_time:.2f} seconds" ) time_taken = time.time() - start_time + for row, output in zip(batch.to_dict("records"), outputs): generated_query = ( output.outputs[0].text.split(";")[0].split("```")[0].strip() + ";" @@ -158,28 +165,33 @@ def chunk_dataframe(df, chunk_size): row["tokens_used"] = len(output.outputs[0].token_ids) row["latency_seconds"] = time_taken / len(batch) + # Verify results golden_query = row["query"] db_name = row["db_name"] db_type = row["db_type"] question = row["question"] query_category = row["query_category"] table_metadata_string = row["table_metadata_string"] - exact_match = correct = 0 - db_creds = db_creds_all[db_type] + try: exact_match, correct = compare_query_results( query_gold=golden_query, query_gen=generated_query, db_name=db_name, db_type=db_type, - db_creds=db_creds, + db_creds=db_creds_all[db_type], question=question, query_category=query_category, table_metadata_string=table_metadata_string, - decimal_points=args.decimal_points, + decimal_points=( + args.decimal_points + if hasattr(args, "decimal_points") + else 2 + ), ) row["exact_match"] = int(exact_match) row["correct"] = int(correct) + row["is_correct"] = int(correct) # For base runner compatibility row["error_msg"] = "" if correct: total_correct += 1 @@ -189,31 +201,47 @@ def chunk_dataframe(df, chunk_size): total_tried += 1 output_rows.append(row) + pbar.update(len(batch)) pbar.set_description( - f"Correct so far: {total_correct}/{(total_tried)} ({100*total_correct/(total_tried):.2f}%)" + f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)" ) + + # Process results df = pd.DataFrame(output_rows) - del df["prompt"] - print(df.groupby("query_category")[["exact_match", "correct"]].mean()) + if "prompt" in df.columns: + del df["prompt"] + + # Get stats by query category + agg_stats = df.groupby("query_category")[["exact_match", "correct"]].mean() + print(agg_stats) df = df.sort_values(by=["db_name", "query_category", "question"]) print(f"Average tokens generated: {df['tokens_used'].mean():.1f}") - # get directory of output_file and create if not exist + + # Create output directory if needed output_dir = os.path.dirname(output_file) - if not os.path.exists(output_dir): + if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir) + df.to_csv(output_file, index=False, float_format="%.2f") print(f"Saved results to {output_file}") - results = df.to_dict("records") - # upload results - with open(prompt_file, "r") as f: - prompt = f.read() - if args.upload_url is not None: - upload_results( - results=results, - url=args.upload_url, - runner_type="vllm_runner", - prompt=prompt, - args=args, - ) + # Print summary stats + print(f"Total questions: {total_tried}") + print(f"Total correct: {total_correct}") + print(f"Accuracy: {total_correct/total_tried:.3f}") + + # Upload results if URL provided + try: + if hasattr(args, "upload_url") and args.upload_url: + with open(prompt_file, "r") as f: + prompt = f.read() + upload_results( + results=df.to_dict("records"), + url=args.upload_url, + runner_type="vllm_runner", + prompt=prompt, + args=args, + ) + except Exception as e: + print(f"Error uploading results: {e}")