Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions Ironwood/configs/hbm/hbm_multiple_devices.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
benchmarks:
- benchmark_name: "multiple_device_hbm_copy"
benchmark_sweep_params:
- {num_elements_range: {start: 1048576, end: 4294967296, multiplier: 2}, dtype: "bfloat16", num_runs: 1}
trace_dir: "../microbenchmarks/hbm"
csv_path: "../microbenchmarks/hbm"
xlml_metrics_dir: "../microbenchmarks/hbm"
xla_dump_dir: "../microbenchmarks/hbm/hlo_graphs"
15 changes: 15 additions & 0 deletions Ironwood/configs/training/gemm_multiple_devices.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
benchmarks:
- benchmark_name: "gemm_multiple_devices"
trace_dir: "../microbenchmarks/gemm_multiple_run_bf16"
csv_path: "../microbenchmarks/gemm_multiple_run_bf16"
xlml_metrics_dir: "../microbenchmarks/gemm_multiple_run_bf16"
xla_dump_dir: "../microbenchmarks/gemm_multiple_run_bf16/hlo_graphs"
benchmark_sweep_params:
- {m: 16384, k: 18432, n: 16384, num_runs: 100, dtype: 'bfloat16'}
- benchmark_name: "gemm_multiple_devices"
trace_dir: "../microbenchmarks/gemm_multiple_run_fp8"
csv_path: "../microbenchmarks/gemm_multiple_run_fp8"
xlml_metrics_dir: "../microbenchmarks/gemm_multiple_run_fp8"
xla_dump_dir: "../microbenchmarks/gemm_multiple_run_fp8/hlo_graphs"
benchmark_sweep_params:
- {m: 16384, k: 18432, n: 16384, num_runs: 100, dtype: 'float8'}
97 changes: 97 additions & 0 deletions Ironwood/src/benchmark_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,3 +540,100 @@ def gemm_accum_calculate_metrics(
total_flops_all_devices,
PEAK_FLOPS_PER_DEVICE,
)

def gemm_multiple_devices(
m: int,
k: int,
n: int,
dtype: jnp.dtype = jax.numpy.float8_e4m3fn,
num_runs: int = 1,
trace_dir: str = None,
) -> Dict[str, Any]:
"""Benchmarks the OUT<M, N>:BF16 = IN0<M, K> dtype x IN1<N, K>:dtype. Accumulation is FP32. Current supported dtype: float8_e4m3fn, bfloat16."""

def f(x, y):
with jax.named_scope(MARKER):
acc = jax.numpy.einsum(
"ij,jk->ik", x, y, preferred_element_type=jnp.float32
)
return acc.astype(jnp.bfloat16)
SHARDING_STRATEGY = ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_M

mesh = create_mesh(SHARDING_STRATEGY)
lhs_sharding = get_lhs_named_shading(mesh, SHARDING_STRATEGY)
rhs_sharding = get_rhs_named_shading(mesh, SHARDING_STRATEGY)
out_sharding = get_out_sharding(SHARDING_STRATEGY)

jit_sharded_f = jax.jit(
shard_map(
f,
mesh,
in_specs=(lhs_sharding.spec, rhs_sharding.spec),
out_specs=out_sharding,
check_rep=False,
)
)

lhs_shape = (m, k)
rhs_shape = (k, n)

lhs_dtype = dtype
rhs_dtype = dtype

key = jax.random.key(SEED)

def data_generator():
"""Creates new random data on host and puts it on device."""
nonlocal key # Use and update the outer 'key'
key, key_lhs, key_rhs = jax.random.split(key, 3)

# Create random data on host
lhs_host = jax.random.normal(key_lhs, lhs_shape).astype(lhs_dtype)
rhs_host = jax.random.normal(key_rhs, rhs_shape).astype(rhs_dtype)

# Put on device (HBM)
lhs_device = jax.device_put(lhs_host, lhs_sharding)
rhs_device = jax.device_put(rhs_host, rhs_sharding)

return (lhs_device, rhs_device)

# Run the benchmark

print("Running gemm_multiple_run benchmark", num_runs)
dtype_str = "fp8" if dtype==jax.numpy.float8_e4m3fn else "bf16"
time_ms_list = multiple_iteration_timeit_from_trace(
jit_sharded_f,
data_generator,
matrix_dim=f"{dtype_str}_{m}x{n}x{k}",
tries=num_runs,
task="gemm_multiple_run",
trace_dir=trace_dir,
)
return {
"time_ms_list": time_ms_list,
}


