diff --git a/eval/anthropic_runner.py b/eval/anthropic_runner.py index 0c961b1..18491ee 100644 --- a/eval/anthropic_runner.py +++ b/eval/anthropic_runner.py @@ -44,6 +44,8 @@ def run_anthropic_eval(args): verbose=args.verbose, instructions=row["instructions"], k_shot_prompt=row["k_shot_prompt"], + glossary=row["glossary"], + table_metadata_string=row["table_metadata_string"], ) generated_query_fut = executor.submit( @@ -51,6 +53,8 @@ def run_anthropic_eval(args): question=row["question"], instructions=row["instructions"], k_shot_prompt=row["k_shot_prompt"], + glossary=row["glossary"], + table_metadata_string=row["table_metadata_string"], ) futures.append(generated_query_fut) diff --git a/eval/api_runner.py b/eval/api_runner.py index 69ec060..c9d13fb 100644 --- a/eval/api_runner.py +++ b/eval/api_runner.py @@ -15,20 +15,32 @@ def generate_prompt( - prompt_file, question, db_name, instructions="", k_shot_prompt="", public_data=True + prompt_file, + question, + db_name, + instructions="", + k_shot_prompt="", + glossary="", + table_metadata_string="", + public_data=True, ): with open(prompt_file, "r") as f: prompt = f.read() question_instructions = question + " " + instructions - pruned_metadata_str = prune_metadata_str( - question_instructions, db_name, public_data - ) + if table_metadata_string == "": + pruned_metadata_str = prune_metadata_str( + question_instructions, db_name, public_data + ) + else: + pruned_metadata_str = table_metadata_string + prompt = prompt.format( user_question=question, instructions=instructions, table_metadata_string=pruned_metadata_str, k_shot_prompt=k_shot_prompt, + glossary=glossary, ) return prompt @@ -104,7 +116,14 @@ def run_api_eval(args): for prompt_file, output_file in zip(prompt_file_list, output_file_list): # create a prompt for each question df["prompt"] = df[ - ["question", "db_name", "instructions", "k_shot_prompt"] + [ + "question", + "db_name", + "instructions", + "k_shot_prompt", + "glossary", + "table_metadata_string", + ] ].apply( lambda row: generate_prompt( prompt_file, @@ -112,6 +131,8 @@ def run_api_eval(args): row["db_name"], row["instructions"], row["k_shot_prompt"], + row["glossary"], + row["table_metadata_string"], public_data, ), axis=1, diff --git a/eval/hf_runner.py b/eval/hf_runner.py index a84d62d..eaecb6e 100644 --- a/eval/hf_runner.py +++ b/eval/hf_runner.py @@ -25,20 +25,32 @@ def generate_prompt( - prompt_file, question, db_name, instructions="", k_shot_prompt="", public_data=True + prompt_file, + question, + db_name, + instructions="", + k_shot_prompt="", + glossary="", + table_metadata_string="", + public_data=True, ): with open(prompt_file, "r") as f: prompt = f.read() question_instructions = question + " " + instructions - pruned_metadata_str = prune_metadata_str( - question_instructions, db_name, public_data - ) + if table_metadata_string == "": + pruned_metadata_str = prune_metadata_str( + question_instructions, db_name, public_data + ) + else: + pruned_metadata_str = table_metadata_string + prompt = prompt.format( user_question=question, instructions=instructions, table_metadata_string=pruned_metadata_str, k_shot_prompt=k_shot_prompt, + glossary=glossary, ) return prompt @@ -143,7 +155,14 @@ def run_hf_eval(args): for prompt_file, output_file in zip(prompt_file_list, output_file_list): # create a prompt for each question df["prompt"] = df[ - ["question", "db_name", "instructions", "k_shot_prompt"] + [ + "question", + "db_name", + "instructions", + "k_shot_prompt", + "glossary", + "table_metadata_string", + ] ].apply( lambda row: generate_prompt( prompt_file, @@ -151,6 +170,8 @@ def run_hf_eval(args): row["db_name"], row["instructions"], row["k_shot_prompt"], + row["glossary"], + row["table_metadata_string"], public_data, ), axis=1, diff --git a/eval/openai_runner.py b/eval/openai_runner.py index 34341f0..90c3220 100644 --- a/eval/openai_runner.py +++ b/eval/openai_runner.py @@ -43,6 +43,8 @@ def run_openai_eval(args): verbose=args.verbose, instructions=row["instructions"], k_shot_prompt=row["k_shot_prompt"], + glossary=row["glossary"], + table_metadata_string=row["table_metadata_string"], ) generated_query_fut = executor.submit( @@ -50,6 +52,8 @@ def run_openai_eval(args): question=row["question"], instructions=row["instructions"], k_shot_prompt=row["k_shot_prompt"], + glossary=row["glossary"], + table_metadata_string=row["table_metadata_string"], ) futures.append(generated_query_fut) diff --git a/eval/vllm_runner.py b/eval/vllm_runner.py index a141fc8..740f77d 100644 --- a/eval/vllm_runner.py +++ b/eval/vllm_runner.py @@ -15,20 +15,32 @@ def generate_prompt( - prompt_file, question, db_name, instructions="", k_shot_prompt="", public_data=True + prompt_file, + question, + db_name, + instructions="", + k_shot_prompt="", + glossary="", + table_metadata_string="", + public_data=True, ): with open(prompt_file, "r") as f: prompt = f.read() question_instructions = question + " " + instructions - pruned_metadata_str = prune_metadata_str( - question_instructions, db_name, public_data - ) + if table_metadata_string == "": + pruned_metadata_str = prune_metadata_str( + question_instructions, db_name, public_data + ) + else: + pruned_metadata_str = table_metadata_string + prompt = prompt.format( user_question=question, instructions=instructions, table_metadata_string=pruned_metadata_str, k_shot_prompt=k_shot_prompt, + glossary=glossary, ) return prompt @@ -76,7 +88,14 @@ def run_vllm_eval(args): print(f"Using prompt file {prompt_file}") # create a prompt for each question df["prompt"] = df[ - ["question", "db_name", "instructions", "k_shot_prompt"] + [ + "question", + "db_name", + "instructions", + "k_shot_prompt", + "glossary", + "table_metadata_string", + ] ].apply( lambda row: generate_prompt( prompt_file, @@ -84,6 +103,8 @@ def run_vllm_eval(args): row["db_name"], row["instructions"], row["k_shot_prompt"], + row["glossary"], + row["table_metadata_string"], public_data, ), axis=1, diff --git a/prompts/README.md b/prompts/README.md index 28680bb..ec3843a 100644 --- a/prompts/README.md +++ b/prompts/README.md @@ -4,6 +4,7 @@ You can define your prompt template by using the following variables: - `table_metadata_string`: The metadata of the table that we want to query. This is a string that contains the table names, column names and column types. This allows the model to know which columns/tables are available for getting information from. For the sqlcoder model that we released, you would need to represent your table metadata as a [SQL DDL](https://en.wikipedia.org/wiki/Data_definition_language) statement. - `instructions`: This is an optional field that allows you to customize specific instructions for each question, if needed. For example, if you want to ask the model to format your dates a particular way, define keywords, or adapt the SQL to a different database, you can do so here. If you don't need to customize the instructions, you can omit this in your prompt. - `k_shot_prompt`: This is another optional field that allows you to provide example SQL queries and their corresponding questions. These examples serve as a context for the model, helping it understand the type of SQL query you're expecting for a given question. Including a few examples in the k_shot_prompt field can significantly improve the model's accuracy in generating relevant SQL queries, especially for complex or less straightforward questions. +- `glossary`: This is an optional field that allows you to define special terminology or rules for creating the SQL queries. Here is how a sample might look like with the above variables: ```markdown @@ -11,6 +12,7 @@ Here is how a sample might look like with the above variables: Generate a SQL query to answer the following question: `{user_question}` `{instructions}` +`{glossary}` ### Database Schema The query will run on a database with the following schema: {table_metadata_string} diff --git a/prompts/prompt.md b/prompts/prompt.md index c7f1285..d2816f3 100644 --- a/prompts/prompt.md +++ b/prompts/prompt.md @@ -1,7 +1,7 @@ ### Task Generate a SQL query to answer the following question: `{user_question}` -{instructions} +{instructions}{glossary} ### Database Schema The query will run on a database with the following schema: {table_metadata_string} diff --git a/prompts/prompt_anthropic.md b/prompts/prompt_anthropic.md index a29db64..dceb94b 100644 --- a/prompts/prompt_anthropic.md +++ b/prompts/prompt_anthropic.md @@ -3,7 +3,7 @@ Human: Your task is to convert a question into a SQL query, given a Postgres database schema. Generate a SQL query that answers the question `{user_question}`. -{instructions} +{instructions}{glossary} This query will run on a database whose schema is represented in this string: {table_metadata_string} diff --git a/prompts/prompt_openai.md b/prompts/prompt_openai.md index ac22fa6..42ff16a 100644 --- a/prompts/prompt_openai.md +++ b/prompts/prompt_openai.md @@ -3,7 +3,7 @@ Your task is to convert a text question to a SQL query that runs on Postgres, gi ### Input: Generate a SQL query that answers the question `{user_question}`. -{instructions} +{instructions}{glossary} This query will run on a database whose schema is represented in this string: {table_metadata_string} {k_shot_prompt} diff --git a/query_generators/anthropic.py b/query_generators/anthropic.py index 5966891..48fc2ee 100644 --- a/query_generators/anthropic.py +++ b/query_generators/anthropic.py @@ -73,7 +73,12 @@ def count_tokens(prompt: str = "") -> int: return num_tokens def generate_query( - self, question: str, instructions: str, k_shot_prompt: str + self, + question: str, + instructions: str, + k_shot_prompt: str, + glossary: str, + table_metadata_string: str, ) -> dict: start_time = time.time() self.err = "" @@ -86,13 +91,18 @@ def generate_query( if "Human:" not in model_prompt: raise ValueError("Invalid prompt file. Please use prompt_anthropic.md") question_instructions = question + " " + instructions + if table_metadata_string == "": + pruned_metadata_str = prune_metadata_str( + question_instructions, self.db_name, self.use_public_data + ) + else: + pruned_metadata_str = table_metadata_string prompt = model_prompt.format( user_question=question, - table_metadata_string=prune_metadata_str( - question_instructions, self.db_name, self.use_public_data - ), + table_metadata_string=pruned_metadata_str, instructions=instructions, k_shot_prompt=k_shot_prompt, + glossary=glossary, ) function_to_run = self.get_completion package = prompt diff --git a/query_generators/openai.py b/query_generators/openai.py index e9d5415..fedfeac 100644 --- a/query_generators/openai.py +++ b/query_generators/openai.py @@ -110,7 +110,12 @@ def count_tokens( return num_tokens def generate_query( - self, question: str, instructions: str, k_shot_prompt: str + self, + question: str, + instructions: str, + k_shot_prompt: str, + glossary: str, + table_metadata_string: str, ) -> dict: start_time = time.time() self.err = "" @@ -120,6 +125,12 @@ def generate_query( with open(self.prompt_file) as file: chat_prompt = file.read() question_instructions = question + " " + instructions + if table_metadata_string == "": + pruned_metadata_str = prune_metadata_str( + question_instructions, self.db_name, self.use_public_data + ) + else: + pruned_metadata_str = table_metadata_string if self.model != "text-davinci-003": try: sys_prompt = chat_prompt.split("### Input:")[0] @@ -129,14 +140,12 @@ def generate_query( assistant_prompt = chat_prompt.split("### Response:")[1] except: raise ValueError("Invalid prompt file. Please use prompt_openai.md") - user_prompt = user_prompt.format( user_question=question, - table_metadata_string=prune_metadata_str( - question_instructions, self.db_name, self.use_public_data - ), + table_metadata_string=pruned_metadata_str, instructions=instructions, k_shot_prompt=k_shot_prompt, + glossary=glossary, ) messages = [] @@ -146,11 +155,10 @@ def generate_query( else: prompt = chat_prompt.format( user_question=question, - table_metadata_string=prune_metadata_str( - question_instructions, self.db_name, self.use_public_data - ), + table_metadata_string=pruned_metadata_str, instructions=instructions, k_shot_prompt=k_shot_prompt, + glossary=glossary, ) function_to_run = None package = None diff --git a/query_generators/query_generator.py b/query_generators/query_generator.py index f3aef93..a40e6be 100644 --- a/query_generators/query_generator.py +++ b/query_generators/query_generator.py @@ -16,7 +16,12 @@ def __init__(self, **kwargs): pass def generate_query( - self, question: str, instructions: str, k_shot_prompt: str + self, + question: str, + instructions: str, + k_shot_prompt: str, + glossary: str, + table_metadata_string: str, ) -> dict: # generate a query given a question, instructions and k-shot prompt # any hard-coded logic, prompt-engineering, table-pruning, api calls etc diff --git a/utils/questions.py b/utils/questions.py index 8ef39b0..49eb440 100644 --- a/utils/questions.py +++ b/utils/questions.py @@ -55,5 +55,23 @@ def prepare_questions_df( lambda x: f"\nAdhere closely to the following correct examples as references for answering the question:\n{x}" ) + # get glossary if applicable + if "glossary" in question_query_df.columns: + question_query_df["glossary"] = question_query_df["glossary"].fillna("") + question_query_df["glossary"] = question_query_df["glossary"].apply( + lambda x: f"\nUse the following instructions if and only if they are relevant to the question:\n{x}\n" + ) + else: + question_query_df["glossary"] = "" + question_query_df.reset_index(inplace=True, drop=True) + + # get table_metadata_string if applicable + if "table_metadata_string" in question_query_df.columns: + question_query_df["table_metadata_string"] = question_query_df[ + "table_metadata_string" + ].fillna("") + else: + question_query_df["table_metadata_string"] = "" + return question_query_df