Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 34 additions & 4 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

ARG TARGET=base
ARG BASE_IMAGE=ubuntu:22.04
ARG BASE_IMAGE_COLOCATED=us-docker.pkg.dev/cloud-tpu-v2-images/pathways-colocated-python/sidecar:2025_10_06-python_3.10-jax_0.6.2

FROM ${BASE_IMAGE} AS base

Expand Down Expand Up @@ -94,18 +95,47 @@ ARG EXTRAS=
# Install a custom jaxlib that includes backport of Pathways shared memory feature.
# PR: https://github.com/openxla/xla/pull/31417
# Needed until Jax is upgraded to 0.8.0 or newer.
ARG INSTALL_PATHWAYS_JAXLIB=false
ARG INSTALL_PATHWAYS_JAXLIB=true

# Ensure we install the TPU version, even if building locally.
# Jax will fallback to CPU when run on a machine without TPU.
RUN uv pip install -qq --prerelease=allow .[core,tpu] && uv cache clean
RUN if [ -n "$EXTRAS" ]; then uv pip install -qq .[$EXTRAS] && uv cache clean; fi

COPY jaxlib-0.6.2.dev20251021-cp310-cp310-manylinux2014_x86_64.whl .

# 2. RUN the pip install command using the new, simple path *inside* the container
RUN if [ "$INSTALL_PATHWAYS_JAXLIB" = "true" ]; then \
uv pip install --prerelease=allow "jaxlib==0.5.3.dev20250918" \
--find-links https://storage.googleapis.com/axlearn-wheels/wheels.html; \
# uv pip install --prerelease=allow "jaxlib==0.6.2.dev20251020" \
# --find-links https://storage.googleapis.com/axlearn-wheels/wheels.html; \
uv pip install jaxlib-0.6.2.dev20251021-cp310-cp310-manylinux2014_x86_64.whl; \
fi
COPY . .

################################################################################
# Colocated Python container spec. #
################################################################################

FROM ${BASE_IMAGE_COLOCATED} as colocated-python

WORKDIR /app
COPY . .

# Install the additional user-provided dependencies, strictly enforcing the rules
# from the base image's constraints file.
RUN \
echo "--> Installing user-provided dependencies..." && \
uv pip install ".[core,gcp]" -c /opt/venv/server_constraints.txt && \
\
# 2. Verify that the colocated_python_cpu_client is present.
echo "--> Verifying JAX patch integrity..." && \
python -c "from jax._src.lib import _jax; _jax.colocated_python_cpu_client" && \
echo "--> JAX patch verification successful." && \
\
# 3. Clean the cache to keep the image slim.
uv cache clean


################################################################################
# GPU container spec. #
################################################################################
Expand All @@ -125,4 +155,4 @@ COPY . .
# Final target spec. #
################################################################################

FROM ${TARGET} AS final
FROM ${TARGET} AS final
59 changes: 57 additions & 2 deletions axlearn/cloud/gcp/bundler.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,65 @@ class ArtifactRegistryBundler(DockerBundler):

TYPE = "artifactregistry"

@config_class
class Config(DockerBundler.Config):
"""Configures CloudBuildBundler.

Attributes:
colocated_image_required: Bool to build a colocated image
colocated_image_name: Colocated Image Name
colocated_dockerfile: Colocated Dockerfile
"""
# Build image asynchronously.
colocated_image_required: bool = False
colocated_image_name: str = None
#colocated_dockerfile: str = None


@classmethod
def from_spec(cls, spec: list[str], *, fv: Optional[flags.FlagValues]) -> DockerBundler.Config:
cfg = super().from_spec(spec, fv=fv)
cfg: ArtifactRegistryBundler.Config = super().from_spec(spec, fv=fv)
cfg.repo = cfg.repo or gcp_settings("docker_repo", required=False, fv=fv)
cfg.dockerfile = cfg.dockerfile or gcp_settings("default_dockerfile", required=False, fv=fv)
cfg.colocated_image_required = cfg.colocated_image_required or gcp_settings("colocated_image_required", required=False, fv=fv)
cfg.colocated_image_name = cfg.colocated_image_name or gcp_settings("colocated_image_name", required=False, fv=fv)
#cfg.colocated_dockerfile = cfg.colocated_dockerfile or gcp_settings("colocated_dockerfile", required=False, fv=fv)
return cfg

