Skip to content

Commit

Permalink
Fix checkpoint conversion scripts and lint errors.
Browse files Browse the repository at this point in the history
  • Loading branch information
wang2yn84 committed Feb 20, 2025
1 parent cc616c4 commit b8222de
Show file tree
Hide file tree
Showing 17 changed files with 1,832 additions and 1,787 deletions.
56 changes: 32 additions & 24 deletions MaxText/llama_or_mistral_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,26 +167,27 @@ def _hf_mapping(layer_idx: int = -1, expert_idx: int = -1) -> dict:
f"layers.{layer_idx}.feed_forward.w3.weight": f"model.layers.{layer_idx}.mlp.up_proj.weight",
}


def _hf_to_maxtext_mapping(layer_idx: int = -1, expert_idx: int = -1) -> dict:
# pylint: disable=line-too-long
# pylint: disable=line-too-long
return {
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
f"model.layers.{layer_idx}.input_layernorm.weight": f"layers.{layer_idx}.attention_norm.weight",
f"model.layers.{layer_idx}.post_attention_layernorm.weight": f"layers.{layer_idx}.ffn_norm.weight",
f"model.layers.{layer_idx}.self_attn.q_proj.weight": f"layers.{layer_idx}.attention.wq.weight",
f"model.layers.{layer_idx}.self_attn.k_proj.weight": f"layers.{layer_idx}.attention.wk.weight",
f"model.layers.{layer_idx}.self_attn.v_proj.weight": f"layers.{layer_idx}.attention.wv.weight",
f"model.layers.{layer_idx}.self_attn.o_proj.weight": f"layers.{layer_idx}.attention.wo.weight",
# MOE model
f"model.layers.{layer_idx}.block_sparse_moe.gate.weight": f"layers.{layer_idx}.feed_forward.gate.weight",
f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w1.weight": f"layers.{layer_idx}.feed_forward.experts.{expert_idx}.w1.weight",
f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w2.weight": f"layers.{layer_idx}.feed_forward.experts.{expert_idx}.w2.weight",
f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w3.weight": f"layers.{layer_idx}.feed_forward.experts.{expert_idx}.w3.weight",
f"model.layers.{layer_idx}.mlp.gate_proj.weight": f"layers.{layer_idx}.feed_forward.w1.weight",
f"model.layers.{layer_idx}.mlp.down_proj.weight": f"layers.{layer_idx}.feed_forward.w2.weight",
f"model.layers.{layer_idx}.mlp.up_proj.weight": f"layers.{layer_idx}.feed_forward.w3.weight"
"model.embed_tokens.weight": "tok_embeddings.weight",
"model.norm.weight": "norm.weight",
"lm_head.weight": "output.weight",
f"model.layers.{layer_idx}.input_layernorm.weight": f"layers.{layer_idx}.attention_norm.weight",
f"model.layers.{layer_idx}.post_attention_layernorm.weight": f"layers.{layer_idx}.ffn_norm.weight",
f"model.layers.{layer_idx}.self_attn.q_proj.weight": f"layers.{layer_idx}.attention.wq.weight",
f"model.layers.{layer_idx}.self_attn.k_proj.weight": f"layers.{layer_idx}.attention.wk.weight",
f"model.layers.{layer_idx}.self_attn.v_proj.weight": f"layers.{layer_idx}.attention.wv.weight",
f"model.layers.{layer_idx}.self_attn.o_proj.weight": f"layers.{layer_idx}.attention.wo.weight",
# MOE model
f"model.layers.{layer_idx}.block_sparse_moe.gate.weight": f"layers.{layer_idx}.feed_forward.gate.weight",
f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w1.weight": f"layers.{layer_idx}.feed_forward.experts.{expert_idx}.w1.weight",
f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w2.weight": f"layers.{layer_idx}.feed_forward.experts.{expert_idx}.w2.weight",
f"model.layers.{layer_idx}.block_sparse_moe.experts.{expert_idx}.w3.weight": f"layers.{layer_idx}.feed_forward.experts.{expert_idx}.w3.weight",
f"model.layers.{layer_idx}.mlp.gate_proj.weight": f"layers.{layer_idx}.feed_forward.w1.weight",
f"model.layers.{layer_idx}.mlp.down_proj.weight": f"layers.{layer_idx}.feed_forward.w2.weight",
f"model.layers.{layer_idx}.mlp.up_proj.weight": f"layers.{layer_idx}.feed_forward.w3.weight",
}


Expand Down Expand Up @@ -220,6 +221,7 @@ def permute_to_match_maxtext_rope(arr):
x[..., 1::2] = odds
return x


