From f8d2ba9c1d2fff41824685446d749bb2f2173723 Mon Sep 17 00:00:00 2001 From: Andreea Popescu Date: Thu, 5 Sep 2024 20:16:08 +0800 Subject: [PATCH] wip --- src/compute_horde_prompt_gen/model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/compute_horde_prompt_gen/model.py b/src/compute_horde_prompt_gen/model.py index 08a8641..c4c2b5b 100644 --- a/src/compute_horde_prompt_gen/model.py +++ b/src/compute_horde_prompt_gen/model.py @@ -80,8 +80,8 @@ def tokenize(prompt: str) -> str: def tokenize_phi3(self, prompts: list[str], role: str) -> str: inputs = [{"role": role, "content": prompt} for prompt in prompts] inputs = self.tokenizer.apply_chat_template( - **inputs, add_generation_prompt=True, return_tensors="pt" - ) + inputs, add_generation_prompt=True, return_tensors="pt" + ).to("cuda") return inputs def generate( @@ -99,6 +99,7 @@ def generate( inputs = self.tokenize_phi3(prompts, role) else: raise ValueError(f"Unknown model {self.model_name}") + print(inputs) return self.model.generate( **inputs,