34
34
from MaxText .globals import EPS
35
35
36
36
from collections import defaultdict
37
+ from datetime import datetime
38
+ from ml_goodput_measurement import goodput
37
39
38
40
39
41
def _prepare_metrics_for_json (metrics , step , run_name ):
@@ -54,6 +56,7 @@ def __init__(self, config, learning_rate_schedule):
54
56
self .config = config
55
57
self .metadata = {}
56
58
self .running_gcs_metrics = [] if config .gcs_metrics else None
59
+ self .running_gcs_rich_metrics = [] if config .enable_rich_metrics else None
57
60
self .performance_metric_queue = self .get_performance_metric_queue (config )
58
61
self .learning_rate_schedule = learning_rate_schedule
59
62
self .cumulative_eval_metrics = {"scalar" : defaultdict (float )}
@@ -74,6 +77,9 @@ def write_metrics(self, metrics, step, is_training=True):
74
77
if self .config .metrics_file :
75
78
self .write_metrics_locally (metrics , step )
76
79
80
+ if self .config .enable_rich_metrics :
81
+ self .write_rich_metrics_for_gcs (metrics ,step , is_training )
82
+
77
83
if self .config .gcs_metrics and jax .process_index () == 0 :
78
84
self .write_metrics_for_gcs (metrics , step , is_training )
79
85
@@ -135,6 +141,84 @@ def write_metrics_for_gcs(self, metrics, step, is_training):
135
141
max_logging .log (f"File { metrics_filename } moved successfully!" )
136
142
self .running_gcs_metrics = [] # reset running_metrics to empty list
137
143
144
+ def write_rich_metrics_for_gcs (self , metrics , step , is_training ):
145
+ """Writes step metrics summary JSON directly to GCS using real perf metrics."""
146
+ run_name = self .config .run_name
147
+ goodput_logger_name = f"goodput_{ run_name } "
148
+ metrics_dict_step = _prepare_metrics_for_json (metrics , step , run_name )
149
+ self .running_gcs_rich_metrics .append (metrics_dict_step )
150
+ if is_training and (step + 1 ) % self .config .log_period == 0 or step == self .config .steps - 1 :
151
+ start_step = (step // self .config .log_period ) * self .config .log_period
152
+
153
+ step_times = []
154
+ tflops_sec = []
155
+ for m in self .running_gcs_rich_metrics :
156
+ if "perf/step_time_seconds" in m :
157
+ step_times .append ((m ["step" ], float (m ["perf/step_time_seconds" ])))
158
+ if "perf/per_device_tflops_per_sec" in m :
159
+ tflops_sec .append (float (m ["perf/per_device_tflops_per_sec" ]))
160
+
161
+ if step_times :
162
+ avg_step_time = sum (t for _ , t in step_times ) / len (step_times ) if step_times else 0
163
+ min_step , min_time = min (step_times , key = lambda x : x [1 ])
164
+ max_step , max_time = max (step_times , key = lambda x : x [1 ])
165
+ else :
166
+ avg_step_time = min_time = max_time = 0
167
+ min_step = max_step = None
168
+
169
+ if tflops_sec :
170
+ mfu_val = f"{ (sum (tflops_sec ) / len (tflops_sec )):.2f} TFLOPs/s" if tflops_sec else "0.00 TFLOPs/s"
171
+ else :
172
+ mfu_val = "N/A"
173
+
174
+ # Todo: Goodput/badput breakdown and lins
175
+ goodput_calculator = goodput .GoodputCalculator (
176
+ job_name = run_name , logger_name = goodput_logger_name
177
+ )
178
+ current_goodput , current_badput_breakdown , last_step = (
179
+ goodput_calculator .get_job_goodput (include_badput_breakdown = True )
180
+ )
181
+ goodput_val = f"{ current_goodput :.2f} %"
182
+ badput_breakdown = {
183
+ "checkpoint_loading" : f"{ current_badput_breakdown [goodput .BadputType .UNPRODUCTIVE_CHECKPOINT_RESTORE_TIME ]:.2f} %" ,
184
+ "data_loading" : f"{ current_badput_breakdown [goodput .BadputType .DATA_LOADING ]:.2f} %" ,
185
+ "failure" : "N/A" ,
186
+ "reshard" : "N/A"
187
+ }
188
+
189
+ project = os .environ ["PROJECT" ]
190
+ cluster = os .environ ["CLUSTER" ]
191
+ zone = os .environ ["ZONE" ]
192
+ region = zone .split ('-' )[0 ] + '-' + zone .split ('-' )[1 ]
193
+ links = {
194
+ "cloud_logging" : f"https://console.cloud.google.com/logs/query;query=resource.type%3D%22k8s_container%22%0Aresource.labels.project_id%3D%22{ project } %22%0Aresource.labels.location%3D%22{ region } %22%0Aresource.labels.cluster_name%3D%22{ cluster } %22%0Aresource.labels.namespace_name%3D%22default%22%0Aresource.labels.pod_name:%22{ run_name } %22%0A;cursorTimestamp=2025-08-22T21:00:46.979388423Z;duration=PT1H?project={ project } " ,
195
+ "goodput_monitor" : "..." ,
196
+ "disruption_dashboard" : "..."
197
+ }
198
+
199
+ summary_data = {
200
+ "avg_step_time" : f"{ avg_step_time :.2f} s" ,
201
+ "min_step_time" : f"{ min_time :.2f} s on step { min_step } " ,
202
+ "max_step_time" : f"{ max_time :.2f} s on step { max_step } " ,
203
+ "goodput" : goodput_val ,
204
+ "MFU" : mfu_val ,
205
+ "badput_breakdown" : badput_breakdown ,
206
+ "links" : links ,
207
+ "generated_at" : datetime .utcnow ().isoformat () + "Z"
208
+ }
209
+
210
+ rich_metrics_filename = f"rich_metrics_step_{ start_step :06} _to_step_{ step :06} .txt"
211
+
212
+ with open (rich_metrics_filename , "wt" , encoding = "utf8" ) as rich_metrics_for_gcs :
213
+ rich_metrics_for_gcs .write (json .dumps (summary_data , indent = 2 ))
214
+
215
+ gcs_filename = os .path .join (self .config .metrics_dir , rich_metrics_filename )
216
+ max_logging .log (f"Moving file { rich_metrics_filename } to GCS..." )
217
+ gcs_utils .upload_blob (gcs_filename , rich_metrics_filename )
218
+ max_logging .log (f"File { rich_metrics_filename } moved successfully!" )
219
+
220
+ self .running_gcs_rich_metrics = [] # reset running_metrics to empty list
221
+
138
222
def write_metrics_to_tensorboard (self , metrics , step , is_training ):
139
223
"""Writes metrics to TensorBoard."""
140
224
if jax .process_index () == 0 :
0 commit comments