Skip to content

refactor: Remove dependency on click library #356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 39 additions & 69 deletions e2e_testing/update_step_time.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from datetime import datetime, timedelta
from pathlib import Path

import click
import numpy as np
import scipy
import yaml
from absl import app, flags
from google.cloud import bigquery
from rich.console import Console
from rich.table import Table
Expand Down Expand Up @@ -246,66 +246,36 @@ def compute_bounds(step_times, confidence_level):
return lower_bound, upper_bound


@click.command()
@click.option(
"--bq-project",
default="tpu-pytorch",
help="BigQuery project ID",
FLAGS = flags.FLAGS
flags.DEFINE_string("bq-project", "tpu-pytorch", "BigQuery project ID")
flags.DEFINE_string("bq-dataset", "benchmark_dataset_test", "BigQuery dataset name")
flags.DEFINE_string("bq-table", "torchprime-e2e-tests", "BigQuery table name")
flags.DEFINE_string(
"start-time",
parse_days_ago("5 days ago").strftime("%Y-%m-%d %H:%M:%S"),
"Start time for the query in GoogleSQL datetime format (e.g., '2025-05-29 17:52:00 America/Los_Angeles'). "
"Can also accept common datetime formats which will be converted. "
"In particular, supports '[N] days ago' format, e.g., '2 days ago'. "
"Defaults to 5 days ago.",
)
@click.option(
"--bq-dataset",
default="benchmark_dataset_test",
help="BigQuery dataset name",
flags.DEFINE_string(
"end-time",
datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
"End time for the query in GoogleSQL datetime format (e.g., '2025-06-01 20:00:00 America/Los_Angeles'). "
"Can also accept common datetime formats which will be converted. "
"In particular, supports '[N] days ago' format, e.g., '2 days ago'. "
"Defaults to the current time.",
)
@click.option(
"--bq-table",
default="torchprime-e2e-tests",
help="BigQuery table name",
flags.DEFINE_integer("limit", 1200, "Maximum number of rows to retrieve")
flags.DEFINE_string(
"output",
"e2e_testing/step_time_bounds.yaml",
"Output YAML file path",
)
@click.option(
"--start-time",
default=parse_days_ago("5 days ago").strftime("%Y-%m-%d %H:%M:%S"),
help="Start time for the query in GoogleSQL datetime format (e.g., '2025-05-29 17:52:00 America/Los_Angeles'). "
"Can also accept common datetime formats which will be converted. "
"In particular, supports '[N] days ago' format, e.g., '2 days ago'. "
"Defaults to 5 days ago.",
)
@click.option(
"--end-time",
default=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
help="End time for the query in GoogleSQL datetime format (e.g., '2025-06-01 20:00:00 America/Los_Angeles'). "
"Can also accept common datetime formats which will be converted. "
"In particular, supports '[N] days ago' format, e.g., '2 days ago'. "
"Defaults to the current time.",
)
@click.option(
"--limit",
default=1200,
type=int,
help="Maximum number of rows to retrieve",
)
@click.option(
"--output",
default="e2e_testing/step_time_bounds.yaml",
type=click.Path(),
help="Output YAML file path",
)
@click.option(
"--confidence_level",
default=99.0,
type=float,
help="Confidence level, default is 99%",
)
def main(
bq_project,
bq_dataset,
bq_table,
start_time,
end_time,
limit,
output,
confidence_level,
):
flags.DEFINE_float("confidence_level", 99.0, "Confidence level, default is 99%")


def main(argv):
"""
Query BigQuery for E2E test results and compute step time bounds.

Expand All @@ -314,19 +284,19 @@ def main(
the results to a YAML file for use in GitHub Actions.
"""
console = Console()
confidence_level = confidence_level / 100.0
confidence_level = FLAGS.confidence_level / 100.0

# Parse datetime inputs
start_time = parse_datetime(start_time)
end_time = parse_datetime(end_time)
start_time = parse_datetime(FLAGS['start-time'].value)
end_time = parse_datetime(FLAGS['end-time'].value)

console.print("[bold]Querying BigQuery:[/bold]")
console.print(f" Project: {bq_project}")
console.print(f" Dataset: {bq_dataset}")
console.print(f" Table: {bq_table}")
console.print(f" Project: {FLAGS.bq_project}")
console.print(f" Dataset: {FLAGS.bq_dataset}")
console.print(f" Table: {FLAGS.bq_table}")
console.print(f" Start: {start_time}")
console.print(f" End: {end_time}")
console.print(f" Limit: {limit}")
console.print(f" Limit: {FLAGS.limit}")

client = bigquery.Client()

Expand All @@ -335,15 +305,15 @@ def main(
SELECT
*
FROM
`{bq_project}`.`{bq_dataset}`.`{bq_table}`
`{FLAGS.bq_project}`.`{FLAGS.bq_dataset}`.`{FLAGS.bq_table}`
WHERE
software_id = '{TORCHPRIME_SOFTWARE_ID}' AND
update_timestamp >= TIMESTAMP('{start_time}') AND
update_timestamp <= TIMESTAMP('{end_time}')
ORDER BY
update_timestamp DESC
LIMIT
{limit};
{FLAGS.limit};
"""

query_job = client.query(query)
Expand Down Expand Up @@ -421,7 +391,7 @@ def main(
console.print(table)

# Write to file
output_path = Path(output)
output_path = Path(FLAGS.output)
output_path.parent.mkdir(exist_ok=True, parents=True)

with open(output_path, "w") as f:
Expand All @@ -443,4 +413,4 @@ def main(


if __name__ == "__main__":
main()
app.run(main)
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ dependencies = [
"dataclasses-json==0.6.7",
"benchmark-db-writer @ git+https://github.com/AI-Hypercomputer/aotc.git@2ff16e670df20b497ddaf1f86920dbb5dd9f0c8f#subdirectory=src/aotc/benchmark_db_writer",
"dacite==1.9.2",
"click~=8.1.8",
"google-cloud-storage==2.19.0"
"google-cloud-storage==2.19.0",
"absl-py"
]

[project.optional-dependencies]
Expand Down
8 changes: 4 additions & 4 deletions torchprime/launcher/benchmark_db_util.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os
import sys
import uuid
from datetime import datetime
from pathlib import Path

import click
from omegaconf import OmegaConf

from torchprime.metrics.metrics import Metrics
Expand Down Expand Up @@ -36,7 +36,7 @@ def get_metrics(base_artifact_path: str, jobset_name_for_outputs: str) -> dict |
)
)
if not metric_file_path.exists():
click.echo(f"Metrics file not found at {metric_file_path}", err=True)
print(f"Metrics file not found at {metric_file_path}", file=sys.stderr)
return None

metrics_data = Metrics.load(metric_file_path)
Expand Down Expand Up @@ -71,7 +71,7 @@ def get_config(base_artifact_path: str, jobset_name_for_outputs: str) -> dict |
)

if not config_file_path.exists():
click.echo(f"Config file not found at {config_file_path}", err=True)
print(f"Config file not found at {config_file_path}", file=sys.stderr)
return None

# Load the JSON configuration file and convert it to a Python dictionary
Expand Down Expand Up @@ -106,7 +106,7 @@ def prepare_benchmark_summary(
hardware_num_chips, tflops_per_chip = get_num_chips_and_tflops_per_chip(tpu_type)
hardware_id = tpu_type.split("-")[0] # Extract the TPU generation (e.g., v4, v5e)

click.echo(
print(
f"Tpu type: {tpu_type}, hardware_num_chips: {hardware_num_chips}, tflops_per_chip: {tflops_per_chip}"
)

Expand Down
Loading
Loading