CI - Wheel Tests (Continuous) #127
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# CI - Wheel Tests (Continuous) | |
# | |
# This workflow builds JAX artifacts and runs CPU/CUDA tests. | |
# | |
# It orchestrates the following: | |
# 1. build-jaxlib-artifact: Calls the `build_artifacts.yml` workflow to build jaxlib and | |
# uploads it to a GCS bucket. | |
# 2. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow which downloads the jaxlib wheel | |
# that was built in the previous step and runs CPU tests. | |
# 3. build-cuda-artifacts: Calls the `build_artifacts.yml` workflow to build CUDA artifacts and | |
# uploads them to a GCS bucket. | |
# 4. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow which downloads the jaxlib and CUDA | |
# artifacts that were built in the previous steps and runs the CUDA tests. | |
# 5. run-bazel-test-cuda: Calls the `bazel_cuda_non_rbe.yml` workflow which downloads the jaxlib | |
# and CUDA artifacts that were built in the previous steps and runs the | |
# CUDA tests using Bazel. | |
name: CI - Wheel Tests (Continuous) | |
on: | |
schedule: | |
- cron: "0 */2 * * *" # Run once every 2 hours | |
concurrency: | |
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} | |
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }} | |
jobs: | |
build-jaxlib-artifact: | |
uses: ./.github/workflows/build_artifacts.yml | |
strategy: | |
fail-fast: false # don't cancel all jobs on failure | |
matrix: | |
# Runner OS and Python values need to match the matrix stategy in the CPU tests job | |
runner: ["linux-x86-n2-16", "linux-arm64-t2a-48", "windows-x86-n2-16"] | |
artifact: ["jaxlib"] | |
python: ["3.10"] | |
with: | |
runner: ${{ matrix.runner }} | |
artifact: ${{ matrix.artifact }} | |
python: ${{ matrix.python }} | |
clone_main_xla: 1 | |
upload_artifacts_to_gcs: true | |
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' | |
build-cuda-artifacts: | |
uses: ./.github/workflows/build_artifacts.yml | |
strategy: | |
fail-fast: false # don't cancel all jobs on failure | |
matrix: | |
# Python values need to match the matrix stategy in the CUDA tests job below | |
runner: ["linux-x86-n2-16"] | |
artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"] | |
python: ["3.10",] | |
with: | |
runner: ${{ matrix.runner }} | |
artifact: ${{ matrix.artifact }} | |
python: ${{ matrix.python }} | |
clone_main_xla: 1 | |
upload_artifacts_to_gcs: true | |
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' | |
run-pytest-cpu: | |
needs: build-jaxlib-artifact | |
uses: ./.github/workflows/pytest_cpu.yml | |
strategy: | |
fail-fast: false # don't cancel all jobs on failure | |
matrix: | |
# Runner OS and Python values need to match the matrix stategy in the | |
# build_jaxlib_artifact job above | |
runner: ["linux-x86-n2-64", "linux-arm64-t2a-48", "windows-x86-n2-64"] | |
python: ["3.10",] | |
enable-x64: [1, 0] | |
with: | |
runner: ${{ matrix.runner }} | |
python: ${{ matrix.python }} | |
enable-x64: ${{ matrix.enable-x64 }} | |
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} | |
run-pytest-cuda: | |
needs: [build-jaxlib-artifact, build-cuda-artifacts] | |
uses: ./.github/workflows/pytest_cuda.yml | |
strategy: | |
fail-fast: false # don't cancel all jobs on failure | |
matrix: | |
# Python values need to match the matrix stategy in the artifact build jobs above | |
runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu"] | |
python: ["3.10",] | |
cuda: ["12.3", "12.1"] | |
enable-x64: [1, 0] | |
exclude: | |
# Run only a single configuration on H100 to save resources | |
- runner: "linux-x86-a3-8g-h100-8gpu" | |
python: "3.10" | |
cuda: "12.1" | |
- runner: "linux-x86-a3-8g-h100-8gpu" | |
python: "3.10" | |
enable-x64: 0 | |
with: | |
runner: ${{ matrix.runner }} | |
python: ${{ matrix.python }} | |
cuda: ${{ matrix.cuda }} | |
enable-x64: ${{ matrix.enable-x64 }} | |
# GCS upload URI is the same for both artifact build jobs | |
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} | |
run-bazel-test-cuda: | |
needs: [build-jaxlib-artifact, build-cuda-artifacts] | |
uses: ./.github/workflows/bazel_cuda_non_rbe.yml | |
strategy: | |
fail-fast: false # don't cancel all jobs on failure | |
matrix: | |
# Python values need to match the matrix stategy in the build artifacts job above | |
runner: ["linux-x86-g2-48-l4-4gpu",] | |
python: ["3.10",] | |
enable-x64: [1, 0] | |
with: | |
runner: ${{ matrix.runner }} | |
python: ${{ matrix.python }} | |
enable-x64: ${{ matrix.enable-x64 }} | |
# GCS upload URI is the same for both artifact build jobs | |
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} |