Skip to content

Commit f446c16

Browse files
committed
working colocated example
1 parent 8f4bc2e commit f446c16

File tree

8 files changed

+663
-86
lines changed

8 files changed

+663
-86
lines changed

Dockerfile.colocated

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Use the JAX image with the custom-built sidecar as the base.
2+
3+
FROM gcr.io/cloud-tpu-multipod-dev/sujinesh_sidecar_debug@sha256:9c9bea3836db6ef736abcdf1838074d8ba18e55e6c59272f19a1d15b1f9819aa
4+
# Defines a build argument for the requirements file. This contains the user's custom
5+
6+
# Set the working directory (this is already inherited)
7+
WORKDIR /app
8+
9+
# Copy the user's requirements file into the image.
10+
COPY . .
11+
12+
# Install the additional user-provided dependencies, strictly enforcing the rules
13+
# from the base image's constraints file.
14+
# First, generate the requirements and a custom constraints file in its own layer for better caching.
15+
RUN \
16+
# Install toml, generate the requirements files, and then clean up.
17+
pip install toml && \
18+
python3 colocated/extract_deps.py && \
19+
pip uninstall -y toml && \
20+
\
21+
# Create a new constraints file that includes the correct jaxlib version.
22+
JAXLIB_VERSION=$(pip show jaxlib | grep Version | awk '{print $2}') && \
23+
cat /opt/venv/server_constraints.txt > /tmp/constraints.txt && \
24+
echo "jaxlib==$JAXLIB_VERSION" >> /tmp/constraints.txt
25+
26+
# Now, install the dependencies and the axlearn package.
27+
RUN \
28+
# Capture the initial jax and jaxlib versions from the base image.
29+
JAX_VERSION_BEFORE=$(pip show jax | grep Version | awk '{print $2}') && \
30+
JAXLIB_VERSION_BEFORE=$(pip show jaxlib | grep Version | awk '{print $2}') && \
31+
\
32+
# Install the main dependencies, using our custom constraints.
33+
uv pip install -r /tmp/requirements.txt -c /tmp/constraints.txt && \
34+
\
35+
# Install the JAX-dependent packages without resolving their dependencies.
36+
uv pip install --no-deps -r /tmp/requirements_nodeps.txt && \
37+
\
38+
# Install the axlearn package itself, without its dependencies.
39+
uv pip install --no-deps . && \
40+
\
41+
# Clean up the temporary files and the cache.
42+
rm /tmp/requirements.txt && \
43+
rm /tmp/requirements_nodeps.txt && \
44+
rm /tmp/constraints.txt && \
45+
uv cache clean && \
46+
\
47+
# Capture the final jax and jaxlib versions.
48+
JAX_VERSION_AFTER=$(pip show jax | grep Version | awk '{print $2}') && \
49+
JAXLIB_VERSION_AFTER=$(pip show jaxlib | grep Version | awk '{print $2}') && \
50+
\
51+
# Verify that the versions have not changed.
52+
if [ "$JAX_VERSION_BEFORE" != "$JAX_VERSION_AFTER" ] || [ "$JAXLIB_VERSION_BEFORE" != "$JAXLIB_VERSION_AFTER" ]; then \
53+
echo "ERROR: jax or jaxlib version changed!" >&2; \
54+
echo "jax version before: $JAX_VERSION_BEFORE, after: $JAX_VERSION_AFTER" >&2; \
55+
echo "jaxlib version before: $JAXLIB_VERSION_BEFORE, after: $JAXLIB_VERSION_AFTER" >&2; \
56+
exit 1; \
57+
fi
58+
59+
# Note: The ENTRYPOINT and CMD are inherited from the base image, so they do not
60+
# need to be redefined here. I.e. the sidecar will be launched automatically.

axlearn/cloud/gcp/bundler.py

Lines changed: 92 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848

4949
import os
5050
import subprocess
51+
import time
5152
from typing import Optional
5253

5354
from absl import app, flags, logging
@@ -58,10 +59,10 @@
5859
from axlearn.cloud.common.bundler import register_bundler
5960
from axlearn.cloud.common.docker import registry_from_repo
6061
from axlearn.cloud.common.utils import canonicalize_to_list, to_bool
61-
from axlearn.cloud.gcp.cloud_build import wait_for_cloud_build
62+
from axlearn.cloud.gcp.cloud_build import get_cloud_build_status
6263
from axlearn.cloud.gcp.config import gcp_settings
6364
from axlearn.cloud.gcp.utils import common_flags
64-
from axlearn.common.config import REQUIRED, Required, config_class, maybe_set_config
65+
from axlearn.common.config import REQUIRED, Required, config_class, maybe_set_config, config_for_class
6566

6667
FLAGS = flags.FLAGS
6768

@@ -98,19 +99,77 @@ class ArtifactRegistryBundler(DockerBundler):
9899

99100
TYPE = "artifactregistry"
100101

