Skip to content

Commit 9301a31

Browse files
committed
Add data to metrics_for_gcs
1 parent df1c471 commit 9301a31

File tree

3 files changed

+96
-1
lines changed

3 files changed

+96
-1
lines changed

MaxText/configs/base.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ reuse_example_batch: 0 # for testing TPU performance, this options repeated uses
7878
metrics_file: "" # for testing, local file that stores scalar metrics. If empty, no metrics are written.
7979
# If true save metrics such as loss and TFLOPS to GCS in {base_output_directory}/{run_name}/metrics/
8080
gcs_metrics: False
81+
enable_rich_metrics: false
8182

8283
# If true save config to GCS in {base_output_directory}/{run_name}/
8384
save_config_to_gcs: False

MaxText/metric_logger.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
from MaxText.globals import EPS
3535

3636
from collections import defaultdict
37+
from datetime import datetime
38+
from ml_goodput_measurement import goodput
3739

3840

3941
def _prepare_metrics_for_json(metrics, step, run_name):
@@ -54,6 +56,7 @@ def __init__(self, config, learning_rate_schedule):
5456
self.config = config
5557
self.metadata = {}
5658
self.running_gcs_metrics = [] if config.gcs_metrics else None
59+
self.running_gcs_rich_metrics = [] if config.enable_rich_metrics else None
5760
self.performance_metric_queue = self.get_performance_metric_queue(config)
5861
self.learning_rate_schedule = learning_rate_schedule
5962
self.cumulative_eval_metrics = {"scalar": defaultdict(float)}
@@ -74,6 +77,9 @@ def write_metrics(self, metrics, step, is_training=True):
7477
if self.config.metrics_file:
7578
self.write_metrics_locally(metrics, step)
7679

80+
if self.config.enable_rich_metrics:
81+
self.write_rich_metrics_for_gcs(metrics,step, is_training)
82+
7783
if self.config.gcs_metrics and jax.process_index() == 0:
7884
self.write_metrics_for_gcs(metrics, step, is_training)
7985

@@ -135,6 +141,84 @@ def write_metrics_for_gcs(self, metrics, step, is_training):
135141
max_logging.log(f"File {metrics_filename} moved successfully!")
136142
self.running_gcs_metrics = [] # reset running_metrics to empty list
137143

