Skip to content

Commit

Permalink
Merge branch 'main' into shralex-patch-1
Browse files Browse the repository at this point in the history
  • Loading branch information
shralex authored Feb 22, 2025
2 parents af0d018 + 02b9b8d commit d02ce89
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 785 deletions.
4 changes: 1 addition & 3 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ checkpoint_storage_use_zarr3: True

reuse_example_batch: 0 # for testing TPU performance, this options repeated uses the same batch.


metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written.
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
gcs_metrics: False
Expand Down Expand Up @@ -278,7 +277,7 @@ logical_axis_rules: [
['vocab', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['q_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['kv_heads', ['tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'expert']],
['embed', ['fsdp', 'sequence', 'tensor_transpose', 'expert']],
['embed', ['fsdp', 'fsdp_transpose', 'sequence', 'expert']],
Expand Down Expand Up @@ -533,7 +532,6 @@ inference_metadata_file: "" # path to a json file
inference_server: "MaxtextInterleavedServer" # inference server to start
inference_benchmark_test: False
enable_model_warmup: False
hf_model_path: "" # inference checkpoint correctness verification

# Stack prefill cache across the layer to reduce the
# Python layer latency.
Expand Down
319 changes: 22 additions & 297 deletions MaxText/llama_or_mistral_ckpt.py

Large diffs are not rendered by default.

15 changes: 0 additions & 15 deletions MaxText/tests/__init__.py

This file was deleted.

17 changes: 9 additions & 8 deletions MaxText/tests/forward_pass_logit_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def get_data(golden_data, golden_data_index, config):
return ids, decoder_segment_ids, decoder_positions, logits


def main(config, test_args): # pylint: disable=W0621
def main(config, test_args):
"""Test the Whole Model of model_name"""

# initialize the model with weights from reference ckpt
Expand Down Expand Up @@ -118,10 +118,9 @@ def main(config, test_args): # pylint: disable=W0621
max_logging.log(f"{golden_logits[2]=}")
max_logging.log(f"{full_train_logits[0, 2, :]=}")
token_size = int(test_args.token_size) if test_args.token_size else golden_logits.shape[0]
# The ellipsis is used to currently support jax nightly versions newer than
# 1/9/2025 and stable tests. This can be simplified later
# The ellipsis is used to currently support jax nightly versions newer than 1/9/2025 and stable tests. This can be simplified later
max_logging.log(
f"Max Numerical Difference {np.max(np.subtract(full_train_logits[..., 0, :token_size, :], golden_logits[:token_size, :]))}" # pylint: disable=C0301
f"Max Numerical Difference {np.max(np.subtract(full_train_logits[..., 0, :token_size, :], golden_logits[:token_size, :]))}"
)

model_probabilities = jax.nn.softmax(full_train_logits[..., 0, :token_size, :], axis=-1)
Expand All @@ -134,17 +133,19 @@ def main(config, test_args): # pylint: disable=W0621
max_logging.log(f"KL divergence = {kl_div}, max KL divergence = {jax.numpy.max(kl_div)}")

if test_args.max_kl_div is not None:
max_logging.log("Checking KL Divergence between train distribution and " "golden distribution")
assert jax.numpy.all(kl_div < test_args.max_kl_div), f"KL divergence values exceed the specified threshold of {test_args.max_kl_div}. Max divergence: {jax.numpy.max(kl_div)}" # pylint: disable=C0301
max_logging.log("Checking KL Divergence between train distribution and golden distribution")
assert jax.numpy.all(
kl_div < test_args.max_kl_div
), f"KL divergence values exceed the specified threshold of {test_args.max_kl_div}. Max divergence: {jax.numpy.max(kl_div)}"
else:
max_logging.log("Checking Numerical Differences between train logits and golden logits") # pylint: disable=C0301
max_logging.log("Checking Numerical Differences between train logits and golden logits")
assert jax.numpy.allclose(
full_train_logits[..., 0, :token_size, :],
golden_logits[:token_size, :],
rtol=float(test_args.rtol),
atol=float(test_args.atol),
equal_nan=False,
), f"Logits do not match closely enough. Required rtol={test_args.rtol}, atol={test_args.atol}." # pylint: disable=C0301
), f"Logits do not match closely enough. Required rtol={test_args.rtol}, atol={test_args.atol}."


if __name__ == "__main__":
Expand Down
124 changes: 0 additions & 124 deletions MaxText/tests/hf_checkpoint_conversion_checker.py

This file was deleted.

58 changes: 0 additions & 58 deletions MaxText/tests/hf_checkpoint_conversion_test.py

This file was deleted.

102 changes: 0 additions & 102 deletions end_to_end/tpu/llama3.1/70b/3_test_llama3.1_70b.sh

This file was deleted.

Loading

0 comments on commit d02ce89

Please sign in to comment.