Skip to content

Commit

Permalink
Move tests from UnitTests.yml to tests files.
Browse files Browse the repository at this point in the history
This change moves out train and decode tests into separate files,
so they can be tested using pytest rather than explicitly being
invoked from UnitTests.yml. This will allow us, in a subsequent
PR, to parallelize the execution of these tests by utilizing
pytest parallelization support.

Several end-to-end tests are still being invoked from UnitTests.yml.
This does not seem like the right place, and will hopefully be
addressed in the future.
  • Loading branch information
shralex committed Dec 21, 2024
1 parent 6ec3368 commit 84bc33f
Show file tree
Hide file tree
Showing 30 changed files with 825 additions and 137 deletions.
94 changes: 12 additions & 82 deletions .github/workflows/UnitTests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ on:
branches: [ "main" ]
workflow_dispatch:
schedule:
# Run the job every 2 hours
- cron: '0 */2 * * *'
# Run the job every 6 hours
- cron: '0 */6 * * *'

jobs:
build_and_upload_image:
Expand All @@ -46,7 +46,7 @@ jobs:
run: docker system prune --all --force
- name: Build an image
run: |
bash docker_build_dependency_image.sh MODE=${{ matrix.device.mode }} DEVICE=${{ matrix.device.type }}
bash docker_build_dependency_image.sh MODE=${{ matrix.device.mode }} DEVICE=${{ matrix.device.type }}
- name: Tag the image
run: |
docker tag maxtext_base_image gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:${{ matrix.device.type }}
Expand All @@ -62,17 +62,15 @@ jobs:
device:
- type: tpu
name: v4-8
attention: autoselected
pytest_marker: ''
pytest_marker: 'not gpu_only' # exclude tests marked gpu_only
container_env:
XLA_PYTHON_CLIENT_MEM_FRACTION: 0.75
TF_FORCE_GPU_ALLOW_GROWTH: false
container_resource_option: "--privileged"
- type: gpu
name: a100-40gb-4
image_suffix: gpu_jax_pinned
attention: dot_product
pytest_marker: -m 'not tpu'
pytest_marker: 'not tpu_only' # exclude tests marked tpu_only
container_env:
XLA_PYTHON_CLIENT_MEM_FRACTION: 0.65
TF_FORCE_GPU_ALLOW_GROWTH: true
Expand All @@ -81,7 +79,7 @@ jobs:
runs-on: ["self-hosted", "${{ matrix.device.type }}", "${{ matrix.device.name }}"]
container:
image: gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:${{ matrix.device.type }}
volumes:
volumes:
- /home/runner/actions-runner/_work/maxtext/maxtext:/deps
env:
XLA_PYTHON_CLIENT_MEM_FRACTION: ${{ matrix.device.container_env.XLA_PYTHON_CLIENT_MEM_FRACTION }}
Expand All @@ -91,87 +89,19 @@ jobs:
- uses: actions/checkout@v4
- name: Test gsutil installation
run: which gsutil >/dev/null 2>&1 || { echo >&2 "gsutil is required but not installed. Aborting"; exit 24;}
- name: Test with pytest
run: cd MaxText;python3 -m pytest ${{ matrix.device.pytest_marker }}
- name: Test train.py with TFDS c4
run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false attention=${{ matrix.device.attention }}
- name: Test train.py with HF c4
run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs hf_train_files=gs://maxtext-dataset/hf/c4/c4-train-00000-of-01637.parquet hf_path=parquet dataset_type=hf steps=2 tokenizer_path=google-t5/t5-large attention=${{ matrix.device.attention }} enable_checkpointing=false
- name: Test train.py with synthetic data
run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false attention=${{ matrix.device.attention }} dataset_type=synthetic
- name: Test train.py with per_device_batch_size < 1
run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 per_device_batch_size=0.25 ici_tensor_parallelism=4 enable_checkpointing=false attention=${{ matrix.device.attention }}
- name: Test decode.py
run: python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=${{ matrix.device.attention }} enable_checkpointing=false max_target_length=128 per_device_batch_size=1
- name: Test int8_decode
run: python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=${{ matrix.device.attention }} enable_checkpointing=false max_target_length=128 per_device_batch_size=1 quantization=int8 quantize_kvcache=True
- name: Test decode.py with per_device_batch_size < 1
run: python3 MaxText/decode.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 ici_tensor_parallelism=4 attention=${{ matrix.device.attention }} enable_checkpointing=false max_target_length=128 per_device_batch_size=.25
- name: Test int8_training
run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=int8 steps=2 enable_checkpointing=false attention=${{ matrix.device.attention }}
- name: Test fp8_training
run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset quantization=fp8 steps=2 enable_checkpointing=false attention=${{ matrix.device.attention }}
- name: Test train.py with dropout
run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false attention=${{ matrix.device.attention }} max_target_length=128 per_device_batch_size=1 dropout_rate=0.02
- name: Test generate_param_only_checkpoint
run: bash end_to_end/test_generate_param_only_checkpoint.sh -r runner_$(date +%Y-%m-%d-%H-%M-%S) -o gs://runner-maxtext-logs -d gs://maxtext-dataset -i 4 -a ${{ matrix.device.attention }}
- name: Test generate_param_only_checkpoint with int8 quantization
run: bash end_to_end/test_generate_param_only_checkpoint.sh -r runner_$(date +%Y-%m-%d-%H-%M-%S) -o gs://runner-maxtext-logs -d gs://maxtext-dataset -i 4 -q int8 -a ${{ matrix.device.attention }}
- name: Test grain checkpoint determinism
run: bash end_to_end/test_checkpointing.sh runner_$(date +%Y-%m-%d-%H-%M-%S) gs://runner-maxtext-logs gs://maxtext-dataset False grain ${{ matrix.device.attention }}
- name: Test checkpoint compatibility
run: bash end_to_end/test_checkpoint_compatibility.sh runner_$(date +%Y-%m-%d-%H-%M-%S) gs://runner-maxtext-logs gs://maxtext-dataset ${{ matrix.device.attention }}

