Skip to content

Commit

Permalink
Add HF model support inc. DS-R1-Distill, Qwen needs yarn support (#17421
Browse files Browse the repository at this point in the history
)

### Problem description
Existing codebase loads the meta checkpoint format but many derivative
models are only available on huggingface.

### What's changed
Add support for loading HuggingFace model formats, paving the way for
full Qwen support (pending yarn rope implementation) and adding
DeepSeek-R1-Distill-Llama-70B support.

### Checklist
All passing locally.
- [x]
[all-post-commit](https://github.com/tenstorrent/tt-metal/actions/runs/13181023765)
- [FIXED] Failing on loading the tokenizer on this pipeline only
(investigating)
- [x]
[Single](https://github.com/tenstorrent/tt-metal/actions/runs/13142509908/job/36672984561)
- [x]
[Single-demos](https://github.com/tenstorrent/tt-metal/actions/runs/13180995444)
    - Only failing on N300 performance - Investigating
- [ ]
[T3K](https://github.com/tenstorrent/tt-metal/actions/runs/13142519276)
- [x]
[Unit](https://github.com/tenstorrent/tt-metal/actions/runs/13163296158/job/36737812258)
- [x]
[Model-perf](https://github.com/tenstorrent/tt-metal/actions/runs/13164376159)
- [x]
[Frequent-1](https://github.com/tenstorrent/tt-metal/actions/runs/13174954913)
- [x]
[Frequent-2](https://github.com/tenstorrent/tt-metal/actions/runs/13164380377/job/36742877847)
- [x]
[Demo](https://github.com/tenstorrent/tt-metal/actions/runs/13180986094)
- [x]
[TG](https://github.com/tenstorrent/tt-metal/actions/runs/13154035596/job/36707218743)
  - Pipelines have issues not related to these changes.

---------

Signed-off-by: Salar Hosseini <skhorasgani@tenstorrent.com>
Co-authored-by: mtairum <mtairum@tenstorrent.com>
Co-authored-by: Salar Hosseini <skhorasgani@tenstorrent.com>
  • Loading branch information
3 people authored Feb 7, 2025
1 parent 558da69 commit d0b59bd
Show file tree
Hide file tree
Showing 50 changed files with 1,983 additions and 780 deletions.
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,13 @@
| [Llama 3.1 70B (TP=8)](./models/demos/t3000/llama3_70b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 190 | 15.1 | 20 | 483.2 | [v0.54.0-rc2](https://github.com/tenstorrent/tt-metal/tree/v0.54.0-rc2) | [9531611](https://github.com/tenstorrent/vllm/tree/953161188c50f10da95a88ab305e23977ebd3750) |
| [Falcon 40B (TP=8)](./models/demos/t3000/falcon40b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | | 5.3 | 36 | 169.6 | [v0.55.0-rc19](https://github.com/tenstorrent/tt-metal/tree/v0.55.0-rc19) | |
| [Mixtral 8x7B (TP=8)](./models/demos/t3000/mixtral8x7b) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 227 | 14.9 | 33 | 476.8 | [v0.55.0-rc19](https://github.com/tenstorrent/tt-metal/tree/v0.55.0-rc19) | |
| [DeepSeek R1 Distill Llama 3.3 70B (TP=8)](./models/demos/llama3) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 1113 | 16.4 | 33 |386.4 | [main](https://github.com/tenstorrent/tt-metal/) | [2f33504](https://github.com/tenstorrent/vllm/tree/2f33504bad49a6202d3685155107a6126a5b5e6e) |
| [Falcon 7B (DP=32)](./models/demos/tg/falcon7b) | 1024 | [Galaxy](https://tenstorrent.com/hardware/galaxy) | 223 | 4.8 | 26 | 4915.2 | [v0.55.0-rc18](https://github.com/tenstorrent/tt-metal/tree/v0.55.0-rc18) | |
| [Llama 3.1 70B (DP=4, TP=8)](./models/demos/t3000/llama3_70b) | 128 | [Galaxy](https://tenstorrent.com/hardware/galaxy) | 190 | 14.3 | 20 | 1835.5 | [v0.52.0-rc31](https://github.com/tenstorrent/tt-metal/tree/v0.52.0-rc31) | |
| [Llama 3.1 70B (TP=32)](./models/demos/llama3) | 32 | [Galaxy](https://tenstorrent.com/hardware/galaxy) | 763 | 13.5 | 80 | 432.0 | [v0.55.0-rc12](https://github.com/tenstorrent/tt-metal/tree/v0.55.0-rc12) | [2f33504](https://github.com/tenstorrent/vllm/tree/2f33504bad49a6202d3685155107a6126a5b5e6e) |
| [DeepSeek R1 Distill Llama 3.3 70B (TP=8)](https://github.com/tenstorrent/tt-metal/tree/hf-llama/models/demos/llama3) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 1113 | 16.4 | 33 |524.8 | [hf-llama](https://github.com/tenstorrent/tt-metal/tree/hf-llama) | [b9564bf](https://github.com/tenstorrent/vllm/tree/b9564bf364e95a3850619fc7b2ed968cc71e30b7) |
| [DeepSeek R1 Distill Llama 3.3 70B (TP=8)](https://github.com/tenstorrent/tt-metal/tree/main/models/demos/llama3) | 32 | [QuietBox](https://tenstorrent.com/hardware/tt-quietbox) | 1113 | 16.4 | 33 |524.8 | [main](https://github.com/tenstorrent/tt-metal/) | [b9564bf](https://github.com/tenstorrent/vllm/tree/b9564bf364e95a3850619fc7b2ed968cc71e30b7) |

> **Last Update:** January 27, 2025
> **Last Update:** February 5, 2025
>
> **Notes:**
>
Expand Down
3 changes: 3 additions & 0 deletions models/common/rmsnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ def __init__(
eps: float = 1e-05,
sharded_program_config=None,
sharded_output_config=None,
ccl_topology=ttnn.Topology.Ring,
):
super().__init__()
self.eps = eps
self.is_distributed = is_distributed
self.ccl_topology = ccl_topology

if state_dict_prefix:
weight_name = f"{state_dict_prefix}{weight_key}.weight"
Expand Down Expand Up @@ -144,6 +146,7 @@ def _distributed_rmsnorm(
tt_stats,
dim=3,
num_links=1,
topology=self.ccl_topology,
memory_config=ttnn.DRAM_MEMORY_CONFIG,
)
# Run distributed rmsnorm part 2
Expand Down
93 changes: 48 additions & 45 deletions models/demos/llama3/PERF.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,51 +4,54 @@ Performance collected from [demo/demo.py](demo/demo.py) and accuracy collected f

Note that `test_llama_accuracy.py` parses the below to determine expected values +- 0.5.

## LlamaOptimizations.performance
## Performance

This configuration uses bfp4 MLP FF1+FF3 for all models.

| Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) |
|-------|--------|-----------|-----------|---------------|
| 1b | N150 | 87 | 98 | 91.0 |
| 1b | N300 | 87 | 98 | 98.8 |
| 1b | T3K | 87 | 98 | 97.8 |
| 1b | TG | 88 | 99 | 51.0 |
| 3b | N150 | 90 | 98 | 49.2 |
| 3b | N300 | 90 | 98 | 56.8 |
| 3b | T3K | 88 | 98 | 54.5 |
| 3b | TG | 90 | 97 | 33.5 |
| 8b | N150 | 86 | 99 | 28.6 |
| 8b | N300 | 85 | 98 | 38.9 |
| 8b | T3K | 84 | 97 | 53.7 |
| 8b | TG | 86 | 98 | 29.5 |
| 11b | N300 | 87 | 98 | 38.6 |
| 11b | T3K | 88 | 98 | 52.6 |
| 11b | TG | 86 | 98 | 29.5 |
| 70b | T3K | 95 | 99 | 14.7 |
| 70b | TG | 95 | 100 | 12.7 |


## LlamaOptimizations.accuracy

This configuration uses bfp4 MLP FF1+FF3 only for the 3.1-70B model.

| Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) |
|-------|--------|-----------|-----------|---------------|
| 1b | N150 | 89 | 98 | 86.8 |
| 1b | N300 | 88 | 99 | 98.1 |
| 1b | T3K | 86 | 99 | 97.5 |
| 1b | TG | 87 | 98 | 51.3 |
| 3b | N150 | 92 | 100 | 44.2 |
| 3b | N300 | 92 | 99 | 54.2 |
| 3b | T3K | 91 | 98 | 55.6 |
| 3b | TG | 91 | 98 | 33.6 |
| 8b | N150 | 91 | 99 | 23.6 |
| 8b | N300 | 91 | 99 | 34.5 |
| 8b | T3K | 90 | 99 | 49.8 |
| 8b | TG | 88 | 100 | 29.5 |
| 11b | N300 | 91 | 99 | 33.8 |
| 11b | T3K | 91 | 99 | 52.6 |
| 11b | TG | 88 | 100 | 29.5 |
| 70b | T3K | 95 | 99 | 14.7 |
| 70b | TG | 95 | 100 | 12.7 |
| Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) |
|----------------|--------|-----------|-----------|---------------|
| Llama3.2-1B | N150 | 89 | 98 | 86.9 |
| Llama3.2-1B | N300 | 91 | 98 | 104.3 |
| Llama3.2-1B | T3K | 91 | 98 | 118.5 |
| Llama3.2-1B | TG | | | 72.3 |
| Llama3.2-3B | N150 | 92 | 96 | 53.3 |
| Llama3.2-3B | N300 | 91 | 96 | 66.1 |
| Llama3.2-3B | T3K | 91 | 96 | 66.9 |
| Llama3.2-3B | TG | | | 48.5 |
| Llama3.1-8B | N150 | 87 | 99 | 27.9 |
| Llama3.1-8B | N300 | 88 | 99 | 43.7 |
| Llama3.1-8B | T3K | 91 | 100 | 64.2 |
| Llama3.1-8B | TG | | | 41.0 |
| Llama3.2-11B | N300 | 89 | 99 | 43.5 |
| Llama3.2-11B | T3K | 88 | 99 | 63.4 |
| Llama3.2-11B | TG | | | 40.9 |
| Llama3.1-70B | T3K | 96 | 100 | 16.1 |
| Llama3.1-70B | TG | | | |
| Qwen2.5-7B | N300 | 81 | 96 | 37.9 |
| Qwen2.5-72B | T3K | 99 | 100 | 12.8 |

## Accuracy

This configuration uses bfp4 MLP FF1+FF3 only for the Llama-3.1-70B model and the Qwen-2.5-72B model.

| Model | Device | Top-1 (%) | Top-5 (%) | Speed (t/s/u) |
|----------------|--------|-----------|-----------|---------------|
| Llama3.2-1B | N150 | 88 | 98 | 86.8 |
| Llama3.2-1B | N300 | 90 | 98 | 98.1 |
| Llama3.2-1B | T3K | 90 | 98 | 97.5 |
| Llama3.2-1B | TG | 87 | 98 | 51.3 |
| Llama3.2-3B | N150 | 93 | 99 | 44.2 |
| Llama3.2-3B | N300 | 92 | 98 | 54.2 |
| Llama3.2-3B | T3K | 93 | 98 | 55.6 |
| Llama3.2-3B | TG | 91 | 98 | 33.6 |
| Llama3.1-8B | N150 | 93 | 100 | 23.6 |
| Llama3.1-8B | N300 | 93 | 100 | 34.5 |
| Llama3.1-8B | T3K | 92 | 100 | 49.8 |
| Llama3.1-8B | TG | 88 | 100 | 29.5 |
| Llama3.2-11B | N300 | 93 | 100 | 33.8 |
| Llama3.2-11B | T3K | 94 | 100 | 52.6 |
| Llama3.2-11B | TG | 88 | 100 | 29.5 |
| Llama3.1-70B | T3K | 97 | 100 | 14.7 |
| Llama3.1-70B | TG | 95 | 100 | 12.7 |
| Qwen2.5-7B | N300 | 81 | 96 | 33.4 |
| Qwen2.5-72B | T3K | 99 | 100 | 12.8 |
27 changes: 23 additions & 4 deletions models/demos/llama3/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ The current version supports the following Llama3 models:
- Llama3.1-8B
- Llama3.2-11B
- Llama3.1-70B (T3000 and TG-only)
- DeepSeek R1 Distill Llama 3.3 70B (T3000 and TG-only)

All the above llama models (with the exception of 70B due to its large size) are compatible and tested on the following Tenstorrent hardware:
- N150 (1-chip)
Expand All @@ -25,13 +26,15 @@ Max Prefill Chunk Sizes (text-only):
| Llama3.1-8B | 4k tokens | 64k tokens | 128k tokens | 128k tokens |
| Llama3.2-11B | 4k tokens | 64k tokens | 128k tokens | 128k tokens |
| Llama3.1-70B | Not supported | Not supported | 32k tokens | 128k tokens |
| DeepSeek-R1-Distill-Llama3.3-70B | Not supported | Not supported | 32k tokens | 128k tokens |

- These max chunk sizes are specific to max context length 128k and are configured via `MAX_PREFILL_CHUNK_SIZES_DIV1024` in [model_config.py](https://github.com/tenstorrent/tt-metal/blob/main/models/demos/llama3/tt/model_config.py). If the max context length is set to a smaller value using the `max_seq_len` flag (see [Run the demo](#run-the-demo)), these chunk sizes can possibly be increased due to using a smaller KV cache.

**Max Context Lengths (Llama3.2-11B multimodal)**: Llama3.2-11B multimodal is currently only supported on N300 and T3000. On N300, a max prefill context length of 8k is supported, while T3000 supports a max context length of 128k.

## How to Run

### Download the weights
### Llama models: download the weights

Download the weights [directly from Meta](https://llama.meta.com/llama-downloads/), this will mean accepting their license terms.

Expand Down Expand Up @@ -59,17 +62,33 @@ Llama3.2-11B multimodal requires extra python dependencies. Install them from:
pip install -r models/demos/llama3/requirements.txt
```

### HuggingFace models (e.g. DeepSeek R1 Distill Llama 3.3 70B)

Download the weights from [HuggingFace](https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Llama-70B). Your model directory should have the following structure:

```
DeepSeek-R1-Distill-Llama-70B/
config.json
generation_config.json
model-00001-of-00062.safetensors
...
```

### Setup TT environment

1. Set up environment variables:
```
export LLAMA_DIR=<meta_llama_model_dir>
export LLAMA_DIR=<model_dir>
```

On N150, N300 and T3K:
```
export WH_ARCH_YAML=wormhole_b0_80_arch_eth_dispatch.yaml
```

- `$LLAMA_DIR` sets the path for the Llama3 model weights and caches.

- `$WH_ARCH_YAML` sets the dispatch over ethernet cores. This is optional for N150 and required for N300 and T3000, enabling a full core grid utilization (8x8), allowing for maximum performance of LLama3 models.
- `$WH_ARCH_YAML` sets the dispatch over ethernet cores. This is optional for N150 and required for N300 and T3000, enabling a full core grid utilization (8x8), allowing for maximum performance of LLama3 models. Do not set this for TG.

On the first execution of each model, TTNN will create weight cache files for that model, to speed up future runs.
These cache files only need to be created once for each model and each weight (i.e. new finetuned weights will need to be cached) and will be stored accordingly to the machine you are running the models:
Expand All @@ -80,14 +99,14 @@ $LLAMA_DIR/T3K # For T3000
$LLAMA_DIR/TG # For TG
```


### Run the demo

The Llama3 demo includes 3 main modes of operation and is fully parametrized to support other configurations.

- `batch-1`: Runs a small prompt for a single user
- `batch-32`: Runs a small prompt for a a batch of 32 users
- `long-context`: Runs a large prompt (64k tokens) for a single user
- `reasoning-1`: Runs a reasoning prompt for a single user

If you want to provide your own demo configuration, please take a look at the pytest parametrize calls in `models/demos/llama3/demo/demo.py`. For convenience we list all the supported params below:

Expand Down
Loading

0 comments on commit d0b59bd

Please sign in to comment.