Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Ironwood/src/benchmark_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def gemm_multiple_run_calculate_metrics(
total_flops,
total_flops_all_devices,
peak_flops,
dtype=dtype.dtype.name,
)

def gemm_simple(
Expand Down
10 changes: 5 additions & 5 deletions Ironwood/src/benchmark_inference_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def add_calculate_metrics(
total_bytes, SHARDING_STRATEGY
)
return unified_bytes_metrics(
m, n, time_ms_list, total_bytes, total_bytes_all_devices
m, n, time_ms_list, total_bytes, total_bytes_all_devices, dtype=dtype.dtype.name
)


Expand Down Expand Up @@ -191,7 +191,7 @@ def rmsnorm_calculate_metrics(
total_bytes, SHARDING_STRATEGY
)
return unified_bytes_metrics(
m, n, time_ms_list, total_bytes, total_bytes_all_devices
m, n, time_ms_list, total_bytes, total_bytes_all_devices, dtype=dtype.dtype.name
)


Expand Down Expand Up @@ -264,7 +264,7 @@ def silu_mul_calculate_metrics(
total_bytes, SHARDING_STRATEGY
)
return unified_bytes_metrics(
m, n, time_ms_list, total_bytes, total_bytes_all_devices
m, n, time_ms_list, total_bytes, total_bytes_all_devices, dtype=dtype.dtype.name
)


Expand Down Expand Up @@ -325,7 +325,7 @@ def sigmoid_calculate_metrics(
total_bytes, SHARDING_STRATEGY
)
return unified_bytes_metrics(
m, n, time_ms_list, total_bytes, total_bytes_all_devices
m, n, time_ms_list, total_bytes, total_bytes_all_devices, dtype=dtype.dtype.name
)


Expand Down Expand Up @@ -413,4 +413,4 @@ def sigmoid_calculate_metrics(
# scale = 2 if dtype == jnp.bfloat16 else 1
# total_bytes = scale * 3 * m
# total_bytes, total_bytes_all_devices = handle_based_on_sharding(total_bytes, SHARDING_STRATEGY)
# return unified_bytes_metrics(m, 0, time_ms_list, total_bytes, total_bytes_all_devices)
# return unified_bytes_metrics(m, 0, time_ms_list, total_bytes, total_bytes_all_devices, dtype=dtype.dtype.name)
5 changes: 5 additions & 0 deletions Ironwood/src/benchmark_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,6 +1110,7 @@ def unified_flops_metrics(
total_flops: int,
total_flops_all_devices: int,
peak_TFLOPS_per_device: float,
dtype: str = None,
) -> Dict[str, Any]:
"""Calculates the metrics for the naive matmul benchmark."""
# Build dictionary of all the parameters in the function
Expand Down Expand Up @@ -1178,6 +1179,7 @@ def unified_bytes_metrics(
total_bytes: int,
total_bytes_all_devices: int = 1e9,
quant_dtype: str = None,
dtype: str = None,
) -> Dict[str, Any]:
"""Calculates the metrics for the naive matmul benchmark."""
# Build dictionary of all the parameters in the function
Expand Down Expand Up @@ -1212,6 +1214,9 @@ def unified_bytes_metrics(
if quant_dtype is not None:
metadata.update({"quant_dtype": quant_dtype})
metrics.update({"quant_dtype": quant_dtype})
if dtype is not None:
metadata.update({"dtype": dtype})
metrics.update({"dtype": dtype})
metadata.update(
{
"StepTime(median,ms)": average_time_ms_statistics.statistics["p50"],
Expand Down