144+
def write_rich_metrics_for_gcs(self, metrics, step, is_training):
145+
"""Writes step metrics summary JSON directly to GCS using real perf metrics."""
146+
run_name = self.config.run_name
147+
goodput_logger_name = f"goodput_{run_name}"
148+
metrics_dict_step = _prepare_metrics_for_json(metrics, step, run_name)
149+
self.running_gcs_rich_metrics.append(metrics_dict_step)
150+
if is_training and (step + 1) % self.config.log_period == 0 or step == self.config.steps - 1:
151+
start_step = (step // self.config.log_period) * self.config.log_period
152+
153+
step_times = []
154+
tflops_sec = []
155+
for m in self.running_gcs_rich_metrics:
156+
if "perf/step_time_seconds" in m:
157+
step_times.append((m["step"], float(m["perf/step_time_seconds"])))
158+
if "perf/per_device_tflops_per_sec" in m:
159+
tflops_sec.append(float(m["perf/per_device_tflops_per_sec"]))
160+
161+
if step_times:
162+
avg_step_time = sum(t for _, t in step_times) / len(step_times) if step_times else 0
163+
min_step, min_time = min(step_times, key=lambda x: x[1])
164+
max_step, max_time = max(step_times, key=lambda x: x[1])
165+
else:
166+
avg_step_time = min_time = max_time = 0
167+
min_step = max_step = None
168+
169+
if tflops_sec:
170+
mfu_val = f"{(sum(tflops_sec) / len(tflops_sec)):.2f} TFLOPs/s" if tflops_sec else "0.00 TFLOPs/s"
171+
else:
172+
mfu_val = "N/A"
173+
174+
# Todo: Goodput/badput breakdown and lins
175+
goodput_calculator = goodput.GoodputCalculator(
176+
job_name=run_name, logger_name=goodput_logger_name
177+
)
178+
current_goodput, current_badput_breakdown, last_step = (
179+
goodput_calculator.get_job_goodput(include_badput_breakdown=True)
180+
)
181+
goodput_val = f"{current_goodput:.2f}%"
182+
badput_breakdown = {
183+
"checkpoint_loading": f"{current_badput_breakdown[goodput.BadputType.UNPRODUCTIVE_CHECKPOINT_RESTORE_TIME]:.2f}%",
184+
"data_loading": f"{current_badput_breakdown[goodput.BadputType.DATA_LOADING]:.2f}%",
185+
"failure": "N/A",
186+
"reshard": "N/A"
187+
}
188+
189+
project = os.environ["PROJECT"]
190+
cluster = os.environ["CLUSTER"]
191+
zone = os.environ["ZONE"]
192+
region = zone.split('-')[0] + '-' + zone.split('-')[1]
193+
links = {
194+
"cloud_logging": f"https://console.cloud.google.com/logs/query;query=resource.type%3D%22k8s_container%22%0Aresource.labels.project_id%3D%22{project}%22%0Aresource.labels.location%3D%22{region}%22%0Aresource.labels.cluster_name%3D%22{cluster}%22%0Aresource.labels.namespace_name%3D%22default%22%0Aresource.labels.pod_name:%22{run_name}%22%0A;cursorTimestamp=2025-08-22T21:00:46.979388423Z;duration=PT1H?project={project}",
195+
"goodput_monitor": "...",
196+
"disruption_dashboard": "..."
197+
}
198+
199+
summary_data = {
200+
"avg_step_time": f"{avg_step_time:.2f}s",
201+
"min_step_time": f"{min_time:.2f}s on step {min_step}",
202+
"max_step_time": f"{max_time:.2f}s on step {max_step}",
203+
"goodput": goodput_val,
204+
"MFU": mfu_val,
205+
"badput_breakdown": badput_breakdown,
206+
"links": links,
207+
"generated_at": datetime.utcnow().isoformat() + "Z"
208+
}
209+
210+
rich_metrics_filename = f"rich_metrics_step_{start_step:06}_to_step_{step:06}.txt"
211+
212+
with open(rich_metrics_filename, "wt", encoding="utf8") as rich_metrics_for_gcs:
213+
rich_metrics_for_gcs.write(json.dumps(summary_data, indent=2))
214+
215+
gcs_filename = os.path.join(self.config.metrics_dir, rich_metrics_filename)
216+
max_logging.log(f"Moving file {rich_metrics_filename} to GCS...")
217+
gcs_utils.upload_blob(gcs_filename, rich_metrics_filename)
218+
max_logging.log(f"File {rich_metrics_filename} moved successfully!")
219+
220+
self.running_gcs_rich_metrics = [] # reset running_metrics to empty list
221+
138222
def write_metrics_to_tensorboard(self, metrics, step, is_training):
139223
"""Writes metrics to TensorBoard."""
140224
if jax.process_index() == 0:

benchmarks/maxtext_xpk_runner.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ class WorkloadConfig:
102102
generate_metrics_and_upload_to_big_query: bool = True
103103
hardware_id: str = 'v6e'
104104
metrics_gcs_file: str = ''
105+
enable_rich_metrics: bool = True
105106
base_config: str = os.path.join("MaxText", "configs", "base.yml")
106107
topology: str = dataclasses.field(init=False)
107108
num_devices_per_slice: int = dataclasses.field(init=False)
@@ -340,6 +341,7 @@ def _build_args_from_config(wl_config: WorkloadConfig) -> dict:
340341
log.info("using steps=(%d) in model convergence test setup", num_steps)
341342

342343
return {"metrics_gcs_file": wl_config.metrics_gcs_file,
344+
# "enable_rich_metrics": wl_config.enable_rich_metrics,
343345
"model_id": wl_config.model.model_type,
344346
"hardware_id": wl_config.hardware_id,
345347
"software_id": "jax_maxtext",
@@ -414,6 +416,8 @@ def build_user_command(
414416
# Save metrics to gcs bucket so that we can upload them to bq in post processing.
415417
enable_metrics_cmd = 'gcs_metrics=true'
416418

419+
enable_rich_metrics_cmd="enable_rich_metrics=true"
420+
417421
upload_hlo_dump=""
418422
hlo_dump=""
419423
if wl_config.hlo_dump:
@@ -435,6 +439,7 @@ def build_user_command(
435439
f'{vertex_tensorboard}',
436440
f'{run_name_command}',
437441
f'{enable_metrics_cmd}'
442+
f'{enable_rich_metrics_cmd}'
438443
f'{upload_hlo_dump}'
439444
])
440445
return command
@@ -617,6 +622,11 @@ def generate_xpk_workload_cmd(
617622
wl_config.run_name,
618623
'metrics')
619624

625+
user_command_prefix = ' '.join([
626+
f'export PROJECT={cluster_config.project} &&',
627+
f'export CLUSTER={cluster_config.cluster_name} &&',
628+
f'export ZONE={cluster_config.zone} &&'
629+
])
620630
user_command = ''
621631
if not is_pathways_headless_enabled:
622632
user_command = build_user_command(
@@ -674,7 +684,7 @@ def generate_xpk_workload_cmd(
674684
f' {device_type}'
675685
f' {all_xpk_storage}'
676686
f' --num-slices={wl_config.num_slices}'
677-
f' --command="{user_command} {upload_metrics_to_bq_cmd}"'
687+
f' --command="{user_command_prefix} {user_command} {upload_metrics_to_bq_cmd}"'
678688
f' {docker_image_flag}'
679689
' --enable-debug-logs'
680690
f' --workload={name}'

0 commit comments

Comments
 (0)