Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
andreea-popescu-reef committed Sep 6, 2024
1 parent 05865a6 commit 203bef9
Showing 1 changed file with 34 additions and 34 deletions.
68 changes: 34 additions & 34 deletions src/compute_horde_prompt_gen/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def __init__(self, model_path: str, quantize: bool = False):
local_files_only=True,
)

# def tokenize(self, prompts: list[str], role: str) -> str:
# pass
def tokenize(self, prompts: list[str], role: str) -> str:
pass

def decode(self, output) -> list[str]:
pass
Expand All @@ -76,6 +76,38 @@ def generate(

return self.decode(output)


class Phi3(GenerativeModel):
def decode(self, output) -> list[str]:
print(f"\nraw_output: {output}\n")
# return [
# strip_input(x, "<|assistant|>") for x in self.tokenizer.batch_decode(output)
# ]
return [
strip_input(
self.tokenizer.decode(x, skip_special_tokens=True), " }}assistant"
)
for x in output
]

def tokenize(self, prompts: list[str], role: str) -> str:
inputs = [{"role": "user", "content": prompt} for prompt in prompts]
print(f"\ninputs: {inputs}\n")
inputs = self.tokenizer.apply_chat_template(
inputs, add_generation_prompt=True, return_tensors="pt"
).to("cuda")
return inputs


class Llama3(GenerativeModel):
def decode(self, output) -> list[str]:
return [
strip_input(
self.tokenizer.decode(x, skip_special_tokens=True), " }}assistant"
)
for x in output
]

def tokenize(self, prompts: list[str], role: str) -> str:
# set default padding token
self.tokenizer.pad_token = self.tokenizer.eos_token
Expand All @@ -101,35 +133,3 @@ def tokenize(prompt: str) -> str:
inputs = [tokenize(prompt) for prompt in prompts]
inputs = self.tokenizer(inputs, return_tensors="pt", padding=True).to("cuda")
return inputs


class Phi3(GenerativeModel):
def decode(self, output) -> list[str]:
print(f"\nraw_output: {output}\n")
# return [
# strip_input(x, "<|assistant|>") for x in self.tokenizer.batch_decode(output)
# ]
return [
strip_input(
self.tokenizer.decode(x, skip_special_tokens=True), " }}assistant"
)
for x in output
]

# def tokenize_phi3(self, prompts: list[str], role: str) -> str:
# inputs = [{"role": "user", "content": prompt} for prompt in prompts]
# print(f"\ninputs: {inputs}\n")
# inputs = self.tokenizer.apply_chat_template(
# inputs, add_generation_prompt=True, return_tensors="pt"
# ).to("cuda")
# return inputs


class Llama3(GenerativeModel):
def decode(self, output) -> list[str]:
return [
strip_input(
self.tokenizer.decode(x, skip_special_tokens=True), " }}assistant"
)
for x in output
]

0 comments on commit 203bef9

Please sign in to comment.