E2E tests #1951
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| name: E2E tests | |
| on: | |
| push: | |
| branches: | |
| - main | |
| pull_request: | |
| schedule: | |
| - cron: "0 8 * * *" # Run daily at 12AM PST (adjusted for UTC) | |
| workflow_dispatch: | |
| inputs: | |
| docker_url: | |
| description: If specified, use this PyTorch/XLA base docker image URL instead of the pin. | |
| required: false | |
| type: string | |
| jobs: | |
| tp-run: | |
| name: Submit workloads | |
| runs-on: ubuntu-24.04 | |
| env: | |
| ARTIFACT_DIR: gs://torchprime-e2e-tests/${{ github.job }}/${{ github.run_id }}-${{ github.run_attempt }} | |
| outputs: | |
| llama-3-8b-name: ${{ steps.run-llama-3-8b.outputs.name }} | |
| llama-3-8b-pure-mlp-name: ${{ steps.run-llama-3-8b-pure-mlp.outputs.name }} | |
| llama-3_1-8b-sa-name: ${{ steps.run-llama-3_1-8b-SplashAttention.outputs.name }} | |
| llama-3_1-8b-scan-offload-name: ${{ steps.run-llama-3_1-8b-scan-offload.outputs.name }} | |
| llama-3-8b-2d-name: ${{ steps.run-llama-3-8b-2d.outputs.name }} | |
| llama-3-8b-2-slice-name: ${{ steps.run-llama-3-8b-2-slice.outputs.name }} | |
| llama-3-8b-sft-name: ${{ steps.run-llama-3-8b-sft.outputs.name }} | |
| llama-3-8b-ddp-fsdp-name: ${{ steps.run-llama-3-8b-ddp-fsdp.outputs.name }} | |
| llama-3-8b-fsdp-cp-name: ${{ steps.run-llama-3-8b-fsdp-cp.outputs.name }} | |
| mixtral-8x7b-name: ${{ steps.run-mixtral-8x7b.outputs.name }} | |
| ds-v3-shallow-name: ${{ steps.run-ds-v3-shallow.outputs.name }} | |
| artifact-dir: ${{ steps.artifacts.outputs.artifact_dir }} | |
| steps: | |
| - name: Record artifact dir | |
| id: artifacts | |
| run: | | |
| echo "Artifact dir: $ARTIFACT_DIR" | |
| echo "artifact_dir=$ARTIFACT_DIR" >> "$GITHUB_OUTPUT" | |
| - name: Maximize build space | |
| uses: AdityaGarg8/[email protected] | |
| with: | |
| remove-dotnet: 'true' | |
| remove-android: 'true' | |
| remove-haskell: 'true' | |
| remove-codeql: 'true' | |
| - uses: actions/checkout@v4 | |
| - uses: ./.github/actions/e2e-setup | |
| with: | |
| gcp_project: ${{ vars.GCP_PROJECT }} | |
| gcp_zone: ${{ vars.GCP_ZONE }} | |
| xpk_cluster_name: ${{ vars.XPK_CLUSTER_NAME }} | |
| tpu_type: ${{ vars.TPU_TYPE }} | |
| artifact_dir: ${{ env.ARTIFACT_DIR }} | |
| gcp_sa_key: ${{ secrets.GCP_SA_KEY }} | |
| - name: Setup Docker URL option | |
| id: docker-url-option | |
| run: | | |
| if [ -n "${{ github.event.inputs.docker_url }}" ]; then | |
| echo "value=--base-docker-url ${{ github.event.inputs.docker_url }}" >> "$GITHUB_OUTPUT" | |
| else | |
| echo "value=" >> "$GITHUB_OUTPUT" | |
| fi | |
| # Launch training workloads. | |
| - name: Run Llama 3.0 8B | |
| id: run-llama-3-8b | |
| env: | |
| XLA_IR_DEBUG: 1 | |
| XLA_HLO_DEBUG: 1 | |
| run: | | |
| name=$(e2e_testing/gen_name.py llama-3-8b) | |
| echo "name=$name" >> "$GITHUB_OUTPUT" | |
| tp run ${{ steps.docker-url-option.outputs.value }} \ | |
| --name $name \ | |
| torchprime/torch_xla_models/train.py \ | |
| model=llama-3-8b \ | |
| model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \ | |
| dataset.hf_dataset_name=gs://torchprime/e2e-test/datasets/wikitext \ | |
| task=train \ | |
| task.global_batch_size=8 \ | |
| task.lr_scheduler.type=constant \ | |
| task.max_steps=15 \ | |
| ici_mesh.fsdp=4 \ | |
| profile_start_step=3 | |
| - name: Run Llama 3.0 8B (@assume_pure) | |
| id: run-llama-3-8b-pure-mlp | |
| env: | |
| XLA_IR_DEBUG: 1 | |
| XLA_HLO_DEBUG: 1 | |
| run: | | |
| name=$(e2e_testing/gen_name.py llama-3-8b-pure-mlp) | |
| echo "name=$name" >> "$GITHUB_OUTPUT" | |
| tp run ${{ steps.docker-url-option.outputs.value }} \ | |
| --name $name \ | |
| torchprime/torch_xla_models/train.py \ | |
| model=llama-3-8b \ | |
| model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \ | |
| dataset.hf_dataset_name=gs://torchprime/e2e-test/datasets/wikitext \ | |
| task=train \ | |
| task.global_batch_size=8 \ | |
| task.lr_scheduler.type=constant \ | |
| task.max_steps=15 \ | |
| ici_mesh.fsdp=4 \ | |
| profile_start_step=3 \ | |
| model.pure_modules=[LlamaMLP,EinsumLinear] | |
| - name: Run Llama 3.1 8B (Splash Attention) | |
| id: run-llama-3_1-8b-SplashAttention | |
| env: | |
| XLA_IR_DEBUG: 1 | |
| XLA_HLO_DEBUG: 1 | |
| run: | | |
| name=$(e2e_testing/gen_name.py llama-3dot1-8b-sa) | |
| echo "name=$name" >> "$GITHUB_OUTPUT" | |
| tp run ${{ steps.docker-url-option.outputs.value }} \ | |
| --name $name \ | |
| torchprime/torch_xla_models/train.py \ | |
| model=llama-3.1-8b \ | |
| model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3.1-405b \ | |
| model.attention_kernel=splash_attention \ | |
| dataset.hf_dataset_name=gs://torchprime/e2e-test/datasets/wikitext \ | |
| task=train \ | |
| task.global_batch_size=8 \ | |
| task.lr_scheduler.type=constant \ | |
| task.max_steps=15 \ | |
| ici_mesh.fsdp=4 \ | |
| profile_start_step=3 | |
| - name: Run Llama 3.1 8B (Scan + Offload) | |
| id: run-llama-3_1-8b-scan-offload | |
| env: | |
| XLA_IR_DEBUG: 1 | |
| XLA_HLO_DEBUG: 1 | |
| run: | | |
| name=$(e2e_testing/gen_name.py llama-3dot1-8b-scan-offload) | |
| echo "name=$name" >> "$GITHUB_OUTPUT" | |
| tp run ${{ steps.docker-url-option.outputs.value }} \ | |
| --name $name \ | |
| torchprime/torch_xla_models/train.py \ | |
| model=llama-3.1-8b \ | |
| model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3.1-405b \ | |
| model/remat=llama-scan-offload \ | |
| dataset.hf_dataset_name=gs://torchprime/e2e-test/datasets/wikitext \ | |
| task=train \ | |
| task.global_batch_size=8 \ | |
| task.lr_scheduler.type=constant \ | |
| task.max_steps=15 \ | |
| ici_mesh.fsdp=4 \ | |
| profile_start_step=3 | |
| - name: Run Llama 3.0 8B (2D sharding) | |
| id: run-llama-3-8b-2d | |
| env: | |
| XLA_IR_DEBUG: 1 | |
| XLA_HLO_DEBUG: 1 | |
| run: | | |
| name=$(e2e_testing/gen_name.py llama-3-8b-2d) | |
| echo "name=$name" >> "$GITHUB_OUTPUT" | |
| tp run ${{ steps.docker-url-option.outputs.value }} \ | |
| --name $name \ | |
| torchprime/torch_xla_models/train.py \ | |
| model=llama-3-8b \ | |
| model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \ | |
| model/sharding=llama-fsdp-tp \ | |
| dataset.hf_dataset_name=gs://torchprime/e2e-test/datasets/wikitext \ | |
| task=train \ | |
| task.global_batch_size=8 \ | |
| task.lr_scheduler.type=constant \ | |
| task.max_steps=15 \ | |
| ici_mesh.fsdp=2 \ | |
| ici_mesh.tensor=2 \ | |
| profile_start_step=3 | |
| - name: Run Llama 3.0 8B (fsdp + cp) # TODO: Update ici_mesh.context = 2 and ici_mesh.fsdp = 2 once e2e test after debug | |
| id: run-llama-3-8b-fsdp-cp | |
| env: | |
| XLA_IR_DEBUG: 1 | |
| XLA_HLO_DEBUG: 1 | |
| run: | | |
| name=$(e2e_testing/gen_name.py llama-3-8b-fsdp-cp) | |
| echo "name=$name" >> "$GITHUB_OUTPUT" | |
| tp run ${{ steps.docker-url-option.outputs.value }} \ | |
| --name $name \ | |
| torchprime/torch_xla_models/train.py \ | |
| model=llama-3-8b-cp \ | |
| model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \ | |
| model/sharding=llama-fsdp-tp-cp \ | |
| dataset.hf_dataset_name=gs://torchprime/e2e-test/datasets/wikitext \ | |
| task=train \ | |
| task.global_batch_size=4 \ | |
| task.max_steps=15 \ | |
| task.lr_scheduler.type=constant \ | |
| ici_mesh.fsdp=4 \ | |
| profile_start_step=3 | |
| - name: Run Mixtral 8x7B | |
| id: run-mixtral-8x7b | |
| env: | |
| XLA_IR_DEBUG: 1 | |
| XLA_HLO_DEBUG: 1 | |
| run: | | |
| name=$(e2e_testing/gen_name.py mixtral-8x7b) | |
| echo "name=$name" >> "$GITHUB_OUTPUT" | |
| tp run ${{ steps.docker-url-option.outputs.value }} \ | |
| --name $name \ | |
| torchprime/torch_xla_models/train.py \ | |
| model=mixtral-8x7b \ | |
| model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/mixtral-8x7b-v0.1/ \ | |
| model.num_hidden_layers=16 \ | |
| dataset.hf_dataset_name=gs://torchprime/e2e-test/datasets/wikitext \ | |
| task=train \ | |
| task.global_batch_size=8 \ | |
| task.lr_scheduler.type=constant \ | |
| task.max_steps=15 \ | |
| ici_mesh.fsdp=4 \ | |
| profile_start_step=3 | |
| - name: Run Llama 3.0 8B (2 slice) | |
| id: run-llama-3-8b-2-slice | |
| env: | |
| XLA_IR_DEBUG: 1 | |
| XLA_HLO_DEBUG: 1 | |
| run: | | |
| name=$(e2e_testing/gen_name.py llama-3-8b-2-slice) | |
| echo "name=$name" >> "$GITHUB_OUTPUT" | |
| tp run ${{ steps.docker-url-option.outputs.value }} \ | |
| --name $name \ | |
| --num-slices 2 \ | |
| torchprime/torch_xla_models/train.py \ | |
| model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \ | |
| model=llama-3-8b \ | |
| model/sharding=llama-fsdp \ | |
| dataset.hf_dataset_name=gs://torchprime/e2e-test/datasets/wikitext \ | |
| task=train \ | |
| task.global_batch_size=16 \ | |
| task.lr_scheduler.type=constant \ | |
| task.max_steps=15 \ | |
| dcn_mesh.fsdp=2 \ | |
| ici_mesh.fsdp=4 \ | |
| profile_start_step=3 | |
| - name: Run Llama 3.0 8B SFT | |
| id: run-llama-3-8b-sft | |
| env: | |
| XLA_IR_DEBUG: 1 | |
| XLA_HLO_DEBUG: 1 | |
| run: | | |
| name=$(e2e_testing/gen_name.py llama-3-8b-sft) | |
| echo "name=$name" >> "$GITHUB_OUTPUT" | |
| tp run ${{ steps.docker-url-option.outputs.value }} \ | |
| --name $name \ | |
| torchprime/torch_xla_models/train.py \ | |
| --config-name llama-3-8b-sft-w-gsm8k \ | |
| model.pretrained_model=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \ | |
| dataset.hf_dataset_name=gs://torchprime/e2e-test/datasets/gsm8k \ | |
| model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \ | |
| ici_mesh.fsdp=4 \ | |
| task.max_steps=50 \ | |
| task.convert_to_safetensors=False \ | |
| profile_start_step=3 | |
| - name: Run Llama 3.0 8B (ddp + fsdp) | |
| id: run-llama-3-8b-ddp-fsdp | |
| env: | |
| XLA_IR_DEBUG: 1 | |
| XLA_HLO_DEBUG: 1 | |
| run: | | |
| name=$(e2e_testing/gen_name.py llama-3-8b-ddp-fsdp) | |
| echo "name=$name" >> "$GITHUB_OUTPUT" | |
| tp run ${{ steps.docker-url-option.outputs.value }} \ | |
| --name $name \ | |
| --num-slices 2 \ | |
| torchprime/torch_xla_models/train.py \ | |
| model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/meta-llama-3-8b \ | |
| model=llama-3-8b \ | |
| model/sharding=llama-fsdp \ | |
| dataset.hf_dataset_name=gs://torchprime/e2e-test/datasets/wikitext \ | |
| task=train \ | |
| task.global_batch_size=16 \ | |
| task.lr_scheduler.type=constant \ | |
| task.max_steps=15 \ | |
| dcn_mesh.data=2 \ | |
| ici_mesh.fsdp=4 \ | |
| profile_start_step=3 | |
| - name: Run Deepseek v3 Shallow | |
| id: run-ds-v3-shallow | |
| env: | |
| XLA_IR_DEBUG: 1 | |
| XLA_HLO_DEBUG: 1 | |
| run: | | |
| name=$(e2e_testing/gen_name.py ds-v3-shallow) | |
| echo "name=$name" >> "$GITHUB_OUTPUT" | |
| tp run ${{ steps.docker-url-option.outputs.value }} \ | |
| --name $name \ | |
| torchprime/torch_xla_models/train.py \ | |
| model=deepseek-v3 \ | |
| model.tokenizer_name=gs://torchprime/e2e-test/hf-model-files/deepseek-v3-tokenizer \ | |
| model.num_hidden_layers=2 \ | |
| model.first_k_dense_replace=1 \ | |
| dataset.hf_dataset_name=gs://torchprime/e2e-test/datasets/wikitext \ | |
| dataset.block_size=512 \ | |
| task=train \ | |
| task.lr_scheduler.type=constant \ | |
| task.global_batch_size=4 \ | |
| task.max_steps=15 \ | |
| ici_mesh.fsdp=4 \ | |
| profile_start_step=7 | |
| # Load reference step times | |
| load-benchmarks: | |
| name: Load reference step times | |
| runs-on: ubuntu-24.04 | |
| outputs: | |
| matrix: ${{ steps.load.outputs.matrix }} | |
| steps: | |
| - uses: actions/checkout@v4 | |
| - name: Load step_time_bounds.yaml | |
| id: load | |
| run: | | |
| # Extract benchmarks as array of objects | |
| MATRIX=$(yq -o=json -I=0 '.benchmarks | to_entries | map({ | |
| "benchmark": .key, | |
| "name": .value.name, | |
| "lower_bound": .value.step_time_lower_bound, | |
| "upper_bound": .value.step_time_upper_bound, | |
| "target_loss": .value.target_loss, | |
| "loss_tolerance": .value.loss_tolerance | |
| })' e2e_testing/step_time_bounds.yaml) | |
| echo "Benchmark matrix JSON: $MATRIX" | |
| echo "matrix=$MATRIX" >> "$GITHUB_OUTPUT" | |
| # Validate the results of the workloads | |
| # | |
| # Each workload has a step time lower bound and upper bound. | |
| # The bounds and confidence intervals are programmatically derived from | |
| # historical E2E test results. To regenerate the bounds, you can run | |
| # `e2e_testing/update_step_time.py`. | |
| validate: | |
| name: ${{ matrix.config.name }} | |
| needs: [tp-run, load-benchmarks] | |
| strategy: | |
| fail-fast: false | |
| matrix: | |
| config: ${{ fromJson(needs.load-benchmarks.outputs.matrix) }} | |
| uses: ./.github/workflows/reusable_e2e_check.yml | |
| with: | |
| jobset_name: >- | |
| ${{ | |
| matrix.config.benchmark == 'llama-3-8b' && needs.tp-run.outputs.llama-3-8b-name || | |
| matrix.config.benchmark == 'llama-3_1-8b-sa' && needs.tp-run.outputs.llama-3_1-8b-sa-name || | |
| matrix.config.benchmark == 'llama-3_1-8b-scan-offload' && needs.tp-run.outputs.llama-3_1-8b-scan-offload-name || | |
| matrix.config.benchmark == 'llama-3-8b-2d' && needs.tp-run.outputs.llama-3-8b-2d-name || | |
| matrix.config.benchmark == 'mixtral-8x7b' && needs.tp-run.outputs.mixtral-8x7b-name || | |
| matrix.config.benchmark == 'llama-3-8b-pure-mlp' && needs.tp-run.outputs.llama-3-8b-pure-mlp-name || | |
| matrix.config.benchmark == 'llama-3-8b-sft' && needs.tp-run.outputs.llama-3-8b-sft-name || | |
| matrix.config.benchmark == 'llama-3-8b-2-slice' && needs.tp-run.outputs.llama-3-8b-2-slice-name || | |
| matrix.config.benchmark == 'llama-3-8b-ddp-fsdp' && needs.tp-run.outputs.llama-3-8b-ddp-fsdp-name || | |
| matrix.config.benchmark == 'llama-3-8b-fsdp-cp' && needs.tp-run.outputs.llama-3-8b-fsdp-cp-name || | |
| matrix.config.benchmark == 'ds-v3-shallow' && needs.tp-run.outputs.ds-v3-shallow-name | |
| }} | |
| artifact_dir: ${{ needs.tp-run.outputs.artifact-dir }} | |
| step_time_lower_bound: ${{ matrix.config.lower_bound }} | |
| step_time_upper_bound: ${{ matrix.config.upper_bound }} | |
| # Optional loss validation settings. When undefined in the matrix, | |
| # these fields expand to an empty string which causes a workflow | |
| # syntax error. Default to ``0`` so the reusable workflow can | |
| # skip the loss check when no value is provided. | |
| target_loss: ${{ matrix.config.target_loss || 0 }} | |
| loss_tolerance: ${{ matrix.config.loss_tolerance || 0 }} | |
| secrets: inherit |