diff --git a/dataset/process.py b/dataset/process.py index a2249a8..4009bd7 100644 --- a/dataset/process.py +++ b/dataset/process.py @@ -3,6 +3,7 @@ from typing import List from pydub import AudioSegment +from tqdm import tqdm import dataset.exceptions from dataset.aeneas_wrapper import AeneasWrapper @@ -160,6 +161,7 @@ def process(self, remove: bool = False) -> None: """ nbm_files = len(os.listdir(self.audio_path)) + progress_bar = tqdm(total=nbm_files) for i, audio_f in enumerate(os.listdir(self.audio_path)): if not audio_f.endswith(".ogg") and not audio_f.endswith(".mp4"): continue @@ -180,10 +182,13 @@ def process(self, remove: bool = False) -> None: self._export_audio(audio_segments, audio_f.split(".")[0]) self._export_lyric(lyric_segments, audio_f.split(".")[0]) - print( - f"Processed {i}/ {nbm_files} - {round(i/nbm_files*100, 2)}%", end="\r" - ) + # print( + # f"Processed {i}/ {nbm_files} - {round(i/nbm_files*100, 2)}%", end="\r" + # ) if remove: os.remove(lyric_path) os.remove(audio_path) + + progress_bar.update(1) + progress_bar.close() diff --git a/train.py b/train.py index f0dddf6..fcc5680 100644 --- a/train.py +++ b/train.py @@ -1,9 +1,9 @@ -from datasets import concatenate_datasets, Dataset +import argparse + +from datasets import Dataset from training import utils from training.train import Trainer -import argparse -import glob parser = argparse.ArgumentParser( description="Process the dataset and train the model", @@ -51,4 +51,4 @@ print(dataset) trainer.train() -trainer.save_model(args.model_path) +trainer.model.save_pretrained(args.model_path) diff --git a/train2.py b/train2.py index 8a20ff3..a5ebfb9 100644 --- a/train2.py +++ b/train2.py @@ -1,19 +1,18 @@ import librosa -import numpy as np -from datasets import Audio, DatasetDict, load_from_disk +from datasets import load_from_disk from training import utils from training.train import Trainer -DS_PATH = "dataset/" +DS_PATH = "dataset/export" -dataset = utils.gather_dataset(DS_PATH) trainer = Trainer() is_prepared = False if not is_prepared: + dataset = utils.gather_dataset(DS_PATH) target_sr = trainer.processor.feature_extractor.sampling_rate def prepare_dataset(batch): diff --git a/training/train.py b/training/train.py index 812ede5..4916fc0 100644 --- a/training/train.py +++ b/training/train.py @@ -1,6 +1,7 @@ """ This module contains the Trainer class which is responsible for training whisper on predicting lyrics. """ +from functools import partial import evaluate from transformers import ( @@ -79,28 +80,30 @@ def train(self, dataset): :return: """ + self.model.generate = partial( + self.model.generate, task="transcribe", use_cache=True + ) training_args = Seq2SeqTrainingArguments( output_dir=self._ouput_dir, # change to a repo name of your choice per_device_train_batch_size=8, gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size learning_rate=1e-5, - warmup_steps=500, - max_steps=4000, - gradient_checkpointing=True, + warmup_steps=50, + max_steps=8000, + gradient_checkpointing=False, fp16=True, eval_strategy="steps", per_device_eval_batch_size=8, predict_with_generate=True, generation_max_length=225, - save_steps=80, - eval_steps=80, + save_steps=1000, + eval_steps=1000, logging_steps=25, report_to=["tensorboard"], - load_best_model_at_end=True, + # load_best_model_at_end=True, metric_for_best_model="wer", greater_is_better=False, push_to_hub=False, - gradient_checkpointing_kwargs={"use_reentrant": False}, ) trainer = Seq2SeqTrainer( diff --git a/training/utils.py b/training/utils.py index 00392fa..070f141 100644 --- a/training/utils.py +++ b/training/utils.py @@ -18,6 +18,7 @@ def gather_dataset(path: str) -> Dataset: """ def gen(): + i = 2 audios = glob.glob(path + "/audio/*") lyrics = glob.glob(path + "/lyrics/*.txt") for audio, lyric in zip(audios, lyrics):