diff --git a/e2e_testing/update_step_time.py b/e2e_testing/update_step_time.py index 4ead058f..27f8a289 100755 --- a/e2e_testing/update_step_time.py +++ b/e2e_testing/update_step_time.py @@ -11,10 +11,10 @@ from datetime import datetime, timedelta from pathlib import Path -import click import numpy as np import scipy import yaml +from absl import app, flags from google.cloud import bigquery from rich.console import Console from rich.table import Table @@ -246,66 +246,36 @@ def compute_bounds(step_times, confidence_level): return lower_bound, upper_bound -@click.command() -@click.option( - "--bq-project", - default="tpu-pytorch", - help="BigQuery project ID", +FLAGS = flags.FLAGS +flags.DEFINE_string("bq-project", "tpu-pytorch", "BigQuery project ID") +flags.DEFINE_string("bq-dataset", "benchmark_dataset_test", "BigQuery dataset name") +flags.DEFINE_string("bq-table", "torchprime-e2e-tests", "BigQuery table name") +flags.DEFINE_string( + "start-time", + parse_days_ago("5 days ago").strftime("%Y-%m-%d %H:%M:%S"), + "Start time for the query in GoogleSQL datetime format (e.g., '2025-05-29 17:52:00 America/Los_Angeles'). " + "Can also accept common datetime formats which will be converted. " + "In particular, supports '[N] days ago' format, e.g., '2 days ago'. " + "Defaults to 5 days ago.", ) -@click.option( - "--bq-dataset", - default="benchmark_dataset_test", - help="BigQuery dataset name", +flags.DEFINE_string( + "end-time", + datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "End time for the query in GoogleSQL datetime format (e.g., '2025-06-01 20:00:00 America/Los_Angeles'). " + "Can also accept common datetime formats which will be converted. " + "In particular, supports '[N] days ago' format, e.g., '2 days ago'. " + "Defaults to the current time.", ) -@click.option( - "--bq-table", - default="torchprime-e2e-tests", - help="BigQuery table name", +flags.DEFINE_integer("limit", 1200, "Maximum number of rows to retrieve") +flags.DEFINE_string( + "output", + "e2e_testing/step_time_bounds.yaml", + "Output YAML file path", ) -@click.option( - "--start-time", - default=parse_days_ago("5 days ago").strftime("%Y-%m-%d %H:%M:%S"), - help="Start time for the query in GoogleSQL datetime format (e.g., '2025-05-29 17:52:00 America/Los_Angeles'). " - "Can also accept common datetime formats which will be converted. " - "In particular, supports '[N] days ago' format, e.g., '2 days ago'. " - "Defaults to 5 days ago.", -) -@click.option( - "--end-time", - default=datetime.now().strftime("%Y-%m-%d %H:%M:%S"), - help="End time for the query in GoogleSQL datetime format (e.g., '2025-06-01 20:00:00 America/Los_Angeles'). " - "Can also accept common datetime formats which will be converted. " - "In particular, supports '[N] days ago' format, e.g., '2 days ago'. " - "Defaults to the current time.", -) -@click.option( - "--limit", - default=1200, - type=int, - help="Maximum number of rows to retrieve", -) -@click.option( - "--output", - default="e2e_testing/step_time_bounds.yaml", - type=click.Path(), - help="Output YAML file path", -) -@click.option( - "--confidence_level", - default=99.0, - type=float, - help="Confidence level, default is 99%", -) -def main( - bq_project, - bq_dataset, - bq_table, - start_time, - end_time, - limit, - output, - confidence_level, -): +flags.DEFINE_float("confidence_level", 99.0, "Confidence level, default is 99%") + + +def main(argv): """ Query BigQuery for E2E test results and compute step time bounds. @@ -314,19 +284,19 @@ def main( the results to a YAML file for use in GitHub Actions. """ console = Console() - confidence_level = confidence_level / 100.0 + confidence_level = FLAGS.confidence_level / 100.0 # Parse datetime inputs - start_time = parse_datetime(start_time) - end_time = parse_datetime(end_time) + start_time = parse_datetime(FLAGS['start-time'].value) + end_time = parse_datetime(FLAGS['end-time'].value) console.print("[bold]Querying BigQuery:[/bold]") - console.print(f" Project: {bq_project}") - console.print(f" Dataset: {bq_dataset}") - console.print(f" Table: {bq_table}") + console.print(f" Project: {FLAGS.bq_project}") + console.print(f" Dataset: {FLAGS.bq_dataset}") + console.print(f" Table: {FLAGS.bq_table}") console.print(f" Start: {start_time}") console.print(f" End: {end_time}") - console.print(f" Limit: {limit}") + console.print(f" Limit: {FLAGS.limit}") client = bigquery.Client() @@ -335,7 +305,7 @@ def main( SELECT * FROM - `{bq_project}`.`{bq_dataset}`.`{bq_table}` + `{FLAGS.bq_project}`.`{FLAGS.bq_dataset}`.`{FLAGS.bq_table}` WHERE software_id = '{TORCHPRIME_SOFTWARE_ID}' AND update_timestamp >= TIMESTAMP('{start_time}') AND @@ -343,7 +313,7 @@ def main( ORDER BY update_timestamp DESC LIMIT - {limit}; + {FLAGS.limit}; """ query_job = client.query(query) @@ -421,7 +391,7 @@ def main( console.print(table) # Write to file - output_path = Path(output) + output_path = Path(FLAGS.output) output_path.parent.mkdir(exist_ok=True, parents=True) with open(output_path, "w") as f: @@ -443,4 +413,4 @@ def main( if __name__ == "__main__": - main() + app.run(main) diff --git a/pyproject.toml b/pyproject.toml index f5e2de4f..6efb8b97 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,8 +24,8 @@ dependencies = [ "dataclasses-json==0.6.7", "benchmark-db-writer @ git+https://github.com/AI-Hypercomputer/aotc.git@2ff16e670df20b497ddaf1f86920dbb5dd9f0c8f#subdirectory=src/aotc/benchmark_db_writer", "dacite==1.9.2", - "click~=8.1.8", - "google-cloud-storage==2.19.0" + "google-cloud-storage==2.19.0", + "absl-py" ] [project.optional-dependencies] diff --git a/torchprime/launcher/benchmark_db_util.py b/torchprime/launcher/benchmark_db_util.py index 953ac61e..47e258fa 100644 --- a/torchprime/launcher/benchmark_db_util.py +++ b/torchprime/launcher/benchmark_db_util.py @@ -1,9 +1,9 @@ import os +import sys import uuid from datetime import datetime from pathlib import Path -import click from omegaconf import OmegaConf from torchprime.metrics.metrics import Metrics @@ -36,7 +36,7 @@ def get_metrics(base_artifact_path: str, jobset_name_for_outputs: str) -> dict | ) ) if not metric_file_path.exists(): - click.echo(f"Metrics file not found at {metric_file_path}", err=True) + print(f"Metrics file not found at {metric_file_path}", file=sys.stderr) return None metrics_data = Metrics.load(metric_file_path) @@ -71,7 +71,7 @@ def get_config(base_artifact_path: str, jobset_name_for_outputs: str) -> dict | ) if not config_file_path.exists(): - click.echo(f"Config file not found at {config_file_path}", err=True) + print(f"Config file not found at {config_file_path}", file=sys.stderr) return None # Load the JSON configuration file and convert it to a Python dictionary @@ -106,7 +106,7 @@ def prepare_benchmark_summary( hardware_num_chips, tflops_per_chip = get_num_chips_and_tflops_per_chip(tpu_type) hardware_id = tpu_type.split("-")[0] # Extract the TPU generation (e.g., v4, v5e) - click.echo( + print( f"Tpu type: {tpu_type}, hardware_num_chips: {hardware_num_chips}, tflops_per_chip: {tflops_per_chip}" ) diff --git a/torchprime/launcher/cli.py b/torchprime/launcher/cli.py index a5134337..bcf96ac5 100644 --- a/torchprime/launcher/cli.py +++ b/torchprime/launcher/cli.py @@ -14,11 +14,12 @@ from datetime import datetime from pathlib import Path -import click import toml +from absl import app, flags from dataclasses_json import dataclass_json from pathspec import PathSpec -from pathspec.patterns import GitWildMatchPattern # type: ignore +from pathspec.patterns import GitWildMatchPattern +from rich.text import Text from watchdog.events import FileSystemEventHandler from watchdog.observers import Observer @@ -50,100 +51,35 @@ class Config: docker_project: str | None = None -def interactive(f): - @click.pass_context - def wrapper(ctx, *args, **kwargs): - return run_with_watcher(ctx)(f)(*args, **kwargs) - - wrapper.__name__ = f.__name__ - wrapper.__doc__ = f.__doc__ - return wrapper - - -@click.group() -@click.option( - "-i", - "--interactive", - is_flag=True, - default=False, - help="Re-run the command whenever a file is edited (useful for fast dev/test iteration)", -) -@click.pass_context -def cli(ctx, interactive): - """ - tp is a CLI for common torchprime workflows. - """ - ctx.ensure_object(dict) - ctx.obj["interactive"] = interactive - - -@cli.command() -@click.option("--cluster", required=True, help="Name of the XPK cluster") -@click.option("--project", required=True, help="GCP project the cluster belongs to") -@click.option("--zone", required=True, help="Compute zone the cluster is located in") -@click.option( - "--num-slices", - required=False, - type=int, - default=1, - help="Number of TPU slice to use by default. Defaults to 1", -) -@click.option( - "--tpu-type", - required=True, - help="The TPU accelerator type in each slice. E.g. v6e-256 for a 256 chip Trillium pod", +FLAGS = flags.FLAGS +flags.DEFINE_boolean( + "interactive", + False, + "Re-run the command whenever a file is edited (useful for fast dev/test iteration)", ) -@click.option( - "--artifact-dir", - required=True, - help="A Google Cloud Storage directory where artifacts such as profiles will be stored. \ -E.g. gs://foo/bar", -) -@click.option( - "--upload-metrics", - required=False, - is_flag=True, - default=False, - help="If given, uploads metrics to the database ", -) -@click.option( - "--bq-project", - required=False, - default="tpu-pytorch", - help="A bigquery project to upload metrics.", -) -@click.option( - "--bq-dataset", - required=False, - default="benchmark_dataset_test", - help="A bigquery dataset to upload metrics.", -) -@click.option( - "--bq-table", - required=False, - default="benchmark_experiment", - help="A bigquery table to upload metrics.", -) -@click.option( - "--docker-project", - required=False, - default=None, - help="GCP project to upload docker containers to. If not set, defaults to the cluster's\ - GCP project", -) -def use( - cluster: str, - project: str, - zone: str, - num_slices: int, - tpu_type: str, - artifact_dir: str, - upload_metrics: bool, - bq_project: str, - bq_dataset: str, - bq_table: str, - docker_project: str | None, -): + +# Flags for `use` command +flags.DEFINE_string("cluster", None, "Name of the XPK cluster") +flags.DEFINE_string("project", None, "GCP project the cluster belongs to") +flags.DEFINE_string("zone", None, "Compute zone the cluster is located in") +flags.DEFINE_integer("num-slices", 1, "Number of TPU slice to use by default. Defaults to 1") +flags.DEFINE_string("tpu-type", None, "The TPU accelerator type in each slice. E.g. v6e-256 for a 256 chip Trillium pod") +flags.DEFINE_string("artifact-dir", None, "A Google Cloud Storage directory where artifacts such as profiles will be stored. E.g. gs://foo/bar") +flags.DEFINE_boolean("upload-metrics", False, "If given, uploads metrics to the database ") +flags.DEFINE_string("bq-project", "tpu-pytorch", "A bigquery project to upload metrics.") +flags.DEFINE_string("bq-dataset", "benchmark_dataset_test", "A bigquery dataset to upload metrics.") +flags.DEFINE_string("bq-table", "benchmark_experiment", "A bigquery table to upload metrics.") +flags.DEFINE_string("docker-project", None, "GCP project to upload docker containers to. If not set, defaults to the cluster's GCP project") + +# Flags for `run` command +flags.DEFINE_string("name", None, "Name of the workload (jobset). If not specified, defaults to one based on the date and time.") +flags.DEFINE_string("base-docker-url", None, "If specified, `tp run` will use this PyTorch/XLA base docker image instead of the one pinned inside `pyproject.toml`") +flags.DEFINE_boolean("use-hf", False, "Use HuggingFace transformer") +flags.DEFINE_boolean("use-local-wheel", False, "Use local torch and torch_xla wheels under folder local_dist/") +flags.DEFINE_string("comments", None, "Optional description of the training run, stored in the database.") + + +def use(): """ Sets up various config like XPK cluster name, GCP project, etc for all subsequent commands to use. Typically, you would only run this command once when @@ -153,31 +89,31 @@ def use( have to type the project and zone if you drop down to xpk. """ config = Config( - cluster=cluster, - project=project, - zone=zone, - num_slices=num_slices, - tpu_type=tpu_type, - artifact_dir=artifact_dir, - upload_metrics=upload_metrics, - bq_project=bq_project, - bq_dataset=bq_dataset, - bq_table=bq_table, - docker_project=docker_project, + cluster=FLAGS.cluster, + project=FLAGS.project, + zone=FLAGS.zone, + num_slices=FLAGS['num-slices'].value, + tpu_type=FLAGS['tpu-type'].value, + artifact_dir=FLAGS['artifact-dir'].value, + upload_metrics=FLAGS['upload-metrics'].value, + bq_project=FLAGS['bq-project'].value, + bq_dataset=FLAGS['bq-dataset'].value, + bq_table=FLAGS['bq-table'].value, + docker_project=FLAGS['docker-project'].value, ) - gcloud_config_name = f"torchprime-{project}-{zone}" + gcloud_config_name = f"torchprime-{FLAGS.project}-{FLAGS.zone}" create_and_activate_gcloud(gcloud_config_name, config) - assert artifact_dir.startswith("gs://"), ( - f"{artifact_dir} must be in a GCS bucket (start with gs://)" + assert FLAGS['artifact-dir'].value.startswith("gs://"), ( + f"{FLAGS['artifact-dir'].value} must be in a GCS bucket (start with gs://)" ) path = write_config(config) - click.echo(f"Written config {path.relative_to(os.getcwd())}") + print(f"Written config {path.relative_to(os.getcwd())}") torchprime.launcher.doctor.check_all(config) def create_and_activate_gcloud(gcloud_config_name, config: Config): - click.echo("Activating gcloud config...") + print("Activating gcloud config...") ensure_command("gcloud") all_configurations = json.loads( subprocess.check_output( @@ -242,22 +178,14 @@ def create_and_activate_gcloud(gcloud_config_name, config: Config): ) -@cli.command( - name="docker-run", - context_settings=dict( - ignore_unknown_options=True, - ), -) -@click.argument("args", nargs=-1, type=click.UNPROCESSED) -@click.option("--use-hf", is_flag=True, help="Use HuggingFace transformer") -def docker_run(args, use_hf: bool): +def docker_run(argv): """ Runs the provided training command locally for quick testing. """ - click.echo(get_project_dir().absolute()) + print(get_project_dir().absolute()) # Build docker image. - build_arg = ["USE_TRANSFORMERS=true"] if use_hf else None + build_arg = ["USE_TRANSFORMERS=true"] if FLAGS['use-hf'].value else None placeholder_url = "torchprime-dev:local" docker_url = buildpush( push_docker=False, placeholder_url=placeholder_url, build_arg=build_arg @@ -266,7 +194,7 @@ def docker_run(args, use_hf: bool): env_forwarding = [ arg for env_var in _DOCKER_ENV_FORWARD_LIST for arg in forward_env(env_var) ] - args = list(v for v in args if v != "") + args = list(v for v in argv[1:] if v != "") command = [ "python", ] + list(args) @@ -288,68 +216,19 @@ def docker_run(args, use_hf: bool): run_docker(docker_command) -@cli.command( - context_settings=dict( - ignore_unknown_options=True, - ) -) -@click.argument("args", nargs=-1, type=click.UNPROCESSED) -@click.option( - "--name", - required=False, - help="Name of the workload (jobset). If not specified, " - "defaults to one based on the date and time.", - default=None, -) -@click.option( - "--base-docker-url", - required=False, - help="If specified, `tp run` will use this PyTorch/XLA base docker image instead of " - "the one pinned inside `pyproject.toml`", - default=None, -) -@click.option( - "--num-slices", - required=False, - type=int, - default=None, - help="Temporarily override the number of TPU slice to use for this run. " - "If unspecified, `tp run` will use the slice count configured in `tp use`.", -) -@click.option("--use-hf", is_flag=True, help="Use HuggingFace transformer") -@click.option( - "--use-local-wheel", - is_flag=True, - help="Use local torch and torch_xla wheels under folder local_dist/", -) -@click.option( - "--comments", - required=False, - default=None, - help="Optional description of the training run, stored in the database.", -) -@interactive -def run( - args, - name: str | None, - base_docker_url: str | None, - num_slices: int | None, - use_hf: bool, - use_local_wheel: bool, - comments: str | None, -): +def run(argv): """ Runs the provided SPMD training command as an xpk job on a GKE cluster. """ config = read_config() - click.echo(get_project_dir().absolute()) + print(get_project_dir().absolute()) # Build docker image. build_arg = [] - if use_hf: + if FLAGS['use-hf'].value: build_arg.append("USE_TRANSFORMERS=true") - if use_local_wheel: + if FLAGS['use-local-wheel'].value: build_arg.append("USE_LOCAL_WHEEL=true") docker_project = config.docker_project if docker_project is None: @@ -357,11 +236,11 @@ def run( docker_url = buildpush( torchprime_project_id=docker_project, build_arg=build_arg, - base_docker_url=base_docker_url, + base_docker_url=FLAGS['base-docker-url'].value, ) # Submit xpk workload - workload_name = name + workload_name = FLAGS.name if workload_name is None: datetime_str = datetime.now().strftime("%Y%m%d-%H%M%S") workload_name = ( @@ -379,8 +258,9 @@ def run( """ ) - command = ["python", "torchprime/launcher/thunk.py"] + list(args) + command = ["python", "torchprime/launcher/thunk.py"] + list(argv[1:]) + num_slices = FLAGS['num-slices'].value if num_slices is None: num_slices = config.num_slices @@ -401,7 +281,7 @@ def run( "--env", f"TORCHPRIME_JOBSET_NAME={workload_name}", "--env", - f"TORCHPRIME_COMMENTS={comments}", + f"TORCHPRIME_COMMENTS={FLAGS.comments}", "--env", f"TORCHPRIME_DOCKER_URL={docker_url}", "--env", @@ -454,38 +334,29 @@ def run( ) subprocess.run(xpk_command, check=True) - styled_workload = click.style(workload_name, bold=True, fg="green") - styled_cluster = click.style(config.cluster, bold=True, fg="green") - styled_artifacts = click.style( - f"{config.artifact_dir}/{workload_name}", bold=True, fg="green" + styled_workload = Text(workload_name, style="bold green") + styled_cluster = Text(config.cluster, style="bold green") + styled_artifacts = Text( + f"{config.artifact_dir}/{workload_name}", style="bold green" ) - click.echo(f""" + print(f""" Workload {styled_workload} submitted to cluster {styled_cluster} Artifacts are stored at {styled_artifacts} """) -@cli.command( - context_settings=dict( - ignore_unknown_options=True, - ) -) -@click.argument("args", nargs=-1, type=click.UNPROCESSED) -@interactive -def test(args): +def test(argv): """ Runs unit tests in torchprime by forwarding arguments to pytest. """ ensure_command("pytest") try: - subprocess.run(["pytest"] + list(args), check=True) + subprocess.run(["pytest"] + list(argv[1:]), check=True) except subprocess.CalledProcessError as e: sys.exit(e.returncode) -@cli.command() -@interactive def doctor(): """ Checks for any problems in your environment (missing packages, credentials, etc.). @@ -505,12 +376,12 @@ def run(self, command, **kwargs): ) self.outputs += b"\n" except subprocess.CalledProcessError as e: - click.echo("Previous command outputs:") - click.echo(self.outputs.decode("utf-8")) - click.echo() - click.echo(f"❌ Error running `{' '.join(command)}` ❌") - click.echo() - click.echo(e.stdout) + print("Previous command outputs:") + print(self.outputs.decode("utf-8")) + print() + print(f"❌ Error running `{' '.join(command)}` ❌") + print() + print(e.stdout) sys.exit(-1) @@ -537,7 +408,7 @@ def write_config(config: Config): config_dir = get_config_dir() config_dir.mkdir(exist_ok=True) default_config = config_dir / DEFAULT_CONFIG_NAME - default_config.write_text(toml.dumps(config.to_dict())) # type:ignore + default_config.write_text(toml.dumps(config.to_dict())) return default_config @@ -545,7 +416,7 @@ def read_config() -> Config: config_path = get_config_dir() / DEFAULT_CONFIG_NAME if not config_path.exists(): raise RuntimeError(f"No config found at {config_path}. Run `tp use` first.") - return Config.from_dict(toml.load(config_path)) # type:ignore + return Config.from_dict(toml.load(config_path)) def ensure_command(name: str): @@ -561,8 +432,7 @@ def ensure_command(name: str): class FileChangeHandler(FileSystemEventHandler): - def __init__(self, command_context, gitignore_spec): - self.command_context = command_context + def __init__(self, gitignore_spec): self.gitignore_spec = gitignore_spec self.last_trigger_time = time.time() self.last_modified_file = "" @@ -606,18 +476,17 @@ def run_command_thread_fn(self): self.file_modified.wait() last_modified_file = self.last_modified_file if last_modified_file: - click.echo(f""" + print(f""" File {last_modified_file} modified, rerunning command... """) - sys.argv[1] = sys.argv[1].replace("-i", "").replace("--interactive", "").strip() - main_command = " ".join(s for s in sys.argv[1:] if s != "") + main_command = " ".join(s for s in sys.argv if s != "-i" and s != "--interactive") subprocess.run(f"tp {main_command}", shell=True, check=False) - click.echo(f""" + print(f""" Done running `tp {main_command}`. """) -def watch_directory(project_dir, command_context): +def watch_directory(project_dir): # Load gitignore patterns gitignore_patterns = [] gitignore_path = os.path.join(project_dir, ".gitignore") @@ -628,7 +497,7 @@ def watch_directory(project_dir, command_context): # Create PathSpec object from gitignore gitignore_spec = PathSpec.from_lines(GitWildMatchPattern, gitignore_patterns) - event_handler = FileChangeHandler(command_context, gitignore_spec) + event_handler = FileChangeHandler(gitignore_spec) observer = Observer() observer.schedule(event_handler, project_dir, recursive=True) observer.start() @@ -641,26 +510,37 @@ def watch_directory(project_dir, command_context): observer.join() -def run_with_watcher(ctx): - """Wrapper to run commands with file watching if interactive mode is enabled""" +def main(argv): + if len(argv) < 2: + print("Usage: tp [options]") + return - def decorator(f): - def wrapper(*args, **kwargs): - # If interactive mode is enabled, start watching for changes - if ctx.obj.get("interactive"): - project_dir = get_project_dir() - click.echo( - f"Watching directory {project_dir} for changes. Press Ctrl+C to stop.\n" - ) - watch_directory(project_dir, ctx) - else: - # Just run the command - return f(*args, **kwargs) - - return wrapper + command = argv[1] + if FLAGS.interactive: + project_dir = get_project_dir() + print( + f"Watching directory {project_dir} for changes. Press Ctrl+C to stop.\n" + ) + watch_directory(project_dir) + return + + if command == "use": + use() + elif command == "docker-run": + docker_run(argv) + elif command == "run": + run(argv) + elif command == "test": + test(argv) + elif command == "doctor": + doctor() + else: + print(f"Unknown command: {command}") + return - return decorator +def cli(): + app.run(main) if __name__ == "__main__": cli() diff --git a/torchprime/launcher/doctor.py b/torchprime/launcher/doctor.py index 614ee5b8..f983476c 100644 --- a/torchprime/launcher/doctor.py +++ b/torchprime/launcher/doctor.py @@ -11,8 +11,6 @@ from pathlib import Path from typing import TypeVar -import click - ConfigT = TypeVar("ConfigT", bound="Config") # noqa: F821 @@ -155,7 +153,7 @@ def check_gke_cluster_exist(config: ConfigT | None = None): def check_all(config: ConfigT | None = None): - click.echo("Checking environment...") + print("Checking environment...") check_list = [ check_docker, check_gcloud_auth_login, @@ -168,20 +166,20 @@ def check_all(config: ConfigT | None = None): check_list.append(check_gke_cluster_exist) for check in check_list: assert check.__doc__ is not None - click.echo(check.__doc__ + "..", nl=False) + print(check.__doc__ + "..", end="") try: try: check(config) except TypeError: check() except CheckFailedError as e: - click.echo() - click.echo() - click.echo(f"❌ Error during {check.__name__} ❌") - click.echo(e) + print() + print() + print(f"❌ Error during {check.__name__} ❌") + print(e) sys.exit(-1) - click.echo(" ✅") - click.echo( + print(" ✅") + print( "🎉 All checks passed. You should be ready to launch distributed training. 🎉" ) diff --git a/torchprime/launcher/thunk.py b/torchprime/launcher/thunk.py index 9203305a..6903aab3 100644 --- a/torchprime/launcher/thunk.py +++ b/torchprime/launcher/thunk.py @@ -4,8 +4,6 @@ from datetime import datetime from pathlib import Path -import click - from torchprime.launcher import upload_metrics_to_bq # Workaround for MegaScale crash @@ -84,7 +82,7 @@ if upload_metrics.lower() == "true" and slice_id == "0" and worker_id == "0": try: - click.echo( + print( f"Primary worker ({host_name}) attempting to upload metrics for job '{jobset_name}'...", ) upload_metrics_to_bq.collect_and_upload_benchmark_summary( @@ -93,5 +91,5 @@ mounted_artifact_path_str=str(mounted_artifact_dir), ) except Exception as e: - click.echo(f"Error uploading results to BigQuery: {e}", err=True) + print(f"Error uploading results to BigQuery: {e}", file=sys.stderr) sys.exit(process.returncode) diff --git a/torchprime/launcher/util.py b/torchprime/launcher/util.py index 030a9d73..8bf2050c 100644 --- a/torchprime/launcher/util.py +++ b/torchprime/launcher/util.py @@ -3,8 +3,6 @@ import grp import subprocess -import click - @functools.lru_cache def is_sudoless_docker() -> bool: @@ -16,11 +14,11 @@ def is_sudoless_docker() -> bool: groups_for_user = [g.gr_name for g in grp.getgrall() if user in g.gr_mem] is_sudoless_docker = "docker" in groups_for_user if is_sudoless_docker: - click.echo( + print( f"User {user} is in the 'docker' group. You can run Docker commands without sudo." ) else: - click.echo( + print( f"User {user} is NOT in the 'docker' group. " "sudo is needed to run Docker commands." ) @@ -33,7 +31,7 @@ def run_docker(command: str | list[str]): command = ["docker"] + command if not is_sudoless_docker(): command = ["sudo"] + command - click.echo(" ".join(command)) + print(" ".join(command)) subprocess.run( command, shell=False,