Skip to content

Commit

Permalink
better hyperparams
Browse files Browse the repository at this point in the history
add progress bar to process
keep the unused variable to process the data!!
  • Loading branch information
ostix360 committed Jun 9, 2024
1 parent 1f71ad5 commit a61028d
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 18 deletions.
11 changes: 8 additions & 3 deletions dataset/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import List

from pydub import AudioSegment
from tqdm import tqdm

import dataset.exceptions
from dataset.aeneas_wrapper import AeneasWrapper
Expand Down Expand Up @@ -160,6 +161,7 @@ def process(self, remove: bool = False) -> None:
"""

nbm_files = len(os.listdir(self.audio_path))
progress_bar = tqdm(total=nbm_files)
for i, audio_f in enumerate(os.listdir(self.audio_path)):
if not audio_f.endswith(".ogg") and not audio_f.endswith(".mp4"):
continue
Expand All @@ -180,10 +182,13 @@ def process(self, remove: bool = False) -> None:
self._export_audio(audio_segments, audio_f.split(".")[0])
self._export_lyric(lyric_segments, audio_f.split(".")[0])

print(
f"Processed {i}/ {nbm_files} - {round(i/nbm_files*100, 2)}%", end="\r"
)
# print(
# f"Processed {i}/ {nbm_files} - {round(i/nbm_files*100, 2)}%", end="\r"
# )

if remove:
os.remove(lyric_path)
os.remove(audio_path)

progress_bar.update(1)
progress_bar.close()
8 changes: 4 additions & 4 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from datasets import concatenate_datasets, Dataset
import argparse

from datasets import Dataset

from training import utils
from training.train import Trainer
import argparse
import glob

parser = argparse.ArgumentParser(
description="Process the dataset and train the model",
Expand Down Expand Up @@ -51,4 +51,4 @@
print(dataset)

trainer.train()
trainer.save_model(args.model_path)
trainer.model.save_pretrained(args.model_path)
7 changes: 3 additions & 4 deletions train2.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
import librosa
import numpy as np
from datasets import Audio, DatasetDict, load_from_disk
from datasets import load_from_disk

from training import utils
from training.train import Trainer

DS_PATH = "dataset/"
DS_PATH = "dataset/export"

dataset = utils.gather_dataset(DS_PATH)
trainer = Trainer()

is_prepared = False


if not is_prepared:
dataset = utils.gather_dataset(DS_PATH)
target_sr = trainer.processor.feature_extractor.sampling_rate

def prepare_dataset(batch):
Expand Down
17 changes: 10 additions & 7 deletions training/train.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
This module contains the Trainer class which is responsible for training whisper on predicting lyrics.
"""
from functools import partial

import evaluate
from transformers import (
Expand Down Expand Up @@ -79,28 +80,30 @@ def train(self, dataset):
:return:
"""

self.model.generate = partial(
self.model.generate, task="transcribe", use_cache=True
)
training_args = Seq2SeqTrainingArguments(
output_dir=self._ouput_dir, # change to a repo name of your choice
per_device_train_batch_size=8,
gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size
learning_rate=1e-5,
warmup_steps=500,
max_steps=4000,
gradient_checkpointing=True,
warmup_steps=50,
max_steps=8000,
gradient_checkpointing=False,
fp16=True,
eval_strategy="steps",
per_device_eval_batch_size=8,
predict_with_generate=True,
generation_max_length=225,
save_steps=80,
eval_steps=80,
save_steps=1000,
eval_steps=1000,
logging_steps=25,
report_to=["tensorboard"],
load_best_model_at_end=True,
# load_best_model_at_end=True,
metric_for_best_model="wer",
greater_is_better=False,
push_to_hub=False,
gradient_checkpointing_kwargs={"use_reentrant": False},
)

trainer = Seq2SeqTrainer(
Expand Down
1 change: 1 addition & 0 deletions training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def gather_dataset(path: str) -> Dataset:
"""

def gen():
i = 2
audios = glob.glob(path + "/audio/*")
lyrics = glob.glob(path + "/lyrics/*.txt")
for audio, lyric in zip(audios, lyrics):
Expand Down

0 comments on commit a61028d

Please sign in to comment.