From ad33c1cea4e399672ffbd985130ef604f6c9b0b2 Mon Sep 17 00:00:00 2001 From: lee Date: Wed, 14 Aug 2024 13:26:17 +0200 Subject: [PATCH] initial commit --- ABOUT_THIS_TEMPLATE.md | 125 --- configs/__init__.py | 7 + configs/datasets.py | 50 ++ configs/fsdp.py | 19 + configs/inference.py | 28 + configs/peft.py | 26 + configs/training.py | 58 ++ docs/index.md | 21 - evaluate_llama2.py | 135 ++++ finetune_llama2.py | 286 +++++++ .../__init__.py | 8 - .../__main__.py | 6 - gumbel_softmax_layer_skipping_2024/base.py | 62 -- gumbel_softmax_layer_skipping_2024/cli.py | 33 - .../subpackage/__init__.py | 5 - .../subpackage/subpackage.py | 15 - llama_datasets/__init__.py | 6 + llama_datasets/cnndm_dataset.py | 34 + llama_datasets/samsum_dataset.py | 33 + llama_datasets/utils.py | 66 ++ mkdocs.yml | 15 - model_checkpointing/__init__.py | 14 + model_checkpointing/checkpoint_handler.py | 324 ++++++++ neuralnets/llama_gumbel.py | 722 ++++++++++++++++++ policies/__init__.py | 7 + .../activation_checkpointing_functions.py | 29 + policies/anyprecision_optimizer.py | 179 +++++ policies/mixed_precision.py | 38 + policies/wrapping.py | 33 + requirements-dev.txt | 14 - requirements.txt | 23 +- setup.py | 44 -- tests/__init__.py | 0 tests/conftest.py | 17 - tests/test_base.py | 22 - utils/__init__.py | 7 + utils/chat_utils.py | 65 ++ utils/checkpoint_converter_fsdp_hf.py | 65 ++ utils/config_utils.py | 62 ++ utils/dataset_utils.py | 77 ++ utils/fsdp_utils.py | 35 + utils/memory_utils.py | 62 ++ utils/model_utils.py | 58 ++ utils/safety_utils.py | 169 ++++ utils/train_utils.py | 493 ++++++++++++ 45 files changed, 3207 insertions(+), 390 deletions(-) delete mode 100644 ABOUT_THIS_TEMPLATE.md create mode 100644 configs/__init__.py create mode 100644 configs/datasets.py create mode 100644 configs/fsdp.py create mode 100644 configs/inference.py create mode 100644 configs/peft.py create mode 100644 configs/training.py delete mode 100644 docs/index.md create mode 100644 evaluate_llama2.py create mode 100644 finetune_llama2.py delete mode 100644 gumbel_softmax_layer_skipping_2024/__init__.py delete mode 100644 gumbel_softmax_layer_skipping_2024/__main__.py delete mode 100644 gumbel_softmax_layer_skipping_2024/base.py delete mode 100644 gumbel_softmax_layer_skipping_2024/cli.py delete mode 100644 gumbel_softmax_layer_skipping_2024/subpackage/__init__.py delete mode 100644 gumbel_softmax_layer_skipping_2024/subpackage/subpackage.py create mode 100644 llama_datasets/__init__.py create mode 100644 llama_datasets/cnndm_dataset.py create mode 100644 llama_datasets/samsum_dataset.py create mode 100644 llama_datasets/utils.py delete mode 100644 mkdocs.yml create mode 100644 model_checkpointing/__init__.py create mode 100644 model_checkpointing/checkpoint_handler.py create mode 100644 neuralnets/llama_gumbel.py create mode 100644 policies/__init__.py create mode 100644 policies/activation_checkpointing_functions.py create mode 100644 policies/anyprecision_optimizer.py create mode 100644 policies/mixed_precision.py create mode 100644 policies/wrapping.py delete mode 100644 requirements-dev.txt delete mode 100644 setup.py delete mode 100644 tests/__init__.py delete mode 100644 tests/conftest.py delete mode 100644 tests/test_base.py create mode 100644 utils/__init__.py create mode 100644 utils/chat_utils.py create mode 100644 utils/checkpoint_converter_fsdp_hf.py create mode 100644 utils/config_utils.py create mode 100644 utils/dataset_utils.py create mode 100644 utils/fsdp_utils.py create mode 100644 utils/memory_utils.py create mode 100644 utils/model_utils.py create mode 100644 utils/safety_utils.py create mode 100644 utils/train_utils.py diff --git a/ABOUT_THIS_TEMPLATE.md b/ABOUT_THIS_TEMPLATE.md deleted file mode 100644 index 62e6da1..0000000 --- a/ABOUT_THIS_TEMPLATE.md +++ /dev/null @@ -1,125 +0,0 @@ -# About this template - -Hi, I've adapted this template from the excellent [python-project-template](https://github.com/rochacbruno/python-project-template/) by [rochacbruno](https://github.com/rochacbruno). It was created having in mind UKP Lab people and what the most common use-cases would be. Following its structure you'll get into developing your next paper in no time! - -It includes: - -- πŸ“¦ A basic [setup.py](setup.py) file to provide installation, packaging and distribution for your project. - Template uses setuptools because it's the de-facto standard for Python packages -- πŸ“ƒ Documentation structure using [mkdocs](http://www.mkdocs.org) -- πŸ§ͺ Testing structure using [pytest](https://docs.pytest.org/en/latest/) -- βœ… Code linting using [pylint](https://pypi.org/project/pylint/) -- 🎯 Entry points to execute your program using `python -m ` with basic CLI argument parsing. -- πŸ”„ Continuous integration using [Github Actions](https://github.com/UKPLab/gumbel-softmax-layer-skipping-2024/actions) with jobs to check, lint and test your project. - -Are there any changes you'd like to request? Feel free to fork and open a pull request! - -## Structure - -Lets take a look at the structure of this template: - -```text -β”‚ .gitignore # A list of files to ignore when pushing to GH -β”‚ ABOUT_THIS_TEMPLATE.md # The file you're reading right now -β”‚ LICENSE # The license for the project -β”‚ mkdocs.yml # Configuration for documentation site -β”‚ NOTICE.txt # Legal notice for the repository -β”‚ README.md # The main readme for the project -β”‚ requirements-dev.txt # List of requirements for testing and devlopment -β”‚ requirements.txt # An empty file to hold the requirements for the project -β”‚ setup.py # The setup.py file for installing and packaging the project -β”‚ -β”œβ”€β”€β”€.github # Github metadata for repository -β”‚ β”‚ dependabot.yml # Dependabot workflow for updating requirements -β”‚ β”‚ init.sh # Initializes the repository -β”‚ β”‚ PULL_REQUEST_TEMPLATE.md # Used automatically by GH for pull requests -β”‚ β”‚ rename_project.sh # Called once at repository creation -β”‚ β”‚ -β”‚ β”œβ”€β”€β”€ISSUE_TEMPLATE # Templates for creating issues on GH -β”‚ β”‚ -β”‚ └───workflows # GH Actions folder -β”‚ docs.yml # Builds documentation automatically -β”‚ main.yml # Runs install and file checks -β”‚ rename_project.yml # Renames repository at creation -β”‚ tests.yml # Run all tests in 'tests' folder -β”‚ -β”œβ”€β”€β”€docs # Auto-generated documentation -β”‚ index.md # Landing page of docs -β”‚ -β”œβ”€β”€β”€gumbel_softmax_layer_skipping_2024 # The main python package for the project -β”‚ base.py # The base module for the project -β”‚ cli.py # Defines CLI instructions -β”‚ __init__.py # This tells Python that this is a package -β”‚ __main__.py # The entry point for the project -β”‚ -└───tests # Unit tests for the project (add more tests files here) - conftest.py # Configuration, hooks and fixtures for pytest - test_base.py # The base test case for the project - __init__.py # This tells Python that this is a test package -``` - -## FAQs - - -### Where should I add new stuff ? - -You should create new files and subpackages inside gumbel_softmax_layer_skipping_2024 and implement your functionalities there. Remember to add what you write to `__init__.py` so that the imports work smoothly. Take a look at `base.py` and `__init__.py` to understand how it works. - -### Why is `requirements.txt` empty ? - -This template is a low dependency project, so it doesn't have any extra dependencies. -You can freely add new dependencies. - -You should put here everything needed to replicate your work. -Testing, linting, and other requirements used only in development should go in `requirements-dev.txt`. - -### Why is there a `requirements-dev.txt` file ? - -This file lists all the requirements for testing and development. Use it to separate things you used during development from the essential stuff needed to replicate your work. - -### What is the `.github` folder? - -It contains [GitHub Actions](https://docs.github.com/en/actions) that are executed automatically when pushing your code. You can see results for your repository [here](https://github.com/UKPLab/gumbel-softmax-layer-skipping-2024/actions). - -### What does the linter workflow do? - -It checks whether your code is clean enough from duplication, inconsistencies, violations to the naming convention etc. -It's not supposed to fail, but you should still look into it to get an idea of which parts of your code may need adjustments. - -### Why do automated actions fail ? - -This means there is something wrong in the files/tests/requirements. -Click on the failing run to read more details. - -### Why include `tests` and `docs` as part of the release? - -This template ships with everything you may need. You can remove what you don't like in this way: - - If you don't need automatic documentation generation, you can delete folder `docs`, file `.github\workflows\docs.yml` and `mkdocs.yml` - - If you don't want automatic testing, you can delete folder `tests` and file `.github\workflows\tests.yml` - -### How can I use pytest & pylint to check my code? - -Command `pytest` called from the project folder will run all tests inside the `tests` folder. -Similarly, `pylint` will run linting checks on your code and give you a status report. -It checks things such as logic, formatting, correct imports, duplication etc. - -### Why conftest includes a go_to_tmpdir fixture? - -When your project deals with file system operations, it is a good idea to use -a fixture to create a temporary directory and then remove it after the test. - -Before executing each test pytest will create a temporary directory and will -change the working directory to that path and run the test. - -So the test can create temporary artifacts isolated from other tests. - -After the execution Pytest will remove the temporary directory. - -### Why this template is not using [pre-commit](https://pre-commit.com/) ? - -pre-commit is an excellent tool to automate checks and formatting on your code. - -However I figured out that pre-commit adds extra dependency and it an entry barrier -for new contributors. - -Once the project is bigger and complex, having pre-commit as a dependency can be a good idea. diff --git a/configs/__init__.py b/configs/__init__.py new file mode 100644 index 0000000..ef2f11b --- /dev/null +++ b/configs/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from configs.peft import lora_config, llama_adapter_config, prefix_config +from configs.fsdp import fsdp_config +from configs.training import train_config +from configs.inference import inference_config diff --git a/configs/datasets.py b/configs/datasets.py new file mode 100644 index 0000000..9a3de6b --- /dev/null +++ b/configs/datasets.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from dataclasses import dataclass + + +@dataclass +class samsum_dataset: + dataset: str = "samsum_dataset" + train_split: str = "train" + test_split: str = "validation" + input_length: int = 2048 + +@dataclass +class samsum_dataset2: + dataset: str = "samsum_dataset2" + train_split: str = "train" + test_split: str = "validation" + input_length: int = 2048 + + +@dataclass +class cnndm_dataset: + dataset: str = "cnndm_dataset" + train_split: str = "train" + test_split: str = "validation" + input_length: int = 2048 + +@dataclass +class grammar_dataset: + dataset: str = "grammar_dataset" + train_split: str = "src/llama_recipes/datasets/grammar_dataset/gtrain_10k.csv" + test_split: str = "src/llama_recipes/datasets/grammar_dataset/grammar_validation.csv" + input_length: int = 2048 + + +@dataclass +class alpaca_dataset: + dataset: str = "alpaca_dataset" + train_split: str = "train" + test_split: str = "val" + data_path: str = "src/llama_recipes/datasets/alpaca_data.json" + + +@dataclass +class custom_dataset: + dataset: str = "custom_dataset" + file: str = "examples/custom_dataset.py" + train_split: str = "train" + test_split: str = "validation" diff --git a/configs/fsdp.py b/configs/fsdp.py new file mode 100644 index 0000000..8ee45c3 --- /dev/null +++ b/configs/fsdp.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from dataclasses import dataclass + +from torch.distributed.fsdp import ShardingStrategy +from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType + +@dataclass +class fsdp_config: + mixed_precision: bool=True + use_fp16: bool=False + sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD + checkpoint_type: StateDictType = StateDictType.FULL_STATE_DICT # TODO: use StateDictType.SHARDED_STATE_DICT for fsdp + fsdp_activation_checkpointing: bool=False + fsdp_cpu_offload: bool=False + pure_bf16: bool = False + optimizer: str= "AdamW" + diff --git a/configs/inference.py b/configs/inference.py new file mode 100644 index 0000000..efbd28a --- /dev/null +++ b/configs/inference.py @@ -0,0 +1,28 @@ +from dataclasses import dataclass + + +@dataclass +class inference_config: + model_name: str=None + peft_model: str=None + quantization: bool=False + use_gumbel: bool=False + max_new_tokens =100 #The maximum numbers of tokens to generate + prompt_file: str=None + seed: int=42 #seed value for reproducibility + do_sample: bool=True #Whether or not to use sampling ; use greedy decoding otherwise. + min_length: int=None #The minimum length of the sequence to be generated, input prompt + min_new_tokens + use_cache: bool=True #[optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding. + top_p: float=0.9 #1.0 # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. + temperature: float=0.01 #1.0 # [optional] The value used to modulate the next token probabilities. + top_k: int=50 # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering. + repetition_penalty: float=1.0 #The parameter for repetition penalty. 1.0 means no penalty. + length_penalty: int=1 #[optional] Exponential penalty to the length that is used with beam-based generation. + enable_azure_content_safety: bool=False # Enable safety check with Azure content safety api + enable_sensitive_topics: bool=False # Enable check for sensitive topics using AuditNLG APIs + enable_salesforce_content_safety: bool=True # Enable safety check with Salesforce safety flan t5 + max_padding_length: int=None # the max padding length to be used with tokenizer padding the prompts. + use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels + output_dir: str = "results" + debugging: bool = False # Enable debugging mode + generation_prompt: bool = True # Set add_generation_prompt diff --git a/configs/peft.py b/configs/peft.py new file mode 100644 index 0000000..73de09f --- /dev/null +++ b/configs/peft.py @@ -0,0 +1,26 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from dataclasses import dataclass, field +from typing import List + +@dataclass +class lora_config: + r: int=8 + lora_alpha: int=32 + target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"]) + bias= "none" + task_type: str= "CAUSAL_LM" + lora_dropout: float=0.05 + inference_mode: bool = False + +@dataclass +class llama_adapter_config: + adapter_len: int= 10 + adapter_layers: int= 30 + task_type: str= "CAUSAL_LM" + +@dataclass +class prefix_config: + num_virtual_tokens: int=30 + task_type: str= "CAUSAL_LM" \ No newline at end of file diff --git a/configs/training.py b/configs/training.py new file mode 100644 index 0000000..a6a1e39 --- /dev/null +++ b/configs/training.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from dataclasses import dataclass + + +@dataclass +class train_config: + model_name: str="llama-7b" + enable_fsdp: bool=True + low_cpu_fsdp: bool=False + run_validation: bool=True + batch_size_training: int=1 + gradient_accumulation_steps: int=1 + num_epochs: int=3 + num_workers_dataloader: int=1 + lr: float=1e-4 + weight_decay: float=0.0 + gamma: float= 0.85 + seed: int=42 + use_fp16: bool=False + mixed_precision: bool=True + val_batch_size: int=1 + dataset = "samsum_dataset" + peft_method: str = "lora" # None , llama_adapter, prefix + use_peft: bool=False + output_dir: str = "model_output" + freeze_layers: bool = False + num_freeze_layers: int = 1 + quantization: bool = False + one_gpu: bool = False + save_model: bool = True + dist_checkpoint_root_folder: str="/storage/ukp/work/lee/intel_ukp_llm/intel_ukp_llm/llama-13b" # will be used if using FSDP + dist_checkpoint_folder: str="checkpoints" # will be used if using FSDP + save_optimizer: bool=False # will be used if using FSDP + use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels + gumbel: bool = False # Enable Gumbel softmax for sampling + gumbel_temperature: float = 1.0 # Gumbel softmax temperature + gumbel_hard: bool = False # Use hard Gumbel softmax + gumbel_noskip_low: int = 2 # Layer to start skipping (low) + gumbel_noskip_high: int = 32 # Layer to stop skipping (high) + debugging: bool = False # Enable debugging mode + debugging_host: str = "localhost" # Debugging host + debugging_port: int = 5678 # Debugging port + gumbel_target: float = 0.8 # Percent of layers that should be used (for calculating gumbel loss) + gumbel_loss_multiplier: float = 50.0 # Simple multiplier for gumbel loss + gumbel_loss_alpha: float = 0.8 # initial weighting factor for the gumbel loss + gumbel_loss_beta: float = 0.0005 # controls the rate at which the weighting factor decreases + use_token_max: bool = False # Use max function over token instead of token mean + use_only_last_token: bool = False # Use only last token for classification + use_only_past_key_values: bool = False # Use only past key values for classification + share_layer: bool = False # Share one gumbel layer across all layers + gumbel_use_simple_classifier: bool = False # Use simple classifier instead of gumbel + gumbel_num_hidden_layers: int = 1 # Number of hidden layers + gradient_clipping_value: float = 1.0 # gradient Clipping value + + + diff --git a/docs/index.md b/docs/index.md deleted file mode 100644 index 2a1722b..0000000 --- a/docs/index.md +++ /dev/null @@ -1,21 +0,0 @@ -# Welcome to MkDocs - -For full documentation visit [mkdocs.org](https://www.mkdocs.org). - -## Commands - -* `mkdocs new [dir-name]` - Create a new project. -* `mkdocs serve` - Start the live-reloading docs server. -* `mkdocs build` - Build the documentation site. -* `mkdocs -h` - Print help message and exit. - -## Project layout - - mkdocs.yml # The configuration file. - docs/ - index.md # The documentation homepage. - ... # Other markdown pages, images and other files. - -## Docs - -::: gumbel_softmax_layer_skipping_2024 \ No newline at end of file diff --git a/evaluate_llama2.py b/evaluate_llama2.py new file mode 100644 index 0000000..0755bce --- /dev/null +++ b/evaluate_llama2.py @@ -0,0 +1,135 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import fire +import os +import json +import time +import csv +import torch +from transformers import LlamaTokenizer, LlamaTokenizerFast, AutoTokenizer + +# Evaluation via huggingface rouge: https://huggingface.co/spaces/evaluate-metric/rouge +import evaluate +import datasets + +from utils.safety_utils import get_safety_checker +from utils.model_utils import load_model, load_peft_model, load_gumbel_llama + +from configs import inference_config +from utils.config_utils import update_config + +torch_device = "cuda" if torch.cuda.is_available() else "cpu" + +def main(**kwargs): + update_config(inference_config, **kwargs) + + # Set the seeds for reproducibility + torch.cuda.manual_seed(inference_config.seed) + torch.manual_seed(inference_config.seed) + + # Load samsum dataset + dataset = datasets.load_dataset("samsum", split="test") + + print(f"Evaluating on {len(dataset)} instances.") + + + # Load model + if inference_config.use_gumbel: + model = load_gumbel_llama(inference_config.model_name, inference_config.quantization) + else: + model = load_model(inference_config.model_name, inference_config.quantization) + + if inference_config.peft_model: + model = load_peft_model(model, inference_config.peft_model) + + model.to(torch_device) + model.bfloat16() + model.eval() + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(inference_config.model_name) + + # Add hooks for gumbel + if inference_config.use_gumbel: + # a dict to store the activations + gumbel_activation = {} + def hook_fn(layer, input, output): + gumbel_activation[layer] = output + + for layer_idx, layer in enumerate(model.model.gumbel_layer_selection): + model.model.gumbel_layer_selection[layer_idx].register_forward_hook(hook_fn) + activations=[] + + predictions_generated = [] + tok_len_generated = [] + references = [] + + dsample = dataset['dialogue'] + dsummary = dataset['summary'] + + start = time.perf_counter() + with torch.no_grad(): + for sample, summary in zip(dsample,dsummary): + batch = tokenizer(f"Summarize this dialog:\n{sample}\n---\nSummary:\n", padding='max_length', truncation=True, max_length=inference_config.max_padding_length, return_tensors="pt") + batch = {k: v.to("cuda") for k, v in batch.items()} + references.append(summary) + outputs = model.generate( + **batch, + max_new_tokens=inference_config.max_new_tokens, + do_sample=inference_config.do_sample, + top_p=inference_config.top_p, + temperature=inference_config.temperature, + min_length=inference_config.min_length, + use_cache=inference_config.use_cache, + top_k=inference_config.top_k, + repetition_penalty=inference_config.repetition_penalty, + length_penalty=inference_config.length_penalty, + ) + new_tokens = outputs[0][batch["input_ids"].shape[-1]:] + tok_len_generated.append(len(new_tokens)) + predictions_generated.append(tokenizer.decode(new_tokens, skip_special_tokens=True)) + if inference_config.use_gumbel: + # Track activated layers + layer_activations = {i:{} for i in range(len(outputs))} + for layer_idx, res in enumerate(gumbel_activation.values()): + for idx, activation in enumerate(res): + layer_activations[idx][layer_idx] = torch.argmax(activation,dim=1).tolist() + activations.append(layer_activations) + + e2e_inference_time = (time.perf_counter()-start)*1000 + print(f"the inference time is {e2e_inference_time} ms") + + e2e_inference_time_norm = e2e_inference_time / sum(tok_len_generated) + print(f"the inference time normalized by number of tokens is {e2e_inference_time_norm} ms") + + # Compute scores + rouge = evaluate.load('rouge') + results_generated = rouge.compute(predictions=predictions_generated, references=references) + results_generated["time"] = e2e_inference_time + + print("Fixed results: ",results_generated) + + results_file = f"{inference_config.model_name.split('/')[-1]}_samsum" + + if inference_config.use_gumbel: + results_file += "_gumbel-True" + else: + results_file += "_gumbel-False" + + + with open(os.path.join(inference_config.output_dir,f"{results_file}.json"),'w') as f: + json.dump(results_generated,f) + + if inference_config.use_gumbel: + with open(os.path.join(inference_config.output_dir,f"{results_file}_activations.json"),'w') as f: + json.dump(activations,f) + + with open(os.path.join(inference_config.output_dir,f"{results_file}.tsv"),'w') as f: + writer = csv.writer(f, delimiter='\t', quotechar='"') + for p,r in zip(predictions_generated, references): + writer.writerow([p,r]) + + +if __name__ == "__main__": + fire.Fire(main) diff --git a/finetune_llama2.py b/finetune_llama2.py new file mode 100644 index 0000000..a3cccf4 --- /dev/null +++ b/finetune_llama2.py @@ -0,0 +1,286 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import os +from pkg_resources import packaging +import wandb + +import fire +import torch +import torch.distributed as dist +import torch.optim as optim +from peft import get_peft_model, prepare_model_for_kbit_training +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, +) +from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload +from torch.optim.lr_scheduler import StepLR +from torch.utils.data import DistributedSampler +from transformers import ( + LlamaForCausalLM, + LlamaTokenizerFast, + LlamaConfig, + default_data_collator, +) +from transformers.models.llama.modeling_llama import LlamaDecoderLayer + +from configs import fsdp_config, train_config +from policies import AnyPrecisionAdamW, apply_fsdp_checkpointing + +from utils import fsdp_auto_wrap_policy +from utils.config_utils import ( + update_config, + generate_peft_config, + generate_dataset_config, +) +from utils.dataset_utils import get_preprocessed_dataset +from utils.train_utils import ( + train, + freeze_transformer_layers, + setup, + setup_environ_flags, + clear_gpu_cache, + print_model_size, + get_policies +) + +from neuralnets.llama_gumbel import LlamaForCausalLMGumbel + +from random import choices +from string import ascii_lowercase, digits + +def short_uuid(): + return ''.join(choices(ascii_lowercase + digits, k=8)) + + +def main(**kwargs): + # Update the configuration for the training and sharding process + update_config((train_config, fsdp_config), **kwargs) + + # Set the seeds for reproducibility + torch.cuda.manual_seed(train_config.seed) + torch.manual_seed(train_config.seed) + + llama_config = LlamaConfig.from_pretrained(train_config.model_name) + + if train_config.gumbel: + print("Running gumbel training (disabling PEFT!)") + llama_config.use_cache = False + train_config.use_peft = False + model_cls = LlamaForCausalLMGumbel + else: + print("Running regular training") + model_cls = LlamaForCausalLM + + if train_config.enable_fsdp: + setup() + # torchrun specific + local_rank = int(os.environ["LOCAL_RANK"]) + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + if torch.distributed.is_initialized(): + torch.cuda.set_device(local_rank) + clear_gpu_cache(local_rank) + setup_environ_flags(rank) + + if rank == 0: + wandb.init() + wandb.config.update(llama_config.to_dict()) + + # Load the pre-trained model and setup its configuration + if train_config.enable_fsdp and train_config.low_cpu_fsdp: + """ + for FSDP, we can save cpu memory by loading pretrained model on rank0 only. + this avoids cpu oom when loading large models like llama 70B, in which case + model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms + overhead and currently requires latest nightly. + """ + v = packaging.version.parse(torch.__version__) + verify_latest_nightly = v.is_devrelease and v.dev >= 20230701 + if not verify_latest_nightly: + raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, " + "please install latest nightly.") + if rank == 0: + model = model_cls.from_pretrained( + train_config.model_name, + config=llama_config, + load_in_8bit=True if train_config.quantization else None, + device_map="auto" if train_config.quantization else None, + # use_cache=use_cache, + ) + else: + # llama_config.use_cache = use_cache + with torch.device("meta"): + model = model_cls(llama_config) + + else: + model = model_cls.from_pretrained( + train_config.model_name, + config=llama_config, + load_in_8bit=True if train_config.quantization else None, + device_map="auto" if train_config.quantization else None, + # use_cache=use_cache, + ) + if train_config.enable_fsdp and train_config.use_fast_kernels: + """ + For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable + using of Flash Attention or Xformer memory-efficient kernels + based on the hardware being used. This would speed up fine-tuning. + """ + try: + from optimum.bettertransformer import BetterTransformer + model = BetterTransformer.transform(model) + except ImportError: + print("Module 'optimum' not found. Please install 'optimum' it before proceeding.") + print_model_size(model, train_config, rank if train_config.enable_fsdp else 0) + + # Prepare the model for int8 training if quantization is enabled + if train_config.quantization: + model = prepare_model_for_kbit_training(model) + + # Convert the model to bfloat16 if fsdp and pure_bf16 is enabled + if train_config.enable_fsdp and fsdp_config.pure_bf16: + model.to(torch.bfloat16) + + # Load the tokenizer and add special tokens + tokenizer = LlamaTokenizerFast.from_pretrained(train_config.model_name) + tokenizer.add_special_tokens( + { + "pad_token": "", + } + ) + if train_config.use_peft: + peft_config = generate_peft_config(train_config, kwargs) + model = get_peft_model(model, peft_config) + model.print_trainable_parameters() + + #setting up FSDP if enable_fsdp is enabled + if train_config.enable_fsdp: + if not train_config.use_peft and train_config.freeze_layers: + + freeze_transformer_layers(train_config.num_freeze_layers) + + mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank) + my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer) + + model = FSDP( + model, + auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy, + cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None, + mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None, + sharding_strategy=fsdp_config.sharding_strategy, + device_id=torch.cuda.current_device(), + forward_prefetch=True, # May need this for gumbel? + limit_all_gathers=True, + sync_module_states=train_config.low_cpu_fsdp, + param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False) + if train_config.low_cpu_fsdp and rank != 0 else None, + ) + if fsdp_config.fsdp_activation_checkpointing: + apply_fsdp_checkpointing(model) + elif not train_config.quantization and not train_config.enable_fsdp: + model.to("cuda") + + dataset_config = generate_dataset_config(train_config, kwargs) + + # Load and preprocess the dataset for training and validation + dataset_train = get_preprocessed_dataset( + tokenizer, + dataset_config, + split="train", + ) + + dataset_val = get_preprocessed_dataset( + tokenizer, + dataset_config, + split="test", + ) + + # Debugging setup: use 10 instances for training/dev + if train_config.debugging: + dataset_train = torch.utils.data.Subset(dataset_train, [i for i in range(10)]) + dataset_val = torch.utils.data.Subset(dataset_val, [i for i in range(10)]) + + train_sampler = None + val_sampler = None + if train_config.enable_fsdp: + train_sampler = DistributedSampler( + dataset_train, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + shuffle=True, + ) + if train_config.run_validation: + val_sampler = DistributedSampler( + dataset_val, + rank=dist.get_rank(), + num_replicas=dist.get_world_size(), + ) + + # Create DataLoaders for the training and validation dataset + train_dataloader = torch.utils.data.DataLoader( + dataset_train, + batch_size=train_config.batch_size_training, + num_workers=train_config.num_workers_dataloader, + pin_memory=True, + sampler=train_sampler if train_sampler else None, + drop_last=True, + collate_fn=default_data_collator, + ) + + eval_dataloader = None + if train_config.run_validation: + eval_dataloader = torch.utils.data.DataLoader( + dataset_val, + batch_size=train_config.val_batch_size, + num_workers=train_config.num_workers_dataloader, + pin_memory=True, + sampler=val_sampler if val_sampler else None, + drop_last=True, + collate_fn=default_data_collator, + ) + + # Initialize the optimizer and learning rate scheduler + if fsdp_config.pure_bf16 and fsdp_config.optimizer == "anyprecision": + optimizer = AnyPrecisionAdamW( + model.parameters(), + lr=train_config.lr, + momentum_dtype=torch.bfloat16, + variance_dtype=torch.bfloat16, + use_kahan_summation=False, + weight_decay=train_config.weight_decay, + ) + else: + optimizer = optim.AdamW( + model.parameters(), + lr=train_config.lr, + weight_decay=train_config.weight_decay, + ) + scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma) + + # Start the training process + results = train( + model, + train_dataloader, + eval_dataloader, + tokenizer, + optimizer, + scheduler, + train_config.gradient_accumulation_steps, + train_config, + fsdp_config if train_config.enable_fsdp else None, + local_rank if train_config.enable_fsdp else None, + rank if train_config.enable_fsdp else None, + ) + + # Save activations into the model folders + if train_config.gumbel and rank==0: + import json + with open(os.path.join(train_config.output_dir,"gumbel_activations_train.json"),'w') as f: + json.dump(results["gumbel_activations_train"],f) + with open(os.path.join(train_config.output_dir,"gumbel_activations_dev.json"),'w') as f: + json.dump(results["gumbel_activations_dev"],f) + +if __name__ == "__main__": + fire.Fire(main) diff --git a/gumbel_softmax_layer_skipping_2024/__init__.py b/gumbel_softmax_layer_skipping_2024/__init__.py deleted file mode 100644 index 4003b3e..0000000 --- a/gumbel_softmax_layer_skipping_2024/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from .base import BaseClass - - - -__all__ = [ - "subpackage", - "BaseClass" - ] \ No newline at end of file diff --git a/gumbel_softmax_layer_skipping_2024/__main__.py b/gumbel_softmax_layer_skipping_2024/__main__.py deleted file mode 100644 index cbd955c..0000000 --- a/gumbel_softmax_layer_skipping_2024/__main__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Entry point for gumbel_softmax_layer_skipping_2024.""" - -from .cli import main # pragma: no cover - -if __name__ == "__main__": # pragma: no cover - main() diff --git a/gumbel_softmax_layer_skipping_2024/base.py b/gumbel_softmax_layer_skipping_2024/base.py deleted file mode 100644 index 7ecc6ac..0000000 --- a/gumbel_softmax_layer_skipping_2024/base.py +++ /dev/null @@ -1,62 +0,0 @@ -# Example class -class BaseClass: - """ - Base class representing an entity. - - Attributes - ---------- - name : str - The name of the entity. - - Methods - ------- - __init__(): - Initializes a new instance of the BaseClass. - __str__(): - Returns a string representation of the entity. - __repr__(): - Returns a string representation of the entity for debugging. - __eq__(other): - Checks if two entities are equal based on their names. - - """ - - def __init__(self, name: str): - """ - Initializes a new instance of the BaseClass. - """ - self.name = name - - def __str__(self): - """ - Returns a string representation of the entity. - """ - return self.name - - def __repr__(self): - """ - Returns a string representation of the entity for debugging. - """ - return self.name - - def __eq__(self, other): - """ - Checks if two entities are equal based on their names. - - Parameters - ---------- - other : BaseClass - Another instance of BaseClass. - - Returns - ------- - bool - True if the entities are equal, False otherwise. - """ - return self.name == other.name - - def something(self): - """ - Does something. - """ - return "something" diff --git a/gumbel_softmax_layer_skipping_2024/cli.py b/gumbel_softmax_layer_skipping_2024/cli.py deleted file mode 100644 index e577e08..0000000 --- a/gumbel_softmax_layer_skipping_2024/cli.py +++ /dev/null @@ -1,33 +0,0 @@ -"""CLI interface for gumbel_softmax_layer_skipping_2024 project. - -Be creative! do whatever you want! - -- Install click or typer and create a CLI app -- Use builtin argparse -- Start a web application -- Import things from your .base module -""" -from .base import BaseClass -from .subpackage import SubPackageClass - -def main(): # pragma: no cover - """ - The main function executes on commands: - `python -m gumbel_softmax_layer_skipping_2024` and `$ gumbel_softmax_layer_skipping_2024 `. - - This is your program's entry point. - - You can change this function to do whatever you want. - Examples: - * Run a test suite - * Run a server - * Do some other stuff - * Run a command line application (Click, Typer, ArgParse) - * List all available tasks - * Run an application (Flask, FastAPI, Django, etc.) - """ - bc = BaseClass("test") - print(f"This will do something: {bc.something()}") - - spc = SubPackageClass("test") - print(f"This will do something else: {spc.something()}") diff --git a/gumbel_softmax_layer_skipping_2024/subpackage/__init__.py b/gumbel_softmax_layer_skipping_2024/subpackage/__init__.py deleted file mode 100644 index bf2baf1..0000000 --- a/gumbel_softmax_layer_skipping_2024/subpackage/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .subpackage import SubPackageClass - -__all__ = [ - "SubPackageClass" -] \ No newline at end of file diff --git a/gumbel_softmax_layer_skipping_2024/subpackage/subpackage.py b/gumbel_softmax_layer_skipping_2024/subpackage/subpackage.py deleted file mode 100644 index c676ed5..0000000 --- a/gumbel_softmax_layer_skipping_2024/subpackage/subpackage.py +++ /dev/null @@ -1,15 +0,0 @@ -class SubPackageClass: - def __init__(self, name): - self.name = name - - def __str__(self): - return f"SubPackage - {self.name}" - - def __repr__(self): - return f"SubPackage - {self.name}" - - def __eq__(self, other): - return self.name == other.name - - def something(self): - return "SubPackage - something" \ No newline at end of file diff --git a/llama_datasets/__init__.py b/llama_datasets/__init__.py new file mode 100644 index 0000000..a49a573 --- /dev/null +++ b/llama_datasets/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from llama_datasets.samsum_dataset import get_preprocessed_samsum as get_samsum_dataset +from llama_datasets.cnndm_dataset import get_preprocessed_cnndm as get_cnndm_dataset + diff --git a/llama_datasets/cnndm_dataset.py b/llama_datasets/cnndm_dataset.py new file mode 100644 index 0000000..2d8270c --- /dev/null +++ b/llama_datasets/cnndm_dataset.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +# For dataset details visit: https://huggingface.co/datasets/cnn_dailymail + +import datasets + +from llama_datasets.utils import Concatenator + +def get_preprocessed_cnndm(dataset_config, tokenizer, split): + dataset = datasets.load_dataset("cnn_dailymail", split=split, ignore_verifications=True) + + + prompt = ( + f"Summarize this article:\n{{article}}\n---\nSummary:\n{{highlights}}{{eos_token}}" + ) + + def apply_prompt_template(sample): + return { + "text": prompt.format( + article=sample["article"], + highlights=sample["highlights"], + eos_token=tokenizer.eos_token, + ) + } + + dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features)) + + dataset = dataset.map( + lambda sample: tokenizer(sample["text"]), + batched=True, + remove_columns=list(dataset.features), + ).map(Concatenator(), batched=True) + return dataset diff --git a/llama_datasets/samsum_dataset.py b/llama_datasets/samsum_dataset.py new file mode 100644 index 0000000..6d6a7c7 --- /dev/null +++ b/llama_datasets/samsum_dataset.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +# For dataset details visit: https://huggingface.co/datasets/samsum + +import datasets + +from llama_datasets.utils import Concatenator + +def get_preprocessed_samsum(dataset_config, tokenizer, split): + dataset = datasets.load_dataset("samsum", split=split) + + prompt = ( + f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n{{summary}}{{eos_token}}" + ) + + def apply_prompt_template(sample): + return { + "text": prompt.format( + dialog=sample["dialogue"], + summary=sample["summary"], + eos_token=tokenizer.eos_token, + ) + } + + dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features)) + + dataset = dataset.map( + lambda sample: tokenizer(sample["text"]), + batched=True, + remove_columns=list(dataset.features), + ).map(Concatenator(), batched=True) + return dataset diff --git a/llama_datasets/utils.py b/llama_datasets/utils.py new file mode 100644 index 0000000..0a11d8c --- /dev/null +++ b/llama_datasets/utils.py @@ -0,0 +1,66 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from tqdm import tqdm +from itertools import chain + +from torch.utils.data import Dataset + +class Concatenator(object): + def __init__(self, chunk_size=2048): + self.chunk_size=chunk_size + self.residual = {"input_ids": [], "attention_mask": []} + + def __call__(self, batch): + concatenated_samples = { + k: v + list(chain(*batch[k])) for k, v in self.residual.items() + } + + total_length = len(concatenated_samples[list(concatenated_samples.keys())[0]]) + + if total_length >= self.chunk_size: + chunk_num = total_length // self.chunk_size + result = { + k: [ + v[i : i + self.chunk_size] + for i in range(0, chunk_num * self.chunk_size, self.chunk_size) + ] + for k, v in concatenated_samples.items() + } + self.residual = { + k: v[(chunk_num * self.chunk_size) :] + for k, v in concatenated_samples.items() + } + else: + result = concatenated_samples + self.residual = {k: [] for k in concatenated_samples.keys()} + + result["labels"] = result["input_ids"].copy() + + return result + +class ConcatDataset(Dataset): + def __init__(self, dataset, chunk_size=4096): + self.dataset = dataset + self.chunk_size = chunk_size + + self.samples = [] + + buffer = { + "input_ids": [], + "attention_mask": [], + "labels": [], + } + + for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True): + buffer = {k: v + sample[k] for k,v in buffer.items()} + + while len(next(iter(buffer.values()))) > self.chunk_size: + self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()}) + buffer = {k: v[self.chunk_size:] for k,v in buffer.items()} + + def __getitem__(self, idx): + return self.samples[idx] + + def __len__(self): + return len(self.samples) diff --git a/mkdocs.yml b/mkdocs.yml deleted file mode 100644 index 1c1983d..0000000 --- a/mkdocs.yml +++ /dev/null @@ -1,15 +0,0 @@ -site_name: gumbel_softmax_layer_skipping_2024 -nav: - - Home: index.md - -theme: - name: material - -plugins: - - search - - mkdocstrings: - project_name: gumbel_softmax_layer_skipping_2024 - handlers: - python: - options: - docstring_style: numpy \ No newline at end of file diff --git a/model_checkpointing/__init__.py b/model_checkpointing/__init__.py new file mode 100644 index 0000000..0e88b78 --- /dev/null +++ b/model_checkpointing/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from model_checkpointing.checkpoint_handler import ( + load_model_checkpoint, + save_model_checkpoint, + load_optimizer_checkpoint, + save_optimizer_checkpoint, + save_model_and_optimizer_sharded, + load_model_sharded, + load_sharded_model_single_gpu, + save_model_checkpoint_nofsdp, + save_optimizer_checkpoint_nofsdp +) diff --git a/model_checkpointing/checkpoint_handler.py b/model_checkpointing/checkpoint_handler.py new file mode 100644 index 0000000..17fb3f3 --- /dev/null +++ b/model_checkpointing/checkpoint_handler.py @@ -0,0 +1,324 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from pathlib import Path +from datetime import datetime +import os +import torch +import time + +from torch.distributed.fsdp import ( + FullyShardedDataParallel as FSDP, + StateDictType, + FullStateDictConfig, # general model non-sharded, non-flattened params + LocalStateDictConfig, # flattened params, usable only by FSDP + # ShardedStateDictConfig, # un-flattened param but shards, usable by other parallel schemes. +) + +from torch.distributed._shard.checkpoint import ( + FileSystemReader, + FileSystemWriter, + save_state_dict, + load_state_dict, +) +from torch.distributed.checkpoint.default_planner import ( + DefaultSavePlanner, + DefaultLoadPlanner, +) + + +from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType +import torch.distributed._shard.checkpoint as dist_cp +import torch.distributed as dist + + +def get_date_of_run(): + """create date and time for file save uniqueness + example: 2022-05-07-08:31:12_PM' + """ + date_of_run = datetime.now().strftime("%Y-%m-%d-%I:%M:%S_%p") + print(f"--> current date and time of run = {date_of_run}") + return date_of_run + + +# create singleton saving policies to avoid making over and over +fullstate_save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + + +def load_model_sharded(model, rank, cfg): + # torch.manual_seed(103) + model_name_local = cfg.model_name.split('/')[-1] + folder_name = ( + cfg.dist_checkpoint_root_folder + + "/" + + cfg.dist_checkpoint_folder + + "-" + + model_name_local + ) + + load_dir = folder_name # Path.cwd() / + + if not load_dir.exists(): + if rank == 0: + print(f"No sharded_state_dict checkpoint directory found...skipping") + return + if rank == 0: + print(f"loading model from model path: {load_dir} ") + reader = FileSystemReader(load_dir) + + with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): + checkpoint = {"model": model.state_dict()} + if rank == 0: + ck = checkpoint.keys() + print(f" checkpoint key len = {len(ck)} and \n keys = {ck}") + + dist_cp.load_state_dict( + state_dict=checkpoint, + storage_reader=reader, + ) + if rank == 0: + print(f"checkpoint after load_state_dict()") + ck = checkpoint.keys() + print(f" checkpoint key len = {len(ck)} and \n keys = {ck}") + model.load_state_dict(checkpoint["model"]) + if rank == 0: + print(f"Sharded state checkpoint loaded from {load_dir}") + + +def save_model_and_optimizer_sharded(model, rank, cfg,optim=None): + """save model and optimizer via sharded_state_dict to save_dir""" + model_name_local = cfg.model_name.split('/')[-1] + folder_name = ( + cfg.dist_checkpoint_root_folder + + "/" + + cfg.dist_checkpoint_folder + + "-" + + model_name_local + ) + + save_dir = folder_name # Path.cwd() / + if rank == 0: + print(f"Saving model to {save_dir}") + + distributed_writer = dist_cp.FileSystemWriter( + save_dir, + ) + t0 = time.perf_counter() + + with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT): + + state_dict = {"model": model.state_dict()} + if optim is not None: + state_dict["optim"] = FSDP.optim_state_dict(model, optim) + + dist_cp.save_state_dict( + state_dict=state_dict, + storage_writer=distributed_writer, + planner=DefaultSavePlanner(), + + ) + dist.barrier() + t1 = time.perf_counter() + if rank == 0: + print(f"Sharded state checkpoint saved to {save_dir}") + print( + f"Checkpoint Time = {t1-t0:.4f}\n" + ) +def save_model_checkpoint( + model, + optimizer, + rank, + cfg, + epoch=1, +): + """saving model via rank0 cpu streaming and full_state_dict""" + + with FSDP.state_dict_type( + model, StateDictType.FULL_STATE_DICT, fullstate_save_policy + ): + cpu_state = model.state_dict() + + print(f"saving process: rank {rank} done w model state_dict\n") + + + if rank == 0: + print(f"--> saving model ...") + # create save path + model_name_local = cfg.model_name.split('/')[-1] + folder_name = ( + cfg.dist_checkpoint_root_folder + + "/" + + cfg.dist_checkpoint_folder + + "-" + + model_name_local + ) + save_dir = folder_name + os.mkdirs(save_dir) + save_name = model_name_local + "-" + str(epoch) + ".pt" + save_full_path = str(save_dir) + "/" + save_name + + # save model + torch.save(cpu_state, save_full_path) + + + print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n") + +def save_model_checkpoint_nofsdp( + model, + cfg, + epoch=1, +): + """saving model via rank0 cpu streaming and full_state_dict""" + + cpu_state = model.state_dict() + + print(f"--> saving model ...") + # create save path + model_name_local = cfg.model_name.split('/')[-1] + folder_name = ( + cfg.dist_checkpoint_root_folder + + "/" + + cfg.dist_checkpoint_folder + + "-" + + model_name_local + ) + save_dir = folder_name + os.mkdirs(save_dir) + save_name = model_name_local + "-" + str(epoch) + ".pt" + save_full_path = str(save_dir) + "/" + save_name + + # save model + torch.save(cpu_state, save_full_path) + + print(f"model checkpoint saved for epoch {epoch} at {save_full_path}\n") + +def load_model_checkpoint(model, rank, cfg): + """load local checkpoint to rank0 cpu + must be called * before * passing to FSDP""" + + if rank != 0: + return + + # where is the checkpoint at... + full_state_dict_model_path = ( + cfg.checkpoint_folder / cfg.checkpoint_model_filename # Path.cwd() / + ) + # is it present... + if not full_state_dict_model_path.is_file(): + print( + f"model checkpoint {full_state_dict_model_path} not present. Returning..." + ) + return + + model_checkpoint = torch.load(full_state_dict_model_path) + # integrate into loaded model + model.load_state_dict(model_checkpoint) + + + print(f"model checkpoint loaded to rank0 cpu") + + +def save_optimizer_checkpoint_nofsdp(optimizer, cfg, epoch=1): + """save optimizer state via full state dict""" + + optim_state = optimizer.state_dict() + + model_name_local = cfg.model_name.split('/')[-1] + folder_name = ( + cfg.dist_checkpoint_root_folder + + "/" + + cfg.dist_checkpoint_folder + + "-" + + model_name_local + ) + save_dir = folder_name # Path.cwd() / + os.mkdirs(save_dir) + + opt_save_name = ( + "optimizer" + "-" + model_name_local + "-" + str(epoch) + ".pt" + ) + opt_save_full_path = save_dir / opt_save_name + + print(f"--> saving optimizer state...") + + torch.save(optim_state, opt_save_full_path) + + print(f"--> saved {opt_save_full_path} to disk") + +def save_optimizer_checkpoint(model, optimizer, rank, cfg, epoch=1): + """save optimizer state via full state dict""" + + + print(f"--> optim state call on rank {rank}\n") + + # pull all sharded optimizer states to rank0 cpu... + + optim_state = FSDP.full_optim_state_dict(model, optimizer) + + + print(f"optim state dict ready on {rank} and len of {len(optim_state)}\n") + + if rank == 0: + model_name_local = cfg.model_name.split('/')[-1] + folder_name = ( + cfg.dist_checkpoint_root_folder + + "/" + + cfg.dist_checkpoint_folder + + "-" + + model_name_local + ) + save_dir = folder_name # Path.cwd() / + os.mkdirs(save_dir) + + opt_save_name = ( + "optimizer" + "-" + model_name_local + "-" + str(epoch) + ".pt" + ) + opt_save_full_path = save_dir / opt_save_name + + print(f"--> saving optimizer state...") + + torch.save(optim_state, opt_save_full_path) + + print(f"--> saved {opt_save_full_path} to disk") + + +def load_optimizer_checkpoint(model, optimizer_checkpoint_path, rank): + """load an fsdp optimizer full_state checkpoint using scatter method + this ensures only rank 0 loads the optimizer state dict and scatters to other ranks + """ + + + if not optimizer_checkpoint_path.is_file(): + print( + f"warning - optimizer checkpoint not present {optimizer_checkpoint_path}. Returning. " + ) + return + + full_osd = None + + if rank == 0: + full_osd = torch.load(optimizer_checkpoint_path) + + # called from all ranks, though only rank0 has a valid param for full_osd + sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, model) + + print(f"optimizer shard loaded on rank {rank}") + +def load_sharded_model_single_gpu(model,model_path): + + reader = FileSystemReader(model_path) + + state_dict = { + "model": model.state_dict() + } + + dist_cp.load_state_dict( + state_dict=state_dict, + storage_reader= FileSystemReader(model_path), + no_dist=True, + ) + + model.load_state_dict(state_dict["model"]) + + print(f"Sharded state checkpoint loaded from {model_path}") + return model diff --git a/neuralnets/llama_gumbel.py b/neuralnets/llama_gumbel.py new file mode 100644 index 0000000..90b4d48 --- /dev/null +++ b/neuralnets/llama_gumbel.py @@ -0,0 +1,722 @@ +import math +from dataclasses import dataclass +from typing import Optional, Union, Tuple, List + +import torch +import torch.nn.functional as F +import wandb +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers import AutoTokenizer +from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast +from transformers.models.llama.modeling_llama import LlamaModel, LlamaPreTrainedModel, LlamaConfig, LlamaDecoderLayer, \ + LlamaRMSNorm, apply_rotary_pos_emb, repeat_kv, LlamaAttention, LlamaMLP, LlamaFlashAttention2 +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +@dataclass +class CausalLMOutputWithPastActivations(CausalLMOutputWithPast): + # Extra class to carry activations for the output + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + activations: Optional[Tuple[torch.FloatTensor]] = None + + +class LlamaForCausalLMGumbel(LlamaPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + self.model = LlamaModelGumbel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + self.gumbel_threshold = config.gumbel_threshold if hasattr(config, + "gumbel_threshold") else 0.5 + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + + >>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS) + >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER) + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + # Fetch activations for loss computation + outputs, activations = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + if self.config.pretraining_tp > 1: + lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0) + logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)] + logits = torch.cat(logits, dim=-1) + else: + logits = self.lm_head(hidden_states) + logits = logits.float() + + loss = None + + if labels is not None: + tokenizer_path = train_config.model_name + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) + current_text = tokenizer.batch_decode(labels, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss_ce = loss_fct(shift_logits, shift_labels) + + shift_activations = activations[...,:,:-1].contiguous() + activation_rate = torch.mean(shift_activations) + loss_gumbel = (self.gumbel_threshold - activation_rate) ** 2 + loss = loss_ce + loss_gumbel + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPastActivations( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + activations=activations, + ) + + def compute_current_loss(self, logits, labels): + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + return loss + + def calc_loss(loss_normal, loss_gumbel, alpha=0.8, beta=0.0005, gumbel_multiplier=50.0): + """ + Calculates the loss using the weighted sum of the normal loss and the gumbel loss, + where the weight is reduced over time. + + Args: + loss_normal: normal loss + loss_gumbel: gumbel loss + alpha: initial weighting factor for the gumbel loss + beta: controls the rate at which the weighting factor decreases + gumbel_multiplier: simple multiplier for the gumbel loss + + Returns: loss + + """ + #w_t = alpha * torch.exp(torch.tensor(-beta * step)) + # Trade-off factor + w_t = 0.5 + return (1 - w_t) * loss_normal, w_t * loss_gumbel * gumbel_multiplier + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + + +class LlamaDecoderLayer2(LlamaDecoderLayer): + def __init__(self, config: LlamaConfig, layer_idx: int): + super().__init__(config, layer_idx) + self.hidden_size = config.hidden_size + self.self_attn = ( + LlamaAttention(config=config) + if not getattr(config, "_flash_attn_2_enabled", False) + else LlamaFlashAttention2(config=config) + ) + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + padding_mask: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + """ + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + padding_mask=padding_mask, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs, present_key_value + + +class LlamaAttention2(LlamaAttention): + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + padding_mask: Optional[torch.LongTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + if self.config.pretraining_tp > 1: + key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp + query_slices = self.q_proj.weight.split( + (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0 + ) + key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) + value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) + + query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)] + query_states = torch.cat(query_states, dim=-1) + + key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)] + key_states = torch.cat(key_states, dim=-1) + + value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)] + value_states = torch.cat(value_states, dim=-1) + + else: + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + past_key_value = (key_states, value_states) # if use_cache else None + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + if self.config.pretraining_tp > 1: + attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2) + o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1) + attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)]) + else: + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +class LlamaModelGumbel(LlamaModel): + def __init__(self, config: LlamaConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([LlamaDecoderLayer2(config, i) for i in range(config.num_hidden_layers)]) + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + # add gumbel softmax layer + self.use_simple_classifier = config.gumbel_use_simple_classifier if hasattr(config, + "gumbel_use_simple_classifier") else False + self.share_layer = config.share_layer if hasattr(config, "share_layer") else False + + if self.use_simple_classifier: + if self.share_layer: + self.gumbel_layer_selection = nn.ModuleList( + [SelectionLayer(config, use_gumbel=False) for _ in range(1)]) + else: + self.gumbel_layer_selection = nn.ModuleList( + [SelectionLayer(config, use_gumbel=False) for _ in range(config.num_hidden_layers)]) + else: + if self.share_layer: + self.gumbel_layer_selection = nn.ModuleList([SelectionLayer(config) for _ in range(1)]) + else: + self.gumbel_layer_selection = nn.ModuleList( + [SelectionLayer(config) for _ in range(config.num_hidden_layers)]) + self.gumbel_threshold = config.gumbel_threshold if hasattr(config, + "gumbel_threshold") else 0.5 # TODO: use 1.0 as hard cap (for debugging) + # ensure gradients for gumbel layers: + self.gumbel_threshold = torch.tensor(self.gumbel_threshold, dtype=torch.float, requires_grad=True) + self.set_requires_grad(self.gumbel_layer_selection, True) + self.set_requires_grad(self.layers, True) + + # Initialize non-skipped layers (bottom and top) + self.gumbel_noskip_low = config.gumbel_noskip_low if hasattr(config, "gumbel_noskip_low") else 2 + self.gumbel_noskip_high = config.gumbel_noskip_high if hasattr(config, "gumbel_noskip_high") else 32 + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def _make_causal_mask(self, + input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 + ): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + # from transformers.models.bart.modeling_bart._expand_mask + def _expand_mask(self,mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask + def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + combined_attention_mask = None + if input_shape[-1] > 1: + combined_attention_mask = self._make_causal_mask( + input_shape, + inputs_embeds.dtype, + device=inputs_embeds.device, + past_key_values_length=past_key_values_length, + ) + + if attention_mask is not None: + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = self._expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( + inputs_embeds.device + ) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask + ) + + return combined_attention_mask + + # Set requires grad for generic module + def set_requires_grad(self, m, requires_grad): + for param in m.parameters(): + param.requires_grad_(requires_grad) + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + a_mask = attention_mask.clone() + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # retrieve input_ids and inputs_embeds + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + + if past_key_values is not None: + + print("PKV: ", past_key_values) + + past_key_values_length = past_key_values[0][0].shape[2] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if position_ids is None: + device = input_ids.device if input_ids is not None else inputs_embeds.device + position_ids = torch.arange( + past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device + ) + position_ids = position_ids.unsqueeze(0).view(-1, seq_length) + else: + position_ids = position_ids.view(-1, seq_length).long() + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + # embed positions + if attention_mask is None: + attention_mask = torch.ones( + (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device + ) + attention_mask = self._prepare_decoder_attention_mask( + attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length + ) + + hidden_states = inputs_embeds + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + # Compute activation rate according to no-skip layers + activations = () + + for idx, decoder_layer in enumerate(self.layers): + + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, past_key_value, output_attentions) + + return custom_forward + + decoder_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + attention_mask, + position_ids, + ) + else: + decoder_outputs, present_key_value = decoder_layer( + hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + if self.share_layer: + gumbel_mask = self.gumbel_layer_selection[0](hidden_states, present_key_value, a_mask) + else: + gumbel_mask = self.gumbel_layer_selection[idx](hidden_states, present_key_value, a_mask) + + activations += (gumbel_mask[:,:, 0].clone().detach().requires_grad_(True),) + + if idx > self.gumbel_noskip_low and idx < self.gumbel_noskip_high: + # Only apply gumbel layer skipping between a certain range + hidden_states = torch.mul(decoder_outputs[0], torch.unsqueeze(gumbel_mask[:,:,0], 2)) + torch.mul(hidden_states, + torch.unsqueeze(gumbel_mask[:,:,1], 2)) + else: + hidden_states = decoder_outputs[0] + + if use_cache: + next_decoder_cache += (decoder_outputs[2 if output_attentions else 1],) + + #print("next_decoder_cache: ", len(next_decoder_cache[0])) + #print("next_decoder_cache shape: ", next_decoder_cache[0][0].shape) + + if output_attentions: + all_self_attns += (decoder_outputs[1],) + + hidden_states = self.norm(hidden_states) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + # Note: only forward the activations as we do not expose this to the main model output + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ), torch.stack(activations) + + +from torch.nn.functional import gumbel_softmax + + +class SelectionLayer(nn.Module): + """ + Selection Layer for Gumbel Softmax Layer Selection. + + # https://pytorch.org/docs/stable/generated/torch.nn.functional.gumbel_softmax.html + """ + + def __init__(self, config, use_gumbel=True): + """ + Selection Layer (with gumbel) + Args: + config: training config + use_gumbel: True for using gumbel, otherwise just sigmoid + """ + super().__init__() + + self.layers = config.num_hidden_layers + self.hidden_layers = config.gumbel_num_hidden_layers if hasattr(config, 'gumbel_num_hidden_layers') else 1 + + self.dense = nn.ModuleList( + [nn.Linear(config.hidden_size, config.hidden_size) for _ in range(self.hidden_layers)]) + self.activation = [nn.ReLU() for _ in range(self.hidden_layers)] + self.use_gumbel = use_gumbel + self.gumbel_tau = config.gumbel_temperature if hasattr(config, 'gumbel_temperature') else 1.0 + self.gumbel_tau = torch.tensor(self.gumbel_tau, requires_grad=True) + if use_gumbel: + self.dense_final = nn.Linear(config.hidden_size, 2) + else: + self.dense_final = nn.Linear(config.hidden_size, 1) + self.activation_final = nn.Sigmoid() + + self.use_token_max = config.use_token_max if hasattr(config, 'use_token_max') else False + self.use_only_last_token = config.use_only_last_token if hasattr(config, 'use_only_last_token') else False + self.use_only_past_key_values = config.use_only_past_key_values if hasattr(config, + 'use_only_past_key_values') else False + + def forward(self, hidden_states: torch.Tensor, present_key_value, attention_mask: torch.Tensor) -> torch.Tensor: + # use each hidden layer + for i, dense_layer in enumerate(self.dense): + hidden_states = dense_layer(hidden_states) + hidden_states = self.activation[i](hidden_states) + # final dense + hidden_states = self.dense_final(hidden_states) + hidden_states = self.activation_final(hidden_states) + + if self.use_only_past_key_values: + if present_key_value is not None: + # Only use past key values + hidden_states = hidden_states[:, :present_key_value.shape[1]] + else: + hidden_states = hidden_states[:, -1, :] + elif self.use_only_last_token: + # Only use last token + hidden_states = hidden_states[:, -1, :] + else: + # average over tokens + new_attention_mask = attention_mask.unsqueeze(-1).repeat(1, 1, hidden_states.shape[-1]) + # average over tokens + masked_tensor = hidden_states * new_attention_mask + hidden_states = gumbel_softmax(masked_tensor, tau=self.gumbel_tau, hard=True, dim=-1) + + return hidden_states diff --git a/policies/__init__.py b/policies/__init__.py new file mode 100644 index 0000000..c7a6f0e --- /dev/null +++ b/policies/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from policies.mixed_precision import * +from policies.wrapping import * +from policies.activation_checkpointing_functions import apply_fsdp_checkpointing +from policies.anyprecision_optimizer import AnyPrecisionAdamW diff --git a/policies/activation_checkpointing_functions.py b/policies/activation_checkpointing_functions.py new file mode 100644 index 0000000..818b7da --- /dev/null +++ b/policies/activation_checkpointing_functions.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from functools import partial + +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper, + CheckpointImpl, + apply_activation_checkpointing, +) +from transformers.models.llama.modeling_llama import LlamaDecoderLayer + +non_reentrant_wrapper = partial( + checkpoint_wrapper, + checkpoint_impl=CheckpointImpl.NO_REENTRANT, +) + +check_fn = lambda submodule: isinstance(submodule, LlamaDecoderLayer) + + +def apply_fsdp_checkpointing(model): + """apply activation checkpointing to model + returns None as model is updated directly + """ + print(f"--> applying fsdp activation checkpointing...") + + apply_activation_checkpointing( + model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn + ) diff --git a/policies/anyprecision_optimizer.py b/policies/anyprecision_optimizer.py new file mode 100644 index 0000000..22b0ca0 --- /dev/null +++ b/policies/anyprecision_optimizer.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +# AnyPrecisionAdamW: a flexible precision AdamW optimizer +# with optional Kahan summation for high precision weight updates. +# Allows direct control over momentum, variance and auxiliary compensation +# buffer dtypes. +# Optional Kahan summation is used to offset precision reduction for +# the weight updates. This allows full training in BFloat16 (equal or +# better than FP32 results in many cases) due to high precision weight upates. + +import torch +from torch.optim.optimizer import Optimizer + + +class AnyPrecisionAdamW(Optimizer): + def __init__( + self, + params, + lr=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0.0, + use_kahan_summation=False, + momentum_dtype=torch.bfloat16, + variance_dtype=torch.bfloat16, + compensation_buffer_dtype=torch.bfloat16, + ): + """ + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + lr (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay coefficient (default: 1e-2) + + # Any Precision specific + use_kahan_summation = creates auxiliary buffer to ensure high precision + model param updates (default: False) + momentum_dtype = dtype for momentum (default: BFloat32) + variance_dtype = dtype for uncentered variance (default: BFloat16) + compensation_buffer_dtype = dtype for Kahan summation + buffer (default: BFloat16) + + # Usage + This optimizer implements optimizer states, and Kahan summation + for high precision updates, all in user controlled dtypes. + Defaults are variance in BF16, Momentum in FP32. + This can be run in FSDP mixed precision, amp, or full precision, + depending on what training pipeline you wish to work with. + + Setting to use_kahan_summation = False, and changing momentum and + variance dtypes to FP32, reverts this to a standard AdamW optimizer. + + """ + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + use_kahan_summation=use_kahan_summation, + momentum_dtype=momentum_dtype, + variance_dtype=variance_dtype, + compensation_buffer_dtype=compensation_buffer_dtype, + ) + + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + + if closure is not None: + with torch.enable_grad(): + # to fix linter, we do not keep the returned loss for use atm. + closure() + + for group in self.param_groups: + + beta1, beta2 = group["betas"] + lr = group["lr"] + weight_decay = group["weight_decay"] + eps = group["eps"] + use_kahan_summation = group["use_kahan_summation"] + + momentum_dtype = group["momentum_dtype"] + variance_dtype = group["variance_dtype"] + compensation_buffer_dtype = group["compensation_buffer_dtype"] + + for p in group["params"]: + if p.grad is None: + continue + + if p.grad.is_sparse: + raise RuntimeError( + "AnyPrecisionAdamW does not support sparse gradients" + ) + + state = self.state[p] + + # State initialization + if len(state) == 0: + + state["step"] = torch.tensor(0.0) + + # momentum - EMA of gradient values + state["exp_avg"] = torch.zeros_like( + p, + dtype=momentum_dtype, + ) + + # variance uncentered - EMA of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + p, + dtype=variance_dtype, + ) + + # optional Kahan summation - accumulated error tracker + if use_kahan_summation: + state["compensation"] = torch.zeros_like( + p, + dtype=compensation_buffer_dtype, + ) + + # main processing ------------------------- + + # update the steps for each param group update + state["step"] += 1 + step = state["step"] + + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + + grad = p.grad + + # weight decay, AdamW style + if weight_decay: + p.data.mul_(1 - lr * weight_decay) + + # update momentum + exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) + + # update uncentered variance + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + + # adjust using bias1 + bias_correction1 = 1 - beta1**step + + step_size = lr / bias_correction1 + + # adjust using bias2 + denom_correction = (1 - beta2**step) ** 0.5 # avoids math import + + centered_variance = (exp_avg_sq.sqrt() / denom_correction).add_( + eps, alpha=1 + ) + + # lr update to compensation + if use_kahan_summation: + compensation = state["compensation"] + + compensation.addcdiv_(exp_avg, centered_variance, value=-step_size) + + # update weights with compensation (Kahan summation) + # save error back to compensation for next iteration + temp_buffer = p.detach().clone() + p.data.add_(compensation) + compensation.add_(temp_buffer.sub_(p.data)) + + else: + # usual AdamW updates + p.data.addcdiv_(exp_avg, centered_variance, value=-step_size) \ No newline at end of file diff --git a/policies/mixed_precision.py b/policies/mixed_precision.py new file mode 100644 index 0000000..11df7ed --- /dev/null +++ b/policies/mixed_precision.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import torch + +from torch.distributed.fsdp import ( + MixedPrecision, +) + +# requires grad scaler in main loop +fpSixteen = MixedPrecision( + param_dtype=torch.float16, + # Gradient communication precision. + reduce_dtype=torch.float16, + # Buffer precision. + buffer_dtype=torch.float16, +) + +bfSixteen = MixedPrecision( + param_dtype=torch.bfloat16, + # Gradient communication precision. + reduce_dtype=torch.bfloat16, + # Buffer precision. + buffer_dtype=torch.bfloat16, + cast_forward_inputs=True, +) + +bfSixteen_mixed = MixedPrecision( + param_dtype=torch.float32, + reduce_dtype=torch.bfloat16, + buffer_dtype=torch.bfloat16, +) + +fp32_policy = MixedPrecision( + param_dtype=torch.float32, + reduce_dtype=torch.float32, + buffer_dtype=torch.float32, +) diff --git a/policies/wrapping.py b/policies/wrapping.py new file mode 100644 index 0000000..da7981c --- /dev/null +++ b/policies/wrapping.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import functools + +from transformers.models.llama.modeling_llama import LlamaDecoderLayer +from torch.distributed.fsdp.wrap import ( + transformer_auto_wrap_policy, + size_based_auto_wrap_policy, +) + + +def get_size_policy(min_params=1e8): + num_wrap_policy = functools.partial( + size_based_auto_wrap_policy, min_num_params=min_params + ) + return num_wrap_policy + + +def get_llama_wrapper(): + """we register our main layer class and use the fsdp transformer wrapping policy + ensures embedding layers are in the root fsdp unit for shared access and that fsdp units map to transformer layers + """ + # ==== use new transformer wrapper + + llama_auto_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls={ + LlamaDecoderLayer, + }, + ) + + return llama_auto_wrap_policy diff --git a/requirements-dev.txt b/requirements-dev.txt deleted file mode 100644 index 829cafe..0000000 --- a/requirements-dev.txt +++ /dev/null @@ -1,14 +0,0 @@ -# This requirements are for development and testing only, not for production. -# You shouldn't need to edit this file, unless you know what you are doing. -pytest -coverage -pylint -black -isort -pytest-cov -mypy -gitchangelog -mkdocs -mkdocs-material -mkdocstrings -mkdocstrings-python diff --git a/requirements.txt b/requirements.txt index 2c0a27e..cf84847 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,20 @@ -# This template is low-dependency. -# By default there is no requirements added here. -# Add the requirements you need to this file. +accelerate>=0.16.0,<1 +click>=8.0.4,<9 +datasets>=2.10.0,<3 +deepspeed>=0.8.3,<0.9 +transformers[torch]>=4.28.1,<5 +langchain>=0.0.139 +protobuf==3.20.3 +torch==2.0.1 +tensorboard +wandb +scikit-learn +tqdm +evaluate +trl==0.5.0 +peft==0.4.0 +bitsandbytes==0.41.1 +pydantic==1.10.9 +fire +py7zr + diff --git a/setup.py b/setup.py deleted file mode 100644 index 2d00a74..0000000 --- a/setup.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Python setup.py for gumbel_softmax_layer_skipping_2024 package""" -import io -import os -from setuptools import find_packages, setup - - -def read(*paths, **kwargs): - """Read the contents of a text file safely. - >>> read("gumbel_softmax_layer_skipping_2024", "VERSION") - '0.1.0' - >>> read("README.md") - ... - """ - - content = "" - with io.open( - os.path.join(os.path.dirname(__file__), *paths), - encoding=kwargs.get("encoding", "utf8"), - ) as open_file: - content = open_file.read().strip() - return content - - -def read_requirements(path): - return [ - line.strip() - for line in read(path).split("\n") - if not line.startswith(('"', "#", "-", "git+")) - ] - - -setup( - name="gumbel_softmax_layer_skipping_2024", - url="https://github.com/UKPLab/gumbel-softmax-layer-skipping-2024/", - long_description=read("README.md"), - long_description_content_type="text/markdown", - author="author_name", - packages=find_packages(exclude=["tests", ".github"]), - install_requires=read_requirements("requirements.txt"), - entry_points={ - "console_scripts": ["gumbel_softmax_layer_skipping_2024 = gumbel_softmax_layer_skipping_2024.__main__:main"] - }, - extras_require={"test": read_requirements("requirements-dev.txt")}, -) diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index 2dc2a8f..0000000 --- a/tests/conftest.py +++ /dev/null @@ -1,17 +0,0 @@ -# This file defines fixtures that are used by all tests in the tests/ directory. -# See https://docs.pytest.org/en/latest/fixture.html for more information about fixtures. - -import sys -import pytest - - -# each test runs on cwd to its temp dir -@pytest.fixture(autouse=True) -def go_to_tmpdir(request): - # Get the fixture dynamically by its name. - tmpdir = request.getfixturevalue("tmpdir") - # ensure local test created packages can be imported - sys.path.insert(0, str(tmpdir)) - # Chdir only for the duration of the test. - with tmpdir.as_cwd(): - yield diff --git a/tests/test_base.py b/tests/test_base.py deleted file mode 100644 index 5face46..0000000 --- a/tests/test_base.py +++ /dev/null @@ -1,22 +0,0 @@ -# Tests are defined here -from gumbel_softmax_layer_skipping_2024 import BaseClass -from gumbel_softmax_layer_skipping_2024.subpackage import SubPackageClass - -def test_template(): - assert True - -def test_base_class(): - bc1 = BaseClass(name="test1") - bc2 = BaseClass(name="test2") - - assert str(bc1) == "test1" - assert repr(bc1) == "test1" - assert bc1 != bc2 - assert bc1.something() == "something" - -def test_subpackage(): - spc = SubPackageClass(name="test") - - assert str(spc) == "SubPackage - test" - assert repr(spc) == "SubPackage - test" - assert spc.something() == "SubPackage - something" \ No newline at end of file diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..7604f45 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +from utils.memory_utils import MemoryTrace +from utils.dataset_utils import * +from utils.fsdp_utils import fsdp_auto_wrap_policy +from utils.train_utils import * diff --git a/utils/chat_utils.py b/utils/chat_utils.py new file mode 100644 index 0000000..530fdcf --- /dev/null +++ b/utils/chat_utils.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import json +from typing import List, Literal, TypedDict + + +Role = Literal["user", "assistant"] + + +class Message(TypedDict): + role: Role + content: str + + +Dialog = List[Message] + +B_INST, E_INST = "[INST]", "[/INST]" +B_SYS, E_SYS = "<>\n", "\n<>\n\n" +def format_tokens(dialogs, tokenizer): + prompt_tokens = [] + for dialog in dialogs: + if dialog[0]["role"] == "system": + dialog = [ + { + "role": dialog[1]["role"], + "content": B_SYS + + dialog[0]["content"] + + E_SYS + + dialog[1]["content"], + } + ] + dialog[2:] + assert all([msg["role"] == "user" for msg in dialog[::2]]) and all( + [msg["role"] == "assistant" for msg in dialog[1::2]] + ), ( + "model only supports 'system','user' and 'assistant' roles, " + "starting with user and alternating (u/a/u/a/u...)" + ) + """ + Please verify that your tokenizer support adding "[INST]", "[/INST]" to your inputs. + Here, we are adding it manually. + """ + dialog_tokens: List[int] = sum( + [ + tokenizer.encode( + f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ", + ) + [tokenizer.eos_token_id] + for prompt, answer in zip(dialog[::2], dialog[1::2]) + ], + [], + ) + assert ( + dialog[-1]["role"] == "user" + ), f"Last message must be from user, got {dialog[-1]['role']}" + dialog_tokens += tokenizer.encode( + f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}", + ) + prompt_tokens.append(dialog_tokens) + return prompt_tokens + + +def read_dialogs_from_file(file_path): + with open(file_path, 'r') as file: + dialogs = json.load(file) + return dialogs diff --git a/utils/checkpoint_converter_fsdp_hf.py b/utils/checkpoint_converter_fsdp_hf.py new file mode 100644 index 0000000..0216e13 --- /dev/null +++ b/utils/checkpoint_converter_fsdp_hf.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +# from accelerate import init_empty_weights, load_checkpoint_and_dispatch + +import fire +import os +import sys +import yaml + +from transformers import LlamaTokenizer + +from utils.model_utils import load_llama_from_config + +# Get the current file's directory +current_directory = os.path.dirname(os.path.abspath(__file__)) + +# Get the parent directory +parent_directory = os.path.dirname(current_directory) + +# Append the parent directory to sys.path +sys.path.append(parent_directory) +from model_checkpointing import load_sharded_model_single_gpu + +def main( + fsdp_checkpoint_path="", # Path to FSDP Sharded model checkpoints + consolidated_model_path="", # Path to save the HF converted model checkpoints + HF_model_path_or_name="" # Path/ name of the HF model that include config.json and tokenizer_config.json (e.g. meta-llama/Llama-2-7b-chat-hf) + ): + + try: + file_name = 'train_params.yaml' + # Combine the directory and file name to create the full path + train_params_path = os.path.join(fsdp_checkpoint_path, file_name) + # Open the file + with open(train_params_path, 'r') as file: + # Load the YAML data + data = yaml.safe_load(file) + + # Access the 'model_name' field + HF_model_path_or_name = data.get('model_name') + + print(f"Model name: {HF_model_path_or_name}") + except FileNotFoundError: + print(f"The file {train_params_path} does not exist.") + HF_model_path_or_name = input("Please enter the model name: ") + print(f"Model name: {HF_model_path_or_name}") + except Exception as e: + print(f"An error occurred: {e}") + + + #load the HF model definition from config + model_def = load_llama_from_config(HF_model_path_or_name) + print("model is loaded from config") + #load the FSDP sharded checkpoints into the model + model = load_sharded_model_single_gpu(model_def, fsdp_checkpoint_path) + print("model is loaded from FSDP checkpoints") + #loading the tokenizer form the model_path + tokenizer = LlamaTokenizer.from_pretrained(HF_model_path_or_name) + tokenizer.save_pretrained(consolidated_model_path) + #save the FSDP sharded checkpoints in HF format + model.save_pretrained(consolidated_model_path) + print(f"HuggingFace model checkpoints has been saved in {consolidated_model_path}") +if __name__ == "__main__": + fire.Fire(main) diff --git a/utils/config_utils.py b/utils/config_utils.py new file mode 100644 index 0000000..c85cf49 --- /dev/null +++ b/utils/config_utils.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import inspect +from dataclasses import asdict +from peft import ( + LoraConfig, + AdaptionPromptConfig, + PrefixTuningConfig, +) + +from configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config +from utils.dataset_utils import DATASET_PREPROC + + +def update_config(config, **kwargs): + if isinstance(config, (tuple, list)): + for c in config: + update_config(c, **kwargs) + else: + for k, v in kwargs.items(): + if hasattr(config, k): + setattr(config, k, v) + elif "." in k: + # allow --some_config.some_param=True + config_name, param_name = k.split(".") + if type(config).__name__ == config_name: + if hasattr(config, param_name): + setattr(config, param_name, v) + else: + # In case of specialized config we can warm user + print(f"Warning: {config_name} does not accept parameter: {k}") + elif isinstance(config, train_config): + print(f"Warning: unknown parameter {k}") + + +def generate_peft_config(train_config, kwargs): + configs = (lora_config, llama_adapter_config, prefix_config) + peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig) + names = tuple(c.__name__.rstrip("_config") for c in configs) + + assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}" + + config = configs[names.index(train_config.peft_method)]() + + update_config(config, **kwargs) + params = asdict(config) + peft_config = peft_configs[names.index(train_config.peft_method)](**params) + + return peft_config + + +def generate_dataset_config(train_config, kwargs): + names = tuple(DATASET_PREPROC.keys()) + + assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}" + + dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]() + + update_config(dataset_config, **kwargs) + + return dataset_config diff --git a/utils/dataset_utils.py b/utils/dataset_utils.py new file mode 100644 index 0000000..8c55ac6 --- /dev/null +++ b/utils/dataset_utils.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import importlib +from functools import partial +from pathlib import Path + +import torch + +from llama_datasets import ( + get_samsum_dataset, + get_cnndm_dataset, +) + + +def load_module_from_py_file(py_file: str) -> object: + """ + This method loads a module from a py file which is not in the Python path + """ + module_name = Path(py_file).name + loader = importlib.machinery.SourceFileLoader(module_name, py_file) + spec = importlib.util.spec_from_loader(module_name, loader) + module = importlib.util.module_from_spec(spec) + + loader.exec_module(module) + + return module + + +def get_custom_dataset(dataset_config, tokenizer, split: str): + if ":" in dataset_config.file: + module_path, func_name = dataset_config.file.split(":") + else: + module_path, func_name = dataset_config.file, "get_custom_dataset" + + if not module_path.endswith(".py"): + raise ValueError(f"Dataset file {module_path} is not a .py file.") + + module_path = Path(module_path) + if not module_path.is_file(): + raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.") + + module = load_module_from_py_file(module_path.as_posix()) + try: + return getattr(module, func_name)(dataset_config, tokenizer, split) + except AttributeError as e: + print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).") + raise e + + +DATASET_PREPROC = { + #"alpaca_dataset": partial(get_alpaca_dataset, max_words=224), + #"grammar_dataset": get_grammar_dataset, + "samsum_dataset": get_samsum_dataset, + "custom_dataset": get_custom_dataset, + "cnndm_dataset": get_cnndm_dataset, +} + + +def get_preprocessed_dataset( + tokenizer, dataset_config, split: str = "train" +) -> torch.utils.data.Dataset: + if not dataset_config.dataset in DATASET_PREPROC: + raise NotImplementedError(f"{dataset_config.dataset} is not (yet) implemented") + + def get_split(): + return ( + dataset_config.train_split + if split == "train" + else dataset_config.test_split + ) + + return DATASET_PREPROC[dataset_config.dataset]( + dataset_config, + tokenizer, + get_split(), + ) diff --git a/utils/fsdp_utils.py b/utils/fsdp_utils.py new file mode 100644 index 0000000..e2cd8d9 --- /dev/null +++ b/utils/fsdp_utils.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +def fsdp_auto_wrap_policy(model, transformer_layer_name): + import functools + + from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy + + from peft.tuners import PrefixEncoder, PromptEmbedding, PromptEncoder + + def lambda_policy_fn(module): + if ( + len(list(module.named_children())) == 0 + and getattr(module, "weight", None) is not None + and module.weight.requires_grad + ): + return True + return False + + lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) + transformer_wrap_policy = functools.partial( + transformer_auto_wrap_policy, + transformer_layer_cls=( + PrefixEncoder, + PromptEncoder, + PromptEmbedding, + transformer_layer_name, + # FullyShardedDataParallelPlugin.get_module_class_from_name( + # model, transformer_layer_name + # ), + ), + ) + + auto_wrap_policy = functools.partial(_or_policy, policies=[lambda_policy, transformer_wrap_policy]) + return auto_wrap_policy \ No newline at end of file diff --git a/utils/memory_utils.py b/utils/memory_utils.py new file mode 100644 index 0000000..725f2b0 --- /dev/null +++ b/utils/memory_utils.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import gc +import psutil +import threading + +import torch + +def byte2gb(x): + return int(x / 2**30) +# This context manager is used to track the peak memory usage of the process +class MemoryTrace: + def __enter__(self): + gc.collect() + torch.cuda.empty_cache() + torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero + self.begin = byte2gb(torch.cuda.memory_allocated()) + self.process = psutil.Process() + self.cpu_begin = byte2gb(self.cpu_mem_used()) + self.peak_monitoring = True + peak_monitor_thread = threading.Thread(target=self.peak_monitor_func) + peak_monitor_thread.daemon = True + peak_monitor_thread.start() + return self + + def cpu_mem_used(self): + """get resident set size memory for the current process""" + return self.process.memory_info().rss + + def peak_monitor_func(self): + self.cpu_peak = -1 + + while True: + self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak) + + # can't sleep or will not catch the peak right (this comment is here on purpose) + # time.sleep(0.001) # 1msec + + if not self.peak_monitoring: + break + + def __exit__(self, *exc): + self.peak_monitoring = False + + gc.collect() + torch.cuda.empty_cache() + self.end = byte2gb(torch.cuda.memory_allocated()) + self.peak = byte2gb(torch.cuda.max_memory_allocated()) + cuda_info = torch.cuda.memory_stats() + self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"]) + self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0) + self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"]) + self.m_cuda_ooms = cuda_info.get("num_ooms", 0) + self.used = byte2gb(self.end - self.begin) + self.peaked = byte2gb(self.peak - self.begin) + self.max_reserved = byte2gb(torch.cuda.max_memory_reserved()) + + self.cpu_end = self.cpu_mem_used() + self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin) + self.cpu_peaked = byte2gb(self.cpu_peak - self.cpu_begin) + # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}") \ No newline at end of file diff --git a/utils/model_utils.py b/utils/model_utils.py new file mode 100644 index 0000000..fee3ee7 --- /dev/null +++ b/utils/model_utils.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the GNU General Public License version 3. +import os +import torch + +from peft import PeftModel +from transformers import LlamaForCausalLM, LlamaConfig +from transformers import BitsAndBytesConfig + +from neuralnets.llama_gumbel import LlamaForCausalLMGumbel + +# Function to load the main model for text generation +def load_model(model_name, quantization): + model = LlamaForCausalLM.from_pretrained( + model_name, + return_dict=True, + load_in_8bit=quantization, + device_map="auto", + low_cpu_mem_usage=True, + ) + return model + +# Function to load the PeftModel for performance optimization +def load_peft_model(model, peft_model): + peft_model = PeftModel.from_pretrained(model, peft_model) + return peft_model + +# Loading the model from config to load FSDP checkpoints into that +def load_llama_from_config(config_path): + model_config = LlamaConfig.from_pretrained(config_path) + model = LlamaForCausalLM(config=model_config) + return model + +# Loading the model from config to load FSDP checkpoints into that +def load_gumbel_llama(model_name, quantization): + # First load architecture + model = LlamaForCausalLMGumbel.from_pretrained( + model_name, + return_dict=True, + load_in_8bit=quantization, + device_map="auto", + low_cpu_mem_usage=True, + ) + # Load weights + model.load_state_dict(torch.load(os.path.join(model_name,"model.pt"),map_location='cuda:0')) + return model + +# Loading the model from config to load FSDP checkpoints into that +def load_gumbel_llama_huggingface(model_name, quantization): + # First load architecture + model = LlamaForCausalLMGumbel.from_pretrained( + model_name, + return_dict=True, + load_in_8bit=quantization, + device_map="auto", + low_cpu_mem_usage=True, + ) + return model diff --git a/utils/safety_utils.py b/utils/safety_utils.py new file mode 100644 index 0000000..38a44d4 --- /dev/null +++ b/utils/safety_utils.py @@ -0,0 +1,169 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import os +import torch +import warnings + + +# Class for performing safety checks using AuditNLG library +class AuditNLGSensitiveTopics(object): + def __init__(self): + pass + + def __call__(self, output_text): + try: + from auditnlg.safety.exam import safety_scores + except ImportError as e: + print("Could not import optional dependency: auditnlg\nPlease install manually with:\n pip install auditnlg\nFollowed by:\npip install -r requirements.txt") + raise e + + + data = [{"output": output_text}] + + result = safety_scores(data=data, method="sensitive_topics") + scores = result[1]["all_scores"][0] + is_safe = scores["pred_class"] == "none" + report = "" + if not is_safe: + report += f"Predicted class: {scores['pred_class']}\n" + report += "|" + "|".join(f"{n:^10}" for n in [list(k.keys())[0] for k in scores["class_scores"]]) + "|\n" + report += "|" + "|".join(f"{n:^10.5}" for n in [list(k.values())[0] for k in scores["class_scores"]]) + "|\n" + return "Sensitive Topics", is_safe, report + + +class SalesforceSafetyChecker(object): + def __init__(self): + pass + + def __call__(self, output_text): + from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig + + config = AutoConfig.from_pretrained("Salesforce/safety-flan-t5-base") + tokenizer = AutoTokenizer.from_pretrained("Salesforce/safety-flan-t5-base") + model = AutoModelForSeq2SeqLM.from_pretrained("Salesforce/safety-flan-t5-base", config=config) + model.eval() + + prefix = "Is the field safe or unsafe?" + input_ids = tokenizer(prefix + " " + output_text + " ", return_tensors="pt").input_ids + + if len(input_ids[0]) > 512: + warnings.warn( + "Input length is > 512 token. Safety check result could be incorrect." + ) + + with torch.no_grad(): + outputs = model.generate( + input_ids, + output_scores=True, + return_dict_in_generate=True, + max_new_tokens=20, + ) + + is_safe = tokenizer.decode(outputs.sequences[0], skip_special_tokens=True).split(" ")[0] == "safe" + + report = "" + if not is_safe: + true_false_ids = tokenizer("true false").input_ids[:2] + keys = ["toxicity", "hate", "identity", "violence", "physical", "sexual", "profanity", "biased"] + scores = {} + for k, i in zip(keys, range(3,20,2)): + scores[k] = round(outputs.scores[i][0,true_false_ids].softmax(dim=0)[0].item(), 5) + + report += "|" + "|".join(f"{n:^10}" for n in scores.keys()) + "|\n" + report += "|" + "|".join(f"{n:^10}" for n in scores.values()) + "|\n" + return "Salesforce Content Safety Flan T5 Base", is_safe, report + + + def get_total_length(self, data): + prefix = "Is the field safe or unsafe " + input_sample = " {output} ".format(**data[0]) + + return len(self.tokenizer(prefix + input_sample)["input_ids"]) + + +# Class for performing safety checks using Azure Content Safety service +class AzureSaftyChecker(object): + def __init__(self): + try: + from azure.ai.contentsafety import ContentSafetyClient + from azure.core.credentials import AzureKeyCredential + + key = os.environ["CONTENT_SAFETY_KEY"] + endpoint = os.environ["CONTENT_SAFETY_ENDPOINT"] + except ImportError: + raise Exception( + "Could not import required package azure-ai-contentsafety. Install with: pip install azure-ai-contentsafety" + ) + except KeyError: + raise Exception( + "Environment variables not set. Please set CONTENT_SAFETY_KEY and CONTENT_SAFETY_ENDPOINT." + ) + + self.client = ContentSafetyClient(endpoint, AzureKeyCredential(key)) + + def __call__(self, output_text): + from azure.core.exceptions import HttpResponseError + from azure.ai.contentsafety.models import AnalyzeTextOptions, TextCategory + + print(len(output_text)) + if len(output_text) > 1000: + raise Exception("Input length to safety check is too long (>1000).") + + categories = [ + TextCategory.VIOLENCE, + TextCategory.SELF_HARM, + TextCategory.SEXUAL, + TextCategory.HATE, + ] + + request = AnalyzeTextOptions(text=output_text, categories=categories) + + try: + response = self.client.analyze_text(request) + except HttpResponseError as e: + print("Analyze text failed.") + if e.error: + print(f"Error code: {e.error.code}") + print(f"Error message: {e.error.message}") + raise + print(e) + raise e + + levels = {0: "Safe", 2: "Low", 4: "Medium", 6: "High"} + + severities = [ + getattr(response, c.name.lower() + "_result").severity for c in categories + ] + + DEFAULT_LEVELS = [0, 0, 0, 0] + + is_safe = all([s <= l for s, l in zip(severities, DEFAULT_LEVELS)]) + + report = "" + if not is_safe: + report = "|" + "|".join(f"{c.name:^10}" for c in categories) + "|\n" + report += "|" + "|".join(f"{levels[s]:^10}" for s in severities) + "|\n" + + return "Azure Content Saftey API", is_safe, report + + +# Function to load the PeftModel for performance optimization +# Function to determine which safety checker to use based on the options selected +def get_safety_checker(enable_azure_content_safety, + enable_sensitive_topics, + enable_salesforce_content_safety, + ): + safety_checker = [] + if enable_azure_content_safety: + safety_checker.append(AzureSaftyChecker()) + if enable_sensitive_topics: + safety_checker.append(AuditNLGSensitiveTopics()) + if enable_salesforce_content_safety: + safety_checker.append(SalesforceSafetyChecker()) + return safety_checker + + + + + diff --git a/utils/train_utils.py b/utils/train_utils.py new file mode 100644 index 0000000..0260068 --- /dev/null +++ b/utils/train_utils.py @@ -0,0 +1,493 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. + +import os +import time +import yaml +from pathlib import Path +from pkg_resources import packaging + +import torch +import torch.cuda.nccl as nccl +import torch.distributed as dist +from torch.distributed.fsdp import StateDictType +from torch.distributed.fsdp import FullyShardedDataParallel +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +from tqdm import tqdm + +from model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint, save_model_checkpoint_nofsdp, save_optimizer_checkpoint_nofsdp +from policies import fpSixteen, bfSixteen_mixed, get_llama_wrapper +from utils.memory_utils import MemoryTrace + + +# Converting Bytes to Megabytes +def byte2mb(x): + return int(x / 2 ** 20) + + +def calc_loss(loss_normal, loss_gumbel, step, alpha=0.8, beta=0.0005, gumbel_multiplier=50.0): + """ + Calculates the loss using the weighted sum of the normal loss and the gumbel loss, + where the weight is reduced over time. + + Args: + loss_normal: normal loss + loss_gumbel: gumbel loss + step: current step + alpha: initial weighting factor for the gumbel loss + beta: controls the rate at which the weighting factor decreases + gumbel_multiplier: simple multiplier for the gumbel loss + + Returns: loss + + """ + #w_t = alpha * torch.exp(torch.tensor(-beta * step)) + w_t = 0.5 + return (1 - w_t) * loss_normal, w_t * loss_gumbel * gumbel_multiplier + + +def train(model, train_dataloader, eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, + train_config, fsdp_config=None, local_rank=None, rank=None): + """ + Trains the model on the given dataloader + + Args: + model: The model to be trained + train_dataloader: The dataloader containing the training data + optimizer: The optimizer used for training + lr_scheduler: The learning rate scheduler + gradient_accumulation_steps: The number of steps to accumulate gradients before performing a backward/update operation + num_epochs: The number of epochs to train for + local_rank: The rank of the current node in a distributed setting + train_config: The training configuration + eval_dataloader: The dataloader containing the eval data + tokenizer: tokenizer used in the eval for decoding the predicitons + + Returns: results dictionary containing average training and validation perplexity and loss + """ + + # Create a gradient scaler for fp16 + if train_config.use_fp16 and train_config.enable_fsdp: + scaler = ShardedGradScaler() + elif train_config.use_fp16 and not train_config.enable_fsdp: + scaler = torch.cuda.amp.GradScaler() + if train_config.enable_fsdp: + world_size = int(os.environ["WORLD_SIZE"]) + train_prep = [] + train_loss = [] + val_prep = [] + val_loss = [] + epoch_times = [] + checkpoint_times = [] + results = {} + best_val_loss = float("inf") + best_train_loss = float("inf") + + if train_config.gumbel: + # a dict to store the activations + gumbel_activations_train = {} + gumbel_activations_dev = {} + for epoch in range(train_config.num_epochs): + epoch_start_time = time.perf_counter() + with MemoryTrace() as memtrace: # track the memory usage + model.train() + if train_config.gumbel: + # a dict to store the activations + gumbel_activation_epoch = {} + def hook_fn(layer, input, output): + gumbel_activation_epoch[layer] = output + + for layer_idx, layer in enumerate(model.model.gumbel_layer_selection): + try: + gumbel_activations_train[epoch][layer_idx] = [] + except KeyError: + gumbel_activations_train[epoch] = {layer_idx:[]} + model.model.gumbel_layer_selection[layer_idx].register_forward_hook(hook_fn) + + total_loss = 0.0 + total_length = len(train_dataloader) // gradient_accumulation_steps + pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch + 1}", total=total_length, dynamic_ncols=True) + for step, batch in enumerate(train_dataloader): + for key in batch.keys(): + if train_config.enable_fsdp: + batch[key] = batch[key].to(local_rank) + else: + batch[key] = batch[key].to('cuda:0') + output = model(**batch) + loss = output.loss + # Gumbel softmax loss with token level + if train_config.gumbel: + global_step = epoch * len(train_dataloader) + step + + with FullyShardedDataParallel.summon_full_params(model): + activation_rate = torch.mean(output.activations) + loss_gumbel = (model.model.gumbel_threshold - activation_rate) ** 2 + + loss_weighted, loss_gumbel_weighted = calc_loss(loss, loss_gumbel, step=global_step, + gumbel_multiplier=train_config.gumbel_loss_multiplier, + alpha=train_config.gumbel_loss_alpha, + beta=train_config.gumbel_loss_beta) + loss = loss_weighted + loss_gumbel_weighted + + if train_config.gumbel: + # Track activated layers + for layer_idx, res in enumerate(gumbel_activation_epoch.values()): + # Track activations per layer and instance + for idx, activation in enumerate(res): + gumbel_activations_train[epoch][layer_idx].append(torch.argmax(activation).cpu().numpy().tolist()) + + loss = loss / gradient_accumulation_steps + total_loss += loss.detach().float() + if train_config.use_fp16: + # if fp16 is enabled, use gradient scaler to handle gradient update + scaler.scale(loss).backward() + if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_value_(model.parameters(), train_config.gradient_clipping_value) + + scaler.step(optimizer) + scaler.update() + optimizer.zero_grad() + pbar.update(1) + else: + # regular backpropagation when fp16 is not used + loss.backward() + + if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1: + torch.nn.utils.clip_grad_value_(model.parameters(), train_config.gradient_clipping_value) + optimizer.step() + optimizer.zero_grad() + pbar.update(1) + + pbar.set_description( + f"Training Epoch: {epoch + 1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss})") # .detach().float() + + if loss < best_train_loss: + best_train_loss = loss.clone().detach() + + model.save_pretrained(train_config.output_dir) + tokenizer.save_pretrained(train_config.output_dir) + torch.save(model.state_dict(), os.path.join(train_config.output_dir, "model.pt")) + pbar.close() + + epoch_end_time = time.perf_counter() - epoch_start_time + epoch_times.append(epoch_end_time) + # Reducing total_loss across all devices if there's more than one CUDA device + if torch.cuda.device_count() > 1 and train_config.enable_fsdp: + dist.all_reduce(total_loss, op=dist.ReduceOp.SUM) + train_epoch_loss = total_loss / len(train_dataloader) + if train_config.enable_fsdp: + train_epoch_loss = train_epoch_loss / world_size + train_perplexity = torch.exp(train_epoch_loss) + + train_prep.append(train_perplexity) + train_loss.append(train_epoch_loss) + + # Update the learning rate as needed + lr_scheduler.step() + + if train_config.run_validation: + if train_config.gumbel: + eval_ppl, eval_epoch_loss, gumbel_activations_eval = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer) + gumbel_activations_dev[epoch] = gumbel_activations_eval + else: + eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer) + checkpoint_start_time = time.perf_counter() + if train_config.save_model and eval_epoch_loss < best_val_loss: + if train_config.enable_fsdp: + dist.barrier() + if train_config.use_peft: + model.save_pretrained(train_config.output_dir) + + else: + if not train_config.enable_fsdp: + # Use for freezing setup + save_model_checkpoint_nofsdp( + model, train_config, epoch=epoch + ) + elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT: + + save_model_checkpoint( + model, optimizer, train_config, rank, epoch=epoch + ) + elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT: + save_model_and_optimizer_sharded(model, rank, train_config) + if train_config.save_optimizer: + save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer) + + if not train_config.use_peft and train_config.save_optimizer: + if not train_config.enable_fsdp: + save_optimizer_checkpoint( + model, optimizer, rank, train_config, epoch=epoch + ) + else: + save_optimizer_checkpoint_nofsdp( + optimizer, train_config, epoch=epoch + ) + if train_config.enable_fsdp: + dist.barrier() + checkpoint_end_time = time.perf_counter() - checkpoint_start_time + checkpoint_times.append(checkpoint_end_time) + if eval_epoch_loss < best_val_loss: + best_val_loss = eval_epoch_loss + if train_config.enable_fsdp: + if rank == 0: + print(f"best eval loss on epoch {epoch + 1} is {best_val_loss}") + else: + print(f"best eval loss on epoch {epoch + 1} is {best_val_loss}") + val_loss.append(best_val_loss) + val_prep.append(eval_ppl) + if train_config.enable_fsdp: + if rank == 0: + print( + f"Epoch {epoch + 1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s") + else: + print( + f"Epoch {epoch + 1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s") + avg_epoch_time = sum(epoch_times) / len(epoch_times) + avg_checkpoint_time = sum(checkpoint_times) / len(checkpoint_times) if len(checkpoint_times) > 0 else 0 + avg_train_prep = sum(train_prep) / len(train_prep) + avg_train_loss = sum(train_loss) / len(train_loss) + if train_config.run_validation: + avg_eval_prep = sum(val_prep) / len(val_prep) + avg_eval_loss = sum(val_loss) / len(val_loss) + + results['avg_train_prep'] = avg_train_prep + results['avg_train_loss'] = avg_train_loss + if train_config.run_validation: + results['avg_eval_prep'] = avg_eval_prep + results['avg_eval_loss'] = avg_eval_loss + results["avg_epoch_time"] = avg_epoch_time + results["avg_checkpoint_time"] = avg_checkpoint_time + + # saving the training params including fsdp setting for reference. + if train_config.enable_fsdp and not train_config.use_peft: + save_train_params(train_config, fsdp_config, rank) + + # After training, save the model: + # use a barrier to make sure training is done on all ranks + dist.barrier() + states = model.state_dict() + + if train_config.gumbel: + results["gumbel_activations_train"] = gumbel_activations_train + results["gumbel_activations_dev"] = gumbel_activations_dev + + return results + + +def evaluation(model, train_config, eval_dataloader, local_rank, tokenizer): + """ + Evaluates the model on the given dataloader + + Args: + model: The model to evaluate + eval_dataloader: The dataloader containing the evaluation data + local_rank: The rank of the current node in a distributed setting + tokenizer: The tokenizer used to decode predictions + + Returns: eval_ppl, eval_epoch_loss + """ + if train_config.enable_fsdp: + world_size = int(os.environ["WORLD_SIZE"]) + model.eval() + eval_preds = [] + eval_loss = 0.0 # Initialize evaluation loss + if train_config.gumbel: + # a dict to store the activations + gumbel_activations_eval = {} + with MemoryTrace() as memtrace: + if train_config.gumbel: + # a dict to store the activations + gumbel_activation_epoch = {} + def hook_fn(layer, input, output): + gumbel_activation_epoch[layer] = output + + for layer_idx, layer in enumerate(model.model.gumbel_layer_selection): + try: + gumbel_activations_eval[layer_idx] = [] + except KeyError: + gumbel_activations_eval = {layer_idx:[]} + model.model.gumbel_layer_selection[layer_idx].register_forward_hook(hook_fn) + for step, batch in enumerate( + tqdm(eval_dataloader, colour="green", desc="evaluating Epoch", dynamic_ncols=True)): + for key in batch.keys(): + if train_config.enable_fsdp: + batch[key] = batch[key].to(local_rank) + else: + batch[key] = batch[key].to('cuda:0') + # Ensure no gradients are computed for this scope to save memory + with torch.no_grad(): + # Forward pass and compute loss + outputs = model(**batch) + loss = outputs.loss + eval_loss += loss.detach().float() + # Decode predictions and add to evaluation predictions list + preds = torch.argmax(outputs.logits, -1) + eval_preds.extend( + tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True) + ) + if train_config.gumbel: + # Track activated layers + for layer_idx, res in enumerate(gumbel_activation_epoch.values()): + # Track activations per layer and instance + for idx, activation in enumerate(res): + gumbel_activations_eval[layer_idx].append(torch.argmax(activation).cpu().numpy().tolist()) + + # If there's more than one CUDA device, reduce evaluation loss across all devices + if torch.cuda.device_count() > 1 and train_config.enable_fsdp: + dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM) + + # Compute average loss and perplexity + eval_epoch_loss = eval_loss / len(eval_dataloader) + if train_config.enable_fsdp: + eval_epoch_loss = eval_epoch_loss / world_size + eval_ppl = torch.exp(eval_epoch_loss) + + # Print evaluation metrics + if train_config.enable_fsdp: + if local_rank == 0: + print(f" {eval_ppl=} {eval_epoch_loss=}") + else: + print(f" {eval_ppl=} {eval_epoch_loss=}") + + if train_config.gumbel: + return eval_ppl, eval_epoch_loss, gumbel_activations_eval + + return eval_ppl, eval_epoch_loss + + +def freeze_transformer_layers(model, num_layer): + for i, layer in enumerate(model.model.layers): + if i < num_layer: + for param in layer.parameters(): + param.requires_grad = False + + +def check_frozen_layers_peft_model(model): + for i, layer in enumerate(model.base_model.model.model.layers): + for name, param in layer.named_parameters(): + print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}") + + +def setup(): + """Initialize the process group for distributed training""" + dist.init_process_group("nccl") + + +def setup_environ_flags(rank): + """Set environment flags for debugging purposes""" + os.environ["TORCH_SHOW_CPP_STACKTRACES"] = str(1) + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = str(1) + if rank == 0: + print(f"--> Running with torch dist debug set to detail") + + +def cleanup(): + """Clean up the process group after training""" + dist.destroy_process_group() + + +def clear_gpu_cache(rank=None): + """Clear the GPU cache for all ranks""" + if rank == 0: + print(f"Clearing GPU cache for all ranks") + torch.cuda.empty_cache() + + +def get_parameter_dtypes(model): + """Get the data types of model parameters""" + parameter_dtypes = {} + for name, parameter in model.named_parameters(): + parameter_dtypes[name] = parameter.dtype + return parameter_dtypes + + +def print_model_size(model, config, rank: int = 0) -> None: + """ + Print model name, the number of trainable parameters and initialization time. + + Args: + model: The PyTorch model. + model_name (str): Name of the model. + init_time_start (float): Initialization start time. + init_time_end (float): Initialization end time. + rank (int, optional): Current process's rank. Defaults to 0. + """ + if rank == 0: + print(f"--> Model {config.model_name}") + total_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + print(f"\n--> {config.model_name} has {total_params / 1e6} Million params\n") + + +def get_policies(cfg, rank): + """Get the policies for mixed precision and fsdp wrapping""" + + verify_bfloat_support = ( + torch.version.cuda + and torch.cuda.is_bf16_supported() + and packaging.version.parse(torch.version.cuda).release >= (11, 0) + and dist.is_nccl_available() + and nccl.version() >= (2, 10) + ) + + mixed_precision_policy = None + wrapping_policy = None + + # Mixed precision + if cfg.mixed_precision: + bf16_ready = verify_bfloat_support + + if bf16_ready and not cfg.use_fp16: + mixed_precision_policy = bfSixteen_mixed + if rank == 0: + print(f"bFloat16 enabled for mixed precision - using bfSixteen policy") + elif cfg.use_fp16: + mixed_precision_policy = fpSixteen + if rank == 0: + print(f"FP16 enabled") + else: + print(f"bFloat16 support not present. Using FP32, and not mixed precision") + wrapping_policy = get_llama_wrapper() + return mixed_precision_policy, wrapping_policy + + +def save_train_params(train_config, fsdp_config, rank): + """ + This function saves the train_config and FSDP config into a train_params.yaml. + This will be used by converter script in the inference folder to fetch the HF model name or path. + It also would be hepful as a log for future references. + """ + # Convert the train_config and fsdp_config objects to dictionaries, + # converting all values to strings to ensure they can be serialized into a YAML file + train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')} + fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')} + # Merge the two dictionaries into one + train_params_dict = {**train_config_dict, **fsdp_config_dict} + # Construct the folder name (follwoing FSDP checkpointing style) using properties of the train_config object + model_name_local = train_config.model_name.split('/')[-1] + folder_name = ( + train_config.dist_checkpoint_root_folder + + "/" + + train_config.dist_checkpoint_folder + + "-" + + model_name_local + ) + + save_dir = folder_name # Path.cwd() / + # If the directory does not exist, create it + if not os.path.exists(save_dir): + os.makedirs(save_dir) + # Convert the dictionary to a YAML string + config_yaml = yaml.dump(train_params_dict, indent=4) + file_name = os.path.join(save_dir, 'train_params.yaml') + + # Check if there's a directory with the same name as the file + if os.path.isdir(file_name): + print(f"Error: {file_name} is a directory, not a file.") + else: + # Write the YAML string to the file + with open(file_name, 'w') as f: + f.write(config_yaml) + if rank == 0: + print(f"training params are saved in {file_name}")