Skip to content

Commit

Permalink
Revert the revert "Merge pull request #1300 from AI-Hypercomputer:rev…
Browse files Browse the repository at this point in the history
…ert-1291-lance-deepseek" and fix all the lint issues.

This reverts commit 02b9b8d, reversing
changes made to 7c9e7ba.
  • Loading branch information
wang2yn84 committed Feb 22, 2025
1 parent 02b9b8d commit 3228831
Show file tree
Hide file tree
Showing 11 changed files with 813 additions and 50 deletions.
3 changes: 2 additions & 1 deletion MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ checkpoint_storage_use_zarr3: True

reuse_example_batch: 0 # for testing TPU performance, this options repeated uses the same batch.


metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written.
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
gcs_metrics: False
Expand Down Expand Up @@ -277,7 +278,7 @@ logical_axis_rules: [
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'expert']],
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']],
Expand Down
17 changes: 3 additions & 14 deletions MaxText/llama_mistral_mixtral_orbax_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
python3 MaxText/llama_or_mistral_ckpt.py --base-model-path <path/to/meta/ckpt> \
--maxtext-model-path <GCS/path/to/save/new/maxtext/ckpt> --model-size llama2-7b
python3 MaxText/llama_mistral_mixtral_orbax_to_hf.py MaxText/configs/base.yml
python3 MaxText/llama_mistral_mixtral_orbax_to_hf.py MaxText/configs/base.yml
base_output_directory=path/to/saving/intermediate_MaxText_files
load_parameters_path=/path/to/MaxText/checkpoint run_name=<your run name> model_name=<llama2 or mistral>
load_parameters_path=/path/to/MaxText/checkpoint run_name=<your run name> model_name=<llama2 or mistral>
hardware=gpu
hf_model_path=/local/path/to/save/HF/model/to
Expand All @@ -39,24 +39,13 @@
import numpy as np
import pyconfig
import max_utils
import jax
from jax.sharding import Mesh
import max_logging
import checkpointing
from generate_param_only_checkpoint import _read_train_checkpoint
import llama_or_mistral_ckpt
from transformers import LlamaForCausalLM, MistralForCausalLM, AutoModelForCausalLM, AutoConfig


def unpermute_from_match_maxtext_rope(arr, model_size):
"""
Function to get the RoPE values in correct ordering
"""
if model_size[:8] != "llama3.1":
return arr
evens = arr[..., ::2]
odds = arr[..., 1::2]
return jax.numpy.concatenate((evens, odds), axis=arr.ndim - 1)
from max_utils import unpermute_from_match_maxtext_rope


def reverse_scale(arr, scale):
Expand Down
320 changes: 294 additions & 26 deletions MaxText/llama_or_mistral_ckpt.py

Large diffs are not rendered by default.

21 changes: 21 additions & 0 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1154,3 +1154,24 @@ def print_system_information():
max_logging.log(f"System Information: Jax Version: {jax.__version__}")
max_logging.log(f"System Information: Jaxlib Version: {jax.lib.__version__}")
max_logging.log(f"System Information: Jax Backend: {jax.lib.xla_bridge.get_backend().platform_version}")


def permute_to_match_maxtext_rope(arr):
"""Permutes the Huggingface Rope to match the MaxText logic."""
assert arr.shape[-1] % 2 == 0, "The last dimension for rope has to be even."
evens, odds = np.split(arr, 2, axis=arr.ndim - 1) # pylint: disable=W0632
x = np.empty_like(arr)
x[..., ::2] = evens
x[..., 1::2] = odds
return x


def unpermute_from_match_maxtext_rope(arr, model_size):
"""
Function to get the RoPE values in correct ordering
"""
if model_size[:8] != "llama3.1":
return arr
evens = arr[..., ::2]
odds = arr[..., 1::2]
return jax.numpy.concatenate((evens, odds), axis=arr.ndim - 1)

Large diffs are not rendered by default.

Large diffs are not rendered by default.

