From 26ef74248ab168c7327946f9a7177bd04470e154 Mon Sep 17 00:00:00 2001 From: Andreea Popescu Date: Thu, 5 Sep 2024 21:06:04 +0800 Subject: [PATCH] wip --- .../download_model.py | 6 +- src/compute_horde_prompt_gen/model.py | 107 ++++++++++-------- src/compute_horde_prompt_gen/run.py | 19 ++-- src/compute_horde_prompt_gen/utils.py | 7 -- 4 files changed, 74 insertions(+), 65 deletions(-) diff --git a/src/compute_horde_prompt_gen/download_model.py b/src/compute_horde_prompt_gen/download_model.py index c354270..108ea1a 100644 --- a/src/compute_horde_prompt_gen/download_model.py +++ b/src/compute_horde_prompt_gen/download_model.py @@ -4,11 +4,9 @@ AutoModelForCausalLM, ) -from model import LLAMA3, PHI3 - MODEL_PATHS = { - LLAMA3: "meta-llama/Meta-Llama-3.1-8B-Instruct", - PHI3: "microsoft/Phi-3-mini-4k-instruct", + "llama3": "meta-llama/Meta-Llama-3.1-8B-Instruct", + "phi3": "microsoft/Phi-3-mini-4k-instruct", } if __name__ == "__main__": diff --git a/src/compute_horde_prompt_gen/model.py b/src/compute_horde_prompt_gen/model.py index 37909bf..4c5859e 100644 --- a/src/compute_horde_prompt_gen/model.py +++ b/src/compute_horde_prompt_gen/model.py @@ -3,10 +3,11 @@ log = logging.getLogger(__name__) -LLAMA3 = "llama3" -PHI3 = "phi3" -PROMPT_ENDING = " }}assistant" +def strip_input(output: str, ending: str) -> str: + # input prompt is repeated in the output, so we need to remove it + idx = output.find(ending) + len(ending) + return output[idx:].strip() class MockModel: @@ -14,13 +15,12 @@ def __init__(self): pass def generate(self, prompts: list[str], num_return_sequences: int, **_kwargs): - content = f"COPY PASTE INPUT PROMPT {PROMPT_ENDING} Here is the list of prompts:\nHow are you?\nDescribe something\nCount to ten\n" + content = f"Here is the list of prompts:\nHow are you?\nDescribe something\nCount to ten\n" return [content for _ in range(len(prompts) * num_return_sequences)] class GenerativeModel: - def __init__(self, model_name: str, model_path: str, quantize: bool = False): - self.model_name = model_name + def __init__(self, model_path: str, quantize: bool = False): import torch from transformers import ( AutoTokenizer, @@ -48,11 +48,62 @@ def __init__(self, model_name: str, model_path: str, quantize: bool = False): model_path, local_files_only=True, ) + + def tokenize(self, prompts: list[str], role: str) -> str: + pass + + def decode(self, output) -> list[str]: + pass + + def generate( + self, + prompts: list[str], + role: str, + num_return_sequences: int, + max_new_tokens: int, + temperature: float, + ): + # encode the prompts + inputs = self.tokenize(prompts, role) + + output = self.model.generate( + inputs, + max_new_tokens=max_new_tokens, + temperature=temperature, + num_return_sequences=num_return_sequences, + do_sample=True, # use sampling-based decoding + ) + + return self.decode(output) + + +class Phi3(GenerativeModel): + def decode(self, output) -> list[str]: + return [ + strip_input(x, "<|assistant|>") for x in self.tokenizer.batch_decode(output) + ] + + def tokenize_phi3(self, prompts: list[str], role: str) -> str: + inputs = [{"role": "user", "content": prompt} for prompt in prompts] + inputs = self.tokenizer.apply_chat_template( + inputs, add_generation_prompt=True, return_tensors="pt" + ).to("cuda") + return inputs + + +class Llama3(GenerativeModel): + def decode(self, output) -> list[str]: + return [ + strip_input( + self.tokenizer.decode(x, skip_special_tokens=True), " }}assistant" + ) + for x in output + ] + + def tokenize(self, prompts: list[str], role: str) -> str: # set default padding token - if self.model_name == LLAMA3: - self.tokenizer.pad_token = self.tokenizer.eos_token + self.tokenizer.pad_token = self.tokenizer.eos_token - def tokenize_llama3(self, prompts: list[str], role: str) -> str: role_templates = { "system": "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n{{{{ {} }}}}<|eot_id|>", "user": "<|start_header_id|>user<|end_header_id|>\n{{{{ {} }}}}<|eot_id|>", @@ -74,41 +125,3 @@ def tokenize(prompt: str) -> str: inputs = [tokenize(prompt) for prompt in prompts] inputs = self.tokenizer(inputs, return_tensors="pt", padding=True).to("cuda") return inputs - - def tokenize_phi3(self, prompts: list[str], role: str) -> str: - inputs = [{"role": "user", "content": prompt} for prompt in prompts] - inputs = self.tokenizer.apply_chat_template( - inputs, add_generation_prompt=True, return_tensors="pt" - ).to("cuda") - return inputs - - def generate( - self, - prompts: list[str], - role: str, - num_return_sequences: int, - max_new_tokens: int, - temperature: float, - ): - # encode the prompts - if self.model_name == LLAMA3: - inputs = self.tokenize_llama3(prompts, role) - elif self.model_name == PHI3: - inputs = self.tokenize_phi3(prompts, role) - else: - raise ValueError(f"Unknown model {self.model_name}") - - output = self.model.generate( - inputs, - max_new_tokens=max_new_tokens, - temperature=temperature, - num_return_sequences=num_return_sequences, - do_sample=True, # use sampling-based decoding - ) - if self.model_name == LLAMA3: - return [self.tokenizer.decode(x, skip_special_tokens=True) for x in output] - elif self.model_name == PHI3: - # return "".join(self.tokenizer.batch_decode(output)) - return self.tokenizer.batch_decode(output) - else: - raise ValueError(f"Unknown model {self.model_name}") diff --git a/src/compute_horde_prompt_gen/run.py b/src/compute_horde_prompt_gen/run.py index 286d4b4..c55d64d 100644 --- a/src/compute_horde_prompt_gen/run.py +++ b/src/compute_horde_prompt_gen/run.py @@ -4,7 +4,7 @@ import argparse from prompt import PromptGeneratingPrompt -from model import MockModel, GenerativeModel +from model import MockModel, Llama3, Phi3 from utils import parse_output, append_to_file logging.basicConfig(level=logging.INFO) @@ -149,15 +149,20 @@ def generate_prompts( len(uuids) == args.number_of_batches ), "Number of uuids should be equal to number of batches requested" - model = ( - GenerativeModel( - model_name=args.model_name, + if args.mock_model: + model = MockModel() + elif args.model_name == "llama3": + model = Llama3( model_path=args.model_path, quantize=args.quantize, ) - if not args.mock_model - else MockModel() - ) + elif args.model_name == "phi3": + model = Phi3( + model_path=args.model_path, + quantize=args.quantize, + ) + else: + raise ValueError(f"Invalid model name: {args.model_name}") for uuid in uuids: start_ts = datetime.datetime.now() diff --git a/src/compute_horde_prompt_gen/utils.py b/src/compute_horde_prompt_gen/utils.py index 699cdee..5825cab 100644 --- a/src/compute_horde_prompt_gen/utils.py +++ b/src/compute_horde_prompt_gen/utils.py @@ -13,14 +13,7 @@ def clean_line(line: str) -> str: return line -PROMPT_ENDING = " }}assistant" - - def parse_output(output: str) -> list[str]: - # input prompt is repeated in the output, so we need to remove it - idx = output.find(PROMPT_ENDING) + len(PROMPT_ENDING) - output = output[idx:].strip() - # split into lines and clean them lines = output.split("\n") lines = [clean_line(line) for line in lines]