Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add example conversion script to convert hf to consolidated weight format #319

Merged
merged 5 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions examples/hf_llama_conversion/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Convert Hugging Face llama weights to official llama consolidated format

This is the reverse conversion for `convert_llama_weights_to_hf.py` script from the transformer package.

## Step 0: Convert to consolidated format
- Create an output directory for the converted weights, such as `test70B`.
- Copy file params.json from the official llama download into that directory.
- Run the conversion script. `model-path` can be a Hugging Face hub model or a local hf model directory.
```
python -m llama_recipes.tools.convert_hf_weights_to_llama --model-path meta-llama/Llama-2-70b-chat-hf --output-dir test70B --model-size 70B
```

## Step 1: Run inference
Checkout the official llama inference [repo](https://github.com/facebookresearch/llama). Test using chat or text completion.
```
torchrun --nproc_per_node 8 example_chat_completion.py --ckpt_dir ./test70B --tokenizer_path ${llama_2_dir}/tokenizer.model
```

For validation, please compare the converted weights with official llama 2 weights
```
python compare_llama_weights.py test70B ${llama_2_70b_chat_dir}
```
48 changes: 48 additions & 0 deletions examples/hf_llama_conversion/compare_llama_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import gc
import glob
import os
import sys

import torch
import tqdm


def main() -> None:
"""Compare two llama checkpoint directories"""

one_files = sorted(glob.glob(os.path.join(sys.argv[1], "consolidated.*.pth")))
two_files = sorted(glob.glob(os.path.join(sys.argv[2], "consolidated.*.pth")))
assert len(one_files) == len(
two_files
), "One directory has {} files while another has {} files.".format(
len(one_files), len(two_files)
)

deltas = []
for i in tqdm.trange(len(one_files), desc="Comparing shards"):
one = torch.load(one_files[i])
two = torch.load(two_files[i])
assert len(one) == len(
two
), "shard should have the same length: {} != {}".format(len(one), len(two))

for _, (v, w) in enumerate(zip(one.items(), two.items())):
assert v[0] == w[0], "{} != {}".format(v[0], w[0])
assert v[1].shape == w[1].shape, "tensor {} shape {} != {}".format(
v[0], v[1].shape, w[1].shape
)

delta = (v[1] - w[1]).abs().max().item()
deltas.append((i, v[0], delta))
del one
del two
gc.collect()

deltas = sorted(deltas, key=lambda x: x[-1], reverse=True)
print("Top 10 largest deltas:")
for i, k, v in deltas[:10]:
print(f" shard {i} {k}: {v}")


if __name__ == "__main__":
main()
163 changes: 163 additions & 0 deletions src/llama_recipes/tools/convert_hf_weights_to_llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
import json
import os
from typing import List, Union

import fire
import torch
from tqdm import tqdm
from transformers import LlamaForCausalLM # @manual

NUM_SHARDS = {
"7B": 1,
"13B": 2,
"34B": 4,
"30B": 4,
"65B": 8,
"70B": 8,
}


def write_model(model_path, model_size, output_base_path):
dtype = torch.bfloat16

params = json.load(open(os.path.join(output_base_path, "params.json"), "r"))
num_shards = NUM_SHARDS[model_size]
n_layers = params["n_layers"]
n_heads = params["n_heads"]
n_heads_per_shard = n_heads // num_shards
dim = params["dim"]
dims_per_head = dim // n_heads
base = 10000.0
inv_freq = (
1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
).to(dtype)

if "n_kv_heads" in params:
num_key_value_heads = params["n_kv_heads"] # for GQA / MQA
num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
key_value_dim = dim // num_key_value_heads
else: # compatibility with other checkpoints
num_key_value_heads = n_heads
num_local_key_value_heads = n_heads_per_shard
key_value_dim = dim

model = LlamaForCausalLM.from_pretrained(
model_path,
torch_dtype=dtype,
low_cpu_mem_usage=True,
)
loaded = model.state_dict()

# permute for sliced rotary
def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
return (
w.view(n_heads, 2, dim1 // n_heads // 2, dim2)
.transpose(1, 2)
.reshape(dim1, dim2)
)

state_dict = [{} for _ in range(num_shards)]

def insert(name: str, tensor: Union[List, torch.Tensor]):
for i in range(num_shards):
state_dict[i][name] = (
tensor[i].clone() if isinstance(tensor, list) else tensor
)

def insert_chunk(name: str, tensor: torch.Tensor, dim: int):
tensors = tensor.chunk(num_shards, dim=dim)
for i, tensor in enumerate(tensors):
state_dict[i][name] = tensor.clone()

insert_chunk("tok_embeddings.weight", loaded["model.embed_tokens.weight"], 1)
insert("norm.weight", loaded["model.norm.weight"])
insert_chunk("output.weight", loaded["lm_head.weight"], 0)

for layer_i in tqdm(range(n_layers), desc="Converting layers"):

ts = (
permute(loaded[f"model.layers.{layer_i}.self_attn.q_proj.weight"])
.view(n_heads_per_shard * num_shards, dims_per_head, dim)
.chunk(num_shards, dim=0)
)
insert(f"layers.{layer_i}.attention.wq.weight", [t.view(-1, dim) for t in ts])

ts = (
permute(
loaded[f"model.layers.{layer_i}.self_attn.k_proj.weight"],
num_key_value_heads,
key_value_dim,
dim,
)
.view(num_local_key_value_heads * num_shards, dims_per_head, dim)
.chunk(num_shards, dim=0)
)
insert(f"layers.{layer_i}.attention.wk.weight", [t.view(-1, dim) for t in ts])

ts = (
loaded[f"model.layers.{layer_i}.self_attn.v_proj.weight"]
.view(num_local_key_value_heads * num_shards, dims_per_head, dim)
.chunk(num_shards, dim=0)
)
insert(f"layers.{layer_i}.attention.wv.weight", [t.view(-1, dim) for t in ts])

insert_chunk(
f"layers.{layer_i}.attention.wo.weight",
loaded[f"model.layers.{layer_i}.self_attn.o_proj.weight"],
1,
)

insert_chunk(
f"layers.{layer_i}.feed_forward.w1.weight",
loaded[f"model.layers.{layer_i}.mlp.gate_proj.weight"],
0,
)

insert_chunk(
f"layers.{layer_i}.feed_forward.w2.weight",
loaded[f"model.layers.{layer_i}.mlp.down_proj.weight"],
1,
)

insert_chunk(
f"layers.{layer_i}.feed_forward.w3.weight",
loaded[f"model.layers.{layer_i}.mlp.up_proj.weight"],
0,
)

insert(
f"layers.{layer_i}.attention_norm.weight",
loaded[f"model.layers.{layer_i}.input_layernorm.weight"],
)
insert(
f"layers.{layer_i}.ffn_norm.weight",
loaded[f"model.layers.{layer_i}.post_attention_layernorm.weight"],
)
insert("rope.freqs", inv_freq)

for i in tqdm(range(num_shards), desc="Saving checkpoint shards"):
torch.save(
state_dict[i], os.path.join(output_base_path, f"consolidated.{i:02d}.pth")
)


def main(
model_path: str,
model_size: str,
output_dir: str,
):
"""Convert llama weights from huggingface format to consolidated format.
params:
model_path: model name or path to the model directory.
model_size: Llama model size, one of 7B, 13B, 34B, 30B, 65B, 70B.
output_dir: directory to save Llama weights, should contains params.json.
"""
assert model_size in NUM_SHARDS, f"Unknown model size {model_size}"
params_path = os.path.join(output_dir, "params.json")
assert os.path.isfile(params_path), f"{params_path} does not exist"

write_model(model_path, model_size, output_dir)


if __name__ == "__main__":
fire.Fire(main)
Loading