def _build_and_push(self, *args, **kwargs):
cfg = self.config
subprocess.run(
["gcloud", "auth", "configure-docker", registry_from_repo(cfg.repo)],
check=True,
)

actual_name = cfg.image
#actual_dockerfile=cfg.dockerfile
actual_target=cfg.target
if bool(cfg.colocated_image_required):

#cfg.dockerfile=cfg.colocated_dockerfile
cfg.image=cfg.colocated_image_name
cfg.target="colocated-python"

colocated_bundler_class = ColocatedArtifactRegistryBundler(cfg=cfg)
colocated_image_name = colocated_bundler_class.bundle(tag=cfg.image)

#cfg.dockerfile=actual_dockerfile
cfg.image=actual_name
cfg.target=actual_target

return super()._build_and_push(*args, **kwargs)


class ColocatedArtifactRegistryBundler(DockerBundler):
"""A DockerBundler that reads configs from gcp_settings, and auths to Artifact Registry."""

@classmethod
def from_spec(cls, spec: list[str], *, fv: Optional[flags.FlagValues]) -> DockerBundler.Config:
cfg: ColocatedArtifactRegistryBundler.Config = super().from_spec(spec, fv=fv)
cfg.repo = cfg.repo or gcp_settings("docker_repo", required=False, fv=fv)
cfg.dockerfile = cfg.colocated_dockerfile or gcp_settings("default_dockerfile", required=False, fv=fv)
return cfg

def _build_and_push(self, *args, **kwargs):
Expand All @@ -111,6 +165,7 @@ def _build_and_push(self, *args, **kwargs):
["gcloud", "auth", "configure-docker", registry_from_repo(cfg.repo)],
check=True,
)

return super()._build_and_push(*args, **kwargs)


Expand Down Expand Up @@ -263,4 +318,4 @@ def with_tpu_extras(bundler: Bundler.Config) -> Bundler.Config:
if __name__ == "__main__":
common_flags()
bundler_main_flags()
app.run(bundler_main)
app.run(bundler_main)
97 changes: 71 additions & 26 deletions axlearn/cloud/gcp/pathways_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
)
from axlearn.common.config import REQUIRED, Required, config_class
from axlearn.common.utils import Nested
from axlearn.cloud.gcp.config import gcp_settings

# The port used by pathways proxy server.
# The specific value is not important, as long as clients and servers use the same port.
Expand All @@ -45,19 +46,25 @@
# The port used by pathways worker server.
# The specific value is not important, as long as clients and servers use the same port.
_PATHWAYS_WORKER_PORT = 29001
_COLOCATED_CONTAINER_PORT = 50051
# Pin to specific pathways image version for stable release.
# There is no guarantee that this image will work with newer Jax releases.
# This image version extends GRPC timeout for long context models, based on jax-0.5.3-patch060625
# This image extends GRPC timeout for long context models.
_PATHWAYS_IMAGE_TAG = "shm_proxy_settings"
_PATHWAYS_IMAGE_TAG = "2025-10-03"

# The docker image used by pathways proxy container.
# pylint: disable=line-too-long
_PATHWAYS_PROXY_IMAGE = (
f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/proxy_server:{_PATHWAYS_IMAGE_TAG}"
# pylint: disable=line-too-long
f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways-colocated-python/proxy_server:{_PATHWAYS_IMAGE_TAG}"
)
# The docker image used by pathways resource manager container and worker container.
_PATHWAYS_SERVER_IMAGE = (
f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways/server:{_PATHWAYS_IMAGE_TAG}"
# pylint: disable=line-too-long
f"us-docker.pkg.dev/cloud-tpu-v2-images/pathways-colocated-python/server:{_PATHWAYS_IMAGE_TAG}"
)