102+
@config_class
103+
class Config(DockerBundler.Config):
104+
"""Configures CloudBuildBundler.
105+
106+
Attributes:
107+
colocated_image_required: Bool to build a colocated image
108+
"""
109+
# Build image asynchronously.
110+
colocated_image_required: bool = False
111+
colocated_image_name: str = None
112+
colocated_dockerfile: str = None
113+
114+
101115
@classmethod
102116
def from_spec(cls, spec: list[str], *, fv: Optional[flags.FlagValues]) -> DockerBundler.Config:
103-
cfg = super().from_spec(spec, fv=fv)
117+
cfg: ArtifactRegistryBundler.Config = super().from_spec(spec, fv=fv)
104118
cfg.repo = cfg.repo or gcp_settings("docker_repo", required=False, fv=fv)
105119
cfg.dockerfile = cfg.dockerfile or gcp_settings("default_dockerfile", required=False, fv=fv)
120+
cfg.colocated_image_required = cfg.colocated_image_required or gcp_settings("colocated_image_required", required=False, fv=fv)
121+
cfg.colocated_image_name = cfg.colocated_image_name or gcp_settings("colocated_image_name", required=False, fv=fv)
122+
cfg.colocated_dockerfile = cfg.colocated_dockerfile or gcp_settings("colocated_dockerfile", required=False, fv=fv)
123+
return cfg
124+
125+
def _build_and_push(self, *args, **kwargs):
126+
cfg = self.config
127+
subprocess.run(
128+
["gcloud", "auth", "configure-docker", registry_from_repo(cfg.repo)],
129+
check=True,
130+
)
131+
132+
print("actual",cfg)
133+
actual_name = cfg.image
134+
actual_dockerfile=cfg.dockerfile
135+
actual_target=cfg.target
136+
if bool(cfg.colocated_image_required):
137+
138+
cfg.dockerfile=cfg.colocated_dockerfile
139+
cfg.image=cfg.colocated_image_name
140+
cfg.target=None
141+
print("updated config: ",cfg)
142+
colocated_bundler_class = ColocatedArtifactRegistryBundler(cfg=cfg)
143+
colocated_image_name = colocated_bundler_class.bundle(tag="latest")
144+
print(colocated_image_name)
145+
146+
cfg.dockerfile=actual_dockerfile
147+
cfg.image=actual_name
148+
cfg.target=actual_target
149+
150+
151+
152+
return super()._build_and_push(*args, **kwargs)
153+
154+
155+
class ColocatedArtifactRegistryBundler(DockerBundler):
156+
"""A DockerBundler that reads configs from gcp_settings, and auths to Artifact Registry."""
157+
158+
@classmethod
159+
def from_spec(cls, spec: list[str], *, fv: Optional[flags.FlagValues]) -> DockerBundler.Config:
160+
cfg: ColocatedArtifactRegistryBundler.Config = super().from_spec(spec, fv=fv)
161+
cfg.repo = cfg.repo or gcp_settings("docker_repo", required=False, fv=fv)
162+
cfg.dockerfile = cfg.colocated_dockerfile or gcp_settings("colocated_dockerfile", required=False, fv=fv)
106163
return cfg
107164

108165
def _build_and_push(self, *args, **kwargs):
109166
cfg = self.config
167+
print("colocated",cfg)
110168
subprocess.run(
111169
["gcloud", "auth", "configure-docker", registry_from_repo(cfg.repo)],
112170
check=True,
113171
)
172+
114173
return super()._build_and_push(*args, **kwargs)
115174

116175

@@ -237,14 +296,36 @@ def wait_until_finished(self, name: str, wait_timeout=3600):
237296
TimeoutError: If the build does not complete within the overall timeout.
238297
ValueError: If the async build fails.
239298
"""
299+
start_time = time.perf_counter()
240300
cfg: CloudBuildBundler.Config = self.config
241-
if cfg.is_async:
242-
wait_for_cloud_build(
243-
project_id=cfg.project,
244-
image_id=self.id(name),
245-
tags=[name],
246-
wait_timeout=wait_timeout,
247-
)
301+
while cfg.is_async:
302+
elapsed_time = time.perf_counter() - start_time
303+
if elapsed_time > wait_timeout:
304+
timeout_msg = (
305+
f"Timed out waiting for CloudBuild to finish for more than "
306+
f"{wait_timeout} seconds."
307+
)
308+
logging.error(timeout_msg)
309+
raise TimeoutError(timeout_msg)
310+
try:
311+
build_status = get_cloud_build_status(
312+
project_id=cfg.project, image_name=self.id(name), tags=[name]
313+
)
314+
except Exception as e: # pylint: disable=broad-except
315+
# TODO(liang-he,markblee): Distinguish transient from non-transient errors.
316+
logging.warning("Failed to get the CloudBuild status, will retry: %s", e)
317+
else:
318+
if not build_status:
319+
logging.warning("CloudBuild for %s does not exist yet.", name)
320+
elif build_status.is_pending():
321+
logging.info("CloudBuild for %s is pending: %s.", name, build_status)
322+
elif build_status.is_success():
323+
logging.info("CloudBuild for %s is successful: %s.", name, build_status)
324+
return
325+
else:
326+
# Unknown status is also considered a failure.
327+
raise RuntimeError(f"CloudBuild for {name} failed: {build_status}.")
328+
time.sleep(30)
248329

249330

250331
def with_tpu_extras(bundler: Bundler.Config) -> Bundler.Config:
@@ -263,4 +344,4 @@ def with_tpu_extras(bundler: Bundler.Config) -> Bundler.Config:
263344
if __name__ == "__main__":
264345
common_flags()
265346
bundler_main_flags()
266-
app.run(bundler_main)
347+
app.run(bundler_main)

0 commit comments

Comments
 (0)