4848
4949import os
5050import subprocess
51+ import time
5152from typing import Optional
5253
5354from absl import app , flags , logging
5859from axlearn .cloud .common .bundler import register_bundler
5960from axlearn .cloud .common .docker import registry_from_repo
6061from axlearn .cloud .common .utils import canonicalize_to_list , to_bool
61- from axlearn .cloud .gcp .cloud_build import wait_for_cloud_build
62+ from axlearn .cloud .gcp .cloud_build import get_cloud_build_status
6263from axlearn .cloud .gcp .config import gcp_settings
6364from axlearn .cloud .gcp .utils import common_flags
64- from axlearn .common .config import REQUIRED , Required , config_class , maybe_set_config
65+ from axlearn .common .config import REQUIRED , Required , config_class , maybe_set_config , config_for_class
6566
6667FLAGS = flags .FLAGS
6768
@@ -98,19 +99,77 @@ class ArtifactRegistryBundler(DockerBundler):
9899
99100 TYPE = "artifactregistry"
100101
102+ @config_class
103+ class Config (DockerBundler .Config ):
104+ """Configures CloudBuildBundler.
105+
106+ Attributes:
107+ colocated_image_required: Bool to build a colocated image
108+ """
109+ # Build image asynchronously.
110+ colocated_image_required : bool = False
111+ colocated_image_name : str = None
112+ colocated_dockerfile : str = None
113+
114+
101115 @classmethod
102116 def from_spec (cls , spec : list [str ], * , fv : Optional [flags .FlagValues ]) -> DockerBundler .Config :
103- cfg = super ().from_spec (spec , fv = fv )
117+ cfg : ArtifactRegistryBundler . Config = super ().from_spec (spec , fv = fv )
104118 cfg .repo = cfg .repo or gcp_settings ("docker_repo" , required = False , fv = fv )
105119 cfg .dockerfile = cfg .dockerfile or gcp_settings ("default_dockerfile" , required = False , fv = fv )
120+ cfg .colocated_image_required = cfg .colocated_image_required or gcp_settings ("colocated_image_required" , required = False , fv = fv )
121+ cfg .colocated_image_name = cfg .colocated_image_name or gcp_settings ("colocated_image_name" , required = False , fv = fv )
122+ cfg .colocated_dockerfile = cfg .colocated_dockerfile or gcp_settings ("colocated_dockerfile" , required = False , fv = fv )
123+ return cfg
124+
125+ def _build_and_push (self , * args , ** kwargs ):
126+ cfg = self .config
127+ subprocess .run (
128+ ["gcloud" , "auth" , "configure-docker" , registry_from_repo (cfg .repo )],
129+ check = True ,
130+ )
131+
132+ print ("actual" ,cfg )
133+ actual_name = cfg .image
134+ actual_dockerfile = cfg .dockerfile
135+ actual_target = cfg .target
136+ if bool (cfg .colocated_image_required ):
137+
138+ cfg .dockerfile = cfg .colocated_dockerfile
139+ cfg .image = cfg .colocated_image_name
140+ cfg .target = None
141+ print ("updated config: " ,cfg )
142+ colocated_bundler_class = ColocatedArtifactRegistryBundler (cfg = cfg )
143+ colocated_image_name = colocated_bundler_class .bundle (tag = "latest" )
144+ print (colocated_image_name )
145+
146+ cfg .dockerfile = actual_dockerfile
147+ cfg .image = actual_name
148+ cfg .target = actual_target
149+
150+
151+
152+ return super ()._build_and_push (* args , ** kwargs )
153+
154+
155+ class ColocatedArtifactRegistryBundler (DockerBundler ):
156+ """A DockerBundler that reads configs from gcp_settings, and auths to Artifact Registry."""
157+
158+ @classmethod
159+ def from_spec (cls , spec : list [str ], * , fv : Optional [flags .FlagValues ]) -> DockerBundler .Config :
160+ cfg : ColocatedArtifactRegistryBundler .Config = super ().from_spec (spec , fv = fv )
161+ cfg .repo = cfg .repo or gcp_settings ("docker_repo" , required = False , fv = fv )
162+ cfg .dockerfile = cfg .colocated_dockerfile or gcp_settings ("colocated_dockerfile" , required = False , fv = fv )
106163 return cfg
107164
108165 def _build_and_push (self , * args , ** kwargs ):
109166 cfg = self .config
167+ print ("colocated" ,cfg )
110168 subprocess .run (
111169 ["gcloud" , "auth" , "configure-docker" , registry_from_repo (cfg .repo )],
112170 check = True ,
113171 )
172+
114173 return super ()._build_and_push (* args , ** kwargs )
115174
116175
@@ -237,14 +296,36 @@ def wait_until_finished(self, name: str, wait_timeout=3600):
237296 TimeoutError: If the build does not complete within the overall timeout.
238297 ValueError: If the async build fails.
239298 """
299+ start_time = time .perf_counter ()
240300 cfg : CloudBuildBundler .Config = self .config
241- if cfg .is_async :
242- wait_for_cloud_build (
243- project_id = cfg .project ,
244- image_id = self .id (name ),
245- tags = [name ],
246- wait_timeout = wait_timeout ,
247- )
301+ while cfg .is_async :
302+ elapsed_time = time .perf_counter () - start_time
303+ if elapsed_time > wait_timeout :
304+ timeout_msg = (
305+ f"Timed out waiting for CloudBuild to finish for more than "
306+ f"{ wait_timeout } seconds."
307+ )
308+ logging .error (timeout_msg )
309+ raise TimeoutError (timeout_msg )
310+ try :
311+ build_status = get_cloud_build_status (
312+ project_id = cfg .project , image_name = self .id (name ), tags = [name ]
313+ )
314+ except Exception as e : # pylint: disable=broad-except
315+ # TODO(liang-he,markblee): Distinguish transient from non-transient errors.
316+ logging .warning ("Failed to get the CloudBuild status, will retry: %s" , e )
317+ else :
318+ if not build_status :
319+ logging .warning ("CloudBuild for %s does not exist yet." , name )
320+ elif build_status .is_pending ():
321+ logging .info ("CloudBuild for %s is pending: %s." , name , build_status )
322+ elif build_status .is_success ():
323+ logging .info ("CloudBuild for %s is successful: %s." , name , build_status )
324+ return
325+ else :
326+ # Unknown status is also considered a failure.
327+ raise RuntimeError (f"CloudBuild for { name } failed: { build_status } ." )
328+ time .sleep (30 )
248329
249330
250331def with_tpu_extras (bundler : Bundler .Config ) -> Bundler .Config :
@@ -263,4 +344,4 @@ def with_tpu_extras(bundler: Bundler.Config) -> Bundler.Config:
263344if __name__ == "__main__" :
264345 common_flags ()
265346 bundler_main_flags ()
266- app .run (bundler_main )
347+ app .run (bundler_main )
0 commit comments