# The container name of pathways resourcemanager.
_PATHWAYS_RESOURCE_MANAGER_CONTAINER_NAME = "pathways-rm"
# The container name of pathways proxy.
Expand All @@ -67,6 +74,9 @@
# The k8s replicatedJob name for pathways-worker pods.
_PATHWAYS_WORKER_REPLICATED_JOB_NAME = "pathways-worker"

_COLOCATED_PYTHON_SIDECAR_NAME = "colocated-python-sidecar"


# Add node-selector for cpu workload to avoid sharing nodes with system services.
_PATHWAYS_HEAD_NODE_POOL_SELECTOR_KEY = "axlearn/nodepool_type"
_PATHWAYS_HEAD_NODE_POOL_SELECTOR_VALUE = "workload"
Expand All @@ -75,6 +85,12 @@
# While workers will share #workers * _PATHWAYS_BACK_OFF_LIMIT total times.
_PATHWAYS_BACK_OFF_LIMIT = 32

FLAGS = flags.FLAGS

def get_colocated_python_image(colocated_image_name, fv: flags.FlagValues = FLAGS) -> str:
repo = gcp_settings("docker_repo", required=False, fv=fv)
return repo+"/"+colocated_image_name+":"+colocated_image_name


def parse_xla_flag_value(value: str) -> Union[int, bool, str]:
"""Attempts to convert an XLA flag string value to int.
Expand Down Expand Up @@ -135,7 +151,6 @@ def get_xla_options(
"""
return {k: v for k, v in xla_options.items() if k.startswith("xla_")}


