Skip to content

Commit 6973808

Browse files
committed
pathways images finalized
1 parent f446c16 commit 6973808

File tree

3 files changed

+91
-77
lines changed

3 files changed

+91
-77
lines changed

axlearn/cloud/gcp/bundler.py

Lines changed: 13 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848

4949
import os
5050
import subprocess
51-
import time
5251
from typing import Optional
5352

5453
from absl import app, flags, logging
@@ -59,10 +58,10 @@
5958
from axlearn.cloud.common.bundler import register_bundler
6059
from axlearn.cloud.common.docker import registry_from_repo
6160
from axlearn.cloud.common.utils import canonicalize_to_list, to_bool
62-
from axlearn.cloud.gcp.cloud_build import get_cloud_build_status
61+
from axlearn.cloud.gcp.cloud_build import wait_for_cloud_build
6362
from axlearn.cloud.gcp.config import gcp_settings
6463
from axlearn.cloud.gcp.utils import common_flags
65-
from axlearn.common.config import REQUIRED, Required, config_class, maybe_set_config, config_for_class
64+
from axlearn.common.config import REQUIRED, Required, config_class, maybe_set_config
6665

6766
FLAGS = flags.FLAGS
6867

@@ -105,6 +104,8 @@ class Config(DockerBundler.Config):
105104
106105
Attributes:
107106
colocated_image_required: Bool to build a colocated image
107+
colocated_image_name: Colocated Image Name
108+
colocated_dockerfile: Colocated Dockerfile
108109
"""
109110
# Build image asynchronously.
110111
colocated_image_required: bool = False
@@ -129,7 +130,6 @@ def _build_and_push(self, *args, **kwargs):
129130
check=True,
130131
)
131132

132-
print("actual",cfg)
133133
actual_name = cfg.image
134134
actual_dockerfile=cfg.dockerfile
135135
actual_target=cfg.target
@@ -138,17 +138,14 @@ def _build_and_push(self, *args, **kwargs):
138138
cfg.dockerfile=cfg.colocated_dockerfile
139139
cfg.image=cfg.colocated_image_name
140140
cfg.target=None
141-
print("updated config: ",cfg)
141+
142142
colocated_bundler_class = ColocatedArtifactRegistryBundler(cfg=cfg)
143143
colocated_image_name = colocated_bundler_class.bundle(tag="latest")
144-
print(colocated_image_name)
145144

146145
cfg.dockerfile=actual_dockerfile
147146
cfg.image=actual_name
148147
cfg.target=actual_target
149-
150-
151-
148+
152149
return super()._build_and_push(*args, **kwargs)
153150

154151

@@ -164,7 +161,6 @@ def from_spec(cls, spec: list[str], *, fv: Optional[flags.FlagValues]) -> Docker
164161

165162
def _build_and_push(self, *args, **kwargs):
166163
cfg = self.config
167-
print("colocated",cfg)
168164
subprocess.run(
169165
["gcloud", "auth", "configure-docker", registry_from_repo(cfg.repo)],
170166
check=True,
@@ -296,36 +292,14 @@ def wait_until_finished(self, name: str, wait_timeout=3600):
296292
TimeoutError: If the build does not complete within the overall timeout.
297293
ValueError: If the async build fails.
298294
"""
299-
start_time = time.perf_counter()
300295
cfg: CloudBuildBundler.Config = self.config
301-
while cfg.is_async:
302-
elapsed_time = time.perf_counter() - start_time
303-
if elapsed_time > wait_timeout:
304-
timeout_msg = (
305-
f"Timed out waiting for CloudBuild to finish for more than "
306-
f"{wait_timeout} seconds."
307-
)
308-
logging.error(timeout_msg)
309-
raise TimeoutError(timeout_msg)
310-
try:
311-
build_status = get_cloud_build_status(
312-
project_id=cfg.project, image_name=self.id(name), tags=[name]
313-
)
314-
except Exception as e: # pylint: disable=broad-except
315-
# TODO(liang-he,markblee): Distinguish transient from non-transient errors.
316-
logging.warning("Failed to get the CloudBuild status, will retry: %s", e)
317-
else:
318-
if not build_status:
319-
logging.warning("CloudBuild for %s does not exist yet.", name)
320-
elif build_status.is_pending():
321-
logging.info("CloudBuild for %s is pending: %s.", name, build_status)
322-
elif build_status.is_success():
323-
logging.info("CloudBuild for %s is successful: %s.", name, build_status)
324-
return
325-
else:
326-
# Unknown status is also considered a failure.
327-
raise RuntimeError(f"CloudBuild for {name} failed: {build_status}.")
328-
time.sleep(30)
296+
if cfg.is_async:
297+
wait_for_cloud_build(
298+
project_id=cfg.project,
299+
image_id=self.id(name),
300+
tags=[name],
301+
wait_timeout=wait_timeout,
302+
)
329303

