Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 4 additions & 86 deletions benchmarks/kernels/benchmark_moe_align_block_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

import torch

from vllm import _custom_ops as ops
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
moe_align_block_size_triton,
moe_align_block_size,
)
from vllm.triton_utils import triton

Expand All @@ -21,60 +20,6 @@ def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
)


def check_correctness(num_tokens, num_experts=256, block_size=256, topk=8):
"""
Verifies vllm vs. Triton
"""
topk_ids = get_topk_ids(num_tokens, num_experts, topk)

# 1. malloc space for triton and vllm
# malloc enough space (max_num_tokens_padded) for the sorted ids
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids_triton = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device="cuda"
)
expert_ids_triton = torch.empty(
(max_num_tokens_padded // block_size,), dtype=torch.int32, device="cuda"
)
num_tokens_post_pad_triton = torch.empty((1,), dtype=torch.int32, device="cuda")

sorted_ids_vllm = torch.empty_like(sorted_ids_triton)
expert_ids_vllm = torch.empty_like(expert_ids_triton)
num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_triton)

# 2. run implementations
moe_align_block_size_triton(
topk_ids,
num_experts,
block_size,
sorted_ids_triton,
expert_ids_triton,
num_tokens_post_pad_triton,
)

ops.moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids_vllm,
expert_ids_vllm,
num_tokens_post_pad_vllm,
)
print(f"✅ VLLM implementation works with {num_experts} experts!")

# 3. compare results
if torch.allclose(expert_ids_triton, expert_ids_vllm) and torch.allclose(
num_tokens_post_pad_triton, num_tokens_post_pad_vllm
):
print("✅ Triton and VLLM implementations match.")
else:
print("❌ Triton and VLLM implementations DO NOT match.")
print("Triton expert_ids:", expert_ids_triton)
print("VLLM expert_ids:", expert_ids_vllm)
print("Triton num_tokens_post_pad:", num_tokens_post_pad_triton)
print("VLLM num_tokens_post_pad:", num_tokens_post_pad_vllm)


# test configurations
num_tokens_range = [1, 16, 256, 4096]
num_experts_range = [16, 64, 224, 256, 280, 512]
Expand All @@ -87,8 +32,8 @@ def check_correctness(num_tokens, num_experts=256, block_size=256, topk=8):
x_names=["num_tokens", "num_experts", "topk"],
x_vals=configs,
line_arg="provider",
line_vals=["vllm", "triton"], # "triton"
line_names=["VLLM", "Triton"], # "Triton"
line_vals=["vllm"],
line_names=["vLLM"],
Comment on lines -90 to +36
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we move the kernel into the benchmark script? It seems not that useful to have a benchmark as just one kernel

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I leave it here because I think someone else may reuse this script to compare. eg. new_vllm vs, old_vllm

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think moving the kernel into the benchmark doesn't help because we already win the triton kernel for every shape, so we don't need to maintain it. In the future when new kernel comes, just compare with the current version would be good, what's your thought?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CC @mgoin

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, that is fair enough. As long as we have a naive implementation to unit test against, it should be okay

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, someone else developed a torch version for unit test. But it is quite slow, so I don't choose to compare it for benchmark.

plot_name="moe-align-block-size-performance",
args={},
)
Expand All @@ -98,36 +43,11 @@ def benchmark(num_tokens, num_experts, topk, provider):
block_size = 256
topk_ids = get_topk_ids(num_tokens, num_experts, topk)

max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda")
max_num_m_blocks = max_num_tokens_padded // block_size
expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda")
num_tokens_post_pad = torch.empty((1,), dtype=torch.int32, device="cuda")

quantiles = [0.5, 0.2, 0.8]

if provider == "vllm":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: ops.moe_align_block_size(
topk_ids,
num_experts,
block_size,
sorted_ids.clone(),
expert_ids.clone(),
num_tokens_post_pad.clone(),
),
quantiles=quantiles,
)
elif provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(
lambda: moe_align_block_size_triton(
topk_ids,
num_experts,
block_size,
sorted_ids.clone(),
expert_ids.clone(),
num_tokens_post_pad.clone(),
),
lambda: moe_align_block_size(topk_ids, block_size, num_experts),
quantiles=quantiles,
)

Expand All @@ -151,6 +71,4 @@ def benchmark(num_tokens, num_experts, topk, provider):
)
args = parser.parse_args()

