Skip to content

Commit

Permalink
Revert lint only changes. Polish the conversion script.
Browse files Browse the repository at this point in the history
  • Loading branch information
wang2yn84 committed Feb 20, 2025
1 parent b8222de commit 36bdf05
Show file tree
Hide file tree
Showing 14 changed files with 1,629 additions and 1,618 deletions.
24 changes: 4 additions & 20 deletions MaxText/tests/hf_checkpoint_conversion_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@
import torch
from MaxText.llama_or_mistral_ckpt import permute_to_match_maxtext_rope
from MaxText.llama_mistral_mixtral_orbax_to_hf import unpermute_from_match_maxtext_rope
# from .. import llama_or_mistral_ckpt
# from .. import llama_mistral_mixtral_orbax_to_hf
# import llama_or_mistral_ckpt
# import llama_mistral_mixtral_orbax_to_hf
import sys
import jax
from jax.sharding import Mesh
Expand All @@ -32,7 +28,6 @@ def load_hf(hf_checkpoint_folder):
with safe_open(st_f, framework="pt", device="cpu") as f:
for key in f.keys():
hf_tensor[key] = f.get_tensor(key).to(torch.float16)
# print(f"Weight name {key}, Shape: {hf_tensor.shape}, dtype: {hf_tensor[key].dtype}")
return hf_tensor


Expand All @@ -41,8 +36,6 @@ def load_meta(meta_checkpoint_folder):
ckpt_paths = sorted(pathlib.Path(meta_checkpoint_folder).glob("[!.]*.pth"))
for i, ckpt_path in enumerate(ckpt_paths):
meta_tensor = torch.load(ckpt_path, map_location="cpu")
# chkpt_vars[int(ckpt_path.name.split(".", maxsplit=2)[1])] = checkpoint
# chkpt_vars = [chkpt_vars[i] for i in sorted(list(chkpt_vars.keys()))]
return meta_tensor


Expand Down Expand Up @@ -86,12 +79,8 @@ def get_named_leaves(pytree, parent_key=""):
return False
try:
if not np.allclose(named_leaves1[key], named_leaves2[key], atol=atol):
# print(f"Mismatch at leaf '{key}':\n{named_leaves1[key]}\n{named_leaves2[key]}")
# return False
# print(f"Mismatch at leaf '{key}'")
mismatch_values1 = named_leaves1[key].flatten()[:10]
mismatch_values2 = named_leaves2[key].flatten()[:10]
# print(f"Mismatch at leaf '{key}':\nFirst 10 elements:\n{mismatch_values1}\n{mismatch_values2}")
print(f"Mismatch at leaf '{key}' with shape {named_leaves1[key].shape}:\n")
for i in range(10):
print(f"{named_leaves1[key][..., i, :]}\n")
Expand All @@ -101,7 +90,6 @@ def get_named_leaves(pytree, parent_key=""):
return
except:
print(f"The issue is with {key}")
# print(f"Checking {key} done")

print("All leaves match within tolerance.")
return True
Expand All @@ -110,7 +98,6 @@ def get_named_leaves(pytree, parent_key=""):
def test_huggingface_to_maxtext_back_to_huggingface_flow():
base_num_query_heads = base_num_kv_heads = 16
head_dim = 32
# import pdb; pdb.set_trace()
wq = np.arange(base_num_query_heads * head_dim * base_num_query_heads * head_dim, dtype=np.float16).reshape(
base_num_query_heads * head_dim, base_num_query_heads * head_dim
)
Expand Down Expand Up @@ -143,22 +130,19 @@ def main():
parser.add_argument(
"--original_ckpt",
type=str,
default="/mnt/disks/persist/checkpoints/huggingface/DeepSeek-R1-Distill-Llama-8B",
default="",
help="The original huggingface checkpoint",
)
parser.add_argument(
"--converted_ckpt",
type=str,
default="/mnt/disks/persist/checkpoints/huggingface/DeepSeek-R1-Distill-Llama-8B/",
default="",
help="The original huggingface checkpoint",
)
args = parser.parse_args()

hf_checkpoint_folder = args.original_ckpt
hf_tensor = load_hf(hf_checkpoint_folder)

meta_checkpoint_folder = args.converted_ckpt
meta_tensor = load_hf(meta_checkpoint_folder)
hf_tensor = load_hf(args.original_ckpt)
meta_tensor = load_hf(args.converted_ckpt)

compare_pytrees(hf_tensor, meta_tensor)

Expand Down
Loading

0 comments on commit 36bdf05

Please sign in to comment.