Skip to content

Commit

Permalink
fix previous commit
Browse files Browse the repository at this point in the history
  • Loading branch information
ostix360 committed May 22, 2024
1 parent 39cf69a commit a951ed7
Showing 1 changed file with 21 additions and 32 deletions.
53 changes: 21 additions & 32 deletions dataset/llama3_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import random
import typing

import torch
from random_word import RandomWords
from transformers import AutoModelForCausalLM, AutoTokenizer

Expand All @@ -37,11 +36,6 @@


def _setup_model() -> typing.Tuple[AutoModelForCausalLM, AutoTokenizer]:
config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
model = AutoModelForCausalLM.from_pretrained(
"unsloth/llama-3-8b-Instruct-bnb-4bit"
)
Expand Down Expand Up @@ -114,12 +108,9 @@ def generate_dataset(self, nb_gen: int = 100) -> None:
tags += f"[{tag.value} N]\n"

# generate the number of lyrics specified
for i in range(1000,1100):
for i in range(nb_gen):
# add word to obtain different lyrics each time
word = " ".join(self._word_generator.get_random_word() for _ in range(1))
nb_chorus = random.randint(2, 4)
nb_verse = random.randint(2, 4)
nb_bridge = random.randint(0, 1)
messages = [
{
"role": "system",
Expand All @@ -130,36 +121,34 @@ def generate_dataset(self, nb_gen: int = 100) -> None:
"content": self.PROMPT.format(
TAGS=tags,
WORD=word,
NB_CHORUS=nb_chorus,
NB_VERSE=nb_verse,
NB_BRIDGE=nb_bridge
NB_CHORUS=random.randint(2, 4),
NB_VERSE=random.randint(2, 4),
NB_BRIDGE=random.randint(0, 1)
),
},
]

# prompt = tokenizer.apply_chat_template(
# messages, tokenize=False, add_generation_prompt=True
# )
#
# input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

# output = model.generate(
# input_ids,
# max_length=500,
# num_return_sequences=1,
# do_sample=True,
# temperature=0.9,
# )
# generated_text = tokenizer.decode(output[0], skip_special_tokens=False)
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device)

generated_text = model.create_chat_completion(messages)["choices"][0]["message"]["content"]
output = model.generate(
input_ids,
max_length=500,
num_return_sequences=1,
do_sample=True,
temperature=0.9,
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=False)

# remove the prompt and the special tokens
# generated_text = (
# generated_text.replace(prompt, "")
# .replace("<|begin_of_text|>", "")
# .replace("<|eot_id|>", "")
# )
generated_text = (
generated_text.replace(prompt, "")
.replace("<|begin_of_text|>", "")
.replace("<|eot_id|>", "")
)

if not os.path.exists(self._path):
os.makedirs(self._path)
Expand Down

0 comments on commit a951ed7

Please sign in to comment.