Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Agent farm submission: defog-ai_sql-eval_230 #232

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 44 additions & 86 deletions runners/anthropic_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def process_row(row, model_name, args):


def run_anthropic_eval(args):
# get params from args
"""Run evaluation using Anthropic"""
questions_file_list = args.questions_file
prompt_file_list = args.prompt_file
output_file_list = args.output_file
Expand All @@ -145,97 +145,55 @@ def run_anthropic_eval(args):
print(
f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}"
)
question_query_df = prepare_questions_df(
df = prepare_questions_df(
questions_file, db_type, num_questions, k_shot, cot_table_alias
)
input_rows = question_query_df.to_dict("records")
output_rows = []
with ThreadPoolExecutor(args.parallel_threads) as executor:
futures = []
for row in input_rows:
generated_query_fut = executor.submit(
process_row,
row=row,
model_name=args.model,
args=args,
)
futures.append(generated_query_fut)

total_tried = 0
total_correct = 0
for f in (pbar := tqdm(as_completed(futures), total=len(futures))):
total_tried += 1
i = futures.index(f)
row = input_rows[i]
result_dict = f.result()
query_gen = result_dict["query"]
reason = result_dict["reason"]
err = result_dict["err"]
# save custom metrics
if "latency_seconds" in result_dict:
row["latency_seconds"] = result_dict["latency_seconds"]
if "tokens_used" in result_dict:
row["tokens_used"] = result_dict["tokens_used"]
row["generated_query"] = query_gen
row["reason"] = reason
row["error_msg"] = err
# save failures into relevant columns in the dataframe
if "GENERATION ERROR" in err:
row["error_query_gen"] = 1
elif "TIMEOUT" in err:
row["timeout"] = 1
else:
expected_query = row["query"]
db_name = row["db_name"]
db_type = row["db_type"]
try:
is_correct = compare_query_results(
query_gold=expected_query,
query_gen=query_gen,
db_name=db_name,
db_type=db_type,
db_creds=db_creds_all[db_type],
question=row["question"],
query_category=row["query_category"],
decimal_points=args.decimal_points,
)
if is_correct:
total_correct += 1
row["is_correct"] = 1
row["error_msg"] = ""
else:
row["is_correct"] = 0
row["error_msg"] = "INCORRECT RESULTS"
except Exception as e:
row["error_db_exec"] = 1
row["error_msg"] = f"EXECUTION ERROR: {str(e)}"
output_rows.append(row)
pbar.set_description(
f"Accuracy: {round(total_correct/total_tried * 100, 2)}% ({total_correct}/{total_tried})"
)

# save results to csv
output_rows, total_correct, total_tried = run_eval_in_threadpool(
df, args.model, process_row, args
)

# Convert to DataFrame and save results
output_df = pd.DataFrame(output_rows)
output_df = output_df.sort_values(by=["db_name", "query_category", "question"])
# get directory of output_file and create if not exist
if "prompt" in output_df.columns:
del output_df["prompt"]

# Get stats by query category
agg_stats = (
output_df.groupby("query_category")
.agg(
num_rows=("db_name", "count"),
mean_correct=("is_correct", "mean"),
mean_error_db_exec=("error_db_exec", "mean"),
)
.reset_index()
)
print(agg_stats)

# Create output directory if needed
output_dir = os.path.dirname(output_file)
if not os.path.exists(output_dir):
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)

output_df.to_csv(output_file, index=False, float_format="%.2f")

# get average rate of correct results
avg_subset = output_df["is_correct"].sum() / len(output_df)
print(f"Average correct rate: {avg_subset:.2f}")

results = output_df.to_dict("records")
# upload results
with open(prompt_file, "r") as f:
prompt = f.read()
if args.upload_url is not None:
upload_results(
results=results,
url=args.upload_url,
runner_type="anthropic",
prompt=prompt,
args=args,
)
# Print summary stats
print(f"Total questions: {total_tried}")
print(f"Total correct: {total_correct}")
print(f"Accuracy: {total_correct/total_tried:.3f}")

