Skip to content

Commit

Permalink
Merge pull request #1291 from AI-Hypercomputer:lance-deepseek
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 729631246
  • Loading branch information
maxtext authors committed Feb 21, 2025
2 parents 8632dcb + 31171bf commit e7038bc
Show file tree
Hide file tree
Showing 8 changed files with 785 additions and 32 deletions.
4 changes: 3 additions & 1 deletion MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ checkpoint_storage_use_zarr3: True

reuse_example_batch: 0 # for testing TPU performance, this options repeated uses the same batch.


metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written.
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
gcs_metrics: False
Expand Down Expand Up @@ -277,7 +278,7 @@ logical_axis_rules: [
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'expert']],
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']],
Expand Down Expand Up @@ -532,6 +533,7 @@ inference_metadata_file: "" # path to a json file
inference_server: "MaxtextInterleavedServer" # inference server to start
inference_benchmark_test: False
enable_model_warmup: False
hf_model_path: "" # inference checkpoint correctness verification

# Stack prefill cache across the layer to reduce the
# Python layer latency.
Expand Down
319 changes: 297 additions & 22 deletions MaxText/llama_or_mistral_ckpt.py

Large diffs are not rendered by default.

15 changes: 15 additions & 0 deletions MaxText/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
17 changes: 8 additions & 9 deletions MaxText/tests/forward_pass_logit_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def get_data(golden_data, golden_data_index, config):
return ids, decoder_segment_ids, decoder_positions, logits


def main(config, test_args):
def main(config, test_args): # pylint: disable=W0621
"""Test the Whole Model of model_name"""

# initialize the model with weights from reference ckpt
Expand Down Expand Up @@ -118,9 +118,10 @@ def main(config, test_args):
max_logging.log(f"{golden_logits[2]=}")
max_logging.log(f"{full_train_logits[0, 2, :]=}")
token_size = int(test_args.token_size) if test_args.token_size else golden_logits.shape[0]
# The ellipsis is used to currently support jax nightly versions newer than 1/9/2025 and stable tests. This can be simplified later
# The ellipsis is used to currently support jax nightly versions newer than
# 1/9/2025 and stable tests. This can be simplified later
max_logging.log(
f"Max Numerical Difference {np.max(np.subtract(full_train_logits[..., 0, :token_size, :], golden_logits[:token_size, :]))}"
f"Max Numerical Difference {np.max(np.subtract(full_train_logits[..., 0, :token_size, :], golden_logits[:token_size, :]))}" # pylint: disable=C0301
)

model_probabilities = jax.nn.softmax(full_train_logits[..., 0, :token_size, :], axis=-1)
Expand All @@ -133,19 +134,17 @@ def main(config, test_args):
max_logging.log(f"KL divergence = {kl_div}, max KL divergence = {jax.numpy.max(kl_div)}")

if test_args.max_kl_div is not None:
max_logging.log("Checking KL Divergence between train distribution and golden distribution")
assert jax.numpy.all(
kl_div < test_args.max_kl_div
), f"KL divergence values exceed the specified threshold of {test_args.max_kl_div}. Max divergence: {jax.numpy.max(kl_div)}"
max_logging.log("Checking KL Divergence between train distribution and " "golden distribution")
assert jax.numpy.all(kl_div < test_args.max_kl_div), f"KL divergence values exceed the specified threshold of {test_args.max_kl_div}. Max divergence: {jax.numpy.max(kl_div)}" # pylint: disable=C0301
else:
max_logging.log("Checking Numerical Differences between train logits and golden logits")
max_logging.log("Checking Numerical Differences between train logits and golden logits") # pylint: disable=C0301
assert jax.numpy.allclose(
full_train_logits[..., 0, :token_size, :],
golden_logits[:token_size, :],
rtol=float(test_args.rtol),
atol=float(test_args.atol),
equal_nan=False,
), f"Logits do not match closely enough. Required rtol={test_args.rtol}, atol={test_args.atol}."
), f"Logits do not match closely enough. Required rtol={test_args.rtol}, atol={test_args.atol}." # pylint: disable=C0301


if __name__ == "__main__":
Expand Down
124 changes: 124 additions & 0 deletions MaxText/tests/hf_checkpoint_conversion_checker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""
Copyright 2025 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import glob
import os
import torch
from safetensors import safe_open
import pathlib
import jax
import numpy as np
import argparse


def load_hf(hf_checkpoint_folder):
safetensor_files = glob.glob(os.path.join(hf_checkpoint_folder, "*.safetensors"))