def gemm_multiple_devices_calculate_metrics(
m: int,
k: int,
n: int,
dtype: jnp.dtype,
time_ms_list: list[float],
) -> Dict[str, Any]:
# Calculate FLOPs
SHARDING_STRATEGY = ShardingStrategy.SHARDING_ON_ALL_DEVICES_WITH_M
total_flops = 2 * m * k * n # Total floating-point operations
total_flops, total_flops_all_devices = handle_based_on_sharding(
total_flops, SHARDING_STRATEGY
)
peak_flops = PEAK_FLOPS_PER_DEVICE if dtype==jax.numpy.float8_e4m3fn else PEAK_FLOPS_PER_DEVICE/2
return unified_flops_metrics(
m,
n,
k,
time_ms_list,
total_flops,
total_flops_all_devices,
peak_flops,
)
77 changes: 77 additions & 0 deletions Ironwood/src/benchmark_hbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@
from benchmark_utils import (
MetricsStatistics,
multiple_iteration_timeit_from_trace,
ShardingStrategy,
create_mesh,
)
from common import MARKER
import jax
import jax.numpy as jnp

from jax.sharding import NamedSharding
from jax.sharding import PartitionSpec as P

SEED = 0
os.environ["LIBTPU_INIT_ARGS"] = (
Expand Down Expand Up @@ -102,3 +106,76 @@ def single_device_hbm_copy_calculate_metrics(
metrics.update(statistics.serialize_statistics())
metrics = {key: value for key, value in metrics.items() if value is not None}
return metadata, metrics

SHARDING_STRATEGY = ShardingStrategy.NO_SHARDING

def multiple_device_hbm_copy(
num_elements: int,
dtype: jnp.dtype,
num_runs: int = 1,
trace_dir: str = None,
) -> Dict[str, Any]:
"""Benchmarks HBM with copy(read and write) on a single device."""

def f(a):
with jax.named_scope(MARKER):
return a.copy()

mesh = create_mesh(SHARDING_STRATEGY)
sharding = NamedSharding(mesh, P(None,))

a = jax.random.normal(jax.random.key(0), (num_elements,), out_sharding=sharding).astype(dtype)
print(a.shape)
print(a.dtype)
jitted_f = jax.jit(f)
# Run once
output = jitted_f(a)
jax.block_until_ready(output)

# Run the benchmark
time_ms_list = multiple_iteration_timeit_from_trace(
compute_func=jitted_f,
data_generator=lambda: (a,),
matrix_dim=f"{num_elements}",
tries=num_runs,
task="copy",
trace_dir=trace_dir,
)
return {"time_ms_list": time_ms_list}

def multiple_device_hbm_copy_calculate_metrics(
num_elements: int, dtype: jnp.dtype, time_ms_list: list
) -> Dict[str, Any]:
"""Calculates the metrics for the single device hbm copy benchmark."""
# Build dictionary of all the parameters in the function
params = locals().items()
metadata = get_metrics_helper(params)
metrics = {}

# Calculate throughput.
tensor_size_bytes = num_elements * dtype.dtype.itemsize

tensor_size_gbytes = (tensor_size_bytes * 2) / 10**9
time_statistics = MetricsStatistics(
metrics_list=time_ms_list, metrics_name="time_ms"
)
time_s_list = [time_ms / 10**3 for time_ms in time_ms_list]
bw_gbyte_sec_list = [tensor_size_gbytes / time_s for time_s in time_s_list]
statistics = MetricsStatistics(
metrics_list=bw_gbyte_sec_list, metrics_name="bw_gbyte_sec"
)
print(
f"Tensor size: {tensor_size_bytes / 1024**2} MB, time taken (median):"
f" {time_statistics.statistics['p50']:.4f} ms, bandwidth (median): {statistics.statistics['p50']:.3f} GB/s"
)
print()
# Gather the metrics to report.
metadata.update(
{
"tensor_size_gbytes": tensor_size_gbytes,
}
)
metrics.update(time_statistics.serialize_statistics())
metrics.update(statistics.serialize_statistics())
metrics = {key: value for key, value in metrics.items() if value is not None}
return metadata, metrics
2 changes: 2 additions & 0 deletions Ironwood/src/run_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
}
HBM_BENCHMARK_MAP = {
"single_device_hbm_copy": "benchmark_hbm.single_device_hbm_copy",
"multiple_device_hbm_copy": "benchmark_hbm.multiple_device_hbm_copy",
}
COMPUTE_BENCHMARK_MAP = {
"gemm_simple": "benchmark_gemm.gemm_simple",
Expand All @@ -62,6 +63,7 @@
"gemm_throttling": "benchmark_gemm_throttling.gemm_throttling",
"gemm": "benchmark_gemm.gemm",
"gemm_accum": "benchmark_gemm.gemm_accum",
"gemm_multiple_devices": "benchmark_gemm.gemm_multiple_devices",
"quantization": "benchmark_compute.quantization",
"transpose_quantization": "benchmark_compute.transpose_quantization",
"quantization_static_scaling": (
Expand Down