Skip to content

Add new flag --enable_rich_metrics #2189

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ reuse_example_batch: 0 # for testing TPU performance, this options repeated uses
metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written.
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
gcs_metrics: False
enable_rich_metrics: false

# If true save config to GCS in {base_output_directory}/{run_name}/
save_config_to_gcs: False
Expand Down
84 changes: 84 additions & 0 deletions MaxText/metric_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from MaxText.globals import EPS

from collections import defaultdict
from datetime import datetime
from ml_goodput_measurement import goodput


def _prepare_metrics_for_json(metrics, step, run_name):
Expand All @@ -54,6 +56,7 @@ def __init__(self, config, learning_rate_schedule):
self.config = config
self.metadata = {}
self.running_gcs_metrics = [] if config.gcs_metrics else None
self.running_gcs_rich_metrics = [] if config.enable_rich_metrics else None
self.performance_metric_queue = self.get_performance_metric_queue(config)
self.learning_rate_schedule = learning_rate_schedule
self.cumulative_eval_metrics = {"scalar": defaultdict(float)}
Expand All @@ -74,6 +77,9 @@ def write_metrics(self, metrics, step, is_training=True):
if self.config.metrics_file:
self.write_metrics_locally(metrics, step)

if self.config.enable_rich_metrics:
self.write_rich_metrics_for_gcs(metrics,step, is_training)

if self.config.gcs_metrics and jax.process_index() == 0:
self.write_metrics_for_gcs(metrics, step, is_training)

Expand Down Expand Up @@ -135,6 +141,84 @@ def write_metrics_for_gcs(self, metrics, step, is_training):
max_logging.log(f"File {metrics_filename} moved successfully!")
self.running_gcs_metrics = [] # reset running_metrics to empty list

def write_rich_metrics_for_gcs(self, metrics, step, is_training):
"""Writes step metrics summary JSON directly to GCS using real perf metrics."""
run_name = self.config.run_name
goodput_logger_name = f"goodput_{run_name}"
metrics_dict_step = _prepare_metrics_for_json(metrics, step, run_name)
self.running_gcs_rich_metrics.append(metrics_dict_step)
if is_training and (step + 1) % self.config.log_period == 0 or step == self.config.steps - 1:
start_step = (step // self.config.log_period) * self.config.log_period

step_times = []
tflops_sec = []
for m in self.running_gcs_rich_metrics:
if "perf/step_time_seconds" in m:
step_times.append((m["step"], float(m["perf/step_time_seconds"])))
if "perf/per_device_tflops_per_sec" in m:
tflops_sec.append(float(m["perf/per_device_tflops_per_sec"]))

if step_times:
avg_step_time = sum(t for _, t in step_times) / len(step_times) if step_times else 0
min_step, min_time = min(step_times, key=lambda x: x[1])
max_step, max_time = max(step_times, key=lambda x: x[1])
else:
avg_step_time = min_time = max_time = 0
min_step = max_step = None

if tflops_sec:
mfu_val = f"{(sum(tflops_sec) / len(tflops_sec)):.2f} TFLOPs/s" if tflops_sec else "0.00 TFLOPs/s"
else:
mfu_val = "N/A"

# Todo: Goodput/badput breakdown and lins
goodput_calculator = goodput.GoodputCalculator(
job_name=run_name, logger_name=goodput_logger_name
)
current_goodput, current_badput_breakdown, last_step = (
goodput_calculator.get_job_goodput(include_badput_breakdown=True)
)
goodput_val = f"{current_goodput:.2f}%"
badput_breakdown = {
"checkpoint_loading": f"{current_badput_breakdown[goodput.BadputType.UNPRODUCTIVE_CHECKPOINT_RESTORE_TIME]:.2f}%",
"data_loading": f"{current_badput_breakdown[goodput.BadputType.DATA_LOADING]:.2f}%",
"failure": "N/A",
"reshard": "N/A"
}

project = os.environ["PROJECT"]
cluster = os.environ["CLUSTER"]
zone = os.environ["ZONE"]
region = zone.split('-')[0] + '-' + zone.split('-')[1]
links = {
"cloud_logging": f"https://console.cloud.google.com/logs/query;query=resource.type%3D%22k8s_container%22%0Aresource.labels.project_id%3D%22{project}%22%0Aresource.labels.location%3D%22{region}%22%0Aresource.labels.cluster_name%3D%22{cluster}%22%0Aresource.labels.namespace_name%3D%22default%22%0Aresource.labels.pod_name:%22{run_name}%22%0A;duration=PT1H?project={project}",
"goodput_monitor": "...",
"disruption_dashboard": "..."
}

summary_data = {
"avg_step_time": f"{avg_step_time:.2f}s",
"min_step_time": f"{min_time:.2f}s on step {min_step}",
"max_step_time": f"{max_time:.2f}s on step {max_step}",
"goodput": goodput_val,
"MFU": mfu_val,
"badput_breakdown": badput_breakdown,
"links": links,
"generated_at": datetime.utcnow().isoformat() + "Z"
}

rich_metrics_filename = f"rich_metrics_step_{start_step:06}_to_step_{step:06}.txt"

with open(rich_metrics_filename, "wt", encoding="utf8") as rich_metrics_for_gcs:
rich_metrics_for_gcs.write(json.dumps(summary_data, indent=2))

gcs_filename = os.path.join(self.config.metrics_dir, rich_metrics_filename)
max_logging.log(f"Moving file {rich_metrics_filename} to GCS...")
gcs_utils.upload_blob(gcs_filename, rich_metrics_filename)
max_logging.log(f"File {rich_metrics_filename} moved successfully!")

self.running_gcs_rich_metrics = [] # reset running_metrics to empty list

def write_metrics_to_tensorboard(self, metrics, step, is_training):
"""Writes metrics to TensorBoard."""
if jax.process_index() == 0:
Expand Down
12 changes: 11 additions & 1 deletion benchmarks/maxtext_xpk_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class WorkloadConfig:
generate_metrics_and_upload_to_big_query: bool = True
hardware_id: str = 'v6e'
metrics_gcs_file: str = ''
enable_rich_metrics: bool = True
base_config: str = os.path.join("MaxText", "configs", "base.yml")
topology: str = dataclasses.field(init=False)
num_devices_per_slice: int = dataclasses.field(init=False)
Expand Down Expand Up @@ -340,6 +341,7 @@ def _build_args_from_config(wl_config: WorkloadConfig) -> dict:
log.info("using steps=(%d) in model convergence test setup", num_steps)

return {"metrics_gcs_file": wl_config.metrics_gcs_file,
# "enable_rich_metrics": wl_config.enable_rich_metrics,
"model_id": wl_config.model.model_type,
"hardware_id": wl_config.hardware_id,
"software_id": "jax_maxtext",
Expand Down Expand Up @@ -414,6 +416,8 @@ def build_user_command(
# Save metrics to gcs bucket so that we can upload them to bq in post processing.
enable_metrics_cmd = 'gcs_metrics=true'

enable_rich_metrics_cmd="enable_rich_metrics=true"

upload_hlo_dump=""
hlo_dump=""
if wl_config.hlo_dump:
Expand All @@ -435,6 +439,7 @@ def build_user_command(
f'{vertex_tensorboard}',
f'{run_name_command}',
f'{enable_metrics_cmd}'
f'{enable_rich_metrics_cmd}'
f'{upload_hlo_dump}'
])
return command
Expand Down Expand Up @@ -617,6 +622,11 @@ def generate_xpk_workload_cmd(
wl_config.run_name,
'metrics')

user_command_prefix = ' '.join([
f'export PROJECT={cluster_config.project} &&',
f'export CLUSTER={cluster_config.cluster_name} &&',
f'export ZONE={cluster_config.zone} &&'
])
user_command = ''
if not is_pathways_headless_enabled:
user_command = build_user_command(
Expand Down Expand Up @@ -674,7 +684,7 @@ def generate_xpk_workload_cmd(
f' {device_type}'
f' {all_xpk_storage}'
f' --num-slices={wl_config.num_slices}'
f' --command="{user_command} {upload_metrics_to_bq_cmd}"'
f' --command="{user_command_prefix} {user_command} {upload_metrics_to_bq_cmd}"'
f' {docker_image_flag}'
' --enable-debug-logs'
f' --workload={name}'
Expand Down
Loading