tpu:
needs: build_and_upload_image
strategy:
fail-fast: false
matrix:
device-type: ["v4-8"]
name: "TPU test (${{ matrix.device-type }})"
runs-on: ["self-hosted", "tpu", "${{ matrix.device-type }}"]
container:
image: gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:tpu
volumes:
- /home/runner/actions-runner/_work/maxtext/maxtext:/deps
options: "--privileged"
steps:
- uses: actions/checkout@v4
- name: Validate Pedagogical Example, Shmap_collective_matmul
run: python3 pedagogical_examples/shmap_collective_matmul.py

gpu:
needs: build_and_upload_image
strategy:
fail-fast: false
matrix:
device-type: ["a100-40gb-4"]
build-mode: ["pinned"]
name: "GPU test (${{ matrix.device-type }}, ${{ matrix.build-mode }})"
runs-on: ["self-hosted", "gpu", "${{ matrix.device-type }}"]
container:
image: gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:gpu
volumes:
- /home/runner/actions-runner/_work/maxtext/maxtext:/deps
env:
XLA_PYTHON_CLIENT_MEM_FRACTION: 0.65
TF_FORCE_GPU_ALLOW_GROWTH: true
options: "--shm-size 2g --runtime=nvidia --gpus all --privileged"
steps:
- uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Test train.py with flash attention
run: python3 MaxText/train.py MaxText/configs/base.yml run_name=runner_$(date +%Y-%m-%d-%H-%M-%S) base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset steps=2 enable_checkpointing=false attention=cudnn_flash_te
- name: Unit Tests
run: cd MaxText;python3 -m pytest tests -m "${{ matrix.device.pytest_marker }} and not integration_test"
- name: Integration Tests
run: cd MaxText; python3 -m pytest tests/integration_tests -m "${{ matrix.device.pytest_marker }} and integration_test"

clean_up:
if: ${{ always() }}
needs: [common, gpu, tpu]
needs: common
name: "Clean up"
runs-on: ["self-hosted"]
steps:
- name: Delete GPU image
run: gcloud container images delete gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:gpu --force-delete-tags --quiet
- name: Delete TPU image
run: gcloud container images delete gcr.io/tpu-prod-env-multipod/maxtext_${{ github.run_id }}:tpu --force-delete-tags --quiet

22 changes: 13 additions & 9 deletions MaxText/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,20 @@

import os
import pyconfig
import sys

from typing import Sequence
from absl import app


def main(argv: Sequence[str]) -> None:
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"

pyconfig.initialize(argv)
config = pyconfig.config
validate_config(config)
max_utils.print_system_information()

def main(config):
engine = maxengine.MaxEngine(config)
rng = jax.random.PRNGKey(1234)
rng, rng_load_params = jax.random.split(rng)
Expand Down Expand Up @@ -71,10 +81,4 @@ def validate_config(config):


if __name__ == "__main__":
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
pyconfig.initialize(sys.argv)
cfg = pyconfig.config
validate_config(cfg)
max_utils.print_system_information()
main(cfg)
app.run(main)
12 changes: 8 additions & 4 deletions MaxText/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@
testpaths =
tests
python_files = *_test.py
addopts =
-rf --import-mode=importlib
addopts =
-rf --import-mode=importlib --strict-markers
--ignore=tests/profiler_test.py
--ignore=tests/train_smoke_test.py
--ignore=tests/train_int8_smoke_test.py
--ignore=tests/train_gpu_smoke_test.py
markers =
tpu: marks tests to be run on TPU
markers =
tpu_only: marks tests to be run on TPUs only
gpu_only: marks tests to be run on GPUs only
integration_test: tests exercising larger portions of the system,
including interactions with other systems like GCS,
e.g., end_to_end tests
16 changes: 8 additions & 8 deletions MaxText/tests/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_structured_data(self, dtype):

