Skip to content

Initial Llama 3.1 405B support (#108) #77

Initial Llama 3.1 405B support (#108)

Initial Llama 3.1 405B support (#108) #77

Workflow file for this run

name: E2E tests
on:
push:
branches:
- main
pull_request:
schedule:
- cron: "0 8 * * *" # Run daily at 12AM PST (adjusted for UTC)
jobs:
tp-run:
name: Submit workloads
runs-on: ubuntu-22.04
env:
ARTIFACT_DIR: gs://torchprime-e2e-tests/${{ github.job }}/${{ github.run_id }}-${{ github.run_attempt }}
outputs:
llama-3-8b-name: ${{ steps.run-llama-3-8b.outputs.name }}
mixtral-8x7b-name: ${{ steps.run-mixtral-8x7b.outputs.name }}
artifact-dir: ${{ steps.artifacts.outputs.artifact_dir }}
steps:
- name: Record artifact dir
id: artifacts
run: |
echo "Artifact dir: $ARTIFACT_DIR"
echo "artifact_dir=$ARTIFACT_DIR" >> "$GITHUB_OUTPUT"
- name: Maximize build space
uses: AdityaGarg8/remove-unwanted-software@v4.1
with:
remove-dotnet: 'true'
remove-android: 'true'
remove-haskell: 'true'
remove-codeql: 'true'
- uses: actions/checkout@v4
- uses: ./.github/actions/e2e-setup
with:
gcp_project: ${{ vars.GCP_PROJECT }}
gcp_zone: ${{ vars.GCP_ZONE }}
xpk_cluster_name: ${{ vars.XPK_CLUSTER_NAME }}
tpu_type: ${{ vars.TPU_TYPE }}
artifact_dir: ${{ env.ARTIFACT_DIR }}
gcp_sa_key: ${{ secrets.GCP_SA_KEY }}
- name: Run Llama 3.0 8B
id: run-llama-3-8b
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
run: |
name=$(e2e_testing/gen_name.py llama-3-8b)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run \
--name $name \
torchprime/torch_xla_models/train.py \
model=llama-3-8b \
global_batch_size=8 \
mesh.fsdp=4 \
dataset_config_name=wikitext-2-raw-v1 \
profile_step=3 \
max_steps=15
- name: Run Mixtral 8x7B
id: run-mixtral-8x7b
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
run: |
name=$(e2e_testing/gen_name.py mixtral-8x7b)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run \
--name $name \
torchprime/torch_xla_models/train.py \
model=mixtral-8x7b \
model.num_hidden_layers=16 \
global_batch_size=8 \
mesh.fsdp=4 \
dataset_config_name=wikitext-2-raw-v1 \
profile_step=3 \
max_steps=15
llama-3-8b:
name: Llama 3.0 8B
needs: tp-run
uses: ./.github/workflows/reusable_e2e_check.yml
with:
jobset_name: ${{ needs.tp-run.outputs.llama-3-8b-name }}
artifact_dir: ${{ needs.tp-run.outputs.artifact-dir }}
secrets: inherit
mixtral-8x7b:
name: Mixtral 8x7B
needs: tp-run
uses: ./.github/workflows/reusable_e2e_check.yml
with:
jobset_name: ${{ needs.tp-run.outputs.mixtral-8x7b-name }}
artifact_dir: ${{ needs.tp-run.outputs.artifact-dir }}
secrets: inherit