|
| 1 | +""" |
| 2 | +Doctor checks for essential programs needed to launch distributed training. |
| 3 | +""" |
| 4 | + |
| 5 | +import getpass |
| 6 | +import grp |
| 7 | +import json |
| 8 | +import os |
| 9 | +import subprocess |
| 10 | +import sys |
| 11 | +from pathlib import Path |
| 12 | + |
| 13 | +import click |
| 14 | + |
| 15 | + |
| 16 | +class CheckFailedError(Exception): |
| 17 | + pass |
| 18 | + |
| 19 | + |
| 20 | +def check_docker(): |
| 21 | + """Check that docker is installed.""" |
| 22 | + try: |
| 23 | + subprocess.run(["docker", "help"], check=True, capture_output=True) |
| 24 | + except FileNotFoundError: |
| 25 | + raise CheckFailedError("docker not found. Please install docker first.") from None |
| 26 | + |
| 27 | + |
| 28 | +def check_gcr_io(): |
| 29 | + """Check that docker config contains gcr.io credential helper.""" |
| 30 | + try: |
| 31 | + docker_config = json.loads( |
| 32 | + Path(os.path.expanduser("~/.docker/config.json")).read_text() |
| 33 | + ) |
| 34 | + cred_helpers = docker_config["credHelpers"] |
| 35 | + _gcr_io = cred_helpers["gcr.io"] |
| 36 | + except (FileNotFoundError, KeyError, json.JSONDecodeError): |
| 37 | + user = getpass.getuser() |
| 38 | + groups_for_user = [g.gr_name for g in grp.getgrall() if user in g.gr_mem] |
| 39 | + setup_cmd = "gcloud auth configure-docker" |
| 40 | + if "docker" not in groups_for_user: |
| 41 | + setup_cmd = f"sudo {setup_cmd}" |
| 42 | + raise CheckFailedError( |
| 43 | + f""" |
| 44 | +Did not find a handler for `gcr.io` in docker credential helpers. |
| 45 | +
|
| 46 | +TorchPrime uploads docker containers to the `gcr.io` docker registry, which |
| 47 | +requires valid credentials. |
| 48 | +
|
| 49 | +To setup the credentials, please run: |
| 50 | +
|
| 51 | + {setup_cmd} |
| 52 | +
|
| 53 | +""".lstrip() |
| 54 | + ) from None |
| 55 | + |
| 56 | + |
| 57 | +def check_docker_access(): |
| 58 | + """Check that the gcloud account can access the gcr.io artifact registry.""" |
| 59 | + try: |
| 60 | + subprocess.run( |
| 61 | + ["gcloud", "artifacts", "repositories", "describe", "gcr.io", "--location=us"], |
| 62 | + check=True, |
| 63 | + capture_output=True, |
| 64 | + ) |
| 65 | + except subprocess.CalledProcessError as e: |
| 66 | + account = subprocess.run( |
| 67 | + ["gcloud", "config", "get-value", "account"], capture_output=True, text=True |
| 68 | + ).stdout.strip() |
| 69 | + raise CheckFailedError( |
| 70 | + f"""The current gcloud account `{account}` cannot access the gcr.io registry. |
| 71 | +The account may not have the required permissions. If it's a service account, the |
| 72 | +VM may not have the correct scopes. |
| 73 | +
|
| 74 | +The easiest way to resolve this is to login with your own account: |
| 75 | +
|
| 76 | + gcloud auth login |
| 77 | +
|
| 78 | +""" |
| 79 | + ) from e |
| 80 | + |
| 81 | + |
| 82 | +def check_gcloud_auth_login(): |
| 83 | + """Check that gcloud is logged in.""" |
| 84 | + try: |
| 85 | + subprocess.run( |
| 86 | + ["gcloud", "auth", "print-access-token"], check=True, capture_output=True |
| 87 | + ) |
| 88 | + except subprocess.CalledProcessError as e: |
| 89 | + raise CheckFailedError( |
| 90 | + f"gcloud auth print-access-token failed: {e.stderr.decode()}" |
| 91 | + ) from e |
| 92 | + |
| 93 | + |
| 94 | +def check_kubectl(): |
| 95 | + """Check that kubectl is installed.""" |
| 96 | + try: |
| 97 | + subprocess.run(["kubectl", "help"], check=True, capture_output=True) |
| 98 | + except FileNotFoundError: |
| 99 | + raise CheckFailedError( |
| 100 | + f"""kubectl not found. |
| 101 | +
|
| 102 | +{get_kubectl_install_instructions()}""" |
| 103 | + ) from None |
| 104 | + |
| 105 | + |
| 106 | +def check_gke_gcloud_auth_plugin(): |
| 107 | + """Check that gke-gcloud-auth-plugin is installed.""" |
| 108 | + if is_gcloud_plugin_installed("gke-gcloud-auth-plugin"): |
| 109 | + return |
| 110 | + raise CheckFailedError( |
| 111 | + f"""The `gke-gcloud-auth-plugin` gcloud component is not installed |
| 112 | +
|
| 113 | +{get_gke_gcloud_auth_plugin_instructions()}""" |
| 114 | + ) |
| 115 | + |
| 116 | + |
| 117 | +def check_all(): |
| 118 | + click.echo("Checking environment...") |
| 119 | + for check in [ |
| 120 | + check_docker, |
| 121 | + check_gcloud_auth_login, |
| 122 | + check_gcr_io, |
| 123 | + check_docker_access, |
| 124 | + check_kubectl, |
| 125 | + check_gke_gcloud_auth_plugin, |
| 126 | + ]: |
| 127 | + assert check.__doc__ is not None |
| 128 | + click.echo(check.__doc__ + "..", nl=False) |
| 129 | + try: |
| 130 | + check() |
| 131 | + except CheckFailedError as e: |
| 132 | + click.echo() |
| 133 | + click.echo() |
| 134 | + click.echo(f"❌ Error during {check.__name__} ❌") |
| 135 | + click.echo(e) |
| 136 | + sys.exit(-1) |
| 137 | + click.echo(" ✅") |
| 138 | + click.echo( |
| 139 | + "🎉 All checks passed. You should be ready to launch distributed training. 🎉" |
| 140 | + ) |
| 141 | + |
| 142 | + |
| 143 | +def get_kubectl_install_instructions(): |
| 144 | + # If gcloud is installed via `apt`, then we should do the same for `kubectl`. |
| 145 | + if is_package_installed("google-cloud-cli"): |
| 146 | + return """ |
| 147 | +Since `gcloud` is installed with `apt`, please install `kubectl` with: |
| 148 | +
|
| 149 | + sudo apt install kubectl |
| 150 | +
|
| 151 | +""".lstrip() |
| 152 | + |
| 153 | + # Otherwise, point users to the GKE docs. |
| 154 | + return "Please visit \ |
| 155 | +https://cloud.google.com/kubernetes-engine/docs/how-to/cluster-access-for-kubectl#install_kubectl" |
| 156 | + |
| 157 | + |
| 158 | +def get_gke_gcloud_auth_plugin_instructions(): |
| 159 | + # If gcloud is installed via `apt`, then we should do the same for this plugin. |
| 160 | + if is_package_installed("google-cloud-cli"): |
| 161 | + return """ |
| 162 | +Since `gcloud` is installed with `apt`, please install `gke-gcloud-auth-plugin` with: |
| 163 | +
|
| 164 | + sudo apt install google-cloud-sdk-gke-gcloud-auth-plugin |
| 165 | +
|
| 166 | +""".lstrip() |
| 167 | + |
| 168 | + # Otherwise, point users to the docs. |
| 169 | + return "Please visit \ |
| 170 | +https://cloud.google.com/kubernetes-engine/docs/how-to/cluster-access-for-kubectl#install_plugin" |
| 171 | + |
| 172 | + |
| 173 | +def is_package_installed(package_name): |
| 174 | + try: |
| 175 | + # Run the dpkg-query command to check for the package |
| 176 | + subprocess.run( |
| 177 | + ["dpkg-query", "-W", "-f='${Status}'", package_name], |
| 178 | + check=True, |
| 179 | + capture_output=True, |
| 180 | + ) |
| 181 | + return True |
| 182 | + except subprocess.CalledProcessError: |
| 183 | + return False |
| 184 | + |
| 185 | + |
| 186 | +def is_gcloud_plugin_installed(plugin_name): |
| 187 | + try: |
| 188 | + # Run `gcloud components list` to get installed components |
| 189 | + result = subprocess.run( |
| 190 | + ["gcloud", "components", "list", "--format=json", f"--filter={plugin_name}"], |
| 191 | + check=True, |
| 192 | + capture_output=True, |
| 193 | + text=True, |
| 194 | + ) |
| 195 | + # Parse the output and look for the plugin |
| 196 | + components = json.loads(result.stdout) |
| 197 | + for component in components: |
| 198 | + if component.get("id") == plugin_name: |
| 199 | + state = component.get("state") |
| 200 | + if state == "Installed" or ( |
| 201 | + isinstance(state, dict) and state.get("name") == "Installed" |
| 202 | + ): |
| 203 | + return True |
| 204 | + return False |
| 205 | + except subprocess.CalledProcessError as e: |
| 206 | + print(f"Error running gcloud command: {e.stderr}") |
| 207 | + return False |
| 208 | + except json.JSONDecodeError: |
| 209 | + print("Error parsing JSON output from gcloud.") |
| 210 | + return False |
| 211 | + |
| 212 | + |
| 213 | +if __name__ == "__main__": |
| 214 | + check_all() |
0 commit comments