diff --git a/README.md b/README.md index 59924bc..f6f0a7f 100644 --- a/README.md +++ b/README.md @@ -1,115 +1,145 @@ -

- -

+# Are Large Language Models Good Classifiers? A Study on Edit Intent Classification in Scientific Document Revisions +This is the official code repository for the paper "Are Large Language Models Good Classifiers? A Study on Edit Intent Classification in Scientific Document Revisions", presented at EMNLP 2024 main conference. It contains the scripts for the fine-tuning approaches outlined in the paper. -# llm_classifier -[![Arxiv](https://img.shields.io/badge/Arxiv-YYMM.NNNNN-red?style=flat-square&logo=arxiv&logoColor=white)](https://put-here-your-paper.com) -[![License](https://img.shields.io/github/license/UKPLab/llm_classifier)](https://opensource.org/licenses/Apache-2.0) -[![Python Versions](https://img.shields.io/badge/Python-3.9-blue.svg?style=flat&logo=python&logoColor=white)](https://www.python.org/) -[![CI](https://github.com/UKPLab/llm_classifier/actions/workflows/main.yml/badge.svg)](https://github.com/UKPLab/llm_classifier/actions/workflows/main.yml) +Please find the paper [here](https://arxiv.org/abs/2410.02028), and star the repository to stay updated with the latest information. -This is the official template for new Python projects at UKP Lab. It was adapted for the needs of UKP Lab from the excellent [python-project-template](https://github.com/rochacbruno/python-project-template/) by [rochacbruno](https://github.com/rochacbruno). +In case of questions please contact [Qian Ruan](mailto:ruan@ukp.tu-darmstadt.de). -It should help you start your project and give you continuous status updates on the development through [GitHub Actions](https://docs.github.com/en/actions). +## Abstract +Classification is a core NLP task architecture with many potential applications. While large language models (LLMs) have brought substantial advancements in text generation, their potential for enhancing classification tasks remains underexplored. To address this gap, we propose a framework for thoroughly investigating fine-tuning LLMs for classification, including both generation- and encoding-based approaches. We instantiate this framework in edit intent classification (EIC), a challenging and underexplored classification task. Our extensive experiments and systematic comparisons with various training approaches and a representative selection of LLMs yield new insights into their application for EIC. We investigate the generalizability of these findings on five further classification tasks. To demonstrate the proposed methods and address the data shortage for empirical edit analysis, we use our bestperforming EIC model to create Re3-Sci2.0, a new large-scale dataset of 1,780 scientific document revisions with over 94k labeled edits. The quality of the dataset is assessed through human evaluation. The new dataset enables an in-depth empirical study of human editing behavior in academic writing. +![](/resource/overview.pdf) -> **Abstract:** The study of natural language processing (NLP) has gained increasing importance in recent years, with applications ranging from machine translation to sentiment analysis. Properly managing Python projects in this domain is of paramount importance to ensure reproducibility and facilitate collaboration. The template provides a structured starting point for projects and offers continuous status updates on development through GitHub Actions. Key features include a basic setup.py file for installation, packaging, and distribution, documentation structure using mkdocs, testing structure using pytest, code linting with pylint, and entry points for executing the program with basic CLI argument parsing. Additionally, the template incorporates continuous integration using GitHub Actions with jobs to check, lint, and test the project, ensuring robustness and reliability throughout the development process. +*Figure 1. In this work, we (1). present a general framework to explore the classification capabilities of LLMs, conducting extensive experiments and systematic comparisons on the EIC task; (2). use the best model to +create the Re3-Sci2.0 dataset, which comprises 1,780 scientific document revisions (a-b), associated reviews (c, d), and 94,482 edits annotated with action and intent labels (e, f), spanning various scholarly domains; +(3). provide a first in-depth empirical analysis of human editing behavior using this new dataset.* -Contact person: [Federico Tiblias](mailto:federico.tiblias@tu-darmstadt.de) +## Approaches +![](/resource/approaches.pdf) -[UKP Lab](https://www.ukp.tu-darmstadt.de/) | [TU Darmstadt](https://www.tu-darmstadt.de/ -) +*Figure 2. Proposed approaches with a systematic investigation of the key components: input types (red), language models (green), and transformation functions (yellow). See §3 and §4 of the paper for details.* -Don't hesitate to send us an e-mail or report an issue, if something is broken (and it shouldn't be) or if you have further questions. - - -## Getting Started - -> **DO NOT CLONE OR FORK** - -If you want to set up this template: - -1. Request a repository on UKP Lab's GitHub by following the standard procedure on the wiki. It will install the template directly. Alternatively, set it up in your personal GitHub account by clicking **[Use this template](https://github.com/rochacbruno/python-project-template/generate)**. -2. Wait until the first run of CI finishes. Github Actions will commit to your new repo with a "✅ Ready to clone and code" message. -3. Delete optional files: - - If you don't need automatic documentation generation, you can delete folder `docs`, file `.github\workflows\docs.yml` and `mkdocs.yml` - - If you don't want automatic testing, you can delete folder `tests` and file `.github\workflows\tests.yml` -4. Prepare a virtual environment: +## Quickstart +1. Download the project from github. ```bash -python -m venv .venv -source .venv/bin/activate -pip install . -pip install -r requirements-dev.txt # Only needed for development +git clone https://github.com/UKPLab/llm_classifier ``` -5. Adapt anything else (for example this file) to your project. - -6. Read the file [ABOUT_THIS_TEMPLATE.md](ABOUT_THIS_TEMPLATE.md) for more information about development. - -## Usage - -### Using the classes -To import classes/methods of `llm_classifier` from inside the package itself you can use relative imports: - -```py -from .base import BaseClass # Notice how I omit the package name - -BaseClass().something() +2. Setup environment +```bash +python -m venv .llm_classifier +source ./.llm_classifier/bin/activate +pip install -r requirements.txt +``` + + +### Fine-tuining LLMs +Check the 'finetune_EIC_\.py' scripts for the complete workflows with each approach: Gen, SeqC, XNet and SNet. You can customize the arguments within \ and \. Refer to the paper for more details. + +For example, fine-tune LLM with the SeqC approach: + +1. Basic Settings + +```python + ############################################################################ + # basic settings + # + task_name ='edit_intent_classification' + method = 'finetuning_llm_seqc' # select an approach from ['finetuning_llm_gen','finetuning_llm_seqc', 'finetuning_llm_snet', 'finetuning_llm_xnet'] + train_type ='train' # name of the training data in data/Re3-Sci/tasks/edit_intent_classification + val_type = 'val' # name of the validation data in data/Re3-Sci/tasks/edit_intent_classification + test_type = 'test' # name of the test data in data/Re3-Sci/tasks/edit_intent_classification + # + ############################################################################ ``` +2. Load Data -To import classes/methods from outside the package (e.g. when you want to use the package in some other project) you can instead refer to the package name: - -```py -from llm_classifier import BaseClass # Notice how I omit the file name -from llm_classifier.subpackage import SubPackageClass # Here it's necessary because it's a subpackage - -BaseClass().something() -SubPackageClass().something() +```python + from tasks.task_data_loader import TaskDataLoader + task_data_loader = TaskDataLoader(task_name=task_name, train_type=train_type, val_type=val_type, test_type=test_type) + train_ds, val_ds, test_ds= task_data_loader.load_data() + labels, label2id, id2label = task_data_loader.get_labels() ``` -### Using scripts - -This is how you can use `llm_classifier` from command line: - -```bash -$ python -m llm_classifier +3. Load Model + +```python + # load model from path + # + model_path = 'path/to/model' + emb_type = None # transformation function for xnet and snet approaches, select from [''diff', diffABS', 'n-diffABS', 'n-o', 'n-diffABS-o'], None for SeqC and Gen + input_type='text_st_on' #input type for the model, select from ['text_nl_on', 'text_st_on', 'inst_text_st_on', 'inst_text_nl_on'] for natural language input, structured input, instruction + structured input, instruction + natural language input, respectively + # + from tasks.task_model_loader import TaskModelLoader + model_loader = TaskModelLoader(task_name=task_name, method=method).model_loader + model, tokenizer = model_loader.load_model_from_path(model_path, labels=labels, label2id=label2id, id2label=id2label, emb_type=emb_type, input_type=input_type) +``` +4. Preprocess dataset + +```python + # + max_length = 1024 + # + from tasks.task_data_preprocessor import TaskDataPreprocessor + data_preprocessor = TaskDataPreprocessor(task_name=task_name, method=method).data_preprocessor + train_ds = data_preprocessor.preprocess_data(train_ds, label2id, tokenizer, max_length=max_length, input_type=input_type) + val_ds = data_preprocessor.preprocess_data(val_ds, label2id, tokenizer, max_length=max_length, input_type=input_type) + test_ds = data_preprocessor.preprocess_data(test_ds, label2id, tokenizer, max_length=max_length, input_type=input_type) +``` +5. Fine-tune model + +```python + # fine-tune model + # + lora_r = 128 # LoRA rank parameter + lora_alpha = 128 # Alpha parameter for LoRA scaling + lora_dropout = 0.1 # Dropout probability for LoRA layers + learning_rate = 2e-4 # Learning rate + per_device_train_batch_size = 32 # Batch size per GPU for training + train_epochs = 2 # Number of epochs to train + recreate_dir = True # Create a directory for the model + # + # create model dir to save the fine-tuned model + from finetune_EIC_SeqC import create_model_dir + output_dir = create_model_dir(task_name, method, model_path, lora_r, lora_alpha, lora_dropout, learning_rate, + per_device_train_batch_size, train_epochs, train_type, test_type, + max_length, emb_type, input_type, recreate_dir=recreate_dir) + # fine-tune + from tasks.task_model_finetuner import TaskModelFinetuner + model_finetuner = TaskModelFinetuner(task_name=task_name, method=method).model_finetuner + model_finetuner.fine_tune(model, tokenizer, train_ds = train_ds , val_ds = val_ds, lora_r = lora_r, lora_alpha = lora_alpha, lora_dropout = lora_dropout, + learning_rate = learning_rate, per_device_train_batch_size = per_device_train_batch_size, train_epochs = train_epochs, output_dir = output_dir) +``` +6. Evaluate + +```python + # fine-tune model + # evaluate the fine-tuned model + from tasks.task_evaluater import TaskEvaluater + evaluater = TaskEvaluater(task_name=task_name, method=method).evaluater + evaluater.evaluate(test_ds, model_dir=output_dir, labels=labels, label2id=label2id, id2label=id2label, emb_type=emb_type, input_type=input_type, response_key=response_key) ``` -### Expected results - -After running the experiments, you should expect the following results: - -(Feel free to describe your expected results here...) - -### Parameter description - -* `x, --xxxx`: This parameter does something nice - -* ... - -* `z, --zzzz`: This parameter does something even nicer - -## Development - -Read the FAQs in [ABOUT_THIS_TEMPLATE.md](ABOUT_THIS_TEMPLATE.md) to learn more about how this template works and where you should put your classes & methods. Make sure you've correctly installed `requirements-dev.txt` dependencies - -## Cite +## Citation Please use the following citation: ``` -@InProceedings{smith:20xx:CONFERENCE_TITLE, - author = {Smith, John}, - title = {My Paper Title}, - booktitle = {Proceedings of the 20XX Conference on XXXX}, - month = mmm, - year = {20xx}, - address = {Gotham City, USA}, - publisher = {Association for XXX}, - pages = {XXXX--XXXX}, - url = {http://xxxx.xxx} +@misc{ruan2024llmclassifier, + title={Are Large Language Models Good Classifiers? A Study on Edit Intent Classification in Scientific Document Revisions}, + author={Qian Ruan and Ilia Kuznetsov and Iryna Gurevych}, + year={2024}, + eprint={2410.02028}, + archivePrefix={arXiv}, + primaryClass={cs.CL}, + url={https://arxiv.org/abs/2410.02028}, } ``` ## Disclaimer +This repository contains experimental software and is published for the sole purpose of giving additional background details on the respective publication. + + + + -> This repository contains experimental software and is published for the sole purpose of giving additional background details on the respective publication. + \ No newline at end of file diff --git a/tasks/__init__.py b/tasks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tasks/edit_intent_classification/__init__.py b/tasks/edit_intent_classification/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tasks/edit_intent_classification/finetuning_llm_gen/__init__.py b/tasks/edit_intent_classification/finetuning_llm_gen/__init__.py new file mode 100644 index 0000000..354aede --- /dev/null +++ b/tasks/edit_intent_classification/finetuning_llm_gen/__init__.py @@ -0,0 +1,11 @@ +from .model_loader import ModelLoader +from .data_preprocessor import DataPreprocessor +from .model_finetuner import ModelFinetuner +from .evaluater import Evaluater + +__all__ = [ + "ModelLoader", + "DataPreprocessor" + "ModelFinetuner" + "Evaluater" + ] \ No newline at end of file diff --git a/tasks/edit_intent_classification/finetuning_llm_gen/data_preprocessor.py b/tasks/edit_intent_classification/finetuning_llm_gen/data_preprocessor.py new file mode 100644 index 0000000..5f7b3ab --- /dev/null +++ b/tasks/edit_intent_classification/finetuning_llm_gen/data_preprocessor.py @@ -0,0 +1,96 @@ +# Initialize static strings for the prompt template +# natural language input (nl) +INSTRUCTION_KEY = "### Instruction:" +INSTRUCTION_KEY_END = '' +INPUT_KEY = "INPUT:" +INPUT_KEY_END = '' +NEW_START = 'NEW:' +NEW_END = '' +OLD_START = 'OLD:' +OLD_END = '' +RESPONSE_KEY = 'RESPONSE:' +END_KEY = '### End' + + +#structured input (st) +INSTRUCTION_KEY_ST = "" +INSTRUCTION_KEY_END_ST = '' +INPUT_KEY_ST = '' +INPUT_KEY_END_ST = '' +NEW_START_ST = '' +NEW_END_ST = '' +OLD_START_ST = '' +OLD_END_ST = '' +RESPONSE_KEY_ST = "" +END_KEY_ST = "" + + +TASK_PROMPT = "Classify the intent of the following sentence edit. The possible labels are: Grammar, Clarity, Fact/Evidence, Claim, Other. " + +PROMPT_ST_DIC = {'nl': [INSTRUCTION_KEY,INSTRUCTION_KEY_END, INPUT_KEY, INPUT_KEY_END, OLD_START,OLD_END, NEW_START, NEW_END, RESPONSE_KEY, END_KEY], + 'st': [INSTRUCTION_KEY_ST,INSTRUCTION_KEY_END_ST, INPUT_KEY_ST, INPUT_KEY_END_ST, OLD_START_ST,OLD_END_ST, NEW_START_ST, NEW_END_ST, RESPONSE_KEY_ST, END_KEY_ST]} + +class DataPreprocessor: + def __init__(self) -> None: + print('Preprocessing the data...Gen') + + def preprocess_data(self, dataset, max_length=1024, input_type='text_st_on', is_train:bool=True): + """ + :param max_length (int): Maximum number of tokens to emit from the tokenizer + :param input_type (str): Type of input text + """ + self.prompt_st_type = input_type.split('_')[-2] + instruction_key, instruction_key_end, input_key, input_key_end, old_start, old_end, new_start, new_end, response_key, end_key = PROMPT_ST_DIC[self.prompt_st_type] + + # Add prompt to each sample + print("Preprocessing dataset...") + if is_train: + dataset = dataset.map(self.create_prompt_formats_train, keep_in_memory=True) + else: + dataset = dataset.map(self.create_prompt_formats_test, keep_in_memory=True) + + # Shuffle dataset + seed = 42 + dataset = dataset.shuffle(seed = seed) + return dataset, response_key + + def create_prompt_formats_train(self, sample): + """ + Creates a formatted prompt template for a prompt in the dataset + :param sample: sample from the dataset + """ + instruction_key, instruction_key_end, input_key, input_key_end, old_start, old_end, new_start, new_end, response_key, end_key = PROMPT_ST_DIC[self.prompt_st_type] + task_prompt = TASK_PROMPT + # Combine a prompt with the static strings + instruction = f"{instruction_key} {task_prompt} {instruction_key_end}" + + text_src = sample['text_src'] if sample['text_src'] is not None else '' + text_tgt = sample['text_tgt'] if sample['text_tgt'] is not None else '' + input_context = f"{input_key}\n {old_start} {text_tgt} {old_end}\n {new_start} {text_src} {new_end}\n{input_key_end}" + response = f"{response_key}{sample['label']}" + end = f"{end_key}" + # Create a list of prompt template elements + parts = [part for part in [instruction, input_context, response, end] if part] + # Join prompt template elements into a single string to create the prompt template + formatted_prompt = "\n".join(parts) + # Store the formatted prompt template in a new key "text" + sample["text"] = formatted_prompt + return sample + + def create_prompt_formats_test(self, sample): + """ + Creates a formatted prompt template for a prompt in the dataset + :param sample: sample from the dataset + """ + instruction_key, instruction_key_end, input_key, input_key_end, old_start, old_end, new_start, new_end, response_key, end_key = PROMPT_ST_DIC[self.prompt_st_type] + task_prompt = TASK_PROMPT + instruction = f"{instruction_key} {task_prompt} {instruction_key_end}" + text_src = sample['text_src'] if sample['text_src'] is not None else '' + text_tgt = sample['text_tgt'] if sample['text_tgt'] is not None else '' + input_context = f"{input_key}\n {old_start} {text_tgt} {old_end}\n {new_start} {text_src} {new_end}\n{input_key_end}" + response = f"{response_key}" + parts = [part for part in [instruction, input_context, response] if part] + formatted_prompt = "\n".join(parts) + sample["text"] = formatted_prompt + return sample + \ No newline at end of file diff --git a/tasks/edit_intent_classification/finetuning_llm_gen/evaluater.py b/tasks/edit_intent_classification/finetuning_llm_gen/evaluater.py new file mode 100644 index 0000000..0ece895 --- /dev/null +++ b/tasks/edit_intent_classification/finetuning_llm_gen/evaluater.py @@ -0,0 +1,146 @@ +from pathlib import Path +import torch +import numpy as np +from tqdm import tqdm +import json +import pandas as pd +from sklearn.metrics import accuracy_score, classification_report +from transformers import AutoTokenizer, pipeline +from peft import AutoPeftModelForCausalLM + +class Evaluater: + def __init__(self) -> None: + '' + + def merge_model(self, finetuned_model_dir:Path): + tokenizer = AutoTokenizer.from_pretrained(str(finetuned_model_dir)) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = 'right' + compute_dtype = getattr(torch, "float16") + model = AutoPeftModelForCausalLM.from_pretrained( + str(finetuned_model_dir)+'/', + torch_dtype=compute_dtype, + return_dict=False, + low_cpu_mem_usage=True, + device_map='cuda:0', + ) + model = model.merge_and_unload() + model.config.pad_token_id = tokenizer.pad_token_id + + if 'mistral' in str(finetuned_model_dir): + model.config.sliding_window = 4096 + + print('merge_model, 4model', model) + + return model, tokenizer + + def predict(self, test, model, tokenizer, labels, output_dir, response_key, is_val=False): + if is_val: + eval_file = output_dir / "VAL_eval_pred.csv" + else: + eval_file = output_dir / "eval_pred.csv" + print('eval_file', eval_file) + if eval_file.exists(): + eval_file.unlink() + + for i in tqdm(range(len(test))): + prompt = test[i]["text"] + pipe = pipeline(task="text-generation", + model=model, + tokenizer=tokenizer, + max_new_tokens = 10, + ) + result = pipe(prompt) + answer = result[0]['generated_text'].split(response_key)[-1] + found = False + for l in labels: + if l.lower() in answer.lower(): + pred = l + found = True + break + if not found: + pred="none" + a = pd.DataFrame({ "true":[test[i]['label']], "pred":[pred], "answer":[answer], "prompt":[prompt]}) + a.to_csv(eval_file,mode="a",index=False,header=not eval_file.exists()) + return eval_file + + + def evaluate(self, test, model=None, tokenizer=None, model_dir=None, output_dir=None, do_predict = True, + labels=None, label2id=None, id2label=None, + emb_type=None, input_type=None, response_key = None, is_val = False): + # load the model + if model is None or tokenizer is None: + model, tokenizer = self.merge_model(model_dir) + + start_time = pd.Timestamp.now() + if output_dir is None: + output_dir = Path(model_dir) + if do_predict: + eval_file = self.predict(test, model, tokenizer, labels, output_dir, response_key, is_val=is_val) + end_time = pd.Timestamp.now() + inference_time = end_time - start_time + inference_time = inference_time.total_seconds() + + + df = pd.read_csv(eval_file) + none_nr = len(df[df['pred'] == 'none']) + total_nr = len(df) + + eff = round((total_nr / int(inference_time)), 1) + if is_val: + inf_file = output_dir / "VAL_inference_time.json" + else: + inf_file = output_dir / "inference_time.json" + with open (inf_file, 'w') as f: + json.dump({'inference_time':int(inference_time), 'inference_efficieny':eff}, f, indent=4) + + + #calculate accuarcy with 'none' samples + y_pred = df["pred"] + y_true = df["true"] + # Map labels to ids + label2id['none'] = len(label2id) + map_func = lambda label: label2id[label] + y_true = np.vectorize(map_func)(y_true) + y_pred = np.vectorize(map_func)(y_pred) + # Calculate accuracy + accuracy = accuracy_score(y_true=y_true, y_pred=y_pred) + print(f'Accuracy: {accuracy:.3f}') + + # Generate accuracy report + # if 'none' exists in the labels, add it to the target names for accuracy calculation + if none_nr > 0: + target_names = labels+['none'] + else: + target_names = labels + class_report = classification_report(y_true=y_true, y_pred=y_pred, target_names=target_names, output_dict=True, zero_division=0) + + # but the marco avg and weighted avg f1 should not include 'none', otherwise 'none' class having 0 samples will affect the calculation + if none_nr > 0: + df = df[df['pred'] != 'none'] + y_pred = df["pred"] + y_true = df["true"] + map_func = lambda label: label2id[str(label)] + y_true = np.vectorize(map_func)(y_true) + y_pred = np.vectorize(map_func)(y_pred) + class_report2 = classification_report(y_true=y_true, y_pred=y_pred, target_names=labels, output_dict=True, zero_division=0) + class_report['weighted avg 2'] = class_report2['weighted avg'] + class_report['macro avg 2'] = class_report2['macro avg'] + + print('\nClassification Report:') + class_report['none_nr'] = none_nr + class_report['AIR'] = round(((total_nr - none_nr) / total_nr)*100, 1) + print(class_report) + + + if is_val: + eval_file = output_dir / "VAL_eval_report.json" + else: + eval_file = output_dir / "eval_report.json" + with open(str(eval_file), 'w') as f: + json.dump(class_report, f, indent=4) + + if is_val: + return accuracy + + diff --git a/tasks/edit_intent_classification/finetuning_llm_gen/model_finetuner.py b/tasks/edit_intent_classification/finetuning_llm_gen/model_finetuner.py new file mode 100644 index 0000000..9977381 --- /dev/null +++ b/tasks/edit_intent_classification/finetuning_llm_gen/model_finetuner.py @@ -0,0 +1,147 @@ +import shutil +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training +from trl import SFTTrainer +from transformers import TrainingArguments, DataCollatorForLanguageModeling +from .evaluater import Evaluater + + +class ModelFinetuner: + def __init__(self) -> None: + '' + + def print_trainable_parameters(self, model, use_4bit = False): + """Prints the number of trainable parameters in the model. + :param model: PEFT model + """ + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + num_params = param.numel() + if num_params == 0 and hasattr(param, "ds_numel"): + num_params = param.ds_numel + all_param += num_params + if param.requires_grad: + trainable_params += num_params + + if use_4bit: + trainable_params /= 2 + print( + f"All Parameters: {all_param:,d} || Trainable Parameters: {trainable_params:,d} || Trainable Parameters %: {100 * trainable_params / all_param}" + ) + + def fine_tune(self, + model, + tokenizer, + train_ds = None, + val_ds = None, + lora_r = 256, + lora_alpha = 256, + lora_dropout = 0.1, + learning_rate = 2e-4, + per_device_train_batch_size = 32, + train_epochs = 10, + output_dir = None, + bias = 'none', + target_modules="all-linear", + task_type = "CAUSAL_LM", + max_seq_length = 1024, + do_val = True, + response_key = '', + labels = None, + label2id = None): + print('fine-tuning....') + # Enable gradient checkpointing to reduce memory usage during fine-tuning + model.gradient_checkpointing_enable() + # Prepare the model for training + model = prepare_model_for_kbit_training(model) + + peft_config = LoraConfig( + r = lora_r, + lora_alpha = lora_alpha, + target_modules = target_modules, + lora_dropout = lora_dropout, + bias = bias, + task_type = task_type, + ) + model = get_peft_model(model, peft_config) + + # Print information about the percentage of trainable parameters + self.print_trainable_parameters(model) + + args = TrainingArguments( + output_dir = output_dir, + num_train_epochs=train_epochs, + per_device_train_batch_size = per_device_train_batch_size, + per_device_eval_batch_size=per_device_train_batch_size, + gradient_accumulation_steps = 8, + learning_rate = learning_rate, + logging_steps=10, + fp16 = True, + weight_decay=0.001, + max_grad_norm=0.3, # max gradient norm based on QLoRA paper + max_steps=-1, + warmup_ratio=0.03, # warmup ratio based on QLoRA paper + group_by_length=True, + lr_scheduler_type="cosine", # use cosine learning rate scheduler + report_to="tensorboard", # report metrics to tensorboard + save_strategy="epoch", # save checkpoint every epoch + gradient_checkpointing=True, # use gradient checkpointing to save memory + optim="paged_adamw_32bit", + ) + + data_collator = DataCollatorForLanguageModeling(tokenizer, mlm = False) + + trainer = SFTTrainer( + model=model, + args=args, + train_dataset=train_ds, + peft_config=peft_config, + dataset_text_field="text", + tokenizer=tokenizer, + packing=False, + max_seq_length=max_seq_length, + data_collator = data_collator, + dataset_kwargs={ + "add_special_tokens": False, + "append_concat_token": False, + } + ) + + model.config.use_cache = False + do_train = True + + # Launch training and log metrics + print("Training...") + if do_train: + train_result = trainer.train() + metrics = train_result.metrics + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_model() + tokenizer.save_pretrained(output_dir) + + # Validate on the validation set and select the best checkpoint + evaluater = Evaluater() + print('Validating...') + if do_val: + best_acc = (None, None) + for d in output_dir.iterdir(): + if d.is_dir() and d.stem.startswith('checkpoint'): + print(d) + acc = evaluater.evaluate(val_ds, model_dir=d, labels=labels, label2id=label2id, response_key=response_key, is_val=True) + if best_acc == (None, None): + best_acc = (d.stem, acc) + else: + if acc > best_acc[1]: + best_acc = (d.stem, acc) + #copy the best ckp to output_dir + print('The best checkpoint is: ', best_acc[0], '..copying') + for f in (output_dir/best_acc[0]).iterdir(): + if (output_dir/f.name).exists(): + (output_dir/f.name).unlink() + shutil.copy(f, output_dir/f.name) + + + + + \ No newline at end of file diff --git a/tasks/edit_intent_classification/finetuning_llm_gen/model_loader.py b/tasks/edit_intent_classification/finetuning_llm_gen/model_loader.py new file mode 100644 index 0000000..835479a --- /dev/null +++ b/tasks/edit_intent_classification/finetuning_llm_gen/model_loader.py @@ -0,0 +1,32 @@ +from pathlib import Path +import torch +from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM + +class ModelLoader: + def __init__(self) -> None: + self.bnb_config = BitsAndBytesConfig( + load_in_4bit = True, # Activate 4-bit precision base model loading + bnb_4bit_use_double_quant = True, # Activate nested quantization for 4-bit base models (double quantization) + bnb_4bit_quant_type = "nf4",# Quantization type (fp4 or nf4) + bnb_4bit_compute_dtype = torch.bfloat16, # Compute data type for 4-bit base models + ) + + def load_model_from_path(self, model_path:str, device_map='auto', + labels=None, label2id=None, id2label=None, + emb_type=None, input_type=None): + + print('Loading model from...', model_path) + + tokenizer = AutoTokenizer.from_pretrained(model_path) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = 'right' + + model = AutoModelForCausalLM.from_pretrained(model_path, + quantization_config = self.bnb_config, + device_map = device_map, + ) + if 'Mistral' in model_path: + model.config.sliding_window = 4096 + return model, tokenizer + + diff --git a/tasks/edit_intent_classification/finetuning_llm_seqc/__init__.py b/tasks/edit_intent_classification/finetuning_llm_seqc/__init__.py new file mode 100644 index 0000000..354aede --- /dev/null +++ b/tasks/edit_intent_classification/finetuning_llm_seqc/__init__.py @@ -0,0 +1,11 @@ +from .model_loader import ModelLoader +from .data_preprocessor import DataPreprocessor +from .model_finetuner import ModelFinetuner +from .evaluater import Evaluater + +__all__ = [ + "ModelLoader", + "DataPreprocessor" + "ModelFinetuner" + "Evaluater" + ] \ No newline at end of file diff --git a/tasks/edit_intent_classification/finetuning_llm_seqc/data_preprocessor.py b/tasks/edit_intent_classification/finetuning_llm_seqc/data_preprocessor.py new file mode 100644 index 0000000..702c3ff --- /dev/null +++ b/tasks/edit_intent_classification/finetuning_llm_seqc/data_preprocessor.py @@ -0,0 +1,50 @@ +class DataPreprocessor: + def __init__(self) -> None: + print('Preprocessing the data...SeqC') + + def preprocess_data(self, dataset, label2id, tokenizer, max_length=1024, input_type='text_st_on'): + """ + :param tokenizer (AutoTokenizer): Model tokenizer + :param max_length (int): Maximum number of tokens to emit from the tokenizer + :param input_type (str): Type of input text + """ + # perpare input text and label + self.label2id = label2id + self.tokenizer = tokenizer + self.max_length = max_length + self.input_type = input_type + print("input_type: ", input_type) + print('max_length: ', max_length) + dataset = dataset.map(lambda x: self.create_input_text_and_label(x, self.input_type),keep_in_memory=True) + # Shuffle dataset + seed = 42 + dataset = dataset.shuffle(seed = seed) + return dataset + + def create_input_text_and_label(self, sample, input_type): + """ + Creates a formatted input text for a sample + :param sample: sample from the dataset + """ + instruction = "Classify the intent of the following sentence edit. The possible labels are: Grammar, Clarity, Fact/Evidence, Claim, Other. " + text_src = sample['text_src'] if sample['text_src'] is not None else '' + text_tgt = sample['text_tgt'] if sample['text_tgt'] is not None else '' + + if input_type == 'text_st_on': + sample["text"] = f" {text_tgt} " + '\n ' + f" {text_src} " + elif input_type == 'text_nl_on': + sample["text"] = text_tgt + '\n ' + text_src + elif input_type == 'inst_text_st_on': + sample["text"] = instruction + '\n ' + f" {text_tgt} " + '\n ' + f" {text_src} " + elif input_type == 'inst_text_nl_on': + sample["text"] = instruction + '\n ' + text_tgt + '\n ' + text_src + else: + raise ValueError("Invalid input type. Choose from ['text_st_on','text_nl_on','inst_text_st_on', 'inst_text_st_on']") + + label = self.label2id[sample['label']] + sample["label"] = label + + sample['input_ids_text'] = self.tokenizer.encode_plus(sample["text"], max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")["input_ids"] + sample['attention_mask_text'] = self.tokenizer.encode_plus(sample["text"], max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")["attention_mask"] + return sample + \ No newline at end of file diff --git a/tasks/edit_intent_classification/finetuning_llm_seqc/evaluater.py b/tasks/edit_intent_classification/finetuning_llm_seqc/evaluater.py new file mode 100644 index 0000000..af6f157 --- /dev/null +++ b/tasks/edit_intent_classification/finetuning_llm_seqc/evaluater.py @@ -0,0 +1,133 @@ +from pathlib import Path +import torch +import numpy as np +from tqdm import tqdm +import json +import pandas as pd +from sklearn.metrics import accuracy_score, classification_report +from peft import AutoPeftModelForSequenceClassification +from transformers import AutoTokenizer +from .model_finetuner import collate_fn + +class Evaluater: + def __init__(self) -> None: + print('Evaluating the model...') + + def merge_model(self, finetuned_model_dir:Path, labels, label2id, id2label): + tokenizer = AutoTokenizer.from_pretrained(str(finetuned_model_dir)) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = 'right' + compute_dtype = getattr(torch, "float16") + model = AutoPeftModelForSequenceClassification.from_pretrained( + str(finetuned_model_dir)+'/', + torch_dtype=compute_dtype, + return_dict=False, + low_cpu_mem_usage=True, + device_map='auto', + num_labels = len(labels), + ) + + model.config.id2label = id2label + model.config.label2id = label2id + + model = model.merge_and_unload() + model.to('cuda:0') + model.config.pad_token_id = tokenizer.pad_token_id + + if 'mistral' in str(finetuned_model_dir): + model.config.sliding_window = 4096 + + return model, tokenizer + + def predict(self, test, model, id2label, output_dir): + eval_file = output_dir / "eval_pred.csv" + print('eval_file', eval_file) + if eval_file.exists(): + eval_file.unlink() + + for i in tqdm(range(len(test))): + inputs = collate_fn([test[i]], device='cuda:0') + with torch.no_grad(): + logits = model(**inputs, return_dict=True).logits.to('cuda:0') + + predicted_class_id = logits.argmax().item() + pred = id2label[int(predicted_class_id)] + + a = {'doc_name':[test[i]['doc_name']], + 'node_ix_src':[test[i]['node_ix_src']], + 'node_ix_tgt':[test[i]['node_ix_tgt']], + 'true':[id2label[test[i]['label']]], + 'pred':[pred], + } + a = pd.DataFrame(a) + a.to_csv(eval_file,mode="a",index=False,header=not eval_file.exists()) + + + def evaluate(self, test, model=None, tokenizer=None, model_dir=None, output_dir=None, do_predict = True, + labels=None, label2id=None, id2label=None, + emb_type=None, input_type=None, response_key = None): + """ + Evaluate the model on the test set + :param test: Test set + :param model: Hugging Face model, the fine-tuned model + :param tokenizer: Model tokenizer + :param model_dir: Directory containing the fine-tuned model + :param output_dir: Directory to save the evaluation results + :param do_predict: Whether to predict the labels + :param labels: List of labels + :param label2id: Dictionary mapping labels to ids + :param id2label: Dictionary mapping ids to labels + :param emb_type: transformation function, None for SeqC + :param input_type: Type of input text + :param response_key: Response key, None for SeqC + """ + # load the model + if model is None or tokenizer is None: + model, tokenizer = self.merge_model(model_dir, labels, label2id, id2label) + + start_time = pd.Timestamp.now() + if output_dir is None: + output_dir = Path(model_dir) + if do_predict: + self.predict(test, model, id2label, output_dir) + end_time = pd.Timestamp.now() + inference_time = end_time - start_time + inference_time = inference_time.total_seconds() + + + df = pd.read_csv(output_dir / "eval_pred.csv") + none_nr = len(df[df['pred'] == 'none']) + assert none_nr == 0, f'None labels found in the predictions: {none_nr}' + total_nr = len(df) + + eff = round((total_nr / int(inference_time)), 1) + with open (output_dir / "inference_time.json", 'w') as f: + json.dump({'inference_time':int(inference_time), 'inference_efficieny':eff}, f, indent=4) + + df = df[df['pred'] != 'none'] + y_pred = df["pred"] + y_true = df["true"] + print(df) + + # Map labels to ids + map_func = lambda label: label2id[label] + y_true = np.vectorize(map_func)(y_true) + y_pred = np.vectorize(map_func)(y_pred) + + # Calculate accuracy + accuracy = accuracy_score(y_true=y_true, y_pred=y_pred) + print(f'Accuracy: {accuracy:.3f}') + + # Generate classification report + class_report = classification_report(y_true=y_true, y_pred=y_pred, target_names=labels, output_dict=True, zero_division=0) + print('\nClassification Report:') + class_report['none_nr'] = none_nr + class_report['AIR'] = round(((total_nr - none_nr) / total_nr)*100, 1) + print(class_report) + + eval_file = output_dir / "eval_report.json" + if eval_file.exists(): + eval_file.unlink() + with open(str(eval_file), 'w') as f: + json.dump(class_report, f, indent=4) + \ No newline at end of file diff --git a/tasks/edit_intent_classification/finetuning_llm_seqc/model_finetuner.py b/tasks/edit_intent_classification/finetuning_llm_seqc/model_finetuner.py new file mode 100644 index 0000000..970acf8 --- /dev/null +++ b/tasks/edit_intent_classification/finetuning_llm_seqc/model_finetuner.py @@ -0,0 +1,151 @@ +import torch +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training +import numpy as np +from trl import SFTTrainer +from transformers import TrainingArguments +import evaluate +accuracy = evaluate.load("accuracy") + + +def collate_fn(examples, device=None): + for example in examples: + example['input_ids'] = torch.as_tensor(example['input_ids_text']) + example['attention_mask'] = torch.as_tensor(example['attention_mask_text']) + example['label'] = torch.as_tensor(example['label']) + input_ids = torch.stack([torch.as_tensor(example["input_ids"]) for example in examples]) + attention_masks = torch.stack([torch.as_tensor(example["attention_mask"]) for example in examples]) + input_ids = torch.squeeze(input_ids, dim=1) + attention_masks = torch.squeeze(attention_masks, dim=1) + labels = torch.stack([torch.as_tensor(example["label"]) for example in examples]) + + if device is not None: + input_ids = input_ids.to(device) + attention_masks = attention_masks.to(device) + labels = labels.to(device) + return {"input_ids": input_ids, + "attention_mask": attention_masks, + "labels": labels} + +def compute_metrics(eval_pred): + predictions, labels = eval_pred + predictions = np.argmax(predictions, axis=1) + return accuracy.compute(predictions=predictions, references=labels) + +class ModelFinetuner: + def __init__(self) -> None: + '' + + def print_trainable_parameters(self, model, use_4bit = False): + """Prints the number of trainable parameters in the model. + :param model: PEFT model + """ + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + num_params = param.numel() + if num_params == 0 and hasattr(param, "ds_numel"): + num_params = param.ds_numel + all_param += num_params + if param.requires_grad: + trainable_params += num_params + + if use_4bit: + trainable_params /= 2 + print( + f"All Parameters: {all_param:,d} || Trainable Parameters: {trainable_params:,d} || Trainable Parameters %: {100 * trainable_params / all_param}" + ) + + def fine_tune(self, + model, + tokenizer, + train_ds = None, + val_ds = None, + lora_r = 128, + lora_alpha = 128, + lora_dropout = 0.1, + learning_rate = 2e-4, + per_device_train_batch_size = 32, + train_epochs = 10, + output_dir = None, + bias = 'none', + target_modules="all-linear", + task_type = "SEQ_CLS", + max_seq_length = 4096 + ): + print('fine-tuning....') + # Enable gradient checkpointing to reduce memory usage during fine-tuning + model.gradient_checkpointing_enable() + + # Prepare the model for training + model = prepare_model_for_kbit_training(model) + + peft_config = LoraConfig( + r = lora_r, + lora_alpha = lora_alpha, + target_modules = target_modules, + lora_dropout = lora_dropout, + bias = bias, + task_type = task_type, + ) + model = get_peft_model(model, peft_config) + + # Print information about the percentage of trainable parameters + self.print_trainable_parameters(model) + + args = TrainingArguments( + output_dir = output_dir, + num_train_epochs=train_epochs, + per_device_train_batch_size = per_device_train_batch_size, + per_device_eval_batch_size=per_device_train_batch_size, + gradient_accumulation_steps = 8, + learning_rate = learning_rate, + logging_steps=10, + fp16 = True, + weight_decay=0.001, + max_grad_norm=0.3, # max gradient norm based on QLoRA paper + max_steps=-1, + warmup_ratio=0.03, # warmup ratio based on QLoRA paper + group_by_length=True, + lr_scheduler_type="cosine", # use cosine learning rate scheduler + report_to="tensorboard", # report metrics to tensorboard + evaluation_strategy="epoch", # save checkpoint every epoch + save_strategy="epoch", + gradient_checkpointing=True, # use gradient checkpointing to save memory + optim="paged_adamw_32bit", + remove_unused_columns=False, + load_best_model_at_end=True, + metric_for_best_model="eval_accuracy", + label_names = ['labels'], + ) + + trainer = SFTTrainer( + model=model, + args=args, + train_dataset=train_ds, + eval_dataset=val_ds, + compute_metrics=compute_metrics, + peft_config=peft_config, + dataset_text_field="text", + tokenizer=tokenizer, + packing=False, + max_seq_length=max_seq_length, + data_collator = collate_fn, + dataset_kwargs={ + "add_special_tokens": False, + "append_concat_token": False, + } + ) + + model.config.use_cache = False + do_train = True + + # Launch training and log metrics + print("Training...") + if do_train: + train_result = trainer.train() + metrics = train_result.metrics + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_model() + tokenizer.save_pretrained(output_dir) + diff --git a/tasks/edit_intent_classification/finetuning_llm_seqc/model_loader.py b/tasks/edit_intent_classification/finetuning_llm_seqc/model_loader.py new file mode 100644 index 0000000..97346cd --- /dev/null +++ b/tasks/edit_intent_classification/finetuning_llm_seqc/model_loader.py @@ -0,0 +1,37 @@ +from pathlib import Path +import torch +from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForSequenceClassification + + +class ModelLoader: + def __init__(self) -> None: + self.bnb_config = BitsAndBytesConfig( + load_in_4bit = True, # Activate 4-bit precision base model loading + bnb_4bit_use_double_quant = True, # Activate nested quantization for 4-bit base models (double quantization) + bnb_4bit_quant_type = "nf4",# Quantization type (fp4 or nf4) + bnb_4bit_compute_dtype = torch.bfloat16, # Compute data type for 4-bit base models + ) + + + def load_model_from_path(self, model_path:str, device_map='auto', + labels=None, label2id=None, id2label=None, + emb_type=None, input_type=None): + + print('Loading model from...', model_path) + + tokenizer = AutoTokenizer.from_pretrained(model_path) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = 'right' + + model = AutoModelForSequenceClassification.from_pretrained(model_path, + quantization_config = self.bnb_config, + device_map = device_map, + num_labels = len(labels), + ) + model.config.pad_token_id = tokenizer.pad_token_id + model.config.id2label = id2label + model.config.label2id = label2id + if 'Mistral' in model_path: + model.config.sliding_window = 4096 + return model, tokenizer + diff --git a/tasks/edit_intent_classification/finetuning_llm_snet/__init__.py b/tasks/edit_intent_classification/finetuning_llm_snet/__init__.py new file mode 100644 index 0000000..354aede --- /dev/null +++ b/tasks/edit_intent_classification/finetuning_llm_snet/__init__.py @@ -0,0 +1,11 @@ +from .model_loader import ModelLoader +from .data_preprocessor import DataPreprocessor +from .model_finetuner import ModelFinetuner +from .evaluater import Evaluater + +__all__ = [ + "ModelLoader", + "DataPreprocessor" + "ModelFinetuner" + "Evaluater" + ] \ No newline at end of file diff --git a/tasks/edit_intent_classification/finetuning_llm_snet/data_preprocessor.py b/tasks/edit_intent_classification/finetuning_llm_snet/data_preprocessor.py new file mode 100644 index 0000000..63d180f --- /dev/null +++ b/tasks/edit_intent_classification/finetuning_llm_snet/data_preprocessor.py @@ -0,0 +1,58 @@ +class DataPreprocessor: + def __init__(self) -> None: + ''' + ''' + + print('Preprocessing the data ...SNet') + + def preprocess_data(self, dataset, label2id, tokenizer, max_length=1024, input_type='text_st_on'): + """ + :param dataset: Hugging Face dataset + :param label2id (dict): label to id mapping + :param tokenizer (AutoTokenizer): Model tokenizer + :param max_length (int): Maximum number of tokens for padding and truncation + :param input_type (str): type of input text + """ + # perpare input text and label + self.label2id = label2id + self.tokenizer = tokenizer + self.max_length = max_length + self.input_type = input_type + print("input_type: ", input_type) + print('max_length: ', max_length) + dataset = dataset.map(self.create_input_text_and_label, keep_in_memory=True) + # Shuffle dataset + seed = 42 + dataset = dataset.shuffle(seed = seed) + return dataset + + def create_input_text_and_label(self, sample): + """ + Creates a formatted input text for a sample + :param sample: sample from the dataset + """ + text_src = sample['text_src'] if sample['text_src'] is not None else '' + text_tgt = sample['text_tgt'] if sample['text_tgt'] is not None else '' + + if self.input_type == 'text_st_on': + sample["text_old"] = f" {text_tgt} " + sample["text_new"] = f" {text_src} " + elif self.input_type == 'text_nl_on': + sample["text_old"] = text_tgt + sample["text_new"] = text_src + else: + raise ValueError("Invalid input type. Choose from ['text_st_on','text_nl_on']") + + sample["text"] = sample["text_old"] + ' ' + sample["text_new"] + + input_old = self.tokenizer.encode_plus(sample["text_old"], max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt") + input_new = self.tokenizer.encode_plus(sample["text_new"], max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt") + + sample["input_ids_tuple"] = (input_old["input_ids"], input_new["input_ids"]) + sample["attention_mask_tuple"] = (input_old["attention_mask"], input_new["attention_mask"]) + label = self.label2id[sample['label']] + sample["label"] = label + + return sample + + \ No newline at end of file diff --git a/tasks/edit_intent_classification/finetuning_llm_snet/evaluater.py b/tasks/edit_intent_classification/finetuning_llm_snet/evaluater.py new file mode 100644 index 0000000..db4537a --- /dev/null +++ b/tasks/edit_intent_classification/finetuning_llm_snet/evaluater.py @@ -0,0 +1,135 @@ +from pathlib import Path +import torch +import numpy as np +from tqdm import tqdm +import json +import pandas as pd +from torch import nn +from transformers import AutoTokenizer +from sklearn.metrics import accuracy_score, classification_report +from .model_finetuner import collate_fn +from .modelling_llama import LlamaForSequenceClassificationSiamese +from .modelling_peft import AutoPeftModelForSequenceClassificationSiamese + +class Evaluater: + def __init__(self) -> None: + print('Evaluating the model...') + + def merge_model(self, finetuned_model_dir:Path, labels, label2id, id2label, emb_type=None): + tokenizer = AutoTokenizer.from_pretrained(str(finetuned_model_dir)) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = 'right' + compute_dtype = getattr(torch, "float16") + + model = AutoPeftModelForSequenceClassificationSiamese.from_pretrained( + str(finetuned_model_dir)+'/', + torch_dtype=compute_dtype, + return_dict=False, + low_cpu_mem_usage=True, + device_map='auto', + num_labels = len(labels), + emb_type=emb_type + ) + + model.config.id2label = id2label + model.config.label2id = label2id + model = model.merge_and_unload() + model.to('cuda:0') + + model = LlamaForSequenceClassificationSiamese(model.config, model=model.model, score=model.score, emb_type=emb_type) + model.config.pad_token_id = tokenizer.pad_token_id + print('model',model) + model.to('cuda:0') + return model, tokenizer + + + def predict(self, test, model, id2label, output_dir): + eval_file = output_dir / "eval_pred.csv" + print('eval_file', eval_file) + if eval_file.exists(): + eval_file.unlink() + + for i in tqdm(range(len(test))): + inputs = collate_fn([test[i]], device='cuda:0') + with torch.no_grad(): + logits = model(**inputs, return_dict=True).logits.to('cuda:0') + predicted_class_id = logits.argmax().item() + pred = id2label[int(predicted_class_id)] + + a = {'doc_name':[test[i]['doc_name']], + 'node_ix_src':[test[i]['node_ix_src']], + 'node_ix_tgt':[test[i]['node_ix_tgt']], + 'true':[id2label[test[i]['label']]], + 'pred':[pred], + } + a = pd.DataFrame(a) + a.to_csv(eval_file,mode="a",index=False,header=not eval_file.exists()) + + + def evaluate(self, test, model=None, tokenizer=None, model_dir=None, output_dir=None, do_predict = True, + labels=None, label2id=None, id2label=None, + emb_type=None, input_type=None, response_key = None): + """ + Evaluate the model on the test set + :param test: Test set + :param model: Hugging Face model, the fine-tuned model + :param tokenizer: Model tokenizer + :param model_dir: Directory containing the fine-tuned model + :param output_dir: Directory to save the evaluation results + :param do_predict: Whether to predict the labels + :param labels: List of labels + :param label2id: Dictionary mapping labels to ids + :param id2label: Dictionary mapping ids to labels + :param emb_type: transformation function + :param input_type: Type of input text + :param response_key: Response key, None for SNet + """ + # load the model + if model is None or tokenizer is None: + model, tokenizer = self.merge_model(model_dir, labels, label2id, id2label, emb_type=emb_type) + + start_time = pd.Timestamp.now() + if output_dir is None: + output_dir = model_dir + if do_predict: + self.predict(test, model, id2label, output_dir) + end_time = pd.Timestamp.now() + inference_time = end_time - start_time + inference_time = inference_time.total_seconds() + + df = pd.read_csv(output_dir / "eval_pred.csv") + none_nr = len(df[df['pred'] == 'none']) + assert none_nr == 0, f'None labels found in the predictions: {none_nr}' + total_nr = len(df) + + eff = round((total_nr / int(inference_time)), 1) + with open (output_dir / "inference_time.json", 'w') as f: + json.dump({'inference_time':int(inference_time), 'inference_efficieny':eff}, f, indent=4) + + df = df[df['pred'] != 'none'] + y_pred = df["pred"] + y_true = df["true"] + print(df) + + # Map labels to ids + map_func = lambda label: label2id[label] + y_true = np.vectorize(map_func)(y_true) + y_pred = np.vectorize(map_func)(y_pred) + + # Calculate accuracy + accuracy = accuracy_score(y_true=y_true, y_pred=y_pred) + print(f'Accuracy: {accuracy:.3f}') + + # Generate classification report + class_report = classification_report(y_true=y_true, y_pred=y_pred, target_names=labels, output_dict=True, zero_division=0) + print('\nClassification Report:') + class_report['none_nr'] = none_nr + class_report['AIR'] = round(((total_nr - none_nr) / total_nr)*100, 1) + print(class_report) + + eval_file = output_dir / "eval_report.json" + if eval_file.exists(): + eval_file.unlink() + with open(str(eval_file), 'w') as f: + json.dump(class_report, f, indent=4) + diff --git a/tasks/edit_intent_classification/finetuning_llm_snet/model_finetuner.py b/tasks/edit_intent_classification/finetuning_llm_snet/model_finetuner.py new file mode 100644 index 0000000..f17c3e3 --- /dev/null +++ b/tasks/edit_intent_classification/finetuning_llm_snet/model_finetuner.py @@ -0,0 +1,147 @@ +import torch +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training +import bitsandbytes as bnb +import numpy as np +from trl import SFTTrainer +from transformers import TrainingArguments +import evaluate +accuracy = evaluate.load("accuracy") + + +def collate_fn(examples, device=None): + for example in examples: + example['input_ids'] = torch.as_tensor(example['input_ids_tuple']) + example['attention_mask'] = torch.as_tensor(example['attention_mask_tuple']) + example['label'] = torch.as_tensor(example['label']) + input_ids = torch.stack([example["input_ids"] for example in examples]) + attention_masks = torch.stack([example["attention_mask"] for example in examples]) + labels = torch.stack([example["label"] for example in examples]) + if device is not None: + input_ids = input_ids.to(device) + attention_masks = attention_masks.to(device) + labels = labels.to(device) + return {"input_ids": input_ids, + "attention_mask": attention_masks, + "labels": labels} + + +def compute_metrics(eval_pred): + predictions, labels = eval_pred + predictions = np.argmax(predictions, axis=1) + return accuracy.compute(predictions=predictions, references=labels) + +class ModelFinetuner: + def __init__(self) -> None: + '' + + def print_trainable_parameters(self, model, use_4bit = False): + """Prints the number of trainable parameters in the model. + :param model: PEFT model + """ + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + num_params = param.numel() + if num_params == 0 and hasattr(param, "ds_numel"): + num_params = param.ds_numel + all_param += num_params + if param.requires_grad: + trainable_params += num_params + + if use_4bit: + trainable_params /= 2 + print( + f"All Parameters: {all_param:,d} || Trainable Parameters: {trainable_params:,d} || Trainable Parameters %: {100 * trainable_params / all_param}" + ) + + def fine_tune(self, + model, + tokenizer, + train_ds = None, + val_ds = None, + lora_r = 128, + lora_alpha = 128, + lora_dropout = 0.1, + learning_rate = 2e-4, + per_device_train_batch_size = 32, + train_epochs = 10, + output_dir = None, + bias = 'none', + target_modules="all-linear", + task_type = None, + max_seq_length = 4096 + ): + print('fine-tuning....') + # Prepare the model for training + model = prepare_model_for_kbit_training(model) + + peft_config = LoraConfig( + r = lora_r, + lora_alpha = lora_alpha, + target_modules = target_modules, + lora_dropout = lora_dropout, + bias = bias, + task_type = task_type, + modules_to_save = ['score'] + ) + model = get_peft_model(model, peft_config) + # Print information about the percentage of trainable parameters + self.print_trainable_parameters(model) + + args = TrainingArguments( + output_dir = output_dir, + num_train_epochs=train_epochs, + per_device_train_batch_size = per_device_train_batch_size, + per_device_eval_batch_size=per_device_train_batch_size, + gradient_accumulation_steps = 8, + learning_rate = learning_rate, + logging_steps=10, + fp16 = True, + weight_decay=0.001, + max_grad_norm=0.3, # max gradient norm based on QLoRA paper + max_steps=-1, + warmup_ratio=0.03, # warmup ratio based on QLoRA paper + group_by_length=True, + lr_scheduler_type="cosine", # use cosine learning rate scheduler + report_to="tensorboard", # report metrics to tensorboard + evaluation_strategy="epoch", # save checkpoint every epoch + save_strategy="epoch", + gradient_checkpointing=True, # use gradient checkpointing to save memory + optim="paged_adamw_32bit", + remove_unused_columns=False, + load_best_model_at_end=True, + metric_for_best_model="eval_accuracy", + label_names = ['labels'], + ) + + trainer = SFTTrainer( + model=model, + args=args, + train_dataset=train_ds, + eval_dataset=val_ds, + compute_metrics=compute_metrics, + peft_config=peft_config, + dataset_text_field="text", + tokenizer=tokenizer, + packing=False, + max_seq_length=max_seq_length, + data_collator = collate_fn, + dataset_kwargs={ + "add_special_tokens": False, + "append_concat_token": False, + } + ) + + model.config.use_cache = False + do_train = True + + # Launch training and log metrics + print("Training...") + + if do_train: + train_result = trainer.train() + metrics = train_result.metrics + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_model() + tokenizer.save_pretrained(output_dir) diff --git a/tasks/edit_intent_classification/finetuning_llm_snet/model_loader.py b/tasks/edit_intent_classification/finetuning_llm_snet/model_loader.py new file mode 100644 index 0000000..43a9a95 --- /dev/null +++ b/tasks/edit_intent_classification/finetuning_llm_snet/model_loader.py @@ -0,0 +1,36 @@ +from pathlib import Path +import torch +from transformers import AutoTokenizer, BitsAndBytesConfig +from .modelling_llama import LlamaForSequenceClassificationSiamese + +class ModelLoader: + def __init__(self, num_cls_layers =1) -> None: + print('Loading the model...') + self.bnb_config = BitsAndBytesConfig( + load_in_4bit = True, # Activate 4-bit precision base model loading + bnb_4bit_use_double_quant = True, # Activate nested quantization for 4-bit base models (double quantization) + bnb_4bit_quant_type = "nf4",# Quantization type (fp4 or nf4) + bnb_4bit_compute_dtype = torch.bfloat16, # Compute data type for 4-bit base models + ) + + def load_model_from_path(self, model_path:str, device_map='auto', + labels=None, label2id=None, id2label=None, + emb_type=None, input_type=None): + + print('Loading model from...', model_path) + + tokenizer = AutoTokenizer.from_pretrained(model_path) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = 'right' + + model = LlamaForSequenceClassificationSiamese.from_pretrained(model_path, + quantization_config = self.bnb_config, + device_map = device_map, + num_labels = len(labels), + emb_type=emb_type, + ) + model.config.pad_token_id = tokenizer.pad_token_id + model.config.id2label = id2label + model.config.label2id = label2id + + return model, tokenizer diff --git a/tasks/edit_intent_classification/finetuning_llm_snet/modelling_llama.py b/tasks/edit_intent_classification/finetuning_llm_snet/modelling_llama.py new file mode 100644 index 0000000..3075d16 --- /dev/null +++ b/tasks/edit_intent_classification/finetuning_llm_snet/modelling_llama.py @@ -0,0 +1,174 @@ +# Built upon the huggingface implementation + +from typing import List, Optional, Tuple, Union +from transformers import LlamaModel, LlamaPreTrainedModel +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_outputs import SequenceClassifierOutputWithPast +from transformers.utils import logging + +logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "LlamaConfig" + + +class LlamaForSequenceClassificationSiamese(LlamaPreTrainedModel): + def __init__(self, config, model=None, score=None, emb_type=None): + super().__init__(config) + self.num_labels = config.num_labels + self.emb_type = emb_type + + if model is None: + self.model = LlamaModel(config) + else: + self.model = model + + if score is None: + if self.emb_type in['diff','diffABS']: + input_size = config.hidden_size + elif self.emb_type in ['n-o','n-diffABS']: + input_size = config.hidden_size*2 + elif self.emb_type in ['n-diffABS-o']: + input_size = config.hidden_size*3 + else: + raise ValueError("invalid emb_type") + self.score = nn.Linear(input_size, self.num_labels, bias=False) + + else: + self.score = score + # Initialize weights and apply final processing + self.post_init() + + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs_old = self.model( + torch.squeeze(input_ids[:,0,:,:],1), + attention_mask=torch.squeeze(attention_mask[:,0,:,:],1), + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + transformer_outputs_new = self.model( + torch.squeeze(input_ids[:,1,:,:],1), + attention_mask=torch.squeeze(attention_mask[:,1,:,:],1), + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + + if input_ids is not None: + batch_size = torch.squeeze(input_ids[:,0,:,:],1).shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths_old = -1 + sequence_lengths_new = -1 + else: + if input_ids is not None: + sequence_lengths_old = (torch.eq(torch.squeeze(input_ids[:,0,:,:],1), self.config.pad_token_id).long().argmax(-1) - 1) + sequence_lengths_new = (torch.eq(torch.squeeze(input_ids[:,1,:,:],1), self.config.pad_token_id).long().argmax(-1) - 1) + else: + sequence_lengths_old = -1 + sequence_lengths_new = -1 + + hidden_states_old = transformer_outputs_old[0] + hidden_states_new = transformer_outputs_new[0] + sequence_lengths_old = sequence_lengths_old.to(hidden_states_old.device) + sequence_lengths_new = sequence_lengths_new.to(hidden_states_new.device) + # get the last token embedding as sentence embedding + hidden_states_old = hidden_states_old[torch.arange(batch_size), sequence_lengths_old] + hidden_states_new = hidden_states_new[torch.arange(batch_size), sequence_lengths_new] + + + if self.emb_type == 'diff': + hidden_states = torch.as_tensor(hidden_states_new - hidden_states_old) + elif self.emb_type == 'diffABS': + hidden_states = torch.abs(torch.as_tensor(hidden_states_new - hidden_states_old)) + elif self.emb_type == 'n-diffABS': + diff = torch.abs(torch.as_tensor(hidden_states_new - hidden_states_old)) + hidden_states = torch.cat((hidden_states_new, diff),1) + elif self.emb_type == 'n-diffABS-o': + diff = torch.abs(torch.as_tensor(hidden_states_new - hidden_states_old)) + hidden_states = torch.cat((hidden_states_new, diff, hidden_states_old),1) + elif self.emb_type == 'n-o': + hidden_states = torch.cat((hidden_states_new, hidden_states_old),1) + + hidden_states = hidden_states.to(hidden_states_old.device) + hidden_states = hidden_states.type(self.score.weight.dtype) + logits = self.score(hidden_states) + pooled_logits = logits + + + loss = None + if labels is not None: + labels = labels.to(logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs_new[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs_new.past_key_values, + hidden_states=transformer_outputs_new.hidden_states, + attentions=transformer_outputs_new.attentions, + ) + + diff --git a/tasks/edit_intent_classification/finetuning_llm_snet/modelling_peft.py b/tasks/edit_intent_classification/finetuning_llm_snet/modelling_peft.py new file mode 100644 index 0000000..e7b3b27 --- /dev/null +++ b/tasks/edit_intent_classification/finetuning_llm_snet/modelling_peft.py @@ -0,0 +1,330 @@ +# Built upon the huggingface implementation +from __future__ import annotations +import inspect +import os +import warnings +from contextlib import contextmanager +from copy import deepcopy +import importlib +import os +from typing import Optional +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_outputs import SequenceClassifierOutput +from transformers import AutoTokenizer + +from peft.config import PeftConfig +from peft.mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING +from peft.peft_model import PeftModel +from peft.utils.constants import TOKENIZER_CONFIG_NAME +from peft.utils.other import check_file_exists_on_hf_hub + +from peft.tuners import ( + AdaLoraModel, + AdaptionPromptModel, + IA3Model, + LoHaModel, + LoKrModel, + LoraModel, + OFTModel, + PolyModel, + PrefixEncoder, + PromptEmbedding, + PromptEncoder, +) +from peft.utils import ( + PeftType, + _get_batch_size, + _set_trainable, +) + +PEFT_TYPE_TO_MODEL_MAPPING = { + PeftType.LORA: LoraModel, + PeftType.LOHA: LoHaModel, + PeftType.LOKR: LoKrModel, + PeftType.PROMPT_TUNING: PromptEmbedding, + PeftType.P_TUNING: PromptEncoder, + PeftType.PREFIX_TUNING: PrefixEncoder, + PeftType.ADALORA: AdaLoraModel, + PeftType.ADAPTION_PROMPT: AdaptionPromptModel, + PeftType.IA3: IA3Model, + PeftType.OFT: OFTModel, + PeftType.POLY: PolyModel, +} + + +class PeftModelForSequenceClassificationSiamese(PeftModel): + """ + Peft model for sequence classification tasks with the snet approach. + + Args: + model ([`~transformers.PreTrainedModel`]): Base transformer model. + peft_config ([`PeftConfig`]): Peft config. + + """ + + def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None: + super().__init__(model, peft_config, adapter_name) + + if self.modules_to_save is None: + self.modules_to_save = {"classifier", "score"} + else: + self.modules_to_save.update({"classifier", "score"}) + + + for name, _ in self.base_model.named_children(): + if any(module_name in name for module_name in self.modules_to_save): + self.cls_layer_name = name + break + + # to make sure classifier layer is trainable + _set_trainable(self, adapter_name) + + def forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + task_ids=None, + **kwargs, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + peft_config = self.active_peft_config + if not peft_config.is_prompt_learning: + with self._enable_peft_forward_hooks(**kwargs): + kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args} + if peft_config.peft_type == PeftType.POLY: + kwargs["task_ids"] = task_ids + return self.base_model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + labels=labels, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs, + ) + + batch_size = _get_batch_size(input_ids, inputs_embeds) + if attention_mask is not None: + # concat prompt attention mask + prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(attention_mask.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + if kwargs.get("position_ids", None) is not None: + warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.") + kwargs["position_ids"] = None + kwargs.update( + { + "attention_mask": attention_mask, + "labels": labels, + "output_attentions": output_attentions, + "output_hidden_states": output_hidden_states, + "return_dict": return_dict, + } + ) + + if peft_config.peft_type == PeftType.PREFIX_TUNING: + return self._prefix_tuning_forward(input_ids=input_ids, **kwargs) + else: + if kwargs.get("token_type_ids", None) is not None: + kwargs["token_type_ids"] = torch.cat( + ( + torch.zeros(batch_size, peft_config.num_virtual_tokens).to(self.word_embeddings.weight.device), + kwargs["token_type_ids"], + ), + dim=1, + ).long() + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids) + prompts = prompts.to(inputs_embeds.dtype) + inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) + return self.base_model(inputs_embeds=inputs_embeds, **kwargs) + + def _prefix_tuning_forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + **kwargs, + ): + batch_size = _get_batch_size(input_ids, inputs_embeds) + past_key_values = self.get_prompt(batch_size) + fwd_params = list(inspect.signature(self.base_model.forward).parameters.keys()) + kwargs.update( + { + "input_ids": input_ids, + "attention_mask": attention_mask, + "inputs_embeds": inputs_embeds, + "output_attentions": output_attentions, + "output_hidden_states": output_hidden_states, + "return_dict": return_dict, + "past_key_values": past_key_values, + } + ) + if "past_key_values" in fwd_params: + return self.base_model(labels=labels, **kwargs) + else: + transformer_backbone_name = self.base_model.get_submodule(self.transformer_backbone_name) + fwd_params = list(inspect.signature(transformer_backbone_name.forward).parameters.keys()) + if "past_key_values" not in fwd_params: + raise ValueError("Model does not support past key values which are required for prefix tuning.") + outputs = transformer_backbone_name(**kwargs) + pooled_output = outputs[1] if len(outputs) > 1 else outputs[0] + if "dropout" in [name for name, _ in list(self.base_model.named_children())]: + pooled_output = self.base_model.dropout(pooled_output) + logits = self.base_model.get_submodule(self.cls_layer_name)(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.base_model.num_labels == 1: + self.config.problem_type = "regression" + elif self.base_model.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.base_model.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.base_model.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + + +class _BaseAutoPeftModelSiamese: + _target_class = None + _target_peft_class = None + + def __init__(self, *args, **kwargs): + raise EnvironmentError( + f"{self.__class__.__name__} is designed to be instantiated " + f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or " + f"`{self.__class__.__name__}.from_config(config)` methods." + ) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + adapter_name: str = "default", + is_trainable: bool = False, + config: Optional[PeftConfig] = None, + emb_type=None, + **kwargs, + ): + r""" + A wrapper around all the preprocessing steps a user needs to perform in order to load a PEFT model. The kwargs + are passed along to `PeftConfig` that automatically takes care of filtering the kwargs of the Hub methods and + the config object init. + """ + peft_config = PeftConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + base_model_path = peft_config.base_model_name_or_path + + task_type = getattr(peft_config, "task_type", None) + + if cls._target_class is not None: + target_class = cls._target_class + elif cls._target_class is None and task_type is not None: + # this is only in the case where we use `AutoPeftModel` + raise ValueError( + "Cannot use `AutoPeftModel` with a task type, please use a specific class for your task type. (e.g. `AutoPeftModelForCausalLM` for `task_type='CAUSAL_LM'`)" + ) + + if task_type is not None: + expected_target_class = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[task_type] + if cls._target_peft_class.__name__ != expected_target_class.__name__: + raise ValueError( + f"Expected target PEFT class: {expected_target_class.__name__}, but you have asked for: {cls._target_peft_class.__name__ }" + " make sure that you are loading the correct model for your task type." + ) + elif task_type is None and getattr(peft_config, "auto_mapping", None) is not None: + auto_mapping = getattr(peft_config, "auto_mapping", None) + base_model_class = auto_mapping["base_model_class"] + parent_library_name = auto_mapping["parent_library"] + + parent_library = importlib.import_module(parent_library_name) + target_class = getattr(parent_library, base_model_class) + else: + raise ValueError( + "Cannot infer the auto class from the config, please make sure that you are loading the correct model for your task type." + ) + + + base_model = target_class.from_pretrained(base_model_path, emb_type=emb_type, **kwargs) + + + tokenizer_exists = False + if os.path.exists(os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_NAME)): + tokenizer_exists = True + else: + token = kwargs.get("token", None) + if token is None: + token = kwargs.get("use_auth_token", None) + + tokenizer_exists = check_file_exists_on_hf_hub( + repo_id=pretrained_model_name_or_path, + filename=TOKENIZER_CONFIG_NAME, + revision=kwargs.get("revision", None), + repo_type=kwargs.get("repo_type", None), + token=token, + ) + + if tokenizer_exists: + tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=kwargs.get("trust_remote_code", False) + ) + base_model.resize_token_embeddings(len(tokenizer)) + + return cls._target_peft_class.from_pretrained( + base_model, + pretrained_model_name_or_path, + adapter_name=adapter_name, + is_trainable=is_trainable, + config=config, + **kwargs, + ) + +from transformers.models.auto.auto_factory import _BaseAutoModelClass +from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING +class AutoModelForSequenceClassificationSiamese(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + +class AutoPeftModelForSequenceClassificationSiamese(_BaseAutoPeftModelSiamese): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.emb_type = args.emb_type + _target_class = AutoModelForSequenceClassificationSiamese + _target_peft_class = PeftModelForSequenceClassificationSiamese + + diff --git a/tasks/edit_intent_classification/finetuning_llm_xnet/__init__.py b/tasks/edit_intent_classification/finetuning_llm_xnet/__init__.py new file mode 100644 index 0000000..f595ad4 --- /dev/null +++ b/tasks/edit_intent_classification/finetuning_llm_xnet/__init__.py @@ -0,0 +1,12 @@ +from .model_loader import ModelLoader +from .data_preprocessor import DataPreprocessor +from .model_finetuner import ModelFinetuner +from .evaluater import Evaluater + + +__all__ = [ + "ModelLoader", + "DataPreprocessor" + "ModelFinetuner" + "Evaluater" + ] \ No newline at end of file diff --git a/tasks/edit_intent_classification/finetuning_llm_xnet/data_preprocessor.py b/tasks/edit_intent_classification/finetuning_llm_xnet/data_preprocessor.py new file mode 100644 index 0000000..d94b9c1 --- /dev/null +++ b/tasks/edit_intent_classification/finetuning_llm_xnet/data_preprocessor.py @@ -0,0 +1,59 @@ +class DataPreprocessor: + def __init__(self) -> None: + ''' + ''' + + print('Preprocessing the data ...XNet') + + def preprocess_data(self, dataset, label2id, tokenizer, max_length=1024, input_type='text_st_on'): + """ + :param dataset: Hugging Face dataset + :param label2id (dict): label to id mapping + :param tokenizer (AutoTokenizer): Model tokenizer + :param max_length (int): Maximum number of tokens for padding and truncation + :param input_type (str): type of input text + """ + # perpare input text and label + self.label2id = label2id + self.tokenizer = tokenizer + self.max_length = max_length + self.input_type = input_type + print("input_type: ", input_type) + print('max_length: ', max_length) + dataset = dataset.map(self.create_input_text_and_label, keep_in_memory=True) + # Shuffle dataset + seed = 42 + dataset = dataset.shuffle(seed = seed) + return dataset + + def create_input_text_and_label(self, sample): + """ + Creates a formatted input text for a sample + :param sample: sample from the dataset + """ + text_src = sample['text_src'] if sample['text_src'] is not None else '' + text_tgt = sample['text_tgt'] if sample['text_tgt'] is not None else '' + + if self.input_type == 'text_st_on': + sample["text_old"] = f" {text_tgt} " + sample["text_new"] = f" {text_src} " + sample["text"] = f" {text_tgt} " + '\n ' + f" {text_src} " + elif self.input_type == 'text_nl_on': + sample["text_old"] = text_tgt + sample["text_new"] = text_src + sample["text"] = text_tgt + '\n ' + text_src + else: + raise ValueError("Invalid input type. Choose from ['text_st_on','text_nl_on']") + + emb1 = self.tokenizer.encode_plus(sample["text_old"],return_tensors="pt")["input_ids"] + emb2 = self.tokenizer.encode_plus(sample["text_new"],return_tensors="pt")["input_ids"] + sample['emb1'] = emb1 + sample['emb2'] = emb2 + + sample['input_ids_text'] = self.tokenizer.encode_plus(sample["text"], max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")["input_ids"] + sample['attention_mask_text'] = self.tokenizer.encode_plus(sample["text"], max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")["attention_mask"] + + label = self.label2id[sample['label']] + sample["label"] = label + + return sample diff --git a/tasks/edit_intent_classification/finetuning_llm_xnet/evaluater.py b/tasks/edit_intent_classification/finetuning_llm_xnet/evaluater.py new file mode 100644 index 0000000..2ddf904 --- /dev/null +++ b/tasks/edit_intent_classification/finetuning_llm_xnet/evaluater.py @@ -0,0 +1,135 @@ +from pathlib import Path +import torch +import numpy as np +from tqdm import tqdm +import json +import pandas as pd +from torch import nn +from transformers import AutoTokenizer +from sklearn.metrics import accuracy_score, classification_report +from .model_finetuner import collate_fn +from .modelling_llama import LlamaForSequenceClassificationCross +from .modelling_peft import AutoPeftModelForSequenceClassificationCross + +class Evaluater: + def __init__(self) -> None: + print('Evaluating the model...') + + def merge_model(self, finetuned_model_dir:Path, labels, label2id, id2label, emb_type=None): + tokenizer = AutoTokenizer.from_pretrained(str(finetuned_model_dir)) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = 'right' + compute_dtype = getattr(torch, "float16") + + model = AutoPeftModelForSequenceClassificationCross.from_pretrained( + str(finetuned_model_dir)+'/', + torch_dtype=compute_dtype, + return_dict=False, + low_cpu_mem_usage=True, + device_map='auto', + num_labels = len(labels), + emb_type=emb_type + ) + + model.config.id2label = id2label + model.config.label2id = label2id + model = model.merge_and_unload() + model.to('cuda:0') + + model = LlamaForSequenceClassificationCross(model.config, model=model.model, score=model.score, emb_type=emb_type) + model.config.pad_token_id = tokenizer.pad_token_id + model.to('cuda:0') + return model, tokenizer + + def predict(self, test, model, id2label, output_dir): + eval_file = output_dir / "eval_pred.csv" + print('eval_file', eval_file) + if eval_file.exists(): + eval_file.unlink() + + for i in tqdm(range(len(test))): + inputs = collate_fn([test[i]], device='cuda:0') + for key, value in inputs.items(): + inputs[key] = inputs[key].to('cuda:0') + with torch.no_grad(): + logits = model(**inputs, return_dict=True).logits.to('cuda:0') + predicted_class_id = logits.argmax().item() + pred = id2label[int(predicted_class_id)] + + a = {'doc_name':[test[i]['doc_name']], + 'node_ix_src':[test[i]['node_ix_src']], + 'node_ix_tgt':[test[i]['node_ix_tgt']], + 'true':[id2label[test[i]['label']]], + 'pred':[pred], + } + a = pd.DataFrame(a) + a.to_csv(eval_file,mode="a",index=False,header=not eval_file.exists()) + + + def evaluate(self, test, model=None, tokenizer=None, model_dir=None, output_dir=None, do_predict = True, + labels=None, label2id=None, id2label=None, + emb_type=None, input_type=None, response_key = None): + """ + Evaluate the model on the test set + :param test: Test set + :param model: Hugging Face model, the fine-tuned model + :param tokenizer: Model tokenizer + :param model_dir: Directory containing the fine-tuned model + :param output_dir: Directory to save the evaluation results + :param do_predict: Whether to predict the labels + :param labels: List of labels + :param label2id: Dictionary mapping labels to ids + :param id2label: Dictionary mapping ids to labels + :param emb_type: transformation function + :param input_type: Type of input text + :param response_key: Response key, None for XNet + """ + # load the model + if model is None or tokenizer is None: + model, tokenizer = self.merge_model(model_dir, labels, label2id, id2label, emb_type=emb_type) + + start_time = pd.Timestamp.now() + if output_dir is None: + output_dir = model_dir + if do_predict: + self.predict(test, model, id2label, output_dir) + end_time = pd.Timestamp.now() + inference_time = end_time - start_time + inference_time = inference_time.total_seconds() + + df = pd.read_csv(output_dir / "eval_pred.csv") + none_nr = len(df[df['pred'] == 'none']) + assert none_nr == 0, f'None labels found in the predictions: {none_nr}' + total_nr = len(df) + + eff = round((total_nr / int(inference_time)), 1) + with open (output_dir / "inference_time.json", 'w') as f: + json.dump({'inference_time':int(inference_time), 'inference_efficieny':eff}, f, indent=4) + + df = df[df['pred'] != 'none'] + y_pred = df["pred"] + y_true = df["true"] + print(df) + + # Map labels to ids + map_func = lambda label: label2id[label] + y_true = np.vectorize(map_func)(y_true) + y_pred = np.vectorize(map_func)(y_pred) + + # Calculate accuracy + accuracy = accuracy_score(y_true=y_true, y_pred=y_pred) + print(f'Accuracy: {accuracy:.3f}') + + # Generate classification report + class_report = classification_report(y_true=y_true, y_pred=y_pred, target_names=labels, output_dict=True, zero_division=0) + print('\nClassification Report:') + class_report['none_nr'] = none_nr + class_report['AIR'] = round(((total_nr - none_nr) / total_nr)*100, 1) + print(class_report) + + eval_file = output_dir / "eval_report.json" + if eval_file.exists(): + eval_file.unlink() + with open(str(eval_file), 'w') as f: + json.dump(class_report, f, indent=4) + diff --git a/tasks/edit_intent_classification/finetuning_llm_xnet/model_finetuner.py b/tasks/edit_intent_classification/finetuning_llm_xnet/model_finetuner.py new file mode 100644 index 0000000..3471249 --- /dev/null +++ b/tasks/edit_intent_classification/finetuning_llm_xnet/model_finetuner.py @@ -0,0 +1,161 @@ +import torch +from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training +import bitsandbytes as bnb +import numpy as np +from trl import SFTTrainer +from transformers import TrainingArguments +import evaluate +accuracy = evaluate.load("accuracy") + +def collate_fn(examples, device=None): + for example in examples: + example['input_ids'] = torch.as_tensor(example['input_ids_text']) + example['attention_mask'] = torch.as_tensor(example['attention_mask_text']) + example['label'] = torch.as_tensor(example['label']) + example['len_old'] = torch.as_tensor(len(example['emb1'][0])) + example['len_new'] = torch.as_tensor(len(example['emb2'][0])) + input_ids = torch.stack([example["input_ids"] for example in examples]) + input_ids = torch.squeeze(input_ids, dim=1) + attention_masks = torch.stack([example["attention_mask"] for example in examples]) + attention_masks = torch.squeeze(attention_masks, dim=1) + labels = torch.stack([example["label"] for example in examples]) + len_old = torch.stack([example["len_old"] for example in examples]) + len_new = torch.stack([example["len_new"] for example in examples]) + + if device is not None: + input_ids = input_ids.to(device) + attention_masks = attention_masks.to(device) + labels = labels.to(device) + len_old = len_old.to(device) + len_new = len_new.to(device) + + return {"input_ids": input_ids, + "attention_mask": attention_masks, + "labels": labels, + "len_old": len_old, + "len_new": len_new} + + +def compute_metrics(eval_pred): + predictions, labels = eval_pred + predictions = np.argmax(predictions, axis=1) + return accuracy.compute(predictions=predictions, references=labels) + +class ModelFinetuner: + def __init__(self) -> None: + '' + + def print_trainable_parameters(self, model, use_4bit = False): + """Prints the number of trainable parameters in the model. + :param model: PEFT model + """ + trainable_params = 0 + all_param = 0 + for _, param in model.named_parameters(): + num_params = param.numel() + if num_params == 0 and hasattr(param, "ds_numel"): + num_params = param.ds_numel + all_param += num_params + if param.requires_grad: + trainable_params += num_params + + if use_4bit: + trainable_params /= 2 + print( + f"All Parameters: {all_param:,d} || Trainable Parameters: {trainable_params:,d} || Trainable Parameters %: {100 * trainable_params / all_param}" + ) + + + def fine_tune(self, + model, + tokenizer, + train_ds = None, + val_ds = None, + lora_r = 128, + lora_alpha = 128, + lora_dropout = 0.1, + learning_rate = 2e-4, + per_device_train_batch_size = 32, + train_epochs = 10, + output_dir = None, + bias = 'none', + target_modules="all-linear", + task_type = None, + max_seq_length = 4096 + ): + print('fine-tuning....') + # Prepare the model for training + model = prepare_model_for_kbit_training(model) + + peft_config = LoraConfig( + r = lora_r, + lora_alpha = lora_alpha, + target_modules = target_modules, + lora_dropout = lora_dropout, + bias = bias, + task_type = task_type, + modules_to_save = ['score'] + ) + + model = get_peft_model(model, peft_config) + # Print information about the percentage of trainable parameters + self.print_trainable_parameters(model) + + args = TrainingArguments( + output_dir = output_dir, + num_train_epochs=train_epochs, + per_device_train_batch_size = per_device_train_batch_size, + per_device_eval_batch_size=per_device_train_batch_size, + gradient_accumulation_steps = 8, + learning_rate = learning_rate, + logging_steps=10, + fp16 = True, + weight_decay=0.001, + max_grad_norm=0.3, # max gradient norm based on QLoRA paper + max_steps=-1, + warmup_ratio=0.03, # warmup ratio based on QLoRA paper + group_by_length=True, + lr_scheduler_type="cosine", # use cosine learning rate scheduler + report_to="tensorboard", # report metrics to tensorboard + evaluation_strategy="epoch", # save checkpoint every epoch + save_strategy="epoch", + gradient_checkpointing=True, # use gradient checkpointing to save memory + optim="paged_adamw_32bit", + remove_unused_columns=False, + load_best_model_at_end=True, + metric_for_best_model="eval_accuracy", + label_names = ['labels'], + ) + + trainer = SFTTrainer( + model=model, + args=args, + train_dataset=train_ds, + eval_dataset=val_ds, + compute_metrics=compute_metrics, + peft_config=peft_config, + dataset_text_field="text", + tokenizer=tokenizer, + packing=False, + max_seq_length=max_seq_length, + data_collator = collate_fn, + dataset_kwargs={ + "add_special_tokens": False, + "append_concat_token": False, + } + ) + + model.config.use_cache = False + do_train = True + + # Launch training and log metrics + print("Training...") + + if do_train: + train_result = trainer.train(ignore_keys_for_eval=['len_old', 'len_new']) + metrics = train_result.metrics + trainer.log_metrics("train", metrics) + trainer.save_metrics("train", metrics) + trainer.save_model() + tokenizer.save_pretrained(output_dir) + diff --git a/tasks/edit_intent_classification/finetuning_llm_xnet/model_loader.py b/tasks/edit_intent_classification/finetuning_llm_xnet/model_loader.py new file mode 100644 index 0000000..4c914df --- /dev/null +++ b/tasks/edit_intent_classification/finetuning_llm_xnet/model_loader.py @@ -0,0 +1,38 @@ +from pathlib import Path +import torch +from transformers import AutoTokenizer, BitsAndBytesConfig +from .modelling_llama import LlamaForSequenceClassificationCross + +class ModelLoader: + def __init__(self, num_cls_layers =1) -> None: + print('Loading the model...') + self.bnb_config = BitsAndBytesConfig( + load_in_4bit = True, # Activate 4-bit precision base model loading + bnb_4bit_use_double_quant = True, # Activate nested quantization for 4-bit base models (double quantization) + bnb_4bit_quant_type = "nf4",# Quantization type (fp4 or nf4) + bnb_4bit_compute_dtype = torch.bfloat16, # Compute data type for 4-bit base models + ) + + def load_model_from_path(self, model_path:str, device_map='auto', + labels=None, label2id=None, id2label=None, + emb_type=None, input_type=None): + + print('Loading model from...', model_path) + + tokenizer = AutoTokenizer.from_pretrained(model_path) + tokenizer.pad_token = tokenizer.eos_token + tokenizer.padding_side = 'right' + + model = LlamaForSequenceClassificationCross.from_pretrained(model_path, + quantization_config = self.bnb_config, + device_map = device_map, + num_labels = len(labels), + emb_type=emb_type, + ) + model.config.pad_token_id = tokenizer.pad_token_id + model.config.id2label = id2label + model.config.label2id = label2id + + return model, tokenizer + + diff --git a/tasks/edit_intent_classification/finetuning_llm_xnet/modelling_llama.py b/tasks/edit_intent_classification/finetuning_llm_xnet/modelling_llama.py new file mode 100644 index 0000000..1a95076 --- /dev/null +++ b/tasks/edit_intent_classification/finetuning_llm_xnet/modelling_llama.py @@ -0,0 +1,165 @@ +# Built upon the huggingface implementation + +from typing import List, Optional, Tuple, Union +from transformers import LlamaModel, LlamaPreTrainedModel +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_outputs import SequenceClassifierOutputWithPast +from transformers.utils import logging + +logger = logging.get_logger(__name__) +_CONFIG_FOR_DOC = "LlamaConfig" + +class LlamaForSequenceClassificationCross(LlamaPreTrainedModel): + def __init__(self, config, model=None, score=None, emb_type=None): + super().__init__(config) + self.num_labels = config.num_labels + self.emb_type = emb_type + + if model is None: + self.model = LlamaModel(config) + else: + self.model = model + + if score is None: + if self.emb_type in['diff','diffABS']: + input_size = config.hidden_size + elif self.emb_type in ['n-o','n-diffABS']: + input_size = config.hidden_size*2 + elif self.emb_type in ['n-diffABS-o']: + input_size = config.hidden_size*3 + else: + raise ValueError("invalid emb_type") + self.score = nn.Linear(input_size, self.num_labels, bias=False) + + else: + self.score = score + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def length_to_mask(self, lens, max_len, device): + lens = lens.to(dtype=torch.long) + lens = lens.to(device) + base = torch.arange(max_len, dtype=torch.long).expand(len(lens), max_len) + base = base.to(device) + return base < lens.unsqueeze(1) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: torch.Tensor = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + len_old: Optional[torch.LongTensor] = None, + len_new: Optional[torch.LongTensor] = None, + ) -> Union[Tuple, SequenceClassifierOutputWithPast]: + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.model( + input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = transformer_outputs[0] + + if input_ids is not None: + batch_size = input_ids.shape[0] + else: + batch_size = inputs_embeds.shape[0] + + if self.config.pad_token_id is None and batch_size != 1: + raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.") + if self.config.pad_token_id is None: + sequence_lengths0 = -1 + sequence_lengths1 = -1 + else: + if input_ids is not None: + # the first pad_token_id is the end of the 2nd sentence + sequence_lengths1 = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1) # new (2nd) sentence last token + sequence_lengths0 = len_old -1 # old (first) sentence, last token + else: + sequence_lengths0 = -1 + sequence_lengths1 = -1 + + sequence_end_ix_old = sequence_lengths0 + sequence_end_ix_new = sequence_lengths1 + sequence_end_ix_old = sequence_end_ix_old.to(hidden_states.device) + sequence_end_ix_new = sequence_end_ix_new.to(hidden_states.device) + + hidden_states_old = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_end_ix_old] + hidden_states_new = hidden_states[torch.arange(batch_size, device=hidden_states.device), sequence_end_ix_new] + + if self.emb_type == 'diff': + hidden_states = torch.as_tensor(hidden_states_new - hidden_states_old) + elif self.emb_type == 'diffABS': + hidden_states = torch.abs(torch.as_tensor(hidden_states_new - hidden_states_old)) + elif self.emb_type == 'n-diffABS': + diff = torch.abs(torch.as_tensor(hidden_states_new - hidden_states_old)) + hidden_states = torch.cat((hidden_states_new, diff),1) + elif self.emb_type == 'n-diffABS-o': + diff = torch.abs(torch.as_tensor(hidden_states_new - hidden_states_old)) + hidden_states = torch.cat((hidden_states_new, diff, hidden_states_old),1) + elif self.emb_type == 'n-o': + hidden_states = torch.cat((hidden_states_new, hidden_states_old),1) + + hidden_states = hidden_states.to(hidden_states_old.device) + hidden_states = hidden_states.type(self.score.weight.dtype) + pooled_logits = self.score(hidden_states) + + loss = None + if labels is not None: + labels = labels.to(pooled_logits.device) + if self.config.problem_type is None: + if self.num_labels == 1: + self.config.problem_type = "regression" + elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.num_labels == 1: + loss = loss_fct(pooled_logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(pooled_logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(pooled_logits, labels) + if not return_dict: + output = (pooled_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutputWithPast( + loss=loss, + logits=pooled_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) diff --git a/tasks/edit_intent_classification/finetuning_llm_xnet/modelling_peft.py b/tasks/edit_intent_classification/finetuning_llm_xnet/modelling_peft.py new file mode 100644 index 0000000..5a2cd9a --- /dev/null +++ b/tasks/edit_intent_classification/finetuning_llm_xnet/modelling_peft.py @@ -0,0 +1,330 @@ +# Built upon the huggingface implementation +from __future__ import annotations +import inspect +import os +import warnings +from contextlib import contextmanager +from copy import deepcopy +import importlib +import os +from typing import Optional +import torch +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +from transformers.modeling_outputs import SequenceClassifierOutput +from transformers import AutoTokenizer +from peft.config import PeftConfig +from peft.mapping import MODEL_TYPE_TO_PEFT_MODEL_MAPPING +from peft.peft_model import PeftModel +from peft.utils.constants import TOKENIZER_CONFIG_NAME +from peft.utils.other import check_file_exists_on_hf_hub + +from peft.tuners import ( + AdaLoraModel, + AdaptionPromptModel, + IA3Model, + LoHaModel, + LoKrModel, + LoraModel, + OFTModel, + PolyModel, + PrefixEncoder, + PromptEmbedding, + PromptEncoder, +) +from peft.utils import ( + PeftType, + _get_batch_size, + _set_trainable, +) + +PEFT_TYPE_TO_MODEL_MAPPING = { + PeftType.LORA: LoraModel, + PeftType.LOHA: LoHaModel, + PeftType.LOKR: LoKrModel, + PeftType.PROMPT_TUNING: PromptEmbedding, + PeftType.P_TUNING: PromptEncoder, + PeftType.PREFIX_TUNING: PrefixEncoder, + PeftType.ADALORA: AdaLoraModel, + PeftType.ADAPTION_PROMPT: AdaptionPromptModel, + PeftType.IA3: IA3Model, + PeftType.OFT: OFTModel, + PeftType.POLY: PolyModel, +} + + +class PeftModelForSequenceClassificationCross(PeftModel): + """ + Peft model for sequence classification tasks with the xnet approach. + + Args: + model ([`~transformers.PreTrainedModel`]): Base transformer model. + peft_config ([`PeftConfig`]): Peft config. + + """ + + def __init__(self, model: torch.nn.Module, peft_config: PeftConfig, adapter_name: str = "default") -> None: + super().__init__(model, peft_config, adapter_name) + + if self.modules_to_save is None: + self.modules_to_save = {"classifier", "score"} + else: + self.modules_to_save.update({"classifier", "score"}) + + + for name, _ in self.base_model.named_children(): + if any(module_name in name for module_name in self.modules_to_save): + self.cls_layer_name = name + break + + # to make sure classifier layer is trainable + _set_trainable(self, adapter_name) + + def forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + task_ids=None, + **kwargs, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + peft_config = self.active_peft_config + if not peft_config.is_prompt_learning: + with self._enable_peft_forward_hooks(**kwargs): + kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_forward_args} + if peft_config.peft_type == PeftType.POLY: + kwargs["task_ids"] = task_ids + return self.base_model( + input_ids=input_ids, + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + labels=labels, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + **kwargs, + ) + + batch_size = _get_batch_size(input_ids, inputs_embeds) + if attention_mask is not None: + # concat prompt attention mask + prefix_attention_mask = torch.ones(batch_size, peft_config.num_virtual_tokens).to(attention_mask.device) + attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1) + if kwargs.get("position_ids", None) is not None: + warnings.warn("Position ids are not supported for parameter efficient tuning. Ignoring position ids.") + kwargs["position_ids"] = None + kwargs.update( + { + "attention_mask": attention_mask, + "labels": labels, + "output_attentions": output_attentions, + "output_hidden_states": output_hidden_states, + "return_dict": return_dict, + } + ) + + if peft_config.peft_type == PeftType.PREFIX_TUNING: + return self._prefix_tuning_forward(input_ids=input_ids, **kwargs) + else: + if kwargs.get("token_type_ids", None) is not None: + kwargs["token_type_ids"] = torch.cat( + ( + torch.zeros(batch_size, peft_config.num_virtual_tokens).to(self.word_embeddings.weight.device), + kwargs["token_type_ids"], + ), + dim=1, + ).long() + if inputs_embeds is None: + inputs_embeds = self.word_embeddings(input_ids) + prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids) + prompts = prompts.to(inputs_embeds.dtype) + inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1) + return self.base_model(inputs_embeds=inputs_embeds, **kwargs) + + def _prefix_tuning_forward( + self, + input_ids=None, + attention_mask=None, + inputs_embeds=None, + labels=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + **kwargs, + ): + batch_size = _get_batch_size(input_ids, inputs_embeds) + past_key_values = self.get_prompt(batch_size) + fwd_params = list(inspect.signature(self.base_model.forward).parameters.keys()) + kwargs.update( + { + "input_ids": input_ids, + "attention_mask": attention_mask, + "inputs_embeds": inputs_embeds, + "output_attentions": output_attentions, + "output_hidden_states": output_hidden_states, + "return_dict": return_dict, + "past_key_values": past_key_values, + } + ) + if "past_key_values" in fwd_params: + return self.base_model(labels=labels, **kwargs) + else: + transformer_backbone_name = self.base_model.get_submodule(self.transformer_backbone_name) + fwd_params = list(inspect.signature(transformer_backbone_name.forward).parameters.keys()) + if "past_key_values" not in fwd_params: + raise ValueError("Model does not support past key values which are required for prefix tuning.") + outputs = transformer_backbone_name(**kwargs) + pooled_output = outputs[1] if len(outputs) > 1 else outputs[0] + if "dropout" in [name for name, _ in list(self.base_model.named_children())]: + pooled_output = self.base_model.dropout(pooled_output) + logits = self.base_model.get_submodule(self.cls_layer_name)(pooled_output) + + loss = None + if labels is not None: + if self.config.problem_type is None: + if self.base_model.num_labels == 1: + self.config.problem_type = "regression" + elif self.base_model.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): + self.config.problem_type = "single_label_classification" + else: + self.config.problem_type = "multi_label_classification" + + if self.config.problem_type == "regression": + loss_fct = MSELoss() + if self.base_model.num_labels == 1: + loss = loss_fct(logits.squeeze(), labels.squeeze()) + else: + loss = loss_fct(logits, labels) + elif self.config.problem_type == "single_label_classification": + loss_fct = CrossEntropyLoss() + loss = loss_fct(logits.view(-1, self.base_model.num_labels), labels.view(-1)) + elif self.config.problem_type == "multi_label_classification": + loss_fct = BCEWithLogitsLoss() + loss = loss_fct(logits, labels) + if not return_dict: + output = (logits,) + outputs[2:] + return ((loss,) + output) if loss is not None else output + + return SequenceClassifierOutput( + loss=loss, + logits=logits, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +class _BaseAutoPeftModelCross: + _target_class = None + _target_peft_class = None + + def __init__(self, *args, **kwargs): + raise EnvironmentError( + f"{self.__class__.__name__} is designed to be instantiated " + f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or " + f"`{self.__class__.__name__}.from_config(config)` methods." + ) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path, + adapter_name: str = "default", + is_trainable: bool = False, + config: Optional[PeftConfig] = None, + emb_type=None, + **kwargs, + ): + r""" + A wrapper around all the preprocessing steps a user needs to perform in order to load a PEFT model. The kwargs + are passed along to `PeftConfig` that automatically takes care of filtering the kwargs of the Hub methods and + the config object init. + """ + peft_config = PeftConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) + base_model_path = peft_config.base_model_name_or_path + + task_type = getattr(peft_config, "task_type", None) + + if cls._target_class is not None: + target_class = cls._target_class + elif cls._target_class is None and task_type is not None: + # this is only in the case where we use `AutoPeftModel` + raise ValueError( + "Cannot use `AutoPeftModel` with a task type, please use a specific class for your task type. (e.g. `AutoPeftModelForCausalLM` for `task_type='CAUSAL_LM'`)" + ) + + if task_type is not None: + expected_target_class = MODEL_TYPE_TO_PEFT_MODEL_MAPPING[task_type] + if cls._target_peft_class.__name__ != expected_target_class.__name__: + raise ValueError( + f"Expected target PEFT class: {expected_target_class.__name__}, but you have asked for: {cls._target_peft_class.__name__ }" + " make sure that you are loading the correct model for your task type." + ) + elif task_type is None and getattr(peft_config, "auto_mapping", None) is not None: + auto_mapping = getattr(peft_config, "auto_mapping", None) + base_model_class = auto_mapping["base_model_class"] + parent_library_name = auto_mapping["parent_library"] + + parent_library = importlib.import_module(parent_library_name) + target_class = getattr(parent_library, base_model_class) + else: + raise ValueError( + "Cannot infer the auto class from the config, please make sure that you are loading the correct model for your task type." + ) + + base_model = target_class.from_pretrained(base_model_path, emb_type=emb_type, **kwargs) + + + tokenizer_exists = False + if os.path.exists(os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_NAME)): + tokenizer_exists = True + else: + token = kwargs.get("token", None) + if token is None: + token = kwargs.get("use_auth_token", None) + + tokenizer_exists = check_file_exists_on_hf_hub( + repo_id=pretrained_model_name_or_path, + filename=TOKENIZER_CONFIG_NAME, + revision=kwargs.get("revision", None), + repo_type=kwargs.get("repo_type", None), + token=token, + ) + + if tokenizer_exists: + tokenizer = AutoTokenizer.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=kwargs.get("trust_remote_code", False) + ) + base_model.resize_token_embeddings(len(tokenizer)) + + return cls._target_peft_class.from_pretrained( + base_model, + pretrained_model_name_or_path, + adapter_name=adapter_name, + is_trainable=is_trainable, + config=config, + **kwargs, + ) + + +from transformers.models.auto.auto_factory import _BaseAutoModelClass +from transformers.models.auto.modeling_auto import MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING +class AutoModelForSequenceClassificationCross(_BaseAutoModelClass): + _model_mapping = MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + +class AutoPeftModelForSequenceClassificationCross(_BaseAutoPeftModelCross): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.emb_type = args.emb_type + print('!!!!!!!!AutoPeftModelForSequenceClassificationCross emb_type', self.emb_type) + print('args: ', args) + _target_class = AutoModelForSequenceClassificationCross + _target_peft_class = PeftModelForSequenceClassificationCross + + diff --git a/tasks/task_data_loader.py b/tasks/task_data_loader.py new file mode 100644 index 0000000..7a80d08 --- /dev/null +++ b/tasks/task_data_loader.py @@ -0,0 +1,46 @@ +from pathlib import Path +from datasets import load_dataset + +class TaskDataLoader: + def __init__(self, task_name:str, train_type:str, val_type:str, test_type:str=None) -> None: + ''' + params:task_name: str: name of the task + params:train_type: str: name of the training data + params:val_type: str: name of the validation data + params:test_type: str: name of the test data + ''' + self.task_name = task_name + self.train_type = train_type + self.test_type = test_type + self.val_type = val_type + self.task_data_dir = Path('data/Re3-Sci/tasks') / task_name + data_files = {} + for i in self.task_data_dir.iterdir(): + if i.is_file() and i.name.endswith('.csv'): + data_files[i.stem] = str(i) + self.data_files = data_files + self.dataset = load_dataset("csv", data_files = data_files, keep_default_na=False) + label_names = sorted(set(label for label in self.dataset["train"]["label"])) + self.labels = label_names + label2id, id2label = dict(), dict() + for i, label in enumerate(self.labels): + label2id[label] = i + id2label[i] = label + self.label2id = label2id + self.id2label = id2label + + + def load_train(self): + return self.dataset[self.train_type] + def load_val(self): + return self.dataset[self.val_type] + def load_test(self): + if self.test_type is None: + return None + return self.dataset[self.test_type] + def load_data(self): + return self.load_train(), self.load_val(), self.load_test() + def get_labels(self): + return self.labels, self.label2id, self.id2label + + diff --git a/tasks/task_data_preprocessor.py b/tasks/task_data_preprocessor.py new file mode 100644 index 0000000..dae2947 --- /dev/null +++ b/tasks/task_data_preprocessor.py @@ -0,0 +1,14 @@ +from pathlib import Path +import importlib + +class TaskDataPreprocessor: + def __init__(self, task_name:str, method:str) -> None: + ''' + ''' + pck = importlib.import_module(f"tasks.{task_name}.{method}") + print(f'import tasks.{task_name}.{method}') + data_preprocessor = getattr(pck, 'DataPreprocessor') + self.task_name = task_name + self.method = method + self.data_preprocessor = data_preprocessor() + diff --git a/tasks/task_eval_res_summarizer.py b/tasks/task_eval_res_summarizer.py new file mode 100644 index 0000000..080699f --- /dev/null +++ b/tasks/task_eval_res_summarizer.py @@ -0,0 +1,302 @@ +from pathlib import Path +import importlib +import pandas as pd +import json +import os +from sklearn.metrics import classification_report +from datetime import datetime +from openpyxl import load_workbook +from openpyxl import Workbook +from openpyxl.styles import PatternFill, Font +from openpyxl.styles.borders import Border, Side, BORDER_THICK, BORDER_THIN +import matplotlib +import shutil + +class TaskEvalResSummarizer: + def __init__(self, task_name:str, method:str=None, labels:list=None, results_dir:str=None) -> None: + ''' + ''' + self.labels = labels + self.task_name = task_name + self.method = method + self.results_dir = results_dir + if method is None: # summarize all method groups + task_dir = Path(f"./{results_dir }/{task_name}") + print('!!!!!!!!!!!!!!Here1') + for method_dir in task_dir.iterdir(): + if method_dir.is_dir(): + self.method_dir = method_dir + self.summarize_eval_res() + else: + print('!!!!!!!!!!!!!!Here2') + self.method_dir = Path(f"./{results_dir}/{task_name}/{method}") + self.summarize_eval_res() + + def create_eval_cols(self): + col_names = [ + 'model_name', 'test_set', 'EI', 'none_nr', 'AIR', 'acc', 'marco_avg_f1', 'weighted_avg_f1','macro_avg2', 'weighted_avg2', + #'nomulti_acc', 'nomulti_marco_avg_f1', 'nomulti_weighted_avg_f1', + #'multi_acc', 'multi_marco_avg_f1', 'multi_weighted_avg_f1', + #'sub_acc', 'sub_marco_avg_f1','sub_weighted_avg_f1', + #'sub2_acc', 'sub2_marco_avg_f1','sub2_weighted_avg_f1' + ] + + for label in self.labels: + col_names.extend([f'{label}_p', f'{label}_r', f'{label}_f1', f'{label}_support']) + return col_names + + def group_models_by_name(self): + model_names = [d.name for d in self.method_dir.iterdir() if d.is_dir()] + baselines = [d for d in model_names if '_lora' not in d] + model_names = [d for d in model_names if d not in baselines] + + models_7B = [d for d in model_names if '7B_' in d] + models_7B_Chat = [d for d in model_names if '7B-Chat_' in d] + models_13B = [d for d in model_names if '13B_' in d] + models_13B_Chat = [d for d in model_names if '13B-Chat_' in d] + baselines.sort() + models_7B.sort() + models_7B_Chat.sort() + models_13B.sort() + models_13B_Chat.sort() + other_models = [d for d in model_names if d not in baselines+models_7B+models_7B_Chat+models_13B+models_13B_Chat] + + model_names = models_13B_Chat + models_13B + models_7B_Chat + models_7B + other_models + baselines + + return model_names + + def get_eval_dict(self, model_dir, model_name, col_names): + dic = {} + for col in col_names: + dic[col] = '' + + eval_file = model_dir/ 'eval_report.json' + pred_file = model_dir/ 'eval_pred.csv' + inference_time_file = model_dir/ 'inference_time.json' + + pred_df = pd.read_csv(pred_file) + with open(eval_file, 'r') as f: + eval_dict = json.load(f) + + if 'accuracy' not in eval_dict.keys(): + acc = 0 + else: + acc = round(eval_dict['accuracy']*100, 1) + + dic['model_name'] = model_name + dic['test_set'] = len(pred_df) + if 'none_nr' in eval_dict.keys(): + dic['none_nr'] = eval_dict['none_nr'] + else: + dic['none_nr'] = 0 + air = ( dic['test_set'] - dic['none_nr'] ) / dic['test_set'] + + dic['AIR'] = round(air*100, 1) + + if inference_time_file.exists(): + with open(inference_time_file, 'r') as f: + inference_time = json.load(f) + print('inference_time', inference_time) + if int(inference_time['inference_time']) == 0: + dic['EI'] = 0 + else: + dic['EI'] = round(int(dic['test_set']) / int(inference_time['inference_time']), 1) + + dic['acc'] = acc + dic['marco_avg_f1'] = round(eval_dict['macro avg']['f1-score']*100, 1) + dic['weighted_avg_f1'] = round(eval_dict['weighted avg']['f1-score']*100, 1) + for label in self.labels: + if label not in eval_dict.keys(): + continue + dic[f'{label}_p'] = round(eval_dict[label]['precision']*100, 1) + dic[f'{label}_r'] = round(eval_dict[label]['recall']*100, 1) + dic[f'{label}_f1'] = round(eval_dict[label]['f1-score']*100, 1) + dic[f'{label}_support'] = eval_dict[label]['support'] + + if 'macro avg2' in eval_dict.keys(): + dic['macro_avg2'] = round(eval_dict['macro avg2']['f1-score']*100, 1) + dic['weighted_avg2'] = round(eval_dict['weighted avg2']['f1-score']*100, 1) + + return dic + + def save_df_to_excel(self, excelfile, sheetname, df): + print(df[['model_name', 'test_set', 'acc', 'marco_avg_f1', 'weighted_avg_f1']]) + + if not os.path.isfile(excelfile): + #print('The excel file is not existing, creating a new excel file...', excelfile) + wb = Workbook() + wb.save(excelfile) + excelfile_work = excelfile + else: + new_excelfile = excelfile.parent/f'{excelfile.stem}_new.xlsx' + shutil.copy(excelfile, new_excelfile) + excelfile_work = new_excelfile + + ''' + excelfile = shutil.copy(excelfile, excelfile) + writer = pd.ExcelWriter( + excelfile, + engine="openpyxl", + mode="a", + if_sheet_exists="overlay", + ) + + + wb = load_workbook(excelfile) + if not (sheetname in wb.sheetnames): + #print('The worksheet is not existing, creating a new worksheet...' , sheetname) + ws1 = wb.create_sheet(sheetname) + ws1.title = sheetname + wb.save(excelfile) + ''' + writer = pd.ExcelWriter( + excelfile_work, + engine="openpyxl", + mode="a", + if_sheet_exists="overlay", + ) + + #book = load_workbook(excelfile) + #idx = wb.sheetnames.index(sheetname) + #ws = book.get_sheet_by_name(sheetname) + #book.remove(ws) + #book.create_sheet(sheetname, idx) + #writer = pd.ExcelWriter(excelfile, engine='openpyxl') + #writer.book = book + #writer.sheets = {ws.title: ws for ws in book.worksheets} + + df.to_excel(writer, sheet_name=sheetname, index=False, header=True) + writer.close() + #writer.save() + + if excelfile_work != excelfile: + shutil.copy(excelfile_work, excelfile) + os.remove(excelfile_work) + + def check_best_models(self, df, metrics): + best_models = {} + for m in metrics: + best_models.update({m: []}) + + for m in metrics: + # remove '' + df = df[df[m]!=''] + max_v = df[m].max() + for index, row in df.iterrows(): + if row[m] == max_v: + best_models[m].append((row['model_name'], row['test_set'])) + return best_models + + + + def color_best_metric(self, excelfile, sheetname, best_models, color, font, metrics): + wb = load_workbook(excelfile) + ws = wb[sheetname] + + # Create a dictionary of column names + ColNames = {} + Current = 0 + for COL in ws.iter_cols(1, ws.max_column): + ColNames[COL[0].value] = Current + Current += 1 + + # Color best metrics + for row_cells in ws.iter_rows(min_row=2, max_row=ws.max_row): + for m in metrics: + bests = best_models[m] + for model in bests: + if row_cells[ColNames['model_name']].value == model[0] and row_cells[ColNames['test_set']].value == model[1]: + row_cells[ColNames[m]].fill = PatternFill("solid", fgColor=color) + row_cells[ColNames[m]].font = Font(b=font) + if not 'ckp' in row_cells[ColNames['model_name']].value: + row_cells[ColNames['model_name']].fill = PatternFill("solid", fgColor='FFCCFF') + + wb.save(excelfile) + + def summarize_eval_res(self): + res_dir = self.method_dir + print('summarizing the results....', res_dir) + eval_res_dir = Path(f'./eval_results/{self.results_dir}') + eval_res_dir.mkdir(parents=True, exist_ok=True) + eval_res_dir = eval_res_dir/self.task_name + eval_res_dir.mkdir(parents=True, exist_ok=True) + eval_res_dir = eval_res_dir/self.method_dir.name + eval_res_dir.mkdir(parents=True, exist_ok=True) + eval_summary_file = eval_res_dir/f'{res_dir.name}_eval_summ.xlsx' + col_names = self.create_eval_cols() + eval_summary = pd.DataFrame(columns=col_names) + + #group models by name + model_names = self.group_models_by_name() + + for model_name in model_names: + model_dir = res_dir/model_name + model_ckp_dirs1 = [d for d in model_dir.iterdir() if d.is_dir() and d.name.startswith('checkpoint')] + model_ckp_dirs2 = [d for d in model_dir.iterdir() if d.is_dir() and d.name.startswith('val_')] + # sort by the number of the checkpoint + model_ckp_dirs1.sort(key=lambda x: int(x.name.split('-')[-1])) + model_ckp_dirs2.sort(key=lambda x: int(x.name.split('-')[-1])) + model_ckp_dirs = model_ckp_dirs1 + model_ckp_dirs2 + eval_file = model_dir/ 'eval_report.json' + print('### summarizing: ', model_dir) + if not eval_file.exists(): + print('!!!! final not evaluated') + dic={} + for col in col_names: + dic[col] = '' + dic['model_name'] = model_name + df = pd.DataFrame(dic, index=[0]) + eval_summary = pd.concat([eval_summary, df], ignore_index=True) + else: + dic = self.get_eval_dict(model_dir, model_name, col_names) + df = pd.DataFrame(dic, index=[0]) + eval_summary = pd.concat([eval_summary, df], ignore_index=True) + if len(model_ckp_dirs) > 0: + for ckp_dir in model_ckp_dirs: + ckp_nr = ckp_dir.name.split('-')[-1] + eval_file = ckp_dir/ 'eval_report.json' + if not eval_file.exists(): + print('--- ckp not evaluated', ckp_dir.name) + continue + if ckp_dir.name.startswith('val_'): + model_name2 = f'val_{model_name}' + else: + model_name2 = model_name + dic = self.get_eval_dict(ckp_dir, f'{model_name2}_ckp{ckp_nr}', col_names) + df = pd.DataFrame(dic, index=[0]) + eval_summary = pd.concat([eval_summary, df], ignore_index=True) + all_df = eval_summary + baseline_df = eval_summary[~eval_summary['model_name'].str.contains('_lora')] + self.save_df_to_excel(eval_summary_file, 'all', all_df) + self.save_df_to_excel(eval_summary_file, 'all_baselines', baseline_df) + + green = '00FF00' + lightgreen = 'CCFFCC' + lightblue = '00FFFF' + pink = 'FF00FF' + yellow = 'FFFF00' + lightyellow = 'FFFFCC' + red = 'FF0000' + metric_cols = col_names[3:] + metric_cols = [c for c in metric_cols if 'support' not in c] + all_bests = self.check_best_models(all_df, metric_cols) + self.color_best_metric(eval_summary_file, 'all', all_bests, green, True, metric_cols) + baseline_bests = self.check_best_models(baseline_df, metric_cols) + self.color_best_metric(eval_summary_file, 'all_baselines', baseline_bests, lightblue, True, metric_cols) + + # make a copy of the summary file, rename it with the current time + new_file = eval_summary_file.parent/f'{eval_summary_file.name}_{datetime.now().strftime("%Y%m%d-%H%M%S")}.xlsx' + shutil.copy(eval_summary_file, new_file) + # remove the old summary file + os.remove(eval_summary_file) + + + + + + +def main(): + '' + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tasks/task_evaluater.py b/tasks/task_evaluater.py new file mode 100644 index 0000000..129771a --- /dev/null +++ b/tasks/task_evaluater.py @@ -0,0 +1,12 @@ +from pathlib import Path +import importlib + +class TaskEvaluater: + def __init__(self, task_name:str, method:str) -> None: + ''' + ''' + pck = importlib.import_module(f"tasks.{task_name}.{method}") + evaluater = getattr(pck, 'Evaluater') + self.task_name = task_name + self.method = method + self.evaluater = evaluater() diff --git a/tasks/task_model_finetuner.py b/tasks/task_model_finetuner.py new file mode 100644 index 0000000..7a7ad82 --- /dev/null +++ b/tasks/task_model_finetuner.py @@ -0,0 +1,13 @@ +from pathlib import Path +import importlib + +class TaskModelFinetuner: + def __init__(self, task_name:str, method:str) -> None: + ''' + ''' + pck = importlib.import_module(f"tasks.{task_name}.{method}") + model_finetuner = getattr(pck, 'ModelFinetuner') + self.task_name = task_name + self.method = method + self.model_finetuner = model_finetuner() + diff --git a/tasks/task_model_loader.py b/tasks/task_model_loader.py new file mode 100644 index 0000000..2d34ecb --- /dev/null +++ b/tasks/task_model_loader.py @@ -0,0 +1,13 @@ +from pathlib import Path +import importlib + +class TaskModelLoader: + def __init__(self, task_name:str, method:str) -> None: + ''' + ''' + pck = importlib.import_module(f"tasks.{task_name}.{method}") + model_loader = getattr(pck, 'ModelLoader') + self.task_name = task_name + self.method = method + self.model_loader = model_loader() +