From 99fc968db332c4f663a546a2d6e0956adad9243a Mon Sep 17 00:00:00 2001 From: Jourdelune Date: Sat, 8 Jun 2024 19:04:01 +0200 Subject: [PATCH] [update] c/v example for finetune whisper --- .gitignore | 5 +- run_speech_recognition_seq2seq.py | 741 ++++++++++++++++++++++++++++++ train.py | 26 +- train2.py | 56 +++ training/collator.py | 23 +- training/train.py | 211 +++------ 6 files changed, 879 insertions(+), 183 deletions(-) create mode 100644 run_speech_recognition_seq2seq.py create mode 100644 train2.py diff --git a/.gitignore b/.gitignore index 8875ff5..c086132 100644 --- a/.gitignore +++ b/.gitignore @@ -166,4 +166,7 @@ formated_dataset/ test.py dataset/ -save.py \ No newline at end of file +save.py + +whisper-finetuned/ +whisper-small-hi/ \ No newline at end of file diff --git a/run_speech_recognition_seq2seq.py b/run_speech_recognition_seq2seq.py new file mode 100644 index 0000000..cb4bc69 --- /dev/null +++ b/run_speech_recognition_seq2seq.py @@ -0,0 +1,741 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Fine-tuning the library models for sequence to sequence speech recognition. +""" +# You can also adapt this script on your own sequence to sequence speech +# recognition task. Pointers for this are left as comments. + +import logging +import os +import sys +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Union + +import datasets +import evaluate +import torch +from datasets import DatasetDict, load_dataset + +import transformers +from transformers import ( + AutoConfig, + AutoFeatureExtractor, + AutoModelForSpeechSeq2Seq, + AutoProcessor, + AutoTokenizer, + HfArgumentParser, + Seq2SeqTrainer, + Seq2SeqTrainingArguments, + set_seed, +) +from transformers.trainer_utils import get_last_checkpoint, is_main_process +from transformers.utils import check_min_version, send_example_telemetry +from transformers.utils.versions import require_version + + +require_version( + "datasets>=1.18.0", + "To fix: pip install -r examples/pytorch/speech-recognition/requirements.txt", +) + +logger = logging.getLogger(__name__) + + +@dataclass +class ModelArguments: + """ + Arguments pertaining to which model/config/tokenizer we are going to fine-tune from. + """ + + model_name_or_path: str = field( + metadata={ + "help": "Path to pretrained model or model identifier from huggingface.co/models" + } + ) + config_name: Optional[str] = field( + default=None, + metadata={ + "help": "Pretrained config name or path if not the same as model_name" + }, + ) + tokenizer_name: Optional[str] = field( + default=None, + metadata={ + "help": "Pretrained tokenizer name or path if not the same as model_name" + }, + ) + feature_extractor_name: Optional[str] = field( + default=None, + metadata={ + "help": "feature extractor name or path if not the same as model_name" + }, + ) + cache_dir: Optional[str] = field( + default=None, + metadata={ + "help": "Where to store the pretrained models downloaded from huggingface.co" + }, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={ + "help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not." + }, + ) + model_revision: str = field( + default="main", + metadata={ + "help": "The specific model version to use (can be a branch name, tag name or commit id)." + }, + ) + token: str = field( + default=None, + metadata={ + "help": ( + "The token to use as HTTP bearer authorization for remote files. If not specified, will use the token " + "generated when running `huggingface-cli login` (stored in `~/.huggingface`)." + ) + }, + ) + trust_remote_code: bool = field( + default=False, + metadata={ + "help": ( + "Whether or not to allow for custom models defined on the Hub in their own modeling files. This option " + "should only be set to `True` for repositories you trust and in which you have read the code, as it will " + "execute code present on the Hub on your local machine." + ) + }, + ) + freeze_feature_encoder: bool = field( + default=True, + metadata={"help": "Whether to freeze the feature encoder layers of the model."}, + ) + freeze_encoder: bool = field( + default=False, + metadata={"help": "Whether to freeze the entire encoder of the seq2seq model."}, + ) + forced_decoder_ids: List[List[int]] = field( + default=None, + metadata={ + "help": "Deprecated. Please use the `language` and `task` arguments instead." + }, + ) + suppress_tokens: List[int] = field( + default=None, + metadata={ + "help": ( + "Deprecated. The use of `suppress_tokens` should not be required for the majority of fine-tuning examples." + "Should you need to use `suppress_tokens`, please manually update them in the fine-tuning script directly." + ) + }, + ) + apply_spec_augment: bool = field( + default=False, + metadata={ + "help": "Whether to apply *SpecAugment* data augmentation to the input features. This is currently only relevant for Wav2Vec2, HuBERT, WavLM and Whisper models." + }, + ) + + +@dataclass +class DataTrainingArguments: + """ + Arguments pertaining to what data we are going to input our model for training and eval. + """ + + dataset_name: str = field( + default=None, + metadata={"help": "The name of the dataset to use (via the datasets library)."}, + ) + dataset_config_name: Optional[str] = field( + default=None, + metadata={ + "help": "The configuration name of the dataset to use (via the datasets library)." + }, + ) + overwrite_cache: bool = field( + default=False, + metadata={"help": "Overwrite the cached training and evaluation sets"}, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the preprocessing."}, + ) + max_train_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ) + }, + ) + max_eval_samples: Optional[int] = field( + default=None, + metadata={ + "help": ( + "For debugging purposes or quicker training, truncate the number of evaluation examples to this " + "value if set." + ) + }, + ) + audio_column_name: str = field( + default="audio", + metadata={ + "help": "The name of the dataset column containing the audio data. Defaults to 'audio'" + }, + ) + text_column_name: str = field( + default="text", + metadata={ + "help": "The name of the dataset column containing the text data. Defaults to 'text'" + }, + ) + max_duration_in_seconds: float = field( + default=20.0, + metadata={ + "help": ( + "Truncate audio files that are longer than `max_duration_in_seconds` seconds to" + " 'max_duration_in_seconds`" + ) + }, + ) + min_duration_in_seconds: float = field( + default=0.0, + metadata={ + "help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds" + }, + ) + preprocessing_only: bool = field( + default=False, + metadata={ + "help": ( + "Whether to only do data preprocessing and skip training. This is especially useful when data" + " preprocessing errors out in distributed training due to timeout. In this case, one should run the" + " preprocessing in a non-distributed setup with `preprocessing_only=True` so that the cached datasets" + " can consequently be loaded in distributed training" + ) + }, + ) + train_split_name: str = field( + default="train", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + eval_split_name: str = field( + default="test", + metadata={ + "help": "The name of the training data set split to use (via the datasets library). Defaults to 'train'" + }, + ) + do_lower_case: bool = field( + default=True, + metadata={"help": "Whether the target text should be lower cased."}, + ) + language: str = field( + default=None, + metadata={ + "help": ( + "Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning " + "only. For English speech recognition, it should be set to `None`." + ) + }, + ) + task: str = field( + default="transcribe", + metadata={ + "help": "Task, either `transcribe` for speech recognition or `translate` for speech translation." + }, + ) + + +@dataclass +class DataCollatorSpeechSeq2SeqWithPadding: + """ + Data collator that will dynamically pad the inputs received. + Args: + processor ([`WhisperProcessor`]) + The processor used for processing the data. + decoder_start_token_id (`int`) + The begin-of-sentence of the decoder. + forward_attention_mask (`bool`) + Whether to return attention_mask. + """ + + processor: Any + decoder_start_token_id: int + forward_attention_mask: bool + + def __call__( + self, features: List[Dict[str, Union[List[int], torch.Tensor]]] + ) -> Dict[str, torch.Tensor]: + # split inputs and labels since they have to be of different lengths and need + # different padding methods + model_input_name = self.processor.model_input_names[0] + input_features = [ + {model_input_name: feature[model_input_name]} for feature in features + ] + label_features = [{"input_ids": feature["labels"]} for feature in features] + + batch = self.processor.feature_extractor.pad( + input_features, return_tensors="pt" + ) + + if self.forward_attention_mask: + batch["attention_mask"] = torch.LongTensor( + [feature["attention_mask"] for feature in features] + ) + + labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") + + # replace padding with -100 to ignore loss correctly + labels = labels_batch["input_ids"].masked_fill( + labels_batch.attention_mask.ne(1), -100 + ) + + # if bos token is appended in previous tokenization step, + # cut bos token here as it's append later anyways + if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item(): + labels = labels[:, 1:] + + batch["labels"] = labels + + return batch + + +def main(): + # 1. Parse input arguments + # See all possible arguments in src/transformers/training_args.py + # or by passing the --help flag to this script. + # We now keep distinct sets of args, for a cleaner separation of concerns. + parser = HfArgumentParser( + (ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments) + ) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + # If we pass only one argument to the script and it's the path to a json file, + # let's parse it to get our arguments. + model_args, data_args, training_args = parser.parse_json_file( + json_file=os.path.abspath(sys.argv[1]) + ) + else: + model_args, data_args, training_args = parser.parse_args_into_dataclasses() + + # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The + # information sent is the one passed as arguments along with your Python/PyTorch versions. + send_example_telemetry("run_speech_recognition_seq2seq", model_args, data_args) + + # 2. Setup logging + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + handlers=[logging.StreamHandler(sys.stdout)], + ) + log_level = training_args.get_process_log_level() + logger.setLevel(log_level) + datasets.utils.logging.set_verbosity(log_level) + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + logger.setLevel( + logging.INFO if is_main_process(training_args.local_rank) else logging.WARN + ) + + # Log on each process the small summary: + logger.warning( + f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, " + f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}" + ) + logger.info(f"Training/evaluation parameters {training_args}") + + # Set the verbosity to info of the Transformers logger (on main process only): + if is_main_process(training_args.local_rank): + transformers.utils.logging.set_verbosity_info() + logger.info("Training/evaluation parameters %s", training_args) + + # 3. Detecting last checkpoint and eventually continue from last checkpoint + last_checkpoint = None + if ( + os.path.isdir(training_args.output_dir) + and training_args.do_train + and not training_args.overwrite_output_dir + ): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0: + raise ValueError( + f"Output directory ({training_args.output_dir}) already exists and is not empty. " + "Use --overwrite_output_dir to overcome." + ) + elif ( + last_checkpoint is not None and training_args.resume_from_checkpoint is None + ): + logger.info( + f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change " + "the `--output_dir` or add `--overwrite_output_dir` to train from scratch." + ) + + # Set seed before initializing model. + set_seed(training_args.seed) + + # 4. Load dataset + raw_datasets = DatasetDict() + + if training_args.do_train: + raw_datasets["train"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=data_args.train_split_name, + cache_dir=model_args.cache_dir, + token=model_args.token, + ) + + if training_args.do_eval: + raw_datasets["eval"] = load_dataset( + data_args.dataset_name, + data_args.dataset_config_name, + split=data_args.eval_split_name, + cache_dir=model_args.cache_dir, + token=model_args.token, + ) + + if ( + data_args.audio_column_name + not in next(iter(raw_datasets.values())).column_names + ): + raise ValueError( + f"--audio_column_name '{data_args.audio_column_name}' not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--audio_column_name` to the correct audio column - one of " + f"{', '.join(next(iter(raw_datasets.values())).column_names)}." + ) + + if data_args.text_column_name not in next(iter(raw_datasets.values())).column_names: + raise ValueError( + f"--text_column_name {data_args.text_column_name} not found in dataset '{data_args.dataset_name}'. " + "Make sure to set `--text_column_name` to the correct text column - one of " + f"{', '.join(next(iter(raw_datasets.values())).column_names)}." + ) + + # 5. Load pretrained model, tokenizer, and feature extractor + # + # Distributed training: + # The .from_pretrained methods guarantee that only one local process can concurrently + config = AutoConfig.from_pretrained( + ( + model_args.config_name + if model_args.config_name + else model_args.model_name_or_path + ), + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + token=model_args.token, + trust_remote_code=model_args.trust_remote_code, + ) + + # SpecAugment for whisper models + if getattr(config, "model_type", None) == "whisper": + config.update({"apply_spec_augment": model_args.apply_spec_augment}) + + feature_extractor = AutoFeatureExtractor.from_pretrained( + ( + model_args.feature_extractor_name + if model_args.feature_extractor_name + else model_args.model_name_or_path + ), + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + token=model_args.token, + trust_remote_code=model_args.trust_remote_code, + ) + tokenizer = AutoTokenizer.from_pretrained( + ( + model_args.tokenizer_name + if model_args.tokenizer_name + else model_args.model_name_or_path + ), + cache_dir=model_args.cache_dir, + use_fast=model_args.use_fast_tokenizer, + revision=model_args.model_revision, + token=model_args.token, + trust_remote_code=model_args.trust_remote_code, + ) + model = AutoModelForSpeechSeq2Seq.from_pretrained( + model_args.model_name_or_path, + config=config, + cache_dir=model_args.cache_dir, + revision=model_args.model_revision, + token=model_args.token, + trust_remote_code=model_args.trust_remote_code, + ) + + if model.config.decoder_start_token_id is None: + raise ValueError( + "Make sure that `config.decoder_start_token_id` is correctly defined" + ) + + if model_args.freeze_feature_encoder: + model.freeze_feature_encoder() + + if model_args.freeze_encoder: + model.freeze_encoder() + model.model.encoder.gradient_checkpointing = False + + if ( + hasattr(model.generation_config, "is_multilingual") + and model.generation_config.is_multilingual + ): + # We only need to set the language and task ids in a multilingual setting + tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task) + model.generation_config.language = data_args.language + model.generation_config.task = data_args.task + elif data_args.language is not None: + raise ValueError( + "Setting language token for an English-only checkpoint is not permitted. The language argument should " + "only be set for multilingual checkpoints." + ) + + # TODO (Sanchit): deprecate these arguments in v4.41 + if model_args.forced_decoder_ids is not None: + logger.warning( + "The use of `forced_decoder_ids` is deprecated and will be removed in v4.41." + "Please use the `language` and `task` arguments instead" + ) + model.generation_config.forced_decoder_ids = model_args.forced_decoder_ids + else: + model.generation_config.forced_decoder_ids = None + model.config.forced_decoder_ids = None + + if model_args.suppress_tokens is not None: + logger.warning( + "The use of `suppress_tokens` is deprecated and will be removed in v4.41." + "Should you need `suppress_tokens`, please manually set them in the fine-tuning script." + ) + model.generation_config.suppress_tokens = model_args.suppress_tokens + + # 6. Resample speech dataset if necessary + dataset_sampling_rate = ( + next(iter(raw_datasets.values())) + .features[data_args.audio_column_name] + .sampling_rate + ) + if dataset_sampling_rate != feature_extractor.sampling_rate: + raw_datasets = raw_datasets.cast_column( + data_args.audio_column_name, + datasets.features.Audio(sampling_rate=feature_extractor.sampling_rate), + ) + + # 7. Preprocessing the datasets. + # We need to read the audio files as arrays and tokenize the targets. + max_input_length = ( + data_args.max_duration_in_seconds * feature_extractor.sampling_rate + ) + min_input_length = ( + data_args.min_duration_in_seconds * feature_extractor.sampling_rate + ) + audio_column_name = data_args.audio_column_name + num_workers = data_args.preprocessing_num_workers + text_column_name = data_args.text_column_name + model_input_name = feature_extractor.model_input_names[0] + do_lower_case = data_args.do_lower_case + # if SpecAugment is used for whisper models, return attention_mask to guide the mask along time axis + forward_attention_mask = ( + getattr(config, "model_type", None) == "whisper" + and getattr(config, "apply_spec_augment", False) + and getattr(config, "mask_time_prob", 0) > 0 + ) + + if data_args.max_train_samples is not None: + raw_datasets["train"] = raw_datasets["train"].select( + range(data_args.max_train_samples) + ) + + if data_args.max_eval_samples is not None: + raw_datasets["eval"] = raw_datasets["eval"].select( + range(data_args.max_eval_samples) + ) + + def prepare_dataset(batch): + # process audio + sample = batch[audio_column_name] + inputs = feature_extractor( + sample["array"], + sampling_rate=sample["sampling_rate"], + return_attention_mask=forward_attention_mask, + ) + # process audio length + batch[model_input_name] = inputs.get(model_input_name)[0] + batch["input_length"] = len(sample["array"]) + if forward_attention_mask: + batch["attention_mask"] = inputs.get("attention_mask")[0] + + # process targets + input_str = ( + batch[text_column_name].lower() + if do_lower_case + else batch[text_column_name] + ) + batch["labels"] = tokenizer(input_str).input_ids + return batch + + with training_args.main_process_first(desc="dataset map pre-processing"): + vectorized_datasets = raw_datasets.map( + prepare_dataset, + remove_columns=next(iter(raw_datasets.values())).column_names, + num_proc=data_args.preprocessing_num_workers, + desc="preprocess train dataset", + ) + + # filter data that is shorter than min_input_length or longer than + # max_input_length + def is_audio_in_length_range(length): + return length > min_input_length and length < max_input_length + + vectorized_datasets = vectorized_datasets.filter( + is_audio_in_length_range, + num_proc=num_workers, + input_columns=["input_length"], + ) + + # for large datasets it is advised to run the preprocessing on a + # single machine first with `args.preprocessing_only` since there will mostly likely + # be a timeout when running the script in distributed mode. + # In a second step `args.preprocessing_only` can then be set to `False` to load the + # cached dataset + if data_args.preprocessing_only: + cache = {k: v.cache_files for k, v in vectorized_datasets.items()} + logger.info(f"Data preprocessing finished. Files cached at {cache}.") + return + + # 8. Load Metric + metric = evaluate.load("wer", cache_dir=model_args.cache_dir) + + def compute_metrics(pred): + pred_ids = pred.predictions + + pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id + + pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) + # we do not want to group tokens when computing the metrics + label_str = tokenizer.batch_decode(pred.label_ids, skip_special_tokens=True) + + wer = metric.compute(predictions=pred_str, references=label_str) + + return {"wer": wer} + + # 9. Create a single speech processor + # make sure all processes wait until data is saved + with training_args.main_process_first(): + # only the main process saves them + if is_main_process(training_args.local_rank): + # save feature extractor, tokenizer and config + feature_extractor.save_pretrained(training_args.output_dir) + tokenizer.save_pretrained(training_args.output_dir) + config.save_pretrained(training_args.output_dir) + + processor = AutoProcessor.from_pretrained(training_args.output_dir) + + # 10. Define data collator + data_collator = DataCollatorSpeechSeq2SeqWithPadding( + processor=processor, + decoder_start_token_id=model.config.decoder_start_token_id, + forward_attention_mask=forward_attention_mask, + ) + + # 11. Initialize Trainer + trainer = Seq2SeqTrainer( + model=model, + args=training_args, + train_dataset=vectorized_datasets["train"] if training_args.do_train else None, + eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None, + tokenizer=feature_extractor, + data_collator=data_collator, + compute_metrics=( + compute_metrics if training_args.predict_with_generate else None + ), + ) + + # 12. Training + if training_args.do_train: + checkpoint = None + if training_args.resume_from_checkpoint is not None: + checkpoint = training_args.resume_from_checkpoint + elif last_checkpoint is not None: + checkpoint = last_checkpoint + train_result = trainer.train(resume_from_checkpoint=checkpoint) + trainer.save_model() # Saves the feature extractor too for easy upload + + metrics = train_result.metrics + max_train_samples = ( + data_args.max_train_samples + if data_args.max_train_samples is not None + else len(vectorized_datasets["train"]) + ) + metrics["train_samples"] = min( + max_train_samples, len(vectorized_datasets["train"]) + ) + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_state() + + # 13. Evaluation + results = {} + if training_args.do_eval: + logger.info("*** Evaluate ***") + metrics = trainer.evaluate( + metric_key_prefix="eval", + max_length=training_args.generation_max_length, + num_beams=training_args.generation_num_beams, + ) + max_eval_samples = ( + data_args.max_eval_samples + if data_args.max_eval_samples is not None + else len(vectorized_datasets["eval"]) + ) + metrics["eval_samples"] = min( + max_eval_samples, len(vectorized_datasets["eval"]) + ) + + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # 14. Write Training Stats + kwargs = { + "finetuned_from": model_args.model_name_or_path, + "tasks": "automatic-speech-recognition", + } + if data_args.dataset_name is not None: + kwargs["dataset_tags"] = data_args.dataset_name + if data_args.dataset_config_name is not None: + kwargs["dataset_args"] = data_args.dataset_config_name + kwargs["dataset"] = ( + f"{data_args.dataset_name} {data_args.dataset_config_name}" + ) + else: + kwargs["dataset"] = data_args.dataset_name + + if training_args.push_to_hub: + trainer.push_to_hub(**kwargs) + else: + trainer.create_model_card(**kwargs) + + return results + + +if __name__ == "__main__": + main() diff --git a/train.py b/train.py index 19251e0..f0dddf6 100644 --- a/train.py +++ b/train.py @@ -36,29 +36,19 @@ dataset = utils.gather_dataset(args.process_ds_path) chuck_ds = [] trainer = Trainer(dataset) - i = 0 - for i in range(len(dataset) // 1000): - ds = trainer.process_dataset(dataset, i) - ds.save_to_disk(f"./dataset/process/{i}") - chuck_ds.append(ds) - ds = trainer.process_dataset(dataset, -1) - ds.save_to_disk(f"./dataset/process/{i+1}") - chuck_ds.append(ds) - - dataset = concatenate_datasets(chuck_ds) - trainer.dataset = dataset.train_test_split(test_size=0.05) + ds = trainer.process_dataset(dataset) + ds.save_to_disk(f"./dataset/process") + dataset = ds + trainer.dataset = dataset.train_test_split(test_size=0.3) elif args.chunked_ds_path: - chuck_ds = [] - nb_chunks = len(glob.glob(f"{args.chunked_ds_path}/*")) - for i in range(nb_chunks): - ds = Dataset.load_from_disk(f"{args.chunked_ds_path}/{i}") - chuck_ds.append(ds) - dataset = concatenate_datasets(chuck_ds) - dataset = dataset.train_test_split(test_size=0.05) + dataset = Dataset.load_from_disk(f"{args.chunked_ds_path}") + dataset = dataset.train_test_split(test_size=0.3) trainer = Trainer(dataset) else: raise ValueError("You must provide either --process_ds_path or --chunked_ds_path") +print(dataset) + trainer.train() trainer.save_model(args.model_path) diff --git a/train2.py b/train2.py new file mode 100644 index 0000000..2d89920 --- /dev/null +++ b/train2.py @@ -0,0 +1,56 @@ +from datasets import load_dataset, DatasetDict +from datasets import Audio + +from training.train import Trainer + +common_voice = DatasetDict() + +common_voice["train"] = load_dataset( + "mozilla-foundation/common_voice_11_0", + "hi", + split="train+validation", + use_auth_token=True, +) +common_voice["test"] = load_dataset( + "mozilla-foundation/common_voice_11_0", "hi", split="test", use_auth_token=True +) + + +common_voice = common_voice.remove_columns( + [ + "accent", + "age", + "client_id", + "down_votes", + "gender", + "locale", + "path", + "segment", + "up_votes", + ] +) + +common_voice = common_voice.cast_column("audio", Audio(sampling_rate=16000)) + +trainer = Trainer() + + +def prepare_dataset(batch): + # load and resample audio data from 48 to 16kHz + audio = batch["audio"] + + # compute log-Mel input features from input audio array + batch["input_features"] = trainer.feature_extractor( + audio["array"], sampling_rate=audio["sampling_rate"] + ).input_features[0] + + # encode target text to label ids + batch["labels"] = trainer.tokenizer(batch["sentence"]).input_ids + return batch + + +common_voice = common_voice.map( + prepare_dataset, remove_columns=common_voice.column_names["train"], num_proc=1 +) + +trainer.train(common_voice) diff --git a/training/collator.py b/training/collator.py index 60c1ce4..6e3c104 100644 --- a/training/collator.py +++ b/training/collator.py @@ -2,33 +2,28 @@ This file contains the data collator for the Speech2Text model. """ -from dataclasses import dataclass -from typing import Any, Dict, Union, List - import torch +from dataclasses import dataclass +from typing import Any, Dict, List, Union + @dataclass class DataCollatorSpeechSeq2SeqWithPadding: - """ - Data collator that will dynamically pad the inputs received. - """ processor: Any + decoder_start_token_id: int def __call__( self, features: List[Dict[str, Union[List[int], torch.Tensor]]] ) -> Dict[str, torch.Tensor]: - """ - This method pads the input features and the labels to the maximum length in the batch and return it. - :param features: The features to pad. - :return: The padded features. - """ # split inputs and labels since they have to be of different lengths and need different padding methods # first treat the audio inputs by simply returning torch tensors input_features = [ - {"input_features": feature["input_features"][0]} for feature in features + {"input_features": feature["input_features"]} for feature in features ] - batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") + batch = self.processor.feature_extractor.pad( + input_features, return_tensors="pt" + ) # get the tokenized label sequences label_features = [{"input_ids": feature["labels"]} for feature in features] @@ -42,7 +37,7 @@ def __call__( # if bos token is appended in previous tokenization step, # cut bos token here as it's append later anyways - if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item(): + if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item(): labels = labels[:, 1:] batch["labels"] = labels diff --git a/training/train.py b/training/train.py index 4f21d31..f883eeb 100644 --- a/training/train.py +++ b/training/train.py @@ -2,31 +2,18 @@ This module contains the Trainer class which is responsible for training whisper on predicting lyrics. """ -import warnings -from functools import partial - import evaluate -import librosa -import numpy as np -import torch -from datasets import Dataset from transformers import ( - WhisperProcessor, - WhisperForConditionalGeneration, - Seq2SeqTrainingArguments, Seq2SeqTrainer, + Seq2SeqTrainingArguments, + WhisperFeatureExtractor, + WhisperForConditionalGeneration, + WhisperProcessor, + WhisperTokenizer, ) -from transformers.models.whisper.english_normalizer import BasicTextNormalizer from training.collator import DataCollatorSpeechSeq2SeqWithPadding -METRIC = evaluate.load("wer") - -NORMALIZER = BasicTextNormalizer() - -torch.backends.cuda.matmul.allow_tf32 = True -torch.backends.cudnn.allow_tf32 = True - class Trainer: """ @@ -35,173 +22,97 @@ class Trainer: def __init__( self, - dataset=None, - model_name="openai/whisper-small", + model_name="openai/whisper-tiny", + language="hindi", + task="transcribe", + output_dir="./whisper-finetuned", ): - """ - The constructor for the Trainer class. - The dataset is optional and can be added later with the method process_dataset. - The dataset should be formated and already mapped to the columns "audio" and "lyrics" and ready for training. - :param dataset: The dataset to train the model on. - """ - self.processor = WhisperProcessor.from_pretrained(model_name, language="en", task="transcribe") - self.model = WhisperForConditionalGeneration.from_pretrained(model_name) - self.dataset = dataset - self.data_collator = DataCollatorSpeechSeq2SeqWithPadding(self.processor) - self.prepare_tokenizer() + """Function to initialize the Trainer class. - def prepare_tokenizer(self) -> None: - """ - A method that adds special tokens i.e. tags to the tokenizer. - :return: None + Args: + model_name (str, optional): _description_. Defaults to "openai/whisper-tiny". + language (str, optional): _description_. Defaults to "hindi". + task (str, optional): _description_. Defaults to "transcribe". + output_dir (str, optional): _description_. Defaults to "./whisper-finetuned". """ - special_tokens_to_add = [] - for i in range(1, 5): - special_tokens_to_add.append(f"[VERSE {i}]") - special_tokens_to_add.append("[CHORUS]") - special_tokens_to_add.append("[BRIDGE]") - self.processor.tokenizer.add_special_tokens( - {"additional_special_tokens": special_tokens_to_add} - ) - self.model.resize_token_embeddings(len(self.processor.tokenizer)) - def process_dataset(self, dataset, chunk_id) -> Dataset: - """ - A method that processes the dataset. - :return: None - """ + self.feature_extractor = WhisperFeatureExtractor.from_pretrained(model_name) + self.tokenizer = WhisperTokenizer.from_pretrained( + model_name, language=language, task=task + ) - def prepare_dataset(example): - target_sr = self.processor.feature_extractor.sampling_rate - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=FutureWarning) - warnings.filterwarnings("ignore", category=UserWarning) - audio, sr = librosa.load(example["audio"], sr=None) - audio = librosa.resample( - np.asarray(audio), - orig_sr=sr, - target_sr=target_sr, - ) - - example = self.processor( - audio=audio, - sampling_rate=target_sr, - text=example["lyrics"], - ) - - # compute input length of audio sample in seconds - example["input_length"] = len(audio) / sr - - return example - if chunk_id == -1: - last_chunk_size = len(dataset) % 1000 - small_dataset = Dataset.from_dict(dataset[-last_chunk_size:]) - else: - small_dataset = Dataset.from_dict(dataset[chunk_id*1000:chunk_id*1000+1000]) - self.dataset = small_dataset.map( - prepare_dataset, - remove_columns=small_dataset.column_names, - num_proc=1, + self.processor = WhisperProcessor.from_pretrained( + model_name, language=language, task=task ) - max_input_length = 30.0 + self.model = WhisperForConditionalGeneration.from_pretrained(model_name) + self.model.generation_config.language = language + self.model.generation_config.task = task - def is_audio_in_length_range(length): - return length < max_input_length + self.model.generation_config.forced_decoder_ids = None + self.metric = evaluate.load("wer") - self.dataset = self.dataset.filter( - is_audio_in_length_range, - input_columns=["input_length"], + self.data_collator = DataCollatorSpeechSeq2SeqWithPadding( + processor=self.processor, + decoder_start_token_id=self.model.config.decoder_start_token_id, ) - return self.dataset - def compute_metrics(self, pred): - """ - A method that computes the metrics. - :param pred: The predictions of the model. - :return: The metrics. - """ + self._ouput_dir = output_dir + + def _compute_metrics(self, pred): pred_ids = pred.predictions label_ids = pred.label_ids # replace -100 with the pad_token_id - label_ids[label_ids == -100] = self.processor.tokenizer.pad_token_id + label_ids[label_ids == -100] = self.tokenizer.pad_token_id # we do not want to group tokens when computing the metrics - pred_str = self.processor.batch_decode(pred_ids, skip_special_tokens=True) - label_str = self.processor.batch_decode(label_ids, skip_special_tokens=True) - - # compute orthographic wer - wer_ortho = 100 * METRIC.compute(predictions=pred_str, references=label_str) - - # compute normalised WER - pred_str_norm = [NORMALIZER(pred) for pred in pred_str] - label_str_norm = [NORMALIZER(label) for label in label_str] - # filtering step to only evaluate the samples that correspond to non-zero references: - pred_str_norm = [ - pred_str_norm[i] - for i in range(len(pred_str_norm)) - if len(label_str_norm[i]) > 0 - ] - label_str_norm = [ - label_str_norm[i] - for i in range(len(label_str_norm)) - if len(label_str_norm[i]) > 0 - ] - - wer = 100 * METRIC.compute(predictions=pred_str_norm, references=label_str_norm) - - return {"wer_ortho": wer_ortho, "wer": wer} - - def train(self): + pred_str = self.tokenizer.batch_decode(pred_ids, skip_special_tokens=True) + label_str = self.tokenizer.batch_decode(label_ids, skip_special_tokens=True) + + wer = 100 * self.metric.compute(predictions=pred_str, references=label_str) + + return {"wer": wer} + + def train(self, dataset): """ A method that trains the model. :return: """ - self.model.generate = partial( - self.model.generate, language="en", task="transcribe", use_cache=True - ) training_args = Seq2SeqTrainingArguments( - output_dir="./train", - per_device_train_batch_size=10, - per_device_eval_batch_size=8, - num_train_epochs=3, + output_dir=self._ouput_dir, # change to a repo name of your choice + per_device_train_batch_size=8, + gradient_accumulation_steps=1, # increase by 2x for every 2x decrease in batch size learning_rate=1e-5, - lr_scheduler_type="linear", - warmup_steps=50, - gradient_checkpointing=False, - fp16=not torch.cuda.is_bf16_supported(), - bf16=torch.cuda.is_bf16_supported(), - bf16_full_eval=torch.cuda.is_bf16_supported(), - fp16_full_eval=not torch.cuda.is_bf16_supported(), + warmup_steps=500, + max_steps=4000, + gradient_checkpointing=True, + fp16=True, evaluation_strategy="steps", - eval_steps=75, - optim="adamw_8bit", + per_device_eval_batch_size=8, predict_with_generate=True, - generation_max_length=350, + generation_max_length=225, + save_steps=10, + eval_steps=10, logging_steps=25, + report_to=["tensorboard"], + load_best_model_at_end=True, metric_for_best_model="wer", greater_is_better=False, + push_to_hub=True, ) trainer = Seq2SeqTrainer( args=training_args, model=self.model, - train_dataset=self.dataset["train"], - eval_dataset=self.dataset["test"], + train_dataset=dataset["train"], + eval_dataset=dataset["test"], data_collator=self.data_collator, - compute_metrics=self.compute_metrics, - tokenizer=self.processor, + compute_metrics=self._compute_metrics, + tokenizer=self.processor.feature_extractor, ) - return trainer.train() - def save_model(self, path: str) -> None: - """ - A method that saves the model. - :param path: The path to save the model. - :return: None - """ + self.processor.save_pretrained(training_args.output_dir) - self.model.save_pretrained(path) - self.processor.save_pretrained(path) + trainer.train()