-
Notifications
You must be signed in to change notification settings - Fork 1
99 lines (93 loc) · 3.01 KB
/
e2e_test.yml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
name: E2E tests
on:
push:
branches:
- main
pull_request:
schedule:
- cron: "0 8 * * *" # Run daily at 12AM PST (adjusted for UTC)
jobs:
tp-run:
name: Submit workloads
runs-on: ubuntu-22.04
env:
ARTIFACT_DIR: gs://torchprime-e2e-tests/${{ github.job }}/${{ github.run_id }}-${{ github.run_attempt }}
outputs:
llama-3-8b-name: ${{ steps.run-llama-3-8b.outputs.name }}
mixtral-8x7b-name: ${{ steps.run-mixtral-8x7b.outputs.name }}
artifact-dir: ${{ steps.artifacts.outputs.artifact_dir }}
steps:
- name: Record artifact dir
id: artifacts
run: |
echo "Artifact dir: $ARTIFACT_DIR"
echo "artifact_dir=$ARTIFACT_DIR" >> "$GITHUB_OUTPUT"
- name: Maximize build space
uses: AdityaGarg8/remove-unwanted-software@v4.1
with:
remove-dotnet: 'true'
remove-android: 'true'
remove-haskell: 'true'
remove-codeql: 'true'
- uses: actions/checkout@v4
- uses: ./.github/actions/e2e-setup
with:
gcp_project: ${{ vars.GCP_PROJECT }}
gcp_zone: ${{ vars.GCP_ZONE }}
xpk_cluster_name: ${{ vars.XPK_CLUSTER_NAME }}
tpu_type: ${{ vars.TPU_TYPE }}
artifact_dir: ${{ env.ARTIFACT_DIR }}
gcp_sa_key: ${{ secrets.GCP_SA_KEY }}
- name: Run Llama 3.0 8B
id: run-llama-3-8b
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
run: |
name=$(e2e_testing/gen_name.py llama-3-8b)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run \
--name $name \
torchprime/torch_xla_models/train.py \
model=llama-3-8b \
global_batch_size=8 \
mesh.fsdp=4 \
dataset_config_name=wikitext-2-raw-v1 \
profile_step=3 \
max_steps=15
- name: Run Mixtral 8x7B
id: run-mixtral-8x7b
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
run: |
name=$(e2e_testing/gen_name.py mixtral-8x7b)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run \
--name $name \
torchprime/torch_xla_models/train.py \
model=mixtral-8x7b \
model.num_hidden_layers=16 \
global_batch_size=8 \
mesh.fsdp=4 \
dataset_config_name=wikitext-2-raw-v1 \
profile_step=3 \
max_steps=15
llama-3-8b:
name: Llama 3.0 8B
needs: tp-run
uses: ./.github/workflows/reusable_e2e_check.yml
with:
jobset_name: ${{ needs.tp-run.outputs.llama-3-8b-name }}
artifact_dir: ${{ needs.tp-run.outputs.artifact-dir }}
secrets: inherit
mixtral-8x7b:
name: Mixtral 8x7B
needs: tp-run
uses: ./.github/workflows/reusable_e2e_check.yml
with:
jobset_name: ${{ needs.tp-run.outputs.mixtral-8x7b-name }}
artifact_dir: ${{ needs.tp-run.outputs.artifact-dir }}
secrets: inherit