From 735a91a890c56d07036deb4cb477dd7c56d22b90 Mon Sep 17 00:00:00 2001 From: Andreea Popescu Date: Thu, 5 Sep 2024 20:35:44 +0800 Subject: [PATCH] wip --- src/compute_horde_prompt_gen/model.py | 7 ++++++- src/compute_horde_prompt_gen/run.py | 2 +- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/compute_horde_prompt_gen/model.py b/src/compute_horde_prompt_gen/model.py index 81d4346..e8c1cb3 100644 --- a/src/compute_horde_prompt_gen/model.py +++ b/src/compute_horde_prompt_gen/model.py @@ -109,4 +109,9 @@ def generate( ) def decode(self, output): - return self.tokenizer.decode(output, skip_special_tokens=True) + if self.model_name == LLAMA3: + return self.tokenizer.decode(output, skip_special_tokens=True) + elif self.model_name == PHI3: + return self.tokenizer.batch_decode(output) + else: + raise ValueError(f"Unknown model {self.model_name}") diff --git a/src/compute_horde_prompt_gen/run.py b/src/compute_horde_prompt_gen/run.py index a80026a..9edc143 100644 --- a/src/compute_horde_prompt_gen/run.py +++ b/src/compute_horde_prompt_gen/run.py @@ -42,7 +42,7 @@ def generate_prompts( new_prompts = [] for j, sequence in enumerate(sequences): output = model.decode(sequence) - log.info(f"{i=} output={output}") + log.info(f"\n\n{i=} output={output}\n\n") generated_prompts = parse_output(output) log.debug(f"{i=} sequence={j} {generated_prompts=} from {output=}")