330304

331305
def with_tpu_extras(bundler: Bundler.Config) -> Bundler.Config:

axlearn/cloud/gcp/pathways_utils.py

Lines changed: 76 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@
2121
_LoadBalancer,
2222
)
2323
from axlearn.cloud.gcp.lws_utils import BaseLeaderWorkerTemplate, TPULeaderWorkerTemplate
24-
from axlearn.cloud.gcp.system_characteristics import USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS
24+
from axlearn.cloud.gcp.system_characteristics import (
25+
GCE_MACHINE_TYPE_TO_MEMORY_CHARACTERISTICS,
26+
USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS,
27+
support_twisted_topology,
28+
)
2529
from axlearn.cloud.gcp.tpu import infer_tpu_workers
2630
from axlearn.cloud.gcp.utils import validate_jobset_name
2731
from axlearn.common.compiler_options import (
@@ -47,30 +51,20 @@
4751
# There is no guarantee that this image will work with newer Jax releases.
4852
# This image version extends GRPC timeout for long context models, based on jax-0.5.3-patch060625
4953
# This image extends GRPC timeout for long context models.
50-
_PATHWAYS_IMAGE_TAG = "disable_settings_20250701"
54+
_PATHWAYS_IMAGE_TAG = "2025-10-03"
5155

5256
# The docker image used by pathways proxy container.
5357
# pylint: disable=line-too-long
54-
# _PATHWAYS_PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/ksadi/unsanitized_proxy_server_maxtext:latest"
5558
_PATHWAYS_PROXY_IMAGE = (
5659
# pylint: disable=line-too-long
57-
"us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/ksadi/unsanitized_proxy_server:latest"
60+
f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways-colocated-python/proxy_server:{_PATHWAYS_IMAGE_TAG}"
5861
)
5962
# The docker image used by pathways resource manager container and worker container.
6063
_PATHWAYS_SERVER_IMAGE = (
6164
# pylint: disable=line-too-long
62-
# "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/ksadi/unsanitized_server@sha256:fde763e2bae514d0fa758840e501b71a9ea48781dddafa5d8ed3a0fa316fd1ae"
63-
"us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/ksadi/unsanitized_server:latest"
64-
# "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/ksadi/unsanitized_server_maxtext:latest"
65-
)
66-
_COLOCATED_PYTHON_IMAGE = (
67-
# "gcr.io/cloud-tpu-multipod-dev/ksadi_sidecar_maxtext:latest"
68-
# pylint: disable=line-too-long
69-
#"us-docker.pkg.dev/cloud-tpu-multipod-dev/colocated-images/sam:v6"
70-
"us-docker.pkg.dev/cloud-tpu-multipod-dev/axlearn/colocated-img13:latest"
71-
# "us-docker.pkg.dev/cloud-tpu-multipod-dev/colocated-images/lk-colocated-image:latest"
72-
# "gcr.io/cloud-tpu-multipod-dev/sujinesh_sidecar_debug@sha256:19abcd94addb6ff2749c299d6b0cc4748f27a4ab8759a18b466d0bdd3e5b71e8"
65+
f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways-colocated-python/server:{_PATHWAYS_IMAGE_TAG}"
7366
)
67+
7468
# The container name of pathways resourcemanager.
7569
_PATHWAYS_RESOURCE_MANAGER_CONTAINER_NAME = "pathways-rm"
7670
# The container name of pathways proxy.
@@ -95,24 +89,19 @@
9589

9690
def get_colocated_python_image(colocated_image_name, fv: flags.FlagValues = FLAGS) -> str:
9791
repo = gcp_settings("docker_repo", required=False, fv=fv)
98-
print(repo)
99-
print(colocated_image_name)
100-
print(repo+colocated_image_name+":latest")
10192
return repo+"/"+colocated_image_name+":latest"
10293

10394

10495
def parse_xla_flag_value(value: str) -> Union[int, bool, str]:
105-
"""Attempts to convert an XLA flag string value to int, then bool.
96+
"""Attempts to convert an XLA flag string value to int.
10697
10798
If conversion fails, returns the original string (stripped).
10899
"""
109-
bool_mapper = {"true": True, "false": False}
110100
stripped_value_str = value.strip()
111101
try:
112102
return int(stripped_value_str)
113103
except ValueError:
114-
# Not an integer, try boolean conversion.
115-
return bool_mapper.get(stripped_value_str.lower(), stripped_value_str)
104+
return stripped_value_str
116105

117106

118107
def get_pathways_tpu_version(gke_machine_type: str) -> str:
@@ -162,6 +151,25 @@ def get_xla_options(
162151
"""
163152
return {k: v for k, v in xla_options.items() if k.startswith("xla_")}
164153

154+
def round_up_to_power_of_2(n):
155+
"""
156+
Rounds an integer up to the nearest power of 2.
157+
158+
Args:
159+
n (int): The number to round up. Must be a positive integer.
160+
161+
Returns:
162+
int: The smallest power of 2 that is greater than or equal to n.
163+
164+
Examples:
165+
round_up_to_power_of_2(7) -> 8
166+
round_up_to_power_of_2(8) -> 8
167+
round_up_to_power_of_2(9) -> 16
168+
round_up_to_power_of_2(32) -> 32
169+
"""
170+
assert isinstance(n, int) and n > 0
171+
return 1 << (n - 1).bit_length()
172+
165173

166174
class PathwaysReplicatedJob(BaseReplicatedJob):
167175
"""Builds a replicated jobspec for Pathways on TPU, to be used with JobSet API."""
@@ -311,7 +319,12 @@ def _build_pathways_head_container(self) -> dict:
311319
# In Jax 0.6.2 and beyond this flag can be renamed to
312320
# IFRT_PROXY_USE_INSECURE_GRPC_CREDENTIALS as well.
313321
self._update_env_list(env_list, "TEST_UNDECLARED_OUTPUTS_DIR", "true")
314-
322+
# Threshold for using shared memory between Jax client and Pathways proxy.
323+
# Setting it to 1 byte so effectively all Jax device_put use shared memory.
324+
self._update_env_list(env_list, "IFRT_PROXY_LARGE_TRANSFER_THRESHOLD", "1")
325+
self._update_env_list(
326+
env_list, "IFRT_PROXY_LARGE_TRANSFER_OPTIMIZATION_DIRECTORY", "/tmp/ifrt_proxy"
327+
)
315328
env_list.append(
316329
{
317330
"name": "HOST_ADDRESS",
@@ -351,10 +364,13 @@ def _build_pathways_head_container(self) -> dict:
351364
mem_req = f"{self.config.pathways_head_mem}Gi"
352365
resources = {
353366
"requests": {"cpu": cpu_req, "memory": mem_req},
354-
"limits": {"cpu": cpu_req, "memory": mem_req},
355367
}
356368
head_container["resources"] = resources
357369

370+
volume_mounts = head_container.get("volumeMounts", [])
371+
volume_mounts.append(dict(name="shared-memory", mountPath="/tmp/ifrt_proxy"))
372+
head_container["volumeMounts"] = volume_mounts
373+
358374
return head_container
359375

360376
def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
@@ -384,10 +400,9 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
384400
]
385401
cmd_args.extend(xla_flags_from_options(self._xla_options).split())
386402

387-
# This is required for GKE Workload Identity and Mac Jax Client support.
388-
# TODO(samos123): Remove this once this becomes the default.
389-
proxy_env = [{"name": "IFRT_PROXY_USE_INSECURE_GRPC_CREDENTIALS", "value": "true"}]
390-
403+
instance_type = f"{pathways_tpu_version}:{system.topology}"
404+
if support_twisted_topology(self._tpu_type):
405+
instance_type = f"{instance_type}_untwisted"
391406
return [
392407
dict(
393408
name=_PATHWAYS_PROXY_CONTAINER_NAME,
@@ -396,8 +411,21 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
396411
# SideCar container is an init container with restartPolicy as "Always".
397412
restartPolicy="Always",
398413
args=cmd_args,
399-
env=proxy_env,
414+
env=[
415+
# This is required for GKE Workload Identity and Mac Jax Client support.
416+
# TODO(samos123): Remove this once this becomes the default.
417+
{"name": "IFRT_PROXY_USE_INSECURE_GRPC_CREDENTIALS", "value": "true"},
418+
{"name": "XLA_FLAGS", "value": f"--xla_dump_to=/output/{cfg.name}/xla"},
419+
{
420+
"name": "IFRT_PROXY_LARGE_TRANSFER_OPTIMIZATION_DIRECTORY",
421+
"value": "/tmp/ifrt_proxy",
422+
},
423+
],
400424
ports=[dict(containerPort=_PATHWAYS_PROXY_PORT)],
425+
volumeMounts=[
426+
dict(name="shared-output", mountPath="/output"),
427+
dict(name="shared-memory", mountPath="/tmp/ifrt_proxy"),
428+
],
401429
),
402430
dict(
403431
name=_PATHWAYS_RESOURCE_MANAGER_CONTAINER_NAME,
@@ -415,17 +443,18 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
415443
f"--server_port={_PATHWAYS_RESOURCE_MANAGER_PORT}",
416444
"--node_type=resource_manager",
417445
f"--instance_count={pathways_instance_count}",
418-
f"--instance_type={pathways_tpu_version}:{system.topology}",
446+
f"--instance_type={instance_type}",
419447
f"--gcs_scratch_location={staging_location}",
420448
],
449+
volumeMounts=[dict(name="shared-output", mountPath="/output")],
421450
),
422451
]
423452

424453
def _colocated_python_container(self):
425454
cfg: PathwaysReplicatedJob.Config = self.config
426455
return dict(
427456
name=_COLOCATED_PYTHON_SIDECAR_NAME,
428-
image=get_colocated_python_image(cfg.colocated_image), #_COLOCATED_PYTHON_IMAGE,
457+
image=get_colocated_python_image(cfg.colocated_image),
429458
restartPolicy="Always",
430459
env=[
431460
{
@@ -450,6 +479,7 @@ def _build_pathways_head_pod(self) -> Nested[Any]:
450479
labels.update({BASTION_JOB_VERSION_LABEL: os.environ.get(BASTION_JOB_VERSION_ENV_VAR)})
451480

452481
volumes.append(dict(name="shared-output", emptyDir={}))
482+
volumes.append(dict(name="shared-memory", emptyDir=dict(medium="Memory")))
453483

454484
if cfg.gcsfuse_mount:
455485
annotations.update(
@@ -466,7 +496,11 @@ def _build_pathways_head_pod(self) -> Nested[Any]:
466496
}
467497

468498
head_container = self._build_pathways_head_container()
469-
init_containers = self._build_pathways_head_sidecar_containers()
499+
init_containers = [
500+
*self._build_pathways_head_sidecar_containers(),
501+
# pylint: disable-next=protected-access
502+
self._inner._build_uploader_container(),
503+
]
470504

471505
# Hardcode metadata.google.internal ip address to avoid transient DNS resolution issue.
472506
metadata_host_alias = dict(
@@ -524,6 +558,8 @@ def _build_pathways_worker_container(
524558
) -> dict:
525559
"""Build the container for the 'pathways-worker' role."""
526560
cfg: TPUReplicatedJob.Config = self._inner.config
561+
system = USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS[self._tpu_type]
562+
host_memory = GCE_MACHINE_TYPE_TO_MEMORY_CHARACTERISTICS[system.gce_machine_type]
527563
# pylint: disable-next=protected-access
528564
container = self._inner._build_container()
529565

@@ -583,9 +619,9 @@ def _build_pathways_worker_container(
583619
# Recycling host memory gives a slight increase in performance.
584620
"--tpu_pinned_host_allocation_recycle=true",
585621
# The flag below is needed for better H2D performance.
586-
# Rule of thumb: 3x the shard size. So 128GB to be safe.
587-
# Decrease if you start running out of host memory on TPU VMs.
588-
"--tpu_premapped_buffer_size=137438953472",
622+
# We use 1/4 of the host memory, rounding up to power of 2 as premapped buffer.
623+
# Note that pathways worker requires this flag to be a power of 2.
624+
f"--tpu_premapped_buffer_size={round_up_to_power_of_2(host_memory//4)*(1<<30)}",
589625
]
590626
mega_scale_args = xla_flags_from_options(self._mxla_options).split()
591627
worker_container["args"].extend(mega_scale_args)
@@ -608,6 +644,7 @@ def _build_pathways_worker_pod(
608644
) -> Nested[Any]:
609645
"""Conoverts a worker pod to a new pod for the 'pathways-workers' role."""
610646
cfg: TPUReplicatedJob.Config = self._inner.config
647+
pathways_cfg: PathwaysReplicatedJob.Config = self.config
611648
# pylint: disable-next=protected-access
612649
pod = self._inner._build_pod()
613650
worker_pod = copy.deepcopy(pod)
@@ -623,7 +660,9 @@ def _build_pathways_worker_pod(
623660
pod_spec["containers"] = [
624661
self._build_pathways_worker_container(pathways_worker_replicated_job_index)
625662
]
626-
pod_spec["initContainers"] = [self._colocated_python_container()]
663+
664+
if pathways_cfg.colocated_image:
665+
pod_spec["initContainers"] = [self._colocated_python_container()]
627666

628667
worker_pod["spec"] = pod_spec
629668

@@ -965,7 +1004,7 @@ def _build_head_container(self) -> dict:
9651004
}
9661005
return dict(
9671006
name=cfg.name,
968-
image=self._bundler.id(cfg.name),
1007+
image=cfg.image_id or self._bundler.id(cfg.name),
9691008
command=["bash", "-c", cfg.command],
9701009
env=[
9711010
{

colocated_commands.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
export NAME=axlearn-img
55
export COLOCATED_NAME=colocated-img
6+
export CKPT_BUCKET_NAME=<>
67

78
axlearn gcp bundle --name=$NAME \
89
--bundler_spec=allow_dirty=True \
@@ -27,4 +28,4 @@ axlearn gcp launch run --cluster=mlperf-v5p \
2728
--bundler_spec=dockerfile=Dockerfile \
2829
--bundler_spec=target=tpu \
2930
--colocated_image=$COLOCATED_NAME \
30-
-- TPU_PREMAPPED_BUFFER_SIZE=34359738368 python3 test_benchmark.py --ckpt_path gs://cloud-tpu-multipod-dev-euw4/axlearn-fuji-v3-70b/checkpoints/step_00000100
31+
-- TPU_PREMAPPED_BUFFER_SIZE=34359738368 python3 test_benchmark.py --ckpt_path $CKPT_BUCKET_NAME

0 commit comments

Comments
 (0)