# Upload results if URL provided
try:
if hasattr(args, "upload_url") and args.upload_url:
with open(prompt_file, "r") as f:
prompt = f.read()
upload_results(
results=output_df.to_dict("records"),
url=args.upload_url,
runner_type="anthropic",
prompt=prompt,
args=args,
)
except Exception as e:
print(f"Error uploading results: {e}")
144 changes: 64 additions & 80 deletions runners/api_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,50 +206,46 @@ def process_row(


def run_api_eval(args):
# get params from args
"""Run evaluation using API"""
questions_file_list = args.questions_file
prompt_file_list = args.prompt_file
output_file_list = args.output_file
num_questions = args.num_questions
public_data = not args.use_private_data
api_url = args.api_url
api_type = args.api_type
output_file_list = args.output_file
k_shot = args.k_shot
num_beams = args.num_beams
max_workers = args.parallel_threads
cot_table_alias = args.cot_table_alias
db_type = args.db_type
decimal_points = args.decimal_points
logprobs = args.logprobs
cot_table_alias = args.cot_table_alias
sql_lora_path = args.adapter if args.adapter else None
sql_lora_name = args.adapter_name if args.adapter_name else None
run_name = args.run_name if args.run_name else None
run_name = getattr(args, "run_name", None)
sql_lora_path = getattr(args, "adapter", None)

if sql_lora_path:
print("Using LoRA adapter at:", sql_lora_path)

# Logprobs visualization directory handling
if logprobs:
# check that the eval-visualizer/public directory exists
if not os.path.exists("./eval-visualizer"):
# thorow error
raise Exception(
"The eval-visualizer directory does not exist. Please clone it with `git clone https://github.com/defog-ai/eval-visualizer/` before running sql-eval with the --logprobs flag."
"The eval-visualizer directory does not exist. Please clone it with "
"`git clone https://github.com/defog-ai/eval-visualizer/` before running "
"sql-eval with the --logprobs flag."
)

if not os.path.exists("./eval-visualizer/public"):
os.makedirs("./eval-visualizer/public")

for questions_file, prompt_file, output_file in zip(
questions_file_list, prompt_file_list, output_file_list
):
print(f"Using prompt file {prompt_file}")
# get questions
print("Preparing questions...")
print(
f"Using {'all' if num_questions is None else num_questions} question(s) from {questions_file}"
)
df = prepare_questions_df(
questions_file, db_type, num_questions, k_shot, cot_table_alias
)
# create a prompt for each question

# Create prompts with all parameters
df["prompt"] = df.apply(
lambda row: generate_prompt(
prompt_file,
Expand All @@ -262,65 +258,30 @@ def run_api_eval(args):
row["table_metadata_string"],
row["prev_invalid_sql"],
row["prev_error_msg"],
row["question_0"],
row["query_0"],
row["question_1"],
row["query_1"],
row["cot_instructions"],
row["cot_pregen"],
row.get("question_0", ""),
row.get("query_0", ""),
row.get("question_1", ""),
row.get("query_1", ""),
row.get("cot_instructions", ""),
row.get("cot_pregen", False),
public_data,
args.num_columns,
args.num_columns if hasattr(args, "num_columns") else 40,
args.shuffle_metadata,
row["table_aliases"],
row.get("table_aliases", ""),
),
axis=1,
)

total_tried = 0
total_correct = 0
output_rows = []

with ThreadPoolExecutor(max_workers=max_workers) as executor:
futures = []
for row in df.to_dict("records"):
futures.append(
executor.submit(
process_row,
row,
api_url,
api_type,
num_beams,
decimal_points,
logprobs,
sql_lora_path,
sql_lora_name,
)
)

with tqdm(as_completed(futures), total=len(futures)) as pbar:
for f in pbar:
row = f.result()
output_rows.append(row)
if row["correct"]:
total_correct += 1
total_tried += 1
pbar.update(1)
pbar.set_description(
f"Correct so far: {total_correct}/{total_tried} ({100*total_correct/total_tried:.2f}%)"
)
output_rows, total_correct, total_tried = run_eval_in_threadpool(
df, args.api_url, process_row, args
)

output_df = pd.DataFrame(output_rows)

print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean())
output_df = output_df.sort_values(by=["db_name", "query_category", "question"])
# get directory of output_file and create if not exist
output_dir = os.path.dirname(output_file)
if not os.path.exists(output_dir):
os.makedirs(output_dir)

results = output_df.to_dict("records")

# Handle logprobs visualization
if logprobs:
results = output_df.to_dict("records")
print(
f"Writing logprobs to JSON file at eval-visualizer/public/{output_file.split('/')[-1].replace('.csv', '.json')}"
)
Expand All @@ -330,27 +291,50 @@ def run_api_eval(args):
) as f:
json.dump(results, f)

del output_df["prompt"]
# Get stats by query category
agg_stats = (
output_df.groupby("query_category")
.agg(
num_rows=("db_name", "count"),
mean_correct=("correct", "mean"),
mean_error_db_exec=("error_db_exec", "mean"),
)
.reset_index()
)
print(agg_stats)

# Clean up and save results
if "prompt" in output_df.columns:
del output_df["prompt"]

output_dir = os.path.dirname(output_file)
if output_dir and not os.path.exists(output_dir):
os.makedirs(output_dir)

try:
output_df.to_csv(output_file, index=False, float_format="%.2f")
except:
output_df.to_pickle(output_file)

# upload results
# with open(prompt_file, "r") as f:
# prompt = f.read()

if args.run_name is None:
# Handle run naming and result upload
if run_name is None:
run_name = output_file.split("/")[-1].replace(".csv", "")
print(
"Run name not provided. Using a output filename for run name:", run_name
"Run name not provided. Using output filename for run name:", run_name
)

if args.upload_url is not None:
upload_results(
results=results,
url=args.upload_url,
runner_type="api_runner",
args=args,
run_name=run_name,
)
print(f"Total questions: {total_tried}")
print(f"Total correct: {total_correct}")
print(f"Accuracy: {total_correct/total_tried:.3f}")

try:
if hasattr(args, "upload_url") and args.upload_url:
upload_results(
results=output_df.to_dict("records"),
url=args.upload_url,
runner_type="api_runner",
args=args,
run_name=run_name,
)
except Exception as e:
print(f"Error uploading results: {e}")
Loading