From 49c423e53cdf29615f136f8d811b4e5f80e48f0a Mon Sep 17 00:00:00 2001 From: Alex Shraer Date: Fri, 20 Dec 2024 07:40:04 +0000 Subject: [PATCH] Move tests from UnitTests.yml to tests files. This change moves out train and decode tests into separate files, so they can be tested using pytest rather than explicitly being invoked from UnitTests.yml. This will allow us, in a subsequent PR, to parallelize the execution of these tests by utilizing pytest parallelization support. Several end-to-end tests are still being invoked from UnitTests.yml. This does not seem like the right place, and will hopefully be addressed in the future. --- .github/workflows/UnitTests.yml | 94 +++---------------- MaxText/decode.py | 22 +++-- MaxText/pytest.ini | 12 ++- MaxText/tests/attention_test.py | 16 ++-- MaxText/tests/decode_int8_test.py | 55 +++++++++++ MaxText/tests/decode_pdb_lt_1_test.py | 53 +++++++++++ MaxText/tests/decode_test.py | 53 +++++++++++ MaxText/tests/gpt3_test.py | 7 +- MaxText/tests/gradient_accumulation_test.py | 2 +- .../inference_microbenchmark_smoke_test.py | 2 +- .../checkpoint_compatibility_test.py | 48 ++++++++++ .../integration_tests/checkpointing_test.py | 51 ++++++++++ .../generate_param_only_checkpoint_test.py | 53 +++++++++++ .../shmap_collective_matmul_test.py | 32 +++++++ MaxText/tests/kernels_test.py | 6 +- MaxText/tests/model_test.py | 2 +- MaxText/tests/multihost_dataloading_test.py | 2 +- MaxText/tests/pipeline_parallelism_test.py | 12 +-- MaxText/tests/simple_decoder_layer_test.py | 4 +- MaxText/tests/standalone_dl_ckpt_test.py | 4 +- MaxText/tests/tokenizer_test.py | 8 +- MaxText/tests/train_base_cudnn_flash_te.py | 46 +++++++++ MaxText/tests/train_base_test.py | 50 ++++++++++ MaxText/tests/train_compile_test.py | 24 ++--- MaxText/tests/train_dropout_test.py | 53 +++++++++++ MaxText/tests/train_fp8_test.py | 51 ++++++++++ MaxText/tests/train_hf_input_pipeline_test.py | 52 ++++++++++ MaxText/tests/train_int8_test.py | 51 ++++++++++ MaxText/tests/train_pdb_lt_1_test.py | 52 ++++++++++ MaxText/tests/train_synthetic_data_test.py | 51 ++++++++++ 30 files changed, 830 insertions(+), 138 deletions(-) create mode 100644 MaxText/tests/decode_int8_test.py create mode 100644 MaxText/tests/decode_pdb_lt_1_test.py create mode 100644 MaxText/tests/decode_test.py create mode 100644 MaxText/tests/integration_tests/checkpoint_compatibility_test.py create mode 100644 MaxText/tests/integration_tests/checkpointing_test.py create mode 100644 MaxText/tests/integration_tests/generate_param_only_checkpoint_test.py create mode 100644 MaxText/tests/integration_tests/shmap_collective_matmul_test.py create mode 100644 MaxText/tests/train_base_cudnn_flash_te.py create mode 100644 MaxText/tests/train_base_test.py create mode 100644 MaxText/tests/train_dropout_test.py create mode 100644 MaxText/tests/train_fp8_test.py create mode 100644 MaxText/tests/train_hf_input_pipeline_test.py create mode 100644 MaxText/tests/train_int8_test.py create mode 100644 MaxText/tests/train_pdb_lt_1_test.py create mode 100644 MaxText/tests/train_synthetic_data_test.py diff --git a/.github/workflows/UnitTests.yml b/.github/workflows/UnitTests.yml index 91a62efc2..815da4c6c 100644 --- a/.github/workflows/UnitTests.yml +++ b/.github/workflows/UnitTests.yml @@ -23,8 +23,8 @@ on: branches: [ "main" ] workflow_dispatch: schedule: - # Run the job every 2 hours - - cron: '0 */2 * * *' + # Run the job every 6 hours + - cron: '0 */6 * * *' jobs: build_and_upload_image: @@ -46,7 +46,7 @@ jobs: run: docker system prune --all --force - name: Build an image run: | - bash docker_build_dependency_image.sh MODE=${{ matrix.device.mode }} DEVICE=${{ matrix.device.type }} + bash docker_build_dependency_image.sh MODE=${{ matrix.device.mode }} DEVICE=${{ matrix.device.type }} - name: Tag the image run: | docker tag maxtext_base_image gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:${{ matrix.device.type }} @@ -62,8 +62,7 @@ jobs: device: - type: tpu name: v4-8 - attention: autoselected - pytest_marker: '' + pytest_marker: 'not gpu_only' # exclude tests marked gpu_only container_env: XLA_PYTHON_CLIENT_MEM_FRACTION: 0.75 TF_FORCE_GPU_ALLOW_GROWTH: false @@ -71,8 +70,7 @@ jobs: - type: gpu name: a100-40gb-4 image_suffix: gpu_jax_pinned - attention: dot_product - pytest_marker: -m 'not tpu' + pytest_marker: 'not tpu_only' # exclude tests marked tpu_only container_env: XLA_PYTHON_CLIENT_MEM_FRACTION: 0.65 TF_FORCE_GPU_ALLOW_GROWTH: true @@ -81,7 +79,7 @@ jobs: runs-on: ["self-hosted", "${{ matrix.device.type }}", "${{ matrix.device.name }}"] container: image: gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:${{ matrix.device.type }} - volumes: + volumes: - /home/runner/actions-runner/_work/maxtext/maxtext:/deps env: XLA_PYTHON_CLIENT_MEM_FRACTION: ${{ matrix.device.container_env.XLA_PYTHON_CLIENT_MEM_FRACTION }} @@ -91,82 +89,14 @@ jobs: - uses: actions/checkout@v4 - name: Test gsutil installation run: which gsutil >/dev/null 2>&1 || { echo >&2 "gsutil is required but not installed. Aborting"; exit 24;} - - name: Test with pytest - run: cd MaxText;python3 -m pytest ${{ matrix.device.pytest_marker }} - - name: Test train.py with TFDS c4 - run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false attention=${{ matrix.device.attention }} - - name: Test train.py with HF c4 - run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs hf_train_files=gs://maxtext-dataset/hf/c4/c4-train-00000-of-01637.parquet hf_path=parquet dataset_type=hf steps=2 tokenizer_path=google-t5/t5-large attention=${{ matrix.device.attention }} enable_checkpointing=false - - name: Test train.py with synthetic data - run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false attention=${{ matrix.device.attention }} dataset_type=synthetic - - name: Test train.py with per_device_batch_size < 1 - run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 per_device_batch_size=0.25 ici_tensor_parallelism=4 enable_checkpointing=false attention=${{ matrix.device.attention }} - - name: Test decode.py - run: python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=${{ matrix.device.attention }} enable_checkpointing=false max_target_length=128 per_device_batch_size=1 - - name: Test int8_decode - run: python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=${{ matrix.device.attention }} enable_checkpointing=false max_target_length=128 per_device_batch_size=1 quantization=int8 quantize_kvcache=True - - name: Test decode.py with per_device_batch_size < 1 - run: python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=${{ matrix.device.attention }} enable_checkpointing=false max_target_length=128 per_device_batch_size=.25 - - name: Test int8_training - run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=int8 steps=2 enable_checkpointing=false attention=${{ matrix.device.attention }} - - name: Test fp8_training - run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=fp8 steps=2 enable_checkpointing=false attention=${{ matrix.device.attention }} - - name: Test train.py with dropout - run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false attention=${{ matrix.device.attention }} max_target_length=128 per_device_batch_size=1 dropout_rate=0.02 - - name: Test generate_param_only_checkpoint - run: bash end_to_end/test_generate_param_only_checkpoint.sh -r runner_$(date +%Y-%m-%d-%H-%M-%S) -o gs://runner-maxtext-logs -d gs://maxtext-dataset -i 4 -a ${{ matrix.device.attention }} - - name: Test generate_param_only_checkpoint with int8 quantization - run: bash end_to_end/test_generate_param_only_checkpoint.sh -r runner_$(date +%Y-%m-%d-%H-%M-%S) -o gs://runner-maxtext-logs -d gs://maxtext-dataset -i 4 -q int8 -a ${{ matrix.device.attention }} - - name: Test grain checkpoint determinism - run: bash end_to_end/test_checkpointing.sh runner_$(date +%Y-%m-%d-%H-%M-%S) gs://runner-maxtext-logs gs://maxtext-dataset False grain ${{ matrix.device.attention }} - - name: Test checkpoint compatibility - run: bash end_to_end/test_checkpoint_compatibility.sh runner_$(date +%Y-%m-%d-%H-%M-%S) gs://runner-maxtext-logs gs://maxtext-dataset ${{ matrix.device.attention }} - - tpu: - needs: build_and_upload_image - strategy: - fail-fast: false - matrix: - device-type: ["v4-8"] - name: "TPU test (${{ matrix.device-type }})" - runs-on: ["self-hosted", "tpu", "${{ matrix.device-type }}"] - container: - image: gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:tpu - volumes: - - /home/runner/actions-runner/_work/maxtext/maxtext:/deps - options: "--privileged" - steps: - - uses: actions/checkout@v4 - - name: Validate Pedagogical Example, Shmap_collective_matmul - run: python3 pedagogical_examples/shmap_collective_matmul.py - - gpu: - needs: build_and_upload_image - strategy: - fail-fast: false - matrix: - device-type: ["a100-40gb-4"] - build-mode: ["pinned"] - name: "GPU test (${{ matrix.device-type }}, ${{ matrix.build-mode }})" - runs-on: ["self-hosted", "gpu", "${{ matrix.device-type }}"] - container: - image: gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:gpu - volumes: - - /home/runner/actions-runner/_work/maxtext/maxtext:/deps - env: - XLA_PYTHON_CLIENT_MEM_FRACTION: 0.65 - TF_FORCE_GPU_ALLOW_GROWTH: true - options: "--shm-size 2g --runtime=nvidia --gpus all --privileged" - steps: - - uses: actions/checkout@v4 - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v3 - - name: Test train.py with flash attention - run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false attention=cudnn_flash_te + - name: Unit Tests + run: cd MaxText;python3 -m pytest tests -m "${{ matrix.device.pytest_marker }} and not integration_test" + - name: Integration Tests + run: cd MaxText; python3 -m pytest tests/integration_tests -m "${{ matrix.device.pytest_marker }} and integration_test" clean_up: if: ${{ always() }} - needs: [common, gpu, tpu] + needs: common name: "Clean up" runs-on: ["self-hosted"] steps: @@ -174,4 +104,4 @@ jobs: run: gcloud container images delete gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:gpu --force-delete-tags --quiet - name: Delete TPU image run: gcloud container images delete gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:tpu --force-delete-tags --quiet - + diff --git a/MaxText/decode.py b/MaxText/decode.py index b74385207..ef2f2fc79 100644 --- a/MaxText/decode.py +++ b/MaxText/decode.py @@ -21,10 +21,20 @@ import os import pyconfig -import sys +from typing import Sequence +from absl import app + + +def main(argv: Sequence[str]) -> None: + jax.config.update("jax_default_prng_impl", "unsafe_rbg") + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + + pyconfig.initialize(argv) + config = pyconfig.config + validate_config(config) + max_utils.print_system_information() -def main(config): engine = maxengine.MaxEngine(config) rng = jax.random.PRNGKey(1234) rng, rng_load_params = jax.random.split(rng) @@ -71,10 +81,4 @@ def validate_config(config): if __name__ == "__main__": - jax.config.update("jax_default_prng_impl", "unsafe_rbg") - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" - pyconfig.initialize(sys.argv) - cfg = pyconfig.config - validate_config(cfg) - max_utils.print_system_information() - main(cfg) + app.run(main) diff --git a/MaxText/pytest.ini b/MaxText/pytest.ini index fa8e8142b..fc6c896a9 100644 --- a/MaxText/pytest.ini +++ b/MaxText/pytest.ini @@ -3,11 +3,15 @@ testpaths = tests python_files = *_test.py -addopts = - -rf --import-mode=importlib +addopts = + -rf --import-mode=importlib --strict-markers --ignore=tests/profiler_test.py --ignore=tests/train_smoke_test.py --ignore=tests/train_int8_smoke_test.py --ignore=tests/train_gpu_smoke_test.py -markers = - tpu: marks tests to be run on TPU \ No newline at end of file +markers = + tpu_only: marks tests to be run on TPUs only + gpu_only: marks tests to be run on GPUs only + integration_test: tests exercising larger portions of the system, + including interactions with other systems like GCS, + e.g., end_to_end tests diff --git a/MaxText/tests/attention_test.py b/MaxText/tests/attention_test.py index f04ebe190..2cf47169e 100644 --- a/MaxText/tests/attention_test.py +++ b/MaxText/tests/attention_test.py @@ -118,7 +118,7 @@ def get_structured_data(self, dtype): return lnx, decoder_segment_ids, decoder_positions - @pytest.mark.tpu + @pytest.mark.tpu_only def test_autoregression(self): prefill_length = self.cfg.max_prefill_predict_length decode_total_length = self.cfg.max_target_length @@ -174,11 +174,11 @@ def test_autoregression(self): self.assertTrue(mha_full_this_idx.shape == mha_idx.shape) self.assertTrue(jax.numpy.allclose(mha_full_this_idx, mha_idx, rtol=1e-02, atol=1e-02, equal_nan=False)) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_model_mode_prefill_dtype_float32(self): self._test_model_mode_prefill_dtype(jnp.float32) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_model_mode_prefill_dtype_bfloat16(self): self._test_model_mode_prefill_dtype(jnp.bfloat16) @@ -224,15 +224,15 @@ def _test_model_mode_prefill_dtype(self, dtype): self.assertEqual(dtype, mha_prefill.dtype) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_tpu_kernel_attention_mha(self): self.tpu_kernel_attention_helper(self.num_kv_heads) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_tpu_kernel_attention_gqa(self): self.tpu_kernel_attention_helper(self.num_kv_heads // 2) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_tpu_kernel_attention_mqa(self): self.tpu_kernel_attention_helper(1) @@ -309,7 +309,7 @@ def tpu_kernel_attention_helper(self, num_kv_heads): jax.numpy.allclose(mha_generic_output, mha_generic_flash_output, rtol=1e-01, atol=1e-01, equal_nan=False) ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_dot_product_cache_axis_order(self): all_axis_orders = [axis_order for axis_order in itertools.permutations(range(4))] for axis_order in random.choices(all_axis_orders, k=4): @@ -423,7 +423,7 @@ def _dot_product_attention( jax.numpy.allclose(attention_w_layout_full_this_idx, attention_w_layout_idx, rtol=rtol, atol=atol, equal_nan=False) ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_dot_product_reshape_q(self): for compute_axis_order in [(0, 1, 2, 3), (0, 2, 1, 3)]: self._dot_product_attention_reshape_q( diff --git a/MaxText/tests/decode_int8_test.py b/MaxText/tests/decode_int8_test.py new file mode 100644 index 000000000..918c73cc7 --- /dev/null +++ b/MaxText/tests/decode_int8_test.py @@ -0,0 +1,55 @@ +""" +Copyright 2023 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Short test for decode with int8 quantization""" +import os +import unittest +import pytest +from decode import main as decode_main +from absl.testing import absltest + + +class Train(unittest.TestCase): + """Tests decode with int8 quantization""" + + # Shared parameters + CONFIG = [ + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "steps=2", + "enable_checkpointing=False", + "ici_tensor_parallelism=4", + "max_target_length=128", + "per_device_batch_size=1", + "quantization=int8", + "quantize_kvcache=True", + r"tokenizer_path=../assets/tokenizer.llama2", + ] + + @pytest.mark.tpu_only + def test_default_config(self): + decode_main(Train.CONFIG) + + @pytest.mark.gpu_only + def test_default_config_dot_product(self): + decode_main(Train.CONFIG + ["attention=dot_product"]) + + +if __name__ == "__main__": + absltest.main() diff --git a/MaxText/tests/decode_pdb_lt_1_test.py b/MaxText/tests/decode_pdb_lt_1_test.py new file mode 100644 index 000000000..2eb08f6fa --- /dev/null +++ b/MaxText/tests/decode_pdb_lt_1_test.py @@ -0,0 +1,53 @@ +""" +Copyright 2023 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Short test for decode with per_device_batch_size < 1""" +import os +import unittest +import pytest +from decode import main as decode_main +from absl.testing import absltest + + +class Train(unittest.TestCase): + """Tests decode with per_device_batch_size < 1""" + + # Shared parameters + CONFIG = [ + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "steps=2", + "enable_checkpointing=False", + "ici_tensor_parallelism=4", + "max_target_length=128", + "per_device_batch_size=.25", + r"tokenizer_path=../assets/tokenizer.llama2", + ] + + @pytest.mark.tpu_only + def test_default_config(self): + decode_main(Train.CONFIG) + + @pytest.mark.gpu_only + def test_default_config_dot_product(self): + decode_main(Train.CONFIG + ["attention=dot_product"]) + + +if __name__ == "__main__": + absltest.main() diff --git a/MaxText/tests/decode_test.py b/MaxText/tests/decode_test.py new file mode 100644 index 000000000..a2255f847 --- /dev/null +++ b/MaxText/tests/decode_test.py @@ -0,0 +1,53 @@ +""" +Copyright 2023 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Short test for decode""" +import os +import unittest +import pytest +from decode import main as decode_main +from absl.testing import absltest + + +class Train(unittest.TestCase): + """Tests decode""" + + # Shared parameters + CONFIG = [ + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "steps=2", + "enable_checkpointing=False", + "ici_tensor_parallelism=4", + "max_target_length=128", + "per_device_batch_size=1", + r"tokenizer_path=../assets/tokenizer.llama2", + ] + + @pytest.mark.tpu_only + def test_default_config(self): + decode_main(Train.CONFIG) + + @pytest.mark.gpu_only + def test_default_config_dot_product(self): + decode_main(Train.CONFIG + ["attention=dot_product"]) + + +if __name__ == "__main__": + absltest.main() diff --git a/MaxText/tests/gpt3_test.py b/MaxText/tests/gpt3_test.py index b1f0bed52..fea40b9e0 100644 --- a/MaxText/tests/gpt3_test.py +++ b/MaxText/tests/gpt3_test.py @@ -85,7 +85,7 @@ def setUp(self): } self.model_vars = init_random_model_vars(self.model, self.rng, self.example_batch) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_logits_numerically(self): # ground truth values are calculated from paxml after loading above model_vars # note we expect all xents are the same except the padding one since: @@ -108,4 +108,7 @@ def test_logits_numerically(self): # Mask out paddings at the end of each example. per_example_xent = per_example_xent * (self.example_batch["targets_segmentation"] != 0) - self.assertTrue(jax.numpy.allclose(per_example_xent, per_example_xent_truth, rtol=1e-03, atol=1e-03)) + self.assertTrue( + jax.numpy.allclose(per_example_xent, per_example_xent_truth, rtol=1e-03, atol=1e-03), + msg=f"per_example_xent:\n{per_example_xent}\n\nper_example_xent_truth:\n{per_example_xent_truth}", + ) diff --git a/MaxText/tests/gradient_accumulation_test.py b/MaxText/tests/gradient_accumulation_test.py index e2730fe97..29c0ab087 100644 --- a/MaxText/tests/gradient_accumulation_test.py +++ b/MaxText/tests/gradient_accumulation_test.py @@ -28,7 +28,7 @@ def generate_random_string(length=10): class GradientAccumulationTest(unittest.TestCase): - @pytest.mark.tpu + @pytest.mark.tpu_only def test_grad_accumulate_same_loss(self): random_suffix = generate_random_string() run_accumulate_metrics_file = f"/tmp/runner_grad_accumulate_{random_suffix}.txt" diff --git a/MaxText/tests/inference_microbenchmark_smoke_test.py b/MaxText/tests/inference_microbenchmark_smoke_test.py index c28de3dcc..43ceb82a1 100644 --- a/MaxText/tests/inference_microbenchmark_smoke_test.py +++ b/MaxText/tests/inference_microbenchmark_smoke_test.py @@ -23,7 +23,7 @@ class Inference_Microbenchmark(unittest.TestCase): - @pytest.mark.tpu + @pytest.mark.tpu_only def test(self): pyconfig.initialize( [ diff --git a/MaxText/tests/integration_tests/checkpoint_compatibility_test.py b/MaxText/tests/integration_tests/checkpoint_compatibility_test.py new file mode 100644 index 000000000..0470d7093 --- /dev/null +++ b/MaxText/tests/integration_tests/checkpoint_compatibility_test.py @@ -0,0 +1,48 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Integraion tests for test_checkpointing.sh""" +from datetime import datetime +import subprocess +import pytest + + +def run_checkpoint_compatibility(attention_type): + """Tests checkpoint compatibility.""" + + run_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + command = [ + "bash", + "end_to_end/test_checkpoint_compatibility.sh", + f"runner_{run_date}", # run_name + r"gs://runner-maxtext-logs", # output_path + r"gs://maxtext-dataset", # dataset_path + attention_type, + ] + + subprocess.run(command, check=True, cwd="..") + + +@pytest.mark.integration_test +@pytest.mark.tpu_only +def test_autoselected_attention(): + run_checkpoint_compatibility("autoselected") + + +@pytest.mark.integration_test +@pytest.mark.gpu_only +def test_with_dot_product(): + run_checkpoint_compatibility("dot_product") diff --git a/MaxText/tests/integration_tests/checkpointing_test.py b/MaxText/tests/integration_tests/checkpointing_test.py new file mode 100644 index 000000000..01db18cba --- /dev/null +++ b/MaxText/tests/integration_tests/checkpointing_test.py @@ -0,0 +1,51 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Integraion tests for test_checkpointing.sh""" +from datetime import datetime +import subprocess +import pytest + + +def run_checkpointing(attention_type): + """Tests grain checkpoint determinism.""" + + run_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + + command = [ + "bash", + "end_to_end/test_checkpointing.sh", + f"runner_{run_date}", # run_name + r"gs://runner-maxtext-logs", # output_path + r"gs://maxtext-dataset", # dataset_path + "False", # collect_stack_trace + "grain", # dataset_type + attention_type, + ] + + subprocess.run(command, check=True, cwd="..") + + +@pytest.mark.integration_test +@pytest.mark.tpu_only +def test_autoselected_attention(): + run_checkpointing("autoselected") + + +@pytest.mark.integration_test +@pytest.mark.gpu_only +def test_with_dot_product(): + run_checkpointing("dot_product") diff --git a/MaxText/tests/integration_tests/generate_param_only_checkpoint_test.py b/MaxText/tests/integration_tests/generate_param_only_checkpoint_test.py new file mode 100644 index 000000000..6afb389c7 --- /dev/null +++ b/MaxText/tests/integration_tests/generate_param_only_checkpoint_test.py @@ -0,0 +1,53 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Integraion tests for test_generate_param_only_checkpoint.sh""" +from datetime import datetime +import subprocess +import pytest + + +def run_generate_param_only_checkpoint(attention_type, quantization): + """Tests generating a parameter-only checkpoint.""" + + run_date = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + # fmt: off + command = [ + "bash", + "end_to_end/test_generate_param_only_checkpoint.sh", + "-r", f"runner_{run_date}", + "-o", r"gs://runner-maxtext-logs", + "-d", r"gs://maxtext-dataset", + "-i", "4", + "-a", attention_type, + "-q", quantization, + ] + + subprocess.run(command, check=True, cwd="..") + + +@pytest.mark.integration_test +@pytest.mark.tpu_only +@pytest.mark.parametrize("quantization", [(""), ("int8")]) +def test_autoselected_attention(quantization): + run_generate_param_only_checkpoint("autoselected", quantization) + + +@pytest.mark.integration_test +@pytest.mark.gpu_only +@pytest.mark.parametrize("quantization", [(""), ("int8")]) +def test_with_dot_product(quantization): + run_generate_param_only_checkpoint("dot_product", quantization) diff --git a/MaxText/tests/integration_tests/shmap_collective_matmul_test.py b/MaxText/tests/integration_tests/shmap_collective_matmul_test.py new file mode 100644 index 000000000..55658e414 --- /dev/null +++ b/MaxText/tests/integration_tests/shmap_collective_matmul_test.py @@ -0,0 +1,32 @@ +""" +Copyright 2024 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Integraion test for pedagogical_examples/shmap_collective_matmul.py""" +import subprocess +import pytest + + +@pytest.mark.integration_test +@pytest.mark.tpu_only +def test_shmap_collective_matmul_example(): + """Validate Pedagogical Example, Shmap_collective_matmul.""" + + command = [ + "python3", + "pedagogical_examples/shmap_collective_matmul.py", + ] + + subprocess.run(command, check=True, cwd="..") diff --git a/MaxText/tests/kernels_test.py b/MaxText/tests/kernels_test.py index 5ec2d1c17..6313aa884 100644 --- a/MaxText/tests/kernels_test.py +++ b/MaxText/tests/kernels_test.py @@ -38,7 +38,7 @@ class RaggedAttentionTest(unittest.TestCase): key = jax.random.key(0) k1, k2, k3 = jax.random.split(key, 3) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_ragged_mqa(self): q = jax.random.normal(self.k1, (self.batch_size, 1, self.head_dim), dtype=self.dtype) k = jax.random.normal(self.k2, (self.batch_size, self.max_target_length, self.head_dim), dtype=self.dtype) @@ -56,7 +56,7 @@ def test_ragged_mqa(self): msg=f"Avg difference: {jnp.average(abs(ragged_out - reference_out))} > 1e-2", ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_ragged_mha(self): q = jax.random.normal(self.k1, (self.batch_size, 1, self.num_query_heads, self.head_dim), dtype=self.dtype) k = jax.random.normal( @@ -79,7 +79,7 @@ def test_ragged_mha(self): msg=f"Avg difference: {jnp.average(abs(ragged_out - reference_out))} > 1e-2", ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_ragged_gqa(self): q = jax.random.normal(self.k1, (self.batch_size, 1, self.num_query_heads, self.head_dim), dtype=self.dtype) k = jax.random.normal( diff --git a/MaxText/tests/model_test.py b/MaxText/tests/model_test.py index 9791af93b..ed1eecdb6 100644 --- a/MaxText/tests/model_test.py +++ b/MaxText/tests/model_test.py @@ -105,7 +105,7 @@ def test_logits_dtype_with_cast_to_fp32(self): def test_logits_dtype_without_cast(self): self._test_logits_cast_driver(cast_logits_to_fp32=False, expected_dtype=jnp.bfloat16) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_train_vs_prefill_and_autoregress(self): PREFILL_RANGE = MAX_PREFILL_PREDICT_LENGTH diff --git a/MaxText/tests/multihost_dataloading_test.py b/MaxText/tests/multihost_dataloading_test.py index ba289c040..297d75370 100644 --- a/MaxText/tests/multihost_dataloading_test.py +++ b/MaxText/tests/multihost_dataloading_test.py @@ -62,7 +62,7 @@ def setUp(self): dataset = dataset.batch(batch_size) self.multihost_gen = multihost_dataloading.MultiHostDataLoadIterator(dataset, self.mesh) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_batch_sharded_data_pipeline(self): first_batch = next(self.multihost_gen) sec_batch = next(self.multihost_gen) diff --git a/MaxText/tests/pipeline_parallelism_test.py b/MaxText/tests/pipeline_parallelism_test.py index 193c677fa..0e1e18d25 100644 --- a/MaxText/tests/pipeline_parallelism_test.py +++ b/MaxText/tests/pipeline_parallelism_test.py @@ -150,7 +150,7 @@ def regular_sequential_layers_dummy_loss( dummy_targets, ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_circular_minimum_microbatches_same_output_and_grad(self): # 4 stages, 8 layers (2 repeats, 1 layer per stage), 4 microbatches pyconfig.initialize( @@ -167,7 +167,7 @@ def test_circular_minimum_microbatches_same_output_and_grad(self): config = pyconfig.config self.assert_pipeline_same_output_and_grad(config) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_circular_extra_microbatches_same_output_and_grad(self): # 4 stages, 8 layers (2 repeats, 1 layer per stage), 8 microbatches pyconfig.initialize( @@ -184,7 +184,7 @@ def test_circular_extra_microbatches_same_output_and_grad(self): config = pyconfig.config self.assert_pipeline_same_output_and_grad(config) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_non_circular_same_output_and_grad(self): # 4 stages, 4 layers (no circular repeats, 1 layer per stage), 4 microbatches pyconfig.initialize( @@ -201,7 +201,7 @@ def test_non_circular_same_output_and_grad(self): config = pyconfig.config self.assert_pipeline_same_output_and_grad(config) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_full_train_circular(self): # Run a full train.py call with 4 stages, 32 layers (2 layers per stage, 4 circular repeats), 8 microbatches train_main( @@ -231,7 +231,7 @@ def test_full_train_circular(self): ] ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_delay_activation_forwarding_same_output_and_grad(self): # 4 stages, delayed activation forwarding, 8 layers (2 repeats, 1 layer per stage), 8 microbatches pyconfig.initialize( @@ -249,7 +249,7 @@ def test_delay_activation_forwarding_same_output_and_grad(self): config = pyconfig.config self.assert_pipeline_same_output_and_grad(config) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_full_train_non_circular(self): # Run a full train.py call with 4 stages, 32 layers (8 layers per stage), 8 microbatches train_main( diff --git a/MaxText/tests/simple_decoder_layer_test.py b/MaxText/tests/simple_decoder_layer_test.py index afa2d0aeb..ba6fa7c3c 100644 --- a/MaxText/tests/simple_decoder_layer_test.py +++ b/MaxText/tests/simple_decoder_layer_test.py @@ -18,7 +18,7 @@ class SimpleDecoderLayerTest(unittest.TestCase): - @pytest.mark.tpu + @pytest.mark.tpu_only def test_simple_decoder_layer(self): train_main( [ @@ -34,7 +34,7 @@ def test_simple_decoder_layer(self): ] ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_mlp_decoder_layer(self): train_main( [ diff --git a/MaxText/tests/standalone_dl_ckpt_test.py b/MaxText/tests/standalone_dl_ckpt_test.py index 1bd774946..652b652e1 100644 --- a/MaxText/tests/standalone_dl_ckpt_test.py +++ b/MaxText/tests/standalone_dl_ckpt_test.py @@ -34,7 +34,7 @@ def _get_random_test_name(self, test_name): random_run_name = test_name + date_time + random_string return random_run_name - @pytest.mark.tpu + @pytest.mark.tpu_only def test_standalone_dataloader(self): random_run_name = self._get_random_test_name("standalone_dataloader") sdl_main( @@ -50,7 +50,7 @@ def test_standalone_dataloader(self): ) ) # need to pass relative path to tokenizer - @pytest.mark.tpu + @pytest.mark.tpu_only def test_standalone_checkpointer(self): random_run_name = self._get_random_test_name("standalone_checkpointer") # checkpoint at 50 diff --git a/MaxText/tests/tokenizer_test.py b/MaxText/tests/tokenizer_test.py index c5222f0de..a64888e42 100644 --- a/MaxText/tests/tokenizer_test.py +++ b/MaxText/tests/tokenizer_test.py @@ -58,12 +58,12 @@ def tearDownClass(cls): os.remove(cls.tokenizer_path) @pytest.mark.skip(reason="mohitkhatwani@ will fix this") - @pytest.mark.tpu + @pytest.mark.tpu_only def test_tokenize(self): text = "This is a test" self.assertTrue(np.array_equal(self.source_tokenizer.encode(text).numpy(), self.test_tokenizer.encode(text).numpy())) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_detokenize(self): tokens = [66, 12, 10, 698] self.assertEqual(np.asarray(self.source_tokenizer.decode(tokens)), np.asarray(self.test_tokenizer.decode(tokens))) @@ -86,13 +86,13 @@ def setUpClass(cls): train_ds_builder = tfds.builder(dataset_name) cls.dataset = train_ds_builder.as_dataset(split="train", read_config=read_config, shuffle_files=True) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_tokenize(self): text = "This is a test" tokens = [2028, 374, 264, 1296] self.assertTrue(np.array_equal(self.source_tokenizer.encode(text), tokens)) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_detokenize(self): tokens = [2028, 374, 264, 1296] text = "This is a test" diff --git a/MaxText/tests/train_base_cudnn_flash_te.py b/MaxText/tests/train_base_cudnn_flash_te.py new file mode 100644 index 000000000..a693d2f75 --- /dev/null +++ b/MaxText/tests/train_base_cudnn_flash_te.py @@ -0,0 +1,46 @@ +""" +Copyright 2023 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Short test for train.py with flash attention on GPU""" +import os +import unittest +import pytest +from train import main as train_main +from absl.testing import absltest + + +class Train(unittest.TestCase): + """Tests base config on GPU with flash attention""" + + @pytest.mark.gpu_only + def test_cudnn_flash_te(self): + train_main( + [ + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "steps=2", + "enable_checkpointing=False", + "attention=cudnn_flash_te", + r"tokenizer_path=../assets/tokenizer.llama2", + ] + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/MaxText/tests/train_base_test.py b/MaxText/tests/train_base_test.py new file mode 100644 index 000000000..c5bc1080a --- /dev/null +++ b/MaxText/tests/train_base_test.py @@ -0,0 +1,50 @@ +""" +Copyright 2023 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Short test for train.py with TFDS c4""" +import os +import unittest +import pytest +from train import main as train_main +from absl.testing import absltest + + +class Train(unittest.TestCase): + """Tests base config""" + + # Shared parameters + CONFIG = [ + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "steps=2", + "enable_checkpointing=False", + r"tokenizer_path=../assets/tokenizer.llama2", + ] + + @pytest.mark.tpu_only + def test_default_config(self): + train_main(Train.CONFIG) + + @pytest.mark.gpu_only + def test_default_config_dot_product(self): + train_main(Train.CONFIG + ["attention=dot_product"]) + + +if __name__ == "__main__": + absltest.main() diff --git a/MaxText/tests/train_compile_test.py b/MaxText/tests/train_compile_test.py index 87977079c..880a7a30c 100644 --- a/MaxText/tests/train_compile_test.py +++ b/MaxText/tests/train_compile_test.py @@ -24,7 +24,7 @@ class TrainCompile(unittest.TestCase): """Tests for the Ahead of Time Compilation functionality, train_compile.py""" - @pytest.mark.tpu + @pytest.mark.tpu_only def test_save_compiled_v4(self): compiled_trainstep_file = "/tmp/test_compiled_v4.pickle" train_compile_main( @@ -40,7 +40,7 @@ def test_save_compiled_v4(self): ) ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_save_compiled_v5e(self): compiled_trainstep_file = "/tmp/test_compiled_v5e.pickle" train_compile_main( @@ -79,7 +79,7 @@ def test_minimal_offloaded_v5e(self): ) ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_save_compiled_v5p_two_slices(self): compiled_trainstep_file = "/tmp/test_compiled_v5p_two_slices.pickle" train_compile_main( @@ -97,7 +97,7 @@ def test_save_compiled_v5p_two_slices(self): # TODO (b/374764692) : Enable when v6e AOT test when stable Jax supports v6e AOT. @pytest.mark.skip(reason="Enable when downstream v6e AOT support reaches stable Jax.") - @pytest.mark.tpu + @pytest.mark.tpu_only def test_save_compiled_v6e(self): compiled_trainstep_file = "/tmp/test_compiled_v6e.pickle" train_compile_main( @@ -113,7 +113,7 @@ def test_save_compiled_v6e(self): ) ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_sequence_parallelism(self): compiled_trainstep_file = "/tmp/test_compiled.pickle" train_compile_main( @@ -131,7 +131,7 @@ def test_sequence_parallelism(self): ) ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_remat_save_dot_except_mlpwi(self): compiled_trainstep_file = "/tmp/test_remat_save_dot_except_mlpwi.pickle" train_compile_main( @@ -153,7 +153,7 @@ def test_remat_save_dot_except_mlpwi(self): ) ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_remat_save_dot_except_mlp(self): compiled_trainstep_file = "/tmp/test_remat_save_dot_except_mlp.pickle" train_compile_main( @@ -175,7 +175,7 @@ def test_remat_save_dot_except_mlp(self): ) ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_remat_save_qkv_proj(self): compiled_trainstep_file = "/tmp/test_remat_save_qkv_proj.pickle" train_compile_main( @@ -197,7 +197,7 @@ def test_remat_save_qkv_proj(self): ) ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_remat_full(self): compiled_trainstep_file = "/tmp/test_remat_full.pickle" train_compile_main( @@ -219,7 +219,7 @@ def test_remat_full(self): ) ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_custom_64x4_mesh(self): compiled_trainstep_file = "/tmp/test_custom_64x4_mesh.pickle" train_compile_main( @@ -241,7 +241,7 @@ def test_custom_64x4_mesh(self): # TODO (b/376470419) : Enable when AOT test work with host offloading. @pytest.mark.skip(reason="Enable when AOT test work with host offloading.") - @pytest.mark.tpu + @pytest.mark.tpu_only def test_llama3_1_70b_opt_offload(self): compiled_trainstep_file = "/tmp/test_llama3_1_70b_opt_offload.pickle" train_compile_main( @@ -259,7 +259,7 @@ def test_llama3_1_70b_opt_offload(self): ) ) - @pytest.mark.tpu + @pytest.mark.tpu_only def test_custom_32x8_mesh(self): compiled_trainstep_file = "/tmp/test_custom_32x8_mesh.pickle" train_compile_main( diff --git a/MaxText/tests/train_dropout_test.py b/MaxText/tests/train_dropout_test.py new file mode 100644 index 000000000..06c5b43a2 --- /dev/null +++ b/MaxText/tests/train_dropout_test.py @@ -0,0 +1,53 @@ +""" +Copyright 2023 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Short test for train.py with dropout""" +import os +import unittest +import pytest +from train import main as train_main +from absl.testing import absltest + + +class Train(unittest.TestCase): + """Tests base config with dropout""" + + # Shared parameters + CONFIG = [ + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "steps=2", + "enable_checkpointing=False", + "max_target_length=128", + "per_device_batch_size=1", + "dropout_rate=0.02", + r"tokenizer_path=../assets/tokenizer.llama2", + ] + + @pytest.mark.tpu_only + def test_default_config(self): + train_main(Train.CONFIG) + + @pytest.mark.gpu_only + def test_default_config_dot_product(self): + train_main(Train.CONFIG + ["attention=dot_product"]) + + +if __name__ == "__main__": + absltest.main() diff --git a/MaxText/tests/train_fp8_test.py b/MaxText/tests/train_fp8_test.py new file mode 100644 index 000000000..c1492cc29 --- /dev/null +++ b/MaxText/tests/train_fp8_test.py @@ -0,0 +1,51 @@ +""" +Copyright 2023 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Short test for fp8 training""" +import os +import unittest +import pytest +from train import main as train_main +from absl.testing import absltest + + +class Train(unittest.TestCase): + """Tests base config with fp8""" + + # Shared parameters + CONFIG = [ + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "quantization=fp8", + "steps=2", + "enable_checkpointing=False", + r"tokenizer_path=../assets/tokenizer.llama2", + ] + + @pytest.mark.tpu_only + def test_default_config(self): + train_main(Train.CONFIG) + + @pytest.mark.gpu_only + def test_default_config_dot_product(self): + train_main(Train.CONFIG + ["attention=dot_product"]) + + +if __name__ == "__main__": + absltest.main() diff --git a/MaxText/tests/train_hf_input_pipeline_test.py b/MaxText/tests/train_hf_input_pipeline_test.py new file mode 100644 index 000000000..8c2d72934 --- /dev/null +++ b/MaxText/tests/train_hf_input_pipeline_test.py @@ -0,0 +1,52 @@ +""" +Copyright 2023 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Short test for train.py with TFDS c4, using HF""" +import os +import unittest +import pytest +from train import main as train_main +from absl.testing import absltest + + +class Train(unittest.TestCase): + """Tests base config using HF input pipeline""" + + # Shared parameters + CONFIG = [ + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + "steps=2", + "enable_checkpointing=False", + "dataset_type=hf", + "hf_path=parquet", + r"hf_train_files=gs://maxtext-dataset/hf/c4/c4-train-00000-of-01637.parquet", + r"tokenizer_path=google-t5/t5-large", + ] + + @pytest.mark.tpu_only + def test_default_config(self): + train_main(Train.CONFIG) + + @pytest.mark.gpu_only + def test_default_config_dot_product(self): + train_main(Train.CONFIG + ["attention=dot_product"]) + + +if __name__ == "__main__": + absltest.main() diff --git a/MaxText/tests/train_int8_test.py b/MaxText/tests/train_int8_test.py new file mode 100644 index 000000000..b00c4df3a --- /dev/null +++ b/MaxText/tests/train_int8_test.py @@ -0,0 +1,51 @@ +""" +Copyright 2023 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Short test for int8 training""" +import os +import unittest +import pytest +from train import main as train_main +from absl.testing import absltest + + +class Train(unittest.TestCase): + """Tests base config with int8""" + + # Shared parameters + CONFIG = [ + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "quantization=int8", + "steps=2", + "enable_checkpointing=False", + r"tokenizer_path=../assets/tokenizer.llama2", + ] + + @pytest.mark.tpu_only + def test_default_config(self): + train_main(Train.CONFIG) + + @pytest.mark.gpu_only + def test_default_config_dot_product(self): + train_main(Train.CONFIG + ["attention=dot_product"]) + + +if __name__ == "__main__": + absltest.main() diff --git a/MaxText/tests/train_pdb_lt_1_test.py b/MaxText/tests/train_pdb_lt_1_test.py new file mode 100644 index 000000000..2cb557185 --- /dev/null +++ b/MaxText/tests/train_pdb_lt_1_test.py @@ -0,0 +1,52 @@ +""" +Copyright 2023 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Short test for train.py with per_device_batch_size < 1""" +import os +import unittest +import pytest +from train import main as train_main +from absl.testing import absltest + + +class Train(unittest.TestCase): + """Tests base config with per_device_batch_size < 1""" + + # Shared parameters + CONFIG = [ + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "steps=2", + "enable_checkpointing=False", + "per_device_batch_size=0.25", + "ici_tensor_parallelism=4", + r"tokenizer_path=../assets/tokenizer.llama2", + ] + + @pytest.mark.tpu_only + def test_default_config(self): + train_main(Train.CONFIG) + + @pytest.mark.gpu_only + def test_default_config_dot_product(self): + train_main(Train.CONFIG + ["attention=dot_product"]) + + +if __name__ == "__main__": + absltest.main() diff --git a/MaxText/tests/train_synthetic_data_test.py b/MaxText/tests/train_synthetic_data_test.py new file mode 100644 index 000000000..dbbfb6e48 --- /dev/null +++ b/MaxText/tests/train_synthetic_data_test.py @@ -0,0 +1,51 @@ +""" +Copyright 2023 Google LLC + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +"""Short test for train.py with synthetic dataset""" +import os +import unittest +import pytest +from train import main as train_main +from absl.testing import absltest + + +class Train(unittest.TestCase): + """Tests base config with synthtic dataset""" + + # Shared parameters + CONFIG = [ + None, + "configs/base.yml", + r"base_output_directory=gs://runner-maxtext-logs", + "run_name=runner_test", + r"dataset_path=gs://maxtext-dataset", + "steps=2", + "enable_checkpointing=False", + "dataset_type=synthetic", + r"tokenizer_path=../assets/tokenizer.llama2", + ] + + @pytest.mark.tpu_only + def test_default_config(self): + train_main(Train.CONFIG) + + @pytest.mark.gpu_only + def test_default_config_dot_product(self): + train_main(Train.CONFIG + ["attention=dot_product"]) + + +if __name__ == "__main__": + absltest.main()