def round_up_to_power_of_2(n):
"""
Rounds an integer up to the nearest power of 2.
Expand Down Expand Up @@ -173,6 +188,7 @@ class Config(BaseReplicatedJob.Config):
pathways_xla_flags: list[str] = []
pathways_head_cpu: Optional[str] = None
pathways_head_mem: Optional[str] = None
colocated_image: Optional[str] = None

@classmethod
def define_flags(cls, fv):
Expand Down Expand Up @@ -201,12 +217,19 @@ def define_flags(cls, fv):
"Memory request for pathways-head container in GiB. Default is 16GiB",
**common_kwargs,
)
flags.DEFINE_string(
"colocated_image",
None,
"Colocated Image Name",
**common_kwargs,
)

@classmethod
def set_defaults(cls, fv):
super().set_defaults(fv)
fv.set_default("pathways_head_cpu", fv.pathways_head_cpu or "1")
fv.set_default("pathways_head_mem", fv.pathways_head_mem or "16")
fv.set_default("colocated_image", fv.colocated_image or None)

@classmethod
def default_config(cls):
Expand Down Expand Up @@ -311,29 +334,29 @@ def _build_pathways_head_container(self) -> dict:
}
)

# pylint: disable=line-too-long
env_list.append(
{
"name": "NUM_REPLICAS",
"valueFrom": {
"fieldRef": {
"fieldPath": "metadata.annotations['jobset.sigs.k8s.io/replicatedjob-replicas']"
}
},
}
)
# # pylint: disable=line-too-long
# env_list.append(
# {
# "name": "NUM_REPLICAS",
# "valueFrom": {
# "fieldRef": {
# "fieldPath": "metadata.annotations['jobset.sigs.k8s.io/replicatedjob-replicas']"
# }
# },
# }
# )
# pylint: enable=line-too-long

env_list.append(
{
"name": "REPLICA_ID",
"valueFrom": {
"fieldRef": {
"fieldPath": "metadata.annotations['jobset.sigs.k8s.io/job-index']"
}
},
}
)
# env_list.append(
# {
# "name": "REPLICA_ID",
# "valueFrom": {
# "fieldRef": {
# "fieldPath": "metadata.annotations['jobset.sigs.k8s.io/job-index']"
# }
# },
# }
# )

head_container["env"] = env_list

Expand Down Expand Up @@ -373,6 +396,7 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
f"--resource_manager_address=localhost:{_PATHWAYS_RESOURCE_MANAGER_PORT}",
f"--server_port={_PATHWAYS_PROXY_PORT}",
f"--gcs_scratch_location={staging_location}",
"--sidecar_name=external",
]
cmd_args.extend(xla_flags_from_options(self._xla_options).split())

Expand Down Expand Up @@ -426,6 +450,22 @@ def _build_pathways_head_sidecar_containers(self) -> list[Nested[Any]]:
),
]

def _colocated_python_container(self):
cfg: PathwaysReplicatedJob.Config = self.config
return dict(
name=_COLOCATED_PYTHON_SIDECAR_NAME,
image=get_colocated_python_image(cfg.colocated_image),
restartPolicy="Always",
env=[
{
"name": "GRPC_SERVER_ADDRESS",
"value": f"0.0.0.0:{_COLOCATED_CONTAINER_PORT}",
},
],
imagePullPolicy="Always",
ports=[dict(containerPort=_COLOCATED_CONTAINER_PORT)],
)

def _build_pathways_head_pod(self) -> Nested[Any]:
"""Builds a pathways head pod. The pod includes a head container,
a proxy container and a resource manager container.
Expand Down Expand Up @@ -604,6 +644,7 @@ def _build_pathways_worker_pod(
) -> Nested[Any]:
"""Conoverts a worker pod to a new pod for the 'pathways-workers' role."""
cfg: TPUReplicatedJob.Config = self._inner.config
pathways_cfg: PathwaysReplicatedJob.Config = self.config
# pylint: disable-next=protected-access
pod = self._inner._build_pod()
worker_pod = copy.deepcopy(pod)
Expand All @@ -619,6 +660,10 @@ def _build_pathways_worker_pod(
pod_spec["containers"] = [
self._build_pathways_worker_container(pathways_worker_replicated_job_index)
]

if pathways_cfg.colocated_image:
pod_spec["initContainers"] = [self._colocated_python_container()]

worker_pod["spec"] = pod_spec

# Service account for nodes.
Expand Down Expand Up @@ -1056,4 +1101,4 @@ def __call__(self) -> Nested[Any]:
size=system.vms_per_slice + 1,
leaderTemplate=self.build_leader_pod(),
workerTemplate=self.build_worker_pod(),
)
)
4 changes: 4 additions & 0 deletions axlearn/common/array_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ async def _async_serialize(
)
# pylint: disable=protected-access
spec_has_metadata = {
"0.6.2.dev0+selfbuilt": lambda: serialization.ts_impl._spec_has_metadata,
"0.6.2": lambda: serialization.ts_impl._spec_has_metadata,
"0.5.3": lambda: serialization._spec_has_metadata,
}[jax.__version__]()
Expand Down Expand Up @@ -487,6 +488,7 @@ async def cb(index: array.Index, device: jax.Device):
requested_domain = ts.IndexTransform(input_shape=shape)[index].domain
restricted_domain = t.domain.intersect(requested_domain)
estimate_read_memory_footprint = {
"0.6.2.dev0+selfbuilt": lambda: serialization.ts_impl.estimate_read_memory_footprint,
"0.6.2": lambda: serialization.ts_impl.estimate_read_memory_footprint,
"0.5.3": lambda: serialization.estimate_read_memory_footprint,
}[jax.__version__]()
Expand Down Expand Up @@ -568,6 +570,7 @@ async def cb(index: array.Index, device: jax.Device):

# pylint: disable=protected-access
create_async_array_from_callback = {
"0.6.2.dev0+selfbuilt": lambda: serialization.ts_impl._create_async_array_from_callback,
"0.6.2": lambda: serialization.ts_impl._create_async_array_from_callback,
"0.5.3": lambda: serialization.create_async_array_from_callback,
}[jax.__version__]()
Expand Down Expand Up @@ -653,6 +656,7 @@ def serialize(
commit_futures = [[] for _ in range(len(tensorstore_specs))]

async_serialize = {
"0.6.2.dev0+selfbuilt": lambda: serialization.ts_impl.async_serialize,
"0.6.2": lambda: serialization.ts_impl.async_serialize,
"0.5.3": lambda: serialization.async_serialize,
}[jax.__version__]()
Expand Down
Loading
Loading