From d0b59bdf04c7a7c6d7781d4dd1113aca2ae0ebe0 Mon Sep 17 00:00:00 2001 From: Mark O'Connor Date: Fri, 7 Feb 2025 13:32:46 +0100 Subject: [PATCH] Add HF model support inc. DS-R1-Distill, Qwen needs yarn support (#17421) ### Problem description Existing codebase loads the meta checkpoint format but many derivative models are only available on huggingface. ### What's changed Add support for loading HuggingFace model formats, paving the way for full Qwen support (pending yarn rope implementation) and adding DeepSeek-R1-Distill-Llama-70B support. ### Checklist All passing locally. - [x] [all-post-commit](https://github.com/tenstorrent/tt-metal/actions/runs/13181023765) - [FIXED] Failing on loading the tokenizer on this pipeline only (investigating) - [x] [Single](https://github.com/tenstorrent/tt-metal/actions/runs/13142509908/job/36672984561) - [x] [Single-demos](https://github.com/tenstorrent/tt-metal/actions/runs/13180995444) - Only failing on N300 performance - Investigating - [ ] [T3K](https://github.com/tenstorrent/tt-metal/actions/runs/13142519276) - [x] [Unit](https://github.com/tenstorrent/tt-metal/actions/runs/13163296158/job/36737812258) - [x] [Model-perf](https://github.com/tenstorrent/tt-metal/actions/runs/13164376159) - [x] [Frequent-1](https://github.com/tenstorrent/tt-metal/actions/runs/13174954913) - [x] [Frequent-2](https://github.com/tenstorrent/tt-metal/actions/runs/13164380377/job/36742877847) - [x] [Demo](https://github.com/tenstorrent/tt-metal/actions/runs/13180986094) - [x] [TG](https://github.com/tenstorrent/tt-metal/actions/runs/13154035596/job/36707218743) - Pipelines have issues not related to these changes. --------- Signed-off-by: Salar Hosseini Co-authored-by: mtairum Co-authored-by: Salar Hosseini --- README.md | 5 +- models/common/rmsnorm.py | 3 + models/demos/llama3/PERF.md | 93 +- models/demos/llama3/README.md | 27 +- models/demos/llama3/demo/demo.py | 254 ++--- .../demo/input_data_questions_reasoning.json | 20 + .../demos/llama3/demo/simple_vision_demo.py | 2 +- models/demos/llama3/lt | 57 +- models/demos/llama3/requirements.txt | 1 + .../llama3/tests/generate_reference_hf.py | 148 +++ .../tests/generate_reference_outputs.py | 60 +- .../tests/generate_reference_outputs.sh | 27 +- ..._llama_cross_attention_transformer_text.py | 9 +- ...{70b.refpt => Llama3.1-70B-Instruct.refpt} | Bin .../{8b.refpt => Llama3.1-8B-Instruct.refpt} | Bin ...{11b.refpt => Llama3.2-11B-Instruct.refpt} | Bin .../{1b.refpt => Llama3.2-1B-Instruct.refpt} | Bin .../{3b.refpt => Llama3.2-3B-Instruct.refpt} | Bin .../Qwen2.5-72B-Instruct.refpt | Bin 0 -> 50726 bytes .../Qwen2.5-7B-Instruct.refpt | Bin 0 -> 50720 bytes .../tests/test_interleaved_to_sharded.py | 35 +- .../demos/llama3/tests/test_llama_accuracy.py | 41 +- .../llama3/tests/test_llama_attention.py | 28 +- .../tests/test_llama_attention_prefill.py | 14 +- .../demos/llama3/tests/test_llama_decoder.py | 7 +- .../tests/test_llama_decoder_prefill.py | 11 +- .../llama3/tests/test_llama_embedding.py | 8 +- models/demos/llama3/tests/test_llama_mlp.py | 22 +- models/demos/llama3/tests/test_llama_model.py | 114 +-- .../llama3/tests/test_llama_model_prefill.py | 23 +- .../demos/llama3/tests/test_llama_rms_norm.py | 9 +- models/demos/llama3/tests/test_llama_torch.py | 13 +- models/demos/llama3/tests/test_lm_head.py | 3 +- models/demos/llama3/tests/test_ref.py | 104 ++ models/demos/llama3/tt/generator_vllm.py | 2 +- models/demos/llama3/tt/llama_attention.py | 139 ++- models/demos/llama3/tt/llama_ccl.py | 8 +- models/demos/llama3/tt/llama_common.py | 95 +- models/demos/llama3/tt/llama_decoder.py | 2 + models/demos/llama3/tt/llama_mlp.py | 75 +- models/demos/llama3/tt/llama_model.py | 14 +- models/demos/llama3/tt/llama_rope.py | 33 +- models/demos/llama3/tt/lm_head.py | 16 +- models/demos/llama3/tt/load_checkpoints.py | 303 ++++++ models/demos/llama3/tt/model_config.py | 897 +++++++++++++----- .../tt/multimodal/llama_cross_attention.py | 2 + .../llama_cross_attention_transformer_text.py | 8 +- ...lama_cross_attention_transformer_vision.py | 14 +- .../llama3/tt/multimodal/llama_image_mlp.py | 14 +- .../tt/multimodal/llama_vision_model.py | 3 +- 50 files changed, 1983 insertions(+), 780 deletions(-) create mode 100644 models/demos/llama3/demo/input_data_questions_reasoning.json mode change 100644 => 100755 models/demos/llama3/lt create mode 100644 models/demos/llama3/tests/generate_reference_hf.py rename models/demos/llama3/tests/reference_outputs/{70b.refpt => Llama3.1-70B-Instruct.refpt} (100%) rename models/demos/llama3/tests/reference_outputs/{8b.refpt => Llama3.1-8B-Instruct.refpt} (100%) rename models/demos/llama3/tests/reference_outputs/{11b.refpt => Llama3.2-11B-Instruct.refpt} (100%) rename models/demos/llama3/tests/reference_outputs/{1b.refpt => Llama3.2-1B-Instruct.refpt} (100%) rename models/demos/llama3/tests/reference_outputs/{3b.refpt => Llama3.2-3B-Instruct.refpt} (100%) create mode 100644 models/demos/llama3/tests/reference_outputs/Qwen2.5-72B-Instruct.refpt create mode 100644 models/demos/llama3/tests/reference_outputs/Qwen2.5-7B-Instruct.refpt create mode 100644 models/demos/llama3/tests/test_ref.py create mode 100644 models/demos/llama3/tt/load_checkpoints.py diff --git a/README.md b/README.md index e4d2c5b951d..817558ebf75 100644 --- a/README.md +++ b/README.md @@ -36,12 +36,13 @@ | [Llama 3.1 70B (TP=8)](./models/demos/t3000/llama3_70b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 190 | 15.1 | 20 | 483.2 | [v0.54.0-rc2](https://github.com/tenstorrent/tt-metal/tree/v0.54.0-rc2) | [9531611](https://github.com/tenstorrent/vllm/tree/953161188c50f10da95a88ab305e23977ebd3750) | | [Falcon 40B (TP=8)](./models/demos/t3000/falcon40b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | | 5.3 | 36 | 169.6 | [v0.55.0-rc19](https://github.com/tenstorrent/tt-metal/tree/v0.55.0-rc19) | | | [Mixtral 8x7B (TP=8)](./models/demos/t3000/mixtral8x7b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 227 | 14.9 | 33 | 476.8 | [v0.55.0-rc19](https://github.com/tenstorrent/tt-metal/tree/v0.55.0-rc19) | | +| [DeepSeek R1 Distill Llama 3.3 70B (TP=8)](./models/demos/llama3) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 1113 | 16.4 | 33 |386.4 | [main](https://github.com/tenstorrent/tt-metal/) | [2f33504](https://github.com/tenstorrent/vllm/tree/2f33504bad49a6202d3685155107a6126a5b5e6e) | | [Falcon 7B (DP=32)](./models/demos/tg/falcon7b) | 1024 | [Galaxy](https://tenstorrent.com/hardware/galaxy) | 223 | 4.8 | 26 | 4915.2 | [v0.55.0-rc18](https://github.com/tenstorrent/tt-metal/tree/v0.55.0-rc18) | | | [Llama 3.1 70B (DP=4, TP=8)](./models/demos/t3000/llama3_70b) | 128 | [Galaxy](https://tenstorrent.com/hardware/galaxy) | 190 | 14.3 | 20 | 1835.5 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | | | [Llama 3.1 70B (TP=32)](./models/demos/llama3) | 32 | [Galaxy](https://tenstorrent.com/hardware/galaxy) | 763 | 13.5 | 80 | 432.0 | [v0.55.0-rc12](https://github.com/tenstorrent/tt-metal/tree/v0.55.0-rc12) | [2f33504](https://github.com/tenstorrent/vllm/tree/2f33504bad49a6202d3685155107a6126a5b5e6e) | -| [DeepSeek R1 Distill Llama 3.3 70B (TP=8)](https://github.com/tenstorrent/tt-metal/tree/hf-llama/models/demos/llama3) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 1113 | 16.4 | 33 |524.8 | [hf-llama](https://github.com/tenstorrent/tt-metal/tree/hf-llama) | [b9564bf](https://github.com/tenstorrent/vllm/tree/b9564bf364e95a3850619fc7b2ed968cc71e30b7) | +| [DeepSeek R1 Distill Llama 3.3 70B (TP=8)](https://github.com/tenstorrent/tt-metal/tree/main/models/demos/llama3) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 1113 | 16.4 | 33 |524.8 | [main](https://github.com/tenstorrent/tt-metal/) | [b9564bf](https://github.com/tenstorrent/vllm/tree/b9564bf364e95a3850619fc7b2ed968cc71e30b7) | -> **Last Update:** January 27, 2025 +> **Last Update:** February 5, 2025 > > **Notes:** > diff --git a/models/common/rmsnorm.py b/models/common/rmsnorm.py index 36f06ea8cc4..28eb9cadf55 100644 --- a/models/common/rmsnorm.py +++ b/models/common/rmsnorm.py @@ -49,10 +49,12 @@ def __init__( eps: float = 1e-05, sharded_program_config=None, sharded_output_config=None, + ccl_topology=ttnn.Topology.Ring, ): super().__init__() self.eps = eps self.is_distributed = is_distributed + self.ccl_topology = ccl_topology if state_dict_prefix: weight_name = f"{state_dict_prefix}{weight_key}.weight" @@ -144,6 +146,7 @@ def _distributed_rmsnorm( tt_stats, dim=3, num_links=1, + topology=self.ccl_topology, memory_config=ttnn.DRAM_MEMORY_CONFIG, ) # Run distributed rmsnorm part 2 diff --git a/models/demos/llama3/PERF.md b/models/demos/llama3/PERF.md index 62ac609d2ce..f0bb11616df 100644 --- a/models/demos/llama3/PERF.md +++ b/models/demos/llama3/PERF.md @@ -4,51 +4,54 @@ Performance collected from [demo/demo.py](demo/demo.py) and accuracy collected f Note that `test_llama_accuracy.py` parses the below to determine expected values +- 0.5. -## LlamaOptimizations.performance +## Performance This configuration uses bfp4 MLP FF1+FF3 for all models. -| Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) | -|-------|--------|-----------|-----------|---------------| -| 1b | N150 | 87 | 98 | 91.0 | -| 1b | N300 | 87 | 98 | 98.8 | -| 1b | T3K | 87 | 98 | 97.8 | -| 1b | TG | 88 | 99 | 51.0 | -| 3b | N150 | 90 | 98 | 49.2 | -| 3b | N300 | 90 | 98 | 56.8 | -| 3b | T3K | 88 | 98 | 54.5 | -| 3b | TG | 90 | 97 | 33.5 | -| 8b | N150 | 86 | 99 | 28.6 | -| 8b | N300 | 85 | 98 | 38.9 | -| 8b | T3K | 84 | 97 | 53.7 | -| 8b | TG | 86 | 98 | 29.5 | -| 11b | N300 | 87 | 98 | 38.6 | -| 11b | T3K | 88 | 98 | 52.6 | -| 11b | TG | 86 | 98 | 29.5 | -| 70b | T3K | 95 | 99 | 14.7 | -| 70b | TG | 95 | 100 | 12.7 | - - -## LlamaOptimizations.accuracy - -This configuration uses bfp4 MLP FF1+FF3 only for the 3.1-70B model. - -| Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) | -|-------|--------|-----------|-----------|---------------| -| 1b | N150 | 89 | 98 | 86.8 | -| 1b | N300 | 88 | 99 | 98.1 | -| 1b | T3K | 86 | 99 | 97.5 | -| 1b | TG | 87 | 98 | 51.3 | -| 3b | N150 | 92 | 100 | 44.2 | -| 3b | N300 | 92 | 99 | 54.2 | -| 3b | T3K | 91 | 98 | 55.6 | -| 3b | TG | 91 | 98 | 33.6 | -| 8b | N150 | 91 | 99 | 23.6 | -| 8b | N300 | 91 | 99 | 34.5 | -| 8b | T3K | 90 | 99 | 49.8 | -| 8b | TG | 88 | 100 | 29.5 | -| 11b | N300 | 91 | 99 | 33.8 | -| 11b | T3K | 91 | 99 | 52.6 | -| 11b | TG | 88 | 100 | 29.5 | -| 70b | T3K | 95 | 99 | 14.7 | -| 70b | TG | 95 | 100 | 12.7 | +| Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) | +|----------------|--------|-----------|-----------|---------------| +| Llama3.2-1B | N150 | 89 | 98 | 86.9 | +| Llama3.2-1B | N300 | 91 | 98 | 104.3 | +| Llama3.2-1B | T3K | 91 | 98 | 118.5 | +| Llama3.2-1B | TG | | | 72.3 | +| Llama3.2-3B | N150 | 92 | 96 | 53.3 | +| Llama3.2-3B | N300 | 91 | 96 | 66.1 | +| Llama3.2-3B | T3K | 91 | 96 | 66.9 | +| Llama3.2-3B | TG | | | 48.5 | +| Llama3.1-8B | N150 | 87 | 99 | 27.9 | +| Llama3.1-8B | N300 | 88 | 99 | 43.7 | +| Llama3.1-8B | T3K | 91 | 100 | 64.2 | +| Llama3.1-8B | TG | | | 41.0 | +| Llama3.2-11B | N300 | 89 | 99 | 43.5 | +| Llama3.2-11B | T3K | 88 | 99 | 63.4 | +| Llama3.2-11B | TG | | | 40.9 | +| Llama3.1-70B | T3K | 96 | 100 | 16.1 | +| Llama3.1-70B | TG | | | | +| Qwen2.5-7B | N300 | 81 | 96 | 37.9 | +| Qwen2.5-72B | T3K | 99 | 100 | 12.8 | + +## Accuracy + +This configuration uses bfp4 MLP FF1+FF3 only for the Llama-3.1-70B model and the Qwen-2.5-72B model. + +| Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) | +|----------------|--------|-----------|-----------|---------------| +| Llama3.2-1B | N150 | 88 | 98 | 86.8 | +| Llama3.2-1B | N300 | 90 | 98 | 98.1 | +| Llama3.2-1B | T3K | 90 | 98 | 97.5 | +| Llama3.2-1B | TG | 87 | 98 | 51.3 | +| Llama3.2-3B | N150 | 93 | 99 | 44.2 | +| Llama3.2-3B | N300 | 92 | 98 | 54.2 | +| Llama3.2-3B | T3K | 93 | 98 | 55.6 | +| Llama3.2-3B | TG | 91 | 98 | 33.6 | +| Llama3.1-8B | N150 | 93 | 100 | 23.6 | +| Llama3.1-8B | N300 | 93 | 100 | 34.5 | +| Llama3.1-8B | T3K | 92 | 100 | 49.8 | +| Llama3.1-8B | TG | 88 | 100 | 29.5 | +| Llama3.2-11B | N300 | 93 | 100 | 33.8 | +| Llama3.2-11B | T3K | 94 | 100 | 52.6 | +| Llama3.2-11B | TG | 88 | 100 | 29.5 | +| Llama3.1-70B | T3K | 97 | 100 | 14.7 | +| Llama3.1-70B | TG | 95 | 100 | 12.7 | +| Qwen2.5-7B | N300 | 81 | 96 | 33.4 | +| Qwen2.5-72B | T3K | 99 | 100 | 12.8 | diff --git a/models/demos/llama3/README.md b/models/demos/llama3/README.md index b64f4739a90..65d370e4a5b 100644 --- a/models/demos/llama3/README.md +++ b/models/demos/llama3/README.md @@ -8,6 +8,7 @@ The current version supports the following Llama3 models: - Llama3.1-8B - Llama3.2-11B - Llama3.1-70B (T3000 and TG-only) +- DeepSeek R1 Distill Llama 3.3 70B (T3000 and TG-only) All the above llama models (with the exception of 70B due to its large size) are compatible and tested on the following Tenstorrent hardware: - N150 (1-chip) @@ -25,13 +26,15 @@ Max Prefill Chunk Sizes (text-only): | Llama3.1-8B | 4k tokens | 64k tokens | 128k tokens | 128k tokens | | Llama3.2-11B | 4k tokens | 64k tokens | 128k tokens | 128k tokens | | Llama3.1-70B | Not supported | Not supported | 32k tokens | 128k tokens | +| DeepSeek-R1-Distill-Llama3.3-70B | Not supported | Not supported | 32k tokens | 128k tokens | + - These max chunk sizes are specific to max context length 128k and are configured via `MAX_PREFILL_CHUNK_SIZES_DIV1024` in [model_config.py](https://github.com/tenstorrent/tt-metal/blob/main/models/demos/llama3/tt/model_config.py). If the max context length is set to a smaller value using the `max_seq_len` flag (see [Run the demo](#run-the-demo)), these chunk sizes can possibly be increased due to using a smaller KV cache. **Max Context Lengths (Llama3.2-11B multimodal)**: Llama3.2-11B multimodal is currently only supported on N300 and T3000. On N300, a max prefill context length of 8k is supported, while T3000 supports a max context length of 128k. ## How to Run -### Download the weights +### Llama models: download the weights Download the weights [directly from Meta](https://llama.meta.com/llama-downloads/), this will mean accepting their license terms. @@ -59,17 +62,33 @@ Llama3.2-11B multimodal requires extra python dependencies. Install them from: pip install -r models/demos/llama3/requirements.txt ``` +### HuggingFace models (e.g. DeepSeek R1 Distill Llama 3.3 70B) + +Download the weights from [HuggingFace](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-70B). Your model directory should have the following structure: + +``` +DeepSeek-R1-Distill-Llama-70B/ + config.json + generation_config.json + model-00001-of-00062.safetensors + ... +``` + ### Setup TT environment 1. Set up environment variables: ``` -export LLAMA_DIR= +export LLAMA_DIR= +``` + +On N150, N300 and T3K: +``` export WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml ``` - `$LLAMA_DIR` sets the path for the Llama3 model weights and caches. -- `$WH_ARCH_YAML` sets the dispatch over ethernet cores. This is optional for N150 and required for N300 and T3000, enabling a full core grid utilization (8x8), allowing for maximum performance of LLama3 models. +- `$WH_ARCH_YAML` sets the dispatch over ethernet cores. This is optional for N150 and required for N300 and T3000, enabling a full core grid utilization (8x8), allowing for maximum performance of LLama3 models. Do not set this for TG. On the first execution of each model, TTNN will create weight cache files for that model, to speed up future runs. These cache files only need to be created once for each model and each weight (i.e. new finetuned weights will need to be cached) and will be stored accordingly to the machine you are running the models: @@ -80,7 +99,6 @@ $LLAMA_DIR/T3K # For T3000 $LLAMA_DIR/TG # For TG ``` - ### Run the demo The Llama3 demo includes 3 main modes of operation and is fully parametrized to support other configurations. @@ -88,6 +106,7 @@ The Llama3 demo includes 3 main modes of operation and is fully parametrized to - `batch-1`: Runs a small prompt for a single user - `batch-32`: Runs a small prompt for a a batch of 32 users - `long-context`: Runs a large prompt (64k tokens) for a single user +- `reasoning-1`: Runs a reasoning prompt for a single user If you want to provide your own demo configuration, please take a look at the pytest parametrize calls in `models/demos/llama3/demo/demo.py`. For convenience we list all the supported params below: diff --git a/models/demos/llama3/demo/demo.py b/models/demos/llama3/demo/demo.py index a0b09e4dae1..21aea65fb6b 100644 --- a/models/demos/llama3/demo/demo.py +++ b/models/demos/llama3/demo/demo.py @@ -15,22 +15,17 @@ from pathlib import Path import hashlib -from models.utility_functions import nearest_32 from models.demos.llama3.tt.llama_common import ( get_prefill_rot_mat, - get_rot_transformation_mat, - HostEmbedding, - encode_prompt_llama_instruct, PagedAttentionConfig, sample_host, ) from models.demos.llama3.tt.llama_model import TtTransformer from models.demos.llama3.tt.llama_embedding import TtLlamaEmbedding -from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer from models.demos.llama3.tt.model_config import TtModelArgs from models.perf.benchmarking_utils import BenchmarkProfiler -from models.demos.utils.llm_demo_utils import create_benchmark_data, verify_perf +from models.demos.utils.llm_demo_utils import create_benchmark_data from models.demos.llama3.tt.model_config import LlamaOptimizations @@ -108,10 +103,7 @@ def preprocess_inputs_prefill( if max_prefill_len == 128 * 1024: max_prefill_len = 128 * 1024 - max_generated_tokens - if instruct: - encoded_prompts = [encode_prompt_llama_instruct(tokenizer, prompt) for prompt in input_prompts] - else: - encoded_prompts = [tokenizer.encode(prompt, bos=True, eos=False) for prompt in input_prompts] + encoded_prompts = [model_args.encode_prompt(prompt, instruct=instruct) for prompt in input_prompts] # Print the length of encoded prompts logger.info("Encoded prompt lengths:" + ", ".join(str(len(prompt)) for prompt in encoded_prompts)) @@ -122,14 +114,26 @@ def preprocess_inputs_prefill( # The large input demo we provide contains more tokens than the maximum (32k tokens) # To avoid running out of memory, clip to max_prefill_len + if min_prompt_len > max_prefill_len: - logger.info(f"Clipping prompts to {max_prefill_len}") - if instruct: # When clipping, make sure to add the ` 】 token at the end (4 tokens) - encoded_prompts = [encod[: max_prefill_len - 4] for encod in encoded_prompts] - dec_prompts = [tokenizer.decode(encod) + " 】" for encod in encoded_prompts] - encoded_prompts = [tokenizer.encode(prompt, bos=True, eos=False) for prompt in dec_prompts] + logger.info(f"Left-clipping prompts to {max_prefill_len}") + if instruct: + # We need to allow a few tokens for the system prompt and the special turn tokens for assistant and user; + # to find out how big those will be, we will: + # 1. Tokenize the entire prompt with non-instruct tokenization + # 2. Calculate overhead = length of instruct tokenization - length of non-instruct tokenization + # 3. Shorten the tokenized clipped prompt by the overhead and convert back to text + # 4. Tokenize the result with instruct tokenization + # 5. Assert that the length of this is equal to the max_prefill_len + raw_prompts = [model_args.encode_prompt(prompt, instruct=False) for prompt in input_prompts] + overhead = [len(e) - len(r) for e, r in zip(encoded_prompts, raw_prompts)] + shortened = [tokenizer.decode(e[-(max_prefill_len - o) :]) for e, o in zip(raw_prompts, overhead)] + encoded_prompts = [model_args.encode_prompt(prompt, instruct=instruct) for prompt in shortened] + assert all( + len(e) == max_prefill_len for e in encoded_prompts + ), f"Clipped prompts are not of the correct length, expected {max_prefill_len} but got {[len(e) for e in encoded_prompts]}" else: - encoded_prompts = [encod[:max_prefill_len] for encod in encoded_prompts] + encoded_prompts = [encod[-max_prefill_len:] for encod in encoded_prompts] # Update prompt lengths prompt_lens = [len(x) for x in encoded_prompts] @@ -227,20 +231,20 @@ def run_llama3_demo( max_seq_len=max_seq_len, ) - tokenizer = Tokenizer(model_args.tokenizer_path) + tokenizer = model_args.tokenizer # Check max sequence length compatibility with model and architecture. Refer to README for more information - llama_model_name = model_args.model_name # ["3.2-1B", "3.2-3B", "3.1-8B", "3.2-11B", "3.1-70B"] + llama_model_name = model_args.base_model_name # ["3.2-1B", "3.2-3B", "3.1-8B", "3.2-11B", "3.1-70B"] tt_device_name = model_args.device_name # ["N150", "N300", "T3K", "TG"] - if llama_model_name in ["3.1-8B", "3.2-11B"] and tt_device_name == "N150": + if llama_model_name in ["Llama3.1-8B", "Llama3.2-11B"] and tt_device_name == "N150": assert ( max_seq_len <= 64 * 1024 ), "N150 only supports a max context length of 64k tokens for Llama3.1-8B and Llama3.2-11B" else: - assert max_seq_len <= 128 * 1024, f"Llama{llama_model_name} supports a max context length of 128k tokens" + assert max_seq_len <= 128 * 1024, f"{llama_model_name} supports a max context length of 128k tokens" - if llama_model_name == "3.1-70B": + if llama_model_name == "Llama3.1-70B": assert tt_device_name in ["T3K", "TG"], "Llama3.1-70B is only supported on T3K or TG" logger.info("Loading weights...") @@ -284,7 +288,7 @@ def run_llama3_demo( state_dict=state_dict, dtype=ttnn.bfloat16, # Row major layout requires bfloat16 ) - embd = HostEmbedding(model_args) + embd = model_args.reference_embedding() state_dict_prefix = model_args.get_state_dict_prefix("", None) embd.load_state_dict({"emb.weight": state_dict[f"{state_dict_prefix}tok_embeddings.weight"]}) profiler.end("loading_weights_to_device") @@ -340,8 +344,10 @@ def run_llama3_demo( model_args.head_dim, model_args.max_seq_len, mesh_device, - seq_len=prefill_seq_len, - scale_factor=model_args.rope_scaling_factor, + prefill_seq_len, + model_args.rope_theta, + model_args.rope_scaling_factor, + model_args.orig_context_len, ) if decoding_pos[batch_id] < prefill_seq_len: pt_prefill_input[batch_id][ @@ -483,10 +489,15 @@ def run_llama3_demo( if tt_model.args.num_devices > 1: if tt_model.args.is_galaxy: tt_out_gathered = ttnn.all_gather( - tt_out, dim=3, num_links=2, cluster_axis=0, mesh_device=mesh_device, topology=ttnn.Topology.Linear + tt_out, + dim=3, + num_links=2, + cluster_axis=0, + mesh_device=mesh_device, + topology=model_args.ccl_topology(), ) else: - tt_out_gathered = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear) + tt_out_gathered = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=model_args.ccl_topology()) ttnn.deallocate(tt_out) else: tt_out_gathered = tt_out @@ -527,10 +538,15 @@ def run_llama3_demo( if tt_model.args.num_devices > 1: if tt_model.args.is_galaxy: tt_out_gathered = ttnn.all_gather( - tt_out, dim=3, num_links=2, cluster_axis=0, mesh_device=mesh_device, topology=ttnn.Topology.Linear + tt_out, + dim=3, + num_links=2, + cluster_axis=0, + mesh_device=mesh_device, + topology=model_args.ccl_topology(), ) else: - tt_out_gathered = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=ttnn.Topology.Linear) + tt_out_gathered = ttnn.all_gather(tt_out, dim=3, num_links=1, topology=model_args.ccl_topology()) ttnn.deallocate(tt_out) else: tt_out_gathered = tt_out @@ -550,13 +566,15 @@ def run_llama3_demo( current_pos_reset = ttnn.from_torch( current_pos, dtype=ttnn.int32, - mesh_mapper=ttnn.ShardTensor2dMesh( - mesh_device, - dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None), - mesh_shape=model_args.cluster_shape, - ) - if tt_model.args.num_devices > 1 - else None, + mesh_mapper=( + ttnn.ShardTensor2dMesh( + mesh_device, + dims=(None, 0) if (model_args.is_galaxy and batch_size > 1) else (None, None), + mesh_shape=model_args.cluster_shape, + ) + if tt_model.args.num_devices > 1 + else None + ), ) tt_out_tok_reset = ttnn.from_torch( torch.nn.functional.pad( @@ -629,8 +647,8 @@ def run_llama3_demo( for user in range(batch_size): user_tok = tt_output_torch[user].tolist() if ( - user_tok != 128009 and user_done[user] == False - ): # Stop saving the ouput after hitting the eos token (<|eot_id|>) (128009) + user_tok not in tokenizer.stop_tokens and user_done[user] == False + ): # Read until an eos token (e.g. <|eot_id|>); create_tokenizer adds stop_tokens to HF tokenizers all_outputs[user].append(user_tok) else: user_done[user] = True @@ -680,14 +698,10 @@ def run_llama3_demo( profiler.start(f"log_saving_file", iteration=batch_idx) for i, (output, prompt) in enumerate(zip(all_outputs, input_prompts)): text = tokenizer.decode(output) - if instruct_mode: - split_text = text.split("<|start_header_id|>assistant<|end_header_id|>", 1) - else: - split_text = text.split(prompt, 1) - if len(split_text) > 1: - text_after_prompt = split_text[1] - else: - text_after_prompt = text # If prompt is not found, use the whole text + prompt_including_assistant_tags = tokenizer.decode( + model_args.encode_prompt(prompt, instruct=instruct_mode) + ) + text_after_prompt = text.replace(prompt_including_assistant_tags, "", 1) if print_to_file: with open(output_filename, "a") as f: f.write( @@ -770,76 +784,78 @@ def run_llama3_demo( ) logger.info("") - supported_models = ["3.2-1B", "3.2-3B", "3.1-8B", "3.2-11B", "3.1-70B"] + supported_models = ["Llama3.2-1B", "Llama3.2-3B", "Llama3.1-8B", "Llama3.2-11B", "Llama3.1-70B"] supported_devices = ["N150", "N300", "T3K", "TG"] # TODO update targets based on the llama3 model and the target device - llama_model_name = model_args.model_name tt_device_name = model_args.device_name - assert llama_model_name in supported_models, f"Model {llama_model_name} not supported" - assert tt_device_name in supported_devices, f"Device {tt_device_name} not supported" - - # Set the target times to first token for every combination of device and model - target_prefill_tok_s = { - "N150_3.2-1B": 1050, # TODO Update target - "N300_3.2-1B": 1050, # TODO Update target - "T3K_3.2-1B": 1050, # TODO Update target - "TG_3.2-1B": 1050, # TODO Update target - # - "N150_3.2-3B": 1050, # TODO Update target - "N300_3.2-3B": 1050, # TODO Update target - "T3K_3.2-3B": 1050, # TODO Update target - "TG_3.2-3B": 1050, # TODO Update target - # - "N150_3.1-8B": 1050, - "N300_3.1-8B": 1050, - "T3K_3.1-8B": 1050, - "TG_3.1-8B": 1050, - # - "N150_3.2-11B": 1050, # TODO Update target - "N300_3.2-11B": 1050, # TODO Update target - "T3K_3.2-11B": 1050, # TODO Update target - "TG_3.2-11B": 1050, # TODO Update target - # - "N150_3.1-70B": 1050, # TODO Update target - "N300_3.1-70B": 1050, # TODO Update target - "T3K_3.1-70B": 1050, # TODO Update target - "TG_3.1-70B": 1050, # TODO Update target - }[f"{tt_device_name}_{llama_model_name}"] - - # Set the target decode timesfor every combination of device and model - target_decode_tok_s_u = { - "N150_3.2-1B": 160, # TODO Update target - "N300_3.2-1B": 250, # TODO Update target - "T3K_3.2-1B": 300, # TODO Update target - "TG_3.2-1B": 300, # TODO Update target - # - "N150_3.2-3B": 60, # TODO Update target - "N300_3.2-3B": 100, # TODO Update target - "T3K_3.2-3B": 150, # TODO Update target - "TG_3.2-3B": 150, # TODO Update target - # - "N150_3.1-8B": 23, # TODO Update target - "N300_3.1-8B": 38, - "T3K_3.1-8B": 45, - "TG_3.1-8B": 45, # TODO Update target - # - "N150_3.2-11B": 23, - "N300_3.2-11B": 38, # TODO Update target - "T3K_3.2-11B": 45, # TODO Update target - "TG_3.2-11B": 45, # TODO Update target - # - "T3K_3.1-70B": 20, # TODO Update target - "TG_3.1-70B": 20, # TODO Update target - }[f"{tt_device_name}_{llama_model_name}"] - - target_decode_tok_s = target_decode_tok_s_u * batch_size - targets = { - "prefill_t/s": target_prefill_tok_s, - "decode_t/s": target_decode_tok_s, - "decode_t/s/u": target_decode_tok_s_u, - } + if model_args.base_model_name in supported_models: + assert tt_device_name in supported_devices, f"Device {tt_device_name} not supported" + + # Set the target times to first token for every combination of device and model + target_prefill_tok_s = { + "N150_Llama3.2-1B": 1050, # TODO Update target + "N300_Llama3.2-1B": 1050, # TODO Update target + "T3K_Llama3.2-1B": 1050, # TODO Update target + "TG_Llama3.2-1B": 1050, # TODO Update target + # + "N150_Llama3.2-3B": 1050, # TODO Update target + "N300_Llama3.2-3B": 1050, # TODO Update target + "T3K_Llama3.2-3B": 1050, # TODO Update target + "TG_Llama3.2-3B": 1050, # TODO Update target + # + "N150_Llama3.1-8B": 1050, + "N300_Llama3.1-8B": 1050, + "T3K_Llama3.1-8B": 1050, + "TG_Llama3.1-8B": 1050, + # + "N150_Llama3.2-11B": 1050, # TODO Update target + "N300_Llama3.2-11B": 1050, # TODO Update target + "T3K_Llama3.2-11B": 1050, # TODO Update target + "TG_Llama3.2-11B": 1050, # TODO Update target + # + "N150_Llama3.1-70B": 1050, # TODO Update target + "N300_Llama3.1-70B": 1050, # TODO Update target + "T3K_Llama3.1-70B": 1050, # TODO Update target + "TG_Llama3.1-70B": 1050, # TODO Update target + }[f"{tt_device_name}_{model_args.base_model_name}"] + + # Set the target decode timesfor every combination of device and model + target_decode_tok_s_u = { + "N150_Llama3.2-1B": 160, # TODO Update target + "N300_Llama3.2-1B": 250, # TODO Update target + "T3K_Llama3.2-1B": 300, # TODO Update target + "TG_Llama3.2-1B": 300, # TODO Update target + # + "N150_Llama3.2-3B": 60, # TODO Update target + "N300_Llama3.2-3B": 100, # TODO Update target + "T3K_Llama3.2-3B": 150, # TODO Update target + "TG_Llama3.2-3B": 150, # TODO Update target + # + "N150_Llama3.1-8B": 23, # TODO Update target + "N300_Llama3.1-8B": 38, + "T3K_Llama3.1-8B": 45, + "TG_Llama3.1-8B": 45, # TODO Update target + # + "N150_Llama3.2-11B": 23, + "N300_Llama3.2-11B": 38, # TODO Update target + "T3K_Llama3.2-11B": 45, # TODO Update target + "TG_Llama3.2-11B": 45, # TODO Update target + # + "T3K_Llama3.1-70B": 20, # TODO Update target + "TG_Llama3.1-70B": 20, # TODO Update target + }[f"{tt_device_name}_{model_args.base_model_name}"] + + target_decode_tok_s = target_decode_tok_s_u * batch_size + targets = { + "prefill_t/s": target_prefill_tok_s, + "decode_t/s": target_decode_tok_s, + "decode_t/s/u": target_decode_tok_s_u, + } + else: + logger.warning(f"Model {model_args.base_model_name} not does not have performance targets set") + targets = {} # Save benchmark data for CI dashboard if is_ci_env: @@ -847,7 +863,7 @@ def run_llama3_demo( benchmark_data.save_partial_run_json( profiler, run_type=f"{tt_device_name}-demo", - ml_model_name=llama_model_name, + ml_model_name=model_args.base_model_name, ml_model_type="llm", num_layers=model_args.n_layers, batch_size=batch_size, @@ -873,6 +889,17 @@ def run_llama3_demo( @pytest.mark.parametrize( "input_prompts, instruct, repeat_batches, max_seq_len, batch_size, max_generated_tokens, paged_attention, page_params, sampling_params", [ + ( # Batch-1 run (Reasoning) - single user, small prompt, long thinking time + "models/demos/llama3/demo/input_data_questions_reasoning.json", # input_prompts + True, # instruct mode + 1, # repeat_batches + 16384, # max_seq_len + 1, # batch_size + 15000, # max_generated_tokens + True, # paged_attention + {"page_block_size": 32, "page_max_num_blocks": 1024}, # page_params # TODO This will be serviced by vLLM + {"temperature": 0, "top_p": 0.08}, # sampling_params (argmax) + ), ( # Batch-1 run (Latency) - single user, small prompt "models/demos/llama3/demo/input_data_questions_prefill_128.json", # input_prompts True, # instruct mode @@ -908,6 +935,7 @@ def run_llama3_demo( ), ], ids=[ + "reasoning-1", # reasoning "batch-1", # latency "batch-32", # throughput "long-context", # max-length @@ -946,7 +974,9 @@ def test_llama_demo( is_ci_env, reset_seeds, ): - if is_ci_env and ("long" in input_prompts or optimizations == LlamaOptimizations.accuracy): + if is_ci_env and ( + "long" in input_prompts or "reasoning" in input_prompts or optimizations == LlamaOptimizations.accuracy + ): pytest.skip("Do not run the 'long-context' or accuracy tests on CI to reduce load") # TODO: Remove this once all batch sizes are supported on TG diff --git a/models/demos/llama3/demo/input_data_questions_reasoning.json b/models/demos/llama3/demo/input_data_questions_reasoning.json new file mode 100644 index 00000000000..360a4b49cad --- /dev/null +++ b/models/demos/llama3/demo/input_data_questions_reasoning.json @@ -0,0 +1,20 @@ +[ + { + "prompt": "Find all integer solutions (x, y) to the equation x^2 - 3y^2 = 1." + }, + { + "prompt": "Find the least odd prime factor of 2019^8 + 1" + }, + { + "prompt": "Compose a maximally-catchy piece of piano music; the left hand should only play chords and the right hand a simple melody. The song should get stuck in the listener's head for days." + }, + { + "prompt": "Compose the most beautiful and maximally-elegant haiku that captures the poignancy of the human condition; think carefully about how to make sure it packs the maximum possible emotional punch for the reader." + }, + { + "prompt": "A fair coin is tossed 8 times. What is the probability (in simplest fractional form) of getting exactly 5 heads?" + }, + { + "prompt": "How many 7-digit integers have digits strictly increasing from left to right? (For example, 1234567 is valid, 1357899 is not because of the repeated 9.)" + } +] diff --git a/models/demos/llama3/demo/simple_vision_demo.py b/models/demos/llama3/demo/simple_vision_demo.py index 47719f91462..7eaed8091a7 100644 --- a/models/demos/llama3/demo/simple_vision_demo.py +++ b/models/demos/llama3/demo/simple_vision_demo.py @@ -108,7 +108,7 @@ def test_llama_multimodal_demo_text( mesh_device.enable_async(True) model_args, model = create_multimodal_model(mesh_device, max_batch_size=max_batch_size, max_seq_len=max_seq_len) generator = LlamaGenerator(model, model_args, mesh_device) - tokenizer = Tokenizer(model_path=tokenizer_path) + tokenizer = model_args.tokenizer formatter = ChatFormat(tokenizer) xattn_caches = generator.model.setup_cache(model_args.max_batch_size) diff --git a/models/demos/llama3/lt b/models/demos/llama3/lt old mode 100644 new mode 100755 index 2a807109237..c088bb586d8 --- a/models/demos/llama3/lt +++ b/models/demos/llama3/lt @@ -61,13 +61,17 @@ def ensure_ttsmi_installed(): def reset_device_sync(config_file): - reset_cmd = ["tt-smi", "-r", config_file] - try: + if os.environ.get("RESET_CMD"): + reset_cmd = os.environ.get("RESET_CMD").split(" ") + print(f"Resetting device using custom command: {reset_cmd}") + else: + reset_cmd = ["tt-smi", "-r", config_file] print(f"Resetting device using config file: {config_file}") + try: result = subprocess.run(reset_cmd, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) print(f"Device reset successfully: {result.stdout}") except subprocess.CalledProcessError as e: - print(f"Error during device reset: {e.stderr}") + print(f"Error during device reset: {e.stdout} {e.stderr}") sys.exit(1) @@ -82,7 +86,7 @@ def get_device(): device = "N150" elif total_devices == 8: device = "T3K" - else: # TG has 36 devices + else: # TG has 36 devices device = "TG" # Old method of getting device name based on hostname @@ -109,11 +113,13 @@ def list_supported_devices(device): # Counts number of devices using `tt-smi -ls` output def count_devices(output): # Split the output into available boards section - sections = output.split('All available boards on host') - available_boards = sections[1].split('Boards that can be reset')[0] + sections = output.split("All available boards on host") + available_boards = sections[1].split("Boards that can be reset")[0] # Count total PCI devices (ignoring N/A) - total_pci_devices = len([line for line in available_boards.split('\n') if ('Wormhole' or 'Grayskull' or 'Blackhole') in line]) + total_pci_devices = len( + [line for line in available_boards.split("\n") if ("Wormhole" or "Grayskull" or "Blackhole") in line] + ) return total_pci_devices @@ -332,7 +338,7 @@ def main(stdscr): # Input fields positions (reordered) input_fields = [ {"label": "Command [demo]", "value": "", "x": 0, "y": 0}, - {"label": "Model (1b, 3b, 8b, 11b, 70b) [all]", "value": "", "x": 0, "y": 1}, + {"label": "Model (1b, 3b, 8b, 11b, 70b, 70b-r1, q7b, q72b) [all]", "value": "", "x": 0, "y": 1}, { "label": f"Device ({list_supported_devices(host_device)}) [all]", "value": "", @@ -447,10 +453,8 @@ def main(stdscr): if current_field == len(input_fields) - 1: # Submit command command_input = input_fields[0]["value"] or "demo" - model_input = input_fields[1]["value"] or "1b,3b,8b,11b,70b" - device_input = ( - input_fields[2]["value"] or list_supported_devices(host_device) - ) + model_input = input_fields[1]["value"] or "1b,3b,8b,11b,70b,70b-r1,q7b,q72b" + device_input = input_fields[2]["value"] or list_supported_devices(host_device) if command_input == "modules": command_input = "rmsnorm,attention,attention-prefill,mlp,lm-head" @@ -461,6 +465,9 @@ def main(stdscr): if command_input == "table": command_input = "accuracy,demo,accuracy-acc,demo-acc" + if command_input == "vision": + command_input = "vision-mlp,vision-attn,vision-block,vision-xfmr,vision-xattn,vision-xblock,vision-conv,vision-class,vision-tile-pos,vision-pos,vision-encoder,vision-text-xfmr,vision-vision-xfmr" + # Parse models, devices, and commands models = parse_list(model_input) devices = parse_list(device_input) @@ -469,7 +476,9 @@ def main(stdscr): # Generate combinations (reordered) # Ignore invalid combinations: # - 11b and 11b-b models on n150 device - # - 70b model on n150 and n300 devices + # - 70b and 70b-r1 model on n150 and n300 devices + # - 72b model on n150 and n300 devices + # - q7b on anything other than N300 # - Vision commands on non-vision (11b) models combinations = [ (c, m, d) @@ -479,6 +488,9 @@ def main(stdscr): if not ( (m in ["11b", "11b-b"] and d == "n150") or (m == "70b" and d in ["n150", "n300"]) + or (m == "70b-r1" and d in ["n150", "n300"]) + or (m == "q72b" and d in ["n150", "n300"]) + or (m == "q7b" and d != "n300") or ("vision" in c and m not in ["11b", "11b-b"]) ) ] @@ -1034,6 +1046,9 @@ def get_llama_dir(model): "11b": os.environ.get("LLAMA_32_11B_DIR", "/proj_sw/user_dev/llama32-data/Llama3.2-11B-Vision-Instruct"), "11b-b": os.environ.get("LLAMA_32_11B_BASE_DIR", "/proj_sw/user_dev/llama32-data/Llama3.2-11B-Vision"), "70b": os.environ.get("LLAMA_31_70B_DIR", "/proj_sw/llama3_1-weights/Meta-Llama-3.1-70B-Instruct/repacked"), + "70b-r1": os.environ.get("DEEPSEEK_R1_LLAMA_70B_DIR", "/proj_sw/deepseek/DeepSeek-R1-Distill-Llama-70B"), + "q7b": os.environ.get("QWEN_7B_DIR", "/proj_sw/user_dev/Qwen/Qwen2.5-7B-Instruct"), + "q72b": os.environ.get("QWEN_72B_DIR", "/proj_sw/user_dev/Qwen/Qwen2.5-72B-Instruct"), }.get(model.lower(), "") if not llama_dir or not os.path.exists(llama_dir): @@ -1044,6 +1059,9 @@ def get_llama_dir(model): print(" - LLAMA_31_8B_DIR for 8b model") print(" - LLAMA_32_11B_DIR for 11b model") print(" - LLAMA_31_70B_DIR for 70b model") + print(" - DEEPSEEK_R1_LLAMA_70B_DIR for DeepSeek R1 Llama 70b distill model") + print(" - QWEN_7B_DIR for 7b Qwen2.5 model") + print(" - QWEN_72B_DIR for 72b Qwen2.5 model") sys.exit(1) return llama_dir @@ -1250,6 +1268,17 @@ def export_results_to_markdown(output_entries, stdscr): "|-------|--------|-----------|-----------|---------------|", ] + fullname = { + "1b": "Llama-3.2-1B", + "3b": "Llama-3.2-3B", + "8b": "Llama-3.1-8B", + "11b": "Llama-3.2-11B", + "70b": "Llama-3.1-70B", + "70b-r1": "DeepSeek-R1-Llama-70B", + "q7b": "Qwen-2.5-7B", + "q72b": "Qwen-2.5-72B", + } + # Add rows for performance table in original order for entry in perf_entries: (model, device), top1, top5, speed = entry @@ -1271,7 +1300,7 @@ def export_results_to_markdown(output_entries, stdscr): # Add rows for accuracy table in original order for entry in acc_entries: (model, device), top1, top5, speed = entry - markdown_lines.append(f"| {model} | {device} | {top1} | {top5} | {speed} |") + markdown_lines.append(f"| {fullname[model]} | {device} | {top1} | {top5} | {speed} |") # Write to PERF.md with open("PERF.md", "w") as f: diff --git a/models/demos/llama3/requirements.txt b/models/demos/llama3/requirements.txt index e830cffd233..438cea7dbee 100644 --- a/models/demos/llama3/requirements.txt +++ b/models/demos/llama3/requirements.txt @@ -1 +1,2 @@ git+https://github.com/tenstorrent/llama-models.git@tt_metal_tag +transformers >= 4.46.3 diff --git a/models/demos/llama3/tests/generate_reference_hf.py b/models/demos/llama3/tests/generate_reference_hf.py new file mode 100644 index 00000000000..f275584e6da --- /dev/null +++ b/models/demos/llama3/tests/generate_reference_hf.py @@ -0,0 +1,148 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import bz2 +import os +import argparse +from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer +from loguru import logger + + +def generate_reference_outputs(total_length, output_file, model_name): + # Set device + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + logger.info(f"Using device: {device}") + + # Load model and tokenizer from HuggingFace + config = AutoConfig.from_pretrained(model_name) + + # Qwen only: add rope scaling to the config + # https://huggingface.co/Qwen/Qwen2.5-7B-Instruct#processing-long-texts + if "Qwen" in model_name: + config.rope_scaling = {"factor": 4.0, "original_max_position_embeddings": 32768, "type": "yarn"} + + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained(model_name, config=config, device_map="auto") + model.eval() + + # Load the book text + current_file_path = os.path.abspath(__file__) + current_file_dir = os.path.dirname(current_file_path) + prompt_file = os.path.join(current_file_dir, "tale-of-two-cities.txt.bz2") + + with bz2.open(prompt_file, "rt", encoding="utf-8") as f: + text = f.read() + + # Encode text to tokens + encoded_tokens = tokenizer.encode(text, add_special_tokens=True)[:total_length] + encoded_tokens_tensor = torch.tensor(encoded_tokens, device=device).unsqueeze(0) # Shape [1, seq_len] on device + + print(f"{'Progress':<15}{'Correct':<8}{'Actual':<15}{'Top 5 Predictions':<75}") + print("-" * 113) + + # Initialize lists to store results + all_top1_correct = [] + all_top5_correct = [] + all_top5_tokens = [] + segment_accuracies = [] + chunk_size = 1024 + + with torch.no_grad(): + for chunk_start in range(0, total_length - 1, chunk_size): + chunk_end = min(chunk_start + chunk_size, total_length) + # Get input and target chunks + chunk_tokens = encoded_tokens_tensor[:, chunk_start:chunk_end] + chunk_next_tokens = encoded_tokens[chunk_start + 1 : chunk_end + 1] + actual_chunk_size = min(len(chunk_tokens[0]), len(chunk_next_tokens)) + + # Trim input chunk if needed + chunk_tokens = chunk_tokens[:, :actual_chunk_size] + + # Process chunk using HuggingFace model + outputs = model(chunk_tokens.to(device)) + logits = outputs.logits + + # Compute top-5 predictions + probs = torch.softmax(logits, dim=-1) + _, chunk_top5_tokens = torch.topk(probs, k=5, dim=-1) # Shape: [1, chunk_size, 5] + chunk_top5_tokens = chunk_top5_tokens.squeeze(0) # Shape: [chunk_size, 5] + + # Get next tokens tensor + chunk_next_tokens_tensor = torch.tensor( + chunk_next_tokens[:actual_chunk_size], device=device + ) # Move to same device + + # Calculate correctness + chunk_top1_correct = chunk_top5_tokens[:, 0] == chunk_next_tokens_tensor + chunk_top5_correct = torch.any(chunk_top5_tokens == chunk_next_tokens_tensor.unsqueeze(1), dim=1) + + # Store results + all_top1_correct.extend(chunk_top1_correct.tolist()) + all_top5_correct.extend(chunk_top5_correct.tolist()) + all_top5_tokens.append(chunk_top5_tokens) + + # Print predictions for this chunk + for i in range(len(chunk_next_tokens)): + global_pos = chunk_start + i + next_token = chunk_next_tokens[i] + + sanitize = lambda x: x.replace("\n", "").replace("\r", "").replace("\x0c", "") + actual_token = sanitize(tokenizer.decode([next_token])) + top5_tokens = [sanitize(tokenizer.decode([t.item()])) for t in chunk_top5_tokens[i]] + correct = "x" if chunk_top1_correct[i] else ("-" if chunk_top5_correct[i] else " ") + top5_str = " ".join(f"{t:<14}" for t in top5_tokens) + + progress_str = f"{global_pos+1}/{total_length-1}" + print(f"{progress_str:<15}{correct:<8}{actual_token:<15}{top5_str}") + + # Calculate and store segment accuracies every 100 tokens + if (global_pos + 1) % 100 == 0 or global_pos == total_length - 2: + start_idx = (global_pos // 100) * 100 + end_idx = min(start_idx + 100, len(all_top1_correct)) + segment_top1_acc = sum(all_top1_correct[start_idx:end_idx]) / (end_idx - start_idx) * 100 + segment_top5_acc = sum(all_top5_correct[start_idx:end_idx]) / (end_idx - start_idx) * 100 + if len(segment_accuracies) <= global_pos // 100: + segment_accuracies.append((segment_top1_acc, segment_top5_acc)) + + # Save the data - ensure tensors are concatenated and on CPU + data = { + "top5_tokens": torch.cat(all_top5_tokens, dim=0).cpu(), + "reference_tokens": encoded_tokens_tensor[:, :total_length].clone().cpu(), + } + + torch.save(data, output_file) + logger.info(f"Saved reference outputs to {output_file}") + + # Print all segment accuracy summaries as a table + print("\nSegment Accuracy Summaries:") + print(f"{'Tokens':<15}{'Top-1 Accuracy':<20}{'Top-5 Accuracy':<20}") + print("-" * 55) + for i, (top1_acc, top5_acc) in enumerate(segment_accuracies): + start_token = i * 100 + 1 + end_token = min((i + 1) * 100, total_length) + print(f"{f'{start_token}-{end_token}':<15}{f'{top1_acc:.2f}%':<20}{f'{top5_acc:.2f}%':<20}") + + # Calculate overall accuracy + overall_top1_acc = sum(acc[0] for acc in segment_accuracies) / len(segment_accuracies) + overall_top5_acc = sum(acc[1] for acc in segment_accuracies) / len(segment_accuracies) + print("-" * 55) + print(f"{'Overall':<15}{f'{overall_top1_acc:.2f}%':<20}{f'{overall_top5_acc:.2f}%':<20}") + + +def main(): + parser = argparse.ArgumentParser(description="Generate reference outputs using HuggingFace models.") + parser.add_argument("--total_length", type=int, default=1024, help="Total length of tokens to process") + parser.add_argument( + "--output_file", type=str, default="reference_outputs.pt", help="Output file path for reference data" + ) + parser.add_argument( + "--model", type=str, required=True, help="HuggingFace model name (e.g., 'meta-llama/Llama-2-7b-hf')" + ) + args = parser.parse_args() + + generate_reference_outputs(total_length=args.total_length, output_file=args.output_file, model_name=args.model) + + +if __name__ == "__main__": + main() diff --git a/models/demos/llama3/tests/generate_reference_outputs.py b/models/demos/llama3/tests/generate_reference_outputs.py index 1f0514bfe7b..f874e913a10 100644 --- a/models/demos/llama3/tests/generate_reference_outputs.py +++ b/models/demos/llama3/tests/generate_reference_outputs.py @@ -5,28 +5,40 @@ import bz2 import os import argparse -import time -from models.demos.llama3.tt.llama_common import HostEmbedding -from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import Transformer -from models.demos.llama3.tt.model_config import TtModelArgs -from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer +from models.demos.llama3.tt.model_config import TtModelArgs, CheckpointType from loguru import logger from transformers import AutoModelForCausalLM, AutoTokenizer def generate_reference_outputs(total_length, output_file, hf_model_name=None): + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + logger.info(f"Using device: {device}") + if hf_model_name: # HuggingFace path tokenizer = AutoTokenizer.from_pretrained(hf_model_name) - model = AutoModelForCausalLM.from_pretrained(hf_model_name, torch_dtype=torch.float32) + config = AutoConfig.from_pretrained(hf_model_name) + # Qwen only: add rope scaling to the config + # https://huggingface.co/Qwen/Qwen2.5-7B-Instruct#processing-long-texts + if "Qwen" in hf_model_name: + config.rope_scaling = {"factor": 4.0, "original_max_position_embeddings": 32768, "type": "yarn"} + model = AutoModelForCausalLM.from_pretrained( + hf_model_name, config=config, torch_dtype=torch.float32 if device == "cpu" else None, device_map="auto" + ) model.eval() + else: # Original path - load reference model model_args = TtModelArgs(mesh_device=None) model_args.max_seq_len = total_length tokenizer = Tokenizer(model_args.tokenizer_path) + # Special-case Hf models as they can load directly from the safetensors much more efficiently + if model_args.checkpoint_type == CheckpointType.Meta: + # Load the model state dict state_dict = model_args.load_state_dict() + + # Initialize the reference model state_dict_prefix = model_args.get_state_dict_prefix("", None) reference_state_dict = { k[len(state_dict_prefix) :]: v @@ -41,13 +53,20 @@ def generate_reference_outputs(total_length, output_file, hf_model_name=None): ) ) } - model = Transformer(model_args) - model.load_state_dict(reference_state_dict) - model.eval() + reference_model = model_args.reference_transformer() + reference_model.to(device) # Move model to device + reference_model.eval() # Set to evaluation mode + reference_model.load_state_dict(reference_state_dict) - # Initialize HostEmbedding - embd = HostEmbedding(model_args) + embd = model_args.reference_embedding(reference_model) + embd.to(device) # Move embedding to device embd.load_state_dict({"emb.weight": state_dict[f"{state_dict_prefix}tok_embeddings.weight"]}) + else: + reference_model = model_args.reference_transformer(load_checkpoint=True) + reference_model.to(device) # Move model to device + reference_model.eval() # Set to evaluation mode + embd = reference_model.model.model.embed_tokens + embd.to(device) # Move embedding to device # Load the book text and encode tokens current_file_path = os.path.abspath(__file__) @@ -57,13 +76,9 @@ def generate_reference_outputs(total_length, output_file, hf_model_name=None): with bz2.open(prompt_file, "rt", encoding="utf-8") as f: text = f.read() - # Modify token encoding based on model type - if hf_model_name: - encoded_tokens = tokenizer.encode(text, add_special_tokens=True)[:total_length] - else: - encoded_tokens = tokenizer.encode(text, bos=True, eos=False)[:total_length] - - encoded_tokens_tensor = torch.tensor(encoded_tokens).unsqueeze(0) # Shape [1, seq_len] + # Encode text to tokens + encoded_tokens = model_args.encode_prompt(text, instruct=False) + encoded_tokens_tensor = torch.tensor(encoded_tokens, device=device).unsqueeze(0) # Move to device print(f"{'Progress':<15}{'Correct':<8}{'Actual':<15}{'Top 5 Predictions':<75}") print("-" * 113) @@ -87,6 +102,7 @@ def generate_reference_outputs(total_length, output_file, hf_model_name=None): chunk_tokens = chunk_tokens[:, :actual_chunk_size] # Process chunk based on model type + chunk_tokens = chunk_tokens.to(device) if hf_model_name: outputs = model(chunk_tokens) ref_output = outputs.logits @@ -100,7 +116,7 @@ def generate_reference_outputs(total_length, output_file, hf_model_name=None): chunk_top5_tokens = chunk_top5_tokens.squeeze(0) # Shape: [chunk_size, 5] # Get next tokens tensor, ensuring same length as predictions - chunk_next_tokens_tensor = torch.tensor(chunk_next_tokens[:actual_chunk_size]) + chunk_next_tokens_tensor = torch.tensor(chunk_next_tokens[:actual_chunk_size], device=device) # Calculate correctness chunk_top1_correct = chunk_top5_tokens[:, 0] == chunk_next_tokens_tensor @@ -137,10 +153,10 @@ def generate_reference_outputs(total_length, output_file, hf_model_name=None): # Concatenate all top5 tokens into a single tensor all_top5_tokens = torch.cat(all_top5_tokens, dim=0) # Shape: [total_tokens, 5] - # Save the data + # Move tensors back to CPU before saving data = { - "top5_tokens": all_top5_tokens, - "reference_tokens": encoded_tokens_tensor, + "top5_tokens": torch.cat(all_top5_tokens, dim=0).cpu(), + "reference_tokens": encoded_tokens_tensor[:, :total_length].clone().cpu(), } torch.save(data, output_file) diff --git a/models/demos/llama3/tests/generate_reference_outputs.sh b/models/demos/llama3/tests/generate_reference_outputs.sh index a756a0b3ef4..bf419c42a08 100755 --- a/models/demos/llama3/tests/generate_reference_outputs.sh +++ b/models/demos/llama3/tests/generate_reference_outputs.sh @@ -33,6 +33,8 @@ LLAMA_DIRS=( "${LLAMA_31_8B_DIR:-/proj_sw/user_dev/llama31-8b-data/Meta-Llama-3.1-8B-Instruct}" "${LLAMA_32_11B_DIR:-/proj_sw/user_dev/llama32-data/Llama3.2-11B-Vision-Instruct}" "${LLAMA_31_70B_DIR:-/proj_sw/llama3_1-weights/Meta-Llama-3.1-70B-Instruct/repacked}" + "${QWEN_25_7B_DIR:-/proj_sw/user_dev/Qwen/Qwen2.5-7B-Instruct}" + "${QWEN_25_72B_DIR:-/proj_sw/user_dev/Qwen/Qwen2.5-72B-Instruct}" ) # Create reference_outputs directory if it doesn't exist @@ -40,21 +42,14 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" OUTPUT_DIR="${SCRIPT_DIR}/reference_outputs" mkdir -p "$OUTPUT_DIR" -# Function to get model size from directory path -get_model_size() { - if [[ $1 == *"-1B"* ]]; then - echo "1b" - elif [[ $1 == *"-3B"* ]]; then - echo "3b" - elif [[ $1 == *"-8B"* ]]; then - echo "8b" - elif [[ $1 == *"-11B"* ]]; then - echo "11b" - elif [[ $1 == *"-70B"* ]]; then - echo "70b" - else - echo "unknown" +# Function to get model name from directory path +get_model_name() { + local dir_name=$(basename "$1") + # If the path ends in /repacked, use the parent directory name instead + if [ "$dir_name" = "repacked" ]; then + dir_name=$(basename "$(dirname "$1")") fi + echo "$dir_name" } # Loop through each LLAMA directory @@ -65,8 +60,8 @@ for DIR in "${LLAMA_DIRS[@]}"; do fi # Get model size for output filename - MODEL_SIZE=$(get_model_size "$DIR") - OUTPUT_FILE="${OUTPUT_DIR}/${MODEL_SIZE}.refpt" + MODEL_NAME=$(get_model_name "$DIR") + OUTPUT_FILE="${OUTPUT_DIR}/${MODEL_NAME}_full.refpt" echo "Generating reference outputs for ${MODEL_SIZE} model..." echo "Using weights from: ${DIR}" diff --git a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py index 631bdf31446..e23ea6e62bd 100644 --- a/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tests/multimodal/test_llama_cross_attention_transformer_text.py @@ -216,8 +216,10 @@ def test_llama_cross_attention_transformer_text_inference( model_args.head_dim, model_args.max_seq_len, mesh_device, - seq_len=seq_len, - scale_factor=model_args.rope_scaling_factor, + seq_len, + model_args.rope_theta, + model_args.rope_scaling_factor, + model_args.orig_context_len, ) tt_out = tt_model( tt_h, @@ -260,6 +262,9 @@ def test_llama_cross_attention_transformer_text_inference( mesh_device, model_args.num_devices, start_pos=cur_pos - 1, + theta=model_args.rope_theta, + scale_factor=model_args.rope_scaling_factor, + orig_context_len=model_args.orig_context_len, ) tt_rope_id = tt_model.rope_setup.get_rot_idxs(position_ids) rot_mats = tt_model.rope_setup.get_rot_mats(tt_rope_id) diff --git a/models/demos/llama3/tests/reference_outputs/70b.refpt b/models/demos/llama3/tests/reference_outputs/Llama3.1-70B-Instruct.refpt similarity index 100% rename from models/demos/llama3/tests/reference_outputs/70b.refpt rename to models/demos/llama3/tests/reference_outputs/Llama3.1-70B-Instruct.refpt diff --git a/models/demos/llama3/tests/reference_outputs/8b.refpt b/models/demos/llama3/tests/reference_outputs/Llama3.1-8B-Instruct.refpt similarity index 100% rename from models/demos/llama3/tests/reference_outputs/8b.refpt rename to models/demos/llama3/tests/reference_outputs/Llama3.1-8B-Instruct.refpt diff --git a/models/demos/llama3/tests/reference_outputs/11b.refpt b/models/demos/llama3/tests/reference_outputs/Llama3.2-11B-Instruct.refpt similarity index 100% rename from models/demos/llama3/tests/reference_outputs/11b.refpt rename to models/demos/llama3/tests/reference_outputs/Llama3.2-11B-Instruct.refpt diff --git a/models/demos/llama3/tests/reference_outputs/1b.refpt b/models/demos/llama3/tests/reference_outputs/Llama3.2-1B-Instruct.refpt similarity index 100% rename from models/demos/llama3/tests/reference_outputs/1b.refpt rename to models/demos/llama3/tests/reference_outputs/Llama3.2-1B-Instruct.refpt diff --git a/models/demos/llama3/tests/reference_outputs/3b.refpt b/models/demos/llama3/tests/reference_outputs/Llama3.2-3B-Instruct.refpt similarity index 100% rename from models/demos/llama3/tests/reference_outputs/3b.refpt rename to models/demos/llama3/tests/reference_outputs/Llama3.2-3B-Instruct.refpt diff --git a/models/demos/llama3/tests/reference_outputs/Qwen2.5-72B-Instruct.refpt b/models/demos/llama3/tests/reference_outputs/Qwen2.5-72B-Instruct.refpt new file mode 100644 index 0000000000000000000000000000000000000000..61de1e579435d473ee43954f9b6c28612db8c87d GIT binary patch literal 50726 zcmcJY3!INt{>LA;aY+fKSTV!6l-rnzF^F-eA(u+XZ7^e$Fbq->m1K2Ox+tOBwCO^m zE3MkDwkXwZEk&VLcGcEalJuYH^Lfu}o!w_A|Nran^{O-H_Pu<*-}61^JkNQ~Z$MVL z@=;W&QuJ^CPKzo;183$;N^RMuMS5!27JVj7E0{WcY(cAW*#+4xr(8U-XV)^(u(@Tg z$d8&1Ij*RC!Q?4zMiorHIA_wd{4!$;CQluk*K*YKf(a9+RT(ukXUy~o6UU7zD5{)1 zb<~X1D*0ubHXTy2sN1xnmf7QT^2@19mA;cFjXzgqRr1RZDO1!WCBMRuaz*LbDbw>4 z`b8%d!JYyL$JD!u9rrB^!bb-MEX8!6ps z`tLa{+@A50l@53NtIklqe|vXuIrjb8 z_)c>9ZUgt&32J}Z5T%8#kDP-4A`800a+YtZ26yMHUrzCOF8a6M{_~RYPyD4cirg(K z+@*rIu2DL}d{W)MqZ?9A{sotR(?I>e4ZHMlf9ji`s&*OGlwz-si{n?(zuoq$YUbb4 z1Bm)F-0o4gdsW>^{1AVkynaBho9X7Ctd6vsW4iGDqxHHl<>+>l(=NsI65O8rCYN_I zAR8>lwdMog2YZQrp8cS<>FjiTpdEa3?0E1;AKKwhw5J{N#I@^ed8yx0U2;K(dgyie zT=geEt8}5|zGk5a^;2<#zUXTqbH(x0=GwoOvnYrHz9E>garfzp5G zs{D9mqdenGevYc6Mz^a0{yD|+7CO!_9uquXE^|M|8SUcg8FyQ;v*|x&`!04sh0h8L za;MWNNs_PgETuI(PS#is=c^n$F`hfxu8d>Ke{g(bd_85kh-*C^$KSWy;8xD<-f{ieH6-UF z_7~OzKh9D^;u>_3&*DbNLp}Xzhu<(iQ~z*1@ul7C4=I0-abmqeJ@&v3tT)ax4(kl` zO`fj@*$(udbdu;}XXqh6bi?}4{@~60R@Qi?Rz6yvDlSKV@abf^7~jZOxJ&&nZKL!a zPh?eW|BBEzUE<4Ve9Zt;dNT_M#np$$#8J?Vt-D(Elz`^&clF9c($E!#bjt z{q1$jO+0LFy|L@!$3!PDO=(}#d(`t(-Fs9%*7MO;+jXJq@4r#?*uT@;hiwB?4?XlwwxfX4q2_9L{T8L*f}G%kywGEu!S8S0 zXMk7dR-(5sN$Kl@ls3CS>Asapml~f*gO%Uuad~+Y<>9;9`Xd+R;W%O(d}Mj>vwLQW zz)Q)Zo2VoT{f~7#eu)23j(==_TLc)7tm{eNzg+xw*njKvRKCx}O3?%RqYvXK*YgGP z|Fe_&p~pen3w?Td{EjAmSwGe@vmL)feH#Vw8@*R~Ke^oN=qufEr1Rl=2migw@#s{K z6UHgHIgCbLewO$<``NJGFhmYZk;6*l*}FN4eE2q3du~LI*t# z%N!AQV_dLrw3|g9_R;oUGRQVFYHa(kKWATlm+@m?KF@^M|H1!M&+FK8#*-ole&_>! z;05mBw$yQka@v6l^?U7C9W6I{f+KphcYVb(^~W0f*Ix`|=%-g)uXbNuu5?v-0gGPv z=XF*Hf5blzyCHen;jgp@H{=B;^u}+)@(@38!yoYv@B%;l2!7|^9|}hou%@74OEYz%BOo@gr6oPshss5epuB2;g9$s zdJ{LpcE8q7@fYkG+6#MPXZ(-$*o$`bryX*R@wkFt=)d%XF8+z%hV{_Jui#ILUqO#? zg&vgC9{TiyF6HF$OZr2Xct(7Co7Jl&F-J;KPZuX^L-mCJf&sNGllX;PS3wpvo$Loqe zja9$=B&9DlRNCV*rO#ULC;Y(eM{huTC#rt9?K;5g&s=Yee)>}Nv^(!n<&g_HAF#t$ z@PImEeZT>EufIS5qDSE_(W`!wQs)0dPb=Ts{J%dp{HaGS_)!1VO!W)L595k)#D0}=#JEXR_lWuCGq20GcpaVOK=XIw zx4;wP65EYrNYn{r8Y+ z>Ibgi0uG@b?8Aw-VLkG)FJEiFKp)!I_XI=x4J#x^m!FjW?1}8%ZOW%wzZ}26$THnk zcBq*i7wudRefYt@s-Kgv&tJMu^ndH|ba0{az1{8~=1+g*4ebD*I6JUU$FJGPLwDl{ z@r4e2*vIo;ZeEVsF-~}I2Oat$UvrOR{Dt)g@%~-Uz0W5Xn3I-@SF$#B2Rnh@;vN1&x_aCKYm&u zI8 z9(n4pKy>}9$)dN*^3m>U$J^VRsUH8qA9(+SUtky3cRb%?e_YS}mOBO4PdbV|e6iE; z0yU(Z_UsoEOb2@~55OPX=?~u6l{kYvu)|ZcCC`$v)_b_p-gfXo1CN$B#07rfgFK|) zTW;hB5AXqxDxMc_v_GJKw);~LKkC6}L_5g|-;f`8(H>k_PlFHR0X}s+pgMIC{SSXs z`iA4dpxMfYxPS+Acprj%;DDUqM0;@AkS+R`dp+`?_XFUB-t@cSYqf_ixRFPH+Jg&n z#o=;>H;mv@CslmX4=SzjiqhfEFD@s;k+zwk`kHpsx35+Hi;+s>@B=sAlOreifgAe0 zV!s7HZ~{m0JKKKsjP081Jha5;6WV1&yWu}kkf1;;uE-l2l#-?0>@cW@BpU}5BR|!KKK=Kgt&kMxL^Y!1U4qK{byJGpF1EZ?gZc5Nx@rebDGii=zcdGMZNJMjJpoWT=+#XrFtKgBO|O$YpI=SeQg!I%E{E%f0BfB4~_wA)}jmfbA+ z*E=2$zf%I#@j8e1hGUKIt?#M*35}H=YM}OQvX!6Z4KVoB^}4Xd)2gQ&9EppJ@5{XI zLw@E7<`48k5Aeg!h>zfTn$M$94}Q$Y;5z$Pu1k=g`2znVE^-cqah&n9^q%0jm->V9 zKY9JUW#xazCEfmj+~5-8fInbwa3L;UcBc5S4+9@?03YTL_<;|2!v{HN2Ogy8cjA2U z1*ajeDNp_F%awnlrP8as-x%b5)IDxr+w;oq?rL|_ECF+}-}hse(63mJ;9tm%KHvyW z=m&lwUf@W3@M0bTFK`8)IKIdY-w}1C-|(lDPXA8npuZ};Em3;EY8*cFJmTRTjUKD| z^?djo&WL^O`0}c6;qnfCo*UxF zyb#7Wy|ejkm5WqFb7C-eYU{1dzwr^rEjiOyx$Bi33k3Kz7?O6BGKC~P5a+r$13mn7w z99*#*e$V=+#<}84JjPzwFY#{GqYw6jF6HniZd6vcD6|{%*fUmemg5I@%k*=>Y8fJs zwoxhe!*1BAzU|P}@dJBew|fG6V#h~r5#2bufj{rVu?zYV59_3h4(-Eyt;KlX)j_%Z6&L}>_$80ckGpG`R#TQ z_8jSQ>;!-2AMA!7@%{<>h37$eU-h^H&4Rj8yu$NJ^SwX0rHb+`YAQX!$jQw7r_2c<8 zxK2J^70=8O4mI~FWxq1}H_9LMx_rcF_!vDyJxUnt| zaXZlmjKCEfu@iidgL>=dhv~FRrV!yyFJn`z*;+IdRL6*P?2sv){O`cmCVaYX6G;f4I}fb5#DQ*Bup{hIYeF ztj9vX2S@M%PwYjCzhgh*1@QzseLhhN;P+GClAPET`tZeG*a`n9kG-%Hb_74{hP}GA z5Wh`ZmG*Fa#ctS%`53%TH}0Xms=Y5dFBh zANIXvg8-P|edU>Ll*fLgl=D311Gl@SpV}uoPIU6~!)s@#e39pgDIUN3E2zAAEu}y8 zS31q_!|(^v5N9tRBJ@E&^bT>xpTU{9haH$Vz#m+1@Bw-7&6_GY-}M4EjxYR}hr#X6 zyVRfgVCJbZ#635t{4ZVyPdHim6_+TjX?xvg+`t9D1D{N{D=ogDH+Yaz4i3x*?3cj< ze83-kz#Th)4}M5J`25gW`qjD6DeprcR=$d#bAw-q7yQ8!{J?3j@ddw+>~G))Zs52J zKlOM=e)K_pZ~`|{%0v7Z$KZzlfE#`Vz7N^{7Ys2D^_3zI=YPNz`M`sA_+PkAVV^e3 z>)-Dh%5iwlx!BL6c+a`{8r7%IRSJIK29AC#5i!pp2l9a*a)TH6fg3ncPducYeQSsx z`hy!d@*EL48Y~c=v_~H5!43I0->mlYJnqN)eZs{pR9@Q~$M8F$-Ah&9Z-CMy2d+5$ zz%A3`82;dfe8>wv;0GSy7~)1fIO2EgD{{^EB>Uf9?~|y1%;Pgl>O`SD$P2El7r1AE z`G)ca9^ZfW`%l)(TRaY5%NL_*2b9te{3bdc;)m7OsU7}^Jje--;D?;x2X5d3ew-Jh zoPOY0-uwKD=ZHUiJw29u|Arh*)vV-RfN+1tcm+516CJ#cU19?LJzt#CPXsef_bTth zE;K$ptE*!4WToH-PTyzi>P zMl7%kbg+xvHzE$OuV6pHduhh``u38m_UB4R{aNXg$x2gv?*i)t))Ogz4ANu33*p2an{TN5sk#WO%mHi)f z#Xju&z#DsFFYJRIu^;my{g4AWksJNK_x=|9Rj4C?UvwP$oyQkAgXcQi;c<_j6=w)P zaNhC<;7Pm+<30N>@Mix7Zs1G22VdSBVF%F$nEz?Oi{+t)Zo(Em-GZK9tU5tB)$>I5o!3KYXQ`a$FW8&sF!%@0W3V%Ph{wp4 z>v2kYi=WrOmLxvk%vV~+euKTRFZQzeBi0>^1I8`mgmF>N{>u2xb)IodyaXrek(>Dn zx{OEo!5@CKgD>yD|>U--d4j&spmFeNH96-zOfXd4GD2_`;9& z*ag48-Ot-c8ppdcq+lueI{XeCSTDl| zy^w?U$j>^2bH~KPowgUxfl2rI{R8XLAxXlA=O^52Rp0&LbFnw7tdp46;Kw*;eTl_tNb6h7Dk{PDkE8xODpDR#hK*n>EX9k2_}^^hOCU- zYHA;~cI>0rKk?oge&B{*!XGALK(WM&IKI=3k8;ql1(5-S1Mn7<4+<`OY z?5n|@=XBs3;>z<|-m8Oe$QOR#9fvb`@;-t*?|0@@7Y@AtDK|j*@O=-@J@HfC`+Sh7 z_RksDpL;6*sLwT>oTPkB^Vwj#J?ZsiuKQt!4-(Yg*TF{Tysvx(KZmTAsQjKDN_Rf2 z^cla;XyOkspcjrC#uMX*`yos=3h!wEe}}Pud#;>oeRzLW16!;SywgL#6P6FZ!Sd z_XD5j=PU5x-Wv3PFZbXfpSQUY{pQ^xzO1h(r+rATz2jYt^6FRnBc*>@tCW7M>&Q>_ zJPCd5yYo2p%Suv8JL-|k&7z_98*fjdPi-jd#rdQ)ZU=w(f?J3S^*?(574j!vx#iJ% z(vR}D&rtmypI@e5%a>Fh;zB)ixNjI-5<48tFT@8td^{@xCvbs&hzmG`_=LE;mm&Uf zxNJ7g@CoTH@Wy7Y1NWagi(c(CrStTnd9{r9`yqbd1|9q@4nNkz*bTn|zXZpR%1*%({O~KD55X7zA;qsc*w1D) zmtK3l{|WJf4sv0Ss+sE7!Rv&%`;@0WIEJ`T4?d;&75tzFE+G!!gMWbw{m=`Xz=Jdn z2mFry&_j>kdVPeSt#2e88hfD^w%cca%CjG5*^&16J>~?{!@oXvyROwlx6xpw1HJ#; zV#f(_24DEc`5$&gFX9|H)4_py z=uyx9opRc-zb_4!IDf%E>eyewA@mn;VSa%>xa=}+p?uJzeuo$Gv^(uh`Aw6lvM;z& z4fPnf3So?l1UBX}G|T-%AN``KQ-)_zQA^3-ZrMkv!Bxm-3p+ zgu{HtZShkG3w7`M{=1$^e3PEW_Nhph+m-i!q4MZ7-#vZM7x z{#^He!}HX`j(2CTR(skZZzIzs|BlyTv;)VXrc1j|9gk^;o}{I%heCbP6TL!x(Tn@D zlZ-F?w^-j7bEW6ul}h>jEAT*H`ng+#-l4we-lJ_2^4{#NTNb>I)zA#NN!$ z?e9F=UbM&lg@HYq1oq(G2Fj7+Y|pdkpMIVA^mP9|eqM3pG?k}oxRjiC4fEmn`?LLc zhW)eb1)>B0Z|@bs2V1B-#E*823-Ao_V_dJdeC!{=4}6X{PUxF|oA^@yvipZP(LS#H z1Jgwg@BIECKyRw{4d`JMJ_yxz{RzW8g- zc(tcLzk9^{JNkn!=cm5)a~SIJ&(`*%JxQVqAMA|&{2l@Jr97kuzVPR{MqZZa^SonK zFXhu6NSHs+3%bl7PZ@^_0nR1uP8L1%`oxNa^wPYqs5o5h$%pOuzH5rl6{2rkJNSMY z@a;WCd}`z=g)X>2k8<>ZUTBZUj063!2jwB1J`In?v^@0n^$SaJv-33 z$8@QOF6~yko_5ruC;r0s!_@P;f&5-zf-4^La{|7D=eziRodnQq>1)lJ8*i_ z@e042ZMw9he-*D6uC|_+1oJHIW_X;Bx?l1k2jiUQ+Qb|DigNfC`Z*u+Vecy}AN**) zys`L|ucvf}<)j_wyl5Z#9sJ=#JLG`QZyj$xt0g|PZ*zk3-|STy*$%rc2mQLcI>d#z z%5y~GAmzvd9$(cFA?!v9g)x(s}@cZ4b?N^C?RlYG%Df~L7C?DE^JnfcvJY!GRWwqU(c2D>{33htd6jB;|4tESTCM$V+Yn7JM53O?Dx<8DEWye(Cy&)kM#!YHRc`G9XtnTpZ~J$g8a~% zl^{OvJ8pFKIk}#vtNioFl+Jii3T(0-^oO3!5-|?odyN&ve`rsC>XGk#>))xr_+HsX z=^dF$Ypzi`+VSGg-q+^&oCf^y|BxT+Hsqte>?NX$T;M@D`R6Vc09X4tFa5!VJaq6U z#!JYTc0-Lj&$~LdlRUd>D@`7#6#C`7uKLmVQeHPf?Z(*w!GV-GiGR|MctAOE0r`W(Tnl^*JPFR`+O_C5JkV=d46^{jBJjQ_Ymy2ep5ws>N=j_hu|=-x9Y?3 z+r)A+jt5&``o-aiKM^1BPsS~N%XnVl_=}%HANwFLbVI)2ktsDI{FLzW!XPp~E?S^!u~p0{0ut^*ot8Nc7)tp_KmU@s`(T``%G~!Ba|~vAt-w#`qu?@d0}1 z#r#d4dY-!=?~fbA=kJa;wD0$p%5OG*%DV;e2!DRXSU)HU6*^bCD)bq)G-Y2CbsNZz2H*TG+{O$I`TaqPTf!nouTJ^22R=U!HARl~? zi~TNgRrGvtz9*VO56F4Ob*b%pzT>&4&8U&*v1gZw??lT79{5AZzjb4^Uphq$KJYr0 z_Pa7wU+y8L#1Z7D9eObC*soI#UvTI7$~%_x3O^VB`eXCEPwAlQO7S1)@Le72YsU2# zZeMe>`tdv7BaJKVSwBOMan{r83))@e=ZGC27rlmwN(VatZYw9;*xz+KOZCtNAMj$F zBsk98{=M3B-s)0!B%Tr9*f)SLagFyo#6#9E{LVYi&4@4T3y_cZuRNc^5AT{Nxrp-} z94}t6UvoZ&d&$rDco{iL{n}(G#qU|4!IyJf;8Vuqu$uX_@cLm?Fu#4)M0C5|tn|09 zDNWj?6#8-eu*0|BFQj)?zkUssA`kZ%F1H+{)L-fH3_ob7>j|}kpAXRPwX5YH{4T-q zo=@V&DF}}wS6rVHt zDLw7mqw)1L5v}yN+GqTldcOVHxDIgqy7?vb&mFAv$Q??T%~bkfQ>Bbc*1@ccu`A>B z*00q5se?*4|4wQ9=ag2pKRNuCygwn1;OES%>;qWuGmh3s{gQKu362l=5zlj>i(QF( z#6kGwZIpbQ8YrEaD+aAoluxo>)^$Ap{Szt=pR@A(ndi~`9uf9KKI}}~;CWbc>%}}z zJO+;vyshu(1jzxO=mVeoO^^2no!pQ1ANctW{zgCK!vDZyTX*q8j;4<5oL?P!sp_Fu z!}H=(R(Q4Tzp0PfQ9o*r^3P=}-7-poH}N{7y8Cxkb#%yjpJ@H=vR?S1-8=F$R^n&v z-M{fE`eWe?rF=&-=xgQsIo`eP{1b~*KE?u^YyMwfr1FMt53UnWRXO)A^txSw5(gT4 z-NQL~;uz=Wng8)y;so)-;Vr6jh3GEw`ha@czur-DJnwjX#Lp+H*{|z5J{;$9ZkVIK zbJcHT2c^s_#J?`DtDO0zi?UIdiTb1eBTCmBr!V^{Km8`9^Sn-OvP|;VR}wXJTGi{i zzBXiUb&HUPe#o)e`o49x=;Yk5ba9f>wSJy5@38WTZvSPT@*UvCU3vJY->UrHe<;1xaqWJ`#eTiTpZO9x z#5d;8XHQoDWghqSpH_abKdAe5m-6>HKCzC#eyk&wI4;$21@*+`=c|ei@#hZXoH#+{ z8_YlQ0Iz0x7pF-N_5tX@d|HBsnw6|4@F)C#MYi}{;W&267y&@7 ztWx64OpoWT*5@Sq1@(+iY?h@Uu_eqX@UVly*t#bSXI`juu_>8i>PQ6k6 zxbJ0wJKpbkn79cY)(5v*KjQ15x}pPLaArOp(p2?~599%7zH2ya`!2Xf?P-^BnevH~ zl=k*Kl5D|fcj~#Szb{AWpY8A2m6Y$Cq4W{^P32j_@vcObL+38r8^2?It$Tyo9l2d8 z?HVf^u|9wNunLGv;L_ghD(}%Bc7up>ZB^W$%;~OoJL1)c99#YJ z@*WJkF@6~j_yO}4`!mY1)BS=cTID$ThWRf`6Q5fWmDVn=be-Q54In$8kesP}UIy_YAGLQQW9xv=yct6W`V(39RdDfBZZRcmbo}`}dcHXKj zeq|DrCOaOz;(cG#MCF@ZPWwLBs2u#zAN;@*{HSNY!Z{bt&xHQ-pmAK|v|Szj@uB5h zJyZM{?~&&V=D#|PRFA)~KHqFSnGY6pP&@pU=N9EnLmjp|{!Z~c_-r-x!@oG6hF^da&mEs} z`*XaY8)_gpPqMw2=<$8Pp$W=YOI6x+pwbE+0G>V~^qpWlUsX7+0UpS{?7AN{QZ2-3yf#>_xKU> zI?pk$vOM6*d_Kl{9Z3|wy>55A*ZJFw1ALNOsU7j=teVPiPga_0y!UuKVBdqLi+tD_ zz0s3=rSYPdo2is_1@&noM3DKMcGuNc`z}hNCWG|HZ5J!8WIJ>3*^QQqbr$y!RWL4d z{GMceMfFGiMWd9z+J1R~>(PsSDDAP_vNdt-|q5i&vWP9la=Rv=Y#fxec2Ldj>iY~!XLngeL8Z6zt={4;Q! z9^;7TdfPqkWJ$~jJF|~!=5p{_XnEm>d|Rw{iv6>l{q_>4%vY?(Lcc;k<`u>nxL)G< zV%RVVw6MR@Z@izW=XL2pw|lIC>Yw>c>6XJvp$k9WYvLn|9=yuM+0Z(OiFFprbT*c*A{&yO)Ho>eQZIiaoGjgty=f& znidT^_J2#gj^Un*yFH&>^@do#XT7+$6@R?Q<$d2%JIc8iWa(m+V<(<#a6i^=6JXu@ zlJO|qrG8#MMC@}IM;|+m!!O_KjSC$|S(kH8c&FDp#Engs8#{OTT>NOiw~O+ri`1gXyF9J07p-O?haq^Niypk4yOE zc-_Rhx^S2HAy0G53x3}vs=k)>UHq8xV~r>G4IJlq!1(^Pet}%v2SUHHUdM-eE=&^t zNrRQ*Hy?Rj3I90%ddqR7m*uDZs#8P{yJ2U;75OzlN$O=!$@)E6-6O`w3db}0nS8X= ze*4c(GI%IQO9w{ydzlc?Ka`Jp!Z9g)!Pnt0VtiGxU$M`**9<~DLOC49OU~zo_3#Hb z(!HJs4|u-rmOk?I>OiHKX=i-LcR zbpOn3HG~hpzvt;8O1??;CHF98D1VOUjV&H$#I@sFi}1k9)h@iB@wpE7UN@Ge>B*K3n4IpXvP?Z)`b`1!y03++a~pB#Yro%J}rq20KznBO7icgExT$JvkH;jV?h zOO2A>3&zhw`|&%~%y;Zt;`noKD1OWDNcAvZey^9`<7FMs?;yv~UAz;4mHt|~#DVa4 zc=1p4=DtYE@q6Mqze8KmdaxhocVgq@<{mP}6YEcYN0r~Vi0jY&fXIbE@jIgYPH0?z z?jOV+=*90bVpo1A(ZjpscMjv|a?c<8tT?)?TaT4K_vM9h^E-q5jv&7i7$+z9!GRz5 z;c?$xT)*%>K7Kdm*Y<}mzo!CDoPXnY)Z+NlpZnig&+t2A{Ek>$fBcSnWWn)!&#(NB zR$PDX%L6z3j^CNe^l&V#e^a-EKfkk7-{XjT^SBqUGEKere+iJwu%N z9h+mN!|%LcZ|;}xXB=2>5SKQ39^-ck9%`tDrLAkid+wMYxtEr4&GR1a_5HQ;CC|6- zNRm8pb_wT6?7{Ci?6DotFOE()Z$gKAZ{zxf^CkEE_ch+kFX$gfCvM*4o_g+~U+Mi5 z_sGZ53+GYpVUKGc&YO(4)E}e>`(FHpd&J}Dh4Ur!hS;9ar>+z?XX#u{-wx z@|?6L@gnHQJ%q>FFT59wd-h7>7xv>GI_{Y(tv~mbA#c@8+38sOaZg#Eap65;X>_=+ z3jW(X?vJ$} z^f+(F_XFTg99r$?Z>-z7r{W{)&pi{|Bf)zY&XaJ@1NS(a3^LNVm&h?H=(dRim^?YX<_T#%w zzTZS2@IoK-=KE*9b4Cx|SMpwp?+E|D8}XfLX?pNp6uJ41lX-&mB;Q%`9VPD* z@arEp2xsKrJH?&ZDo?i^!H@3*;Y&H+@$sD=^(Eig`GGO^XN25*C&YI^ycgy>9=_AL`6c0p+cIHroIK~?Ip>bO@h6@SbIu(3k%M#KobyHw`g0DObJjeU z=bSWnbIut%L7#KT_!)FLCmfyw<{U46!Z}>*!8uy!a4wE>uAF0qKj%=lrwRU?x5O@-ljJ-T=Nvb99%f!a4$c|U|EVAzaLx~Y_$BA?72S`C zb8MVbgZ}}`#W^zkgmYl5&-|JrD$&dDt-aPcxM@8X;be#7%##y#guurudG;LAA=&OsnQ=P-D`#yJY?$TA5Ve+J&T~q%T2N2S#6ie{V^6}>v8@~9z zcnO+CjsCL)Eq^%zF)e56gzSkEF3T<`l4jI|aXq^xl`c^0wA8j~DQ(-e&1jdA)~-!j zO1qRcDJd!GZBx!{-7cf;nQ5)twn;6jE0(fB<(KCd@6U?=8<17z7fLW)J8gO5XP&$_q68cwD*R?E4-hL^DkR2?f^snYADCac-VnK=_C%_uu^dgkbv6DExvnUnu; z>hzJbT9(Z&*0AB=Qu)JX~cBTtVX^$m;_S*xvkC}S$;m!{Kz7xc&HAO6Dt*`LC$fs?vs&=Kela1v zIBI!DNJr1f$(&liaQDv5@-h7-|4o{ZlQSu^fQ3aJj#Q7!7p~xcv011QgM|XEHc$HH z^ViOwt>C}E7w9Sfzl&5v$1l}CDqpYiUAH%BuYWA{^L1rZUwy91s>M`RFul4iZ+=o*7z@sP?<_t-avC`XyDCR}r-{y)@6?%XRqZT}2&v zeyyK>V>(}^=y@mi8`@a)k{75P(m>^2^M77_)tMLi1&&nk-$~OS;c{SIwO`MymS}!hyOnogpmCIEJjqRAWr1}*SCkt6hFv4 ztUC6A{oGSc?Xx^DWqXh7;`XPl=L()b+wDC)f0FBMDu^EZqA%*`?dv%8+wFPYGv8f| z$A8&g&uw@*J?sT~;Ddj?VtU!8GtYFZ-6T5nXZ+Bl)9LdZI*gAqUF2l^BFjC^&lh_> zp3gEq(1l)J+0*CEtZ_PjZI+5K_@i&=?()3QecpAR!#~f*RTll%E1f3l{Nv`k%mb>w8^{5ltv&8Cm)B>ie;L=?yrKF( z{-p9IlC^KsKQ(XC`h@2YClXPo{49v$tE;Dp`}I8NjqR)5yP>uk^Wc|n_Re6VNi_FT7P z&+y&E9a@YP-MY7^BkKY@ zm&ohg%B^~ym8`P2`#C*`YTl>zG1k)?&fkvN?v~hLmfF9Nd*2p~L$9a&9Q~DZ9zY&4 zvyu9bw1f6F9}Y_q@%a(Ub;{2pm*i3OV;vwLrC)EyYxo(X&bdK!!4=#?y8{Q-SMq!G zjoj}z{$UT~_gn3!v^F9XZ`B^l$3eG5zAK zH=ffDeJ`Jp!%AeciOgnXwjwV-h4aSh$2`z;*ej$9osd3yhA-rRf9%t0itPT;-%aqs zp%N55u|E&@4bC-6U3$7bL8q+sSIYj7dszL*@5uK! z*I4NOoL?-pLe6o%^=kw0;}BC z6YPd|^nxDI3wDZKV!!lbJpGxE`I(1(oR_o0xWl~A;oM-D^&dWGUhI*1Xon8tuqXKA zIsNGe-_VWK4?5s+%yvS2+cHS{Y_B4^^KtzH`()iEo}GMM&;Qj*CGl!!qUz)g$VuEH z-=Tdnz#6yh-lwsS;4k8-3-yogu77)n{!!iWfpgVnW7YoE5{&2N$i{07wx%0^S*6-|JqOdoVZwJqH&q_tm^%(-({^;A8&p+7yP)M=;2SqgU=sT zfBYx(8~p1$>jOVxe>yQi%pdO!hy%ROU|+?4YPkag z`wR9B(|pc!Y>092I&`zxiLUp2Q%kCWd61tvaR(fbJ6>cXS`f5yQFbcwg~ zj6e9j)JJ$-VLM&aPIdO3?6=ohA&Isd%?_)h@U;e21jMcBU4-}rc zU;KXTgrTJ6IN&(2)C*gt^^~$${WsgNW;;IL;0+LT%X?h|Klnh7^*cqUhx;w}dUA59 z+LzhS7>8c)<0fyZA9{(UhYfM+*H%UJ>~Q@9e}g`D1b*z}!~OfZG>vB*CH&!k$PJ(P zRlJ5r6^H5{tkdu}Zj0LYq)UO18RrD!!MetL%sVj^ zs_4#lqF2}Ulw;hy3`fbuMKINL-}7_$V}9cGi@a-av!wca-$#Pf#EtJd4!SC7&7 z8aAMn-BiESc~yHm0(gK=As(h5;_{8#zqG#h*iZ)qxRmw#8TbX4;m#9yPCGbV7AHO* zvAv>?IMW9o^Z@_NbI|gIxLiK!bUbRbRK2nDn;)#`fwR;OUF@vkwQ2{auWeWG2QJtP zc(e*|S?+c0nGDh6-U9rG_&nK3{lMvHlIkHI*b)4|KXkzb9IzK~0*?VVh<@J`F<$Eq z)$d7AxieYimic6}-a<4|0b7hFstXuHc0pmY5E> z9-kxznHPHkFZ>MqVjl2&vx(#naRe{$!{26`5$uO~LR`Ru{1g7b1^t6hEL^|?zQF|? zj?NJO;DnshUsF4E=8uIB^ueL}K#j+)zzbYP*nhyK(IWN7?q2mi{eeX3hy8|^iD3oQJIGKz{5B zys$Us0T=jy@5uT97wCh}XZE95xLiG3xTN(`nUSk<#T*g7@6W1NxBo73zSSf_?eKwL z67Rr|{ULmZ>j(0IEBHPxGLcS0;@WejxOXdYX_+>uk5Ag#x?)!f2cs^>3a7mt} z^2M&gp-G&|X=PPLD^-rR9DQo3KCivXEDbNr>%gs=?Q-sQ?(aB+93kJt8{`JR(C)BD zRF8Tp4@I#OIWf|j=XB=2B*k4t1`}h*Vp^J*>Uk&u2oZ}q& zjAehU%kU4s$c6ri=irCEf#crq#Q=T`Zr}x;(7%0(_+;J?H{u%ehxp;Y;7FcKe8OIk zpL;9#+g)B)yZGFjeQJrul4G0=hVkHs-x3eM{Y?Fl1Kg04xQ86z2X5d9ZrEXn+sZu6 z1FrFYj{;rd95})cIEMIv8+Het-E3E1xa5B1HMTRGS2V@`-qrmEH`hO|6r0hep87|H z7Am(Hr{49?&291Sj-GzHaqL$g{qPAFXB|LeM{vK=OcH;tNwufwv_eW*M9f#1U-MGiORkz zBK(g18hqC<-QTCE;|tc$4*N^B@oL{SNoDH=Dj6S3Z{*$J2VVF^s7L$~y|E63`a@sn z$L1Gt?(@^{>GD}bPt1@0m>2#-y`d-KDf-*iRdPr6uS1PgN001d%3B{z-7oDv^(*G~ z3xe~A5I_8k^$-4$C&Uq)L;S#v_2}YqVgTG~`@+&@+bOtKbiT*D;1ie-PqU+j;35gdsZ#1GCHup9iHdE%XylRvGjDV(`yNL(N20W-XROf=rBUG7TI zxKEw`S8#og>)SH*{FTR44s-h}9*-VFyQ+;J^sc9gg((%H?I>XoT#C{d9W` zJNf-&kBL$U`u)WELBHL7pCi=o`k_@k(l2>MEdAE<{Goq_`bDq!BlNWIB_%BMooHK;_-REF?!c)i`z!B!2|U0DJiv!|#eBpW@POY~ zxUe1&U&vF~?=DN1oEP1#vgsU^_#ODM4}8?`VRv}K`zNbEd|+3hJrSpi!jCu^;uq>0 ze-3d2&j*^yKgjd2Bkbz1<=7b~{=pA?iQC`~-tZUl^Y6<<2fV?T_`T%IZ(<9sI!;_7dU;UhHRxgR$_#AHfgYz_HGf)A`|i5k9~X+_>ihZj_uO zvi~9Pr z`$XEeI_~tYE57W85&A^WA&$rie&7a<;0BJoAFA)X;fuAJ@AVQY2YO=Ywy?rN{6+-l ztndkrXNw=Wfn(Nm@loPJmEc**8_?hS+yOj7J8Y<7k+-3R-&X)1&R4+)|HN+E{3rtK zCvGz?!@Q2BnZBl(&?&S8N#s&PqjrZ#KFL@bqfFHPFhvccT{1^Q2-&puT4;;Y> zd&G~iJ1>8Q_J`jR7vUQ@!590Bg>$^^zqU)pgA;x{vb_N1`I=4|z;o_GPMQdaA=ObeIP{Krhq>dO$Cn+hX6LKENLwV%anM1>Res5AesH(NARyfIZ{?@Wr}> ze!w3-sAK1$Ua))ifqB*o`$Y0m@MB%$egrs{EiJsadwlm7RA=1k?W)h2tg@O7jPs2t z_Ook-==tlx_)u@<)Geyz^AYS3T+ko-!v16F4SnHHp&y|q=%7FJh5pbN`a^GwLx1Qi z&2oh6VNrSmSL}Cq4-pD+0^f5z;d|akgD-l3-UXvhr?YjD>fjB|;0xa1%=uY}b7+s; zI|FC(TW}`+gB!RO;yGG;fj@YIFMNgmft(>f-+2FtpO6>#(KPw@LiG>!$G4l2r)H_W zv;CBFqgq~Ht2j^lgZCGV=lulhbJaxA*X}id?bHkKIp?AdMb>E)+2b}k9>y(zXUgMyutp- zcLvDsm}k1}pEw3yuev=JF7OMV-~wLk16Vh(cjN~j)(7yxo|zXu!GZYYu=GF9cfn<% z-`B9N7qy;qZ{Vo)cB{)(F6TH74R@XRa%!IR{SUtvg5Q;u)lQxDh;yV6Z{!76>;id1 zoUu1>26u1jUlJdP13uO& zyzdAe-~->aJP-DV{NMrK#8dDZ`CU+0e2 zlsL1!aQ_aRc#lZ@3;q4i38J&U-s$v+-^4-i#ov(^xxg2^L;C|?@Wu|o8GPCAgDZIo zdSE{02WN0E%Fp>7N@#!B9d?NQ6>B40I0sHHsd``UtMB!DS?1s9xciBpueeA3yY`gf zO|{-yj5{6g!OyCW-{*Bwf9Mhic#dAsGkAaxc%u*WgWdBy)Ccwn{^5NK&H-cHlZIdH zz9@YV56}zakrO}19%IQH%87i)i=ELP${otdy#sK-F2NDIBmOlC{1F^OKSxgFW}V@@ zWY9X<0l#Azl&d=D#UpP{>}WX>%&g!Ir$m+@GPI7wwtc!#qv~syFn%T zz~1nG)+6#+><@gfL-fLWjXuCX)C+kg_@fuj_lOhd1^mPP34Aawc!c^u59B%MgZrqw z7Z3G=K7Qr=ADj-&5N_Bh`dFT&_T}z>(Qehz597}kXYdW>1aEK-~{nMbEWxxDhy>e64pPa9f`M{s`3ptq|JirHgMjv4w5$c2c1jG^az&#rD zfSz~{&Ag!=n4k44mOjt}_FT?#k0>qqhdVDgIZ^=r(|H59fIB!~*VsAyg}7jc;1kPE zu}^RbaUhQck8ph>u7>!q&V{&Ohxj||4f0|K$cvoV1M-4Daw8{k9C^d}c@Ky^AQy5Y zFLGiJ$O(Pc5%@su=ak1>X1Yey6+R6Z@Q}d@gq|Uh+KU^KQNeK)kxq_fh#iG3T_;SWn@1 zjnPNn1ES0OylwSVKNzR-Xg8I4n^a<7PZ|*VqgUvnU-Zg%E8;~a;=2ss&ABV*tgM6J z&G{(b0~u$#rJe6A@co5iOEnSkg!KzP;g5Y0`EK~#g}$E0MFj4Acj02=JGPVNd!wpK zzQ2&_f$TSG`2HZjD+90ay9@9=$mHCAQFKlt)J4Cvtp;LG_h-!EBlr5MAmL%PtL zXa3Ox_`@IZhd9M|5}vc&^zeJJ0e;UN?|UiXeAox_Fz+C*@}T#+b*yl_+DbU=i83>{qkRZZ)#2x(G79od*|@=uICv(RQ(&cY_?nv>b;}( zwgw`^1w7y@#05OS2V81gD!L&q*a^6xXY2)iFdue%ld1)xK%C%0`YGjAtD$U`*@t`A_w?kUyR4DLj3U4(Vj2FFT@Q! zgWup1;+J`*J01}~zze$q$HTT);u82l2YHCIX_m8mn&xZg_qZ=ERDE?VmGeJV8QK+e z!KWy@!k(}PaKV1S1^YVKRr7)ia^rXBnSc1;d58=47UF{6GM;>AtLg6YdU4$EnZGx_ z%N%c;J0NAb|LCLQ_o)3Q-3GSz2elt}zmhf(?2qrSRY*EB&fr^={ed5Tj2(e9&%qnJ zBhG;{ag6ccgZ+W;hEC%5m3t)E;r6OO;e(dY{;)f6Cg0kZpm8+~$n}dud@ z?GIew?`-X_@b6w&kK_l~9r%D3xG)d)%X&1+dIJ~y5PaUg?dNd;58@l~?)f)0?t|tk zpBks~!?7xN-zVHdT;LOYu$x%%Dwbb@3w8qC5SMq&AJ37WafNmF!%rLdT)X3F$rIWO zc<{SM;#d0$>OaDEvB>N5oX^zWaPisU63bq|13wDw1^Ka;LYy)+Z)h*r2X=!zh3)AY z_oWlXMSUg5-UrT%3w{H>;M8NM`j;IgdRx41tO?#1U@yoIuH@I)%Y2XDWc!{KC;Hv( zH}`#^dRz&Wc-``TzftWU zO;8zcKe*I#A#Z9RC;Vl{Yg{Pr825)?W6{)mQt`1eGmpPv{l7u^03U zA3P_{lIMr#9^{qe;hcADzFhMzuC9_i4Ek>wPwW%EN(MOdef~CXr{DYj2lVdsJ|@}r zoH|YOQtxw0b^5lvMR?gD+bx*78;rSkW!*$0$=)pI5U*``d zgzfYv4@WQ5v7erGH4pcSAGaM`wifO9>JCGEg<*qHDdIs zjiPgf?Rl5^y212W@BiuNoh>JS=V@{UjUV}-%I`hlGRwPXfZE{)dn3=rp5UAIT%Y3{ zcUius=)CIqe53heTviV~k8_?=!uw_Ncjl)*d~~tCn3v}tRujEt-uEO}&REZ}SNeT* zOyl_jK($Thi04Z)j_?H^@YTlsl7f~0vZP*cimlVCgN1DVa{QlUVtgi$1&(0UB!!YB1e5>k{oqs%U z{Go^4@b?M9mwAbU_3h`(k3B6k5cpkZ>vf{-X0j96wH~<36Av>V<8BpyoxBdi4>-Ul zeuN*uUuDy2YX47hezDW-#U_jX3a_K||0YB2@DDD`%em7&+wC_i)t`Ck2M+MZxmWwv z>QBGXYg9k%xZlBk^PTad-`~>Ik349vrSYh=L3!Z$L1q%v7QgLJQKVQ{ARS;A8x7v1MGj;A#o1*IS+-e zkL=gnzeSEL`*j(|SMY(3VUMhih<#X1kAqG-+tb({;u9Q0|7D-Rxx)=@)E_!?okuNp zoY-`!o~NmZLjT2XtK6av*eiC5zrxSvYU;<|A!XirULWy`-6`tV(fJkj&vWea1;-2g z82|k%=LI`N=M&S%j^PLWU^n3Yu7*YMM?c~W{nOkJ99lP3!{<+`WZW&6sSYl*gF{jN z%Y5XkX(AUt(B#RR6+u4^H@Toc#s6 z#~*l39y6Ysk2@?7IzezNM791l`07xIGF!`|PJzaTf~A+4N`v#;Cd{G+-T`XiQub^4H> z^S&OtB7R`s%+Gqpyx1rEeDYfC`vLn?Qk>*laHYzDKF|1ZuIj6N->SLeY;&JmEcN~* z$K&A-TzDUjp0LYOUhi^^=apVy-gREq-SPKI=Ocr~M)aIFw-b(IJFP$Tj~`bVrFo(AJKG)fz^kJ5eXW69v{&QcBm0KlTSb@n!g~zl-Ddqi z=aTv1FTwgDujAfKW#WqS8~l)aSp0oS>>hi=Zkn$UUtfKyl5?nY99PGCy&O45&zm}* z_^^TM35QiK`?Ja~|E{vj9V(fRxXtgAnZJ0h^p@)N|2EtE$y=*cWz%oiPt~wfb(2+vs>RzQ6cyxX$g4R}Jk)AJ=ue_pM)^P<^rMOX5^d zwq5c5YL@d^{1W>_U-%v8=y!f8I&<5qWZ!siQ`HAFSBambitPw`4yzGzTV2t5$B$C( z_)((hkdKg0@ptc7zlcNRk;I+O-UmH7Pjo7FQd!sjLwuXKz8ohi`|<9N@? zxirs-)A%3!vQOc=FY9dA;q(nxOKfKdJo7KUDtSagqM$fp{I- zcZ(((SHl4R%K1c9$5Z^SUsd(XttY`+7}t-cseO7M)A4&!;^c2_*Yh0j`rDtMc04}i zICO*8&sA56{;T~&&oCF>Z^SOc_;jTM*!Q->PTucy5}D`@`_ZmAH7xLY^o9A(HGVI; zUtO^g)pq%b6+0)s67P8*L_9-(#qE&% z{=ZZm&Byr*{fH~*e~0Dh9{3seTlRh-I_YavvQH;IF27Li@uu^r{bj-TYM<-4-fn^F zhpq3+(7Ods&(`xfja4Eac?bSd((}yirGChTJlOZ_PI|t}{IH&LU%Kz5dVaEpO8%bk z!%pl!bo@Wz{aF{=XNte`2|wuXV%rCH3E$Lt9*%?WcHJac)du2c*fiCf?N$9qRn?s? zMwM(Rw|QOdP+$F4c^r0{ZF@fG^%MKNWwH8$A9BDCIHqdasK4XERo2fG`+bd?qVrvj zN}j)MyI3$^?TmlyE7j32>k9L;-Y_rxq~57<@Jl|q&2i$S`D7e+jvUZ$71((%kK;M> zg?R+|IsBCi^o=~MXS|PP-HV?f{H*b5NJkk3`o)eGYlXIYuKfXiqpmE-Nhvv=oe2hm< z#uE>jXXFzahuqvVfFIr?g>;|~J>n|+cJc(itHL?YwdR-kpv(F7TH~@Oumj%LPIJ8E z{RnX<(F_q6CY!NdH%Y#0y)I_BPE6jI^XY7-C?~43oj%omT3Hi$3I;;I}&LhCFZK~S;Z8^{rc%v`u z0eaXu{m4t8%iq!FURe*vzkj)8{UzVud4>2$p01L7hWA(4HFn0|&!T;*<0^K_^PP^5 zU;Ex3{otGXT=kt-Aul*!r_col=H2W8AB1s$d7Jv&;rnYOZztpOX`JdO2C00; z{s}$e68wboVTZIcAAIl}JLUb&8qY%O z3V9OuUf}m(kK5rAdFB|O(7%=}6n~8yt6Z9(5_@4@_A_DJU|!BIShx5)63jo*c(9NA zF;VldZyV+Ogm&ya97n%##vQ)c7xVo>?C=HW&)gqlAF zj#s;l#&dqkzVRXJxw`%9`grwQXuUjSyZyv@(Z}9*$6HTj-LI?rVSo4!c7dGichB{> zr(MG5h?_M(=cA(?_oq4k_}q%>VuJ7VQGfK$JlG%h!G7*Z%ZWV@M{4JZ&bQ7Vj+-ye zKQ48k;QXF`UOo!r_tW;HXRUYmLQm)qeW6G6*0h@VAbt=xUa2OAD%&228*#xr$Pexr zPm10M>lwMR_ocRH{5;$jLC?!d;djA!U&?yUI?lSj#`%3S>ks-Xu9FzKiZ+jaxSmW!ch#Gjo zF5OfO>q?04<(6lI>wL!nzR?Hs%roxsc2u5M8zZ{NhrH0iZuq?<=kg`(m)Dryp~j+f zx8*E~PPY4zZ;%I&XSA~($wTOO$Z`;`UiSXwg#^tX>IXTX_naN*ug>$x6Yp>xee!!0 z-mfe*{q~O6e=wc*?cdywIod#c@IL5><={O${J|G`z;Cfz?1b^@T|~EMFO|H9q91XN zbCOfGj|*MCWBE43X*}_`_U)>l7q1fevD=y6&vVX(U$dToOEd2mSD8*-`#*Mte8dmr zXT6?mdtm<7Pm3Sk13d11%TC85`qK{_x7u--k9iI(6ulYtfB0x%y38BW2M6e&AM!Zj z7x4-{xL<>O&?gUg%>$r|pZ(BDd?wlNSyzyUJcc~+UgME!J<}gPun+n-+a&sLTR+&r z(;i>O!XZp-BInBpr$)SB>dTar8gdN%fYusl;A6C(Cslc*Ak&ehaw6{37qVpq=kUkY_PJ z_L4D5^Fb%PANr>K8=S!-)Yl-}-4y#J^Y9+xLGzh=So6V86Y~pxU&rftHOsqltLkHn zC;OCgj^Ft2+1dqs@x2$u6>Fz?LOF?(Q^-Fo$04u#%pc3XUh62jmz#h3Z@xh7=nXyF zO(Ullg{9M-=zGtfsNoU*ali9B#+iJy-gf)Xjv5&9(aZsn`JEm_j1T$ax#5o}nHPNR zZV~=f)^^K# zy8Vo}c5Vv|>_1n}o4lnmUPbgv=Or#Oew_mt{El+me$@OvbH66a^vY4&7q8auR6F#p zFm9orpofY>)qmU;m9%q@i|+#wfB0Su>l@#D;d?HutDN(}2lV*v75BC_SRT$H`8^Q! z9FF5VzO*xr?*qje$9Hbyg8BH~5Pv@+)_A_#8qPbrmI%fg$9L*NIS1SzydOyX`SE-o zGMtyc%a?rY?DQ6X2c2^UzWe+M}h-IYF1iOm1m(j^Y?cX(sb=X;!?-4f6FJG7-dKl^e1PHZf?e0L0g zBLC#?sPgwMVvXm!itvR!g?}eB)_A_dhdziueD@Q*@^=#L-VuMtG8SFFUxz+o(Is#F zrS$o3TgZ3#cLe!vS}Z>K9vt}bcg*}hH9zAUx*zjzyHgFd?MM8b zB)*GR6ul5n>~5D6N%XYc^o#N@_zZF8@2UJ!I{cj%^nI|IrHRt&H94;vFOBFH~G85e1CqU_fO$>=e0*(s#HX z=Gwl>&Hs5m!gY}E@5GWjT;B@6ue$2z=ZUp$h3f|2QH?b}@gKVaw=iBFj2AzA|1RhHD5T6@6a{!{06XoZ#}U5+ zsB1pBzmMO6bE!HSfd6)0uR8bNGmRJe#-2DI=AJqH!^bT?Z|7b&d@#P4f7io3YtH4l z9}V93P7*)p3Hsba#?GL7!DtB_-UH_TE_T8_T=em+&lRDwWTEEeo-6lQnV);8$jyBX z?uim-!IgVv+|vYq?nj}Q;y&-^ehv4Op7A=&x&j~ED`foR);Dr&pDj6<54+^vD1OKL zCGN?kna)QoG@g5G+*4!z&&^jEzt_c1xCcgl<}@)Xl*{j}k&Am&=;LP3f6OKK2cVC9 z@Nw5f@yR`*UA8~u=Kc=);@%DSYf=N;c;9tos_0@jocH4Q+%rMX++$!~eveVh=P~fl zJq+Hjac=@WHgtdFz*-uv>W@JvEkTGe(^%*(&+>D$swtMF~XTea5 zjNi&iO@wrQ1=((lZ=>0?iIYM~i#i;s9+y9`;D3IpU@`;q zX`hMG=s7u=Q>Ty3oF0<;&x<7_IHL8}7^HqeGAX3K>Ors-VC6{x_XLdSd_p literal 0 HcmV?d00001 diff --git a/models/demos/llama3/tests/test_interleaved_to_sharded.py b/models/demos/llama3/tests/test_interleaved_to_sharded.py index 62a0a20dd2e..e2915d8b7a8 100644 --- a/models/demos/llama3/tests/test_interleaved_to_sharded.py +++ b/models/demos/llama3/tests/test_interleaved_to_sharded.py @@ -6,16 +6,7 @@ from loguru import logger import os import ttnn -from models.demos.llama3.tt.llama_common import ( - precompute_freqs, -) -from models.demos.llama3.tt.llama_decoder import TtTransformerBlock from models.demos.llama3.tt.model_config import TtModelArgs -from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import TransformerBlock -from models.utility_functions import ( - comp_pcc, - comp_allclose, -) from models.utility_functions import skip_for_grayskull @@ -31,8 +22,6 @@ indirect=True, ) def test_llama_decoder_inference(mesh_device, use_program_cache, reset_seeds): - dtype = ttnn.bfloat8_b - mesh_device.enable_async(True) model_args = TtModelArgs(mesh_device) @@ -43,42 +32,20 @@ def test_llama_decoder_inference(mesh_device, use_program_cache, reset_seeds): partial_state_dict = { k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) } - reference_model = TransformerBlock(layer_id=0, args=model_args) + reference_model = model_args.reference_decoder() reference_model.load_state_dict(partial_state_dict) - generation_start_pos = 0 generation_length = 10 - all_tests_pass = True - - # Initialize TT model - tt_model = TtTransformerBlock( - args=model_args, - mesh_device=mesh_device, - dtype=dtype, - state_dict=state_dict, - layer_num=0, - weight_cache_path=model_args.weight_cache_path(dtype), - ) seqlen = 1 batch = model_args.max_batch_size - cos, sin = precompute_freqs(model_args.head_dim, model_args.max_seq_len * 2, model_args.rope_scaling_factor) - freqs_cis = torch.complex(cos, sin) - for i in range(generation_length): logger.info(f"[Decoder] Generating token {i}") # input = torch.randn(1, 32, 4096) pt_decode_input = (torch.rand(batch, seqlen, model_args.dim) * 2) - 1 tt_decode_input = pt_decode_input.clone() - current_pos = generation_start_pos + i - current_pos_tensor = ttnn.from_torch( - torch.tensor([current_pos] * batch), - device=mesh_device, - dtype=ttnn.int32, - mesh_mapper=ttnn.ReplicateTensorToMesh(mesh_device), - ) decode_input = model_args.prepare_residual_tensor_decode( tt_decode_input, diff --git a/models/demos/llama3/tests/test_llama_accuracy.py b/models/demos/llama3/tests/test_llama_accuracy.py index c77f3e3c914..d0fd2d2a15b 100644 --- a/models/demos/llama3/tests/test_llama_accuracy.py +++ b/models/demos/llama3/tests/test_llama_accuracy.py @@ -9,21 +9,16 @@ import ttnn from models.demos.llama3.tt.llama_common import ( get_prefill_rot_mat, - HostEmbedding, PagedAttentionConfig, ) from models.demos.llama3.tt.llama_model import TtTransformer from models.demos.llama3.tt.model_config import TtModelArgs, LlamaOptimizations -from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer from models.demos.llama3.demo.demo import preprocess_inputs_prefill from pathlib import Path -def get_accuracy_thresholds(model_name: str, device_name: str, optimizations: LlamaOptimizations): +def get_accuracy_thresholds(base_model_name: str, device_name: str, optimizations: LlamaOptimizations): """Parse accuracy thresholds from PERF.md for the given model, optimization mode, and device.""" - # Get model size (e.g., "1b", "3b", etc.) - model_size = model_name.split("-")[1].lower() - # Read PERF.md perf_file = Path(__file__).parent.parent / "PERF.md" with open(perf_file, "r") as f: @@ -31,22 +26,28 @@ def get_accuracy_thresholds(model_name: str, device_name: str, optimizations: Ll # Split into sections based on optimization mode sections = content.split("## ") - target_section = next(s for s in sections if s.startswith(f"LlamaOptimizations.{optimizations.__name__}\n")) + target_section = next(s for s in sections if s.lower().startswith(f"{optimizations.__name__}\n")) # Parse the table and find the row for our model and device + # Potential lines have the form "| Llama3.1-8b | T3K | 91 | 99 | 49.8 |" + correct_line = ( + lambda line: "|" in line + and base_model_name.lower() in line.split("|")[1].strip().lower() + and device_name.lower() in line.split("|")[2].strip().lower() + ) rows = [ line.split("|")[1:] # Each row starts with a separator - for line in target_section.replace(" ", "").split("\n") - if f"|{model_size}|{device_name}|" in line + for line in target_section.split("\n") + if correct_line(line) ] if not rows: raise ValueError( - f"Could not find accuracy data for {model_size} on {device_name} in {optimizations.__name__} mode" + f"Could not find accuracy data for {base_model_name} on {device_name} in {optimizations.__name__} mode" ) assert ( len(rows) == 1 - ), f"Found multiple rows for {model_size} on {device_name} in {optimizations.__name__} mode in PERF.md" + ), f"Found multiple rows for {base_model_name} on {device_name} in {optimizations.__name__} mode in PERF.md" row = rows[0] top1_acc = float(row[2].strip()) top5_acc = float(row[3].strip()) @@ -60,11 +61,12 @@ def get_accuracy_thresholds(model_name: str, device_name: str, optimizations: Ll @pytest.mark.parametrize( "prefill_len, decode_len, max_seq_len", # Max seqlen should be at least prefill_len + decode_len ((512, 128, 1024),), + # ((131072-8192, 8192-1, 131072),), ) @pytest.mark.parametrize( "mesh_device", [ - {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + {"N150": (1, 1), "N300": (1, 2), "N150x4": (1, 4), "T3K": (1, 8), "TG": (8, 4)}.get( os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) ) ], @@ -130,7 +132,7 @@ def test_tt_model_acc( mesh_device, optimizations=optimizations, max_batch_size=batch_size, max_seq_len=max_seq_len ) - tokenizer = Tokenizer(model_args.tokenizer_path) + tokenizer = model_args.tokenizer # Load state_dict for TT model logger.info("Loading weights...") @@ -138,11 +140,10 @@ def test_tt_model_acc( logger.info("Finished loading weights...") # Load the reference data - model_size = model_args.model_name.split("-")[1].lower() # e.g., "1b", "3b", "8b", "70b" if use_reference_file: # Existing reference file loading logic - reference_data_file = f"models/demos/llama3/tests/reference_outputs/{model_size}.refpt" + reference_data_file = f"models/demos/llama3/tests/reference_outputs/{model_args.model_name}.refpt" logger.info(f"Loading reference data from {reference_data_file}") assert os.path.exists(reference_data_file) reference_data = torch.load(reference_data_file) @@ -201,7 +202,7 @@ def test_tt_model_acc( paged_attention_config=paged_attention_config, ) # Initialize embedding - embd = HostEmbedding(model_args) + embd = model_args.reference_embedding() state_dict_prefix = model_args.get_state_dict_prefix("", None) embd.load_state_dict({"emb.weight": state_dict[f"{state_dict_prefix}tok_embeddings.weight"]}) @@ -230,8 +231,10 @@ def test_tt_model_acc( model_args.head_dim, model_args.max_seq_len, mesh_device, - seq_len=prefill_lens[0], - scale_factor=model_args.rope_scaling_factor, + prefill_lens[0], + model_args.rope_theta, + model_args.rope_scaling_factor, + model_args.orig_context_len, ) prefill_input = model_args.prepare_residual_tensor_prefill( @@ -438,7 +441,7 @@ def test_tt_model_acc( # Get accuracy thresholds from PERF.md min_top1_acc, min_top5_acc = get_accuracy_thresholds( - model_args.model_name, + model_args.base_model_name, model_args.device_name, optimizations, ) diff --git a/models/demos/llama3/tests/test_llama_attention.py b/models/demos/llama3/tests/test_llama_attention.py index c0a077b465c..e942eb8a3f8 100644 --- a/models/demos/llama3/tests/test_llama_attention.py +++ b/models/demos/llama3/tests/test_llama_attention.py @@ -13,7 +13,6 @@ precompute_freqs, PagedAttentionConfig, ) -from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import Attention from models.utility_functions import ( comp_pcc, comp_allclose, @@ -71,7 +70,7 @@ def test_llama_attention_inference( mesh_device.enable_async(True) model_args = TtModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len) - model_args.n_layers = 1 # For the unit test, just run a sigle layer + model_args.n_layers = 1 # For the unit test, just run a single layer state_dict = model_args.load_state_dict() @@ -81,7 +80,7 @@ def test_llama_attention_inference( k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) } - reference_model = Attention(args=model_args) + reference_model = model_args.reference_attention() reference_model.load_state_dict(partial_state_dict) seq_len = 1 @@ -97,8 +96,8 @@ def test_llama_attention_inference( model_args.head_dim, model_args.max_seq_len, model_args.rope_theta, - model_args.use_scaled_rope, model_args.rope_scaling_factor, + model_args.orig_context_len, ) transformation_mats = rope_setup.get_both_trans_mats() @@ -146,8 +145,8 @@ def test_llama_attention_inference( model_args.head_dim, model_args.max_seq_len * 2, model_args.rope_theta, - model_args.use_scaled_rope, model_args.rope_scaling_factor, + model_args.orig_context_len, ) freqs_cis = torch.complex(cos, sin) @@ -166,7 +165,7 @@ def test_llama_attention_inference( for i in range(generation_length): # 70B attention block typically sees tensors with mean 0 and std 0.03 - 0.05 in layer 1 - pt_attention_input = torch.randn(batch_size, seq_len, model_args.dim) * 0.05 + pt_attention_input = torch.randn(batch_size, seq_len, model_args.dim) # Qwen2.5 0.5B sees 0.1 to 2.1 tt_attention_input = pt_attention_input.clone() @@ -209,7 +208,7 @@ def test_llama_attention_inference( all_tests_pass = False # Increment position - current_pos = torch.tensor([generation_start_pos + i for _ in range(batch_size)]) + current_pos = torch.tensor([generation_start_pos + i + 1 for _ in range(batch_size)]) current_pos_tensor = ttnn.from_torch( current_pos, device=mesh_device, @@ -266,21 +265,16 @@ def test_llama_attention_inference( )[:batch_size, :, :, :] for cache in tt_model.layer_past ] - - for i, (cache_pt, cache_tt) in enumerate(zip(pytorch_layer_present, tt_layer_present)): - cache_length_to_check = min(model_args.max_seq_len, generation_start_pos + generation_length + 1) + for label, cache_pt, cache_tt in zip(["K", "V"], pytorch_layer_present, tt_layer_present): + cache_length_to_check = min(model_args.max_seq_len, generation_start_pos + i + 1) cache_pt = cache_pt[:, :, generation_start_pos:cache_length_to_check, :] cache_tt = cache_tt[:, :, generation_start_pos:cache_length_to_check, :] does_pass, output_pcc = comp_pcc(cache_pt, cache_tt, pcc) - if i == 0: - logger.info(f"K cache output: {output_pcc}") - else: - logger.info(f"V cache output: {output_pcc}") - + logger.info(f"{label} cache output: {output_pcc}") if does_pass: - logger.info(f"KV Cache Passed!") + logger.info(f"{label} cache Passed!") else: - logger.warning(f"KV Cache Failed! PCC value is lower than {pcc}") + logger.warning(f"{label} Cache Failed! PCC value is lower than {pcc}") all_tests_pass = False if all_tests_pass: diff --git a/models/demos/llama3/tests/test_llama_attention_prefill.py b/models/demos/llama3/tests/test_llama_attention_prefill.py index b8496e652a2..bf1db31f622 100644 --- a/models/demos/llama3/tests/test_llama_attention_prefill.py +++ b/models/demos/llama3/tests/test_llama_attention_prefill.py @@ -13,7 +13,7 @@ get_rot_transformation_mat, PagedAttentionConfig, ) -from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import Attention, precompute_freqs_cis +from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import precompute_freqs_cis from models.utility_functions import ( comp_pcc, comp_allclose, @@ -51,7 +51,7 @@ @pytest.mark.parametrize( "max_seq_len", ( - 2048, + 256, # 4096, # 1024 * 32, # 1024 * 64, ), @@ -80,7 +80,7 @@ def test_llama_attention_inference( partial_state_dict = { k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) } - reference_model = Attention(args=model_args) + reference_model = model_args.reference_attention() reference_model.load_state_dict(partial_state_dict) # pre-compute the rotational embedding matrix and send to device @@ -88,10 +88,13 @@ def test_llama_attention_inference( model_args.head_dim, model_args.max_seq_len, mesh_device, - seq_len=max_seq_len, - scale_factor=model_args.rope_scaling_factor, + max_seq_len, + model_args.rope_theta, + model_args.rope_scaling_factor, + model_args.orig_context_len, ) transformation_mat_torch = get_rot_transformation_mat(model_args.head_dim) + transformation_mats_prefill = ttnn.as_tensor( transformation_mat_torch, dtype=ttnn.bfloat16, @@ -165,7 +168,6 @@ def test_llama_attention_inference( model_args.head_dim, model_args.max_seq_len * 2, model_args.rope_theta, - model_args.use_scaled_rope, model_args.rope_scaling_factor, )[positions] attn_mask = torch.full((max_seq_len, max_seq_len), torch.finfo(torch.float32).min) diff --git a/models/demos/llama3/tests/test_llama_decoder.py b/models/demos/llama3/tests/test_llama_decoder.py index c74a4aa3dbc..df7562461c4 100644 --- a/models/demos/llama3/tests/test_llama_decoder.py +++ b/models/demos/llama3/tests/test_llama_decoder.py @@ -13,7 +13,6 @@ from models.demos.llama3.tt.model_config import TtModelArgs from models.demos.llama3.tt.llama_decoder import TtTransformerBlock from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup -from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import TransformerBlock from models.utility_functions import ( comp_pcc, comp_allclose, @@ -78,7 +77,7 @@ def test_llama_decoder_inference( partial_state_dict = { k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) } - reference_model = TransformerBlock(layer_id=0, args=model_args) + reference_model = model_args.reference_decoder() reference_model.load_state_dict(partial_state_dict) generation_start_pos = 0 @@ -92,8 +91,8 @@ def test_llama_decoder_inference( model_args.head_dim, model_args.max_seq_len, model_args.rope_theta, - model_args.use_scaled_rope, model_args.rope_scaling_factor, + model_args.orig_context_len, ) transformation_mats = rope_setup.get_both_trans_mats() @@ -143,8 +142,8 @@ def test_llama_decoder_inference( model_args.head_dim, model_args.max_seq_len * 2, model_args.rope_theta, - model_args.use_scaled_rope, model_args.rope_scaling_factor, + model_args.orig_context_len, ) freqs_cis = torch.complex(cos, sin) diff --git a/models/demos/llama3/tests/test_llama_decoder_prefill.py b/models/demos/llama3/tests/test_llama_decoder_prefill.py index 85f767b3301..53cbf81cb03 100644 --- a/models/demos/llama3/tests/test_llama_decoder_prefill.py +++ b/models/demos/llama3/tests/test_llama_decoder_prefill.py @@ -13,7 +13,7 @@ ) from models.demos.llama3.tt.llama_decoder import TtTransformerBlock from models.demos.llama3.tt.model_config import TtModelArgs -from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import TransformerBlock, precompute_freqs_cis +from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import precompute_freqs_cis from models.utility_functions import ( comp_pcc, comp_allclose, @@ -79,7 +79,7 @@ def test_llama_decoder_inference( k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) } - reference_model = TransformerBlock(layer_id=0, args=model_args) + reference_model = model_args.reference_decoder() reference_model.load_state_dict(partial_state_dict) generation_start_pos = 0 @@ -91,8 +91,10 @@ def test_llama_decoder_inference( model_args.head_dim, model_args.max_seq_len, mesh_device, - seq_len=max_seq_len, - scale_factor=model_args.rope_scaling_factor, + max_seq_len, + model_args.rope_theta, + model_args.rope_scaling_factor, + model_args.orig_context_len, ) transformation_mat_torch = get_rot_transformation_mat(model_args.head_dim) transformation_mats_prefill = ttnn.as_tensor( @@ -153,7 +155,6 @@ def test_llama_decoder_inference( model_args.head_dim, model_args.max_seq_len * 2, model_args.rope_theta, - model_args.use_scaled_rope, model_args.rope_scaling_factor, )[positions] diff --git a/models/demos/llama3/tests/test_llama_embedding.py b/models/demos/llama3/tests/test_llama_embedding.py index 9c42a859a94..71d56a3a7f4 100644 --- a/models/demos/llama3/tests/test_llama_embedding.py +++ b/models/demos/llama3/tests/test_llama_embedding.py @@ -8,13 +8,11 @@ import ttnn from models.demos.llama3.tt.llama_embedding import TtLlamaEmbedding from models.demos.llama3.tt.model_config import TtModelArgs -from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer from models.utility_functions import ( comp_pcc, comp_allclose, ) from models.utility_functions import skip_for_grayskull -from models.demos.llama3.tt.llama_common import HostEmbedding @torch.no_grad() @@ -44,9 +42,9 @@ def test_llama_embedding(max_seq_len, batch_size, mesh_device, use_program_cache model_args.n_layers = 1 state_dict = model_args.load_state_dict() - tokenizer = Tokenizer(model_args.tokenizer_path) + tokenizer = model_args.tokenizer - reference_emb = HostEmbedding(model_args) + reference_emb = model_args.reference_embedding() if model_args.is_vision(): layer_name = "text_model.tok_embeddings.weight" else: @@ -62,7 +60,7 @@ def test_llama_embedding(max_seq_len, batch_size, mesh_device, use_program_cache ) prompts = ["Joy"] * 32 - pt_input = torch.tensor([tokenizer.encode(prompt, bos=False, eos=False) for prompt in prompts]) + pt_input = torch.tensor([model_args.encode_prompt(prompt, instruct=False) for prompt in prompts]) reference_output = reference_emb(pt_input) logger.info(f"reference_output: {reference_output.shape}") diff --git a/models/demos/llama3/tests/test_llama_mlp.py b/models/demos/llama3/tests/test_llama_mlp.py index 7d785a554b7..710ee9498c5 100644 --- a/models/demos/llama3/tests/test_llama_mlp.py +++ b/models/demos/llama3/tests/test_llama_mlp.py @@ -9,7 +9,6 @@ import ttnn from models.demos.llama3.tt.llama_mlp import TtLlamaMLP from models.demos.llama3.tt.model_config import TtModelArgs -from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import FeedForward from models.utility_functions import ( comp_pcc, comp_allclose, @@ -57,12 +56,7 @@ def test_llama_mlp_inference(seq_len, batch_size, mesh_device, use_program_cache } model_args.WEIGHTS_DTYPE = dtype - reference_model = FeedForward( - dim=model_args.dim, - hidden_dim=4 * model_args.dim, - multiple_of=model_args.multiple_of, - ffn_dim_multiplier=model_args.ffn_dim_multiplier, - ) + reference_model = model_args.reference_mlp() reference_model.load_state_dict(partial_state_dict) tt_model = TtLlamaMLP( @@ -84,12 +78,14 @@ def test_llama_mlp_inference(seq_len, batch_size, mesh_device, use_program_cache ), # When both dims are None, the mapper used is `ReplicateTensorToMesh` dtype=ttnn.bfloat8_b, memory_config=( - tt_model.model_config["MLP_ACT_MEMCFG"] - if model_args.is_galaxy - else model_args.model_config["SHARDED_MLP_INPUT_MEMCFG"] - ) - if mode == "decode" - else ttnn.DRAM_MEMORY_CONFIG, + ( + tt_model.model_config["MLP_ACT_MEMCFG"] + if model_args.is_galaxy + else model_args.model_config["SHARDED_MLP_INPUT_MEMCFG"] + ) + if mode == "decode" + else ttnn.DRAM_MEMORY_CONFIG + ), layout=ttnn.TILE_LAYOUT, ) diff --git a/models/demos/llama3/tests/test_llama_model.py b/models/demos/llama3/tests/test_llama_model.py index a41645f3394..fefda03034f 100644 --- a/models/demos/llama3/tests/test_llama_model.py +++ b/models/demos/llama3/tests/test_llama_model.py @@ -8,14 +8,10 @@ import ttnn from models.demos.llama3.tt.llama_common import ( sample_host, - encode_prompt_llama_instruct, - HostEmbedding, PagedAttentionConfig, ) from models.demos.llama3.tt.model_config import TtModelArgs, LlamaOptimizations from models.demos.llama3.tt.llama_model import TtTransformer -from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import Transformer -from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer from models.utility_functions import ( comp_pcc, comp_allclose, @@ -92,7 +88,7 @@ def test_llama_model_inference( dtype = ttnn.bfloat8_b mesh_device.enable_async(True) mode_accuracy = optimizations == LlamaOptimizations.accuracy - instruct = True if weights == "instruct" else False + instruct = False # True if weights == "instruct" else False dummy_weights = True if weights == "random" else False model_args = TtModelArgs( mesh_device, @@ -103,49 +99,52 @@ def test_llama_model_inference( max_batch_size=batch_size, ) - model_name = { - (16, False): "llama32_1b", - (28, False): "llama32_3b", - (32, False): "llama31_8b", - (32, True): "llama32_11b", - (80, False): "llama31_70b", - }[(model_args.n_layers, model_args.is_vision())] - # Define minimum PCC for each iteration if layers == 1: pcc = 0.88 if mode_accuracy else 0.86 else: pcc = 0.94 if mode_accuracy else 0.86 - # Define tight final PCC thresholds for quick mode - final_model_pcc = { - "llama32_1b": 0.9990 if mode_accuracy else 0.9864, - "llama32_3b": 0.9989 if mode_accuracy else 0.9837, - "llama31_8b": 0.9987 if mode_accuracy else 0.9850, - "llama32_11b": 0.9987 if mode_accuracy else 0.9850, - "llama31_70b": 0.9419 if mode_accuracy else 0.9419, - }[model_name] - - final_k_cache_pcc = { - "llama32_1b": 0.9998, - "llama32_3b": 0.9998, - "llama31_8b": 0.9997, - "llama32_11b": 0.9995, - "llama31_70b": 0.9997, - }[model_name] - final_v_cache_pcc = { - "llama32_1b": 0.9996, - "llama32_3b": 0.9998, - "llama31_8b": 0.9997, - "llama32_11b": 0.9996, - "llama31_70b": 0.9997, - }[model_name] - - quick_iterations = {"llama32_1b": 2, "llama32_3b": 4, "llama31_8b": 6, "llama32_11b": 6, "llama31_70b": 6}[ - model_name - ] - - iterations = quick_iterations if layers == 1 else 9 + if layers == 1: # quick mode has tight PCC checks for known models + model_name = { + (16, False): "llama32_1b", + (28, False): "llama32_3b", + (32, False): "llama31_8b", + (32, True): "llama32_11b", + (80, False): "llama31_70b", + }[(model_args.n_layers, model_args.is_vision())] + + # Define tight final PCC thresholds for quick mode + final_model_pcc = { + "llama32_1b": 0.9991 if mode_accuracy else 0.9864, + "llama32_3b": 0.9989 if mode_accuracy else 0.9837, + "llama31_8b": 0.9987 if mode_accuracy else 0.9850, + "llama32_11b": 0.9987 if mode_accuracy else 0.9850, + "llama31_70b": 0.9843 if mode_accuracy else 0.97607, + }[model_name] + + final_k_cache_pcc = { + "llama32_1b": 0.9998, + "llama32_3b": 0.9998, + "llama31_8b": 0.9997, + "llama32_11b": 0.9995, + "llama31_70b": 0.9997, + }[model_name] + final_v_cache_pcc = { + "llama32_1b": 0.9996, + "llama32_3b": 0.9998, + "llama31_8b": 0.9997, + "llama32_11b": 0.9996, + "llama31_70b": 0.9997, + }[model_name] + + quick_iterations = {"llama32_1b": 2, "llama32_3b": 4, "llama31_8b": 6, "llama32_11b": 6, "llama31_70b": 6}[ + model_name + ] + + iterations = quick_iterations + else: + iterations = 9 if layers is not None: model_args.n_layers = layers @@ -172,18 +171,18 @@ def test_llama_model_inference( ] * model_args.max_batch_size # "This is a test" encoded prompt assert not instruct, "Instruct prompt not implemented with dummy weights" else: - tokenizer = Tokenizer(model_args.tokenizer_path) + tokenizer = model_args.tokenizer if instruct: - encoded_prompts = [encode_prompt_llama_instruct(tokenizer, prompt) for prompt in prompts] + encoded_prompts = [model_args.encode_prompt(prompt) for prompt in prompts] else: - encoded_prompts = [tokenizer.encode(prompt, bos=True, eos=False) for prompt in prompts] + encoded_prompts = [model_args.encode_prompt(prompt, instruct=False) for prompt in prompts] if run_ref_pt: - reference_model = Transformer(model_args) + reference_model = model_args.reference_transformer() reference_model.load_state_dict(reference_state_dict) # Embedding on host - embd = HostEmbedding(model_args) + embd = model_args.reference_embedding() embd.load_state_dict({"emb.weight": state_dict[f"{state_dict_prefix}tok_embeddings.weight"]}) generation_start_pos = 0 @@ -320,15 +319,21 @@ def test_llama_model_inference( pt_decode_input = embd(encoded_prompts_tensor[:, i]).view(batch, seqlen, -1) else: # Greedy decode (temperature = 0) the generated token and save it to print out later - tt_out_tok = sample_host(tt_output_torch, None, temperature=0, top_p=0.8) - tt_decode_input = embd(tt_out_tok) - all_outputs.append(tt_out_tok.squeeze(1).tolist()[0]) # Update generated token to list of TT outputs if run_ref_pt: + # Sample from reference model first pt_out_tok = sample_host(ref_output, None, temperature=0, top_p=0.8) pt_decode_input = embd(pt_out_tok) - all_outputs_ref.append( - pt_out_tok.squeeze(1).tolist()[0] - ) # Update generated token to list of ref outputs + all_outputs_ref.append(pt_out_tok.squeeze(1).tolist()[0]) + + # Use the same token for TT model (teacher forcing) + tt_decode_input = pt_decode_input + all_outputs.append(pt_out_tok.squeeze(1).tolist()[0]) + else: + # If not running reference model, sample from TT model directly + tt_out_tok = sample_host(tt_output_torch, None, temperature=0, top_p=0.8) + tt_decode_input = embd(tt_out_tok) + all_outputs.append(tt_out_tok.squeeze(1).tolist()[0]) + # Measure PCC if also running reference model if run_ref_pt: if layers == 1 and i == iterations - 1: # On last iteration in the quick test, set a tighter PCC @@ -432,6 +437,7 @@ def test_llama_model_inference( logger.info(f"All {generation_length} Llama decode iterations Passed!") else: logger.warning("One or more iterations of Llama decode had bad PCC") - assert final_tests_pass, f"PCC value is lower than {final_model_pcc} for final output. Check Warnings!" + if layers == 1: + assert final_tests_pass, f"PCC value is lower than {final_model_pcc} for final output. Check Warnings!" assert kv_cache_tests_pass, f"KV Cache PCC value is lower expected for some of the outputs. Check Warnings!" assert all_tests_pass, f"PCC value is lower than {pcc} for some of the outputs. Check Warnings!" diff --git a/models/demos/llama3/tests/test_llama_model_prefill.py b/models/demos/llama3/tests/test_llama_model_prefill.py index 91e45e8bc98..fb16414e979 100644 --- a/models/demos/llama3/tests/test_llama_model_prefill.py +++ b/models/demos/llama3/tests/test_llama_model_prefill.py @@ -9,15 +9,10 @@ import ttnn from models.demos.llama3.tt.llama_common import ( get_prefill_rot_mat, - get_rot_transformation_mat, - HostEmbedding, - encode_prompt_llama_instruct, PagedAttentionConfig, ) from models.demos.llama3.tt.llama_model import TtTransformer from models.demos.llama3.tt.model_config import TtModelArgs, LlamaOptimizations -from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import Transformer -from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer from models.utility_functions import ( comp_pcc, comp_allclose, @@ -98,7 +93,7 @@ def test_llama_model_inference( instruct = True model_args = TtModelArgs(mesh_device, max_batch_size=batch_size, optimizations=optimizations, max_seq_len=seq_len) - tokenizer = Tokenizer(model_args.tokenizer_path) + tokenizer = model_args.tokenizer logger.info("Loading weights...") state_dict_prefix = model_args.get_state_dict_prefix("", None) @@ -125,16 +120,14 @@ def test_llama_model_inference( with bz2.open(prompt_file, "rt", encoding="utf-8") as f: prompt = f.read() - if instruct: - encoded_prompt = encode_prompt_llama_instruct(tokenizer, prompt)[:seq_len] - else: - encoded_prompt = tokenizer.encode(prompt, bos=True, eos=False)[:seq_len] + encoded_prompt = model_args.encode_prompt(prompt, instruct=instruct)[:seq_len] if run_ref_pt: - reference_model = Transformer(model_args) + reference_model = model_args.reference_transformer() reference_model.load_state_dict(reference_state_dict) + # Embedding on host - embd = HostEmbedding(model_args) + embd = model_args.reference_embedding() embd.load_state_dict({"emb.weight": state_dict[f"{state_dict_prefix}tok_embeddings.weight"]}) # pre-compute the rotational embedding matrix and send to device @@ -142,8 +135,10 @@ def test_llama_model_inference( model_args.head_dim, model_args.max_seq_len, mesh_device, - seq_len=seq_len, - scale_factor=model_args.rope_scaling_factor, + seq_len, + model_args.rope_theta, + model_args.rope_scaling_factor, + model_args.orig_context_len, ) # Setup page table page_table_tt = None diff --git a/models/demos/llama3/tests/test_llama_rms_norm.py b/models/demos/llama3/tests/test_llama_rms_norm.py index 5fdc99ee14d..4493b8b4518 100644 --- a/models/demos/llama3/tests/test_llama_rms_norm.py +++ b/models/demos/llama3/tests/test_llama_rms_norm.py @@ -8,7 +8,6 @@ import ttnn from models.common.rmsnorm import RMSNorm as TtRMSNorm from models.demos.llama3.tt.model_config import TtModelArgs -from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import RMSNorm as RefRMSNorm from models.utility_functions import ( comp_pcc, comp_allclose, @@ -77,7 +76,7 @@ def test_llama_rms_norm_inference( partial_state_dict = { k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) } - reference_model = RefRMSNorm(dim=model_args.dim, eps=model_args.norm_eps) + reference_model = model_args.reference_rms_norm() reference_model.load_state_dict(partial_state_dict) input = torch.rand(1, 1, 32, model_args.dim) @@ -90,9 +89,9 @@ def test_llama_rms_norm_inference( dtype=dtype, layout=ttnn.TILE_LAYOUT, mesh_mapper=ttnn.ShardTensor2dMesh(mesh_device, dims=(None, -1), mesh_shape=model_args.cluster_shape), - memory_config=model_args.get_model_config()["DECODE_RESIDUAL_MEMCFG"] - if mode == "decode" - else ttnn.DRAM_MEMORY_CONFIG, + memory_config=( + model_args.get_model_config()["DECODE_RESIDUAL_MEMCFG"] if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG + ), ) tt_output = tt_model(tt_input, mode=mode) diff --git a/models/demos/llama3/tests/test_llama_torch.py b/models/demos/llama3/tests/test_llama_torch.py index 90713eb01ab..3ff878c5ec0 100644 --- a/models/demos/llama3/tests/test_llama_torch.py +++ b/models/demos/llama3/tests/test_llama_torch.py @@ -4,10 +4,7 @@ import torch # import ttnn -from models.demos.llama3.tt.llama_common import HostEmbedding from models.demos.llama3.tt.model_config import TtModelArgs -from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import Transformer -from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer from loguru import logger @@ -18,16 +15,16 @@ def test_llama_torch_inference(ensure_gc): model_args = TtModelArgs(mesh_device=None) state_dict = model_args.load_state_dict() - tokenizer = Tokenizer(model_args.tokenizer_path) + tokenizer = model_args.tokenizer prompts = ["1 2 3 4 "] * model_args.max_batch_size - encoded_prompts = [tokenizer.encode(prompt, bos=True, eos=False) for prompt in prompts] + encoded_prompts = [model_args.encode_prompt(prompt, instruct=False) for prompt in prompts] - reference_model = Transformer(model_args) + reference_model = model_args.reference_transformer() reference_model.load_state_dict(state_dict) # Embedding on host - embd = HostEmbedding(model_args) + embd = model_args.reference_embedding() state_dict_prefix = model_args.get_state_dict_prefix("", None) embd.load_state_dict({"emb.weight": state_dict[f"{state_dict_prefix}tok_embeddings.weight"]}) @@ -66,4 +63,4 @@ def test_llama_torch_inference(ensure_gc): all_outputs_ref.append(pt_out_tok.squeeze(1).tolist()[0]) # Update generated token to list of ref outputs # TODO print all 32 users - logger.info("[User 0] Ref generation: ", "".join(tokenizer.decode(all_outputs_ref))) + logger.info("[User 0] Ref generation: '" + "".join(tokenizer.decode(all_outputs_ref)) + "'") diff --git a/models/demos/llama3/tests/test_lm_head.py b/models/demos/llama3/tests/test_lm_head.py index b3b422b36dc..ea42d7c4eb4 100644 --- a/models/demos/llama3/tests/test_lm_head.py +++ b/models/demos/llama3/tests/test_lm_head.py @@ -9,7 +9,6 @@ import ttnn from models.demos.llama3.tt.lm_head import LMHead from models.demos.llama3.tt.model_config import TtModelArgs -from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import ColumnParallelLinear from models.utility_functions import ( comp_pcc, comp_allclose, @@ -52,7 +51,7 @@ def test_llama_lm_head_inference(seq_len, batch_size, mesh_device, use_program_c } model_args.WEIGHTS_DTYPE = dtype - reference_model = ColumnParallelLinear(model_args.dim, model_args.vocab_size, bias=False, init_method=lambda x: x) + reference_model = model_args.reference_lm_head() reference_model.load_state_dict(partial_state_dict) tt_model = LMHead( diff --git a/models/demos/llama3/tests/test_ref.py b/models/demos/llama3/tests/test_ref.py new file mode 100644 index 00000000000..d3ad5ba20bf --- /dev/null +++ b/models/demos/llama3/tests/test_ref.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 +import torch +import pytest +from loguru import logger +import os +import ttnn +from models.demos.llama3.tt.llama_attention import TtLlamaAttention +from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup +from models.demos.llama3.tt.model_config import TtModelArgs +from models.demos.llama3.tt.llama_common import ( + precompute_freqs, + PagedAttentionConfig, +) +from models.utility_functions import ( + comp_pcc, + comp_allclose, +) +from models.utility_functions import skip_for_grayskull +from models.demos.llama3.tt.load_checkpoints import convert_meta_to_hf, convert_hf_to_meta, map_hf_to_meta_keys + + +@torch.no_grad() +@skip_for_grayskull("Requires wormhole_b0 to run") +@pytest.mark.parametrize( + "mesh_device", + [ + {"N150": (1, 1), "N300": (1, 2), "T3K": (1, 8), "TG": (8, 4)}.get( + os.environ.get("FAKE_DEVICE"), len(ttnn.get_device_ids()) + ) + ], + indirect=True, +) +@pytest.mark.parametrize( + "paged_attention", + ( + # True, + False, + ), + ids=( + # "paged_attention", + "default_attention", + ), +) +@pytest.mark.parametrize( + "page_params", + [{"page_block_size": 32, "page_max_num_blocks": 1024}], +) +@pytest.mark.parametrize( + "batch_size", + (1,), +) +@pytest.mark.parametrize( + "max_seq_len", + (128,), # For decode-only unit test, there's no need to run with large sequence lengths +) +def test_llama_attention_inference( + max_seq_len, + batch_size, + paged_attention, + page_params, + mesh_device, + use_program_cache, + reset_seeds, + ensure_gc, +): + dtype = ttnn.bfloat8_b + pcc = 0.99 + + mesh_device.enable_async(True) + + model_args = TtModelArgs(mesh_device, max_batch_size=batch_size, max_seq_len=max_seq_len) + model_args.n_layers = 1 # For the unit test, just run a single layer + + state_dict = model_args.load_state_dict() + + first_layer_prefix = model_args.get_state_dict_prefix("TtLlamaAttention", 0) + "." + # Ref model needs partial state dict, but our models use full state dict keys as cached weight names + partial_state_dict = { + k[len(first_layer_prefix) :]: v for k, v in state_dict.items() if (k.startswith(first_layer_prefix)) + } + + ref_model = model_args.reference_attention() + ref_model.load_state_dict(partial_state_dict) + + from transformers import AutoModelForCausalLM + + hf_transformer = AutoModelForCausalLM.from_pretrained(model_args.DEFAULT_CKPT_DIR) + hf_model = hf_transformer.model.layers[0].self_attn + hf_model.eval() + + # Get the state dicts + ref_state_dict = ref_model.attention.state_dict() # should contain hf keys and weights + hf_state_dict = hf_model.state_dict() + + for key in ["k_proj", "q_proj"]: + for suffix in ["weight", "bias"]: + print( + f"{key}.{suffix}: ref matches hf : {torch.allclose(ref_state_dict[key + '.' + suffix], hf_state_dict[key + '.' + suffix])}" + ) + + print(" ".join(f"{x:+3.1f}" for x in ref_state_dict["k_proj.bias"])) + print(" ".join(f"{x:+3.1f}" for x in hf_state_dict["k_proj.bias"])) diff --git a/models/demos/llama3/tt/generator_vllm.py b/models/demos/llama3/tt/generator_vllm.py index 846e0cef34f..06a9b1e37ea 100644 --- a/models/demos/llama3/tt/generator_vllm.py +++ b/models/demos/llama3/tt/generator_vllm.py @@ -32,7 +32,7 @@ def initialize_vllm_text_transformer( # Load model args, weights model_args = TtModelArgs( mesh_device, - instruct=("Instruct" in hf_config._name_or_path), + instruct=("Instruct" in hf_config._name_or_path or "DeepSeek-R1-Distill-Llama-70B" in hf_config._name_or_path), max_batch_size=max_batch_size, optimizations=optimizations, max_seq_len=max_seq_len, diff --git a/models/demos/llama3/tt/llama_attention.py b/models/demos/llama3/tt/llama_attention.py index 322e2edf2d2..a2c5490fef8 100644 --- a/models/demos/llama3/tt/llama_attention.py +++ b/models/demos/llama3/tt/llama_attention.py @@ -2,12 +2,14 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import List, Optional +import math import torch import ttnn from models.common.lightweightmodule import LightweightModule from models.demos.llama3.tt.llama_ccl import tt_all_reduce, tt_all_gather +from models.demos.llama3.tt.llama_common import first_five +from models.demos.llama3.tt.load_checkpoints import permute class TtLlamaAttention(LightweightModule): @@ -41,7 +43,7 @@ def __init__( self.num_reduce_scatter_links = configuration.num_reduce_scatter_links self.num_all_gather_links = configuration.num_all_gather_links self.MAX_QKV_MM_SEQ_LEN = configuration.MAX_QKV_MM_SEQ_LEN - + self.tile_size = configuration.tile_size self.num_device_groups = self.num_devices // self.n_kv_heads self.num_devices_per_group = self.n_kv_heads if self.TG else self.num_devices self.batch_size_per_device_group = ( @@ -99,10 +101,65 @@ def __init__( else: cache_name = lambda name: weight_cache_path / (f"{layer_name}.{name}") - wq_str = f"{layer_name}.wq.weight" - wk_str = f"{layer_name}.wk.weight" - wv_str = f"{layer_name}.wv.weight" - wo_str = f"{layer_name}.wo.weight" + wq_str = f"{layer_name}.wq" + wk_str = f"{layer_name}.wk" + wv_str = f"{layer_name}.wv" + wo_str = f"{layer_name}.wo" + + # Initialize bias tensors as None + self.wqkv_bias_decode = None + self.wqkv_bias_prefill = None + + # Create combined QKV bias if present in state dict + if f"{wq_str}.bias" in self.state_dict: + qkv_bias = torch.concat( + [ + torch.concat( + [ + torch.chunk(self.state_dict[f"{wq_str}.bias"], configuration.num_devices)[i], + torch.chunk(self.state_dict[f"{wk_str}.bias"], configuration.num_devices)[i], + torch.chunk(self.state_dict[f"{wv_str}.bias"], configuration.num_devices)[i], + ], + dim=-1, + ) + for i in range(configuration.num_devices) + ], + dim=-1, + ) + # Prefill can use broadcasting on the bias add so wants a 1d tensor + self.wqkv_bias_prefill = ttnn.as_tensor( + qkv_bias, + device=self.mesh_device, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + dtype=self.dtype, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + layout=ttnn.TILE_LAYOUT, + cache_file_name=cache_name("wqkv_bias_prefill_sharded"), + ) + # as_tensor returns (32, dim) which is incorrect, this reshape updates the padded size to the correct size + self.wqkv_bias_prefill = ttnn.reshape( + self.wqkv_bias_prefill, ttnn.Shape([1, 1, 1, self.wqkv_bias_prefill.shape[-1]]) + ) + + # Broadcasting does not seem to be supported inside execute_trace so expand to the whole batch size + # Create a list of bias tensors for each multiple of tile_size up to max_batch_size + self.wqkv_bias_decode = [] + for batch_size in range( + configuration.tile_size, + configuration.tile_padded_batch_rows + configuration.tile_size, + configuration.tile_size, + ): + qkv_bias_decode = qkv_bias.unsqueeze(0).expand(batch_size, -1) + bias_tensor = ttnn.as_tensor( + qkv_bias_decode, + device=self.mesh_device, + mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=-1), + dtype=self.dtype, + memory_config=ttnn.DRAM_MEMORY_CONFIG, + layout=ttnn.TILE_LAYOUT, + cache_file_name=cache_name(f"wqkv_bias_decode_sharded_{batch_size}"), + ) + self.wqkv_bias_decode.append(bias_tensor) # when splitting the devices, we need to make sure that the number of heads is divisible by the number of devices assert self.n_heads % self.num_devices_per_group == 0 @@ -118,9 +175,9 @@ def __init__( qkv_list = [] for i in range(self.num_devices_per_group): # Chunk weights - wq_selected = torch.chunk(self.state_dict[wq_str], self.num_devices_per_group, dim=0)[i] - wk_selected = torch.chunk(self.state_dict[wk_str], self.num_devices_per_group, dim=0)[i] - wv_selected = torch.chunk(self.state_dict[wv_str], self.num_devices_per_group, dim=0)[i] + wq_selected = torch.chunk(self.state_dict[f"{wq_str}.weight"], self.num_devices_per_group, dim=0)[i] + wk_selected = torch.chunk(self.state_dict[f"{wk_str}.weight"], self.num_devices_per_group, dim=0)[i] + wv_selected = torch.chunk(self.state_dict[f"{wv_str}.weight"], self.num_devices_per_group, dim=0)[i] # Transpose the selected chunks wq = torch.transpose(wq_selected, -2, -1) @@ -146,7 +203,7 @@ def __init__( # For ring topology we can use all gather matmul for wo self.use_fused_all_gather_matmul = self.model_config["USE_FUSED_ALL_GATHER_MATMUL"] - pt_wo = self.state_dict[wo_str].transpose(-1, -2).unsqueeze(0).unsqueeze(0) + pt_wo = self.state_dict[f"{wo_str}.weight"].transpose(-1, -2).unsqueeze(0).unsqueeze(0) wo_mem_config = configuration.create_dram_sharded_mem_config( configuration.dim // configuration.num_devices, configuration.dim @@ -163,9 +220,9 @@ def __init__( dims=(2, 3) if (self.use_fused_all_gather_matmul or self.TG) else (3, 2), mesh_shape=configuration.cluster_shape, ), - cache_file_name=cache_name("wo_width_sharded_2d") - if (self.use_fused_all_gather_matmul or self.TG) - else cache_name("wo"), + cache_file_name=( + cache_name("wo_width_sharded_2d") if (self.use_fused_all_gather_matmul or self.TG) else cache_name("wo") + ), ) if not use_paged_kv_cache: # vLLM provides its own kv cache @@ -221,9 +278,11 @@ def init_kv_cache(self, configuration, weight_cache_path): device=self.mesh_device, memory_config=ttnn.DRAM_MEMORY_CONFIG, mesh_mapper=ttnn.ReplicateTensorToMesh(self.mesh_device), - cache_file_name=f"{weight_cache_path}/kvcache_{k_or_v.shape}" - if weight_cache_path and not configuration.dummy_weights - else None, + cache_file_name=( + f"{weight_cache_path}/kvcache_{k_or_v.shape}" + if weight_cache_path and not configuration.dummy_weights + else None + ), ) for k_or_v in [cache_k, cache_v] ] @@ -245,14 +304,28 @@ def forward_decode( # QKV matmuls # Use HiFi2 for DRAM-sharded matmuls as they are otherwise flop-bound. Loses 1 bit of activation precision. ### + + as_torch = lambda tensor: torch.Tensor( + ttnn.to_torch(tensor, mesh_composer=ttnn.ConcatMeshToTensor(self.mesh_device, dim=-1)) + ) + + # print(f"our x:", " ".join(f'{t:+3.1f}' for t in as_torch(x)[0, 0, 0].flatten())) xqkv_fused_sharded = ttnn.linear( x, self.wqkv, + # bias=self.wqkv_bias, memory_config=ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG, program_config=self.model_config["XQKV_DECODE_PROGCFG"], compute_kernel_config=self.compute_kernel_config_hifi2, dtype=self.ccl_dtype if self.TG else ttnn.bfloat16, ) + # FIXME: File bug against dram-sharded matmuls with bias + if self.wqkv_bias_decode: + # select the bias tensor based on the number of tiles in the rows + # WARNING: must not change the batch size between compiling and executing a trace + num_tiles = int(math.ceil(xqkv_fused_sharded.shape[-2] / self.tile_size)) + xqkv_fused_sharded = xqkv_fused_sharded + self.wqkv_bias_decode[num_tiles - 1] + ttnn.deallocate(x) xqkv_fused = tt_all_reduce( xqkv_fused_sharded, @@ -263,6 +336,7 @@ def forward_decode( memory_config=self.model_config["QKV_OUT_GATHERED_MEMCFG"](list(self.mesh_device.shape)[1]), sharded=True, dtype=self.ccl_dtype, + topology=self.ccl_topology, ) if self.TG: @@ -437,13 +511,16 @@ def forward_decode( num_reduce_scatter_links=self.num_reduce_scatter_links, num_all_gather_links=self.num_all_gather_links, dim=0 if (self.TG and self.hidden_size < 8192) else 3, + topology=self.ccl_topology, memory_config=( - self.model_config["SELF_OUT_REDUCE_SCATTER_MEMCFG"] - if self.hidden_size == 8192 - else self.model_config["SELF_OUT_GATHERED_MEMCFG"](list(self.mesh_device.shape)[0]) - ) - if self.TG - else self.model_config["DECODE_RESIDUAL_MEMCFG"], + ( + self.model_config["SELF_OUT_REDUCE_SCATTER_MEMCFG"] + if self.hidden_size == 8192 + else self.model_config["SELF_OUT_GATHERED_MEMCFG"](list(self.mesh_device.shape)[0]) + ) + if self.TG + else self.model_config["DECODE_RESIDUAL_MEMCFG"] + ), sharded=True, dtype=self.ccl_dtype, use_composite=True if self.hidden_size == 8192 else False, @@ -481,12 +558,17 @@ def forward_prefill( xqkv_fused = ttnn.linear( x_11SH, self.wqkv, + # bias=self.wqkv_bias_prefill, dtype=self.ccl_dtype if self.TG else ttnn.bfloat16, memory_config=ttnn.DRAM_MEMORY_CONFIG, compute_kernel_config=self.compute_kernel_config_hifi2, program_config=self.model_config["XQKV_PREFILL_PROGCFG"](seq_len), ) + # FIXME: surely ttnn.linear bias should work? + if self.wqkv_bias_prefill is not None: + xqkv_fused = xqkv_fused + self.wqkv_bias_prefill + xqkv_fused = tt_all_reduce( xqkv_fused, self.mesh_device, @@ -500,6 +582,18 @@ def forward_prefill( if seq_len > self.MAX_QKV_MM_SEQ_LEN: xqkv_fused = ttnn.reshape(xqkv_fused, [1, 1, seq_len, -1]) + def fix(xqkv): + torch_q = xqkv[: self.head_dim * self.n_local_heads] + torch_k = xqkv[ + self.head_dim * self.n_local_heads : self.head_dim * (self.n_local_heads + self.n_local_kv_heads) + ] + torch_v = xqkv[self.head_dim * (self.n_local_heads + self.n_local_kv_heads) :] + to_hf = lambda t: permute(t.unsqueeze(-1), t.shape[0] // self.head_dim, t.shape[0], 1).squeeze(-1) + torch_q = to_hf(torch_q) + torch_k = to_hf(torch_k) + torch_v = torch_v + return torch_k.flatten() + ttnn.deallocate(x_11SH) # split qkv into heads @@ -677,6 +771,7 @@ def forward_prefill( dim=0 if self.TG else 3, num_reduce_scatter_links=self.num_reduce_scatter_links, num_all_gather_links=self.num_all_gather_links, + topology=self.ccl_topology, memory_config=ttnn.DRAM_MEMORY_CONFIG, dtype=self.ccl_dtype, ) diff --git a/models/demos/llama3/tt/llama_ccl.py b/models/demos/llama3/tt/llama_ccl.py index 300c615c187..5e91c6f5209 100644 --- a/models/demos/llama3/tt/llama_ccl.py +++ b/models/demos/llama3/tt/llama_ccl.py @@ -13,6 +13,7 @@ def tt_all_reduce( dim=0, num_reduce_scatter_links=1, num_all_gather_links=2, + topology=ttnn.Topology.Linear, memory_config=None, sharded=False, dtype=ttnn.bfloat16, @@ -40,6 +41,7 @@ def tt_all_reduce( dim=dim, math_op=ttnn.ReduceType.Sum, num_links=num_reduce_scatter_links, + topology=topology, memory_config=memory_config, ) input_tensor.deallocate(True) @@ -63,7 +65,7 @@ def tt_all_reduce( num_links=num_all_gather_links, cluster_axis=cluster_axis, mesh_device=mesh_device, - topology=ttnn.Topology.Linear, + topology=topology, memory_config=ttnn.DRAM_MEMORY_CONFIG if not sharded else memory_config, ) @@ -87,7 +89,7 @@ def tt_all_reduce( cluster_axis=cluster_axis, mesh_device=mesh_device, math_op=ttnn.ReduceType.Sum, - topology=ttnn.Topology.Linear, + topology=topology, memory_config=ttnn.DRAM_MEMORY_CONFIG if not sharded else memory_config, ) @@ -97,7 +99,7 @@ def tt_all_reduce( num_links=num_all_gather_links, cluster_axis=cluster_axis, mesh_device=mesh_device, - topology=ttnn.Topology.Linear, + topology=topology, memory_config=input_mem_cfg, ) diff --git a/models/demos/llama3/tt/llama_common.py b/models/demos/llama3/tt/llama_common.py index 843cf066c78..d1de6bce149 100644 --- a/models/demos/llama3/tt/llama_common.py +++ b/models/demos/llama3/tt/llama_common.py @@ -44,15 +44,34 @@ def encode_prompt_llama_instruct(tokenizer, prompt_text, system_prompt_text=None return begin_of_text + system_prompt + user_prompt + assistant_reply -def apply_scaling(freqs: torch.Tensor, scale_factor: float = 8): - # Llama-3.x specific scaling +def encode_prompt_hf(tokenizer, prompt_text, system_prompt_text=None): + """See https://huggingface.co/docs/transformers/main/en/chat_templating""" + chat = [] + if system_prompt_text: + chat.append({"role": "system", "content": system_prompt_text}) + if prompt_text: + chat.append({"role": "user", "content": prompt_text}) + return tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True) + + +def encode_prompt_hf(tokenizer, prompt_text, system_prompt_text=None): + """See https://huggingface.co/docs/transformers/main/en/chat_templating""" + chat = [] + if system_prompt_text: + chat.append({"role": "system", "content": system_prompt_text}) + if prompt_text: + chat.append({"role": "user", "content": prompt_text}) + return tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True) + + +def apply_scaling(freqs: torch.Tensor, scale_factor: float, orig_context_len: int): + # FIXME: Llama-3.x specific scaling - we need to support yarn for Qwen2.5 models # Values obtained from grid search low_freq_factor = 1 high_freq_factor = 4 - old_context_len = 8192 # original llama3 length - low_freq_wavelen = old_context_len / low_freq_factor - high_freq_wavelen = old_context_len / high_freq_factor + low_freq_wavelen = orig_context_len / low_freq_factor + high_freq_wavelen = orig_context_len / high_freq_factor new_freqs = [] for freq in freqs: wavelen = 2 * math.pi / freq @@ -62,12 +81,12 @@ def apply_scaling(freqs: torch.Tensor, scale_factor: float = 8): new_freqs.append(freq / scale_factor) else: assert low_freq_wavelen != high_freq_wavelen - smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + smooth = (orig_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) new_freqs.append((1 - smooth) * freq / scale_factor + smooth * freq) return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) -def precompute_freqs(dim: int, end: int, theta: float = 500000.0, use_scaled: bool = True, scale_factor: float = 8): +def precompute_freqs(dim: int, end: int, theta, scale_factor, orig_context_len): """ Precompute the frequency tensor for sine and cosine values with given dimensions. @@ -81,8 +100,8 @@ def precompute_freqs(dim: int, end: int, theta: float = 500000.0, use_scaled: bo """ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end) - if use_scaled: - freqs = apply_scaling(freqs, scale_factor) + if scale_factor is not None: + freqs = apply_scaling(freqs, scale_factor, orig_context_len) freqs = torch.outer(t, freqs).float() return torch.cos(freqs), torch.sin(freqs) @@ -112,8 +131,10 @@ def gather_cos_sin(position_ids, cos, sin): return cos, sin -def get_prefill_rot_mat(head_dim, max_seq_len, mesh_device, seq_len, scale_factor, start_pos=0): - cos, sin = precompute_freqs(head_dim, max_seq_len * 2, scale_factor=scale_factor) +def get_prefill_rot_mat( + head_dim, max_seq_len, mesh_device, seq_len, theta, scale_factor, orig_context_len, start_pos=0 +): + cos, sin = precompute_freqs(head_dim, max_seq_len * 2, theta, scale_factor, orig_context_len) cos_gathered, sin_gathered = gather_cos_sin(torch.arange(start_pos, start_pos + seq_len), cos, sin) assert cos_gathered.size() == (1, 1, seq_len, head_dim) assert sin_gathered.size() == (1, 1, seq_len, head_dim) @@ -151,14 +172,15 @@ def get_single_rot_mat( dhead, mesh_device, num_devices, - start_pos=0, - theta: float = 500000.0, - use_scaled=True, + start_pos, + theta, + scale_factor, + orig_context_len, on_host=False, ): freqs_unscaled = 1.0 / (theta ** (torch.arange(0, dhead, 2)[: (dhead // 2)].float() / dhead)) - if use_scaled: - freqs = apply_scaling(freqs_unscaled) + if scale_factor is not None: + freqs = apply_scaling(freqs_unscaled, scale_factor, orig_context_len) sin_freqs, cos_freqs = torch.sin(freqs), torch.cos(freqs) rot_matrix = torch.zeros(dhead, dhead) rot_matrix[torch.arange(0, dhead, 2), torch.arange(0, dhead, 2)] = cos_freqs.clone() @@ -169,8 +191,8 @@ def get_single_rot_mat( # Support for start_pos different than 0 freqs = start_pos * freqs_unscaled - if use_scaled: - freqs = apply_scaling(freqs) + if scale_factor is not None: + freqs = apply_scaling(freqs, scale_factor, orig_context_len) sin_freqs, cos_freqs = torch.sin(freqs), torch.cos(freqs) current_rot_mat = torch.zeros(dhead, dhead) current_rot_mat[torch.arange(0, dhead, 2), torch.arange(0, dhead, 2)] = cos_freqs.clone() @@ -376,3 +398,40 @@ def get_max_prefill_chunk_size(seq_len, max_prefill_seq_len): return chunk_size raise ValueError("No valid chunk size found") + + +def nearest_multiple(x, multiple_of): + return math.ceil(x / multiple_of) * multiple_of + + +def pad_to_size(x: torch.Tensor, dim: int, size: int) -> torch.Tensor: + """ + Pads the specified dimension of the input tensor with zeros + + :param x: Input PyTorch Tensor + :param dim: The dimension to pad + :param size: The size to pad to + :return: Padded PyTorch Tensor + """ + # handle negative dim + if dim < 0: + dim = x.dim() + dim + assert isinstance(x, torch.Tensor), "Input must be a torch.Tensor" + assert -x.dim() <= dim < x.dim(), f"Dimension out of range (expected between {-x.dim()} and {x.dim()-1})" + dim = x.dim() + dim if dim < 0 else dim + + current_size = x.size(dim) + pad_size = size - current_size + + if pad_size == 0: + return x # No padding needed + + # Prepare the padding configuration for F.pad + # F.pad expects padding in the form (pad_last_dim_left, pad_last_dim_right, ..., pad_dim_left, pad_dim_right) + # We only pad on the "end" side of the specified dimension + pad = [0] * (2 * x.dim()) # Initialize padding for all dimensions + pad_index = 2 * (x.dim() - dim - 1) + pad[pad_index + 1] = pad_size # Pad on the "right" side of the specified dimension + + padded_x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + return padded_x diff --git a/models/demos/llama3/tt/llama_decoder.py b/models/demos/llama3/tt/llama_decoder.py index 96116cc6340..58404ec1e09 100644 --- a/models/demos/llama3/tt/llama_decoder.py +++ b/models/demos/llama3/tt/llama_decoder.py @@ -72,6 +72,7 @@ def __init__( is_distributed=self.args.is_distributed_norm, sharded_program_config=self.model_config["SHARDED_NORM_ATTN_PRGM_CFG"], sharded_output_config=self.model_config["SHARDED_ATTN_INPUT_MEMCFG"], + ccl_topology=self.args.ccl_topology(), ), args, TG=args.is_galaxy, @@ -88,6 +89,7 @@ def __init__( is_distributed=self.args.is_distributed_norm, sharded_program_config=self.model_config["SHARDED_NORM_MLP_PRGM_CFG"], sharded_output_config=self.model_config["SHARDED_MLP_INPUT_MEMCFG"], + ccl_topology=self.args.ccl_topology(), ), args, TG=args.is_galaxy, diff --git a/models/demos/llama3/tt/llama_mlp.py b/models/demos/llama3/tt/llama_mlp.py index 31a845052d1..4ea55b8865b 100644 --- a/models/demos/llama3/tt/llama_mlp.py +++ b/models/demos/llama3/tt/llama_mlp.py @@ -5,6 +5,7 @@ import torch import ttnn from models.common.lightweightmodule import LightweightModule +from models.demos.llama3.tt.llama_common import pad_to_size from models.demos.llama3.tt.llama_ccl import tt_all_reduce @@ -21,41 +22,44 @@ def __init__( self.model_config = model_config state_dict_prefix = state_dict_prefix or args.get_state_dict_prefix(self.__class__.__name__, layer_num) torch_weight = lambda name: torch.transpose(self.state_dict[f"{state_dict_prefix}.{name}.weight"], -2, -1) + pad_hidden_dim = lambda tensor, dim: pad_to_size(tensor, dim=dim, size=args.hidden_dim) + # If pading was applied (e.g. via env var), add the unpadded hidden dim to the cache name to avoid loading incorrect weights + hidden_dim_string = f".hidden_dim_{args.hidden_dim}" if args.hidden_dim != args.unpadded_hidden_dim else "" if args.dummy_weights: cache_name = lambda _: None else: - cache_name = lambda name: weight_cache_path / (state_dict_prefix + f".{name}") + cache_name = lambda name: weight_cache_path / f"{state_dict_prefix}.{name}{hidden_dim_string}" w1_w3_mem_config = args.create_dram_sharded_mem_config(args.dim, args.hidden_dim // args.num_devices) w2_mem_config = args.create_dram_sharded_mem_config(args.hidden_dim // args.num_devices, args.dim) # TODO Clean up this code. With sharding, we load the normal weights and then shard them - as_sharded_tensor = lambda name, type, dim: ttnn.as_tensor( - torch_weight(name[:2]), # Grab only the wX part of the name + as_sharded_tensor = lambda name, type, dims: ttnn.as_tensor( + pad_hidden_dim( + torch_weight(name[:2]), dims[0] if args.is_galaxy else dims[-1] + ), # Grab only the wX part of the name dtype=type, device=self.mesh_device, - mesh_mapper=ttnn.ShardTensor2dMesh(self.mesh_device, dims=dim, mesh_shape=args.cluster_shape), + mesh_mapper=ttnn.ShardTensor2dMesh(self.mesh_device, dims=dims, mesh_shape=args.cluster_shape), layout=ttnn.TILE_LAYOUT, - memory_config=ttnn.DRAM_MEMORY_CONFIG - if args.is_galaxy - else w2_mem_config - if "w2" in name - else w1_w3_mem_config, + memory_config=( + ttnn.DRAM_MEMORY_CONFIG if args.is_galaxy else w2_mem_config if "w2" in name else w1_w3_mem_config + ), cache_file_name=cache_name(name), ) self.four_bit_mlp = args.optimizations.bfp4_mlp # Sharded weights - w1_dim = (-1, -2) if args.is_galaxy else (-2, -1) - w2_dim = (-2, -1) if args.is_galaxy else (-1, -2) + w1_dims = (-1, -2) if args.is_galaxy else (-2, -1) + w2_dims = (-2, -1) if args.is_galaxy else (-1, -2) self.w1 = as_sharded_tensor( - "w1_sharded", ttnn.bfloat4_b if self.four_bit_mlp else ttnn.bfloat8_b, dim=w1_dim + "w1_sharded", ttnn.bfloat4_b if self.four_bit_mlp else ttnn.bfloat8_b, dims=w1_dims ) # bfp4 normally ok here but sub .99 pcc for llama 3.1 weights - self.w2 = as_sharded_tensor("w2_sharded", ttnn.bfloat8_b, dim=w2_dim) - self.w3 = as_sharded_tensor("w3_sharded", ttnn.bfloat4_b if self.four_bit_mlp else ttnn.bfloat8_b, dim=w1_dim) + self.w2 = as_sharded_tensor("w2_sharded", ttnn.bfloat8_b, dims=w2_dims) + self.w3 = as_sharded_tensor("w3_sharded", ttnn.bfloat4_b if self.four_bit_mlp else ttnn.bfloat8_b, dims=w1_dims) def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: """ @@ -89,10 +93,12 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: w1_out = ttnn.linear( x, self.w1, - compute_kernel_config=self.args.compute_kernel_config_lofi - if self.four_bit_mlp - else self.args.compute_kernel_config_hifi2_fp16, - core_grid=ttnn.CoreGrid(y=8, x=8) if not pc_1 else None, + compute_kernel_config=( + self.args.compute_kernel_config_lofi + if self.four_bit_mlp + else self.args.compute_kernel_config_hifi2_fp16 + ), + core_grid=None, # FIXME: validate on TG ttnn.CoreGrid(y=8, x=8) if not pc_1 else None, dtype=ttnn.bfloat8_b if TG else ttnn.bfloat16, program_config=pc_1, memory_config=x.memory_config(), @@ -101,11 +107,13 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: w3_out = ttnn.linear( x, self.w3, - compute_kernel_config=self.args.compute_kernel_config_lofi - if self.four_bit_mlp - else self.args.compute_kernel_config_hifi2_fp16, - core_grid=ttnn.CoreGrid(y=8, x=8) if not pc_3 else None, - dtype=ttnn.bfloat8_b if TG else ttnn.bfloat16, + compute_kernel_config=( + self.args.compute_kernel_config_lofi + if self.four_bit_mlp + else self.args.compute_kernel_config_hifi2_fp16 + ), + core_grid=None, # FIXME: validate on TG ttnn.CoreGrid(y=8, x=8) if not pc_3 else None, + dtype=ttnn.bfloat16, program_config=pc_3, memory_config=x.memory_config(), ) @@ -144,6 +152,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: cluster_axis=1, num_all_gather_links=2, sharded=True if mode == "decode" else False, + topology=self.args.ccl_topology(), memory_config=self.model_config["FF1_OUT_GATHERED_MEMCFG"] if mode == "decode" else None, ) w3_out = tt_all_reduce( @@ -152,6 +161,7 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: cluster_axis=1, num_all_gather_links=2, sharded=True if mode == "decode" else False, + topology=self.args.ccl_topology(), memory_config=self.model_config["FF1_OUT_GATHERED_MEMCFG"] if mode == "decode" else None, ) @@ -188,10 +198,12 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: compute_kernel_config=self.args.compute_kernel_config_hifi2_fp16, dtype=self.args.ccl_dtype if TG else ttnn.bfloat16, program_config=pc_2, - memory_config=(ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG) - if TG - else w2_in.memory_config(), - core_grid=ttnn.CoreGrid(y=8, x=8) if not pc_2 else None, + memory_config=( + (ttnn.L1_WIDTH_SHARDED_MEMORY_CONFIG if mode == "decode" else ttnn.DRAM_MEMORY_CONFIG) + if TG + else w2_in.memory_config() + ), + core_grid=None, # FIXME: validate on TG ttnn.CoreGrid(y=8, x=8) if not pc_2 else None, ) ttnn.deallocate(w2_in) # if mode == "decode" and not TG: @@ -204,11 +216,14 @@ def forward(self, x: ttnn.Tensor, mode) -> ttnn.Tensor: num_reduce_scatter_links=self.args.num_reduce_scatter_links, num_all_gather_links=self.args.num_all_gather_links, sharded=(mode == "decode"), - memory_config=(self.model_config["FF2_OUT_REDUCE_SCATTER_MEMCFG"] if TG else w2_out.memory_config()) - if mode == "decode" - else ttnn.DRAM_MEMORY_CONFIG, + memory_config=( + (self.model_config["FF2_OUT_REDUCE_SCATTER_MEMCFG"] if TG else w2_out.memory_config()) + if mode == "decode" + else ttnn.DRAM_MEMORY_CONFIG + ), dtype=self.args.ccl_dtype, use_composite=True if self.dim == 8192 else False, + topology=self.args.ccl_topology(), ) # Ensure dim 0 and 1 are 1 diff --git a/models/demos/llama3/tt/llama_model.py b/models/demos/llama3/tt/llama_model.py index 3b784ad0bbb..8a909981efb 100644 --- a/models/demos/llama3/tt/llama_model.py +++ b/models/demos/llama3/tt/llama_model.py @@ -2,8 +2,6 @@ # SPDX-License-Identifier: Apache-2.0 -import os -import math import ttnn import torch import torch.nn as nn @@ -11,11 +9,10 @@ from models.demos.llama3.tt.llama_decoder import TtTransformerBlock from models.common.rmsnorm import RMSNorm import ttnn -from typing import Optional from models.common.lightweightmodule import LightweightModule from models.demos.llama3.tt.distributed_norm import DistributedNorm from models.demos.llama3.tt.lm_head import LMHead -from models.demos.llama3.tt.llama_common import copy_host_to_device, get_prefill_rot_mat, HostEmbedding +from models.demos.llama3.tt.llama_common import copy_host_to_device, get_prefill_rot_mat from models.demos.llama3.tt.llama_rope import TtLlamaRotarySetup from models.demos.llama3.tt.llama_embedding import TtLlamaEmbedding @@ -56,8 +53,8 @@ def __init__( args.head_dim, args.max_seq_len, args.rope_theta, - args.use_scaled_rope, args.rope_scaling_factor, + args.orig_context_len, ) self.trans_mats_dict = self.rope_setup.get_both_trans_mats() @@ -87,6 +84,7 @@ def __init__( is_distributed=self.args.is_distributed_norm, sharded_program_config=self.model_config["SHARDED_NORM_LM_HEAD_PRGM_CFG"], sharded_output_config=self.model_config["LM_HEAD_INPUT_MEMCFG"], + ccl_topology=self.args.ccl_topology(), ), args, args.is_galaxy, @@ -124,8 +122,10 @@ def prepare_inputs_prefill(self, tokens, start_pos=0, page_table=None, chunk_pag self.args.head_dim, self.args.max_seq_len, self.mesh_device, - seq_len=S, - scale_factor=self.args.rope_scaling_factor, + S, + self.args.rope_theta, + self.args.rope_scaling_factor, + self.args.orig_context_len, start_pos=start_pos, ) diff --git a/models/demos/llama3/tt/llama_rope.py b/models/demos/llama3/tt/llama_rope.py index 06406a4eb2d..4b395c3eec5 100644 --- a/models/demos/llama3/tt/llama_rope.py +++ b/models/demos/llama3/tt/llama_rope.py @@ -11,8 +11,8 @@ from loguru import logger -def compute_gather_cos_sin(dhead, end, theta, position_ids, use_scaled_rope, scale_factor): - cos, sin = precompute_freqs(dhead, end, theta, use_scaled_rope, scale_factor) +def compute_gather_cos_sin(dhead, end, theta, scale_factor, orig_context_len, position_ids): + cos, sin = precompute_freqs(dhead, end, theta, scale_factor, orig_context_len) return gather_cos_sin(position_ids, cos, sin) @@ -23,9 +23,9 @@ def __init__( batch_size: int, head_dim: int, max_seq_len: int, - rope_theta: float = 10000, - use_scaled_rope: bool = False, - scale_factor: float = 8, + rope_theta: float, + scale_factor: float, # use None to disable rope scaling + orig_context_len: int, # only used if scaling enabled datatype=ttnn.bfloat16, ): super().__init__() @@ -40,16 +40,15 @@ def __init__( else: self.batch_size_per_device_group = self.batch_size self.core_grid = device.compute_with_storage_grid_size() - num_cores = self.core_grid.x * self.core_grid.y # Generate the cos/sin matrices needed for ttnn.embedding op cos_matrix, sin_matrix = compute_gather_cos_sin( dhead=head_dim, end=max_seq_len * 2, theta=rope_theta, - position_ids=torch.arange(max_seq_len), - use_scaled_rope=use_scaled_rope, scale_factor=scale_factor, + orig_context_len=orig_context_len, + position_ids=torch.arange(max_seq_len), ) self.cos_matrix = ttnn.from_torch( @@ -73,7 +72,7 @@ def __init__( 1, 1, batch_size, - 1 + 1, # 1, 1, num_cores, 1 ) # Repeat across all cores on device trans_mat_mem_config = ttnn.create_sharded_memory_config( @@ -89,13 +88,15 @@ def __init__( layout=ttnn.TILE_LAYOUT, dtype=datatype, memory_config=trans_mat_mem_config, - mesh_mapper=ShardTensor2dMesh( - device, - dims=(None, 2) if (self.num_devices == 32 and batch_size > 1) else (None, None), - mesh_shape=list(device.shape), - ) - if self.is_mesh_device - else None, + mesh_mapper=( + ShardTensor2dMesh( + device, + dims=(None, 2) if (self.num_devices == 32 and batch_size > 1) else (None, None), + mesh_shape=list(device.shape), + ) + if self.is_mesh_device + else None + ), ) # TODO: Colman, should this be TILE_SIZE or head_dim? Why should it be different for prefill and decode? diff --git a/models/demos/llama3/tt/lm_head.py b/models/demos/llama3/tt/lm_head.py index bd5cbe6ba8f..a79f8856e66 100644 --- a/models/demos/llama3/tt/lm_head.py +++ b/models/demos/llama3/tt/lm_head.py @@ -103,13 +103,15 @@ def __init__( ) if args.is_galaxy: self.program_configs = [ - None - if args.dim == 2048 - else args.dram_matmul_config( - args.tile_padded_batch_rows, # (8k, 128k) -> (2k, 16k) - args.dim // 4, - 16 * 1024, - args.lm_head_core_grid.num_cores, + ( + None + if args.dim == 2048 + else args.dram_matmul_config( + args.tile_padded_batch_rows, # (8k, 128k) -> (2k, 16k) + args.dim // 4, + 16 * 1024, + args.lm_head_core_grid.num_cores, + ) ) ] diff --git a/models/demos/llama3/tt/load_checkpoints.py b/models/demos/llama3/tt/load_checkpoints.py new file mode 100644 index 00000000000..7e330a2e18d --- /dev/null +++ b/models/demos/llama3/tt/load_checkpoints.py @@ -0,0 +1,303 @@ +# SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. + +# SPDX-License-Identifier: Apache-2.0 + +import os +import torch +from safetensors.torch import load_file as safetensors_load_file +from tqdm import tqdm +import json +from pathlib import Path +from loguru import logger + + +# TODO Update function for large models: For 1 layer tests we only want to load 1 checkpoint file, instead of all. +def load_hf_state_dict(ckpt_dir): + # First check if index file exists + index_path = os.path.join(ckpt_dir, "model.safetensors.index.json") + if os.path.exists(index_path): + # Multi-file case: Read the index file and load all referenced safetensor files + with open(index_path, "r") as f: + index_data = json.load(f) + + # Retrieve the weight file names from the index JSON + weight_map = index_data["weight_map"] + safetensor_files = set(weight_map.values()) + + # Read each safetensors file mentioned in the index + loaded_weights = {} + for file in safetensor_files: + safetensor_path = os.path.join(ckpt_dir, file) + weights = safetensors_load_file(safetensor_path) + loaded_weights.update(weights) # Merge weights into a single dictionary + else: + # Single-file case: Load the single model.safetensors file + safetensor_path = os.path.join(ckpt_dir, "model.safetensors") + if not os.path.exists(safetensor_path): + raise FileNotFoundError(f"Neither model.safetensors.index.json nor model.safetensors found in {ckpt_dir}") + loaded_weights = safetensors_load_file(safetensor_path) + + if not "lm_head.weight" in loaded_weights: + # Assume tied to the embeddings if not present + loaded_weights["lm_head.weight"] = loaded_weights["model.embed_tokens.weight"] + + return loaded_weights + + +def convert_hf_to_meta(state_dict, head_dim): + state_dict = convert_hf_qkv_to_meta_format(state_dict, head_dim) + state_dict = map_hf_to_meta_keys(state_dict) + return state_dict + + +def map_hf_to_meta_keys(loaded_weights): + hf_to_meta = { + # Top level mappings + "model.embed_tokens.weight": "tok_embeddings.weight", + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + # Layer level mappings + "input_layernorm.weight": "attention_norm.weight", + "post_attention_layernorm.weight": "ffn_norm.weight", + # Attention module mappings + "self_attn.q_proj.weight": "attention.wq.weight", + "self_attn.k_proj.weight": "attention.wk.weight", + "self_attn.v_proj.weight": "attention.wv.weight", + "self_attn.o_proj.weight": "attention.wo.weight", + "self_attn.q_proj.bias": "attention.wq.bias", + "self_attn.k_proj.bias": "attention.wk.bias", + "self_attn.v_proj.bias": "attention.wv.bias", + # Feed forward module mappings + "mlp.gate_proj.weight": "feed_forward.w1.weight", + "mlp.up_proj.weight": "feed_forward.w3.weight", + "mlp.down_proj.weight": "feed_forward.w2.weight", + # Direct module mappings + "gate_proj.weight": "w1.weight", + "down_proj.weight": "w2.weight", + "up_proj.weight": "w3.weight", + "q_proj.weight": "wq.weight", + "k_proj.weight": "wk.weight", + "v_proj.weight": "wv.weight", + "o_proj.weight": "wo.weight", + "q_proj.bias": "wq.bias", + "k_proj.bias": "wk.bias", + "v_proj.bias": "wv.bias", + "weight": "emb.weight", # For host embeddings + # Full path layer mappings + "model.layers.{layer}.input_layernorm.weight": "layers.{layer}.attention_norm.weight", + "model.layers.{layer}.post_attention_layernorm.weight": "layers.{layer}.ffn_norm.weight", + "model.layers.{layer}.self_attn.q_proj.weight": "layers.{layer}.attention.wq.weight", + "model.layers.{layer}.self_attn.k_proj.weight": "layers.{layer}.attention.wk.weight", + "model.layers.{layer}.self_attn.v_proj.weight": "layers.{layer}.attention.wv.weight", + "model.layers.{layer}.self_attn.o_proj.weight": "layers.{layer}.attention.wo.weight", + "model.layers.{layer}.self_attn.q_proj.bias": "layers.{layer}.attention.wq.bias", + "model.layers.{layer}.self_attn.k_proj.bias": "layers.{layer}.attention.wk.bias", + "model.layers.{layer}.self_attn.v_proj.bias": "layers.{layer}.attention.wv.bias", + "model.layers.{layer}.mlp.gate_proj.weight": "layers.{layer}.feed_forward.w1.weight", + "model.layers.{layer}.mlp.up_proj.weight": "layers.{layer}.feed_forward.w3.weight", + "model.layers.{layer}.mlp.down_proj.weight": "layers.{layer}.feed_forward.w2.weight", + } + + meta_state_dict = {} + for key, tensor in loaded_weights.items(): + if key in hf_to_meta: + # Direct match for top-level keys + meta_state_dict[hf_to_meta[key]] = tensor + elif "model.layers." in key: + # Extract layer number and form a template key + parts = key.split(".") + layer_num = parts[2] # e.g. "0" in "model.layers.0.input_layernorm.weight" + template_key = "model.layers.{layer}." + ".".join(parts[3:]) + if template_key in hf_to_meta: + meta_state_dict[hf_to_meta[template_key].format(layer=layer_num)] = tensor + + return meta_state_dict + + +def load_meta_state_dict(ckpt_dir, n_layers=None, start_layer_idx=0): + checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) + assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" + is_chunked = "layers_" in str(checkpoints[0]) + if is_chunked: + checkpoint = load_chunked_checkpoints(checkpoints, n_layers, start_layer_idx) + else: + checkpoint = load_sharded_checkpoints(checkpoints, n_layers) + + return checkpoint + + +def load_chunked_checkpoints(checkpoints, n_layers, start_layer_idx): + checkpoint = {} + + (f"Loading {len(checkpoints)} checkpoint files") + for ckpt in tqdm(checkpoints): + if n_layers: + # Layer range is in the file name, like layers_start-end.pth + layer_range = ckpt.stem.split("_")[1] + start_layer, end_layer = map(int, layer_range.split("-")) + if start_layer > n_layers + start_layer_idx: + continue + if end_layer < start_layer_idx: + continue + + loaded_ckpt = torch.load(ckpt, map_location="cpu") + checkpoint.update(loaded_ckpt) + return checkpoint + + +def load_sharded_checkpoints(checkpoints, n_layers): + checkpoint = {} + logger.info(f"Loading {len(checkpoints)} checkpoint files") + for ckpt in tqdm(checkpoints): + loaded_ckpt = torch.load(ckpt, map_location="cpu") + for ( + key, + value, + ) in loaded_ckpt.items(): + if "layers." in key: + layer_num = int(key.split("layers.")[1].split(".")[0]) + if n_layers and layer_num >= n_layers: + continue + if key in checkpoint: + checkpoint[key] += [value] + else: + checkpoint[key] = [value] + del loaded_ckpt + + # concat checkpoint values + for key, value in checkpoint.items(): + if len(value) == 1 or "norm" in key: + checkpoint[key] = value[0] + else: + if key == "tok_embeddings.weight" or key == "output.weight": + assert value[0].shape[1] == 8192 # FIXME: do we need this hardcoded shape? + # Concatenate along dimension 0 for llama3 token embeddings weight and lm head + checkpoint[key] = torch.cat(value, dim=0) + else: + # cat_dim is index of the smallest dimension in value[0].shape + cat_dim = torch.argmin(torch.tensor(value[0].shape)) + checkpoint[key] = torch.cat(value, dim=cat_dim) + + return checkpoint + + +def convert_hf_qkv_to_meta_format(loaded_weights, head_dim): + """Convert HuggingFace QKV weights to Meta format for RoPE compatibility.""" + converted_weights = {} + for key, tensor in loaded_weights.items(): + if "q_proj.weight" in key or "k_proj.weight" in key: + # For weights: n_heads = tensor.shape[0] // head_dim + n_heads = tensor.shape[0] // head_dim + converted_weights[key] = reverse_permute(tensor, n_heads, tensor.shape[0], tensor.shape[1]) + elif "q_proj.bias" in key or "k_proj.bias" in key: + # For biases: n_heads = tensor.shape[0] // head_dim + n_heads = tensor.shape[0] // head_dim + converted_weights[key] = reverse_permute(tensor, n_heads, tensor.shape[0], 1).squeeze(-1) + else: + # Keep all other weights unchanged + converted_weights[key] = tensor + return converted_weights + + +def convert_meta_to_hf(state_dict, head_dim): + state_dict = convert_meta_qkv_to_hf_format(state_dict, head_dim) + state_dict = map_meta_to_hf_keys(state_dict) + return state_dict + + +def map_meta_to_hf_keys(loaded_weights): + # Define mappings at each level of the hierarchy + meta_to_hf_mappings = { + # Top level + "tok_embeddings.weight": "model.embed_tokens.weight", + "norm.weight": "model.norm.weight", + "output.weight": "lm_head.weight", + # Layer level + "attention_norm.weight": "input_layernorm.weight", + "ffn_norm.weight": "post_attention_layernorm.weight", + # Attention module + "attention.wq.weight": "self_attn.q_proj.weight", + "attention.wk.weight": "self_attn.k_proj.weight", + "attention.wv.weight": "self_attn.v_proj.weight", + "attention.wo.weight": "self_attn.o_proj.weight", + "attention.wq.bias": "self_attn.q_proj.bias", + "attention.wk.bias": "self_attn.k_proj.bias", + "attention.wv.bias": "self_attn.v_proj.bias", + # Feed forward module + "feed_forward.w1.weight": "mlp.gate_proj.weight", + "feed_forward.w3.weight": "mlp.up_proj.weight", + "feed_forward.w2.weight": "mlp.down_proj.weight", + # Direct mappings for when we get just the final components + "w1.weight": "gate_proj.weight", + "w2.weight": "down_proj.weight", + "w3.weight": "up_proj.weight", + "wq.weight": "q_proj.weight", + "wk.weight": "k_proj.weight", + "wv.weight": "v_proj.weight", + "wo.weight": "o_proj.weight", + "wq.bias": "q_proj.bias", + "wk.bias": "k_proj.bias", + "wv.bias": "v_proj.bias", + # Host embeddings + "emb.weight": "weight", + } + + hf_state_dict = {} + for key, tensor in loaded_weights.items(): + # Handle full model paths with layer numbers + if "layers." in key: + parts = key.split(".") + layer_num = parts[1] + remainder = ".".join(parts[2:]) + if remainder in meta_to_hf_mappings: + new_key = f"model.layers.{layer_num}.{meta_to_hf_mappings[remainder]}" + hf_state_dict[new_key] = tensor + continue + + # Try exact matches first + if key in meta_to_hf_mappings: + hf_state_dict[meta_to_hf_mappings[key]] = tensor + continue + + # For submodule state dicts, try matching the end of the key + matched = False + for meta_pattern, hf_pattern in meta_to_hf_mappings.items(): + if key.endswith(meta_pattern): + # Replace only the matching part at the end + prefix = key[: -len(meta_pattern)] + new_key = prefix + hf_pattern + hf_state_dict[new_key] = tensor + matched = True + break + + # If no mapping found, keep the original key + if not matched: + hf_state_dict[key] = tensor + + return hf_state_dict + + +def convert_meta_qkv_to_hf_format(loaded_weights, head_dim): + """Convert Meta QKV weights back to HuggingFace format.""" + converted_weights = {} + for key, tensor in loaded_weights.items(): + if "wq.weight" in key or "wk.weight" in key: + # For weights: n_heads = tensor.shape[0] // head_dim + n_heads = tensor.shape[0] // head_dim + converted_weights[key] = permute(tensor, n_heads, tensor.shape[0], tensor.shape[1]) + elif "wq.bias" in key or "wk.bias" in key: + # For biases: n_heads = tensor.shape[0] // head_dim + n_heads = tensor.shape[0] // head_dim + converted_weights[key] = permute(tensor.unsqueeze(-1), n_heads, tensor.shape[0], 1).squeeze(-1) + else: + # Keep all other weights unchanged + converted_weights[key] = tensor + return converted_weights + + +def reverse_permute(tensor, n_heads, dim1, dim2): + return tensor.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2) + + +def permute(tensor, n_heads, dim1, dim2): + return tensor.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2) diff --git a/models/demos/llama3/tt/model_config.py b/models/demos/llama3/tt/model_config.py index 0002654966a..6c91825dbbc 100644 --- a/models/demos/llama3/tt/model_config.py +++ b/models/demos/llama3/tt/model_config.py @@ -16,12 +16,22 @@ num_to_core_range_set, calculate_hidden_dim, get_out_subblock_w, + encode_prompt_llama_instruct, + encode_prompt_hf, + nearest_multiple, ) from typing import Tuple from models.utility_functions import nearest_32 from pathlib import Path -from tqdm import tqdm from dataclasses import dataclass +from enum import Enum, auto +from models.demos.llama3.tt.load_checkpoints import ( + load_meta_state_dict, + load_hf_state_dict, + convert_hf_to_meta, + convert_meta_to_hf, + reverse_permute, +) @dataclass @@ -35,9 +45,10 @@ class LlamaOptimizations: @classmethod def accuracy(cls, model_name): """Configuration optimized for accuracy - Only 3.1-70B uses bfp4 MLPs in this configuration + Only 70B models uses bfp4 MLPs in this configuration """ - return cls(bfp4_mlp=model_name == "3.1-70B") + bfp4 = model_name in ["Llama3.1-70B", "DeepSeek-R1-Distill-Llama-70B", "Qwen2.5-72B"] + return cls(bfp4_mlp=bfp4) @classmethod def performance(cls, model_name): @@ -47,6 +58,11 @@ def performance(cls, model_name): return cls(bfp4_mlp=True) +class CheckpointType(Enum): + Meta = auto() + HuggingFace = auto() + + class TtModelArgs: OP_KEYS = ( # Embedding @@ -92,7 +108,7 @@ def __init__( ): self.num_devices = mesh_device.get_num_devices() if mesh_device else 0 self.mesh_device = mesh_device - self.device_name = {0: "CPU", 1: "N150", 2: "N300", 8: "T3K", 32: "TG"}[self.num_devices] + self.device_name = {0: "CPU", 1: "N150", 2: "N300", 4: "N150x4", 8: "T3K", 32: "TG"}[self.num_devices] self.model_name = "Unknown" # Llama model name will be dependent on the checkpoint directory self.max_seq_len = max_seq_len self.max_batch_size = max_batch_size @@ -108,6 +124,7 @@ def __init__( self.DEFAULT_CKPT_DIR = LLAMA_DIR self.DEFAULT_TOKENIZER_PATH = LLAMA_DIR self.DEFAULT_CACHE_PATH = os.path.join(LLAMA_DIR, self.device_name) + self.model_name = os.path.basename(LLAMA_DIR) # May be overridden by config else: assert "Please set $LLAMA_DIR to a valid checkpoint directory" @@ -116,14 +133,7 @@ def __init__( assert os.path.exists( self.DEFAULT_CKPT_DIR ), f"Checkpoint directory {self.DEFAULT_CKPT_DIR} does not exist, please set LLAMA_DIR=... or LLAMA_CKPT_DIR=..." - assert os.path.isfile( - self.DEFAULT_TOKENIZER_PATH + "/tokenizer.model" - ), f"Tokenizer file {self.DEFAULT_TOKENIZER_PATH + '/tokenizer.model'} does not exist, please set LLAMA_TOKENIZER_PATH=..." - if not os.path.exists(self.DEFAULT_CACHE_PATH): - os.makedirs(self.DEFAULT_CACHE_PATH) - assert os.path.exists( - self.DEFAULT_CACHE_PATH - ), f"Cache directory {self.DEFAULT_CACHE_PATH} does not exist, please set LLAMA_CACHE_PATH=..." + os.makedirs(self.DEFAULT_CACHE_PATH, exist_ok=True) # Check if weights exist in the specified folder. If not warn the user to run the download and untar script. # assert os.path.isfile( # self.DEFAULT_CKPT_DIR + "/consolidated.00.pth" @@ -133,57 +143,6 @@ def __init__( logger.info(f"Tokenizer file: {self.DEFAULT_TOKENIZER_PATH + '/tokenizer.model'}") logger.info(f"Cache directory: {self.DEFAULT_CACHE_PATH}") - # Set the model name based on the checkpoint directory being loaded - if "3.2-1B" in LLAMA_DIR: - local_params = "LLAMA3_2_1B_PARAMS" - self.model_name = "3.2-1B" - self.rope_scaling_factor = 32 - elif "3.2-3B" in LLAMA_DIR: - local_params = "LLAMA3_2_3B_PARAMS" - self.model_name = "3.2-3B" - self.rope_scaling_factor = 32 - elif "3.1-8B" in LLAMA_DIR: - local_params = "LLAMA3_1_8B_PARAMS" - self.model_name = "3.1-8B" - self.rope_scaling_factor = 8 - elif "3.2-11B" in LLAMA_DIR: - local_params = "LLAMA3_2_11B_PARAMS" - self.model_name = "3.2-11B" - self.rope_scaling_factor = 8 # shared with 3.1-8B - elif "3.1-70B" in LLAMA_DIR: - local_params = "LLAMA3_1_70B_PARAMS" - self.model_name = "3.1-70B" - self.rope_scaling_factor = 8 - self.is_70b = True # self.dim == 8192 and self.n_layers == 80 - else: - # NOTE: 3.2-90B and 3.3-70B also use scaling factor of 8 - raise ValueError(f"Unsupported LLAMA model: {LLAMA_DIR}") - - # Set the max number of tokens for each prefill chunk based on the model and device - MAX_PREFILL_CHUNK_SIZES_DIV1024 = { - "3.2-1B": {"N150": 128, "N300": 128, "T3K": 128, "TG": 128}, - "3.2-3B": {"N150": 8, "N300": 128, "T3K": 128, "TG": 128}, - "3.1-8B": {"N150": 4, "N300": 64, "T3K": 128, "TG": 128}, - "3.2-11B": {"N150": 4, "N300": 64, "T3K": 128, "TG": 128}, - "3.1-70B": {"N150": None, "N300": None, "T3K": 32, "TG": 128}, - } - max_prefill_chunk_size_div1024 = MAX_PREFILL_CHUNK_SIZES_DIV1024[self.model_name][self.device_name] - assert ( - max_prefill_chunk_size_div1024 is not None - ), f"Unsupported model {self.model_name} on device {self.device_name}" - self.max_prefill_chunk_size = max_prefill_chunk_size_div1024 * 1024 - - if callable(optimizations): - self.optimizations = optimizations(self.model_name) - else: - self.optimizations = optimizations - - # Load model params - if not dummy_weights: - self._set_llama_params(self.DEFAULT_CKPT_DIR) - else: # With Dummy weights, set the params from the local copy inside the model folder. This is required for CI pipeline that doesn't mount the external folders. - self._set_llama_params(self.LOCAL_LLAMA_PARAMS[local_params]) - # Some consumers like SentencePiece only accept str not Path for files self.model_base_path = Path(self.DEFAULT_CKPT_DIR) self.model_cache_path = Path(self.DEFAULT_CACHE_PATH) @@ -196,6 +155,58 @@ def __init__( # If the weights file contain the keyword `instruct` also set self.instruct to true if "instruct" in self.DEFAULT_CACHE_PATH.lower(): self.instruct = True + + # Load model params + if not dummy_weights: + self.checkpoint_type = self.detect_checkpoint_type() + self._set_model_params(self.DEFAULT_CKPT_DIR) + else: # With Dummy weights, set the params from the local copy inside the model folder. This is required for CI pipeline that doesn't mount the external folders. + self.checkpoint_type = CheckpointType.Meta + if "3.2-1B" in self.DEFAULT_CKPT_DIR: + local_params = "LLAMA3_2_1B_PARAMS" + elif "3.2-3B" in self.DEFAULT_CKPT_DIR: + local_params = "LLAMA3_2_3B_PARAMS" + elif "3.1-8B" in self.DEFAULT_CKPT_DIR: + local_params = "LLAMA3_1_8B_PARAMS" + elif "3.2-11B" in self.DEFAULT_CKPT_DIR: + local_params = "LLAMA3_2_11B_PARAMS" + elif "3.1-70B" in self.DEFAULT_CKPT_DIR: + local_params = "LLAMA3_1_70B_PARAMS" + else: + raise ValueError( + f"No local params found for {self.DEFAULT_CKPT_DIR}, dummy weights are not supported for this model" + ) + self._set_model_params(self.LOCAL_LLAMA_PARAMS[local_params]) + + # Set the max number of tokens for each prefill chunk based on the model and device + max_prefill_chunk_size_div1024 = os.getenv("MAX_PREFILL_CHUNK_SIZE") + if max_prefill_chunk_size_div1024 is None: + MAX_PREFILL_CHUNK_SIZES_DIV1024 = { + "Llama3.2-1B": {"N150": 128, "N300": 128, "T3K": 128, "TG": 128}, + "Llama3.2-3B": {"N150": 8, "N300": 128, "T3K": 128, "TG": 128}, + "Llama3.1-8B": {"N150": 4, "N300": 64, "T3K": 128, "TG": 128}, + "Llama3.2-11B": {"N150": 4, "N300": 64, "T3K": 128, "TG": 128}, + "Llama3.1-70B": {"N150": None, "N300": None, "T3K": 32, "TG": 128}, + "DeepSeek-R1-Distill-Llama-70B": {"N150": None, "N300": None, "T3K": 32, "TG": 128}, + "Qwen2.5-7B": {"N150": 4, "N300": 64, "T3K": 128, "TG": 128}, + "Qwen2.5-72B": {"N150": None, "N300": None, "T3K": 32, "TG": 128}, + } + try: + max_prefill_chunk_size_div1024 = MAX_PREFILL_CHUNK_SIZES_DIV1024[self.base_model_name][self.device_name] + except KeyError: + raise ValueError( + f"Unknown model {self.model_name} on device {self.device_name}, try setting MAX_PREFILL_CHUNK_SIZE between 4 (compatible) and 128 (faster)" + ) + assert ( + max_prefill_chunk_size_div1024 is not None + ), f"Unsupported model {self.model_name} on device {self.device_name}" + self.max_prefill_chunk_size = max_prefill_chunk_size_div1024 * 1024 + + if callable(optimizations): + self.optimizations = optimizations(self.model_name) + else: + self.optimizations = optimizations + self.dummy_weights = dummy_weights self.tile_padded_batch_rows = self.tile_size * int(math.ceil(self.max_batch_size / self.tile_size)) @@ -215,10 +226,12 @@ def __init__( self.model_config.update({f"{key}_TILE": ttnn.TILE_LAYOUT for key in self.OP_KEYS if "LAYOUT" in key}) self.cos, self.sin = precompute_freqs( - self.head_dim, self.max_seq_len * 2, self.rope_theta, self.use_scaled_rope, self.rope_scaling_factor + self.head_dim, self.max_seq_len * 2, self.rope_theta, self.rope_scaling_factor, self.orig_context_len ) # for prefill self.rot_emb = freqs_to_rotation_matrix(self.cos, self.sin) # for decode + self.tokenizer = None if dummy_weights else self.create_tokenizer() + device = mesh_device.get_devices()[0] if mesh_device is not None else None self.cluster_shape = list(mesh_device.shape) self.is_galaxy = self.num_devices == 32 @@ -350,45 +363,61 @@ def find_largest_divisor(n, max_divisor=8): else: self.model_config["ATTN_ALL_GATHER_MATMUL_PROGCFG"] = None + prefill_rows = lambda seq_len: min(seq_len, 1024) // self.tile_size + mlp1_3_grid = lambda seq_len: ( + (8, min(min(seq_len, 1024) // 32, 4)) + if self.is_galaxy + else self.find_prefill_grid(prefill_rows(seq_len), self.dim // self.tile_size) + ) + mlp2_grid = lambda seq_len: ( + (8, min(min(seq_len, 1024) // 32, 4)) + if self.is_galaxy + else self.find_prefill_grid(prefill_rows(seq_len), self.hidden_dim // self.tile_size) + ) + self.model_config["PREFILL_MLP_W1_W3_PRG_CONFIG"] = lambda seq_len: self.matmul_config( m=min(seq_len, 1024), k=self.dim // self.cluster_shape[0], n=self.hidden_dim // self.cluster_shape[1], - grid_size=(8, min(min(seq_len, 1024) // 32, 4)) - if self.is_galaxy - else ((8, 8) if seq_len >= 1024 else (8, 4)), + grid_size=mlp1_3_grid(seq_len), ) self.model_config["PREFILL_MLP_W2_PRG_CONFIG"] = lambda seq_len: self.matmul_config( m=min(seq_len, 1024), k=self.hidden_dim // (self.cluster_shape[1] if self.is_galaxy else 1), n=self.dim, - grid_size=(8, min(min(seq_len, 1024) // 32, 4)) - if self.is_galaxy - else ((8, 8) if seq_len >= 1024 else (8, 4)), + grid_size=mlp2_grid(seq_len), ) + k_dim = self.dim // self.cluster_shape[0] if self.is_galaxy else self.dim + n_dim = self.dim // self.cluster_shape[1] if self.is_galaxy else self.dim + num_rows = lambda seq_len: min(seq_len, 1024 if self.is_galaxy else 2048) self.model_config["WO_PREFILL_PROGCFG"] = lambda seq_len: self.matmul_config( - m=min(seq_len, 1024 if self.is_galaxy else 2048), - k=self.dim // self.cluster_shape[0] if self.is_galaxy else self.dim, - n=self.dim // self.cluster_shape[1] if self.is_galaxy else self.dim, - grid_size=(8, 8), + m=num_rows(seq_len), + k=k_dim, + n=n_dim, + grid_size=self.find_prefill_grid(num_rows(seq_len), n_dim // self.tile_size), in0_block_w=1, fuse_batch=seq_len <= 1024, # if self.is_galaxy else 2048), ) - # Calculate largest number of lm_head_num_rows such that self.dim % (lm_head_num_rows * 8) == 0 + # Calculate largest number of lm_head_num_rows such that self.dim % (lm_head_num_rows * lm_head_cores_per_row) == 0 if self.num_devices == 32: lm_head_num_rows = 4 while self.dim % (32 * 32 * lm_head_num_rows) != 0: lm_head_num_rows -= 1 else: lm_head_num_rows = 8 - while self.dim % (32 * lm_head_num_rows * 8) != 0: - lm_head_num_rows -= 1 - assert ( - lm_head_num_rows > 0 - ), f"Could not find a lm_head_num_rows such that self.dim(={self.dim}) % (lm_head_num_rows * 4) == 0" - self.lm_head_core_grid = ttnn.CoreGrid(y=lm_head_num_rows, x=8) + lm_head_cores_per_row = 8 + while self.dim % (32 * lm_head_num_rows * lm_head_cores_per_row) != 0: + lm_head_num_rows -= 1 + if lm_head_num_rows == 0: + lm_head_cores_per_row -= 1 + if lm_head_cores_per_row == 0: + raise ValueError( + f"Could not find a lm_head_num_rows such that self.dim(={self.dim}) % (lm_head_num_rows * 8) == 0" + ) + lm_head_num_rows = 8 + self.lm_head_core_grid = ttnn.CoreGrid(y=lm_head_num_rows, x=lm_head_cores_per_row) self.model_config["LM_HEAD_INPUT_MEMCFG"] = ttnn.create_sharded_memory_config( ( @@ -455,7 +484,6 @@ def find_largest_divisor(n, max_divisor=8): grid_by_batch = (1, 1) else: raise ValueError(f"Batch size {self.max_batch_size} not supported") - core_grid_by_batch = ttnn.CoreGrid(y=grid_by_batch[1], x=grid_by_batch[0]) core_range_set_by_batch = ttnn.CoreRangeSet( { ttnn.CoreRange( @@ -610,41 +638,42 @@ def find_largest_divisor(n, max_divisor=8): else self.model_config["FULL_GRID_MEMCFG"] ) - self.model_config["FF1_3_TG_PROGCFG"] = self.matmul_1d_config_from_tensor_shapes( - ( - 1, - 1, - 32, - self.dim // 4, - ), - ( - 1, - 1, - self.dim // 4, - self.hidden_dim // 8, - ), - grid=ttnn.CoreGrid(x=8, y=2), - overwrite_subblock_h=1, - overwrite_subblock_w=1, - ) + if self.is_galaxy: + self.model_config["FF1_3_TG_PROGCFG"] = self.matmul_1d_config_from_tensor_shapes( + ( + 1, + 1, + 32, + self.dim // 4, + ), + ( + 1, + 1, + self.dim // 4, + self.hidden_dim // 8, + ), + grid=ttnn.CoreGrid(x=8, y=2), + overwrite_subblock_h=1, + overwrite_subblock_w=1, + ) - self.model_config["FF2_TG_PROGCFG"] = self.matmul_1d_config_from_tensor_shapes( - ( - 1, - 1, - 32, - self.hidden_dim // 8, - ), - ( - 1, - 1, - self.hidden_dim // 8, - self.dim // 4, - ), - grid=ttnn.CoreGrid(x=8, y=2), - overwrite_subblock_h=1, - overwrite_subblock_w=1, - ) + self.model_config["FF2_TG_PROGCFG"] = self.matmul_1d_config_from_tensor_shapes( + ( + 1, + 1, + 32, + self.hidden_dim // 8, + ), + ( + 1, + 1, + self.hidden_dim // 8, + self.dim // 4, + ), + grid=ttnn.CoreGrid(x=8, y=2), + overwrite_subblock_h=1, + overwrite_subblock_w=1, + ) self.model_config["FF1_OUT_REDUCE_SCATTER_MEMCFG"] = ttnn.create_sharded_memory_config( shape=(32, self.hidden_dim // 28 // 8), # shard_grid_cores = 28, num_devices=8 @@ -815,6 +844,7 @@ def _get_xattn_kv_prefill_mem_cfg(seq_len): self.model_config["XATTN_KV_PREFILL_MEM_CFG"] = _get_xattn_kv_prefill_mem_cfg self.VISION_MAX_MM_SEQ = nearest_32(self.vision_chunk_ntok) + # RMS NORM self.model_config["SHARDED_NORM_ATTN_PRGM_CFG"] = self.create_sharded_norm_config(attn_input_grid) self.model_config["SHARDED_NORM_MLP_PRGM_CFG"] = self.create_sharded_norm_config(mlp_core_grid) @@ -835,7 +865,7 @@ def _get_xattn_kv_prefill_mem_cfg(seq_len): ), ) - self.model_config = set_tg_attention_config(self.model_config, self.dim) + self.set_tg_attention_config() self.is_multichip = self.num_devices > 1 self.num_reduce_scatter_links = 1 @@ -844,12 +874,20 @@ def _get_xattn_kv_prefill_mem_cfg(seq_len): ) # TODO: try out 3 for short axis and 4 for long axis (TG only) <- should work but untested in model self.ccl_dtype = ttnn.bfloat8_b + logger.info(f"Attention grid: {attn_input_grid}") + logger.info(f"MLP grid: {mlp_core_grid}") + logger.info(f"MLP prefill grids @ 32: w1/w3: {mlp1_3_grid(32)}, w2: {mlp2_grid(32)}") + logger.info( + f"MLP prefill grids @ max_seq_len({self.max_seq_len}): w1/w3: {mlp1_3_grid(self.max_seq_len)}, w2: {mlp2_grid(self.max_seq_len)}" + ) + logger.info(f"LM head grid: {self.lm_head_core_grid}") + def is_distributed_norm(self, mode): if not self.is_multichip: return False if all([dim > 1 for dim in list(self.mesh_device.shape)]): # 2D grid return True - elif self.dim >= 8192 and mode == "prefill": # Somewhere between 4k and 8k WH runs out of L1 if not distributed + elif self.dim > 4096 and mode == "prefill": # Somewhere between 4k and 8k WH runs out of L1 if not distributed return True return False @@ -932,23 +970,72 @@ def prepare_residual_tensor_prefill(self, x_bsh, force_replicated=False): ) return xs_1BSH - def _set_llama_params_from_dict(self, params): - # Text params - self.dim = params["dim"] - self.ffn_dim_multiplier = params["ffn_dim_multiplier"] - self.multiple_of = params["multiple_of"] - self.n_heads = params["n_heads"] - self.n_kv_heads = params["n_kv_heads"] - self.n_layers = params["n_layers"] - self.norm_eps = params["norm_eps"] - self.rope_theta = params["rope_theta"] - self.use_scaled_rope = params["use_scaled_rope"] + def _set_params_from_dict(self, params): + # Common params with different names between Meta and HF + self.dim = params.get("dim", params.get("hidden_size")) + self.n_heads = params.get("n_heads", params.get("num_attention_heads")) + self.n_kv_heads = params.get("n_kv_heads", params.get("num_key_value_heads")) + self.n_layers = params.get("n_layers", params.get("num_hidden_layers")) + self.full_model_n_layers = self.n_layers + self.norm_eps = params.get("norm_eps", params.get("rms_norm_eps")) self.vocab_size = params["vocab_size"] self.padded_vocab_size = 128 * 1024 self.head_dim = self.dim // self.n_heads - self.hidden_dim = calculate_hidden_dim(self.dim, self.ffn_dim_multiplier, self.multiple_of) - # Vision params + # Handle different MLP dimension specifications + if "intermediate_size" in params: + self.hidden_dim = params["intermediate_size"] + self.ffn_dim_multiplier = None + self.multiple_of = None + else: + self.ffn_dim_multiplier = params["ffn_dim_multiplier"] + self.multiple_of = params["multiple_of"] + self.hidden_dim = calculate_hidden_dim(self.dim, self.ffn_dim_multiplier, self.multiple_of) + + if "_name_or_path" in params: + self.model_name = os.path.basename(params["_name_or_path"]) + + if self.base_model_name == "Qwen2.5-7B" and self.num_devices not in [0, 2, 4]: + raise AssertionError( + "Qwen2.5-7B is only supported on 2 or 4 devices, run on an N300 or use FAKE_DEVICE=N150x4" + ) + + self.unpadded_hidden_dim = self.hidden_dim + # Don't need to pad for CPU runs + if self.num_devices: + # Default padding cores for each model, 0 if not set here + default_padded_cores = { + "Qwen2.5-72B": 32, + "Qwen2.5-7B": 16, + }.get(self.base_model_name, 0) + + # Override MLP padding cores from env var + mlp_padded_cores = int(os.environ.get("PAD_MLP_CORES", default_padded_cores)) + + # Only pad if MLP_PADDED_CORES is non-zero + if mlp_padded_cores > 0: + padded_hidden_dim = nearest_multiple( + self.hidden_dim, mlp_padded_cores * self.tile_size * self.num_devices + ) + if padded_hidden_dim != self.hidden_dim: + logger.info( + f"PAD_MLP_CORES={mlp_padded_cores}, padding hidden dim from {self.hidden_dim} to {padded_hidden_dim}" + ) + self.hidden_dim = padded_hidden_dim + + # RoPE params + self.rope_theta = params.get("rope_theta") + # If use_scaled_rope is not present, assume setting rope_scaling means use scaled rope + # If it is present and is set to false, do not use scaled rope + # Setting self.rope_scaling_factor to None is our way of saying do not use scaled rope + if "rope_scaling" in params and params.get("use_scaled_rope", True): + self.rope_scaling_factor = params.get("factor", None) + self.orig_context_len = params.get("original_max_position_embeddings", None) + else: + self.rope_scaling_factor = None + self.orig_context_len = None + + # Vision params (Meta-specific) self.vision_chunk_size = params.get("vision_chunk_size", -1) self.vision_max_num_chunks = params.get("vision_max_num_chunks", 4) self.vision_num_cross_attention_layers = params.get("vision_num_cross_attention_layers", -1) @@ -967,6 +1054,14 @@ def _set_llama_params_from_dict(self, params): self.vision_patch_size = 14 self.vision_in_channels = 3 + @property + def use_scaled_rope(self): + return self.rope_scaling_factor is not None + + @property + def base_model_name(self): + return self.model_name.split("B-")[0] + "B" if "B-" in self.model_name else self.model_name + @property def vision_chunk_ntok(self): """ @@ -974,12 +1069,50 @@ def vision_chunk_ntok(self): """ return (self.vision_chunk_size // self.vision_patch_size) ** 2 + 1 + def _set_model_params(self, checkpoint_dir): + if self.checkpoint_type == CheckpointType.Meta: + self._set_llama_params(checkpoint_dir) + elif self.checkpoint_type == CheckpointType.HuggingFace: + self._set_hf_params(checkpoint_dir) + else: + raise ValueError(f"Unsupported checkpoint type: {self.checkpoint_type}") + def _set_llama_params(self, checkpoint_dir): params_file = os.path.join(checkpoint_dir, "params.json") assert os.path.exists(params_file), f"params.json file not found at {params_file}" with open(params_file, "r") as f: params = json.load(f) - self._set_llama_params_from_dict(params) + self._set_params_from_dict(params) + + # Meta-style config dicts don't specity model name or rope_scaling_factor so hard-code these + # Set the model name based on the checkpoint directory being loaded + # FIXME: add a llama prefix to all llama-specific models and names + if "3.2-1B" in checkpoint_dir: + self.model_name = "Llama3.2-1B" + "-Instruct" if self.instruct else "" + self.rope_scaling_factor = 32 + elif "3.2-3B" in checkpoint_dir: + self.model_name = "Llama3.2-3B" + "-Instruct" if self.instruct else "" + self.rope_scaling_factor = 32 + elif "3.1-8B" in checkpoint_dir: + self.model_name = "Llama3.1-8B" + "-Instruct" if self.instruct else "" + self.rope_scaling_factor = 8 + elif "3.2-11B" in checkpoint_dir: + self.model_name = "Llama3.2-11B" + "-Instruct" if self.instruct else "" + self.rope_scaling_factor = 8 # shared with 3.1-8B + elif "3.1-70B" in checkpoint_dir: + self.model_name = "Llama3.1-70B" + "-Instruct" if self.instruct else "" + self.rope_scaling_factor = 8 + self.is_70b = True # self.dim == 8192 and self.n_layers == 80 + else: + logger.warning(f"Unknown Meta-style model: {checkpoint_dir}") + self.orig_context_len = 8192 + + def _set_hf_params(self, checkpoint_dir): + config_file = os.path.join(checkpoint_dir, "config.json") + assert os.path.exists(config_file), f"config.json file not found at {config_file}" + with open(config_file, "r") as f: + config = json.load(f) + self._set_params_from_dict(config) def __repr__(self): return f"""ModelArgs( @@ -992,7 +1125,7 @@ def __repr__(self): ffn_dim_multiplier={self.ffn_dim_multiplier}, norm_eps={self.norm_eps}, rope_theta={self.rope_theta}, - use_scaled_rope={self.use_scaled_rope}, + rope_scaling_factor={self.rope_scaling_factor}, max_batch_size={self.max_batch_size}, max_seq_len={self.max_seq_len}, vision_chunk_size={self.vision_chunk_size}, @@ -1031,19 +1164,19 @@ def get_model_config(self): # TODO Update function for large models: For 1 layer tests we only want to load 1 checkpoint file, instead of all. def load_state_dict(self): - """Generate or load state_dict for n_layers of the model""" if self.dummy_weights: reference_model = Transformer(self) state_dict = reference_model.state_dict() state_dict_prefix = self.get_state_dict_prefix("", None) state_dict = {f"{state_dict_prefix}{k}": torch.randn_like(v) for k, v in state_dict.items()} + elif self.checkpoint_type == CheckpointType.Meta: + state_dict = load_meta_state_dict(self.DEFAULT_CKPT_DIR, self.n_layers) else: - state_dict = load_llama_state_dict(self.DEFAULT_CKPT_DIR, self.n_layers) - + assert self.checkpoint_type == CheckpointType.HuggingFace + state_dict = load_hf_state_dict(self.DEFAULT_CKPT_DIR) + state_dict = convert_hf_to_meta(state_dict, self.head_dim) keys_dict = list(state_dict.keys())[:] - remv = [ - f"layers.{i}." for i in list(range(self.n_layers, 32)) - ] # TODO, this is not generalized to all models. it assumes max layers = 32 + remv = [f"layers.{i}." for i in list(range(self.n_layers, self.full_model_n_layers))] for k in keys_dict: if any([r in k for r in remv]): state_dict.pop(k) @@ -1068,7 +1201,7 @@ def matmul_config( in0_block_w: int = None, fuse_batch: bool = False, fused_activation=None, - ) -> ttnn.MatmulMultiCoreReuseMultiCastProgramConfig: + ): per_core_M = math.ceil(m / (self.tile_size * grid_size[1])) per_core_N = math.ceil(n / (self.tile_size * grid_size[0])) @@ -1134,6 +1267,31 @@ def find_grid(self, N): f"Cannot find a grid configuration for {N} tiles that evenly divides into {max_cores} cores of max size {max_rows}x{max_cols}." ) + def find_prefill_grid(self, row_tiles, col_tiles): + """Find a grid such that the number of row tiles evenly divides into the number + of rows and the number of column tiles evenly divides into the number of columns + """ + max_rows = 8 + max_cols = 8 + + # Find number of cols that evenly divides into the number of columns + cols = None + rows = None + + for i in range(max_cols, 0, -1): + if col_tiles % i == 0: + cols = i + break + + for i in range(max_rows, 0, -1): + if row_tiles % i == 0: + rows = i + break + + assert cols is not None, f"Cannot find a number of columns that evenly divides into {col_tiles}, not even 1(!)." + assert rows is not None, f"Cannot find a number of rows that evenly divides into {row_tiles}, not even 1(!)." + return rows, cols + def dram_shard_core_grid_for_k_and_n(self, k: int, n: int) -> Tuple[int, int]: rows, cols = self.find_grid_k_n(k // self.tile_size, n // self.tile_size) return ttnn.CoreGrid(x=cols, y=rows) @@ -1143,7 +1301,6 @@ def find_grid_k_n(self, K, N): Find the number of rows and columns for a grid of cores such that the total number of tiles N can be evenly divided among the cores. Each core will have the same integer number of tiles. - The grid size is limited to a maximum of 2 rows and 8 columns. Parameters: N (int): Total number of tiles to be distributed. @@ -1154,9 +1311,9 @@ def find_grid_k_n(self, K, N): Raises: AssertionError: If it's not possible to find such a grid configuration. """ - max_rows = 4 + max_rows = 8 max_cols = 8 # Maximum number of rows or columns - max_cores = max_rows * max_cols # Maximum number of cores (8x2 grid) + max_cores = max_rows * max_cols # Maximum number of cores # Find all possible numbers of cores that divide N and are less than or equal to max_cores possible_cores = [c for c in range(1, max_cores + 1) if K % c == 0 and N % c == 0] @@ -1175,12 +1332,10 @@ def find_grid_k_n(self, K, N): f"Cannot find a grid configuration such that both {K} and {N} tiles evenly divide into cores of max size {max_rows}x{max_cols}." ) - def dram_matmul_config( - self, m: int, k: int, n: int, num_cores=None - ) -> ttnn.MatmulMultiCoreReuseMultiCastDRAMShardedProgramConfig: + def dram_matmul_config(self, m: int, k: int, n: int, num_cores=None): # in0_block_w must evenly divide k and be no larger than tile_size * num_cores if num_cores is None: - # num_cores = self.dram_shard_core_grid_for_k_and_n(k).num_cores + # num_cores = self.dram_shard_core_grid_for_k(k).num_cores num_cores = self.dram_shard_core_grid_for_k_and_n(k, n).num_cores assert ( k % (self.tile_size * num_cores) == 0 @@ -1302,72 +1457,352 @@ def create_sharded_norm_config(self, grid): inplace=False, ) + def detect_checkpoint_type(self) -> CheckpointType: + """Detect if checkpoint directory contains Meta or HuggingFace format weights. + + Returns: + CheckpointType: Meta or HuggingFace enum value + + Raises: + ValueError: If neither Meta nor HuggingFace checkpoint format is detected + """ + config_path = os.path.join(self.DEFAULT_CKPT_DIR, "config.json") + params_path = os.path.join(self.DEFAULT_CKPT_DIR, "params.json") + + if os.path.exists(config_path): + with open(config_path) as f: + config = json.load(f) + if "transformers_version" in config: + return CheckpointType.HuggingFace + + if os.path.exists(params_path): + return CheckpointType.Meta + + raise ValueError( + f"Could not detect Meta or HuggingFace checkpoint format in {self.DEFAULT_CKPT_DIR}. " + "Directory should contain either config.json (HuggingFace) or params.json (Meta)." + ) + + def create_tokenizer(self): + """Create and return a Tokenizer instance based on the checkpoint type.""" + if self.checkpoint_type == CheckpointType.Meta: + # Use the Meta Tokenizer + from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.tokenizer import Tokenizer + + return Tokenizer(self.tokenizer_path) + else: + # Create a HuggingFace AutoTokenizer + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(self.DEFAULT_TOKENIZER_PATH) -def load_llama_state_dict(ckpt_dir, n_layers=None, start_layer_idx=0): - checkpoints = sorted(Path(ckpt_dir).glob("*.pth")) - assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}" - is_chunked = "layers_" in str(checkpoints[0]) - if is_chunked: - checkpoint = load_chunked_checkpoints(checkpoints, n_layers, start_layer_idx) - else: - checkpoint = load_sharded_checkpoints(checkpoints, n_layers) - - return checkpoint - - -def load_chunked_checkpoints(checkpoints, n_layers, start_layer_idx): - checkpoint = {} - - (f"Loading {len(checkpoints)} checkpoint files") - for ckpt in tqdm(checkpoints): - if n_layers: - # Layer range is in the file name, like layers_start-end.pth - layer_range = ckpt.stem.split("_")[1] - start_layer, end_layer = map(int, layer_range.split("-")) - if start_layer > n_layers + start_layer_idx: - continue - if end_layer < start_layer_idx: - continue - - loaded_ckpt = torch.load(ckpt, map_location="cpu") - checkpoint.update(loaded_ckpt) - return checkpoint - - -def load_sharded_checkpoints(checkpoints, n_layers): - checkpoint = {} - logger.info(f"Loading {len(checkpoints)} checkpoint files") - for ckpt in tqdm(checkpoints): - loaded_ckpt = torch.load(ckpt, map_location="cpu") - for ( - key, - value, - ) in loaded_ckpt.items(): - if "layers." in key: - layer_num = int(key.split("layers.")[1].split(".")[0]) - if n_layers and layer_num >= n_layers: - continue - if key in checkpoint: - checkpoint[key] += [value] + # Add meta-compatible stop token list to the HF tokenizer + if not "stop_tokens" in tokenizer.__dict__: + tokenizer.stop_tokens = [tokenizer.eos_token_id] + return tokenizer + + def encode_prompt(self, prompt_text, system_prompt_text=None, instruct=True): + if self.checkpoint_type == CheckpointType.Meta: + if instruct: + return encode_prompt_llama_instruct(self.tokenizer, prompt_text, system_prompt_text) + else: + return self.tokenizer.encode(prompt_text, bos=True, eos=False) + else: + if instruct: + return encode_prompt_hf(self.tokenizer, prompt_text, system_prompt_text) + else: + return self.tokenizer.encode(prompt_text, add_special_tokens=False) + + def reference_lm_head(self): + if self.checkpoint_type == CheckpointType.Meta: + from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import ColumnParallelLinear + + return ColumnParallelLinear(self.dim, self.vocab_size, bias=False, init_method=lambda x: x) + else: + model = self.reference_transformer(wrap=False) + layer = model.lm_head + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) + return layer + + def reference_transformer(self, wrap=True, load_checkpoint=False): + if self.checkpoint_type == CheckpointType.Meta: + from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import Transformer + + model = Transformer(self) + if load_checkpoint: + model.load_state_dict(self.load_state_dict()) + return model + else: + from transformers import AutoConfig, AutoModelForCausalLM + + if not load_checkpoint: + config = AutoConfig.from_pretrained(self.DEFAULT_CKPT_DIR) + config.num_layers = self.n_layers + model = AutoModelForCausalLM.from_config(config) else: - checkpoint[key] = [value] - del loaded_ckpt + model = AutoModelForCausalLM.from_pretrained(self.DEFAULT_CKPT_DIR) + if wrap: + wrapper = HfModelWrapper(model, self.head_dim) + return wrapper + else: + return model + + def reference_rms_norm(self): + if self.checkpoint_type == CheckpointType.Meta: + from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import RMSNorm + + return RMSNorm(self.dim, self.norm_eps) + else: + model = self.reference_transformer(wrap=False) + layer = model.model.norm + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) + return layer + + def reference_mlp(self): + if self.checkpoint_type == CheckpointType.Meta: + from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import FeedForward - # concat checkpoint values - for key, value in checkpoint.items(): - if len(value) == 1 or "norm" in key: - checkpoint[key] = value[0] + return FeedForward(self.dim, 4 * self.dim, self.multiple_of, self.ffn_dim_multiplier) else: - if key == "tok_embeddings.weight" or key == "output.weight": - assert value[0].shape[1] == 8192 # FIXME: do we need this hardcoded shape? - # Concatenate along dimension 0 for llama3 token embeddings weight and lm head - checkpoint[key] = torch.cat(value, dim=0) + model = self.reference_transformer(wrap=False) + layer = model.model.layers[0].mlp + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) + return layer + + def reference_embedding(self, reference_model=None): + if self.checkpoint_type == CheckpointType.Meta: + from models.demos.llama3.tt.llama_common import HostEmbedding + + return HostEmbedding(self) + else: + if reference_model is None: + model = self.reference_transformer(wrap=False) else: - # cat_dim is index of the smallest dimension in value[0].shape - cat_dim = torch.argmin(torch.tensor(value[0].shape)) - checkpoint[key] = torch.cat(value, dim=cat_dim) + model = reference_model + layer = model.model.embed_tokens + layer._load_state_dict = layer.load_state_dict + layer.load_state_dict = lambda x: layer._load_state_dict(convert_meta_to_hf(x, self.head_dim)) + return layer + + def reference_decoder(self): + if self.checkpoint_type == CheckpointType.Meta: + from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import TransformerBlock - return checkpoint + return TransformerBlock(layer_id=0, args=self) + else: + model = self.reference_transformer(wrap=False) + layer = model.model.layers[0] + wrapper = HfDecoderWrapper(layer, self.head_dim) + return wrapper + + def reference_attention(self): + if self.checkpoint_type == CheckpointType.Meta: + from models.demos.t3000.llama2_70b.reference.llama.llama31_8b.model import Attention + + return Attention(self) + else: + model = self.reference_transformer(wrap=False) + layer = model.model.layers[0].self_attn + wrapper = HfAttentionWrapper(layer, self.head_dim) + return wrapper + + def set_tg_attention_config(self): + shard_spec_n_cores_grid = ttnn.CoreRangeSet({num_to_corerange(40)}) + + self.model_config["CREATE_HEAD_INPUT_MEMCFG"] = ( + None + if self.dim < 4096 + else ttnn.MemoryConfig( + ttnn.TensorMemoryLayout.WIDTH_SHARDED, + ttnn.BufferType.L1, + ttnn.ShardSpec( + shard_spec_n_cores_grid, + [ + 32, + 32, + ], + ttnn.ShardOrientation.ROW_MAJOR, + ), + ) + ) + + if self.is_galaxy: + num_cores = 40 if self.dim == 8192 else (24 if self.dim == 4096 else (20 if self.dim == 3072 else 12)) + + self.model_config["QKV_OUT_GATHERED_MEMCFG"] = lambda mesh_cols: ttnn.create_sharded_memory_config( + shape=(32 * mesh_cols, 32), # mesh_cols = 4 + core_grid=num_to_coregrid(num_cores), + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + + self.model_config["SELF_OUT_GATHERED_MEMCFG"] = lambda mesh_rows: ttnn.create_sharded_memory_config( + shape=(32 * mesh_rows, self.dim // 4 // min(32, self.dim // 4 // 32)), + core_grid=num_to_coregrid(min(32, self.dim // 4 // 32)), + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + self.model_config["GATHER_USERS_MEMCFG"] = lambda mesh_cols: ttnn.create_sharded_memory_config( + shape=(32 * mesh_cols, 32), # mesh_cols = 4 + core_grid=num_to_coregrid(min(32, self.dim // 8 // 32)), + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + else: + qkv_core_grid = self.dram_shard_core_grid_for_k(self.dim) + self.model_config["QKV_OUT_GATHERED_MEMCFG"] = lambda mesh_rows: ttnn.create_sharded_memory_config( + ( + self.tile_size * mesh_rows, + self.dim // qkv_core_grid.num_cores, + ), # Shard shape: [32, 128] -> 1 shard per core + core_grid=qkv_core_grid, + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + gather_core_grid = self.dram_shard_core_grid_for_k(self.dim // 4) + self.model_config["SELF_OUT_GATHERED_MEMCFG"] = lambda mesh_rows: ttnn.create_sharded_memory_config( + ( + self.tile_size * mesh_rows, + self.dim // 4 // gather_core_grid.num_cores, + ), + core_grid=gather_core_grid, + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + users_core_grid = self.dram_shard_core_grid_for_k(self.dim // 8) + self.model_config["GATHER_USERS_MEMCFG"] = lambda mesh_cols: ttnn.create_sharded_memory_config( + ( + self.tile_size * mesh_cols, + self.dim // 8 // users_core_grid.num_cores, + ), + core_grid=users_core_grid, + strategy=ttnn.ShardStrategy.WIDTH, + orientation=ttnn.ShardOrientation.ROW_MAJOR, + use_height_and_width_as_shard_shape=True, + ) + + +class HfAttentionWrapper: + def __init__(self, attention, head_dim): + from transformers import DynamicCache + + super().__init__() + self.attention = attention + self.past_key_value = DynamicCache() + self.head_dim = head_dim + + def forward(self, x, start_pos, freqs_cis_i, mask=None): + position_ids = torch.tensor([list(range(start_pos, start_pos + x.shape[1]))] * x.shape[0]) + if mask is not None: + while len(mask.shape) < 4: + mask = mask.unsqueeze(0) + output, _, self.past_key_value = self.attention( + x, + past_key_value=self.past_key_value, + use_cache=True, + position_ids=position_ids, + attention_mask=mask, + ) + return output + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def load_state_dict(self, state_dict): + return self.attention.load_state_dict(convert_meta_to_hf(state_dict, self.head_dim)) + + @property + def cache_k(self): + [(k, v)] = self.past_key_value.to_legacy_cache() + hf_k = k.permute(0, 2, 1, 3) # match meta-style reference which uses (batch_size, seq, n_kv_heads, head_dim) + batch_size, seq_len, n_heads, head_dim = hf_k.shape + + meta_k = torch.zeros_like(hf_k) + for b in range(batch_size): + for s in range(seq_len): + # Flatten just heads and head_dim + flat = hf_k[b, s].flatten() + # Apply reverse_permute + transformed = reverse_permute(flat.unsqueeze(-1), n_heads, flat.shape[0], 1).squeeze(-1) + # Restore heads and head_dim shape + meta_k[b, s] = transformed.reshape(n_heads, head_dim) + + return meta_k + + @property + def cache_v(self): + [(k, v)] = self.past_key_value.to_legacy_cache() + return v.permute(0, 2, 1, 3) # match meta-style reference which uses (batch_size, seq, n_kv_heads, head_dim) + + +class HfDecoderWrapper: + def __init__(self, decoder, head_dim): + from transformers import DynamicCache + + self.decoder = decoder + self.head_dim = head_dim + self.past_key_values = DynamicCache() + + def forward(self, x, start_pos, freqs_cis_i, mask=None): + position_ids = torch.tensor([list(range(start_pos, start_pos + x.shape[1]))] * x.shape[0]) + if mask is not None: + while len(mask.shape) < 4: + mask = mask.unsqueeze(0) + output, self.past_key_values = self.decoder.forward( + x, + past_key_value=self.past_key_values, + use_cache=True, + position_ids=position_ids, + attention_mask=mask, + ) + return output + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def load_state_dict(self, state_dict): + return self.decoder.load_state_dict(convert_meta_to_hf(state_dict, self.head_dim)) + + +class HfModelWrapper: + def __init__(self, model, head_dim): + from transformers import DynamicCache + + self.model = model + self.head_dim = head_dim + self.past_key_values = DynamicCache() + + def forward(self, inputs_embeds, start_pos, mode="decode"): + position_ids = torch.tensor( + [list(range(start_pos, start_pos + inputs_embeds.shape[1]))] * inputs_embeds.shape[0] + ) + logits, new_cache, hidden_states = self.model.forward( + inputs_embeds=inputs_embeds, + position_ids=position_ids, + use_cache=True, + past_key_values=self.past_key_values, + return_dict=False, + output_hidden_states=True, + ) + self.past_key_values = new_cache + return logits if mode == "decode" else hidden_states[-2] # last hidden state is final norm + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def load_state_dict(self, state_dict): + return self.model.load_state_dict(convert_meta_to_hf(state_dict, self.head_dim)) + + def eval(self): + self.model.eval() def num_to_corerange(x): @@ -1388,51 +1823,3 @@ def num_to_coregrid(x): return ttnn.CoreGrid(y=2, x=6) if x == 20: return ttnn.CoreGrid(y=4, x=5) - - -def set_tg_attention_config(model_config, dim): - shard_spec_n_cores_grid = ttnn.CoreRangeSet({num_to_corerange(40)}) - - model_config["CREATE_HEAD_INPUT_MEMCFG"] = ( - None - if dim < 4096 - else ttnn.MemoryConfig( - ttnn.TensorMemoryLayout.WIDTH_SHARDED, - ttnn.BufferType.L1, - ttnn.ShardSpec( - shard_spec_n_cores_grid, - [ - 32, - 32, - ], - ttnn.ShardOrientation.ROW_MAJOR, - ), - ) - ) - - num_cores = 40 if dim == 8192 else (24 if dim == 4096 else (20 if dim == 3072 else 12)) - - model_config["QKV_OUT_GATHERED_MEMCFG"] = lambda mesh_cols: ttnn.create_sharded_memory_config( - shape=(32 * mesh_cols, 32), # mesh_cols = 4 - core_grid=num_to_coregrid(num_cores), - strategy=ttnn.ShardStrategy.WIDTH, - orientation=ttnn.ShardOrientation.ROW_MAJOR, - use_height_and_width_as_shard_shape=True, - ) - - model_config["SELF_OUT_GATHERED_MEMCFG"] = lambda mesh_rows: ttnn.create_sharded_memory_config( - shape=(32 * mesh_rows, dim // 4 // min(32, dim // 4 // 32)), - core_grid=num_to_coregrid(min(32, dim // 4 // 32)), - strategy=ttnn.ShardStrategy.WIDTH, - orientation=ttnn.ShardOrientation.ROW_MAJOR, - use_height_and_width_as_shard_shape=True, - ) - model_config["GATHER_USERS_MEMCFG"] = lambda mesh_cols: ttnn.create_sharded_memory_config( - shape=(32 * mesh_cols, 32), # mesh_cols = 4 - core_grid=num_to_coregrid(min(32, dim // 8 // 32)), - strategy=ttnn.ShardStrategy.WIDTH, - orientation=ttnn.ShardOrientation.ROW_MAJOR, - use_height_and_width_as_shard_shape=True, - ) - - return model_config diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention.py b/models/demos/llama3/tt/multimodal/llama_cross_attention.py index 57bfedecffa..ef312334bcf 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention.py @@ -292,6 +292,7 @@ def forward_decode( dim=3, math_op=ttnn.ReduceType.Sum, num_links=1, + topology=self.configuration.ccl_topology(), memory_config=ttnn.DRAM_MEMORY_CONFIG, ) @@ -382,6 +383,7 @@ def forward_prefill( dim=3, math_op=ttnn.ReduceType.Sum, num_links=1, + topology=self.configuration.ccl_topology(), memory_config=ttnn.DRAM_MEMORY_CONFIG, ) return dense_out_reduced diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py index 162f6dc6da7..28ee6e810ed 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_text.py @@ -126,8 +126,8 @@ def __init__( configuration.head_dim, configuration.max_seq_len, configuration.rope_theta, - configuration.use_scaled_rope, configuration.rope_scaling_factor, + configuration.orig_context_len, ) self.trans_mats_dict = self.rope_setup.get_both_trans_mats() @@ -291,9 +291,9 @@ def forward( h = xattn_layer( h, xattn_mask=xattn_mask, - xattn_cache=xattn_caches[xattn_layer_idx] - if cross_page_table is None - else kv_cache[total_layer_idx], + xattn_cache=( + xattn_caches[xattn_layer_idx] if cross_page_table is None else kv_cache[total_layer_idx] + ), full_text_row_masked_out_mask_1NSH=full_text_row_masked_out_mask_1NSH, full_text_row_masked_out_mask_11SD=full_text_row_masked_out_mask_11SD, mode=mode, diff --git a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_vision.py b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_vision.py index 4c59ecec52b..06e5095d4ca 100644 --- a/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_vision.py +++ b/models/demos/llama3/tt/multimodal/llama_cross_attention_transformer_vision.py @@ -72,14 +72,16 @@ def shuffle_weight(weight): return w.transpose(-1, -2).view(orig_shape) as_interleaved_tensor = lambda name, suffix, type, dim: ttnn.as_tensor( - shuffle_weight(torch_weight(name, suffix)) - if suffix == "weight" - else torch_bias(name, suffix), # Grab only the wX part of the name + ( + shuffle_weight(torch_weight(name, suffix)) if suffix == "weight" else torch_bias(name, suffix) + ), # Grab only the wX part of the name dtype=type, device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=dim) - if dim is not None - else ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=( + ttnn.ShardTensorToMesh(self.mesh_device, dim=dim) + if dim is not None + else ttnn.ReplicateTensorToMesh(self.mesh_device) + ), layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, cache_file_name=cache_name(name, suffix), diff --git a/models/demos/llama3/tt/multimodal/llama_image_mlp.py b/models/demos/llama3/tt/multimodal/llama_image_mlp.py index b0c63a83df2..45755f88f30 100644 --- a/models/demos/llama3/tt/multimodal/llama_image_mlp.py +++ b/models/demos/llama3/tt/multimodal/llama_image_mlp.py @@ -35,14 +35,16 @@ def __init__( cache_name = lambda name, suffix: weight_cache_path / (state_dict_prefix + f".{name}.{suffix}") as_interleaved_tensor = lambda name, suffix, type, dim: ttnn.as_tensor( - torch_weight(name, suffix) - if suffix == "weight" - else torch_bias(name, suffix), # Grab only the wX part of the name + ( + torch_weight(name, suffix) if suffix == "weight" else torch_bias(name, suffix) + ), # Grab only the wX part of the name dtype=type, device=self.mesh_device, - mesh_mapper=ttnn.ShardTensorToMesh(self.mesh_device, dim=dim) - if dim is not None - else ttnn.ReplicateTensorToMesh(self.mesh_device), + mesh_mapper=( + ttnn.ShardTensorToMesh(self.mesh_device, dim=dim) + if dim is not None + else ttnn.ReplicateTensorToMesh(self.mesh_device) + ), layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG, cache_file_name=cache_name(name, suffix), diff --git a/models/demos/llama3/tt/multimodal/llama_vision_model.py b/models/demos/llama3/tt/multimodal/llama_vision_model.py index c22c4100f43..7a4918c96c1 100644 --- a/models/demos/llama3/tt/multimodal/llama_vision_model.py +++ b/models/demos/llama3/tt/multimodal/llama_vision_model.py @@ -28,7 +28,6 @@ from models.demos.llama3.tt.llama_common import ( get_prefill_rot_mat, get_rot_transformation_mat, - get_single_rot_mat, copy_host_to_device, get_padded_prefill_len, ) @@ -374,7 +373,9 @@ def prepare_inputs_prefill( self.configuration.max_seq_len, self.mesh_device, seq_len=S, + theta=self.configuration.rope_theta, scale_factor=self.configuration.rope_scaling_factor, + orig_context_len=self.configuration.orig_context_len, ) if isinstance(page_table, torch.Tensor):