diff --git a/Ironwood/src/benchmark_gemm.py b/Ironwood/src/benchmark_gemm.py index c8c27bbe..0534b347 100644 --- a/Ironwood/src/benchmark_gemm.py +++ b/Ironwood/src/benchmark_gemm.py @@ -155,6 +155,7 @@ def gemm_multiple_run_calculate_metrics( total_flops, total_flops_all_devices, peak_flops, + dtype=dtype.dtype.name, ) def gemm_simple( diff --git a/Ironwood/src/benchmark_inference_compute.py b/Ironwood/src/benchmark_inference_compute.py index 5059de12..8bfa6b0e 100644 --- a/Ironwood/src/benchmark_inference_compute.py +++ b/Ironwood/src/benchmark_inference_compute.py @@ -124,7 +124,7 @@ def add_calculate_metrics( total_bytes, SHARDING_STRATEGY ) return unified_bytes_metrics( - m, n, time_ms_list, total_bytes, total_bytes_all_devices + m, n, time_ms_list, total_bytes, total_bytes_all_devices, dtype=dtype.dtype.name ) @@ -191,7 +191,7 @@ def rmsnorm_calculate_metrics( total_bytes, SHARDING_STRATEGY ) return unified_bytes_metrics( - m, n, time_ms_list, total_bytes, total_bytes_all_devices + m, n, time_ms_list, total_bytes, total_bytes_all_devices, dtype=dtype.dtype.name ) @@ -264,7 +264,7 @@ def silu_mul_calculate_metrics( total_bytes, SHARDING_STRATEGY ) return unified_bytes_metrics( - m, n, time_ms_list, total_bytes, total_bytes_all_devices + m, n, time_ms_list, total_bytes, total_bytes_all_devices, dtype=dtype.dtype.name ) @@ -325,7 +325,7 @@ def sigmoid_calculate_metrics( total_bytes, SHARDING_STRATEGY ) return unified_bytes_metrics( - m, n, time_ms_list, total_bytes, total_bytes_all_devices + m, n, time_ms_list, total_bytes, total_bytes_all_devices, dtype=dtype.dtype.name ) @@ -413,4 +413,4 @@ def sigmoid_calculate_metrics( # scale = 2 if dtype == jnp.bfloat16 else 1 # total_bytes = scale * 3 * m # total_bytes, total_bytes_all_devices = handle_based_on_sharding(total_bytes, SHARDING_STRATEGY) -# return unified_bytes_metrics(m, 0, time_ms_list, total_bytes, total_bytes_all_devices) +# return unified_bytes_metrics(m, 0, time_ms_list, total_bytes, total_bytes_all_devices, dtype=dtype.dtype.name) diff --git a/Ironwood/src/benchmark_utils.py b/Ironwood/src/benchmark_utils.py index d0e72b9e..baa33b5f 100644 --- a/Ironwood/src/benchmark_utils.py +++ b/Ironwood/src/benchmark_utils.py @@ -1110,6 +1110,7 @@ def unified_flops_metrics( total_flops: int, total_flops_all_devices: int, peak_TFLOPS_per_device: float, + dtype: str = None, ) -> Dict[str, Any]: """Calculates the metrics for the naive matmul benchmark.""" # Build dictionary of all the parameters in the function @@ -1178,6 +1179,7 @@ def unified_bytes_metrics( total_bytes: int, total_bytes_all_devices: int = 1e9, quant_dtype: str = None, + dtype: str = None, ) -> Dict[str, Any]: """Calculates the metrics for the naive matmul benchmark.""" # Build dictionary of all the parameters in the function @@ -1212,6 +1214,9 @@ def unified_bytes_metrics( if quant_dtype is not None: metadata.update({"quant_dtype": quant_dtype}) metrics.update({"quant_dtype": quant_dtype}) + if dtype is not None: + metadata.update({"dtype": dtype}) + metrics.update({"dtype": dtype}) metadata.update( { "StepTime(median,ms)": average_time_ms_statistics.statistics["p50"],