From 618250d379875d193979d77cb70caf1e6da68bf9 Mon Sep 17 00:00:00 2001 From: Andreea Popescu Date: Thu, 5 Sep 2024 20:46:04 +0800 Subject: [PATCH] wip --- src/compute_horde_prompt_gen/model.py | 15 ++++++--------- src/compute_horde_prompt_gen/run.py | 6 +++--- 2 files changed, 9 insertions(+), 12 deletions(-) diff --git a/src/compute_horde_prompt_gen/model.py b/src/compute_horde_prompt_gen/model.py index 4363190..37909bf 100644 --- a/src/compute_horde_prompt_gen/model.py +++ b/src/compute_horde_prompt_gen/model.py @@ -14,10 +14,8 @@ def __init__(self): pass def generate(self, prompts: list[str], num_return_sequences: int, **_kwargs): - return [1 for _ in range(len(prompts) * num_return_sequences)] - - def decode(self, _output): - return f"COPY PASTE INPUT PROMPT {PROMPT_ENDING} Here is the list of prompts:\nHow are you?\nDescribe something\nCount to ten\n" + content = f"COPY PASTE INPUT PROMPT {PROMPT_ENDING} Here is the list of prompts:\nHow are you?\nDescribe something\nCount to ten\n" + return [content for _ in range(len(prompts) * num_return_sequences)] class GenerativeModel: @@ -100,18 +98,17 @@ def generate( else: raise ValueError(f"Unknown model {self.model_name}") - return self.model.generate( + output = self.model.generate( inputs, max_new_tokens=max_new_tokens, temperature=temperature, num_return_sequences=num_return_sequences, do_sample=True, # use sampling-based decoding ) - - def decode(self, output): if self.model_name == LLAMA3: - return self.tokenizer.decode(output, skip_special_tokens=True) + return [self.tokenizer.decode(x, skip_special_tokens=True) for x in output] elif self.model_name == PHI3: - return "".join(self.tokenizer.batch_decode(output)) + # return "".join(self.tokenizer.batch_decode(output)) + return self.tokenizer.batch_decode(output) else: raise ValueError(f"Unknown model {self.model_name}") diff --git a/src/compute_horde_prompt_gen/run.py b/src/compute_horde_prompt_gen/run.py index 9edc143..286d4b4 100644 --- a/src/compute_horde_prompt_gen/run.py +++ b/src/compute_horde_prompt_gen/run.py @@ -41,9 +41,9 @@ def generate_prompts( new_prompts = [] for j, sequence in enumerate(sequences): - output = model.decode(sequence) - log.info(f"\n\n{i=} output={output}\n\n") - generated_prompts = parse_output(output) + # output = model.decode(sequence) + # log.info(f"\n\n{i=} output={output}\n\n") + generated_prompts = parse_output(sequence) log.debug(f"{i=} sequence={j} {generated_prompts=} from {output=}") log.info(f"{i=} sequence={j} generated {len(generated_prompts)} prompts")