17 changes: 8 additions & 9 deletions MaxText/tests/forward_pass_logit_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def get_data(golden_data, golden_data_index, config):
return ids, decoder_segment_ids, decoder_positions, logits


def main(config, test_args):
def main(config, test_args): # pylint: disable=W0621
"""Test the Whole Model of model_name"""

# initialize the model with weights from reference ckpt
Expand Down Expand Up @@ -118,9 +118,10 @@ def main(config, test_args):
max_logging.log(f"{golden_logits[2]=}")
max_logging.log(f"{full_train_logits[0, 2, :]=}")
token_size = int(test_args.token_size) if test_args.token_size else golden_logits.shape[0]
# The ellipsis is used to currently support jax nightly versions newer than 1/9/2025 and stable tests. This can be simplified later
# The ellipsis is used to currently support jax nightly versions newer than
# 1/9/2025 and stable tests. This can be simplified later
max_logging.log(
f"Max Numerical Difference {np.max(np.subtract(full_train_logits[..., 0, :token_size, :], golden_logits[:token_size, :]))}"
f"Max Numerical Difference {np.max(np.subtract(full_train_logits[..., 0, :token_size, :], golden_logits[:token_size, :]))}" # pylint: disable=C0301
)

model_probabilities = jax.nn.softmax(full_train_logits[..., 0, :token_size, :], axis=-1)
Expand All @@ -133,19 +134,17 @@ def main(config, test_args):
max_logging.log(f"KL divergence = {kl_div}, max KL divergence = {jax.numpy.max(kl_div)}")

if test_args.max_kl_div is not None:
max_logging.log("Checking KL Divergence between train distribution and golden distribution")
assert jax.numpy.all(
kl_div < test_args.max_kl_div
), f"KL divergence values exceed the specified threshold of {test_args.max_kl_div}. Max divergence: {jax.numpy.max(kl_div)}"
max_logging.log("Checking KL Divergence between train distribution and " "golden distribution")
assert jax.numpy.all(kl_div < test_args.max_kl_div), f"KL divergence values exceed the specified threshold of {test_args.max_kl_div}. Max divergence: {jax.numpy.max(kl_div)}" # pylint: disable=C0301
else:
max_logging.log("Checking Numerical Differences between train logits and golden logits")
max_logging.log("Checking Numerical Differences between train logits and golden logits") # pylint: disable=C0301
assert jax.numpy.allclose(
full_train_logits[..., 0, :token_size, :],
golden_logits[:token_size, :],
rtol=float(test_args.rtol),
atol=float(test_args.atol),
equal_nan=False,
), f"Logits do not match closely enough. Required rtol={test_args.rtol}, atol={test_args.atol}."
), f"Logits do not match closely enough. Required rtol={test_args.rtol}, atol={test_args.atol}." # pylint: disable=C0301


if __name__ == "__main__":
Expand Down
124 changes: 124 additions & 0 deletions MaxText/tests/hf_checkpoint_conversion_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""
Copyright 2025 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import glob
import os
import torch
from safetensors import safe_open
import pathlib
import jax
import numpy as np
import argparse


def load_hf(hf_checkpoint_folder):
safetensor_files = glob.glob(os.path.join(hf_checkpoint_folder, "*.safetensors"))

hf_tensor = {}
for st_f in safetensor_files:
with safe_open(st_f, framework="pt", device="cpu") as f:
for key in f.keys():
hf_tensor[key] = f.get_tensor(key).to(torch.float16)
return hf_tensor


def load_meta(meta_checkpoint_folder):
meta_tensor = {}
ckpt_paths = sorted(pathlib.Path(meta_checkpoint_folder).glob("[!.]*.pth"))
for ckpt_path in ckpt_paths:
meta_tensor = torch.load(ckpt_path, map_location="cpu")
return meta_tensor