hf_tensor = {}
for st_f in safetensor_files:
with safe_open(st_f, framework="pt", device="cpu") as f:
for key in f.keys():
hf_tensor[key] = f.get_tensor(key).to(torch.float16)
return hf_tensor


def load_meta(meta_checkpoint_folder):
meta_tensor = {}
ckpt_paths = sorted(pathlib.Path(meta_checkpoint_folder).glob("[!.]*.pth"))
for ckpt_path in ckpt_paths:
meta_tensor = torch.load(ckpt_path, map_location="cpu")
return meta_tensor


def compare_pytrees(tree1, tree2, atol=0.001):
"""
Compares two JAX pytrees to check if all leaf values are within the given absolute tolerance.
Args:
tree1: First pytree.
tree2: Second pytree.
atol: Absolute tolerance for comparison (default: 0.001).
Returns:
A boolean indicating if all leaf values are within the specified range.
"""
# Ensure both trees have the same structure
if jax.tree_util.tree_structure(tree1) != jax.tree_util.tree_structure(tree2):
print(
"Pytrees have different structures! Tree1:"
f"{jax.tree_util.tree_structure(tree1)} \n\n\n"
f"Tree2: {jax.tree_util.tree_structure(tree2)}"
)
return

# Compare leaves with names
def get_named_leaves(pytree, parent_key=""):
named_leaves = {}
for key, value in pytree.items():
new_key = f"{parent_key}.{key}" if parent_key else key
if isinstance(value, dict):
named_leaves.update(get_named_leaves(value, new_key))
else:
named_leaves[new_key] = value
return named_leaves

named_leaves1 = get_named_leaves(tree1)
named_leaves2 = get_named_leaves(tree2)

print(f"There are {len(named_leaves1.keys())} leaves to check.")
for key in named_leaves1: # pylint: disable=C0206
if key not in named_leaves2:
print(f"Missing key in second tree: {key}")
return
try:
if not np.allclose(named_leaves1[key], named_leaves2[key], atol=atol):
print(f"Mismatch at leaf '{key}' with shape {named_leaves1[key].shape}:\n")
for i in range(10):
print(f"{named_leaves1[key][..., i, :]}\n")
print("The second tensor:\n")
for i in range(10):
print(f"{named_leaves2[key][..., i, :]}\n")
return
except: # pylint: disable=W0702
print(f"The issue is with {key}")

print(f"All {len(named_leaves1.keys())} leaves match within tolerance.")


def main():
parser = argparse.ArgumentParser(description="Compares the original checkpoint and converted back checkpoint.")
parser.add_argument(
"--original_ckpt",
type=str,
default="",
help="The original huggingface checkpoint",
)
parser.add_argument(
"--converted_ckpt",
type=str,
default="",
help="The original huggingface checkpoint",
)
args = parser.parse_args()

hf_tensor = load_hf(args.original_ckpt)
meta_tensor = load_hf(args.converted_ckpt)

compare_pytrees(hf_tensor, meta_tensor)


if __name__ == "__main__":
main()
58 changes: 58 additions & 0 deletions MaxText/tests/hf_checkpoint_conversion_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""
Copyright 2025 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

""" Tests for kernels """

import numpy as np
from MaxText.llama_or_mistral_ckpt import permute_to_match_maxtext_rope
from MaxText.llama_mistral_mixtral_orbax_to_hf import unpermute_from_match_maxtext_rope
import unittest


class HFCheckpointConversionTest(unittest.TestCase):

def test_huggingface_to_maxtext_back_to_huggingface_flow(self):
base_num_query_heads = 16
head_dim = 32
wq = np.arange(base_num_query_heads * head_dim * base_num_query_heads * head_dim, dtype=np.float16).reshape(
base_num_query_heads * head_dim, base_num_query_heads * head_dim
)
wq1 = wq.transpose()
wq2 = np.reshape(wq1, [base_num_query_heads * head_dim, base_num_query_heads, head_dim])

wq3 = permute_to_match_maxtext_rope(wq2)
stack_shape = (1,)
x = np.zeros(stack_shape + wq3.shape, dtype=np.float16)
x[0, ...] = wq3
x = np.transpose(x, axes=(1, 0, 2, 3))

x = x[:, 0, :, :]
wq4 = unpermute_from_match_maxtext_rope(x, "llama3.1")
wq5 = wq4.reshape(base_num_query_heads * head_dim, base_num_query_heads * head_dim)
wq6 = wq5.transpose()

