Skip to content

Commit

Permalink
added basic instruct questions
Browse files Browse the repository at this point in the history
upgraded openai's runner and standard prompt generator to bypass pruning when num_columns = 0
  • Loading branch information
wongjingping committed Mar 28, 2024
1 parent 313725e commit b0ab6e3
Show file tree
Hide file tree
Showing 7 changed files with 234 additions and 17 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -294,8 +294,8 @@ You can use the following flags in the command line to change the configurations
|--------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| -f, --prompt_file | Markdown file with the prompt used for query generation. You can pass in a list of prompts to test sequentially without reloading the script. |
| -b, --num_beams | Indicates the number of beams you want to use for beam search at inference. Only available for `hf_runner`, `vllm_runner`, and `api_runner`. |
| -c, --num_columns | Number of columns. |
| -s, --shuffle_metadata | Shuffle metadata. |
| -c, --num_columns | Number of columns, default 20. To not prune the columns, set it to 0. |
| -s, --shuffle_metadata | Shuffle metadata, default False. This shuffles the order of the tables within the schema and the order of the columns within each table but does not shift columns between tables (to preserve the structure of the database). |
| -k, --k_shot | Used when you want to include k-shot examples in your prompt. Make sure that the column 'k_shot_prompt' exists in your questions_file. |

