Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pathwaysutils/managed_pathways_service/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# This file marks this directory as a Python package.
152 changes: 152 additions & 0 deletions pathwaysutils/managed_pathways_service/pw-cluster.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: pathways-akshu-s4-rw7
spec:
coordinator:
replicatedJob: pathways-head
failurePolicy:
maxRestarts: 1
restartStrategy: Recreate
network:
enableDNSHostnames: true
publishNotReadyAddresses: true
replicatedJobs:
- name: pathways-head
replicas: 1
template:
metadata:
annotations:
alpha.jobset.sigs.k8s.io/exclusive-topology: kubernetes.io/hostname
spec:
backoffLimit: 0
completionMode: Indexed
completions: 1
parallelism: 1
template:
metadata:
labels:
kueue.x-k8s.io/podset: pathways-head
spec:
containers:
- args:
- --server_port=29001
- --gcs_scratch_location=gs://akshu-v5e
- --node_type=resource_manager
- --instance_count=4
- --instance_type=tpuv5e:4x8
- --xla_tpu_use_enhanced_launch_barrier=true
- --logtostderr
- --stderrthreshold=0
- --v=1
env:
- name: REPLICATED_JOB_NAME
valueFrom:
fieldRef:
fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']
- name: JOBSET_NAME
valueFrom:
fieldRef:
fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name']
- name: HOST_ADDRESS
valueFrom:
fieldRef:
fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator']
- name: TPU_SKIP_MDS_QUERY
value: "true"
image: us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/akshu/unsanitized_server:latest
imagePullPolicy: Always
name: pathways-rm
ports:
- containerPort: 29001
protocol: TCP
- containerPort: 29002
protocol: TCP
resources:
limits:
cpu: "8"
memory: 16G
nodeSelector:
cloud.google.com/gke-nodepool: cpu-np
dnsPolicy: ClusterFirstWithHostNet
hostNetwork: true
restartPolicy: OnFailure
- name: worker
replicas: 4
template:
metadata: {}
spec:
backoffLimit: 64
completionMode: Indexed
completions: 8
parallelism: 8
template:
metadata:
annotations:
alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
labels:
kueue.x-k8s.io/podset: worker
spec:
containers:
- args:
- --server_port=29005
- --resource_manager_address=$(PATHWAYS_HEAD):29001
- --gcs_scratch_location=gs://akshu-v5e
- --xla_tpu_use_enhanced_launch_barrier=true
- --logtostderr
- --stderrthreshold=0
- --v=1
env:
- name: TPU_MIN_LOG_LEVEL
value: "0"
- name: TF_CPP_MIN_LOG_LEVEL
value: "0"
- name: XCLOUD_ENVIRONMENT
value: GCP
- name: MEGASCALE_GRPC_ENABLE_XOR_TRACER
value: "false"
- name: MEGASCALE_NUM_SLICES
valueFrom:
fieldRef:
fieldPath: metadata.labels['jobset.sigs.k8s.io/replicatedjob-replicas']
- name: JOBSET_NAME
valueFrom:
fieldRef:
fieldPath: metadata.annotations['jobset.sigs.k8s.io/jobset-name']
- name: REPLICATED_JOB_NAME
valueFrom:
fieldRef:
fieldPath: metadata.annotations['jobset.sigs.k8s.io/replicatedjob-name']
- name: MEGASCALE_SLICE_ID
valueFrom:
fieldRef:
fieldPath: metadata.labels['jobset.sigs.k8s.io/job-index']
- name: PATHWAYS_HEAD
valueFrom:
fieldRef:
fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator']
- name: MEGASCALE_COORDINATOR_ADDRESS
valueFrom:
fieldRef:
fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator']
image: us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/akshu/unsanitized_server:latest
imagePullPolicy: Always
name: pathways-worker
ports:
- containerPort: 29005
protocol: TCP
- containerPort: 29006
protocol: TCP
- containerPort: 8471
protocol: TCP
- containerPort: 8080
protocol: TCP
resources:
limits:
google.com/tpu: "4"
dnsPolicy: ClusterFirstWithHostNet
hostNetwork: true
nodeSelector:
cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
cloud.google.com/gke-tpu-topology: 4x8
restartPolicy: OnFailure
56 changes: 56 additions & 0 deletions pathwaysutils/managed_pathways_service/pw-proxy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
name: ${PROXY_NAME}
spec:
coordinator:
replicatedJob: pathways-head
failurePolicy:
maxRestarts: 1
restartStrategy: Recreate
network:
enableDNSHostnames: true
publishNotReadyAddresses: true
replicatedJobs:
- name: pathways-head
replicas: 1
template:
metadata:
annotations:
alpha.jobset.sigs.k8s.io/exclusive-topology: kubernetes.io/hostname
spec:
backoffLimit: 0
completionMode: Indexed
completions: 1
parallelism: 1
template:
metadata:
labels:
kueue.x-k8s.io/podset: pathways-head
spec:
containers:
- args:
- --server_port=29000
- --resource_manager_address=${PATHWAYS_HEAD}:${PATHWAYS_HEAD_PORT}
- --gcs_scratch_location=${GCS_BUCKET}
- --virtual_slices=${EXPECTED_INSTANCES}
env:
- name: PATHWAYS_HEAD
valueFrom:
fieldRef:
fieldPath: metadata.labels['jobset.sigs.k8s.io/coordinator']
image: us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/akshu/unsanitized_proxy_server:latest
imagePullPolicy: Always
name: pathways-proxy
ports:
- containerPort: 29000
protocol: TCP
resources:
limits:
cpu: "16"
memory: 100G
nodeSelector:
cloud.google.com/gke-nodepool: cpu-np
dnsPolicy: ClusterFirstWithHostNet
hostNetwork: true
restartPolicy: OnFailure
50 changes: 50 additions & 0 deletions pathwaysutils/managed_pathways_service/run_connect_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script to run JAX code on TPU with the Managed Pathways service."""

from collections.abc import Sequence
from absl import app
from . import tpu_manager


def main(argv: Sequence[str]) -> None:
if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
with tpu_manager.connect(
"pw-scale-test-v5e-32",
"cloud-tpu-multipod-dev",
"us-south1",
"gs://akshu-v5e",
"pathways-akshu-s4-rw7-pathways-head-0-0.pathways-akshu-s4-rw7:29001",
{"tpuv5e:4x8": 2},
) as tm:
pass
# import jax.numpy as jnp
# import pathwaysutils
# import pprint

# pathwaysutils.initialize()

# orig_matrix = jnp.zeros(5)

# print("start")
# result_matrix = orig_matrix + 1
# print("Original Random Matrix:")
# pprint.pprint(orig_matrix)
# print("\nMatrix after adding 1:")
# pprint.pprint(result_matrix)


if __name__ == "__main__":
app.run(main)
Loading