From 0031d554be2edb20c775ba61a05f60ac9b044793 Mon Sep 17 00:00:00 2001 From: Ali Alshaarawy <45029495+ali-alshaar7@users.noreply.github.com> Date: Mon, 27 Jan 2025 14:30:53 -0500 Subject: [PATCH] Add Deepseek r1 distill llama models (#1922) Co-authored-by: Ali Alshaarawy --- README.md | 1 + litgpt/config.py | 47 +++++++++++++++++++++++++++++ tests/test_model.py | 2 ++ tutorials/download_model_weights.md | 3 ++ 4 files changed, 53 insertions(+) diff --git a/README.md b/README.md index c9bf8339be..efdb1efbf5 100644 --- a/README.md +++ b/README.md @@ -142,6 +142,7 @@ Every model is written from scratch to maximize performance and remove layers of | Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) | | Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122) | | QwQ | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) | +| R1 Distll Llama | 8B, 70B | DeepSeek AI | [DeepSeek AI 2025](https://github.com/deepseek-ai/DeepSeek-R1/blob/main/DeepSeek_R1.pdf) | | SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) | | Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) | | StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | diff --git a/litgpt/config.py b/litgpt/config.py index 0613a1929a..7106ea581d 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -2400,5 +2400,52 @@ def norm_class(self) -> Type: copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind) configs.append(copy) +############### +# DeepSeek R1 Distill +############### + +r1_distill_llama = [ + # https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-8B/blob/main/config.json + dict( + name="R1-Distill-Llama-8B", + hf_config=dict(org="deepseek-ai", name="DeepSeek-R1-Distill-Llama-8B"), + block_size=131072, + vocab_size=128000, + padded_vocab_size=128256, + n_layer=32, + n_head=32, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=14336, + rope_base=500000, + rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192) + ), + # https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-70B/blob/main/config.json + dict( + name="R1-Distill-Llama-70B", + hf_config=dict(org="deepseek-ai", name="DeepSeek-R1-Distill-Llama-70B"), + block_size=131072, + vocab_size=128000, + padded_vocab_size=128256, + n_layer=80, + n_head=64, + n_embd=8192, + n_query_groups=8, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="LLaMAMLP", + intermediate_size=28672, + rope_base=500000, + rope_adjustments=dict(factor=8.0, low_freq_factor=1.0, high_freq_factor=4.0, original_max_seq_len=8192) + ), +] + +configs.extend(r1_distill_llama) name_to_config = {config["name"]: config for config in configs} diff --git a/tests/test_model.py b/tests/test_model.py index a2c9997273..4e5189968d 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -225,6 +225,8 @@ def test_against_original_open_llama_3b(device, dtype): {"name": "Llama-3.2-1B"}, {"name": "Llama-3.2-3B"}, {"name": "Llama-3.3-70B-Instruct"}, + {"name": "R1-Distill-Llama-8B"}, + {"name": "R1-Distill-Llama-70B"}, ], ) @pytest.mark.parametrize( diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index e93f9d4a91..40335d949a 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -41,6 +41,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights. | Qwen2.5 Coder | 0.5B, 1.5B, 3B, 7B, 14B, 32B | Alibaba Group | [Hui, Binyuan et al. 2024](https://arxiv.org/abs/2409.12186) | | Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122) | | QwQ | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) | +| R1 Distll Llama | 8B, 70B | DeepSeek AI | [DeepSeek AI 2025](https://github.com/deepseek-ai/DeepSeek-R1/blob/main/DeepSeek_R1.pdf) | | RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) | | SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) | | StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) | @@ -87,6 +88,8 @@ codellama/CodeLlama-7b-Python-hf databricks/dolly-v2-12b databricks/dolly-v2-3b databricks/dolly-v2-7b +deepseek-ai/DeepSeek-R1-Distill-Llama-8B +deepseek-ai/DeepSeek-R1-Distill-Llama-70B EleutherAI/pythia-1.4b EleutherAI/pythia-1.4b-deduped EleutherAI/pythia-12b