-
Notifications
You must be signed in to change notification settings - Fork 321
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix checkpoint conversion scripts and lint errors.
- Loading branch information
Showing
17 changed files
with
1,832 additions
and
1,787 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
""" | ||
Copyright 2023 Google LLC | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
https://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
import glob | ||
import os | ||
import torch | ||
from safetensors import safe_open | ||
import pathlib | ||
import jax | ||
import jax.numpy as jnp | ||
import numpy as np | ||
import torch | ||
from MaxText.llama_or_mistral_ckpt import permute_to_match_maxtext_rope | ||
from MaxText.llama_mistral_mixtral_orbax_to_hf import unpermute_from_match_maxtext_rope | ||
# from .. import llama_or_mistral_ckpt | ||
# from .. import llama_mistral_mixtral_orbax_to_hf | ||
# import llama_or_mistral_ckpt | ||
# import llama_mistral_mixtral_orbax_to_hf | ||
import sys | ||
import jax | ||
from jax.sharding import Mesh | ||
from jax.experimental import mesh_utils | ||
import argparse | ||
import pyconfig | ||
import pytest | ||
|
||
import unittest | ||
|
||
|
||
def load_hf(hf_checkpoint_folder): | ||
safetensor_files = glob.glob(os.path.join(hf_checkpoint_folder, "*.safetensors")) | ||
|
||
hf_tensor = {} | ||
for st_f in safetensor_files: | ||
with safe_open(st_f, framework="pt", device="cpu") as f: | ||
for key in f.keys(): | ||
hf_tensor[key] = f.get_tensor(key).to(torch.float16) | ||
# print(f"Weight name {key}, Shape: {hf_tensor.shape}, dtype: {hf_tensor[key].dtype}") | ||
return hf_tensor | ||
|
||
|
||
def load_meta(meta_checkpoint_folder): | ||
meta_tensor = {} | ||
ckpt_paths = sorted(pathlib.Path(meta_checkpoint_folder).glob("[!.]*.pth")) | ||
for i, ckpt_path in enumerate(ckpt_paths): | ||
meta_tensor = torch.load(ckpt_path, map_location="cpu") | ||
# chkpt_vars[int(ckpt_path.name.split(".", maxsplit=2)[1])] = checkpoint | ||
# chkpt_vars = [chkpt_vars[i] for i in sorted(list(chkpt_vars.keys()))] | ||
return meta_tensor | ||
|
||
|
||
def compare_pytrees(tree1, tree2, atol=0.001): | ||
""" | ||
Compares two JAX pytrees to check if all leaf values are within the given absolute tolerance. | ||
Args: | ||
tree1: First pytree. | ||
tree2: Second pytree. | ||
atol: Absolute tolerance for comparison (default: 0.001). | ||
Returns: | ||
A boolean indicating if all leaf values are within the specified range. | ||
""" | ||
# Ensure both trees have the same structure | ||
if jax.tree_util.tree_structure(tree1) != jax.tree_util.tree_structure(tree2): | ||
print( | ||
f"Pytrees have different structures! Tree1: {jax.tree_util.tree_structure(tree1)} \n\n\nTree2: {jax.tree_util.tree_structure(tree2)}" | ||
) | ||
return False | ||
|
||
# Compare leaves with names | ||
def get_named_leaves(pytree, parent_key=""): | ||
named_leaves = {} | ||
for key, value in pytree.items(): | ||
new_key = f"{parent_key}.{key}" if parent_key else key | ||
if isinstance(value, dict): | ||
named_leaves.update(get_named_leaves(value, new_key)) | ||
else: | ||
named_leaves[new_key] = value | ||
return named_leaves | ||
|
||
named_leaves1 = get_named_leaves(tree1) | ||
named_leaves2 = get_named_leaves(tree2) | ||
|
||
for key in named_leaves1: | ||
# import pdb; pdb.set_trace() | ||
if key not in named_leaves2: | ||
print(f"Missing key in second tree: {key}") | ||
return False | ||
try: | ||
if not np.allclose(named_leaves1[key], named_leaves2[key], atol=atol): | ||
# print(f"Mismatch at leaf '{key}':\n{named_leaves1[key]}\n{named_leaves2[key]}") | ||
# return False | ||
# print(f"Mismatch at leaf '{key}'") | ||
mismatch_values1 = named_leaves1[key].flatten()[:10] | ||
mismatch_values2 = named_leaves2[key].flatten()[:10] | ||
# print(f"Mismatch at leaf '{key}':\nFirst 10 elements:\n{mismatch_values1}\n{mismatch_values2}") | ||
print(f"Mismatch at leaf '{key}' with shape {named_leaves1[key].shape}:\n") | ||
for i in range(10): | ||
print(f"{named_leaves1[key][..., i, :]}\n") | ||
print(f"The second tensor:\n") | ||
for i in range(10): | ||
print(f"{named_leaves2[key][..., i, :]}\n") | ||
return | ||
except: | ||
print(f"The issue is with {key}") | ||
# print(f"Checking {key} done") | ||
|
||
print("All leaves match within tolerance.") | ||
return True | ||
|
||
|
||
def test_huggingface_to_maxtext_back_to_huggingface_flow(): | ||
base_num_query_heads = base_num_kv_heads = 16 | ||
head_dim = 32 | ||
# import pdb; pdb.set_trace() | ||
wq = np.arange(base_num_query_heads * head_dim * base_num_query_heads * head_dim, dtype=np.float16).reshape( | ||
base_num_query_heads * head_dim, base_num_query_heads * head_dim | ||
) | ||
wq1 = wq.transpose() | ||
wq2 = np.reshape(wq1, [base_num_query_heads * head_dim, base_num_query_heads, head_dim]) | ||
|
||
wq3 = llama_or_mistral_ckpt.permute_to_match_maxtext_rope(wq2) | ||
stack_shape = (1,) | ||
x = np.zeros(stack_shape + wq3.shape, dtype=np.float16) | ||
x[0, ...] = wq3 | ||
x = np.transpose(x, axes=(1, 0, 2, 3)) | ||
|
||
x = x[:, 0, :, :] | ||
wq4 = llama_mistral_mixtral_orbax_to_hf.unpermute_from_match_maxtext_rope(x, "llama3.1") | ||
wq5 = wq4.reshape(base_num_query_heads * head_dim, base_num_query_heads * head_dim) | ||
wq6 = wq5.transpose() | ||
|
||
if not np.array_equal(wq, wq6): | ||
print("Test failed: wq does not match wq6") | ||
|
||
if not np.array_equal(wq1, wq5): | ||
print("Test failed: wq1 does not match wq5") | ||
|
||
if not np.array_equal(wq2, wq4): | ||
print("Test failed: wq2 does not match wq4") | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="Compares the original checkpoint and converted back checkpoint.") | ||
parser.add_argument( | ||
"--original_ckpt", | ||
type=str, | ||
default="/mnt/disks/persist/checkpoints/huggingface/DeepSeek-R1-Distill-Llama-8B", | ||
help="The original huggingface checkpoint", | ||
) | ||
parser.add_argument( | ||
"--converted_ckpt", | ||
type=str, | ||
default="/mnt/disks/persist/checkpoints/huggingface/DeepSeek-R1-Distill-Llama-8B/", | ||
help="The original huggingface checkpoint", | ||
) | ||
args = parser.parse_args() | ||
|
||
hf_checkpoint_folder = args.original_ckpt | ||
hf_tensor = load_hf(hf_checkpoint_folder) | ||
|
||
meta_checkpoint_folder = args.converted_ckpt | ||
meta_tensor = load_hf(meta_checkpoint_folder) | ||
|
||
compare_pytrees(hf_tensor, meta_tensor) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.