Skip to content

Commit 5564ea5

Browse files
authored
Print dtype in metrics stage (#89)
As requested in [another PR](#84 (comment)) for easier result inspection.
1 parent c181492 commit 5564ea5

File tree

1 file changed

+12
-6
lines changed

1 file changed

+12
-6
lines changed

Ironwood/src/benchmark_utils.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,7 +1140,9 @@ def unified_flops_metrics(
11401140
metrics_list=tflops_per_sec_all_devices, metrics_name="tflops_per_sec"
11411141
)
11421142
mfu_statistics = MetricsStatistics(metrics_list=mfu, metrics_name="MFU")
1143+
dtype_prefix = f"[{dtype}] " if dtype is not None else ""
11431144
print(
1145+
f"{dtype_prefix}"
11441146
f"Total floating-point ops: {total_flops}, Step Time (median): {average_time_ms_statistics.statistics['p50']:.2f}, "
11451147
f"Throughput (median): {tflops_per_sec_statistics.statistics['p50']:.2f} TFLOP / second / device, "
11461148
f"TotalThroughput (median): {tflops_per_sec_all_devices_statistics.statistics['p50']:.2f} TFLOP / second, "
@@ -1204,19 +1206,23 @@ def unified_bytes_metrics(
12041206
gigabytes_per_sec_all_devices_statistics = MetricsStatistics(
12051207
metrics_list=digabytes_per_sec_all_devices, metrics_name="Gbytes_per_sec"
12061208
)
1207-
print(
1208-
f"Total bytes: {total_bytes}, Step Time (median): {average_time_ms_statistics.statistics['p50']:.2f}, Throughput (median):"
1209-
f" {gigabytes_per_sec_statistics.statistics['p50']:.2f} GBytes / second / device,"
1210-
f" TotalThroughput (median): {gigabytes_per_sec_all_devices_statistics.statistics['p50']:.2f} GBytes / second"
1211-
)
1212-
print()
1209+
type_prefix = ""
12131210
# Gather the metrics to report.
12141211
if quant_dtype is not None:
12151212
metadata.update({"quant_dtype": quant_dtype})
12161213
metrics.update({"quant_dtype": quant_dtype})
1214+
type_prefix = f"[q={quant_dtype}] "
12171215
if dtype is not None:
12181216
metadata.update({"dtype": dtype})
12191217
metrics.update({"dtype": dtype})
1218+
type_prefix = f"[d={dtype}] "
1219+
print(
1220+
f"{type_prefix}"
1221+
f"Total bytes: {total_bytes}, Step Time (median): {average_time_ms_statistics.statistics['p50']:.2f}, Throughput (median):"
1222+
f" {gigabytes_per_sec_statistics.statistics['p50']:.2f} GBytes / second / device,"
1223+
f" TotalThroughput (median): {gigabytes_per_sec_all_devices_statistics.statistics['p50']:.2f} GBytes / second"
1224+
)
1225+
print()
12201226
metadata.update(
12211227
{
12221228
"StepTime(median,ms)": average_time_ms_statistics.statistics["p50"],

0 commit comments

Comments
 (0)