From 2ff1fa9cdf6ea3840ed0e36e692003af69d55231 Mon Sep 17 00:00:00 2001 From: wendy Date: Tue, 6 Feb 2024 22:17:37 +0800 Subject: [PATCH 1/3] add glossary --- eval/anthropic_runner.py | 2 ++ eval/api_runner.py | 6 ++++-- eval/hf_runner.py | 6 ++++-- eval/openai_runner.py | 2 ++ eval/vllm_runner.py | 6 ++++-- prompts/README.md | 2 ++ prompts/prompt.md | 2 +- prompts/prompt_anthropic.md | 2 +- prompts/prompt_openai.md | 2 +- query_generators/anthropic.py | 3 ++- query_generators/openai.py | 4 +++- query_generators/query_generator.py | 2 +- utils/questions.py | 9 +++++++++ 13 files changed, 36 insertions(+), 12 deletions(-) diff --git a/eval/anthropic_runner.py b/eval/anthropic_runner.py index 0c961b1..af9b06f 100644 --- a/eval/anthropic_runner.py +++ b/eval/anthropic_runner.py @@ -44,6 +44,7 @@ def run_anthropic_eval(args): verbose=args.verbose, instructions=row["instructions"], k_shot_prompt=row["k_shot_prompt"], + glossary=row["glossary"], ) generated_query_fut = executor.submit( @@ -51,6 +52,7 @@ def run_anthropic_eval(args): question=row["question"], instructions=row["instructions"], k_shot_prompt=row["k_shot_prompt"], + glossary=row["glossary"], ) futures.append(generated_query_fut) diff --git a/eval/api_runner.py b/eval/api_runner.py index 69ec060..1713934 100644 --- a/eval/api_runner.py +++ b/eval/api_runner.py @@ -15,7 +15,7 @@ def generate_prompt( - prompt_file, question, db_name, instructions="", k_shot_prompt="", public_data=True + prompt_file, question, db_name, instructions="", k_shot_prompt="", glossary="", public_data=True ): with open(prompt_file, "r") as f: prompt = f.read() @@ -29,6 +29,7 @@ def generate_prompt( instructions=instructions, table_metadata_string=pruned_metadata_str, k_shot_prompt=k_shot_prompt, + glossary=glossary, ) return prompt @@ -104,7 +105,7 @@ def run_api_eval(args): for prompt_file, output_file in zip(prompt_file_list, output_file_list): # create a prompt for each question df["prompt"] = df[ - ["question", "db_name", "instructions", "k_shot_prompt"] + ["question", "db_name", "instructions", "k_shot_prompt", "glossary"] ].apply( lambda row: generate_prompt( prompt_file, @@ -112,6 +113,7 @@ def run_api_eval(args): row["db_name"], row["instructions"], row["k_shot_prompt"], + row["glossary"], public_data, ), axis=1, diff --git a/eval/hf_runner.py b/eval/hf_runner.py index a84d62d..3cb7064 100644 --- a/eval/hf_runner.py +++ b/eval/hf_runner.py @@ -25,7 +25,7 @@ def generate_prompt( - prompt_file, question, db_name, instructions="", k_shot_prompt="", public_data=True + prompt_file, question, db_name, instructions="", k_shot_prompt="", glossary="", public_data=True ): with open(prompt_file, "r") as f: prompt = f.read() @@ -39,6 +39,7 @@ def generate_prompt( instructions=instructions, table_metadata_string=pruned_metadata_str, k_shot_prompt=k_shot_prompt, + glossary=glossary, ) return prompt @@ -143,7 +144,7 @@ def run_hf_eval(args): for prompt_file, output_file in zip(prompt_file_list, output_file_list): # create a prompt for each question df["prompt"] = df[ - ["question", "db_name", "instructions", "k_shot_prompt"] + ["question", "db_name", "instructions", "k_shot_prompt", "glossary"] ].apply( lambda row: generate_prompt( prompt_file, @@ -151,6 +152,7 @@ def run_hf_eval(args): row["db_name"], row["instructions"], row["k_shot_prompt"], + row["glossary"], public_data, ), axis=1, diff --git a/eval/openai_runner.py b/eval/openai_runner.py index 34341f0..e2ba978 100644 --- a/eval/openai_runner.py +++ b/eval/openai_runner.py @@ -43,6 +43,7 @@ def run_openai_eval(args): verbose=args.verbose, instructions=row["instructions"], k_shot_prompt=row["k_shot_prompt"], + glossary=row["glossary"], ) generated_query_fut = executor.submit( @@ -50,6 +51,7 @@ def run_openai_eval(args): question=row["question"], instructions=row["instructions"], k_shot_prompt=row["k_shot_prompt"], + glossary=row["glossary"], ) futures.append(generated_query_fut) diff --git a/eval/vllm_runner.py b/eval/vllm_runner.py index a141fc8..d5a7473 100644 --- a/eval/vllm_runner.py +++ b/eval/vllm_runner.py @@ -15,7 +15,7 @@ def generate_prompt( - prompt_file, question, db_name, instructions="", k_shot_prompt="", public_data=True + prompt_file, question, db_name, instructions="", k_shot_prompt="", glossary="", public_data=True ): with open(prompt_file, "r") as f: prompt = f.read() @@ -29,6 +29,7 @@ def generate_prompt( instructions=instructions, table_metadata_string=pruned_metadata_str, k_shot_prompt=k_shot_prompt, + glossary=glossary, ) return prompt @@ -76,7 +77,7 @@ def run_vllm_eval(args): print(f"Using prompt file {prompt_file}") # create a prompt for each question df["prompt"] = df[ - ["question", "db_name", "instructions", "k_shot_prompt"] + ["question", "db_name", "instructions", "k_shot_prompt", "glossary"] ].apply( lambda row: generate_prompt( prompt_file, @@ -84,6 +85,7 @@ def run_vllm_eval(args): row["db_name"], row["instructions"], row["k_shot_prompt"], + row["glossary"], public_data, ), axis=1, diff --git a/prompts/README.md b/prompts/README.md index 28680bb..ec3843a 100644 --- a/prompts/README.md +++ b/prompts/README.md @@ -4,6 +4,7 @@ You can define your prompt template by using the following variables: - `table_metadata_string`: The metadata of the table that we want to query. This is a string that contains the table names, column names and column types. This allows the model to know which columns/tables are available for getting information from. For the sqlcoder model that we released, you would need to represent your table metadata as a [SQL DDL](https://en.wikipedia.org/wiki/Data_definition_language) statement. - `instructions`: This is an optional field that allows you to customize specific instructions for each question, if needed. For example, if you want to ask the model to format your dates a particular way, define keywords, or adapt the SQL to a different database, you can do so here. If you don't need to customize the instructions, you can omit this in your prompt. - `k_shot_prompt`: This is another optional field that allows you to provide example SQL queries and their corresponding questions. These examples serve as a context for the model, helping it understand the type of SQL query you're expecting for a given question. Including a few examples in the k_shot_prompt field can significantly improve the model's accuracy in generating relevant SQL queries, especially for complex or less straightforward questions. +- `glossary`: This is an optional field that allows you to define special terminology or rules for creating the SQL queries. Here is how a sample might look like with the above variables: ```markdown @@ -11,6 +12,7 @@ Here is how a sample might look like with the above variables: Generate a SQL query to answer the following question: `{user_question}` `{instructions}` +`{glossary}` ### Database Schema The query will run on a database with the following schema: {table_metadata_string} diff --git a/prompts/prompt.md b/prompts/prompt.md index c7f1285..d2816f3 100644 --- a/prompts/prompt.md +++ b/prompts/prompt.md @@ -1,7 +1,7 @@ ### Task Generate a SQL query to answer the following question: `{user_question}` -{instructions} +{instructions}{glossary} ### Database Schema The query will run on a database with the following schema: {table_metadata_string} diff --git a/prompts/prompt_anthropic.md b/prompts/prompt_anthropic.md index a29db64..dceb94b 100644 --- a/prompts/prompt_anthropic.md +++ b/prompts/prompt_anthropic.md @@ -3,7 +3,7 @@ Human: Your task is to convert a question into a SQL query, given a Postgres database schema. Generate a SQL query that answers the question `{user_question}`. -{instructions} +{instructions}{glossary} This query will run on a database whose schema is represented in this string: {table_metadata_string} diff --git a/prompts/prompt_openai.md b/prompts/prompt_openai.md index ac22fa6..42ff16a 100644 --- a/prompts/prompt_openai.md +++ b/prompts/prompt_openai.md @@ -3,7 +3,7 @@ Your task is to convert a text question to a SQL query that runs on Postgres, gi ### Input: Generate a SQL query that answers the question `{user_question}`. -{instructions} +{instructions}{glossary} This query will run on a database whose schema is represented in this string: {table_metadata_string} {k_shot_prompt} diff --git a/query_generators/anthropic.py b/query_generators/anthropic.py index 5966891..824c73f 100644 --- a/query_generators/anthropic.py +++ b/query_generators/anthropic.py @@ -73,7 +73,7 @@ def count_tokens(prompt: str = "") -> int: return num_tokens def generate_query( - self, question: str, instructions: str, k_shot_prompt: str + self, question: str, instructions: str, k_shot_prompt: str, glossary: str ) -> dict: start_time = time.time() self.err = "" @@ -93,6 +93,7 @@ def generate_query( ), instructions=instructions, k_shot_prompt=k_shot_prompt, + glossary=glossary, ) function_to_run = self.get_completion package = prompt diff --git a/query_generators/openai.py b/query_generators/openai.py index e9d5415..cd9eac0 100644 --- a/query_generators/openai.py +++ b/query_generators/openai.py @@ -110,7 +110,7 @@ def count_tokens( return num_tokens def generate_query( - self, question: str, instructions: str, k_shot_prompt: str + self, question: str, instructions: str, k_shot_prompt: str, glossary: str ) -> dict: start_time = time.time() self.err = "" @@ -137,6 +137,7 @@ def generate_query( ), instructions=instructions, k_shot_prompt=k_shot_prompt, + glossary=glossary, ) messages = [] @@ -151,6 +152,7 @@ def generate_query( ), instructions=instructions, k_shot_prompt=k_shot_prompt, + glossary=glossary, ) function_to_run = None package = None diff --git a/query_generators/query_generator.py b/query_generators/query_generator.py index f3aef93..cf8c089 100644 --- a/query_generators/query_generator.py +++ b/query_generators/query_generator.py @@ -16,7 +16,7 @@ def __init__(self, **kwargs): pass def generate_query( - self, question: str, instructions: str, k_shot_prompt: str + self, question: str, instructions: str, k_shot_prompt: str, glossary: str ) -> dict: # generate a query given a question, instructions and k-shot prompt # any hard-coded logic, prompt-engineering, table-pruning, api calls etc diff --git a/utils/questions.py b/utils/questions.py index 8ef39b0..76e058f 100644 --- a/utils/questions.py +++ b/utils/questions.py @@ -55,5 +55,14 @@ def prepare_questions_df( lambda x: f"\nAdhere closely to the following correct examples as references for answering the question:\n{x}" ) + # get glossary if applicable + if "glossary" in question_query_df.columns: + question_query_df["glossary"] = question_query_df["glossary"].fillna("") + question_query_df["glossary"] = question_query_df["glossary"].apply( + lambda x: f"\nUse the following instructions if and only if they are relevant to the question:\n{x}\n" + ) + else: + question_query_df["glossary"] = "" + question_query_df.reset_index(inplace=True, drop=True) return question_query_df From b0e85969151ba5b20c502d5d2d3ffb0c2a94d820 Mon Sep 17 00:00:00 2001 From: wendy Date: Tue, 6 Feb 2024 23:00:56 +0800 Subject: [PATCH 2/3] add table_metadata_string --- eval/anthropic_runner.py | 2 ++ eval/api_runner.py | 15 ++++++++++----- eval/hf_runner.py | 15 ++++++++++----- eval/openai_runner.py | 2 ++ eval/vllm_runner.py | 15 ++++++++++----- query_generators/anthropic.py | 12 ++++++++---- query_generators/openai.py | 17 +++++++++-------- query_generators/query_generator.py | 2 +- utils/questions.py | 9 +++++++++ 9 files changed, 61 insertions(+), 28 deletions(-) diff --git a/eval/anthropic_runner.py b/eval/anthropic_runner.py index af9b06f..18491ee 100644 --- a/eval/anthropic_runner.py +++ b/eval/anthropic_runner.py @@ -45,6 +45,7 @@ def run_anthropic_eval(args): instructions=row["instructions"], k_shot_prompt=row["k_shot_prompt"], glossary=row["glossary"], + table_metadata_string=row["table_metadata_string"], ) generated_query_fut = executor.submit( @@ -53,6 +54,7 @@ def run_anthropic_eval(args): instructions=row["instructions"], k_shot_prompt=row["k_shot_prompt"], glossary=row["glossary"], + table_metadata_string=row["table_metadata_string"], ) futures.append(generated_query_fut) diff --git a/eval/api_runner.py b/eval/api_runner.py index 1713934..a220d44 100644 --- a/eval/api_runner.py +++ b/eval/api_runner.py @@ -15,15 +15,19 @@ def generate_prompt( - prompt_file, question, db_name, instructions="", k_shot_prompt="", glossary="", public_data=True + prompt_file, question, db_name, instructions="", k_shot_prompt="", glossary="", table_metadata_string="", public_data=True ): with open(prompt_file, "r") as f: prompt = f.read() question_instructions = question + " " + instructions - pruned_metadata_str = prune_metadata_str( - question_instructions, db_name, public_data - ) + if table_metadata_string == "": + pruned_metadata_str = prune_metadata_str( + question_instructions, db_name, public_data + ) + else: + pruned_metadata_str = table_metadata_string + prompt = prompt.format( user_question=question, instructions=instructions, @@ -105,7 +109,7 @@ def run_api_eval(args): for prompt_file, output_file in zip(prompt_file_list, output_file_list): # create a prompt for each question df["prompt"] = df[ - ["question", "db_name", "instructions", "k_shot_prompt", "glossary"] + ["question", "db_name", "instructions", "k_shot_prompt", "glossary", "table_metadata_string"] ].apply( lambda row: generate_prompt( prompt_file, @@ -114,6 +118,7 @@ def run_api_eval(args): row["instructions"], row["k_shot_prompt"], row["glossary"], + row["table_metadata_string"], public_data, ), axis=1, diff --git a/eval/hf_runner.py b/eval/hf_runner.py index 3cb7064..82d0426 100644 --- a/eval/hf_runner.py +++ b/eval/hf_runner.py @@ -25,15 +25,19 @@ def generate_prompt( - prompt_file, question, db_name, instructions="", k_shot_prompt="", glossary="", public_data=True + prompt_file, question, db_name, instructions="", k_shot_prompt="", glossary="", table_metadata_string="", public_data=True ): with open(prompt_file, "r") as f: prompt = f.read() question_instructions = question + " " + instructions - pruned_metadata_str = prune_metadata_str( - question_instructions, db_name, public_data - ) + if table_metadata_string == "": + pruned_metadata_str = prune_metadata_str( + question_instructions, db_name, public_data + ) + else: + pruned_metadata_str = table_metadata_string + prompt = prompt.format( user_question=question, instructions=instructions, @@ -144,7 +148,7 @@ def run_hf_eval(args): for prompt_file, output_file in zip(prompt_file_list, output_file_list): # create a prompt for each question df["prompt"] = df[ - ["question", "db_name", "instructions", "k_shot_prompt", "glossary"] + ["question", "db_name", "instructions", "k_shot_prompt", "glossary", "table_metadata_string"] ].apply( lambda row: generate_prompt( prompt_file, @@ -153,6 +157,7 @@ def run_hf_eval(args): row["instructions"], row["k_shot_prompt"], row["glossary"], + row["table_metadata_string"], public_data, ), axis=1, diff --git a/eval/openai_runner.py b/eval/openai_runner.py index e2ba978..90c3220 100644 --- a/eval/openai_runner.py +++ b/eval/openai_runner.py @@ -44,6 +44,7 @@ def run_openai_eval(args): instructions=row["instructions"], k_shot_prompt=row["k_shot_prompt"], glossary=row["glossary"], + table_metadata_string=row["table_metadata_string"], ) generated_query_fut = executor.submit( @@ -52,6 +53,7 @@ def run_openai_eval(args): instructions=row["instructions"], k_shot_prompt=row["k_shot_prompt"], glossary=row["glossary"], + table_metadata_string=row["table_metadata_string"], ) futures.append(generated_query_fut) diff --git a/eval/vllm_runner.py b/eval/vllm_runner.py index d5a7473..58ac09b 100644 --- a/eval/vllm_runner.py +++ b/eval/vllm_runner.py @@ -15,15 +15,19 @@ def generate_prompt( - prompt_file, question, db_name, instructions="", k_shot_prompt="", glossary="", public_data=True + prompt_file, question, db_name, instructions="", k_shot_prompt="", glossary="", table_metadata_string="", public_data=True ): with open(prompt_file, "r") as f: prompt = f.read() question_instructions = question + " " + instructions - pruned_metadata_str = prune_metadata_str( - question_instructions, db_name, public_data - ) + if table_metadata_string == "": + pruned_metadata_str = prune_metadata_str( + question_instructions, db_name, public_data + ) + else: + pruned_metadata_str = table_metadata_string + prompt = prompt.format( user_question=question, instructions=instructions, @@ -77,7 +81,7 @@ def run_vllm_eval(args): print(f"Using prompt file {prompt_file}") # create a prompt for each question df["prompt"] = df[ - ["question", "db_name", "instructions", "k_shot_prompt", "glossary"] + ["question", "db_name", "instructions", "k_shot_prompt", "glossary", "table_metadata_string"] ].apply( lambda row: generate_prompt( prompt_file, @@ -86,6 +90,7 @@ def run_vllm_eval(args): row["instructions"], row["k_shot_prompt"], row["glossary"], + row["table_metadata_string"], public_data, ), axis=1, diff --git a/query_generators/anthropic.py b/query_generators/anthropic.py index 824c73f..d53418f 100644 --- a/query_generators/anthropic.py +++ b/query_generators/anthropic.py @@ -73,7 +73,7 @@ def count_tokens(prompt: str = "") -> int: return num_tokens def generate_query( - self, question: str, instructions: str, k_shot_prompt: str, glossary: str + self, question: str, instructions: str, k_shot_prompt: str, glossary: str, table_metadata_string: str ) -> dict: start_time = time.time() self.err = "" @@ -86,11 +86,15 @@ def generate_query( if "Human:" not in model_prompt: raise ValueError("Invalid prompt file. Please use prompt_anthropic.md") question_instructions = question + " " + instructions + if table_metadata_string == "": + pruned_metadata_str = prune_metadata_str( + question_instructions, self.db_name, self.use_public_data + ) + else: + pruned_metadata_str = table_metadata_string prompt = model_prompt.format( user_question=question, - table_metadata_string=prune_metadata_str( - question_instructions, self.db_name, self.use_public_data - ), + table_metadata_string=pruned_metadata_str, instructions=instructions, k_shot_prompt=k_shot_prompt, glossary=glossary, diff --git a/query_generators/openai.py b/query_generators/openai.py index cd9eac0..ad5368b 100644 --- a/query_generators/openai.py +++ b/query_generators/openai.py @@ -110,7 +110,7 @@ def count_tokens( return num_tokens def generate_query( - self, question: str, instructions: str, k_shot_prompt: str, glossary: str + self, question: str, instructions: str, k_shot_prompt: str, glossary: str, table_metadata_string: str ) -> dict: start_time = time.time() self.err = "" @@ -120,6 +120,12 @@ def generate_query( with open(self.prompt_file) as file: chat_prompt = file.read() question_instructions = question + " " + instructions + if table_metadata_string == "": + pruned_metadata_str = prune_metadata_str( + question_instructions, self.db_name, self.use_public_data + ) + else: + pruned_metadata_str = table_metadata_string if self.model != "text-davinci-003": try: sys_prompt = chat_prompt.split("### Input:")[0] @@ -129,12 +135,9 @@ def generate_query( assistant_prompt = chat_prompt.split("### Response:")[1] except: raise ValueError("Invalid prompt file. Please use prompt_openai.md") - user_prompt = user_prompt.format( user_question=question, - table_metadata_string=prune_metadata_str( - question_instructions, self.db_name, self.use_public_data - ), + table_metadata_string=pruned_metadata_str, instructions=instructions, k_shot_prompt=k_shot_prompt, glossary=glossary, @@ -147,9 +150,7 @@ def generate_query( else: prompt = chat_prompt.format( user_question=question, - table_metadata_string=prune_metadata_str( - question_instructions, self.db_name, self.use_public_data - ), + table_metadata_string=pruned_metadata_str, instructions=instructions, k_shot_prompt=k_shot_prompt, glossary=glossary, diff --git a/query_generators/query_generator.py b/query_generators/query_generator.py index cf8c089..a2c4ca8 100644 --- a/query_generators/query_generator.py +++ b/query_generators/query_generator.py @@ -16,7 +16,7 @@ def __init__(self, **kwargs): pass def generate_query( - self, question: str, instructions: str, k_shot_prompt: str, glossary: str + self, question: str, instructions: str, k_shot_prompt: str, glossary: str, table_metadata_string: str ) -> dict: # generate a query given a question, instructions and k-shot prompt # any hard-coded logic, prompt-engineering, table-pruning, api calls etc diff --git a/utils/questions.py b/utils/questions.py index 76e058f..539d82a 100644 --- a/utils/questions.py +++ b/utils/questions.py @@ -65,4 +65,13 @@ def prepare_questions_df( question_query_df["glossary"] = "" question_query_df.reset_index(inplace=True, drop=True) + + # get table_metadata_string if applicable + if "table_metadata_string" in question_query_df.columns: + question_query_df["table_metadata_string"] = question_query_df[ + "table_metadata_string" + ].fillna("") + else: + question_query_df["table_metadata_string"] = "" + return question_query_df From e16b03821cef481d59efb9c19d0200e911923bfb Mon Sep 17 00:00:00 2001 From: wendy Date: Tue, 6 Feb 2024 23:09:58 +0800 Subject: [PATCH 3/3] linted --- eval/api_runner.py | 18 ++++++++++++++++-- eval/hf_runner.py | 18 ++++++++++++++++-- eval/vllm_runner.py | 18 ++++++++++++++++-- query_generators/anthropic.py | 7 ++++++- query_generators/openai.py | 9 +++++++-- query_generators/query_generator.py | 7 ++++++- utils/questions.py | 2 +- 7 files changed, 68 insertions(+), 11 deletions(-) diff --git a/eval/api_runner.py b/eval/api_runner.py index a220d44..c9d13fb 100644 --- a/eval/api_runner.py +++ b/eval/api_runner.py @@ -15,7 +15,14 @@ def generate_prompt( - prompt_file, question, db_name, instructions="", k_shot_prompt="", glossary="", table_metadata_string="", public_data=True + prompt_file, + question, + db_name, + instructions="", + k_shot_prompt="", + glossary="", + table_metadata_string="", + public_data=True, ): with open(prompt_file, "r") as f: prompt = f.read() @@ -109,7 +116,14 @@ def run_api_eval(args): for prompt_file, output_file in zip(prompt_file_list, output_file_list): # create a prompt for each question df["prompt"] = df[ - ["question", "db_name", "instructions", "k_shot_prompt", "glossary", "table_metadata_string"] + [ + "question", + "db_name", + "instructions", + "k_shot_prompt", + "glossary", + "table_metadata_string", + ] ].apply( lambda row: generate_prompt( prompt_file, diff --git a/eval/hf_runner.py b/eval/hf_runner.py index 82d0426..eaecb6e 100644 --- a/eval/hf_runner.py +++ b/eval/hf_runner.py @@ -25,7 +25,14 @@ def generate_prompt( - prompt_file, question, db_name, instructions="", k_shot_prompt="", glossary="", table_metadata_string="", public_data=True + prompt_file, + question, + db_name, + instructions="", + k_shot_prompt="", + glossary="", + table_metadata_string="", + public_data=True, ): with open(prompt_file, "r") as f: prompt = f.read() @@ -148,7 +155,14 @@ def run_hf_eval(args): for prompt_file, output_file in zip(prompt_file_list, output_file_list): # create a prompt for each question df["prompt"] = df[ - ["question", "db_name", "instructions", "k_shot_prompt", "glossary", "table_metadata_string"] + [ + "question", + "db_name", + "instructions", + "k_shot_prompt", + "glossary", + "table_metadata_string", + ] ].apply( lambda row: generate_prompt( prompt_file, diff --git a/eval/vllm_runner.py b/eval/vllm_runner.py index 58ac09b..740f77d 100644 --- a/eval/vllm_runner.py +++ b/eval/vllm_runner.py @@ -15,7 +15,14 @@ def generate_prompt( - prompt_file, question, db_name, instructions="", k_shot_prompt="", glossary="", table_metadata_string="", public_data=True + prompt_file, + question, + db_name, + instructions="", + k_shot_prompt="", + glossary="", + table_metadata_string="", + public_data=True, ): with open(prompt_file, "r") as f: prompt = f.read() @@ -81,7 +88,14 @@ def run_vllm_eval(args): print(f"Using prompt file {prompt_file}") # create a prompt for each question df["prompt"] = df[ - ["question", "db_name", "instructions", "k_shot_prompt", "glossary", "table_metadata_string"] + [ + "question", + "db_name", + "instructions", + "k_shot_prompt", + "glossary", + "table_metadata_string", + ] ].apply( lambda row: generate_prompt( prompt_file, diff --git a/query_generators/anthropic.py b/query_generators/anthropic.py index d53418f..48fc2ee 100644 --- a/query_generators/anthropic.py +++ b/query_generators/anthropic.py @@ -73,7 +73,12 @@ def count_tokens(prompt: str = "") -> int: return num_tokens def generate_query( - self, question: str, instructions: str, k_shot_prompt: str, glossary: str, table_metadata_string: str + self, + question: str, + instructions: str, + k_shot_prompt: str, + glossary: str, + table_metadata_string: str, ) -> dict: start_time = time.time() self.err = "" diff --git a/query_generators/openai.py b/query_generators/openai.py index ad5368b..fedfeac 100644 --- a/query_generators/openai.py +++ b/query_generators/openai.py @@ -110,7 +110,12 @@ def count_tokens( return num_tokens def generate_query( - self, question: str, instructions: str, k_shot_prompt: str, glossary: str, table_metadata_string: str + self, + question: str, + instructions: str, + k_shot_prompt: str, + glossary: str, + table_metadata_string: str, ) -> dict: start_time = time.time() self.err = "" @@ -121,7 +126,7 @@ def generate_query( chat_prompt = file.read() question_instructions = question + " " + instructions if table_metadata_string == "": - pruned_metadata_str = prune_metadata_str( + pruned_metadata_str = prune_metadata_str( question_instructions, self.db_name, self.use_public_data ) else: diff --git a/query_generators/query_generator.py b/query_generators/query_generator.py index a2c4ca8..a40e6be 100644 --- a/query_generators/query_generator.py +++ b/query_generators/query_generator.py @@ -16,7 +16,12 @@ def __init__(self, **kwargs): pass def generate_query( - self, question: str, instructions: str, k_shot_prompt: str, glossary: str, table_metadata_string: str + self, + question: str, + instructions: str, + k_shot_prompt: str, + glossary: str, + table_metadata_string: str, ) -> dict: # generate a query given a question, instructions and k-shot prompt # any hard-coded logic, prompt-engineering, table-pruning, api calls etc diff --git a/utils/questions.py b/utils/questions.py index 539d82a..49eb440 100644 --- a/utils/questions.py +++ b/utils/questions.py @@ -73,5 +73,5 @@ def prepare_questions_df( ].fillna("") else: question_query_df["table_metadata_string"] = "" - + return question_query_df