def convert_huggingface_to_jax_weights(base_model_path, model_size, huggingface_ckpt, model_params, mem_info):
base_num_decoder_layers = model_params["num_layers"]
base_num_query_heads = model_params["num_heads"]
Expand All @@ -239,7 +241,7 @@ def convert_huggingface_to_jax_weights(base_model_path, model_size, huggingface_
parts = key.split(".")
layer = int(parts[2]) if "layers" in key else 0
mapped_key = _hf_to_maxtext_mapping(layer)[key]
chkpt_vars[mapped_key]=f.get_tensor(key)
chkpt_vars[mapped_key] = f.get_tensor(key)

logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))

Expand Down Expand Up @@ -269,13 +271,14 @@ def convert_huggingface_to_jax_weights(base_model_path, model_size, huggingface_
# logits dense #################################################
max_logging.log("Processing logits dense")

jax_weights["decoder"]["logits_dense"]["kernel"] = chkpt_vars["output.weight"].to(torch.float16).numpy().transpose()[:, :vocab_size]
jax_weights["decoder"]["logits_dense"]["kernel"] = (
chkpt_vars["output.weight"].to(torch.float16).numpy().transpose()[:, :vocab_size]
)
# logits_dense = np.concatenate(
# [var["output.weight"].type(torch.float16).numpy() for var in chkpt_vars], axis=0
# ).transpose()[:, :vocab_size]
# jax_weights["decoder"]["logits_dense"]["kernel"] = logits_dense


logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))

# token embedding ##############################################
Expand All @@ -291,7 +294,9 @@ def convert_huggingface_to_jax_weights(base_model_path, model_size, huggingface_
if model_size[:6] == "llama3":
jax_weights["token_embedder"]["embedding"] = chkpt_vars["tok_embeddings.weight"].to(torch.float16).numpy()
else:
jax_weights["token_embedder"]["embedding"] = chkpt_vars["tok_embeddings.weight"].to(torch.float16).numpy()[:vocab_size, :]
jax_weights["token_embedder"]["embedding"] = (
chkpt_vars["tok_embeddings.weight"].to(torch.float16).numpy()[:vocab_size, :]
)

logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))

Expand Down Expand Up @@ -522,7 +527,7 @@ def convert_to_jax_weights(base_model_path, model_size, huggingface_ckpt):
max_logging.log(f"Loading the base model from {base_model_path}")
# Skip any hidden files for checkpoints
if huggingface_ckpt:
return convert_huggingface_to_jax_weights( base_model_path, model_size, huggingface_ckpt, model_params, mem_info)
return convert_huggingface_to_jax_weights(base_model_path, model_size, huggingface_ckpt, model_params, mem_info)
chkpt_vars = {}
checkpoint = {}
ckpt_paths = sorted(pathlib.Path(base_model_path).glob("[!.]*.pth"))
Expand Down Expand Up @@ -831,6 +836,7 @@ def checkpoint_device_put(arr):
# Upon preemption, exit when and only when all ongoing saves are complete.
checkpoint_manager.wait_until_finished()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--base-model-path", type=str, required=True)
Expand All @@ -843,4 +849,6 @@ def checkpoint_device_put(arr):
raise NotImplementedError

os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={SIMULATED_CPU_DEVICES_COUNT}"
save_jax_weights_to_checkpoint(args.maxtext_model_path, convert_to_jax_weights(args.base_model_path, args.model_size, args.huggingface_checkpoint))
save_jax_weights_to_checkpoint(
args.maxtext_model_path, convert_to_jax_weights(args.base_model_path, args.model_size, args.huggingface_checkpoint)
)
15 changes: 15 additions & 0 deletions MaxText/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
167 changes: 167 additions & 0 deletions MaxText/tests/hf_checkpoint_conversion_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import glob
import os
import torch
from safetensors import safe_open
import pathlib
import jax
import jax.numpy as jnp
import numpy as np
import torch
from MaxText.llama_or_mistral_ckpt import permute_to_match_maxtext_rope
from MaxText.llama_mistral_mixtral_orbax_to_hf import unpermute_from_match_maxtext_rope
# from .. import llama_or_mistral_ckpt
# from .. import llama_mistral_mixtral_orbax_to_hf
# import llama_or_mistral_ckpt
# import llama_mistral_mixtral_orbax_to_hf
import sys
import jax
from jax.sharding import Mesh
from jax.experimental import mesh_utils
import argparse
import pyconfig
import pytest

import unittest


def load_hf(hf_checkpoint_folder):
safetensor_files = glob.glob(os.path.join(hf_checkpoint_folder, "*.safetensors"))

hf_tensor = {}
for st_f in safetensor_files:
with safe_open(st_f, framework="pt", device="cpu") as f:
for key in f.keys():
hf_tensor[key] = f.get_tensor(key).to(torch.float16)
# print(f"Weight name {key}, Shape: {hf_tensor.shape}, dtype: {hf_tensor[key].dtype}")
return hf_tensor


