Skip to content

Commit 2eae09b

Browse files
[moe training] add memory bandwidth calculations to kernel benchmarking scripts (#2769)
1 parent aed4f84 commit 2eae09b

File tree

6 files changed

+75
-25
lines changed

6 files changed

+75
-25
lines changed

benchmarks/prototype/moe_training/benchmark_moe_fsdp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from benchmarks.prototype.moe_training.utils import (
2626
bench_fwd_bwd_microseconds,
27-
profile_fn,
27+
profile_fwd_bwd,
2828
)
2929

3030
# this feature requires CUDA and SM89+
@@ -128,7 +128,7 @@ def warmup(model, input):
128128
print(f"BF16 time: {bf16_us} us")
129129
if enable_profile:
130130
print("Profiling bf16 training")
131-
profile_fn(ref_model, ref_x, labels=labels, profile_name="bf16_profile")
131+
profile_fwd_bwd(ref_model, ref_x, labels=labels, profile_name="bf16_profile")
132132

133133
scaled_us = bench_fwd_bwd_microseconds(
134134
model,
@@ -140,7 +140,7 @@ def warmup(model, input):
140140
print(f"Scaled time: {scaled_us} us")
141141
if enable_profile:
142142
print("Profiling quantized training")
143-
profile_fn(model, x, labels=labels, profile_name=f"{recipe_name}_profile")
143+
profile_fwd_bwd(model, x, labels=labels, profile_name=f"{recipe_name}_profile")
144144

145145
print(f"Speedup: {bf16_us / scaled_us:.3f}x")
146146
dist.destroy_process_group()

benchmarks/prototype/moe_training/benchmark_per_group_scaling_kernels.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
triton_fp8_per_group_rowwise_scales,
2020
)
2121
from torchao.prototype.moe_training.utils import (
22+
generate_jagged_offs,
2223
torch_to_float8_per_group_colwise,
2324
torch_to_float8_per_group_rowwise,
2425
)
@@ -40,6 +41,8 @@ class ExperimentConfig:
4041
class ExperimentResult:
4142
torch_time_us: float
4243
triton_time_us: float
44+
torch_mem_bw_gbps: float
45+
triton_mem_bw_gbps: float
4346

4447

4548
@dataclass(frozen=True)
@@ -50,7 +53,7 @@ class Experiment:
5053

5154
def get_configs() -> List[ExperimentConfig]:
5255
input_shapes = [(16640, 5120)] # (Mg, K)
53-
n_groups_list = [16, 128]
56+
n_groups_list = [1, 16, 128]
5457
high_precision_dtypes = [torch.bfloat16]
5558
configs = []
5659
for input_shape, n_groups, high_precision_dtype in itertools.product(
@@ -81,15 +84,9 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
8184
# that occurs in the backward pass of the differentiable scaled grouped mm.
8285
# - the transposed tensor in col-major format with groups along the row dimension,
8386
# which represents the right operand.
84-
group_size = input_row_major.shape[1] // config.n_groups
8587
n_groups = config.n_groups
86-
offs = torch.arange(
87-
group_size,
88-
group_size * n_groups + 1,
89-
group_size,
90-
device=device,
91-
dtype=torch.int32,
92-
)
88+
Mg = input_row_major.shape[0]
89+
offs = generate_jagged_offs(n_groups, Mg, multiple_of=16)
9390

9491
def warmup(func, *args, **kwargs):
9592
for _ in range(10):
@@ -140,9 +137,21 @@ def run_triton(
140137
run_triton, input_row_major, input_col_major, offs
141138
)
142139

140+
# mem bw calculations - excluding scales to simplify calculation
141+
# but still get an accurate estimate.
142+
bytes_per_input_el = torch.finfo(config.high_precision_dtype).bits / 8
143+
num_elements = input_tensor.numel()
144+
read_bytes = num_elements * bytes_per_input_el
145+
write_bytes = num_elements # 1 byte per element in float8_e4m3fn
146+
read_write_bytes = read_bytes + write_bytes
147+
torch_mem_bw_gbps = (read_write_bytes) / (torch_time_us / 1e6) / 1e9
148+
triton_mem_bw_gbps = (read_write_bytes) / (triton_time_us / 1e6) / 1e9
149+
143150
return ExperimentResult(
144151
torch_time_us=torch_time_us,
145152
triton_time_us=triton_time_us,
153+
torch_mem_bw_gbps=torch_mem_bw_gbps,
154+
triton_mem_bw_gbps=triton_mem_bw_gbps,
146155
)
147156

148157

@@ -153,6 +162,8 @@ def print_results(experiments: List[Experiment]):
153162
"high_precision_dtype",
154163
"torch_time_us",
155164
"triton_time_us",
165+
"torch_mem_bw_gbps",
166+
"triton_mem_bw_gbps",
156167
"triton_speedup",
157168
]
158169
rows = []
@@ -167,6 +178,8 @@ def print_results(experiments: List[Experiment]):
167178
experiment.config.high_precision_dtype,
168179
experiment.result.torch_time_us,
169180
experiment.result.triton_time_us,
181+
round(experiment.result.torch_mem_bw_gbps, 3),
182+
round(experiment.result.triton_mem_bw_gbps, 3),
170183
f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x",
171184
]
172185
)

