Skip to content

Commit a908b1b

Browse files
authored
Add a tp doctor command (#56)
* tp doctor * check for kubectl * Implement gke-gcloud-auth-plugin check * Fine tune checks * Implement docker check * Implement docker access check * update msg
1 parent b5e079c commit a908b1b

File tree

3 files changed

+279
-10
lines changed

3 files changed

+279
-10
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ dev = [
3434
"dataclasses-json~=0.6.7",
3535
"watchdog~=6.0.0",
3636
"pathspec~=0.12.1",
37-
"xpk@git+https://github.com/AI-Hypercomputer/xpk"
37+
"xpk@git+https://github.com/AI-Hypercomputer/xpk@f33d1a6772c3ca73dd68e45b226621534a39b2a5"
3838
]
3939

4040
[project.scripts]

torchprime/launcher/cli.py

Lines changed: 64 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from watchdog.events import FileSystemEventHandler
2121
from watchdog.observers import Observer
2222

23+
import torchprime.launcher.doctor
2324
from torchprime.launcher.buildpush import buildpush
2425

2526

@@ -41,12 +42,17 @@ def wrapper(ctx, *args, **kwargs):
4142
return run_with_watcher(ctx)(f)(*args, **kwargs)
4243

4344
wrapper.__name__ = f.__name__
45+
wrapper.__doc__ = f.__doc__
4446
return wrapper
4547

4648

4749
@click.group()
4850
@click.option(
49-
"-i", "--interactive", is_flag=True, default=False, help="Enable shouting mode."
51+
"-i",
52+
"--interactive",
53+
is_flag=True,
54+
default=False,
55+
help="Re-run the command whenever a file is edited (useful for fast dev/test iteration)",
5056
)
5157
@click.pass_context
5258
def cli(ctx, interactive):
@@ -119,9 +125,11 @@ def use(
119125

120126
path = write_config(config)
121127
click.echo(f"Written config {path.relative_to(os.getcwd())}")
128+
torchprime.launcher.doctor.check_all()
122129

123130

124131
def create_and_activate_gcloud(gcloud_config_name, config: Config):
132+
click.echo("Activating gcloud config...")
125133
ensure_command("gcloud")
126134
all_configurations = json.loads(
127135
subprocess.check_output(
@@ -134,38 +142,55 @@ def create_and_activate_gcloud(gcloud_config_name, config: Config):
134142
if gcloud_config["name"] == gcloud_config_name:
135143
existing = True
136144
break
145+
runner = CommandRunner()
137146
if existing:
138-
subprocess.check_output(
147+
runner.run(
139148
[
140149
"gcloud",
141150
"config",
142151
"configurations",
143152
"activate",
144153
gcloud_config_name,
145-
]
154+
],
146155
)
147156
else:
148-
subprocess.check_output(
149-
["gcloud", "config", "configurations", "create", gcloud_config_name, "--activate"]
157+
runner.run(
158+
[
159+
"gcloud",
160+
"config",
161+
"configurations",
162+
"create",
163+
gcloud_config_name,
164+
"--activate",
165+
],
150166
)
151167

152-
subprocess.check_output(
168+
runner.run(
169+
[
170+
"gcloud",
171+
"auth",
172+
"application-default",
173+
"set-quota-project",
174+
config.project,
175+
],
176+
)
177+
runner.run(
153178
[
154179
"gcloud",
155180
"config",
156181
"set",
157182
"compute/zone",
158183
config.zone,
159-
]
184+
],
160185
)
161-
subprocess.check_output(
186+
runner.run(
162187
[
163188
"gcloud",
164189
"config",
165190
"set",
166191
"project",
167192
config.project,
168-
]
193+
],
169194
)
170195

171196

@@ -251,6 +276,36 @@ def test(args):
251276
sys.exit(e.returncode)
252277

253278

279+
@cli.command()
280+
@interactive
281+
def doctor():
282+
"""
283+
Checks for any problems in your environment (missing packages, credentials, etc.).
284+
"""
285+
torchprime.launcher.doctor.check_all()
286+
287+
288+
class CommandRunner:
289+
def __init__(self):
290+
self.outputs = b""
291+
292+
def run(self, command, **kwargs):
293+
try:
294+
self.outputs += f">> {' '.join(command)}\n".encode()
295+
self.outputs += subprocess.check_output(
296+
command, **kwargs, stderr=subprocess.STDOUT
297+
)
298+
self.outputs += b"\n"
299+
except subprocess.CalledProcessError as e:
300+
click.echo("Previous command outputs:")
301+
click.echo(self.outputs.decode("utf-8"))
302+
click.echo()
303+
click.echo(f"❌ Error running `{' '.join(command)}` ❌")
304+
click.echo()
305+
click.echo(e.stdout)
306+
sys.exit(-1)
307+
308+
254309
def forward_env(name: str) -> list[str]:
255310
if name in os.environ:
256311
return ["--env", f"{name}={os.environ[name]}"]

torchprime/launcher/doctor.py

Lines changed: 214 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,214 @@
1+
"""
2+
Doctor checks for essential programs needed to launch distributed training.
3+
"""
4+
5+
import getpass
6+
import grp
7+
import json
8+
import os
9+
import subprocess
10+
import sys
11+
from pathlib import Path
12+
13+
import click
14+
15+
16+
class CheckFailedError(Exception):
17+
pass
18+
19+
20+
def check_docker():
21+
"""Check that docker is installed."""
22+
try:
23+
subprocess.run(["docker", "help"], check=True, capture_output=True)
24+
except FileNotFoundError:
25+
raise CheckFailedError("docker not found. Please install docker first.") from None
26+
27+
28+
def check_gcr_io():
29+
"""Check that docker config contains gcr.io credential helper."""
30+
try:
31+
docker_config = json.loads(
32+
Path(os.path.expanduser("~/.docker/config.json")).read_text()
33+
)
34+
cred_helpers = docker_config["credHelpers"]
35+
_gcr_io = cred_helpers["gcr.io"]
36+
except (FileNotFoundError, KeyError, json.JSONDecodeError):
37+
user = getpass.getuser()
38+
groups_for_user = [g.gr_name for g in grp.getgrall() if user in g.gr_mem]
39+
setup_cmd = "gcloud auth configure-docker"
40+
if "docker" not in groups_for_user:
41+
setup_cmd = f"sudo {setup_cmd}"
42+
raise CheckFailedError(
43+
f"""
44+
Did not find a handler for `gcr.io` in docker credential helpers.
45+
46+
TorchPrime uploads docker containers to the `gcr.io` docker registry, which
47+
requires valid credentials.
48+
49+
To setup the credentials, please run:
50+
51+
{setup_cmd}
52+
53+
""".lstrip()
54+
) from None
55+
56+
57+
def check_docker_access():
58+
"""Check that the gcloud account can access the gcr.io artifact registry."""
59+
try:
60+
subprocess.run(
61+
["gcloud", "artifacts", "repositories", "describe", "gcr.io", "--location=us"],
62+
check=True,
63+
capture_output=True,
64+
)
65+
except subprocess.CalledProcessError as e:
66+
account = subprocess.run(
67+
["gcloud", "config", "get-value", "account"], capture_output=True, text=True
68+
).stdout.strip()
69+
raise CheckFailedError(
70+
f"""The current gcloud account `{account}` cannot access the gcr.io registry.
71+
The account may not have the required permissions. If it's a service account, the
72+
VM may not have the correct scopes.
73+
74+
The easiest way to resolve this is to login with your own account:
75+
76+
gcloud auth login
77+
78+
"""
79+
) from e
80+
81+
82+
def check_gcloud_auth_login():
83+
"""Check that gcloud is logged in."""
84+
try:
85+
subprocess.run(
86+
["gcloud", "auth", "print-access-token"], check=True, capture_output=True
87+
)
88+
except subprocess.CalledProcessError as e:
89+
raise CheckFailedError(
90+
f"gcloud auth print-access-token failed: {e.stderr.decode()}"
91+
) from e
92+
93+
94+
def check_kubectl():
95+
"""Check that kubectl is installed."""
96+
try:
97+
subprocess.run(["kubectl", "help"], check=True, capture_output=True)
98+
except FileNotFoundError:
99+
raise CheckFailedError(
100+
f"""kubectl not found.
101+
102+
{get_kubectl_install_instructions()}"""
103+
) from None
104+
105+
106+
def check_gke_gcloud_auth_plugin():
107+
"""Check that gke-gcloud-auth-plugin is installed."""
108+
if is_gcloud_plugin_installed("gke-gcloud-auth-plugin"):
109+
return
110+
raise CheckFailedError(
111+
f"""The `gke-gcloud-auth-plugin` gcloud component is not installed
112+
113+
{get_gke_gcloud_auth_plugin_instructions()}"""
114+
)
115+
116+
117+
def check_all():
118+
click.echo("Checking environment...")
119+
for check in [
120+
check_docker,
121+
check_gcloud_auth_login,
122+
check_gcr_io,
123+
check_docker_access,
124+
check_kubectl,
125+
check_gke_gcloud_auth_plugin,
126+
]:
127+
assert check.__doc__ is not None
128+
click.echo(check.__doc__ + "..", nl=False)
129+
try:
130+
check()
131+
except CheckFailedError as e:
132+
click.echo()
133+
click.echo()
134+
click.echo(f"❌ Error during {check.__name__} ❌")
135+
click.echo(e)
136+
sys.exit(-1)
137+
click.echo(" ✅")
138+
click.echo(
139+
"🎉 All checks passed. You should be ready to launch distributed training. 🎉"
140+
)
141+
142+
143+
def get_kubectl_install_instructions():
144+
# If gcloud is installed via `apt`, then we should do the same for `kubectl`.
145+
if is_package_installed("google-cloud-cli"):
146+
return """
147+
Since `gcloud` is installed with `apt`, please install `kubectl` with:
148+
149+
sudo apt install kubectl
150+
151+
""".lstrip()
152+
153+
# Otherwise, point users to the GKE docs.
154+
return "Please visit \
155+
https://cloud.google.com/kubernetes-engine/docs/how-to/cluster-access-for-kubectl#install_kubectl"
156+
157+
158+
def get_gke_gcloud_auth_plugin_instructions():
159+
# If gcloud is installed via `apt`, then we should do the same for this plugin.
160+
if is_package_installed("google-cloud-cli"):
161+
return """
162+
Since `gcloud` is installed with `apt`, please install `gke-gcloud-auth-plugin` with:
163+
164+
sudo apt install google-cloud-sdk-gke-gcloud-auth-plugin
165+
166+
""".lstrip()
167+
168+
# Otherwise, point users to the docs.
169+
return "Please visit \
170+
https://cloud.google.com/kubernetes-engine/docs/how-to/cluster-access-for-kubectl#install_plugin"
171+
172+
173+
def is_package_installed(package_name):
174+
try:
175+
# Run the dpkg-query command to check for the package
176+
subprocess.run(
177+
["dpkg-query", "-W", "-f='${Status}'", package_name],
178+
check=True,
179+
capture_output=True,
180+
)
181+
return True
182+
except subprocess.CalledProcessError:
183+
return False
184+
185+
186+
def is_gcloud_plugin_installed(plugin_name):
187+
try:
188+
# Run `gcloud components list` to get installed components
189+
result = subprocess.run(
190+
["gcloud", "components", "list", "--format=json", f"--filter={plugin_name}"],
191+
check=True,
192+
capture_output=True,
193+
text=True,
194+
)
195+
# Parse the output and look for the plugin
196+
components = json.loads(result.stdout)
197+
for component in components:
198+
if component.get("id") == plugin_name:
199+
state = component.get("state")
200+
if state == "Installed" or (
201+
isinstance(state, dict) and state.get("name") == "Installed"
202+
):
203+
return True
204+
return False
205+
except subprocess.CalledProcessError as e:
206+
print(f"Error running gcloud command: {e.stderr}")
207+
return False
208+
except json.JSONDecodeError:
209+
print("Error parsing JSON output from gcloud.")
210+
return False
211+
212+
213+
if __name__ == "__main__":
214+
check_all()

0 commit comments

Comments
 (0)