if not np.array_equal(wq, wq6):
print("Test failed: wq does not match wq6")

if not np.array_equal(wq1, wq5):
print("Test failed: wq1 does not match wq5")

if not np.array_equal(wq2, wq4):
print("Test failed: wq2 does not match wq4")


if __name__ == "__main__":
unittest.main()
102 changes: 102 additions & 0 deletions end_to_end/tpu/llama3.1/70b/3_test_llama3.1_70b.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
export CHECKPOINT_ORIGINAL=/mnt/disks/persist/checkpoints/huggingface/Llama3.1-70B-Instruct
huggingface-cli download deepseek-ai/DeepSeek-R1-Distill-Llama-70B --local-dir $CHECKPOINT_ORIGINAL

export CHECKPOINT_TPU_SCANNED=$CHECKPOINT_ORIGINAL/scanned_chkpt

export TOKENIZER=assets/tokenizer_llama3.tiktoken
export BASE_OUTPUT_PATH=$CHECKPOINT_ORIGINAL
export RUN_NAME=unscanned_chkpt
export CHECKPOINT_TPU_UNSCANNED=$BASE_OUTPUT_PATH/$RUN_NAME/checkpoints/0/items
export MODEL_SIZE=llama3.1-70b

JAX_PLATFORMS=cpu python3 MaxText/llama_or_mistral_ckpt.py --base-model-path=$CHECKPOINT_ORIGINAL --model-size=$MODEL_SIZE --maxtext-model-path=$CHECKPOINT_TPU_SCANNED --huggingface-checkpoint=true


JAX_PLATFORMS=tpu python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=$TOKENIZER run_name=runner_2025-02-13-08-31 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=$MODEL_SIZE ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 per_device_batch_size=1 prompt="I love to" scan_layers=true load_parameters_path=$CHECKPOINT_TPU_SCANNED/0/items

# If everything looks good, we move on to convert to the unrolled checkpoint for performant serving
JAX_PLATFORMS=cpu python MaxText/generate_param_only_checkpoint.py MaxText/configs/base.yml async_checkpointing=false base_output_directory=${BASE_OUTPUT_PATH} load_parameters_path=${CHECKPOINT_TPU_SCANNED}/0/items run_name=${RUN_NAME} model_name=${MODEL_SIZE} force_unroll=true


JAX_PLATFORMS=tpu python MaxText/decode.py MaxText/configs/base.yml tokenizer_path=$TOKENIZER run_name=runner_2025-02-13-08-31 steps=10 weight_dtype=bfloat16 async_checkpointing=false model_name=$MODEL_SIZE ici_fsdp_parallelism=1 ici_autoregressive_parallelism=-1 per_device_batch_size=1 prompt="I love to" scan_layers=false load_parameters_path=$CHECKPOINT_TPU_UNSCANNED

# Example output
# Input `I love to` -> ` read, but I don't have much time. How can I read more books?
# I'm a busy person, but I want to read more. How can I fit reading into my schedule?
# I want to read more, but I don't have enough time. What can I do?
# I don't have time to read, but I want to. How can I make time for reading?
# I want to read more, but my schedule is too busy. What can I do?
# I don't have much time, but I want to read more. How can I manage that?
# I want to read more, but I'm too busy. How can I fit reading into my schedule?
# I don't have time to read, but I want to. What can I do?
# I want to read more, but I don't have enough time. How can I make time for reading?
# I'm busy, but I want to read more. How can I fit reading into my schedule?

# Okay, so I'm trying to figure out how to read more books even though I'm really busy. I love reading, but it seems like I never have the time. Let me think about this step by step.

# First, maybe I can start by looking at my daily routine to see where I can squeeze in some reading time. I usually wake up early, get ready for work, and then have a busy day. Maybe I can wake up a bit earlier each day to read before work. But wait, I'm not a morning person. That might be tough. Maybe I can try just 15 minutes in the morning. That doesn't seem too bad.

# Another idea is to use my commute. I take the bus to work, which is about 30 minutes each way. I could listen to audiobooks or read on my phone during that time. I've heard that audiobooks are a good way to consume books quickly, especially for non-fiction. But I'm more into fiction, so maybe I can find a good fiction audiobook. Or maybe I can read e-books on my phone. I have a Kindle app, so that could work.

# Lunch breaks are another possibility. I usually have about an hour for lunch. Maybe I can spend 20-30 minutes reading during that time. But sometimes I meet friends or have meetings, so it might not be consistent. Still, it's worth trying on the days I'm alone.

