Skip to content

Commit

Permalink
add batching dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
ostix360 committed May 31, 2024
1 parent f0b1c66 commit e183eac
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 8 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ dataset/audio/
dataset/lyrics/
dataset/data.json
train/
model/
formated_dataset/

test.py
8 changes: 5 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@

if LOAD_DATASET:
dataset = utils.gather_dataset("./dataset")
dataset = dataset.train_test_split(test_size=0.1)
else:
dataset = DatasetDict.load_from_disk("./formated_dataset")
trainer = Trainer(dataset)
if LOAD_DATASET:
dataset = trainer.process_dataset(dataset)
dataset.save_to_disk("./formated_dataset")
for i in range(dataset.num_rows//1000):
dataset = trainer.process_dataset(dataset, i)
dataset.save_to_disk(f"./formated_dataset_{i}")

trainer.train()
trainer.model.save_pretrained("./model")
trainer.processor.save_pretrained("./model")
10 changes: 5 additions & 5 deletions training/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, dataset=None, model_name="openai/whisper-small", ):
self.model = WhisperForConditionalGeneration.from_pretrained(model_name)
self.dataset = dataset
self.data_collator = DataCollatorSpeechSeq2SeqWithPadding(self.processor)
self.prepare_tokenizer()
# self.prepare_tokenizer()

def prepare_tokenizer(self) -> None:
"""
Expand All @@ -51,7 +51,7 @@ def prepare_tokenizer(self) -> None:
self.processor.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens_to_add})
self.model.resize_token_embeddings(len(self.processor.tokenizer))

def process_dataset(self, dataset) -> Dataset:
def process_dataset(self, dataset, chunk_id) -> Dataset:
"""
A method that processes the dataset.
:return: None
Expand All @@ -78,10 +78,10 @@ def prepare_dataset(example):
example["input_length"] = len(audio) / sr

return example

self.dataset = dataset.map(
small_dataset = Dataset.from_dict(dataset[chunk_id*1000:chunk_id*1000+1000])
self.dataset = small_dataset.map(
prepare_dataset,
remove_columns=dataset.column_names["train"],
remove_columns=small_dataset.column_names,
num_proc=1,
)
return self.dataset
Expand Down

0 comments on commit e183eac

Please sign in to comment.