Skip to content

E2E tests

E2E tests #1951

Workflow file for this run

name: E2E tests
on:
push:
branches:
- main
pull_request:
schedule:
- cron: "0 8 * * *" # Run daily at 12AM PST (adjusted for UTC)
workflow_dispatch:
inputs:
docker_url:
description: If specified, use this PyTorch/XLA base docker image URL instead of the pin.
required: false
type: string
jobs:
tp-run:
name: Submit workloads
runs-on: ubuntu-24.04
env:
ARTIFACT_DIR: gs://torchprime-e2e-tests/${{ github.job }}/${{ github.run_id }}-${{ github.run_attempt }}
outputs:
llama-3-8b-name: ${{ steps.run-llama-3-8b.outputs.name }}
llama-3-8b-pure-mlp-name: ${{ steps.run-llama-3-8b-pure-mlp.outputs.name }}
llama-3_1-8b-sa-name: ${{ steps.run-llama-3_1-8b-SplashAttention.outputs.name }}
llama-3_1-8b-scan-offload-name: ${{ steps.run-llama-3_1-8b-scan-offload.outputs.name }}
llama-3-8b-2d-name: ${{ steps.run-llama-3-8b-2d.outputs.name }}
llama-3-8b-2-slice-name: ${{ steps.run-llama-3-8b-2-slice.outputs.name }}
llama-3-8b-sft-name: ${{ steps.run-llama-3-8b-sft.outputs.name }}
llama-3-8b-ddp-fsdp-name: ${{ steps.run-llama-3-8b-ddp-fsdp.outputs.name }}
llama-3-8b-fsdp-cp-name: ${{ steps.run-llama-3-8b-fsdp-cp.outputs.name }}
mixtral-8x7b-name: ${{ steps.run-mixtral-8x7b.outputs.name }}
ds-v3-shallow-name: ${{ steps.run-ds-v3-shallow.outputs.name }}
artifact-dir: ${{ steps.artifacts.outputs.artifact_dir }}
steps:
- name: Record artifact dir
id: artifacts
run: |
echo "Artifact dir: $ARTIFACT_DIR"
echo "artifact_dir=$ARTIFACT_DIR" >> "$GITHUB_OUTPUT"
- name: Maximize build space
uses: AdityaGarg8/[email protected]
with:
remove-dotnet: 'true'
remove-android: 'true'
remove-haskell: 'true'
remove-codeql: 'true'
- uses: actions/checkout@v4
- uses: ./.github/actions/e2e-setup
with:
gcp_project: ${{ vars.GCP_PROJECT }}
gcp_zone: ${{ vars.GCP_ZONE }}
xpk_cluster_name: ${{ vars.XPK_CLUSTER_NAME }}
tpu_type: ${{ vars.TPU_TYPE }}
artifact_dir: ${{ env.ARTIFACT_DIR }}
gcp_sa_key: ${{ secrets.GCP_SA_KEY }}
- name: Setup Docker URL option
id: docker-url-option
run: |
if [ -n "${{ github.event.inputs.docker_url }}" ]; then
echo "value=--base-docker-url ${{ github.event.inputs.docker_url }}" >> "$GITHUB_OUTPUT"
else
echo "value=" >> "$GITHUB_OUTPUT"
fi
# Launch training workloads.
- name: Run Llama 3.0 8B
id: run-llama-3-8b
env:
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
run: |
name=$(e2e_testing/gen_name.py llama-3-8b)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run ${{ steps.docker-url-option.outputs.value }} \
--name $name \
torchprime/torch_xla_models/train.py \
model=llama-3-8b \
model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \
dataset.hf_dataset_name=gs://torchprime/e2e-test/datasets/wikitext \
task=train \
task.global_batch_size=8 \
task.lr_scheduler.type=constant \
task.max_steps=15 \
ici_mesh.fsdp=4 \
profile_start_step=3
- name: Run Llama 3.0 8B (@assume_pure)
id: run-llama-3-8b-pure-mlp
env:
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
run: |
name=$(e2e_testing/gen_name.py llama-3-8b-pure-mlp)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run ${{ steps.docker-url-option.outputs.value }} \
--name $name \
torchprime/torch_xla_models/train.py \
model=llama-3-8b \
model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \
dataset.hf_dataset_name=gs://torchprime/e2e-test/datasets/wikitext \
task=train \
task.global_batch_size=8 \
task.lr_scheduler.type=constant \
task.max_steps=15 \
ici_mesh.fsdp=4 \
profile_start_step=3 \
model.pure_modules=[LlamaMLP,EinsumLinear]
- name: Run Llama 3.1 8B (Splash Attention)
id: run-llama-3_1-8b-SplashAttention
env:
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
run: |
name=$(e2e_testing/gen_name.py llama-3dot1-8b-sa)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run ${{ steps.docker-url-option.outputs.value }} \
--name $name \
torchprime/torch_xla_models/train.py \
model=llama-3.1-8b \
model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3.1-405b \
model.attention_kernel=splash_attention \
dataset.hf_dataset_name=gs://torchprime/e2e-test/datasets/wikitext \
task=train \
task.global_batch_size=8 \
task.lr_scheduler.type=constant \
task.max_steps=15 \
ici_mesh.fsdp=4 \
profile_start_step=3
- name: Run Llama 3.1 8B (Scan + Offload)
id: run-llama-3_1-8b-scan-offload
env:
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
run: |
name=$(e2e_testing/gen_name.py llama-3dot1-8b-scan-offload)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run ${{ steps.docker-url-option.outputs.value }} \
--name $name \
torchprime/torch_xla_models/train.py \
model=llama-3.1-8b \
model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3.1-405b \
model/remat=llama-scan-offload \
dataset.hf_dataset_name=gs://torchprime/e2e-test/datasets/wikitext \
task=train \
task.global_batch_size=8 \
task.lr_scheduler.type=constant \
task.max_steps=15 \
ici_mesh.fsdp=4 \
profile_start_step=3
- name: Run Llama 3.0 8B (2D sharding)
id: run-llama-3-8b-2d
env:
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
run: |
name=$(e2e_testing/gen_name.py llama-3-8b-2d)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run ${{ steps.docker-url-option.outputs.value }} \
--name $name \
torchprime/torch_xla_models/train.py \
model=llama-3-8b \
model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \
model/sharding=llama-fsdp-tp \
dataset.hf_dataset_name=gs://torchprime/e2e-test/datasets/wikitext \
task=train \
task.global_batch_size=8 \
task.lr_scheduler.type=constant \
task.max_steps=15 \
ici_mesh.fsdp=2 \
ici_mesh.tensor=2 \
profile_start_step=3
- name: Run Llama 3.0 8B (fsdp + cp) # TODO: Update ici_mesh.context = 2 and ici_mesh.fsdp = 2 once e2e test after debug
id: run-llama-3-8b-fsdp-cp
env:
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
run: |
name=$(e2e_testing/gen_name.py llama-3-8b-fsdp-cp)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run ${{ steps.docker-url-option.outputs.value }} \
--name $name \
torchprime/torch_xla_models/train.py \
model=llama-3-8b-cp \
model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \
model/sharding=llama-fsdp-tp-cp \
dataset.hf_dataset_name=gs://torchprime/e2e-test/datasets/wikitext \
task=train \
task.global_batch_size=4 \
task.max_steps=15 \
task.lr_scheduler.type=constant \
ici_mesh.fsdp=4 \
profile_start_step=3
- name: Run Mixtral 8x7B
id: run-mixtral-8x7b
env:
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
run: |
name=$(e2e_testing/gen_name.py mixtral-8x7b)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run ${{ steps.docker-url-option.outputs.value }} \
--name $name \
torchprime/torch_xla_models/train.py \
model=mixtral-8x7b \
model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/mixtral-8x7b-v0.1/ \
model.num_hidden_layers=16 \
dataset.hf_dataset_name=gs://torchprime/e2e-test/datasets/wikitext \
task=train \
task.global_batch_size=8 \
task.lr_scheduler.type=constant \
task.max_steps=15 \
ici_mesh.fsdp=4 \
profile_start_step=3
- name: Run Llama 3.0 8B (2 slice)
id: run-llama-3-8b-2-slice
env:
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
run: |
name=$(e2e_testing/gen_name.py llama-3-8b-2-slice)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run ${{ steps.docker-url-option.outputs.value }} \
--name $name \
--num-slices 2 \
torchprime/torch_xla_models/train.py \
model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \
model=llama-3-8b \
model/sharding=llama-fsdp \
dataset.hf_dataset_name=gs://torchprime/e2e-test/datasets/wikitext \
task=train \
task.global_batch_size=16 \
task.lr_scheduler.type=constant \
task.max_steps=15 \
dcn_mesh.fsdp=2 \
ici_mesh.fsdp=4 \
profile_start_step=3
- name: Run Llama 3.0 8B SFT
id: run-llama-3-8b-sft
env:
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
run: |
name=$(e2e_testing/gen_name.py llama-3-8b-sft)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run ${{ steps.docker-url-option.outputs.value }} \
--name $name \
torchprime/torch_xla_models/train.py \
--config-name llama-3-8b-sft-w-gsm8k \
model.pretrained_model=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \
dataset.hf_dataset_name=gs://torchprime/e2e-test/datasets/gsm8k \
model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \
ici_mesh.fsdp=4 \
task.max_steps=50 \
task.convert_to_safetensors=False \
profile_start_step=3
- name: Run Llama 3.0 8B (ddp + fsdp)
id: run-llama-3-8b-ddp-fsdp
env:
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
run: |
name=$(e2e_testing/gen_name.py llama-3-8b-ddp-fsdp)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run ${{ steps.docker-url-option.outputs.value }} \
--name $name \
--num-slices 2 \
torchprime/torch_xla_models/train.py \
model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \
model=llama-3-8b \
model/sharding=llama-fsdp \
dataset.hf_dataset_name=gs://torchprime/e2e-test/datasets/wikitext \
task=train \
task.global_batch_size=16 \
task.lr_scheduler.type=constant \
task.max_steps=15 \
dcn_mesh.data=2 \
ici_mesh.fsdp=4 \
profile_start_step=3
- name: Run Deepseek v3 Shallow
id: run-ds-v3-shallow
env:
XLA_IR_DEBUG: 1
XLA_HLO_DEBUG: 1
run: |
name=$(e2e_testing/gen_name.py ds-v3-shallow)
echo "name=$name" >> "$GITHUB_OUTPUT"
tp run ${{ steps.docker-url-option.outputs.value }} \
--name $name \
torchprime/torch_xla_models/train.py \
model=deepseek-v3 \
model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/deepseek-v3-tokenizer \
model.num_hidden_layers=2 \
model.first_k_dense_replace=1 \
dataset.hf_dataset_name=gs://torchprime/e2e-test/datasets/wikitext \
dataset.block_size=512 \
task=train \
task.lr_scheduler.type=constant \
task.global_batch_size=4 \
task.max_steps=15 \
ici_mesh.fsdp=4 \
profile_start_step=7
# Load reference step times
load-benchmarks:
name: Load reference step times
runs-on: ubuntu-24.04
outputs:
matrix: ${{ steps.load.outputs.matrix }}
steps:
- uses: actions/checkout@v4
- name: Load step_time_bounds.yaml
id: load
run: |
# Extract benchmarks as array of objects
MATRIX=$(yq -o=json -I=0 '.benchmarks | to_entries | map({
"benchmark": .key,
"name": .value.name,
"lower_bound": .value.step_time_lower_bound,
"upper_bound": .value.step_time_upper_bound,
"target_loss": .value.target_loss,
"loss_tolerance": .value.loss_tolerance
})' e2e_testing/step_time_bounds.yaml)
echo "Benchmark matrix JSON: $MATRIX"
echo "matrix=$MATRIX" >> "$GITHUB_OUTPUT"
# Validate the results of the workloads
#
# Each workload has a step time lower bound and upper bound.
# The bounds and confidence intervals are programmatically derived from
# historical E2E test results. To regenerate the bounds, you can run
# `e2e_testing/update_step_time.py`.
validate:
name: ${{ matrix.config.name }}
needs: [tp-run, load-benchmarks]
strategy:
fail-fast: false
matrix:
config: ${{ fromJson(needs.load-benchmarks.outputs.matrix) }}
uses: ./.github/workflows/reusable_e2e_check.yml
with:
jobset_name: >-
${{
matrix.config.benchmark == 'llama-3-8b' && needs.tp-run.outputs.llama-3-8b-name ||
matrix.config.benchmark == 'llama-3_1-8b-sa' && needs.tp-run.outputs.llama-3_1-8b-sa-name ||
matrix.config.benchmark == 'llama-3_1-8b-scan-offload' && needs.tp-run.outputs.llama-3_1-8b-scan-offload-name ||
matrix.config.benchmark == 'llama-3-8b-2d' && needs.tp-run.outputs.llama-3-8b-2d-name ||
matrix.config.benchmark == 'mixtral-8x7b' && needs.tp-run.outputs.mixtral-8x7b-name ||
matrix.config.benchmark == 'llama-3-8b-pure-mlp' && needs.tp-run.outputs.llama-3-8b-pure-mlp-name ||
matrix.config.benchmark == 'llama-3-8b-sft' && needs.tp-run.outputs.llama-3-8b-sft-name ||
matrix.config.benchmark == 'llama-3-8b-2-slice' && needs.tp-run.outputs.llama-3-8b-2-slice-name ||
matrix.config.benchmark == 'llama-3-8b-ddp-fsdp' && needs.tp-run.outputs.llama-3-8b-ddp-fsdp-name ||
matrix.config.benchmark == 'llama-3-8b-fsdp-cp' && needs.tp-run.outputs.llama-3-8b-fsdp-cp-name ||
matrix.config.benchmark == 'ds-v3-shallow' && needs.tp-run.outputs.ds-v3-shallow-name
}}
artifact_dir: ${{ needs.tp-run.outputs.artifact-dir }}
step_time_lower_bound: ${{ matrix.config.lower_bound }}
step_time_upper_bound: ${{ matrix.config.upper_bound }}
# Optional loss validation settings. When undefined in the matrix,
# these fields expand to an empty string which causes a workflow
# syntax error. Default to ``0`` so the reusable workflow can
# skip the loss check when no value is provided.
target_loss: ${{ matrix.config.target_loss || 0 }}
loss_tolerance: ${{ matrix.config.loss_tolerance || 0 }}
secrets: inherit