Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
andreea-popescu-reef committed Sep 5, 2024
1 parent 618250d commit 26ef742
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 65 deletions.
6 changes: 2 additions & 4 deletions src/compute_horde_prompt_gen/download_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
AutoModelForCausalLM,
)

from model import LLAMA3, PHI3

MODEL_PATHS = {
LLAMA3: "meta-llama/Meta-Llama-3.1-8B-Instruct",
PHI3: "microsoft/Phi-3-mini-4k-instruct",
"llama3": "meta-llama/Meta-Llama-3.1-8B-Instruct",
"phi3": "microsoft/Phi-3-mini-4k-instruct",
}

if __name__ == "__main__":
Expand Down
107 changes: 60 additions & 47 deletions src/compute_horde_prompt_gen/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,24 @@

log = logging.getLogger(__name__)

LLAMA3 = "llama3"
PHI3 = "phi3"

PROMPT_ENDING = " }}assistant"
def strip_input(output: str, ending: str) -> str:
# input prompt is repeated in the output, so we need to remove it
idx = output.find(ending) + len(ending)
return output[idx:].strip()


class MockModel:
def __init__(self):
pass

def generate(self, prompts: list[str], num_return_sequences: int, **_kwargs):
content = f"COPY PASTE INPUT PROMPT {PROMPT_ENDING} Here is the list of prompts:\nHow are you?\nDescribe something\nCount to ten\n"
content = f"Here is the list of prompts:\nHow are you?\nDescribe something\nCount to ten\n"
return [content for _ in range(len(prompts) * num_return_sequences)]


class GenerativeModel:
def __init__(self, model_name: str, model_path: str, quantize: bool = False):
self.model_name = model_name
def __init__(self, model_path: str, quantize: bool = False):
import torch
from transformers import (
AutoTokenizer,
Expand Down Expand Up @@ -48,11 +48,62 @@ def __init__(self, model_name: str, model_path: str, quantize: bool = False):
model_path,
local_files_only=True,
)

def tokenize(self, prompts: list[str], role: str) -> str:
pass

def decode(self, output) -> list[str]:
pass

def generate(
self,
prompts: list[str],
role: str,
num_return_sequences: int,
max_new_tokens: int,
temperature: float,
):
# encode the prompts
inputs = self.tokenize(prompts, role)

output = self.model.generate(
inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
num_return_sequences=num_return_sequences,
do_sample=True, # use sampling-based decoding
)

return self.decode(output)


class Phi3(GenerativeModel):
def decode(self, output) -> list[str]:
return [
strip_input(x, "<|assistant|>") for x in self.tokenizer.batch_decode(output)
]

def tokenize_phi3(self, prompts: list[str], role: str) -> str:
inputs = [{"role": "user", "content": prompt} for prompt in prompts]
inputs = self.tokenizer.apply_chat_template(
inputs, add_generation_prompt=True, return_tensors="pt"
).to("cuda")
return inputs


class Llama3(GenerativeModel):
def decode(self, output) -> list[str]:
return [
strip_input(
self.tokenizer.decode(x, skip_special_tokens=True), " }}assistant"
)
for x in output
]

def tokenize(self, prompts: list[str], role: str) -> str:
# set default padding token
if self.model_name == LLAMA3:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token = self.tokenizer.eos_token

def tokenize_llama3(self, prompts: list[str], role: str) -> str:
role_templates = {
"system": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{{{{ {} }}}}<|eot_id|>",
"user": "<|start_header_id|>user<|end_header_id|>\n{{{{ {} }}}}<|eot_id|>",
Expand All @@ -74,41 +125,3 @@ def tokenize(prompt: str) -> str:
inputs = [tokenize(prompt) for prompt in prompts]
inputs = self.tokenizer(inputs, return_tensors="pt", padding=True).to("cuda")
return inputs

def tokenize_phi3(self, prompts: list[str], role: str) -> str:
inputs = [{"role": "user", "content": prompt} for prompt in prompts]
inputs = self.tokenizer.apply_chat_template(
inputs, add_generation_prompt=True, return_tensors="pt"
).to("cuda")
return inputs

def generate(
self,
prompts: list[str],
role: str,
num_return_sequences: int,
max_new_tokens: int,
temperature: float,
):
# encode the prompts
if self.model_name == LLAMA3:
inputs = self.tokenize_llama3(prompts, role)
elif self.model_name == PHI3:
inputs = self.tokenize_phi3(prompts, role)
else:
raise ValueError(f"Unknown model {self.model_name}")

output = self.model.generate(
inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
num_return_sequences=num_return_sequences,
do_sample=True, # use sampling-based decoding
)
if self.model_name == LLAMA3:
return [self.tokenizer.decode(x, skip_special_tokens=True) for x in output]
elif self.model_name == PHI3:
# return "".join(self.tokenizer.batch_decode(output))
return self.tokenizer.batch_decode(output)
else:
raise ValueError(f"Unknown model {self.model_name}")
19 changes: 12 additions & 7 deletions src/compute_horde_prompt_gen/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import argparse

from prompt import PromptGeneratingPrompt
from model import MockModel, GenerativeModel
from model import MockModel, Llama3, Phi3
from utils import parse_output, append_to_file

logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -149,15 +149,20 @@ def generate_prompts(
len(uuids) == args.number_of_batches
), "Number of uuids should be equal to number of batches requested"

model = (
GenerativeModel(
model_name=args.model_name,
if args.mock_model:
model = MockModel()
elif args.model_name == "llama3":
model = Llama3(
model_path=args.model_path,
quantize=args.quantize,
)
if not args.mock_model
else MockModel()
)
elif args.model_name == "phi3":
model = Phi3(
model_path=args.model_path,
quantize=args.quantize,
)
else:
raise ValueError(f"Invalid model name: {args.model_name}")

for uuid in uuids:
start_ts = datetime.datetime.now()
Expand Down
7 changes: 0 additions & 7 deletions src/compute_horde_prompt_gen/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,7 @@ def clean_line(line: str) -> str:
return line


PROMPT_ENDING = " }}assistant"


def parse_output(output: str) -> list[str]:
# input prompt is repeated in the output, so we need to remove it
idx = output.find(PROMPT_ENDING) + len(PROMPT_ENDING)
output = output[idx:].strip()

# split into lines and clean them
lines = output.split("\n")
lines = [clean_line(line) for line in lines]
Expand Down

0 comments on commit 26ef742

Please sign in to comment.