2121 _LoadBalancer ,
2222)
2323from axlearn .cloud .gcp .lws_utils import BaseLeaderWorkerTemplate , TPULeaderWorkerTemplate
24- from axlearn .cloud .gcp .system_characteristics import USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS
24+ from axlearn .cloud .gcp .system_characteristics import (
25+ GCE_MACHINE_TYPE_TO_MEMORY_CHARACTERISTICS ,
26+ USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS ,
27+ support_twisted_topology ,
28+ )
2529from axlearn .cloud .gcp .tpu import infer_tpu_workers
2630from axlearn .cloud .gcp .utils import validate_jobset_name
2731from axlearn .common .compiler_options import (
4751# There is no guarantee that this image will work with newer Jax releases.
4852# This image version extends GRPC timeout for long context models, based on jax-0.5.3-patch060625
4953# This image extends GRPC timeout for long context models.
50- _PATHWAYS_IMAGE_TAG = "disable_settings_20250701 "
54+ _PATHWAYS_IMAGE_TAG = "2025-10-03 "
5155
5256# The docker image used by pathways proxy container.
5357# pylint: disable=line-too-long
54- # _PATHWAYS_PROXY_IMAGE = "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/ksadi/unsanitized_proxy_server_maxtext:latest"
5558_PATHWAYS_PROXY_IMAGE = (
5659 # pylint: disable=line-too-long
57- "us-docker.pkg.dev/cloud-tpu-v2-images-dev /pathways/gke/ksadi/unsanitized_proxy_server:latest "
60+ f "us-docker.pkg.dev/cloud-tpu-v2-images/pathways-colocated-python/proxy_server: { _PATHWAYS_IMAGE_TAG } "
5861)
5962# The docker image used by pathways resource manager container and worker container.
6063_PATHWAYS_SERVER_IMAGE = (
6164 # pylint: disable=line-too-long
62- # "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/ksadi/unsanitized_server@sha256:fde763e2bae514d0fa758840e501b71a9ea48781dddafa5d8ed3a0fa316fd1ae"
63- "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/ksadi/unsanitized_server:latest"
64- # "us-docker.pkg.dev/cloud-tpu-v2-images-dev/pathways/gke/ksadi/unsanitized_server_maxtext:latest"
65- )
66- _COLOCATED_PYTHON_IMAGE = (
67- # "gcr.io/cloud-tpu-multipod-dev/ksadi_sidecar_maxtext:latest"
68- # pylint: disable=line-too-long
69- #"us-docker.pkg.dev/cloud-tpu-multipod-dev/colocated-images/sam:v6"
70- "us-docker.pkg.dev/cloud-tpu-multipod-dev/axlearn/colocated-img13:latest"
71- # "us-docker.pkg.dev/cloud-tpu-multipod-dev/colocated-images/lk-colocated-image:latest"
72- # "gcr.io/cloud-tpu-multipod-dev/sujinesh_sidecar_debug@sha256:19abcd94addb6ff2749c299d6b0cc4748f27a4ab8759a18b466d0bdd3e5b71e8"
65+ f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways-colocated-python/server:{ _PATHWAYS_IMAGE_TAG } "
7366)
67+
7468# The container name of pathways resourcemanager.
7569_PATHWAYS_RESOURCE_MANAGER_CONTAINER_NAME = "pathways-rm"
7670# The container name of pathways proxy.
9589
9690def get_colocated_python_image (colocated_image_name , fv : flags .FlagValues = FLAGS ) -> str :
9791 repo = gcp_settings ("docker_repo" , required = False , fv = fv )
98- print (repo )
99- print (colocated_image_name )
100- print (repo + colocated_image_name + ":latest" )
10192 return repo + "/" + colocated_image_name + ":latest"
10293
10394
10495def parse_xla_flag_value (value : str ) -> Union [int , bool , str ]:
105- """Attempts to convert an XLA flag string value to int, then bool .
96+ """Attempts to convert an XLA flag string value to int.
10697
10798 If conversion fails, returns the original string (stripped).
10899 """
109- bool_mapper = {"true" : True , "false" : False }
110100 stripped_value_str = value .strip ()
111101 try :
112102 return int (stripped_value_str )
113103 except ValueError :
114- # Not an integer, try boolean conversion.
115- return bool_mapper .get (stripped_value_str .lower (), stripped_value_str )
104+ return stripped_value_str
116105
117106
118107def get_pathways_tpu_version (gke_machine_type : str ) -> str :
@@ -162,6 +151,25 @@ def get_xla_options(
162151 """
163152 return {k : v for k , v in xla_options .items () if k .startswith ("xla_" )}
164153
154+ def round_up_to_power_of_2 (n ):
155+ """
156+ Rounds an integer up to the nearest power of 2.
157+
158+ Args:
159+ n (int): The number to round up. Must be a positive integer.
160+
161+ Returns:
162+ int: The smallest power of 2 that is greater than or equal to n.
163+
164+ Examples:
165+ round_up_to_power_of_2(7) -> 8
166+ round_up_to_power_of_2(8) -> 8
167+ round_up_to_power_of_2(9) -> 16
168+ round_up_to_power_of_2(32) -> 32
169+ """
170+ assert isinstance (n , int ) and n > 0
171+ return 1 << (n - 1 ).bit_length ()
172+
165173
166174class PathwaysReplicatedJob (BaseReplicatedJob ):
167175 """Builds a replicated jobspec for Pathways on TPU, to be used with JobSet API."""
@@ -311,7 +319,12 @@ def _build_pathways_head_container(self) -> dict:
311319 # In Jax 0.6.2 and beyond this flag can be renamed to
312320 # IFRT_PROXY_USE_INSECURE_GRPC_CREDENTIALS as well.
313321 self ._update_env_list (env_list , "TEST_UNDECLARED_OUTPUTS_DIR" , "true" )
314-
322+ # Threshold for using shared memory between Jax client and Pathways proxy.
323+ # Setting it to 1 byte so effectively all Jax device_put use shared memory.
324+ self ._update_env_list (env_list , "IFRT_PROXY_LARGE_TRANSFER_THRESHOLD" , "1" )
325+ self ._update_env_list (
326+ env_list , "IFRT_PROXY_LARGE_TRANSFER_OPTIMIZATION_DIRECTORY" , "/tmp/ifrt_proxy"
327+ )
315328 env_list .append (
316329 {
317330 "name" : "HOST_ADDRESS" ,
@@ -351,10 +364,13 @@ def _build_pathways_head_container(self) -> dict:
351364 mem_req = f"{ self .config .pathways_head_mem } Gi"
352365 resources = {
353366 "requests" : {"cpu" : cpu_req , "memory" : mem_req },
354- "limits" : {"cpu" : cpu_req , "memory" : mem_req },
355367 }
356368 head_container ["resources" ] = resources
357369
370+ volume_mounts = head_container .get ("volumeMounts" , [])
371+ volume_mounts .append (dict (name = "shared-memory" , mountPath = "/tmp/ifrt_proxy" ))
372+ head_container ["volumeMounts" ] = volume_mounts
373+
358374 return head_container
359375
360376 def _build_pathways_head_sidecar_containers (self ) -> list [Nested [Any ]]:
@@ -384,10 +400,9 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
384400 ]
385401 cmd_args .extend (xla_flags_from_options (self ._xla_options ).split ())
386402
387- # This is required for GKE Workload Identity and Mac Jax Client support.
388- # TODO(samos123): Remove this once this becomes the default.
389- proxy_env = [{"name" : "IFRT_PROXY_USE_INSECURE_GRPC_CREDENTIALS" , "value" : "true" }]
390-
403+ instance_type = f"{ pathways_tpu_version } :{ system .topology } "
404+ if support_twisted_topology (self ._tpu_type ):
405+ instance_type = f"{ instance_type } _untwisted"
391406 return [
392407 dict (
393408 name = _PATHWAYS_PROXY_CONTAINER_NAME ,
@@ -396,8 +411,21 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
396411 # SideCar container is an init container with restartPolicy as "Always".
397412 restartPolicy = "Always" ,
398413 args = cmd_args ,
399- env = proxy_env ,
414+ env = [
415+ # This is required for GKE Workload Identity and Mac Jax Client support.
416+ # TODO(samos123): Remove this once this becomes the default.
417+ {"name" : "IFRT_PROXY_USE_INSECURE_GRPC_CREDENTIALS" , "value" : "true" },
418+ {"name" : "XLA_FLAGS" , "value" : f"--xla_dump_to=/output/{ cfg .name } /xla" },
419+ {
420+ "name" : "IFRT_PROXY_LARGE_TRANSFER_OPTIMIZATION_DIRECTORY" ,
421+ "value" : "/tmp/ifrt_proxy" ,
422+ },
423+ ],
400424 ports = [dict (containerPort = _PATHWAYS_PROXY_PORT )],
425+ volumeMounts = [
426+ dict (name = "shared-output" , mountPath = "/output" ),
427+ dict (name = "shared-memory" , mountPath = "/tmp/ifrt_proxy" ),
428+ ],
401429 ),
402430 dict (
403431 name = _PATHWAYS_RESOURCE_MANAGER_CONTAINER_NAME ,
@@ -415,17 +443,18 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
415443 f"--server_port={ _PATHWAYS_RESOURCE_MANAGER_PORT } " ,
416444 "--node_type=resource_manager" ,
417445 f"--instance_count={ pathways_instance_count } " ,
418- f"--instance_type={ pathways_tpu_version } : { system . topology } " ,
446+ f"--instance_type={ instance_type } " ,
419447 f"--gcs_scratch_location={ staging_location } " ,
420448 ],
449+ volumeMounts = [dict (name = "shared-output" , mountPath = "/output" )],
421450 ),
422451 ]
423452
424453 def _colocated_python_container (self ):
425454 cfg : PathwaysReplicatedJob .Config = self .config
426455 return dict (
427456 name = _COLOCATED_PYTHON_SIDECAR_NAME ,
428- image = get_colocated_python_image (cfg .colocated_image ), #_COLOCATED_PYTHON_IMAGE,
457+ image = get_colocated_python_image (cfg .colocated_image ),
429458 restartPolicy = "Always" ,
430459 env = [
431460 {
@@ -450,6 +479,7 @@ def _build_pathways_head_pod(self) -> Nested[Any]:
450479 labels .update ({BASTION_JOB_VERSION_LABEL : os .environ .get (BASTION_JOB_VERSION_ENV_VAR )})
451480
452481 volumes .append (dict (name = "shared-output" , emptyDir = {}))
482+ volumes .append (dict (name = "shared-memory" , emptyDir = dict (medium = "Memory" )))
453483
454484 if cfg .gcsfuse_mount :
455485 annotations .update (
@@ -466,7 +496,11 @@ def _build_pathways_head_pod(self) -> Nested[Any]:
466496 }
467497
468498 head_container = self ._build_pathways_head_container ()
469- init_containers = self ._build_pathways_head_sidecar_containers ()
499+ init_containers = [
500+ * self ._build_pathways_head_sidecar_containers (),
501+ # pylint: disable-next=protected-access
502+ self ._inner ._build_uploader_container (),
503+ ]
470504
471505 # Hardcode metadata.google.internal ip address to avoid transient DNS resolution issue.
472506 metadata_host_alias = dict (
@@ -524,6 +558,8 @@ def _build_pathways_worker_container(
524558 ) -> dict :
525559 """Build the container for the 'pathways-worker' role."""
526560 cfg : TPUReplicatedJob .Config = self ._inner .config
561+ system = USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS [self ._tpu_type ]
562+ host_memory = GCE_MACHINE_TYPE_TO_MEMORY_CHARACTERISTICS [system .gce_machine_type ]
527563 # pylint: disable-next=protected-access
528564 container = self ._inner ._build_container ()
529565
@@ -583,9 +619,9 @@ def _build_pathways_worker_container(
583619 # Recycling host memory gives a slight increase in performance.
584620 "--tpu_pinned_host_allocation_recycle=true" ,
585621 # The flag below is needed for better H2D performance.
586- # Rule of thumb: 3x the shard size. So 128GB to be safe .
587- # Decrease if you start running out of host memory on TPU VMs .
588- "--tpu_premapped_buffer_size=137438953472 " ,
622+ # We use 1/4 of the host memory, rounding up to power of 2 as premapped buffer .
623+ # Note that pathways worker requires this flag to be a power of 2 .
624+ f "--tpu_premapped_buffer_size={ round_up_to_power_of_2 ( host_memory // 4 ) * ( 1 << 30 ) } " ,
589625 ]
590626 mega_scale_args = xla_flags_from_options (self ._mxla_options ).split ()
591627 worker_container ["args" ].extend (mega_scale_args )
@@ -608,6 +644,7 @@ def _build_pathways_worker_pod(
608644 ) -> Nested [Any ]:
609645 """Conoverts a worker pod to a new pod for the 'pathways-workers' role."""
610646 cfg : TPUReplicatedJob .Config = self ._inner .config
647+ pathways_cfg : PathwaysReplicatedJob .Config = self .config
611648 # pylint: disable-next=protected-access
612649 pod = self ._inner ._build_pod ()
613650 worker_pod = copy .deepcopy (pod )
@@ -623,7 +660,9 @@ def _build_pathways_worker_pod(
623660 pod_spec ["containers" ] = [
624661 self ._build_pathways_worker_container (pathways_worker_replicated_job_index )
625662 ]
626- pod_spec ["initContainers" ] = [self ._colocated_python_container ()]
663+
664+ if pathways_cfg .colocated_image :
665+ pod_spec ["initContainers" ] = [self ._colocated_python_container ()]
627666
628667 worker_pod ["spec" ] = pod_spec
629668
@@ -965,7 +1004,7 @@ def _build_head_container(self) -> dict:
9651004 }
9661005 return dict (
9671006 name = cfg .name ,
968- image = self ._bundler .id (cfg .name ),
1007+ image = cfg . image_id or self ._bundler .id (cfg .name ),
9691008 command = ["bash" , "-c" , cfg .command ],
9701009 env = [
9711010 {
0 commit comments