### Execution-related parameters
Expand Down
41 changes: 41 additions & 0 deletions data/instruct_basic_postgres.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
db_name,query_category,question,query
broker,basic_join_date_group_order_limit,Which are the top 3 countries by number of transactions on 1st April 2023?,"SELECT c.sbCustCountry, COUNT(*) AS num_transactions FROM sbTransaction t JOIN sbCustomer c ON t.sbTxCustId = c.sbCustId WHERE t.sbTxDateTime >= '2023-04-01' AND t.sbTxDateTime < '2023-04-02' GROUP BY c.sbCustCountry ORDER BY num_transactions DESC LIMIT 3;"
broker,basic_join_date_group_order_limit,What is the transaction type (buy or sell) with the highest total amount for stock tickers in the last 7 days?,"SELECT t.sbTxType, SUM(t.sbTxAmount) AS total_amount FROM sbTransaction t JOIN sbTicker tk ON t.sbTxTickerId = tk.sbTickerId WHERE tk.sbTickerType = 'stock' AND t.sbTxDateTime >= NOW() - INTERVAL '7 days' GROUP BY t.sbTxType ORDER BY total_amount DESC LIMIT 1;"
broker,basic_join_group_order_limit,What are the top 5 ticker symbols by total transaction shares?,"SELECT tk.sbTickerSymbol, SUM(t.sbTxShares) AS total_shares FROM sbTransaction t JOIN sbTicker tk ON t.sbTxTickerId = tk.sbTickerId GROUP BY tk.sbTickerSymbol ORDER BY total_shares DESC LIMIT 5;"
broker,basic_join_group_order_limit,What are the top 3 customer states by average transaction amount?,"SELECT c.sbCustState, AVG(t.sbTxAmount) AS avg_amount FROM sbTransaction t JOIN sbCustomer c ON t.sbTxCustId = c.sbCustId GROUP BY c.sbCustState ORDER BY avg_amount DESC LIMIT 3;"
broker,basic_join_distinct,What are the distinct customer IDs who have transacted in Apple (AAPL) stock?,SELECT DISTINCT t.sbTxCustId FROM sbTransaction t JOIN sbTicker tk ON t.sbTxTickerId = tk.sbTickerId WHERE tk.sbTickerSymbol = 'AAPL';
broker,basic_join_distinct,What are the distinct ticker IDs that have been transacted by customers from USA?,SELECT DISTINCT t.sbTxTickerId FROM sbTransaction t JOIN sbCustomer c ON t.sbTxCustId = c.sbCustId WHERE c.sbCustCountry = 'USA';
broker,basic_group_order_limit,What is the most frequent transaction type (buy or sell) by number of transactions?,"SELECT sbTxType, COUNT(*) AS num_transactions FROM sbTransaction GROUP BY sbTxType ORDER BY num_transactions DESC LIMIT 1;"
broker,basic_group_order_limit,What are the top 2 transaction statuses by total transaction amount?,"SELECT sbTxStatus, SUM(sbTxAmount) AS total_amount FROM sbTransaction GROUP BY sbTxStatus ORDER BY total_amount DESC LIMIT 2;"
broker,basic_left_join,What are the customer IDs and names who have not made any transactions?,"SELECT c.sbCustId, c.sbCustName FROM sbCustomer c LEFT JOIN sbTransaction t ON c.sbCustId = t.sbTxCustId WHERE t.sbTxCustId IS NULL;"
broker,basic_left_join,What are the ticker IDs and symbols that do not have any daily price data?,"SELECT tk.sbTickerId, tk.sbTickerSymbol FROM sbTicker tk LEFT JOIN sbDailyPrice dp ON tk.sbTickerId = dp.sbDpTickerId WHERE dp.sbDpTickerId IS NULL;"
car_dealership,basic_join_date_group_order_limit,"What are the top 3 car models by total revenue sold in March 2023, showing the make, model, number of sales, and total revenue for each?","SELECT c.make, c.model, COUNT(*) AS total_sales, SUM(s.sale_price) AS revenue FROM sales s JOIN cars c ON s.car_id = c.id WHERE s.sale_date BETWEEN '2023-03-01' AND '2023-03-31' GROUP BY c.make, c.model ORDER BY revenue DESC LIMIT 3;"
car_dealership,basic_join_date_group_order_limit,"Who are the top 5 salespersons by number of sales in the past 30 days, and what is the total revenue each of them generated?","SELECT sp.first_name, sp.last_name, COUNT(*) AS total_sales, SUM(s.sale_price) AS revenue FROM sales s JOIN salespersons sp ON s.salesperson_id = sp.id WHERE s.sale_date >= NOW() - INTERVAL '30 days' GROUP BY sp.first_name, sp.last_name ORDER BY total_sales DESC LIMIT 5;"
car_dealership,basic_join_group_order_limit,"What are the top 5 states by total revenue, and how many unique customers made purchases from each state?","SELECT c.state, COUNT(DISTINCT s.customer_id) AS unique_customers, SUM(s.sale_price) AS revenue FROM sales s JOIN customers c ON s.customer_id = c.id GROUP BY c.state ORDER BY revenue DESC LIMIT 5;"
car_dealership,basic_join_group_order_limit,"What are the 10 car models with the highest average sale price, showing the make, model, year and average price for each?","SELECT c.make, c.model, c.year, AVG(s.sale_price) AS avg_price FROM sales s JOIN cars c ON s.car_id = c.id GROUP BY c.make, c.model, c.year ORDER BY avg_price DESC LIMIT 10;"
car_dealership,basic_join_distinct,"Provide the distinct customer IDs, first names and last names of all customers who have made a purchase.","SELECT DISTINCT c.id AS customer_id, c.first_name, c.last_name FROM customers c JOIN sales s ON c.id = s.customer_id;"
car_dealership,basic_join_distinct,"List the distinct IDs, first names and last names of all salespersons who have made a sale.","SELECT DISTINCT s.id AS salesperson_id, s.first_name, s.last_name FROM salespersons s JOIN sales sa ON s.id = sa.salesperson_id;"
car_dealership,basic_group_order_limit,What are the top 3 most frequently used payment methods for payments received?,"SELECT payment_method, COUNT(*) AS COUNT FROM payments_received GROUP BY payment_method ORDER BY COUNT DESC LIMIT 3;"
car_dealership,basic_group_order_limit,What are the 5 most expensive car engine types on average?,"SELECT engine_type, AVG(price) AS avg_price FROM cars GROUP BY engine_type ORDER BY avg_price DESC LIMIT 5;"
car_dealership,basic_left_join,"Which cars (IDs, makes, models, years) have not been sold yet?","SELECT c.id AS car_id, c.make, c.model, c.year FROM cars c LEFT JOIN sales s ON c.id = s.car_id WHERE s.car_id IS NULL;"
car_dealership,basic_left_join,"Which salespersons (IDs, first names, last names) have not made any sales yet?","SELECT sp.id AS salesperson_id, sp.first_name, sp.last_name FROM salespersons sp LEFT JOIN sales s ON sp.id = s.salesperson_id WHERE s.salesperson_id IS NULL;"
derm_treatment,basic_join_date_group_order_limit,"What are the top 3 doctor specialties by number of treatments in 2022, and what is the average PASI score at day 100 for each specialty?","SELECT d.specialty, COUNT(*) AS num_treatments, AVG(o.day100_pasi_score) AS avg_pasi_score FROM treatments t JOIN doctors d ON t.doc_id = d.doc_id JOIN outcomes o ON t.treatment_id = o.treatment_id WHERE t.start_dt BETWEEN '2022-01-01' AND '2022-12-31' GROUP BY d.specialty ORDER BY num_treatments DESC LIMIT 3;"
derm_treatment,basic_join_date_group_order_limit,"For treatments started in the past 6 months, which patient gender had the most number of distinct patients treated, and what was the average itch VAS score at day 30 for that gender?","SELECT p.gender, COUNT(DISTINCT t.patient_id) AS num_patients, AVG(o.day30_itch_vas) AS avg_itch_score FROM treatments t JOIN patients p ON t.patient_id = p.patient_id JOIN outcomes o ON t.treatment_id = o.treatment_id WHERE t.start_dt >= DATE_TRUNC('month', CURRENT_DATE - INTERVAL '6 months') GROUP BY p.gender ORDER BY num_patients DESC LIMIT 1;"
derm_treatment,basic_join_group_order_limit,"What are the top 3 drugs by number of treatments, and what is the average lesion count at day 100 for each drug?","SELECT dr.drug_name, COUNT(*) AS num_treatments, AVG(o.day100_lesion_cnt) AS avg_lesion_count FROM treatments t JOIN drugs dr ON t.drug_id = dr.drug_id JOIN outcomes o ON t.treatment_id = o.treatment_id GROUP BY dr.drug_name ORDER BY num_treatments DESC LIMIT 3;"
derm_treatment,basic_join_group_order_limit,"What are the top 5 diagnoses by average TEWL at day 30, and how many distinct patients were treated for each diagnosis?","SELECT di.diag_name, COUNT(DISTINCT t.patient_id) AS num_patients, AVG(o.day30_tewl) AS avg_tewl FROM treatments t JOIN diagnoses di ON t.diag_id = di.diag_id JOIN outcomes o ON t.treatment_id = o.treatment_id GROUP BY di.diag_name ORDER BY avg_tewl DESC LIMIT 5;"
derm_treatment,basic_join_distinct,"Provide the distinct patient IDs, first names and last names of patients who had any adverse events during their treatments.","SELECT DISTINCT p.patient_id, p.first_name, p.last_name FROM patients p JOIN treatments t ON p.patient_id = t.patient_id JOIN adverse_events ae ON t.treatment_id = ae.treatment_id;"
derm_treatment,basic_join_distinct,"List the distinct doctor IDs, names and specialties who prescribed biologic drugs to their patients.","SELECT DISTINCT d.doc_id, d.first_name, d.last_name, d.specialty FROM doctors d JOIN treatments t ON d.doc_id = t.doc_id JOIN drugs dr ON t.drug_id = dr.drug_id WHERE dr.drug_type = 'biologic';"
derm_treatment,basic_group_order_limit,Which insurance type covers the highest number of patients?,"SELECT ins_type, COUNT(*) AS num_patients FROM patients GROUP BY ins_type ORDER BY num_patients DESC LIMIT 1;"
derm_treatment,basic_group_order_limit,What is the drug type with the least number of drugs?,"SELECT drug_type, COUNT(*) AS num_drugs FROM drugs GROUP BY drug_type ORDER BY num_drugs LIMIT 1;"
derm_treatment,basic_left_join,Find the patient IDs and names of patients who have not received any treatments.,"SELECT p.patient_id, p.first_name, p.last_name FROM patients p LEFT JOIN treatments t ON p.patient_id = t.patient_id WHERE t.patient_id IS NULL;"
derm_treatment,basic_left_join,List the drug IDs and names that have not been used in any treatments.,"SELECT dr.drug_id, dr.drug_name FROM drugs dr LEFT JOIN treatments t ON dr.drug_id = t.drug_id WHERE t.drug_id IS NULL;"
ewallet,basic_join_date_group_order_limit,"What are the top 3 merchants by total transaction amount between June 1st and June 7th 2023? Return the merchant name, total number of transactions, and total transaction amount.","SELECT m.name AS merchant_name, COUNT(*) AS total_transactions, SUM(t.amount) AS total_amount FROM consumer_div.wallet_transactions_daily t JOIN consumer_div.merchants m ON t.receiver_id = m.mid WHERE t.receiver_type = 1 AND t.created_at BETWEEN '2023-06-01' AND '2023-06-07' GROUP BY m.name ORDER BY total_amount DESC LIMIT 3;"
ewallet,basic_join_date_group_order_limit,"How many distinct users made transactions each month for the past 6 months? Return the month and total number of users, ordered by most recent month first.","SELECT DATE_TRUNC('month', t.created_at) AS MONTH, COUNT(DISTINCT t.sender_id) AS total_users FROM consumer_div.wallet_transactions_daily t WHERE t.sender_type = 0 AND t.created_at >= NOW() - INTERVAL '6 months' GROUP BY MONTH ORDER BY MONTH DESC LIMIT 6;"
ewallet,basic_join_group_order_limit,"What are the top 5 most used coupon codes? Return the coupon code, total number of uses, and total discount amount.","SELECT c.code AS coupon_code, COUNT(*) AS total_uses, SUM(t.amount) AS total_discount FROM consumer_div.wallet_transactions_daily t JOIN consumer_div.coupons c ON t.coupon_id = c.cid WHERE t.type = 'credit' GROUP BY c.code ORDER BY total_uses DESC LIMIT 5;"
ewallet,basic_join_group_order_limit,"What are the top 10 countries by total wallet balance for active users? Return the country, total number of users, and total wallet balance.","SELECT u.country, COUNT(*) AS total_users, SUM(b.balance) AS total_balance FROM consumer_div.users u JOIN consumer_div.wallet_user_balance_daily b ON u.uid = b.user_id WHERE u.status = 'active' GROUP BY u.country ORDER BY total_balance DESC LIMIT 10;"
ewallet,basic_join_distinct,Get the distinct user IDs of all active users who have made a wallet transaction.,SELECT DISTINCT t.sender_id AS user_id FROM consumer_div.wallet_transactions_daily t JOIN consumer_div.users u ON t.sender_id = u.uid WHERE t.sender_type = 0 AND u.status = 'active';
ewallet,basic_join_distinct,Get the distinct user IDs of all business users who have received a notification.,SELECT DISTINCT n.user_id FROM consumer_div.notifications n JOIN consumer_div.users u ON n.user_id = u.uid WHERE u.user_type = 'business';
ewallet,basic_group_order_limit,What are the top 2 device types by total number of user sessions? Return the device type and total session count.,"SELECT device_type, COUNT(*) AS total_sessions FROM consumer_div.user_sessions GROUP BY device_type ORDER BY total_sessions DESC LIMIT 2;"
ewallet,basic_group_order_limit,What are the top 3 user statuses by total number of users? Return the status and total user count.,"SELECT status, COUNT(*) AS total_users FROM consumer_div.users GROUP BY status ORDER BY total_users DESC LIMIT 3;"
ewallet,basic_left_join,Get the user ID and username of all users who do not have a wallet balance.,"SELECT u.uid AS user_id, u.username FROM consumer_div.users u LEFT JOIN consumer_div.wallet_user_balance_daily b ON u.uid = b.user_id WHERE b.user_id IS NULL;"
ewallet,basic_left_join,Get the merchant ID and name of all merchants who have not issued any coupons.,"SELECT m.mid AS merchant_id, m.name AS merchant_name FROM consumer_div.merchants m LEFT JOIN consumer_div.coupons c ON m.mid = c.merchant_id WHERE c.merchant_id IS NULL;"
14 changes: 13 additions & 1 deletion eval/openai_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def run_openai_eval(args):
query_gen = result_dict["query"]
reason = result_dict["reason"]
err = result_dict["err"]
table_metadata_string = result_dict["table_metadata_string"]
# save custom metrics
if "latency_seconds" in result_dict:
row["latency_seconds"] = result_dict["latency_seconds"]
Expand All @@ -75,6 +76,7 @@ def run_openai_eval(args):
row["generated_query"] = query_gen
row["reason"] = reason
row["error_msg"] = err
row["table_metadata_string"] = table_metadata_string
# save failures into relevant columns in the dataframe
if "GENERATION ERROR" in err:
row["error_query_gen"] = 1
Expand Down Expand Up @@ -123,7 +125,17 @@ def run_openai_eval(args):
output_df = output_df.sort_values(by=["db_name", "query_category", "question"])
if "prompt" in output_df.columns:
del output_df["prompt"]
print(output_df.groupby("query_category")[["correct", "error_db_exec"]].mean())
# get num rows, mean correct, mean error_db_exec for each query_category
agg_stats = (
output_df.groupby("query_category")
.agg(
num_rows=("db_name", "count"),
mean_correct=("correct", "mean"),
mean_error_db_exec=("error_db_exec", "mean"),
)
.reset_index()
)
print(agg_stats)
# get directory of output_file and create if not exist
output_dir = os.path.dirname(output_file)
if not os.path.exists(output_dir):
Expand Down
33 changes: 21 additions & 12 deletions query_generators/openai.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import time
from typing import Dict, List

