CI - Wheel Tests (Continuous) #20
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# CI - Wheel Tests (Continuous) | |
# | |
# This workflow builds JAX artifacts and runs CPU/CUDA tests. | |
# | |
# It orchestrates the following: | |
# 1. build-jaxlib-artifact: Calls the `build_artifacts.yml` workflow to build jaxlib and | |
# uploads it to a GCS bucket. | |
# 2. run-pytest-cpu: Calls the `pytest_cpu.yml` workflow to download the jaxlib wheel that was built | |
# in the previous step and runs CPU tests. | |
# 3. build-cuda-artifacts: Calls the `build_artifacts.yml` workflow to build CUDA artifacts and | |
# uploads them to a GCS bucket. | |
# 4. run-pytest-cuda: Calls the `pytest_cuda.yml` workflow to download the jaxlib and CUDA artifacts | |
# that were built in the previous steps and runs the CUDA tests. | |
name: CI - Wheel Tests (Continuous) | |
on: | |
schedule: | |
- cron: "0 */2 * * *" # Run once every 2 hours | |
concurrency: | |
group: ${{ github.workflow }}-${{ github.head_ref || github.ref }} | |
cancel-in-progress: ${{ !contains(github.ref, 'release/') && github.ref != 'main' }} | |
jobs: | |
build-jaxlib-artifact: | |
uses: ./.github/workflows/build_artifacts.yml | |
strategy: | |
fail-fast: false # don't cancel all jobs on failure | |
matrix: | |
# Runner OS and Python values need to match the matrix stategy in the CPU tests job | |
runner: ["linux-x86-n2-16", "linux-arm64-t2a-48", "windows-x86-n2-16"] | |
artifact: ["jaxlib"] | |
python: ["3.10"] | |
with: | |
runner: ${{ matrix.runner }} | |
artifact: ${{ matrix.artifact }} | |
python: ${{ matrix.python }} | |
clone_main_xla: 1 | |
upload_artifacts_to_gcs: true | |
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' | |
build-cuda-artifacts: | |
uses: ./.github/workflows/build_artifacts.yml | |
strategy: | |
fail-fast: false # don't cancel all jobs on failure | |
matrix: | |
# Python values need to match the matrix stategy in the GPU tests job below | |
runner: ["linux-x86-n2-16"] | |
artifact: ["jax-cuda-plugin", "jax-cuda-pjrt"] | |
python: ["3.10",] | |
with: | |
runner: ${{ matrix.runner }} | |
artifact: ${{ matrix.artifact }} | |
python: ${{ matrix.python }} | |
clone_main_xla: 1 | |
upload_artifacts_to_gcs: true | |
gcs_upload_uri: 'gs://general-ml-ci-transient/jax-github-actions/jax/${{ github.workflow }}/${{ github.run_number }}/${{ github.run_attempt }}' | |
run-pytest-cpu: | |
needs: build-jaxlib-artifact | |
uses: ./.github/workflows/pytest_cpu.yml | |
strategy: | |
fail-fast: false # don't cancel all jobs on failure | |
matrix: | |
# Runner OS and Python values need to match the matrix stategy in the | |
# build_jaxlib_artifact job above | |
runner: ["linux-x86-n2-64", "linux-arm64-t2a-48", "windows-x86-n2-64"] | |
python: ["3.10",] | |
enable-x64: [1, 0] | |
with: | |
runner: ${{ matrix.runner }} | |
python: ${{ matrix.python }} | |
enable-x64: ${{ matrix.enable-x64 }} | |
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} | |
run-pytest-cuda: | |
needs: [build-jaxlib-artifact, build-cuda-artifacts] | |
uses: ./.github/workflows/pytest_cuda.yml | |
strategy: | |
fail-fast: false # don't cancel all jobs on failure | |
matrix: | |
# Python values need to match the matrix stategy in the artifact build jobs above | |
runner: ["linux-x86-g2-48-l4-4gpu", "linux-x86-a3-8g-h100-8gpu"] | |
python: ["3.10",] | |
cuda: ["12.3", "12.1"] | |
enable-x64: [1, 0] | |
exclude: | |
# Run only a single configuration on H100 to save resources | |
- runner: "linux-x86-a3-8g-h100-8gpu" | |
python: "3.10" | |
cuda: "12.1" | |
- runner: "linux-x86-a3-8g-h100-8gpu" | |
python: "3.10" | |
enable-x64: 0 | |
with: | |
runner: ${{ matrix.runner }} | |
python: ${{ matrix.python }} | |
cuda: ${{ matrix.cuda }} | |
enable-x64: ${{ matrix.enable-x64 }} | |
# GCS upload URI is the same for both artifact build jobs | |
gcs_download_uri: ${{ needs.build-jaxlib-artifact.outputs.gcs_upload_uri }} |