Skip to content

Commit 7df8f7d

Browse files
committed
Propagate dtype to final benchmark result
1 parent 8d1bf96 commit 7df8f7d

File tree

3 files changed

+11
-5
lines changed

3 files changed

+11
-5
lines changed

Ironwood/src/benchmark_gemm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@ def gemm_multiple_run_calculate_metrics(
155155
total_flops,
156156
total_flops_all_devices,
157157
peak_flops,
158+
dtype=dtype.dtype.name,
158159
)
159160

160161
def gemm_simple(

Ironwood/src/benchmark_inference_compute.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ def add_calculate_metrics(
124124
total_bytes, SHARDING_STRATEGY
125125
)
126126
return unified_bytes_metrics(
127-
m, n, time_ms_list, total_bytes, total_bytes_all_devices
127+
m, n, time_ms_list, total_bytes, total_bytes_all_devices, dtype=dtype.dtype.name
128128
)
129129

130130

@@ -191,7 +191,7 @@ def rmsnorm_calculate_metrics(
191191
total_bytes, SHARDING_STRATEGY
192192
)
193193
return unified_bytes_metrics(
194-
m, n, time_ms_list, total_bytes, total_bytes_all_devices
194+
m, n, time_ms_list, total_bytes, total_bytes_all_devices, dtype=dtype.dtype.name
195195
)
196196

197197

@@ -264,7 +264,7 @@ def silu_mul_calculate_metrics(
264264
total_bytes, SHARDING_STRATEGY
265265
)
266266
return unified_bytes_metrics(
267-
m, n, time_ms_list, total_bytes, total_bytes_all_devices
267+
m, n, time_ms_list, total_bytes, total_bytes_all_devices, dtype=dtype.dtype.name
268268
)
269269

270270

@@ -325,7 +325,7 @@ def sigmoid_calculate_metrics(
325325
total_bytes, SHARDING_STRATEGY
326326
)
327327
return unified_bytes_metrics(
328-
m, n, time_ms_list, total_bytes, total_bytes_all_devices
328+
m, n, time_ms_list, total_bytes, total_bytes_all_devices, dtype=dtype.dtype.name
329329
)
330330

331331

@@ -413,4 +413,4 @@ def sigmoid_calculate_metrics(
413413
# scale = 2 if dtype == jnp.bfloat16 else 1
414414
# total_bytes = scale * 3 * m
415415
# total_bytes, total_bytes_all_devices = handle_based_on_sharding(total_bytes, SHARDING_STRATEGY)
416-
# return unified_bytes_metrics(m, 0, time_ms_list, total_bytes, total_bytes_all_devices)
416+
# return unified_bytes_metrics(m, 0, time_ms_list, total_bytes, total_bytes_all_devices, dtype=dtype.dtype.name)

Ironwood/src/benchmark_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,7 @@ def unified_flops_metrics(
11101110
total_flops: int,
11111111
total_flops_all_devices: int,
11121112
peak_TFLOPS_per_device: float,
1113+
dtype: str = None,
11131114
) -> Dict[str, Any]:
11141115
"""Calculates the metrics for the naive matmul benchmark."""
11151116
# Build dictionary of all the parameters in the function
@@ -1178,6 +1179,7 @@ def unified_bytes_metrics(
11781179
total_bytes: int,
11791180
total_bytes_all_devices: int = 1e9,
11801181
quant_dtype: str = None,
1182+
dtype: str = None,
11811183
) -> Dict[str, Any]:
11821184
"""Calculates the metrics for the naive matmul benchmark."""
11831185
# Build dictionary of all the parameters in the function
@@ -1212,6 +1214,9 @@ def unified_bytes_metrics(
12121214
if quant_dtype is not None:
12131215
metadata.update({"quant_dtype": quant_dtype})
12141216
metrics.update({"quant_dtype": quant_dtype})
1217+
if dtype is not None:
1218+
metadata.update({"dtype": dtype})
1219+
metrics.update({"dtype": dtype})
12151220
metadata.update(
12161221
{
12171222
"StepTime(median,ms)": average_time_ms_statistics.statistics["p50"],

0 commit comments

Comments
 (0)