from defog_data.metadata import dbs
from func_timeout import FunctionTimedOut, func_timeout
from openai import OpenAI
import tiktoken

from query_generators.query_generator import QueryGenerator
from utils.pruning import prune_metadata_str
from utils.pruning import prune_metadata_str, to_prompt_schema

openai = OpenAI()

Expand Down Expand Up @@ -130,15 +132,21 @@ def generate_query(
chat_prompt = file.read()
question_instructions = question + " " + instructions
if table_metadata_string == "":
pruned_metadata_str = prune_metadata_str(
question_instructions,
self.db_name,
self.use_public_data,
columns_to_keep,
shuffle,
)
else:
pruned_metadata_str = table_metadata_string
if columns_to_keep > 0:
table_metadata_string = prune_metadata_str(
question_instructions,
self.db_name,
self.use_public_data,
columns_to_keep,
shuffle,
)
elif columns_to_keep == 0:
md = dbs[self.db_name]["table_metadata"]
table_metadata_string = to_prompt_schema(md, shuffle)
else:
raise ValueError("columns_to_keep must be >= 0")
if glossary == "":
glossary = dbs[self.db_name]["glossary"]
if self.model != "text-davinci-003":
try:
sys_prompt = chat_prompt.split("### Input:")[0]
Expand All @@ -150,7 +158,7 @@ def generate_query(
raise ValueError("Invalid prompt file. Please use prompt_openai.md")
user_prompt = user_prompt.format(
user_question=question,
table_metadata_string=pruned_metadata_str,
table_metadata_string=table_metadata_string,
instructions=instructions,
k_shot_prompt=k_shot_prompt,
glossary=glossary,
Expand All @@ -165,7 +173,7 @@ def generate_query(
else:
prompt = chat_prompt.format(
user_question=question,
table_metadata_string=pruned_metadata_str,
table_metadata_string=table_metadata_string,
instructions=instructions,
k_shot_prompt=k_shot_prompt,
glossary=glossary,
Expand Down Expand Up @@ -211,6 +219,7 @@ def generate_query(
tokens_used = self.count_tokens(self.model, messages=messages)

return {
"table_metadata_string": table_metadata_string,
"query": self.query,
"reason": self.reason,
"err": self.err,
Expand Down
Loading

0 comments on commit b0ab6e3

Please sign in to comment.