return lnx, decoder_segment_ids, decoder_positions

@pytest.mark.tpu
@pytest.mark.tpu_only
def test_autoregression(self):
prefill_length = self.cfg.max_prefill_predict_length
decode_total_length = self.cfg.max_target_length
Expand Down Expand Up @@ -174,11 +174,11 @@ def test_autoregression(self):
self.assertTrue(mha_full_this_idx.shape == mha_idx.shape)
self.assertTrue(jax.numpy.allclose(mha_full_this_idx, mha_idx, rtol=1e-02, atol=1e-02, equal_nan=False))

@pytest.mark.tpu
@pytest.mark.tpu_only
def test_model_mode_prefill_dtype_float32(self):
self._test_model_mode_prefill_dtype(jnp.float32)

@pytest.mark.tpu
@pytest.mark.tpu_only
def test_model_mode_prefill_dtype_bfloat16(self):
self._test_model_mode_prefill_dtype(jnp.bfloat16)

Expand Down Expand Up @@ -224,15 +224,15 @@ def _test_model_mode_prefill_dtype(self, dtype):

self.assertEqual(dtype, mha_prefill.dtype)

@pytest.mark.tpu
@pytest.mark.tpu_only
def test_tpu_kernel_attention_mha(self):
self.tpu_kernel_attention_helper(self.num_kv_heads)

@pytest.mark.tpu
@pytest.mark.tpu_only
def test_tpu_kernel_attention_gqa(self):
self.tpu_kernel_attention_helper(self.num_kv_heads // 2)

@pytest.mark.tpu
@pytest.mark.tpu_only
def test_tpu_kernel_attention_mqa(self):
self.tpu_kernel_attention_helper(1)

Expand Down Expand Up @@ -309,7 +309,7 @@ def tpu_kernel_attention_helper(self, num_kv_heads):
jax.numpy.allclose(mha_generic_output, mha_generic_flash_output, rtol=1e-01, atol=1e-01, equal_nan=False)
)

@pytest.mark.tpu
@pytest.mark.tpu_only
def test_dot_product_cache_axis_order(self):
all_axis_orders = [axis_order for axis_order in itertools.permutations(range(4))]
for axis_order in random.choices(all_axis_orders, k=4):
Expand Down Expand Up @@ -423,7 +423,7 @@ def _dot_product_attention(
jax.numpy.allclose(attention_w_layout_full_this_idx, attention_w_layout_idx, rtol=rtol, atol=atol, equal_nan=False)
)

@pytest.mark.tpu
@pytest.mark.tpu_only
def test_dot_product_reshape_q(self):
for compute_axis_order in [(0, 1, 2, 3), (0, 2, 1, 3)]:
self._dot_product_attention_reshape_q(
Expand Down
55 changes: 55 additions & 0 deletions MaxText/tests/decode_int8_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
"""
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""Short test for decode with int8 quantization"""
import os
import unittest
import pytest
from decode import main as decode_main
from absl.testing import absltest


class Train(unittest.TestCase):
"""Tests decode with int8 quantization"""

# Shared parameters
CONFIG = [
None,
"configs/base.yml",
r"base_output_directory=gs://runner-maxtext-logs",
"run_name=runner_test",
r"dataset_path=gs://maxtext-dataset",
"steps=2",
"enable_checkpointing=False",
"ici_tensor_parallelism=4",
"max_target_length=128",
"per_device_batch_size=1",
"quantization=int8",
"quantize_kvcache=True",
r"tokenizer_path=../assets/tokenizer.llama2",
]

@pytest.mark.tpu_only
def test_default_config(self):
decode_main(Train.CONFIG)

@pytest.mark.gpu_only
def test_default_config_dot_product(self):
decode_main(Train.CONFIG + ["attention=dot_product"])


if __name__ == "__main__":
absltest.main()
53 changes: 53 additions & 0 deletions MaxText/tests/decode_pdb_lt_1_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
"""
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

"""Short test for decode with per_device_batch_size < 1"""
import os
import unittest
import pytest
from decode import main as decode_main
from absl.testing import absltest


class Train(unittest.TestCase):
"""Tests decode with per_device_batch_size < 1"""

# Shared parameters
CONFIG = [
None,
"configs/base.yml",
r"base_output_directory=gs://runner-maxtext-logs",
"run_name=runner_test",
r"dataset_path=gs://maxtext-dataset",
"steps=2",
"enable_checkpointing=False",
"ici_tensor_parallelism=4",
"max_target_length=128",
"per_device_batch_size=.25",
r"tokenizer_path=../assets/tokenizer.llama2",
]

@pytest.mark.tpu_only
def test_default_config(self):
decode_main(Train.CONFIG)

@pytest.mark.gpu_only
def test_default_config_dot_product(self):
decode_main(Train.CONFIG + ["attention=dot_product"])


if __name__ == "__main__":
absltest.main()
Loading

0 comments on commit 84bc33f

Please sign in to comment.