-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize and fix prompt generation #6
Changes from 4 commits
664c7df
d1d56aa
645b6d6
a99fa78
e5d448b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -2,6 +2,7 @@ | |||||
import os | ||||||
import logging | ||||||
import argparse | ||||||
from collections import deque | ||||||
|
||||||
from prompt import PromptGeneratingPrompt | ||||||
from model import MockModel, Llama3, Phi3 | ||||||
|
@@ -19,7 +20,10 @@ def generate_prompts( | |||||
max_new_tokens: int = 2000, | ||||||
temperature: float = 1.0, | ||||||
filepath: str = "prompts.txt", | ||||||
leftover_prompts: deque = None, | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Please don't hate me for this :D There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I... I... :D |
||||||
): | ||||||
if leftover_prompts is None: | ||||||
leftover_prompts = deque() | ||||||
prompt_generator = PromptGeneratingPrompt() | ||||||
|
||||||
i = -1 | ||||||
|
@@ -38,22 +42,28 @@ def generate_prompts( | |||||
) | ||||||
|
||||||
seconds_taken = (datetime.datetime.now() - start_ts).total_seconds() | ||||||
log.info(f"{i=} generation took {seconds_taken:.2f}s") | ||||||
|
||||||
new_prompts = [] | ||||||
for j, sequence in enumerate(sequences): | ||||||
generated_prompts = parse_output(sequence) | ||||||
log.debug(f"{i=} sequence={j} {generated_prompts=} from {sequence=}") | ||||||
|
||||||
log.info(f"{i=} sequence={j} generated {len(generated_prompts)} prompts") | ||||||
new_prompts.extend(generated_prompts) | ||||||
|
||||||
# check_prompts_quality(new_prompts) | ||||||
|
||||||
# remove any duplicates | ||||||
new_prompts = list(set(new_prompts)) | ||||||
log.info( | ||||||
f"{i=} generation took {seconds_taken:.2f}s; generated {len(new_prompts)} prompts" | ||||||
) | ||||||
|
||||||
if total_prompts - len(new_prompts) < 0: | ||||||
# Use leftover prompts from previous batch if available | ||||||
while leftover_prompts and total_prompts > 0: | ||||||
new_prompts.append(leftover_prompts.popleft()) | ||||||
total_prompts -= 1 | ||||||
|
||||||
if len(new_prompts) > total_prompts: | ||||||
# Save extra prompts for next batch | ||||||
leftover_prompts.extend(new_prompts[total_prompts:]) | ||||||
Comment on lines
+63
to
+65
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you need the original value of |
||||||
new_prompts = new_prompts[:total_prompts] | ||||||
|
||||||
total_prompts -= len(new_prompts) | ||||||
|
@@ -62,6 +72,8 @@ def generate_prompts( | |||||
if total_prompts == 0: | ||||||
break | ||||||
|
||||||
return leftover_prompts | ||||||
|
||||||
|
||||||
if __name__ == "__main__": | ||||||
parser = argparse.ArgumentParser(description="Generate prompts") | ||||||
|
@@ -74,19 +86,19 @@ def generate_prompts( | |||||
parser.add_argument( | ||||||
"--batch_size", | ||||||
type=int, | ||||||
default=20, | ||||||
default=262, # on A6000 we want 240 prompts generated in single file, but not all results are valid | ||||||
help="Batch size - number of prompts given as input per generation request", | ||||||
) | ||||||
parser.add_argument( | ||||||
"--num_return_sequences", | ||||||
type=int, | ||||||
default=5, | ||||||
default=1, # better to generate as many as possible prompts on different themes | ||||||
help="Number of return sequences outputted for each prompt given as input", | ||||||
) | ||||||
parser.add_argument( | ||||||
"--max_new_tokens", | ||||||
type=int, | ||||||
default=500, | ||||||
default=40, # 40 new tokens is enough for reasonable length prompt - 30 caused too much cut off prompts | ||||||
help="Max new tokens", | ||||||
) | ||||||
parser.add_argument( | ||||||
|
@@ -108,16 +120,10 @@ def generate_prompts( | |||||
default="./saved_models/", | ||||||
help="Path to load the model and tokenizer from", | ||||||
) | ||||||
parser.add_argument( | ||||||
"--number_of_batches", | ||||||
type=int, | ||||||
default=None, | ||||||
help="Number of batches to generate", | ||||||
) | ||||||
parser.add_argument( | ||||||
"--number_of_prompts_per_batch", | ||||||
type=int, | ||||||
required=True, | ||||||
default=240, | ||||||
help="Number of prompts per uuid batch", | ||||||
) | ||||||
parser.add_argument( | ||||||
|
@@ -137,11 +143,6 @@ def generate_prompts( | |||||
|
||||||
uuids = args.uuids.split(",") | ||||||
|
||||||
if args.number_of_batches: | ||||||
assert ( | ||||||
len(uuids) == args.number_of_batches | ||||||
), "Number of uuids should be equal to number of batches requested" | ||||||
|
||||||
model_path = os.path.join(args.model_path, args.model_name) | ||||||
if args.model_name == "mock": | ||||||
model = MockModel() | ||||||
|
@@ -158,16 +159,18 @@ def generate_prompts( | |||||
else: | ||||||
raise ValueError(f"Invalid model name: {args.model_name}") | ||||||
|
||||||
leftover_prompts = None | ||||||
for uuid in uuids: | ||||||
start_ts = datetime.datetime.now() | ||||||
generate_prompts( | ||||||
leftover_prompts = generate_prompts( | ||||||
model, | ||||||
total_prompts=args.number_of_prompts_per_batch, | ||||||
batch_size=args.batch_size, | ||||||
num_return_sequences=args.num_return_sequences, | ||||||
max_new_tokens=args.max_new_tokens, | ||||||
temperature=args.temperature, | ||||||
filepath=os.path.join(args.output_folder_path, f"prompts_{uuid}.txt"), | ||||||
leftover_prompts=leftover_prompts, | ||||||
) | ||||||
seconds_taken = (datetime.datetime.now() - start_ts).total_seconds() | ||||||
log.info( | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,22 +8,34 @@ | |
|
||
def clean_line(line: str) -> str: | ||
line = line.strip() | ||
head, sep, tail = line.partition("<|") | ||
if head: | ||
line = head.strip() | ||
else: | ||
# if we started with a tag we assume that inside we find our prompt | ||
line = tail.partition("|>")[2].partition("<|")[0].strip() | ||
# remove list numbering if present | ||
line = re.sub(r"^\s*\d+\.?\s*", "", line) | ||
# strip quotations | ||
line = line.strip("\"'") | ||
Comment on lines
+19
to
+20
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Won't that mess up lines, where a quote mark is at the start/end of the line, and the other quote in the middle of the line? The strip will only remote the quote at the start/end, not the quote in the middle. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we don't care - we care ONLY about the lines that are single line prompt - and ends with |
||
return line | ||
|
||
|
||
def parse_output(output: str) -> list[str]: | ||
# split into lines and clean them | ||
lines = output.split("\n") | ||
lines = [clean_line(line) for line in lines] | ||
for line in lines: | ||
cleaned_line = clean_line(line) | ||
# we skip if line is too short or too long and not ends with ? | ||
# in most cases it would be just first line | ||
if ( | ||
len(cleaned_line) > 10 | ||
and len(cleaned_line) < 300 | ||
and cleaned_line.endswith("?") | ||
): | ||
return [cleaned_line] | ||
|
||
# filter out null lines or prompts that are too short or long | ||
lines = [line for line in lines if (len(line) > 10 and len(line) < 300)] | ||
|
||
# skip first line as that's frequently broken (i.e. "Here are the prompts:") | ||
# skip last line as it might not be comletely generated | ||
return lines[1:-1] | ||
return [] | ||
|
||
|
||
def check_prompts_quality(prompts: list[str]): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should not pick the same thing twice.