benchmarks/prototype/moe_training/benchmark_rowwise_3d_quant_kernels.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class ExperimentConfig:
3737
class ExperimentResult:
3838
torch_time_us: float
3939
triton_time_us: float
40+
torch_mem_bw_gbps: float
41+
triton_mem_bw_gbps: float
4042

4143

4244
@dataclass(frozen=True)
@@ -48,8 +50,12 @@ class Experiment:
4850
def get_configs() -> List[ExperimentConfig]:
4951
# Llama4 shapes
5052
input_shapes = [
53+
(1, 8192, 5120), # w1, w3
54+
(1, 5120, 8192), # w2
5155
(16, 8192, 5120), # w1, w3
5256
(16, 5120, 8192), # w2
57+
(128, 8192, 5120), # w1, w3
58+
(128, 5120, 8192), # w2
5359
]
5460
high_precision_dtypes = [torch.bfloat16]
5561
configs = []
@@ -110,9 +116,25 @@ def run_triton(input_tensor: torch.Tensor):
110116
input_tensor,
111117
)
112118

119+
# mem bw calculations - excluding scales to simplify calculation
120+
# but still get an accurate estimate.
121+
bytes_per_input_el = torch.finfo(config.high_precision_dtype).bits / 8
122+
bytes_per_output_el = torch.finfo(torch.float8_e4m3fn).bits / 8
123+
num_elements = input_tensor.numel()
124+
125+
read_bytes = num_elements * bytes_per_input_el
126+
write_bytes = num_elements * bytes_per_output_el
127+
128+
# Both torch.compile codegen and the triton kernel read the input tensor twice
129+
# (once for scale calculations, once for scaling + casting).
130+
torch_mem_bw_gbps = ((read_bytes * 2 + write_bytes) / 1e9) / (torch_time_us / 1e6)
131+
triton_mem_bw_gbps = ((read_bytes * 2 + write_bytes) / 1e9) / (triton_time_us / 1e6)
132+
113133
return ExperimentResult(
114134
torch_time_us=torch_time_us,
115135
triton_time_us=triton_time_us,
136+
torch_mem_bw_gbps=torch_mem_bw_gbps,
137+
triton_mem_bw_gbps=triton_mem_bw_gbps,
116138
)
117139

118140

@@ -121,6 +143,8 @@ def print_results(experiments: List[Experiment]):
121143
"input_shape",
122144
"torch_time_us",
123145
"triton_time_us",
146+
"torch_mem_bw_gbps",
147+
"triton_mem_bw_gbps",
124148
"triton_speedup",
125149
]
126150
rows = []
@@ -131,6 +155,8 @@ def print_results(experiments: List[Experiment]):
131155
input_shape,
132156
experiment.result.torch_time_us,
133157
experiment.result.triton_time_us,
158+
round(experiment.result.torch_mem_bw_gbps, 3),
159+
round(experiment.result.triton_mem_bw_gbps, 3),
134160
f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x",
135161
]
136162
)

benchmarks/prototype/moe_training/benchmark_scaled_grouped_mm_dq.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import torch
1313
from tabulate import tabulate
1414
from tqdm import tqdm
15-
from utils import bench_fwd_bwd_microseconds, profile_fn
15+
from utils import bench_fwd_bwd_microseconds, profile_fwd_bwd
1616

