Skip to content

Commit d2d3f26

Browse files
committed
Add data to metrics_for_gcs
1 parent df1c471 commit d2d3f26

File tree

1 file changed

+69
-0
lines changed

1 file changed

+69
-0
lines changed

MaxText/metric_logger.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@
3434
from MaxText.globals import EPS
3535

3636
from collections import defaultdict
37+
from datetime import datetime
38+
from ml_goodput_measurement import goodput
3739

3840

3941
def _prepare_metrics_for_json(metrics, step, run_name):
@@ -135,6 +137,73 @@ def write_metrics_for_gcs(self, metrics, step, is_training):
135137
max_logging.log(f"File {metrics_filename} moved successfully!")
136138
self.running_gcs_metrics = [] # reset running_metrics to empty list
137139

140+
def write_rich_metrics_for_gcs(self, metrics, step, is_training):
141+
"""Writes step metrics summary JSON directly to GCS using real perf metrics."""
142+
run_name = self.config.run_name
143+
goodput_logger_name = f"goodput_{run_name}"
144+
metrics_dict_step = _prepare_metrics_for_json(metrics, step, run_name)
145+
self.running_gcs_metrics.append(metrics_dict_step)
146+
if is_training and (step + 1) % self.config.log_period == 0 or step == self.config.steps - 1:
147+
start_step = (step // self.config.log_period) * self.config.log_period
148+
149+
step_times = []
150+
tflops_sec = []
151+
for m in self.running_gcs_metrics:
152+
if "perf/step_time_seconds" in m:
153+
step_times.append((m["step"], float(m["perf/step_time_seconds"])))
154+
if "perf/per_device_tflops_per_sec" in m:
155+
tflops_sec.append(float(m["perf/per_device_tflops_per_sec"]))
156+
157+
if step_times:
158+
avg_step_time = sum(t for _, t in step_times) / len(step_times) if step_times else 0
159+
min_step, min_time = min(step_times, key=lambda x: x[1])
160+
max_step, max_time = max(step_times, key=lambda x: x[1])
161+
else:
162+
avg_step_time = min_time = max_time = 0
163+
min_step = max_step = None
164+
165+
if tflops_sec:
166+
mfu_val = f"{(sum(tflops_sec) / len(tflops_sec)):.2f} TFLOPs/s" if tflops_sec else "0.00 TFLOPs/s"
167+
else:
168+
mfu_val = "N/A"
169+
170+
# Todo: Goodput/badput breakdown and lins
171+
goodput_calculator = goodput.GoodputCalculator(
172+
job_name=run_name, logger_name=goodput_logger_name
173+
)
174+
current_goodput, current_badput_breakdown, last_step = (
175+
goodput_calculator.get_job_goodput(include_badput_breakdown=True)
176+
)
177+
goodput = f"{current_goodput:.2f}%"
178+
badput_breakdown = {
179+
"checkpoint_loading": f"{current_badput_breakdown[goodput.BadputType.UNPRODUCTIVE_CHECKPOINT_RESTORE_TIME]:.2f}%",
180+
"data_loading": f"{current_badput_breakdown[goodput.BadputType.DATA_LOADING]:.2f}%",
181+
"failure": "N/A",
182+
"reshard": "N/A"
183+
}
184+
185+
links = {
186+
"cloud_logging": f"https://console.cloud.google.com/logs/query;query=run_name:{self.config.run_name}",
187+
"goodput_monitor": "...",
188+
"disruption_dashboard": "..."
189+
}
190+
191+
summary_data = {
192+
"avg_step_time": f"{avg_step_time:.2f}s",
193+
"min_step_time": f"{min_time:.2f}s on step {min_step}",
194+
"max_step_time": f"{max_time:.2f}s on step {max_step}",
195+
"goodput": goodput,
196+
"MFU": mfu_val,
197+
"badput_breakdown": badput_breakdown,
198+
"links": links,
199+
"generated_at": datetime.utcnow().isoformat() + "Z"
200+
}
201+
202+
summary_gcs_filename = f"{self.config.metrics_dir}/summary_step_{start_step:06}_to_step_{step:06}.json"
203+
gcs_utils.upload_blob_from_string(summary_gcs_filename, json.dumps(summary_data, indent=2))
204+
205+
self.running_gcs_metrics = [] # reset running_metrics to empty list
206+
138207
def write_metrics_to_tensorboard(self, metrics, step, is_training):
139208
"""Writes metrics to TensorBoard."""
140209
if jax.process_index() == 0:

0 commit comments

Comments
 (0)