print("Running correctness check...")
check_correctness(num_tokens=1024, num_experts=args.num_experts, topk=args.topk)
benchmark.run(print_data=True, show_plots=True)
140 changes: 2 additions & 138 deletions vllm/model_executor/layers/fused_moe/moe_align_block_size.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,144 +5,8 @@
import torch

from vllm import _custom_ops as ops
from vllm.triton_utils import tl, triton
from vllm.utils import cdiv, round_up


@triton.jit
def moe_align_block_size_stage1(
topk_ids_ptr,
tokens_cnts_ptr,
num_experts: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
):
pid = tl.program_id(0)

start_idx = pid * tokens_per_thread

off_c = (pid + 1) * num_experts

for i in range(tokens_per_thread):
if start_idx + i < numel:
idx = tl.load(topk_ids_ptr + start_idx + i)
token_cnt = tl.load(tokens_cnts_ptr + off_c + idx)
tl.store(tokens_cnts_ptr + off_c + idx, token_cnt + 1)


@triton.jit
def moe_align_block_size_stage2(
tokens_cnts_ptr,
num_experts: tl.constexpr,
):
pid = tl.program_id(0)

last_cnt = 0
for i in range(1, num_experts + 1):
token_cnt = tl.load(tokens_cnts_ptr + i * num_experts + pid)
last_cnt = last_cnt + token_cnt
tl.store(tokens_cnts_ptr + i * num_experts + pid, last_cnt)


@triton.jit
def moe_align_block_size_stage3(
total_tokens_post_pad_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
):
last_cumsum = 0
off_cnt = num_experts * num_experts
for i in range(1, num_experts + 1):
token_cnt = tl.load(tokens_cnts_ptr + off_cnt + i - 1)
last_cumsum = last_cumsum + tl.cdiv(token_cnt, block_size) * block_size
tl.store(cumsum_ptr + i, last_cumsum)
tl.store(total_tokens_post_pad_ptr, last_cumsum)


@triton.jit
def moe_align_block_size_stage4(
topk_ids_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
tokens_cnts_ptr,
cumsum_ptr,
num_experts: tl.constexpr,
block_size: tl.constexpr,
numel: tl.constexpr,
tokens_per_thread: tl.constexpr,
):
pid = tl.program_id(0)
start_idx = tl.load(cumsum_ptr + pid)
end_idx = tl.load(cumsum_ptr + pid + 1)

for i in range(start_idx, end_idx, block_size):
tl.store(expert_ids_ptr + i // block_size, pid)

start_idx = pid * tokens_per_thread
off_t = pid * num_experts

for i in range(start_idx, tl.minimum(start_idx + tokens_per_thread,
numel)):
expert_id = tl.load(topk_ids_ptr + i)
token_cnt = tl.load(tokens_cnts_ptr + off_t + expert_id)
rank_post_pad = token_cnt + tl.load(cumsum_ptr + expert_id)
tl.store(sorted_token_ids_ptr + rank_post_pad, i)
tl.store(tokens_cnts_ptr + off_t + expert_id, token_cnt + 1)


# Triton implementation based on:
# https://github.com/sgl-project/sglang/commit/ba5112ff691d791a9e38c6c71f59324a5fcb49d0
def moe_align_block_size_triton(
topk_ids: torch.Tensor,
num_experts: int,
block_size: int,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
) -> None:
numel = topk_ids.numel()
grid = (num_experts, )
tokens_cnts = torch.zeros((num_experts + 1, num_experts),
dtype=torch.int32,
device=topk_ids.device)
cumsum = torch.zeros((num_experts + 1, ),
dtype=torch.int32,
device=topk_ids.device)
tokens_per_thread = cdiv(numel, num_experts)
sorted_token_ids.fill_(numel)
expert_ids.zero_()

moe_align_block_size_stage1[grid](
topk_ids,
tokens_cnts,
num_experts,
numel,
tokens_per_thread,
)
moe_align_block_size_stage2[grid](
tokens_cnts,
num_experts,
)
moe_align_block_size_stage3[(1, )](
num_tokens_post_pad,
tokens_cnts,
cumsum,
num_experts,
block_size,
)
moe_align_block_size_stage4[grid](
topk_ids,
sorted_token_ids,
expert_ids,
tokens_cnts,
cumsum,
num_experts,
block_size,
numel,
tokens_per_thread,
)
from vllm.triton_utils import triton
from vllm.utils import round_up


def moe_align_block_size(
Expand Down