1717
from torchao.prototype.moe_training import _scaled_grouped_mm
1818
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
@@ -46,8 +46,9 @@ class Experiment:
4646

4747

4848
def get_configs() -> List[ExperimentConfig]:
49+
# Llama4 shapes
4950
A_shapes = [(16640, 5120)]
50-
B_shapes = [(16, 8192, 5120)]
51+
B_shapes = [(1, 8192, 5120), (16, 8192, 5120), (128, 8192, 5120)]
5152
recipes = [MoEScalingType.FP8_ROWWISE]
5253
high_precision_dtypes = [torch.bfloat16]
5354
configs = []
@@ -91,7 +92,8 @@ def run_experiment(
9192
# - the transposed tensor in col-major format with groups along the row dimension,
9293
# which represents the right operand.
9394
n_groups = config.B_shape[0]
94-
offs = generate_jagged_offs(n_groups, A.shape[0], multiple_of=16)
95+
Mg = A.shape[0]
96+
offs = generate_jagged_offs(n_groups, Mg, multiple_of=16)
9597

9698
labels = torch.ones(
9799
(A.shape[0], B_t.shape[-1]), device=device, dtype=torch.bfloat16
@@ -107,7 +109,7 @@ def run_experiment(
107109
use_compile=args.compile,
108110
)
109111
if args.profile:
110-
profile_fn(
112+
profile_fwd_bwd(
111113
torch._grouped_mm,
112114
A,
113115
B_t,
@@ -128,7 +130,7 @@ def run_experiment(
128130
use_compile=args.compile,
129131
)
130132
if args.profile:
131-
profile_fn(
133+
profile_fwd_bwd(
132134
_scaled_grouped_mm,
133135
A,
134136
B_t,

benchmarks/prototype/moe_training/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def bench_fwd_bwd_microseconds(
2323
return statistics.median(times)
2424

2525

26-
def profile_fn(
26+
def profile_fwd_bwd(
2727
fn,
2828
*args,
2929
labels=None,

torchao/prototype/moe_training/kernels/float8_rowwise.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@
2626
torch.float64: tl.float64,
2727
}
2828

29-
block_sizes_n = [32, 128, 512] # large dim (output_features)
30-
block_sizes_k = [32, 128, 512] # small dim (input_features)
31-
num_warps = [8]
32-
num_stages = [2, 4]
29+
block_sizes_n = [32, 128, 256] # large dim (output_features)
30+
block_sizes_k = [32, 128, 256] # small dim (input_features)
31+
num_warps = [2, 4]
32+
num_stages = [2, 3, 4, 5, 6]
3333
kernel_configs_2D = [
3434
triton.Config(
3535
{"BLOCK_SIZE_N": block_size_n, "BLOCK_SIZE_K": block_size_k},
@@ -176,9 +176,18 @@ def _triton_fp8_rowwise_3d_transpose_scales_rhs_kernel(
176176
input_dtype
177177
)
178178

179-
# compute scales with local amax, using axis=0 because for each expert,
180-
# we are reading the non-transposed input, and want to compute the scales
181-
# along axis=1 for the transposed input.
179+
# In a normal torch implementation, we should transpose the tensor then compute the amax
180+
# along the dim1 (N), to compute colwise scales for a RHS operand of a scaled grouped gemm:
181+
# input_data = input_data.transpose(-2,-1) # (E, K, N) -> (E, N, K)
182+
# amaxes = input_data.abs().max(dim=1) # (E, N, K) -> (E, 1, K)
183+
#
184+
# Here, we are reading a (K, N) chunk for a given E, and computing the amax along the dim=1 (N)
185+
# to compute an equivalent scale of shape (K,) for this chunk of the expert.
186+
# We then use atomic min to compute the final scale for these logical columns of the transposed tensor.
187+
#
188+
# Later, we will use this scale to cast the same (K,N) input chunk to fp8 and transpose it to (N, K) before
189+
# writing it to the output tensor.
190+
# ((K, N) * (K, 1))^T = (N, K)
182191
amaxes = tl.max(tl.abs(input_data), axis=1).to(tl.float64) # (K,)
183192
scales = (fp8_dtype_max / tl.clamp(amaxes, min=EPS, max=float("inf"))).to(
184193
tl.float32

0 commit comments

Comments
 (0)