def load_meta(meta_checkpoint_folder):
meta_tensor = {}
ckpt_paths = sorted(pathlib.Path(meta_checkpoint_folder).glob("[!.]*.pth"))
for i, ckpt_path in enumerate(ckpt_paths):
meta_tensor = torch.load(ckpt_path, map_location="cpu")
# chkpt_vars[int(ckpt_path.name.split(".", maxsplit=2)[1])] = checkpoint
# chkpt_vars = [chkpt_vars[i] for i in sorted(list(chkpt_vars.keys()))]
return meta_tensor


def compare_pytrees(tree1, tree2, atol=0.001):
"""
Compares two JAX pytrees to check if all leaf values are within the given absolute tolerance.
Args:
tree1: First pytree.
tree2: Second pytree.
atol: Absolute tolerance for comparison (default: 0.001).
Returns:
A boolean indicating if all leaf values are within the specified range.
"""
# Ensure both trees have the same structure
if jax.tree_util.tree_structure(tree1) != jax.tree_util.tree_structure(tree2):
print(
f"Pytrees have different structures! Tree1: {jax.tree_util.tree_structure(tree1)} \n\n\nTree2: {jax.tree_util.tree_structure(tree2)}"
)
return False

# Compare leaves with names
def get_named_leaves(pytree, parent_key=""):
named_leaves = {}
for key, value in pytree.items():
new_key = f"{parent_key}.{key}" if parent_key else key
if isinstance(value, dict):
named_leaves.update(get_named_leaves(value, new_key))
else:
named_leaves[new_key] = value
return named_leaves

named_leaves1 = get_named_leaves(tree1)
named_leaves2 = get_named_leaves(tree2)

for key in named_leaves1:
# import pdb; pdb.set_trace()
if key not in named_leaves2:
print(f"Missing key in second tree: {key}")
return False
try:
if not np.allclose(named_leaves1[key], named_leaves2[key], atol=atol):
# print(f"Mismatch at leaf '{key}':\n{named_leaves1[key]}\n{named_leaves2[key]}")
# return False
# print(f"Mismatch at leaf '{key}'")
mismatch_values1 = named_leaves1[key].flatten()[:10]
mismatch_values2 = named_leaves2[key].flatten()[:10]
# print(f"Mismatch at leaf '{key}':\nFirst 10 elements:\n{mismatch_values1}\n{mismatch_values2}")
print(f"Mismatch at leaf '{key}' with shape {named_leaves1[key].shape}:\n")
for i in range(10):
print(f"{named_leaves1[key][..., i, :]}\n")
print(f"The second tensor:\n")
for i in range(10):
print(f"{named_leaves2[key][..., i, :]}\n")
return
except:
print(f"The issue is with {key}")
# print(f"Checking {key} done")

print("All leaves match within tolerance.")
return True


def test_huggingface_to_maxtext_back_to_huggingface_flow():
base_num_query_heads = base_num_kv_heads = 16
head_dim = 32
# import pdb; pdb.set_trace()
wq = np.arange(base_num_query_heads * head_dim * base_num_query_heads * head_dim, dtype=np.float16).reshape(
base_num_query_heads * head_dim, base_num_query_heads * head_dim
)
wq1 = wq.transpose()
wq2 = np.reshape(wq1, [base_num_query_heads * head_dim, base_num_query_heads, head_dim])

wq3 = llama_or_mistral_ckpt.permute_to_match_maxtext_rope(wq2)
stack_shape = (1,)
x = np.zeros(stack_shape + wq3.shape, dtype=np.float16)
x[0, ...] = wq3
x = np.transpose(x, axes=(1, 0, 2, 3))

x = x[:, 0, :, :]
wq4 = llama_mistral_mixtral_orbax_to_hf.unpermute_from_match_maxtext_rope(x, "llama3.1")
wq5 = wq4.reshape(base_num_query_heads * head_dim, base_num_query_heads * head_dim)
wq6 = wq5.transpose()

if not np.array_equal(wq, wq6):
print("Test failed: wq does not match wq6")

if not np.array_equal(wq1, wq5):
print("Test failed: wq1 does not match wq5")

if not np.array_equal(wq2, wq4):
print("Test failed: wq2 does not match wq4")


def main():
parser = argparse.ArgumentParser(description="Compares the original checkpoint and converted back checkpoint.")
parser.add_argument(
"--original_ckpt",
type=str,
default="/mnt/disks/persist/checkpoints/huggingface/DeepSeek-R1-Distill-Llama-8B",
help="The original huggingface checkpoint",
)
parser.add_argument(
"--converted_ckpt",
type=str,
default="/mnt/disks/persist/checkpoints/huggingface/DeepSeek-R1-Distill-Llama-8B/",
help="The original huggingface checkpoint",
)
args = parser.parse_args()

hf_checkpoint_folder = args.original_ckpt
hf_tensor = load_hf(hf_checkpoint_folder)

meta_checkpoint_folder = args.converted_ckpt
meta_tensor = load_hf(meta_checkpoint_folder)

compare_pytrees(hf_tensor, meta_tensor)


if __name__ == "__main__":
main()
Loading

0 comments on commit b8222de

Please sign in to comment.