|
34 | 34 | from MaxText.globals import EPS
|
35 | 35 |
|
36 | 36 | from collections import defaultdict
|
| 37 | +from datetime import datetime |
| 38 | +from ml_goodput_measurement import goodput |
37 | 39 |
|
38 | 40 |
|
39 | 41 | def _prepare_metrics_for_json(metrics, step, run_name):
|
@@ -135,6 +137,73 @@ def write_metrics_for_gcs(self, metrics, step, is_training):
|
135 | 137 | max_logging.log(f"File {metrics_filename} moved successfully!")
|
136 | 138 | self.running_gcs_metrics = [] # reset running_metrics to empty list
|
137 | 139 |
|
| 140 | + def write_rich_metrics_for_gcs(self, metrics, step, is_training): |
| 141 | + """Writes step metrics summary JSON directly to GCS using real perf metrics.""" |
| 142 | + run_name = self.config.run_name |
| 143 | + goodput_logger_name = f"goodput_{run_name}" |
| 144 | + metrics_dict_step = _prepare_metrics_for_json(metrics, step, run_name) |
| 145 | + self.running_gcs_metrics.append(metrics_dict_step) |
| 146 | + if is_training and (step + 1) % self.config.log_period == 0 or step == self.config.steps - 1: |
| 147 | + start_step = (step // self.config.log_period) * self.config.log_period |
| 148 | + |
| 149 | + step_times = [] |
| 150 | + tflops_sec = [] |
| 151 | + for m in self.running_gcs_metrics: |
| 152 | + if "perf/step_time_seconds" in m: |
| 153 | + step_times.append((m["step"], float(m["perf/step_time_seconds"]))) |
| 154 | + if "perf/per_device_tflops_per_sec" in m: |
| 155 | + tflops_sec.append(float(m["perf/per_device_tflops_per_sec"])) |
| 156 | + |
| 157 | + if step_times: |
| 158 | + avg_step_time = sum(t for _, t in step_times) / len(step_times) if step_times else 0 |
| 159 | + min_step, min_time = min(step_times, key=lambda x: x[1]) |
| 160 | + max_step, max_time = max(step_times, key=lambda x: x[1]) |
| 161 | + else: |
| 162 | + avg_step_time = min_time = max_time = 0 |
| 163 | + min_step = max_step = None |
| 164 | + |
| 165 | + if tflops_sec: |
| 166 | + mfu_val = f"{(sum(tflops_sec) / len(tflops_sec)):.2f} TFLOPs/s" if tflops_sec else "0.00 TFLOPs/s" |
| 167 | + else: |
| 168 | + mfu_val = "N/A" |
| 169 | + |
| 170 | + # Todo: Goodput/badput breakdown and lins |
| 171 | + goodput_calculator = goodput.GoodputCalculator( |
| 172 | + job_name=run_name, logger_name=goodput_logger_name |
| 173 | + ) |
| 174 | + current_goodput, current_badput_breakdown, last_step = ( |
| 175 | + goodput_calculator.get_job_goodput(include_badput_breakdown=True) |
| 176 | + ) |
| 177 | + goodput = f"{current_goodput:.2f}%" |
| 178 | + badput_breakdown = { |
| 179 | + "checkpoint_loading": f"{current_badput_breakdown[goodput.BadputType.UNPRODUCTIVE_CHECKPOINT_RESTORE_TIME]:.2f}%", |
| 180 | + "data_loading": f"{current_badput_breakdown[goodput.BadputType.DATA_LOADING]:.2f}%", |
| 181 | + "failure": "N/A", |
| 182 | + "reshard": "N/A" |
| 183 | + } |
| 184 | + |
| 185 | + links = { |
| 186 | + "cloud_logging": f"https://console.cloud.google.com/logs/query;query=run_name:{self.config.run_name}", |
| 187 | + "goodput_monitor": "...", |
| 188 | + "disruption_dashboard": "..." |
| 189 | + } |
| 190 | + |
| 191 | + summary_data = { |
| 192 | + "avg_step_time": f"{avg_step_time:.2f}s", |
| 193 | + "min_step_time": f"{min_time:.2f}s on step {min_step}", |
| 194 | + "max_step_time": f"{max_time:.2f}s on step {max_step}", |
| 195 | + "goodput": goodput, |
| 196 | + "MFU": mfu_val, |
| 197 | + "badput_breakdown": badput_breakdown, |
| 198 | + "links": links, |
| 199 | + "generated_at": datetime.utcnow().isoformat() + "Z" |
| 200 | + } |
| 201 | + |
| 202 | + summary_gcs_filename = f"{self.config.metrics_dir}/summary_step_{start_step:06}_to_step_{step:06}.json" |
| 203 | + gcs_utils.upload_blob_from_string(summary_gcs_filename, json.dumps(summary_data, indent=2)) |
| 204 | + |
| 205 | + self.running_gcs_metrics = [] # reset running_metrics to empty list |
| 206 | + |
138 | 207 | def write_metrics_to_tensorboard(self, metrics, step, is_training):
|
139 | 208 | """Writes metrics to TensorBoard."""
|
140 | 209 | if jax.process_index() == 0:
|
|
0 commit comments