Skip to content

[trainer] add Jax trainer guide for TPU#4343

Open
siyuanfoundation wants to merge 1 commit intokubeflow:masterfrom
siyuanfoundation:tpu
Open

[trainer] add Jax trainer guide for TPU#4343
siyuanfoundation wants to merge 1 commit intokubeflow:masterfrom
siyuanfoundation:tpu

Conversation

@siyuanfoundation
Copy link

@siyuanfoundation siyuanfoundation commented Mar 16, 2026

Description of Changes

This PR adds a JAX user guide describing how to run distributed JAX
training jobs with Kubeflow Trainer on TPUs.

Related Issues

Related: kubeflow/trainer#3183

Checklist

image

@google-oss-prow
Copy link

Hi @siyuanfoundation. Thanks for your PR.

I'm waiting for a kubeflow member to verify that this patch is reasonable to test. If it is, they should reply with /ok-to-test on its own line. Until that is done, I will not automatically test new commits in this PR, but the usual testing commands by org members will still work. Regular contributors should join the org to skip this step.

Once the patch is verified, the new status will be reflected by the ok-to-test label.

I understand the commands that are listed here.

Details

Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository.

@google-oss-prow google-oss-prow bot added size/L area/trainer AREA: Kubeflow Trainer / Kubeflow Training Operator labels Mar 16, 2026
@google-oss-prow
Copy link

[APPROVALNOTIFIER] This PR is NOT APPROVED

This pull-request has been approved by:
Once this PR has been reviewed and has the lgtm label, please assign gaocegege for approval. For more information see the Kubernetes Code Review Process.

The full list of commands accepted by this bot can be found here.

Details Needs approval from an approver in each of these files:

Approvers can indicate their approval by writing /approve in a comment
Approvers can cancel approval by writing /approve cancel in a comment

@github-actions
Copy link

🚫 This command cannot be processed. Only organization members or owners can use the commands.

Signed-off-by: siyuanfoundation <sizhang@google.com>
@siyuanfoundation
Copy link
Author

/cc @andreyvelich

@google-oss-prow google-oss-prow bot requested a review from andreyvelich March 16, 2026 20:59
Copy link
Member

@andreyvelich andreyvelich left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, overall lgtm, left a few comments.
Thank you for this @siyuanfoundation!
/assign @akshaychitneni @kubeflow/kubeflow-trainer-team

## Prerequisites

Before exploring this guide, make sure to follow:
- [The Getting Started guide](https://www.kubeflow.org/docs/components/trainer/user-guides/)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
- [The Getting Started guide](https://www.kubeflow.org/docs/components/trainer/user-guides/)
- [The Getting Started guide](/docs/components/trainer/user-guides/)

## JAX on TPU Overview

JAX on TPU requires a different runtime environment than GPU. Specifically:
- **Image**: You must use a JAX image compatible with TPUs (e.g., `us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu`).
Copy link
Member

@andreyvelich andreyvelich Mar 16, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you can add example how ClusterTrainingRuntime might look like?
Do you know if users want to set node selectors per job, or this is something that cluster admins can configure when they create reusable ClusterTrainingRuntime?

As @kaisoz mentioned in this PR, our default ClusterTrainingRuntime's image doesn't support TPUs: kubeflow/trainer#3151 (comment)
cc @kubeflow/kubeflow-trainer-team

JAX on TPU requires a different runtime environment than GPU. Specifically:
- **Image**: You must use a JAX image compatible with TPUs (e.g., `us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu`).
- **Resources**: You must request `google.com/tpu` resources.
- **Node Selectors**: You must specify GKE-specific node selectors and topology for TPU nodes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know that JobSet also supports Exclusive Topology for TPU workload placement:

alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool

Are there any interest from GKE team to showcase how this can be used with the TrainJob too?

Additionally, TPU multi-slice examples: kubernetes-sigs/jobset#1168

cc @GiuseppeTT @imreddy13


### Node Selectors and Topology

When running on GKE, TPUs are often managed via [Compute Classes](https://cloud.google.com/kubernetes-engine/docs/how-to/tpus-compute-class). You must match the `node_selector` to your TPU node pool labels:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This URL doesn't work.

Comment on lines +18 to +19
apiVersion: cloud.google.com/v1
kind: ComputeClass
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it require DRA driver to be installed? Shall we mention this?

and `jax[tpu]` in the same image leads to backend and plugin conflicts.
A separate TPU-specific runtime is required.

Check out [the JAX on TPU guide](https://www.kubeflow.org/docs/components/trainer/user-guides/jax-tpu/)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Check out [the JAX on TPU guide](https://www.kubeflow.org/docs/components/trainer/user-guides/jax-tpu/)
Check out [the JAX on TPU guide](/docs/components/trainer/user-guides/jax-tpu/)

@andreyvelich
Copy link
Member

/ok-to-test

@github-actions
Copy link

Approvals successfully granted for pending runs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area/trainer AREA: Kubeflow Trainer / Kubeflow Training Operator ok-to-test size/L

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants