diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 4884f76a30..627ca7c7b3 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -78,6 +78,7 @@ reuse_example_batch: 0 # for testing TPU performance, this options repeated uses metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written. # If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/ gcs_metrics: False +enable_rich_metrics: false # If true save config to GCS in {base_output_directory}/{run_name}/ save_config_to_gcs: False diff --git a/MaxText/metric_logger.py b/MaxText/metric_logger.py index fd9fd46f13..093c33cdc9 100644 --- a/MaxText/metric_logger.py +++ b/MaxText/metric_logger.py @@ -34,6 +34,8 @@ from MaxText.globals import EPS from collections import defaultdict +from datetime import datetime +from ml_goodput_measurement import goodput def _prepare_metrics_for_json(metrics, step, run_name): @@ -54,6 +56,7 @@ def __init__(self, config, learning_rate_schedule): self.config = config self.metadata = {} self.running_gcs_metrics = [] if config.gcs_metrics else None + self.running_gcs_rich_metrics = [] if config.enable_rich_metrics else None self.performance_metric_queue = self.get_performance_metric_queue(config) self.learning_rate_schedule = learning_rate_schedule self.cumulative_eval_metrics = {"scalar": defaultdict(float)} @@ -74,6 +77,9 @@ def write_metrics(self, metrics, step, is_training=True): if self.config.metrics_file: self.write_metrics_locally(metrics, step) + if self.config.enable_rich_metrics: + self.write_rich_metrics_for_gcs(metrics,step, is_training) + if self.config.gcs_metrics and jax.process_index() == 0: self.write_metrics_for_gcs(metrics, step, is_training) @@ -135,6 +141,87 @@ def write_metrics_for_gcs(self, metrics, step, is_training): max_logging.log(f"File {metrics_filename} moved successfully!") self.running_gcs_metrics = [] # reset running_metrics to empty list + def write_rich_metrics_for_gcs(self, metrics, step, is_training): + """Writes step metrics summary JSON directly to GCS using real perf metrics.""" + run_name = self.config.run_name + goodput_logger_name = f"goodput_{run_name}" + metrics_dict_step = _prepare_metrics_for_json(metrics, step, run_name) + self.running_gcs_rich_metrics.append(metrics_dict_step) + if is_training and (step + 1) % self.config.log_period == 0 or step == self.config.steps - 1: + start_step = (step // self.config.log_period) * self.config.log_period + + step_times = [] + tflops_sec = [] + for m in self.running_gcs_rich_metrics: + step_num = m["step"] + if step_num < 2: + continue + if "perf/step_time_seconds" in m: + step_times.append((step_num, float(m["perf/step_time_seconds"]))) + if "perf/per_device_tflops_per_sec" in m: + tflops_sec.append(float(m["perf/per_device_tflops_per_sec"])) + + if step_times: + avg_step_time = sum(t for _, t in step_times) / len(step_times) + min_step, min_time = min(step_times, key=lambda x: x[1]) + max_step, max_time = max(step_times, key=lambda x: x[1]) + else: + avg_step_time = min_time = max_time = 0 + min_step = max_step = None + + if tflops_sec: + mfu_val = f"{(sum(tflops_sec) / len(tflops_sec)):.2f} TFLOPs/s" + else: + mfu_val = "N/A" + + # Todo: Goodput/badput breakdown and lins + goodput_calculator = goodput.GoodputCalculator( + job_name=run_name, logger_name=goodput_logger_name + ) + current_goodput, current_badput_breakdown, last_step = ( + goodput_calculator.get_job_goodput(include_badput_breakdown=True) + ) + goodput_val = f"{current_goodput:.2f}%" + badput_breakdown = { + "checkpoint_loading": f"{current_badput_breakdown[goodput.BadputType.UNPRODUCTIVE_CHECKPOINT_RESTORE_TIME]:.2f}%", + "data_loading": f"{current_badput_breakdown[goodput.BadputType.DATA_LOADING]:.2f}%", + "failure": "N/A", + "reshard": "N/A" + } + + project = os.environ["PROJECT"] + cluster = os.environ["CLUSTER"] + zone = os.environ["ZONE"] + region = zone.split('-')[0] + '-' + zone.split('-')[1] + links = { + "cloud_logging": f"https://console.cloud.google.com/logs/query;query=resource.type%3D%22k8s_container%22%0Aresource.labels.project_id%3D%22{project}%22%0Aresource.labels.location%3D%22{region}%22%0Aresource.labels.cluster_name%3D%22{cluster}%22%0Aresource.labels.namespace_name%3D%22default%22%0Aresource.labels.pod_name:%22{run_name}%22%0A;duration=PT1H?project={project}", + "goodput_monitor": "...", + "disruption_dashboard": "..." + } + + summary_data = { + "avg_step_time": f"{avg_step_time:.2f}s" if step_times else "N/A", + "min_step_time": f"{min_time:.2f}s on step {min_step}" if min_step is not None else "N/A", + "max_step_time": f"{max_time:.2f}s on step {max_step}" if max_step is not None else "N/A", + "goodput": goodput_val, + "MFU": mfu_val, + "badput_breakdown": badput_breakdown, + "links": links, + "generated_at": datetime.utcnow().isoformat() + "Z" + } + + rich_metrics_filename = f"rich_metrics_step_{start_step:06}_to_step_{step:06}.txt" + + with open(rich_metrics_filename, "wt", encoding="utf8") as rich_metrics_for_gcs: + rich_metrics_for_gcs.write(json.dumps(summary_data, indent=2)) + + gcs_filename = os.path.join(self.config.metrics_dir, rich_metrics_filename) + max_logging.log(f"Moving file {rich_metrics_filename} to GCS...") + gcs_utils.upload_blob(gcs_filename, rich_metrics_filename) + max_logging.log(f"File {rich_metrics_filename} moved successfully!") + + self.running_gcs_rich_metrics = [] # reset running_metrics to empty list + def write_metrics_to_tensorboard(self, metrics, step, is_training): """Writes metrics to TensorBoard.""" if jax.process_index() == 0: diff --git a/benchmarks/maxtext_xpk_runner.py b/benchmarks/maxtext_xpk_runner.py index 602df31dbd..24554f0e4f 100644 --- a/benchmarks/maxtext_xpk_runner.py +++ b/benchmarks/maxtext_xpk_runner.py @@ -102,6 +102,7 @@ class WorkloadConfig: generate_metrics_and_upload_to_big_query: bool = True hardware_id: str = 'v6e' metrics_gcs_file: str = '' + enable_rich_metrics: bool = True base_config: str = os.path.join("MaxText", "configs", "base.yml") topology: str = dataclasses.field(init=False) num_devices_per_slice: int = dataclasses.field(init=False) @@ -340,6 +341,7 @@ def _build_args_from_config(wl_config: WorkloadConfig) -> dict: log.info("using steps=(%d) in model convergence test setup", num_steps) return {"metrics_gcs_file": wl_config.metrics_gcs_file, + # "enable_rich_metrics": wl_config.enable_rich_metrics, "model_id": wl_config.model.model_type, "hardware_id": wl_config.hardware_id, "software_id": "jax_maxtext", @@ -414,6 +416,8 @@ def build_user_command( # Save metrics to gcs bucket so that we can upload them to bq in post processing. enable_metrics_cmd = 'gcs_metrics=true' + enable_rich_metrics_cmd="enable_rich_metrics=true" + upload_hlo_dump="" hlo_dump="" if wl_config.hlo_dump: @@ -435,6 +439,7 @@ def build_user_command( f'{vertex_tensorboard}', f'{run_name_command}', f'{enable_metrics_cmd}' + f'{enable_rich_metrics_cmd}' f'{upload_hlo_dump}' ]) return command @@ -617,6 +622,11 @@ def generate_xpk_workload_cmd( wl_config.run_name, 'metrics') + user_command_prefix = ' '.join([ + f'export PROJECT={cluster_config.project} &&', + f'export CLUSTER={cluster_config.cluster_name} &&', + f'export ZONE={cluster_config.zone} &&' + ]) user_command = '' if not is_pathways_headless_enabled: user_command = build_user_command( @@ -674,7 +684,7 @@ def generate_xpk_workload_cmd( f' {device_type}' f' {all_xpk_storage}' f' --num-slices={wl_config.num_slices}' - f' --command="{user_command} {upload_metrics_to_bq_cmd}"' + f' --command="{user_command_prefix} {user_command} {upload_metrics_to_bq_cmd}"' f' {docker_image_flag}' ' --enable-debug-logs' f' --workload={name}'