# Evenings are tricky because I'm often tired after work. But maybe right before bed, I can read for 15-20 minutes instead of scrolling through my phone. That might also help me wind down and sleep better. Plus, it's a good way to relax.

# I also have weekends. Maybe I can dedicate a couple of hours on Saturday or Sunday to reading. That could help me catch up on my reading without feeling rushed.

# Another thought: maybe I can make a reading list and set a goal for how many books I want to read each month. That way, I can track my progress and stay motivated. I could use a reading log or an app to keep track.

# I should also consider the types of books I'm reading. Maybe shorter books or novellas can be finished quicker, fitting into my busy schedule better. Or I could mix in some graphic novels, which are usually faster to read.

# I've heard about the concept of "reading sprints" where you read for a set amount of time without distractions. Maybe I can try that during my breaks or in the evenings. It might help me focus and get through more pages.

# Another idea is to join a book club or find a reading buddy. That could keep me accountable and give me a reason to prioritize reading. Plus, discussing books with others might make it more enjoyable and motivate me to keep going.

# I also need to think about eliminating distractions. Maybe turning off notifications on my phone or finding a quiet spot where I can read without interruptions. Creating a dedicated reading space might help me get into the right mindset.

# What about multitasking? I could listen to audiobooks while doing chores, exercising, or driving. That way, I'm making use of time that would otherwise be unproductive.

# I should also be realistic about my goals. I might not be able to read as much as I'd like, but setting achievable targets can help me stay on track without feeling overwhelmed. Maybe start with one book a month and gradually increase as I find more time.

# Another thing to consider is the format of the books. E-books are convenient because I can carry them on my phone, but physical books might be better for certain times, like before bed when I want to avoid screens.

# I could also try speed reading techniques, but I'm not sure how effective they are. Maybe skimming through less important parts or focusing on key points could help me get through books faster.

# Lastly, I need to prioritize reading as a form of self-care. It's important for my mental health and relaxation, so making time for it should be non-negotiable, just like other important activities.

# Putting it all together, I think the key is to find small pockets of time throughout the day and use them effectively. Whether it's during commutes, breaks, or before bed, every little bit counts. Combining different strategies like audiobooks, e-books, setting goals, and creating a reading-friendly environment can help me read more despite being busy.
# <|reserved_special_token_9|>

# To read more despite a busy schedule, consider the following organized approach:

# 1. **Morning Routine**: Start with 15 minutes of reading in the morning, even if you're not a morning person. It sets a positive tone for the day.

# 2. **Commute Utilization**: Use your 30-minute bus commute to listen to audiobooks or read e-books on your phone. This is an efficient way to consume books, especially fiction.

# 3. **Lunch Breaks**: Dedicate 20-30 minutes of your lunch break to reading, especially on days when you're alone. This provides a midday mental break.

# 4. **Evening Routine**: Wind down before bed with 15-20 minutes of reading instead of screen time. This aids relaxation and sleep.

# 5. **Weekend Dedication**: Allocate a couple of hours on weekends to reading, allowing you to catch up without feeling rushed.

# 6. **Reading Goals and Tracking**: Create a reading list and set monthly goals. Use a reading log or app to track progress and stay motivated.

# 7. **Book Selection**: Opt for shorter books, novellas, or graphic novels to fit into your schedule and vary your reading material.

# 8. **Reading Sprints**: Try focused reading sessions without distractions during breaks or evenings to maximize productivity.

# 9. **Accountability and Community**: Join a book club or find a reading buddy for accountability and enjoyment. Discussions can enhance your reading experience.

# 10. **Distraction Management**: Create a quiet reading space and minimize interruptions by turning off notifications.

# 11. **Multitasking with Audiobooks**: Listen to audiobooks during chores, exercise, or driving to utilize otherwise idle time.

# 12. **Realistic Goal Setting**: Start with achievable targets, like one book a month, and gradually increase as you find more time.

# 13. **Book Format Flexibility**: Use e-books for convenience and physical books for screen-free reading, especially before bed.

# 14. **Speed Reading Techniques**: Experiment with skimming or focusing on key points to read more efficiently.

# 15. **Prioritize Self-Care**: Treat reading as essential for mental health and relaxation, making it a non-negotiable part of your routine.

# By integrating these strategies, you can effectively use small time pockets to read more, combining audiobooks, e-books, goal setting, and a conducive reading environment.<|end_of_text|>
Loading

0 comments on commit e7038bc

Please sign in to comment.