[trainer] add Jax trainer guide for TPU#4343
[trainer] add Jax trainer guide for TPU#4343siyuanfoundation wants to merge 1 commit intokubeflow:masterfrom
Conversation
|
Hi @siyuanfoundation. Thanks for your PR. I'm waiting for a kubeflow member to verify that this patch is reasonable to test. If it is, they should reply with Once the patch is verified, the new status will be reflected by the I understand the commands that are listed here. DetailsInstructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository. |
|
[APPROVALNOTIFIER] This PR is NOT APPROVED This pull-request has been approved by: The full list of commands accepted by this bot can be found here. DetailsNeeds approval from an approver in each of these files:Approvers can indicate their approval by writing |
|
🚫 This command cannot be processed. Only organization members or owners can use the commands. |
Signed-off-by: siyuanfoundation <sizhang@google.com>
|
/cc @andreyvelich |
andreyvelich
left a comment
There was a problem hiding this comment.
Looks great, overall lgtm, left a few comments.
Thank you for this @siyuanfoundation!
/assign @akshaychitneni @kubeflow/kubeflow-trainer-team
| ## Prerequisites | ||
|
|
||
| Before exploring this guide, make sure to follow: | ||
| - [The Getting Started guide](https://www.kubeflow.org/docs/components/trainer/user-guides/) |
There was a problem hiding this comment.
| - [The Getting Started guide](https://www.kubeflow.org/docs/components/trainer/user-guides/) | |
| - [The Getting Started guide](/docs/components/trainer/user-guides/) |
| ## JAX on TPU Overview | ||
|
|
||
| JAX on TPU requires a different runtime environment than GPU. Specifically: | ||
| - **Image**: You must use a JAX image compatible with TPUs (e.g., `us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu`). |
There was a problem hiding this comment.
Maybe you can add example how ClusterTrainingRuntime might look like?
Do you know if users want to set node selectors per job, or this is something that cluster admins can configure when they create reusable ClusterTrainingRuntime?
As @kaisoz mentioned in this PR, our default ClusterTrainingRuntime's image doesn't support TPUs: kubeflow/trainer#3151 (comment)
cc @kubeflow/kubeflow-trainer-team
| JAX on TPU requires a different runtime environment than GPU. Specifically: | ||
| - **Image**: You must use a JAX image compatible with TPUs (e.g., `us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu`). | ||
| - **Resources**: You must request `google.com/tpu` resources. | ||
| - **Node Selectors**: You must specify GKE-specific node selectors and topology for TPU nodes. |
There was a problem hiding this comment.
I know that JobSet also supports Exclusive Topology for TPU workload placement:
alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
Are there any interest from GKE team to showcase how this can be used with the TrainJob too?
Additionally, TPU multi-slice examples: kubernetes-sigs/jobset#1168
|
|
||
| ### Node Selectors and Topology | ||
|
|
||
| When running on GKE, TPUs are often managed via [Compute Classes](https://cloud.google.com/kubernetes-engine/docs/how-to/tpus-compute-class). You must match the `node_selector` to your TPU node pool labels: |
| apiVersion: cloud.google.com/v1 | ||
| kind: ComputeClass |
There was a problem hiding this comment.
Does it require DRA driver to be installed? Shall we mention this?
| and `jax[tpu]` in the same image leads to backend and plugin conflicts. | ||
| A separate TPU-specific runtime is required. | ||
|
|
||
| Check out [the JAX on TPU guide](https://www.kubeflow.org/docs/components/trainer/user-guides/jax-tpu/) |
There was a problem hiding this comment.
| Check out [the JAX on TPU guide](https://www.kubeflow.org/docs/components/trainer/user-guides/jax-tpu/) | |
| Check out [the JAX on TPU guide](/docs/components/trainer/user-guides/jax-tpu/) |
|
/ok-to-test |
|
Approvals successfully granted for pending runs. |
Description of Changes
This PR adds a JAX user guide describing how to run distributed JAX
training jobs with Kubeflow Trainer on TPUs.
Related Issues
Related: kubeflow/trainer#3183
Checklist