Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Glossary & table metadata string #80

Merged
merged 3 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions eval/anthropic_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,17 @@ def run_anthropic_eval(args):
verbose=args.verbose,
instructions=row["instructions"],
k_shot_prompt=row["k_shot_prompt"],
glossary=row["glossary"],
table_metadata_string=row["table_metadata_string"],
)

generated_query_fut = executor.submit(
qg.generate_query,
question=row["question"],
instructions=row["instructions"],
k_shot_prompt=row["k_shot_prompt"],
glossary=row["glossary"],
table_metadata_string=row["table_metadata_string"],
)
futures.append(generated_query_fut)

Expand Down
31 changes: 26 additions & 5 deletions eval/api_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,32 @@


def generate_prompt(
prompt_file, question, db_name, instructions="", k_shot_prompt="", public_data=True
prompt_file,
question,
db_name,
instructions="",
k_shot_prompt="",
glossary="",
table_metadata_string="",
public_data=True,
):
with open(prompt_file, "r") as f:
prompt = f.read()
question_instructions = question + " " + instructions

pruned_metadata_str = prune_metadata_str(
question_instructions, db_name, public_data
)
if table_metadata_string == "":
pruned_metadata_str = prune_metadata_str(
question_instructions, db_name, public_data
)
else:
pruned_metadata_str = table_metadata_string

prompt = prompt.format(
user_question=question,
instructions=instructions,
table_metadata_string=pruned_metadata_str,
k_shot_prompt=k_shot_prompt,
glossary=glossary,
)
return prompt

Expand Down Expand Up @@ -104,14 +116,23 @@ def run_api_eval(args):
for prompt_file, output_file in zip(prompt_file_list, output_file_list):
# create a prompt for each question
df["prompt"] = df[
["question", "db_name", "instructions", "k_shot_prompt"]
[
"question",
"db_name",
"instructions",
"k_shot_prompt",
"glossary",
"table_metadata_string",
]
].apply(
lambda row: generate_prompt(
prompt_file,
row["question"],
row["db_name"],
row["instructions"],
row["k_shot_prompt"],
row["glossary"],
row["table_metadata_string"],
public_data,
),
axis=1,
Expand Down
31 changes: 26 additions & 5 deletions eval/hf_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,32 @@


def generate_prompt(
prompt_file, question, db_name, instructions="", k_shot_prompt="", public_data=True
prompt_file,
question,
db_name,
instructions="",
k_shot_prompt="",
glossary="",
table_metadata_string="",
public_data=True,
):
with open(prompt_file, "r") as f:
prompt = f.read()
question_instructions = question + " " + instructions

pruned_metadata_str = prune_metadata_str(
question_instructions, db_name, public_data
)
if table_metadata_string == "":
pruned_metadata_str = prune_metadata_str(
question_instructions, db_name, public_data
)
else:
pruned_metadata_str = table_metadata_string

prompt = prompt.format(
user_question=question,
instructions=instructions,
table_metadata_string=pruned_metadata_str,
k_shot_prompt=k_shot_prompt,
glossary=glossary,
)
return prompt

Expand Down Expand Up @@ -143,14 +155,23 @@ def run_hf_eval(args):
for prompt_file, output_file in zip(prompt_file_list, output_file_list):
# create a prompt for each question
df["prompt"] = df[
["question", "db_name", "instructions", "k_shot_prompt"]
[
"question",
"db_name",
"instructions",
"k_shot_prompt",
"glossary",
"table_metadata_string",
]
].apply(
lambda row: generate_prompt(
prompt_file,
row["question"],
row["db_name"],
row["instructions"],
row["k_shot_prompt"],
row["glossary"],
row["table_metadata_string"],
public_data,
),
axis=1,
Expand Down
4 changes: 4 additions & 0 deletions eval/openai_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,17 @@ def run_openai_eval(args):
verbose=args.verbose,
instructions=row["instructions"],
k_shot_prompt=row["k_shot_prompt"],
glossary=row["glossary"],
table_metadata_string=row["table_metadata_string"],
)

generated_query_fut = executor.submit(
qg.generate_query,
question=row["question"],
instructions=row["instructions"],
k_shot_prompt=row["k_shot_prompt"],
glossary=row["glossary"],
table_metadata_string=row["table_metadata_string"],
)
futures.append(generated_query_fut)

Expand Down
31 changes: 26 additions & 5 deletions eval/vllm_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,32 @@


def generate_prompt(
prompt_file, question, db_name, instructions="", k_shot_prompt="", public_data=True
prompt_file,
question,
db_name,
instructions="",
k_shot_prompt="",
glossary="",
table_metadata_string="",
public_data=True,
):
with open(prompt_file, "r") as f:
prompt = f.read()
question_instructions = question + " " + instructions

pruned_metadata_str = prune_metadata_str(
question_instructions, db_name, public_data
)
if table_metadata_string == "":
pruned_metadata_str = prune_metadata_str(
question_instructions, db_name, public_data
)
else:
pruned_metadata_str = table_metadata_string

prompt = prompt.format(
user_question=question,
instructions=instructions,
table_metadata_string=pruned_metadata_str,
k_shot_prompt=k_shot_prompt,
glossary=glossary,
)
return prompt

Expand Down Expand Up @@ -76,14 +88,23 @@ def run_vllm_eval(args):
print(f"Using prompt file {prompt_file}")
# create a prompt for each question
df["prompt"] = df[
["question", "db_name", "instructions", "k_shot_prompt"]
[
"question",
"db_name",
"instructions",
"k_shot_prompt",
"glossary",
"table_metadata_string",
]
].apply(
lambda row: generate_prompt(
prompt_file,
row["question"],
row["db_name"],
row["instructions"],
row["k_shot_prompt"],
row["glossary"],
row["table_metadata_string"],
public_data,
),
axis=1,
Expand Down
2 changes: 2 additions & 0 deletions prompts/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ You can define your prompt template by using the following variables:
- `table_metadata_string`: The metadata of the table that we want to query. This is a string that contains the table names, column names and column types. This allows the model to know which columns/tables are available for getting information from. For the sqlcoder model that we released, you would need to represent your table metadata as a [SQL DDL](https://en.wikipedia.org/wiki/Data_definition_language) statement.
- `instructions`: This is an optional field that allows you to customize specific instructions for each question, if needed. For example, if you want to ask the model to format your dates a particular way, define keywords, or adapt the SQL to a different database, you can do so here. If you don't need to customize the instructions, you can omit this in your prompt.
- `k_shot_prompt`: This is another optional field that allows you to provide example SQL queries and their corresponding questions. These examples serve as a context for the model, helping it understand the type of SQL query you're expecting for a given question. Including a few examples in the k_shot_prompt field can significantly improve the model's accuracy in generating relevant SQL queries, especially for complex or less straightforward questions.
- `glossary`: This is an optional field that allows you to define special terminology or rules for creating the SQL queries.

Here is how a sample might look like with the above variables:
```markdown
### Task
Generate a SQL query to answer the following question:
`{user_question}`
`{instructions}`
`{glossary}`
### Database Schema
The query will run on a database with the following schema:
{table_metadata_string}
Expand Down
2 changes: 1 addition & 1 deletion prompts/prompt.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
### Task
Generate a SQL query to answer the following question:
`{user_question}`
{instructions}
{instructions}{glossary}
### Database Schema
The query will run on a database with the following schema:
{table_metadata_string}
Expand Down
2 changes: 1 addition & 1 deletion prompts/prompt_anthropic.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Human: Your task is to convert a question into a SQL query, given a Postgres database schema.

Generate a SQL query that answers the question `{user_question}`.
{instructions}
{instructions}{glossary}
This query will run on a database whose schema is represented in this string:
{table_metadata_string}

Expand Down
2 changes: 1 addition & 1 deletion prompts/prompt_openai.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Your task is to convert a text question to a SQL query that runs on Postgres, gi

### Input:
Generate a SQL query that answers the question `{user_question}`.
{instructions}
{instructions}{glossary}
This query will run on a database whose schema is represented in this string:
{table_metadata_string}
{k_shot_prompt}
Expand Down
18 changes: 14 additions & 4 deletions query_generators/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,12 @@ def count_tokens(prompt: str = "") -> int:
return num_tokens

def generate_query(
self, question: str, instructions: str, k_shot_prompt: str
self,
question: str,
instructions: str,
k_shot_prompt: str,
glossary: str,
table_metadata_string: str,
) -> dict:
start_time = time.time()
self.err = ""
Expand All @@ -86,13 +91,18 @@ def generate_query(
if "Human:" not in model_prompt:
raise ValueError("Invalid prompt file. Please use prompt_anthropic.md")
question_instructions = question + " " + instructions
if table_metadata_string == "":
pruned_metadata_str = prune_metadata_str(
question_instructions, self.db_name, self.use_public_data
)
else:
pruned_metadata_str = table_metadata_string
prompt = model_prompt.format(
user_question=question,
table_metadata_string=prune_metadata_str(
question_instructions, self.db_name, self.use_public_data
),
table_metadata_string=pruned_metadata_str,
instructions=instructions,
k_shot_prompt=k_shot_prompt,
glossary=glossary,
)
function_to_run = self.get_completion
package = prompt
Expand Down
24 changes: 16 additions & 8 deletions query_generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,12 @@ def count_tokens(
return num_tokens

def generate_query(
self, question: str, instructions: str, k_shot_prompt: str
self,
question: str,
instructions: str,
k_shot_prompt: str,
glossary: str,
table_metadata_string: str,
) -> dict:
start_time = time.time()
self.err = ""
Expand All @@ -120,6 +125,12 @@ def generate_query(
with open(self.prompt_file) as file:
chat_prompt = file.read()
question_instructions = question + " " + instructions
if table_metadata_string == "":
pruned_metadata_str = prune_metadata_str(
question_instructions, self.db_name, self.use_public_data
)
else:
pruned_metadata_str = table_metadata_string
if self.model != "text-davinci-003":
try:
sys_prompt = chat_prompt.split("### Input:")[0]
Expand All @@ -129,14 +140,12 @@ def generate_query(
assistant_prompt = chat_prompt.split("### Response:")[1]
except:
raise ValueError("Invalid prompt file. Please use prompt_openai.md")

user_prompt = user_prompt.format(
user_question=question,
table_metadata_string=prune_metadata_str(
question_instructions, self.db_name, self.use_public_data
),
table_metadata_string=pruned_metadata_str,
instructions=instructions,
k_shot_prompt=k_shot_prompt,
glossary=glossary,
)

messages = []
Expand All @@ -146,11 +155,10 @@ def generate_query(
else:
prompt = chat_prompt.format(
user_question=question,
table_metadata_string=prune_metadata_str(
question_instructions, self.db_name, self.use_public_data
),
table_metadata_string=pruned_metadata_str,
instructions=instructions,
k_shot_prompt=k_shot_prompt,
glossary=glossary,
)
function_to_run = None
package = None
Expand Down
7 changes: 6 additions & 1 deletion query_generators/query_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ def __init__(self, **kwargs):
pass

def generate_query(
self, question: str, instructions: str, k_shot_prompt: str
self,
question: str,
instructions: str,
k_shot_prompt: str,
glossary: str,
table_metadata_string: str,
) -> dict:
# generate a query given a question, instructions and k-shot prompt
# any hard-coded logic, prompt-engineering, table-pruning, api calls etc
Expand Down
18 changes: 18 additions & 0 deletions utils/questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,23 @@ def prepare_questions_df(
lambda x: f"\nAdhere closely to the following correct examples as references for answering the question:\n{x}"
)

# get glossary if applicable
if "glossary" in question_query_df.columns:
question_query_df["glossary"] = question_query_df["glossary"].fillna("")
question_query_df["glossary"] = question_query_df["glossary"].apply(
lambda x: f"\nUse the following instructions if and only if they are relevant to the question:\n{x}\n"
)
else:
question_query_df["glossary"] = ""

question_query_df.reset_index(inplace=True, drop=True)

# get table_metadata_string if applicable
if "table_metadata_string" in question_query_df.columns:
question_query_df["table_metadata_string"] = question_query_df[
"table_metadata_string"
].fillna("")
else:
question_query_df["table_metadata_string"] = ""

return question_query_df
Loading