diff --git a/src/compute_horde_prompt_gen/model.py b/src/compute_horde_prompt_gen/model.py index 9cb1c30..36229e7 100644 --- a/src/compute_horde_prompt_gen/model.py +++ b/src/compute_horde_prompt_gen/model.py @@ -49,8 +49,8 @@ def __init__(self, model_path: str, quantize: bool = False): local_files_only=True, ) - # def tokenize(self, prompts: list[str], role: str) -> str: - # pass + def tokenize(self, prompts: list[str], role: str) -> str: + pass def decode(self, output) -> list[str]: pass @@ -76,6 +76,38 @@ def generate( return self.decode(output) + +class Phi3(GenerativeModel): + def decode(self, output) -> list[str]: + print(f"\nraw_output: {output}\n") + # return [ + # strip_input(x, "<|assistant|>") for x in self.tokenizer.batch_decode(output) + # ] + return [ + strip_input( + self.tokenizer.decode(x, skip_special_tokens=True), " }}assistant" + ) + for x in output + ] + + def tokenize(self, prompts: list[str], role: str) -> str: + inputs = [{"role": "user", "content": prompt} for prompt in prompts] + print(f"\ninputs: {inputs}\n") + inputs = self.tokenizer.apply_chat_template( + inputs, add_generation_prompt=True, return_tensors="pt" + ).to("cuda") + return inputs + + +class Llama3(GenerativeModel): + def decode(self, output) -> list[str]: + return [ + strip_input( + self.tokenizer.decode(x, skip_special_tokens=True), " }}assistant" + ) + for x in output + ] + def tokenize(self, prompts: list[str], role: str) -> str: # set default padding token self.tokenizer.pad_token = self.tokenizer.eos_token @@ -101,35 +133,3 @@ def tokenize(prompt: str) -> str: inputs = [tokenize(prompt) for prompt in prompts] inputs = self.tokenizer(inputs, return_tensors="pt", padding=True).to("cuda") return inputs - - -class Phi3(GenerativeModel): - def decode(self, output) -> list[str]: - print(f"\nraw_output: {output}\n") - # return [ - # strip_input(x, "<|assistant|>") for x in self.tokenizer.batch_decode(output) - # ] - return [ - strip_input( - self.tokenizer.decode(x, skip_special_tokens=True), " }}assistant" - ) - for x in output - ] - - # def tokenize_phi3(self, prompts: list[str], role: str) -> str: - # inputs = [{"role": "user", "content": prompt} for prompt in prompts] - # print(f"\ninputs: {inputs}\n") - # inputs = self.tokenizer.apply_chat_template( - # inputs, add_generation_prompt=True, return_tensors="pt" - # ).to("cuda") - # return inputs - - -class Llama3(GenerativeModel): - def decode(self, output) -> list[str]: - return [ - strip_input( - self.tokenizer.decode(x, skip_special_tokens=True), " }}assistant" - ) - for x in output - ]