diff --git a/research/mpt/README.md b/research/mpt/README.md index a8d63e073f..0d93815e0d 100644 --- a/research/mpt/README.md +++ b/research/mpt/README.md @@ -1,42 +1,42 @@ *LAST UPDATED: 11/24/2023* -# **Sparse Finetuned LLMs with DeepSparse** +# **Sparse Fine-Tuned LLMs With DeepSparse** DeepSparse has support for performant inference of sparse large language models, starting with Mosaic's MPT and Meta's Llama 2. -Check out our paper [Sparse Finetuning for Inference Acceleration of Large Language Models](https://arxiv.org/abs/2310.06927) +Check out our paper [Sparse Fine-tuning for Inference Acceleration of Large Language Models](https://arxiv.org/abs/2310.06927) In this research overview, we will discuss: -1. [Our Sparse Fineuning Research](#sparse-finetuning-research) -2. [How to try Text Generation with DeepSparse](#try-it-now) +1. [Our Sparse Fine-Tuning Research](#sparse-finetuning-research) +2. [How to Try Text Generation With DeepSparse](#try-it-now) -## **Sparse Finetuning Research** +## **Sparse Fine-Tuning Research** -We show that MPT-7B and Llama-2-7B can be pruned to ~60% sparsity with INT8 quantization (and 70% sparsity without quantization), with no accuracy drop, using a technique called **Sparse Finetuning**, where we prune the network during the finetuning process. +We show that MPT-7B and Llama-2-7B can be pruned to ~60% sparsity with INT8 quantization (and 70% sparsity without quantization), with no accuracy drop, using a technique called **Sparse Fine-Tuning**, where we prune the network during the fine-tuning process. When running the pruned network with DeepSparse, we can accelerate inference by ~7x over the dense-FP32 baseline! -### **Sparse Finetuning on Grade-School Math (GSM)** +### **Sparse Fine-Tuning on Grade-School Math (GSM)** -Training LLMs consist of two steps. First, the model is pre-trained on a very large corpus of text (typically >1T tokens). Then, the model is adapted for downstream use by continuing training with a much smaller high quality curated dataset. This second step is called finetuning. +Training LLMs consists of two steps. First, the model is pre-trained on a very large corpus of text (typically >1T tokens). Then, the model is adapted for downstream use by continuing training with a much smaller high-quality curated dataset. This second step is called fine-tuning. Fine-tuning is useful for two main reasons: 1. It can teach the model *how to respond* to input (often called **instruction tuning**). 2. It can teach the model *new information* (often called **domain adaptation**). -An example of how domain adaptation is helpful is solving the [Grade-school math (GSM) dataset](https://huggingface.co/datasets/gsm8k). GSM is a set of grade school word problems and a notoriously difficult task for LLMs, as evidenced by the 0% zero-shot accuracy of MPT-7B. By fine-tuning with a very small set of ~7k training examples, however, we can boost the model's accuracy on the test set to 28.2%. +An example of how domain adaptation is helpful in solving the [Grade-school math (GSM) dataset](https://huggingface.co/datasets/gsm8k). GSM is a set of grade school word problems and a notoriously difficult task for LLMs, as evidenced by the 0% zero-shot accuracy of MPT-7B. By fine-tuning with a very small set of ~7k training examples, however, we can boost the model's accuracy on the test set to 28.2%. -The key insight from [our paper](https://arxiv.org/abs/2310.06927) is that we can prune the network during the finetuning process. We apply [SparseGPT](https://arxiv.org/pdf/2301.00774.pdf) to prune the network after dense finetuning and retrain for 2 epochs with L2 distillation. The result is a 60% sparse-quantized model with no accuracy drop on GSM8k runs 7x faster than the dense baseline with DeepSparse! +The key insight from [our paper](https://arxiv.org/abs/2310.06927) is that we can prune the network during the fine-tuning process. We apply [SparseGPT](https://arxiv.org/pdf/2301.00774.pdf) to prune the network after dense fine-tuning and retrain for 2 epochs with L2 distillation. The result is a 60% sparse-quantized model with no accuracy drop on GSM8k runs 7x faster than the dense baseline with DeepSparse!