diff --git a/.gitignore b/.gitignore index c368e9b..548b01d 100644 --- a/.gitignore +++ b/.gitignore @@ -163,6 +163,7 @@ dataset/audio/ dataset/lyrics/ dataset/data.json train/ +model/ formated_dataset/ test.py \ No newline at end of file diff --git a/train.py b/train.py index 7ed6691..8332e62 100644 --- a/train.py +++ b/train.py @@ -8,12 +8,14 @@ if LOAD_DATASET: dataset = utils.gather_dataset("./dataset") - dataset = dataset.train_test_split(test_size=0.1) else: dataset = DatasetDict.load_from_disk("./formated_dataset") trainer = Trainer(dataset) if LOAD_DATASET: - dataset = trainer.process_dataset(dataset) - dataset.save_to_disk("./formated_dataset") + for i in range(dataset.num_rows//1000): + dataset = trainer.process_dataset(dataset, i) + dataset.save_to_disk(f"./formated_dataset_{i}") trainer.train() +trainer.model.save_pretrained("./model") +trainer.processor.save_pretrained("./model") diff --git a/training/train.py b/training/train.py index 2e5f6b9..6601495 100644 --- a/training/train.py +++ b/training/train.py @@ -36,7 +36,7 @@ def __init__(self, dataset=None, model_name="openai/whisper-small", ): self.model = WhisperForConditionalGeneration.from_pretrained(model_name) self.dataset = dataset self.data_collator = DataCollatorSpeechSeq2SeqWithPadding(self.processor) - self.prepare_tokenizer() + # self.prepare_tokenizer() def prepare_tokenizer(self) -> None: """ @@ -51,7 +51,7 @@ def prepare_tokenizer(self) -> None: self.processor.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens_to_add}) self.model.resize_token_embeddings(len(self.processor.tokenizer)) - def process_dataset(self, dataset) -> Dataset: + def process_dataset(self, dataset, chunk_id) -> Dataset: """ A method that processes the dataset. :return: None @@ -78,10 +78,10 @@ def prepare_dataset(example): example["input_length"] = len(audio) / sr return example - - self.dataset = dataset.map( + small_dataset = Dataset.from_dict(dataset[chunk_id*1000:chunk_id*1000+1000]) + self.dataset = small_dataset.map( prepare_dataset, - remove_columns=dataset.column_names["train"], + remove_columns=small_dataset.column_names, num_proc=1, ) return self.dataset