def compare_pytrees(tree1, tree2, atol=0.001):
"""
Compares two JAX pytrees to check if all leaf values are within the given absolute tolerance.
Args:
tree1: First pytree.
tree2: Second pytree.
atol: Absolute tolerance for comparison (default: 0.001).
Returns:
A boolean indicating if all leaf values are within the specified range.
"""
# Ensure both trees have the same structure
if jax.tree_util.tree_structure(tree1) != jax.tree_util.tree_structure(tree2):
print(
"Pytrees have different structures! Tree1:"
f"{jax.tree_util.tree_structure(tree1)} \n\n\n"
f"Tree2: {jax.tree_util.tree_structure(tree2)}"
)
return

# Compare leaves with names
def get_named_leaves(pytree, parent_key=""):
named_leaves = {}
for key, value in pytree.items():
new_key = f"{parent_key}.{key}" if parent_key else key
if isinstance(value, dict):
named_leaves.update(get_named_leaves(value, new_key))
else:
named_leaves[new_key] = value
return named_leaves

named_leaves1 = get_named_leaves(tree1)
named_leaves2 = get_named_leaves(tree2)

print(f"There are {len(named_leaves1.keys())} leaves to check.")
for key in named_leaves1: # pylint: disable=C0206
if key not in named_leaves2:
print(f"Missing key in second tree: {key}")
return
try:
if not np.allclose(named_leaves1[key], named_leaves2[key], atol=atol):
print(f"Mismatch at leaf '{key}' with shape {named_leaves1[key].shape}:\n")
for i in range(10):
print(f"{named_leaves1[key][..., i, :]}\n")
print("The second tensor:\n")
for i in range(10):
print(f"{named_leaves2[key][..., i, :]}\n")
return
except: # pylint: disable=W0702
print(f"The issue is with {key}")

print(f"All {len(named_leaves1.keys())} leaves match within tolerance.")


def main():
parser = argparse.ArgumentParser(description="Compares the original checkpoint and converted back checkpoint.")
parser.add_argument(
"--original_ckpt",
type=str,
default="",
help="The original huggingface checkpoint",
)
parser.add_argument(
"--converted_ckpt",
type=str,
default="",
help="The original huggingface checkpoint",
)
args = parser.parse_args()

hf_tensor = load_hf(args.original_ckpt)
meta_tensor = load_hf(args.converted_ckpt)

compare_pytrees(hf_tensor, meta_tensor)


if __name__ == "__main__":
main()
57 changes: 57 additions & 0 deletions MaxText/tests/hf_checkpoint_conversion_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
"""
Copyright 2025 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

""" Tests for kernels """

import numpy as np
from max_utils import permute_to_match_maxtext_rope, unpermute_from_match_maxtext_rope
import unittest


class HFCheckpointConversionTest(unittest.TestCase):

def test_huggingface_to_maxtext_back_to_huggingface_flow(self):
base_num_query_heads = 16
head_dim = 32
wq = np.arange(base_num_query_heads * head_dim * base_num_query_heads * head_dim, dtype=np.float16).reshape(
base_num_query_heads * head_dim, base_num_query_heads * head_dim
)
wq1 = wq.transpose()
wq2 = np.reshape(wq1, [base_num_query_heads * head_dim, base_num_query_heads, head_dim])

wq3 = permute_to_match_maxtext_rope(wq2)
stack_shape = (1,)
x = np.zeros(stack_shape + wq3.shape, dtype=np.float16)
x[0, ...] = wq3
x = np.transpose(x, axes=(1, 0, 2, 3))

x = x[:, 0, :, :]
wq4 = unpermute_from_match_maxtext_rope(x, "llama3.1")
wq5 = wq4.reshape(base_num_query_heads * head_dim, base_num_query_heads * head_dim)
wq6 = wq5.transpose()

if not np.array_equal(wq, wq6):
print("Test failed: wq does not match wq6")

if not np.array_equal(wq1, wq5):
print("Test failed: wq1 does not match wq5")

if not np.array_equal(wq2, wq4):
print("Test failed: wq2 does not match wq4")


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 3228831

Please sign in to comment.